详解Python手写数字识别模型的构建与使用
作者:顾城沐心 发布时间:2023-10-21 18:34:12
标签:Python,手写,数字,识别
一:手写数字模型构建与保存
1 加载数据集
# 1加载数据
digits_data = load_digits()
可以先简单查看下 手写数字集,如下可以隐约看出数字为8
plt.imshow(digits_data.images[8])
plt.show()
2 特征数据 标签数据
# 数据划分
x_data = digits_data.data
y_data = digits_data.target
3 训练集 测试集
# 训练集 + 测试集
x_test = x_data[:40]
y_test = y_data[:40]
x_train = x_data[40:]
y_train = y_data[40:]
# 概率问题
y_train_2 = np.zeros(shape=(len(y_train), 10))
4 数据流图 输入层
input_size = digits_data.data.shape[1] # 输入的列数
# 数据流图的构建
# x:输入64个特征值--像素
x = tf.placeholder(np.float32, shape=[None, input_size])
# y:识别的数字 有几个类别[0-9]
y = tf.placeholder(np.float32, shape=[None, 10])
5 隐藏层
5.1 第一层
# 第一层隐藏层
# 参数1 输入维度 参数2:输出维度(神经元个数) 标准差是0.1的正态分布
w1 = tf.Variable(tf.random_normal([input_size, 80], stddev=0.1))
# b的个数就是隐藏层神经元的个数
b1 = tf.Variable(tf.constant(0.01), [80])
# 第一层计算
one = tf.matmul(x, w1) + b1
# 激活函数 和0比 大于0则激活
op1 = tf.nn.relu(one)
5.2 第二层
# 第二层隐藏层 上一层输出为下一层输入
# 参数1 输入维度 参数2:输出维度(神经元个数) 标准差是0.1的正态分布
w2 = tf.Variable(tf.random_normal([80, 10], stddev=0.1))
# b的个数就是隐藏层神经元的个数
b2 = tf.Variable(tf.constant(0.01), [10])
# 第一层计算
two = tf.matmul(op1, w2) + b2
# 激活函数 和0比 大于0则激活
op2 = tf.nn.relu(two)
6 损失函数
# 构建损失函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=op2))
7 梯度下降算法
# 梯度下降算法
Optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.005).minimize(loss)
8 输出损失值
# 变量初始化
init = tf.global_variables_initializer()
data_size = digits_data.data.shape[0]
# 开启会话
with tf.Session() as sess:
sess.run(init)
# 训练次数
for i in range(500):
# 数据分组
start = (i * 100) % data_size
end = min(start + 100, data_size)
batch_x = x_train[start:end]
batch_y = y_train_2[start:end]
sess.run(Optimizer, feed_dict={x: batch_x, y: batch_y})
# 输出损失值
train_loss = sess.run(loss, feed_dict={x: batch_x, y: batch_y})
print(train_loss)
9 模型 保存与使用
obj = tf.train.Saver()
# 模型保存
obj.save(sess, 'model-digits.ckpt')
10 完整源码分享
import tensorflow as tf
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# 1加载数据
digits_data = load_digits()
# 查看数据
# print(digits_data)
# 查看数据基本特征 (1797, 64) 64:8*8像素点
# print(digits_data.data.shape)
# plt.imshow(digits_data.images[8])
# plt.show()
# 数据划分
x_data = digits_data.data
y_data = digits_data.target
# 训练集 + 测试集
x_test = x_data[:40]
y_test = y_data[:40]
x_train = x_data[40:]
y_train = y_data[40:]
# 概率问题
y_train_2 = np.zeros(shape=(len(y_train), 10))
# 对应的分类 当前行对应列变成1
for index, row in enumerate(y_train_2):
# 当前行 对应的数字对应列
row[int(y_train[index])] = 1
# print(y_train_2[0])
input_size = digits_data.data.shape[1] # 输入的列数
# 数据流图的构建
# x:输入64个特征值--像素
x = tf.placeholder(np.float32, shape=[None, input_size])
# y:识别的数字 有几个类别[0-9]
y = tf.placeholder(np.float32, shape=[None, 10])
# 第一层隐藏层
# 参数1 输入维度 参数2:输出维度(神经元个数) 标准差是0.1的正态分布
w1 = tf.Variable(tf.random_normal([input_size, 80], stddev=0.1))
# b的个数就是隐藏层神经元的个数
b1 = tf.Variable(tf.constant(0.01), [80])
# 第一层计算
one = tf.matmul(x, w1) + b1
# 激活函数 和0比 大于0则激活
op1 = tf.nn.relu(one)
# 第二层隐藏层 上一层输出为下一层输入
# 参数1 输入维度 参数2:输出维度(神经元个数) 标准差是0.1的正态分布
w2 = tf.Variable(tf.random_normal([80, 10], stddev=0.1))
# b的个数就是隐藏层神经元的个数
b2 = tf.Variable(tf.constant(0.01), [10])
# 第一层计算
two = tf.matmul(op1, w2) + b2
# 激活函数 和0比 大于0则激活
op2 = tf.nn.relu(two)
# 构建损失函数
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(labels=y, logits=op2))
# 梯度下降算法 优化器 learning_rate学习率(步长)
Optimizer = tf.train.GradientDescentOptimizer(learning_rate=0.005).minimize(loss)
# 变量初始化
init = tf.global_variables_initializer()
data_size = digits_data.data.shape[0]
# 开启会话
with tf.Session() as sess:
sess.run(init)
# 训练次数
for i in range(500):
# 数据分组
start = (i * 100) % data_size
end = min(start + 100, data_size)
batch_x = x_train[start:end]
batch_y = y_train_2[start:end]
sess.run(Optimizer, feed_dict={x: batch_x, y: batch_y})
# 输出损失值
train_loss = sess.run(loss, feed_dict={x: batch_x, y: batch_y})
print(train_loss)
obj = tf.train.Saver()
# 模型保存
obj.save(sess, 'modelSave/model-digits.ckpt')
损失值在0.303左右,如下图所示
二:手写数字模型使用与测试
对上一步创建的模型,使用测试
import tensorflow as tf
from sklearn.datasets import load_digits
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
# 1加载数据
digits_data = load_digits()
# 数据划分
x_data = digits_data.data
y_data = digits_data.target
# 训练集 + 测试集
x_test = x_data[:40]
y_test = y_data[:40]
x_train = x_data[40:]
y_train = y_data[40:]
# 概率问题
y_train_2 = np.zeros(shape=(len(y_train), 10))
# 对应的分类 当前行对应列变成1
for index, row in enumerate(y_train_2):
# 当前行 对应的数字对应列
row[int(y_train[index])] = 1
# 网络搭建
num_class = 10 # 数字0-9
hidden_num = 80 # 神经元个数
input_size = digits_data.data.shape[1] # 输入的列数
# 数据流图的构建
# x:输入64个特征值--像素
x = tf.placeholder(np.float32, shape=[None, 64])
# y:识别的数字 有几个类别[0-9]
y = tf.placeholder(np.float32, shape=[None, 10])
# 第一层隐藏层
# 参数1 输入维度 参数2:输出维度(神经元个数) 标准差是0.1的正态分布
w1 = tf.Variable(tf.random_normal([input_size, 80], stddev=0.1))
# b的个数就是隐藏层神经元的个数
b1 = tf.Variable(tf.constant(0.01), [80])
# 第一层计算
one = tf.matmul(x, w1) + b1
# 激活函数 和0比 大于0则激活
op1 = tf.nn.relu(one)
# 第二层隐藏层 上一层输出为下一层输入
# 参数1 输入维度 参数2:输出维度(神经元个数) 标准差是0.1的正态分布
w2 = tf.Variable(tf.random_normal([80, 10], stddev=0.1))
# b的个数就是隐藏层神经元的个数
b2 = tf.Variable(tf.constant(0.01), [10])
# 第一层计算
two = tf.matmul(op1, w2) + b2
# 激活函数 和0比 大于0则激活
op2 = tf.nn.relu(two)
# 变量初始化
init = tf.global_variables_initializer()
train_count = 500
batch_size = 100
data_size = x_train.shape[0]
pre_max_index = tf.argmax(op2, 1)
plt.imshow(digits_data.images[13]) # 3
plt.show()
with tf.Session() as sess:
sess.run(init)
# 使用网络
obj = tf.train.Saver()
obj.restore(sess, 'modelSave/model-digits.ckpt')
print(sess.run(op2, feed_dict={x: [x_test[13], x_test[14]]}))
print(sess.run(pre_max_index, feed_dict={x: [x_test[13], x_test[14]]}))
想要测试的数据,如下图所示
使用模型测试出来的结果,如下图所示,模型基本能够使用
来源:https://blog.csdn.net/m0_56051805/article/details/128398291
0
投稿
猜你喜欢
- python自带日志管理模块logging,使用时可进行模块化配置,详细可参考博文Python日志采集(详细)。但logging配置起来比较
- 字典对象的核心是散列表。散列表是一个稀疏数组(总是有空白元素的数组),数组的每个单元叫做 bucket。每个 bucket 有两部分:一个是
- 一个程序要进行交互,就需要进行输入,进行输入→处理→输出的过程。所以就需要用到输入和输出功能。同样的,在Python中,怎么实现输入和输出?
- 从今天开始,我将全面的共享出我所能理解的所有WEB标准方面的知识放在这个“WEB标准能有多难?”的专栏里。当然由于振之的水平有限,所讲并非是
- 概述Object.freeze(obj)可以冻结一个对象。一个被冻结的对象再也不能被修改;冻结了一个对象则不能向这个对象添加新的属性,不能删
- 在为一个客户排除死锁问题时我遇到了一个有趣的包括InnoDB间隙锁的情形。对于一个WHERE子句不匹配任何行的非插入的写操作中,
- SQL2005增加了4个关于队计算的函数:分别是ROW_NUMBER,RANK,DENSE_RANK,NTILE. 注意:这些函数
- 本文为大家分享了mysql 8.0.12 安装详细教程,供大家参考,具体内容如下一、安装 1.从官网上下载MySQL8.0.12版本,下载链
- 目标站点分析本次要抓取的目标站点为:中介网,这个网站提供了网站排行榜、互联网网站排行榜、中文网站排行榜等数据。网站展示的样本数据量是 :58
- 前言我们知道当文件不存在的时候,open()方法的写模式与追加模式都会新建文件,但是对文件进行判断的场景还有很多,比如,在爬虫下载图片的时候
- 遇到的问题:在做爬虫的时候,爬取的url链接内还有转义字符,反斜杠 \,打算用正则的re.sub()替换掉的时候遇到了问题,这是要做替换的字
- 看代码吧~import pymongofrom dateutil import parserdateStr = "2019-05-
- Pycharm运行时总是跳出Python Console最近运行程序的时候发现,每次点击运行之后,都是出现的Python Console。最
- 1.思路在网上查找了半天,基本都是提取word中文字的,没有找到可以把word中的图片提取出来的方法。一个巧合的情况下,发现将word的后缀
- 目录前言limit深分页为什么会变慢?通过子查询优化回顾B+ 树结构把条件转移到主键索引树INNER JOIN 延迟关联标签记录法使用bet
- 或许你也经历过,很多人都说一个女人很漂亮,而你觉得很一般。有时候,我也尝试理解为什么会对某个女人情有独钟。通常,我用迷人来描述,但这个&qu
- 本文实例讲述了python写日志文件操作类与应用。分享给大家供大家参考,具体如下:项目的开发过程中,日志文件是少不了的,通过写日志文件,可以
- mysql数据库没有增量备份的机制,当数据量太大的时候备份是一个很大的问题。还好mysql数据库提供了一种主从备份的机制,其实就是把主数据库
- 前言本文参考了以下代码Windows系统环境下Python脚本实现全局“划词复制”功能from py
- Go语言在进行文件操作的时候,可以有多种方法。最常见的比如直接对文件本身进行Read和Write; 除此之外,还可以使用bufio库的流式处