Python Opencv使用ann神经网络识别手写数字功能
作者:Keras深度学习 发布时间:2023-11-03 02:44:52
标签:python,opencv,识别手写数字,ann,神经网络
opencv中也提供了一种类似于Keras的神经网络,即为ann,这种神经网络的使用方法与Keras的很接近。
关于mnist数据的解析,读者可以自己从网上下载相应压缩文件,用python自己编写解析代码,由于这里主要研究knn算法,为了图简单,直接使用Keras的mnist手写数字解析模块。
本次代码运行环境为:
python 3.6.8
opencv-python 4.4.0.46
opencv-contrib-python 4.4.0.46
下面的代码为使用ann进行模型的训练:
from keras.datasets import mnist
from keras import utils
import cv2
import numpy as np
#opencv中ANN定义神经网络层
def create_ANN():
ann=cv2.ml.ANN_MLP_create()
#设置神经网络层的结构 输入层为784 隐藏层为80 输出层为10
ann.setLayerSizes(np.array([784,64,10]))
#设置网络参数为误差反向传播法
ann.setTrainMethod(cv2.ml.ANN_MLP_BACKPROP)
#设置激活函数为sigmoid
ann.setActivationFunction(cv2.ml.ANN_MLP_SIGMOID_SYM)
#设置训练迭代条件
#结束条件为训练30次或者误差小于0.00001
ann.setTermCriteria((cv2.TermCriteria_EPS|cv2.TermCriteria_COUNT,100,0.0001))
return ann
#计算测试数据上的识别率
def evaluate_acc(ann,test_images,test_labels):
#采用的sigmoid激活函数,需要对结果进行置信度处理
#对于大于0.99的可以确定为1 对于小于0.01的可以确信为0
test_ret=ann.predict(test_images)
#预测结果是一个元组
test_pre=test_ret[1]
#可以直接最大值的下标 (10000,)
test_pre=test_pre.argmax(axis=1)
true_sum=(test_pre==test_labels)
return true_sum.mean()
if __name__=='__main__':
#直接使用Keras载入的训练数据(60000, 28, 28) (60000,)
(train_images,train_labels),(test_images,test_labels)=mnist.load_data()
#变换数据的形状并归一化
train_images=train_images.reshape(train_images.shape[0],-1)#(60000, 784)
train_images=train_images.astype('float32')/255
test_images=test_images.reshape(test_images.shape[0],-1)
test_images=test_images.astype('float32')/255
#将标签变为one-hot形状 (60000, 10) float32
train_labels=utils.to_categorical(train_labels)
#测试数据标签不用变为one-hot (10000,)
test_labels=test_labels.astype(np.int)
#定义神经网络模型结构
ann=create_ANN()
#开始训练
ann.train(train_images,cv2.ml.ROW_SAMPLE,train_labels)
#在测试数据上测试准确率
print(evaluate_acc(ann,test_images,test_labels))
#保存模型
ann.save('mnist_ann.xml')
#加载模型
myann=cv2.ml.ANN_MLP_load('mnist_ann.xml')
训练100次得到的准确率为0.9376,可以接着增加训练次数或者提高神经网络的层次结构深度来提高准确率。
使用ann神经网络的模型结构非常小,因为只是保存了权重参数。
可以看到整个模型文件的大小才1M,而svm的大小为十多兆,knn的为几百兆,因此使用ann神经网络更加适合部署在客户端上。
接下来使用ann进行图片的测试识别:
import cv2
import numpy as np
if __name__=='__main__':
#读取图片
img=cv2.imread('shuzi.jpg',0)
img_sw=img.copy()
#将数据类型由uint8转为float32
img=img.astype(np.float32)
#图片形状由(28,28)转为(784,)
img=img.reshape(-1,)
#增加一个维度变为(1,784)
img=img.reshape(1,-1)
#图片数据归一化
img=img/255
#载入ann模型
ann=cv2.ml.ANN_MLP_load('minist_ann.xml')
#进行预测
img_pre=ann.predict(img)
#因为激活函数sigmoid,因此要进行置信度处理
ret=img_pre[1]
ret[ret>0.9]=1
ret[ret<0.1]=0
print(ret)
cv2.imshow('test',img_sw)
cv2.waitKey(0)
运行程序,结果如下,可见该模型正确识别了数字0.
来源:https://keras-lx.blog.csdn.net/article/details/111694841
0
投稿
猜你喜欢
- 查看逻辑读前10的SQL:set linesize 300;set pagesize 300;set long 50000;SELECT *
- 新建图像文件后选Channels面板,新建Alpha1通道:输入文字; &nbs
- Access 连接字符串 strConnect = “Provider=Microsoft.Jet.OLEDB.4.0;
- 如果程序中没有设置session的过期时间,那么session过期时间就会按照IIS设置的过期时间来执行,IIS中session默认过期时间
- 在现在的项目里,不管是电商项目还是别的项目,在管理端都会有导出的功能,比方说订单表导出,用户表导出,业绩表导出。这些都需要提前生成excel
- 根据微软论坛作者的英文解释,.NET framework 4.0 安装失败回滚貌似是因为“msvcr100_clr0400.d
- AXObject可用来解决IE需要激活 ActiveX 控件和生成控件调用代码 AXObjec
- 1. 首先 进入cmd, 输入python,看python是否安装成功说明python安装,没有问题2. 修改注册表第一步window +
- 页面大小、窗口大小和滚动条位置这三个数值在不同的浏览器例如Firefox和IE中有着不同的实现。即使在同一种浏览器例如IE中,不同版本也有不
- 一个asp显示当前日期农历的代码函数,效果 今天是:农历丁亥年(猪)八月十三。调用方便!Function nl()'获取当前系统时间
- 秦歌这篇文章总结得很不错,俺挑刺来啦:1. 优先级的描述不严谨,有 !important 时,网页样式可以覆盖用户自定义样式。用户!impo
- 先想创意,再画草图,接着鼠绘,最后做成flas * 。这是我的习惯流程。 这是想到中秋时,我第一时间内能浮想出的图像:大意是嫦娥奔
- 抽象工厂模式抽象工厂模式是一种创建型设计模式, 它能创建一系列相关的对象, 而无需指定其具体类。抽象工厂定义了用于创建不同产品的接口, 但将
- 记一次在写cli脚本的时候,碰到的一个问题。问题自己是写服务端的,有时候会写一些cli脚本去跑测试。习惯main.go写主流程,其他子文件写
- Cookie简介首先,我们对Cookie做一个简单的介绍,说明如何利用ASP来维护cookie。Cookie是存储在客户端计算机中的一个小文
- 实体有五种预定义的XML实体,HTML编码者应该熟悉。XML文档中的字符&、<、>、"和'被分别表示为
- Django将秒转换为xx天xx时xx分,具体代码如下所示:from django.utils.translation import nge
- 此类技巧还有很多,欢迎继续分享解析 URL从 James Padolsey 的 Blog中看到的个小技巧,就是利用 a 标签的 DOM 属性
- 在默认情况下,Access 2000/2002数据库是以“共享”的方式打开的,这样可以保证多人能够同时使用同一个数据库。不过,在共享方式打开
- 今天,本文向大家推荐20佳国外的脚本下载网站。1- Hot Scripts2- Code Canyon3- User Scripts4- S