PyTorch中的参数类torch.nn.Parameter()详解
作者:Adenialzz 发布时间:2021-09-07 19:06:30
前言
今天来聊一下PyTorch中的torch.nn.Parameter()这个函数,笔者第一次见的时候也是大概能理解函数的用途,但是具体实现原理细节也是云里雾里,在参考了几篇博文,做过几个实验之后算是清晰了,本文在记录的同时希望给后来人一个参考,欢迎留言讨论。
分析
先看其名,parameter,中文意为参数。我们知道,使用PyTorch训练神经网络时,本质上就是训练一个函数,这个函数输入一个数据(如CV中输入一张图像),输出一个预测(如输出这张图像中的物体是属于什么类别)。而在我们给定这个函数的结构(如卷积、全连接等)之后,能学习的就是这个函数的参数了,我们设计一个损失函数,配合梯度下降法,使得我们学习到的函数(神经网络)能够尽量准确地完成预测任务。
通常,我们的参数都是一些常见的结构(卷积、全连接等)里面的计算参数。而当我们的网络有一些其他的设计时,会需要一些额外的参数同样很着整个网络的训练进行学习更新,最后得到最优的值,经典的例子有注意力机制中的权重参数、Vision Transformer中的class token和positional embedding等。
而这里的torch.nn.Parameter()就可以很好地适应这种应用场景。
下面是这篇博客的一个总结,笔者认为讲的比较明白,在这里引用一下:
首先可以把这个函数理解为类型转换函数,将一个不可训练的类型Tensor转换成可以训练的类型parameter并将这个parameter绑定到这个module里面(net.parameter()中就有这个绑定的parameter,所以在参数优化的时候可以进行优化的),所以经过类型转换这个self.v变成了模型的一部分,成为了模型中根据训练可以改动的参数了。使用这个函数的目的也是想让某些变量在学习的过程中不断的修改其值以达到最优化。
ViT中nn.Parameter()的实验
看过这个分析后,我们再看一下Vision Transformer中的用法:
...
self.pos_embedding = nn.Parameter(torch.randn(1, num_patches+1, dim))
self.cls_token = nn.Parameter(torch.randn(1, 1, dim))
...
我们知道在ViT中,positonal embedding和class token是两个需要随着网络训练学习的参数,但是它们又不属于FC、MLP、MSA等运算的参数,在这时,就可以用nn.Parameter()来将这个随机初始化的Tensor注册为可学习的参数Parameter。
为了确定这两个参数确实是被添加到了net.Parameters()内,笔者稍微改动源码,显式地指定这两个参数的初始数值为0.98,并打印迭代器net.Parameters()。
...
self.pos_embedding = nn.Parameter(torch.ones(1, num_patches+1, dim) * 0.98)
self.cls_token = nn.Parameter(torch.ones(1, 1, dim) * 0.98)
...
实例化一个ViT模型并打印net.Parameters():
net_vit = ViT(
image_size = 256,
patch_size = 32,
num_classes = 1000,
dim = 1024,
depth = 6,
heads = 16,
mlp_dim = 2048,
dropout = 0.1,
emb_dropout = 0.1
)
for para in net_vit.parameters():
print(para.data)
输出结果中可以看到,最前两行就是我们显式指定为0.98的两个参数pos_embedding和cls_token:
tensor([[[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800],
[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800],
[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800],
...,
[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800],
[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800],
[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800]]])
tensor([[[0.9800, 0.9800, 0.9800, ..., 0.9800, 0.9800, 0.9800]]])
tensor([[-0.0026, -0.0064, 0.0111, ..., 0.0091, -0.0041, -0.0060],
[ 0.0003, 0.0115, 0.0059, ..., -0.0052, -0.0056, 0.0010],
[ 0.0079, 0.0016, -0.0094, ..., 0.0174, 0.0065, 0.0001],
...,
[-0.0110, -0.0137, 0.0102, ..., 0.0145, -0.0105, -0.0167],
[-0.0116, -0.0147, 0.0030, ..., 0.0087, 0.0022, 0.0108],
[-0.0079, 0.0033, -0.0087, ..., -0.0174, 0.0103, 0.0021]])
...
...
这就可以确定nn.Parameter()添加的参数确实是被添加到了Parameters列表中,会被送入优化器中随训练一起学习更新。
from torch.optim import Adam
opt = Adam(net_vit.parameters(), learning_rate=0.001)
其他解释
以下是国外StackOverflow的一个大佬的解读,笔者自行翻译并放在这里供大家参考,想查看原文的同学请戳这里。
我们知道Tensor相当于是一个高维度的矩阵,它是Variable类的子类。Variable和Parameter之间的差异体现在与Module关联时。当Parameter作为model的属性与module相关联时,它会被自动添加到Parameters列表中,并且可以使用net.Parameters()迭代器进行访问。
最初在Torch中,一个Variable(例如可以是某个中间state)也会在赋值时被添加为模型的Parameter。在某些实例中,需要缓存变量,而不是将它们添加到Parameters列表中。
文档中提到的一种情况是RNN,在这种情况下,您需要保存最后一个hidden state,这样就不必一次又一次地传递它。需要缓存一个Variable,而不是让它自动注册为模型的Parameter,这就是为什么我们有一个显式的方法将参数注册到我们的模型,即nn.Parameter类。
举个例子:
import torch
import torch.nn as nn
from torch.optim import Adam
class NN_Network(nn.Module):
def __init__(self,in_dim,hid,out_dim):
super(NN_Network, self).__init__()
self.linear1 = nn.Linear(in_dim,hid)
self.linear2 = nn.Linear(hid,out_dim)
self.linear1.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
self.linear1.bias = torch.nn.Parameter(torch.ones(hid))
self.linear2.weight = torch.nn.Parameter(torch.zeros(in_dim,hid))
self.linear2.bias = torch.nn.Parameter(torch.ones(hid))
def forward(self, input_array):
h = self.linear1(input_array)
y_pred = self.linear2(h)
return y_pred
in_d = 5
hidn = 2
out_d = 3
net = NN_Network(in_d, hidn, out_d)
然后检查一下这个模型的Parameters列表:
for param in net.parameters():
print(type(param.data), param.size())
""" Output
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
<class 'torch.FloatTensor'> torch.Size([5, 2])
<class 'torch.FloatTensor'> torch.Size([2])
"""
可以轻易地送入到优化器中:
opt = Adam(net.parameters(), learning_rate=0.001)
另外,请注意Parameter的require_grad会自动设定。
各位读者有疑惑或异议的地方,欢迎留言讨论。
参考:
https://www.jb51.net/article/238632.htm
https://stackoverflow.com/questions/50935345/understanding-torch-nn-parameter
来源:https://blog.csdn.net/weixin_44966641/article/details/118730730
猜你喜欢
- 1,创建测试表CREATE TABLE `testsign` ( `userid` int(5) DEFAULT NULL, `user
- 概述在pytorch中有两种方式可以保存推理模型,第一种是只保存模型的参数,比如parameters和buffers;另外一种是保存整个模型
- 今天碰到一个很有意思的问题,需要将普通的 Unicode字符串转换为 Unicode编码的字符串,如下:将 \\u9500\\u552e 转
- 安装破解包:AWVS14.6.220117111破解Win&Linux&Mac.zip网盘链接:https://pan.ba
- 本文实例讲述了mysql重复索引与冗余索引。分享给大家供大家参考,具体如下:重复索引:表示一个列或者顺序相同的几个列上建立的多个索引。冗余索
- 问题:1. 访问 ASP 页面时,出现以下错误:Active Server Pages 错误 'ASP 0201'错误无效的
- 切片操作首先支持下标索引,通过[ N:M :P ]操作索引正向从0开始,逆向从-1开始N:切片开始位置M:切片结束位置(不包含)P:指定切片
- 本文实例讲述了Python定时任务sched模块用法。分享给大家供大家参考,具体如下:通过sched模块可以实现通过自定义时间,自定义函数,
- 本文实例为大家分享了Django实现文件上传下载的具体代码,供大家参考,具体内容如下一、django实现文件下载(1)、后台接口如果从服务器
- 开发微信小程序过程中,有个需求需要用到日期时间筛选器,查看微信官方文档后,发现官方文档的picker筛选器只能单独支持日期或者是时间,所以为
- 导语也许是为了和音,在立冬这一天的人间里北方多个城市,悄然降下冬天的第一场初雪,组成了一段旋律💨一天过两季,黄叶转飞花——从天而降落,昼夜不
- 不进行计算时,生成器和list空间占用import timefrom memory_profiler import profile@prof
- 笔者在今天的工作中,遇到了一个需求,那就是如何将Python字符串生成PDF。比如,需要把Python字符串‘这是测试文件'生成为P
- 本文向大家介绍了使用SQL语句提取数据库所有表的表名、字段名的实例代码,在SQLserver 中进行了测试,具体内容如下:--查询所有用户表
- 当perl脚本运行时,从命令行上传递给它的参数存储在内建数组@ARGV中,@ARGV是PERL默认用来接收参数的数组,可以有多个参数,$AR
- Vue baseurl配置最近的一个vue项目,没有config文件夹,配置baseurl废了很大劲,终于找到了方法,感天动地o(╥﹏╥)o
- 1. ASCII码我们知道,在计算机内部,所有的信息最终都表示为一个二进制的字符串。每一个二进制位(bit)有0和1两种状态,因此八个二进制
- 大致介绍好久没有写博客了,正好今天有时间把前几天写的利用python定时发送QQ邮件记录一下1、首先利用request库去请求数据,天气预报
- 故障状况:php网站连接mysql失败,但在命令行下通过mysql命令可登录并正常操作。解决方案:1、命令行下登录mysql,执行以下命令:
- 首先我们看公式:这个是要拟合的函数然后我们求出它的损失函数, 注意:这里的n和m均为数据集的长度,写的时候忘了注意,前面的theta0-th