如何计算 tensorflow 和 pytorch 模型的浮点运算数
作者:浩哥依然 发布时间:2023-07-17 04:20:58
本文主要讨论如何计算 tensorflow 和 pytorch 模型的 FLOPs。如有表述不当之处欢迎批评指正。欢迎任何形式的转载,但请务必注明出处。
1. 引言
FLOPs 是 floating point operations 的缩写,指浮点运算数,可以用来衡量模型/算法的计算复杂度。本文主要讨论如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算对应模型的 FLOPs。
2. 模型结构
为了说明方便,先搭建一个简单的神经网络模型,其模型结构以及主要参数如表1 所示。
表 1 模型结构及主要参数
Layers | channels | Kernels | Strides | Units | Activation |
---|---|---|---|---|---|
Conv2D | 32 | (4,4) | (1,2) | \ | relu |
GRU | \ | \ | \ | 96 | \ |
Dense | \ | \ | \ | 256 | sigmoid |
用 tensorflow(实际使用 tensorflow 中的 keras 模块)实现该模型的代码为:
from tensorflow.keras.layers import *
from tensorflow.keras.models import load_model, Model
def test_model_tf(Input_shape):
# shape: [B, C, T, F]
main_input = Input(batch_shape=Input_shape, name='main_inputs')
conv = Conv2D(32, kernel_size=(4, 4), strides=(1, 2), activation='relu', data_format='channels_first', name='conv')(main_input)
# shape: [B, T, FC]
gru = Reshape((conv.shape[2], conv.shape[1] * conv.shape[3]))(conv)
gru = GRU(units=96, reset_after=True, return_sequences=True, name='gru')(gru)
output = Dense(256, activation='sigmoid', name='output')(gru)
model = Model(inputs=[main_input], outputs=[output])
return model
用 pytorch 实现该模型的代码为:
import torch
import torch.nn as nn
class test_model_torch(nn.Module):
def __init__(self):
super(test_model_torch, self).__init__()
self.conv2d = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4,4), stride=(1,2))
self.relu = nn.ReLU()
self.gru = nn.GRU(input_size=4064, hidden_size=96)
self.fc = nn.Linear(96, 256)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
# shape: [B, C, T, F]
out = self.conv2d(inputs)
out = self.relu(out)
# shape: [B, T, FC]
batch, channel, frame, freq = out.size()
out = torch.reshape(out, (batch, frame, freq*channel))
out, _ = self.gru(out)
out = self.fc(out)
out = self.sigmoid(out)
return out
3. 计算模型的 FLOPs
本节讨论的版本具体为:tensorflow 1.12.0, tensorflow 2.3.1 以及 pytorch 1.10.1+cu102。
3.1. tensorflow 1.12.0
在 tensorflow 1.12.0 环境中,可以使用以下代码计算模型的 FLOPs:
import tensorflow as tf
import tensorflow.keras.backend as K
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops
if __name__ == "__main__":
x = K.random_normal(shape=(1, 1, 100, 256))
model = test_model_tf(x.shape)
print('FLOPs of tensorflow 1.12.0:', get_flops(model))
3.2. tensorflow 2.3.1
在 tensorflow 2.3.1 环境中,可以使用以下代码计算模型的 FLOPs :
import tensorflow.compat.v1 as tf
import tensorflow.compat.v1.keras.backend as K
tf.disable_eager_execution()
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops
if __name__ == "__main__":
x = K.random_normal(shape=(1, 1, 100, 256))
model = test_model_tf(x.shape)
print('FLOPs of tensorflow 2.3.1:', get_flops(model))
3.3. pytorch 1.10.1+cu102
在 pytorch 1.10.1+cu102 环境中,可以使用以下代码计算模型的 FLOPs(需要安装 thop):
import thop
x = torch.randn(1, 1, 100, 256)
model = test_model_torch()
flops, _ = thop.profile(model, inputs=(x,))
print('FLOPs of pytorch 1.10.1:', flops * 2)
需要注意的是,thop 返回的是 MACs (Multiply–Accumulate Operations),其等于 2 2 2 倍的 FLOPs,所以上述代码有乘 2 2 2 操作。
3.4. 结果对比
三者计算出的 FLOPs 分别为:
tensorflow 1.12.0:
tensorflow 2.3.1:
pytorch 1.10.1:
可以看到 tensorflow 1.12.0 和 tensorflow 2.3.1 的结果基本在同一个量级,而与 pytorch 1.10.1 计算出来的相差甚远。但如果将上述模型结构改为只包含第一层 Conv2D,三者计算出来的 FLOPs 却又是一致的。所以推断差异主要来自于 GRU 的 FLOPs。如读者知道其中详情,还请不吝赐教。
4. 总结
本文给出了在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算模型 FLOPs 的方法,但从本文所使用的测试模型来看, tensorflow 与 pytorch 统计出的结果相差甚远。当然,也可以根据网络层的类型及其对应的参数,推导计算出每个网络层所需的 FLOPs。
来源:https://blog.csdn.net/wjrenxinlei/article/details/127973081
猜你喜欢
- 一定要注重代码规范,按照平时的代码管理,可以将Python代码规范检测分为两种:静态本地检测:可以借助静态检查工具,比如:Flake8,Py
- Python安装过程,供大家参考,具体内容如下1.下载安装程序我们安装Python的一个重要目的是为了用IAR编译CC2640 OAD文件时
- 楔子在 Python3.6 之前,格式化字符串一般会使用百分号占位符或者 format 函数,举个例子:name = &
- eval()函数eval() 函数用来执行一个字符串表达式,并返回表达式的值。语法eval(expression[, globals[, l
- 本文介绍基于Anaconda环境以及PyCharm软件结合,安装PyTorch深度学习框架。一、anaconda安装(一)下载官网下载链接:
- Python之Web框架Django项目搭建全过程IDE说明:Win7系统Python:3.5Django:1.10Pymysql:0.7.
- 大家好,之前分享过很多关于 Pandas 的文章,今天我给大家分享5个小而美的 Pandas 实战案例。内容主要分为:如何自行模拟数据多种数
- 我们可以把表里每一个横行的数据,看成是不同的元组。在理解了这个概念后,昨天我们学了不少的namedtuple类,是否也能把元组转换成name
- matplotlib官方文档:https://matplotlib.org/stable/users/index.htmlmatplotli
- K线数据提取依据原有数据集格式,按要求生成新表:1、每分钟的close数据的第一条、最后一条、最大值及最小值,2、每分钟vol数据的增长量(
- 一、前言设计应用程序时,有时不希望将一个不太相关的功能集成到程序中,或者是因为该功能与当前设计的应用程序联系不大,或者是因为该功能已经可以使
- 目录1. 前言2. 准备3. 实战1、获取目标应用的包名及初始化 Activity2、获取所有在线的设备3、群控打开目标应用4、封装执行步骤
- 0. 前言周日在爬一个国外网站的时候,发现用协程并发请求,并且请求次数太快的时候,会出现对方把我的服务器IP封掉的情况。于是网上找了一下开源
- 导语前段时间有小伙伴留言说想让我带大家写写桌面小挂件,今天就满足一下留过类似言的小伙伴的请求呗~不过感觉写桌面的挂历啥的没意思,就简单带大家
- 很多用户在网站上会糊弄填写一个电子信箱,请问有什么办法可以阻止这种行为?我们通常用两种方法来进行判断:第一种,设定只有形如aspxhome@
- 本文实例讲述了Django实现简单分页功能的方法。分享给大家供大家参考,具体如下:使用django的第三方模块django-pure-pag
- 准备篇:CentOS 6.6系统安装配置图解教程https://www.jb51.net/os/239738.html一、配置防火墙,开启8
- 需求:需求简单:但是感觉最后那部分遍历有意思:S型数组赋值,考虑到下标,简单题先实现个差不多的m = 5cols = 9rows = 4nu
- 具体内容如下:1 os.system例如 ipython中运行如下命令,返回运行状态statusos.system('cat /et
- 一、vim python自动补全插件:pydiction 可以实现下面python代码的自动补全:1.简单python关键词补全 2.pyt