四种Python机器学习超参数搜索方法总结
作者:Python数据挖掘 发布时间:2022-03-19 17:29:22
在建模时模型的超参数对精度有一定的影响,而设置和调整超参数的取值,往往称为调参。
在实践中调参往往依赖人工来进行设置调整范围,然后使用机器在超参数范围内进行搜素。本文将演示在sklearn中支持的四种基础超参数搜索方法:
GridSearch
RandomizedSearch
HalvingGridSearch
HalvingRandomSearch
原始模型
作为精度对比,我们最开始使用随机森林来训练初始化模型,并在测试集计算精度:
# 数据读取
df = pd.read_csv('https://mirror.coggle.club/dataset/heart.csv')
X = df.drop(columns=['output'])
y = df['output']
# 数据划分
x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y)
# 模型训练与计算准确率
clf = RandomForestClassifier(random_state=0)
clf.fit(x_train, y_train)
clf.score(x_test, y_test)
模型最终在测试集精度为:0.802。
GridSearch
GridSearch是比较基础的超参数搜索方法,中文名字网格搜索。其原理是在计算的过程中遍历所有的超参数组合,然后搜索到最优的结果。
如下代码所示,我们对4个超参数进行搜索,搜索空间为 5 * 3 * 2 * 3 = 90组超参数。对于每组超参数还需要计算5折交叉验证,则需要训练450次。
parameters = {
'max_depth': [2,4,5,6,7],
'min_samples_leaf': [1,2,3],
'min_weight_fraction_leaf': [0, 0.1],
'min_impurity_decrease': [0, 0.1, 0.2]
}
# Fitting 5 folds for each of 90 candidates, totalling 450 fits
clf = GridSearchCV(
RandomForestClassifier(random_state=0),
parameters, refit=True, verbose=1,
)
clf.fit(x_train, y_train)
clf.best_estimator_.score(x_test, y_test)
模型最终在测试集精度为:0.815。
RandomizedSearch
RandomizedSearch是在一定范围内进行搜索,且需要设置搜索的次数,其默认不会对所有的组合进行搜索。
n_iter代表超参数组合的个数,默认会设置比所有组合次数少的取值,如下面设置的为10,则只进行50次训练。
parameters = {
'max_depth': [2,4,5,6,7],
'min_samples_leaf': [1,2,3],
'min_weight_fraction_leaf': [0, 0.1],
'min_impurity_decrease': [0, 0.1, 0.2]
}
clf = RandomizedSearchCV(
RandomForestClassifier(random_state=0),
parameters, refit=True, verbose=1, n_iter=10,
)
clf.fit(x_train, y_train)
clf.best_estimator_.score(x_test, y_test)
模型最终在测试集精度为:0.815。
HalvingGridSearch
HalvingGridSearch和GridSearch非常相似,但在迭代的过程中是有参数组合减半的操作。
最开始使用所有的超参数组合,但使用最少的数据,筛选其中最优的超参数,增加数据再进行筛选。
HalvingGridSearch的思路和hyperband的思路非常相似,但是最朴素的实现。先使用少量数据筛选超参数组合,然后使用更多的数据验证精度。
n_iterations: 3
n_required_iterations: 5
n_possible_iterations: 3
min_resources_: 20
max_resources_: 227
aggressive_elimination: False
factor: 3
----------
iter: 0
n_candidates: 90
n_resources: 20
Fitting 5 folds for each of 90 candidates, totalling 450 fits
----------
iter: 1
n_candidates: 30
n_resources: 60
Fitting 5 folds for each of 30 candidates, totalling 150 fits
----------
iter: 2
n_candidates: 10
n_resources: 180
Fitting 5 folds for each of 10 candidates, totalling 50 fits
----------
模型最终在测试集精度为:0.855。
HalvingRandomSearch
HalvingRandomSearch和HalvingGridSearch类似,都是逐步增加样本,减少超参数组合。但每次生成超参数组合,都是随机筛选的。
n_iterations: 3
n_required_iterations: 3
n_possible_iterations: 3
min_resources_: 20
max_resources_: 227
aggressive_elimination: False
factor: 3
----------
iter: 0
n_candidates: 11
n_resources: 20
Fitting 5 folds for each of 11 candidates, totalling 55 fits
----------
iter: 1
n_candidates: 4
n_resources: 60
Fitting 5 folds for each of 4 candidates, totalling 20 fits
----------
iter: 2
n_candidates: 2
n_resources: 180
Fitting 5 folds for each of 2 candidates, totalling 10 fits
模型最终在测试集精度为:0.828。
总结与对比
HalvingGridSearch和HalvingRandomSearch比较适合在数据量比较大的情况使用,可以提高训练速度。如果计算资源充足,GridSearch和HalvingGridSearch会得到更好的结果。
后续我们将分享其他的一些高阶调参库的实现,其中也会有数据量改变的思路。如在Optuna中,核心是参数组合的生成和剪枝、训练的样本增加等细节。
来源:https://blog.csdn.net/qq_34160248/article/details/127717971
![](https://www.aspxhome.com/images/zang.png)
![](https://www.aspxhome.com/images/jiucuo.png)
猜你喜欢
- 本文介绍了tf.truncated_normal与tf.random_normal的详细用法,分享给大家,具体如下:tf.truncated
- 一、前言在项目开发中,数据库应用必不可少。虽然数据库的种类有很多,如SQLite、MySQL、Oracle等,但是它们的功能基本是一样都是一
- 用golang来实现的webserver通常是是这样的//main.gopackage mainimport ("fmt"
- 环境配置系统:Windows10版本:python 3.8Turtle扫盲1.绘图窗体的设置turtle.setup(width, heig
- 注释注释就是对代码的解释和说明。目的是为了让别人和自己很容易看懂。为了让别人一看就知道这段代码是做什么用的。正确的程序注释一般包括序言性注释
- 本文实例讲述了Python上下文管理器类和上下文管理器装饰器contextmanager用法。分享给大家供大家参考,具体如下:一. 什么是上
- 利用networkx,numpy,matplotlib,将邻接矩阵输出为图形。1,自身确定一个邻接矩阵,然后通过循环的方式添加变,然后输出图
- 1. UDPUDP是一种无连接的、不可靠的传输协议,相比于TCP,UDP具有数据传输速度快、传输延迟小等优点,但是不保证数据的可靠传输,需要
- 整理了一下python 中文件的输入输出及主要介绍一些os模块中对文件系统的操作。文件输入输出1、内建函数open(file_name,文件
- 字符串的相似性比较应用场合很多,像拼写纠错、文本去重、上下文相似性等。评价字符串相似度最常见的办法就是:把一个字符串通过插入、删除或替换这样
- 目录:分析和设计组件编码实现和算法用 Ant 构建组件测试 JavaScript 组件我们走到哪儿了?前两期思考了太多东西,你是否已有倦意?
- --语 句 功 能 --数据操作 SELECT --从数据库表中检索数据行和列 INSERT --向数据库表添加新数据行 DELETE --
- 1. 反射简介1.1 反射是什么?Go语言提供了一种机制在运行时更新和检查变量的值、调用变量的方法和变量支持的内在操作,但是在编译时并不知道
- 一个非常繁琐粗暴的方法,python属于入门级水平,就酱先备份一下,如果有更好的方法再更新arrs=[[2,15,48,4,5],[6,7,
- 目录一、进程(Process)二、线程(Thread)三、并发编程解决方案:四、多线程实现 (两种)1、第一种 函数方法2、第二种 类方法包
- 用下面代码可实现:<%Dim writeDim fileSysObj, tf, readrea
- 本文实例讲述了Go语言计算两个经度和纬度之间距离的方法。分享给大家供大家参考。具体实现方法如下:package main &nbs
- Python的第一个主流打包格式是.egg文件,现在大家庭中又有了一个叫做Wheel(*.whl)的新成员。wheel“被设计成包含PEP
- 我们经常会要用到页面的包含这样东西. 在asp.net 我开始也还是习惯用asp中的include 用起来感觉很麻烦.
- 一、数据无量纲化处理 (热力图)1.数据无量纲化处理(仅介绍本文用到的方法):min-max归一化该方法是对原始数据进行线性变换,