pytorch加载语音类自定义数据集的方法教程
作者:凌逆战 发布时间:2021-07-15 20:38:07
前言
pytorch对一下常用的公开数据集有很方便的API接口,但是当我们需要使用自己的数据集训练神经网络时,就需要自定义数据集,在pytorch中,提供了一些类,方便我们定义自己的数据集合
torch.utils.data.Dataset:所有继承他的子类都应该重写 __len()__ , __getitem()__ 这两个方法
__len()__ :返回数据集中数据的数量
__getitem()__ :返回支持下标索引方式获取的一个数据
torch.utils.data.DataLoader:对数据集进行包装,可以设置batch_size、是否shuffle....
第一步
自定义的 Dataset 都需要继承 torch.utils.data.Dataset 类,并且重写它的两个成员方法:
__len()__:读取数据,返回数据和标签
__getitem()__:返回数据集的长度
from torch.utils.data import Dataset
class AudioDataset(Dataset):
def __init__(self, ...):
"""类的初始化"""
pass
def __getitem__(self, item):
"""每次怎么读数据,返回数据和标签"""
return data, label
def __len__(self):
"""返回整个数据集的长度"""
return total
注意事项:Dataset只负责数据的抽象,一次调用getiitem只返回一个样本
案例:
文件目录结构
p225
***.wav
***.wav
***.wav
...
dataset.py
目的:读取p225文件夹中的音频数据
class AudioDataset(Dataset):
def __init__(self, data_folder, sr=16000, dimension=8192):
self.data_folder = data_folder
self.sr = sr
self.dim = dimension
# 获取音频名列表
self.wav_list = []
for root, dirnames, filenames in os.walk(data_folder):
for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
self.wav_list.append(os.path.join(root, filename))
def __getitem__(self, item):
# 读取一个音频文件,返回每个音频数据
filename = self.wav_list[item]
wb_wav, _ = librosa.load(filename, sr=self.sr)
# 取 帧
if len(wb_wav) >= self.dim:
max_audio_start = len(wb_wav) - self.dim
audio_start = np.random.randint(0, max_audio_start)
wb_wav = wb_wav[audio_start: audio_start + self.dim]
else:
wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
return wb_wav, filename
def __len__(self):
# 音频文件的总数
return len(self.wav_list)
注意事项:19-24行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
第二步
实例化 Dataset 对象
Dataset= AudioDataset("./p225", sr=16000)
如果要通过batch读取数据的可直接跳到第三步,如果你想一个一个读取数据的可以看我接下来的操作
# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)
for i, data in enumerate(train_set):
wb_wav, filname = data
print(i, wb_wav.shape, filname)
if i == 3:
break
# 0 (8192,) ./p225\p225_001.wav
# 1 (8192,) ./p225\p225_002.wav
# 2 (8192,) ./p225\p225_003.wav
# 3 (8192,) ./p225\p225_004.wav
第三步
如果想要通过batch读取数据,需要使用DataLoader进行包装
为何要使用DataLoader?
深度学习的输入是mini_batch形式
样本加载时候可能需要随机打乱顺序,shuffle操作
样本加载需要采用多线程
pytorch提供的 DataLoader 封装了上述的功能,这样使用起来更方便。
DataLoader(dataset, batch_size=1, shuffle=False, sampler=None, num_workers=0, collate_fn=default_collate, pin_memory=False, drop_last=False)
参数:
dataset:加载的数据集(Dataset对象)
batch_size:每个批次要加载多少个样本(默认值:1)
shuffle:每个epoch是否将数据打乱
sampler:定义从数据集中抽取样本的策略。如果指定,则不能指定洗牌。
batch_sampler:类似于sampler,但每次返回一批索引。与batch_size、shuffle、sampler和drop_last相互排斥。
num_workers:使用多进程加载的进程数,0代表不使用多线程
collate_fn:如何将多个样本数据拼接成一个batch,一般使用默认拼接方式
pin_memory:是否将数据保存在pin memory区,pin memory中的数据转到GPU会快一些
drop_last:dataset中的数据个数可能不是batch_size的整数倍,drop_last为True会将多出来不足一个batch的数据丢弃
返回:数据加载器
案例:
# 实例化AudioDataset对象
train_set = AudioDataset("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
for (i, data) in enumerate(train_loader):
wav_data, wav_name = data
print(wav_data.shape) # torch.Size([8, 8192])
print(i, wav_name)
# ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
# './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')
我们来吃几个栗子消化一下:
栗子1
这个例子就是本文一直举例的,栗子1只是合并了一下而已
文件目录结构
p225
***.wav
***.wav
***.wav
...
dataset.py
目的:读取p225文件夹中的音频数据
import fnmatch
import os
import librosa
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
class Aduio_DataLoader(Dataset):
def __init__(self, data_folder, sr=16000, dimension=8192):
self.data_folder = data_folder
self.sr = sr
self.dim = dimension
# 获取音频名列表
self.wav_list = []
for root, dirnames, filenames in os.walk(data_folder):
for filename in fnmatch.filter(filenames, "*.wav"): # 实现列表特殊字符的过滤或筛选,返回符合匹配“.wav”字符列表
self.wav_list.append(os.path.join(root, filename))
def __getitem__(self, item):
# 读取一个音频文件,返回每个音频数据
filename = self.wav_list[item]
print(filename)
wb_wav, _ = librosa.load(filename, sr=self.sr)
# 取 帧
if len(wb_wav) >= self.dim:
max_audio_start = len(wb_wav) - self.dim
audio_start = np.random.randint(0, max_audio_start)
wb_wav = wb_wav[audio_start: audio_start + self.dim]
else:
wb_wav = np.pad(wb_wav, (0, self.dim - len(wb_wav)), "constant")
return wb_wav, filename
def __len__(self):
# 音频文件的总数
return len(self.wav_list)
train_set = Aduio_DataLoader("./p225", sr=16000)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
for (i, data) in enumerate(train_loader):
wav_data, wav_name = data
print(wav_data.shape) # torch.Size([8, 8192])
print(i, wav_name)
# ('./p225\\p225_293.wav', './p225\\p225_156.wav', './p225\\p225_277.wav', './p225\\p225_210.wav',
# './p225\\p225_126.wav', './p225\\p225_021.wav', './p225\\p225_257.wav', './p225\\p225_192.wav')
注意事项:
27-33行:每个音频的长度不一样,如果直接读取数据返回出来的话,会造成维度不匹配而报错,因此只能每次取一个音频文件读取一帧,这样显然并没有用到所有的语音数据,
48行:我们在__getitem__中并没有将numpy数组转换为tensor格式,可是第48行显示数据是tensor格式的。这里需要引起注意
栗子2
相比于案例1,案例二才是重点,因为我们不可能每次只从一音频文件中读取一帧,然后读取另一个音频文件,通常情况下,一段音频有很多帧,我们需要的是按顺序的读取一个batch_size的音频帧,先读取第一个音频文件,如果满足一个batch,则不用读取第二个batch,如果不足一个batch则读取第二个音频文件,来补充。
我给出一个建议,先按顺序读取每个音频文件,以窗长8192、帧移4096对语音进行分帧,然后拼接。得到(帧数,帧长,1)(frame_num, frame_len, 1)的数组保存到h5中。然后用上面讲到的 torch.utils.data.Dataset 和 torch.utils.data.DataLoader 读取数据。
具体实现代码:
第一步:创建一个H5_generation脚本用来将数据转换为h5格式文件:
第二步:通过Dataset从h5格式文件中读取数据
import numpy as np
from torch.utils.data import Dataset
from torch.utils.data import DataLoader
import h5py
def load_h5(h5_path):
# load training data
with h5py.File(h5_path, 'r') as hf:
print('List of arrays in input file:', hf.keys())
X = np.array(hf.get('data'), dtype=np.float32)
Y = np.array(hf.get('label'), dtype=np.float32)
return X, Y
class AudioDataset(Dataset):
"""数据加载器"""
def __init__(self, data_folder):
self.data_folder = data_folder
self.X, self.Y = load_h5(data_folder) # (3392, 8192, 1)
def __getitem__(self, item):
# 返回一个音频数据
X = self.X[item]
Y = self.Y[item]
return X, Y
def __len__(self):
return len(self.X)
train_set = AudioDataset("./speaker225_resample_train.h5")
train_loader = DataLoader(train_set, batch_size=64, shuffle=True, drop_last=True)
for (i, wav_data) in enumerate(train_loader):
X, Y = wav_data
print(i, X.shape)
# 0 torch.Size([64, 8192, 1])
# 1 torch.Size([64, 8192, 1])
# ...
我尝试在__init__中生成h5文件,但是会导致内存 * ,就很奇怪,因此我只好分开了,
参考
pytorch学习(四)—自定义数据集(讲的比较详细)
来源:https://www.cnblogs.com/LXP-Never/p/13816254.html


猜你喜欢
- 一、需求说明需要使用Python实现将内容转为base64编码,解码,方便后续的数据操作。二、base64简介Base64是一种二进制到文本
- 数据初始化import pandas as pdimport numpy as npa=np.array([['北京',
- 介绍go1.5+版本提供编译好的安装包,我们只需要解压到相应的目录,并添加一些环境变量的配置即可。Go语言的安装步骤
- Dreamweaver MX 2004的强大功能以及更加完善的人性化设置已经深受大家喜爱。在此笔者就谈
- 0.引言利用Dlib官方训练好的模型“shape_predictor_68_face_landmarks.dat”进行68点标定,利用Ope
- 先看map。map()函数接收两个参数,一个是函数,一个是序列,map将传入的函数依次作用到序列的每个元素,并把结果作为新的list返回。举
- 初学初用,随手记录以当作笔记使用,会慢慢再进行补充添加,错误之处烦请指正。(1)运行本地文件,在代码不加载的情况下可以直接显示结果% run
- 前言在我们对DataFrame对象进行处理时候,下意识的会想到对DataFrame进行遍历,然后将处理后的值再填入DataFrame中,这样
- 关于django celery的使用网上有很多文章,本文就不多做更多的说明。本文使用版本python==3.8.15Django==3.2.
- 我发现有的网站利用了SQL SERVER提供的通过EXCHANGE或OUTLOOK收发邮件的扩展存储过程来完成收发和自动处理邮件(这句话太长
- Python读写word文档有现成的库可以处理。我这里采用 python-docx。可以用pip install python-docx安装
- 爬取一些网站下指定的内容,一般来说可以用xpath来直接从网页上来获取,但是当我们获取的内容不唯一的时候我们无法选择,我们所需要的、所指定的
- 本文实例讲述了Python处理命令行参数模块optpars用法。分享给大家供大家参考,具体如下:optpars是python中用来处理命令行
- 同样是做表格,但是有些人的表格就做的很好看。融合了之前所学不同模块的知识,来讲讲Django中生成表格的特殊方法。这里只是mark一下导出的
- 1.如何统计序列中元素出现的频率并排序?统计序列中元素出现的频率的结果肯定是一个字典,Key 为序列中的元素而 Value 为元素出现的次数
- 目录 一、前言1.1 什么是 import 机制?1.2 import 是如何执行的?二、import 机制概览三、import
- 前言在用python处理表格数据中,这其中的工作重点就是对表格类型的数据进行梳理、计算和展示,本文重点介绍展示这个方面的工作。首先我们看一个
- 简介pyenv 是一个开源的 Python 版本管理工具,可以轻松地给系统安装任意 Python 版本,想玩哪个版本,瞬间就可以切换。有了
- 在类中每次实例化一个对象都会生产一个字典来保存一个对象的所有的实例属性,这样非常的有用处,可以使我们任意的去设置新的属性。每次实例化一个对象
- 文章简介本文介绍一种 Golang 程序在运行时加载 C 动态库的技术,跳过了 Golang 项目编译阶段需要链接 C 动态库的过程,提高了