四种Python机器学习超参数搜索方法总结
作者:Python数据挖掘 发布时间:2022-03-19 17:29:22
在建模时模型的超参数对精度有一定的影响,而设置和调整超参数的取值,往往称为调参。
在实践中调参往往依赖人工来进行设置调整范围,然后使用机器在超参数范围内进行搜素。本文将演示在sklearn中支持的四种基础超参数搜索方法:
GridSearch
RandomizedSearch
HalvingGridSearch
HalvingRandomSearch
原始模型
作为精度对比,我们最开始使用随机森林来训练初始化模型,并在测试集计算精度:
# 数据读取
df = pd.read_csv('https://mirror.coggle.club/dataset/heart.csv')
X = df.drop(columns=['output'])
y = df['output']
# 数据划分
x_train, x_test, y_train, y_test = train_test_split(X, y, stratify=y)
# 模型训练与计算准确率
clf = RandomForestClassifier(random_state=0)
clf.fit(x_train, y_train)
clf.score(x_test, y_test)
模型最终在测试集精度为:0.802。
GridSearch
GridSearch是比较基础的超参数搜索方法,中文名字网格搜索。其原理是在计算的过程中遍历所有的超参数组合,然后搜索到最优的结果。
如下代码所示,我们对4个超参数进行搜索,搜索空间为 5 * 3 * 2 * 3 = 90组超参数。对于每组超参数还需要计算5折交叉验证,则需要训练450次。
parameters = {
'max_depth': [2,4,5,6,7],
'min_samples_leaf': [1,2,3],
'min_weight_fraction_leaf': [0, 0.1],
'min_impurity_decrease': [0, 0.1, 0.2]
}
# Fitting 5 folds for each of 90 candidates, totalling 450 fits
clf = GridSearchCV(
RandomForestClassifier(random_state=0),
parameters, refit=True, verbose=1,
)
clf.fit(x_train, y_train)
clf.best_estimator_.score(x_test, y_test)
模型最终在测试集精度为:0.815。
RandomizedSearch
RandomizedSearch是在一定范围内进行搜索,且需要设置搜索的次数,其默认不会对所有的组合进行搜索。
n_iter代表超参数组合的个数,默认会设置比所有组合次数少的取值,如下面设置的为10,则只进行50次训练。
parameters = {
'max_depth': [2,4,5,6,7],
'min_samples_leaf': [1,2,3],
'min_weight_fraction_leaf': [0, 0.1],
'min_impurity_decrease': [0, 0.1, 0.2]
}
clf = RandomizedSearchCV(
RandomForestClassifier(random_state=0),
parameters, refit=True, verbose=1, n_iter=10,
)
clf.fit(x_train, y_train)
clf.best_estimator_.score(x_test, y_test)
模型最终在测试集精度为:0.815。
HalvingGridSearch
HalvingGridSearch和GridSearch非常相似,但在迭代的过程中是有参数组合减半的操作。
最开始使用所有的超参数组合,但使用最少的数据,筛选其中最优的超参数,增加数据再进行筛选。
HalvingGridSearch的思路和hyperband的思路非常相似,但是最朴素的实现。先使用少量数据筛选超参数组合,然后使用更多的数据验证精度。
n_iterations: 3
n_required_iterations: 5
n_possible_iterations: 3
min_resources_: 20
max_resources_: 227
aggressive_elimination: False
factor: 3
----------
iter: 0
n_candidates: 90
n_resources: 20
Fitting 5 folds for each of 90 candidates, totalling 450 fits
----------
iter: 1
n_candidates: 30
n_resources: 60
Fitting 5 folds for each of 30 candidates, totalling 150 fits
----------
iter: 2
n_candidates: 10
n_resources: 180
Fitting 5 folds for each of 10 candidates, totalling 50 fits
----------
模型最终在测试集精度为:0.855。
HalvingRandomSearch
HalvingRandomSearch和HalvingGridSearch类似,都是逐步增加样本,减少超参数组合。但每次生成超参数组合,都是随机筛选的。
n_iterations: 3
n_required_iterations: 3
n_possible_iterations: 3
min_resources_: 20
max_resources_: 227
aggressive_elimination: False
factor: 3
----------
iter: 0
n_candidates: 11
n_resources: 20
Fitting 5 folds for each of 11 candidates, totalling 55 fits
----------
iter: 1
n_candidates: 4
n_resources: 60
Fitting 5 folds for each of 4 candidates, totalling 20 fits
----------
iter: 2
n_candidates: 2
n_resources: 180
Fitting 5 folds for each of 2 candidates, totalling 10 fits
模型最终在测试集精度为:0.828。
总结与对比
HalvingGridSearch和HalvingRandomSearch比较适合在数据量比较大的情况使用,可以提高训练速度。如果计算资源充足,GridSearch和HalvingGridSearch会得到更好的结果。
后续我们将分享其他的一些高阶调参库的实现,其中也会有数据量改变的思路。如在Optuna中,核心是参数组合的生成和剪枝、训练的样本增加等细节。
来源:https://blog.csdn.net/qq_34160248/article/details/127717971
猜你喜欢
- PHP作为开源语言,发展至今已有很多成熟的国内外开源系统,足以满足个人和企业用户自己建立WEB站点,下面则主要介绍PHP建站的流程和步骤。不
- torch.nn.Modules 相当于是对网络某种层的封装,包括网络结构以及网络参数和一些操作torch.nn.Module 是所有神经网
- php创建JSON数据详解:<?php //创建一个字符数组 $arr=array( 'id'
- vbscript中,错误处理使用on error resume next来完成,如果在你的代码里加入这一句,在这句之后的其他代码如果出现错误
- 代码如下: var lishustr = "qwertyuiopasdfghjklmnbvcxz"; var s = l
- 下面通过实例代码给大家分享Python切片操作去除字符串首尾的空格的方法,具体内容如下所示:#利用切片操作,实现一个trim()函数,去除字
- 阅读上一篇:你是真正的用户体验设计者吗? Ⅵ很可怕,是吧!图中翻译:(从内到外)第一层:用户体验第二层:内容管理界面设计顾客关系管理交互设计
- <%@ page language="java" import="java.util.*" p
- 一、创建元组tup1 = ('physics', 'chemistry', 1997, 2000);tup2
- 好久没有更新博客了,今天看到论坛上有位朋友问起全屏布局,有点像vc的界面。来了兴趣,就写了一个。运用IE6的怪异模式,通过绝对定位来实现的。
- 本文实例讲述了Python查找最长不包含重复字符的子字符串算法。分享给大家供大家参考,具体如下:题目描述请从字符串中找出一个最长的不包含重复
- 使用的类库pip install openpyxl操作实现•工作簿操作# coding: utf-8from openpyxl import
- 北京时间2月15日据国外媒体报道,美国知名sns网站Facebook全球活跃用户量已突破1.75亿大关。数据显示,全球20%的网民都使用Fa
- 似乎讨论分页的人很少,难道大家都沉迷于limit m,n?在有索引的情况下,limit m,n速度足够,可是在复杂条件搜索时,where s
- ajax缓存和编码问题不难解决,下面是解决方法。编码问题默认使用UTF-8,如果一旦发现对象找不到的情况,可能js中输入了中文,同时js的编
- 这篇博客将介绍如何使用Python,Opencv进行二维直方图的计算及绘制(分别用Opencv和Numpy计算),二维直方图可以让我们对不同
- 去空格函数有如下两种:·LTRIM()LTRIM() 函数把字符串头部(左)的空格去掉,其语法如下:LTRIM (<character
- html5带给我们的不仅仅是更多语义丰富的标签,还有更多更牛逼的特性,比如“离线存储”。 对于台式电脑来说,或者它并没有带来什么惊喜,但是对
- 作用:可以清空此文件所在的web站点所有文件,将文件内容清零.运行完毕所有文件大小都变成0字节.此代码本人原创,转载请注明转自本站,谢谢合作
- 基本模块 python爬虫,web spider。爬取网站获取网页数据,并进行分析提取。基本模块使用的是 urllib,urlli