使用已经得到的keras模型识别自己手写的数字方式
作者:游学者冬夜 发布时间:2021-04-03 17:41:11
环境:Python+keras,后端为Tensorflow
训练集:MNIST
对于如何训练一个识别手写数字的神经网络,网上资源十分丰富,并且能达到相当高的精度。但是很少有人涉及到如何将图片输入到网络中并让已经训练好的模型惊醒识别,下面来说说实现方法及注意事项。
首先import相关库,这里就不说了。
然后需要将训练好的模型导入,可通过该语句实现:
model = load_model('cnn_model_2.h5') (cnn_model_2.h5替换为你的模型名)
之后是导入图片,需要的格式为28*28。可用opencv导入:
img = cv2.imread('temp3.png', 0) (temp3.png替换为你手写的图片)
然后reshape一下以符合模型的输入要求:
img = (img.reshape(1,1,28,28)).astype("float32")/255
之后就可以用模型识别了:
predict = model.predict_classes(img)
最后print一下predict即可。
下面划重点:因为MNIST使用的是黑底白字的图片,所以你自己手写数字的时候一定要注意把得到的图片也改成黑底白字的,否则会识别错(至少我得到的结论是这样的 ,之前用白底黑字的图总是识别出错)
源码一览:
import cv2
import numpy as np
from keras.models import load_model
model = load_model('cnn_model_2.h5')
image = cv2.imread('temp3.png', 0)
img = cv2.imread('temp3.png', 0)
img = (img.reshape(1,1,28,28)).astype("float32")/255
predict = model.predict_classes(img)
print ('识别为:')
print (predict)
cv2.imshow("Image1", image)
cv2.waitKey(0)
效果图:
补充知识:keras编写自定义的层
写在前面的话
keras已经有很多封装好的库供我们调用,但是有些时候我们需要的操作keras并没有,这时就需要学会自定义keras层了
1.Lambda
这个东西很方便,但是只能完成简单、无状态的自定义操作,而不能建立含有可训练权重的自定义层。
from keras.layers import Input,Lambda
from keras import Model
import tensorflow as tf
input=Input(shape=(224,224,3))
input.shape #Input第一个维度为batchsize维度
output=Lambda(lambda x: x[...,1])(input) #取最后一个维度的数据,...表示前面所有的维度
Model=Model(inputs=input,outputs=output)
Model.output
2.keras_custom
学习自keras中文文档
2.自定义keras层(带有可训练权重)
① build:定义权重,且self.build=True,可以通过迪奥哟经super([layer],self).build()完成
② call:功能逻辑实现
③ compute_output_shape:计算输出张量的shape
import keras.backend as K
from keras.engine.topology import Layer #这里的Layer是一个父类,下面的MyLayer将会继承Layer
class MyLayer(Layer): #自定义一个keras层类
def __init__(self,output_dim,**kwargs): #初始化方法
self.output_dim=output_dim
super(MyLayer,self).__init__(**kwargs) #必须要的初始化自定义层
def build(self,input_shape): #为Mylayer建立一个可训练的权重
#通过add_weight的形式来为Mylayer创建权重矩阵
self.kernel=self.add_weight(name='kernel',
shape=(input_shape[1],self.output_dim), #这里就是建立一个shape大小的权重矩阵
initializer='uniform',
trainable=True)
super(MyLayer,self).build(input_shape) #一定要用,也可以用下面一行
#self.build=True
def call(self,x): #call函数里就是定义了对x张量的计算图,且x只是一个形式,所以不能被事先定义
return K.dot(x,self.kernel) #矩阵乘法
def compute_output_shape(self,input_shape):
return (input_shape[0],self.output_dim) #这里是自己手动计算出来的output_shape
--------------------------------------------------------------------------------
class Mylayer(Layer):
def __init__(self,output_dim,**kwargs):
self.output_dim=output_dim
super(MyLayer,self).__init__(**kwargs)
def build(self,input_shape):
assert isinstance(input_shape,list) #判断input_shape是否是list类型的
self.kernel=self.add_weight(name='kernel',
shape=(input_shape[0][1],self.output_dim), #input_shape应该长得像[(2,2),(3,3)]
initializer='uniform',
trainable=True)
super(MyLayer,self).build(input_shape)
def call(self,x):
assert isinstance(x,list)
a,b=x #从这里可以看出x应该是一个类似[(2,2),(3,3)]的list,a=(2,2),b=(3,3)
return [K.dot(a,self.kernel)+b,K.mean(b,axis=-1)]
来源:https://blog.csdn.net/baidu_35113561/article/details/79371716
猜你喜欢
- 二叉树中和为某一值的路径:输入一颗二叉树的跟节点和一个整数,打印出二叉树中结点值的和为输入整数的所有路径。路径定义为从树的根结点开始往下一直
- reduce() 函数在 python 2 是内置函数, 从python 3 开始移到了 functools 模块。官方文档是这样介绍的re
- 方法: 使用urlencode函数urllib.request.urlopen()import urllib.requestimport u
- 1、es的批量插入这是为了方便后期配置的更改,把配置信息放在logging.conf中用elasticsearch来实现批量操作,先安装依赖
- 发现很多朋友对 CSS 的优先权不甚了解,规则很简单。需要说明的一点,如果你的样式管理需要深层判断 CSS 的优先权,更应反思自己的 CSS
- 一、表单验证form1、创建一个新的表单:<form id="id是唯一的,不可重复" name=“可重复”,me
- 经常使用到有关数据库的操作。包括连接代码、SQL命令等等,又不曾刻意去记忆它们(我本人是不愿意去记这东东),所以常常在用到的时候又去查书本,
- 全选、全不选、反选这几个功能我们经常会用到,如我们可以用在文章列表管理页面,也可以用在音乐播放页面,使用全选我们可以很方便的进行批量操作,如
- 前言列表(list)同字符串一样都是有序的,因为他们都可以通过切片和索引进行数据访问,且列表是可变的。创建列表的几种方法第一种name_li
- 一般来说,我们判断 iframe 是否加载完成其实与 判断 JavaScript 文件是否加载完成 采用的方法很类似:var&nb
- 下列语句部分是Mssql语句,不可以在access中使用。SQL分类:DDL—数据定义语言(CREATE,ALTER,DROP,DECLAR
- 前后端分离前后端分离的好处最大的好处就是前端JS可以做很大部分的数据处理工作,对服务器的压力减小到最小。后台错误不会直接反映到前台,错误接秒
- 一、概述在ubuntu环境下进行嵌入式开发,我们在进行不同的项目开发时,可能会遇到python环境不统一的情况。这时,我们可以通过updat
- 在已有的shapefile文件的基础上增加字段: # -*- coding:gb2312 -*-import shapefiler=shap
- SQLServer中建立与服务器的连接时出错的解决方案如下:步骤1:在SQLServer 实例上启用远程连接1.指向“开始->程序-&
- 如果遇到下述错误,表示当启动mysqld时或重新加载授权表时,在用户表中发现具有非法密码的账户。发现用户'some_user'
- 本文介绍机器学习中的Logistic回归算法,我们使用这个算法来给数据进行分类。Logistic回归算法同样是需要通过样本空间学习的监督学习
- 前言数据驱动是一种思想,让数据和代码进行分离,比如爬虫时,我们需要分页爬取数据时,我们往往把页数 page 参数化,放在 for 循环 ra
- 网站域名一般都会选简短易记的,因为这对于网站宣传来说也可以省不少力。而被很多网站忽视的站内Url结构则在一定程度上反映出网站的整体架构。当设
- 一、写在前面作为一名测试,有时候经常会遇到需要录屏记录自己操作,方便后续开发同学定位。以前都是用ScreenToGif来录屏制作成动态图,偶