TensorFlow tf.nn.softmax_cross_entropy_with_logits的用法
作者:xf__mao 发布时间:2021-03-20 19:55:17
在计算loss的时候,最常见的一句话就是tf.nn.softmax_cross_entropy_with_logits,那么它到底是怎么做的呢?
首先明确一点,loss是代价值,也就是我们要最小化的值
tf.nn.softmax_cross_entropy_with_logits(logits, labels, name=None)
除去name参数用以指定该操作的name,与方法有关的一共两个参数:
第一个参数logits:就是神经网络最后一层的输出,如果有batch的话,它的大小就是[batchsize,num_classes],单样本的话,大小就是num_classes
第二个参数labels:实际的标签,大小同上
具体的执行流程大概分为两步:
第一步是先对网络最后一层的输出做一个softmax,这一步通常是求取输出属于某一类的概率,对于单样本而言,输出就是一个num_classes大小的向量([Y1,Y2,Y3...]其中Y1,Y2,Y3...分别代表了是属于该类的概率)
softmax的公式是:
至于为什么是用的这个公式?这里不介绍了,涉及到比较多的理论证明
第二步是softmax的输出向量[Y1,Y2,Y3...]和样本的实际标签做一个交叉熵,公式如下:
其中指代实际的标签中第i个的值(用mnist数据举例,如果是3,那么标签是[0,0,0,1,0,0,0,0,0,0],除了第4个值为1,其他全为0)
就是softmax的输出向量[Y1,Y2,Y3...]中,第i个元素的值
显而易见,预测越准确,结果的值越小(别忘了前面还有负号),最后求一个平均,得到我们想要的loss
注意!!!这个函数的返回值并不是一个数,而是一个向量,如果要求交叉熵,我们要再做一步tf.reduce_sum操作,就是对向量里面所有元素求和,最后才得到,如果求loss,则要做一步tf.reduce_mean操作,对向量求均值!
理论讲完了,上代码
import tensorflow as tf
#our NN's output
logits=tf.constant([[1.0,2.0,3.0],[1.0,2.0,3.0],[1.0,2.0,3.0]])
#step1:do softmax
y=tf.nn.softmax(logits)
#true label
y_=tf.constant([[0.0,0.0,1.0],[0.0,0.0,1.0],[0.0,0.0,1.0]])
#step2:do cross_entropy
cross_entropy = -tf.reduce_sum(y_*tf.log(y))
#do cross_entropy just one step
cross_entropy2=tf.reduce_sum(tf.nn.softmax_cross_entropy_with_logits(logits, y_))#dont forget tf.reduce_sum()!!
with tf.Session() as sess:
softmax=sess.run(y)
c_e = sess.run(cross_entropy)
c_e2 = sess.run(cross_entropy2)
print("step1:softmax result=")
print(softmax)
print("step2:cross_entropy result=")
print(c_e)
print("Function(softmax_cross_entropy_with_logits) result=")
print(c_e2)
输出结果是:
step1:softmax result=
[[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]
[ 0.09003057 0.24472848 0.66524094]]
step2:cross_entropy result=
1.22282
Function(softmax_cross_entropy_with_logits) result=
1.2228
最后大家可以试试e^1/(e^1+e^2+e^3)是不是0.09003057,发现确实一样!!这也证明了我们的输出是符合公式逻辑的
来源:https://blog.csdn.net/mao_xiao_feng/article/details/53382790
猜你喜欢
- 本文实例讲述了Flask框架实现的前端RSA加密与后端Python解密功能。分享给大家供大家参考,具体如下:前言在使用 Flask 开发用户
- 如何使用数据绑定控件实现不换页提交数据?Chunfeng.html' 提交页面< html><
- 本程序将使用字典来构建有向无环图,然后遍历图将其转换为对应的Excel文件最终结果如下:代码:(py3) [root@7-o-1 py-da
- 前序There should be one - and preferably only one - obvious way to do it
- Python 使用 selenium 进行自动化测试 或者协助日常工作,内容如下所示:1、基础准备需要准备 Python 环境需要安装 se
- 现在看小说已经有了听书这个功能了,但是有时候你想看的书的听书功能收费,这时候可能大家就只能老老实实选择看或者付费听。(还能拿来练英语听力欸嘿
- pycharm sql语句警告产生原因为没有配置数据库,配置数据库,似乎没什么作用那么,直接去掉他的警告提示找到setting->ed
- 前言:大概一年前写的,前段时间跑了下,发现还能用,就分享出来了供大家学习,代码的很多细节不太记得了,也尽力做了优化。因为毕竟是微博,反爬技术
- 主要功能在copyFiles()函数里实现,如下:def copyFiles(src, dst): sr
- 一.Pygame程序基本搭建过程Pygame搭建游戏窗口主要为如下几步1.初始化化程序在使用Pygame编程之前,我们要对程序进行初始化,代
- 本文实例讲述了Python读取properties配置文件操作。分享给大家供大家参考,具体如下:工作需要将Java项目的逻辑改为python
- torch.nn 是专门为神经网络设计的模块化接口,nn构建于autgrad之上,可以用来定义和运行神经网络nn.Module 是nn中重要
- 绘制八个子图import matplotlib.pyplot as pltfig = plt.figure()shape=['.
- 一、怎么样取得最新版本的MySQL?要安装MySQL,首先要当然要取得它的最新版本,虽然大家都知道在FreeBSD的Packages中可以找
- 本文实例为大家分享了Python turtle实现贪吃蛇游戏的具体代码,供大家参考,具体内容如下# Simple Snake Game in
- php文件 <?php class xpathExtension{ public static function getNodes($
- 一、is_numberic函数简介国内一部分CMS程序里面有用到过is_numberic函数,我们先看看这个函数的结构bool is_num
- 引言最近再做图像处理相关的操作的时间优化,用到了OpenCV和Pillow两个库,两个库各有优缺点。各位小伙伴需要按照自己需求选用。本篇博客
- 在知乎上遇到一个问题,说:计算机中的「null」怎么读?null正确的发音是/n^l/,有点类似四声‘纳儿&rs
- 1、爬取网页分析爬取的目标网址为:https://www.gushiwen.cn/在登陆界面需要做的工作有,获取验证码图片,并识别该验证码,