Tensorflow 自定义loss的情况下初始化部分变量方式
作者:I_will_____ 发布时间:2023-02-26 22:43:39
标签:Tensorflow,loss,初始化,变量
一般情况下,tensorflow里面变量初始化过程为:
#variables ...........
#.....................
init = tf.initialize_all_variables()
sess.run(init)
这里 tf.initialize_all_variables() 会初始化所有的变量。
实际过程中,假设有a, b, c三个变量,其中a已经被初始化了,只想单独初始化b,c,那么:
#variables ...
...
init = tf.variables_initializer([b,c])
sess.run(init)
此外,如果自行修改了optimizer,如下代码就会报错:
#definition of variables a, b, c ...
....
my_optimizer = tf.train.RMSProp(learning_rate = 0.1).minimize(my_cost)
init = tf.variables_initializer([b,c])
sess.run(init)
这是因为自己定义的optimizer会生成新的variables,但是在init里面并没有初始化,所以无法访问,会报错。解决方法如下:
a = tf.Variables(...) #line N
temp = set(tf.all_variables())
b = tf.Variables(...)
c = tf.Variables(...)
#definition of my optimizer
optimizer = tf.train.......
init = tf.variables_initializer(set(tf.all_varialbles())-temp) # line M
sess.run(init)
首先,temp = set(tf.all_variables()) 将该行(line N)代码之前的所有变量保存在temp中,接下来定义变量b, c,以及自定义的optimizer,然后 set(tf.all_varialbles()存储了改行(line M)之前的所有变量(包括optimizer生成的变量以及temp中所含的变量),set(tf.all_varialbles())-temp相减得到line N~M这几行定义的变量。
来源:https://blog.csdn.net/chenxicx1992/article/details/56483180


猜你喜欢
- 1、我们使用正常的输出语句得到的是(输出结果:division by zero)虽然得到了错误的日志输出,但是不知道为什么出错,也不能定位具
- keras中的Reshapekeras自带from keras.layers import Reshapelayer_1 = Reshape
- API:statuses/public_timeline 返回最新的200条公共微博,返回结果非完全实时CODE:#!/usr/
- 像在下拉菜单中选择省、市这样的操作,我一直用ASP来创建生成列表函数,把它们保存在一个Include文件中,用的时候就加载。这样做确实有个不
- 1. Shared and Exclusive Locksshared lock (译:共享锁)exclusive lock (
- 在任何编程语言中,检查字符串是否包含子字符串都是常见的任务。例如,假设您正在构建在线游戏。您可能需要检查用户名是否包含禁止使用的短语,以确保
- 如下所示:import numpy as np arr = [1,2,3,4,5,6]#求均值arr_mean = np.mean(arr)
- js代码:$(".head").change(function() {var val = $(this).val();i
- 一,问题因为我想在我的服务器上部署两个vue项目,但是vue打包后默认的项目名是dist,这样子就跟我上一个vue项目冲突了。因此查了一下资
- mat数据格式是Matlab默认保存的数据格式。在Python中,我们可以使用h5py库来读取mat文件。>>> impo
- 因为要批量用某软件处理一批eps文件,所以要模拟鼠标及键盘动作,使其能够自动化操作。#-*-coding:utf-8-*-import os
- 设计原理从结构上来说,一个简单的图形界面,需要由界面组件、组件的事件 * (响应各类事件的逻辑)和具体的事件处理逻辑组成。界面实现的主要工作
- 介绍观察者模式:是一种行为型设计模式。主要关注的是对象的责任,允许你定义一种订阅机制,可在对象事件发生时通知多个"观察"
- 基于 Vue 技术栈的你如果需要选用一种移动端跨平台框架,是 Weex?React-Native?还是Flutter? 无疑,相对于后两者,
- 1. 普通装饰器 import logging1. foo = use_loggine(foo) def use_loggine(func)
- 一.MYSQL的命令行模式的设置桌面->我的电脑->属性->环境变量->新建->PATH=“;path\mys
- 1. 确认已经安装了NT/2000和SQL Server的最新补丁程序,不用说大家应该已经安装好了,但是我觉得最好还是在这里提醒一
- bisect是python内置模块,用于有序序列的插入和查找。查找: bisect(array, item)插入: insort(array
- FCKeditor至今已经到了2.3.1版本了,对于国内的WEB开发者来说,也基本上都已经“闻风知多少”了,很多人将其融放到自己的项目中,更
- Oracle数据库以其高可靠性、安全性、可兼容性,得到越来越多的企业的青睐。如何使Oracle数据库保持优良性能,这是许多数据库管理员关心的