在pytorch中如何查看模型model参数parameters
作者:xiaoju233 发布时间:2021-12-04 22:43:29
标签:pytorch,查看模型,model,parameters
pytorch查看模型model参数parameters
示例1:pytorch自带的faster r-cnn模型
import torch
import torchvision
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
for name, p in model.named_parameters():
print(name)
print(p.requires_grad)
print(...)
#或者
for p in model.parameters():
print(p)
print(...)
示例2:自定义网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
self.features = self._vgg_layers(cfg)
def _vgg_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x ,kernel_size=3, padding=1),
nn.BatchNorm2d(x),
nn.ReLU(inplace=True)
]
in_channels = x
return nn.Sequential(*layers)
def forward(self, data):
out_map = self.features(data)
return out_map
Model = Net()
for name, p in model.named_parameters():
print(name)
print(p.requires_grad)
print(...)
#或者
for p in model.parameters():
print(p)
print(...)
在自定义网络中,model.parameters()方法继承自nn.Module
pytorch查看模型参数总结
1:DNN_printer
其中(3, 32, 32)是输入的大小,其他方法中的参数同理
from DNN_printer import DNN_printer
batch_size = 512
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
// put the code here and you can get the result
DNN_printer(net, (3, 32, 32),batch_size)
结果
2:parameters
def cnn_paras_count(net):
"""cnn参数量统计, 使用方式cnn_paras_count(net)"""
# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in net.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
return total_params, total_trainable_params
cnn_paras_count(net)
直接输出参数量,然后自己计算
需要注意的是,一般模型中参数是以float32保存的,也就是一个参数由4个bytes表示,那么就可以将参数量转化为存储大小。
例如:
44426个参数*4 / 1024 ≈ 174KB
3:get_model_complexity_info()
from ptflops import get_model_complexity_info
from torchvision import models
net = models.mobilenet_v2()
ops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
print_per_layer_stat=True, verbose=True)
4:torchstat
from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))
输出
来源:https://blog.csdn.net/qq_38600065/article/details/105552816


猜你喜欢
- 数组都是从0开始。javascript是arrayname[i],而vbscript是arrayname(i) javascript的字符串
- 概述在开发中,可能会遇到当页面滚动停止之后执行某些操作的需求。在 scrollend 事件之前,并没有可靠的方法来检测页面滚动是否完成。这意
- 一、互联网人的焦虑互联网人是最焦虑的那批人,也是最爱学习的那批人。没办法,互联网行业的节奏实在太快了,每天都生活在信息 * 的环境里,“风口”
- 本文实例为大家分享了python实现屏幕中间倒计时的具体代码,供大家参考,具体内容如下先看下效果图:代码:import timefrom t
- 使用scipy.optimize模块的root和fsolve函数进行数值求解线性及非线性方程,下面直接贴上代码,代码很简单from scip
- Python 跟 Python3 完全就是两种语言1、 import caffe FAILED环境为 Ubuntu 16 cuda
- 1.对查询进行优化,应尽量避免全表扫描,首先应考虑在 where 及 order by 涉及的列上建立索引。2.应尽量避免在 where 子
- Django2.0 通过URL访问上传的文件(pdf、picture等)Django是一个成熟的web框架,基于python实现,有很多的优
- 活在当下的程序员应该都听过“面向对象编程”一词,也经常有人问能不能用一句话解释下什么是“面向对象编程”,我们先来看看比较正式的说法。把一组数
- 接口性能测试时,接口请求参数是根据一定的规则拼接后进行MD5加密后再进行传参,因此借助于python脚本实现,则可以有效提升测试效率。1.分
- python中xmltodict使用xml转换成OrderedDict代码 :import xmltodictfrom pprin
- 之前的文章介绍了python抓取网页数据并将数据保存到本地excel文件,后续可以将数据保存到数据库(SqlServer、mysql等)中,
- 今天给大家介绍一个可以获取当前系统信息的库——psutil利用psutil库可以获取系统的一些信息,如cpu,内存等使用率,从而可以查看当前
- 作为一个Oracle数据库开发者或者DBA,在实际工作中经常会遇到这样的问题:试图对库表中的某一列或几列创建唯一索引时,系统提示ORA-01
- PyCharm就是Python语言开发中一个很受欢迎的IDE,界面类似于visual studio,android studio,集成的功能
- 目录类空指向ES6 箭头函数vuetifyvue-cli异步和同步运行和部署TIPS排名不分先后最近好像都是只发了一些生活类,吐槽的一些 b
- 网上有很多关于Python+opencv人脸检测的例子,并大都附有源程序。但是在实际使用时依然会遇到这样或者那样的问题,在这里给出常见的两种
- 训练好了model后,可以通过python调用caffe的模型,然后进行模型测试的输出。本次测试主要依靠的模型是在caffe模型里面自带训练
- 哈喽!我的朋友们,最近有一个新项目。所以一直没更新!有没有想我啊!!今天咱们来说一下JS原生轮播图!话不多说:直接来代码吧:下面是CSS部分
- 如何将123456789转化成123,456,789这样的形式呢?很多流量大的站比如优酷都有这样的格式。也是设计程序最常用的算