Pytorch统计参数网络参数数量方式
作者:qq_34535410 发布时间:2021-03-13 03:09:04
标签:Pytorch,统计参数,网络参数,数量
Pytorch统计参数网络参数数量
def get_parameter_number(net):
total_num = sum(p.numel() for p in net.parameters())
trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
return {'Total': total_num, 'Trainable': trainable_num}
Pytorch如何计算网络的参数量
本文以 Dense Block 为例,Pytorch 为 DL 框架,最终计算模块参数量方法如下:
import torch
import torch.nn as nn
class Norm_Conv(nn.Module):
def __init__(self,in_channel):
super(Norm_Conv,self).__init__()
self.layers = nn.Sequential(
nn.Conv2d(in_channel,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel),
nn.Conv2d(in_channel,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel),
nn.Conv2d(in_channel,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel))
def forward(self,input):
out = self.layers(input)
return out
class DenseBlock_Norm(nn.Module):
def __init__(self,in_channel):
super(DenseBlock_Norm,self).__init__()
self.first_layer = nn.Sequential(nn.Conv2d(in_channel,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel))
self.second_layer = nn.Sequential(nn.Conv2d(in_channel*2,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel))
self.third_layer = nn.Sequential(
nn.Conv2d(in_channel*3,in_channel,3,1,1),
nn.ReLU(True),
nn.BatchNorm2d(in_channel))
def forward(self,input):
output1 = self.first_layer(input)
output2 = self.second_layer(torch.cat((output1,input),dim=1))
output3 = self.third_layer(torch.cat((input,output1,output2),dim=1))
return output3
def count_param(model):
param_count = 0
for param in model.parameters():
param_count += param.view(-1).size()[0]
return param_count
# Get Parameter number of Network
in_channel = 128
net1 = Norm_Conv(in_channel)
print('Norm Conv parameter count is {}'.format(count_param(net1)))
net2 = DenseBlock_Norm(in_channel)
print('DenseBlock Norm parameter count is {}'.format(count_param(net2)))
最终结果如下
Norm Conv parameter count is 443520
DenseBlock Norm parameter count is 885888
来源:https://blog.csdn.net/qq_34535410/article/details/89715192
0
投稿
猜你喜欢
- 导语承载童年的纸飞机你还会叠嘛?如果你是个80后或者90后,那你应该记得小时候玩的纸飞机。叠好后,哈口仙气,飞出去,感觉棒棒哒。虽然是一个极
- 以前我一直用os.system()处理一些系统管理任务,因为我认为那是运行linux命令最简单的方式.我们能从Python官方文档里读到应该
- 今天的这篇文章是讲XHTML中的细节部分的,这篇续述的主题就是ID与CLASS怎么用,在标题中有提及使用原则与技巧,这里的使用原则与技巧是我
- 有时一些网页对源码进行了加密,我们很难找到类似像“onselectstart="return false"”这样的代码,
- 本文实例讲述了Laravel框架执行原生SQL语句及使用paginate分页的方法。分享给大家供大家参考,具体如下:1、运行原生sqlpub
- 表单的验证是开发WEB应用程序中常遇到的一关。有时候我们必须保证表单的某些项必须填写、必须为数字、必须是指定的位数等等,这时候就要用到表单验
- 正三角形九九乘法表#正三角形九九乘法表for i in range(1,10): for j in range(1
- 1. 实例描述在平时编程的过程中,会经常在网上翻译一些单词,本文使用Python制作一款翻译小工具,不仅可以自己用,还可以嵌入到程序当中。运
- import threadingfrom time import sleepdef test_func(id): &n
- 深入理解python try异常处理机制#python的try语句有两种风格#一:种是处理异常(try/except/else)#二:种是无
- 本文研究的主要是python模块之paramiko的相关用法,具体实现代码如下,一起来看看。paramiko模块提供了ssh及sft进行远程
- 原理就是先声明常量,包括列数,行数,各列的属性,然后在程序的其它过程用这些常量来控制Cells。非常方便,便于修改和移植! 以下为窗体整体代
- 在windows下安装配置Ulipad今天推荐一款轻便的文本编辑器Ulipad,用来写一些小的Python脚本非常方便。Ulipad下载地址
- 如下所示:>>> import pandas as pd>>> import numpy as np#
- 语音识别是人工智能中的一个领域,它允许计算机理解人类语音并将其转换为文本。该技术用于 Alexa 和各种聊天机器人应用程序等设备。而我们最常
- Windows 10 x64macOS Sierra 10.12.4Python 2.7准备好装哔~了么,来吧,做个真正意义上的绿色小软件W
- 如何在SQL中启用全文检索功能?本文将通过实例向你剖折这个问题。这是一个全文索引的一个例子,首先在查询分析器中使用:use pubsgo--
- 1、可以使用"+"号完成操作输出为:[1, 2, 3, 8, 'google', 'com
- var InterestKeywordListString = $("#userInterestKeywordLabel"
- Python tcp socket编程详解初学脚本语言Python,测试可用的tcp通讯程序:服务器:#!/usr/bin/env pyth