keras实现VGG16方式(预测一张图片)
作者:程序员孙大圣 发布时间:2024-01-04 00:36:53
标签:keras,VGG16,预测,图片
我就废话不多说了,大家还是直接看代码吧~
from keras.applications.vgg16 import VGG16#直接导入已经训练好的VGG16网络
from keras.preprocessing.image import load_img#load_image作用是载入图片
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.applications.vgg16 import decode_predictions
model = VGG16()
image = load_img('D:\\photo\\dog.jpg',target_size=(224,224))#参数target_size用于设置目标的大小,如此一来无论载入的原图像大小如何,都会被标准化成统一的大小,这样做是为了向神经网络中方便地输入数据所需的。
image = img_to_array(image)#函数img_to_array会把图像中的像素数据转化成NumPy中的array,这样数据才可以被Keras所使用。
#神经网络接收一张或多张图像作为输入,也就是说,输入的array需要有4个维度: samples, rows, columns, and channels。由于我们仅有一个 sample(即一张image),我们需要对这个array进行reshape操作。
image = image.reshape((1,image.shape[0],image.shape[1],image.shape[2]))
image = preprocess_input(image)#对图像进行预处理
y = model.predict(image)#预测图像的类别
label = decode_predictions(y)#Keras提供了一个函数decode_predictions(),用以对已经得到的预测向量进行解读。该函数返回一个类别列表,以及类别中每个类别的预测概率,
label = label[0][0]
print('%s(%.2f%%)'%(label[1],label[2]*100))
# print(model.summary())
from keras.models import Sequential
from keras.layers.core import Flatten,Dense,Dropout
from keras.layers.convolutional import Convolution2D,MaxPooling2D,ZeroPadding2D
from keras.optimizers import SGD
import numpy as np
from keras.preprocessing import image
from keras.applications.imagenet_utils import preprocess_input, decode_predictions
import time
from keras import backend as K
K.set_image_dim_ordering('th')
def VGG_16(weights_path=None):
model = Sequential()
model.add(ZeroPadding2D((1, 1), input_shape=(3, 224, 224)))
model.add(Convolution2D(64, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(64, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(128, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(256, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(ZeroPadding2D((1, 1)))
model.add(Convolution2D(512, (3, 3), activation='relu'))
model.add(MaxPooling2D((2, 2), strides=(2, 2)))
model.add(Flatten())
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(4096, activation='relu'))
model.add(Dropout(0.5))
model.add(Dense(1000, activation='softmax'))
if weights_path:
model.load_weights(weights_path,by_name=True)
return model
model = VGG_16(weights_path='F:\\Kaggle\\vgg16_weights.h5')
sgd = SGD(lr=0.1, decay=1e-6, momentum=0.9, nesterov=True)
model.compile(optimizer=sgd, loss='categorical_crossentropy')
t0 = time.time()
img = image.load_img('D:\\photo\\dog.jpg', target_size=(224, 224))
x = image.img_to_array(img) # 三维(224,224,3)
x = np.expand_dims(x, axis=0) # 四维(1,224,224,3)#因为keras要求的维度是这样的,所以要增加一个维度
x = preprocess_input(x) # 预处理
print(x.shape)
y_pred = model.predict(x) # 预测概率
t1 = time.time()
print("测试图:", decode_predictions(y_pred)) # 输出五个最高概率(类名, 语义概念, 预测概率)
print("耗时:", str((t1 - t0) * 1000), "ms")
这是两种不同的方式,第一种是直接使用vgg16的参数,需要在运行时下载,第二种是我们已经下载好的权重,直接在参数中输入我们的路径即可。
补充知识:keras加经典网络的预训练模型(以VGG16为例)
我就废话不多说了,大家还是直接看代码吧~
# 使用VGG16模型
from keras.applications.vgg16 import VGG16
print('Start build VGG16 -------')
# 获取vgg16的卷积部分,如果要获取整个vgg16网络需要设置:include_top=True
model_vgg16_conv = VGG16(weights='imagenet', include_top=False)
model_vgg16_conv.summary()
# 创建自己的输入格式
# if K.image_data_format() == 'channels_first':
# input_shape = (3, img_width, img_height)
# else:
# input_shape = (img_width, img_height, 3)
input = Input(input_shape, name = 'image_input') # 注意,Keras有个层就是Input层
# 将vgg16模型原始输入转换成自己的输入
output_vgg16_conv = model_vgg16_conv(input)
# output_vgg16_conv是包含了vgg16的卷积层,下面我需要做二分类任务,所以需要添加自己的全连接层
x = Flatten(name='flatten')(output_vgg16_conv)
x = Dense(4096, activation='relu', name='fc1')(x)
x = Dense(512, activation='relu', name='fc2')(x)
x = Dense(128, activation='relu', name='fc3')(x)
x = Dense(1, activation='softmax', name='predictions')(x)
# 最终创建出自己的vgg16模型
my_model = Model(input=input, output=x)
# 下面的模型输出中,vgg16的层和参数不会显示出,但是这些参数在训练的时候会更改
print('\nThis is my vgg16 model for the task')
my_model.summary()
来源:https://blog.csdn.net/sunshunli/article/details/81456566
0
投稿
猜你喜欢
- 治標不治本的就是將php.ini內的reporting部份修改,讓notice不顯示 error_reporting = E_ALL; di
- 1. 规范简介本规范主要规定ASP源程序在书写过程中所应遵循的规则及注意事项。编写该规范的目的是使项目开发人员的源代码书写习惯保持一致。这样
- 本文实例讲述了Python实现html转换为pdf报告(生成pdf报告)功能。分享给大家供大家参考,具体如下:1、先说下html转换为pdf
- 自Python3.1中,整数bit_length方法允许查询二进制的位数或长度。常规做法:>>> bin(256)'
- 本文实例讲述了使用symfony命令创建项目的方法。分享给大家供大家参考,具体如下:概况这一章节描述一个Symfony项目的合理结构框架,并
- 代码需要先导入pandasarr的数据类型为一维的np.arrayimport pandas as pdarr[~pd.isnull(arr
- 核心播放模块(pygame内核)import time import pygameimport easygui as guifile = r
- 1、Python 有两种类型可以表示字符序列bytes:实例包含的是 原始数据 ,即 8 位的无符号值(通常按照 ASCII 编码 标准来显
- 目录1. 递归函数与回溯深搜的基础知识2. 求子集 (LeetCode 78)3. 求子集2 (LeetCode 90)4. 组合数之和(L
- 之前看到很多人一直都问CSS 中DIV垂直居中的问题,看来对此的需求还不少。现在就把我经验拿出来分享一下,希望大家鼓鼓掌。因为在 CSS 中
- 1.过程蜘蛛纸牌大家玩过没有?之前的电脑上自带的游戏,用他来摸鱼过的举个手。但是现在的电脑上已经没有蜘蛛纸牌了。所以…
- 继续练手,根据之前获取汽油价格的方式获取了金价,暂时没钱投资,看看而已#!/usr/bin/env python# -*- coding:
- ASP实现即时显示当前页面浏览人数online.asp文件 <!--#include file="dbconn.a
- 在使用python爬取网站信息时,查看爬取完后的数据发现,数据并没有被爬取下来,这是因为网站中有UA这种请求载体的身份标识,如果不是基于某一
- tensorflow作为google开源的项目,现在赶超了caffe,好像成为最受欢迎的深度学习框架。确实在编写的时候更能感受到代码的真实存
- 最近需要用python打包一个单页面网页demo,于是准备用python包pyinstaller来打包程序。网上搜索了一下,大部分教程都是打
- 问: 如果数据表中有时间字段,现在要迁移到其他时区的服务器上,该如何处理呢?答:在高版本的mysqldump中,新增了一个选项:--tz-u
- 有时候在使用Python处理比较耗时操作的时候,为了便于观察处理进度,这时候就需要通过进度条将处理情况进行可视化展示,以便我们能够及时了解情
- pycharm sql语句警告产生原因为没有配置数据库,配置数据库,似乎没什么作用那么,直接去掉他的警告提示找到setting->ed
- input()作用:让用户从控制台输入一串字符,按下回车后结束输入,并返回字符串注意:很多初学者以为它可以返回数字,其实是错的!>&g