人工智能Text Generation文本生成原理示例详解
作者:我是王大你是谁 发布时间:2022-01-16 22:45:28
承上启下
上一篇文章我们介绍了 RNN 相关的基础知识,现在我们介绍文本生成的基本原理,主要是为了能够灵活运用 RNN 的相关知识,真实的文本生成项目在实操方面比这个要复杂,但是基本的原理是不变的,这里就是抛砖引玉了。
RNN 基础知识回顾链接:https://www.jb51.net/article/228994.htm
原理
我们这里用到了 RNN 来进行文本生成,其他的可以对时序数据进行建模的模型都可以拿来使用,如 LSTM 等。这里假如已经训练好一个 RNN 模型来预测下一个字符,假如我们限定了输入的长度为为 21 ,这里举例说明:
input:“the cat sat on the ma”
把 21 个字符的文本分割成字符级别的输入,输进到模型中,RNN 来积累输入的信息,最终输出的状态向量 h ,然后经过全连接层转换和 Softmax 分类器的分类,最终输出是一个候选字符的概率分布。
在上面的例子中,输入“the cat sat on the ma”,最后会输出 26 个英文字母和其他若干用到的字符(如可能还有标点,空格等)的概率分布。
"a" --> 0.05
"b" --> 0.03
"c" --> 0.054
...
"t" --> 0.06
...
"," --> 0.01
"。" --> 0.04
此时预测的下一个字符“t”概率值最大,所以选择“t”作为下一个字符,我们之后将“t”拼接到“the cat sat on the ma”之后得到“the cat sat on the mat”,然后我们取后 21 个字符“he cat sat on the mat”,输入到模型中
input:“he cat sat on the mat”
此时加入预测下一个字符的概率分布中“。”的概率最大,我们就取“。”拼接到“the cat sat on the mat”之后,得到“the cat sat on the mat。”,如果还需要继续进行下去,则不断重复上面的过程。如果我们的文本生成要求到此结束,则最终得到了文本
the cat sat on the mat。
通常我们要用和目标相同的数据进行训练。如想生成诗词,就用唐诗宋词去训练模型,像生成歌词,就用周杰伦的歌词去训练。
选取预测的下一个字符的三种方式
一般在得到概率分布,然后去预测下一个字符的时候,会有三种方法。
第一种方法就是像上面提到的,选择概率分布中概率最大的字符即可。这种方法虽然最简单,但是效果并不是最好的,因为几乎预测字符都是确定的,但是不能达到多元化的有意思的字符结果。公式如下:
next_index = np.argmax(pred)
第二种方法会从多项分布中随机抽样,预测成某个字符的概率为多少,则它被选取当作下一个字符的概率就是多少。在实际情况中往往概率分布中的值都很小,而且很多候选项的概率相差不大,这样大家被选择的概率都差不多,下一个字符的预测随机性就很强。假如我们得到预测成某个正确字符的概率为 0.1 ,而预测成其他几个字符的概率也就只是稍微低于 0.1 ,那么这几个字符被选取当作下一个字符的概率都很相近。这种方式过于随机,生成的文本的语法和拼写错误往往很多。公式如下:
next_onehot = np.random.multimomial(1, pred, 1)
next_index = np.argmax(next_onehot)
第三种方法是介于前两种方法之间的一种,生成的下一个字符具有一定的随机性,但是随机性并不大,这要靠 temperature 参数进行调节, temperature 是在 0 到 1 之间的小数,如果为 1 则和第一种方法相同,如果为其他值则可以将概率进行不同程度的放大,这表示概率大的字符越大概率被选取到,概率小的字符越小概率被选择到,这样就可以有明显的概率区分度,这样就不会出现第二种方法中的情况。公式如下所示:
pred = pred ** (1/temperature)
pred = pred / np.sum(pred)
训练
假如我们有一句话作为训练数据,如下:
Machine learning is a subset of artificial intelligence.
我们设置两个参数 len = 5 和 stride = 3 ,len 是输入长度,stride 是步长,我们将输入 5 个字符作为输入,然后输入下一个字符作为标签,如下
input:“Machi”
target:“n”
然后因为我们设置了 stride 为 3 ,所以我们在文本中向右平移 3 位,然后又选择 5 个字符作为输入,之后的一个字符作为标签,如下:
input:“hine ”
target:“l”
如此往复,不断向右平移 3 个字符,将新得到的 5 个字符和接下来的 1 个字符作为标签作为训练数据输入到模型中,让模型学习文本内部的特征。其实训练数据就是(字符串,下一个字符)的键值对。此时得到的所有训练数据为:
input:'Machi'
target:'n'
input:'hine '
target:'l'
input:'e lea'
target:'r'
...
input:'ligen'
target:'c'
然后用这些训练数据进行大量的训练得到的模型,就可以用来生成新的文本啦!。
来源:https://juejin.cn/post/6973567782113771551


猜你喜欢
- 安装使用pip install XlsxWriter来安装,Xlsxwriter用来创建excel表格,功能很强大,下面具体介绍:1.简单使
- Yolov5如何更换BiFPN?第一步:修改common.py将如下代码添加到common.py文件中# BiFPN # 两个特征图add操
- 在python中,总的来说有三种大的模式打开文件,分别是:a, w, r当以a模式打开时,只能写文件,而且是在文件末尾添加内容。当以a+模式
- Djangos 内置的模板加载器(在先前的模板加载内幕章节有叙述)通常会满足你的所有的模板加载需求,但是如果你有特殊的加载需求的话,编写自己
- FSO,正如UFO般令人激动、令人神往,当然更多的亦是让人欢喜让人忧。君不见某空间服务商广告:100MB空间只要60RMB/年,支持数据库,
- 1.认识数组数组就是某类数据的集合,数据类型可以是整型、字符串、甚至是对象Javascript不支持多维数组,但是因为数组里面可以包含对象(
- 1.Jinja21.简介Jinja2是Python下一个被广泛应用的模版引擎,他的设计思想来源于Django的模板引擎,并扩展了其语法和一系
- ⭐️requests的使用(二)上一篇我们说了requests的简单用法,知道了如何发送请求,今天我们更深层次的来学习requests。我们
- 如果遇到与文件许可有关的问题,可能数启动mysqld时UMASK环境变量设置得不正确。例如,当你创建表时,MySQL可能会发出下述错误消息:
- 线性代数线性代数,矩阵计算,优化与内存;比如矩阵乘法,分解,行列式等数学知识,是所有数组类库的重要组成部分。和MATLAB等其他语言相比,n
- 王者荣耀的火爆就不用说了,但是一局中总会有那么几个挂机的,总能看到有些人在骂人,我们发现,当你输入一些常见的辱骂性词汇时,系统会自动将该词变
- 本文主要介绍了vscode插件听网易云的实现,具体如下:当真正的听到了我本人的我喜欢的歌单里的歌时,惊呆了老铁,所以我此时此刻用激动的心颤抖
- ▲ SHOW执行下面这个命令可以了解服务器的运行状态mysql >show status;该命令将显示出一长列状态
- pydev debugger: process 10341 is connecting无法debu今天在Pycharm中debug时无法正常
- 前言plt.show()展示图片的时候,截图进行保存,图片不是多么清晰如何保存高清图也是一知识点函数包名:import matplotlib
- 一、下载1.mysql官网下载地址:https://downloads.mysql.com/archives/community/2.下载完
- 语法:ROW_NUMBER() OVER(PARTITION BY COLUMN ORDER BY COLUMN) <BR> 例
- 微信小程序canvas写字板效果及实例写字板效果:书写文字,画板重置,导出图片,导出图片前判断是否书写内容app.json:添加一个路由:&
- asin()方法返回x的反正弦,以弧度表示。语法以下是asin()方法语法:asin(x)注意:此函数是无法直接访问的,所以我们
- 背景说明:10 * time.Second //正常数字相乘没错但是package mainimport "time"f