PyTorch中的Variable变量详解
作者:Wei Ji 发布时间:2023-02-19 18:48:47
一、了解Variable
顾名思义,Variable就是 变量 的意思。实质上也就是可以变化的量,区别于int变量,它是一种可以变化的变量,这正好就符合了反向传播,参数更新的属性。
具体来说,在pytorch中的Variable就是一个存放会变化值的地理位置,里面的值会不停发生片花,就像一个装鸡蛋的篮子,鸡蛋数会不断发生变化。那谁是里面的鸡蛋呢,自然就是pytorch中的tensor了。(也就是说,pytorch都是有tensor计算的,而tensor里面的参数都是Variable的形式)。如果用Variable计算的话,那返回的也是一个同类型的Variable。
【tensor 是一个多维矩阵】
用一个例子说明,Variable的定义:
import torch
from torch.autograd import Variable # torch 中 Variable 模块
tensor = torch.FloatTensor([[1,2],[3,4]])
# 把鸡蛋放到篮子里, requires_grad是参不参与误差反向传播, 要不要计算梯度
variable = Variable(tensor, requires_grad=True)
print(tensor)
"""
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
print(variable)
"""
Variable containing:
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
注:tensor不能反向传播,variable可以反向传播。
二、Variable求梯度
Variable计算时,它会逐渐地生成计算图。这个图就是将所有的计算节点都连接起来,最后进行误差反向传递的时候,一次性将所有Variable里面的梯度都计算出来,而tensor就没有这个能力。
v_out.backward() # 模拟 v_out 的误差反向传递
print(variable.grad) # 初始 Variable 的梯度
'''
0.5000 1.0000
1.5000 2.0000
'''
三、获取Variable里面的数据
直接print(Variable) 只会输出Variable形式的数据,在很多时候是用不了的。所以需要转换一下,将其变成tensor形式。
print(variable) # Variable 形式
"""
Variable containing:
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
print(variable.data) # 将variable形式转为tensor 形式
"""
1 2
3 4
[torch.FloatTensor of size 2x2]
"""
print(variable.data.numpy()) # numpy 形式
"""
[[ 1. 2.]
[ 3. 4.]]
"""
扩展
在PyTorch中计算图的特点总结如下:
autograd根据用户对Variable的操作来构建其计算图。
1、requires_grad
variable默认是不需要被求导的,即requires_grad属性默认为False,如果某一个节点的requires_grad为True,那么所有依赖它的节点requires_grad都为True。
2、volatile
variable的volatile属性默认为False,如果某一个variable的volatile属性被设为True,那么所有依赖它的节点volatile属性都为True。volatile属性为True的节点不会求导,volatile的优先级比requires_grad高。
3、retain_graph
多次反向传播(多层监督)时,梯度是累加的。一般来说,单次反向传播后,计算图会free掉,也就是反向传播的中间缓存会被清空【这就是动态度的特点】。为进行多次反向传播需指定retain_graph=True来保存这些缓存。
4、backward()
反向传播,求解Variable的梯度。放在中间缓存中。
来源:https://blog.csdn.net/qq_19329785/article/details/85029116
猜你喜欢
- 十六进制(Hexadecimal)是计算机中数据的一种表示方法。同日常生活中的表示法不一样,它由0-9,A-F组成,字母不区分大小写。与10
- a. 如果欲使用gb2312编码,那么php要输出头:header(“Content-Type: text/html; charset=gb
- 文 | 李晓飞来源:Python 技术「ID: pythonall」爬虫程序想必大家都很熟悉了,随便写一个就可以获取网页上的信息,
- 我们将要来学习python的重要概念迭代和迭代器,通过简单实用的例子如列表迭代器和xrange。可迭代一个对象,物理或者虚拟存储的序列。li
- 接触 Python 不久,看到很多人写2048,自己也捣鼓了一个,主要是熟悉Python语法。程序使用Python3 写的,代码150行左右
- 本文实例讲述了php防止sql注入中过滤分页参数的方法。分享给大家供大家参考。具体分析如下:就网络安全而言,在网络上不要相信任何输入信息,对
- 一:简介由paramiko是用python语言写的一个模块,遵循SSH2协议,支持以加密和认证的方式,进行远程服务器的连接。由于使用的是py
- 本篇概要1.线程与多线程2.进程与多进程3.多线程并发下载图片4.多进程并发提高数字运算关于并发在计算机编程领域,并发编程是一个很常见的名词
- SQL Server Sa用户相信大家都有一定的理解,下面就为您介绍SQL Server 2000身份验证模式的修改方法及SQL Serve
- 前言本篇文章要使用OpenCV、Numpy 和Math这3个工具包实现一个简单的滤镜编辑器。在这个滤镜编辑器中,包含了3种滤镜效果,它们分别
- 1、安装setuptools命令如下:wget --no-check-certificate https://pypi.python.org
- MySQL常用的四种引擎的介绍(1):MyISAM存储引擎:不支持事务、也不支持外键,优势是访问速度快,对事务完整性没有 要求或者以sele
- 本篇阅读的代码实现了将输入的数字转化成一个列表,输入数字中的每一位按照从左到右的顺序成为列表中的一项。本篇阅读的代码片段来自于30-seco
- 使用步骤大致分为两步,就不多废话第一步、修改hosts文件将0.0.0.0 account.jetbrains.com添加到hosts文件最
- --相信大家肯定经常会把数据导入到数据库中,但是可能会有些记录行的所有列的数据是null,这为null的数据是我们不需要 --现在需要一个简
- prototype框架最早是出于方便Ruby开发人员进行JavaScript开发所构建的,从这个版本上更加体现的淋漓尽致。比起1.3.1版本
- Python3 解释器Linux/Unix的系统上,一般默认的 python 版本为 2.x,我们可以将 python3.x 安装在 /us
- 1 概述C/C++和Java(以及大多数的主流编程语言)都有自己成熟的单元测试框架,前者如Check,后者如JUnit,但这些编程框架本质上
- 在矩阵应用的过程中,经常需要使用随机数,那么怎么使用numpy 产生随机数呢 ,为此专门做一个总结。random模块用于生成随机数,下面是一
- 本文实例讲述了Python创建xml的方法。分享给大家供大家参考。具体实现方法如下:from xml.dom.minidom import