在pytorch中如何查看模型model参数parameters
作者:xiaoju233 发布时间:2021-12-04 22:43:29
标签:pytorch,查看模型,model,parameters
pytorch查看模型model参数parameters
示例1:pytorch自带的faster r-cnn模型
import torch
import torchvision
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(pretrained=True)
for name, p in model.named_parameters():
print(name)
print(p.requires_grad)
print(...)
#或者
for p in model.parameters():
print(p)
print(...)
示例2:自定义网络模型
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
cfg = [64, 64, 'M', 128, 128, 'M', 256, 256, 256, 'M', 512, 512, 512, 'M', 512, 512, 512]
self.features = self._vgg_layers(cfg)
def _vgg_layers(self, cfg):
layers = []
in_channels = 3
for x in cfg:
if x == 'M':
layers += [nn.MaxPool2d(kernel_size=2, stride=2)]
else:
layers += [nn.Conv2d(in_channels, x ,kernel_size=3, padding=1),
nn.BatchNorm2d(x),
nn.ReLU(inplace=True)
]
in_channels = x
return nn.Sequential(*layers)
def forward(self, data):
out_map = self.features(data)
return out_map
Model = Net()
for name, p in model.named_parameters():
print(name)
print(p.requires_grad)
print(...)
#或者
for p in model.parameters():
print(p)
print(...)
在自定义网络中,model.parameters()方法继承自nn.Module
pytorch查看模型参数总结
1:DNN_printer
其中(3, 32, 32)是输入的大小,其他方法中的参数同理
from DNN_printer import DNN_printer
batch_size = 512
def train(epoch):
print('\nEpoch: %d' % epoch)
net.train()
train_loss = 0
correct = 0
total = 0
// put the code here and you can get the result
DNN_printer(net, (3, 32, 32),batch_size)
结果
2:parameters
def cnn_paras_count(net):
"""cnn参数量统计, 使用方式cnn_paras_count(net)"""
# Find total parameters and trainable parameters
total_params = sum(p.numel() for p in net.parameters())
print(f'{total_params:,} total parameters.')
total_trainable_params = sum(p.numel() for p in net.parameters() if p.requires_grad)
print(f'{total_trainable_params:,} training parameters.')
return total_params, total_trainable_params
cnn_paras_count(net)
直接输出参数量,然后自己计算
需要注意的是,一般模型中参数是以float32保存的,也就是一个参数由4个bytes表示,那么就可以将参数量转化为存储大小。
例如:
44426个参数*4 / 1024 ≈ 174KB
3:get_model_complexity_info()
from ptflops import get_model_complexity_info
from torchvision import models
net = models.mobilenet_v2()
ops, params = get_model_complexity_info(net, (3, 224, 224), as_strings=True,
print_per_layer_stat=True, verbose=True)
4:torchstat
from torchstat import stat
import torchvision.models as models
model = models.resnet152()
stat(model, (3, 224, 224))
输出
来源:https://blog.csdn.net/qq_38600065/article/details/105552816
0
投稿
猜你喜欢
- Windows 10 x64macOS Sierra 10.12.4Python 2.7准备好装哔~了么,来吧,做个真正意义上的绿色小软件W
- XML.DOM需求有一个表,里面数据量比较大,每天一更新,其字段可以通过xml配置文件进行配置,即,可能每次建表的字段不一样。上游跑时会根据
- 一、K.prodprodkeras.backend.prod(x, axis=None, keepdims=False)功能:在某一指定轴,
- datasets.ImageFolder是PyTorch提供的一个预定义数据集类,用于处理图像数据。它可以方便地将一组图像加载到内存中,并为
- ThinkPHP CURD方法的limit方法也是模型类的连贯操作方法之一,主要用于指定查询和操作的数量,特别在分页查询的时候使用较多。并且
- 目录连接池是什么?为什么需要连接池?连接池的原理是什么?使用python语言自制简易mysql连接池开始使用自定义配置文件名 & 配
- 如下所示:>> type(np.newaxis)NoneType>> np.newaxis == NoneTruen
- 背景:做任务领金币的过程很无聊,而且每天都是重复同样的工作,非常符合自动化的定义;工具:python,appium,Android 手机(我
- Oracle不像SQLServer那样在存储过程中用Select就可以返回结果集,而是通过Out型的参数进行结果集返回的。实际上是利用REF
- <?php /*============================文件说明===========================
- 如果你忘记了你的MYSQL的root口令的话,你可以通过下面的过程恢复。1. 向mysqld server 发
- 一、背景:在平时工作中有遇到端口检测,查看服务端特定端口是否对外开放,常用nmap,tcping,telnet等,同时也可以利用站长工具等w
- 调用pytorch内置的模型的方法import torchvisionmodel = torchvision.models.resnet50
- 大家好,我是辰哥~今天给大家分享两个制作二维码的Python库,可以生成普通的二维码、图片背景版二维码、动图GIF版二维。1.MyQR安装p
- 本文主要关于python的正则表达式的符号与方法。findall: 找寻所有匹配,返回所有组合的列表search: 找寻第一个匹配并返回su
- RabbitMQ 6种工作模式对RabbitMQ 6种工作模式(简单模式、工作模式、订阅模式、路由模式、主题模式、RPC模式)进行场景和参数
- 本文实例讲述了CI操作cookie的方法。分享给大家供大家参考,具体如下:CI 操作cookie 有三种方法,2中Ci自带的,其
- 用法:matplotlib.pyplot.stem(*args, linefmt=None, markerfmt=None, basefmt
- 一、系统简介实现一个学生信息的管理系统:主要功能有:添加学生信息删除学生信息修改学生信息查询学生信息显示学生信息退出当前系统二、步骤分析显示
- 一、新建项目,在主配置文件中,修改以下内容:ALLOWED_HOSTS = ['127.0.0.1','localh