tensorflow创建变量以及根据名称查找变量
作者:lijiao 发布时间:2023-08-13 10:13:06
环境:Ubuntu14.04,tensorflow=1.4(bazel源码安装),Anaconda python=3.6
声明变量主要有两种方法:tf.Variable和 tf.get_variable,二者的最大区别是:
(1) tf.Variable是一个类,自带很多属性函数;而 tf.get_variable是一个函数;
(2) tf.Variable只能生成独一无二的变量,即如果给出的name已经存在,则会自动修改生成新的变量name;
(3) tf.get_variable可以用于生成共享变量。默认情况下,该函数会进行变量名检查,如果有重复则会报错。当在指定变量域中声明可
以变量共享时,可以重复使用该变量(例如RNN中的参数共享)。
下面给出简单的的示例程序:
import tensorflow as tf
with tf.variable_scope('scope1',reuse=tf.AUTO_REUSE) as scope1:
x1 = tf.Variable(tf.ones([1]),name='x1')
x2 = tf.Variable(tf.zeros([1]),name='x1')
y1 = tf.get_variable('y1',initializer=1.0)
y2 = tf.get_variable('y1',initializer=0.0)
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
print(x1.name,x1.eval())
print(x2.name,x2.eval())
print(y1.name,y1.eval())
print(y2.name,y2.eval())
输出结果为:
scope1/x1:0 [ 1.]
scope1/x1_1:0 [ 0.]
scope1/y1:0 1.0
scope1/y1:0 1.0
1. tf.Variable(…)
tf.Variable(…)使用给定初始值来创建一个新变量,该变量会默认添加到 graph collections listed in collections, which defaults to [GraphKeys.GLOBAL_VARIABLES]。
如果trainable属性被设置为True,该变量同时也会被添加到graph collection GraphKeys.TRAINABLE_VARIABLES.
# tf.Variable
__init__(
initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None,
constraint=None
)
2. tf.get_variable(…)
tf.get_variable(…)的返回值有两种情形:
使用指定的initializer来创建一个新变量;
当变量重用时,根据变量名搜索返回一个由tf.get_variable创建的已经存在的变量;
get_variable(
name,
shape=None,
dtype=None,
initializer=None,
regularizer=None,
trainable=True,
collections=None,
caching_device=None,
partitioner=None,
validate_shape=True,
use_resource=None,
custom_getter=None,
constraint=None
)
3. 根据名称查找变量
在创建变量时,即使我们不指定变量名称,程序也会自动进行命名。于是,我们可以很方便的根据名称来查找变量,这在抓取参数、finetune模型等很多时候都很有用。
示例1:
通过在tf.global_variables()变量列表中,根据变量名进行匹配搜索查找。 该种搜索方式,可以同时找到由tf.Variable或者tf.get_variable创建的变量。
import tensorflow as tf
x = tf.Variable(1,name='x')
y = tf.get_variable(name='y',shape=[1,2])
for var in tf.global_variables():
if var.name == 'x:0':
print(var)
示例2:
利用get_tensor_by_name()同样可以获得由tf.Variable或者tf.get_variable创建的变量。
需要注意的是,此时获得的是Tensor, 而不是Variable,因此 x不等于x1.
import tensorflow as tf
x = tf.Variable(1,name='x')
y = tf.get_variable(name='y',shape=[1,2])
graph = tf.get_default_graph()
x1 = graph.get_tensor_by_name("x:0")
y1 = graph.get_tensor_by_name("y:0")
示例3:
针对tf.get_variable创建的变量,可以利用变量重用来直接获取已经存在的变量。
with tf.variable_scope("foo"):
bar1 = tf.get_variable("bar", (2,3)) # create
with tf.variable_scope("foo", reuse=True):
bar2 = tf.get_variable("bar") # reuse
with tf.variable_scope("", reuse=True): # root variable scope
bar3 = tf.get_variable("foo/bar") # reuse (equivalent to the above)
print((bar1 is bar2) and (bar2 is bar3))
猜你喜欢
- CSS网页布局开发中,会有很多小技巧,这里再扩展一下您所想要得到的知识,相信您会有很多收获!一、ul标签在Mozilla中默认是有paddi
- 如何制作K线图?也不难,代码和说明见下:<%@ Language=VBScript %><%Respo
- 我一直使用Microsoft的FrontPage 98来开发ASP/ADO之类的Internet数据库应用程序。现在我听说许多人都非常信奉采
- 由于最近在处理shp文件,想要跳出arcpy的限制,所以打算学习一下pyshp包的使用方法。在使用《Python地理空间分析指南(第2版)》
- 在oracle数据库的开发中,常因为时间的问题大费周章,所以特地将ORACLE数据的日期函数收藏致此。乃供他日所查也。 add_months
- 第一种方法: 代码如下:/* 创建链接服务器 */ exec sp_addlinkedserver 'srv_lnk
- 本程序有两文件test.asp 和tree.asp 还有一些图标文件 1。test.asp 调用类生成树 代码如下<%@
- SQL触发器实例1 定义: 何为触发器?在SQL Server里面也就是对某一个表的一定的操作,触发某种条件,从而执行的一段程序。触发器是一
- 第一章:基本的圆角框第二章:透明圆角化背景图片第三章:圆角化图片 第四章:CSS圆角框组件 V1.0在上面的案例中,我只给出最为原始的圆角框
- 一、定位 oracle分两大块,一块是开发,一块是管理。开发主要是写写存储过程、触发器什么的,还有就是用Oracle的Develop工具做f
- 在应用系统开发初期,由于开发数据库数据比较少,对于查询SQL语句,复杂视图的编写,刚开始不会体会出SQL语句各种写法的性能优劣,但是如果将应
- 我刚进入5gsns的时候,我真不知道怎么玩,我是通过白鸦的博客过去的,之前也没有怎么去玩过这类的网站。对于sns网站还算是陌生,不过还好网站
- 1.1.1 摘要 Join是关系型数据库系统的重要操作之一,SQL Server中包含的常用Join:内联接、外联接和交叉联接等。如果我们想
- 背景:我们有一个用go做的项目,其中用到了zmq4进行通信,一个简单的rpc过程,早期远端是使用一个map去做ip和具体socket的映射。
- 匹配中文字符的正则表达式: [\u4e00-\u9fa5]评注:匹配中文还真是个头疼的事,有了这个表达式就好办了匹配双字节字符(包括汉字在内
- CSS的学习和其他的学习一样,都需要特定的方法才能比较快的去掌握它.要想掌握CSS, 首先要学会HTML,我刚开始是从零开始学习的
- 我们可以利用Session对象来进行注册验证。Session对象会帮我们把某一用户的信息保留下来,让后续的网页读取。我们就可以在用户注册成功
- golang 字符串 int uint int64 uint64 互转字符串 转 intintNum, _ = strconv.Atoi(i
- sql2000安全很重要将有安全问题的SQL过程删除.比较全面.一切为了安全!删除了调用shell,注册表,COM组件的破坏权限use&nb
- cmake-2.8.3.tar.gzmysql-5.5.8.tar.gz一,cmake-2.8.3的安装:tar -zxf cmake-2.