网络编程
位置:首页>> 网络编程>> Python编程>> Keras之自定义损失(loss)函数用法说明

Keras之自定义损失(loss)函数用法说明

作者:鹊踏枝  发布时间:2023-09-24 12:12:15 

标签:Keras,自定义,损失,loss

在Keras中可以自定义损失函数,在自定义损失函数的过程中需要注意的一点是,损失函数的参数形式,这一点在Keras中是固定的,须如下形式:


def my_loss(y_true, y_pred):
# y_true: True labels. TensorFlow/Theano tensor
# y_pred: Predictions. TensorFlow/Theano tensor of the same shape as y_true
.
.
.
return scalar #返回一个标量值

然后在model.compile中指定即可,如:

model.compile(loss=my_loss, optimizer='sgd')

具体参考Keras官方metrics的定义keras/metrics.py:


"""Built-in metrics.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import six
from . import backend as K
from .losses import mean_squared_error
from .losses import mean_absolute_error
from .losses import mean_absolute_percentage_error
from .losses import mean_squared_logarithmic_error
from .losses import hinge
from .losses import logcosh
from .losses import squared_hinge
from .losses import categorical_crossentropy
from .losses import sparse_categorical_crossentropy
from .losses import binary_crossentropy
from .losses import kullback_leibler_divergence
from .losses import poisson
from .losses import cosine_proximity
from .utils.generic_utils import deserialize_keras_object
from .utils.generic_utils import serialize_keras_object

def binary_accuracy(y_true, y_pred):
return K.mean(K.equal(y_true, K.round(y_pred)), axis=-1)

def categorical_accuracy(y_true, y_pred):
return K.cast(K.equal(K.argmax(y_true, axis=-1),
      K.argmax(y_pred, axis=-1)),
    K.floatx())

def sparse_categorical_accuracy(y_true, y_pred):
# reshape in case it's in shape (num_samples, 1) instead of (num_samples,)
if K.ndim(y_true) == K.ndim(y_pred):
 y_true = K.squeeze(y_true, -1)
# convert dense predictions to labels
y_pred_labels = K.argmax(y_pred, axis=-1)
y_pred_labels = K.cast(y_pred_labels, K.floatx())
return K.cast(K.equal(y_true, y_pred_labels), K.floatx())

def top_k_categorical_accuracy(y_true, y_pred, k=5):
return K.mean(K.in_top_k(y_pred, K.argmax(y_true, axis=-1), k), axis=-1)

def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
# If the shape of y_true is (num_samples, 1), flatten to (num_samples,)
return K.mean(K.in_top_k(y_pred, K.cast(K.flatten(y_true), 'int32'), k),
    axis=-1)

# Aliases

mse = MSE = mean_squared_error
mae = MAE = mean_absolute_error
mape = MAPE = mean_absolute_percentage_error
msle = MSLE = mean_squared_logarithmic_error
cosine = cosine_proximity

def serialize(metric):
return serialize_keras_object(metric)

def deserialize(config, custom_objects=None):
return deserialize_keras_object(config,
        module_objects=globals(),
        custom_objects=custom_objects,
        printable_module_name='metric function')

def get(identifier):
if isinstance(identifier, dict):
 config = {'class_name': str(identifier), 'config': {}}
 return deserialize(config)
elif isinstance(identifier, six.string_types):
 return deserialize(str(identifier))
elif callable(identifier):
 return identifier
else:
 raise ValueError('Could not interpret '
      'metric function identifier:', identifier)

来源:https://blog.csdn.net/u011501388/article/details/84030578

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com