用tensorflow实现弹性网络回归算法
作者:xckkcxxck 发布时间:2023-07-21 16:52:18
标签:tensorflow,回归算法
本文实例为大家分享了tensorflow实现弹性网络回归算法,供大家参考,具体内容如下
python代码:
#用tensorflow实现弹性网络算法(多变量)
#使用鸢尾花数据集,后三个特征作为特征,用来预测第一个特征。
#1 导入必要的编程库,创建计算图,加载数据集
import matplotlib.pyplot as plt
import tensorflow as tf
import numpy as np
from sklearn import datasets
from tensorflow.python.framework import ops
ops.get_default_graph()
sess = tf.Session()
iris = datasets.load_iris()
x_vals = np.array([[x[1], x[2], x[3]] for x in iris.data])
y_vals = np.array([y[0] for y in iris.data])
#2 声明学习率,批量大小,占位符和模型变量,模型输出
learning_rate = 0.001
batch_size = 50
x_data = tf.placeholder(shape=[None, 3], dtype=tf.float32) #占位符大小为3
y_target = tf.placeholder(shape=[None, 1], dtype=tf.float32)
A = tf.Variable(tf.random_normal(shape=[3,1]))
b = tf.Variable(tf.random_normal(shape=[1,1]))
model_output = tf.add(tf.matmul(x_data, A), b)
#3 对于弹性网络回归算法,损失函数包括L1正则和L2正则
elastic_param1 = tf.constant(1.)
elastic_param2 = tf.constant(1.)
l1_a_loss = tf.reduce_mean(abs(A))
l2_a_loss = tf.reduce_mean(tf.square(A))
e1_term = tf.multiply(elastic_param1, l1_a_loss)
e2_term = tf.multiply(elastic_param2, l2_a_loss)
loss = tf.expand_dims(tf.add(tf.add(tf.reduce_mean(tf.square(y_target - model_output)), e1_term), e2_term), 0)
#4 初始化变量, 声明优化器, 然后遍历迭代运行, 训练拟合得到参数
init = tf.global_variables_initializer()
sess.run(init)
my_opt = tf.train.GradientDescentOptimizer(learning_rate)
train_step = my_opt.minimize(loss)
loss_vec = []
for i in range(1000):
rand_index = np.random.choice(len(x_vals), size=batch_size)
rand_x = x_vals[rand_index]
rand_y = np.transpose([y_vals[rand_index]])
sess.run(train_step, feed_dict={x_data:rand_x, y_target:rand_y})
temp_loss = sess.run(loss, feed_dict={x_data:rand_x, y_target:rand_y})
loss_vec.append(temp_loss)
if (i+1)%250 == 0:
print('Step#' + str(i+1) +'A = ' + str(sess.run(A)) + 'b=' + str(sess.run(b)))
print('Loss= ' +str(temp_loss))
#现在能观察到, 随着训练迭代后损失函数已收敛。
plt.plot(loss_vec, 'k--')
plt.title('Loss per Generation')
plt.xlabel('Generation')
plt.ylabel('Loss')
plt.show()
本文参考书《Tensorflow机器学习实战指南》
来源:http://blog.csdn.net/xckkcxxck/article/details/78992345


猜你喜欢
- 1、将一个字典输入:该字典必须满足:value是一个list类型的元素,且每一个key对应的value长度都相同:(以该字典的key为col
- 1.如何将Query String传送到另一个ASP文件去?Response.Redirect("second.asp? 
- <?php # 设置 $domain 为你的域名 (注意没有www) $domain = "aspxhome.com&quo
- 关于python的ssh库操作需要引入一个远程控制的模块——paramiko,可用于对远程服务器进行
- vue更新到2.0之后,作者就宣告不再对vue-resource更新,而是推荐使用axios。前段时间第一次在项目里用到vue,关于登陆问题
- 前言数学建模的介绍与作用全国大学生数学建模竞赛:全国大学生数学建模竞赛创办于1992年,每年一届,已成为全国高校规模最大的基础性学科竞赛,也
- 话不多说 直接上代码<el-upload :action="actionUrl&q
- 说明如果你的项目流量非常小,完全不用担心有并发的购买请求,那么做这样一个系统意义不大。但如果你的系统要像12306那样,接受高并发访问和下单
- 一组常用的弹出窗口用法,以下代码集合常用的弹出窗口用法。1、最基本的弹出窗口代码<SCRIPT LANGUAGE="
- 在SQL Server中Count(*)或者Count(1)或者Count([列])或许是最常用的聚合函数。很多人其实对这三者之间是区分不清
- 这篇文章主要介绍了python打包成so文件过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友
- 小渣渣复现大佬project发现GPU跑不动,出现如下报错:RuntimeError: CUDA out of memory.看下来最简单粗
- 引言通过前面的文章我们已经了解到OpenCV 是一个用于计算机视觉和机器学习的开源 python 库。它主要针对实时计算机视觉和图像处理。它
- 前言今晚就是新年夜啦,为了 刷一波存在感 送出我的祝福,同时让它看起来不像群发消息,我们简单地用三步来实现定制QQ祝福~
- 示例代码,用到了函数substr与iconv_substr,mb_substr<html><head><met
- 我们平常在网页上显示的字体最小一般是12PX,当小于10PX时,显示的效果就大打折扣了,因为中文默认的字体是宋体,当小于12PX时的效果如下
- 您是否记得关闭所有的XHTML元素,在HTML中一些元素没有必要被关闭。当下一个元素开始的时候,上一个元素就自动被关闭。XHTML中是不允许
- 工欲善其事必先利其器,Pycharm 是最受欢迎的Python开发工具,它提供的功能非常强大,是构建大型项目的理想工具之一,如果能挖掘出里面
- 本文实例为大家分享了js动态时间显示 的具体代码,供大家参考,具体内容如下<!doctype html><html>
- 上一章节我们学习了基础的定义 PPT 的方法以及每一页中的样式,这节课我们将真正的在 PPT 中添加内容,学习一下 pptx 的段落的使用。