Python 实现3种回归模型(Linear Regression,Lasso,Ridge)的示例
作者:农大鲁迅 发布时间:2021-06-17 20:46:53
标签:python,回归模型,Linear,Regression,lasso,ridge
公共的抽象基类
import numpy as np
from abc import ABCMeta, abstractmethod
class LinearModel(metaclass=ABCMeta):
"""
Abstract base class of Linear Model.
"""
def __init__(self):
# Before fit or predict, please transform samples' mean to 0, var to 1.
self.scaler = StandardScaler()
@abstractmethod
def fit(self, X, y):
"""fit func"""
def predict(self, X):
# before predict, you must run fit func.
if not hasattr(self, 'coef_'):
raise Exception('Please run `fit` before predict')
X = self.scaler.transform(X)
X = np.c_[np.ones(X.shape[0]), X]
# `x @ y` == `np.dot(x, y)`
return X @ self.coef_
Linear Regression
class LinearRegression(LinearModel):
"""
Linear Regression.
"""
def __init__(self):
super().__init__()
def fit(self, X, y):
"""
:param X_: shape = (n_samples + 1, n_features)
:param y: shape = (n_samples])
:return: self
"""
self.scaler.fit(X)
X = self.scaler.transform(X)
X = np.c_[np.ones(X.shape[0]), X]
self.coef_ = np.linalg.inv(X.T @ X) @ X.T @ y
return self
Lasso
class Lasso(LinearModel):
"""
Lasso Regression, training by Coordinate Descent.
cost = ||X @ coef_||^2 + alpha * ||coef_||_1
"""
def __init__(self, alpha=1.0, n_iter=1000, e=0.1):
self.alpha = alpha
self.n_iter = n_iter
self.e = e
super().__init__()
def fit(self, X, y):
self.scaler.fit(X)
X = self.scaler.transform(X)
X = np.c_[np.ones(X.shape[0]), X]
self.coef_ = np.zeros(X.shape[1])
for _ in range(self.n_iter):
z = np.sum(X * X, axis=0)
tmp = np.zeros(X.shape[1])
for k in range(X.shape[1]):
wk = self.coef_[k]
self.coef_[k] = 0
p_k = X[:, k] @ (y - X @ self.coef_)
if p_k < -self.alpha / 2:
w_k = (p_k + self.alpha / 2) / z[k]
elif p_k > self.alpha / 2:
w_k = (p_k - self.alpha / 2) / z[k]
else:
w_k = 0
tmp[k] = w_k
self.coef_[k] = wk
if np.linalg.norm(self.coef_ - tmp) < self.e:
break
self.coef_ = tmp
return self
Ridge
class Ridge(LinearModel):
"""
Ridge Regression.
"""
def __init__(self, alpha=1.0):
self.alpha = alpha
super().__init__()
def fit(self, X, y):
"""
:param X_: shape = (n_samples + 1, n_features)
:param y: shape = (n_samples])
:return: self
"""
self.scaler.fit(X)
X = self.scaler.transform(X)
X = np.c_[np.ones(X.shape[0]), X]
self.coef_ = np.linalg.inv(
X.T @ X + self.alpha * np.eye(X.shape[1])) @ X.T @ y
return self
测试代码
import matplotlib.pyplot as plt
import numpy as np
def gen_reg_data():
X = np.arange(0, 45, 0.1)
X = X + np.random.random(size=X.shape[0]) * 20
y = 2 * X + np.random.random(size=X.shape[0]) * 20 + 10
return X, y
def test_linear_regression():
clf = LinearRegression()
X, y = gen_reg_data()
clf.fit(X, y)
plt.plot(X, y, '.')
X_axis = np.arange(-5, 75, 0.1)
plt.plot(X_axis, clf.predict(X_axis))
plt.title("Linear Regression")
plt.show()
def test_lasso():
clf = Lasso()
X, y = gen_reg_data()
clf.fit(X, y)
plt.plot(X, y, '.')
X_axis = np.arange(-5, 75, 0.1)
plt.plot(X_axis, clf.predict(X_axis))
plt.title("Lasso")
plt.show()
def test_ridge():
clf = Ridge()
X, y = gen_reg_data()
clf.fit(X, y)
plt.plot(X, y, '.')
X_axis = np.arange(-5, 75, 0.1)
plt.plot(X_axis, clf.predict(X_axis))
plt.title("Ridge")
plt.show()
测试效果
更多机器学习代码,请访问 https://github.com/WiseDoge/plume
来源:https://www.jianshu.com/p/997e0ee1e010


猜你喜欢
- 在asp中调用sql server的存储过程可以加快程序运行速度,本文介绍了asp使用存储过程的方法。1.调用存储过程的一般方法 先假设在s
- 1 集合集合可以使用大括号({})或者set()函数进行创建,但是创建一个空集合必须使用set()函数,而不能用{},大括号是用来创建一个空
- 今天学到了如何使用Python的smtplib库发送邮件,中间也是遇到了各种各样的错误和困难,还好都一一的解决了。下面来谈一谈我的这段经历。
- Hypothesis是Python的一个高级测试库。它允许编写测试用例时参数化,然后生成使测试失败的简单易懂的测试数据。可以用更少的工作在代
- 初试牛刀假设你希望学习Python这门语言,却苦于找不到一个简短而全面的入门教程。那么本教程将花费十分钟的时间带你走入Python的大门。本
- 1. Express简介express是一个基于node.js平台的极简,灵活的web应用开发框架,它提供一系列强大的特征,帮助你创建各种w
- 本文实例讲述了Python数据分析之双色球统计两个红和蓝球哪组合比例高的方法。分享给大家供大家参考,具体如下:统计两个红球和蓝球,哪个组合最
- QUICKSORT(A, p, r)是快速排序的子程序,调用划分程序对数组进行划分,然后递归地调用QUICKSORT(A, p, r),以完
- 项目一开始的设计很重要,django中app的名称建议用小写我的博客由两个app组成,Blog和JiaBlog,总觉得不美观,想改成小写的o
- 1 为什么要分库分表物理服务机的CPU、内存、存储设备、连接数等资源有限,某个时段大量连接同时执行操作,会导致数据库在处理上遇到性能瓶颈。为
- 本文实例讲述了原生JavaScript实现的简单省市县 * 联动功能。分享给大家供大家参考,具体如下: * 联动是我们写表单时必不可少的,比如在
- 1、通过复制数据构造张量1.1 torch.tensor()torch.tensor([[0.1, 1.2], [2.2, 3.1], [4
- ①GET# -*- coding:utf-8 -*-import requestsdef get(url, datas=None): &nb
- 1. 引言当我们设计软件时,我们通常会花费大量精力来编写高质量的代码。但这往往还不够,一个好的软件还应该考虑其整个系统,如测试、部署、网络等
- 本文描述通过统计分析出医院信息系统需分区的表,对需分区的表选择分区键,即找出包括在你的分区键中的列(表的属性),对大型数据的管理比较有意义,
- Anaconda 实际上是一个软件发行版,它附带了conda、Python和150多个科学包及其依赖项。其中,conda是一个开源的软件包管
- pylint是一个不错的代码静态检查工具。将其配置在pycharm中,随时对代码进行分析,确保所有代码都符合pep8规范,以便于养成良好的习
- 列名用了中文的缘故,设置pandas的参数即可,代码如下: import pandas as pd #这两个参数的默认设置都是False p
- 如何实现在下拉菜单里输入文字? 用这个代码试试看,应该可以的:<script>function pp(){se.opt
- 我们知道,session是一种会话技术,用来实现跨脚本共享数据或者检测跟踪用户状态。session的工作原理(1)当一个session第一次