TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法
作者:xf__mao 发布时间:2021-03-20 19:55:17
在计算loss的时候,最常见的一句话就是tf.nn.softmax_cross_entropy_with_logits,那么它到底是怎么做的呢?
首先明确一点,loss是代价值,也就是我们要最小化的值
tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)
除去name参数用以指定该操作的name,与方法有关的一共两个参数:
第一个参数logits:就是神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes],单样本的话,大小就是num_classes
第二个参数labels:实际的标签,大小同上
具体的执行流程大概分为两步:
第一步是先对网络最后一层的输出做一个softmax,这一步通常是求取输出属于某一类的概率,对于单样本而言,输出就是一个num_classes大小的向量([Y1,Y2,Y3...]其中Y1,Y2,Y3...分别代表了是属于该类的概率)
softmax的公式是:
至于为什么是用的这个公式?这里不介绍了,涉及到比较多的理论证明
第二步是softmax的输出向量[Y1,Y2,Y3...]和样本的实际标签做一个交叉熵,公式如下:
其中指代实际的标签中第i个的值(用mnist数据举例,如果是3,那么标签是[0,0,0,1,0,0,0,0,0,0],除了第4个值为1,其他全为0)
就是softmax的输出向量[Y1,Y2,Y3...]中,第i个元素的值
显而易见,预测越准确,结果的值越小(别忘了前面还有负号),最后求一个平均,得到我们想要的loss
注意!!!这个函数的返回值并不是一个数,而是一个向量,如果要求交叉熵,我们要再做一步tf.reduce_sum操作,就是对向量里面所有元素求和,最后才得到,如果求loss,则要做一步tf.reduce_mean操作,对向量求均值!
理论讲完了,上代码
import tensorflow as tf
#our NN's output
logits=tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
#step1:do softmax
y=tf.nn.softmax(logits)
#true label
y_=tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]])
#step2:do cross_entropy
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#do cross_entropy just one step
cross_entropy2=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits, y_))#dont forget tf.reduce_sum()!!
with tf.Session() as sess:
softmax=sess.run(y)
c_e = sess.run(cross_entropy)
c_e2 = sess.run(cross_entropy2)
print("step1:softmax result=")
print(softmax)
print("step2:cross_entropy result=")
print(c_e)
print("Function(softmax_cross_entropy_with_logits) result=")
print(c_e2)
输出结果是:
step1:softmax result=
[[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]]
step2:cross_entropy result=
1.22282
Function(softmax_cross_entropy_with_logits) result=
1.2228
最后大家可以试试e^1/(e^1+e^2+e^3)是不是0.09003057,发现确实一样!!这也证明了我们的输出是符合公式逻辑的
来源:https://blog.csdn.net/mao_xiao_feng/article/details/53382790
![](https://www.aspxhome.com/images/zang.png)
![](https://www.aspxhome.com/images/jiucuo.png)
猜你喜欢
- mac下安装mysql8.0.11时 要求输入密码 之后想修改密码注意 此方法适用于mac下的mysql8.0.11 其他版本不一定相同1.
- 前言 BeautifulSoup是主要以解析web网页的Python模块,它会提供一些强大的解释器,以解
- 在做一些工作的时候,有时候会涉及到给图片加上水印,这个如果手动添加的话,效率太低了,通常选择代码完成。下面这个是给图像添加文字水印(图片水印
- 本文研究的主要是Python enumerate索引迭代的问题,具体介绍如下。索引迭代Python中,迭代永远是取出元素本身,而非元素的索引
- 第一种,也是我最常用的,第一帧里加上这个比较灵活,想要自定义加入菜单,只要定义drMenu这个对象就可以了var drMenu&n
- Python os.remove() 方法os.remove() 方法用于删除指定路径的文件。如果指定的路径是一个目录,将抛出OSError
- MSSQL SERVER 2005 数学函数 1.求绝对值 ABS() select FWeight-50,ABS(FWeight-50),
- 一般情况下编译安装python环境需要执行以下步骤:下载源码包解压源码包安装配置编译以及编译安装TALK IS CHEAP, SHOW YO
- 1.世界地图绘制演示先给大家看下效果图哈。① 世界地图数据准备地图数据如下:因为是世界地图,所以对标的国家,我设置了 2 组,里面的数据是随
- 1.过滤器的使用1.过滤器和测试器在Python中,如果需要对某个变量进行处理,我们可以通过函数来实现。在模板中,我们则是通过过滤器来实现的
- 本文实例讲述了Python在字典中将键映射到多个值上的方法。分享给大家供大家参考,具体如下:问题:一个能将键(key)映射到多个值的字典(即
- <?php $url="http://www.golden-book.com/booksinfo/12/264.html&q
- InstrRev描述:返回某字符串在另一个字符串中出现的从结尾计起的位置。语法:InstrRev(string1, string2
- 原来的语句是这样的: select sum(sl0000) from xstfxps2 where dhao00 in ( select d
- 字体设计是人类商业活动的需求,它随着时代和科学技术的进步而不断地变化着。被广泛应用于网络生活的各个方面。现代字体设计在计算机技术的应用中已经
- 停止MySQL服务Windows可以右键我的电脑--管理--服务和应用程序--服务--找到对应的服务停止掉免密登录切换到MySQL安装路径下
- Oracle数据库提供了几种不同的数据库启动和关闭方式,本文将详细介绍这些启动和关闭方式之间的区别以及它们各自不同的功能。 一、启动和关闭O
- 前言在迷宫问题中,给定入口和出口,要求找到路径。本文将讨论三种求解方法,递归求解、回溯求解和队列求解。在介绍具体算法之前,先考虑将迷宫数字化
- 游戏规则用pygame动画实现神庙逃亡类似的小游戏,当玩家移动的时候躲避 * ,如果 * 命中玩家或者名字龙都会减速,玩家躲避 * 使更多的 * 打
- 转换为字符串类型tips['sex_str'] = tips['sex'].astype(str)转换为数值