tensorflow中Dense函数的具体使用
作者:一穷二白到年薪百万 发布时间:2021-04-26 17:01:49
1 作用
注意此处Tensorflow版本是2.0+。
由于本人是Pytorch用户,对Tensorflow不是很熟悉,在读到用tf写的代码时就很是麻烦。如图所示,遇到了如下代码:
h = Dense(units=adj_dim, activation=None)(dec_in)
Dense层就是全连接层,对于层方式的初始化的时候,layers.Dense(units,activation)函数一般只需要指定输出节点数Units和激活函数类型即可。输入节点数将根据第一次运算时输入的shape确定,同时输入、输出节点自动创建并初始化权值w和偏置向量b。
下面是Dense的接口
Dense(units,
activation=None,
use_bias=True,
kernel_initializer='glorot_uniform',
bias_initializer='zeros',
kernel_regularizer=None,
bias_regularizer=None,
activity_regularizer=None,
kernel_constraint=None, bias_constraint=None)
units, 代表该层的输出维度
activation=None, 激活函数.但是默认 liner
use_bias=True, 是否使用b 直线 y=ax+b 中的 b
此处没有写 iuput 的情况, 通常会有两种写法:
1 : Dense(units,input_shape())
2 : Dense(units)(x) #这里的 x 是以张量.
Dense( n )( x ) : = ReLU ( W x + b )
W 是权重函数, Dense() 会随机给 W 一个初始值。所以这里跟Pytorch的nn.linear()一样。
2 例子
# 使用第一种方法进行初始化
# 作为 Sequential 模型的第一层,需要指定输入维度。可以为 input_shape=(16,) 或者 input_dim=16,这两者是等价的。
model = Sequential()
model.add(Dense(32, input_shape=(16,)))
# 现在模型就会以尺寸为 (*, 16) 的数组作为输入,
# 其输出数组的尺寸为 (*, 32)
# 在第一层之后,就不再需要指定输入的尺寸了:
model.add(Dense(32))
3 与torch.nn.Linear的区别
# Pytorch实现
trd = torch.nn.Linear(in_features = 3, out_features = 30)
y = trd(torch.ones(5, 3))
print(y.size())
# torch.Size([5, 30])
# Tensorflow实现
model = tf.keras.models.Sequential()
model.add(tf.keras.layers.Dense(30, input_shape=(5,), activation=None))
————————————————————————————————————
tfd = tf.keras.layers.Dense(30, input_shape=(3,), activation=None)
x = tfd(tf.ones(shape=(5, 3)))
print(x.shape)
# (5, 30)
上面Tensorflow的实现方式相同,但是我存在疑惑
4 参考文献
[1]dense层、激活函数、输出层设计
[2]Dense(units, activation=None,)初步
[3]深入理解 keras 中 Dense 层参数
[4]tensorflow - Tensorflow 的 tf.keras.layers.Dense 和 PyTorch 的 torch.nn.Linear 的区别?
来源:https://blog.csdn.net/zfhsfdhdfajhsr/article/details/128950106


猜你喜欢
- 0. 前言深度学习已经成为机器学习中最受欢迎和发展最快的领域。自 2012 年深度学习性能超越机器学习等传统方法以来,深度学习架构开始快速应
- 一、连接MYSQL:格式: mysql -h主机地址 -u用户名 -p用户密码1、例1:连接到本机上的MYSQL。首先在打开DOS窗口,然后
- 1)忘记在 if , elif , else , for , while , class ,def 声明末尾添加 :(导致 “SyntaxE
- DataAccess.csusing System;using System.Collections.Generic;using Syst
- 上拉加载以及下拉刷新都是移动端很常见的功能,在搜索或者一些分类列表页面常常会用到。跟横向滚动一样,我们还是采用better-scroll这个
- 一、理论知识准备1.确定假设函数 如:y=2x+7 其中,(x,y)是一组数据,设共有m个2.误差cost 用平方误差代价函数 3.减小误差
- 前言正则表达式是文本处理领域中的一个强大的工具,它可以让文本处理的能力呈指数级的提升,如果一款文本编辑器不支持正则表达式,那么它就算不上是一
- Msg 102, Level 15, State 1, Line 3 Incorrect syntax near '+'.
- from http://www.devshed.com/c/a/MySQL/Error-Handling-Examples/ Error H
- 初学python ,研究了几天,写了一个python 调用 有道api接口程序效果看下图:申明:代码仅供和我一样的初学者学习交流有道api申
- 1. 计算给出两个时间之间的时间差import datetime as dt# current timecur_time = dt.date
- java 中JDBC连接数据库代码和步骤详解JDBC连接数据库 •创建一个以JDBC连接数据库的程序,包含7个步骤:
- 自定义模板标签,过滤器。英文翻译是Customtemplatetagsandfilters。customfilter自定义过滤器今天不在我的
- 最近正好在寻求一种Python的数据库ORM (Object Relational Mapper),SQLAlchemy (项目主页)这个开
- 目录一、 什么是自定义指令二、 如何自定义指令钩子函数三、应用场景输入框防抖图片懒加载一键 Copy的功能拖拽总结一、 什么是自定义指令我们
- 目录一、路由配置二、vue页面嵌套三、嵌套联系一、路由配置const routes = [ { pat
- HTMLParser是python用来解析html的模块。它可以分析出html里面的标签、数据等等,是一种处理html的简便途径。HTMLP
- Python自带一个轻量级的关系型数据库SQLite。这一数据库使用SQL语言。SQLite作为后端数据库,可以搭配Python建网站,或者
- 下面先给大家介绍下python获取酷狗音乐top500的下载地址 MP3格式,具体代码如下所示:# -*- coding: utf-8 -*
- 需求:需求简单:但是感觉最后那部分遍历有意思:S型数组赋值,考虑到下标,简单题先实现个差不多的m = 5cols = 9rows = 4nu