Keras之自定义损失(loss)函数用法说明
作者:鹊踏枝 发布时间:2023-09-24 12:12:15
标签:Keras,自定义,损失,loss
在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:
def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
.
.
.
return scalar #返回一个标量值
然后在model.compile中指定即可,如:
model.compile(loss=my_loss, optimizer='sgd')
具体参考Keras官方metrics的定义keras/metrics.py:
"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object
def binary_accuracy(y_true, y_pred):
return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)
def categorical_accuracy(y_true, y_pred):
return K.cast(K.equal(K.argmax(y_true, axis=-1),
K.argmax(y_pred, axis=-1)),
K.floatx())
def sparse_categorical_accuracy(y_true, y_pred):
# reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
if K.ndim(y_true) == K.ndim(y_pred):
y_true = K.squeeze(y_true, -1)
# convert dense predictions to labels
y_pred_labels = K.argmax(y_pred, axis=-1)
y_pred_labels = K.cast(y_pred_labels, K.floatx())
return K.cast(K.equal(y_true, y_pred_labels), K.floatx())
def top_k_categorical_accuracy(y_true, y_pred, k=5):
return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
# If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
axis=-1)
# Aliases
mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity
def serialize(metric):
return serialize_keras_object(metric)
def deserialize(config, custom_objects=None):
return deserialize_keras_object(config,
module_objects=globals(),
custom_objects=custom_objects,
printable_module_name='metric function')
def get(identifier):
if isinstance(identifier, dict):
config = {'class_name': str(identifier), 'config': {}}
return deserialize(config)
elif isinstance(identifier, six.string_types):
return deserialize(str(identifier))
elif callable(identifier):
return identifier
else:
raise ValueError('Could not interpret '
'metric function identifier:', identifier)
来源:https://blog.csdn.net/u011501388/article/details/84030578
![](https://www.aspxhome.com/images/zang.png)
![](https://www.aspxhome.com/images/jiucuo.png)
猜你喜欢
- python实现12306余票查询我们说先在浏览器中打开开发者工具(F12),尝试一次余票的查询,通过开发者工具查看发出请求的包余票查询界面
- 首先,我们需要着重介绍一些概念,以给你提供一些使这个“奇迹”得以发生的组成部分。太轻易地泄露伏笔对于讲故事来说不是个好的形式,所以那些不愿意
- 在实际应用中,我们经常需要使用定时器去触发一些事件。Python中通过线程实现定时器timer,其使用非常简单。看示例:import thr
- 实现二维平面上散点的绘制,并可以给每个散点标记序号或者名称:import numpy as npimport matplotlib.pypl
- 本文实例讲述了python从sqlite读取并显示数据的方法。分享给大家供大家参考。具体实现方法如下:import cgi, os, sys
- Python 操作文件编程语言对文件系统的操作是一项必不可少的功能,各种编程语言基本上都有对文件系统的操作,最简洁的莫过于linux里面sh
- 一、'建立register.asp 代码如下:<%@ language=vbscript %>&nb
- 注意:这种方法十分受光线变化影响自己在家拿着手机瞎晃的成果图:源代码:# -*- coding: utf-8 -*- ""
- 使用微信获取地址信息是和微信支付一道申请的,微信支付申请通过,就可以使用该功能。微信商城中,使用微信支付获取用户的收货地址,可以省略用户输入
- 前面的python3入门系列基本上也对python入了门,从这章起就开始介绍下python的爬虫教程,拿出来给大家分享;爬虫说的简单,就是去
- 导语哈喽!我是木木子,又到了今日更新时刻!我们来看看写什么呢?小编有个好兄弟最近在追妹子,跟妹子打得火热!就差临门一脚了,这一jio我帮忙补
- 1、基本原理访问网站扫码登录页,网站给浏览器返回一个二维码和一个唯一标志KEY浏览器开启定时轮询服务器,确认KEY对应的扫码结果用户使用ap
- Pycharm本身并不带编译器,所以第一次用需要自己下载编译器插件。1、首先去 https://www.python.org/downloa
- 本文实例讲述了Python实现截屏的函数。分享给大家供大家参考。具体如下:1.可指定保存目录.2.截屏图片名字以时间为文件名3.截屏图片存为
- 下面给大家分享Python爬虫后获取重定向url的两种方法,具体内容如下所示;方法(一)# 获得重定向url from urllib imp
- pytorch中如何只让指定变量向后传播梯度?(或者说如何让指定变量不参与后向传播?)有以下公式,假如要让L对xvar求导:(1)中,L对x
- 为新项目写的一份规范文档, 分享给大家. 我想前端开发过程中, 无论是团队开发, 还是单兵做站, 有一份开发文档做规范, 对开发工作都是很有
- 作用:可以清空此文件所在的web站点所有文件,将文件内容清零.运行完毕所有文件大小都变成0字节.此代码本人原创,转载请注明转自本站,谢谢合作
- 获得某层tensor的输出维度代码如下所示:from keras import backend as K@wraps(Conv2D)def
- 一、破解原理其实原理很简单,一句话概括就是「大力出奇迹」,Python 有两个压缩文件库:zipfile 和 rarfile,这两个库提供的