pytorch 如何使用batch训练lstm网络
作者:king的江鸟 发布时间:2023-10-18 04:46:02
标签:pytorch,batch,lstm
batch的lstm
# 导入相应的包
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import torch.utils.data as Data
torch.manual_seed(1)
# 准备数据的阶段
def prepare_sequence(seq, to_ix):
idxs = [to_ix[w] for w in seq]
return torch.tensor(idxs, dtype=torch.long)
with open("/home/lstm_train.txt", encoding='utf8') as f:
train_data = []
word = []
label = []
data = f.readline().strip()
while data:
data = data.strip()
SP = data.split(' ')
if len(SP) == 2:
word.append(SP[0])
label.append(SP[1])
else:
if len(word) == 100 and 'I-PRO' in label:
train_data.append((word, label))
word = []
label = []
data = f.readline()
word_to_ix = {}
for sent, _ in train_data:
for word in sent:
if word not in word_to_ix:
word_to_ix[word] = len(word_to_ix)
tag_to_ix = {"O": 0, "I-PRO": 1}
for i in range(len(train_data)):
train_data[i] = ([word_to_ix[t] for t in train_data[i][0]], [tag_to_ix[t] for t in train_data[i][1]])
# 词向量的维度
EMBEDDING_DIM = 128
# 隐藏层的单元数
HIDDEN_DIM = 128
# 批大小
batch_size = 10
class LSTMTagger(nn.Module):
def __init__(self, embedding_dim, hidden_dim, vocab_size, tagset_size, batch_size):
super(LSTMTagger, self).__init__()
self.hidden_dim = hidden_dim
self.batch_size = batch_size
self.word_embeddings = nn.Embedding(vocab_size, embedding_dim)
# The LSTM takes word embeddings as inputs, and outputs hidden states
# with dimensionality hidden_dim.
self.lstm = nn.LSTM(embedding_dim, hidden_dim, batch_first=True)
# The linear layer that maps from hidden state space to tag space
self.hidden2tag = nn.Linear(hidden_dim, tagset_size)
def forward(self, sentence):
embeds = self.word_embeddings(sentence)
# input_tensor = embeds.view(self.batch_size, len(sentence) // self.batch_size, -1)
lstm_out, _ = self.lstm(embeds)
tag_space = self.hidden2tag(lstm_out)
scores = F.log_softmax(tag_space, dim=2)
return scores
def predict(self, sentence):
embeds = self.word_embeddings(sentence)
lstm_out, _ = self.lstm(embeds)
tag_space = self.hidden2tag(lstm_out)
scores = F.log_softmax(tag_space, dim=2)
return scores
loss_function = nn.NLLLoss()
model = LSTMTagger(EMBEDDING_DIM, HIDDEN_DIM, len(word_to_ix), len(tag_to_ix), batch_size)
optimizer = optim.SGD(model.parameters(), lr=0.1)
data_set_word = []
data_set_label = []
for data_tuple in train_data:
data_set_word.append(data_tuple[0])
data_set_label.append(data_tuple[1])
torch_dataset = Data.TensorDataset(torch.tensor(data_set_word, dtype=torch.long), torch.tensor(data_set_label, dtype=torch.long))
# 把 dataset 放入 DataLoader
loader = Data.DataLoader(
dataset=torch_dataset, # torch TensorDataset format
batch_size=batch_size, # mini batch size
shuffle=True, #
num_workers=2, # 多线程来读数据
)
# 训练过程
for epoch in range(200):
for step, (batch_x, batch_y) in enumerate(loader):
# 梯度清零
model.zero_grad()
tag_scores = model(batch_x)
# 计算损失
tag_scores = tag_scores.view(-1, tag_scores.shape[2])
batch_y = batch_y.view(batch_y.shape[0]*batch_y.shape[1])
loss = loss_function(tag_scores, batch_y)
print(loss)
# 后向传播
loss.backward()
# 更新参数
optimizer.step()
# 测试过程
with torch.no_grad():
inputs = torch.tensor([data_set_word[0]], dtype=torch.long)
print(inputs)
tag_scores = model.predict(inputs)
print(tag_scores.shape)
print(torch.argmax(tag_scores, dim=2))
补充:PyTorch基础-使用LSTM神经网络实现手写数据集识别
看代码吧~
import numpy as np
import torch
from torch import nn,optim
from torch.autograd import Variable
from torchvision import datasets,transforms
from torch.utils.data import DataLoader
# 训练集
train_data = datasets.MNIST(root="./", # 存放位置
train = True, # 载入训练集
transform=transforms.ToTensor(), # 把数据变成tensor类型
download = True # 下载
)
# 测试集
test_data = datasets.MNIST(root="./",
train = False,
transform=transforms.ToTensor(),
download = True
)
# 批次大小
batch_size = 64
# 装载训练集
train_loader = DataLoader(dataset=train_data,batch_size=batch_size,shuffle=True)
# 装载测试集
test_loader = DataLoader(dataset=test_data,batch_size=batch_size,shuffle=True)
for i,data in enumerate(train_loader):
inputs,labels = data
print(inputs.shape)
print(labels.shape)
break
# 定义网络结构
class LSTM(nn.Module):
def __init__(self):
super(LSTM,self).__init__()# 初始化
self.lstm = torch.nn.LSTM(
input_size = 28, # 表示输入特征的大小
hidden_size = 64, # 表示lstm模块的数量
num_layers = 1, # 表示lstm隐藏层的层数
batch_first = True # lstm默认格式input(seq_len,batch,feature)等于True表示input和output变成(batch,seq_len,feature)
)
self.out = torch.nn.Linear(in_features=64,out_features=10)
self.softmax = torch.nn.Softmax(dim=1)
def forward(self,x):
# (batch,seq_len,feature)
x = x.view(-1,28,28)
# output:(batch,seq_len,hidden_size)包含每个序列的输出结果
# 虽然lstm的batch_first为True,但是h_n,c_n的第0个维度还是num_layers
# h_n :[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
# c_n:[num_layers,batch,hidden_size]只包含最后一个序列的输出结果
output,(h_n,c_n) = self.lstm(x)
output_in_last_timestep = h_n[-1,:,:]
x = self.out(output_in_last_timestep)
x = self.softmax(x)
return x
# 定义模型
model = LSTM()
# 定义代价函数
mse_loss = nn.CrossEntropyLoss()# 交叉熵
# 定义优化器
optimizer = optim.Adam(model.parameters(),lr=0.001)# 随机梯度下降
# 定义模型训练和测试的方法
def train():
# 模型的训练状态
model.train()
for i,data in enumerate(train_loader):
# 获得一个批次的数据和标签
inputs,labels = data
# 获得模型预测结果(64,10)
out = model(inputs)
# 交叉熵代价函数out(batch,C:类别的数量),labels(batch)
loss = mse_loss(out,labels)
# 梯度清零
optimizer.zero_grad()
# 计算梯度
loss.backward()
# 修改权值
optimizer.step()
def test():
# 模型的测试状态
model.eval()
correct = 0 # 测试集准确率
for i,data in enumerate(test_loader):
# 获得一个批次的数据和标签
inputs,labels = data
# 获得模型预测结果(64,10)
out = model(inputs)
# 获得最大值,以及最大值所在的位置
_,predicted = torch.max(out,1)
# 预测正确的数量
correct += (predicted==labels).sum()
print("Test acc:{0}".format(correct.item()/len(test_data)))
correct = 0
for i,data in enumerate(train_loader): # 训练集准确率
# 获得一个批次的数据和标签
inputs,labels = data
# 获得模型预测结果(64,10)
out = model(inputs)
# 获得最大值,以及最大值所在的位置
_,predicted = torch.max(out,1)
# 预测正确的数量
correct += (predicted==labels).sum()
print("Train acc:{0}".format(correct.item()/len(train_data)))
# 训练
for epoch in range(10):
print("epoch:",epoch)
train()
test()
来源:https://blog.csdn.net/weixin_40939578/article/details/104462188
0
投稿
猜你喜欢
- 背景:使用python脚本传递参数在实际工作过程中还是比较常用,以下提供了好几种的实现方式:一、使用sys.argv的数组传入说明:使用sy
- python 列表和链表的区别python 中的 list 并不是我们传统意义上的列表,传统列表——通常也叫作链表(linked list)
- 前面介绍过vSQLAlchemy中的 Engine 和 Connection,这两个对象用在row SQL (原生的sql语句)上操作,而
- maketrans()方法返回的字符串intab每个字符映射到字符的字符串outtab相同位置的转换表。然后这个表被传递到tra
- 输入:[1.0000, -1.0000, 3.0000]课本中的标准差计算公式:按照上述公式计算:Numpy中的std计算:import n
- 1.官网语法pandas.read_csv(filepath_or_buffer, sep=NoDefault.no_default**,*
- eval()在print干事情之前,先看看这个东东。不是没有用,因为说不定某些时候要用到。>>> help(eval)&n
- 1.CUDA驱动和CUDA Toolkit对应版本表一:CUDA驱动及CUDA Toolkit最高对应版本最新可查阅官方文档注:驱动是向下兼
- 前言常见的通知方式有:邮件,电话,短信,微信。短信和电话:通常是收费的,较少使用;邮件:适合带文件类型的通知,较正式,存档使用;微信:适合告
- 新手,参考了以下链接:python opencv在图像上画矩形(已验证)本文可以实现在指定图片上动态绘制圆和矩形。import cv2imp
- 从Request对象中获取数据我们在第三章讲述View的函数时已经介绍过HttpRequest对象了,但当时并没有讲太多。 让我们回忆下:每
- 本文实例讲述了python创建临时文件夹的方法。分享给大家供大家参考。具体实现方法如下:import tempfile, os tempfd
- 现在我主要教大家如何去实战,做一个简易的知乎日报API 首先你要熟悉django的基本用法,会写模型,会写视图函数,会配置url。1.配置字
- 网页兼容测试,除了做不同浏览器的兼容测试,还要观察网页在不同分辨率下的表现情况。在页面中使用了CSS绝对定位,发现在宽屏下错位。随后测试非1
- 准备工作:python:https://www.python.org/downloads/Dev-C++:https://sourcefor
- 本文实例讲述了Python结合ImageMagick实现多张图片合并为一个pdf文件的方法。分享给大家供大家参考,具体如下:前段时间买了不少
- 快速入门In [1]: import time# 获取当前时间In [25]: time.strftime("%Y-%m-%d_%
- 概述从今天开始我们将开启一段自然语言处理 (NLP) 的旅程. 自然语言处理可以让来处理, 理解, 以及运用人类的语言, 实现机器语言和人类
- 前言:本篇基于Python3环境,Python2环境下的range会有所不同,但并不影响我们使用。1、range()函数是什么?range(
- 用采集程序的优点有:无须维护网站,因为采集程序中的数据来自其他网站,它将随着该网站的更新而更新;可以节省服务器资源,一般采集程序就几个文件,