用tensorflow实现弹性网络回归算法
作者:xckkcxxck 发布时间:2023-07-21 16:52:18
标签:tensorflow,回归算法
本文实例为大家分享了tensorflow实现弹性网络回归算法,供大家参考,具体内容如下
python代码:
#用tensorflow实现弹性网络算法(多变量)
#使用鸢尾花数据集,后三个特征作为特征,用来预测第一个特征。
#1 导入必要的编程库,创建计算图,加载数据集
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn import datasets
from tensorflow.python.framework import ops
ops.get_default_graph()
sess = tf.Session()
iris = datasets.load_iris()
x_vals = np.array([[x[1], x[2], x[3]] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])
#2 声明学习率,批量大小,占位符和模型变量,模型输出
learning_rate = 0.001
batch_size = 50
x_data = tf.placeholder(shape=[None, 3], dtype=tf.float32) #占位符大小为3
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
A = tf.Variable(tf.random_normal(shape=[3,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))
model_output = tf.add(tf.matmul(x_data, A), b)
#3 对于弹性网络回归算法,损失函数包括L1正则和L2正则
elastic_param1 = tf.constant(1.)
elastic_param2 = tf.constant(1.)
l1_a_loss = tf.reduce_mean(abs(A))
l2_a_loss = tf.reduce_mean(tf.square(A))
e1_term = tf.multiply(elastic_param1, l1_a_loss)
e2_term = tf.multiply(elastic_param2, l2_a_loss)
loss = tf.expand_dims(tf.add(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), e1_term), e2_term), 0)
#4 初始化变量, 声明优化器, 然后遍历迭代运行, 训练拟合得到参数
init = tf.global_variables_initializer()
sess.run(init)
my_opt = tf.train.GradientDescentOptimizer(learning_rate)
train_step = my_opt.minimize(loss)
loss_vec = []
for i in range(1000):
rand_index = np.random.choice(len(x_vals), size=batch_size)
rand_x = x_vals[rand_index]
rand_y = np.transpose([y_vals[rand_index]])
sess.run(train_step, feed_dict={x_data:rand_x, y_target:rand_y})
temp_loss = sess.run(loss, feed_dict={x_data:rand_x, y_target:rand_y})
loss_vec.append(temp_loss)
if (i+1)%250 == 0:
print('Step#' + str(i+1) +'A = ' + str(sess.run(A)) + 'b=' + str(sess.run(b)))
print('Loss= ' +str(temp_loss))
#现在能观察到, 随着训练迭代后损失函数已收敛。
plt.plot(loss_vec, 'k--')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()
本文参考书《Tensorflow机器学习实战指南》
来源:http://blog.csdn.net/xckkcxxck/article/details/78992345
0
投稿
猜你喜欢
- 0.前言Telnet协议属于TCP/IP协议族里的一种,对于我们这些网络攻城狮来说,再熟悉不过了,常用于远程登陆到网络设备进行操作,但是,它
- 原文地址:30 Days of Mootools 1.2 Tutorials - Day 10 - Using FX.TweenMooToo
- Python开发最牛逼的IDE——pycharm(其实其它的工具,例如eclipse也可以写,只不过比较麻烦,需要安装很多的插件,所以说py
- 前言 pycharm默认是没有为我们设置模板信息的,但为了更加方便的实现代码管理,以及能够一目
- Django中上传文件方式。如何实现文件上传功能?1创建项目uploadfile:创建app:front项目设置INSTALLED_APPS
- 1、pyecharts介绍 Echarts是一款由百度公司开发的开源数据可视化JS库,pyecharts是一款使用python调用echar
- 在Web标准中的页面布局是使用Div配合CSS来实现的。这其中最常用到的就是使整个页面水平居中的效果,这是在页面布局中基本,也是最应该首先掌
- 本文实例为大家分享了python实现网上购物系统的具体代码,供大家参考,具体内容如下1.购物商城的需求分析:1、输出欢迎界面还有登录注册菜单
- 步骤——1:定位在通过与客户,或与和客户接触的业务人员交流,做出一个准确的定位.定位的准确与否,虽然不能决定一定通过,但如果定位不准或相差太
- 1.requests库简介requests 是 Python 中比较常用的网页请求库,主要用来发送 HTTP 请求,在使用爬虫或测试服务器响
- 今天用numpy 的linalg.det()求矩阵的逆的过程中出现了一个错误:TypeError: No loop matching the
- 开发时,通常打开Debug模式会快速定位开发时的一些问题。项目开始部署时,关闭Debug模式,url.py路由静态文件和图片写法:# url
- Django中提供了一个关于URL的映射的解决方案,1.客户端的浏览器发起一个url请求,Django根据URL解析,把url中的参数捕获,
- JSON Schema是一个用于验证JSON数据结构的强大工具, 我查看并学习了JSON Schema的官方文档, 做了详细的记录, 分享一
- 前言发现自己学习python 的各种库老是容易忘记,所有想利用这个平台,记录和分享一下学习时候的知识点,以后也能及时的复习,最近学习pand
- PyCharm安装配置Qt Designer+PyUIC教程1、安装依赖命令形式pip install PyQt5pip install p
- 一、高斯滤波 高斯滤波是一种线性平滑滤波,适用于消除高斯噪声,广泛应用于图像处理的减噪过程。 [1] 通俗的讲,高斯滤波就是对整幅图像进
- 前言大家好,我是小张~记得小时候,家里只有一个钟表用来看时间(含有时针、分针、秒针的那种),挂在墙上哒哒哒响个不停,现在生活条件好了、基本人
- 本文实例讲述了Python 类的私有属性和私有方法。分享给大家供大家参考,具体如下:xx:公有变量_xx:公有变量或方法,不能通过impor
- 门限回归模型(Threshold Regressive Model,简称TR模型或TRM)的基本思想是通过门限变量的控制作用,当给出预报因子