Keras中的多分类损失函数用法categorical_crossentropy
作者:赵大寳Note 发布时间:2023-06-23 12:25:37
from keras.utils.np_utils import to_categorical
注意:当使用categorical_crossentropy损失函数时,你的标签应为多类模式,例如如果你有10个类别,每一个样本的标签应该是一个10维的向量,该向量在对应有值的索引位置为1其余为0。
可以使用这个方法进行转换:
from keras.utils.np_utils import to_categorical
categorical_labels = to_categorical(int_labels, num_classes=None)
以mnist数据集为例:
from keras.datasets import mnist
(X_train, y_train), (X_test, y_test) = mnist.load_data()
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)
...
model.compile(loss='categorical_crossentropy', optimizer='adam')
model.fit(X_train, y_train, epochs=100, batch_size=1, verbose=2)
补充知识:Keras中损失函数binary_crossentropy和categorical_crossentropy产生不同结果的分析
问题
在使用keras做对心电信号分类的项目中发现一个问题,这个问题起源于我的一个使用错误:
binary_crossentropy 二进制交叉熵用于二分类问题中,categorical_crossentropy分类交叉熵适用于多分类问题中,我的心电分类是一个多分类问题,但是我起初使用了二进制交叉熵,代码如下所示:
sgd = SGD(lr=0.003, decay=0, momentum=0.7, nesterov=False)
model.compile(loss='categorical_crossentropy',
optimizer='sgd',metrics=['accuracy'])
model.fit(X_train, Y_train, validation_data=(X_test,Y_test),batch_size=16, epochs=20)
score = model.evaluate(X_test, Y_test, batch_size=16)
注意:我的CNN网络模型在最后输入层正确使用了应该用于多分类问题的softmax激活函数
后来我在另一个残差网络模型中对同类数据进行相同的分类问题中,正确使用了分类交叉熵,令人奇怪的是残差模型的效果远弱于普通卷积神经网络,这一点是不符合常理的,经过多次修改分析终于发现可能是损失函数的问题,因此我使用二进制交叉熵在残差网络中,终于取得了优于普通卷积神经网络的效果。
因此可以断定问题就出在所使用的损失函数身上
原理
本人也只是个只会使用框架的调参侠,对于一些原理也是一知半解,经过了学习才大致明白,将一些原理记录如下:
要搞明白分类熵和二进制交叉熵先要从二者适用的激活函数说起
激活函数
sigmoid, softmax主要用于神经网络输出层的输出。
softmax函数
softmax可以看作是Sigmoid的一般情况,用于多分类问题。
Softmax函数将K维的实数向量压缩(映射)成另一个K维的实数向量,其中向量中的每个元素取值都介于 (0,1) 之间。常用于多分类问题。
sigmoid函数
Sigmoid 将一个实数映射到 (0,1) 的区间,可以用来做二分类。Sigmoid 在特征相差比较复杂或是相差不是特别大时效果比较好。Sigmoid不适合用在神经网络的中间层,因为对于深层网络,sigmoid 函数反向传播时,很容易就会出现梯度消失的情况(在 sigmoid 接近饱和区时,变换太缓慢,导数趋于 0,这种情况会造成信息丢失),从而无法完成深层网络的训练。所以Sigmoid主要用于对神经网络输出层的激活。
分析
所以说多分类问题是要softmax激活函数配合分类交叉熵函数使用,而二分类问题要使用sigmoid激活函数配合二进制交叉熵函数适用,但是如果在多分类问题中使用了二进制交叉熵函数最后的模型分类效果会虚高,即比模型本身真实的分类效果好。
所以就会出现我遇到的情况,这里引用了论坛一位大佬的样例:
model.compile(loss='binary_crossentropy', optimizer='adam', metrics=['accuracy']) # WRONG way
model.fit(x_train, y_train,
batch_size=batch_size,
epochs=2, # only 2 epochs, for demonstration purposes
verbose=1,
validation_data=(x_test, y_test))
# Keras reported accuracy:
score = model.evaluate(x_test, y_test, verbose=0)
score[1]
# 0.9975801164627075
# Actual accuracy calculated manually:
import numpy as np
y_pred = model.predict(x_test)
acc = sum([np.argmax(y_test[i])==np.argmax(y_pred[i]) for i in range(10000)])/10000
acc
# 0.98780000000000001
score[1]==acc
# False
样例中模型在评估中得到的准确度高于实际测算得到的准确度,网上给出的原因是Keras没有定义一个准确的度量,但有几个不同的,比如binary_accuracy和categorical_accuracy,当你使用binary_crossentropy时keras默认在评估过程中使用了binary_accuracy,但是针对你的分类要求,应当采用的是categorical_accuracy,所以就造成了这个问题(其中的具体原理我也没去看源码详细了解)
解决
所以问题最后的解决方法就是:
对于多分类问题,要么采用
from keras.metrics import categorical_accuracy
model.compile(loss='binary_crossentropy',
optimizer='adam', metrics=[categorical_accuracy])
要么采用
model.compile(loss='categorical_crossentropy',
optimizer='adam',metrics=['accuracy'])
来源:https://blog.csdn.net/u010412858/article/details/76842216


猜你喜欢
- 简单介绍下SecureCRTSecureCRT是一款支持SSH(SSH1和SSH2)的终端仿真程序,简单地说是Windows下登录UNIX或
- 一、框架菜单1.1 common模块1.2 其他二、Excel接口测试案例编写三、读取Excel测试封装(核心封装)excel_utils.
- 本文将带领大家由浅入深的去窥探一下,这个装饰器到底是何方神圣,看完本篇,装饰器就再也不是难点了.一、什么是装饰器网上有人是这么评价装饰器的,
- urllib模块发起的POST请求案例:爬取百度翻译的翻译结果1.通过浏览器捉包工具,找到POST请求的url针对ajax页面请求的所对应u
- 一个很棒的 blog 文章,是 PPK 两年前写的,文章中解释了 contains() 和 compareDocumentPosition(
- 什么是多态?多态(Polymorphism)按字面的意思就是“多种状态”。在面向对象语言中,接口的多种不同的实现方式即为多态。引用Charl
- 如题,我有一个模板,我想根据需求复制模板中间的某一页多次,比如复制第五页,然后复制3次,那么第六页,第七页,第八页都是和第五页一模一样的pp
- 夹角余弦(Cosine)也可以叫余弦相似度。 几何中夹角余弦可用来衡量两个向量方向的差异,机器学习中借用这一概念来衡量样本向量之间的差异。(
- 注:本文所说的视觉设计师专指网页视觉设计师。网页设计师与平面设计师都归类为设计师,其实这两个职业是跨行业的,虽然有很多设计师一直在跨行业工作
- 问题微信公众号获取code时的跳转链接,默认是获取当前页面的链接,代码如下:// 说明:获取当前页面的url地址function GetCu
- 本文实例讲述了Node Express用法。分享给大家供大家参考,具体如下:安装npm install --save express基本使用
- 本文实例讲述了php+mysql开发的最简单在线题库。分享给大家供大家参考,具体如下:题库,对于教育机构,学校,在线教育,是很有必要的,网上
- 编写Python代码,大家都需要遵循PEP8,因此在pycharm中,如何设置每行最大长度限制,成为了一个小的知识盲点,在这里做一下记录,方
- 写在之前命名空间,又名 namesapce,是在很多的编程语言中都会出现的术语,估计很多人都知道这个词,但是让你真的来说这是个什么,估计就歇
- 这篇文章主要介绍了如何通过python实现人脸识别验证,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋
- 编写tasks.pyfrom celery import Celeryfrom tornado.httpclient import HTTP
- 前言MySQL 的权限表在数据库启动的时候就载入内存,当用户通过身份认证后,就在内存中进行相应权限的存取,这样,此用户就可以在数据库中做权限
- 从句法上看,协程与生成器类似,都是定义体中包含 yield 关键字的函数。可是,在协程中, yield 通常出现在表达式的右边(例如, da
- [编者注:]提起数据库,第一个想到的公司,一般都会是Oracle(即甲骨文公司)。Oracle在数据库领域一直处于领先地位。Oracle关系
- 本文主要是利用Python的第三方库Pillow,实现单通道灰度图像的颜色翻转功能。# -*- encoding:utf-8 -*-impo