Python如何加载模型并查看网络
作者:ShuqiaoS 发布时间:2021-11-01 15:53:22
加载模型并查看网络
加载模型,以vgg19为例。
打开终端
> python
Python 3.7.2 (tags/v3.7.2:9a3ffc0492, Dec 23 2018, 23:09:28) [MSC v.1916 64 bit
(AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> from torchvision import models
>>> model = models.vgg19(pretrained=True) #此时如果是第一次加载会开始下载模型的pth文件
>>> print(model.model)
结果:
VGG(
(features): Sequential(
(0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(1): ReLU(inplace)
(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(3): ReLU(inplace)
(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(6): ReLU(inplace)
(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(8): ReLU(inplace)
(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(11): ReLU(inplace)
(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(13): ReLU(inplace)
(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(15): ReLU(inplace)
(16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(17): ReLU(inplace)
(18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(20): ReLU(inplace)
(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(22): ReLU(inplace)
(23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(24): ReLU(inplace)
(25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(26): ReLU(inplace)
(27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(29): ReLU(inplace)
(30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(31): ReLU(inplace)
(32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(33): ReLU(inplace)
(34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
(35): ReLU(inplace)
(36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)
(avgpool): AdaptiveAvgPool2d(output_size=(7, 7))
(classifier): Sequential(
(0): Linear(in_features=25088, out_features=4096, bias=True)
(1): ReLU(inplace)
(2): Dropout(p=0.5)
(3): Linear(in_features=4096, out_features=4096, bias=True)
(4): ReLU(inplace)
(5): Dropout(p=0.5)
(6): Linear(in_features=4096, out_features=1000, bias=True)
)
)
注意,直接打印模型是没有办法看到模型结构的,只能看到带模型参数的pth文件内容;需要打印model.model才可以看到模型本身。
神经网络_模型的保存,模型的加载
模型的保存(torch.save)
方式1(模型结构+模型参数)
参数:保存位置
# 创建模型
vgg16 = torchvision.models.vgg16(pretrained=False)
# 保存方式1——模型结构+模型参数
torch.save(vgg16, "vgg16_method1.pth")
方式2(模型参数)
# 保存方式2 模型参数(官方推荐)。保存成字典,只保存网络模型中的一些参数
torch.save(vgg16.state_dict(), "vgg16_method2.pth")
模型的加载(torch.load)
对应保存方式1
参数:模型路径
# 方式1 --》 保存方式1
model1 = torch.load("vgg16_method1.pth")
对应保存方式2
vgg16.load_state_dict("vgg16_method2.pth")
输出为字典形式。若要回复网络,采用以下形式:
model2 = torch.load("vgg16_method2.pth") #输出是字典形式
# 恢复网络结构
vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(model2)
方式1存储,加载时需注意事项
新建自己的网络:
class test(nn.Module):
def __init__(self):
super(lh, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
保存自己的网络:
Test = test()
# 保存自己定义的网络
torch.save(Test, "Test_method1.pth")
加载自己的网络:
model3 = torch.load("Test_method1.pth")
会报错!!!!!!
解决办法(需要注意):
将定义的网络复制到加载的python文件中:
class test(nn.Module):
def __init__(self):
super(test, self).__init__()
self.conv1 = nn.Conv2d(3, 64, kernel_size=3)
def forward(self, x):
x = self.conv1(x)
return x
model3 = torch.load("Test_method1.pth")
来源:https://blog.csdn.net/ShuqiaoS/article/details/90766761


猜你喜欢
- 本文要实现的功能是:根据下拉列表的选项将数据库中对应的内容显示在页面,选定要排除的选项后,提交剩余的选项到数据库。为了方便前后台交互,利用了
- 说明本文根据https://github.com/liuchengxu/blockchain-tutorial的内容,用python实现的,
- 初学vue自己新建一个vue项目来做学习demo。不过在编写代码时一直出现空格不规范的警告。严重影响初学者的热情。错误如下图所示。(这样的错
- 下文分步骤给大家介绍的非常详细,具体详情请看下文吧。一、准备用两台服务器做测试:Master Server: 192.0.0.1/Linux
- 本文实例讲述了Python实现PS图像调整颜色梯度效果。分享给大家供大家参考,具体如下:这里用 Python 实现 PS 中的色彩图,可以看
- 一、排序的基本概念和分类所谓排序,就是使一串记录,按照其中的某个或某些关键字的大小,递增或递减的排列起来的操作。排序算法,就是如何使得记录按
- 本文实例讲述了Python定时任务sched模块用法。分享给大家供大家参考,具体如下:通过sched模块可以实现通过自定义时间,自定义函数,
- 最近需要用python打包一个单页面网页demo,于是准备用python包pyinstaller来打包程序。网上搜索了一下,大部分教程都是打
- 测试数据 http://grouplens.org/datasets/movielens/协同过滤推荐算法主要分为:1、基于用户。根据相邻用
- 一,粘包问题详情 1,只有TCP有粘包现象,UDP永远不会粘包你的程序实际上无权直接操作网卡的,你操作网卡都是通过操作系统给用户程序暴露出来
- 在ASP中,如何创建DSN? 见下:<HTML><HEAD><META&n
- 我们来看看MD5加密码的实现:注意看一下他数据库里的加密位数!先在通用处申明:Private Const BITS_TO
- 普通滑动验证以http://admin.emaotai.cn/login.aspx为例这类验证码只需要我们将滑块拖动指定位置,处理起来比较简
- 在正式编写爬虫案例前,先对 scrapy 进行一下系统的学习。scrapy 安装与简单运行使用命令 pip install scrapy 进
- 本篇,我们学习PyQt5界面中拖放(Drag 和Drop)控件。拖放动作在GUI中,拖放指的是点击一个对象,并将其拖动到另一个对象上的动作。
- 这篇文章主要介绍了Vue子组件内的props对象里的default参数是如何定义Array、Object、或Function默认值的正确写法
- Vue服务器部署刷新页面404问题描述在上线vue开发的前端网页部署在服务器上后,刷新页面显示404原因因为网页上显示的是静态绝对路径而实际
- mysql官方提供了很多种connector,其中包括python的connector。下载地址在:http://dev.mysql.com
- 本文实例讲述了Python多线程编程之多线程加锁操作。分享给大家供大家参考,具体如下:Python语言本身是支持多线程的,不像PHP语言。下
- 你可能在使用MySQL过程中,各种意外导致数据库表的损坏,而且这些数据往往是最新的数据,通常不可能在备份数据中找到。本章将讲述如何检测MyS