python神经网络编程实现手写数字识别
作者:wenmiao_ 发布时间:2021-08-31 16:08:14
标签:python,数字识别
本文实例为大家分享了python实现手写数字识别的具体代码,供大家参考,具体内容如下
import numpy
import scipy.special
#import matplotlib.pyplot
class neuralNetwork:
def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate):
self.inodes=inputnodes
self.hnodes=hiddennodes
self.onodes=outputnodes
self.lr=learningrate
self.wih=numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes))
self.who=numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hnodes))
self.activation_function=lambda x: scipy.special.expit(x)
pass
def train(self,inputs_list,targets_list):
inputs=numpy.array(inputs_list,ndmin=2).T
targets=numpy.array(targets_list,ndmin=2).T
hidden_inputs=numpy.dot(self.wih,inputs)
hidden_outputs=self.activation_function(hidden_inputs)
final_inputs=numpy.dot(self.who,hidden_outputs)
final_outputs=self.activation_function(final_inputs)
output_errors=targets-final_outputs
hidden_errors=numpy.dot(self.who.T,output_errors)
self.who+=self.lr*numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),numpy.transpose(hidden_outputs))
self.wih+=self.lr*numpy.dot((hidden_errors*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs))
pass
def query(self,input_list):
inputs=numpy.array(input_list,ndmin=2).T
hidden_inputs=numpy.dot(self.wih,inputs)
hidden_outputs=self.activation_function(hidden_inputs)
final_inputs=numpy.dot(self.who,hidden_outputs)
final_outputs=self.activation_function(final_inputs)
return final_outputs
input_nodes=784
hidden_nodes=100
output_nodes=10
learning_rate=0.1
n=neuralNetwork(input_nodes,hidden_nodes,output_nodes,learning_rate)
training_data_file=open(r"C:\Users\lsy\Desktop\nn\mnist_train.csv","r")
training_data_list=training_data_file.readlines()
training_data_file.close()
#print(n.wih)
#print("")
epochs=2
for e in range(epochs):
for record in training_data_list:
all_values=record.split(",")
inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
targets=numpy.zeros(output_nodes)+0.01
targets[int(all_values[0])]=0.99
n.train(inputs,targets)
#print(n.wih)
#print(len(training_data_list))
#for i in training_data_list:
# print(i)
test_data_file=open(r"C:\Users\lsy\Desktop\nn\mnist_test.csv","r")
test_data_list=test_data_file.readlines()
test_data_file.close()
scorecard=[]
for record in test_data_list:
all_values=record.split(",")
correct_lable=int(all_values[0])
inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
outputs=n.query(inputs)
label=numpy.argmax(outputs)
if(label==correct_lable):
scorecard.append(1)
else:
scorecard.append(0)
scorecard_array=numpy.asarray(scorecard)
print(scorecard_array)
print("")
print(scorecard_array.sum()/scorecard_array.size)
#all_value=test_data_list[0].split(",")
#input=(numpy.asfarray(all_value[1:])/255.0*0.99)+0.01
#print(all_value[0])
#image_array=numpy.asfarray(all_value[1:]).reshape((28,28))
#matplotlib.pyplot.imshow(image_array,cmap="Greys",interpolation="None")
#matplotlib.pyplot.show()
#nn=n.query((numpy.asfarray(all_value[1:])/255.0*0.99)+0.01)
#for i in nn :
# print(i)
《python神经网络编程》中代码,仅做记录,以备后用。
image_file_name=r"*.JPG"
img_array=scipy.misc.imread(image_file_name,flatten=True)
img_data=255.0-img_array.reshape(784)
image_data=(img_data/255.0*0.99)+0.01
图片对应像素的读取。因训练集灰度值与实际相反,故用255减取反。
import numpy
import scipy.special
#import matplotlib.pyplot
import scipy.misc
from PIL import Image
class neuralNetwork:
def __init__(self,inputnodes,hiddennodes,outputnodes,learningrate):
self.inodes=inputnodes
self.hnodes=hiddennodes
self.onodes=outputnodes
self.lr=learningrate
self.wih=numpy.random.normal(0.0,pow(self.hnodes,-0.5),(self.hnodes,self.inodes))
self.who=numpy.random.normal(0.0,pow(self.onodes,-0.5),(self.onodes,self.hnodes))
self.activation_function=lambda x: scipy.special.expit(x)
pass
def train(self,inputs_list,targets_list):
inputs=numpy.array(inputs_list,ndmin=2).T
targets=numpy.array(targets_list,ndmin=2).T
hidden_inputs=numpy.dot(self.wih,inputs)
hidden_outputs=self.activation_function(hidden_inputs)
final_inputs=numpy.dot(self.who,hidden_outputs)
final_outputs=self.activation_function(final_inputs)
output_errors=targets-final_outputs
hidden_errors=numpy.dot(self.who.T,output_errors)
self.who+=self.lr*numpy.dot((output_errors*final_outputs*(1.0-final_outputs)),numpy.transpose(hidden_outputs))
self.wih+=self.lr*numpy.dot((hidden_errors*hidden_outputs*(1.0-hidden_outputs)),numpy.transpose(inputs))
pass
def query(self,input_list):
inputs=numpy.array(input_list,ndmin=2).T
hidden_inputs=numpy.dot(self.wih,inputs)
hidden_outputs=self.activation_function(hidden_inputs)
final_inputs=numpy.dot(self.who,hidden_outputs)
final_outputs=self.activation_function(final_inputs)
return final_outputs
input_nodes=784
hidden_nodes=100
output_nodes=10
learning_rate=0.1
n=neuralNetwork(input_nodes,hidden_nodes,output_nodes,learning_rate)
training_data_file=open(r"C:\Users\lsy\Desktop\nn\mnist_train.csv","r")
training_data_list=training_data_file.readlines()
training_data_file.close()
#print(n.wih)
#print("")
#epochs=2
#for e in range(epochs):
for record in training_data_list:
all_values=record.split(",")
inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
targets=numpy.zeros(output_nodes)+0.01
targets[int(all_values[0])]=0.99
n.train(inputs,targets)
#image_file_name=r"C:\Users\lsy\Desktop\nn\1000-1.JPG"
'''
img_array=scipy.misc.imread(image_file_name,flatten=True)
img_data=255.0-img_array.reshape(784)
image_data=(img_data/255.0*0.99)+0.01
#inputs=(numpy.asfarray(image_data)/255.0*0.99)+0.01
outputs=n.query(image_data)
label=numpy.argmax(outputs)
print(label)
'''
#print(n.wih)
#print(len(training_data_list))
#for i in training_data_list:
# print(i)
test_data_file=open(r"C:\Users\lsy\Desktop\nn\mnist_test.csv","r")
test_data_list=test_data_file.readlines()
test_data_file.close()
scorecard=[]
total=[0,0,0,0,0,0,0,0,0,0]
rightsum=[0,0,0,0,0,0,0,0,0,0]
for record in test_data_list:
all_values=record.split(",")
correct_lable=int(all_values[0])
inputs=(numpy.asfarray(all_values[1:])/255.0*0.99)+0.01
outputs=n.query(inputs)
label=numpy.argmax(outputs)
total[correct_lable]+=1
if(label==correct_lable):
scorecard.append(1)
rightsum[correct_lable]+=1
else:
scorecard.append(0)
scorecard_array=numpy.asarray(scorecard)
print(scorecard_array)
print("")
print(scorecard_array.sum()/scorecard_array.size)
print("")
print(total)
print(rightsum)
for i in range(10):
print((rightsum[i]*1.0)/total[i])
#all_value=test_data_list[0].split(",")
#input=(numpy.asfarray(all_value[1:])/255.0*0.99)+0.01
#print(all_value[0])
#image_array=numpy.asfarray(all_value[1:]).reshape((28,28))
#matplotlib.pyplot.imshow(image_array,cmap="Greys",interpolation="None")
#matplotlib.pyplot.show()
#nn=n.query((numpy.asfarray(all_value[1:])/255.0*0.99)+0.01)
#for i in nn :
# print(i)
尝试统计了对于各个数据测试数量及正确率。
原本想验证书后向后查询中数字‘9'识别模糊是因为训练数量不足或错误率过高而产生,然最终结果并不支持此猜想。
另书中只能针对特定像素的图片进行学习,真正手写的图片并不能满足训练条件,实际用处仍需今后有时间改进。
来源:https://blog.csdn.net/wenmiao_/article/details/88191457
0
投稿
猜你喜欢
- 工具版本python版本:3.8 django版本:2.0.0 mysql版本: 5.5.53 pip3创建工程djangostartDja
- 在使用ORACLE的过程过,我们会经常遇到一些ORACLE产生的错误,对于初学者而言,这些错误可能有点模糊,而且可能一时不知怎么去处理产生的
- //CLASS@Mr.Think*****getElementsByTagName function tag(name,elem){ if(
- 我今天晚上,做一个快印公司的网站布局,在Div镶套布局中,父标签DIV的高度不变。在IE下没有问题,但是在FIREFOX下就有问题了。如图:
- 听名字就知道这个函数是用来求tensor中某个dim的前k大或者前k小的值以及对应的index。用法torch.topk(input, k,
- AES加密方式有五种:ECB, CBC, CTR, CFB, OFB从安全性角度推荐CBC加密方法,本文介绍了CBC,ECB两种加密方法的p
- 最初打算使用scroll-view实现,效果好、流畅、有惯性滑动,但由于滚动条没法去掉、无法实现上下层的帧布局,最终放弃了。还是自己写个吧,
- 使用tensorflow训练模型的时候,模型持久化对我们来说非常重要。如果我们的模型比较复杂,需要的数据比较多,那么在模型的训练时间会耗时很
- 用python操作ms sqlserver,有好几种方法:(1)利用pymssql (2)利用pyodbc这里讲import&nb
- 本文实例讲述了python实现分析apache和nginx日志文件并输出访客ip列表的方法。分享给大家供大家参考。具体如下:这里使用pyth
- 这是一个获取字符串中两个子串之间的子串,如从字符串www.aspxhome.com中获取coderbolg子串,就让这个PHP函数来实现吧,
- 核心代码--下面我演示下MySQL中的排序列的实现--测试数据CREATE TABLE tb(score INT);INSERT tb SE
- SQLite是一款轻型的数据库,是遵守ACID的关系型数据库管理系统。不像常见的客户-服务器范例,SQLite引擎不是个程序与之通信的独立进
- 有时候我们可能不知道一个用户的密码,但是又需要以这个用户做一些操作,又不能去修改掉这个用户的密码,这个时候,就可以利用一些小窍门,来完成操作
- 本文为大家分享了python爱心表白的具体代码,供大家参考,具体内容如下import turtleimport time# 画爱心的顶部de
- Bootstrap 的响应式 CSS 能够自适应于台式机、平板电脑和手机下面是手机端显示的样式代码如下:1.test.php<html
- 关于SQL查询效率,100w数据,查询只要1秒,与您分享:机器情况:p4: 2.4内存: 1 Gos: windows 2003数据库:SQ
- 从4年之前什么都不知道,到现在对代码的一网情深,感谢无忧的兄弟姐妹的帮助,感谢无忧给我们提供了这么好的交流平台。现将最近几天捣鼓的asp封装
- 一:Fancy Indexingimport numpy as np#Fancy Indexingx = np.arange(16)np.r
- 概要在列表,元组,实例,类,字典和函数中存在循环引用问题。有 __del__ 方法的实例会以健全的方式被处理。给新类型添加GC支持是很容易的