Keras之自定义损失(loss)函数用法说明
作者:鹊踏枝 发布时间:2023-09-24 12:12:15
标签:Keras,自定义,损失,loss
在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:
def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
.
.
.
return scalar #返回一个标量值
然后在model.compile中指定即可,如:
model.compile(loss=my_loss, optimizer='sgd')
具体参考Keras官方metrics的定义keras/metrics.py:
"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
def binary_accuracy(y_true, y_pred):
return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
def categorical_accuracy(y_true, y_pred):
return K.cast(K.equal(K.argmax(y_true, axis=-1),
K.argmax(y_pred, axis=-1)),
K.floatx())
def sparse_categorical_accuracy(y_true, y_pred):
# reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
if K.ndim(y_true) == K.ndim(y_pred):
y_true = K.squeeze(y_true, -1)
# convert dense predictions to labels
y_pred_labels = K.argmax(y_pred, axis=-1)
y_pred_labels = K.cast(y_pred_labels, K.floatx())
return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
def top_k_categorical_accuracy(y_true, y_pred, k=5):
return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
# If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
axis=-1)
# Aliases
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
def serialize(metric):
return serialize_keras_object(metric)
def deserialize(config, custom_objects=None):
return deserialize_keras_object(config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='metric function')
def get(identifier):
if isinstance(identifier, dict):
config = {'class_name': str(identifier), 'config': {}}
return deserialize(config)
elif isinstance(identifier, six.string_types):
return deserialize(str(identifier))
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret '
'metric function identifier:', identifier)
来源:https://blog.csdn.net/u011501388/article/details/84030578


猜你喜欢
- 任务描述本次实践是一个多分类任务,需要将照片中的宝石分别进行识别,完成宝石的识别实践平台:百度AI实训平台-AI Studio、Paddle
- -- 任意的测试表 CREATE TABLE test_delete( name varchar(10), value INT ); go
- 我准备在ASP中连接MYSQL了,请问如何做?首先要正确安装MYSQLX,装好之后,可调用以下程序即可正常访问MYSQL:<%@&nb
- 在上篇文章给大家介绍了BootstrapTable与KnockoutJS相结合实现增删改查功能【一】,介绍了下knockout.js的一些基
- 一、什么是异常在python中,错误触发的异常如下二、异常的种类在python中不同的异常可以用不同的类型去标识,一个异常标识一种错误。1
- 本文实例讲述了Python使用中文正则表达式匹配指定中文字符串的方法。分享给大家供大家参考,具体如下:业务场景:从中文字句中匹配出指定的中文
- PyTorch之TensorDatasetTensorDataset 可以用来对 tensor 进行打包,就好像 pyt
- 小朋友你可能有很多问号~,上一小节不是已经一顿操作猛如虎搭建好 Python + PyCharm 可用开发环境了吗?为什么这节又来个项目运行
- Symfony是一个强大的基于PHP的Web开发框架,在这里我们用十分钟的时间来做一个简单的增删改查的程序, 任何不熟悉Symfony的人都
- 去年淘宝做了个“胖子”项目,就是把网页的默认宽度从780提升到了950。也就是说,基本放弃了800×600的用户(没有完全放弃,如果你仔细研
- 本文研究的主要是Django1.10文档的深入学习,Applications基础部分的相关内容,具体介绍如下。Applications应用D
- 前言对Python游戏有所了解的朋友都知道,在2D的游戏制作中,经常会用到一个模块pygame,他能帮助我们实现很多方便使用的功能,例如绘制
- 目录前言yarn create 做了什么源码解析项目依赖模版配置工具函数copycopyDiremptyDir核心函数命令行交互并创建文件夹
- --sql基本操作--创建数据库create database Studets--创建表create table student ( sno
- 本文实例讲述了js实现的星星评分功能函数。分享给大家供大家参考,具体如下:<!DOCTYPE html PUBLIC "-/
- 1. 新建项目在命令行窗口下输入scrapy startproject scrapytest, 如下然后就自动创建了相应的文件,如下2. 修
- (高手就不要笑话了^_^)。好了,其他的不说现在就开始:select 子句主要决定了从表中取出的列名,列数以及列的显示顺序等信息,"
- 一、写在前面作为一名测试,有时候经常会遇到需要录屏记录自己操作,方便后续开发同学定位。以前都是用ScreenToGif来录屏制作成动态图,偶
- 在python中enumerate的用法多用于在for循环中得到计数,本文即以实例形式向大家展现python中enumerate的用法。具体
- Pycharm创建的项目,使用了虚拟环境,对库的版本进行管理;有些项目的对第三方库的版本 要求不同,可使用虚拟环境进行管理直接想通过pip命