python PyTorch预训练示例
作者:算法学习者 发布时间:2022-07-06 18:24:02
前言
最近使用PyTorch感觉妙不可言,有种当初使用Keras的快感,而且速度还不慢。各种设计直接简洁,方便研究,比tensorflow的臃肿好多了。今天让我们来谈谈PyTorch的预训练,主要是自己写代码的经验以及论坛PyTorch Forums上的一些回答的总结整理。
直接加载预训练模型
如果我们使用的模型和原模型完全一样,那么我们可以直接加载别人训练好的模型:
my_resnet = MyResNet(*args, **kwargs)
my_resnet.load_state_dict(torch.load("my_resnet.pth"))
当然这样的加载方法是基于PyTorch推荐的存储模型的方法:
torch.save(my_resnet.state_dict(), "my_resnet.pth")
还有第二种加载方法:
my_resnet = torch.load("my_resnet.pth")
加载部分预训练模型
其实大多数时候我们需要根据我们的任务调节我们的模型,所以很难保证模型和公开的模型完全一样,但是预训练模型的参数确实有助于提高训练的准确率,为了结合二者的优点,就需要我们加载部分预训练模型。
pretrained_dict = model_zoo.load_url(model_urls['resnet152'])
model_dict = model.state_dict()
# 将pretrained_dict里不属于model_dict的键剔除掉
pretrained_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict}
# 更新现有的model_dict
model_dict.update(pretrained_dict)
# 加载我们真正需要的state_dict
model.load_state_dict(model_dict)
因为需要剔除原模型中不匹配的键,也就是层的名字,所以我们的新模型改变了的层需要和原模型对应层的名字不一样,比如:resnet最后一层的名字是fc(PyTorch中),那么我们修改过的resnet的最后一层就不能取这个名字,可以叫fc_
微改基础模型预训练
对于改动比较大的模型,我们可能需要自己实现一下再加载别人的预训练参数。但是,对于一些基本模型PyTorch中已经有了,而且我只想进行一些小的改动那么怎么办呢?难道我又去实现一遍吗?当然不是。
我们首先看看怎么进行微改模型。
微改基础模型
PyTorch中的torchvision里已经有很多常用的模型了,可以直接调用:
AlexNet
VGG
ResNet
SqueezeNet
DenseNet
import torchvision.models as models
resnet18 = models.resnet18()
alexnet = models.alexnet()
squeezenet = models.squeezenet1_0()
densenet = models.densenet_161()
但是对于我们的任务而言有些层并不是直接能用,需要我们微微改一下,比如,resnet最后的全连接层是分1000类,而我们只有21类;又比如,resnet第一层卷积接收的通道是3, 我们可能输入图片的通道是4,那么可以通过以下方法修改:
resnet.conv1 = nn.Conv2d(4, 64, kernel_size=7, stride=2, padding=3, bias=False)
resnet.fc = nn.Linear(2048, 21)
简单预训练
模型已经改完了,接下来我们就进行简单预训练吧。
我们先从torchvision中调用基本模型,加载预训练模型,然后,重点来了,将其中的层直接替换为我们需要的层即可:
resnet = torchvision.models.resnet152(pretrained=True)
# 原本为1000类,改为10类
resnet.fc = torch.nn.Linear(2048, 10)
其中使用了pretrained参数,会直接加载预训练模型,内部实现和前文提到的加载预训练的方法一样。因为是先加载的预训练参数,相当于模型中已经有参数了,所以替换掉最后一层即可。OK!
来源:http://blog.csdn.net/AMDS123/article/details/70144935
猜你喜欢
- 代码如下:SELECT [StartDate] FROM [dbo].[udf_Week](2012,2012) WHERE [
- PHP 是世界上最好的语言。经典的 LNMP(linux + nginx + php + mysql)环境有很多现成的部署脚本,但是在 Do
- jsp登陆验证,网页登陆验证带验证码校验,登录功能之添加验证码part_1:专门用于生成一个验证码图片的类:VerificationCode
- 看过数据库的备份与还原。大多数都是用组件来完成的。其实可通过sql语句来完成。 由于时间关系,未对参数进行验证和界面美化。代码
- 酝酿了将近一个春夏秋冬的腾讯网首页终于亮剑!反响热烈!让我们来分享它成功背后的酸甜苦辣吧。腾讯网首页改版终于开花结果。于2008年3月25日
- 1 为什么需要防抖和节流在前端开发当中,有些交互事件,会被频繁触发,这样会导致我们的页面渲染性能下降,如果频繁触发接口调用的话,会直接导致服
- 摘要:现代网站和web应用程序趋向于依赖客户端的大量的javascript来提供丰富的交互。特别是通过不刷新页面的异步请求来返回数据或从服务
- HP QR Code是一个PHP二维码生成类库,利用它可以轻松生成二维码,官网提供了下载和多个演示demo,查看地址:http://phpq
- 本文实例讲述了Python实现备份文件的方法,是一个非常实用的技巧。分享给大家供大家参考。具体方法如下:该实例主要实现读取一个任务文件, 根
- 细节汇总函数的形参列表可以是多个,返回值列表也可以是多个形参列表和返回值列表的数据类型,可以是值类型、也可以是引用类型函数的命名遵循标识符命
- 最近,我面试了一个有五年 Web 应用程序开发经验的软件开发人员。四年半来她一直在从事 JavaScript 相关的工作,她自认为 Java
- 开始制作符合标准的站点,第一件事情就是声明符合自己需要的DOCTYPE。查看本站首页原代码,可以看到第一行就是:<!DOCTYPE h
- 本文介绍了一些JavaScript常用到得表单验证函数,方便大家使用。 判断是否为整数,是则返回true,否则返回falsefun
- 每个电子商务数据分析师必须掌握的一项数据聚类技能如果你是一名在电子商务公司工作的数据分析师,从客户数据中挖掘潜在价值,来提高客户留存率很可能
- 硬件平台:SUN Ultra Enterprise 3000 操作系统:Solaris 2.5(中文简体) 磁盘:4.2GB 内存:256M
- 一、概述 对象是Oracle8i以上版本中的一个新的特性,对象实际是对一组数据和操作的封装,对象的抽象就是类。在面向对象技术中,对象涉及到以
- asp之家注:如果你学习过asp,并且在网络公司上过班,一定会接触到网购系统,网购系统可以说是一个典型的程序类型,而其中最重要,也是最关键的
- 内容摘要:有很多朋友虽然安装好了mysql但却不知如何使用它。在这篇文章中我们就从连接mysql、修改密码、增加用户等方面来学习一些mysq
- 内容摘要:通常的,ASP中表单提交的数据一般被写入数据库。然而,如果你想让发送数据更为简便易行,那么,可以将它书写为XML文件格式。这种方式
- 最近参与了将一个Sybase数据库移植到Microsoft SQL Server 2000上的项目,我在这一项目上获得的经验,将对Sybas