Python tensorflow与pytorch的浮点运算数如何计算
作者:浩哥依然 发布时间:2023-06-28 14:13:15
1. 引言
FLOPs 是 floating point operations 的缩写,指浮点运算数,可以用来衡量模型/算法的计算复杂度。本文主要讨论如何在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算对应模型的 FLOPs。
2. 模型结构
为了说明方便,先搭建一个简单的神经网络模型,其模型结构以及主要参数如表1 所示。
表 1 模型结构及主要参数
Layers | channels | Kernels | Strides | Units | Activation |
---|---|---|---|---|---|
Conv2D | 32 | (4,4) | (1,2) | \ | relu |
GRU | \ | \ | \ | 96 | \ |
Dense | \ | \ | \ | 256 | sigmoid |
用 tensorflow(实际使用 tensorflow 中的 keras 模块)实现该模型的代码为:
from tensorflow.keras.layers import *
from tensorflow.keras.models import load_model, Model
def test_model_tf(Input_shape):
# shape: [B, C, T, F]
main_input = Input(batch_shape=Input_shape, name='main_inputs')
conv = Conv2D(32, kernel_size=(4, 4), strides=(1, 2), activation='relu', data_format='channels_first', name='conv')(main_input)
# shape: [B, T, FC]
gru = Reshape((conv.shape[2], conv.shape[1] * conv.shape[3]))(conv)
gru = GRU(units=96, reset_after=True, return_sequences=True, name='gru')(gru)
output = Dense(256, activation='sigmoid', name='output')(gru)
model = Model(inputs=[main_input], outputs=[output])
return model
用 pytorch 实现该模型的代码为:
import torch
import torch.nn as nn
class test_model_torch(nn.Module):
def __init__(self):
super(test_model_torch, self).__init__()
self.conv2d = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=(4,4), stride=(1,2))
self.relu = nn.ReLU()
self.gru = nn.GRU(input_size=4064, hidden_size=96)
self.fc = nn.Linear(96, 256)
self.sigmoid = nn.Sigmoid()
def forward(self, inputs):
# shape: [B, C, T, F]
out = self.conv2d(inputs)
out = self.relu(out)
# shape: [B, T, FC]
batch, channel, frame, freq = out.size()
out = torch.reshape(out, (batch, frame, freq*channel))
out, _ = self.gru(out)
out = self.fc(out)
out = self.sigmoid(out)
return out
3. 计算模型的 FLOPs
本节讨论的版本具体为:tensorflow 1.12.0, tensorflow 2.3.1 以及 pytorch 1.10.1+cu102。
3.1. tensorflow 1.12.0
在 tensorflow 1.12.0 环境中,可以使用以下代码计算模型的 FLOPs:
import tensorflow as tf
import tensorflow.keras.backend as K
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops
if __name__ == "__main__":
x = K.random_normal(shape=(1, 1, 100, 256))
model = test_model_tf(x.shape)
print('FLOPs of tensorflow 1.12.0:', get_flops(model))
3.2. tensorflow 2.3.1
在 tensorflow 2.3.1 环境中,可以使用以下代码计算模型的 FLOPs :
import tensorflow.compat.v1 as tf
import tensorflow.compat.v1.keras.backend as K
tf.disable_eager_execution()
def get_flops(model):
run_meta = tf.RunMetadata()
opts = tf.profiler.ProfileOptionBuilder.float_operation()
flops = tf.profiler.profile(graph=K.get_session().graph,
run_meta=run_meta, cmd='op', options=opts)
return flops.total_float_ops
if __name__ == "__main__":
x = K.random_normal(shape=(1, 1, 100, 256))
model = test_model_tf(x.shape)
print('FLOPs of tensorflow 2.3.1:', get_flops(model))
3.3. pytorch 1.10.1+cu102
在 pytorch 1.10.1+cu102 环境中,可以使用以下代码计算模型的 FLOPs(需要安装 thop):
import thop
x = torch.randn(1, 1, 100, 256)
model = test_model_torch()
flops, _ = thop.profile(model, inputs=(x,))
print('FLOPs of pytorch 1.10.1:', flops * 2)
需要注意的是,thop 返回的是 MACs (Multiply–Accumulate Operations),其等于 2 2 2 倍的 FLOPs,所以上述代码有乘 2 2 2 操作。
3.4. 结果对比
三者计算出的 FLOPs 分别为:
tensorflow 1.12.0:
tensorflow 2.3.1:
pytorch 1.10.1:
可以看到 tensorflow 1.12.0 和 tensorflow 2.3.1 的结果基本在同一个量级,而与 pytorch 1.10.1 计算出来的相差甚远。但如果将上述模型结构改为只包含第一层 Conv2D,三者计算出来的 FLOPs 却又是一致的。所以推断差异主要来自于 GRU 的 FLOPs。如读者知道其中详情,还请不吝赐教。
4. 总结
本文给出了在 tensorflow 1.x, tensorflow 2.x 以及 pytorch 中利用相关工具计算模型 FLOPs 的方法,但从本文所使用的测试模型来看, tensorflow 与 pytorch 统计出的结果相差甚远。当然,也可以根据网络层的类型及其对应的参数,推导计算出每个网络层所需的 FLOPs。
来源:https://blog.csdn.net/wjrenxinlei/article/details/127973081
猜你喜欢
- 1.requests库简介requests 是 Python 中比较常用的网页请求库,主要用来发送 HTTP 请求,在使用爬虫或测试服务器响
- 一个网站的一个页面download.asp通过判断referer来确定是不是从他本站点过来的链接,使用这个功能我们可以用来防止下载盗链,当然
- 一、首先我们来填个坑支付验签失败这个问题折磨了我两天,官方文档比较含糊不清。各种百度下来的方法试过之后也不尽人意,最后发现问题是没有二次签名
- 在国外一博客看到的技巧,终于解决IE的这个老大难问题。我在IE的setAttribute bug也提到其解决方法,一是innerHTML,一
- python 中提供一种用于对函数固定属性的函数(与数学上的偏函数不一样)# 通常会返回10进制int('12345') &
- 本文实例讲述了python连接字符串的方法。分享给大家供大家参考。具体如下:方法1:直接通过加号操作符相加foobar = 'foo
- 代码和说明如下:<%Const ForReading = 1 &nbs
- 1 Kmean图像分割按照Kmean原理,对图像像素进行聚类。优点:此方法原理简单,效果显著。缺点:实践发现对于前景和背景颜色相近或者颜色区
- 本文实例讲述了Python网络编程之TCP与UDP协议套接字用法。分享给大家供大家参考,具体如下:TCP协议服务器端:#!/usr/bin/
- 抢票是并发执行多个进程可以访问同一个文件多个进程共享同一文件,我们可以把文件当数据库,用多个进程模拟多个人执行抢票任务db.tx
- 不论什么语言,我们都需要注意性能优化问题,提高执行效率。选择了脚本语言就要忍受其速度,这句话在某种程度上说明了Python作为脚本语言的不足
- 本文实例为大家分享了python实现单链表反转的具体代码,供大家参考,具体内容如下代码如下:class Node(object): 
- 【译者的话】 作为一家非盈利性的防止青少年 * 的机构, Five Alive 希望拥有一个独特的标志来配合机构的宣传。他们决定在网站上通过竞
- jxdawei的个人博客:http://www.iwcn.net本文目的:与您分享如何学习基于web标准的网页制作。适合人群:网页制作初学者
- 圆形的绘制 :OpenCV中使用circle(img,center,radius,color,thickness=None,lineType
- 不知道大家有没有这样一个烦恼,“自己的电脑总是被别人使用,又不好意思设置密码”,所以利用python设计了一个程序来实现自由管控。功能虽然简
- <script type="text/javascript"> // Close HTML Tags ---
- 目前绝大多数手机都支持WAP 2.0。WAP 2.0的页面设计具有更好的视觉效果,更接近网页。不过由于手机千差万别,手机浏览器的能力也各不相
- 本文实例讲述了Python实现更改图片尺寸大小的方法。分享给大家供大家参考,具体如下:1、PIL包推荐Pillow 。2、源码:#encod
- 字符串在Python内部的表示是Unicode编码,因此,在做编码转换时,通常需要以Unicode作为中间编码,即先将其他编码的字符串解码(