pytorch_detach 切断网络反传方式
作者:JJunQw 发布时间:2022-09-25 21:10:50
detach
官方文档中,对这个方法是这么介绍的。
detach = _add_docstr(_C._TensorBase.detach, r"""
Returns a new Tensor, detached from the current graph.
The result will never require gradient.
.. note::
Returned Tensor uses the same data tensor as the original one.
In-place modifications on either of them will be seen, and may trigger
errors in correctness checks.
""")
返回一个新的从当前图中分离的 Variable。
返回的 Variable 永远不会需要梯度
如果 被 detach 的Variable volatile=True, 那么 detach 出来的 volatile 也为 True
还有一个注意事项,即:返回的 Variable 和 被 detach 的Variable 指向同一个 tensor
import torch
from torch.nn import init
t1 = torch.tensor([1., 2.],requires_grad=True)
t2 = torch.tensor([2., 3.],requires_grad=True)
v3 = t1 + t2
v3_detached = v3.detach()
v3_detached.data.add_(t1) # 修改了 v3_detached Variable中 tensor 的值
print(v3, v3_detached) # v3 中tensor 的值也会改变
print(v3.requires_grad,v3_detached.requires_grad)
'''
tensor([4., 7.], grad_fn=<AddBackward0>) tensor([4., 7.])
True False
'''
在pytorch中通过拷贝需要切断位置前的tensor实现这个功能。tensor中拷贝的函数有两个,一个是clone(),另外一个是copy_(),clone()相当于完全复制了之前的tensor,他的梯度也会复制,而且在反向传播时,克隆的样本和结果是等价的,可以简单的理解为clone只是给了同一个tensor不同的代号,和‘='等价。所以如果想要生成一个新的分开的tensor,请使用copy_()。
不过对于这样的操作,pytorch中有专门的函数——detach()。
用户自己创建的节点是leaf_node(如图中的abc三个节点),不依赖于其他变量,对于leaf_node不能进行in_place操作.根节点是计算图的最终目标(如图y),通过链式法则可以计算出所有节点相对于根节点的梯度值.这一过程通过调用root.backward()就可以实现.
因此,detach所做的就是,重新声明一个变量,指向原变量的存放位置,但是requires_grad为false.更深入一点的理解是,计算图从detach过的变量这里就断了, 它变成了一个leaf_node.即使之后重新将它的requires_node置为true,它也不会具有梯度.
pytorch 梯度
(0.4之后),tensor和variable合并,tensor具有grad、grad_fn等属性;
默认创建的tensor,grad默认为False, 如果当前tensor_grad为None,则不会向前传播,如果有其它支路具有grad,则只传播其它支路的grad
# 默认创建requires_grad = False的Tensor
x = torch.ones(1) # create a tensor with requires_grad=False (default)
print(x.requires_grad)
# out: False
# 创建另一个Tensor,同样requires_grad = False
y = torch.ones(1) # another tensor with requires_grad=False
# both inputs have requires_grad=False. so does the output
z = x + y
# 因为两个Tensor x,y,requires_grad=False.都无法实现自动微分,
# 所以操作(operation)z=x+y后的z也是无法自动微分,requires_grad=False
print(z.requires_grad)
# out: False
# then autograd won't track this computation. let's verify!
# 因而无法autograd,程序报错
# z.backward()
# out:程序报错:RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
# now create a tensor with requires_grad=True
w = torch.ones(1, requires_grad=True)
print(w.requires_grad)
# out: True
# add to the previous result that has require_grad=False
# 因为total的操作中输入Tensor w的requires_grad=True,因而操作可以进行反向传播和自动求导。
total = w + z
# the total sum now requires grad!
total.requires_grad
# out: True
# autograd can compute the gradients as well
total.backward()
print(w.grad)
#out: tensor([ 1.])
# and no computation is wasted to compute gradients for x, y and z, which don't require grad
# 由于z,x,y的requires_grad=False,所以并没有计算三者的梯度
z.grad == x.grad == y.grad == None
# True
nn.Paramter
import torch.nn.functional as F
# With square kernels and equal stride
filters = torch.randn(8,4,3,3)
weiths = torch.nn.Parameter(torch.randn(8,4,3,3))
inputs = torch.randn(1,4,5,5)
out = F.conv2d(inputs, weiths, stride=2,padding=1)
print(out.shape)
con2d = torch.nn.Conv2d(4,8,3,stride=2,padding=1)
out_2 = con2d(inputs)
print(out_2.shape)
补充:Pytorch-detach()用法
目的:
神经网络的训练有时候可能希望保持一部分的网络参数不变,只对其中一部分的参数进行调整。
或者训练部分分支网络,并不让其梯度对主网络的梯度造成影响.这时候我们就需要使用detach()函数来切断一些分支的反向传播.
1 tensor.detach()
返回一个新的tensor,从当前计算图中分离下来。但是仍指向原变量的存放位置,不同之处只是requirse_grad为false.得到的这个tensir永远不需要计算器梯度,不具有grad.
即使之后重新将它的requires_grad置为true,它也不会具有梯度grad.这样我们就会继续使用这个新的tensor进行计算,后面当我们进行反向传播时,到该调用detach()的tensor就会停止,不能再继续向前进行传播.
注意:
使用detach返回的tensor和原始的tensor共同一个内存,即一个修改另一个也会跟着改变。
比如正常的例子是:
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a)
print(a.grad)
out = a.sigmoid()
out.sum().backward()
print(a.grad)
输出
tensor([1., 2., 3.], requires_grad=True)
None
tensor([0.1966, 0.1050, 0.0452])
1.1 当使用detach()分离tensor但是没有更改这个tensor时,并不会影响backward():
import torch
a = torch.tensor([1, 2, 3.], requires_grad=True)
print(a.grad)
out = a.sigmoid()
print(out)
#添加detach(),c的requires_grad为False
c = out.detach()
print(c)
#这时候没有对c进行更改,所以并不会影响backward()
out.sum().backward()
print(a.grad)
'''返回:
None
tensor([0.7311, 0.8808, 0.9526], grad_fn=<SigmoidBackward>)
tensor([0.7311, 0.8808, 0.9526])
tensor([0.1966, 0.1050, 0.0452])
'''
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/junqing_wu/article/details/93851909


猜你喜欢
- 1. 概述对于社区,没有一个明确的定义,有很多对社区的定义,如社区是指在一个网络中,有一组节点,它们彼此都相似,而组内的节点与网络中的其他节
- 有用的符号:| 竖杠后的字符会被原样输出 · 点表示下一级的所有字符都会被原样输出,不再被识别。(就是|的升级版,实现批量) include
- 快速测试创建项目与appdjango-admin startproject mysitedjango-admin startapp app1
- 前言:最近碰到业务需要根据PSD文件实现PSD文件解析图层功能,搜到了Python的一个解析PSD的库。这个库就是psd-tools,psd
- 有时候难免会要用到只读的文本框,可今天发现只读文本框有一个缺陷,当鼠标焦点在文本框里面的时候按回退键(backSpace), 会退回到前一个
- 前言:本文的主要内容是介绍Python中的变量命名规则和简单数据类型的应用,简单的数据类型包括字符串和数字等,文中还附有代码以及相应的运行结
- 一、前言实现名片管理系统,首先要创建两个python file ,分别是cards_main.py和cards_tool.py,前一个是主代
- 在深度学习中,模型的输入size通常是正方形尺寸的,比如300 x 300这样.直接resize的话,会把图像拉的变形.通常我们希望resi
- BULK INSERT以用户指定的格式复制一个数据文件至数据库表或视图中。 语法:BULK INSERT [ [ 'database
- Python 网页解析HTMLParse的实例详解使用python将网页抓取下来之后,下一步我们就应该解析网页,提取我们所需要的内容了,在p
- 前言:自增列可使用 auto_increment 来实现,当一个列被标识为 auto_increment 之后,在添加时如果不给此列设置任何
- 作用:用ASP程序将页面中的电话号码生成图片格式。 代码如下:<% Call Com_CreatValidCode
- Dreamweaver MX 2004新增加了表格宽度辅助线功能,让我们在编辑网页表格的时候能清楚地看到表格中各单元的宽度以及变化,很直观。
- 这篇文章讨论了Python的from <module> import *和from <package> import
- 本文实例讲述了Python中类的创建和实例化操作。分享给大家供大家参考,具体如下:python中同样使用关键字class创建一个类,类名称第
- CURLOPT_RETURNTRANSFER 选项:curl_setopt($ch, CURLOPT_RETURNTRANSFER,1);如
- 二维码的分类线性堆叠式二维码矩阵式二维码二维码的优缺点优点信息容量大编码范围广容错能力强译码可靠性高可引入加密措施成本低,易制作缺点二维码技
- 1.在列属性中加入事件 { &
- 自己写的一个自动完成效果,暂时没有ajax数据源,用静态数据代替。仅供喜欢JavaScript的同学们参考,代码如下<!DOCTYPE
- 微信小程序开发内测一个月.数据传递的方式很少.经常遇到页面销毁后回传参数的问题,小程序中并没有类似Android的startActivity