深度学习TextRNN的tensorflow1.14实现示例
作者:我是王大你是谁 发布时间:2023-12-31 18:59:23
标签:tensorflow,深度学习,TextRNN
实现对下一个单词的预测
RNN 原理自己找,这里只给出简单例子的实现代码
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
sentences = ['i love damao','i like mengjun','we love all']
words = list(set(" ".join(sentences).split()))
word2idx = {v:k for k,v in enumerate(words)}
idx2word = {k:v for k,v in enumerate(words)}
V = len(words) # 词典大小
step = 2 # 时间序列长度
hidden = 5 # 隐层大小
dim = 50 # 词向量维度
# 制作输入和标签
def make_batch(sentences):
input_batch = []
target_batch = []
for sentence in sentences:
words = sentence.split()
input = [word2idx[word] for word in words[:-1]]
target = word2idx[words[-1]]
input_batch.append(input)
target_batch.append(np.eye(V)[target]) # 这里将标签改为 one-hot 编码,之后计算交叉熵的时候会用到
return input_batch, target_batch
# 初始化词向量
embedding = tf.get_variable(shape=[V, dim], initializer=tf.random_normal_initializer(), name="embedding")
X = tf.placeholder(tf.int32, [None, step])
XX = tf.nn.embedding_lookup(embedding, X)
Y = tf.placeholder(tf.int32, [None, V])
# 定义 cell
cell = tf.nn.rnn_cell.BasicRNNCell(hidden)
# 计算各个时间点的输出和隐层输出的结果
outputs, hiddens = tf.nn.dynamic_rnn(cell, XX, dtype=tf.float32) # outputs: [batch_size, step, hidden] hiddens: [batch_size, hidden]
# 这里将所有时间点的状态向量都作为了后续分类器的输入(也可以只将最后时间节点的状态向量作为后续分类器的输入)
W = tf.Variable(tf.random_normal([step*hidden, V]))
b = tf.Variable(tf.random_normal([V]))
L = tf.matmul(tf.reshape(outputs,[-1, step*hidden]), W) + b
# 计算损失并进行优化
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=L))
optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)
# 预测
prediction = tf.argmax(L, 1)
# 初始化 tf
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# 喂训练数据
input_batch, target_batch = make_batch(sentences)
for epoch in range(5000):
_, loss = sess.run([optimizer, cost], feed_dict={X:input_batch, Y:target_batch})
if (epoch+1)%1000 == 0:
print("epoch: ", '%04d'%(epoch+1), 'cost= ', '%04f'%(loss))
# 预测数据
predict = sess.run([prediction], feed_dict={X: input_batch})
print([sentence.split()[:2] for sentence in sentences], '->', [idx2word[n] for n in predict[0]])
结果打印
epoch: 1000 cost= 0.008979
epoch: 2000 cost= 0.002754
epoch: 3000 cost= 0.001283
epoch: 4000 cost= 0.000697
epoch: 5000 cost= 0.000406
[['i', 'love'], ['i', 'like'], ['we', 'love']] -> ['damao', 'mengjun', 'all']
来源:https://juejin.cn/post/6949412624215834638


猜你喜欢
- 首先,我需要强调下,这篇主旨是揭示堆表的删除记录找回的原理,我所考虑的方面并不适用于每个人的每种情况,望大家见谅~ 很多朋友认为数据库在简单
- 本文我们讲述通过 array_unique()函数删除数组中重复元素。array_unique()函数,将数组元素的值作为字符串排序,然后对
- 看下面的例子就会明白了: print '|','*'.ljust(10),'|' print
- 查看安装的python版本号可以使用【python --version】命令。具体方法:首先按【win+r】组合键打开运行;然后输入cmd,
- 数据科学领域日常使用 Python 处理大规模数据集的时候经常需要使用到合并、链接的方式进行数据集的整合,其中应用的数据类型包括 Serie
- 由于代码比较短,这里就不进行注释了代码如下:<% '当目标页面的包含文件即#include的页面里边存在respon
- 简单试用了一下IE8后,今天相对有时间点,对IE8、IE7、IE6、Firefox2.0.0.12做了简单的一些CSS HACK测
- 方法一:在目前绝大部分数据库有分布式查询的需要。下面简单的介绍如何在oracle中配置实现跨库访问。比如现在有2个数据库服务器,安装了2个数
- Cookie用于服务器实现会话,用户登录及相关功能时进行状态管理。要在用户浏览器上安装cookie,HTTP服务器向HTTP响应添加类似以下
- * 上有个有意思的话题叫细胞自动机:https://en.wikipedia.org/wiki/Cellular_automaton在2
- 本文为大家分享了VMWare linux安装mysql 5.7.13的教程,供大家参考,具体内容如下1、基础环境说明虚拟机:VMWare操作
- 我就废话不多说了,大家还是直接看代码吧!print("thresh =",thresh)coords = np.colu
- 和其他语言一样,JavaScript也有条件语句对流程上进行判断。包括各种操作符合逻辑语句比较操作符常用的比较操作符有  
- 文章中有不正确的或者说辞不清的地方,麻烦大家指出了~~~与PHP字符串转义相关的配置和函数如下: 1.magic_quotes_runtim
- 它支持主流的浏览器,包含:Chrome、Firefox、Safari、Microsoft Edge 等,同时支持以无头模式、有头模式运行pl
- 在Python代码中指定GPUimport osos.environ["CUDA_VISIBLE_DEVICES"] =
- --1、为数据库启用SQL Server全文索引EXEC sp_fulltext_database 'enable'--2、
- 前几天,Opera宣布其用户已经超过1亿——桌面版和手机版均超过5000万。Opera Mini是一个很优秀的手机浏览器,对手机用户而言,O
- 前言相信各位一定有收到过这样的群发短信,据说还被归类为玩转微信的五大技巧之一╮(╯▽╰)╭但,其实,只要跑一下脚本,就轻松找出删除自己的好友
- hints是oracle提供的一种机制,用来告诉优化器按照我们的告诉它的方式生成执行计划。我们可以用hints来实现: