浅谈keras中loss与val_loss的关系
作者:lgy_keira 发布时间:2021-12-12 08:41:22
loss函数如何接受输入值
keras封装的比较厉害,官网给的例子写的云里雾里,
在stackoverflow找到了答案
You can wrap the loss function as a inner function and pass your input tensor to it (as commonly done when passing additional arguments to the loss function).
def custom_loss_wrapper(input_tensor):
def custom_loss(y_true, y_pred):
return K.binary_crossentropy(y_true, y_pred) + K.mean(input_tensor)
return custom_loss
input_tensor = Input(shape=(10,))
hidden = Dense(100, activation='relu')(input_tensor)
out = Dense(1, activation='sigmoid')(hidden)
model = Model(input_tensor, out)
model.compile(loss=custom_loss_wrapper(input_tensor), optimizer='adam')
You can verify that input_tensor and the loss value will change as different X is passed to the model.
X = np.random.rand(1000, 10)
y = np.random.randint(2, size=1000)
model.test_on_batch(X, y) # => 1.1974642
X *= 1000
model.test_on_batch(X, y) # => 511.15466
fit_generator
fit_generator ultimately calls train_on_batch which allows for x to be a dictionary.
Also, it could be a list, in which casex is expected to map 1:1 to the inputs defined in Model(input=[in1, …], …)
### generator
yield [inputX_1,inputX_2],y
### model
model = Model(inputs=[inputX_1,inputX_2],outputs=...)
补充知识:学习keras时对loss函数不同的选择,则model.fit里的outputs可以是one_hot向量,也可以是整形标签
我就废话不多说了,大家还是直接看代码吧~
from __future__ import absolute_import, division, print_function, unicode_literals
import tensorflow as tf
from tensorflow import keras
import numpy as np
import matplotlib.pyplot as plt
print(tf.__version__)
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat',
'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
# plt.figure()
# plt.imshow(train_images[0])
# plt.colorbar()
# plt.grid(False)
# plt.show()
train_images = train_images / 255.0
test_images = test_images / 255.0
# plt.figure(figsize=(10,10))
# for i in range(25):
# plt.subplot(5,5,i+1)
# plt.xticks([])
# plt.yticks([])
# plt.grid(False)
# plt.imshow(train_images[i], cmap=plt.cm.binary)
# plt.xlabel(class_names[train_labels[i]])
# plt.show()
model = keras.Sequential([
keras.layers.Flatten(input_shape=(28, 28)),
keras.layers.Dense(128, activation='relu'),
keras.layers.Dense(10, activation='softmax')
])
model.compile(optimizer='adam',
loss='categorical_crossentropy',
#loss = 'sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
metrics=['accuracy'])
one_hot_train_labels = keras.utils.to_categorical(train_labels, num_classes=10)
model.fit(train_images, one_hot_train_labels, epochs=10)
one_hot_test_labels = keras.utils.to_categorical(test_labels, num_classes=10)
test_loss, test_acc = model.evaluate(test_images, one_hot_test_labels)
print('\nTest accuracy:', test_acc)
# predictions = model.predict(test_images)
# predictions[0]
# np.argmax(predictions[0])
# test_labels[0]
loss若为loss=‘categorical_crossentropy', 则fit中的第二个输出必须是一个one_hot类型,
而若loss为loss = ‘sparse_categorical_crossentropy' 则之后的label不需要变成one_hot向量,直接使用整形标签即可
来源:https://blog.csdn.net/u013608336/article/details/82559469
猜你喜欢
- 背景中秋的时候,一个朋友给我发了一封邮件,说他在爬链家的时候,发现网页返回的代码都是乱码,让我帮他参谋参谋(中秋加班,真是敬业= =!),其
- PyQt5动态(可拖动控件大小)布局控件QSplitter简介PyQt还提供了特殊的布局管理器QSplitter。它可以动态地拖动子控件之间
- 正则表达式是处理字符串的强大工具。作为一个概念而言,正则表达式对于Python来说并不是独有的。但是,Python中的正则表达式在实际使用过
- 介绍一下,如何在php程序中运行Python脚本,在php中python程序的运行,主要依靠 程序执行函数,这里说一下三个相关函数:exec
- 前言功能新增学生显示学生查找学生删除学生存到文档创建入口函数在入口函数中,可以先打印一个菜单,用菜单来进行交互。def menu(): &n
- 一切皆是对象在 Python 一切皆是对象,包括所有类型的常量与变量,整型,布尔型,甚至函数。 参见stackoverflow上的一个问题
- 关键路径计算是项目管理中关于进度管理的基本计算。 但是对于绝大多数同学来说, 关键路径计算都只是对一些简单情形的计算。今天,田老师根据以往的
- LearningjQuery.com 博客帖子列表的左边有一个很酷的日期,如图:从图中我们看到,“2009”垂直排列在右侧。用Firebug
- defer介绍defer是golang的一个特色功能,被称为“延迟调用函数”。当外部函数返回后执行defer。类似于其他语言的 try… c
- 如何设置list步长示例:range(a, b, step)>>> list(range(0,5,2)) [0,
- 因些朋友发来邮件讲根据文章修改后无效,懒羊再次检查后发现在工具栏中并无添加,所以还得做一下下面步骤,再此给大家造成的不便还请多多谅解!因FC
- 很久以前写过如何成为优秀的设计师,近半年来经常做设计评审,有很多感触,顺便写一点下来,我们的Blog也应该有更高的更新频率。言归正传,我认为
- 在很多企业会使用闲置的 Windows 机器作为临时服务器,有时候我们想远程调用里面的程序或查看日志文件Windows 内置的服务
- 需求描述我们需要登录考勤系统(网页端,非手机端)进行签到,如果不想每天都早早起来打卡签到,就可以通过写程序实现这一功能。业务梳理通过长时间的
- 阅读上一篇:javascript面向对象编程(一)[javascript模拟传统OOP]javascript是一种非常灵活的语言,它的灵活度
- 不知道大家有没发现DWMX中有一个和FW差不多的制作弹出菜单功能?这个功能允许用文字和图片做为主菜单,如果用文字的话要先做虚拟链接。下面简单
- python实现PSO算法优化二元函数,具体代码如下所示:import numpy as np import random import m
- 最近随着狂风计划的席卷,我也终于开始橱窗产品位列表展示的编码工作,这只是一个改进项目,因此有原代码可供参考。但是当我打开原代码模板的时候便愣
- 基本开发环境· Python 3.6· Pycharm相关模块使用import requestsimport timefrom tkinte
- 这段时间服务器崩溃2次,一直没有找到原因,今天看到论坛发出的错误信息邮件,想起可能是mysql的默认连接数引起的问题,一查果然,老天,默认