用tensorflow构建线性回归模型的示例代码
作者:freedom098 发布时间:2022-04-12 03:41:47
用tensorflow构建简单的线性回归模型是tensorflow的一个基础样例,但是原有的样例存在一些问题,我在实际调试的过程中做了一点自己的改进,并且有一些体会。
首先总结一下tf构建模型的总体套路
1、先定义模型的整体图结构,未知的部分,比如输入就用placeholder来代替。
2、再定义最后与目标的误差函数。
3、最后选择优化方法。
另外几个值得注意的地方是:
1、tensorflow构建模型第一步是先用代码搭建图模型,此时图模型是静止的,是不产生任何运算结果的,必须使用Session来驱动。
2、第二步根据问题的不同要求构建不同的误差函数,这个函数就是要求优化的函数。
3、调用合适的优化器优化误差函数,注意,此时反向传播调整参数的过程隐藏在了图模型当中,并没有显式显现出来。
4、tensorflow的中文意思是张量流动,也就是说有两个意思,一个是参与运算的不仅仅是标量或是矩阵,甚至可以是具有很高维度的张量,第二个意思是这些数据在图模型中流动,不停地更新。
5、session的run函数中,按照传入的操作向上查找,凡是操作中涉及的无论是变量、常量都要参与运算,占位符则要在run过程中以字典形式传入。
以上时tensorflow的一点认识,下面是关于梯度下降的一点新认识。
1、梯度下降法分为批量梯度下降和随机梯度下降法,第一种是所有数据都参与运算后,计算误差函数,根据此误差函数来更新模型参数,实际调试发现,如果定义误差函数为平方误差函数,这个值很快就会飞掉,原因是,批量平方误差都加起来可能会很大,如果此时学习率比较高,那么调整就会过,造成模型参数向一个方向大幅调整,造成最终结果发散。所以这个时候要降低学习率,让参数变化不要太快。
2、随机梯度下降法,每次用一个数据计算误差函数,然后更新模型参数,这个方法有可能会造成结果出现震荡,而且麻烦的是由于要一个个取出数据参与运算,而不是像批量计算那样采用了广播或者向量化乘法的机制,收敛会慢一些。但是速度要比使用批量梯度下降要快,原因是不需要每次计算全部数据的梯度了。比较折中的办法是mini-batch,也就是每次选用一小部分数据做梯度下降,目前这也是最为常用的方法了。
3、epoch概念:所有样本集过完一轮,就是一个epoch,很明显,如果是严格的随机梯度下降法,一个epoch内更新了样本个数这么多次参数,而批量法只更新了一次。
以上是我个人的一点认识,希望大家看到有不对的地方及时批评指针,不胜感激!
#encoding=utf-8
__author__ = 'freedom'
import tensorflow as tf
import numpy as np
def createData(dataNum,w,b,sigma):
train_x = np.arange(dataNum)
train_y = w*train_x+b+np.random.randn()*sigma
#print train_x
#print train_y
return train_x,train_y
def linerRegression(train_x,train_y,epoch=100000,rate = 0.000001):
train_x = np.array(train_x)
train_y = np.array(train_y)
n = train_x.shape[0]
x = tf.placeholder("float")
y = tf.placeholder("float")
w = tf.Variable(tf.random_normal([1])) # 生成随机权重
b = tf.Variable(tf.random_normal([1]))
pred = tf.add(tf.mul(x,w),b)
loss = tf.reduce_sum(tf.pow(pred-y,2))
optimizer = tf.train.GradientDescentOptimizer(rate).minimize(loss)
init = tf.initialize_all_variables()
sess = tf.Session()
sess.run(init)
print 'w start is ',sess.run(w)
print 'b start is ',sess.run(b)
for index in range(epoch):
#for tx,ty in zip(train_x,train_y):
#sess.run(optimizer,{x:tx,y:ty})
sess.run(optimizer,{x:train_x,y:train_y})
# print 'w is ',sess.run(w)
# print 'b is ',sess.run(b)
# print 'pred is ',sess.run(pred,{x:train_x})
# print 'loss is ',sess.run(loss,{x:train_x,y:train_y})
#print '------------------'
print 'loss is ',sess.run(loss,{x:train_x,y:train_y})
w = sess.run(w)
b = sess.run(b)
return w,b
def predictionTest(test_x,test_y,w,b):
W = tf.placeholder(tf.float32)
B = tf.placeholder(tf.float32)
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
n = test_x.shape[0]
pred = tf.add(tf.mul(X,W),B)
loss = tf.reduce_mean(tf.pow(pred-Y,2))
sess = tf.Session()
loss = sess.run(loss,{X:test_x,Y:test_y,W:w,B:b})
return loss
if __name__ == "__main__":
train_x,train_y = createData(50,2.0,7.0,1.0)
test_x,test_y = createData(20,2.0,7.0,1.0)
w,b = linerRegression(train_x,train_y)
print 'weights',w
print 'bias',b
loss = predictionTest(test_x,test_y,w,b)
print loss
来源:http://blog.csdn.net/freedom098/article/details/52106931
猜你喜欢
- 线程和进程1、线程共享创建它的进程的地址空间,进程有自己的地址空间2、线程可以访问进程所有的数据,线程可以相互访问3、线程之间的数据是独立的
- mmh3安装方法哈希方法主要有MD、SHA、Murmur、CityHash、MAC等几种方法。mmh3全程murmurhash3,是一种非加
- 使用torchvision来进行图片的数据增广数据增强就是增强一个已有数据集,使得有更多的多样性。对于图片数据来说,就是改变图片的颜色和形状
- context 有什么作用context 主要用来在goroutine 之间传递上下文信息,包括:取消信号、超时时间、截止时间、k-v 等。
- 本文实例讲述了PHP查询快递信息的方法。分享给大家供大家参考。具体如下:这里使用快递100物流查询官方文档中只能返回html的接口也可以返回
- 看了下传统的方法,觉得不好,太麻烦。自己重写了个,思路比较新。这个函数的优点是html代码可以很简洁,使用图片也可以很少,只需要两张图片。事
- 区别:xx:公有变量,所有对象都可以访问;xxx:双下划线代表着是系统定义的名字。__xxx:双前置下划线,避免与子类中的属性命名冲突,无法
- 本文实例讲述了CentOS环境下安装Redis3.0及phpredis扩展测试。分享给大家供大家参考,具体如下:线上的统一聊天及推送系统re
- 目录一、线程基础以及守护进程二、线程锁(互斥锁)三、线程锁(递归锁)四、死锁五、队列六、相关面试题七、判断数据是否安全八、进程池 &
- 1. 截取GB2312中文字符串 <?php //截取中文字符串 function mysubstr($str, $star
- 内置数据类型在编程中,数据类型是一个重要的概念。变量可以存储不同类型的数据,并且不同类型可以执行不同的操作。在这些类别中,Python 默认
- 本文实例讲述了Python双链表原理与实现方法。分享给大家供大家参考,具体如下:Python实现双链表文章目录Python实现双链表定义链表
- 目录项目初始化选择 MQTT 客户端库Pip 安装 Paho MQTT 客户端Python MQTT 使用连接 MQTT 服务器导入 Pah
- 体系结构 Microsoft按照客户/服务器体系结构的分布进行操作。这种方法产生不必要的代价和复杂性。在Internet中,Oracle已经
- 在Windows环境下,经常遇到系统Over的情况,如果你在新装了系统和SQL Server 2005后,需要把SQL Server2000
- 最近在做项目的时候经常会用到定时任务,由于我的项目是使用Java来开发,用的是SpringBoot框架,因此要实现这个定时任务其实并不难。后
- 这次我们讨论的是,区分有单选框的选项和普通的选项~~乍听起来,可能不太理解我说了什么,下面举个例子先~~1、标签的单选~~例如QQ秀的支付流
- AddHeaderAddHeader 方法用指定的值添加 HTML 标题。该方法常常向响应添加新的 HTTP 标题。它并不替代现有的同名标题
- 在遥感应用中,我们经常需要对某一景遥感影像中的全部像元的像素值进行平均值求取——这一操作很好实现,基
- 在PHP界谈模板引擎,必不可免的要拿Smarty开刀, 这个无比傻帽的却又带有一点点官方色彩的模板引擎, 如果没有我这样人富有正义感又富有创