深度学习TextRNN的tensorflow1.14实现示例
作者:我是王大你是谁 发布时间:2023-12-31 18:59:23
标签:tensorflow,深度学习,TextRNN
实现对下一个单词的预测
RNN 原理自己找,这里只给出简单例子的实现代码
import tensorflow as tf
import numpy as np
tf.reset_default_graph()
sentences = ['i love damao','i like mengjun','we love all']
words = list(set(" ".join(sentences).split()))
word2idx = {v:k for k,v in enumerate(words)}
idx2word = {k:v for k,v in enumerate(words)}
V = len(words) # 词典大小
step = 2 # 时间序列长度
hidden = 5 # 隐层大小
dim = 50 # 词向量维度
# 制作输入和标签
def make_batch(sentences):
input_batch = []
target_batch = []
for sentence in sentences:
words = sentence.split()
input = [word2idx[word] for word in words[:-1]]
target = word2idx[words[-1]]
input_batch.append(input)
target_batch.append(np.eye(V)[target]) # 这里将标签改为 one-hot 编码,之后计算交叉熵的时候会用到
return input_batch, target_batch
# 初始化词向量
embedding = tf.get_variable(shape=[V, dim], initializer=tf.random_normal_initializer(), name="embedding")
X = tf.placeholder(tf.int32, [None, step])
XX = tf.nn.embedding_lookup(embedding, X)
Y = tf.placeholder(tf.int32, [None, V])
# 定义 cell
cell = tf.nn.rnn_cell.BasicRNNCell(hidden)
# 计算各个时间点的输出和隐层输出的结果
outputs, hiddens = tf.nn.dynamic_rnn(cell, XX, dtype=tf.float32) # outputs: [batch_size, step, hidden] hiddens: [batch_size, hidden]
# 这里将所有时间点的状态向量都作为了后续分类器的输入(也可以只将最后时间节点的状态向量作为后续分类器的输入)
W = tf.Variable(tf.random_normal([step*hidden, V]))
b = tf.Variable(tf.random_normal([V]))
L = tf.matmul(tf.reshape(outputs,[-1, step*hidden]), W) + b
# 计算损失并进行优化
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(labels=Y, logits=L))
optimizer = tf.train.AdamOptimizer(0.001).minimize(cost)
# 预测
prediction = tf.argmax(L, 1)
# 初始化 tf
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
# 喂训练数据
input_batch, target_batch = make_batch(sentences)
for epoch in range(5000):
_, loss = sess.run([optimizer, cost], feed_dict={X:input_batch, Y:target_batch})
if (epoch+1)%1000 == 0:
print("epoch: ", '%04d'%(epoch+1), 'cost= ', '%04f'%(loss))
# 预测数据
predict = sess.run([prediction], feed_dict={X: input_batch})
print([sentence.split()[:2] for sentence in sentences], '->', [idx2word[n] for n in predict[0]])
结果打印
epoch: 1000 cost= 0.008979
epoch: 2000 cost= 0.002754
epoch: 3000 cost= 0.001283
epoch: 4000 cost= 0.000697
epoch: 5000 cost= 0.000406
[['i', 'love'], ['i', 'like'], ['we', 'love']] -> ['damao', 'mengjun', 'all']
来源:https://juejin.cn/post/6949412624215834638
0
投稿
猜你喜欢
- 一、安装相关的模块首先第一步的话我们需要安装相关的模块,通过pip命令来安装pip install gif另外由于gif模块之后会被当做是装
- 一、 yaml1、 准备支持的数据类型:字典、列表、字符串、布尔值、整数、浮点数、Null、时间等基本语法规则:大小写敏感使用缩进表示层级关
- APScheduler就是定时进行周期性的运行某些程序,在语言程序编写中,一直会遇到些定时服务,有时是根据时间定时,有时在固定的位置上进行定
- from urllib.request import urlopen  
- 在多线程中使用lock可以让多个线程在共享资源的时候不会“乱”,例如,创建多个线程,每个线程都往空列
- This is a {t}. {name}是一个很强大的字符串模板解析方法。它接受三个参数,分别是{args.text},{args.obj
- 本文实例讲述了python命令行参数解析OptionParser类的用法,分享给大家供大家参考。具体代码如下:from optparse i
- Golang Goroutine和线程的区别 Golang,轻松学习一、Golang Goroutine?当使用者分配足够多的任务,系统能自
- 前言在Python中已经内置了一个smtp邮件发送模块,Django在此基础上进行了简单地封装,让我们在Django环境中可以更方便更灵活的
- 现如今,各个国家交流密切,通过翻译使我们打破了语言壁垒,而翻译在互联网上的存在也尤为普遍。python中执行翻译操作的包是translate
- 如下所示:L = ['adam', 'Lisa', 'bart', 'Paul
- 1、官网下载地址在官网找到你想安装的版本 官网地址:https://www.python.org/并且选择下载windows版本目前最新的版
- 方法一使用Python中的内置函数isupper()和islower()来判断一个字母是否为大写或小写字母。# 获取用户输入letter =
- 对象Python 中,一切皆对象。每个对象由:标识(identity)、类型(type)、value(值)组成。1. 标识用于唯一标识对象,
- 我看见朋友可以把数据库的记录输出到页面表格上去,觉得很有用。这是怎么做的啊?见下:dbtable.asp<html><he
- 这篇文章主要介绍了python构造函数init实例方法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要
- 代码如下,U我认为对于新手来说最重要的是学会rnn读取数据的格式。# -*- coding: utf-8 -*-""&q
- 楔子我们知道python的执行效率不是很高,而且由于GIL的原因,导致python不能充分利用多核CPU。一般的解决方式是使用多进程,但是多
- 1.尽量不要对列名进行函数处理。而是针对后面的值进行处理例如where col1 = -5的效率比where -col1=5的效率要高因为后
- 0. 学习目标线性表在计算机中的表示可以采用多种方法,采用不同存储方法的线性表也有着不同的名称和特点。线性表有两种基本的存储结构:顺序存储结