pytorch 计算Parameter和FLOP的操作
作者:落地生根1314 发布时间:2023-03-01 04:15:55
标签:pytorch,Parameter,FLOP
深度学习中,模型训练完后,查看模型的参数量和浮点计算量,在此记录下:
1 THOP
在pytorch中有现成的包thop用于计算参数数量和FLOP,首先安装thop:
pip install thop
注意安装thop时可能出现如下错误:
解决方法:
pip install --upgrade git+https://github.com/Lyken17/pytorch-OpCounter.git # 下载源码安装
使用方法如下:
from torchvision.models import resnet50 # 引入ResNet50模型
from thop import profile
model = resnet50()
flops, params = profile(model, input_size=(1, 3, 224,224)) # profile(模型,输入数据)
对于自己构建的函数也一样,例如shuffleNetV2
from thop import profile
from utils.ShuffleNetV2 import shufflenetv2 # 导入shufflenet2 模块
import torch
model_shuffle = shufflenetv2(width_mult=0.5)
model = torch.nn.DataParallel(model_shuffle) # 调用shufflenet2 模型,该模型为自己定义的
flop, para = profile(model, input_size=(1, 3, 224, 224),)
print("%.2fM" % (flop/1e6), "%.2fM" % (para/1e6))
更多细节,可参考thop GitHub链接: https://github.com/Lyken17/pytorch-OpCounter
2 计算参数
pytorch本身带有计算参数的方法
from thop import profile
from utils.ShuffleNetV2 import shufflenetv2 # 导入shufflenet2 模块
import torch
model_shuffle = shufflenetv2(width_mult=0.5)
model = torch.nn.DataParallel(model_shuffle)
total = sum([param.nelement() for param in model.parameters()])
print("Number of parameter: %.2fM" % (total / 1e6))
补充:pytorch: 计算网络模型的计算量(FLOPs)和参数量(Params)
计算量:
FLOPs,FLOP时指浮点运算次数,s是指秒,即每秒浮点运算次数的意思,考量一个网络模型的计算量的标准。
参数量:
Params,是指网络模型中需要训练的参数总数。
第一步:安装模块(thop)
pip install thop
第二步:计算
import torch
from thop import profile
net = Model() # 定义好的网络模型
input = torch.randn(1, 3, 112, 112)
flops, params = profile(net, (inputs,))
print('flops: ', flops, 'params: ', params)
注意:
输入input的第一维度是批量(batch size),批量的大小不回影响参数量, 计算量是batch_size=1的倍数
profile(net, (inputs,))的 (inputs,)中必须加上逗号,否者会报错
以上为个人经验,希望能给大家一个参考,也希望大家多多支持脚本之家。如有错误或未考虑完全的地方,望不吝赐教。
来源:https://blog.csdn.net/qq_26369907/article/details/89857021
0
投稿
猜你喜欢
- python中return不返回值是因为你没有将返回的值取出来。解决方法:调用函数,将函数的返回值赋给一个变量,输出这个变量就可以看到函数的
- 用dicompyler软件打开dicom图像,头文件如图所示:当然也可以直接读取:ds = dicom.read_file('H:\
- 【pytorch官方文档】:https://pytorch.org/docs/stable/generated/torch.nn.AvgPo
- ORM 查询管理器对于 ORM 定义: 对象关系映射, Object Relational Mapping, ORM, 是一种程序设计技术,
- http-server是一个简单的命令行http服务器,基于nodejs,下载地址:https://nodejs.org/en/downlo
- Python中的列表基于PyListObject实现,列表支持元素的插入、删除、更新操作,因此PyListObject是一个变长对象(列表的
- 这几天看了篇叫"Penetration: from application down to OS (Oracle)"的文
- 有个文本文件,需要替换里面的一个词,用python来完成,我是这样写的:def modify_text(): with open('
- 本文为大家解析了python实现4名牌手洗牌发牌的问题,供大家参考,具体内容如下编写程序, 4名牌手打牌,计算机随机将52张牌(不含大小鬼)
- 在遥感应用中,我们经常需要对某一景遥感影像中的全部像元的像素值进行平均值求取——这一操作很好实现,基
- Python中打开文本使用的是with语句,比如打开一个文件并读取每一行with open(filename) as fp: f
- 最近在学爬虫时发现许多网站都有自己的反爬虫机制,这让我们没法直接对想要的数据进行爬取,于是了解这种反爬虫机制就会帮助我们找到解决方法。常见的
- Python自动化测试 Eclipse+Pydev 搭建开发环境C#之所以容易让人感兴趣,是因为安装完Visual Studio, 就可以很
- 问题:之前在学习list和dict相关的知识时,遇到了一个常见的问题:如何在遍历list或dict的时候正常删除?例如我们在遍历dict的时
- 什么是 BokehBokeh 是 Python 中的交互式可视化库。Bokeh提供的最佳功能是针对现代 Web 浏览器进行演示的高度交互式图
- 最近,某水果手机厂在万众期待中开了一场没有发布万众期待的手机产品的发布会,发布了除手机外的其他一些产品,也包括最新的水果14系统。几天后,更
- 流程:模拟登录→获取Html页面→正则解析所有符合条件的行→逐一将符合条件的行的所有列存入到CSVData[]临时变量中→写入到CSV文件中
- 引文: 长期以来,多媒体信息在计算机中都是以文件形式存放,由操作系统管理的,但是随着计算机网络,分布式计算的发展,对多媒体信息进行高效的管理
- 目录项目引入flask-sqlalchemyORM简介及模型定义表关系类型及编码实现一对多关系(多对一关系)一对一关系多对多关系数据库基本操
- 一个有点绕的例子,用PyScripter调试器步进跟踪可以看清楚对 象结构的具体细节。对原作改变了一下,在未定义子对象属性时__getite