Pytorch中DataLoader的使用方法详解
作者:生信小兔 发布时间:2023-07-19 04:45:39
在Pytorch中,torch.utils.data中的Dataset与DataLoader是处理数据集的两个函数,用来处理加载数据集。通常情况下,使用的关键在于构建dataset类。
一:dataset类构建。
在构建数据集类时,除了__init__(self),还要有__len__(self)与__getitem__(self,item)两个方法,这三个是必不可少的,至于其它用于数据处理的函数,可以任意定义。
class dataset:
def __init__(self,...):
...
def __len__(self,...):
return n
def __getitem__(self,item):
return data[item]
正常情况下,该数据集是要继承Pytorch中Dataset类的,但实际操作中,即使不继承,数据集类构建后仍可以用Dataloader()加载的。
在dataset类中,__len__(self)返回数据集中数据个数,__getitem__(self,item)表示每次返回第item条数据。
二:DataLoader使用
在构建dataset类后,即可使用DataLoader加载。DataLoader中常用参数如下:
1.dataset:需要载入的数据集,如前面构造的dataset类。
2.batch_size:批大小,在神经网络训练时我们很少逐条数据训练,而是几条数据作为一个batch进行训练。
3.shuffle:是否在打乱数据集样本顺序。True为打乱,False反之。
4.drop_last:是否舍去最后一个batch的数据(很多情况下数据总数N与batch size不整除,导致最后一个batch不为batch size)。True为舍去,False反之。
三:举例
兔兔以指标为1,数据个数为100的数据为例。
import torch
from torch.utils.data import DataLoader
class dataset:
def __init__(self):
self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32)
self.y=(torch.sin(self.x)+1)/2
def __len__(self):
return 100
def __getitem__(self, item):
return self.x[item],self.y[item]
data=DataLoader(dataset(),batch_size=10,shuffle=True)
for batch in data:
print(batch)
当然,利用这个数据集可以进行简单的神经网络训练。
from torch import nn
data=DataLoader(dataset(),batch_size=10,shuffle=True)
bp=nn.Sequential(nn.Linear(1,5),
nn.Sigmoid(),
nn.Linear(5,1),
nn.Sigmoid())
optim=torch.optim.Adam(params=bp.parameters())
Loss=nn.MSELoss()
for epoch in range(10):
print('the {} epoch'.format(epoch))
for batch in data:
yp=bp(batch[0])
loss=Loss(yp,batch[1])
optim.zero_grad()
loss.backward()
optim.step()
ps:下面再给大家补充介绍下Pytorch中DataLoader的使用。
前言
最近开始接触pytorch,从跑别人写好的代码开始,今天需要把输入数据根据每个batch的最长输入数据,填充到一样的长度(之前是将所有的数据直接填充到一样的长度再输入)。
刚开始是想偷懒,没有去认真了解输入的机制,结果一直报错…还是要认真学习呀!
加载数据
pytorch中加载数据的顺序是:
①创建一个dataset对象
②创建一个dataloader对象
③循环dataloader对象,将data,label拿到模型中去训练
dataset
你需要自己定义一个class,里面至少包含3个函数:
①__init__:传入数据,或者像下面一样直接在函数里加载数据
②__len__:返回这个数据集一共有多少个item
③__getitem__:返回一条训练数据,并将其转换成tensor
import torch
from torch.utils.data import Dataset
class Mydata(Dataset):
def __init__(self):
a = np.load("D:/Python/nlp/NRE/a.npy",allow_pickle=True)
b = np.load("D:/Python/nlp/NRE/b.npy",allow_pickle=True)
d = np.load("D:/Python/nlp/NRE/d.npy",allow_pickle=True)
c = np.load("D:/Python/nlp/NRE/c.npy")
self.x = list(zip(a,b,d,c))
def __getitem__(self, idx):
assert idx < len(self.x)
return self.x[idx]
def __len__(self):
return len(self.x)
dataloader
参数:
dataset:传入的数据
shuffle = True:是否打乱数据
collate_fn:使用这个参数可以自己操作每个batch的数据
dataset = Mydata()
dataloader = DataLoader(dataset, batch_size = 2, shuffle=True,collate_fn = mycollate)
下面是将每个batch的数据填充到该batch的最大长度
def mycollate(data):
a = []
b = []
c = []
d = []
max_len = len(data[0][0])
for i in data:
if len(i[0])>max_len:
max_len = len(i[0])
if len(i[1])>max_len:
max_len = len(i[1])
if len(i[2])>max_len:
max_len = len(i[2])
print(max_len)
# 填充
for i in data:
if len(i[0])<max_len:
i[0].extend([27] * (max_len-len(i[0])))
if len(i[1])<max_len:
i[1].extend([27] * (max_len-len(i[1])))
if len(i[2])<max_len:
i[2].extend([27] * (max_len-len(i[2])))
a.append(i[0])
b.append(i[1])
d.append(i[2])
c.extend(i[3])
# 这里要自己转成tensor
a = torch.Tensor(a)
b = torch.Tensor(b)
c = torch.Tensor(c)
d = torch.Tensor(d)
data1 = [a,b,d,c]
print("data1",data1)
return data1
结果:
最后循环该dataloader ,拿到数据放入模型进行训练:
for ii, data in enumerate(test_data_loader):
if opt.use_gpu:
data = list(map(lambda x: torch.LongTensor(x.long()).cuda(), data))
else:
data = list(map(lambda x: torch.LongTensor(x.long()), data))
out = model(data[:-1]) #数据data[:-1]
loss = F.cross_entropy(out, data[-1])# 最后一列是标签
写在最后:建议像我一样刚开始不太熟练的小伙伴,在处理数据输入的时候可以打印出来仔细查看。
来源:https://blog.csdn.net/weixin_60737527/article/details/126754254
猜你喜欢
- ul设置浮动后不能自适应高度,也就是不能撑开父容器,不能自适应内容的高度。解决方法是在ul结束标签前加个清除浮动。 &
- IE6绝对定位的bug及其解决办法。position:absolute定位在IE6下存在left和bottom的定位错误问题:<!–I
- 列表转化为字符串如下所示:>>> list1=['ak','uk',4]>>&
- asp获取application对象代码如下: <%application("new&qu
- 不知各位是否有手写代码的习惯。例如:要在一个单元格插入一段CSS代码,或者一段Javascript代码,怎么做才比较快捷方便呢?虽然Drea
- 游戏说明:一个考验您记忆力的游戏,只要两个方块的;图案能够凑成一对,最终翻开所有的图片,那么您就获胜,计算机将自动记录您的游戏时
- 通常我们在制作上图的时候,会分别给四个div加上不同的css属性,来实现中间间隔。但我们更希望的是不需要对html标签做标识,直接能通过cs
- 交待:使用的软硬件环境为Win XP SP2、SQL Server 2000 SP2个人版、普通双核台式机、1000M局域网,A机为已使用的
- 无论安装何版本的mysql,在管理工具的服务中启动mysql服务时都会在中途报错。内容为:在 本地计算机 无法启动mysql服务 错误106
- 首先打击我的就是rpm安装,它告诉我发现了Mysql版本冲突,安装无法继续。我用rpm -q 查询后,想通过rpm -e 来删除系统自带的版
- 描述:让Len,Left,Right函数识别中文;对中文识别为两个字符,ASCII码为一个;可用此函数代替Len,Left,Right函数。
- 背景:作为一个python小白,今天从菜鸟教程上看了一些python的教程,看到了python的一些语法,对比起来(有其他语言功底),感觉还
- 在asp中调用sql server的存储过程可以加快程序运行速度,本文介绍了asp使用存储过程的方法。1.调用存储过程的一般方法 先假设在s
- 下列语句部分是Mssql语句,不可以在access中使用。SQL分类:DDL—数据定义语言(CREATE,ALTER,DROP,DECLAR
- 先让我们看一个例子,了解什么是模式化窗口。以下是QQ秀商城在非登录时提示登录的一种状态。当我在非登录状态,通过保存形象的方式买一件衣服时,弹
- 下表列出 SQL Server 查询分析器提供的所有键盘快捷方式。活动 快捷方式 书签:清除所有书签。 CTRL-SHIFT-F2
- 每个产品诞生的背后都凝结着一位或是多位设计师的心血,在产品的诞生过程中文化、科技、环保、创意等这些方方面面的细节集结成一个绚丽的故事,因为有
- 译文原文:http://blog.benhuoer.com/2009/04/10-simple-and-impressive-design-
- 在开始聊我在阿里四个月的网页推广设计之前,我想先来说说我对平面设计和网页设计的认识。它们之间的交集。它们都是集艺术创作、电脑技术和数字技术于
- Dim iSet conn=Server.CreateObject("ADODB.Connecti