python 还原梯度下降算法实现一维线性回归
作者:Mchael菜鸟 发布时间:2023-10-09 21:53:42
标签:python,一维,线性回归
首先我们看公式:
这个是要拟合的函数
然后我们求出它的损失函数, 注意:这里的n和m均为数据集的长度,写的时候忘了
注意,前面的theta0-theta1x是实际值,后面的y是期望值
接着我们求出损失函数的偏导数:
最终,梯度下降的算法:
学习率一般小于1,当损失函数是0时,我们输出theta0和theta1.
接下来上代码!
class LinearRegression():
def __init__(self, data, theta0, theta1, learning_rate):
self.data = data
self.theta0 = theta0
self.theta1 = theta1
self.learning_rate = learning_rate
self.length = len(data)
# hypothesis
def h_theta(self, x):
return self.theta0 + self.theta1 * x
# cost function
def J(self):
temp = 0
for i in range(self.length):
temp += pow(self.h_theta(self.data[i][0]) - self.data[i][1], 2)
return 1 / (2 * self.m) * temp
# partial derivative
def pd_theta0_J(self):
temp = 0
for i in range(self.length):
temp += self.h_theta(self.data[i][0]) - self.data[i][1]
return 1 / self.m * temp
def pd_theta1_J(self):
temp = 0
for i in range(self.length):
temp += (self.h_theta(data[i][0]) - self.data[i][1]) * self.data[i][0]
return 1 / self.m * temp
# gradient descent
def gd(self):
min_cost = 0.00001
round = 1
max_round = 10000
while min_cost < abs(self.J()) and round <= max_round:
self.theta0 = self.theta0 - self.learning_rate * self.pd_theta0_J()
self.theta1 = self.theta1 - self.learning_rate * self.pd_theta1_J()
print('round', round, ':\t theta0=%.16f' % self.theta0, '\t theta1=%.16f' % self.theta1)
round += 1
return self.theta0, self.theta1
def main():
data = [[1, 2], [2, 5], [4, 8], [5, 9], [8, 15]] # 这里换成你想拟合的数[x, y]
# plot scatter
x = []
y = []
for i in range(len(data)):
x.append(data[i][0])
y.append(data[i][1])
plt.scatter(x, y)
# gradient descent
linear_regression = LinearRegression(data, theta0, theta1, learning_rate)
theta0, theta1 = linear_regression.gd()
# plot returned linear
x = np.arange(0, 10, 0.01)
y = theta0 + theta1 * x
plt.plot(x, y)
plt.show()
来源:https://blog.csdn.net/weixin_46490003/article/details/109184418
0
投稿
猜你喜欢
- default-character-set=gbk #或gb2312,big5,utf8 然后重新启动mysql 运行->servic
- 这里要注意的是js的时间戳是13位,php的时间戳是10位,转换函数如下: var nowtime = (new Date).getTime
- 微软今天宣布正式发布SQL Server 2008服务器软件,这将帮助微软与Oracle 11g,IBM DB2 9.5数据库产品对抗.此前
- 最近开始在项目中使用Quickwork For Asp,虽然该框架是自己独立完成的,不过功能没做过详细的总结,所以很多参数总是会弄错,毕竟鱼
- 先想创意,再画草图,接着鼠绘,最后做成flas * 。这是我的习惯流程。 这是想到中秋时,我第一时间内能浮想出的图像:大意是嫦娥奔
- 简单计数器代码如下所示:<% Set fs = CreateObject("Scri
- 1.JOIN和UNION区别 join 是两张表做交连后里面条件相同的部分记录产生一个记录集, union是产生的两个记录集(字段要一样的)
- 2010新的架构工具可以让我们了解应用程序和功能设计,并帮助验证设计和执行不偏离。它除了支持一般系统分析设计流程(需求→实体)外,也支持另一
- 一、object类的源码class object: """ The most bas
- 目前,我们要在网页中使用圆角效果,总是通过切图然后嵌套很多div,用背景来实现圆角效果。对于前端开发工程师来说,圆角的确是一个让人又爱又恨的
- 前言ThinkPHP,是为了简化企业级应用开发和敏捷WEB应用开发而诞生的开源轻量级PHP框架。随着框架代码量的增加,一些潜在的威胁也逐渐暴
- 在一群里有朋友发问,有时间,也就看看了,不多说了,看图了:用一般的 select .... order 排序出来,就如下图了,是
- 1.DNS查询过程:以查询 www.baidu.com为例(1)电脑向本地域名服务器发送解析www.baidu.com的请求(2)本地域名服
- PHP mysqli_thread_id() 函数返回当前连接的线程 ID,然后杀死连接:<?php// 假定数据库用户名:root,
- php中主要用到的就是要用到fread()和fwirte()。而静态页面生成了之后,就会牵扯到修改的问题。这里可以用到正则匹配的方法来替换模
- 现代浏览器可以基于RFC 2397标准使用base64把图片进行编码,然后输出类似data:image/png;base64,iVBORw0
- 阅读上一章:Chapter 14 图片替换Chapter 15 为<body>指定样式把内容与显示效果分开设定的好处之一就是灵活
- think-queue是ThinkPHP官方提供的一个消息队列服务,是专门支持队列服务的扩展包。think-queue消息队列适用于大并发或
- JS代码:function showFlash(src,w,h){ html&nbs
- 上传问题可以说是网络编程中经常遇到的,也是一个很重要的问题,我们不仅要实现上传文件,图片等基本功能,还有考虑到上传程序的安全性,本文介绍了一