网络编程
位置:首页>> 网络编程>> Python编程>> Tensorflow 自定义loss的情况下初始化部分变量方式

Tensorflow 自定义loss的情况下初始化部分变量方式

作者:I_will_____  发布时间:2023-02-26 22:43:39 

标签:Tensorflow,loss,初始化,变量

一般情况下,tensorflow里面变量初始化过程为:


 #variables ...........
 #.....................
 init = tf.initialize_all_variables()
 sess.run(init)

这里 tf.initialize_all_variables() 会初始化所有的变量。

实际过程中,假设有a, b, c三个变量,其中a已经被初始化了,只想单独初始化b,c,那么:


 #variables ...
 ...
 init = tf.variables_initializer([b,c])
 sess.run(init)

此外,如果自行修改了optimizer,如下代码就会报错:


 #definition of variables a, b, c ...
 ....
 my_optimizer = tf.train.RMSProp(learning_rate = 0.1).minimize(my_cost)
 init = tf.variables_initializer([b,c])
 sess.run(init)

这是因为自己定义的optimizer会生成新的variables,但是在init里面并没有初始化,所以无法访问,会报错。解决方法如下:


 a = tf.Variables(...)      #line N
 temp = set(tf.all_variables())
 b = tf.Variables(...)
 c = tf.Variables(...)
 #definition of my optimizer
 optimizer = tf.train.......
 init = tf.variables_initializer(set(tf.all_varialbles())-temp) # line M
 sess.run(init)

首先,temp = set(tf.all_variables()) 将该行(line N)代码之前的所有变量保存在temp中,接下来定义变量b, c,以及自定义的optimizer,然后 set(tf.all_varialbles()存储了改行(line M)之前的所有变量(包括optimizer生成的变量以及temp中所含的变量),set(tf.all_varialbles())-temp相减得到line N~M这几行定义的变量。

来源:https://blog.csdn.net/chenxicx1992/article/details/56483180

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com