Keras自定义IOU方式
作者:Nick Blog 发布时间:2022-12-24 07:48:27
我就废话不多说了,大家还是直接看代码吧!
def iou(y_true, y_pred, label: int):
"""
Return the Intersection over Union (IoU) for a given label.
Args:
y_true: the expected y values as a one-hot
y_pred: the predicted y values as a one-hot or softmax output
label: the label to return the IoU for
Returns:
the IoU for the given label
"""
# extract the label values using the argmax operator then
# calculate equality of the predictions and truths to the label
y_true = K.cast(K.equal(K.argmax(y_true), label), K.floatx())
y_pred = K.cast(K.equal(K.argmax(y_pred), label), K.floatx())
# calculate the |intersection| (AND) of the labels
intersection = K.sum(y_true * y_pred)
# calculate the |union| (OR) of the labels
union = K.sum(y_true) + K.sum(y_pred) - intersection
# avoid divide by zero - if the union is zero, return 1
# otherwise, return the intersection over union
return K.switch(K.equal(union, 0), 1.0, intersection / union)
def mean_iou(y_true, y_pred):
"""
Return the Intersection over Union (IoU) score.
Args:
y_true: the expected y values as a one-hot
y_pred: the predicted y values as a one-hot or softmax output
Returns:
the scalar IoU value (mean over all labels)
"""
# get number of labels to calculate IoU for
num_labels = K.int_shape(y_pred)[-1] - 1
# initialize a variable to store total IoU in
mean_iou = K.variable(0)
# iterate over labels to calculate IoU for
for label in range(num_labels):
mean_iou = mean_iou + iou(y_true, y_pred, label)
# divide total IoU by number of labels to get mean IoU
return mean_iou / num_labels
补充知识:keras 自定义评估函数和损失函数loss训练模型后加载模型出现ValueError: Unknown metric function:fbeta_score
keras自定义评估函数
有时候训练模型,现有的评估函数并不足以科学的评估模型的好坏,这时候就需要自定义一些评估函数,比如样本分布不均衡是准确率accuracy评估无法判定一个模型的好坏,这时候需要引入精确度和召回率作为评估标准,不幸的是keras没有这些评估函数。
以下是参考别的文章摘取的两个自定义评估函数
召回率:
def recall(y_true, y_pred):
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
recall = true_positives / (possible_positives + K.epsilon())
return recall
精确度:
def precision(y_true, y_pred):
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
precision = true_positives / (predicted_positives + K.epsilon())
return precision
自定义了评估函数,一般在编译模型阶段加入即可:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy', precision, recall])
自定义了损失函数focal_loss一般也在编译阶段加入:
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],
metrics=['accuracy',fbeta_score], )
其他的没有特别要注意的点,直接按照原来的思路训练一版模型出来就好了,关键的地方在于加载模型这里,自定义的函数需要特殊的加载方式,不然会出现加载没有自定义函数的问题:ValueError: Unknown loss function:focal_loss
解决方案:
model_name = 'test_calssification_model.h5'
model_dfcw = load_model(model_name,
custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})
注意点:将自定义的损失函数和评估函数都加入到custom_objects里,以上就是在自定义一个损失函数从编译模型阶段到加载模型阶段出现的所有的问题。
来源:https://blog.csdn.net/xijuezhu8128/article/details/86608333


猜你喜欢
- 在我们做搜索的时候经常要用到模糊查询 (注:其中name1,name2,name3,name4为数据库字段) 1.方法 sql="
- 要实现的SQL查询很原始:要求从第一个表进行查询得到第二个表格式的数据,上网查询之后竟然能写出下面的SQL:select * from us
- 先看一张我绘制的原理图原理图setImmediate 也是宏任务,在 Node 环境下,微任务还有 process.nextTickJS 中
- 众所周知,我们可以通过索引值(或称下标)来查找序列类型(如字符串、列表、元组…)中的单个元素,那么,如果要获取一个索引区间的元素该怎么办呢?
- 写在前面这篇文章的诞生要感谢MIT 6.284课程。在其中一节课中,谈到了多线程的协同的一些问题,其中就涉及到了channel这个概念,并由
- 今天的第二个作品,哈哈哈哈,搞起来感觉还挺有意思的,不过代码里纸牌J,Q,K,A几个数字被我替换成了11,12,13,14......主要是
- 注:因为最近想用一下Python做一些简单小游戏的开发作为项目练手之用,而Pygame模块里面提供了大量的有用的方法和属性。今天我们就在之前
- 初步介绍 当然,我知道现在有成千上万个关于 用CSS处理圆角 的教程,但不管怎么说,我仍然想把这篇文章展示给您。也希望您会发现这篇文章会非常
- 一、Pycharm中安装Django此教程默认你已安装并配置了Python 3.7.6)1.File—>Settings二、搭建Dja
- 出现原因:缺失相应的whl文件。解决办法:下载并安装对应的whl文件。提供一个whl文件的下载网址:http://www.lfd.uci.e
- 如果按本文操作遇到一些问题报错,如C:\Users\milyyy\AppData\Roaming\npm-cache\_logs\2018-
- pyecharts显示数据为百分比的柱状图pyecharts是做数据分析的好帮手,柱状图比较简单,网站例子不够多,一般柱状图就是直接传两组数
- 目录1. 前言2. 实战一下2-1 进入虚拟环境,创建一个项目及 App2-2 创建模板目录并配置 set
- 本文实例讲述了scrapy自定义pipeline类实现将采集数据保存到mongodb的方法。分享给大家供大家参考。具体如下:# Standa
- Web Accessibility Initiative Accessible Rich Internet Applications认识AR
- 引言在上一篇文章中 GoFrame数据校验之校验结果 | Error接口对象 ,关于顺序与非顺序性校验没有做充分的介绍。这篇文章填上之前留的
- 前言本文主要给大家介绍关于python中__init__、__new__和__call__方法的相关内容,分享出来供大家参考学习,下面话不多
- 要说2017年什么技术最火爆,无疑是google领衔的深度学习开源框架Tensorflow。本文简述一下深度学习的入门例子MNIST。深度学
- 先来看一下最终的效果吧开始聊天,输入消息并点击发送消息就可以开始聊天了点击 “获取后端数据”开启实时推送先来简单了解一下 Django Ch
- 背景对接多个外部接口,需要保存请求参数以及返回参数,方便消息的补偿,因为多个外部接口,多个接口字段都不统一,整体使用一个大字段(longte