TensorFlow搭建神经网络最佳实践
作者:marsjhao 发布时间:2021-03-11 18:59:26
一、TensorFLow完整样例
在MNIST数据集上,搭建一个简单神经网络结构,一个包含ReLU单元的非线性化处理的两层神经网络。在训练神经网络的时候,使用带指数衰减的学习率设置、使用正则化来避免过拟合、使用滑动平均模型来使得最终的模型更加健壮。
程序将计算神经网络前向传播的部分单独定义一个函数inference,训练部分定义一个train函数,再定义一个主函数main。
完整程序:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Thu May 25 08:56:30 2017
@author: marsjhao
"""
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
INPUT_NODE = 784 # 输入节点数
OUTPUT_NODE = 10 # 输出节点数
LAYER1_NODE = 500 # 隐含层节点数
BATCH_SIZE = 100
LEARNING_RETE_BASE = 0.8 # 基学习率
LEARNING_RETE_DECAY = 0.99 # 学习率的衰减率
REGULARIZATION_RATE = 0.0001 # 正则化项的权重系数
TRAINING_STEPS = 10000 # 迭代训练次数
MOVING_AVERAGE_DECAY = 0.99 # 滑动平均的衰减系数
# 传入神经网络的权重和偏置,计算神经网络前向传播的结果
def inference(input_tensor, avg_class, weights1, biases1, weights2, biases2):
# 判断是否传入ExponentialMovingAverage类对象
if avg_class == None:
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights1) + biases1)
return tf.matmul(layer1, weights2) + biases2
else:
layer1 = tf.nn.relu(tf.matmul(input_tensor, avg_class.average(weights1))
+ avg_class.average(biases1))
return tf.matmul(layer1, avg_class.average(weights2))\
+ avg_class.average(biases2)
# 神经网络模型的训练过程
def train(mnist):
x = tf.placeholder(tf.float32, [None,INPUT_NODE], name='x-input')
y_ = tf.placeholder(tf.float32, [None, OUTPUT_NODE], name='y-input')
# 定义神经网络结构的参数
weights1 = tf.Variable(tf.truncated_normal([INPUT_NODE, LAYER1_NODE],
stddev=0.1))
biases1 = tf.Variable(tf.constant(0.1, shape=[LAYER1_NODE]))
weights2 = tf.Variable(tf.truncated_normal([LAYER1_NODE, OUTPUT_NODE],
stddev=0.1))
biases2 = tf.Variable(tf.constant(0.1, shape=[OUTPUT_NODE]))
# 计算非滑动平均模型下的参数的前向传播的结果
y = inference(x, None, weights1, biases1, weights2, biases2)
global_step = tf.Variable(0, trainable=False) # 定义存储当前迭代训练轮数的变量
# 定义ExponentialMovingAverage类对象
variable_averages = tf.train.ExponentialMovingAverage(
MOVING_AVERAGE_DECAY, global_step) # 传入当前迭代轮数参数
# 定义对所有可训练变量trainable_variables进行更新滑动平均值的操作op
variables_averages_op = variable_averages.apply(tf.trainable_variables())
# 计算滑动模型下的参数的前向传播的结果
average_y = inference(x, variable_averages, weights1, biases1, weights2, biases2)
# 定义交叉熵损失值
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(
logits=y, labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
# 定义L2正则化器并对weights1和weights2正则化
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
regularization = regularizer(weights1) + regularizer(weights2)
loss = cross_entropy_mean + regularization # 总损失值
# 定义指数衰减学习率
learning_rate = tf.train.exponential_decay(LEARNING_RETE_BASE, global_step,
mnist.train.num_examples / BATCH_SIZE, LEARNING_RETE_DECAY)
# 定义梯度下降操作op,global_step参数可实现自加1运算
train_step = tf.train.GradientDescentOptimizer(learning_rate)\
.minimize(loss, global_step=global_step)
# 组合两个操作op
train_op = tf.group(train_step, variables_averages_op)
'''''
# 与tf.group()等价的语句
with tf.control_dependencies([train_step, variables_averages_op]):
train_op = tf.no_op(name='train')
'''
# 定义准确率
# 在最终预测的时候,神经网络的输出采用的是经过滑动平均的前向传播计算结果
correct_prediction = tf.equal(tf.argmax(average_y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 初始化回话sess并开始迭代训练
with tf.Session() as sess:
sess.run(tf.global_variables_initializer())
# 验证集待喂入数据
validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
# 测试集待喂入数据
test_feed = {x: mnist.test.images, y_: mnist.test.labels}
for i in range(TRAINING_STEPS):
if i % 1000 == 0:
validate_acc = sess.run(accuracy, feed_dict=validate_feed)
print('After %d training steps, validation accuracy'
' using average model is %f' % (i, validate_acc))
xs, ys = mnist.train.next_batch(BATCH_SIZE)
sess.run(train_op, feed_dict={x: xs, y_:ys})
test_acc = sess.run(accuracy, feed_dict=test_feed)
print('After %d training steps, test accuracy'
' using average model is %f' % (TRAINING_STEPS, test_acc))
# 主函数
def main(argv=None):
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
train(mnist)
# 当前的python文件是shell文件执行的入口文件,而非当做import的python module。
if __name__ == '__main__': # 在模块内部执行
tf.app.run() # 调用main函数并传入所需的参数list
二、分析与改进设计
1. 程序分析改进
第一,计算前向传播的函数inference中需要将所有的变量以参数的形式传入函数,当神经网络结构变得更加复杂、参数更多的时候,程序的可读性将变得非常差。
第二,在程序退出时,训练好的模型就无法再利用,且大型神经网络的训练时间都比较长,在训练过程中需要每隔一段时间保存一次模型训练的中间结果,这样如果在训练过程中程序死机,死机前的最新的模型参数仍能保留,杜绝了时间和资源的浪费。
第三,将训练和测试分成两个独立的程序,将训练和测试都会用到的前向传播的过程抽象成单独的库函数。这样就保证了在训练和预测两个过程中所调用的前向传播计算程序是一致的。
2. 改进后程序设计
mnist_inference.py
该文件中定义了神经网络的前向传播过程,其中的多次用到的weights定义过程又单独定义成函数。
通过tf.get_variable函数来获取变量,在神经网络训练时创建这些变量,在测试时会通过保存的模型加载这些变量的取值,而且可以在变量加载时将滑动平均值重命名。所以可以直接通过同样的名字在训练时使用变量自身,在测试时使用变量的滑动平均值。
mnist_train.py
该程序给出了神经网络的完整训练过程。
mnist_eval.py
在滑动平均模型上做测试。
通过tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)获取最新模型的文件名,实际是获取checkpoint文件的所有内容。
三、TensorFlow最佳实践样例
mnist_inference.py
import tensorflow as tf
INPUT_NODE = 784
OUTPUT_NODE = 10
LAYER1_NODE = 500
def get_weight_variable(shape, regularizer):
weights = tf.get_variable("weights", shape,
initializer=tf.truncated_normal_initializer(stddev=0.1))
if regularizer != None:
# 将权重参数的正则化项加入至损失集合
tf.add_to_collection('losses', regularizer(weights))
return weights
def inference(input_tensor, regularizer):
with tf.variable_scope('layer1'):
weights = get_weight_variable([INPUT_NODE, LAYER1_NODE], regularizer)
biases = tf.get_variable("biases", [LAYER1_NODE],
initializer=tf.constant_initializer(0.0))
layer1 = tf.nn.relu(tf.matmul(input_tensor, weights) + biases)
with tf.variable_scope('layer2'):
weights = get_weight_variable([LAYER1_NODE, OUTPUT_NODE], regularizer)
biases = tf.get_variable("biases", [OUTPUT_NODE],
initializer=tf.constant_initializer(0.0))
layer2 = tf.matmul(layer1, weights) + biases
return layer2
mnist_train.py
import os
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
BATCH_SIZE = 100
LEARNING_RATE_BASE = 0.8
LEARNING_RATE_DECAY = 0.99
REGULARIZATION_RATE = 0.0001
TRAINING_STEPS = 10000
MOVING_AVERAGE_DECAY = 0.99
MODEL_SAVE_PATH = "Model_Folder/"
MODEL_NAME = "model.ckpt"
def train(mnist):
# 定义输入placeholder
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE],
name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE],
name='y-input')
# 定义正则化器及计算前向过程输出
regularizer = tf.contrib.layers.l2_regularizer(REGULARIZATION_RATE)
y = mnist_inference.inference(x, regularizer)
# 定义当前训练轮数及滑动平均模型
global_step = tf.Variable(0, trainable=False)
variable_averages = tf.train.ExponentialMovingAverage(MOVING_AVERAGE_DECAY,
global_step)
variables_averages_op = variable_averages.apply(tf.trainable_variables())
# 定义损失函数
cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=y,
labels=tf.argmax(y_, 1))
cross_entropy_mean = tf.reduce_mean(cross_entropy)
loss = cross_entropy_mean + tf.add_n(tf.get_collection('losses'))
# 定义指数衰减学习率
learning_rate = tf.train.exponential_decay(LEARNING_RATE_BASE, global_step,
mnist.train.num_examples / BATCH_SIZE, LEARNING_RATE_DECAY)
# 定义训练操作,包括模型训练及滑动模型操作
train_step = tf.train.GradientDescentOptimizer(learning_rate)\
.minimize(loss, global_step=global_step)
train_op = tf.group(train_step, variables_averages_op)
# 定义Saver类对象,保存模型,TensorFlow持久化类
saver = tf.train.Saver()
# 定义会话,启动训练过程
with tf.Session() as sess:
tf.global_variables_initializer().run()
for i in range(TRAINING_STEPS):
xs, ys = mnist.train.next_batch(BATCH_SIZE)
_, loss_value, step = sess.run([train_op, loss, global_step],
feed_dict={x: xs, y_: ys})
if i % 1000 == 0:
print("After %d training step(s), loss on training batch is %g."\
% (step, loss_value))
# save方法的global_step参数可以让每个被保存的模型的文件名末尾加上当前训练轮数
saver.save(sess, os.path.join(MODEL_SAVE_PATH, MODEL_NAME),
global_step=global_step)
def main(argv=None):
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
train(mnist)
if __name__ == '__main__':
tf.app.run()
mnist_eval.py
import time
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import mnist_inference
import mnist_train
EVAL_INTERVAL_SECS = 10
def evaluate(mnist):
with tf.Graph().as_default() as g:
# 定义输入placeholder
x = tf.placeholder(tf.float32, [None, mnist_inference.INPUT_NODE],
name='x-input')
y_ = tf.placeholder(tf.float32, [None, mnist_inference.OUTPUT_NODE],
name='y-input')
# 定义feed字典
validate_feed = {x: mnist.validation.images, y_: mnist.validation.labels}
# 测试时不加参数正则化损失
y = mnist_inference.inference(x, None)
# 计算正确率
correct_prediction = tf.equal(tf.argmax(y, 1), tf.argmax(y_, 1))
accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
# 加载滑动平均模型下的参数值
variable_averages = tf.train.ExponentialMovingAverage(
mnist_train.MOVING_AVERAGE_DECAY)
saver = tf.train.Saver(variable_averages.variables_to_restore())
# 每隔EVAL_INTERVAL_SECS秒启动一次会话
while True:
with tf.Session() as sess:
ckpt = tf.train.get_checkpoint_state(mnist_train.MODEL_SAVE_PATH)
if ckpt and ckpt.model_checkpoint_path:
saver.restore(sess, ckpt.model_checkpoint_path)
# 取checkpoint文件中的当前迭代轮数global_step
global_step = ckpt.model_checkpoint_path\
.split('/')[-1].split('-')[-1]
accuracy_score = sess.run(accuracy, feed_dict=validate_feed)
print("After %s training step(s), validation accuracy = %g"\
% (global_step, accuracy_score))
else:
print('No checkpoint file found')
return
time.sleep(EVAL_INTERVAL_SECS)
def main(argv=None):
mnist = input_data.read_data_sets("MNIST_data", one_hot=True)
evaluate(mnist)
if __name__ == '__main__':
tf.app.run()
来源:http://blog.csdn.net/marsjhao/article/details/72831021
猜你喜欢
- 前言本文的操作环境:ubuntu,Python2.7,采用的是Pycharm进行代码编辑,个人很喜欢它的代码自动补齐功能。示例图如上图,我们
- AJAX应用因为它们的表现力的丰富、更加互动和更加迅速的响应得到了赞扬声;这些优点都是通过使用XMLHttpRequest对象来动态的载入数
- 炫酷地图前期我们介绍了很多的地图模板,不管是全球的还是中国的,其实我感觉都十分的炫酷,哈哈哈,可是还有更加神奇的,更加炫酷的地图模板,下面让
- 本文分析了python3新特性函数注释Function Annotations用法。分享给大家供大家参考,具体如下:Python 3.X新增
- 本文实例讲述了wxPython定时器wx.Timer简单应用。分享给大家供大家参考。具体如下:# -*- coding: utf-8 -*-
- 今天来填坑, 昨天说playwright未必一定要使用pytest-playwright包。 它也可以和pyunit一起使用。那么今天,田辛
- openpyxl 的用法实例1.1 Openpyxl 库的安装使用openpyxl 模块是一个读写 Excel 2010 文档的 Pytho
- 本文实例讲述了Python数据结构与算法之图的广度优先与深度优先搜索算法。分享给大家供大家参考,具体如下:根据 * 的伪代码实现:广度优先
- Python装饰器(decorator)是在程序开发中经常使用到的功能,合理使用装饰器,能让我们的程序如虎添翼。装饰器引入初期及问题诞生假如
- 文字向下滾動,逐渐隐藏效果~ 挺好的 <!DOCTYPE html PUBLIC "-//W3C//DTD XHT
- 使用T_SQL创建数据库 TestSchool 创建一个学生表 TblStudent 创建学生成绩表 TblScore q tScoreId
- 前言笔者用的是mac开发,但是mac自带的php功能安装十分不方便,并且和线上的linux开发环境不一致。在没有用docker之前一直用va
- 本文使用pygame实现播放mp3,文中用到pygame及mutagen库,安装:pip install pygamepip install
- mysql 配置白名单访问的步骤1.登录mysql -uroot -pmysql2.切换至mysql库use mysql;3.查看有白名单权
- 第一个问题是重命名数据库问题:在企业管理器中是无法直接对数据库重命名的,只能在查询分析器中操作create proc killspid (@
- 话不多说,直接附上源码,仅供参考封装了一下,要用的话直接调用下面getEvent函数即可function getEvent() { if (
- Python之所以这么流行,是因为它不仅能够应用于科技领域,还能用来做许多其他学科的研究工具,绘制地图便是其功能之一。今天我们用matplo
- 二分查找法(Binary Search)是一种在有序数组中查找某一特定元素的算法,它的思想是将数组从中间分成两部分,判断目标元素在哪一部分中
- 为了更好的说明问题,首先引出下面的题目//请说明下面变量 a-d 的值 var a = [[1][1]]; var b = [['a
- Python脚本常见参数获取和处理平常写 python 脚本时会有一些从命令行获取参数的需求,这篇文章记录下常见的参数获取和处理方式。1.