Keras自定义IOU方式
作者:Nick Blog 发布时间:2022-12-24 07:48:27
我就废话不多说了,大家还是直接看代码吧!
def iou(y_true, y_pred, label: int):
"""
Return the Intersection over Union (IoU) for a given label.
Args:
y_true: the expected y values as a one-hot
y_pred: the predicted y values as a one-hot or softmax output
label: the label to return the IoU for
Returns:
the IoU for the given label
"""
# extract the label values using the argmax operator then
# calculate equality of the predictions and truths to the label
y_true = K.cast(K.equal(K.argmax(y_true), label), K.floatx())
y_pred = K.cast(K.equal(K.argmax(y_pred), label), K.floatx())
# calculate the |intersection| (AND) of the labels
intersection = K.sum(y_true * y_pred)
# calculate the |union| (OR) of the labels
union = K.sum(y_true) + K.sum(y_pred) - intersection
# avoid divide by zero - if the union is zero, return 1
# otherwise, return the intersection over union
return K.switch(K.equal(union, 0), 1.0, intersection / union)
def mean_iou(y_true, y_pred):
"""
Return the Intersection over Union (IoU) score.
Args:
y_true: the expected y values as a one-hot
y_pred: the predicted y values as a one-hot or softmax output
Returns:
the scalar IoU value (mean over all labels)
"""
# get number of labels to calculate IoU for
num_labels = K.int_shape(y_pred)[-1] - 1
# initialize a variable to store total IoU in
mean_iou = K.variable(0)
# iterate over labels to calculate IoU for
for label in range(num_labels):
mean_iou = mean_iou + iou(y_true, y_pred, label)
# divide total IoU by number of labels to get mean IoU
return mean_iou / num_labels
补充知识:keras 自定义评估函数和损失函数loss训练模型后加载模型出现ValueError: Unknown metric function:fbeta_score
keras自定义评估函数
有时候训练模型,现有的评估函数并不足以科学的评估模型的好坏,这时候就需要自定义一些评估函数,比如样本分布不均衡是准确率accuracy评估无法判定一个模型的好坏,这时候需要引入精确度和召回率作为评估标准,不幸的是keras没有这些评估函数。
以下是参考别的文章摘取的两个自定义评估函数
召回率:
def recall(y_true, y_pred):
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
possible_positives = K.sum(K.round(K.clip(y_true, 0, 1)))
recall = true_positives / (possible_positives + K.epsilon())
return recall
精确度:
def precision(y_true, y_pred):
true_positives = K.sum(K.round(K.clip(y_true * y_pred, 0, 1)))
predicted_positives = K.sum(K.round(K.clip(y_pred, 0, 1)))
precision = true_positives / (predicted_positives + K.epsilon())
return precision
自定义了评估函数,一般在编译模型阶段加入即可:
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy', precision, recall])
自定义了损失函数focal_loss一般也在编译阶段加入:
model.compile(optimizer=Adam(lr=0.0001), loss=[focal_loss],
metrics=['accuracy',fbeta_score], )
其他的没有特别要注意的点,直接按照原来的思路训练一版模型出来就好了,关键的地方在于加载模型这里,自定义的函数需要特殊的加载方式,不然会出现加载没有自定义函数的问题:ValueError: Unknown loss function:focal_loss
解决方案:
model_name = 'test_calssification_model.h5'
model_dfcw = load_model(model_name,
custom_objects={'focal_loss': focal_loss,'fbeta_score':fbeta_score})
注意点:将自定义的损失函数和评估函数都加入到custom_objects里,以上就是在自定义一个损失函数从编译模型阶段到加载模型阶段出现的所有的问题。
来源:https://blog.csdn.net/xijuezhu8128/article/details/86608333
猜你喜欢
- 设计网站的同志背景主要有两种:学计算机、学艺术。基本上会写代码的不懂设计,会设计的不懂代码,这个格局似乎到今天还没变。某些学计算机的同学,有
- 前言今天学习Django框架,用ajax向后台发送post请求,直接报了403错误,说CSRF验证失败;先前用模板的话都是在里面加一个 {%
- 本文实例为大家分享了python爬取微信公众号文章的具体代码,供大家参考,具体内容如下# -*- coding: utf-8 -*-impo
- 1. 停应用层的各种程序。 2. 停oralce的监听进程: $lsnrctl stop 3. 在独占的系统用户下,备份控制文件: SQL&
- 品牌是我们一直挂在嘴边的词语,视觉设计师们经常说到,公司的品牌该如何如何去设计?这个违背了我们的公司品牌!等等。之前我有谈过关于 品牌灵魂的
- 目录1图像叠加2图像融合3按位操作1图像叠加可以通过OpenCV函数cv.add()或简单地通过numpy操作添加两个图像,res = im
- 支持CSS属性Safari和WebKit实施大子的CSS 2.1规格所界定的万维网联盟( W3C ) ,以及部分的CSS 3规格。 。这个C
- 1.实现效果2.实现代码# 导入所需库from tkinter import *import randomclass main:  
- 本文为大家分享了python2.7和NLTK安装教程,具体内容如下系统:Windows 7 Ultimate 64-bitsPython 2
- 引言周末我和小明又开始了疯狂的考证学习,昨晚通过合法的手段获取了一套学习资料,却遇到了一个问题:一套完整的资料,被机构拆分成了162个wor
- 问题:SQL Server 2005中如何利用xml拆分字符串序列?解答:下文中介绍的方法比替换为select union all方法更为见
- 前言:在motplotlib的学习过程中,我们使用最多的就是numpy模块。numpy 模块被称为 matplotlib 模块绘制图表伴侣。
- 1.环境mysql 8.0Django 3.2pycharm 2021.112. (No changes detected)及解决2.1 问
- 最近在看python脚本语言,脚本语言是一种解释性的语言,不需要编译,可以直接用,由解释器来负责解释。python语言很强大,而且写起来很简
- 现象:生产中心进行拷机任务下了300个任务,过了一阵时间后发现任务不再被调度起来,查看后台日志发现日志输出停在某个时间点。分析:1、首先确认
- 数据字典是Oracle存放有关数据库信息的地方,其用途是用来描述数据的。比如一个表的创建者信息,创建时间信息,所属表空间信息,用户访问权限信
- 新闻系统、blog系统等都可能用到将动态页面生成静态页面的技巧来提高页面的访问速度,从而减轻服务器的压力,本文为大家搜集整理了ASP编程中常
- 昨天给公司服务器重做了一下系统,遇到Asp附件无法上传,之前服务器上使用好好的,怎么重做了就不正常了,于是一番google,baidu,下面
- 以下是涉及到插入表格的查询的5种改进方法:1)使用LOAD DATA INFILE从文本下载数据这将比使用插入语句快20倍。2)使用带有多个
- 在用plt.imshow和cv2.imshow显示同一幅图时可能会出现颜色差别很大的现象。这是因为:opencv的接口使用BGR,而matp