使用pytorch搭建AlexNet操作(微调预训练模型及手动搭建)
作者:sjtu_leexx 发布时间:2023-05-04 05:09:51
本文介绍了如何在pytorch下搭建AlexNet,使用了两种方法,一种是直接加载预训练模型,并根据自己的需要微调(将最后一层全连接层输出由1000改为10),另一种是手动搭建。
构建模型类的时候需要继承自torch.nn.Module类,要自己重写__ \_\___init__ \_\___方法和正向传递时的forward方法,这里我自己的理解是,搭建网络写在__ \_\___init__ \_\___中,每次正向传递需要计算的部分写在forward中,例如把矩阵压平之类的。
加载预训练alexnet之后,可以print出来查看模型的结构及信息:
model = models.alexnet(pretrained=True)
print(model)
分为两个部分,features及classifier,后续搭建模型时可以也写成这两部分,并且从打印出来的模型信息中也可以看出每一层的引用方式,便于修改,例如model.classifier[1]指的就是Linear(in_features=9216, out_features=4096, bias=True)这层。
下面放出完整的搭建代码:
import torch.nn as nn
from torchvision import models
class BuildAlexNet(nn.Module):
def __init__(self, model_type, n_output):
super(BuildAlexNet, self).__init__()
self.model_type = model_type
if model_type == 'pre':
model = models.alexnet(pretrained=True)
self.features = model.features
fc1 = nn.Linear(9216, 4096)
fc1.bias = model.classifier[1].bias
fc1.weight = model.classifier[1].weight
fc2 = nn.Linear(4096, 4096)
fc2.bias = model.classifier[4].bias
fc2.weight = model.classifier[4].weight
self.classifier = nn.Sequential(
nn.Dropout(),
fc1,
nn.ReLU(inplace=True),
nn.Dropout(),
fc2,
nn.ReLU(inplace=True),
nn.Linear(4096, n_output))
#或者直接修改为
# model.classifier[6]==nn.Linear(4096,n_output)
# self.classifier = model.classifier
if model_type == 'new':
self.features = nn.Sequential(
nn.Conv2d(3, 64, 11, 4, 2),
nn.ReLU(inplace = True),
nn.MaxPool2d(3, 2, 0),
nn.Conv2d(64, 192, 5, 1, 2),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2, 0),
nn.Conv2d(192, 384, 3, 1, 1),
nn.ReLU(inplace = True),
nn.Conv2d(384, 256, 3, 1, 1),
nn.ReLU(inplace=True),
nn.MaxPool2d(3, 2, 0))
self.classifier = nn.Sequential(
nn.Dropout(),
nn.Linear(9216, 4096),
nn.ReLU(inplace=True),
nn.Dropout(),
nn.Linear(4096, 4096),
nn.ReLU(inplace=True),
nn.Linear(4096, n_output))
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
out = self.classifier(x)
return out
微调预训练模型的思路为:直接保留原模型的features部分,重写classifier部分。在classifier部分中,我们实际需要修改的只有最后一层全连接层,之前的两个全连接层不需要修改,所以重写的时候需要把这两层的预训练权重和偏移保留下来,也可以像注释掉的两行代码里那样直接引用最后一层全连接层进行修改。
网络搭好之后可以小小的测试一下以检验维度是否正确。
import numpy as np
from torch.autograd import Variable
import torch
if __name__ == '__main__':
model_type = 'pre'
n_output = 10
alexnet = BuildAlexNet(model_type, n_output)
print(alexnet)
x = np.random.rand(1,3,224,224)
x = x.astype(np.float32)
x_ts = torch.from_numpy(x)
x_in = Variable(x_ts)
y = alexnet(x_in)
这里如果不加“x = x.astype(np.float32)”的话会报一个类型错误,感觉有点奇怪。
输出y.data.numpy()可得10维输出,表明网络搭建正确。
来源:https://blog.csdn.net/sjtuxx_lee/article/details/83048006
猜你喜欢
- 因为Python是自带文档,可以通过help函数来查询每一个系统函数的用法解释说明。一般来说,关键的使用方法和注意点在这个系统的文档中都说的
- 操作系统会为每一个创建的进程分配一个独立的地址空间,不同进程的地址空间是完全隔离的,因此如果不加其他的措施,他们完全感觉不到彼此的存在。那么
- 今天好不容易闲下来半天,所以和大家分享一下我之前总结的一套Web UI 设计命名规范,也就是网站用户界面设计(俗称网页设计)命名规范。这套规
- 在blueidea上看到movoin转的一个动态加载include文件代码,接着dnawo又修改了下,我用了dnawo修改后的版本,感觉挺好
- 本文实例讲述了python服务器与android客户端socket通信的方法。分享给大家供大家参考。具体实现方法如下:首先,服务器端使用py
- 要知道我们程序猿也是需要浪漫的,小博我之前在网上搜寻了很多代码,确发现好多都不是最新的,所以自己就整理了一下代码,现在与广大博友们分享下我们
- 从ASP初入门到PHP,感觉到PHP的强大之一就是内置函数的丰富,比如先前学习的PHP日期时间函数,读写文件的相关函数等都无不表明了PHP的
- 最近项目需要抓包功能,并且抓包后要对数据包进行存库并分析。抓包想使用tcpdump来完成,但是tcpdump抓包之后只能保存为文件,我需要将
- 近年来,广告已成为很多网站的主要收入来源。不久前,在线广告往往遭到访客的拒绝,广告客户也不确定它的价值和效力。今天,大多数访客期望在商业网站
- 安装pyinstallerpip install pyinstaller制作项目的.spec文件 进入django项目所在路径,
- 关于asp随机数的相关文章:asp生成一个不重复的随机数字 8个asp生成随机字符的函数 <html> <me
- 这里所谓的复杂表单,是指表单中包含多种不同的输入类型,比如下拉列表框、单行文本、多行文本、数值等。在经常需要更换这类表单的场合,需要有一个表
- 两年前在 B 站上看到了一个宝藏 up 主,名叫 "Jannchie见齐",专门做动态条形图样式的数据可视化。做出的效果
- 英文文档:setattr(object, name, value)This is the counterpart of getattr().
- PyCharm 具备一般 IDE 的功能,比如,调试、语法高亮、项目管理、代码跳转、智能提示、自动完成、单元测试、版本控制…另外,PyCha
- 本文实例讲述了php隐藏IP地址后两位显示为星号的方法。分享给大家供大家参考。具体实现方法如下:我们在很多的公共网站中都会有碰到显示用户的I
- MaxDB和MySQL是独立的数据库管理服务器。系统间的协同性是可能的,通过相应的方式,系统能够彼此交换数据。要想在MaxDB和MySQL之
- 严格控制Session可以将不需要Session的内容(比如帮助画面,访问者区域,等等)移动到关闭Session的独立ASP应用程序中。在基
- 1.函数对象前面我们学习了关于Python中的变量类型,例如int、str、bool、list等等…&hell
- windows环境下python2.7 脚本指定一个参数作为要检索的字符串例如: >find.py ./ hello# coding=