sklearn+python:线性回归案例
作者:yuanlulu 发布时间:2023-10-19 20:07:01
使用一阶线性方程预测波士顿房价
载入的数据是随sklearn一起发布的,来自boston 1993年之前收集的506个房屋的数据和价格。load_boston()用于载入数据。
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
import time
from sklearn.linear_model import LinearRegression
boston = load_boston()
X = boston.data
y = boston.target
print("X.shape:{}. y.shape:{}".format(X.shape, y.shape))
print('boston.feature_name:{}'.format(boston.feature_names))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=3)
model = LinearRegression()
start = time.clock()
model.fit(X_train, y_train)
train_score = model.score(X_train, y_train)
cv_score = model.score(X_test, y_test)
print('time used:{0:.6f}; train_score:{1:.6f}, sv_score:{2:.6f}'.format((time.clock()-start),
train_score, cv_score))
输出内容为:
X.shape:(506, 13). y.shape:(506,)
boston.feature_name:['CRIM' 'ZN' 'INDUS' 'CHAS' 'NOX' 'RM' 'AGE' 'DIS' 'RAD' 'TAX' 'PTRATIO'
'B' 'LSTAT']
time used:0.012403; train_score:0.723941, sv_score:0.794958
可以看到测试集上准确率并不高,应该是欠拟合。
使用多项式做线性回归
上面的例子是欠拟合的,说明模型太简单,无法拟合数据的情况。现在增加模型复杂度,引入多项式。
打个比方,如果原来的特征是[a, b]两个特征,
在degree为2的情况下, 多项式特征变为[1, a, b, a^2, ab, b^2]。degree为其它值的情况依次类推。
多项式特征相当于增加了数据和模型的复杂性,能够更好的拟合。
下面的代码使用Pipeline把多项式特征和线性回归特征连起来,最终测试degree在1、2、3的情况下的得分。
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
import time
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
def polynomial_model(degree=1):
polynomial_features = PolynomialFeatures(degree=degree, include_bias=False)
linear_regression = LinearRegression(normalize=True)
pipeline = Pipeline([('polynomial_features', polynomial_features),
('linear_regression', linear_regression)])
return pipeline
boston = load_boston()
X = boston.data
y = boston.target
print("X.shape:{}. y.shape:{}".format(X.shape, y.shape))
print('boston.feature_name:{}'.format(boston.feature_names))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=3)
for i in range(1,4):
print( 'degree:{}'.format( i ) )
model = polynomial_model(degree=i)
start = time.clock()
model.fit(X_train, y_train)
train_score = model.score(X_train, y_train)
cv_score = model.score(X_test, y_test)
print('time used:{0:.6f}; train_score:{1:.6f}, sv_score:{2:.6f}'.format((time.clock()-start),
train_score, cv_score))
输出结果为:
X.shape:(506, 13). y.shape:(506,)
boston.feature_name:['CRIM' 'ZN' 'INDUS' 'CHAS' 'NOX' 'RM' 'AGE' 'DIS' 'RAD' 'TAX' 'PTRATIO'
'B' 'LSTAT']
degree:1
time used:0.003576; train_score:0.723941, sv_score:0.794958
degree:2
time used:0.030123; train_score:0.930547, sv_score:0.860465
degree:3
time used:0.137346; train_score:1.000000, sv_score:-104.429619
可以看到degree为1和上面不使用多项式是一样的。degree为3在训练集上的得分为1,在测试集上得分是负数,明显过拟合了。
所以最终应该选择degree为2的模型。
二阶多项式比一阶多项式好的多,但是测试集和训练集上的得分仍有不少差距,这可能是数据不够的原因,需要更多的讯据才能进一步提高模型的准确度。
正规方程解法和梯度下降的比较
除了梯度下降法来逼近最优解,也可以使用正规的方程解法直接计算出最终的解来。
根据吴恩达的课程,线性回归最优解为:
theta = (X^T * X)^-1 * X^T * y
其实两种方法各有优缺点:
梯度下降法:
缺点:需要选择学习率,需要多次迭代
优点:特征值很多(1万以上)时仍然能以不错的速度工作
正规方程解法:
优点:不需要设置学习率,不需要多次迭代
缺点:需要计算X的转置和逆,复杂度O3;特征值很多(1万以上)时特变慢
在分类等非线性计算中,正规方程解法并不适用,所以梯度下降法适用范围更广。
来源:https://blog.csdn.net/yuanlulu/article/details/81068027
猜你喜欢
- Iterable – 可迭代对象能够逐一返回其成员项的对象。 可迭代对象的例子包括所有序列类型 (例如 list, str 和 tuple)
- 前言subprocess库提供了一个API创建子进程并与之通信。这对于运行生产或消费文本的程序尤其有好处,因为这个API支持通过新进行的标准
- 用python语言读取二进制图片文件,并提取非零数据统计信息(例如:max,min,skewness and kurtosis)python
- 前言:目前在研究易信公众号,想给公众号增加一个获取个人交通违章的查询菜单,通过点击返回查询数据。以下是实施过程。一、首先,用火狐浏览器打开X
- 本方法是基于文本密度的方法,最初的想法来源于哈工大的《基于行块分布函数的通用网页正文抽取算法》,本文基于此进行一些小修改。约定:
- MooTools 1.2的整理排序类Sortables原文地址:30 Days of Mootools 1.2 Tutorials - Da
- JSP 开发之 releaseSession的实例详解Hibernate可以实现分页查询,昨天试了一下,分页效果不错。但是发现了一个问题,就
- 以下示例显示如何在 XPath 查询中指定轴。这些示例中的 XPath 查询都在 SampleSchema1.xml 中所包含的映射架构上指
- vbscript脚本中,fso对象CreateTextFile方法调用时可能会报“无效的过程调用或参数”错误,在使用ASP生成静态页面时,如
- Pytorch 多分类模型绘制 ROC, PR 曲线(代码 亲测 可用)ROC曲线示例代码import torchimport torch.
- python——pip install xxx报错SyntaxError: invalid syntax在安装好python后,进入pyth
- 首先说说框架(Frameworks)这个词,框架就是为我们提供了一个平台一个运行环境,在如此统一的前提下我们做相关开发才能“有章可循”,要充
- 父层: <div class="col-xs-12"> <div class
- BrowserPlus 到底是什么,又能做什么?BrowserPlus 是 Yahoo! 最近刚发布一个 Web 扩展的平台:终端用户需安装
- SQL Server数据库快捷键:书签:清除所有书签。 CTRL-SHIFT-F2书签:插入或删除书签(切换)。 CTRL+F2书签:移动到
- 在keras中,数据是以张量的形式表示的,不考虑动态特性,仅考虑shape的时候,可以把张量用类似矩阵的方式来理解。例如[[1],[2],[
- 什么是DLL文件?DLL文件为动态链接库(英语: Dynamic-link library, 缩写为DLL)它是微软公司在微软视窗操作系统中
- 实这本是说明一个问题 : 每个人在提高自己能力这件事情上, 需要持续不断地努力。以最典型的例子来看,只有通过学习,程序员才能保证不断进步。
- 本文实例讲述了php简单获取复选框值的方法。分享给大家供大家参考,具体如下:html:<form id="form1&quo
- 本文实例为大家分享了Bootstrap实现渐变顶部固定自适应导航栏的具体代码,供大家参考,具体内容如下具体代码如下所示:<!DOCTY