Pytorch 抽取vgg各层并进行定制化处理的方法
作者:xiaoxifei 发布时间:2023-01-28 16:30:15
标签:Pytorch,vgg,定制化,处理
工作中有时候需要对vgg进行定制化处理,比如有些时候需要借助于vgg的层结构,但是需要使用的是2 channels输入,等等需求,这时候可以使用vgg的原始结构用class重写一遍,但是这样的方式比较慢,并且容易出错,下面给出一种比较简单的方式
def define_vgg(vgg,input_channels,endlayer,use_maxpool=False):
vgg_ad = copy.deepcopy(vgg)
model = nn.Sequential()
i = 0
for layer in list(vgg_ad.features):
if i > endlayer:
break
if isinstance(layer, nn.Conv2d) and i is 0:
name = "conv_" + str(i)
layer = nn.Conv2d(input_channels,
layer.out_channels,
layer.kernel_size,
stride = layer.stride,
padding=layer.padding)
model.add_module(name, layer)
if isinstance(layer, nn.Conv2d):
name = "conv_" + str(i)
model.add_module(name, layer)
if isinstance(layer, nn.ReLU):
name = "leakyrelu_" + str(i)
layer = nn.LeakyReLU(inplace=True)
model.add_module(name, layer)
if isinstance(layer, nn.MaxPool2d):
name = "pool_" + str(i)
if use_maxpool:
model.add_module(name, layer)
else:
avgpool = nn.AvgPool2d(kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding)
model.add_module(name, avgpool)
i += 1
return model
函数输入项中的vgg 是直接使用的import torchvision.models.vgg16 传入的是vgg16 非预训练版本。end_layer 是需要提取的层数,这里使用了vgg.features 是指仅仅在vgg.features 上进行层的提取;也可以根据定制在classifier上进行提取。
下面是我的一个提取前7层的示例,可以使用pyCharm evaluate 上面函数返回的model,可以看到这个示例的情况,这里我的定制条件是输入通道为2 ,需要提取前7层,并且将ReLu更换为LeakyRelu。
Sequential(
(conv_0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(leakyrelu_1): LeakyReLU(negative_slope=0.01, inplace)
(conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(leakyrelu_3): LeakyReLU(negative_slope=0.01, inplace)
(pool_4): AvgPool2d(kernel_size=2, stride=2, padding=0)
(conv_5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(leakyrelu_6): LeakyReLU(negative_slope=0.01, inplace)
(conv_7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
来源:https://blog.csdn.net/xiaoxifei/article/details/86489948
0
投稿
猜你喜欢
- 本文实例讲述了Python编程之序列操作。分享给大家供大家参考,具体如下:#coding=utf8''''&
- Pythond 的函数是由一个新的语句编写,即def,def是可执行的语句--函数并不存在,直到Python运行了def后才存在。函数是通过
- 导语三月疫情原因,很多地方都封闭式管理了!在回家无聊的打酱油,小编今天给大伙带来了一波小游戏——全民
- CSS是制作网页效果必不可少的东西,字体的颜色定义、表格的样式定义、图片的特效等等都少不了它。但在Dr
- 1 什么是注释注释就是对代码的解释和说明,其目的是让人们能够更加轻松地了解代码。注释是编写程序时,写程序的人给一个语句、程序段、函数等的解释
- 接着上篇的内容,这里实现一个交易记录链,废话不多说,先看图:跟之前的逻辑类似,但也有少许不同,这里多了一个payloadhash,以及对pa
- Session每台电脑访问服务器,都有独立的session,key值都一样,内容不一样。1.session保存在服务器上。2.session
- 这篇文章主要介绍了微信小程序 云开发模糊查询实现详解,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友
- 开启Web服务1.基本方式Python中自带了简单的服务器程序,能较容易地打开服务。在python3中将原来的SimpleHTTPServe
- 测试环境为Windows 10 系统,Python3.7,转换需要提前安装pydub、ffmpeg,安装和加入环境变量配置方法自行解决,至于
- 一、总结apply —— 应用在 dataFrame 上,用于对 row 或者 column 进行计算applymap —— 应用在 dat
- 一、问题由来工作的局域网中,会接入很多设备,机器人上的网络设备就2个了,一个巨哥红外,一个海康可见光。机器人还有自身的ip。有时候机器人挂的
- 本文实例讲述了python统计文本文件内单词数量的方法。分享给大家供大家参考。具体实现方法如下:# count lines, sentenc
- 我就废话不多说了,直接上代码吧!>>> list1 = [1,2,3,4,4]>>> list2 = [
- LDAP(Light Directory Access Portocol)是轻量目录访问协议,基于X.500标准,支持TCP/IP。LDAP
- 本文实例讲述了Python高级编程之继承问题。分享给大家供大家参考,具体如下:多继承问题1.单独调用父类: 一个子类同时继承自多个父类,又称
- zip()的作用先看一下语法:zip(iter1 [,iter2 [...]]) —> zip objectPython的内置help
- 从两个优秀的世界各取所需,更高效的复用代码。想想就醉了,.NET和python融合了。“懒惰”的程序员们,还等什么?Jesse Smith为
- 内容摘要:本文介绍了使用js来实现下拉伸缩导航菜单的功能,并带有渐显的效果,值得收藏。正好这几天公司不忙,学校又没有事情,所以想抽空架一个个
- 介绍我们可以通过for循环来迭代list、tuple、dict、set、字符串,dict比较特殊dict的存储不是连续的,所以迭代(遍历)出