关于tensorflow softmax函数用法解析
作者:ASR_THU 发布时间:2022-10-29 07:42:09
如下所示:
def softmax(logits, axis=None, name=None, dim=None):
"""Computes softmax activations.
This function performs the equivalent of
softmax = tf.exp(logits) / tf.reduce_sum(tf.exp(logits), axis)
Args:
logits: A non-empty `Tensor`. Must be one of the following types: `half`,
`float32`, `float64`.
axis: The dimension softmax would be performed on. The default is -1 which
indicates the last dimension.
name: A name for the operation (optional).
dim: Deprecated alias for `axis`.
Returns:
A `Tensor`. Has the same type and shape as `logits`.
Raises:
InvalidArgumentError: if `logits` is empty or `axis` is beyond the last
dimension of `logits`.
"""
axis = deprecation.deprecated_argument_lookup("axis", axis, "dim", dim)
if axis is None:
axis = -1
return _softmax(logits, gen_nn_ops.softmax, axis, name)
softmax函数的返回结果和输入的tensor有相同的shape,既然没有改变tensor的形状,那么softmax究竟对tensor做了什么?
答案就是softmax会以某一个轴的下标为索引,对这一轴上其他维度的值进行 激活 + 归一化处理。
一般来说,这个索引轴都是表示类别的那个维度(tf.nn.softmax中默认为axis=-1,也就是最后一个维度)
举例:
def softmax(X, theta = 1.0, axis = None):
"""
Compute the softmax of each element along an axis of X.
Parameters
----------
X: ND-Array. Probably should be floats.
theta (optional): float parameter, used as a multiplier
prior to exponentiation. Default = 1.0
axis (optional): axis to compute values along. Default is the
first non-singleton axis.
Returns an array the same size as X. The result will sum to 1
along the specified axis.
"""
# make X at least 2d
y = np.atleast_2d(X)
# find axis
if axis is None:
axis = next(j[0] for j in enumerate(y.shape) if j[1] > 1)
# multiply y against the theta parameter,
y = y * float(theta)
# subtract the max for numerical stability
y = y - np.expand_dims(np.max(y, axis = axis), axis)
# exponentiate y
y = np.exp(y)
# take the sum along the specified axis
ax_sum = np.expand_dims(np.sum(y, axis = axis), axis)
# finally: divide elementwise
p = y / ax_sum
# flatten if X was 1D
if len(X.shape) == 1: p = p.flatten()
return p
c = np.random.randn(2,3)
print(c)
# 假设第0维是类别,一共有里两种类别
cc = softmax(c,axis=0)
# 假设最后一维是类别,一共有3种类别
ccc = softmax(c,axis=-1)
print(cc)
print(ccc)
结果:
c:
[[-1.30022268 0.59127472 1.21384177]
[ 0.1981082 -0.83686108 -1.54785864]]
cc:
[[0.1826746 0.80661068 0.94057075]
[0.8173254 0.19338932 0.05942925]]
ccc:
[[0.0500392 0.33172426 0.61823654]
[0.65371718 0.23222472 0.1140581 ]]
可以看到,对axis=0的轴做softmax时,输出结果在axis=0轴上和为1(eg: 0.1826746+0.8173254),同理在axis=1轴上做的话结果的axis=1轴和也为1(eg: 0.0500392+0.33172426+0.61823654)。
这些值是怎么得到的呢?
以cc为例(沿着axis=0做softmax):
以ccc为例(沿着axis=1做softmax):
知道了计算方法,现在我们再来讨论一下这些值的实际意义:
cc[0,0]实际上表示这样一种概率: P( label = 0 | value = [-1.30022268 0.1981082] = c[*,0] ) = 0.1826746
cc[1,0]实际上表示这样一种概率: P( label = 1 | value = [-1.30022268 0.1981082] = c[*,0] ) = 0.8173254
ccc[0,0]实际上表示这样一种概率: P( label = 0 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.0500392
ccc[0,1]实际上表示这样一种概率: P( label = 1 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.33172426
ccc[0,2]实际上表示这样一种概率: P( label = 2 | value = [-1.30022268 0.59127472 1.21384177] = c[0]) = 0.61823654
将他们扩展到更多维的情况:假设c是一个[batch_size , timesteps, categories]的三维tensor
output = tf.nn.softmax(c,axis=-1)
那么 output[1, 2, 3] 则表示 P(label =3 | value = c[1,2] )
来源:https://blog.csdn.net/zongza/article/details/88016668


猜你喜欢
- 1. 事务介绍MVCC之前,先介绍下事务:事务是为了保证数据库中数据的完整性和一致性。事务的4个基本要素:原子性(Atomicity):要么
- 在讨论其返回值前,我们先来介绍以下calcHist()函数的用法:cv2.calcHist()函数cv2.calcHist()函数的作用通过
- 介绍Redis是一个开源的基于内存也可持久化的Key-Value数据库,采用ANSI C语言编写。它拥有丰富的数据结构,拥有事务功能,保证命
- 前端JS中使用XMLHttpRequest 2上传图片到服务器,PC端和大部分手机上都正常,但在少部分安卓手机上上传失败,服务器上查看图片,
- 我一般是不看别人写的代码的,为啥?累!而且这位同志给的还是经过压缩的!汗。。。考我是不是?还有,这位同志也不给个示例的代码,只说是代码没有问
- 本文实例讲述了Python3爬虫爬取英雄联盟高清桌面壁纸功能。分享给大家供大家参考,具体如下:使用Scrapy爬虫抓取英雄联盟高清桌面壁纸源
- --禁用 alter table tb disable trigger tir_name --啟用 alter table tb enabl
- 在最近的一次调试中,出现如下错误~·错误类型:ADODB.Recordset (0x800A0E7D)连接无法用于执行此操作。在此上下文中它
- 本文实例讲述了Python3实现的Mysql数据库操作封装类。分享给大家供大家参考,具体如下:#encoding:utf-8#name:mo
- 本文实例为大家分享了python openCV自制绘画板的具体代码,供大家参考,具体内容如下import numpy as npimport
- 跑模型和测试一些批量操作时,常常需要一个或多个文件中的文件的命名格式具有一定的规律。有时候获取的数据又是从一些网站爬取下来的,数据名具有一定
- 找了半天,以为numpy的where函数像matlab 的find函数一样好用,能够返回一个区间内的元素索引位置。结果没有。。(也可能是我没
- 数据库操作类的优点优点可以说是非常多了,常见的优点就是便于维护、复用、高效、安全、易扩展。例如PDO支持的数据库类型是非常多的,与mysql
- python有专门的神经网络库,但为了加深印象,我自己在numpy库的基础上,自己编写了一个简单的神经网络程序,是基于Rosenblatt感
- 本文实例讲述了Java实现从数据库导出大量数据记录并保存到文件的方法。分享给大家供大家参考,具体如下:数据库脚本:-- Table &quo
- 故障:数据库报错:“MSSQL Server 2000 附加数据库错误823”,附加数据库失败。故障
- 1、TransBigData简介TransBigData是一个为交通时空大数据处理、分析和可视化而开发的Python包。TransBigDa
- 在pycharm中,当调用( import / from … import… )其他文件夹下的函数或模块,会发现编辑器无法识别( can n
- 脚本运行环境python 3.6+edge浏览器(推荐使用,因为在edge浏览器中可以获得额外12分,当然chrome浏览器也可以)webd
- 1、封装的理解封装(Encapsulation):属性和方法的抽象属性的抽象:对类的属性(变量)进行定义、隔离和保护分为私有属性和公开属性: