使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
作者:sjtu_leexx 发布时间:2023-05-04 05:09:51
本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。
构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。
加载预训练alexnet之后,可以print出来查看模型的结构及信息:
model = models.alexnet(pretrained=True)
print(model)
分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。
下面放出完整的搭建代码:
import torch.nn as nn
from torchvision import models
class BuildAlexNet(nn.Module):
def __init__(self, model_type, n_output):
super(BuildAlexNet, self).__init__()
self.model_type = model_type
if model_type == 'pre':
model = models.alexnet(pretrained=True)
self.features = model.features
fc1 = nn.Linear(9216, 4096)
fc1.bias = model.classifier[1].bias
fc1.weight = model.classifier[1].weight
fc2 = nn.Linear(4096, 4096)
fc2.bias = model.classifier[4].bias
fc2.weight = model.classifier[4].weight
self.classifier = nn.Sequential(
nn.Dropout(),
fc1,
nn.ReLU(inplace=True),
nn.Dropout(),
fc2,
nn.ReLU(inplace=True),
nn.Linear(4096, n_output))
#或者直接修改为
# model.classifier[6]==nn.Linear(4096,n_output)
# self.classifier = model.classifier
if model_type == 'new':
self.features = nn.Sequential(
nn.Conv2d(3, 64, 11, 4, 2),
nn.ReLU(inplace = True),
nn.MaxPool2d(3, 2, 0),
nn.Conv2d(64, 192, 5, 1, 2),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2, 0),
nn.Conv2d(192, 384, 3, 1, 1),
nn.ReLU(inplace = True),
nn.Conv2d(384, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2, 0))
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(9216, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, n_output))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
out = self.classifier(x)
return out
微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。
网络搭好之后可以小小的测试一下以检验维度是否正确。
import numpy as np
from torch.autograd import Variable
import torch
if __name__ == '__main__':
model_type = 'pre'
n_output = 10
alexnet = BuildAlexNet(model_type, n_output)
print(alexnet)
x = np.random.rand(1,3,224,224)
x = x.astype(np.float32)
x_ts = torch.from_numpy(x)
x_in = Variable(x_ts)
y = alexnet(x_in)
这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。
输出y.data.numpy()可得10维输出,表明网络搭建正确。
来源:https://blog.csdn.net/sjtuxx_lee/article/details/83048006


猜你喜欢
- 废话不多说,直接开始拉~~~我们总共有 6 只海龟,颜色不同,它们以随机长度移动。首先,我们应该通过输入乌龟的颜色来押注乌龟。第一个越线的乌
- 1.跨域原理1. 首先浏览器安全策略限制js ajax跨域访问服务器2. 如果服务器返回的头部信息中有当前域:// 允许 http://lo
- 1.在线定制下载echartshttps://echarts.apache.org/zh/builder.html2.创建一个django项
- 本文实例为大家分享了TensorFlow实现Logistic回归的具体代码,供大家参考,具体内容如下1.导入模块import numpy a
- python中return的用法1、return语句就是把执行结果返回到调用的地方,并把程序的控制权一起返回程序运行到所遇到的第一个retu
- 这次主要记录python-Parser的用法,以及可能遇到的系列操作。1 前言if __name__ == "__main__&q
- 前 言在开发高并发系统时,我们可能会遇到接口访问频次过高,为了保证系统的高可用和稳定性,这时候就需要做流量限制,你可能是用的 Ng
- 前言:vue调用本地摄像头实现拍照功能,由于调用摄像头有使用权限,只能在本地运行,线上需用https域名才可以使用。实现效果:1、摄像头效果
- 大家知道,mailto是网页设计制作中的一个非常实用的html标签,许多拥有个人网页的朋友都喜欢在网站的醒目位置处写上自己的电子邮件地址,这
- 在项目开发过程中,遇到如下用户体验提升需求:需要实现错误提示时根据后台返回错误列表信息,换行展示。实现方式如下:通过F12元素查看,在对应的
- 前言这篇文章主要介绍了使用Python画了一棵圣诞树的实例代码,本文通过实例代码给大家介绍的非常详细,对大家的学习或工作具有一定的参考借鉴价
- 日常运维工作中,通常是邮件报警机制,但邮件可能不被及时查看,导致问题出现得不到及时有效处理。所以想到用Python实现发短信功能,当监控到问
- 本文实例讲述了Python切片工具pillow用法。分享给大家供大家参考,具体如下:切片:使用切片将源图像分成许多的功能区域因为要对图片进行
- 一,未使用 git add 缓存代码时。可以使用 git checkout -- filepathname (比如: git checkou
- JDBC之C3P0数据库连接池,供大家参考,具体内容如下1 首先在src中创建c3p0-config.xml 配置文件,文件中内容如下(首先
- 需要注意的是:更改完源程序.c文件,需要对整个项目重新编译、make install,对已经生成的文件进行更新,类似于之前VS中在一个类中增
- 在一群里有朋友发问,有时间,也就看看了,不多说了,看图了:用一般的 select .... order 排序出来,就如下图了,是
- 前言本人曾对 Vuex 作过详细介绍,但是今天去回顾的时候发现文章思路有些繁琐,不容易找到重点。于是,在下班前几分钟,我对其重新梳理了一遍。
- 由于内容过多,大家可以通过ctrl+F搜索即可IE浏览器id 后缀名 php识别出的文件类型0 gif image/gif1 jpg ima
- SQL查询某字段的值为空sql中字段的默认有NULL和另一种空白的形式如何取查询这两种存在的记录呢?空白值查询:SELECT * FROM