TensorFlow教程Softmax逻辑回归识别手写数字MNIST数据集
作者:零尾 发布时间:2021-05-24 18:25:35
标签:TensorFlow,Softmax,MNIST,逻辑回归
基于MNIST数据集的逻辑回归模型做十分类任务
没有隐含层的Softmax Regression只能直接从图像的像素点推断是哪个数字,而没有特征抽象的过程。多层神经网络依靠隐含层,则可以组合出高阶特征,比如横线、竖线、圆圈等,之后可以将这些高阶特征或者说组件再组合成数字,就能实现精准的匹配和分类。
import tensorflow as tf
import numpy as np
import input_data
print('Download and Extract MNIST dataset')
mnist = input_data.read_data_sets('data/', one_hot=True) # one_hot=True意思是编码格式为01编码
print("tpye of 'mnist' is %s" % (type(mnist)))
print("number of train data is %d" % (mnist.train.num_examples))
print("number of test data is %d" % (mnist.test.num_examples))
trainimg = mnist.train.images
trainlabel = mnist.train.labels
testimg = mnist.test.images
testlabel = mnist.test.labels
print("MNIST loaded")
"""
print("type of 'trainimg' is %s" % (type(trainimg)))
print("type of 'trainlabel' is %s" % (type(trainlabel)))
print("type of 'testimg' is %s" % (type(testimg)))
print("type of 'testlabel' is %s" % (type(testlabel)))
print("------------------------------------------------")
print("shape of 'trainimg' is %s" % (trainimg.shape,))
print("shape of 'trainlabel' is %s" % (trainlabel.shape,))
print("shape of 'testimg' is %s" % (testimg.shape,))
print("shape of 'testlabel' is %s" % (testlabel.shape,))
"""
x = tf.placeholder(tf.float32, [None, 784])
y = tf.placeholder(tf.float32, [None, 10]) # None is for infinite
w = tf.Variable(tf.zeros([784, 10])) # 为了方便直接用0初始化,可以高斯初始化
b = tf.Variable(tf.zeros([10])) # 10分类的任务,10种label,所以只需要初始化10个b
pred = tf.nn.softmax(tf.matmul(x, w) + b) # 前向传播的预测值
cost = tf.reduce_mean(-tf.reduce_sum(y*tf.log(pred), reduction_indices=[1])) # 交叉熵损失函数
optm = tf.train.GradientDescentOptimizer(0.01).minimize(cost)
corr = tf.equal(tf.argmax(pred, 1), tf.argmax(y, 1)) # tf.equal()对比预测值的索引和真实label的索引是否一样,一样返回True,不一样返回False
accr = tf.reduce_mean(tf.cast(corr, tf.float32))
init = tf.global_variables_initializer() # 全局参数初始化器
training_epochs = 100 # 所有样本迭代100次
batch_size = 100 # 每进行一次迭代选择100个样本
display_step = 5
# SESSION
sess = tf.Session() # 定义一个Session
sess.run(init) # 在sess里run一下初始化操作
# MINI-BATCH LEARNING
for epoch in range(training_epochs): # 每一个epoch进行循环
avg_cost = 0. # 刚开始损失值定义为0
num_batch = int(mnist.train.num_examples/batch_size)
for i in range(num_batch): # 每一个batch进行选择
batch_xs, batch_ys = mnist.train.next_batch(batch_size) # 通过next_batch()就可以一个一个batch的拿数据,
sess.run(optm, feed_dict={x: batch_xs, y: batch_ys}) # run一下用梯度下降进行求解,通过placeholder把x,y传进来
avg_cost += sess.run(cost, feed_dict={x: batch_xs, y:batch_ys})/num_batch
# DISPLAY
if epoch % display_step == 0: # display_step之前定义为5,这里每5个epoch打印一下
train_acc = sess.run(accr, feed_dict={x: batch_xs, y:batch_ys})
test_acc = sess.run(accr, feed_dict={x: mnist.test.images, y: mnist.test.labels})
print("Epoch: %03d/%03d cost: %.9f TRAIN ACCURACY: %.3f TEST ACCURACY: %.3f"
% (epoch, training_epochs, avg_cost, train_acc, test_acc))
print("DONE")
迭代100次跑一下模型,最终,在测试集上可以达到92.2%的准确率,虽然还不错,但是还达不到实用的程度。手写数字的识别的主要应用场景是识别银行支票,如果准确率不够高,可能会引起严重的后果。
Epoch: 095/100 loss: 0.283259882 train_acc: 0.940 test_acc: 0.922
插一些知识点,关于tensorflow中一些函数的用法
sess = tf.InteractiveSession()
arr = np.array([[31, 23, 4, 24, 27, 34],
[18, 3, 25, 0, 6, 35],
[28, 14, 33, 22, 30, 8],
[13, 30, 21, 19, 7, 9],
[16, 1, 26, 32, 2, 29],
[17, 12, 5, 11, 10, 15]])
在tensorflow中打印要用.eval()
tf.rank(arr).eval() # 打印矩阵arr的维度
tf.shape(arr).eval() # 打印矩阵arr的大小
tf.argmax(arr, 0).eval() # 打印最大值的索引,参数0为按列求索引,1为按行求索引
来源:https://blog.csdn.net/lwplwf/article/details/60603746


猜你喜欢
- 最近写运维自动化平台,需要用python写很多的小功能模块。这里就分享一个用Python的paramiko来实现功能的一段代码:复制远程服务
- Python 的 httpx 包是一个复杂的 Web 客户端。当你安装它后,你就可以用它来从网站上获取数据。像往常一样,
- 如下所示:var myarr=new Array(); //先声明一维 for(var i=0;i<2;i++){ //一
- 本文实例讲述了Python基于回溯法子集树模板解决马踏棋盘问题。分享给大家供大家参考,具体如下:问题将马放到国际象棋的8*8棋盘board上
- 前言最近由于在寻找方向上迷失自我,准备了解更多的计算机视觉任务重的模型。看到语义分割任务重Unet一个有意思的模型,我准备来复现一下它。一、
- 无论安装何版本的mysql,在管理工具的服务中启动mysql服务时都会在中途报错。内容为:在 本地计算机 无法启动mysql服务 错误106
- 前言提到数据库,大家第一时间想到的可能是 sql 数据库,这种数据库非常好用,但是对于新手就不是很容易上手,需要熟悉一段时间才可以大概掌握。
- 本文主要介绍了pandas统计重复值次数的方法实现,分享给大家,具体如下:from pandas import DataFramedf =
- 前言最近做了几个简单的爬虫python程序,于是就想做个窗口看看效果。首先是,窗口的话,以前没怎么接触过,就先考虑用Qt制作简单的ui。这里
- 前言:随机数模块实现了各种分布的伪随机数生成器。对于整数,从范围中有统一的选择。 对于序列,存在随机元素的统一选择、用于生成列表的随机排列的
- 1.TCP是一种面向连接的可靠地协议,在一方发送数据之前,必须在双方之间建立一个连接,建立的过程需要经过三次握手,通信完成后要拆除连接,需要
- 本文实例讲述了Python使用PyCrypto实现AES加密功能。分享给大家供大家参考,具体如下:#!/usr/bin/env python
- 一. Go 切片和 Go 数组定义Go 切片:又称动态数组,它实际是基于数组类型做的一层封装。Go 数组:数组是内置(build-in)类型
- 一、输出指令ASP的输出指令<% =expression %>显示表达式的值。这个输出指令等同于使用Resp
- 本文转自公众号:"算法与编程之美"1、问题描述Python中数据类型有列表,元组,字典,队列,栈,树等等。像列表,元组这
- 这篇文章主要介绍了Python如何基于smtplib发不同格式的邮件,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习
- 前言最近在工作中碰到一个小的排序问题,需要按嵌套对象的多个属性来排序,于是发现了Python里的operator模块和sorted函数组合可
- 目录需求背景思路分析UI展示开始使用一 编写支付组件模板二 支付组件的JS相关代码和说明附:组件JS完整的源码需求背景市场报告列表展示的报告
- 我们在做深度学习的过程中,经常面临图片样本不足、不平衡的情况,在本文中,作者结合实际工作经验,通过图像的移动、缩放、旋转、增加噪声等图像变换
- 本文实例讲述了python使用自定义user-agent抓取网页的方法。分享给大家供大家参考。具体如下:下面python代码通过urllib