基于pytorch的lstm参数使用详解
作者:hufei_neo 发布时间:2023-11-21 08:41:21
lstm(*input, **kwargs)
将多层长短时记忆(LSTM)神经网络应用于输入序列。
参数:
input_size:输入'x'中预期特性的数量
hidden_size:隐藏状态'h'中的特性数量
num_layers:循环层的数量。例如,设置' ' num_layers=2 ' '意味着将两个LSTM堆叠在一起,形成一个'堆叠的LSTM ',第二个LSTM接收第一个LSTM的输出并计算最终结果。默认值:1
bias:如果' False',则该层不使用偏置权重' b_ih '和' b_hh '。默认值:'True'
batch_first:如果' 'True ' ',则输入和输出张量作为(batch, seq, feature)提供。默认值: 'False'
dropout:如果非零,则在除最后一层外的每个LSTM层的输出上引入一个“dropout”层,相当于:attr:'dropout'。默认值:0
bidirectional:如果‘True',则成为双向LSTM。默认值:'False'
输入:input,(h_0, c_0)
**input**of shape (seq_len, batch, input_size):包含输入序列特征的张量。输入也可以是一个压缩的可变长度序列。
see:func:'torch.nn.utils.rnn.pack_padded_sequence' 或:func:'torch.nn.utils.rnn.pack_sequence' 的细节。
**h_0** of shape (num_layers * num_directions, batch, hidden_size):张量包含批处理中每个元素的初始隐藏状态。
如果RNN是双向的,num_directions应该是2,否则应该是1。
**c_0** of shape (num_layers * num_directions, batch, hidden_size):张量包含批处理中每个元素的初始单元格状态。
如果没有提供' (h_0, c_0) ',则**h_0**和**c_0**都默认为零。
输出:output,(h_n, c_n)
**output**of shape (seq_len, batch, num_directions * hidden_size) :包含LSTM最后一层输出特征' (h_t) '张量,
对于每个t. If a:class: 'torch.nn.utils.rnn.PackedSequence' 已经给出,输出也将是一个打包序列。
对于未打包的情况,可以使用'output.view(seq_len, batch, num_directions, hidden_size)',正向和反向分别为方向' 0 '和' 1 '。
同样,在包装的情况下,方向可以分开。
**h_n** of shape (num_layers * num_directions, batch, hidden_size):包含' t = seq_len '隐藏状态的张量。
与*output*类似, the layers可以使用以下命令分隔
h_n.view(num_layers, num_directions, batch, hidden_size) 对于'c_n'相似
**c_n** (num_layers * num_directions, batch, hidden_size):张量包含' t = seq_len '的单元状态
所有的权重和偏差都初始化自: where:
include:: cudnn_persistent_rnn.rst
import torch
import torch.nn as nn
# 双向rnn例子
# rnn = nn.RNN(10, 20, 2)
# input = torch.randn(5, 3, 10)
# h0 = torch.randn(2, 3, 20)
# output, hn = rnn(input, h0)
# print(output.shape,hn.shape)
# torch.Size([5, 3, 20]) torch.Size([2, 3, 20])
# 双向lstm例子
rnn = nn.LSTM(10, 20, 2) #(input_size,hidden_size,num_layers)
input = torch.randn(5, 3, 10) #(seq_len, batch, input_size)
h0 = torch.randn(2, 3, 20) #(num_layers * num_directions, batch, hidden_size)
c0 = torch.randn(2, 3, 20) #(num_layers * num_directions, batch, hidden_size)
# output:(seq_len, batch, num_directions * hidden_size)
# hn,cn(num_layers * num_directions, batch, hidden_size)
output, (hn, cn) = rnn(input, (h0, c0))
print(output.shape,hn.shape,cn.shape)
>>>torch.Size([5, 3, 20]) torch.Size([2, 3, 20]) torch.Size([2, 3, 20])
来源:https://blog.csdn.net/hufei_neo/article/details/94435294


猜你喜欢
- Navicat数据存放位置和备份数据库路径设置navicat的数据库存放位置在什么地方?带着这样的疑问,我们去解决问题,navicat是默认
- 一、安装1.从官网下载Linux版的Pycharm官网链接:https://www.jetbrains.com/pycharm/downlo
- 二元运算二元运算是指由两个元素形成第三个元素的一种规则,例如数的加法及乘法;更一般地,由两个集合形成第三个集合的产生方法或构成规则称为二次运
- <% pagenum=55'指定打印行数 %> <HTML> <HEAD> <
- 1、通常sql执行流程用户发起请求到业务服务器,执行sql语句时,先到连接池中获取连接,然后到mysql服务器执行查询。1.1 问题1:My
- Vue导航栏 用Vue写手机端的项目,经常会写底部导航栏,
- python提供了大量的库,可以非常方便的进行各种操作,现在把python中实现读写csv文件的方法使用程序的方式呈现出来。在编写pytho
- 通信信息包是发送至MySQL服务器的单个SQL语句,或发送至客户端的单一行。在MySQL 5.1服务器和客户端之间最大能发送的可能信息包为1
- 成品效果 <body> <div id="game" style="p
- count(*)实现1、MyISAM:将表的总行数存放在磁盘上,针对无过滤条件的查询可以直接返回如果有过滤条件的count(*),MyISA
- 本文实例讲述了微信扫码支付模式。分享给大家供大家参考,具体如下:背景:因为微信占据众多的用户群,作为程序开发,自然而然也成了研究的重点。毕竟
- 一、前言作为一个数据库爱好者,自己动手写过简单的SQL解析器以及存储引擎,但感觉还是不够过瘾。<<事务处理-概念与技术>&
- 绘制一个菱形四边形,边长为 200 像素。方法1和2绘制了内角为60和120度的菱形,方法3绘制了内角为90度的菱形。方法1
- 如何把[1, 5, 6, [2, 7, [3, [4, 5, 6]]]]变成[1, 5, 6, 2, 7, 3, 4, 5, 6]?思考:-
- 本文实现的原理很简单,优化方法是用的梯度下降。后面有测试结果。先来看看实现的示例代码:# coding=utf-8from math imp
- 如下所示:import pandas as pdfrom pandas import DataFrameseries = pd.read_c
- 当你要使用data URI scheme的时候,你会发现,虽然他可以使用在绝大多数浏览器上,但无法再IE6和IE7上工作。不过值得庆幸的这一
- 在实验中需要自己构造单独的HTTP数据报文,而使用SOCK_STREAM进行发送数据包,需要进行完整的TCP交互。因此想使用原始套接字进行编
- 1、Librosaimport librosa filepath = "/Users/birenjianmo/Desktop/le
- 本文实例为大家分享了js+css实现换肤效果的具体代码,供大家参考,具体内容如下效果图如下:需求:点击对应小圆点,下面内容颜色跟着改变主要思