Pytorch 抽取vgg各层并进行定制化处理的方法
作者:xiaoxifei 发布时间:2023-01-28 16:30:15
标签:Pytorch,vgg,定制化,处理
工作中有时候需要对vgg进行定制化处理,比如有些时候需要借助于vgg的层结构,但是需要使用的是2 channels输入,等等需求,这时候可以使用vgg的原始结构用class重写一遍,但是这样的方式比较慢,并且容易出错,下面给出一种比较简单的方式
def define_vgg(vgg,input_channels,endlayer,use_maxpool=False):
vgg_ad = copy.deepcopy(vgg)
model = nn.Sequential()
i = 0
for layer in list(vgg_ad.features):
if i > endlayer:
break
if isinstance(layer, nn.Conv2d) and i is 0:
name = "conv_" + str(i)
layer = nn.Conv2d(input_channels,
layer.out_channels,
layer.kernel_size,
stride = layer.stride,
padding=layer.padding)
model.add_module(name, layer)
if isinstance(layer, nn.Conv2d):
name = "conv_" + str(i)
model.add_module(name, layer)
if isinstance(layer, nn.ReLU):
name = "leakyrelu_" + str(i)
layer = nn.LeakyReLU(inplace=True)
model.add_module(name, layer)
if isinstance(layer, nn.MaxPool2d):
name = "pool_" + str(i)
if use_maxpool:
model.add_module(name, layer)
else:
avgpool = nn.AvgPool2d(kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding)
model.add_module(name, avgpool)
i += 1
return model
函数输入项中的vgg 是直接使用的import torchvision.models.vgg16 传入的是vgg16 非预训练版本。end_layer 是需要提取的层数,这里使用了vgg.features 是指仅仅在vgg.features 上进行层的提取;也可以根据定制在classifier上进行提取。
下面是我的一个提取前7层的示例,可以使用pyCharm evaluate 上面函数返回的model,可以看到这个示例的情况,这里我的定制条件是输入通道为2 ,需要提取前7层,并且将ReLu更换为LeakyRelu。
Sequential(
(conv_0): Conv2d(2, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(leakyrelu_1): LeakyReLU(negative_slope=0.01, inplace)
(conv_2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(leakyrelu_3): LeakyReLU(negative_slope=0.01, inplace)
(pool_4): AvgPool2d(kernel_size=2, stride=2, padding=0)
(conv_5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(leakyrelu_6): LeakyReLU(negative_slope=0.01, inplace)
(conv_7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
)
来源:https://blog.csdn.net/xiaoxifei/article/details/86489948


猜你喜欢
- 前言在机器学习中,我们会经常和矩阵打交道。在矩阵的运算中,python默认的输出是浮点数,但是如果我们想要矩阵的元素以分数的形式显示,可以通
- 本文是对《Python Qt GUI快速编程》的第10章的例子剪贴板用Python3+PyQt5进行改写,分别对文本,图片和html文本的复
- 前言因近期进行时间序列分析时遇到了数据预处理中的缺失值处理问题,其中日期缺失和填充在网上没有找到较好较全资料,耗费了我一晚上工作时间,所以下
- php文件 <?php class xpathExtension{ public static function getNodes($
- 基本思路是使用opencv来把随机生成的字符,和随机生成的线段,放到一个随机生成的图像中去。虽然没有加复杂的形态学处理,但是目前看起来效果还
- 1、同级目录下调用若在程序 testone.py 中导入模块 testtwo.py , 则直接使用【import testtwo 或 fro
- xlsxwriter 简介用于以 Excel 2007+ XLSX 文件格式编写文件,相较之下 PhpSpreadsheet 支持更多的格式
- Pytorch把Tensor转化成图像可视化在调试程序的时候经常想把tensor可视化成来看看,可以这样操作:from torchvisio
- 我的机器不知为何,安装MySQL的时候,一到配置那一步就无休止的等待,只好结束任务,然而启动MySQL的时候出现1067错误提示
- 前言为了体现不加索引和添加索引的区别,需要使用百万级的数据,但是百万数据的表,如果使用一条条添加,特别繁琐又麻烦,这里使用存储过程快速添加数
- 网上找了很多资料,都不理想。其实ubuntu20以后的版本,很多功能都预装好了,安装django也没有以前的版本那么复杂。很简单,只需要几步
- 数据描述每条数据项储存在列表中,最后一列储存结果多条数据项形成数据集data=[[d1,d2,d3...dn,result],
- 动态 web 应用也会需要静态文件,通常是 CSS 和 JavaScript 文件。理想状况下, 我们已经配置好 Web 服务器来提供静态文
- 一、导入所需的库import turtleimport randomfrom math import *二、生成斐波那契数列斐波那契数列是指
- UDPUDP是面向无连接的通讯协议,UDP数据包括目的端口号和源端口号信息,由于通讯不需要连接,所以可以实现广播发送。 UDP传输数据时有大
- tensorflow中的conv2有padding=‘SAME'这个参数。吴恩达讲课中说到当padding=(f-1)/2(f为卷积
- 据了解绝大多数开发人员对于索引的理解都是一知半解,局限于大多数日常工作没有机会、也什么没有必要去关心、了解索引,实在哪天某个查询太慢了找到查
- 在python中我们可以使用openCV给图片添加水印,这里注意openCV无法添加汉字水印,添加汉字水印上可使用PIL库给图片添加水印一:
- 很多朋友使用Dreamweaver一段时间后,开始热衷于寻找各式各样的插件,追求各种各样的特效,而对于Dreamweaver中的基本功能反而
- pycharm的pytest功能在新建一个python文件时,比如名称是test_test_test.py,由于含有test,pycharm