python 还原梯度下降算法实现一维线性回归
作者:Mchael菜鸟 发布时间:2023-10-09 21:53:42
标签:python,一维,线性回归
首先我们看公式:
这个是要拟合的函数
然后我们求出它的损失函数, 注意:这里的n和m均为数据集的长度,写的时候忘了
注意,前面的theta0-theta1x是实际值,后面的y是期望值
接着我们求出损失函数的偏导数:
最终,梯度下降的算法:
学习率一般小于1,当损失函数是0时,我们输出theta0和theta1.
接下来上代码!
class LinearRegression():
def __init__(self, data, theta0, theta1, learning_rate):
self.data = data
self.theta0 = theta0
self.theta1 = theta1
self.learning_rate = learning_rate
self.length = len(data)
# hypothesis
def h_theta(self, x):
return self.theta0 + self.theta1 * x
# cost function
def J(self):
temp = 0
for i in range(self.length):
temp += pow(self.h_theta(self.data[i][0]) - self.data[i][1], 2)
return 1 / (2 * self.m) * temp
# partial derivative
def pd_theta0_J(self):
temp = 0
for i in range(self.length):
temp += self.h_theta(self.data[i][0]) - self.data[i][1]
return 1 / self.m * temp
def pd_theta1_J(self):
temp = 0
for i in range(self.length):
temp += (self.h_theta(data[i][0]) - self.data[i][1]) * self.data[i][0]
return 1 / self.m * temp
# gradient descent
def gd(self):
min_cost = 0.00001
round = 1
max_round = 10000
while min_cost < abs(self.J()) and round <= max_round:
self.theta0 = self.theta0 - self.learning_rate * self.pd_theta0_J()
self.theta1 = self.theta1 - self.learning_rate * self.pd_theta1_J()
print('round', round, ':\t theta0=%.16f' % self.theta0, '\t theta1=%.16f' % self.theta1)
round += 1
return self.theta0, self.theta1
def main():
data = [[1, 2], [2, 5], [4, 8], [5, 9], [8, 15]] # 这里换成你想拟合的数[x, y]
# plot scatter
x = []
y = []
for i in range(len(data)):
x.append(data[i][0])
y.append(data[i][1])
plt.scatter(x, y)
# gradient descent
linear_regression = LinearRegression(data, theta0, theta1, learning_rate)
theta0, theta1 = linear_regression.gd()
# plot returned linear
x = np.arange(0, 10, 0.01)
y = theta0 + theta1 * x
plt.plot(x, y)
plt.show()
来源:https://blog.csdn.net/weixin_46490003/article/details/109184418


猜你喜欢
- 主要记录两种不同的beam search版本版本一使用类似层次遍历的方式进行搜索,用队列进行维护,每次循环对当前层的所有节点进行搜索,这些节
- 前言相信大家在日常的web开发中,作为前端经常会遇到处理图片拉伸问题的情况。例如banner、图文列表、头像等所有和用户或客户自主操作图片上
- 上下班打卡是程序员最讨厌的东西,更讨厌的是设置了连上指定wifi打卡。手机上有一些定时机器人之类的app,经过实际测试,全军覆没,没一个可以
- 装饰器这东西我看了一会儿才明白,在函数外面套了一层函数,感觉和java里的aop功能很像;写了2个装饰器日志的例子,第一个是不带参数的装饰器
- 前言: 上一篇讲了Python排序问题中比较经典的三个方法,(链接:关于Python排
- 本文实例讲述了wxPython框架类和面板类的使用方法,分享给大家供大家参考。具体分析如下:实现代码如下:import wx c
- 背景:pytest以特定规则搜索测试用例,所以测试用例文件、测试类以及类中的方法、测试函数这些命名都必须符合规则,才能被pytest搜索到并
- var header1 = document.getElementById("header"); var p = doc
- 某些时候我们需要让类动态的添加属性或方法,比如我们在做插件时就可以采用这种方法。用一个配置文件指定需要加载的模块,可以根据业务扩展任意加入需
- 现在的域名提供已经取消免费的url转发功能,而且我们一般主要用的是带www的域名,以前不带www的域名一般是做url转发跳转到带www的域名
- 无限循环如果条件判断语句永远为 true,循环将会无限的执行下去,如下实例:#!/usr/bin/python# -*- coding: U
- 一、mysqldump 简介mysqldump 是 MySQL 自带的逻辑备份工具。MySQLdump是一个数据库逻辑备份程序,
- 在日常生活中,随机数对于我们而言并不陌生,例如手机短信验证码就是一个随机的数字字符串;对于统计分析、机器学习等领域而言,通常也需要生成大量的
- 在分析python的参数传递是如何进行的之前,我们需要先来了解一下,python变量和赋值的基本原理,这样有助于我们更好的理解参数传递。py
- 本月的每月挑战会主题是NLP,我们会在本文帮你开启一种可能:使用pandas和python的自然语言工具包分析你Gmail邮箱中的内容。NL
- glob模块实例详解glob的应用场景是要寻找一系列(符合特定规则)文件名。glob模块是最简单的模块之一,内容非常少。用它可以查找符合特定
- 在用plt.imshow和cv2.imshow显示同一幅图时可能会出现颜色差别很大的现象。这是因为:opencv的接口使用BGR,而matp
- 本文实例讲述了Python实现复杂对象转JSON的方法。分享给大家供大家参考,具体如下:在Python对于简单的对象转json还是比较简单的
- <!doctype html><html><head><meta charset="ut
- 本文实例讲述了Python3.5内置模块之shelve模块、xml模块、configparser模块、hashlib、hmac模块用法。分享