对Pytorch中nn.ModuleList 和 nn.Sequential详解
作者:ustc_lijia 发布时间:2023-07-04 06:54:46
简而言之就是,nn.Sequential类似于Keras中的贯序模型,它是Module的子类,在构建数个网络层之后会自动调用forward()方法,从而有网络模型生成。而nn.ModuleList仅仅类似于pytho中的list类型,只是将一系列层装入列表,并没有实现forward()方法,因此也不会有网络模型产生的副作用。
需要注意的是,nn.ModuleList接受的必须是subModule类型,例如:
nn.ModuleList(
[nn.ModuleList([Conv(inp_dim + j * increase, oup_dim, 1, relu=False, bn=False) for j in range(5)]) for i in
range(nstack)])
其中,二次嵌套的list内部也必须额外使用一个nn.ModuleList修饰实例化,否则会无法识别类型而报错!
摘录自
nn.ModuleList is just like a Python list. It was designed to store any desired number of nn.Module's. It may be useful, for instance, if you want to design a neural network whose number of layers is passed as input:
class LinearNet(nn.Module):
def __init__(self, input_size, num_layers, layers_size, output_size):
super(LinearNet, self).__init__()
self.linears = nn.ModuleList([nn.Linear(input_size, layers_size)])
self.linears.extend([nn.Linear(layers_size, layers_size) for i in range(1, self.num_layers-1)])
self.linears.append(nn.Linear(layers_size, output_size)
nn.Sequential allows you to build a neural net by specifying sequentially the building blocks (nn.Module's) of that net. Here's an example:
class Flatten(nn.Module):
def forward(self, x):
N, C, H, W = x.size() # read in N, C, H, W
return x.view(N, -1)
simple_cnn = nn.Sequential(
nn.Conv2d(3, 32, kernel_size=7, stride=2),
nn.ReLU(inplace=True),
Flatten(),
nn.Linear(5408, 10),
)
In nn.Sequential, the nn.Module's stored inside are connected in a cascaded way. For instance, in the example that I gave, I define a neural network that receives as input an image with 3 channels and outputs 10 neurons. That network is composed by the following blocks, in the following order: Conv2D -> ReLU -> Linear layer. Moreover, an object of type nn.Sequential has a forward() method, so if I have an input image x I can directly call y = simple_cnn(x) to obtain the scores for x. When you define an nn.Sequential you must be careful to make sure that the output size of a block matches the input size of the following block. Basically, it behaves just like a nn.Module
On the other hand, nn.ModuleList does not have a forward() method, because it does not define any neural network, that is, there is no connection between each of the nn.Module's that it stores. You may use it to store nn.Module's, just like you use Python lists to store other types of objects (integers, strings, etc). The advantage of using nn.ModuleList's instead of using conventional Python lists to store nn.Module's is that Pytorch is “aware” of the existence of the nn.Module's inside an nn.ModuleList, which is not the case for Python lists. If you want to understand exactly what I mean, just try to redefine my class LinearNet using a Python list instead of a nn.ModuleList and train it. When defining the optimizer() for that net, you'll get an error saying that your model has no parameters, because PyTorch does not see the parameters of the layers stored in a Python list. If you use a nn.ModuleList instead, you'll get no error.
来源:https://blog.csdn.net/xiaojiajia007/article/details/82118559


猜你喜欢
- 可怜我的C盘本来只有8.XG,所以不得不卸载掉它。卸载掉本身没啥问题,只是昨晚突然发现 Sql Server 2008 R2 Managem
- 一、PyTorch 检查模型梯度是否可导当我们构建复杂网络模型或在模型中加入复杂操作时,可能会需要验证该模型或操作是否可导,即模型是否能够优
- 本文实例讲述了Python保存最后N个元素的方法。分享给大家供大家参考,具体如下:问题:希望在迭代或是其他形式的处理过程中对最后几项记录做一
- 1、jsp前端<%-- Created by IntelliJ IDEA. User: Lenovo Date: 2020/6/19
- 流读写很多时候,数据读写不一定是文件,也可以在内存中读写。1、StringIO:在内存中读写str。要把str写入StringIO,我们需要
- 目录前言为什么要用numpy数组的创建生成均匀分布的array:生成特殊数组获取数组的属性数组索引,切片,赋值数组操作输出数组总结前言Num
- 目录前言连接管理额外连接管理端口总结前言下面这个报错,相信大多数童鞋都遇见过;那么碰到这个问题,我们应该怎么办呢?在MySQL 5.7及之前
- 首先给出展示结果,大体就是检测工业板子是否出现。采取检测的方法比较简单,用的OpenCV的模板检测。大体思路opencv读取视频将视频分割为
- 词云图from pyecharts.charts import WordClouddef word1(): words= [ &
- 简单的小练习,实现将一个指定列表中的数值进行转化,对于其中的非负数不作处理,对于负数需要转化为制定的数值,很简单就不多说了,下面是具体的实现
- 反射指的是运行时动态的获取变量的相关信息1. reflect 包类型是变量,类别是常量reflect.TypeOf,获取变量的类型,返回re
- 题目要求1.后台管理员只有一个用户:admin, 密码: admin2.当管理员登陆成功后,可以管理前台会员信息。3.会员信息管
- 语音识别是人工智能中的一个领域,它允许计算机理解人类语音并将其转换为文本。该技术用于 Alexa 和各种聊天机器人应用程序等设备。而我们最常
- 一、关于XML解析XML在Java应用程序里变得越来越重要, 广泛应用于数据存储和交换. 比如我们常见的配置文件,都是以XML方式存储的.
- 在 Java 中打印当前线程的方法栈,可以用 kill -3 命令向 JVM 发送一个 OS 信号,JVM 捕捉以后会自动 dump 出来;
- 本文实例讲述了python简单文本处理的方法。分享给大家供大家参考。具体如下:由于有多线程的影响,c++项目打印出来的时间顺序不一致,导致不
- 目录property属性property属性的定义和调用要注意一下几点:具体实例property属性的有两种方式装饰器方式旧式类新式类注意类
- 前面简单提到了 Python 模拟登录的程序,但是没写清楚,这里再补上一个带注释的 Python 模拟登录的示例程序。简单说一下流程:先用c
- Notebook 修改字体和大小原理很简单,就是更改CSS文件原本的字体很难看,尤其是 引号😡我推荐两款字体,Consolas 和 Fira
- 环境介绍系统环境:Windows 10Python版本:Python 3.5必备包:无 运行Python脚本:.bat文件在Win