keras 自定义loss层+接受输入实例
作者:lgy_keira 发布时间:2023-09-23 16:37:55
loss函数如何接受输入值
keras封装的比较厉害,官网给的例子写的云里雾里,
在stackoverflow找到了答案
You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).
def custom_loss_wrapper(input_tensor):
def custom_loss(y_true, y_pred):
return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')
You can verify that input_tensor and the loss value will change as different X is passed to the model.
X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466
fit_generator
fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.
Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)
### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)
补充知识:keras中自定义 loss损失函数和修改不同样本的loss权重(样本权重、类别权重)
首先辨析一下概念:
1. loss是整体网络进行优化的目标, 是需要参与到优化运算,更新权值W的过程的
2. metric只是作为评价网络表现的一种“指标”, 比如accuracy,是为了直观地了解算法的效果,充当view的作用,并不参与到优化过程
一、keras自定义损失函数
在keras中实现自定义loss, 可以有两种方式,一种自定义 loss function, 例如:
# 方式一
def vae_loss(x, x_decoded_mean):
xent_loss = objectives.binary_crossentropy(x, x_decoded_mean)
kl_loss = - 0.5 * K.mean(1 + z_log_sigma - K.square(z_mean) - K.exp(z_log_sigma), axis=-1)
return xent_loss + kl_loss
vae.compile(optimizer='rmsprop', loss=vae_loss)
或者通过自定义一个keras的层(layer)来达到目的, 作为model的最后一层,最后令model.compile中的loss=None:
# 方式二
# Custom loss layer
class CustomVariationalLayer(Layer):
def __init__(self, **kwargs):
self.is_placeholder = True
super(CustomVariationalLayer, self).__init__(**kwargs)
def vae_loss(self, x, x_decoded_mean_squash):
x = K.flatten(x)
x_decoded_mean_squash = K.flatten(x_decoded_mean_squash)
xent_loss = img_rows * img_cols * metrics.binary_crossentropy(x, x_decoded_mean_squash)
kl_loss = - 0.5 * K.mean(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
return K.mean(xent_loss + kl_loss)
def call(self, inputs):
x = inputs[0]
x_decoded_mean_squash = inputs[1]
loss = self.vae_loss(x, x_decoded_mean_squash)
self.add_loss(loss, inputs=inputs)
# We don't use this output.
return x
y = CustomVariationalLayer()([x, x_decoded_mean_squash])
vae = Model(x, y)
vae.compile(optimizer='rmsprop', loss=None)
在keras中自定义metric非常简单,需要用y_pred和y_true作为自定义metric函数的输入参数 点击查看metric的设置
注意事项:
1. keras中定义loss,返回的是batch_size长度的tensor, 而不是像tensorflow中那样是一个scalar
2. 为了能够将自定义的loss保存到model, 以及可以之后能够顺利load model, 需要把自定义的loss拷贝到keras.losses.py 源代码文件下,否则运行时找不到相关信息,keras会报错
有时需要不同的sample的loss施加不同的权重,这时需要用到sample_weight,例如
discriminator.train_on_batch(imgs, [valid, labels], class_weight=class_weights)
二、keras中的样本权重
# Import
import numpy as np
from sklearn.utils import class_weight
# Example model
model = Sequential()
model.add(Dense(32, activation='relu', input_dim=100))
model.add(Dense(1, activation='sigmoid'))
# Use binary crossentropy loss
model.compile(optimizer='rmsprop',
loss='binary_crossentropy',
metrics=['accuracy'])
# Calculate the weights for each class so that we can balance the data
weights = class_weight.compute_class_weight('balanced',
np.unique(y_train),
y_train)
# Add the class weights to the training
model.fit(x_train, y_train, epochs=10, batch_size=32, class_weight=weights)
Note that the output of the class_weight.compute_class_weight() is an numpy array like this: [2.57569845 0.68250928].
来源:https://blog.csdn.net/u013608336/article/details/82559469


猜你喜欢
- 在go语言中,byte其实是uint8的别名,byte 和 uint8 之间可以直接进行互转。目前来只能将0~255范围的int转成byte
- 本文实例讲述了微信小程序module.exports模块化操作。分享给大家供大家参考,具体如下:文件 目录如上图:看到网上写的模块化都比较复
- 1.获取所有数据库名: SELECT Name FROM Master..SysDatabases ORDER BY Name2.获取所有表
- 本文实例为大家分享了python实现记事本功能的具体代码,供大家参考,具体内容如下1. 案例介绍tkinter 是 Python下面向 tk
- 一、爬山法简介爬山法(climbing method)是一种优化算法,其一般从一个随机的解开始,然后逐步找到一个最优解(局部最优)。 假定所
- Matplotlib配置了配色方案和默认设置,主要用来准备用于发布的图片。有两种方式可以设置参数,即全局参数定制和rc设置方法。查看matp
- 本文主要给大家介绍了关于Python中字典(dict)合并的四种方法,分享出来供大家参考学习,话不多说了,来一起看看详细的介绍:字典是Pyt
- 在北美,人们对于 PostgreSQL 的热情不断升温。随着 PostgreSQL 的发展, PostgreSQL 8.x 已经从技术上超越
- 前言本文做的是基于三层神经网络实现手写数字分类,神经网络设计是设计复杂深度学习算法应用的基础,本文将介绍如何设计一个三层神经网络模型来实现手
- #!/bin/env python # -*- coding: utf-8 -*- #filename: peartes
- 在开发过程中我们需要将我们的数据通过图标的形式展现出来,接下来我为大家介绍一个有趣的框架:Echarts。这是一个使用JavaScript实
- 有时候需要在网页中某个div载入之后,动态引入一段javascript,IE下的解决方案: newjs. onreadystatechang
- 一、前言为方便描述教程例子,这里给出mysql表结构定义和golang结构体定义。下面是教程用到的foods表结构定义:CREATE TAB
- QSpinBox 是一个计数器控件,允许用户选择一个整数值,通过单击向上/向下按钮或按键盘上的上/下箭头来增加/减少当前显示的值,当然用户也
- 本文实例讲述了PHP共享内存使用与信号控制。分享给大家供大家参考,具体如下:共享内存共享内存的使用主要是为了能够在同一台机器不同的进程中共享
- 在这可以用join()函数'x'.join(y),x可以是任意分割字符,y是列表或元组。以列表为例,可以将列表中的每一个元素
- 对于有的vps,系统默认安装了mysql。我们需要从我们的服务器、vps上卸载(移除)默认的mysql。那么如何(怎样)在ubuntu\De
- 画之前肯定要知道规格图,我找了一个大致的图。参考图片:绘制大星的方法很简单,五角星的补角是144度。绘制小五角星有点麻烦,因为我国国旗上的小
- function click(e) { if (document.all) { if (event.button==1||event.but
- 本文实例讲述了Python3爬虫爬取英雄联盟高清桌面壁纸功能。分享给大家供大家参考,具体如下:使用Scrapy爬虫抓取英雄联盟高清桌面壁纸源