Python 实现3种回归模型(Linear Regression,Lasso,Ridge)的示例
作者:农大鲁迅 发布时间:2021-06-17 20:46:53
标签:python,回归模型,Linear,Regression,lasso,ridge
公共的抽象基类
import numpy as np
from abc import ABCMeta, abstractmethod
class LinearModel(metaclass=ABCMeta):
"""
Abstract base class of Linear Model.
"""
def __init__(self):
# Before fit or predict, please transform samples' mean to 0, var to 1.
self.scaler = StandardScaler()
@abstractmethod
def fit(self, X, y):
"""fit func"""
def predict(self, X):
# before predict, you must run fit func.
if not hasattr(self, 'coef_'):
raise Exception('Please run `fit` before predict')
X = self.scaler.transform(X)
X = np.c_[np.ones(X.shape[0]), X]
# `x @ y` == `np.dot(x, y)`
return X @ self.coef_
Linear Regression
class LinearRegression(LinearModel):
"""
Linear Regression.
"""
def __init__(self):
super().__init__()
def fit(self, X, y):
"""
:param X_: shape = (n_samples + 1, n_features)
:param y: shape = (n_samples])
:return: self
"""
self.scaler.fit(X)
X = self.scaler.transform(X)
X = np.c_[np.ones(X.shape[0]), X]
self.coef_ = np.linalg.inv(X.T @ X) @ X.T @ y
return self
Lasso
class Lasso(LinearModel):
"""
Lasso Regression, training by Coordinate Descent.
cost = ||X @ coef_||^2 + alpha * ||coef_||_1
"""
def __init__(self, alpha=1.0, n_iter=1000, e=0.1):
self.alpha = alpha
self.n_iter = n_iter
self.e = e
super().__init__()
def fit(self, X, y):
self.scaler.fit(X)
X = self.scaler.transform(X)
X = np.c_[np.ones(X.shape[0]), X]
self.coef_ = np.zeros(X.shape[1])
for _ in range(self.n_iter):
z = np.sum(X * X, axis=0)
tmp = np.zeros(X.shape[1])
for k in range(X.shape[1]):
wk = self.coef_[k]
self.coef_[k] = 0
p_k = X[:, k] @ (y - X @ self.coef_)
if p_k < -self.alpha / 2:
w_k = (p_k + self.alpha / 2) / z[k]
elif p_k > self.alpha / 2:
w_k = (p_k - self.alpha / 2) / z[k]
else:
w_k = 0
tmp[k] = w_k
self.coef_[k] = wk
if np.linalg.norm(self.coef_ - tmp) < self.e:
break
self.coef_ = tmp
return self
Ridge
class Ridge(LinearModel):
"""
Ridge Regression.
"""
def __init__(self, alpha=1.0):
self.alpha = alpha
super().__init__()
def fit(self, X, y):
"""
:param X_: shape = (n_samples + 1, n_features)
:param y: shape = (n_samples])
:return: self
"""
self.scaler.fit(X)
X = self.scaler.transform(X)
X = np.c_[np.ones(X.shape[0]), X]
self.coef_ = np.linalg.inv(
X.T @ X + self.alpha * np.eye(X.shape[1])) @ X.T @ y
return self
测试代码
import matplotlib.pyplot as plt
import numpy as np
def gen_reg_data():
X = np.arange(0, 45, 0.1)
X = X + np.random.random(size=X.shape[0]) * 20
y = 2 * X + np.random.random(size=X.shape[0]) * 20 + 10
return X, y
def test_linear_regression():
clf = LinearRegression()
X, y = gen_reg_data()
clf.fit(X, y)
plt.plot(X, y, '.')
X_axis = np.arange(-5, 75, 0.1)
plt.plot(X_axis, clf.predict(X_axis))
plt.title("Linear Regression")
plt.show()
def test_lasso():
clf = Lasso()
X, y = gen_reg_data()
clf.fit(X, y)
plt.plot(X, y, '.')
X_axis = np.arange(-5, 75, 0.1)
plt.plot(X_axis, clf.predict(X_axis))
plt.title("Lasso")
plt.show()
def test_ridge():
clf = Ridge()
X, y = gen_reg_data()
clf.fit(X, y)
plt.plot(X, y, '.')
X_axis = np.arange(-5, 75, 0.1)
plt.plot(X_axis, clf.predict(X_axis))
plt.title("Ridge")
plt.show()
测试效果
更多机器学习代码,请访问 https://github.com/WiseDoge/plume
来源:https://www.jianshu.com/p/997e0ee1e010
0
投稿
猜你喜欢
- TNS简要介绍与应用 Oracle中TNS的完整定义:transparence Network Substrate透明网络底层,监听服务是它
- Windowns操作系统中安装Python,供大家参考,具体内容如下一.下载Python1.python 官网 下载安装包2.选择
- 当需要存储很多同类型的不通过数据时可能需要使用到嵌套,先用一个例子说明嵌套的使用1、在列表中存储字典#假设年级里有一群国际化的学生,有黄皮肤
- 简介scrapy 是一个 python 下面功能丰富、使用快捷方便的爬虫框架。用 scrapy 可以快速的开发一个简单的爬虫,官方给出的一个
- 下面列出了asp远程网页数据采集程序中经常用到的函数,很实用,特别是正则表达式过滤函数。包括了使用xmlhttp采集远程网页内容,使用ado
- 闭包内容:匿名函数:能够完成简单的功能,传递这个函数的引用,只有功能普通函数:能够完成复杂的功能,传递这个函数的引用,只有功能闭包:能够完成
- 1.效果2.环境1.pytorch2.visdom3.python3.53.用到的代码# coding:utf8import torchfr
- 一、前言:Thrift 是一种接口描述语言和二进制通信协议。以前也没接触过,最近有个项目需要建立自动化测试,这个项目之间的微服务都是通过 T
- 在使用Dreamweaver制作主页的时候往往需要改变表格的高度。然而有时当我们拖动表格的边框,无论怎样拖动,等到放下鼠标,表格却又恢复到原
- Ajax在网上已经叫喊了好几年了, 但是还是有很多像我这样的新手没掌握它, 像这样能改善交互体验的技术不会用真是很遗憾呢. 所以我就把我学到
- 一、爬取数据话不多说了,直接上代码( copy即可用 )import requestsimport pandas as pdclass Sp
- 利用numpy库(缺点:有缺失值就无法读取)读:import numpy my_matrix = numpy.loadtxt(open(&q
- 简介使用百度深度学习框架paddlepaddle对人像图片进行自动化抠图安装根据PaddlePaddle官网命令安装如pip install
- 本文实例讲述了python自动化之Ansible的安装。分享给大家供大家参考,具体如下:一 点睛Ansible只需在管理端部署环境即可,建议
- 做 Web 开发少不了要与模板引擎打交道。我陆续也接触了 Python 的不少模板引擎,感觉可以总结一下了。一、首先按照我的熟悉程度列一下:
- 本文实例讲述了PHP采集静态页面并把页面css,img,js保存的方法。分享给大家供大家参考。具体分析如下:这是一个可以获取网页的html代
- 突然发现自己对Web前端技术掌握得很少很少,就是自己最感兴趣的XHTML+CSS部分知道也不算多。在XHTML 1.1规定的诸多元素中,我平
- python的pickle模块实现了基本的数据序列和反序列化。通过pickle模块的序列化操作我们能够将程序中运行的对象信息保存到文件中去,
- 一、Selects检索表中的所有行$users = DB::table('users')->get();foreach
- filetype.pySmall and dependency free Python package to infer file type