Pytorch统计参数网络参数数量方式
作者:qq_34535410 发布时间:2021-03-13 03:09:04
标签:Pytorch,统计参数,网络参数,数量
Pytorch统计参数网络参数数量
def get_parameter_number(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
Pytorch如何计算网络的参数量
本文以 Dense Block 为例,Pytorch 为 DL 框架,最终计算模块参数量方法如下:
import torch
import torch.nn as nn
class Norm_Conv(nn.Module):
def __init__(self,in_channel):
super(Norm_Conv,self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_channel,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel),
nn.Conv2d(in_channel,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel),
nn.Conv2d(in_channel,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel))
def forward(self,input):
out = self.layers(input)
return out
class DenseBlock_Norm(nn.Module):
def __init__(self,in_channel):
super(DenseBlock_Norm,self).__init__()
self.first_layer = nn.Sequential(nn.Conv2d(in_channel,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel))
self.second_layer = nn.Sequential(nn.Conv2d(in_channel*2,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel))
self.third_layer = nn.Sequential(
nn.Conv2d(in_channel*3,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel))
def forward(self,input):
output1 = self.first_layer(input)
output2 = self.second_layer(torch.cat((output1,input),dim=1))
output3 = self.third_layer(torch.cat((input,output1,output2),dim=1))
return output3
def count_param(model):
param_count = 0
for param in model.parameters():
param_count += param.view(-1).size()[0]
return param_count
# Get Parameter number of Network
in_channel = 128
net1 = Norm_Conv(in_channel)
print('Norm Conv parameter count is {}'.format(count_param(net1)))
net2 = DenseBlock_Norm(in_channel)
print('DenseBlock Norm parameter count is {}'.format(count_param(net2)))
最终结果如下
Norm Conv parameter count is 443520
DenseBlock Norm parameter count is 885888
来源:https://blog.csdn.net/qq_34535410/article/details/89715192
0
投稿
猜你喜欢
- django的表单系统,分2种基于django.forms.Form的所有表单类的父类基于django.forms.ModelForm,可以
- python中在实现一元线性回归时会使用最小二乘法,那你知道最小二乘法是什么吗。其实最小二乘法为分类回归算法的基础,从求解线性透视图中的消失
- PHP str_split() 函数实例把字符串 "Hello" 分割到数组中:<?php print_r(str
- 本文为大家分享了Linux环境下mysql5.6.24自动安装脚本代码,供大家参考,具体内容如下说明:一、本脚本仅供测试使用,若正式环境想要
- 本文开篇第一句话,想引用鲁迅先生《祝福》里的一句话,那便是:“我真傻,真的,我单单知道后端整天都是CRUD,我没想到前端整天都是Form表单
- 目录1、基础理论1.1 事务1.2 分布式事务2、分布式事务的解决方案2.1 两阶段提交/XA2.2 SAGA2.3 TCC2.4 本地消息
- 先来看个用Python实现的二分查找算法实例import sys def search2(a,m): low = 0 high = le
- 简介本文主要介绍如何通过pyplot来绘制函数图。主要绘制函数如下: - 一元一次函数 - 一元二次函数 - 指数函数 - 自然对数函数 -
- 1.小猫运动游戏源码# @Author : 辣条'''多行注释本程序运行后会有一只小猫向前走安装模块 pip ins
- Jupyter Notebook内使用argparse报错在github上下载了代码来学习时,发现将其直接copy到jupyter note
- 今天在做python获取邮件时需要递归调用解析函数才可以解析邮件内容,最后想要将解析出的内容返回时发现返回的是None 可以内容却可以打印出
- 本文实例为大家分享SQL SERVER数据库备份的具体代码,供大家参考,具体内容如下/** 批量循环备份用户数据库,做为数据库迁
- 引言https://github.com/go-chassis/go-chassis是一个微服务开发框架,而微服务开发框架带来的其中一个课题
- Python中可以使用 pickle 模块将对象转化为文件保存在磁盘上,在需要的时候再读取并还原。具体用法如下:pickle是Python库
- 主要有以下步骤:1、人脸检测2、人脸预处理3、从收集的人脸训练机器学习算法4、人脸识别5、收尾工作人脸检测算法:基于Haar的脸部检测器的基
- 概要 简单介绍几种用于判断numpy数组是否全
- TRUNCATE TABLE (Transact-SQL)Removes all rows from a table without log
- 一.Jupyter介绍Jupyter Notebook是一个交互式笔记本,支持运行40多种编程语言。Jupyter Notebook 的本质
- Blender 并不是唯一一款允许你为场景编程和自动化任务的3D软件; 随着每一个新版本的推出,Blender 正逐渐成为一个可靠的 CG
- 标题比较麻烦,都有些叙述不清;昨天下午在调试接口框架的时候,遇到了一个问题是这样的:使用python 写了一个函数,return 了两个返回