Pytorch 中retain_graph的用法详解
作者:DaneAI 发布时间:2021-01-20 21:23:45
标签:Pytorch,retain,graph
用法分析
在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用是什么?
############################
# (1) Update D network: maximize D(x)-1-D(G(z))
###########################
real_img = Variable(target)
if torch.cuda.is_available():
real_img = real_img.cuda()
z = Variable(data)
if torch.cuda.is_available():
z = z.cuda()
fake_img = netG(z)
netD.zero_grad()
real_out = netD(real_img).mean()
fake_out = netD(fake_img).mean()
d_loss = 1 - real_out + fake_out
d_loss.backward(retain_graph=True) #####
optimizerD.step()
############################
# (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss
###########################
netG.zero_grad()
g_loss = generator_criterion(fake_out, fake_img, real_img)
g_loss.backward()
optimizerG.step()
fake_img = netG(z)
fake_out = netD(fake_img).mean()
g_loss = generator_criterion(fake_out, fake_img, real_img)
running_results['g_loss'] += g_loss.data[0] * batch_size
d_loss = 1 - real_out + fake_out
running_results['d_loss'] += d_loss.data[0] * batch_size
running_results['d_score'] += real_out.data[0] * batch_size
running_results['g_score'] += fake_out.data[0] * batch_size
在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;
其实retain_graph这个参数在平常中我们是用不到的,但是在特殊的情况下我们会用到它,
如下代码:
import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward()
output2.backward()
输出如下错误信息:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-19-8ad6b0658906> in <module>()
----> 1 output1.backward()
2 output2.backward()
D:\ProgramData\Anaconda3\lib\site-packages\torch\tensor.py in backward(self, gradient, retain_graph, create_graph)
91 products. Defaults to ``False``.
92 """
---> 93 torch.autograd.backward(self, gradient, retain_graph, create_graph)
94
95 def register_hook(self, hook):
D:\ProgramData\Anaconda3\lib\site-packages\torch\autograd\__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
88 Variable._execution_engine.run_backward(
89 tensors, grad_tensors, retain_graph, create_graph,
---> 90 allow_unreachable=True) # allow_unreachable flag
91
92
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
修改成如下正确:
import torch
y=x**2
z=y*4
output1=z.mean()
output2=z.sum()
output1.backward(retain_graph=True)
output2.backward()
# 假如你有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True)
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数
Variable 类源代码
class Variable(_C._VariableBase):
"""
Attributes:
data: 任意类型的封装好的张量。
grad: 保存与data类型和位置相匹配的梯度,此属性难以分配并且不能重新分配。
requires_grad: 标记变量是否已经由一个需要调用到此变量的子图创建的bool值。只能在叶子变量上进行修改。
volatile: 标记变量是否能在推理模式下应用(如不保存历史记录)的bool值。只能在叶变量上更改。
is_leaf: 标记变量是否是图叶子(如由用户创建的变量)的bool值.
grad_fn: Gradient function graph trace.
Parameters:
data (any tensor class): 要包装的张量.
requires_grad (bool): bool型的标记值. **Keyword only.**
volatile (bool): bool型的标记值. **Keyword only.**
"""
def backward(self, gradient=None, retain_graph=None, create_graph=None, retain_variables=None):
"""计算关于当前图叶子变量的梯度,图使用链式法则导致分化
如果Variable是一个标量(例如它包含一个单元素数据),你无需对backward()指定任何参数
如果变量不是标量(包含多个元素数据的矢量)且需要梯度,函数需要额外的梯度;
需要指定一个和tensor的形状匹配的grad_output参数(y在指定方向投影对x的导数);
可以是一个类型和位置相匹配且包含与自身相关的不同函数梯度的张量。
函数在叶子上累积梯度,调用前需要对该叶子进行清零。
Arguments:
grad_variables (Tensor, Variable or None):
变量的梯度,如果是一个张量,除非“create_graph”是True,否则会自动转换成volatile型的变量。
可以为标量变量或不需要grad的值指定None值。如果None值可接受,则此参数可选。
retain_graph (bool, optional): 如果为False,用来计算梯度的图将被释放。
在几乎所有情况下,将此选项设置为True不是必需的,通常可以以更有效的方式解决。
默认值为create_graph的值。
create_graph (bool, optional): 为True时,会构造一个导数的图,用来计算出更高阶导数结果。
默认为False,除非``gradient``是一个volatile变量。
"""
torch.autograd.backward(self, gradient, retain_graph, create_graph, retain_variables)
def register_hook(self, hook):
"""Registers a backward hook.
每当与variable相关的梯度被计算时调用hook,hook的申明:hook(grad)->Variable or None
不能对hook的参数进行修改,但可以选择性地返回一个新的梯度以用在`grad`的相应位置。
函数返回一个handle,其``handle.remove()``方法用于将hook从模块中移除。
Example:
>>> v = Variable(torch.Tensor([0, 0, 0]), requires_grad=True)
>>> h = v.register_hook(lambda grad: grad * 2) # double the gradient
>>> v.backward(torch.Tensor([1, 1, 1]))
>>> v.grad.data
2
2
2
[torch.FloatTensor of size 3]
>>> h.remove() # removes the hook
"""
if self.volatile:
raise RuntimeError("cannot register a hook on a volatile variable")
if not self.requires_grad:
raise RuntimeError("cannot register a hook on a variable that "
"doesn't require gradient")
if self._backward_hooks is None:
self._backward_hooks = OrderedDict()
if self.grad_fn is not None:
self.grad_fn._register_hook_dict(self)
handle = hooks.RemovableHandle(self._backward_hooks)
self._backward_hooks[handle.id] = hook
return handle
def reinforce(self, reward):
"""Registers a reward obtained as a result of a stochastic process.
区分随机节点需要为他们提供reward值。如果图表中包含任何的随机操作,都应该在其输出上调用此函数,否则会出现错误。
Parameters:
reward(Tensor): 带有每个元素奖赏的张量,必须与Variable数据的设备位置和形状相匹配。
"""
if not isinstance(self.grad_fn, StochasticFunction):
raise RuntimeError("reinforce() can be only called on outputs "
"of stochastic functions")
self.grad_fn._reinforce(reward)
def detach(self):
"""返回一个从当前图分离出来的心变量。
结果不需要梯度,如果输入是volatile,则输出也是volatile。
.. 注意::
返回变量使用与原始变量相同的数据张量,并且可以看到其中任何一个的就地修改,并且可能会触发正确性检查中的错误。
"""
result = NoGrad()(self) # this is needed, because it merges version counters
result._grad_fn = None
return result
def detach_(self):
"""从创建它的图中分离出变量并作为该图的一个叶子"""
self._grad_fn = None
self.requires_grad = False
def retain_grad(self):
"""Enables .grad attribute for non-leaf Variables."""
if self.grad_fn is None: # no-op for leaves
return
if not self.requires_grad:
raise RuntimeError("can't retain_grad on Variable that has requires_grad=False")
if hasattr(self, 'retains_grad'):
return
weak_self = weakref.ref(self)
def retain_grad_hook(grad):
var = weak_self()
if var is None:
return
if var._grad is None:
var._grad = grad.clone()
else:
var._grad = var._grad + grad
self.register_hook(retain_grad_hook)
self.retains_grad = True
来源:https://blog.csdn.net/happyday_d/article/details/85554623


猜你喜欢
- 引言“ 这是MySQL系列笔记的第一篇,文章内容均为本人通过实践及查阅资料相关整理所得,可用作新手入门指南,或
- 在Python中,很多对象都是可以通过for语句来直接遍历的,例如list、string、dict等等,这些对象都可以被称为可迭代对象。至于
- 引言:本文是学习Turtle库时,发现两种方法都能改变画笔的方向,但二者又不是完全相同,故对其加以辨析总结到此,在本文你将收获:1.两种改变
- Purge死锁场景说明Purge死锁说明表中存在记录(unique key) 10,20,30,40 (且有 自增主键 ),现在删除记录 2
- 本文研究的主要内容是Python中装饰器相关学习总结,具体如下。装饰器(decorator)功能引入日志函数执行时间统计执行函数前预备处理执
- 本文详细讲述了Python使用MySQLdb for Python操作数据库的方法,分享给大家供大家参考。具体如下:一般来说网站就是要和数据
- Pytorch中的model.train() 和 model.eval() 原理与用法一、两种模式pytorch可以给我们提供两种方式来切换
- 问题描述今天在使用Numpy中的矩阵做相减操作时,出现了一些本应为负值的位置自动转换为了正值,观察发现转换后的正值为原本的负值加上256得到
- python安装完毕后,提示找不到ssl模块:[www@pythontab.com ~]$ pythonPython 2.7.15 (def
- this指针是面向对象程序设计中的一项重要概念,它表示当前运行的对象。在实现对象的方法时,可以使用this指针来获得该对象自身的引用。和其他
- 在 Class 块中,成员通过相应的声明语句被声明为 Private(私有成员,只能在类内部调用)
- JS数组遍历普通函数优点:支持流程控制(break、continue、return)forconst arr = ["A"
- 1. list查询个数:调用list.count(obj)函数,返回obj在list中的个数。输入:list_a = [2 for x in
- MySQL的ODBC接口实现是通过安装MyODBC驱动,这个驱动程序是跨平台的。如果在Linux等Unix体系操作系统下使用,需要先安装Io
- NumPy 支持的几类矩阵乘法也很重要。元素级乘法你已看过了一些元素级乘法。你可以使用 multiply 函数或 * 运算符来实现。回顾一下
- 本文实例讲述了Python中函数的参数定义和可变参数用法。分享给大家供大家参考。具体如下:刚学用Python的时候,特别是看一些库的源码时,
- 项目内容:用Python写的糗事百科的网络爬虫。使用方法:新建一个Bug.py文件,然后将代码复制到里面后,双击运行。程序功能:在命令提示行
- 参数数量及其作用该函数共有五个参数,分别是:被赋值的变量 ref要分配给变量的值 value、是否验证形状 validate_shape是否
- 代码如下:--CAST 和 CONVERT 函数 Percentage DECLARE @dec decimal(5,3), @var va
- 假如不使用INSTEAD OF触发器或可更新分区视图而是通过视图来修改数据,那么再修改之前,请考虑下列准则:◆如果在视图定义中使用了 WIT