pytorch下使用LSTM神经网络写诗实例
作者:ColdCabbage 发布时间:2022-03-12 13:21:43
标签:pytorch,LSTM,神经网络,写诗
在pytorch下,以数万首唐诗为素材,训练双层LSTM神经网络,使其能够以唐诗的方式写诗。
代码结构分为四部分,分别为
1.model.py,定义了双层LSTM模型
2.data.py,定义了从网上得到的唐诗数据的处理方法
3.utlis.py 定义了损失可视化的函数
4.main.py定义了模型参数,以及训练、唐诗生成函数。
参考:电子工业出版社的《深度学习框架PyTorch:入门与实践》第九章
main代码及注释如下
import sys, os
import torch as t
from data import get_data
from model import PoetryModel
from torch import nn
from torch.autograd import Variable
from utils import Visualizer
import tqdm
from torchnet import meter
import ipdb
class Config(object):
data_path = 'data/'
pickle_path = 'tang.npz'
author = None
constrain = None
category = 'poet.tang' #or poet.song
lr = 1e-3
weight_decay = 1e-4
use_gpu = True
epoch = 20
batch_size = 128
maxlen = 125
plot_every = 20
#use_env = True #是否使用visodm
env = 'poety'
#visdom env
max_gen_len = 200
debug_file = '/tmp/debugp'
model_path = None
prefix_words = '细雨鱼儿出,微风燕子斜。'
#不是诗歌组成部分,是意境
start_words = '闲云潭影日悠悠'
#诗歌开始
acrostic = False
#是否藏头
model_prefix = 'checkpoints/tang'
#模型保存路径
opt = Config()
def generate(model, start_words, ix2word, word2ix, prefix_words=None):
'''
给定几个词,根据这几个词接着生成一首完整的诗歌
'''
results = list(start_words)
start_word_len = len(start_words)
# 手动设置第一个词为<START>
# 这个地方有问题,最后需要再看一下
input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
if opt.use_gpu:input=input.cuda()
hidden = None
if prefix_words:
for word in prefix_words:
output,hidden = model(input,hidden)
# 下边这句话是为了把input变成1*1?
input = Variable(input.data.new([word2ix[word]])).view(1,1)
for i in range(opt.max_gen_len):
output,hidden = model(input,hidden)
if i<start_word_len:
w = results[i]
input = Variable(input.data.new([word2ix[w]])).view(1,1)
else:
top_index = output.data[0].topk(1)[1][0]
w = ix2word[top_index]
results.append(w)
input = Variable(input.data.new([top_index])).view(1,1)
if w=='<EOP>':
del results[-1] #-1的意思是倒数第一个
break
return results
def gen_acrostic(model,start_words,ix2word,word2ix, prefix_words = None):
'''
生成藏头诗
start_words : u'深度学习'
生成:
深木通中岳,青苔半日脂。
度山分地险,逆浪到南巴。
学道兵犹毒,当时燕不移。
习根通古岸,开镜出清羸。
'''
results = []
start_word_len = len(start_words)
input = Variable(t.Tensor([word2ix['<START>']]).view(1,1).long())
if opt.use_gpu:input=input.cuda()
hidden = None
index=0 # 用来指示已经生成了多少句藏头诗
# 上一个词
pre_word='<START>'
if prefix_words:
for word in prefix_words:
output,hidden = model(input,hidden)
input = Variable(input.data.new([word2ix[word]])).view(1,1)
for i in range(opt.max_gen_len):
output,hidden = model(input,hidden)
top_index = output.data[0].topk(1)[1][0]
w = ix2word[top_index]
if (pre_word in {u'。',u'!','<START>'} ):
# 如果遇到句号,藏头的词送进去生成
if index==start_word_len:
# 如果生成的诗歌已经包含全部藏头的词,则结束
break
else:
# 把藏头的词作为输入送入模型
w = start_words[index]
index+=1
input = Variable(input.data.new([word2ix[w]])).view(1,1)
else:
# 否则的话,把上一次预测是词作为下一个词输入
input = Variable(input.data.new([word2ix[w]])).view(1,1)
results.append(w)
pre_word = w
return results
def train(**kwargs):
for k,v in kwargs.items():
setattr(opt,k,v) #设置apt里属性的值
vis = Visualizer(env=opt.env)
#获取数据
data, word2ix, ix2word = get_data(opt) #get_data是data.py里的函数
data = t.from_numpy(data)
#这个地方出错了,是大写的L
dataloader = t.utils.data.DataLoader(data,
batch_size = opt.batch_size,
shuffle = True,
num_workers = 1) #在python里,这样写程序可以吗?
#模型定义
model = PoetryModel(len(word2ix), 128, 256)
optimizer = t.optim.Adam(model.parameters(), lr=opt.lr)
criterion = nn.CrossEntropyLoss()
if opt.model_path:
model.load_state_dict(t.load(opt.model_path))
if opt.use_gpu:
model.cuda()
criterion.cuda()
#The tnt.AverageValueMeter measures and returns the average value
#and the standard deviation of any collection of numbers that are
#added to it. It is useful, for instance, to measure the average
#loss over a collection of examples.
#The add() function expects as input a Lua number value, which
#is the value that needs to be added to the list of values to
#average. It also takes as input an optional parameter n that
#assigns a weight to value in the average, in order to facilitate
#computing weighted averages (default = 1).
#The tnt.AverageValueMeter has no parameters to be set at initialization time.
loss_meter = meter.AverageValueMeter()
for epoch in range(opt.epoch):
loss_meter.reset()
for ii,data_ in tqdm.tqdm(enumerate(dataloader)):
#tqdm是python中的进度条
#训练
data_ = data_.long().transpose(1,0).contiguous()
#上边一句话,把data_变成long类型,把1维和0维转置,把内存调成连续的
if opt.use_gpu: data_ = data_.cuda()
optimizer.zero_grad()
input_, target = Variable(data_[:-1,:]), Variable(data_[1:,:])
#上边一句,将输入的诗句错开一个字,形成训练和目标
output,_ = model(input_)
loss = criterion(output, target.view(-1))
loss.backward()
optimizer.step()
loss_meter.add(loss.data[0]) #为什么是data[0]?
#可视化用到的是utlis.py里的函数
if (1+ii)%opt.plot_every ==0:
if os.path.exists(opt.debug_file):
ipdb.set_trace()
vis.plot('loss',loss_meter.value()[0])
# 下面是对目前模型情况的测试,诗歌原文
poetrys = [[ix2word[_word] for _word in data_[:,_iii]]
for _iii in range(data_.size(1))][:16]
#上面句子嵌套了两个循环,主要是将诗歌索引的前十六个字变成原文
vis.text('</br>'.join([''.join(poetry) for poetry in
poetrys]),win = u'origin_poem')
gen_poetries = []
#分别以以下几个字作为诗歌的第一个字,生成8首诗
for word in list(u'春江花月夜凉如水'):
gen_poetry = ''.join(generate(model,word,ix2word,word2ix))
gen_poetries.append(gen_poetry)
vis.text('</br>'.join([''.join(poetry) for poetry in
gen_poetries]), win = u'gen_poem')
t.save(model.state_dict(), '%s_%s.pth' %(opt.model_prefix,epoch))
def gen(**kwargs):
'''
提供命令行接口,用以生成相应的诗
'''
for k,v in kwargs.items():
setattr(opt,k,v)
data, word2ix, ix2word = get_data(opt)
model = PoetryModel(len(word2ix), 128, 256)
map_location = lambda s,l:s
# 上边句子里的map_location是在load里用的,用以加载到指定的CPU或GPU,
# 上边句子的意思是将模型加载到默认的GPU上
state_dict = t.load(opt.model_path, map_location = map_location)
model.load_state_dict(state_dict)
if opt.use_gpu:
model.cuda()
if sys.version_info.major == 3:
if opt.start_words.insprintable():
start_words = opt.start_words
prefix_words = opt.prefix_words if opt.prefix_words else None
else:
start_words = opt.start_words.encode('ascii',\
'surrogateescape').decode('utf8')
prefix_words = opt.prefix_words.encode('ascii',\
'surrogateescape').decode('utf8') if opt.prefix_words else None
start_words = start_words.replace(',',u',')\
.replace('.',u'。')\
.replace('?',u'?')
gen_poetry = gen_acrostic if opt.acrostic else generate
result = gen_poetry(model,start_words,ix2word,word2ix,prefix_words)
print(''.join(result))
if __name__ == '__main__':
import fire
fire.Fire()
以上代码给我一些经验,
1. 了解python的编程方式,如空格、换行等;进一步了解python的各个基本模块;
2. 可能出的错误:函数名写错,大小写,变量名写错,括号不全。
3. 对cuda()的用法有了进一步认识;
4. 学会了调试程序(fire);
5. 学会了训练结果的可视化(visdom);
6. 进一步的了解了LSTM,对深度学习的架构、实现有了宏观把控。
这篇pytorch下使用LSTM神经网络写诗实例就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持脚本之家。
来源:https://blog.csdn.net/weixin_39845112/article/details/80045091


猜你喜欢
- http://www.cppcns.com/shujuku/mysql/283231.html 也可以参照这个8.0
- Updates(2019.8.14 19:53)吃饭前用这个方法实战了一下,吃完回来一看好像不太行:跑完一组参数之后,到跑下一组参数时好像没
- def report_hook(count, block_size, total_size):... &n
- 目录描述语法使用示例1. 所有参数都省略2. 指定key参数3. 指定reverse参数注意事项1. sort函数会改变原列表顺序2. 列表
- 由于工作需求,想实现一个多级联动选择器,但是网上现有的联动选择器都不是我想要的,我参照基于vue2.0的element-ui中的Cascad
- 前言:大家都知道python项目中需要导入各种包(这里的包引鉴于java中的),官话来讲就是Module。而什么又是Module呢,通俗来讲
- Mysql的存储过程是从版本5才开始支持的,所以目前一般使用的都可以用到存储过程。今天分享下自己对于Mysql存储过程的认识与了解。一些简单
- 在传统的递归中,典型的模式是,你执行第一个递归调用,然后接着调用下一个递归来计算结果。这种方式中途你是得不到计算结果,知道所有的递归调用都返
- 随机漫步是这样行走得到的途径:每次行走都是完全随机的,没有明确的方向,结果是由一系列随机决策决定的。random_walk.py#rando
- 本文实例讲述了thinkPHP实现MemCache分布式缓存功能。分享给大家供大家参考,具体如下:两天在研究MemCache分布式缓存的问题
- 前言本文主要跟大家介绍了关于Vue实例中生命周期created和mounted区别的相关内容,分享出来供大家参考学习,下面话不多说了,来一起
- 初学python,写一个小程序练习一下。主要功能就是增删改查的一些功能。主要用到的技术:字典的使用,pickle的使用,io文件操作。代码如
- 本文实例讲述了Python实现的随机森林算法。分享给大家供大家参考,具体如下:随机森林是数据挖掘中非常常用的分类预测算法,以分类或回归的决策
- 本文所述的Python实现冒泡,插入,选择排序简单实例比较适合Python初学者从基础开始学习数据结构和算法,示例简单易懂,具体代码如下:#
- 1.安装pymysql:pip install pymysql (在命令行窗口中执行)2.卸载pymysql:pip uninstall p
- 一、邮件发送示例邮件发送示例flask_email及smtplib原生邮件发送示例,适用于基于Flask框架开发,但是内部设置的定时任务发送
- Python有自己内置的标准GUI库--Tkinter,只要安装好Python就可以调用。今天学习到了图形界面设计的问题,刚开始就卡住了。为
- 一、导入re库python使用正则表达式要导入re库。import re在re库中。正则表达式通常被用来检索查找、替换那些符合某个模式(规则
- oracle mysql 中的“不等于“ <> != ^= is notoracleoracle中的
- 一、什么是数据库事务数据库事务( transaction)是访问并可能操作各种数据项的一个数据库操作序列,这些操作要么全部执行,要么全部不执