Pytorch实现简单自定义网络层的方法
作者:ting_qifengl 发布时间:2021-01-13 16:02:55
前言
Pytorch、Tensoflow等许多深度学习框架集成了大量常见的网络层,为我们搭建神经网络提供了诸多便利。但在实际工作中,因为项目要求、研究需要或者 * 文需要等等,大家一般都会需要自己发明一个现在在深度学习框架中还不存在的层。 在这些情况下,就必须构建自定义层。
博主在学习了沐神的动手学深度学习这本书之后,学到了许多东西。这里记录一下书中基于Pytorch实现简单自定义网络层的方法,仅供参考。
一、不带参数的层
首先,我们构造一个没有任何参数的自定义层,要构建它,只需继承基础层类并实现前向传播功能。
import torch
import torch.nn.functional as F
from torch import nn
class CenteredLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
return X - X.mean()
输入一些数据,验证一下网络是否能正常工作:
layer = CenteredLayer()
print(layer(torch.FloatTensor([1, 2, 3, 4, 5])))
输出结果如下:
tensor([-2., -1., 0., 1., 2.])
运行正常,表明网络没有问题。
现在将我们自建的网络层作为组件合并到更复杂的模型中,并输入数据进行验证:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
Y = net(torch.rand(4, 8))
print(Y.mean()) # 因为模型参数较多,输出也较多,所以这里输出Y的均值,验证模型可运行即可
结果如下:
tensor(-5.5879e-09, grad_fn=<MeanBackward0>)
二、带参数的层
这里使用内置函数来创建参数,这些函数可以提供一些基本的管理功能,使用更加方便。
这里实现了一个简单的自定义的全连接层,大家可根据需要自行修改即可。
class MyLinear(nn.Module):
def __init__(self, in_units, units):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_units, units))
self.bias = nn.Parameter(torch.randn(units,))
def forward(self, X):
linear = torch.matmul(X, self.weight.data) + self.bias.data
return F.relu(linear)
接下来实例化类并访问其模型参数:
linear = MyLinear(5, 3)
print(linear.weight)
结果如下:
Parameter containing:
tensor([[-0.3708, 1.2196, 1.3658],
[ 0.4914, -0.2487, -0.9602],
[ 1.8458, 0.3016, -0.3956],
[ 0.0616, -0.3942, 1.6172],
[ 0.7839, 0.6693, -0.8890]], requires_grad=True)
而后输入一些数据,查看模型输出结果:
print(linear(torch.rand(2, 5)))
# 结果如下
tensor([[1.2394, 0.0000, 0.0000],
[1.3514, 0.0968, 0.6667]])
我们还可以使用自定义层构建模型,使用方法与使用内置的全连接层相同。
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
print(net(torch.rand(2, 64)))
# 结果如下
tensor([[4.1416],
[0.2567]])
三、总结
我们可以通过基本层类设计自定义层。这允许我们定义灵活的新层,其行为与深度学习框架中的任何现有层不同。
在自定义层定义完成后,我们就可以在任意环境和网络架构中调用该自定义层。
层可以有局部参数,这些参数可以通过内置函数创建。
四、参考
《动手学深度学习》 — 动手学深度学习 2.0.0-beta0 documentation
https://zh-v2.d2l.ai/
附:pytorch获取网络的层数和每层的名字
#创建自己的网络
import models
model = models.__dict__["resnet50"](pretrained=True)
for index ,(name, param) in enumerate(model.named_parameters()):
? ? print( str(index) + " " +name)
结果如下:
0 conv1.weight
1 bn1.weight
2 bn1.bias
3 layer1.0.conv1.weight
4 layer1.0.bn1.weight
5 layer1.0.bn1.bias
6 layer1.0.conv2.weight
7 layer1.0.bn2.weight
8 layer1.0.bn2.bias
9 layer1.0.conv3.weight
来源:https://blog.csdn.net/ting_qifengl/article/details/124870577
猜你喜欢
- 使用 Python 进行数据处理的时候,常常会遇到判断一个数是否在一个区间内的操作。我们可以使用 if else 进行判断,但是,既然使用了
- 安装好mysql后,在终端输入 mysql -u root -p 按回车,输入密码后提示access denied......ues pas
- * 表的建立关系数据库的主要特点之一就是用表的方式组织数据。表是SQL语言存放数据、查找数据以及更新数据的基本数据结构。在SQL语言中,表有
- python里使用正则表达式的组嵌套实例详解由于组本身是一个完整的正则表达式,所以可以将组嵌套在其他组中,以构建更复杂的表达式。下面的例子,
- sys模块 与 os包一样,也是对系统资源进行调用。功能同样也是非常丰富,接下来我们会对 sys模块的一些简单且常用的函数进行介绍,主要针对
- 用VBS语言实现的一个简单网页计算器,功能:可以进行加法、减法、乘法、除法、取反、开根号、及指数运算。虽然简单但是比起windows xp自
- LRU:least recently used,最近最少使用算法。它的使用场景是:在有限的空间中存储对象时,当空间满时,会按一定的原则删除原
- 用到了两个库,xlrd和xlwtxlrd是读excel,xlwt是写excel的库[/code]1)xlwd用到的方法:xlwt.Workb
- 代码如下,另存为asp文件,请传到你的服务器上就可以了马上测一下<%Response.Expires = 0Response.Expi
- 你是否对获得MySQL数据库与表的最基本命令的实际操作感到十分头疼?如果是这样子的话,以下的文章将会给你相应的解决方案,以下的文
- 遵循Web标准的思想,网页要表现出一种亲和力。那么,针对残障用户来说,其“阅读”器可不能读取图像上传递的信息的。所以我们会采用一种Using
- 图片人脸识别import cv2filepath = "img/xingye-1.png"img = cv2.imrea
- 1.安装Python-LDAP(python_ldap-2.4.25-cp27-none-win_amd64.whl)pip install
- 1.什么是并发编程并发编程是实现多任务协同处理,改善系统性能的方式。Python中实现并发编程主要依靠进程(Process):进程是计算机中
- 图像标注在计算机视觉中很重要,计算机视觉是一种技术,它允许计算机从数字图像或视频中获得高水平的理解力,并以人类的方式观察和解释视觉信息。注释
- 每个进行过较大型的ASP-Web应用程序设计的开发人员大概都有如下的经历:ASP代码与页面HTML混淆难分,业务逻辑与显示方式绞合,使得代码
- Cookie用于服务器实现会话,用户登录及相关功能时进行状态管理。要在用户浏览器上安装cookie,HTTP服务器向HTTP响应添加类似以下
- 原文地址:30 Days of Mootools 1.2 Tutorials - Day 21 - Classes part
- urllib是python的一个获取url(Uniform Resource Locators,统一资源定址符)了,可以利用它来抓取远程的数
- 本文实例为大家分享了python实现学生成绩测评系统的具体代码,供大家参考,具体内容如下1、问题描述(功能要求): 根据实验指导书