pytorch Variable与Tensor合并后 requires_grad()默认与修改方式
作者:西电小猪猪 发布时间:2021-08-05 09:11:59
pytorch更新完后合并了Variable与Tensor
torch.Tensor()能像Variable一样进行反向传播的更新,返回值为Tensor
Variable自动创建tensor,且返回值为Tensor,(所以以后不需要再用Variable)
Tensor创建后,默认requires_grad=Flase
可以通过xxx.requires_grad_()将默认的Flase修改为True
下面附代码及官方文档代码:
import torch
from torch.autograd import Variable #使用Variabl必须调用库
lis=torch.range(1,6).reshape((-1,3))#创建1~6 形状
#行不指定(-1意为由计算机自己计算)列为3的floattensor矩阵
print(lis)
print(lis.requires_grad) #查看默认的requires_grad是否是Flase
lis.requires_grad_() #使用.requires_grad_()修改默认requires_grad为true
print(lis.requires_grad)
结果如下:
tensor([[1., 2., 3.],
[4., 5., 6.]])
False
True
创建一个Variable,Variable必须接收Tensor数据 不能直接写为 a=Variable(range(6)).reshape((-1,3))
否则报错 Variable data has to be a tensor, but got range
正确如下:
import torch
from torch.autograd import Variable
tensor=torch.FloatTensor(range(8)).reshape((-1,4))
my_ten=Variable(tensor)
print(my_ten)
print(my_ten.requires_grad)
my_ten.requires_grad_()
print(my_ten.requires_grad)
结果:
tensor([[0., 1., 2., 3.],
[4., 5., 6., 7.]])
False
True
由上面可以看出,Tensor完全可以取代Variable。
下面给出官方文档:
# 默认创建requires_grad = False的Tensor
x = torch . ones ( 1 ) # create a tensor with requires_grad=False (default)
x . requires_grad
# out: False
# 创建另一个Tensor,同样requires_grad = False
y = torch . ones ( 1 ) # another tensor with requires_grad=False
# both inputs have requires_grad=False. so does the output
z = x + y
# 因为两个Tensor x,y,requires_grad=False.都无法实现自动微分,
# 所以操作(operation)z=x+y后的z也是无法自动微分,requires_grad=False
z . requires_grad
# out: False
# then autograd won't track this computation. let's verify!
# 因而无法autograd,程序报错
z . backward ( )
# out:程序报错:RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
# now create a tensor with requires_grad=True
w = torch . ones ( 1 , requires_grad = True )
w . requires_grad
# out: True
# add to the previous result that has require_grad=False
# 因为total的操作中输入Tensor w的requires_grad=True,因而操作可以进行反向传播和自动求导。
total = w + z
# the total sum now requires grad!
total . requires_grad
# out: True
# autograd can compute the gradients as well
total . backward ( )
w . grad
#out: tensor([ 1.])
# and no computation is wasted to compute gradients for x, y and z, which don't require grad
# 由于z,x,y的requires_grad=False,所以并没有计算三者的梯度
z . grad == x . grad == y . grad == None
# True
existing_tensor . requires_grad_ ( )
existing_tensor . requires_grad
# out:True
或者直接用Tensor创建时给定requires_grad=True
my_tensor = torch.zeros(3,4,requires_grad = True)
my_tensor.requires_grad
# out: True
lis=torch.range(1,6,requires_grad=True).reshape((-1,3))
print(lis)
print(lis.requires_grad)
lis.requires_grad_()
print(lis.requires_grad)
结果
tensor([[1., 2., 3.],
[4., 5., 6.]], requires_grad=True)
True
True
补充:volatile 和 requires_grad在pytorch中的意思
Backward过程中排除子图
pytorch的BP过程是由一个函数决定的,loss.backward(), 可以看到backward()函数里并没有传要求谁的梯度。那么我们可以大胆猜测,在BP的过程中,pytorch是将所有影响loss的Variable都求了一次梯度。
但是有时候,我们并不想求所有Variable的梯度。那就要考虑如何在Backward过程中排除子图(ie.排除没必要的梯度计算)。
如何BP过程中排除子图? Variable的两个参数(requires_grad和volatile)
requires_grad=True 要求梯度
requires_grad=False 不要求梯度
volatile=True相当于requires_grad=False。反之则反之。。。。。。。ok
注意:如果a是requires_grad=True,b是requires_grad=False。则c=a+b是requires_grad=True。同样的道理应用于volatile
为什么要排除子图
也许有人会问,梯度全部计算,不更新的话不就得了。
这样就涉及了效率的问题了,计算很多没用的梯度是浪费了很多资源的(时间,计算机内存)
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/weixin_43635550/article/details/100192797


猜你喜欢
- 往mysql数据库中插入数据。以前常用INSERT INTO 表名 (列名1,列名2…) VALUES(列值1,列值2);如果在PHP程序中
- 背景在实现图片转码的需求时,需要支持最大 500 个图片下载后转换格式;如果是一个一个下载后转码,耗时太长,需要使用 goroutine 实
- 本文实例讲述了Linux下安装Memcached服务器和客户端与php使用。分享给大家供大家参考,具体如下:Memcached是高性能的分布
- 目录简介图形加载和说明图形的灰度灰度图像的压缩原始图像的压缩总结简介本文将会以图表的形式为大家讲解怎么在NumPy中进行多维数据的线性代数运
- GIL(Global Interpreter Lock,即全局解释器锁)1.为什么有GIL设计者为了规避类似于内存管理这样的复杂的竞争风险问
- //1、运行到C盘根目录 //2、输入:SET ORACLE_SID = 你的SID名称 3、输入:sqlplus/nolog 4、输入:c
- 我们在选择一件商品的时候,会先了解一些相关的商品信息,根据自己的需求和情况再进行选择。这种现象也同样适用于找工作,筛选一个岗位的重要环节,就
- 在ASP与ASP.NET之间共享对话状态(1)ASP实现原来的ASP对话只能将对话数据保存在内存中。为了将对话数据保存到SQL Server
- 用户输入1、使用input来等待用户输入。如 username = input('username:') password
- 前言我们写好的gin项目想要部署在服务器上,我们应该怎么做呢,接下来我会详细的讲解一下部署教程。1.首先我们要有一台虚拟机,虚拟机上安装好g
- 前言前几天写了一篇MySQL高并发生成唯一订单号的方法,有人私信问有没有SQL server版本的,今天中午特地写了SQL server版本
- 本文实例讲述了python采集百度百科的方法。分享给大家供大家参考。具体如下:#!/usr/bin/python# -*- coding:
- 导读:有时候,为了开发项目,我们需要在一台服务器上部署MySql数据库服务器,然后使用本地电脑远程访问和管理MySql数据库,那么如何实现M
- 适合各种浏览器的js拖动层,ie,firefox等,调用方便!<!DOCTYPE HTML PUBLIC "-//W3C//
- 身体是革命的本钱,身体健康了我们才有更多精力做自己想做的事情,追求女神,追求梦想。然而程序员是一个苦比的职业,大部分时间都对着电脑,我现在颈
- 编解码器在字符与字节之间的转换过程称为编解码,Python自带了超过100种编解码器,比如:ascii(英文体系)gb2312(中文体系)u
- 最近被“模块化”缠身,又是文章又是PPT的,被逼着想了很多相关的东西。整理下我这段时间对于“模块化”的思考,大多都是我自己从事页面重构这份工
- 案例:如果我们起了一个协程,但这个协程出现了panic,但我们没有捕获这个协程,就会造成程序的崩溃,这时可以在goroutine中使用rec
- 前言复习试题时,发现一道复数问题问题关于 Python 的复数类型,以下选项中描述错误的是A复数的虚数部分通过后缀“J”或者“j”来表示B对
- 0、前言在python2.7及以上的版本,str.format()的方式为格式化提供了非常大的便利。与之前的%型格式化字符串相比,他显得更为