keras打印loss对权重的导数方式
作者:HackerTom 发布时间:2023-05-17 18:21:11
标签:keras,loss,权重,导数
Notes
怀疑模型梯度 * ,想打印模型 loss 对各权重的导数看看。如果如果fit来训练的话,可以用keras.callbacks.TensorBoard实现。
但此次使用train_on_batch来训练的,用K.gradients和K.function实现。
Codes
以一份 VAE 代码为例
# -*- coding: utf8 -*-
import keras
from keras.models import Model
from keras.layers import Input, Lambda, Conv2D, MaxPooling2D, Flatten, Dense, Reshape
from keras.losses import binary_crossentropy
from keras.datasets import mnist, fashion_mnist
import keras.backend as K
from scipy.stats import norm
import numpy as np
import matplotlib.pyplot as plt
BATCH = 128
N_CLASS = 10
EPOCH = 5
IN_DIM = 28 * 28
H_DIM = 128
Z_DIM = 2
(x_train, y_train), (x_test, y_test) = fashion_mnist.load_data()
x_train = x_train.reshape(len(x_train), -1).astype('float32') / 255.
x_test = x_test.reshape(len(x_test), -1).astype('float32') / 255.
def sampleing(args):
"""reparameterize"""
mu, logvar = args
eps = K.random_normal([K.shape(mu)[0], Z_DIM], mean=0.0, stddev=1.0)
return mu + eps * K.exp(logvar / 2.)
# encode
x_in = Input([IN_DIM])
h = Dense(H_DIM, activation='relu')(x_in)
z_mu = Dense(Z_DIM)(h) # mean,不用激活
z_logvar = Dense(Z_DIM)(h) # log variance,不用激活
z = Lambda(sampleing, output_shape=[Z_DIM])([z_mu, z_logvar]) # 只能有一个参数
encoder = Model(x_in, [z_mu, z_logvar, z], name='encoder')
# decode
z_in = Input([Z_DIM])
h_hat = Dense(H_DIM, activation='relu')(z_in)
x_hat = Dense(IN_DIM, activation='sigmoid')(h_hat)
decoder = Model(z_in, x_hat, name='decoder')
# VAE
x_in = Input([IN_DIM])
x = x_in
z_mu, z_logvar, z = encoder(x)
x = decoder(z)
out = x
vae = Model(x_in, [out, out], name='vae')
# loss_kl = 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)
# loss_recon = binary_crossentropy(K.reshape(vae_in, [-1, IN_DIM]), vae_out) * IN_DIM
# loss_vae = K.mean(loss_kl + loss_recon)
def loss_kl(y_true, y_pred):
return 0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1)
# vae.add_loss(loss_vae)
vae.compile(optimizer='rmsprop',
loss=[loss_kl, 'binary_crossentropy'],
loss_weights=[1, IN_DIM])
vae.summary()
# 获取模型权重 variable
w = vae.trainable_weights
print(w)
# 打印 KL 对权重的导数
# KL 要是 Tensor,不能是上面的函数 `loss_kl`
grad = K.gradients(0.5 * K.sum(K.square(z_mu) + K.exp(z_logvar) - 1. - z_logvar, axis=1),
w)
print(grad) # 有些是 None 的
grad = grad[grad is not None] # 去掉 None,不然报错
# 打印梯度的函数
# K.function 的输入和输出必要是 list!就算只有一个
show_grad = K.function([vae.input], [grad])
# vae.fit(x_train, # y_train, # 不能传 y_train
# batch_size=BATCH,
# epochs=EPOCH,
# verbose=1,
# validation_data=(x_test, None))
''' 以 train_on_batch 方式训练 '''
for epoch in range(EPOCH):
for b in range(x_train.shape[0] // BATCH):
idx = np.random.choice(x_train.shape[0], BATCH)
x = x_train[idx]
l = vae.train_on_batch([x], [x, x])
# 计算梯度
gd = show_grad([x])
# 打印梯度
print(gd)
# show manifold
PIXEL = 28
N_PICT = 30
grid_x = norm.ppf(np.linspace(0.05, 0.95, N_PICT))
grid_y = grid_x
figure = np.zeros([N_PICT * PIXEL, N_PICT * PIXEL])
for i, xi in enumerate(grid_x):
for j, yj in enumerate(grid_y):
noise = np.array([[xi, yj]]) # 必须秩为 2,两层中括号
x_gen = decoder.predict(noise)
# print('x_gen shape:', x_gen.shape)
x_gen = x_gen[0].reshape([PIXEL, PIXEL])
figure[i * PIXEL: (i+1) * PIXEL,
j * PIXEL: (j+1) * PIXEL] = x_gen
fig = plt.figure(figsize=(10, 10))
plt.imshow(figure, cmap='Greys_r')
fig.savefig('./variational_autoencoder.png')
plt.show()
补充知识:keras 自定义损失 自动求导时出现None
问题记录,keras 自定义损失 自动求导时出现None,后来想到是因为传入的变量没有使用,所以keras无法求出偏导,修改后问题解决。就是不愿使用的变量×0,求导后还是0就可以了。
def my_complex_loss_graph(y_label, emb_uid, lstm_out,y_true_1,y_true_2,y_true_3,out_1,out_2,out_3):
mse_out_1 = mean_squared_error(y_true_1, out_1)
mse_out_2 = mean_squared_error(y_true_2, out_2)
mse_out_3 = mean_squared_error(y_true_3, out_3)
# emb_uid= K.reshape(emb_uid, [-1, 32])
cosine_sim = tf.reduce_sum(0.5*tf.square(emb_uid-lstm_out))
cost=0*cosine_sim+K.sum([0.5*mse_out_1 , 0.25*mse_out_2,0.25*mse_out_3],axis=1,keepdims=True)
# print(mse_out_1)
final_loss = cost
return K.mean(final_loss)
来源:https://blog.csdn.net/HackerTom/article/details/90177044
0
投稿
猜你喜欢
- 有的时候,为了对python文件进行加密,会把python模块编译成.pyd文件,供其他人调用。拿到一个.pyd文件,在没有文档说明的情况下
- 1 安装Djangopython -m pip install django2 新建项目 my_apidjango-admin startp
- // 添加function col_add() { var selObj = $("#mySelect");&
- 1.如何定位并优化慢查询sqla.根据慢日志定位慢查询sqlSHOW VARIABLES LIKE '%query%' &
- 简介主要介绍事件总线的定义和编写方法和Vue是如何实现消息的订阅与发布的。事件总线事件总线是组件间通信的一种方式,适用于任意组件间的通信,比
- 代码如下:--Begin Index(索引) 分析优化的相关 Sql -- 返回当前数据库所有碎片率大于25%的索引 -- 运行
- MySQL清空表数据清空表数据一共有三种方式1 、truncate (速度很快) 自增字段清空从1开始 全表清空首选2、drop 直接删表&
- 具体内容如下所示:一、常用函数1、ASCII()返回字符表达式最左端字符的ASCII 码值。在ASCII()函数中,纯数字的字符串可不用‘&
- 为了方便快捷开发,有些常用的代码块可以直接在IDE编辑器中保存为一个代码块,用简写的方式快捷调取,常用的方法:先建一个模板分组并命令为myT
- 下载IDEA、PyCharm、PhpStorm免费激活码本次更新:2020年11月13 (定期更新)推荐教程:IntelliJ IDEA 2
- 本文总结了python画图中使用各种特殊符号方式一、问题背景在论文中,如何使用特殊符号进行表示?这里给出效果图和代码完整代码:from ma
- 较基础的SVM,后续会加上多分类以及高斯核,供大家参考。Talk is cheap, show me the codeimport tens
- 普通滑动验证以http://admin.emaotai.cn/login.aspx为例这类验证码只需要我们将滑块拖动指定位置,处理起来比较简
- 一、特效预览处理前处理后细节放大后二、程序原理1.输入你想隐藏的文字2.然后写到另一张跟照片同等大小的空白纸张上3.将相同位置的文字的颜色用
- 代码如下:ALTER proc [dbo].[sp_common_paypal_AddInfo] ( @paypalsql va
- 看了两天 go 语言,是时候练练手了。go 的 routine(例程) 和 chan(通道) 简直是神器,实现多线程(在 go 里准确的来说
- 前言gif图就是动态图,它的原理和视频有点类似,也是通过很多静态图片合成的.本篇文章主要介绍,如何利用Python快速合成gif图,主要利用
- python字符串字符串是 Python 中最常用的数据类型。我们可以使用引号('或")来创建字符串。创建字符串很简单,只
- 本文实例讲述了python中list循环语句用法。分享给大家供大家参考。具体用法分析如下:Python 的强大特性之一就是其对 list 的
- 在这里给出是的WindowsXP操作系统下的安装过程一、下载安装文件到MySQL官方网站找到ZIP文件提示:有些是安装文件,安装时会有提示,