网络编程
位置:首页>> 网络编程>> Python编程>> keras用auc做metrics以及早停实例

keras用auc做metrics以及早停实例

作者:ssswill  发布时间:2022-04-19 03:55:12 

标签:keras,auc,metrics,早停

我就废话不多说了,大家还是直接看代码吧~


import tensorflow as tf
from sklearn.metrics import roc_auc_score

def auroc(y_true, y_pred):
return tf.py_func(roc_auc_score, (y_true, y_pred), tf.double)
# Build Model...

model.compile(loss='categorical_crossentropy', optimizer='adam',metrics=['accuracy', auroc])

完整例子:


def auc(y_true, y_pred):
auc = tf.metrics.auc(y_true, y_pred)[1]
K.get_session().run(tf.local_variables_initializer())
return auc

def create_model_nn(in_dim,layer_size=200):
model = Sequential()
model.add(Dense(layer_size,input_dim=in_dim, kernel_initializer='normal'))
model.add(BatchNormalization())
model.add(Activation('relu'))
model.add(Dropout(0.3))
for i in range(2):
 model.add(Dense(layer_size))
 model.add(BatchNormalization())
 model.add(Activation('relu'))
 model.add(Dropout(0.3))
model.add(Dense(1, activation='sigmoid'))
adam = optimizers.Adam(lr=0.01)
model.compile(optimizer=adam,loss='binary_crossentropy',metrics = [auc])
return model
####cv train
folds = StratifiedKFold(n_splits=5, shuffle=False, random_state=15)
oof = np.zeros(len(df_train))
predictions = np.zeros(len(df_test))
for fold_, (trn_idx, val_idx) in enumerate(folds.split(df_train.values, target2.values)):
print("fold n°{}".format(fold_))
X_train = df_train.iloc[trn_idx][features]
y_train = target2.iloc[trn_idx]
X_valid = df_train.iloc[val_idx][features]
y_valid = target2.iloc[val_idx]
model_nn = create_model_nn(X_train.shape[1])
callback = EarlyStopping(monitor="val_auc", patience=50, verbose=0, mode='max')
history = model_nn.fit(X_train, y_train, validation_data = (X_valid ,y_valid),epochs=1000,batch_size=64,verbose=0,callbacks=[callback])
print('\n Validation Max score : {}'.format(np.max(history.history['val_auc'])))
predictions += model_nn.predict(df_test[features]).ravel()/folds.n_splits

补充知识:Keras可使用的评价函数

1:binary_accuracy(对二分类问题,计算在所有预测值上的平均正确率)

binary_accuracy(y_true, y_pred)

2:categorical_accuracy(对多分类问题,计算在所有预测值上的平均正确率)

categorical_accuracy(y_true, y_pred)

3:sparse_categorical_accuracy(与categorical_accuracy相同,在对稀疏的目标值预测时有用 )

sparse_categorical_accuracy(y_true, y_pred)

4:top_k_categorical_accuracy(计算top-k正确率,当预测值的前k个值中存在目标类别即认为预测正确 )

top_k_categorical_accuracy(y_true, y_pred, k=5)

5:sparse_top_k_categorical_accuracy(与top_k_categorical_accracy作用相同,但适用于稀疏情况)

sparse_top_k_categorical_accuracy(y_true, y_pred, k=5)

来源:https://blog.csdn.net/ssswill/article/details/95515314

0
投稿

猜你喜欢

  • 本文详细分析了Yii框架的登录流程。分享给大家供大家参考。具体分析如下:Yii对于新手来说上手有点难度,特别是关于session,cooki
  • 相信互联网的从业者都有同一个顾虑,那就是怎样将自己网站的用户牢牢抓住。如果以用户的角度来讲,任何网站其实都是一样的,都是我获取东西、获取服务
  • 有时候发微博时候,需要裁切图片为九宫格,但是ps或者其他工具都太麻烦,这里写一个python一键切割九宫格的工具,以供大家学习和使用!实现代
  • Swiper是纯javascript打造的滑动特效插件,面向手机、平板电脑等移动终端。Swiper能实现触屏焦点图、触屏Tab切换、触屏多图
  • 项目简介鉴于项目保密的需要,不便透露太多项目的信息,因此,简单介绍一下项目存在的难点:海量数据:项目是对CSV文件中的数据进行处理,而特点是
  • Ajax的出现让Web展现了更新的活力,基本所有的语言,都动态支持Ajax与起服务端进行通信,并在页面实现无刷新动态交互。 下面是散仙使用D
  • 数据库索引是一个数据结构,提高操作的速度,在一个表中可以使用一个或多个列,提供两个快速随机查找和高效的顺序访问记录的基础创建索引。在创建索引
  • Python中的五种特性:切片,迭代,列表生成式,生成器,迭代器。切片切片就相当于其他语言中的截断函数,取部分指定元素用的。L = list
  • 本文实例为大家分享了PHP文件打包下载zip的具体代码,供大家参考,具体内容如下<?php//获取文件列表function list_
  • 在pycharm中,可以通过venv来建立工程,运行等等。但是一旦把这个工程文件夹拿到其他地方运行,而且不是在venv环境中运行,就有可能遇
  • 比如有一个需求,通过sql语句,返回-5至5的随机整数.如果这一个放在PHP中,则非常简单直接用print rand(-5,5);?>
  • 了然于胸 - collectModules时序图经过loadConfig和applyConfigDefaults,我们已经将用户自定义信息和
  • 前言:macOS自带的Apache可以提供通过http://localhost:8081访问本地文件服务,那么python有没有类似功能的库
  • 在编写自动化测试用例的时候,每次登录都需要输入验证码,后来想把让python自己识别图片里的验证码,不需要自己手动登陆,所以查了一下识别功能
  • 示例:《电影类型分类》获取数据来源电影名称打斗次数接吻次数电影类型California Man3104RomanceHe's Not
  • 本文实例讲述了Python编程对列表中字典元素进行排序的方法。分享给大家供大家参考,具体如下:内容目录:1. 问题起源2. 对列表中的字典元
  • 问题你有一个数据序列,想利用一些规则从中提取出需要的值或者是缩短序列解决方案最简单的过滤序列元素的方法就是使用列表推导。比如:>>
  • 今天在慕课网上学习了有关于python操作MySQL的相关知识,在此做些总结。python操作数据库还是相对比较简单的,由于python统一
  • 关于浅拷贝和深拷贝想必大家在学习中遇到很多次,这也是面试中常常被问到的问题,借由这个时间,整理一下浅拷贝和深拷贝的关系先从一个简单的例子入手
  • 1.delete不能使自动编号返回为起始值。但是truncate能使自动增长的列的值返回为默认的种子 2.truncate只能一次清空,不能
手机版 网络编程 asp之家 www.aspxhome.com