Pytorch中DataLoader的使用方法详解
作者:生信小兔 发布时间:2023-07-19 04:45:39
在Pytorch中,torch.utils.data中的Dataset与DataLoader是处理数据集的两个函数,用来处理加载数据集。通常情况下,使用的关键在于构建dataset类。
一:dataset类构建。
在构建数据集类时,除了__init__(self),还要有__len__(self)与__getitem__(self,item)两个方法,这三个是必不可少的,至于其它用于数据处理的函数,可以任意定义。
class dataset:
def __init__(self,...):
...
def __len__(self,...):
return n
def __getitem__(self,item):
return data[item]
正常情况下,该数据集是要继承Pytorch中Dataset类的,但实际操作中,即使不继承,数据集类构建后仍可以用Dataloader()加载的。
在dataset类中,__len__(self)返回数据集中数据个数,__getitem__(self,item)表示每次返回第item条数据。
二:DataLoader使用
在构建dataset类后,即可使用DataLoader加载。DataLoader中常用参数如下:
1.dataset:需要载入的数据集,如前面构造的dataset类。
2.batch_size:批大小,在神经网络训练时我们很少逐条数据训练,而是几条数据作为一个batch进行训练。
3.shuffle:是否在打乱数据集样本顺序。True为打乱,False反之。
4.drop_last:是否舍去最后一个batch的数据(很多情况下数据总数N与batch size不整除,导致最后一个batch不为batch size)。True为舍去,False反之。
三:举例
兔兔以指标为1,数据个数为100的数据为例。
import torch
from torch.utils.data import DataLoader
class dataset:
def __init__(self):
self.x=torch.randint(0,20,size=(100,1),dtype=torch.float32)
self.y=(torch.sin(self.x)+1)/2
def __len__(self):
return 100
def __getitem__(self, item):
return self.x[item],self.y[item]
data=DataLoader(dataset(),batch_size=10,shuffle=True)
for batch in data:
print(batch)
当然,利用这个数据集可以进行简单的神经网络训练。
from torch import nn
data=DataLoader(dataset(),batch_size=10,shuffle=True)
bp=nn.Sequential(nn.Linear(1,5),
nn.Sigmoid(),
nn.Linear(5,1),
nn.Sigmoid())
optim=torch.optim.Adam(params=bp.parameters())
Loss=nn.MSELoss()
for epoch in range(10):
print('the {} epoch'.format(epoch))
for batch in data:
yp=bp(batch[0])
loss=Loss(yp,batch[1])
optim.zero_grad()
loss.backward()
optim.step()
ps:下面再给大家补充介绍下Pytorch中DataLoader的使用。
前言
最近开始接触pytorch,从跑别人写好的代码开始,今天需要把输入数据根据每个batch的最长输入数据,填充到一样的长度(之前是将所有的数据直接填充到一样的长度再输入)。
刚开始是想偷懒,没有去认真了解输入的机制,结果一直报错…还是要认真学习呀!
加载数据
pytorch中加载数据的顺序是:
①创建一个dataset对象
②创建一个dataloader对象
③循环dataloader对象,将data,label拿到模型中去训练
dataset
你需要自己定义一个class,里面至少包含3个函数:
①__init__:传入数据,或者像下面一样直接在函数里加载数据
②__len__:返回这个数据集一共有多少个item
③__getitem__:返回一条训练数据,并将其转换成tensor
import torch
from torch.utils.data import Dataset
class Mydata(Dataset):
def __init__(self):
a = np.load("D:/Python/nlp/NRE/a.npy",allow_pickle=True)
b = np.load("D:/Python/nlp/NRE/b.npy",allow_pickle=True)
d = np.load("D:/Python/nlp/NRE/d.npy",allow_pickle=True)
c = np.load("D:/Python/nlp/NRE/c.npy")
self.x = list(zip(a,b,d,c))
def __getitem__(self, idx):
assert idx < len(self.x)
return self.x[idx]
def __len__(self):
return len(self.x)
dataloader
参数:
dataset:传入的数据
shuffle = True:是否打乱数据
collate_fn:使用这个参数可以自己操作每个batch的数据
dataset = Mydata()
dataloader = DataLoader(dataset, batch_size = 2, shuffle=True,collate_fn = mycollate)
下面是将每个batch的数据填充到该batch的最大长度
def mycollate(data):
a = []
b = []
c = []
d = []
max_len = len(data[0][0])
for i in data:
if len(i[0])>max_len:
max_len = len(i[0])
if len(i[1])>max_len:
max_len = len(i[1])
if len(i[2])>max_len:
max_len = len(i[2])
print(max_len)
# 填充
for i in data:
if len(i[0])<max_len:
i[0].extend([27] * (max_len-len(i[0])))
if len(i[1])<max_len:
i[1].extend([27] * (max_len-len(i[1])))
if len(i[2])<max_len:
i[2].extend([27] * (max_len-len(i[2])))
a.append(i[0])
b.append(i[1])
d.append(i[2])
c.extend(i[3])
# 这里要自己转成tensor
a = torch.Tensor(a)
b = torch.Tensor(b)
c = torch.Tensor(c)
d = torch.Tensor(d)
data1 = [a,b,d,c]
print("data1",data1)
return data1
结果:
最后循环该dataloader ,拿到数据放入模型进行训练:
for ii, data in enumerate(test_data_loader):
if opt.use_gpu:
data = list(map(lambda x: torch.LongTensor(x.long()).cuda(), data))
else:
data = list(map(lambda x: torch.LongTensor(x.long()), data))
out = model(data[:-1]) #数据data[:-1]
loss = F.cross_entropy(out, data[-1])# 最后一列是标签
写在最后:建议像我一样刚开始不太熟练的小伙伴,在处理数据输入的时候可以打印出来仔细查看。
来源:https://blog.csdn.net/weixin_60737527/article/details/126754254


猜你喜欢
- 1.安装相应的库文件sudo apt-get install python-mysqldb2.数据库操作import MySQLdb db
- 一、前言写这篇文章的灵感来源于我玩游戏的时候(为了避免过不了审就不说是啥游戏了),看见一个大佬在游戏里面建造了“还原方阵
- 今天我们看看所有的类!由于工作的上的事有点忙!点图!以后讲解这是我编译好了的类的结构图,我们可以用很多软件可以从原板的DLL看到这些内容!当
- 一、环境配置安装 Python请确保您已经安装了 Python 3.x。可以在Python 官网下载并安装。安装所需库在命令提示符或终端中运
- 前言由于pycharm自带的pip源网站是国外网址,这就导致了许多国内用户在pycharm中下载其他软件包速度极慢,有时还会跳出下载失败的界
- openCV是一个开源的用C/C++开发的计算机图形图像库,非常强大,研究资料很齐全。本文重点是介绍如何使用php来调用其中的局部的功能。人
- Python 在命令行解析方面给出了类似的几个选择:自己解析, 自给自足(batteries-included)的方式,以及大量的
- Mcrypt扩展库可以实现加密解密功能,就是既能将明文加密,也可以密文还原。1.PHP加密扩展库Mcrypt安装在标准的PHP安装过程中并没
- 网上关于tensorflow模型文件ckpt格式转pb文件的帖子很多,本人几乎尝试了所有方法,最后终于成功了,现总结如下。方法无外乎下面两种
- vue实现一个分页组件vue-paginaitonvue使用了一段时间的感触就是,我再也不想直接操作DOM了。数据绑定式的编程体验真是好。实
- 介绍SUM()函数用于计算一组值或表达式的总和,SUM()函数的语法如下:SUM(DISTINCT expression)SUM()函数是如
- 今年年初,新一季的《最强大脑》开播了,第一集选拔的时候大家做了一个数字游戏,名叫《数字华容道》,当时何猷君以二十几秒的成绩夺得该项目的冠军,
- 简洁的隐藏垂直菜单在hover时将内容展开。这样的效果在JS里有很多个版本,但这个可以说是绝无仅有的CSS版本。此菜单可以在IE5.5,IE
- 目前由于phantomjs已经不维护了,而新版的Chrome(59+)推出了Headless模式,对爬虫来说尤其是定时任务的爬虫截屏之类的是
- 有时候难免需要直接调用Shell命令来完成一些比较简单的操作,比如mount一个文件系统之类的。那么我们使用Python如何调用Linux的
- eval()函数可以将字符串型的list、tuple、dict等等转换为原有的数据类型即使用eval可以实现从元组,列表,字典型的字符串到元
- 1.关系模型:用二维表格结构表示实体集,外键表示实体间联系的数据模型称为关系模型。关系模型是由若干个关系模式组成的集合。2.关系模式:关系模
- 写在最前最近在使用vue的时候,遇到一个需求,实现左右div可通过中间部分拖拽调整宽度,类似于这样这是我最终的实现效果还是老话,因为我不是专
- 本文将想大家简单介绍一下XML HttpRequst对象基础方法,希望通过本文能够使大家对其有一个初步的了解readyState一共有5个可
- 前言一个非常神秘的魔术方法。这个方法非常不起眼,用途狭窄,我几乎从未注意过它,然而,当发现它可能是上述“定律”的唯一例外情况时,我认为值得再