Pytorch实现简单自定义网络层的方法
作者:ting_qifengl 发布时间:2021-01-13 16:02:55
前言
Pytorch、Tensoflow等许多深度学习框架集成了大量常见的网络层,为我们搭建神经网络提供了诸多便利。但在实际工作中,因为项目要求、研究需要或者 * 文需要等等,大家一般都会需要自己发明一个现在在深度学习框架中还不存在的层。 在这些情况下,就必须构建自定义层。
博主在学习了沐神的动手学深度学习这本书之后,学到了许多东西。这里记录一下书中基于Pytorch实现简单自定义网络层的方法,仅供参考。
一、不带参数的层
首先,我们构造一个没有任何参数的自定义层,要构建它,只需继承基础层类并实现前向传播功能。
import torch
import torch.nn.functional as F
from torch import nn
class CenteredLayer(nn.Module):
def __init__(self):
super().__init__()
def forward(self, X):
return X - X.mean()
输入一些数据,验证一下网络是否能正常工作:
layer = CenteredLayer()
print(layer(torch.FloatTensor([1, 2, 3, 4, 5])))
输出结果如下:
tensor([-2., -1., 0., 1., 2.])
运行正常,表明网络没有问题。
现在将我们自建的网络层作为组件合并到更复杂的模型中,并输入数据进行验证:
net = nn.Sequential(nn.Linear(8, 128), CenteredLayer())
Y = net(torch.rand(4, 8))
print(Y.mean()) # 因为模型参数较多,输出也较多,所以这里输出Y的均值,验证模型可运行即可
结果如下:
tensor(-5.5879e-09, grad_fn=<MeanBackward0>)
二、带参数的层
这里使用内置函数来创建参数,这些函数可以提供一些基本的管理功能,使用更加方便。
这里实现了一个简单的自定义的全连接层,大家可根据需要自行修改即可。
class MyLinear(nn.Module):
def __init__(self, in_units, units):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_units, units))
self.bias = nn.Parameter(torch.randn(units,))
def forward(self, X):
linear = torch.matmul(X, self.weight.data) + self.bias.data
return F.relu(linear)
接下来实例化类并访问其模型参数:
linear = MyLinear(5, 3)
print(linear.weight)
结果如下:
Parameter containing:
tensor([[-0.3708, 1.2196, 1.3658],
[ 0.4914, -0.2487, -0.9602],
[ 1.8458, 0.3016, -0.3956],
[ 0.0616, -0.3942, 1.6172],
[ 0.7839, 0.6693, -0.8890]], requires_grad=True)
而后输入一些数据,查看模型输出结果:
print(linear(torch.rand(2, 5)))
# 结果如下
tensor([[1.2394, 0.0000, 0.0000],
[1.3514, 0.0968, 0.6667]])
我们还可以使用自定义层构建模型,使用方法与使用内置的全连接层相同。
net = nn.Sequential(MyLinear(64, 8), MyLinear(8, 1))
print(net(torch.rand(2, 64)))
# 结果如下
tensor([[4.1416],
[0.2567]])
三、总结
我们可以通过基本层类设计自定义层。这允许我们定义灵活的新层,其行为与深度学习框架中的任何现有层不同。
在自定义层定义完成后,我们就可以在任意环境和网络架构中调用该自定义层。
层可以有局部参数,这些参数可以通过内置函数创建。
四、参考
《动手学深度学习》 — 动手学深度学习 2.0.0-beta0 documentation
https://zh-v2.d2l.ai/
附:pytorch获取网络的层数和每层的名字
#创建自己的网络
import models
model = models.__dict__["resnet50"](pretrained=True)
for index ,(name, param) in enumerate(model.named_parameters()):
? ? print( str(index) + " " +name)
结果如下:
0 conv1.weight
1 bn1.weight
2 bn1.bias
3 layer1.0.conv1.weight
4 layer1.0.bn1.weight
5 layer1.0.bn1.bias
6 layer1.0.conv2.weight
7 layer1.0.bn2.weight
8 layer1.0.bn2.bias
9 layer1.0.conv3.weight
来源:https://blog.csdn.net/ting_qifengl/article/details/124870577


猜你喜欢
- 一.图像金字塔原理上一篇文章讲解的图像采样处理可以降低图像的大小,本文将补充图像金字塔知识,了解专门用于图像向上采样和向下采样的pyrUp(
- 方法: 使用urlencode函数urllib.request.urlopen()import urllib.requestimport u
- 用途logging模块是Python的内置模块,主要用于输出运行日志,可以灵活配置输出日志的各项信息。基本使用方法logging.basic
- 前言本文主要给大家介绍了关于微信小程序自定义导航的相关内容,详细代码请见github,请点击地址 (本地下载),其中有原生小程序的
- Transact-SQL(又称T-SQL),是在Microsoft SQL Server和Sybase SQL
- 下面步骤展示的是如何经过VirtualBox管理器,使得pycharm和ubuntu中的项目环境连接对应起来!如果你有属于自己的服务器,核心
- 目录selenium模块selenium基本概念基本使用基于浏览器自动化的操作selenium处理iframe:selenium模拟登陆QQ
- python一直对中文支持的不好,最近老遇到编码问题,而且几乎没有通用的方案来解决这个问题,但是对常见的方法都试过之后,发现还是可以解决的,
- 在使用echarts的自定义饼图Customized Pie时,定义的动态数据会发生颜色无法渲染的问题,如下图所示:该图表的颜色是根据ite
- 1、去除一个数组中的重复元素:使用grep函数代码片段: 代码:my @array = ( 'a', 'b'
- 废话不多说,看代码吧!'''待完善。此代码实现了,根据标注文本的属性,数值,位置,及 容差,去判断 设计 和 实测两
- 示例标准线程多进程,生产者/消费者示例:Worker越多,问题越大# -*- coding: utf8 -*-import osimport
- 本文的爬虫教程分为四部: 1.从哪爬 where &nbs
- 我在程序中加入了分数显示,三种特殊食物,将贪吃蛇的游戏逻辑写到了SnakeGame的类中,而不是在Snake类中。特殊食物:1.绿色:普通,
- PyQt5是python中一个非常实用的GUI编程模块,功能十分强大。刚刚学完了Pyqt的编程,就迫不及待的写出了一个电子词典GUI程序。整
- 网上有很多方法能够过去到IP地址归属地的脚本,但是我发现淘宝IP地址库的信息更详细些,所以用shell写个脚本来处理日常工作中一些IP地址分
- 根据不同配置文件调用不同的验证函数检查输入。可以根据需求更改验证函数的逻辑。def VerifyData(func):
- 本文介绍了一系列安装教程,具体如下1.安装Python版本选择是3.5.1,因为网上有些深度学习实例用的就是这个版本,跟他们一样的话可以避免
- 使用Python实现了一下我们同事的C++高斯投影正反算,实际跑通,可用。#!/ usr/bin/python# -*- coding:ut
- HttpRequest.FILES表单上传的文件对象存储在类字典对象request.FILES中,表单格式需为multipart/form-