keras实现theano和tensorflow训练的模型相互转换
作者:零落_World 发布时间:2023-04-18 05:49:26
标签:keras,theano,tensorflow
我就废话不多说了,大家还是直接看代码吧~
</pre><pre code_snippet_id="1947416" snippet_file_name="blog_20161025_1_3331239" name="code" class="python">
# coding:utf-8
"""
If you want to load pre-trained weights that include convolutions (layers Convolution2D or Convolution1D),
be mindful of this: Theano and TensorFlow implement convolution in different ways (TensorFlow actually implements correlation, much like Caffe),
and thus, convolution kernels trained with Theano (resp. TensorFlow) need to be converted before being with TensorFlow (resp. Theano).
"""
from keras import backend as K
from keras.utils.np_utils import convert_kernel
from text_classifier import keras_text_classifier
import sys
def th2tf( model):
import tensorflow as tf
ops = []
for layer in model.layers:
if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
original_w = K.get_value(layer.W)
converted_w = convert_kernel(original_w)
ops.append(tf.assign(layer.W, converted_w).op)
K.get_session().run(ops)
return model
def tf2th(model):
for layer in model.layers:
if layer.__class__.__name__ in ['Convolution1D', 'Convolution2D']:
original_w = K.get_value(layer.W)
converted_w = convert_kernel(original_w)
K.set_value(layer.W, converted_w)
return model
def conv_layer_converted(tf_weights, th_weights, m = 0):
"""
:param tf_weights:
:param th_weights:
:param m: 0-tf2th, 1-th2tf
:return:
"""
if m == 0: # tf2th
tc = keras_text_classifier(weights_path=tf_weights)
model = tc.loadmodel()
model = tf2th(model)
model.save_weights(th_weights)
elif m == 1: # th2tf
tc = keras_text_classifier(weights_path=th_weights)
model = tc.loadmodel()
model = th2tf(model)
model.save_weights(tf_weights)
else:
print("0-tf2th, 1-th2tf")
return
if __name__ == '__main__':
if len(sys.argv) < 4:
print("python tf_weights th_weights <0|1>\n0-tensorflow to theano\n1-theano to tensorflow")
sys.exit(0)
tf_weights = sys.argv[1]
th_weights = sys.argv[2]
m = int(sys.argv[3])
conv_layer_converted(tf_weights, th_weights, m)
补充知识:keras学习之修改底层为TensorFlow还是theano
我们知道,keras的底层是TensorFlow或者theano
要知道我们是用的哪个为底层,只需要import keras即可显示
修改方法:
打开
修改
来源:https://blog.csdn.net/cdj0311/article/details/52918687
0
投稿
猜你喜欢
- 这篇文章主要介绍了django序列化serializers过程解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价
- 一、re.findall函数介绍它在re.py中有定义:def findall(pattern, string, flags=0): &nb
- 本文主要研究的是selenium python浏览器多窗口处理的相关内容,分享了操作实例代码,具体如下:#!/usr/bin/python#
- 大家好,我们的数据库已经介绍完了,这里给大家总结一下。我们这段主要是学习了SQL的增删改查语句,其中查询是我们的重点。我们是以SQL Ser
- 在 Python 中,列表是一种非常常见且强大的数据类型。但有时候,我们需要从一个列表中删除特定元素,尤其是当这个元素出现多次时。本文将介绍
- 当程序中出现错误时怎么解决?也就是我们所说的bug(缺陷),以及工作中如何对bug进行调试❤ 什么是bug(缺陷)软件缺陷就是通
- c shell perl php下的日期时间转换: 秒数与人类可读日期 scalar localtime 与 seconds since `
- 本文实例讲述了javascript设置页面背景色及背景图片的方法。分享给大家供大家参考,具体如下:<!DOCTYPE HTML PUB
- 编写兼容IE和FireFox的脚本确定的件很烦人的事,今日又经历了一次。一、正式表达式问题试图用以下表达式提取中括号“]”后面的内容,连接调
- 而Easp类中提供了大量实用的ASP通用过程及方法,可以简化大部分的ASP操作。目前只提供了VBScript版,JScript版将来可能会提
- Sql Server 中一个非常强大的日期格式化函数: 获得当前系统时间,GETDATE(): 2008年01月08日 星期二 14:59
- 有时候需要在终端显示彩色的字符,即根据需要显示不同颜色的字符串,比如我们要在终端打印一行错误提示信息,要把它弄成红色的。其实这个在Pytho
- <% '************************************************
- 当列表菜单项目特别多的时候,使用JavaScript手风琴菜单(Accordion Menus)是个不错的选择。手风琴折叠菜单利于组织菜单项
- Django将秒转换为xx天xx时xx分,具体代码如下所示:from django.utils.translation import nge
- 1. PHP入侵检测系统PHP IDS(即PHP-入侵检测系统)是一套易于使用、结构良好、速度出色且专门面向PHP类Web应用程序的先进安全
- 本文实例为大家分享了opencv实现双边滤波的具体代码,供大家参考,具体内容如下1、2D卷积#!/usr/bin/env python3#
- 谁在用这些导航google是个大公司,全世界都有google的脚印,韩国的google动画效果非常不错,蓝色理想论坛里已经有人挖过来了,可惜
- urllib的基本用法urllib库的基本组成利用最简单的urlopen方法爬取网页html利用Request方法构建headers模拟浏览
- 设计方法曾经是个很尴尬的话题,因为经常看上去很美。专业人士们动手动脚折腾一大圈,出来的结果令人大跌眼镜。也有些设计师总喜欢把方法、概念吹的特