Tensorflow加载模型实现图像分类识别流程详解
作者:技术老匠 发布时间:2023-12-22 02:31:13
前言
深度学习框架在市面上有很多。比如Theano、Caffe、CNTK、MXnet 、Tensorflow等。今天讲解的就是主角Tensorflow。Tensorflow的前身是Google大脑项目的一个分布式机器学习训练框架,它是一个十分基础且集成度很高的系统,它的目标就是为研究超大型规模的视觉项目,后面延申到各个领域。Tensorflow 在2015年正式开源,开源的一个月内就收获到1w多的starts,这足以说明Tensorflow的优越性以及Google的影响力。在Api方面Tensorflow为了满足绝大部分的开发者需求,这也是Google的一贯作风,集成了Java、Go、Python、C++等编程语言。
图像识别是一件很有趣的事,话不多说,咱们先了解下特征提取VGG in Tensorflow。官网地址:VGG in TensorFlow · Davi Frossard。
VGG 是牛津大学的 K. Simonyan 和 A. Zisserman 在论文“Very Deep Convolutional Networks for Large-Scale Image Recognition”中提出的卷积神经网络模型。该模型在 ImageNet 中实现了 92.7% 的 top-5 测试准确率,这是一个包含 1000 个类别的超过 1400 万张图像的数据集。 在这篇简短的文章中,我们提供了 VGG16 的实现以及从原始 Caffe 模型转换为 TensorFlow 的权重。这句话是VGGNet官方的介绍,直接从它提供的数字可以看出来,它的识别率是十分高的,是不是很激动,动起手来吧。
开发步骤分4步,如下所示:
a) 依赖加载
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os
import scipy.io
import scipy.misc
from imagenet_classes import class_names
b)定义卷积、池化等函数
def _conv_layer(input,weight,bias):
conv = tf.nn.conv2d(input,weight,strides=[1,1,1,1],padding="SAME")
return tf.nn.bias_add(conv,bias)
def _pool_layer(input):
return tf.nn.max_pool(input,ksize=[1,2,2,1],strides=[1,2,2,1],padding="SAME")
def preprocess(image,mean_pixel):
'''简单预处理,全部图片减去平均值'''
return image-mean_pixel
def unprocess(image,mean_pixel):
return image+mean_pixel
c)图像的读取以及保存
def imread(path):
return scipy.misc.imread(path)
def imsave(image,path):
img = np.clip(image,0,255).astype(np.int8)
scipy.misc.imsave(path,image)
d) 定义网络结构,这里使用的是VGG19
def net(data_path,input_image,sess=None):
"""
读取VGG模型参数,搭建VGG网络
:param data_path: VGG模型文件位置
:param input_image: 输入测试图像
:return:
"""
layers = (
'conv1_1', 'conv1_2', 'pool1',
'conv2_1', 'conv2_2', 'pool2',
'conv3_1', 'conv3_2', 'conv3_3','conv3_4', 'pool3',
'conv4_1', 'conv4_2', 'conv4_3','conv4_4', 'pool4',
'conv5_1', 'conv5_2', 'conv5_3','conv5_4', 'pool5',
'fc1' , 'fc2' , 'fc3' ,
'softmax'
)
data = scipy.io.loadmat(data_path)
mean = data["normalization"][0][0][0][0][0]
input_image = np.array([preprocess(input_image, mean)]).astype(np.float32)#去除平均值
net = {}
current = input_image
net["src_image"] = tf.constant(current) # 存储数据
count = 0 #计数存储
for i in range(43):
if str(data['layers'][0][i][0][0][0][0])[:4] == ("relu"):
continue
if str(data['layers'][0][i][0][0][0][0])[:4] == ("pool"):
current = _pool_layer(current)
elif str(data['layers'][0][i][0][0][0][0]) == ("softmax"):
current = tf.nn.softmax(current)
elif i == (37):
shape = int(np.prod(current.get_shape()[1:]))
current = tf.reshape(current, [-1, shape])
kernels, bias = data['layers'][0][i][0][0][0][0]
kernels = np.reshape(kernels,[-1,4096])
bias = bias.reshape(-1)
current = tf.nn.relu(tf.add(tf.matmul(current,kernels),bias))
elif i == (39):
kernels, bias = data['layers'][0][i][0][0][0][0]
kernels = np.reshape(kernels,[4096,4096])
bias = bias.reshape(-1)
current = tf.nn.relu(tf.add(tf.matmul(current,kernels),bias))
elif i == 41:
kernels, bias = data['layers'][0][i][0][0][0][0]
kernels = np.reshape(kernels, [4096, 1000])
bias = bias.reshape(-1)
current = tf.add(tf.matmul(current, kernels), bias)
else:
kernels,bias = data['layers'][0][i][0][0][0][0]
#注意VGG存储方式为[,]
#kernels = np.transpose(kernels,[1,0,2,3])
bias = bias.reshape(-1)#降低维度
current = tf.nn.relu(_conv_layer(current,kernels,bias))
net[layers[count]] = current #存储数据
count += 1
return net, mean
e)加载模型进行识别
if __name__ == '__main__':
VGG_PATH = "./one/imagenet-vgg-verydeep-19.mat"
IMG_PATH = './one/3.jpg'
input_image =imread(IMG_PATH)
shape = (1, input_image.shape[0], input_image.shape[1], input_image.shape[2])
with tf.Session() as sess:
image = tf.placeholder('float', shape=shape)
nets, mean_pixel, all_layers= net(VGG_PATH, image)
input_image_pre=np.array([preprocess(input_image,mean_pixel)])
layers = all_layers
for i , layer in enumerate(layers):
print("[%d/%d] %s" % (i+1,len(layers),layers))
features = nets[layer].eval(feed_dict={image:input_image_pre})
print("Type of 'feature' is ",type(features))
print("Shape of 'features' is %s" % (features.shape,))
if 1:
plt.figure(i+1,figsize=(10,5))
plt.matshow(features[0,:,:,0],cmap=plt.cm.gray,fignum=i+1)
plt.title(""+layer)
plt.colorbar()
plt.show()
VGG19网络介绍
VGG19 的宏观架构如图所示。我们在 TensorFlow 中的文件 vgg19.py 中对其进行编码。请注意,我们包含一个预处理层,它采用像素值在 0-255 范围内的 RGB 图像并减去平均图像值(在整个 ImageNet 训练集上计算)。
来源:https://blog.csdn.net/qq_33011831/article/details/126919632


猜你喜欢
- 开启Web服务1.基本方式Python中自带了简单的服务器程序,能较容易地打开服务。在python3中将原来的SimpleHTTPServe
- 最近开始学习Python,但只限于看理论,编几行代码,觉得没有意思,就想能不能用Python编写可视化的界面。遂查找了相关资料,发现了PyQ
- 好多同志对 iframe 是如何控制的,并不是十分了解,基本上还处于一个模糊的认识状态.注意两个事项,ifr 是一个以存在的 iframe
- 来看一个实例:<!DOCTYPE html><html lang="en"><head&g
- 目录一、Python 中的作用域规则和嵌套函数二、定义闭包函数三、何时使用闭包?四、总结一、Python 中的作用域规则和嵌套函数每当执行一
- 本文实例讲述了golang实现的文件上传与文件下载功能。分享给大家供大家参考,具体如下:upload.gopackage commonimp
- MySQL数据库具有跨平台性,不仅可以在Windows上运行,还可以在UNIX,Linux和Mac OS等操作系统上运行 1.先简
- 插件很多从事互联网行业或者开发的人员来不是很陌生,wordpress之所以为什么那么受欢迎,很大部分是因为他的强大的插件库,还要譬如就是大家
- 某些时候我们需要让类动态的添加属性或方法,比如我们在做插件时就可以采用这种方法。用一个配置文件指定需要加载的模块,可以根据业务扩展任意加入需
- 前言在启动 Django 项目时,Django 默认监听的端口号为 8000,设置的默认 IP 地址为 127.0.0.1 。如果需要修改默
- 组件间通信的概念开始之前,我们把组件间通信这个词进行拆分组件通信都知道组件是vue最强大的功能之一,vue中每一个.vue文件我们都可以视之
- django常见数据库配置错误出现报错代码为1045的这类几乎都是数据库配置出错报错1django.db.utils.Operational
- Mac下mysql安装配置方法图文教程记录如下使用安装包安装mysql双击pkg文件安装一路向下,记得保存最后弹出框中的密码(它是你mysq
- 简介这个模块处理python中常见类型数据和Python bytes之间转换。这可用于处理存储在文件或网络连接中的bytes数据以及其他来源
- auto-vue-fileauto create .vue file by shell command通过终端自动创建vue文件前言:1:
- 引言善于观察的朋友一定会敏锐地发现ChatGPT网页端是逐句给出问题答案的,同样,ChatGPT后台Api接口请求中,如果将Stream参数
- EXCEL的数值排序功能还是挺强大的,升序、降序,尤其自定义排序,能够对多个字段进行排序工作。那么,在Python * 中,有没有这样强大的排
- SELECT TABLE_SCHEMA,TABLE_NAME FROM information_schema.`COLUMNS` WHERE
- 在实际的工作中,尤其是在生产环境里边,SQL语句的优化问题十分的重要,它对数据库的性能的提升也起着显著的作用.我们总是在抱怨机器的性能问题,
- 最近基于selenium写了一个python小工具,记录下学习记录,自己运行的环境是Ubuntu 14.04.4, Python 2.7,C