Keras预训练的ImageNet模型实现分类操作
作者:cchangcs 公众号: hw_cch 发布时间:2021-12-14 01:33:11
标签:Keras,预训练,ImageNet,分类
本文主要介绍通过预训练的ImageNet模型实现图像分类,主要使用到的网络结构有:VGG16、InceptionV3、ResNet50、MobileNet。
代码:
import keras
import numpy as np
from keras.applications import vgg16, inception_v3, resnet50, mobilenet
# 加载模型
vgg_model = vgg16.VGG16(weights='imagenet')
inception_model = inception_v3.InceptionV3(weights='imagenet')
resnet_model = resnet50.ResNet50(weights='imagenet')
mobilenet_model = mobilenet.MobileNet(weights='imagenet')
# 导入所需的图像预处理模块
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.imagenet_utils import decode_predictions
import matplotlib.pyplot as plt
%matplotlib inline
filename= 'images/cat.jpg'
# 将图片输入到网络之前执行预处理
'''
1、加载图像,load_img
2、将图像从PIL格式转换为Numpy格式,image_to_array
3、将图像形成批次,Numpy的expand_dims
'''
# 以PIL格式加载图像
original = load_img(filename, target_size=(224, 224))
print('PIL image size', original.size)
plt.imshow(original)
plt.show()
# 将输入图像从PIL格式转换为Numpy格式
# In PIL-- 图像为(width, height, channel)
# In Numpy——图像为(height, width, channel)
numpy_image = img_to_array(original)
plt.imshow(np.uint8(numpy_image))
plt.show()
print('numpy array size', numpy_image.size)
# 将图像/图像转换为批量格式
# expand_dims将为特定轴上的数据添加额外的维度
# 网络的输入矩阵具有形式(批量大小,高度,宽度,通道)
# 因此,将额外的维度添加到轴0。
image_batch = np.expand_dims(numpy_image, axis=0)
print('image batch size', image_batch.shape)
plt.imshow(np.uint8(image_batch[0]))
# 使用各种网络进行预测
# 通过从批处理中的图像的每个通道中减去平均值来预处理输入。
# 平均值是通过从ImageNet获得的所有图像的R,G,B像素的平均值获得的三个元素的阵列
# 获得每个类的发生概率
# 将概率转换为人类可读的标签
# VGG16 网络模型
# 对输入到VGG模型的图像进行预处理
processed_image = vgg16.preprocess_input(image_batch.copy())
# 获取预测得到的属于各个类别的概率
predictions = vgg_model.predict(processed_image)
# 输出预测值
# 将预测概率转换为类别标签
# 缺省情况下将得到最有可能的五种类别
label_vgg = decode_predictions(predictions)
label_vgg
# ResNet50网络模型
# 对输入到ResNet50模型的图像进行预处理
processed_image = resnet50.preprocess_input(image_batch.copy())
# 获取预测得到的属于各个类别的概率
predictions = resnet_model.predict(processed_image)
# 将概率转换为类标签
# 如果要查看前3个预测,可以使用top参数指定它
label_resnet = decode_predictions(predictions, top=3)
label_resnet
# MobileNet网络结构
# 对输入到MobileNet模型的图像进行预处理
processed_image = mobilenet.preprocess_input(image_batch.copy())
# 获取预测得到属于各个类别的概率
predictions = mobilenet_model.predict(processed_image)
# 将概率转换为类标签
label_mobilnet = decode_predictions(predictions)
label_mobilnet
# InceptionV3网络结构
# 初始网络的输入大小与其他网络不同。 它接受大小的输入(299,299)。
# 因此,根据它加载具有目标尺寸的图像。
# 加载图像为PIL格式
original = load_img(filename, target_size=(299, 299))
# 将PIL格式的图像转换为Numpy数组
numpy_image = img_to_array(original)
# 根据批量大小重塑数据
image_batch = np.expand_dims(numpy_image, axis=0)
# 将输入图像转换为InceptionV3所能接受的格式
processed_image = inception_v3.preprocess_input(image_batch.copy())
# 获取预测得到的属于各个类别的概率
predictions = inception_model.predict(processed_image)
# 将概率转换为类标签
label_inception = decode_predictions(predictions)
label_inception
import cv2
numpy_image = np.uint8(img_to_array(original)).copy()
numpy_image = cv2.resize(numpy_image,(900,900))
cv2.putText(numpy_image, "VGG16: {}, {:.2f}".format(label_vgg[0][0][1], label_vgg[0][0][2]) , (350, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
cv2.putText(numpy_image, "MobileNet: {}, {:.2f}".format(label_mobilenet[0][0][1], label_mobilenet[0][0][2]) , (350, 75), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
cv2.putText(numpy_image, "Inception: {}, {:.2f}".format(label_inception[0][0][1], label_inception[0][0][2]) , (350, 110), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
cv2.putText(numpy_image, "ResNet50: {}, {:.2f}".format(label_resnet[0][0][1], label_resnet[0][0][2]) , (350, 145), cv2.FONT_HERSHEY_SIMPLEX, 1, (255, 0, 0), 3)
numpy_image = cv2.resize(numpy_image, (700,700))
cv2.imwrite("images/{}_output.jpg".format(filename.split('/')[-1].split('.')[0]),cv2.cvtColor(numpy_image, cv2.COLOR_RGB2BGR))
plt.figure(figsize=[10,10])
plt.imshow(numpy_image)
plt.axis('off')
训练数据:
运行结果:
来源:https://blog.csdn.net/github_39611196/article/details/88016880


猜你喜欢
- 开发人员有时候使用类似下面SQL将字符串转换为日期时间类型,乍一看,这样的SQL的写法是没有什么问题的。但是这样的SQL其实有时候就是一个定
- 前言还是之前说的项目,环境目前已经准备好了,项目准备验证阶段发现了一个问题,从上层应用输入鉴权访问应用,一直在等待状态,输入了正确的用户名及
- 最近有同学询问如何利用Python处理xml文件,特此整理一个比较简洁的操作手册,供大家参阅。首先准备一个xml文件,xml中的内容如下所示
- 目录一、什么样的备份是数据库逻辑备份呢?二、常用的逻辑备份①生成 INSERT 语句备份②生成特定格式的纯文本备份数据文件备份1.通过执行
- 环境准备python3.6PyCharm 2017.1.3Windows环境框架搭建selenium3.6安装方法:pip install
- 本人最近在当当网上购买了一本关于用户体验的书,在此把最实在的内容整理下发给大家分享下。第一步:表现层视觉设计,也就是我们说的网页设计师做的工
- 很多朋友对FrontPage2003中增加的网页布局功能很感兴趣,现在我们一起来深入了解这一实用功能。 用FrontPage200
- python中return的用法1、return语句就是把执行结果返回到调用的地方,并把程序的控制权一起返回程序运行到所遇到的第一个retu
- 针对现在大部分的网站都是使用js加密,js加载的,并不能直接抓取出来,这时候就不得不适用一些三方类库来执行js语句execjs,一个比较好用
- 节点类型主要有三种:元素节点,属性节点和文本节点。而对DOM的主要也就是围绕元素节点和属性节点的增删改查。下面就分别从对元素节点的操作和对属
- 先判断是jquery对象还是html对象, 如果是jquery对象, 可以直接用 jquery对象.attr("
- 实训课期间忙里偷闲的学习了python的selenium包,唯一一点不好是要自己去查英文文档,明摆着欺负我这种英语不好的,想着用谷歌翻译一下
- 最近做了一个项目其中有项目需求涉及到访问控制,在访问需要登录才能使用的页面或功能时,会弹出登录框:效果如下: 图 1-点击用户名时,如未登录
- 一、介绍Python:python代码解释器,用于编译.py代码,python可以单独安装,本次环境配置目的用于解决计算机视觉处理,因此选用
- var sss=(String.fromCharCode(127)); var xmlhttp =
- 本文实例讲述了C#编程实现连接ACCESS数据库的方法。分享给大家供大家参考,具体如下:一、建立FORM窗体,加一个按钮控件,加一个DATA
- 从字面意思看了一下是因为slave_pending_jobs_size_max默认值为16777216(16MB),但是slave接收到的s
- 适配器设计模式是懒得改动某些代码,或者某些接口不方便改动的时候,使用一个特定的封装,一些特定的编写办法,使不同的接口可以使用同种调用方式使用
- 前言最近在工作中遇到了这个需求,估计搞了一个多小时才把这个远程连接搞好。一台本地电脑,一台云服务器,都是linux系统。下面来看看详细的介绍
- 一.思路1.整体思路2.代码思路思路很简单,就是用python发送请求,提取响应体中的状态码加以判断,最后保存到本地txt文本中,以实现网站