keras的get_value运行越来越慢的解决方案
作者:头狼586 发布时间:2023-10-22 14:53:49
keras 深度学习框架中get_value函数运行越来越慢,内存消耗越来越大问题
问题描述
如上图所示,经过时间和内存消耗跟踪测试,发现是keras.backend.get_value() 函数导致的程序越来越慢,而且严重的造成内存泄露;
查看该函数内部实现,发现一个主要核心是x.eval(session=get_session()),该语句可能是导致内存泄露和运行慢的核心语句; 根据查看一些博文得到了运行得越来越慢的
原因:该x.eval函数会添加新的节点到tf的图中;而这也导致了tf的图越来越大,内存泄露;
解决方法
import tensorflow.keras.backend as K
def get_my_session(gpu_fraction=0.1):
'''Assume that you have 6GB of GPU memory and want to allocate ~2GB'''
num_threads = os.environ.get('OMP_NUM_THREADS')
gpu_options = tf.GPUOptions(per_process_gpu_memory_fraction=gpu_fraction)
if num_threads:
return tf.Session(config=tf.ConfigProto(
gpu_options=gpu_options, intra_op_parallelism_threads=num_threads))
else:
return tf.Session(config=tf.ConfigProto(gpu_options=gpu_options))
K.set_session(get_my_session())
如上图所示, 我在使用tensorflow之前(也就是该工程文件前面),对session进行自定义,然后用自定义的session设定keras.backend.set_session();
然后删除get_value() 函数,直接用get_value()中所使用的执行语句x.eval(session=get_my_session());这样这个添加节点导致内存泄露的核心语句x.eval()就使用的是该工程统一自定义session,然后用tf.reset_default_graph() 对图重置就可以了
即上图问题代码修改为:
output = ctc_decode(y_pred,input_length=input_length,)
output = output[0][0]
out = output.eval(session=get_my_session())
# 删除 K.get_value(out[0][0])
tf.reset_default_graph() # 然后重置tf图,这句很关键
这样就解决了get_value()导致的越来越慢的问题;
个人认为:这样可能就不会总是添加新的节点,导致tf图不断地无限变大;而是重复使用这一个自定义的节点。
补充:tensorflow与keras之间版本问题引起get_session问题解决办法
1.产生报错原因
import tensorflow.keras.backend as K
def __init__(self, **kwargs):
self.__dict__.update(self._defaults) # set up default values
self.__dict__.update(kwargs) # and update with user overrides
self.class_names = self._get_class()
self.anchors = self._get_anchors()
self.sess = K.get_session()
报错如下:
get_session is not available when using TensorFlow 2.0.
意思是 tf2.0 没有 get_session
2.解决方案1
import tensorflow.python.keras.backend as K
sess = K.get_session()
3. 解决方案2
import tensorflow as tf
sess = tf.compat.v1.keras.backend.get_session()
之前一直采用方案1 解决,感觉比较方便;但是解决方案1 有其它属性会丢失问题
比如AttributeError: module ‘keras.backend' has no attribute image_dim_ordering
所以建议大家采用方案2
来源:https://blog.csdn.net/mingshili/article/details/81941677


猜你喜欢
- 1.MySQL官网下载压缩版文件,放至安装路径下载zip安装包MySQL :: Download MySQL Community Serve
- 前言相关一些检测工具挺多的,比如powertop、powerstat、s-tui等。但如何通过代码的方式来实时检测,是个麻烦的问题。通过许久
- 另外他们列出的这些区别有些是蛮有意义的,有些可能由于他们本人的MySQL DBA的身份,对Oracle的理解有些偏差,有些则有凑数的嫌疑.
- 本文实例为大家分享了js简单计算器的实现代码,供大家参考,具体内容如下1.html代码 <input type="text&
- ASP 错误代码 说明 ASP 0100 内存不足 ASP 0101 意外错误 ASP 0102 需要字符串输入 ASP 0103 需要数字
- 一. 什么是Selenium?网络爬虫是Python编程中一个非常有用的技巧,它可以让您自动获取网页上的数据。在本文中,我们将介绍如何使用S
- 1.Python3读取hdf文件最开始使用Python导入pyhdf包的时候是可以的,但是当导入pyhdf.SD的时候就出现了以下问题:我查
- 目录1、请求模块:urllib.requestdata参数:post请求urlopen()中的参数timeout:设置请求超时时间:响应类型
- 1. Single array iteration>>> a = np.arange(6).reshape(2,3)>
- 本篇概要1.线程与多线程2.进程与多进程3.多线程并发下载图片4.多进程并发提高数字运算关于并发在计算机编程领域,并发编程是一个很常见的名词
- 在用tensorflow做一维的卷积神经网络的时候会遇到tf.nn.conv1d和layers.conv1d这两个函数,但是这两个函数有什么
- 修改闭包内使用的外部变量错误示例:# 定义一个外部函数def func_out(num1): # 定义一个内部函数
- 一、注意你的Python版本Python官方网站为http://www.python.org/,当前最新稳定版本为3.6.5,在3.0版本时
- 在网站的一些应用中需要提供用户直接打印页面的功能,最明显的就是电子优惠券,商家根据网站提供的模板输入内容,然后生成优惠券页面,用户打印这个页
- 表单介绍说到表单,在HTML中表单的创建时通过<form>标签实现的,在<form>标签内部,字段通过使用<i
- 虽然说标题将的是首页的访问感受,但是同样适合于网站其它页面的用户体验设计,一个好的网站设计应当尽量做到首页和次页一视同仁。第一步(视觉设计)
- 在python中,普通的列表list和numpy中的数组array是不一样的,最大的不同是:一个列表中可以存放不同类型的数据,包括int、f
- startswith()方法Python startswith() 方法用于检查字符串是否是以指定子字符串开头如果是则返回 True,否则返
- 以a=[1,2,3] 为例,似乎使用del, remove, pop一个元素2 之后 a都是为 [1,3],如下:>>>
- 可及,通俗的说是“可以达到”,加上主语和宾语,在“交互设计”这个大的语境下,含义应该是“用户可以达到自己的操作目标”,这不是和“有效性—用户