TensorFlow实现简单线性回归
作者:kylinxjd 发布时间:2023-09-18 13:23:45
标签:TensorFlow,线性回归
本文实例为大家分享了TensorFlow实现简单线性回归的具体代码,供大家参考,具体内容如下
简单的一元线性回归
一元线性回归公式:
其中x是特征:[x1,x2,x3,…,xn,]T
w是权重,b是偏置值
代码实现
导入必须的包
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
# 屏蔽warning以下的日志信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
产生模拟数据
def generate_data():
x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32)
y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30))
return x, y
x是100行1列的数据,tf.matmul是矩阵相乘,所以权值设置成二维的。
设置的w是1.3, b是1
实现回归
def myregression():
"""
自实现线性回归
:return:
"""
x, y = generate_data()
# 建立模型 y = x * w + b
# w 1x1的二维数据
w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a')
b = tf.Variable(0.0, name='bias_b')
y_predict = tf.matmul(x, a) + b
# 建立损失函数
loss = tf.reduce_mean(tf.square(y_predict - y))
# 训练
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss=loss)
# 初始化全局变量
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print('初始的权重:%f偏置值:%f' % (a.eval(), b.eval()))
# 训练优化
for i in range(1, 100):
sess.run(train_op)
print('第%d次优化的权重:%f偏置值:%f' % (i, a.eval(), b.eval()))
# 显示回归效果
show_img(x.eval(), y.eval(), y_predict.eval())
使用matplotlib查看回归效果
def show_img(x, y, y_pre):
plt.scatter(x, y)
plt.plot(x, y_pre)
plt.show()
完整代码
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def generate_data():
x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32)
y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30))
return x, y
def myregression():
"""
自实现线性回归
:return:
"""
x, y = generate_data()
# 建立模型 y = x * w + b
w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a')
b = tf.Variable(0.0, name='bias_b')
y_predict = tf.matmul(x, w) + b
# 建立损失函数
loss = tf.reduce_mean(tf.square(y_predict - y))
# 训练
train_op = tf.train.GradientDescentOptimizer(0.0001).minimize(loss=loss)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print('初始的权重:%f偏置值:%f' % (w.eval(), b.eval()))
# 训练优化
for i in range(1, 35000):
sess.run(train_op)
print('第%d次优化的权重:%f偏置值:%f' % (i, w.eval(), b.eval()))
show_img(x.eval(), y.eval(), y_predict.eval())
def show_img(x, y, y_pre):
plt.scatter(x, y)
plt.plot(x, y_pre)
plt.show()
if __name__ == '__main__':
myregression()
看看训练的结果(因为数据是随机产生的,每次的训练结果都会不同,可适当调节梯度下降的学习率和训练步数)
35000次的训练结果
来源:https://blog.csdn.net/kylinxjd/article/details/105557304
0
投稿
猜你喜欢
- 本文实例讲述了Python类装饰器。分享给大家供大家参考,具体如下:编写类装饰器类装饰器类似于函数装饰器的概念,但它应用于类,它们可以用于管
- 从Python字符串中删除最后一个分号或者逗号第一种方法使用 str.rstrip() 方法从字符串中删除最后一个逗号,例如 new_str
- IE下专属CSS:<![if !IE]><link rel="stylesheet" type=&qu
- 本文实例讲述了asp.net实现图片以二进制流输出的两种方法。分享给大家供大家参考,具体如下:方法一:System.IO.MemoryStr
- 上篇文章讲了js中的传值和传址 和 函数的作用域.这章我们来探讨js中的变量,表达式,和运算符 还有一些 js 语句。升级中……1, 表达式
- 这个javascript农历日历,万年历代码网上看到的,很不错,功能齐全,值得收藏!功能介绍:动态显示当前世界各国各时区时间,显示当前农历,
- SMTP协议首先了解SMTP(简单邮件传输协议),邮件传送代理程序使用SMTP协议来发送电邮到接收者的邮件服务器。SMTP协议只能用来发送邮
- 因此,在我接触那么多种语言当中,asp是最不严格的一种,是对程序员要求最低的一种。 昨天测试了asp.net、php和asp的运行速度比较,
- 本文实例讲述了php实现mysql事务处理的方法。分享给大家供大家参考。具体分析如下:要实现本功能的条件是环境 mysql 5.2 /php
- 这篇文章主要介绍了Python如何使用Gitlab API实现批量的合并分支,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的
- 使用Django中遇到这样一个需求,对一个表的几个字段做 联合唯一索引,例如学生表中 姓名和班级 2个字段在一起表示一个唯一记录。Djang
- 从开始认识CSS(DW4)那时起,我就知道了CSS的强大,但从未用CSS排版过,因为我曾经尝试过学习,但感觉太难了而且用DW的表格,所见及所
- 首先,大家先去下载一份dvbbs.php beta1的代码,解压后先抛开php代码,找出你的mysql手册,如果没有手册那么就直接看下面的实
- PDO::getAvailableDriversPDO::getAvailableDrivers — 返回一个可用驱动的数组(PHP 5 &
- 在实际处理数据时,因系统内存有限,我们不可能一次把所有数据都导出进行操作,所以需要批量导出依次操作。为了加快运行,我们会采用多线程的方法进行
- 1、唠唠叨叨最近项目中需要Python的打包,看到网上也没有很详细的资料,于是做了一些示例程序。小小的研究了一下,Python如何在Wind
- 作者: Alan Pearce原文: Multi-Column Layouts Climb Out of the Box地址: http:/
- 【摘 要】 我只是提供我几个我认为有助于提高写高性能的asp.net应用程序的技巧,本文提到的提高asp.net性能的技巧只是一个起步,更多
- 数据库,网站运营的基础,网站生存的要素,不管是个人用户还是企业用户都非常依赖网站数据库的支持,然而很多别有用心的攻击者也同样非常&l
- 本文实例讲述了php基于协程实现异步的方法。分享给大家供大家参考,具体如下:github上php的协程大部分是根据这篇文章实现的:http: