PyTorch中的Variable变量详解
作者:Wei Ji 发布时间:2023-02-19 18:48:47
一、了解Variable
顾名思义,Variable就是 变量 的意思。实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属性。
具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生片花,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了。(也就是说,pytorch都是有tensor计算的,而tensor里面的参数都是Variable的形式)。如果用Variable计算的话,那返回的也是一个同类型的Variable。
【tensor 是一个多维矩阵】
用一个例子说明,Variable的定义:
import torch
from torch.autograd import Variable # torch 中 Variable 模块
tensor = torch.FloatTensor([[1,2],[3,4]])
# 把鸡蛋放到篮子里, requires_grad是参不参与误差反向传播, 要不要计算梯度
variable = Variable(tensor, requires_grad=True)
print(tensor)
"""
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
print(variable)
"""
Variable containing:
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
注:tensor不能反向传播,variable可以反向传播。
二、Variable求梯度
Variable计算时,它会逐渐地生成计算图。这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力。
v_out.backward() # 模拟 v_out 的误差反向传递
print(variable.grad) # 初始 Variable 的梯度
'''
0.5000 1.0000
1.5000 2.0000
'''
三、获取Variable里面的数据
直接print(Variable) 只会输出Variable形式的数据,在很多时候是用不了的。所以需要转换一下,将其变成tensor形式。
print(variable) # Variable 形式
"""
Variable containing:
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
print(variable.data) # 将variable形式转为tensor 形式
"""
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
print(variable.data.numpy()) # numpy 形式
"""
[[ 1. 2.]
[ 3. 4.]]
"""
扩展
在PyTorch中计算图的特点总结如下:
autograd根据用户对Variable的操作来构建其计算图。
1、requires_grad
variable默认是不需要被求导的,即requires_grad属性默认为False,如果某一个节点的requires_grad为True,那么所有依赖它的节点requires_grad都为True。
2、volatile
variable的volatile属性默认为False,如果某一个variable的volatile属性被设为True,那么所有依赖它的节点volatile属性都为True。volatile属性为True的节点不会求导,volatile的优先级比requires_grad高。
3、retain_graph
多次反向传播(多层监督)时,梯度是累加的。一般来说,单次反向传播后,计算图会free掉,也就是反向传播的中间缓存会被清空【这就是动态度的特点】。为进行多次反向传播需指定retain_graph=True来保存这些缓存。
4、backward()
反向传播,求解Variable的梯度。放在中间缓存中。
来源:https://blog.csdn.net/qq_19329785/article/details/85029116
猜你喜欢
- 本文实例讲述了Python中random模块用法。分享给大家供大家参考。具体如下:import randomx = random.randi
- 1 算术运算add(other)比如进行数学运算加上具体的一个数字data['open'].add(1)2018-02-27
- 如图所示,要处理的数据是一个json数组,而且非常大下图为电脑配置,使用 json.load() 方法加载上述json文件电脑直接卡死解决思
- def getFibonacci(num): res=[0,1] a=0 b=1 for x in
- 1、首先在系统盘中查找scrrun.dll,如果存在这个文件,请跳到第三步,如果没有,请执行第二步。 2、在安装文件目录i386中找到scr
- 系统环境:64位win7企业版python2.7.102016.08.16修改内容:1)read_until()函数是可以设置timeout
- 如果你正在负责一个基于SQL Server的项目,或者你刚刚接触SQL Server,你都有可能要面临一些数据库性能的问题,这篇文章会为你提
- 实现网页的键盘输入操作from selenium.webdriver.common.keys import Keys * 页有时需要将鼠标
- 列表生成式即List Comprehensions,是Python内置的非常简单却强大的可以用来创建list的生成式。一个循环在C语言等其他
- 本文实例讲述了Python Django框架url反向解析实现动态生成对应的url链接。分享给大家供大家参考,具体如下:url反向解析:根据
- 在网站开发的时候经常要用chr(),但本人比较懒没时间记那么多。于是到用到的时候就查,这样麻烦。现在将它写出来方便以后用到查,也方便大家!c
- 这个是今年年初写的一篇,拿出来温习下。指针让程序结构变得混乱,也让程序执行效率提高,因此在oo的语言中不提倡指针的使用,使得程序结构清晰易读
- 在SQL Server 2008 中,新的FILESTREAM 数据类型,允许像文件和图片这种大型的二进制数据可以直接在NTFS文件系统中进
- Pygal可用来生成可缩放的矢量图形文件,对于需要在尺寸不同的屏幕上显示的图表,这很有用,可以自动缩放,自适应观看者的屏幕1、Pygal模块
- 前言 大家好,好男人就是我,我就是好男人,我就是-0nise。在各大漏洞举报平台,我们时常会
- 如下拉框的text是<input type=button value=ggg>,那么生成的combobox里
- 前言由于Django是 * 站,所有每次请求均会去数据进行相应的操作,当程序访问量大时,耗时必然会更加明显,最简单解决方式是使用:缓存,缓存
- 1 unittest框架unittest 是python 的单元测试框架,它主要有以下作用:提供用例组织与执行:当你的测试用例只有几条时,可
- 前言中位数是一个可将数值集合划分为相等的上下两部分的一个数值。如果列表数据的个数是奇数,则列表中间那个数据就是列表数据的中位数;如果列表数据
- CSS(叠层样式表)和XSL(可扩展样式语言)都可以定义XML文件的显示,这两种方式有哪些不同以及它们在使用中的具体方法,我们将在本文给予介