nlp计数法应用于PTB数据集示例详解
作者:jym蒟蒻 发布时间:2023-10-26 17:24:07
标签:nlp,计数法,PTB,数据集
PTB数据集
内容如下:
一行保存一个句子;将稀有单词替换成特殊字符 < unk > ;将具体的数字替换 成“N”
we 're talking about years ago before anyone heard of asbestos having any questionable properties
there is no asbestos in our products now
neither <unk> nor the researchers who studied the workers were aware of any research on smokers of the kent cigarettes
we have no useful information on whether users are at risk said james a. <unk> of boston 's <unk> cancer institute
dr. <unk> led a team of researchers from the national cancer institute and the medical schools of harvard university and boston university
ptb.py
使用PTB数据集:
由下面这句话,可知用PTB数据集时候,是把所有句子首尾连接了。
words = open(file_path).read().replace('\n', '<eos>').strip().split()
ptb.py起到了下载PTB数据集,把数据集存到文件夹某个位置,然后对数据集进行提取的功能,提取出corpus, word_to_id, id_to_word。
import sys
import os
sys.path.append('..')
try:
import urllib.request
except ImportError:
raise ImportError('Use Python3!')
import pickle
import numpy as np
url_base = 'https://raw.githubusercontent.com/tomsercu/lstm/master/data/'
key_file = {
'train':'ptb.train.txt',
'test':'ptb.test.txt',
'valid':'ptb.valid.txt'
}
save_file = {
'train':'ptb.train.npy',
'test':'ptb.test.npy',
'valid':'ptb.valid.npy'
}
vocab_file = 'ptb.vocab.pkl'
dataset_dir = os.path.dirname(os.path.abspath(__file__))
def _download(file_name):
file_path = dataset_dir + '/' + file_name
if os.path.exists(file_path):
return
print('Downloading ' + file_name + ' ... ')
try:
urllib.request.urlretrieve(url_base + file_name, file_path)
except urllib.error.URLError:
import ssl
ssl._create_default_https_context = ssl._create_unverified_context
urllib.request.urlretrieve(url_base + file_name, file_path)
print('Done')
def load_vocab():
vocab_path = dataset_dir + '/' + vocab_file
if os.path.exists(vocab_path):
with open(vocab_path, 'rb') as f:
word_to_id, id_to_word = pickle.load(f)
return word_to_id, id_to_word
word_to_id = {}
id_to_word = {}
data_type = 'train'
file_name = key_file[data_type]
file_path = dataset_dir + '/' + file_name
_download(file_name)
words = open(file_path).read().replace('\n', '<eos>').strip().split()
for i, word in enumerate(words):
if word not in word_to_id:
tmp_id = len(word_to_id)
word_to_id[word] = tmp_id
id_to_word[tmp_id] = word
with open(vocab_path, 'wb') as f:
pickle.dump((word_to_id, id_to_word), f)
return word_to_id, id_to_word
def load_data(data_type='train'):
'''
:param data_type: 数据的种类:'train' or 'test' or 'valid (val)'
:return:
'''
if data_type == 'val': data_type = 'valid'
save_path = dataset_dir + '/' + save_file[data_type]
word_to_id, id_to_word = load_vocab()
if os.path.exists(save_path):
corpus = np.load(save_path)
return corpus, word_to_id, id_to_word
file_name = key_file[data_type]
file_path = dataset_dir + '/' + file_name
_download(file_name)
words = open(file_path).read().replace('\n', '<eos>').strip().split()
corpus = np.array([word_to_id[w] for w in words])
np.save(save_path, corpus)
return corpus, word_to_id, id_to_word
if __name__ == '__main__':
for data_type in ('train', 'val', 'test'):
load_data(data_type)
使用ptb.py
corpus保存了单词ID列表,id_to_word 是将单词ID转化为单词的字典,word_to_id 是将单词转化为单词ID的字典。
使用ptb.load_data()加载数据。里面的参数 ‘train’、‘test’、‘valid’ 分别对应训练用数据、测试用数据、验证用数据。
import sys
sys.path.append('..')
from dataset import ptb
corpus, word_to_id, id_to_word = ptb.load_data('train')
print('corpus size:', len(corpus))
print('corpus[:30]:', corpus[:30])
print()
print('id_to_word[0]:', id_to_word[0])
print('id_to_word[1]:', id_to_word[1])
print('id_to_word[2]:', id_to_word[2])
print()
print("word_to_id['car']:", word_to_id['car'])
print("word_to_id['happy']:", word_to_id['happy'])
print("word_to_id['lexus']:", word_to_id['lexus'])
结果:
corpus size: 929589
corpus[:30]: [ 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23
24 25 26 27 28 29]
id_to_word[0]: aer
id_to_word[1]: banknote
id_to_word[2]: berlitz
word_to_id['car']: 3856
word_to_id['happy']: 4428
word_to_id['lexus']: 7426
Process finished with exit code 0
计数方法应用于PTB数据集
其实和不用PTB数据集的区别就在于这句话。
corpus, word_to_id, id_to_word = ptb.load_data('train')
下面这句话起降维的效果
word_vecs = U[:, :wordvec_size]
整个代码其实耗时最大的是在下面这个函数上:
W = ppmi(C, verbose=True)
完整代码:
import sys
sys.path.append('..')
import numpy as np
from common.util import most_similar, create_co_matrix, ppmi
from dataset import ptb
window_size = 2
wordvec_size = 100
corpus, word_to_id, id_to_word = ptb.load_data('train')
vocab_size = len(word_to_id)
print('counting co-occurrence ...')
C = create_co_matrix(corpus, vocab_size, window_size)
print('calculating PPMI ...')
W = ppmi(C, verbose=True)
print('calculating SVD ...')
#try:
# truncated SVD (fast!)
print("ok")
from sklearn.utils.extmath import randomized_svd
U, S, V = randomized_svd(W, n_components=wordvec_size, n_iter=5,
random_state=None)
#except ImportError:
# SVD (slow)
# U, S, V = np.linalg.svd(W)
word_vecs = U[:, :wordvec_size]
querys = ['you', 'year', 'car', 'toyota']
for query in querys:
most_similar(query, word_to_id, id_to_word, word_vecs, top=5)
下面这个是用普通的np.linalg.svd(W)做出的结果。
[query] you
i: 0.7016294002532959
we: 0.6388039588928223
anybody: 0.5868048667907715
do: 0.5612815618515015
'll: 0.512611985206604
[query] year
month: 0.6957005262374878
quarter: 0.691483736038208
earlier: 0.6661213636398315
last: 0.6327787041664124
third: 0.6230476498603821
[query] car
luxury: 0.6767407655715942
auto: 0.6339930295944214
vehicle: 0.5972712635993958
cars: 0.5888376235961914
truck: 0.5693157315254211
[query] toyota
motor: 0.7481387853622437
nissan: 0.7147319316864014
motors: 0.6946366429328918
lexus: 0.6553674340248108
honda: 0.6343469619750977
下面结果,是用了sklearn模块里面的randomized_svd方法,使用了随机数的 Truncated SVD,仅对奇异值较大的部分进行计算,计算速度比常规的 SVD 快。
calculating SVD ...
ok
[query] you
i: 0.6678948998451233
we: 0.6213737726211548
something: 0.560122013092041
do: 0.5594725608825684
someone: 0.5490139126777649
[query] year
month: 0.6444296836853027
quarter: 0.6192560791969299
next: 0.6152222156524658
fiscal: 0.5712860226631165
earlier: 0.5641934871673584
[query] car
luxury: 0.6612467765808105
auto: 0.6166062355041504
corsica: 0.5270425081253052
cars: 0.5142025947570801
truck: 0.5030257105827332
[query] toyota
motor: 0.7747215628623962
motors: 0.6871038675308228
lexus: 0.6786072850227356
nissan: 0.6618651151657104
mazda: 0.6237337589263916
Process finished with exit code 0
来源:https://www.cnblogs.com/jiangyiming/p/16102323.html
0
投稿
猜你喜欢
- 什么是LSTM1、LSTM的结构我们可以看出,在n时刻,LSTM的输入有三个:当前时刻网络的输入值Xt;上一时刻LSTM的输出值ht-1;上
- 不知道工商银行帐号是否是这样的格式, 如果错了请大家见谅!<script language="javascript"
- 大家在用asp开发程序的时候,有时候会碰到以下的错误提示:Active Server Pages 错误 'ASP 0141'
- Oracle :NvlNVL函数:NVL函数是将NULL值的字段转换成默认字段输出。NVL(expr1,expr2)expr1,需要转换的字
- 对所有数据进行整合与管理当你使用SQL Server 2008企业级的数据仓库平台时,你可以高效的操纵所有数据,并对其进行统一管理存储。◆合
- 参考: Smashing magzine翻译+整理: Demix当完成一项前端的工作之后,许多人都会忘记该项目的结构与细节。然而代码并不是马
- 使用python编写了共六种图像增强算法:1)基于直方图均衡化2)基于拉普拉斯算子3)基于对数变换4)基于伽马变换5)限制对比度自适应直方图
- 可以查看mysql文件目录my.ini文件,可以找到类似于 datadir="D:/beeagle/Program Files/M
- 1 基本用法把序列乘以一个整数,就会产生一个新序列。这个新序列是原始序列复制了整数份,然后再拼接起来的结果。l=[1,2,3]l2=l *
- 首先介绍下比较简单但必不可少且实用的知识,可以当手册查询,适合像我一样的新手看。PHP常用库函数介绍一、PHP字符串操作常用函数1.确定字符
- 上文我们总结过了Python多继承的相关知识,没看过的小伙伴们也可以去看看,今天给大家介绍Python类的单继承相关知识。一、类的继承面向对
- 本文转自微信公众号:算法与编程之美一、引言在具备一定的Python编程基础以后,我们可以结合for循环进行多角星的编写,只要简单的几次循环,
- 使用MySQL进行数据库备份,有很正规的数据库备份方法,同其他的数据库服务器有相同的概念,但有没有想过,MySQL会有更简捷的使用文件目录的
- 其实就是利用文件“global.asa”!许多ASP编程新手都想知道这东西是什么?事实上,global.asa就是一个事件驱动程序,其中共包
- 我今天晚上,做一个快印公司的网站布局,在Div镶套布局中,父标签DIV的高度不变。在IE下没有问题,但是在FIREFOX下就有问题了。如图:
- 今天在群里,熊猫君提议整理一个帖子,一方面为初学者提供一个入门指南,另一方面也象借此和已经在从事这个行业进行一点交流。下面是我从事这个行当多
- 1、创建表的同时创建主键约束(1)无命名create table student ( studentid int primary key n
- numpy.nan的数据类型是float类型import numpy as nptype(np.nan) # float任何数字和numpy
- 为何使用函数最大化代码的重用和最小化代码冗余流程的分解编写函数>>def语句在Python中创建一个函数是通过def关键字进行的
- psutil是什么psutil是一个能够获取系统信息(包括进程、CPU、内存、磁盘、网络等)的Python模块。主要用来做系统监控,性能分析