tensorflow识别自己手写数字
作者:juezhanangle 发布时间:2022-10-12 20:25:04
tensorflow作为google开源的项目,现在赶超了caffe,好像成为最受欢迎的深度学习框架。确实在编写的时候更能感受到代码的真实存在,这点和caffe不同,caffe通过编写配置文件进行网络的生成。环境tensorflow是0.10的版本,注意其他版本有的语句会有错误,这是tensorflow版本之间的兼容问题。
还需要安装PIL:pip install Pillow
图片的格式:
– 图像标准化,可安装在20×20像素的框内,同时保留其长宽比。
– 图片都集中在一个28×28的图像中。
– 像素以列为主进行排序。像素值0到255,0表示背景(白色),255表示前景(黑色)。
创建一个.png的文件,背景是白色的,手写的字体是黑色的,
下面是数据测试的代码,一个两层的卷积神经网,然后用save进行模型的保存。
# coding: UTF-8
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import input_data
'''''
得到数据
'''
mnist = input_data.read_data_sets("MNIST_data/", one_hot=True)
training = mnist.train.images
trainlable = mnist.train.labels
testing = mnist.test.images
testlabel = mnist.test.labels
print ("MNIST loaded")
# 获取交互式的方式
sess = tf.InteractiveSession()
# 初始化变量
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
'''''
生成权重函数,其中shape是数据的形状
'''
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
'''''
生成偏执项 其中shape是数据形状
'''
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
strides=[1, 2, 2, 1], padding='SAME')
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
x_image = tf.reshape(x, [-1, 28, 28, 1])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
keep_prob = tf.placeholder("float")
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
cross_entropy = -tf.reduce_sum(y_*tf.log(y_conv))
train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)
correct_prediction = tf.equal(tf.argmax(y_conv, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
# 保存网络训练的参数
saver = tf.train.Saver()
sess.run(tf.initialize_all_variables())
for i in range(8000):
batch = mnist.train.next_batch(50)
if i%100 == 0:
train_accuracy = accuracy.eval(feed_dict={
x:batch[0], y_: batch[1], keep_prob: 1.0})
print "step %d, training accuracy %g"%(i, train_accuracy)
train_step.run(feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})
save_path = saver.save(sess, "model_mnist.ckpt")
print("Model saved in life:", save_path)
print "test accuracy %g"%accuracy.eval(feed_dict={
x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0})
其中input_data.py如下代码,是进行mnist数据集的下载的:代码是由mnist数据集提供的官方下载的版本。
# Copyright 2015 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions for downloading and reading MNIST data."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import gzip
import os
import tensorflow.python.platform
import numpy
from six.moves import urllib
from six.moves import xrange # pylint: disable=redefined-builtin
import tensorflow as tf
SOURCE_URL = 'http://yann.lecun.com/exdb/mnist/'
def maybe_download(filename, work_directory):
"""Download the data from Yann's website, unless it's already here."""
if not os.path.exists(work_directory):
os.mkdir(work_directory)
filepath = os.path.join(work_directory, filename)
if not os.path.exists(filepath):
filepath, _ = urllib.request.urlretrieve(SOURCE_URL + filename, filepath)
statinfo = os.stat(filepath)
print('Successfully downloaded', filename, statinfo.st_size, 'bytes.')
return filepath
def _read32(bytestream):
dt = numpy.dtype(numpy.uint32).newbyteorder('>')
return numpy.frombuffer(bytestream.read(4), dtype=dt)[0]
def extract_images(filename):
"""Extract the images into a 4D uint8 numpy array [index, y, x, depth]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2051:
raise ValueError(
'Invalid magic number %d in MNIST image file: %s' %
(magic, filename))
num_images = _read32(bytestream)
rows = _read32(bytestream)
cols = _read32(bytestream)
buf = bytestream.read(rows * cols * num_images)
data = numpy.frombuffer(buf, dtype=numpy.uint8)
data = data.reshape(num_images, rows, cols, 1)
return data
def dense_to_one_hot(labels_dense, num_classes=10):
"""Convert class labels from scalars to one-hot vectors."""
num_labels = labels_dense.shape[0]
index_offset = numpy.arange(num_labels) * num_classes
labels_one_hot = numpy.zeros((num_labels, num_classes))
labels_one_hot.flat[index_offset + labels_dense.ravel()] = 1
return labels_one_hot
def extract_labels(filename, one_hot=False):
"""Extract the labels into a 1D uint8 numpy array [index]."""
print('Extracting', filename)
with gzip.open(filename) as bytestream:
magic = _read32(bytestream)
if magic != 2049:
raise ValueError(
'Invalid magic number %d in MNIST label file: %s' %
(magic, filename))
num_items = _read32(bytestream)
buf = bytestream.read(num_items)
labels = numpy.frombuffer(buf, dtype=numpy.uint8)
if one_hot:
return dense_to_one_hot(labels)
return labels
class DataSet(object):
def __init__(self, images, labels, fake_data=False, one_hot=False,
dtype=tf.float32):
"""Construct a DataSet.
one_hot arg is used only if fake_data is true. `dtype` can be either
`uint8` to leave the input as `[0, 255]`, or `float32` to rescale into
`[0, 1]`.
"""
dtype = tf.as_dtype(dtype).base_dtype
if dtype not in (tf.uint8, tf.float32):
raise TypeError('Invalid image dtype %r, expected uint8 or float32' %
dtype)
if fake_data:
self._num_examples = 10000
self.one_hot = one_hot
else:
assert images.shape[0] == labels.shape[0], (
'images.shape: %s labels.shape: %s' % (images.shape,
labels.shape))
self._num_examples = images.shape[0]
# Convert shape from [num examples, rows, columns, depth]
# to [num examples, rows*columns] (assuming depth == 1)
assert images.shape[3] == 1
images = images.reshape(images.shape[0],
images.shape[1] * images.shape[2])
if dtype == tf.float32:
# Convert from [0, 255] -> [0.0, 1.0].
images = images.astype(numpy.float32)
images = numpy.multiply(images, 1.0 / 255.0)
self._images = images
self._labels = labels
self._epochs_completed = 0
self._index_in_epoch = 0
@property
def images(self):
return self._images
@property
def labels(self):
return self._labels
@property
def num_examples(self):
return self._num_examples
@property
def epochs_completed(self):
return self._epochs_completed
def next_batch(self, batch_size, fake_data=False):
"""Return the next `batch_size` examples from this data set."""
if fake_data:
fake_image = [1] * 784
if self.one_hot:
fake_label = [1] + [0] * 9
else:
fake_label = 0
return [fake_image for _ in xrange(batch_size)], [
fake_label for _ in xrange(batch_size)]
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._num_examples:
# Finished epoch
self._epochs_completed += 1
# Shuffle the data
perm = numpy.arange(self._num_examples)
numpy.random.shuffle(perm)
self._images = self._images[perm]
self._labels = self._labels[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._num_examples
end = self._index_in_epoch
return self._images[start:end], self._labels[start:end]
def read_data_sets(train_dir, fake_data=False, one_hot=False, dtype=tf.float32):
class DataSets(object):
pass
data_sets = DataSets()
if fake_data:
def fake():
return DataSet([], [], fake_data=True, one_hot=one_hot, dtype=dtype)
data_sets.train = fake()
data_sets.validation = fake()
data_sets.test = fake()
return data_sets
TRAIN_IMAGES = 'train-images-idx3-ubyte.gz'
TRAIN_LABELS = 'train-labels-idx1-ubyte.gz'
TEST_IMAGES = 't10k-images-idx3-ubyte.gz'
TEST_LABELS = 't10k-labels-idx1-ubyte.gz'
VALIDATION_SIZE = 5000
local_file = maybe_download(TRAIN_IMAGES, train_dir)
train_images = extract_images(local_file)
local_file = maybe_download(TRAIN_LABELS, train_dir)
train_labels = extract_labels(local_file, one_hot=one_hot)
local_file = maybe_download(TEST_IMAGES, train_dir)
test_images = extract_images(local_file)
local_file = maybe_download(TEST_LABELS, train_dir)
test_labels = extract_labels(local_file, one_hot=one_hot)
validation_images = train_images[:VALIDATION_SIZE]
validation_labels = train_labels[:VALIDATION_SIZE]
train_images = train_images[VALIDATION_SIZE:]
train_labels = train_labels[VALIDATION_SIZE:]
data_sets.train = DataSet(train_images, train_labels, dtype=dtype)
data_sets.validation = DataSet(validation_images, validation_labels,
dtype=dtype)
data_sets.test = DataSet(test_images, test_labels, dtype=dtype)
return data_sets
然后进行代码的测试:
# import modules
import sys
import tensorflow as tf
from PIL import Image, ImageFilter
def predictint(imvalue):
"""
This function returns the predicted integer.
The imput is the pixel values from the imageprepare() function.
"""
# Define the model (same as when creating the model file)
x = tf.placeholder(tf.float32, [None, 784])
W = tf.Variable(tf.zeros([784, 10]))
b = tf.Variable(tf.zeros([10]))
def weight_variable(shape):
initial = tf.truncated_normal(shape, stddev=0.1)
return tf.Variable(initial)
def bias_variable(shape):
initial = tf.constant(0.1, shape=shape)
return tf.Variable(initial)
def conv2d(x, W):
return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
return tf.nn.max_pool(x, ksize=[1, 2, 2, 1], strides=[1, 2, 2, 1], padding='SAME')
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
x_image = tf.reshape(x, [-1, 28, 28, 1])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7 * 7 * 64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y_conv = tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
init_op = tf.initialize_all_variables()
saver = tf.train.Saver()
"""
Load the model_mnist.ckpt file
file is stored in the same directory as this python script is started
Use the model to predict the integer. Integer is returend as list.
Based on the documentatoin at
https://www.tensorflow.org/versions/master/how_tos/variables/index.html
"""
with tf.Session() as sess:
sess.run(init_op)
saver.restore(sess, "model_mnist.ckpt")
# print ("Model restored.")
prediction = tf.argmax(y_conv, 1)
return prediction.eval(feed_dict={x: [imvalue], keep_prob: 1.0}, session=sess)
def imageprepare(argv):
"""
This function returns the pixel values.
The imput is a png file location.
"""
im = Image.open(argv).convert('L')
width = float(im.size[0])
height = float(im.size[1])
newImage = Image.new('L', (28, 28), (255)) # creates white canvas of 28x28 pixels
if width > height: # check which dimension is bigger
# Width is bigger. Width becomes 20 pixels.
nheight = int(round((20.0 / width * height), 0)) # resize height according to ratio width
if (nheight == 0): # rare case but minimum is 1 pixel
nheigth = 1
# resize and sharpen
img = im.resize((20, nheight), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
wtop = int(round(((28 - nheight) / 2), 0)) # caculate horizontal pozition
newImage.paste(img, (4, wtop)) # paste resized image on white canvas
else:
# Height is bigger. Heigth becomes 20 pixels.
nwidth = int(round((20.0 / height * width), 0)) # resize width according to ratio height
if (nwidth == 0): # rare case but minimum is 1 pixel
nwidth = 1
# resize and sharpen
img = im.resize((nwidth, 20), Image.ANTIALIAS).filter(ImageFilter.SHARPEN)
wleft = int(round(((28 - nwidth) / 2), 0)) # caculate vertical pozition
newImage.paste(img, (wleft, 4)) # paste resized image on white canvas
# newImage.save("sample.png")
tv = list(newImage.getdata()) # get pixel values
# normalize pixels to 0 and 1. 0 is pure white, 1 is pure black.
tva = [(255 - x) * 1.0 / 255.0 for x in tv]
return tva
# print(tva)
def main(argv):
"""
Main function.
"""
imvalue = imageprepare(argv)
predint = predictint(imvalue)
print (predint[0]) # first value in list
if __name__ == "__main__":
main('2.png')
其中我用于测试的代码如下:
可以将图片另存到路径下面,然后进行测试。
(1)载入我的手写数字的图像。
(2)将图像转换为黑白(模式“L”)
(3)确定原始图像的尺寸是最大的
(4)调整图像的大小,使得最大尺寸(醚的高度及宽度)为20像素,并且以相同的比例最小化尺寸刻度。
(5)锐化图像。这会极大地强化结果。
(6)把图像粘贴在28×28像素的白色画布上。在最大的尺寸上从顶部或侧面居中图像4个像素。最大尺寸始终是20个像素和4 + 20 + 4 = 28,最小尺寸被定位在28和缩放的图像的新的大小之间差的一半。
(7)获取新的图像(画布+居中的图像)的像素值。
(8)归一化像素值到0和1之间的一个值(这也在TensorFlow MNIST教程中完成)。其中0是白色的,1是纯黑色。从步骤7得到的像素值是与之相反的,其中255是白色的,0黑色,所以数值必须反转。下述公式包括反转和规格化(255-X)* 1.0 / 255.0
来源:http://blog.csdn.net/juezhanangle/article/details/73018584


猜你喜欢
- numpy多维数组的创建多维数组(矩阵ndarray)ndarray的基本属性shape维度的大小ndim维度的个数dtype数据类型1.1
- ZeroClipboard.js是一个支持复制和粘贴的JavaScript插件,目前官方已经到2.x的版本了,但不支持IE9以下的浏览器,而
- 前提:安装xhtml2pdf https://pypi.python.org/pypi/xhtml2pdf/下载字体:微软雅黑;待转换的文件
- 本文仅作为基本操作流程的记录,不进行细节描述一、环境安装1、安装Pycharm在官网上下载最新版本Pycharm安装即可2、安装pyQT5p
- Vue-validator 是Vue的表单验证插件,供大家参考,具体内容如下Vue版本: 1.0.24 Vue-validator版本: 2
- 我们很容易用numpy()和from_numpy()将Tensor和NumPy中的数组相互转换。但是需要注意的一点是: 这两个函数所产生的T
- Python 包含6种数据类型,其中Number(数字)、String(字符串)、Tuple(元组)、List(列表)、Dictionary
- 一直用pycharm写代码一直用anaconda管理python环境但是今天我居然发现我不会更改pycharm当前的运行环境到我新建的ana
- 1 通过官网下载MySQL5.6版本压缩包,mysql-5.6.36-winx64.zip;2 在D盘创建目录,比如D:\MySQL,将my
- 在线文本去重复工具第一种方法:<textarea id="list" class="toolarea&q
- 本文实例讲述了Python iter()函数用法。分享给大家供大家参考,具体如下:python中的迭代器用起来非常灵巧,不仅可以迭代序列,也
- 本文实例讲述了Python实现模拟分割大文件及多线程处理的方法。分享给大家供大家参考,具体如下:#!/usr/bin/env python#
- 1.新式类与经典类在Python 2及以前的版本中,由任意内置类型派生出的类(只要一个内置类型位于类树的某个位置),都属于“新式类”,都会获
- HTTP上传的文件的原理HTTP协议的文件上传是通过HTTP POST请求实现的,使用multipart/form-data格式将待上传的文
- 在多个文件或者不同语言协同的项目中,python脚本经常需要从命令行直接读取参数。万能的python就自带了argprase包使得这一工作变
- 前言实现一个帧动画,使用的一个图,根据不同的时间显示不同的图。使用的就是如下所示的一张图,宽度780 * 300 ,使用加载图片 260 *
- 前言今晚就是新年夜啦,为了 刷一波存在感 送出我的祝福,同时让它看起来不像群发消息,我们简单地用三步来实现定制QQ祝福~
- 关于python读取xml文章很多,但大多文章都是贴一个xml文件,然后再贴个处理文件的代码。这样并不利于初学者的学习,希望这篇文章可以更通
- 如下所示:import matplotlib.pyplot as pltimport numpy as npx = [11422,11360
- iframe标签在网页中可以创建一个内嵌框架,通过指定src属性来调用另一个网页文档的内容。和frameset一样,用它来对网页结构进行拆分