TensorFlow实现简单线性回归
作者:kylinxjd 发布时间:2023-09-18 13:23:45
标签:TensorFlow,线性回归
本文实例为大家分享了TensorFlow实现简单线性回归的具体代码,供大家参考,具体内容如下
简单的一元线性回归
一元线性回归公式:
其中x是特征:[x1,x2,x3,…,xn,]T
w是权重,b是偏置值
代码实现
导入必须的包
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
# 屏蔽warning以下的日志信息
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
产生模拟数据
def generate_data():
x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32)
y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30))
return x, y
x是100行1列的数据,tf.matmul是矩阵相乘,所以权值设置成二维的。
设置的w是1.3, b是1
实现回归
def myregression():
"""
自实现线性回归
:return:
"""
x, y = generate_data()
# 建立模型 y = x * w + b
# w 1x1的二维数据
w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a')
b = tf.Variable(0.0, name='bias_b')
y_predict = tf.matmul(x, a) + b
# 建立损失函数
loss = tf.reduce_mean(tf.square(y_predict - y))
# 训练
train_op = tf.train.GradientDescentOptimizer(0.1).minimize(loss=loss)
# 初始化全局变量
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print('初始的权重:%f偏置值:%f' % (a.eval(), b.eval()))
# 训练优化
for i in range(1, 100):
sess.run(train_op)
print('第%d次优化的权重:%f偏置值:%f' % (i, a.eval(), b.eval()))
# 显示回归效果
show_img(x.eval(), y.eval(), y_predict.eval())
使用matplotlib查看回归效果
def show_img(x, y, y_pre):
plt.scatter(x, y)
plt.plot(x, y_pre)
plt.show()
完整代码
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
def generate_data():
x = tf.constant(np.array([i for i in range(0, 100, 5)]).reshape(-1, 1), tf.float32)
y = tf.add(tf.matmul(x, [[1.3]]) + 1, tf.random_normal([20, 1], stddev=30))
return x, y
def myregression():
"""
自实现线性回归
:return:
"""
x, y = generate_data()
# 建立模型 y = x * w + b
w = tf.Variable(tf.random_normal([1, 1], mean=0.0, stddev=1.0), name='weight_a')
b = tf.Variable(0.0, name='bias_b')
y_predict = tf.matmul(x, w) + b
# 建立损失函数
loss = tf.reduce_mean(tf.square(y_predict - y))
# 训练
train_op = tf.train.GradientDescentOptimizer(0.0001).minimize(loss=loss)
init_op = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init_op)
print('初始的权重:%f偏置值:%f' % (w.eval(), b.eval()))
# 训练优化
for i in range(1, 35000):
sess.run(train_op)
print('第%d次优化的权重:%f偏置值:%f' % (i, w.eval(), b.eval()))
show_img(x.eval(), y.eval(), y_predict.eval())
def show_img(x, y, y_pre):
plt.scatter(x, y)
plt.plot(x, y_pre)
plt.show()
if __name__ == '__main__':
myregression()
看看训练的结果(因为数据是随机产生的,每次的训练结果都会不同,可适当调节梯度下降的学习率和训练步数)
35000次的训练结果
来源:https://blog.csdn.net/kylinxjd/article/details/105557304


猜你喜欢
- 1、安装virtulenv、virtulenvwrapper包pip install virtualenv virtualenvwrappe
- 一、Ajax简介Ajax被认为是(Asynchronous JavaScript and XML)的缩写,允许浏览器与服务器通信而无需刷新当
- require 方法的加载规则优先从缓存中加载核心模块路径形式的模块第三方模块一、优先从缓存中加载main.js:执行加载a.js模块req
- 导读:这篇论坛文章主要介绍了使用SQL Server升级顾问的具体步骤,详细内容请参考下文。微软提供了SQL Server 2008升级顾问
- 本文实例讲述了PHP 对象继承原理与简单用法。分享给大家供大家参考,具体如下:对象继承继承已为大家所熟知的一个程序设计特性,PHP 的对象模
- 第一次写博客,实属心血来潮。为什么要写这篇博客呢?原因如下1、有一次我想配置数据库端口号时,找不到对应的解决方案2、是时候有个地方可以记录一
- 重试指的是当加载出错时,有能力重新发起加载组件的请求。异步组件加载失败后的重试机制,与请求服务端接口失败后的重试机制一样。所以,先来讨论接口
- mysqld_safe脚本执行的基本流程:1、查找basedir和ledir。2、查找datadir和my.cnf。3、对my.cnf做一些
- 昨日,女票拿了一个Excel文档,里面有上万条数据要进行分析,刚开始一个字段分析,Excel用的不错,还能搞定,到后来两个字段的分析,还有区
- 正则表达式的使用想要学习 Python 爬虫 , 首先需要了解一下正则表达式的使用,下面我们就来看看如何使用。. 的使用这个时候的点就相当于
- 前言上一篇博文我们讲到了节流函数的应用场景,我们知道了节流函数可以用在模糊查询、scroller、onresize等场景;今天这篇我们来讲防
- 在用python绘图的时候,经常由于数据的原因导致画出来的图折线分界过于明显,因此需要对原数据绘制的折线进行平滑处理,本文介绍利用插值法进行
- 普通MySQL运行,数据量和访问量不大的话,是足够快的,但是当数据量和访问量剧增的时候,那么就会明显发现MySQL很慢,甚至do
- 面对网络不稳定,页面更新等问题,很可能出现程序异常的问题,所以我们要对程序进行一些异常处理。大家可能觉得处理异常是一个比较麻烦的活,但在面对
- 代码在ext里的src\core\ext.js下 最新的ext3.0beat1的代码如下: ua = navigator.userAgent
- 遇到问题nohup python flush.py &这样运行,生成了nohup.out文件,但是内容始终是空的,试了半天也不行。浪
- python的scipy.stats模块是连续型随机变量的公共方法,可以产生随机数,通常是以正态分布作为scipy.stats的基本使用方法
- 实现1)有相同的数据,直接返回(返回值:0);2)有主键相同,但是数据不同的数据,进行更新处理(返回值:2);3)没有数据,进行插入数据处理
- 一、前期准备CREATE TABLE `t1` ( `id` int(11) NOT NULL AUTO_INCREMENT,
- 项目介绍背景:DC竞赛比赛项目,运用回归模型进 * 价预测。数据介绍:数据主要包括2014年5月至2015年5月美国King County的房