tensorflow pb to tflite 精度下降详解
作者:牛蛙爹爹 发布时间:2023-05-25 19:05:41
标签:tensorflow,tflite,精度下降
之前希望在手机端使用深度模型做OCR,于是尝试在手机端部署tensorflow模型,用于图像分类。
思路主要是想使用tflite部署到安卓端,但是在使用tflite的时候发现模型的精度大幅度下降,已经不能支持业务需求了,最后就把OCR模型调用写在服务端了,但是精度下降的原因目前也没有找到,现在这里记录一下。
工作思路:
1.训练图像分类模型;2.模型固化成pb;3.由pb转成tflite文件;
但是使用python 的tf interpreter 调用tflite文件就已经出现精度下降的问题,android端部署也是一样。
1.网络结构
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
slim = tf.contrib.slim
def ttnet(images, num_classes=10, is_training=False,
dropout_keep_prob=0.5,
prediction_fn=slim.softmax,
scope='TtNet'):
end_points = {}
with tf.variable_scope(scope, 'TtNet', [images, num_classes]):
net = slim.conv2d(images, 32, [3, 3], scope='conv1')
# net = slim.conv2d(images, 64, [3, 3], scope='conv1_2')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool1')
net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn1')
# net = slim.conv2d(net, 128, [3, 3], scope='conv2_1')
net = slim.conv2d(net, 64, [3, 3], scope='conv2')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool2')
net = slim.conv2d(net, 128, [3, 3], scope='conv3')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool3')
net = slim.conv2d(net, 256, [3, 3], scope='conv4')
net = slim.max_pool2d(net, [2, 2], 2, scope='pool4')
net = slim.batch_norm(net, activation_fn=tf.nn.relu, scope='bn2')
# net = slim.conv2d(net, 512, [3, 3], scope='conv5')
# net = slim.max_pool2d(net, [2, 2], 2, scope='pool5')
net = slim.flatten(net)
end_points['Flatten'] = net
# net = slim.fully_connected(net, 1024, scope='fc3')
net = slim.dropout(net, dropout_keep_prob, is_training=is_training,
scope='dropout3')
logits = slim.fully_connected(net, num_classes, activation_fn=None,
scope='fc4')
end_points['Logits'] = logits
end_points['Predictions'] = prediction_fn(logits, scope='Predictions')
return logits, end_points
ttnet.default_image_size = 28
def ttnet_arg_scope(weight_decay=0.0):
with slim.arg_scope(
[slim.conv2d, slim.fully_connected],
weights_regularizer=slim.l2_regularizer(weight_decay),
weights_initializer=tf.truncated_normal_initializer(stddev=0.1),
activation_fn=tf.nn.relu) as sc:
return sc
基于slim,由于是一个比较简单的分类问题,网络结构也很简单,几个卷积加池化。
测试效果是很棒的。真实样本测试集能达到99%+的准确率。
2.模型固化,生成pb文件
#coding:utf-8
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow as tf
from nets import nets_factory
import cv2
import os
import numpy as np
from datasets import dataset_factory
from preprocessing import preprocessing_factory
from tensorflow.python.platform import gfile
slim = tf.contrib.slim
#todo
#support arbitray image size and num_class
tf.app.flags.DEFINE_string(
'checkpoint_path', '/tmp/tfmodel/',
'The directory where the model was written to or an absolute path to a '
'checkpoint file.')
tf.app.flags.DEFINE_string(
'model_name', 'inception_v3', 'The name of the architecture to evaluate.')
tf.app.flags.DEFINE_string(
'preprocessing_name', None, 'The name of the preprocessing to use. If left '
'as `None`, then the model_name flag is used.')
FLAGS = tf.app.flags.FLAGS
tf.app.flags.DEFINE_integer(
'eval_image_size', None, 'Eval image size')
tf.app.flags.DEFINE_integer(
'eval_image_height', None, 'Eval image height')
tf.app.flags.DEFINE_integer(
'eval_image_width', None, 'Eval image width')
tf.app.flags.DEFINE_string(
'export_path', './ttnet_1.0_37_32.pb', 'the export path of the pd file')
FLAGS = tf.app.flags.FLAGS
NUM_CLASSES = 37
def main(_):
network_fn = nets_factory.get_network_fn(
FLAGS.model_name,
num_classes=NUM_CLASSES,
is_training=False)
# pre_image = tf.placeholder(tf.float32, [None, None, 3], name='input_data')
# preprocessing_name = FLAGS.preprocessing_name or FLAGS.model_name
# image_preprocessing_fn = preprocessing_factory.get_preprocessing(
# preprocessing_name,
# is_training=False)
# image = image_preprocessing_fn(pre_image, FLAGS.eval_image_height, FLAGS.eval_image_width)
# images2 = tf.expand_dims(image, 0)
images2 = tf.placeholder(tf.float32, (None,32, 32, 3),name='input_data')
logits, endpoints = network_fn(images2)
with tf.Session() as sess:
output = tf.identity(endpoints['Predictions'],name="output_data")
with gfile.GFile(FLAGS.export_path, 'wb') as f:
f.write(sess.graph_def.SerializeToString())
if __name__ == '__main__':
tf.app.run()
3.生成tflite文件
import tensorflow as tf
graph_def_file = "/datastore1/Colonist_Lord/Colonist_Lord/workspace/models/model_files/passport_model_with_tflite/ocr_frozen.pb"
input_arrays = ["input_data"]
output_arrays = ["output_data"]
converter = tf.lite.TFLiteConverter.from_frozen_graph(
graph_def_file, input_arrays, output_arrays)
tflite_model = converter.convert()
open("converted_model.tflite", "wb").write(tflite_model)
使用pb文件进行测试,效果正常;使用tflite文件进行测试,精度下降严重。下面附上pb与tflite测试代码。
pb测试代码
with tf.gfile.GFile(graph_filename, "rb") as f:
graph_def = tf.GraphDef()
graph_def.ParseFromString(f.read())
with tf.Graph().as_default() as graph:
tf.import_graph_def(graph_def)
input_node = graph.get_tensor_by_name('import/input_data:0')
output_node = graph.get_tensor_by_name('import/output_data:0')
with tf.Session() as sess:
for image_file in image_files:
abs_path = os.path.join(image_folder, image_file)
img = cv2.imread(abs_path).astype(np.float32)
img = cv2.resize(img, (int(input_node.shape[1]), int(input_node.shape[2])))
output_data = sess.run(output_node, feed_dict={input_node: [img]})
index = np.argmax(output_data)
label = dict_laebl[index]
dst_floder = os.path.join(result_folder, label)
if not os.path.exists(dst_floder):
os.mkdir(dst_floder)
cv2.imwrite(os.path.join(dst_floder, image_file), img)
count += 1
tflite测试代码
model_path = "converted_model.tflite" #"/datastore1/Colonist_Lord/Colonist_Lord/data/passport_char/ocr.tflite"
interpreter = tf.contrib.lite.Interpreter(model_path=model_path)
interpreter.allocate_tensors()
# Get input and output tensors.
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()
for image_file in image_files:
abs_path = os.path.join(image_folder,image_file)
img = cv2.imread(abs_path).astype(np.float32)
img = cv2.resize(img, tuple(input_details[0]['shape'][1:3]))
# input_data = np.array(np.random.random_sample(input_shape), dtype=np.float32)
interpreter.set_tensor(input_details[0]['index'], [img])
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
index = np.argmax(output_data)
label = dict_laebl[index]
dst_floder = os.path.join(result_folder,label)
if not os.path.exists(dst_floder):
os.mkdir(dst_floder)
cv2.imwrite(os.path.join(dst_floder,image_file),img)
count+=1
最后也算是绕过这个问题解决了业务需求,后面有空的话,还是会花时间研究一下这个问题。
如果有哪个大佬知道原因,希望不吝赐教。
补充知识:.pb 转tflite代码,使用量化,减小体积,converter.post_training_quantize = True
import tensorflow as tf
path = "/home/python/Downloads/a.pb" # pb文件位置和文件名
inputs = ["input_images"] # 模型文件的输入节点名称
classes = ['feature_fusion/Conv_7/Sigmoid','feature_fusion/concat_3'] # 模型文件的输出节点名称
# converter = tf.contrib.lite.TocoConverter.from_frozen_graph(path, inputs, classes, input_shapes={'input_images':[1, 320, 320, 3]})
converter = tf.lite.TFLiteConverter.from_frozen_graph(path, inputs, classes,
input_shapes={'input_images': [1, 320, 320, 3]})
converter.post_training_quantize = True
tflite_model = converter.convert()
open("/home/python/Downloads/aNew.tflite", "wb").write(tflite_model)
来源:https://blog.csdn.net/qian1122221/article/details/84590248


猜你喜欢
- 本文实例为大家分享了PHP文件打包下载zip的具体代码,供大家参考,具体内容如下<?php//获取文件列表function list_
- 用 xlrd 模块读取 Excelxlrd 安装cmd 中输入pip install xlrd 即可安装 xlrd 模块若失败请自行百度”p
- python每天在指定时间段运行程序及关闭程序场景程序需要在每天某一时间段内运行,然后在某一时间段内停止该程序。程序:from dateti
- 目录1. 理解 * 和 ** 2.Python函数的参数 3. 支持任意参数的函数
- 本文实例为大家分享了linux采用binary方式安装mysql的具体步骤,供大家参考,具体内容如下1、下载binary文件在官网上下载 m
- 阅读上一篇:FrontPage XP设计教程2——网页的编辑 制作一个漂亮的网页,离不开网页整体布局的设计,网页布局设计的合理与否,直接影响
- 本文探讨了提高MySQL数据库性能的思路,并从8个方面给出了具体的解决方法。1、选取最适用的字段属性MySQL可以很好的支持大数据量的存取,
- 请定义函数,将列表[10, 1, 2, 20, 10, 3, 2, 1, 15, 20, 44, 56, 3, 2, 1]中的重复元素除去,
- 本文实例为大家分享了python制作英文字典的具体代码,供大家参考,具体内容如下功能有添加单词,多次添加单词的意思,查询,退出,建立单词文件
- 本文实例讲述了Python输出PowerPoint(ppt)文件中全部文字信息的方法。分享给大家供大家参考。具体分析如下:下面的代码依赖于w
- 起步在我的印象中,python的机制会自动清理已经完成任务的子进程的。通过网友的提问,还真看到了僵尸进程。import multiproce
- 🥩数据采集🍖确定网址王者新赛季马上就要开始了,大家都开始冲榜了,准备拿一个小省标,那么,本文,就来练习获取各地最低战力的爬虫采集实战。确定好
- 环境 python3.0工具 pycharm谷歌插件chromedriver程序执行方法from selenium import webdr
- 什么多态:同一事物有多种形态为何要有多态=》多态会带来什么样的特性,多态性多态性指的是可以在不考虑对象具体类型的情况下而直接使用对象多态指的
- Oracle游标分为显示游标和隐式游标。 显示游标(Explicit Cursor):在PL/SQL程序中定义的、用于查询的游标称作显示游标
- 最近工作转型到数据开发领域,想在本地搭建一个数据开发环境。自己有三年python开发经验,马上想到使用numpy、scipy、sklearn
- 一般我们可以使用背景图的方式给图片添加阴影,但对于不固定尺寸的图片如何实现呢?我们可以采取“视觉欺骗 * ”——定义渐变边框来实现运行代码框&
- 在定义类的过程中,无论是显式创建类的构造方法,还是向类中添加实例方法,都要求将 self 参数作为方法的第一个参数。例如,定义一个 Pers
- 需求:Python实现三次密码验证,每次验证结果需要提示,三次验证不通过需要单独提示代码如下:user = '张无忌'pas
- 获取nc数据的相关信息from netCDF4 import Datasetimport numpy as npimport pandas