pytorch 禁止/允许计算局部梯度的操作
作者:Answerlzd 发布时间:2021-01-17 01:55:35
标签:pytorch,计算,梯度
一、禁止计算局部梯度
torch.autogard.no_grad: 禁用梯度计算的上下文管理器。
当确定不会调用Tensor.backward()计算梯度时,设置禁止计算梯度会减少内存消耗。如果需要计算梯度设置Tensor.requires_grad=True
两种禁用方法:
将不用计算梯度的变量放在with torch.no_grad()里
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... y = x * 2
>>> y.requires_grad
Out[12]:False
使用装饰器 @torch.no_gard()修饰的函数,在调用时不允许计算梯度
>>> @torch.no_grad()
... def doubler(x):
... return x * 2
>>> z = doubler(x)
>>> z.requires_grad
Out[13]:False
二、禁止后允许计算局部梯度
torch.autogard.enable_grad :允许计算梯度的上下文管理器
在一个no_grad上下文中使能梯度计算。在no_grad外部此上下文管理器无影响.
用法和上面类似:
使用with torch.enable_grad()允许计算梯度
>>> x = torch.tensor([1.], requires_grad=True)
>>> with torch.no_grad():
... with torch.enable_grad():
... y = x * 2
>>> y.requires_grad
Out[14]:True
>>> y.backward() # 计算梯度
>>> x.grad
Out[15]: tensor([2.])
在禁止计算梯度下调用被允许计算梯度的函数,结果可以计算梯度
>>> @torch.enable_grad()
... def doubler(x):
... return x * 2
>>> with torch.no_grad():
... z = doubler(x)
>>> z.requires_grad
Out[16]:True
三、是否计算梯度
torch.autograd.set_grad_enable()
可以作为一个函数使用:
>>> x = torch.tensor([1.], requires_grad=True)
>>> is_train = False
>>> with torch.set_grad_enabled(is_train):
... y = x * 2
>>> y.requires_grad
Out[17]:False
>>> torch.set_grad_enabled(True)
>>> y = x * 2
>>> y.requires_grad
Out[18]:True
>>> torch.set_grad_enabled(False)
>>> y = x * 2
>>> y.requires_grad
Out[19]:False
总结:
单独使用这三个函数时没有什么,但是若是嵌套,遵循就近原则。
x = torch.tensor([1.], requires_grad=True)
with torch.enable_grad():
torch.set_grad_enabled(False)
y = x * 2
print(y.requires_grad)
Out[20]: False
torch.set_grad_enabled(True)
with torch.no_grad():
z = x * 2
print(z.requires_grad)
Out[21]:False
补充:pytorch局部范围内禁用梯度计算,no_grad、enable_grad、set_grad_enabled使用举例
原文及翻译
Locally disabling gradient computation
在局部区域内关闭(禁用)梯度的计算.
The context managers torch.no_grad(), torch.enable_grad(),
and torch.set_grad_enabled() are helpful for locally disabling
and enabling gradient computation. See Locally disabling gradient
computation for more details on their usage. These context
managers are thread local, so they won't work if you send
work to another thread using the threading module, etc.
上下文管理器torch.no_grad()、torch.enable_grad()和
torch.set_grad_enabled()可以用来在局部范围内启用或禁用梯度计算.
在Locally disabling gradient computation章节中详细介绍了
局部禁用梯度计算的使用方式.这些上下文管理器具有线程局部性,
因此,如果你使用threading模块来将工作负载发送到另一个线程,
这些上下文管理器将不会起作用.
no_grad Context-manager that disabled gradient calculation.
no_grad 用于禁用梯度计算的上下文管理器.
enable_grad Context-manager that enables gradient calculation.
enable_grad 用于启用梯度计算的上下文管理器.
set_grad_enabled Context-manager that sets gradient calculation to on or off.
set_grad_enabled 用于设置梯度计算打开或关闭状态的上下文管理器.
例子1
Microsoft Windows [版本 10.0.18363.1440]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102
(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000001A2E55A8870>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715, 0.9088, -1.7601],
[-0.1806, 2.0937, 1.0406, -1.7651],
[ 1.1216, 0.8440, 0.1783, 0.6859]], requires_grad=True)
>>> b = a * 2
>>> b
tensor([[ 0.5648, -0.7430, 1.8176, -3.5202],
[-0.3612, 4.1874, 2.0812, -3.5303],
[ 2.2433, 1.6879, 0.3567, 1.3718]], grad_fn=<MulBackward0>)
>>> b.requires_grad
True
>>> b.grad
__main__:1: UserWarning: The .grad attribute of a Tensor that is not a leaf Tensor is being accessed. Its .grad attribute won't be populated during autograd.backward(). If you indeed want the gradient for a non-leaf Tensor, use .retain_grad() on the non-leaf Tensor. If you access the non-leaf Tensor by mistake, make sure you access the leaf Tensor instead. See github.com/pytorch/pytorch/pull/30531 for more informations.
>>> print(b.grad)
None
>>> a.requires_grad
True
>>> a.grad
>>> print(a.grad)
None
>>>
>>> with torch.no_grad():
... c = a * 2
...
>>> c
tensor([[ 0.5648, -0.7430, 1.8176, -3.5202],
[-0.3612, 4.1874, 2.0812, -3.5303],
[ 2.2433, 1.6879, 0.3567, 1.3718]])
>>> c.requires_grad
False
>>> print(c.grad)
None
>>> a.grad
>>>
>>> print(a.grad)
None
>>> c.sum()
tensor(6.1559)
>>>
>>> c.sum().backward()
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "D:\Anaconda3\envs\pytorch_1.7.1_cu102\lib\site-packages\torch\tensor.py", line 221, in backward
torch.autograd.backward(self, gradient, retain_graph, create_graph)
File "D:\Anaconda3\envs\pytorch_1.7.1_cu102\lib\site-packages\torch\autograd\__init__.py", line 132, in backward
allow_unreachable=True) # allow_unreachable flag
RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
>>>
>>>
>>> b.sum()
tensor(6.1559, grad_fn=<SumBackward0>)
>>> b.sum().backward()
>>>
>>>
>>> a.grad
tensor([[2., 2., 2., 2.],
[2., 2., 2., 2.],
[2., 2., 2., 2.]])
>>> a.requires_grad
True
>>>
>>>
例子2
Microsoft Windows [版本 10.0.18363.1440]
(c) 2019 Microsoft Corporation。保留所有权利。
C:\Users\chenxuqi>conda activate pytorch_1.7.1_cu102
(pytorch_1.7.1_cu102) C:\Users\chenxuqi>python
Python 3.7.9 (default, Aug 31 2020, 17:10:11) [MSC v.1916 64 bit (AMD64)] :: Anaconda, Inc. on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.manual_seed(seed=20200910)
<torch._C.Generator object at 0x000002109ABC8870>
>>>
>>> a = torch.randn(3,4,requires_grad=True)
>>> a
tensor([[ 0.2824, -0.3715, 0.9088, -1.7601],
[-0.1806, 2.0937, 1.0406, -1.7651],
[ 1.1216, 0.8440, 0.1783, 0.6859]], requires_grad=True)
>>> a.requires_grad
True
>>>
>>> with torch.set_grad_enabled(False):
... b = a * 2
...
>>> b
tensor([[ 0.5648, -0.7430, 1.8176, -3.5202],
[-0.3612, 4.1874, 2.0812, -3.5303],
[ 2.2433, 1.6879, 0.3567, 1.3718]])
>>> b.requires_grad
False
>>>
>>> with torch.set_grad_enabled(True):
... c = a * 3
...
>>> c
tensor([[ 0.8472, -1.1145, 2.7263, -5.2804],
[-0.5418, 6.2810, 3.1219, -5.2954],
[ 3.3649, 2.5319, 0.5350, 2.0576]], grad_fn=<MulBackward0>)
>>> c.requires_grad
True
>>>
>>> d = a * 4
>>> d.requires_grad
True
>>>
>>> torch.set_grad_enabled(True) # this can also be used as a function
<torch.autograd.grad_mode.set_grad_enabled object at 0x00000210983982C8>
>>>
>>> # 以函数调用的方式来使用
>>>
>>> e = a * 5
>>> e
tensor([[ 1.4119, -1.8574, 4.5439, -8.8006],
[-0.9030, 10.4684, 5.2031, -8.8257],
[ 5.6082, 4.2198, 0.8917, 3.4294]], grad_fn=<MulBackward0>)
>>> e.requires_grad
True
>>>
>>> d
tensor([[ 1.1296, -1.4859, 3.6351, -7.0405],
[-0.7224, 8.3747, 4.1625, -7.0606],
[ 4.4866, 3.3759, 0.7133, 2.7435]], grad_fn=<MulBackward0>)
>>>
>>> torch.set_grad_enabled(False) # 以函数调用的方式来使用
<torch.autograd.grad_mode.set_grad_enabled object at 0x0000021098394C48>
>>>
>>> f = a * 6
>>> f
tensor([[ 1.6943, -2.2289, 5.4527, -10.5607],
[ -1.0836, 12.5621, 6.2437, -10.5908],
[ 6.7298, 5.0638, 1.0700, 4.1153]])
>>> f.requires_grad
False
>>>
>>>
>>>
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/Answer3664/article/details/99460175


猜你喜欢
- 在 HTML 中使用JavaScriptJavaScript能以两种方式嵌入HTML:作为语句和函数使用时,用 SCRIPT 标记作为事件处
- 本文实例讲述了Python pandas自定义函数的使用方法。分享给大家供大家参考,具体如下:自定义函数的使用import numpy as
- 如何使用migrations的使用非常简单: 修改model, 比如增加field, 然后运行python manager.py makem
- 前不久网上公开了一个MySQL Func的漏洞,讲的是使用MySQL创建一个自定义的函数,然后通过这个函数来攻击服务器。最早看到相关的报道是
- 1.提示窗口,当页面被打开时就弹出提示窗口。<style type="text/css"> body { b
- 在前文说过,如果想要更好的做接口测试,我们要利用自己的代码基础与代码优势,所以该章节不会再介绍商业化的、通用的接口测试工具,重点介绍如何通过
- 今天的问题是请问以下 alert 弹出值分别是什么?var f = function f2()&nb
- 本文实例讲述了php实现mysql事务处理的方法。分享给大家供大家参考。具体分析如下:要实现本功能的条件是环境 mysql 5.2 /php
- 最近我在Go Forum 中发现了String size of 20 character 的问题,“hollowaykeanho” 给出了相
- 1 Kmean图像分割按照Kmean原理,对图像像素进行聚类。优点:此方法原理简单,效果显著。缺点:实践发现对于前景和背景颜色相近或者颜色区
- Graphical User Interface,简称 GUI,又称图形化用户接口,所谓的GUI编程,指的是用户不需要输入代码指令,只通过图
- 本文实例为大家分享了答题辅助python具体代码,供大家参考,具体内容如下from screenshot import pull_scree
- 在编程时你一定碰到过时间触发的事件,在VB中有timer控件,而ASP中没有,假如你要不停地查询数据库来等待一个返回结果的话,我想你一定知道
- 我们怎样才能了解用户需求呢?大家都知道可用性测试、调查问卷之类与用户进行沟通的途径,这些方法各有各的利弊,如果逐一分析的话,恐怕至少要分成三
- 1、定义具名元组需要2个参数,第1个参数是类名,第2个参数是字段名,既可以是可迭代对象(如列表和元组),也可以是空格间隔的字符串:Card
- 本文实例讲述了Python实现基于socket的udp传输与接收功能。分享给大家供大家参考,具体如下:udp的传输与接收windows网络调
- 实例如下:from win32com.client import Dispatch import win32com.client
- 之前修改两张及以上表的时候,总得需要用几次语句才修改,万一其中一条没修改上,又没事务机制的话,处理很麻烦,于是想到能不能一条语句完成呢? 结
- 起由:前一阵子想要刷一刷国二Python的题库,千方百计找到题库之后,打开一个个word文档,发现一题一题阅读很麻烦,而且答案就在题目的下面
- 本文是从百度百科中摘录出来的,asp在it中还有Application Service Provider,也就是应用服务供应商的意思。概述A