深度学习TextLSTM的tensorflow1.14实现示例
作者:我是王大你是谁 发布时间:2022-07-12 06:26:46
标签:tensorflow,TextLSTM,深度学习
对单词最后一个字母的预测
LSTM 的原理自己找,这里只给出简单的示例代码,就是对单词最后一个字母的预测。
# LSTM 的原理自己找,这里只给出简单的示例代码
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
# 预测最后一个字母
words = ['make','need','coal','word','love','hate','live','home','hash','star']
# 字典集
chars = [c for c in 'abcdefghijklmnopqrstuvwxyz']
# 生成字符索引字典
word2idx = {v:k for k,v in enumerate(chars)}
idx2word = {k:v for k,v in enumerate(chars)}
V = len(chars) # 字典大小
step = 3 # 时间步长大小
hidden = 50 # 隐藏层大小
dim = 32 # 词向量维度
def make_batch(words):
input_batch, target_batch = [], []
for word in words:
input = [word2idx[c] for c in word[:-1]] # 除最后一个字符的所有字符当作输入
target = word2idx[word[-1]] # 最后一个字符当作标签
input_batch.append(input)
target_batch.append(np.eye(V)[target]) # 这里将标签转换为 one-hot ,后面计算 softmax_cross_entropy_with_logits_v2 的时候会用到
return input_batch, target_batch
# 初始化词向量
embedding = tf.get_variable("embedding", shape=[V, dim], initializer=tf.random_normal_initializer)
X = tf.placeholder(tf.int32, [None, step])
# 将输入进行词嵌入转换
XX = tf.nn.embedding_lookup(embedding, X)
Y = tf.placeholder(tf.int32, [None, V])
# 定义 LSTM cell
cell = tf.nn.rnn_cell.BasicLSTMCell(hidden)
# 隐层计算结果
outputs, states = tf.nn.dynamic_rnn(cell, XX, dtype=tf.float32) # output: [batch_size, step, hidden] states: (c=[batch_size, hidden], h=[batch_size, hidden])
# 隐层连接分类器的权重和偏置参数
W = tf.Variable(tf.random_normal([hidden, V]))
b = tf.Variable(tf.random_normal([V]))
# 这里只用到了最后输出的 c 向量 states[0] (也可以用所有时间点的输出特征向量)
feature = tf.matmul(states[0], W) + b # [batch_size, n_class]
# 计算损失并进行迭代优化
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=feature, labels=Y))
optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)
# 预测
prediction = tf.argmax(feature, 1)
# 初始化 tf
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# 生产输入和标签
input_batch, target_batch = make_batch(words)
# 训练模型
for epoch in range(1000):
_, loss = sess.run([optimizer, cost], feed_dict={X:input_batch, Y:target_batch})
if (epoch+1)%100 == 0:
print('epoch: ', '%04d'%(epoch+1), 'cost=', '%04f'%(loss))
# 预测结果
predict = sess.run([prediction], feed_dict={X:input_batch})
print([words[i][:-1]+' '+idx2word[c] for i,c in enumerate(predict[0])])
结果打印
epoch: 0100 cost= 0.003784
epoch: 0200 cost= 0.001891
epoch: 0300 cost= 0.001122
epoch: 0400 cost= 0.000739
epoch: 0500 cost= 0.000522
epoch: 0600 cost= 0.000388
epoch: 0700 cost= 0.000300
epoch: 0800 cost= 0.000238
epoch: 0900 cost= 0.000193
epoch: 1000 cost= 0.000160
['mak e', 'nee d', 'coa l', 'wor d', 'lov e', 'hat e', 'liv e', 'hom e', 'has h', 'sta r']
来源:https://juejin.cn/post/6949412997903155230


猜你喜欢
- Python 是由吉多·范罗苏姆(Guido Van Rossum)在 90 年代早期设计。 它是如今最常用的编程语言之一。它的语法简洁且优
- Python文件输入输出本文以.txt文件为例,说明Python从.txt文件中读取内容和向.txt文件写入内容的方法。a.txt文件内容:
- 前几天图书馆说服务器(Ubuntu14.04)有安全漏洞,不按时修复会关停。看了一下漏洞清单,主要是ssh和mysql的版本问题。把mysq
- 本文实例为大家分享了Python实现井字棋小游戏的具体代码,供大家参考,具体内容如下import osdef print_board(boa
- Python自带一个轻量级的关系型数据库SQLite。这一数据库使用SQL语言。SQLite作为后端数据库,可以搭配Python建网站,或者
- 1.__new__(cls, *args, **kwargs) 创建对象时调用,返回当前对象的一个实例;注意:这里的第一个参数是
- spyder快捷键与python符号化输出spyder快捷键1、F5执行当前文件2、F9执行选中的部分3、Tab预加载以该字母为首的变量名例
- 我在代码里定义了两个通道,分别用于生产端口和限制连接数,如果不限制连接数,容易被对方检测到或导致对方服务器不能正常运行。// 生产端口var
- Python Json读写操作_JsonPath用法详解1. 介绍JSONPath是一种信息抽取类库,是从JSON文档中抽取指定信息的工具,
- 身份证校验码的计算方法1、将前面的身份证号码17位数分别乘以不同的系数。第i位对应的数为[2^(18-i)]mod11。从第一位到第十七位的
- 1. 简介在windows系统上,重复性的操作可以用Python脚本来完成,其中常用的模块是win32gui、win32con、win32a
- 服务器有多张显卡,一般是组里共用,分配好显卡和任务就体现公德了。除了在代码中指定使用的 GPU 编号,还可以直接设置可见 GPU 编号,使程
- 段时间作项目中,遇到使用视图的问题,以前的工作中很少遇到视图,认为直接用表就ok了,何须视图呢?下面我来讲述一下它的功用:以往当我们查询数据
- 这是我记得的问题,基本都没答上来,大家知道的教教小弟,咱不能再不会了 1.在js里类的继承一般是类抄写和原型继承混合使用,在extjs的ex
- 本文利用Python3启动简单的HTTP服务器,以实现在同一网络中共享本地文件。启动HTTP服务器打开终端,转入目标文件所在文件夹,键入以下
- 网站设计似乎朝着越来越复杂的方向发展。这部分源于显示器的逐步增大,随着宽屏显示器的增多,更有加剧网站页面复杂程度的趋势。但是我接触网站设计近
- 本文实例为大家分享了vue仿写下拉菜单功能,带有过渡效果(移动端),供大家参考,具体内容如下效果图clickOutside.js 点击目标之
- Python之成为图像处理任务的最佳选择,是因为这一科学编程语言日益普及,并且其自身免费提供许多最先进的图像处理工具。本文主要介绍了一些简单
- s={ x1,x2,x3.....};集合有自动去重的功能,而且可以进行交并补运算,而且集合是无序的,每次打印的结果不一样,故不可以用元素下
- 将json多行数据传入到mysql中使用python实现表需要提前创建,字符集utf8 如果不行换成utf8mb4import jsonim