Pytorch数据读取之Dataset和DataLoader知识总结
作者:群星闪耀 发布时间:2023-11-02 22:57:37
一、前言
确保安装
scikit-image
numpy
二、Dataset
一个例子:
# 导入需要的包
import torch
import torch.utils.data.dataset as Dataset
import numpy as np
# 编造数据
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
# 数据[1,2],对应的标签是[0],数据[3,4],对应的标签是[1]
#创建子类
class subDataset(Dataset.Dataset):
#初始化,定义数据内容和标签
def __init__(self, Data, Label):
self.Data = Data
self.Label = Label
#返回数据集大小
def __len__(self):
return len(self.Data)
#得到数据内容和标签
def __getitem__(self, index):
data = torch.Tensor(self.Data[index])
label = torch.IntTensor(self.Label[index])
return data, label
# 主函数
if __name__ == '__main__':
dataset = subDataset(Data, Label)
print(dataset)
print('dataset大小为:', dataset.__len__())
print(dataset.__getitem__(0))
print(dataset[0])
输出的结果
我们有了对Dataset的一个整体的把握,再来分析里面的细节:
#创建子类
class subDataset(Dataset.Dataset):
创建子类时,继承的时Dataset.Dataset,不是一个Dataset。因为Dataset是module模块,不是class类,所以需要调用module里的class才行,因此是Dataset.Dataset!
len和getitem这两个函数,前者给出数据集的大小**,后者是用于查找数据和标签。是最重要的两个函数,我们后续如果要对数据做一些操作基本上都是再这两个函数的基础上进行。
三、DatasetLoader
DataLoader(dataset,
batch_size=1,
shuffle=False,
sampler=None,
batch_sampler=None,
num_works=0,
clollate_fn=None,
pin_memory=False,
drop_last=False,
timeout=0,
worker_init_fn=None,
multiprocessing_context=None)
功能:构建可迭代的数据装载器;
dataset:Dataset类,决定数据从哪里读取及如何读取;数据集的路径
batchsize:批大小;
num_works:是否多进程读取数据;只对于CPU
shuffle:每个epoch是否打乱;
drop_last:当样本数不能被batchsize整除时,是否舍弃最后一批数据;
Epoch:所有训练样本都已输入到模型中,称为一个Epoch;
Iteration:一批样本输入到模型中,称之为一个Iteration;
Batchsize:批大小,决定一个Epoch中有多少个Iteration;
还是举一个实例:
import torch
import torch.utils.data.dataset as Dataset
import torch.utils.data.dataloader as DataLoader
import numpy as np
Data = np.asarray([[1, 2], [3, 4],[5, 6], [7, 8]])
Label = np.asarray([[0], [1], [0], [2]])
#创建子类
class subDataset(Dataset.Dataset):
#初始化,定义数据内容和标签
def __init__(self, Data, Label):
self.Data = Data
self.Label = Label
#返回数据集大小
def __len__(self):
return len(self.Data)
#得到数据内容和标签
def __getitem__(self, index):
data = torch.Tensor(self.Data[index])
label = torch.IntTensor(self.Label[index])
return data, label
if __name__ == '__main__':
dataset = subDataset(Data, Label)
print(dataset)
print('dataset大小为:', dataset.__len__())
print(dataset.__getitem__(0))
print(dataset[0])
#创建DataLoader迭代器,相当于我们要先定义好前面说的Dataset,然后再用Dataloader来对数据进行一些操作,比如是否需要打乱,则shuffle=True,是否需要多个进程读取数据num_workers=4,就是四个进程
dataloader = DataLoader.DataLoader(dataset,batch_size= 2, shuffle = False, num_workers= 4)
for i, item in enumerate(dataloader): #可以用enumerate来提取出里面的数据
print('i:', i)
data, label = item #数据是一个元组
print('data:', data)
print('label:', label)
四、将Dataset数据和标签放在GPU上(代码执行顺序出错则会有bug)
这部分可以直接去看博客:Dataset和DataLoader
总结下来时有两种方法解决
1.如果在创建Dataset的类时,定义__getitem__方法的时候,将数据转变为GPU类型。则需要将Dataloader里面的参数num_workers设置为0,因为这个参数是对于CPU而言的。如果数据改成了GPU,则只能单进程。如果是在Dataloader的部分,先多个子进程读取,再转变为GPU,则num_wokers不用修改。就是上述__getitem__部分的代码,移到Dataloader部分。
2.不过一般来讲,数据集和标签不会像我们上述编辑的那么简单。一般再kaggle上的标签都是存在CSV这种文件中。需要pandas的配合。
这个进阶可以看:WRITING CUSTOM DATASETS, DATALOADERS AND TRANSFORMS,他是用人脸图片作为数据和人脸特征点作为标签。
来源:https://blog.csdn.net/weixin_40244676/article/details/117043973


猜你喜欢
- 要实现的目标,简单示例:from functools import partialdef func1(f): re
- 目录背景方案一:老数据备份方案二:分表方案三:迁移至tidb重点说下同步老数据遇到的坑最终同步脚本方案总结背景由于历史业务数据采用mysql
- 看代码吧~# 加载库import pandas as pd# 데이터프레임을 만듭니다.dataframe = pd.DataFrame()
- Requests 是使用 Apache2 Licensed 许可证的 基于Python开发的HTTP 库,其在Python内置模块的基础上进
- 现在基于WEB页的HTML的编辑器在新闻系统,文章系统中用得越来越广,一个网页一粘就可以保持原来的样式,同时图片也可以在这个页中保持。但是在
- 《hadoop权威指南》的天气数据可以在ftp://ftp3.ncdc.noaa.gov/pub/data/noaa下载,在网上看到这个数据
- 一、随机数种子为什么要提出随机数种子呢?咱们前面提到过了,随机数均是模拟出来的, 想要模拟的比较真实,就需要变换种子函数内的数值,一般以时间
- 1.过程蜘蛛纸牌大家玩过没有?之前的电脑上自带的游戏,用他来摸鱼过的举个手。但是现在的电脑上已经没有蜘蛛纸牌了。所以…
- 在现在的项目里,不管是电商项目还是别的项目,在管理端都会有导出的功能,比方说订单表导出,用户表导出,业绩表导出。这些都需要提前生成excel
- 终于完成了偶的拖动窗口,花了近15个小时,庆祝一下(*^__^*);以前写了IE下的功能,于是又写了firefox下的功能,在firefox
- 一:编译器 编译器是一种特殊的程序,它可以把以特定编程语言写成的程序变为机器可以运行的机器码。我们把一个程序写好,这时我们利用的环境是文本编
- 版本选择因为MySql的版本越来越多,而作为中小网站者可能没有足够的经济去购买商业版本,所以一般选择免费版,而且功能也是足够使用的。有钱任性
- 目录1 简介2 Dash中的常用特殊功能部件2.1 用Store()来存储数据2.2 用Interval()实现周期性回调2.3 利用Col
- Bootstrap 通过一些简单的 HTML 标签和扩展的类即可创建出不同样式的表单。0x01 样式1一个登录界面:<!DOCTYPE
- 本文实例为大家分享了js拖拽实现图形伸缩效果的具体代码,供大家参考,具体内容如下点击矩形的四个角和四个边实现不同的效果<!DOCTYP
- 之前在豆瓣上听到有友邻在抱怨卓越的配送速度慢得跟蜗牛一样,超过配送时间期限几天还没送到,当时不太相信,因为此前在卓越网上购买的物品基本上是在
- 用于存储数据的csv文件有时候数据量是十分庞大的,然而我们有时候并不需要全部的数据,我们需要的可能仅仅是前面的几行。这样就可以通过panda
- 一、VScode下载官网Download Visual Studio Code - Mac, Linux, Windows点击64 bit会
- 1. 场景描述linux服务器下安装了Anaconda3,执行Pyhton的K-means算法,结果出现如下图的中文字符乱码。上次已经解决了
- Go语言实现互斥锁、随机数、time、Listimport ( "container/list"