Pytorch之Variable的用法
作者:啧啧啧biubiu 发布时间:2022-01-19 04:16:39
1.简介
torch.autograd.Variable是Autograd的核心类,它封装了Tensor,并整合了反向传播的相关实现
Variable和tensor的区别和联系
Variable是篮子,而tensor是鸡蛋,鸡蛋应该放在篮子里才能方便拿走(定义variable时一个参数就是tensor)
Variable这个篮子里除了装了tensor外还有requires_grad参数,表示是否需要对其求导,默认为False
Variable这个篮子呢,自身有一些属性
比如grad,梯度variable.grad是d(y)/d(variable)保存的是变量y对variable变量的梯度值,如果requires_grad参数为False,所以variable.grad返回值为None,如果为True,返回值就为对variable的梯度值
比如grad_fn,对于用户自己创建的变量(Variable())grad_fn是为none的,也就是不能调用backward函数,但对于由计算生成的变量,如果存在一个生成中间变量的requires_grad为true,那其的grad_fn不为none,反则为none
比如data,这个就很简单,这个属性就是装的鸡蛋(tensor)
Varibale包含三个属性:
data:存储了Tensor,是本体的数据 grad:保存了data的梯度,本事是个Variable而非Tensor,与data形状一致 grad_fn:指向Function对象,用于反向传播的梯度计算之用
代码1
import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
y = x + temp + 2
y = y.mean() #求平均数
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)
输出1
none
(因为requires_grad=False)
代码2
import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
y = x + temp + 2
y = y.mean() #求平均数
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(temp.grad) # d(y)/d(temp)
输出2
tensor([[0.2500, 0.2500],
[0.2500, 0.2500]])
代码3
import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
y = x + 2
y = y.mean() #求平均数
y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(x.grad) # d(y)/d(x)
输出3
Traceback (most recent call last):
File "path", line 12, in <module>
y.backward()
(报错了,因为生成变量y的中间变量只有x,而x的requires_grad是False,所以y的grad_fn是none)
代码4
import numpy as np
import torch
from torch.autograd import Variable
x = Variable(torch.ones(2,2),requires_grad = False)
temp = Variable(torch.zeros(2,2),requires_grad = True)
y = x + 2
y = y.mean() #求平均数
#y.backward() #反向传递函数,用于求y对前面的变量(x)的梯度
print(y.grad_fn) # d(y)/d(x)
输出4
none
2.grad属性
在每次backward后,grad值是会累加的,所以利用BP算法,每次迭代是需要将grad清零的。
x.grad.data.zero_()
(in-place操作需要加上_,即zero_)
来源:https://blog.csdn.net/qq_37385726/article/details/81706820


猜你喜欢
- mysql常用导出数据命令:1.mysql导出整个数据库 mysqldump -hhostname -uusername -pp
- 由于一些读者对于960 Grid System CSS Framework的原理和使用方法比较感兴趣,暴风彬彬今天将和大家一同分享这篇关于9
- python 包含子目录中的模块方法比较简单,关键是能够在sys.path里面找到通向模块文件的路径。下面将具体介绍几种常用情况: (1)主
- pip源配置文件可以放置的位置:Linux/Unix:/etc/pip.con~/.pip/pip.conf (每一个我都找了都没有,所以我
- 使用vue制作加载更多功能,通过ajax获取的数据往data里面push经常不成功,原因是push是往数组中追加数据内容的,而不能用作数组之
- 本文实例讲述了python排序方法。分享给大家供大家参考。具体如下:>>> def my_key1(x):... &nbs
- 脚本之家下载地址:https://www.jb51.net/softs/79861.html官网下载地址:http://www.micros
- 本文实例讲述了python采用getopt解析命令行输入参数的方法,分享给大家供大家参考。具体实例代码如下:import getopt im
- 前记在Python3.7后官方库出现了contextvars模块, 它的主要功能就是可以为多线程以及asyncio生态添加上下文功能,即使程
- 本文转自微信公众号:"算法与编程之美" Python用HBuilder创建交流社区APP 基于P
- 在Google Reader上看到网友分享的一个链接,真的发现自己已经out了。上面的这张图,是纯CSS实现的,没有背景图、没有Javasc
- 一、python-yml文件读写使用库 :import yaml安装:pip install pyyaml示例:文件config2.ymlg
- 引言近期做一些基于TCP协议的项目,跟其他接口方调试时经常出现不一致的问题,而程序日志又不能完成保证公正,就只能通过tcpdump抓包的方式
- 本文讲述了Symfony核心类。分享给大家供大家参考,具体如下:Symfony的核心类Symfony的MVC方式使用了一些你以后会经常碰到的
- 如果有一个多任务多loss的网络,那么在训练时,loss是如何工作的呢?比如下面:model = Model(inputs = input,
- python datetime 和时间戳互转import datetime, timenow = datetime.datetime.now
- 前言Pandas是Python下一个开源数据分析的库,它提供的数据结构DataFrame极大的简化了数据分析过程中一些繁琐操作。1. 基本使
- 一、服务器环境1、系统windows2003 中文企业版 sp22、mysql 5.1.553、php 5.2.174、IIS 6.0二、破
- os模块下有两个函数:os.walk()os.listdir()# -*- coding: utf-8 -*- &
- 一.假设有数据集dfdf.isnull()返回DateFrame,元素为空或者NA就显示True,否则就是False二.判断有空值的列df.