tensorflow 恢复指定层与不同层指定不同学习率的方法
作者:跬步达千里 发布时间:2023-08-27 23:09:14
标签:tensorflow,指定层
如下所示:
#tensorflow 中从ckpt文件中恢复指定的层或将指定的层不进行恢复:
#tensorflow 中不同的layer指定不同的学习率
with tf.Graph().as_default():
#存放的是需要恢复的层参数
variables_to_restore = []
#存放的是需要训练的层参数名,这里是没恢复的需要进行重新训练,实际上恢复了的参数也可以训练
variables_to_train = []
for var in slim.get_model_variables():
excluded = False
for exclusion in fine_tune_layers:
#比如fine tune layer中包含logits,bottleneck
if var.op.name.startswith(exclusion):
excluded = True
break
if not excluded:
variables_to_restore.append(var)
#print('var to restore :',var)
else:
variables_to_train.append(var)
#print('var to train: ',var)
#这里省略掉一些步骤,进入训练步骤:
#将variables_to_train,需要训练的参数给optimizer 的compute_gradients函数
grads = opt.compute_gradients(total_loss, variables_to_train)
#这个函数将只计算variables_to_train中的梯度
#然后将梯度进行应用:
apply_gradient_op = opt.apply_gradients(grads, global_step=global_step)
#也可以直接调用opt.minimize(total_loss,variables_to_train)
#minimize只是将compute_gradients与apply_gradients封装成了一个函数,实际上还是调用的这两个函数
#如果在梯度里面不同的参数需要不同的学习率,那么可以:
capped_grads_and_vars = []#[(MyCapper(gv[0]), gv[1]) for gv in grads_and_vars]
#update_gradient_vars是需要更新的参数,使用的是全局学习率
#对于不是update_gradient_vars的参数,将其梯度更新乘以0.0001,使用基本上不动
for grad in grads:
for update_vars in update_gradient_vars:
if grad[1]==update_vars:
capped_grads_and_vars.append((grad[0],grad[1]))
else:
capped_grads_and_vars.append((0.0001*grad[0],grad[1]))
apply_gradient_op = opt.apply_gradients(capped_grads_and_vars, global_step=global_step)
#在恢复模型时:
with sess.as_default():
if pretrained_model:
print('Restoring pretrained model: %s' % pretrained_model)
init_fn = slim.assign_from_checkpoint_fn(
pretrained_model,
variables_to_restore)
init_fn(sess)
#这样就将指定的层参数没有恢复
来源:https://blog.csdn.net/LIYUAN123ZHOUHUI/article/details/69569493


猜你喜欢
- 本文实例为大家分享了js实现QQ邮箱邮件拖拽删除的具体代码,供大家参考,具体内容如下步骤分析:根据数据结构生成HTML结构全选和单选功能的实
- 我的页面上有一个下拉菜单,页面上有一个文本输入框,一个图像上传框,文本输入框默认是显示的,而图片上传框是隐藏的.假设下拉菜单有两项A和B,我
- 本文实例讲述了Python图像处理之颜色的定义与使用。分享给大家供大家参考,具体如下:python中的颜色相关的定义在matplotlib模
- 解决window.open后返回object的错误 <a href="javascript:void(window.open
- 关于 TensorFlowTensorFlow™ 是一个采用数据流图(data flow graphs),用于数值计算的开源软件库。节点(N
- 本文实例讲述了Python模块的制作方法。分享给大家供大家参考,具体如下:1 目的利用setup.py将框架安装到python环境中,作为第
- 通常程序会被编写为一个顺序执行并完成一个独立任务的代码。如果没有特别的需求,最好总是这样写代码,因为这种类型的程序通常很容易写,也很容易维护
- 本文实例为大家分享了python实现图像识别的具体代码,供大家参考,具体内容如下#! /usr/bin/env python from PI
- decode()方法使用注册编码的编解码器的字符串进行解码。它默认为默认的字符串编码。语法以下是decode()方法的语法:st
- 这篇文章主要介绍了python如何基于redis实现ip代理池,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,
- 前段时间看到letcode上的元音字母字符串反转的题目,今天来研究一下字符串反转的内容。主要有三种方法:1.切片法(最简洁的一种)#切片法d
- 本文实例为大家分享了python实现矩阵打印的具体代码,供大家参考,具体内容如下之前面试嵌入式软件的一道题,用c实现矩阵打印,考场上并没有写
- Geany中配置python的方法:一、文件下载并安装1、下载Python下载地址:https://www.python.org/downl
- 本文实例讲述了python迭代器的简单用法,分享给大家供大家参考。具体分析如下:生成器表达式是用来生成函数调用时序列参数的一种迭代器写法生成
- 安装redis服务1 下载redis cd /usr/local/ 进入安装目录 wget http://downl
- 正则表达式处理花括号内容替换赋值@Test public void replaceStr() { &
- 场景产品中有一张图片表pics,数据量将近100万条,有一条相关的查询语句,由于执行频次较高,想针对此语句进行优化表结构很简单,主要字段:u
- 假设我们有一个非常简单的Post模型,它将是一个图像及其描述,from django.db import modelsclass Post(
- 提高MySQL 查询效率的三个技巧小结MySQL由于它本身的小巧和操作的高效, 在数据库应用中越来越多的被采用.我在开发
- Template无疑是一个好东西,可以将字符串的格式固定下来,重复利用。同时Template也可以让开发人员可以分别考虑字符串的格式和其内容