使用keras内置的模型进行图片预测实例
作者:lucky404 发布时间:2021-12-27 17:54:29
keras 模块里面为我们提供了一个预训练好的模型,也就是开箱即可使用的图像识别模型
趁着国庆假期有时间我们就来看看这个预训练模型如何使用吧
可用的模型有哪些?
根据官方文档目前可用的模型大概有如下几个
1、VGG16
2、VGG19
3、ResNet50
4、InceptionResNetV2
5、InceptionV3
它们都被集成到了keras.applications 中
模型文件从哪来
当我们使用了这几个模型时,keras就会去自动下载这些已经训练好的模型保存到我们本机上面
模型文件会被下载到 ~/.keras/models/并在载入模型时自动载入
各个模型的信息:
如何使用预训练模型
使用大致分为三个步骤
1、导入所需模块
2、找一张你想预测的图像将图像转为矩阵
3、将图像矩阵放到模型中进行预测
关于图像矩阵的大小
VGG16,VGG19,ResNet50 默认输入尺寸是224x224
InceptionV3, InceptionResNetV2 模型的默认输入尺寸是299x299
代码demo
假设我现在有一张图片
我需要使用预训练模型来识别它
那我们就按照上面的步骤
第一步导入模块
from keras.applications import VGG16
from keras.applications import VGG19
from keras.applications import ResNet50
from keras.applications import InceptionV3
from keras.applications import InceptionResNetV2
第二步将图像转为矩阵
这里我们需要使用 keras.preprocessing.image 里面 img_to_array 来帮我们转
image = cv2.imread(img)
image = cv2.resize(image, self.dim)
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
第三步 将图像矩阵丢到模型中进行预测
predict = model.predict(preprocess)
decode_predict = decode_predictions(predict)
完整代码如下
1、配置文件
2、获取配置文件的模块
3、图像预测模块
配置文件
[image]
image_path=/home/fantasy/Pictures/cat.jpg
[model]
model=vgg16
[weights]
weight=imagenet
获取配置文件的模块
import configparser
cf = configparser.ConfigParser()
cf.read("configs.cnf")
def getOption(section, key):
return cf.get(section, key)
图像预测模块以及主要实现
# keras 提供了一些预训练模型,也就是开箱即用的 已经训练好的模型
# 我们可以使用这些预训练模型来进行图像识别,目前的预训练模型大概可以识别2.2w种类型的东西
# 可用的模型:
# VGG16
# VGG19
# ResNet50
# InceptionResNetV2
# InceptionV3
# 这些模型被集成到 keras.applications 中
# 当我们使用了这些内置的预训练模型时,模型文件会被下载到 ~/.keras/models/并在载入模型时自动载入
# VGG16,VGG19,ResNet50 默认输入尺寸是224x224
# InceptionV3, InceptionResNetV2 模型的默认输入尺寸是299x299
# 使用内置的预训练模型的步骤
# step1 导入需要的模型
# step2 将需要识别的图像数据转换为矩阵(矩阵的大小需要根据模型的不同而定)
# step3 将图像矩阵丢到模型里面进行预测
# -------------------------------------------------------
# step1
import cv2
import numpy as np
from getConfig import getOption
from keras.applications import VGG16
from keras.applications import VGG19
from keras.applications import ResNet50
from keras.applications import InceptionV3
from keras.applications import InceptionResNetV2
from keras.applications import imagenet_utils
from keras.applications.imagenet_utils import decode_predictions
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.inception_v3 import preprocess_input
class ImageTools(object):
"""
使用keras预训练模型进行图像识别
"""
def __init__(self, img, model, w):
self.image = img
self.model = model
self.weight = w
# step2
def image2matrix(self, img):
"""
将图像转为矩阵
"""
image = cv2.imread(img)
image = cv2.resize(image, self.dim)
image = img_to_array(image)
image = np.expand_dims(image, axis=0)
return image
@property
def dim(self):
"""
图像矩阵的维度
"""
if self.model in ["inceptionv3", "inceptionresnetv2"]:
shape = (299, 299)
else:
shape = (224, 224)
return shape
@property
def Model(self):
"""
模型
"""
models = {
"vgg16": VGG16,
"vgg19": VGG19,
"resnet50": ResNet50,
"inceptionv3": InceptionV3,
"inceptionresnetv2": InceptionResNetV2
}
return models[self.model]
# step3
def prediction(self):
"""
预测
"""
model = self.Model(weights=self.weight)
if self.model in ["inceptionv3", "inceptionresnetv2"]:
preprocess = preprocess_input(self.image2matrix(self.image))
else:
preprocess = imagenet_utils.preprocess_input(self.image2matrix(self.image))
predict = model.predict(preprocess)
decode_predict = decode_predictions(predict)
for (item, (imgId, imgLabel, proba)) in enumerate(decode_predict[0]):
print("{}, {}, {:.2f}%".format(item + 1, imgLabel, proba * 100))
if __name__ == "__main__":
image = getOption("image", "image_path")
model = getOption("model", "model")
weight = getOption("weights", "weight")
tools = ImageTools(image, model, weight)
tools.prediction()
运行起来时会将模型文件下载到本机,因此第一次运行会比较久(有可能出现的情况就是下载不了,被墙了)
我们来看看使用VGG16的模型预测输出的效果如何
最后如果大家需要使用其他模型时修改 配置文件的model 即可
来源:https://blog.csdn.net/lucky404/article/details/82931322


猜你喜欢
- 修改list中所有元素类型:方法一:new = list()a = ['1', '2', '3
- import urllib.parse,os.path,time,sys,re,urllib.requestfrom http.client
- AlexNet是2012年ImageNet比赛的冠军,虽然过去了很长时间,但是作为深度学习中的经典模型,AlexNet不但有助于我们理解其中
- 环境:win10+phpstorm2022+phpstudy8+lnmp1、phpinfo(); 查看是否安装xdebug,没有
- var request = require('request')var url = 'http://www.baid
- 前言随着科技的发展,人脸识别技术在许多领域得到的非常广泛的应用,手机支付、银行身份验证、手机人脸解锁等等。识别废话少说,这里我们使用 ope
- 前言:情人节、三八女神节、520、七夕节、圣诞节、元旦、生日、新年、各种纪念日……这些节日,对于
- 1.mysql复制概念 指将主数据库的DDL和DML操作通过二进制日志传到复制服务器上,然后在复制服务器上将这些日志文件重新执行
- 1、Motivation:I wanna modify the value of some param;I wanna check the
- 一、下载镜像docker Hub官网URL:https://hub.docker.com/_/mysql/下载最新版本:docker pul
- 刚开始使用django,在创建第一个app时被提示不知道命令runserver,百度得出是环境变量的问题。1、配置python变量环境,C:
- PHP 过滤器PHP 过滤器用于验证和过滤来自非安全来源的数据,比如用户的输入。什么是 PHP 过滤器PHP 过滤器用于验证和过滤来自非安全
- df.fillna主要用来对缺失值进行填充,可以选择填充具体的数字,或者选择临近填充。官方文档DataFrame.fillna(self,
- 本文实例为大家分享了python3实现弹弹球小游戏的具体代码,供大家参考,具体内容如下from tkinter import *from t
- 在使用Django过程中需要开发一些API给其他系统使用,为了安全把Token等验证信息放在header头中。如何获取:使用request.
- enumerate首先介绍的是enumerate函数。在我们日常编程的过程当中,经常会遇到一个问题。在C语言以及一些古老的语言当中是没有迭代
- 有些事情始终是需要坚持下去的。。。今天复习一下之前用到的连续相同数据的统计。首先,创建一个简单的测试表,这里过程就略过了,直接上表(真的是以
- 本文实例讲述了Python面向对象程序设计之继承、多态原理与用法。分享给大家供大家参考,具体如下:相关内容:继承:多继承、super、__i
- 字符串就是一个话题中心。给字符串编号在很多很多情况下,我们都要对字符串中的每个字符进行操作(具体看后面的内容),要准确进行操作,必须做的一个
- 1. 上下文管理器是什么?举个例子,你在写Python代码的时候经常将一系列操作放在一个语句块中:(1)当某条件为真 – 执行这个语句块(2