python梯度下降算法的实现
作者:epleone 发布时间:2022-01-25 11:11:09
标签:python,梯度下降
本文实例为大家分享了python实现梯度下降算法的具体代码,供大家参考,具体内容如下
简介
本文使用python实现了梯度下降算法,支持y = Wx+b的线性回归
目前支持批量梯度算法和随机梯度下降算法(bs=1)
也支持输入特征向量的x维度小于3的图像可视化
代码要求python版本>3.4
代码
'''
梯度下降算法
Batch Gradient Descent
Stochastic Gradient Descent SGD
'''
__author__ = 'epleone'
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys
# 使用随机数种子, 让每次的随机数生成相同,方便调试
# np.random.seed(111111111)
class GradientDescent(object):
eps = 1.0e-8
max_iter = 1000000 # 暂时不需要
dim = 1
func_args = [2.1, 2.7] # [w_0, .., w_dim, b]
def __init__(self, func_arg=None, N=1000):
self.data_num = N
if func_arg is not None:
self.FuncArgs = func_arg
self._getData()
def _getData(self):
x = 20 * (np.random.rand(self.data_num, self.dim) - 0.5)
b_1 = np.ones((self.data_num, 1), dtype=np.float)
# x = np.concatenate((x, b_1), axis=1)
self.x = np.concatenate((x, b_1), axis=1)
def func(self, x):
# noise太大的话, 梯度下降法失去作用
noise = 0.01 * np.random.randn(self.data_num) + 0
w = np.array(self.func_args)
# y1 = w * self.x[0, ] # 直接相乘
y = np.dot(self.x, w) # 矩阵乘法
y += noise
return y
@property
def FuncArgs(self):
return self.func_args
@FuncArgs.setter
def FuncArgs(self, args):
if not isinstance(args, list):
raise Exception(
'args is not list, it should be like [w_0, ..., w_dim, b]')
if len(args) == 0:
raise Exception('args is empty list!!')
if len(args) == 1:
args.append(0.0)
self.func_args = args
self.dim = len(args) - 1
self._getData()
@property
def EPS(self):
return self.eps
@EPS.setter
def EPS(self, value):
if not isinstance(value, float) and not isinstance(value, int):
raise Exception("The type of eps should be an float number")
self.eps = value
def plotFunc(self):
# 一维画图
if self.dim == 1:
# x = np.sort(self.x, axis=0)
x = self.x
y = self.func(x)
fig, ax = plt.subplots()
ax.plot(x, y, 'o')
ax.set(xlabel='x ', ylabel='y', title='Loss Curve')
ax.grid()
plt.show()
# 二维画图
if self.dim == 2:
# x = np.sort(self.x, axis=0)
x = self.x
y = self.func(x)
xs = x[:, 0]
ys = x[:, 1]
zs = y
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.scatter(xs, ys, zs, c='r', marker='o')
ax.set_xlabel('X Label')
ax.set_ylabel('Y Label')
ax.set_zlabel('Z Label')
plt.show()
else:
# plt.axis('off')
plt.text(
0.5,
0.5,
"The dimension(x.dim > 2) \n is too high to draw",
size=17,
rotation=0.,
ha="center",
va="center",
bbox=dict(
boxstyle="round",
ec=(1., 0.5, 0.5),
fc=(1., 0.8, 0.8), ))
plt.draw()
plt.show()
# print('The dimension(x.dim > 2) is too high to draw')
# 梯度下降法只能求解凸函数
def _gradient_descent(self, bs, lr, epoch):
x = self.x
# shuffle数据集没有必要
# np.random.shuffle(x)
y = self.func(x)
w = np.ones((self.dim + 1, 1), dtype=float)
for e in range(epoch):
print('epoch:' + str(e), end=',')
# 批量梯度下降,bs为1时 等价单样本梯度下降
for i in range(0, self.data_num, bs):
y_ = np.dot(x[i:i + bs], w)
loss = y_ - y[i:i + bs].reshape(-1, 1)
d = loss * x[i:i + bs]
d = d.sum(axis=0) / bs
d = lr * d
d.shape = (-1, 1)
w = w - d
y_ = np.dot(self.x, w)
loss_ = abs((y_ - y).sum())
print('\tLoss = ' + str(loss_))
print('拟合的结果为:', end=',')
print(sum(w.tolist(), []))
print()
if loss_ < self.eps:
print('The Gradient Descent algorithm has converged!!\n')
break
pass
def __call__(self, bs=1, lr=0.1, epoch=10):
if sys.version_info < (3, 4):
raise RuntimeError('At least Python 3.4 is required')
if not isinstance(bs, int) or not isinstance(epoch, int):
raise Exception(
"The type of BatchSize/Epoch should be an integer number")
self._gradient_descent(bs, lr, epoch)
pass
pass
if __name__ == "__main__":
if sys.version_info < (3, 4):
raise RuntimeError('At least Python 3.4 is required')
gd = GradientDescent([1.2, 1.4, 2.1, 4.5, 2.1])
# gd = GradientDescent([1.2, 1.4, 2.1])
print("要拟合的参数结果是: ")
print(gd.FuncArgs)
print("===================\n\n")
# gd.EPS = 0.0
gd.plotFunc()
gd(10, 0.01)
print("Finished!")
来源:https://blog.csdn.net/epleone/article/details/78843595


猜你喜欢
- python如何为创建大量实例节省内存,具体内容如下案例:某网络游戏中,定义了玩家类Player(id, name, status,....
- 本文实例讲述了mysql索引基数概念与用法。分享给大家供大家参考,具体如下:Cardinality(索引基数)是mysql索引很重要的一个概
- php随机数生成一个给定范围的随机数,用 PHP 就太简单不过了,而且可以指定从负数到正整数的范围,如:<?phpecho mt_ra
- 在mac下载安装prometheus在https://prometheus.io/download/下载prometheus放到自定义的位置
- 前言这篇文章将为大家介绍:GoFrame 错误处理的常用方法&错误码的使用。如何自定义错误对象、如何忽略部分堆栈信息、如何自定义错误
- 安装去http://www.mysql.com/downloads/, 选择最下方的MySQL Community Edition,点击My
- 你知道SQL Server这么庞大的企业级数据库服务器产品是如何build出来的吗?这有些相关的数据:每个build 的大小在300GB左右
- 一、前言越来越多的网站和App开始为用户搭建签到系统,以此来吸引和留住用户。签到系统是一种轻量、互动性强的营销方式,通过用户签到获取免费权益
- javascript Date.getUTCDay()方法按照通用时间在指定日期返回星期几。通过getUTCDay返回的值是对应
- 写作思路1、简述实现原理2、部分代码解析3、位置同步解析(①上下两屏位置同步②编辑屏位置保持不变)效果图如下:版本1:这就是我们常见的预览窗
- 如下 <!DOCTYPE HTML> <html> <head> <meta charset=&q
- 1.自定义线程池import threadingimport Queueimport timequeue = Queue.Queue()de
- SQLAlchemy是Python编程语言下的一款开源软件,提供了SQL工具包及对象关系映射(ORM)工具,使用MIT许可证发行。SQLAl
- 这段时间服务器崩溃2次,一直没有找到原因,今天看到论坛发出的错误信息邮件,想起可能是mysql的默认连接数引起的问题,一查果然,老天,默认
- 模版基本介绍模板是一个文本,用于分离文档的表现形式和内容。 模板定义了占位符以及各种用于规范文档该如何显示的各部分基本逻辑(模板标签)。 模
- 前言嗨,彦祖们,不会过圣诞了还是一个人吧?今天我们来讲一下如何用python来画一个圣诞树,学会就快给那个她发过去吧,我的朋友圈已经让圣诞树
- 无意间碰到的一个大神整理的Python学习思维导图,感觉对初学者理清学习思路大有裨益,非常感谢他的分享。14 张思维导图基础知识数据类型序列
- 0. 前言周日在爬一个国外网站的时候,发现用协程并发请求,并且请求次数太快的时候,会出现对方把我的服务器IP封掉的情况。于是网上找了一下开源
- 问题描述vscode中跨目录的模块调用远不如pycharm中的来的简单,在pycharm中即使是不同库文件夹中子函数也可以进行互相调用。而在
- JavaScript lastIndexOf 方法lastIndexOf 方法用于计算指定的字符串在整个字符串中最后一次出现的位置,并返回该