tensorflow建立一个简单的神经网络的方法
作者:Mr丶Caleb 发布时间:2022-09-27 17:01:51
标签:tensorflow,神经网络
本笔记目的是通过tensorflow实现一个两层的神经网络。目的是实现一个二次函数的拟合。
如何添加一层网络
代码如下:
def add_layer(inputs, in_size, out_size, activation_function=None):
# add one more layer and return the output of this layer
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
Wx_plus_b = tf.matmul(inputs, Weights) + biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs
注意该函数中是xW+b,而不是Wx+b。所以要注意乘法的顺序。x应该定义为[类别数量, 数据数量], W定义为[数据类别,类别数量]。
创建一些数据
# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise
numpy的linspace函数能够产生等差数列。start,stop决定等差数列的起止值。endpoint参数指定包不包括终点值。
numpy.linspace(start, stop, num=50, endpoint=True, retstep=False, dtype=None)[source]
Return evenly spaced numbers over a specified interval.
Returns num evenly spaced samples, calculated over the interval [start, stop].
noise函数为添加噪声所用,这样二次函数的点不会与二次函数曲线完全重合。
numpy的newaxis可以新增一个维度而不需要重新创建相应的shape在赋值,非常方便,如上面的例子中就将x_data从一维变成了二维。
添加占位符,用作输入
# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
添加隐藏层和输出层
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)
计算误差,并用梯度下降使得误差最小
# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
完整代码如下:
from __future__ import print_function
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
def add_layer(inputs, in_size, out_size, activation_function=None):
# add one more layer and return the output of this layer
Weights = tf.Variable(tf.random_normal([in_size, out_size]))
biases = tf.Variable(tf.zeros([1, out_size]) + 0.1)
Wx_plus_b = tf.matmul(inputs, Weights) + biases
if activation_function is None:
outputs = Wx_plus_b
else:
outputs = activation_function(Wx_plus_b)
return outputs
# Make up some real data
x_data = np.linspace(-1,1,300)[:, np.newaxis]
noise = np.random.normal(0, 0.05, x_data.shape)
y_data = np.square(x_data) - 0.5 + noise
# define placeholder for inputs to network
xs = tf.placeholder(tf.float32, [None, 1])
ys = tf.placeholder(tf.float32, [None, 1])
# add hidden layer
l1 = add_layer(xs, 1, 10, activation_function=tf.nn.relu)
# add output layer
prediction = add_layer(l1, 10, 1, activation_function=None)
# the error between prediciton and real data
loss = tf.reduce_mean(tf.reduce_sum(tf.square(ys - prediction),
reduction_indices=[1]))
train_step = tf.train.GradientDescentOptimizer(0.1).minimize(loss)
# important step
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
# plot the real data
fig = plt.figure()
ax = fig.add_subplot(1,1,1)
ax.scatter(x_data, y_data)
plt.ion()
plt.show()
for i in range(1000):
# training
sess.run(train_step, feed_dict={xs: x_data, ys: y_data})
if i % 50 == 0:
# to visualize the result and improvement
try:
ax.lines.remove(lines[0])
except Exception:
pass
prediction_value = sess.run(prediction, feed_dict={xs: x_data})
# plot the prediction
lines = ax.plot(x_data, prediction_value, 'r-', lw=5)
plt.pause(0.1)
运行结果:
来源:http://blog.csdn.net/qq_30159351/article/details/52639291
0
投稿
猜你喜欢
- 如果只是想实现将jenkins的构建结果发送到企业微信进行通知,最简便的方式是安装Qy Wechat Notification Plugin
- 测试代码:输出简单的ul li1.asp代码如下:<% response.write "<ul>" r
- 人常常感受到色彩对自己心理的影响,这些影响总是在不知不觉中发挥作用,左右我们的情绪。色彩的心理效应发生在不同层次中。有些属直接的刺激,有些要
- 一起画图吧为什么突然想搞这个画图软件呢不瞒各位,是因为最近接到了一个很小很小很小小得不能再小的小项目就是基于Tkinter,做一个简易的画图
- 前几天要写一个东西里面有用到读文件的。 可是我不想用FSO,我怕有的空间不支持。 &nbs
- unittest是python的一个单元测试框架关于断言它是用于对一个确定结果和预测结果的一种判断,如果结果正确无任何返回效果,如果结果错误
- python多线程适合IO密集型场景,而在CPU密集型场景,并不能充分利用多核CPU,而协程本质基于线程,同样不能充分发挥多核的优势。针对计
- 1 发送文本信息'''加密发送文本邮件'''def sendEmail(from_addr,
- 本文实例讲述了PHP实现将科学计数法转换为原始数字字符串的方法,分享给大家供大家参考。具体实现代码如下:function NumToStr(
- 使用 WinHttpRequest 伪造 HTTP 头信息,伪造 Referer 等信息。由于微软封锁了 XmlHttp 对象,所以无法伪造
- windows下python安装pip 简易教程,具体内容如下1.前提你要已经安装了 某个 版本的 python, 下载地址)安装后,需要配
- 目标网址:https://www.baidu.com/要获取的内容:链接分析:从下图可以看出只需要获取关键字,再构建就可以了。完整代码:im
- 给定一个字符串,求它最长的回文子串长度,例如输入字符串'35534321',它的最长回文子串是'3553',
- 一、python中对文件、文件夹操作时经常用到的os模块和shutil模块常用方法。1.得到当前工作目录,即当前Python脚本工作的目录路
- 目录深度遍历递归用栈来遍历磁盘广度遍历磁盘用队列遍历磁盘深度遍历递归import osdef get_files(path):
- 配置指令如下:[opcache]zend_extension=opcache.soopcache.enable_cli=1;共享内存大小,
- 一、安装pip install apscheduler二、ApScheduler 简介1 APScheduler的组件triggers:触发
- 哪个Python版本?当我提及Python,所指的就是CPython 2(准确的是2.7).我会显式提醒那些相同的代码在CPython 3
- 表结构的修改1、表结构修改后,原来表中已存在的数据,就会出现结构混乱,makemigrations更新表的时候就会出错比如第一次建模型,漏了
- 我想把存在数据库里的每天24小时来访者数另放到一个Excel文件中去,可以吗?可以,其实就是将数据库里面的内容生成一个Excel文件:toe