tensorflow实现训练变量checkpoint的保存与读取
作者:imumu_xi 发布时间:2023-12-15 18:10:33
标签:tensorflow,checkpoint,保存,读取
1.保存变量
先创建(在tf.Session()之前)saver
saver = tf.train.Saver(tf.global_variables(),max_to_keep=1) #max_to_keep这个保证只保存最后一次training的训练数据
然后在训练的循环里面
checkpoint_path = os.path.join(Path, 'model.ckpt') saver.save(session, checkpoint_path, global_step=step) #这里的step是循环训练的次数,也就是第几次迭代
以下保存的变量文件
2.变量读取
1.若要直接恢复所有变量可以
saver = tf.train.Saver(tf.global_variables())
moudke_file=tf.train.latest_checkpoint('PATH')
saver.restore(sess,moudke_file)
PATH是存放保存变量的路径,会自动找到最近保存的变量文件
2 若想读取其中一部分变量值
def read_checkpoint():
w = []
checkpoint_path = '/home/ximao/models/resnet3/variable_logs/model.ckpt-17000'
reader = tf.train.NewCheckpointReader(checkpoint_path)
var = reader.get_variable_to_shape_map()
for key in var:
if 'weights' in key and 'conv' in key and 'Mo' not in key:
print('tensorname:', key)
# # print(reader.get_tensor(key))
3. 若想恢复其中一部分变量值到新网络
(1)首先你要先获取你想要赋值新网络变量的变量名,这里变量名不是一个字符串,而是<name,shape,dtype>这样的一个结构,
然后把你要赋值的元素转为张量,最后把值赋给你得到变量名 如下:
var=[v for v in weight_pruned if v.op.name=='WRN/conv1/weights']
conv1_temp=tf.convert_to_tensor(conv1,dtype=tf.float32)
sess.run(tf.assign(var[0],conv1_temp))
weight_pruned 存放的是你新网络中所有的变量
来源:https://blog.csdn.net/sinat_30372583/article/details/79763044


猜你喜欢
- Django中提供了一个类Paginator专门用来管理和处理分页数据,所以我们在使用之前先导入好相应的类,,另外这里我们也导入了待会会用到
- 如何提高Request集合的使用效率?以加快程序处理速度: strTitle=Request.Form("Title&q
- Golang与python线程详解及简单实例在GO中,开启15个线程,每个线程把全局变量遍历增加100000次,因此预测结果是 15*100
- 【尊重原创,转载请注明出处】https://blog.csdn.net/guyuealian/article/details/7967225
- 一、生成器1、生成器定义在Python中,一边循环一边计算的机制,称为生成器:generator2、生成器存在的意义列表所有数据都在内存中,
- 在数据处理过程中,经常会出现对某列批量做某些操作,比如dataframe df要对列名为“values”做大于等于30设置为1,小于30设置
- xbox series和ps5发售以来,国内黄牛价格一直居高不下。虽然海外amazon上ps5补货很少而且基本撑不过一分钟,但是xbox s
- 最近在重构公司以前产品的前端代码,摈弃了以前的session-cookie鉴权方式,采用token鉴权,忙里偷闲觉得有必要对几种常见的鉴权方
- Python中的字符串对象是不能更改的,也即直接修改字符串中的某一位或几位字符是实现不了的,即python中字符串对象不可更改,但字符串对象
- 1.substring_index函数的语法及其用法(1)语法:substring_index(string,sep,num)即substr
- 用户权限管理主要有以下作用: 1. 可以限制用户访问哪些库、哪些表 2. 可以限制用户对哪些表执行SELECT、CREATE、DELETE、
- 一、报错error connecting to master 'x@x.x.x.x:x' - retry-time: 60&
- GoroutineGoroutine 是 Golang 提供的一种轻量级线程,我们通常称之为「协程」,相比较线程,创建一个协程的成本是很低的
- MySQL是一个非常流行的小型关系型数据库管理系统,2008年1月16号被Sun公司收购。目前MySQL被广泛地应用在Internet上的中
- 一、SQLAlchemy简介1.1、SQLAlchemy是什么?sqlalchemy是一个python语言实现的的针对关系型数据库的orm库
- 我的数据库和报表服务的版本如下:数据库:SQL Server 2008 R2报表服务:SQL Server 2008 R2 Reportin
- 在Centos中安装完MySQL数据库以后,不知道密码,这可怎么办,下面给大家说一下怎么重置密码1、修改配置文件my.cnf 按i编辑[ro
- 我目标文件夹下有一大批图片,我要把它转变为指定尺寸大小的图片,用pthon和opencv实现的。以上为原图片。import cv2impor
- 绘制双变量联合分布图有时我们不仅需要查看单个变量的分 布,同时也需要查看变量之间的联系, 往往还需要进行预测等。这时就需要用到双变量联合分布
- 从内部架构和理念划分,目前JavaScript框架可以划分为5类。第一种是以命名空间为导向的类库或框架,如果创建一个数组用new Array