对pytorch网络层结构的数组化详解
作者:库页 发布时间:2023-09-02 12:10:09
标签:pytorch,网络层,结构,数组化
最近再写openpose,它的网络结构是多阶段的网络,所以写网络的时候很想用列表的方式,但是直接使用列表不能将网络中相应的部分放入到cuda中去。
其实这个问题很简单的,使用moduleList就好了。
1 我先是定义了一个函数,用来根据超参数,建立一个基础网络结构
stage = [[3, 3, 3, 1, 1], [7, 7, 7, 7, 7, 1, 1]]
branches_cfg = [[[128, 128, 128, 512, 38], [128, 128, 128, 512, 19]],
[[128, 128, 128, 128, 128, 128, 38], [128, 128, 128, 128, 128, 128, 19]]]
# used for add two branches as well as adapt to certain stage
def add_extra(i, branches_cfg, stage):
"""
only add CNN of brancdes S & L in stage Ti at the end of net
:param in_channels:the input channels & out
:param stage: size of filter
:param branches_cfg: channels of image
:return:list of layers
"""
in_channels = i
layers = []
for k in range(len(stage)):
padding = stage[k] // 2
conv2d = nn.Conv2d(in_channels, branches_cfg[k], kernel_size=stage[k], padding=padding)
layers += [conv2d, nn.ReLU(inplace=True)]
in_channels = branches_cfg[k]
return layers
2 然后用普通列表装载他们
conf_bra_list = []
paf_bra_list = []
# param for branch network
in_channels = 128
for i in range(all_stage):
if i > 0:
branches = branches_cfg[1]
conv_sz = stage[1]
else:
branches = branches_cfg[0]
conv_sz = stage[0]
conf_bra_list.append(nn.Sequential(*add_extra(in_channels, branches[0], conv_sz)))
paf_bra_list.append(nn.Sequential(*add_extra(in_channels, branches[1], conv_sz)))
in_channels = 185
3 再然后,使用moduleList方法,把普通列表专成pytorch下的模块
# to list
self.conf_bra = nn.ModuleList(conf_bra_list)
self.paf_bra = nn.ModuleList(paf_bra_list)
4 最后,调用就好了
out_0 = x
# the base transform
for k in range(len(self.vgg)):
out_0 = self.vgg[k](out_0)
# local name space
name = locals()
confs = []
pafs = []
outs = []
length = len(self.conf_bra)
for i in range(length):
name['conf_%s' % (i + 1)] = self.conf_bra[i](name['out_%s' % i])
name['paf_%s' % (i + 1)] = self.paf_bra[i](name['out_%s' % i])
name['out_%s' % (i + 1)] = torch.cat([name['conf_%s' % (i + 1)], name['paf_%s' % (i + 1)], out_0], 1)
confs.append('conf_%s' % (i + 1))
pafs.append('paf_%s' % (i + 1))
outs.append('out_%s' % (i + 1))
5 顺便装了一下,使用了python局部变量命名空间,name = locals(),其实完全使用普通列表保存变量就好了,高兴就好。
来源:https://blog.csdn.net/daniaokuye/article/details/78827436


猜你喜欢
- 使用drop函数删除dataframe的某列或某行数据:drop(labels, axis=0, level=None, inplace=F
- 一、前言对于一个桌面应用来说,有时候单独一个窗口用户使用起来会不太方便,比方说写日报或者查看文件等,若是在同一窗口内,我只能做一件事,不能边
- 如何在约定时间显示特定的提示信息?<%Function Greeting()
- 我就废话不多说了,大家还是直接看代码吧~func ReadLine(fileName string) ([]string,error){f,
- 之前对bottle做过不少的介绍,也写过一些文章来说明bottle的缺点,最近发现其实之前有些地方说的不太公平,所以趁此机会也来更正一下。&
- 这篇文章主要介绍了python构造函数init实例方法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要
- 我就废话不多说了,大家还是直接看代码吧~import tensorflow as tffrom sklearn.metrics import
- 1、建表语句:CREATE TABLE `employees` ( `emp_no` int(11) NOT NULL, `birth_da
- 使用 datetime 模块中的 timedelta() 方法将天数添加到日期中,例如 result_1 = date_1 + timede
- 应用场景:在进行多选的时候一般默认显示第一个。实现方法:纯vue实现例子:<span v-for="(one,index)
- 这篇文章主要介绍了Django多进程滚动日志问题解决方案,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的
- 前言大家好,说起动态条形图,之前推荐过两个 Python 库,比如Bar Chart Race、Pandas_Alive,都可以实现。今天就
- SQLite3数据库的介绍和使用(面向业务编程-数据库)SQLite3介绍SQLite是一种用C语言实现的的SQL数据库它的特点有:轻量级、
- 毋庸置疑,Python越来越被认可为程序员新时代的风口语言,Python的应用能力是成为一代码农大神的必要项。首先告诉你的是,零基础学习开始
- 本文实例讲述了python 实现的发送邮件模板。分享给大家供大家参考,具体如下:##发送普通txt文件(与发送html邮件不同的是邮件内容设
- 实现效果图如下:当我点击 + 按钮时,会添加一行输入框组;当点击 - 按钮时,会删除这一行输入框组html代码如下:<div clas
- 前言今天就来理一理session、cookie、token这三者之间的关系!1.为什么会有它们?我们都知道 HTTP 协议是无状态的,所谓的
- 之前安装mysql 5.7.12时未做总结,换新电脑,补上安装记录,安装的时候,找了些网友的安装记录,发现好多坑(一)mysql 5.7.1
- 1.用户输入月份,判断这个月是哪个季节month = int(input('Month:'))if month in [3,
- 1. 安装pip3yum install python34-pip2. 安装python34develyum install python3