浅谈keras中的后端backend及其相关函数(K.prod,K.cast)
作者:C小C 发布时间:2021-07-04 08:53:54
一、K.prod
prod
keras.backend.prod(x, axis=None, keepdims=False)
功能:在某一指定轴,计算张量中的值的乘积。
参数
x: 张量或变量。
axis: 一个整数需要计算乘积的轴。
keepdims: 布尔值,是否保留原尺寸。 如果 keepdims 为 False,则张量的秩减 1。 如果 keepdims 为 True,缩小的维度保留为长度 1。
返回
x 的元素的乘积的张量。
Numpy 实现
def prod(x, axis=None, keepdims=False):
if isinstance(axis, list):
axis = tuple(axis)
return np.prod(x, axis=axis, keepdims=keepdims)
具体例子:
import numpy as np
x=np.array([[2,4,6],[2,4,6]])
scaling = np.prod(x, axis=1, keepdims=False)
print(x)
print(scaling)
【运行结果】
二、K.cast
cast
keras.backend.cast(x, dtype)
功能:将张量转换到不同的 dtype 并返回。
你可以转换一个 Keras 变量,但它仍然返回一个 Keras 张量。
参数
x: Keras 张量(或变量)。
dtype: 字符串, ('float16', 'float32' 或 'float64')。
返回
Keras 张量,类型为 dtype。
例子
>>> from keras import backend as K
>>> input = K.placeholder((2, 3), dtype='float32')
>>> input
<tf.Tensor 'Placeholder_2:0' shape=(2, 3) dtype=float32>
# It doesn't work in-place as below.
>>> K.cast(input, dtype='float16')
<tf.Tensor 'Cast_1:0' shape=(2, 3) dtype=float16>
>>> input
<tf.Tensor 'Placeholder_2:0' shape=(2, 3) dtype=float32>
# you need to assign it.
>>> input = K.cast(input, dtype='float16')
>>> input
<tf.Tensor 'Cast_2:0' shape=(2, 3) dtype=float16>
补充知识:keras源码之backend库目录
backend库目录
先看common.py
一上来是一些说明
# the type of float to use throughout the session. 整个模块都是用浮点型数据
_FLOATX = 'float32' # 数据类型为32位浮点型
_EPSILON = 1e-7 # 很小的常数
_IMAGE_DATA_FORMAT = 'channels_last' # 图像数据格式 最后显示通道,tensorflow格式
接下来看里面的一些函数
def epsilon():
"""Returns the value of the fuzz factor used in numeric expressions.
返回数值表达式中使用的模糊因子的值
# Returns
A float.
# Example
```python
>>> keras.backend.epsilon()
1e-07
```
"""
return _EPSILON
该函数定义了一个常量,值为1e-07,在终端可以直接输出,如下:
def set_epsilon(e):
"""Sets the value of the fuzz factor used in numeric expressions.
# Arguments
e: float. New value of epsilon.
# Example
```python
>>> from keras import backend as K
>>> K.epsilon()
1e-07
>>> K.set_epsilon(1e-05)
>>> K.epsilon()
1e-05
```
"""
global _EPSILON
_EPSILON = e
该函数允许自定义值
以string的形式返回默认的浮点类型:
def floatx():
"""Returns the default float type, as a string.
(e.g. 'float16', 'float32', 'float64').
# Returns
String, the current default float type.
# Example
```python
>>> keras.backend.floatx()
'float32'
```
"""
return _FLOATX
把numpy数组投影到默认的浮点类型:
def cast_to_floatx(x):
"""Cast a Numpy array to the default Keras float type.把numpy数组投影到默认的浮点类型
# Arguments
x: Numpy array.
# Returns
The same Numpy array, cast to its new type.
# Example
```python
>>> from keras import backend as K
>>> K.floatx()
'float32'
>>> arr = numpy.array([1.0, 2.0], dtype='float64')
>>> arr.dtype
dtype('float64')
>>> new_arr = K.cast_to_floatx(arr)
>>> new_arr
array([ 1., 2.], dtype=float32)
>>> new_arr.dtype
dtype('float32')
```
"""
return np.asarray(x, dtype=_FLOATX)
默认数据格式、自定义数据格式和检查数据格式:
def image_data_format():
"""Returns the default image data format convention ('channels_first' or 'channels_last').
# Returns
A string, either `'channels_first'` or `'channels_last'`
# Example
```python
>>> keras.backend.image_data_format()
'channels_first'
```
"""
return _IMAGE_DATA_FORMAT
def set_image_data_format(data_format):
"""Sets the value of the data format convention.
# Arguments
data_format: string. `'channels_first'` or `'channels_last'`.
# Example
```python
>>> from keras import backend as K
>>> K.image_data_format()
'channels_first'
>>> K.set_image_data_format('channels_last')
>>> K.image_data_format()
'channels_last'
```
"""
global _IMAGE_DATA_FORMAT
if data_format not in {'channels_last', 'channels_first'}:
raise ValueError('Unknown data_format:', data_format)
_IMAGE_DATA_FORMAT = str(data_format)
def normalize_data_format(value):
"""Checks that the value correspond to a valid data format.
# Arguments
value: String or None. `'channels_first'` or `'channels_last'`.
# Returns
A string, either `'channels_first'` or `'channels_last'`
# Example
```python
>>> from keras import backend as K
>>> K.normalize_data_format(None)
'channels_first'
>>> K.normalize_data_format('channels_last')
'channels_last'
```
# Raises
ValueError: if `value` or the global `data_format` invalid.
"""
if value is None:
value = image_data_format()
data_format = value.lower()
if data_format not in {'channels_first', 'channels_last'}:
raise ValueError('The `data_format` argument must be one of '
'"channels_first", "channels_last". Received: ' +
str(value))
return data_format
剩余的关于维度顺序和数据格式的方法:
def set_image_dim_ordering(dim_ordering):
"""Legacy setter for `image_data_format`.
# Arguments
dim_ordering: string. `tf` or `th`.
# Example
```python
>>> from keras import backend as K
>>> K.image_data_format()
'channels_first'
>>> K.set_image_data_format('channels_last')
>>> K.image_data_format()
'channels_last'
```
# Raises
ValueError: if `dim_ordering` is invalid.
"""
global _IMAGE_DATA_FORMAT
if dim_ordering not in {'tf', 'th'}:
raise ValueError('Unknown dim_ordering:', dim_ordering)
if dim_ordering == 'th':
data_format = 'channels_first'
else:
data_format = 'channels_last'
_IMAGE_DATA_FORMAT = data_format
def image_dim_ordering():
"""Legacy getter for `image_data_format`.
# Returns
string, one of `'th'`, `'tf'`
"""
if _IMAGE_DATA_FORMAT == 'channels_first':
return 'th'
else:
return 'tf'
在common.py之后有三个backend,分别是cntk,tensorflow和theano。
__init__.py
首先从common.py中引入了所有需要的东西
from .common import epsilon
from .common import floatx
from .common import set_epsilon
from .common import set_floatx
from .common import cast_to_floatx
from .common import image_data_format
from .common import set_image_data_format
from .common import normalize_data_format
接下来是检查环境变量与配置文件,设置backend和format,默认的backend是tensorflow。
# Set Keras base dir path given KERAS_HOME env variable, if applicable.
# Otherwise either ~/.keras or /tmp.
if 'KERAS_HOME' in os.environ: # 环境变量
_keras_dir = os.environ.get('KERAS_HOME')
else:
_keras_base_dir = os.path.expanduser('~')
if not os.access(_keras_base_dir, os.W_OK):
_keras_base_dir = '/tmp'
_keras_dir = os.path.join(_keras_base_dir, '.keras')
# Default backend: TensorFlow. 默认后台是TensorFlow
_BACKEND = 'tensorflow'
# Attempt to read Keras config file.读取keras配置文件
_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
if os.path.exists(_config_path):
try:
with open(_config_path) as f:
_config = json.load(f)
except ValueError:
_config = {}
_floatx = _config.get('floatx', floatx())
assert _floatx in {'float16', 'float32', 'float64'}
_epsilon = _config.get('epsilon', epsilon())
assert isinstance(_epsilon, float)
_backend = _config.get('backend', _BACKEND)
_image_data_format = _config.get('image_data_format',
image_data_format())
assert _image_data_format in {'channels_last', 'channels_first'}
set_floatx(_floatx)
set_epsilon(_epsilon)
set_image_data_format(_image_data_format)
_BACKEND = _backend
之后的tensorflow_backend.py文件是一些tensorflow中的函数说明,详细内容请参考tensorflow有关资料。
来源:https://blog.csdn.net/C_chuxin/article/details/87919432
猜你喜欢
- 层的八条定律当然,这些并非真正的定律,而只是一些有益的忠告,使你免陷于使用层时可能的困顿中。原来有九条定律的,我们精简掉一条,还有下面的八条
- 本文通过实例为大家分享了python实现批量提取指定文件夹下同类型文件,供大家参考,具体内容如下代码import osimport shut
- MySQL是一个小型关系型数据库管理系统,开发者为瑞典MySQLAB公司,在2008年1月16号被Sun公司收购。MySQL被广泛地应用在I
- JSON(JavaScript Object Notation)是一种轻量级的数据交换格式,它基于ECMAScript的一个子集。 JSON
- 1.引言本文是Python生态系统中一些有用技巧的分享。大多数技巧只是使用标准库中的包,但其他一些技巧会涉及一些第三方包。在开始阅读本文内容
- Ping服务ping 是基于 XML_RPC 标准协议的更新通告服务,用于Blog把内容更新快速通知给搜索引擎,以便搜索引擎及时进行抓取和更
- 请问如何使用CDONTS组件来发送电子邮件?我们可以在IIS4下使用CDONTS来完成。首先要确认是否安装了SMTP服务(OPTIONPAC
- 本次主要是使用selenium模拟登录网页端的TX新闻,本来最开始是模拟请求的,但是某一天突然发现,部分账号需要经过滑块验证才能正常登录,如
- 选择排序:选择排序(Selection sort)是一种简单直观的 排序算法 。它的工作原理如下。首先在未排序序列中找到最小(大)元素,存放
- 获取DataFrame虽然是一个比较简单的操作,但是有时候到手边就是写不出来,所以在这里总结记录一下:1.链表推倒式data =
- 有。试试下面这个程序:saveip.asp<%Server.Scripttimeout = 1000On 
- 1:创建用户 create temporary tablespace user_temp tempfile 'D:\app\topw
- 如果遇到与文件许可有关的问题,可能数启动mysqld时UMASK环境变量设置得不正确。例如,当你创建表时,MySQL可能会发出下述错误消息:
- 根据代码中运行的结果来看,主要由以下几种:1. sum():将array中每个元素相加的结果2. axis对应的是维度的相加。比如:1、ax
- 前言我们在写爬虫是遇到最多的应该就是js反爬了,今天分享一个比较常见的js反爬,这个我已经在多个网站上见到过了。我把js反爬分为参数由js加
- 因些朋友发来邮件讲根据文章修改后无效,懒羊再次检查后发现在工具栏中并无添加,所以还得做一下下面步骤,再此给大家造成的不便还请多多谅解!因FC
- 表的故障检测和修正的一般过程如下:◆ 检查出错的表。如果该表检查通过,则完成任务,否则必须修复出错的数据库表。◆ 在开始修复之前对表文件进行
- 在我们爬虫的时候经常会遇到验证码,新浪微博的验证码是四宫格形式。可以采用模板验证码的破解方式,也就是把所有验证码的情况全部列出来,然后拿验证
- 之前一直对于python类的继承机制认知的比较混乱,今天学习记录一下。(1)首先使用直接继承的方式class parent():  
- 大家有没有这种感觉,一到国庆、春节这种长假,抢火车票就非常困难?各大互联网公司都推出抢票服务,只要加钱给服务费就可以增加抢到票的几率。有些火