pytorch常用函数定义及resnet模型修改实例
作者:MapleTx's 发布时间:2022-09-18 08:19:19
模型定义常用函数
利用nn.Parameter()设计新的层
import torch
from torch import nn
class MyLinear(nn.Module):
def __init__(self, in_features, out_features):
super().__init__()
self.weight = nn.Parameter(torch.randn(in_features, out_features))
self.bias = nn.Parameter(torch.randn(out_features))
def forward(self, input):
return (input @ self.weight) + self.bias
nn.Sequential
一个有序的容器,神经网络模块将按照在传入构造器的顺序依次被添加到计算图中执行,同时以神经网络模块为元素的有序字典也可以作为传入参数。Sequential适用于快速验证结果,简单易读,但使用Sequential也会使得模型定义丧失灵活性,比如需要在模型中间加入一个外部输入时就不适合用Sequential的方式实现。
net = nn.Sequential(
('fc1',MyLinear(4, 3)),
('act',nn.ReLU()),
('fc2',MyLinear(3, 1))
)
nn.ModuleList()
ModuleList 接收一个子模块(或层,需属于nn.Module类)的列表作为输入,然后也可以类似List那样进行append和extend操作。同时,子模块或层的权重也会自动添加到网络中来。
net = nn.ModuleList([nn.Linear(784, 256), nn.ReLU()])
net.append(nn.Linear(256, 10)) # # 类似List的append操作
print(net[-1]) # 类似List的索引访问
print(net)
Linear(in_features=256, out_features=10, bias=True)
ModuleList(
(0): Linear(in_features=784, out_features=256, bias=True)
(1): ReLU()
(2): Linear(in_features=256, out_features=10, bias=True)
)
要特别注意的是,nn.ModuleList 并没有定义一个网络,它只是将不同的模块储存在一起。
ModuleList中元素的先后顺序并不代表其在网络中的真实位置顺序,需要经过forward函数指定各个层的先后顺序后才算完成了模型的定义。
具体实现时用for循环即可完成:
class model(nn.Module):
def __init__(self, ...):
super().__init__()
self.modulelist = ...
...
def forward(self, x):
for layer in self.modulelist:
x = layer(x)
return x
nn.ModuleDict()
ModuleDict和ModuleList的作用类似,只是ModuleDict能够更方便地为神经网络的层添加名称。
net = nn.ModuleDict({
'linear': nn.Linear(784, 256),
'act': nn.ReLU(),
})
net['output'] = nn.Linear(256, 10) # 添加
print(net['linear']) # 访问
print(net.output)
print(net)
Linear(in_features=784, out_features=256, bias=True)
Linear(in_features=256, out_features=10, bias=True)
ModuleDict(
(act): ReLU()
(linear): Linear(in_features=784, out_features=256, bias=True)
(output): Linear(in_features=256, out_features=10, bias=True)
)
ModuleList和ModuleDict在某个完全相同的层需要重复出现多次时,非常方便实现,可以”一行顶多行“;当我们需要之前层的信息的时候,比如 ResNets 中的残差计算,当前层的结果需要和之前层中的结果进行融合,一般使用 ModuleList/ModuleDict 比较方便。
nn.Flatten
展平输入的张量: 28x28 -> 784
input = torch.randn(32, 1, 5, 5)
m = nn.Sequential(
nn.Conv2d(1, 32, 5, 1, 1),
nn.Flatten()
)
output = m(input)
output.size()
模型修改案例
有了上面的一些常用方法,我们可以修改现有的一些开源模型,这里通过介绍修改模型层、添加额外输入的案例来帮助我们更好地理解。
修改模型层
以pytorch官方视觉库torchvision预定义好的模型ResNet50为例,探索如何修改模型的某一层或者某几层。
我们先看看模型的定义:
import torchvision.models as models
net = models.resnet50()
print(net)
ResNet(
(conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(relu): ReLU(inplace=True)
(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
(layer1): Sequential(
(0): Bottleneck(
(conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
..............
(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))
(fc): Linear(in_features=2048, out_features=1000, bias=True)
)
为了适配ImageNet,fc层输出是1000,若需要用这个resnet模型去做一个10分类的问题,就应该修改模型的fc层,将其输出节点数替换为10。另外,我们觉得一层全连接层可能太少了,想再加一层。
可以做如下修改:
from collections import OrderedDict
classifier = nn.Sequential(OrderedDict([('fc1', nn.Linear(2048, 128)),
('relu1', nn.ReLU()),
('dropout1',nn.Dropout(0.5)),
('fc2', nn.Linear(128, 10)),
('output', nn.Softmax(dim=1))
]))
net.fc = classifier # 将模型(net)最后名称为“fc”的层替换成了我们自己定义的名称为“classifier”的结构
添加外部输入
有时候在模型训练中,除了已有模型的输入之外,还需要输入额外的信息。比如在CNN网络中,我们除了输入图像,还需要同时输入图像对应的其他信息,这时候就需要在已有的CNN网络中添加额外的输入变量。
基本思路是:将原模型添加输入位置前的部分作为一个整体,同时在forward中定义好原模型不变的部分、添加的输入和后续层之间的连接关系,从而完成模型的修改。
我们以torchvision的resnet50模型为基础,任务还是10分类任务。不同点在于,我们希望利用已有的模型结构,在倒数第二层增加一个额外的输入变量add_variable来辅助预测。
具体实现如下:
class Model(nn.Module):
def __init__(self, net):
super(Model, self).__init__()
self.net = net
self.relu = nn.ReLU()
self.dropout = nn.Dropout(0.5)
self.fc_add = nn.Linear(1001, 10, bias=True)
self.output = nn.Softmax(dim=1)
def forward(self, x, add_variable):
x = self.net(x)
# add_variable (batch_size, )->(batch_size, 1)
x = torch.cat((self.dropout(self.relu(x)), add_variable.unsqueeze(1)),1)
x = self.fc_add(x)
x = self.output(x)
return x
修改好的模型结构进行实例化,就可以使用
import torchvision.models as models
net = models.resnet50()
model = Model(net).cuda()
# 使用时输入两个inputs
outputs = model(inputs, add_var)
参考资料:
Pytorch模型定义与深度学习自查手册
来源:https://www.cnblogs.com/qftie/p/16324068.html
猜你喜欢
- 使用python实现双向链表,供大家参考,具体内容如下双向链表: 指的是讲数据链接在一起,每个数据是一个节点,每一个节点都有一个数据区,两个
- 在ASP中,如何创建DSN? 见下:<HTML><HEAD><META&n
- Python内存管理一、对象池1.小整数池系统默认创建好的,等着你使用概述:整数在程序中的使用非常广泛,Python为了优化速度,使用了小整
- Python continue语句返回while循环的开始。Continue语句拒绝在该循环的当前迭代中的其余语句执行并移动控制
- 从2003年到现在,从ACCESS到SQL SERVER的使用。在ACCESS中没有存储过程的概念。在使用过程中,发现ACCESS与SQL
- TensorFLow能够识别的图像文件,可以通过numpy,使用tf.Variable或者tf.placeholder加载进tensorfl
- WordPress 的插件机制实际上只的就是这个 Hook 了,它中文被翻译成钩子,允许你参与 WordPress 核心的运行,是一个非常棒
- 日常工作中需要对比两个Excel工作表中的数据差异是很不方便的,使用python来做就比较简单了!我们的思路是通过读取两个Excel的数据,
- 我们在升级系统的时候,经常碰到需要更新服务器端数据结构等操作,之前的方式是通过手工编写alter sql脚本处理,经常会发现遗漏,导致程序发
- 背景基于现在微服务或者服务化的思想,我们大部分的业务逻辑处理函数都是长这样的:比如grpc服务端:func (s *Service) Get
- 环境使用Python 3.8–> 解释器 <执行python代码>Pycharm–
- 本文实例讲述了js+php实现静态页面实时调用用户登陆状态的方法。分享给大家供大家参考。具体分析如下:在程序开发中,经常会把页面做成html
- 我们知道django的orm想实现自增,可以直接使用AutoField字段既可以实现,但是这种情况必须要求此字段是主键,但是我们知道主键只能
- 使用tensorflow训练模型的时候,模型持久化对我们来说非常重要。如果我们的模型比较复杂,需要的数据比较多,那么在模型的训练时间会耗时很
- 如下所示:# -*- coding:utf-8 -*-import sysimport osfrom glob import globimp
- 模型训练时GPU利用率太低的原因最近在训练SSD模型时发现GPU的利用率只有8%,而CPU的利用率却非常高。后来了解到,一般使用CPU进行数
- Dreaweaver MX 2004 中增加了图片处理功能,如图片亮度和对比度的调节、图片的锐化效果等
- 1.介绍在 Golang 语言项目开发中,经常会遇到数据排序问题。Golang 语言标准库 sort 包,为我们提供了数据排序的功能,我们可
- 首先说说框架(Frameworks)这个词,框架就是为我们提供了一个平台一个运行环境,在如此统一的前提下我们做相关开发才能“有章可循”,要充
- virtualenv 是用来创建一个虚拟的python环境的第三方包,一个专属于项目的python环境。安装virtualenv(请确保py