keras实现VGG16方式(预测一张图片)
作者:程序员孙大圣 发布时间:2024-01-04 00:36:53
标签:keras,VGG16,预测,图片
我就废话不多说了,大家还是直接看代码吧~
from keras.applications.vgg16 import VGG16#直接导入已经训练好的VGG16网络
from keras.preprocessing.image import load_img#load_image作用是载入图片
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.applications.vgg16 import decode_predictions
model = VGG16()
image = load_img('D:\\photo\\dog.jpg',target_size=(224,224))#参数target_size用于设置目标的大小,如此一来无论载入的原图像大小如何,都会被标准化成统一的大小,这样做是为了向神经网络中方便地输入数据所需的。
image = img_to_array(image)#函数img_to_array会把图像中的像素数据转化成NumPy中的array,这样数据才可以被Keras所使用。
#神经网络接收一张或多张图像作为输入,也就是说,输入的array需要有4个维度: samples, rows, columns, and channels。由于我们仅有一个 sample(即一张image),我们需要对这个array进行reshape操作。
image = image.reshape((1,image.shape[0],image.shape[1],image.shape[2]))
image = preprocess_input(image)#对图像进行预处理
y = model.predict(image)#预测图像的类别
label = decode_predictions(y)#Keras提供了一个函数decode_predictions(),用以对已经得到的预测向量进行解读。该函数返回一个类别列表,以及类别中每个类别的预测概率,
label = label[0][0]
print('%s(%.2f%%)'%(label[1],label[2]*100))
# print(model.summary())
from keras.models import Sequential
from keras.layers.core import Flatten,Dense,Dropout
from keras.layers.convolutional import Convolution2D,MaxPooling2D,ZeroPadding2D
from keras.optimizers import SGD
import numpy as np
from keras.preprocessing import image
from keras.applications.imagenet_utils import preprocess_input, decode_predictions
import time
from keras import backend as K
K.set_image_dim_ordering('th')
def VGG_16(weights_path=None):
model = Sequential()
model.add(ZeroPadding2D((1, 1), input_shape=(3, 224, 224)))
model.add(Convolution2D(64, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1000, activation='softmax'))
if weights_path:
model.load_weights(weights_path,by_name=True)
return model
model = VGG_16(weights_path='F:\\Kaggle\\vgg16_weights.h5')
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss='categorical_crossentropy')
t0 = time.time()
img = image.load_img('D:\\photo\\dog.jpg', target_size=(224, 224))
x = image.img_to_array(img) # 三维(224,224,3)
x = np.expand_dims(x, axis=0) # 四维(1,224,224,3)#因为keras要求的维度是这样的,所以要增加一个维度
x = preprocess_input(x) # 预处理
print(x.shape)
y_pred = model.predict(x) # 预测概率
t1 = time.time()
print("测试图:", decode_predictions(y_pred)) # 输出五个最高概率(类名, 语义概念, 预测概率)
print("耗时:", str((t1 - t0) * 1000), "ms")
这是两种不同的方式,第一种是直接使用vgg16的参数,需要在运行时下载,第二种是我们已经下载好的权重,直接在参数中输入我们的路径即可。
补充知识:keras加经典网络的预训练模型(以VGG16为例)
我就废话不多说了,大家还是直接看代码吧~
# 使用VGG16模型
from keras.applications.vgg16 import VGG16
print('Start build VGG16 -------')
# 获取vgg16的卷积部分,如果要获取整个vgg16网络需要设置:include_top=True
model_vgg16_conv = VGG16(weights='imagenet', include_top=False)
model_vgg16_conv.summary()
# 创建自己的输入格式
# if K.image_data_format() == 'channels_first':
# input_shape = (3, img_width, img_height)
# else:
# input_shape = (img_width, img_height, 3)
input = Input(input_shape, name = 'image_input') # 注意,Keras有个层就是Input层
# 将vgg16模型原始输入转换成自己的输入
output_vgg16_conv = model_vgg16_conv(input)
# output_vgg16_conv是包含了vgg16的卷积层,下面我需要做二分类任务,所以需要添加自己的全连接层
x = Flatten(name='flatten')(output_vgg16_conv)
x = Dense(4096, activation='relu', name='fc1')(x)
x = Dense(512, activation='relu', name='fc2')(x)
x = Dense(128, activation='relu', name='fc3')(x)
x = Dense(1, activation='softmax', name='predictions')(x)
# 最终创建出自己的vgg16模型
my_model = Model(input=input, output=x)
# 下面的模型输出中,vgg16的层和参数不会显示出,但是这些参数在训练的时候会更改
print('\nThis is my vgg16 model for the task')
my_model.summary()
来源:https://blog.csdn.net/sunshunli/article/details/81456566


猜你喜欢
- 每个存储过程都有默认的返回值,默认值为0。下面我们分别看看在management studio中如何查看输出参数,返回值以及结果集,然后我们
- 第一步:保存下列文件为:CALENDAR.ASP <%@ LANGUAGE = V
- 前言每种编程语言为了表现出色,并且实现卓越的性能,都需要有大量编译器级与解释器级的优化。由于字符串是任何编程语言中不可或缺的一个部分,因此,
- 内容摘要合理使用渐变留白网格布局提高字体应用明确而有效的导航设计漂亮、有用的页脚介绍优秀设计和卓越设计之间的区别是比较小的。一般人可能无法解
- 1.安装MySql目前MySQL有两种形式的文件,一个是msi格式,一个是zip格式的。msi格式的直接点击setup.exe就好,按照步骤
- 本文实例为大家分享了Python时间戳使用和相互转换的具体代码,供大家参考,具体内容如下1.将字符串的时间转换为时间戳方法: &n
- 最近对 Range 和 Selection 比较感兴趣。基本非 IE 的浏览器都支持 DOM Level2 中的 Range,而 IE 中仅
- 听到一些人说现在做产品设计很没有成就感。没有什么创造力,除了抄袭模仿(称之为竞争分析)、千篇一律(又称规范标准)还有复杂的流程、粗制滥造的表
- 1、创建存储过程 create or replace procedure test(var_name_1 in type,var_name_
- 本文实例讲述了python单元测试unittest用法。分享给大家供大家参考。具体分析如下:单元测试作为任何语言的开发者都应该是必要的,因为
- SQL 标准使用 CREATE TABLE 语句创建数据表;MySQL 则实现了三种创建表的方法,支持自定义表结构或者通过复制已有的表结构来
- 软件版本及平台:MySQL-5.7.17-winx64,win7家庭版一、下载安装包https://cdn.mysql.com//Downl
- 制作圆角导航其实跟制作圆角边框是一样的道理,有一种很常见的方法就是使用CSS绝对定位,切四个圆角的小图片,然后分别定位在四个角,这样就可以实
- 写项目时,发现 element 里的图标没有我需要的图标,两种情况:① 简单的替换小图标,没有选中变色等要求② 有选中变色等要求,稍微复杂的
- 一、前端工具vscode1.1、概述前端开发是创建Web页面或app等前端界面呈现给用户的过程,通过HTML,CSS及JavaScript以
- 本人在CentOS6.4上安装万mysql后,无法通过root进入,因为安装的时候,并没有设置root密码,似乎有个初始随机密码,但是不记得
- 导语昨天看到有留言竟然说我是月更博主,我明明更新地这么勤快(心虚.jpg)。看吧,昨天刚更新过,今天又来更新了。今天还是带大家写个小游戏吧,
- 但是如果在utf-8编码下,一个汉字是占3个字符长度的,比如字符串$str=”你好啊!!”; 如果你用strlen函数来判断,长度是11,正
- 实现一个树形表格的时候有多种方法:比如把 ztree 的树形直接拼接成表格,或者用强大的 jqgrid 实现,今天介绍一个比较轻量级的实现:
- 一 模板语法传值方式一:# urls.pypath('template', views.template)# views.p