Pytorch之保存读取模型实例
作者:啧啧啧biubiu 发布时间:2023-04-03 02:15:11
标签:Pytorch,保存,读取,模型
pytorch保存数据
pytorch保存数据的格式为.t7文件或者.pth文件,t7文件是沿用torch7中读取模型权重的方式。而pth文件是python中存储文件的常用格式。而在keras中则是使用.h5文件。
# 保存模型示例代码
print('===> Saving models...')
state = {
'state': model.state_dict(),
'epoch': epoch # 将epoch一并保存
}
if not os.path.isdir('checkpoint'):
os.mkdir('checkpoint')
torch.save(state, './checkpoint/autoencoder.t7')
保存用到torch.save函数,注意该函数第一个参数可以是单个值也可以是字典,字典可以存更多你要保存的参数(不仅仅是权重数据)。
pytorch读取数据
pytorch读取数据使用的方法和我们平时使用预训练参数所用的方法是一样的,都是使用load_state_dict这个函数。
下方的代码和上方的保存代码可以搭配使用。
print('===> Try resume from checkpoint')
if os.path.isdir('checkpoint'):
try:
checkpoint = torch.load('./checkpoint/autoencoder.t7')
model.load_state_dict(checkpoint['state']) # 从字典中依次读取
start_epoch = checkpoint['epoch']
print('===> Load last checkpoint data')
except FileNotFoundError:
print('Can\'t found autoencoder.t7')
else:
start_epoch = 0
print('===> Start from scratch')
以上是pytorch读取的方法汇总,但是要注意,在使用官方的预处理模型进行读取时,一般使用的格式是pth,使用官方的模型读取命令会检查你模型的格式是否正确,如果不是使用官方提供模型通过下面的函数强行读取模型(将其他模型例如caffe模型转过来的模型放到指定目录下)会发生错误。
def vgg19(pretrained=False, **kwargs):
"""VGG 19-layer model (configuration "E")
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
"""
model = VGG(make_layers(cfg['E']), **kwargs)
if pretrained:
model.load_state_dict(model_zoo.load_url(model_urls['vgg19']))
return model
假如我们有从caffe模型转过来的pytorch模型([0-255,BGR]),我们可以使用:
model_dir = '自己的模型地址'
model = VGG()
model.load_state_dict(torch.load(model_dir + 'vgg_conv.pth'))
也就是pytorch的读取函数进行读取即可。
来源:https://blog.csdn.net/qq_37385726/article/details/81943980


猜你喜欢
- 前言图是一种抽象数据结构,本质和树结构是一样的。图与树相比较,图具有封闭性,可以把树结构看成是图结构的前生。在树结构中,如果把兄弟节点之间或
- 本来想等到IE8正式发布时再在blog中写段代码,用来提示IE6用户升级到IE8的,不过貌似IE 8已经RTM了,今天又正好看到这个“升级I
- 最近碰见太多次lambda函数了,那就来详细解释一下该函数。lambda函数我们先对lambda函数进行一个简单的介绍lambda函数是一种
- 把一些地域性比较明显的数据显示在一张地图上,远比给别人一个 Excel 文件好得多。Matplotlib 中也有画地图的函数,但是是静态图,
- HTML与CSS在Flash中的应用:不小心看到同事Den在弄个小东西:在Flash里使用HTML和CSS,代码是这样:var m
- 目录range函数的使用第一种创建方式第二种创建方式第三种创建方式判断指定的数有没有在当前序列中循环结构总结range函数的使用作为循环遍历
- 首先 下载 jedis.jar包然后再 工程设置里面找到Libraries,点击+。添加下载好的jedis.jar包。点击OK退出即可创建J
- getattr()函数是Python自省的核心函数,具体使用大体如下:获取对象引用getattrGetattr用于返回一个对象属性,或者方法
- v1.0.0完成基础框架、初始功能背景:为了提高日常工作效率、学习界面工具开发,可以将一些常用的功能集成到一个小的测试工具中,供大家使用。一
- 今天我们来学习字符串数据类型相关知识,将讨论如何声明字符串数据类型,字符串数据类型与 ASCII 表的关系,字符串数据类型的属性,以及一些重
- 本文实例为大家分享了树回归的具体代码,供大家参考,具体内容如下#-*- coding:utf-8 -*- #!/usr/bin/python
- juypter notebook中直接使用log_device_placement=True打印不出来device信息# Creates a
- 使用Vue实现简单的用户登录界面,登录成功以后查询账号用户类型进行相应的页面路由跳转,效果如下图所示:HTML部分:<div clas
- 主要使用IE各个阶段实现的一些方法,从中也可以看出IE的发展史。暂时提供到IE4的判定。var isIE = window.ActiveXO
- 一、KNN算法简介邻近算法,或者说K最近邻(kNN,k-NearestNeighbor)分类算法是数据挖掘分类技术中最简单的方法之一。所谓K
- 实现原理 把所有需要延时加载的图片改成如下的格式:<img lazy_src="图片路径" border
- python np.dot(a,b)运算规则解析首先我们知道dot运算时不满 * 换律的,np.dot(a, b)与np.dot(b, a)是
- 1、值为列表的构造实例dic = {}dic.setdefault(key,[]).append(value)*********示例如下**
- 问题如何设定matplotlib输出的图片大小?import matplotlib.pyplot as plt一、plt.figure(fi
- 如何实现让每句话的头一个字母都大写? <%dim txtFnametxtFName = &qu