sklearn+python:线性回归案例
作者:yuanlulu 发布时间:2023-10-19 20:07:01
使用一阶线性方程预测波士顿房价
载入的数据是随sklearn一起发布的,来自boston 1993年之前收集的506个房屋的数据和价格。load_boston()用于载入数据。
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
import time
from sklearn.linear_model import LinearRegression
boston = load_boston()
X = boston.data
y = boston.target
print("X.shape:{}. y.shape:{}".format(X.shape, y.shape))
print('boston.feature_name:{}'.format(boston.feature_names))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=3)
model = LinearRegression()
start = time.clock()
model.fit(X_train, y_train)
train_score = model.score(X_train, y_train)
cv_score = model.score(X_test, y_test)
print('time used:{0:.6f}; train_score:{1:.6f}, sv_score:{2:.6f}'.format((time.clock()-start),
train_score, cv_score))
输出内容为:
X.shape:(506, 13). y.shape:(506,)
boston.feature_name:['CRIM' 'ZN' 'INDUS' 'CHAS' 'NOX' 'RM' 'AGE' 'DIS' 'RAD' 'TAX' 'PTRATIO'
'B' 'LSTAT']
time used:0.012403; train_score:0.723941, sv_score:0.794958
可以看到测试集上准确率并不高,应该是欠拟合。
使用多项式做线性回归
上面的例子是欠拟合的,说明模型太简单,无法拟合数据的情况。现在增加模型复杂度,引入多项式。
打个比方,如果原来的特征是[a, b]两个特征,
在degree为2的情况下, 多项式特征变为[1, a, b, a^2, ab, b^2]。degree为其它值的情况依次类推。
多项式特征相当于增加了数据和模型的复杂性,能够更好的拟合。
下面的代码使用Pipeline把多项式特征和线性回归特征连起来,最终测试degree在1、2、3的情况下的得分。
from sklearn.datasets import load_boston
from sklearn.model_selection import train_test_split
import time
from sklearn.linear_model import LinearRegression
from sklearn.preprocessing import PolynomialFeatures
from sklearn.pipeline import Pipeline
def polynomial_model(degree=1):
polynomial_features = PolynomialFeatures(degree=degree, include_bias=False)
linear_regression = LinearRegression(normalize=True)
pipeline = Pipeline([('polynomial_features', polynomial_features),
('linear_regression', linear_regression)])
return pipeline
boston = load_boston()
X = boston.data
y = boston.target
print("X.shape:{}. y.shape:{}".format(X.shape, y.shape))
print('boston.feature_name:{}'.format(boston.feature_names))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=3)
for i in range(1,4):
print( 'degree:{}'.format( i ) )
model = polynomial_model(degree=i)
start = time.clock()
model.fit(X_train, y_train)
train_score = model.score(X_train, y_train)
cv_score = model.score(X_test, y_test)
print('time used:{0:.6f}; train_score:{1:.6f}, sv_score:{2:.6f}'.format((time.clock()-start),
train_score, cv_score))
输出结果为:
X.shape:(506, 13). y.shape:(506,)
boston.feature_name:['CRIM' 'ZN' 'INDUS' 'CHAS' 'NOX' 'RM' 'AGE' 'DIS' 'RAD' 'TAX' 'PTRATIO'
'B' 'LSTAT']
degree:1
time used:0.003576; train_score:0.723941, sv_score:0.794958
degree:2
time used:0.030123; train_score:0.930547, sv_score:0.860465
degree:3
time used:0.137346; train_score:1.000000, sv_score:-104.429619
可以看到degree为1和上面不使用多项式是一样的。degree为3在训练集上的得分为1,在测试集上得分是负数,明显过拟合了。
所以最终应该选择degree为2的模型。
二阶多项式比一阶多项式好的多,但是测试集和训练集上的得分仍有不少差距,这可能是数据不够的原因,需要更多的讯据才能进一步提高模型的准确度。
正规方程解法和梯度下降的比较
除了梯度下降法来逼近最优解,也可以使用正规的方程解法直接计算出最终的解来。
根据吴恩达的课程,线性回归最优解为:
theta = (X^T * X)^-1 * X^T * y
其实两种方法各有优缺点:
梯度下降法:
缺点:需要选择学习率,需要多次迭代
优点:特征值很多(1万以上)时仍然能以不错的速度工作
正规方程解法:
优点:不需要设置学习率,不需要多次迭代
缺点:需要计算X的转置和逆,复杂度O3;特征值很多(1万以上)时特变慢
在分类等非线性计算中,正规方程解法并不适用,所以梯度下降法适用范围更广。
来源:https://blog.csdn.net/yuanlulu/article/details/81068027


猜你喜欢
- 概述传入条件的不同,会执行不同的语句每一个case分支都是唯一的,从上到下逐一测试,直到匹配为止。语法第一种【switch 带上表达式】sw
- Selenium爬虫遇到 数据是以 JSON 字符串的形式包裹在 Script 标签中,假设Script标签下代码如下:<script
- 后台数据库: [Microsoft Access] 与 [Microsoft Sql Server] 更换之后,ASP代码应注意要修改的一些
- python在mysql中插入null空值sql = “INSERT INTO MROdata (MmeUeS1apId) VALUES (
- 维护是什么,维护就是修改,不断的修改,但是要保证你的html和css有清晰的版本界定,有扩展性,不要因为做的太死而重新去做这个页面。一个赚钱
- 在Windows系统中用“Ctrl+C”和“Ctrl+V”就可以完成复制、粘贴工作,是不是很爽?其实使用a标签的accesskey属性也可以
- MySQL 5.7安装、升级笔记分享:卸载当前的 MySQL查看当前 MySQL 版本:[root@coderknock ~]# mysql
- 前言随着Python 3.8的发布,赋值表达式运算符(也称为海象运算符)也发布了。运算符使值的赋值可以传递到表达式中。这通常会使语句数减少一
- yaml文件内容apiVersion: policy/v1beta1kind: PodSecurityPolicymetadata: &nb
- 鉴于ASP脚本语言是在服务器端IIS或PWS中解释和运行,并可动态生成普通的HTML网页,然后再传送到客户端供浏览的这一特点。我们要在本机上
- 说起模板引擎,很多人会认为这是后台的东西(如PHP的Smarty、Java的Velocity),跟前端没有关系。然而,随着前端的逻辑变得越来
- 目录项目地址安装导入使用1 创建连接2 执行sql语句3 select 方法4 insert_into 方法5 merge_in
- 一、下载instant client1.附链接:http://www.oracle.com/technetwork/topics/winx6
- 前言通常在项目中,一般都需要一种编程语言来操作数据库,使用Python来操作数据库有着天然的优势,因为Python的字典和MongoDB的文
- 一、概述spark 有三大引擎,spark core、sparkSQL、sparkStreaming,spark core 的关键抽象是 S
- 前言在学习操作系统的时候,我们应该都学习过临界区、互斥锁这些概念,用于在并发环境下保证状态的正确性。比如在秒杀时,100 个用户同时抢 10
- 1.在Server端添加Silverlight-enabled WCF service [ServiceContract(Namespace
- ASP中判断字符串中是否包含字母和数字的两个函数function isnaw(str) for
- 一、 迪杰斯特拉算法思想Dijkstra算法主要针对的是有向图的单元最短路径问题,且不能出现权值为负的情况!Dijkstra算法类似于贪心算
- 目录前言一、常用命令二、嗅探数据包三、构造数据包四、各个协议用法五、发包,收包六、SYN半开式扫描七、数据包序列化,反序列化八、数据包与字符