PyTorch中torch.utils.data.DataLoader简单介绍与使用方法
作者:想变厉害的大白菜 发布时间:2023-10-30 07:12:00
一、torch.utils.data.DataLoader 简介
作用:torch.utils.data.DataLoader 主要是对数据进行 batch 的划分。
数据加载器,结合了数据集和取样器,并且可以提供多个线程处理数据集。
在训练模型时使用到此函数,用来 把训练数据分成多个小组 ,此函数 每次抛出一组数据 。直至把所有的数据都抛出。就是做一个数据的初始化。
好处:
使用DataLoader的好处是,可以快速的迭代数据。
用于生成迭代数据非常方便。
注意:
除此之外,特别要注意的是输入进函数的数据一定得是可迭代的。如果是自定的数据集的话可以在定义类中用def__len__、def__getitem__定义。
二、实例
BATCH_SIZE 刚好整除数据量
"""
批训练,把数据变成一小批一小批数据进行训练。
DataLoader就是用来包装所使用的数据,每次抛出一批数据
"""
import torch
import torch.utils.data as Data
BATCH_SIZE = 5 # 批训练的数据个数
x = torch.linspace(1, 10, 10) # 训练数据
print(x)
y = torch.linspace(10, 1, 10) # 标签
print(y)
# 把数据放在数据库中
torch_dataset = Data.TensorDataset(x, y) # 对给定的 tensor 数据,将他们包装成 dataset
loader = Data.DataLoader(
# 从数据库中每次抽出batch size个样本
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 要不要打乱数据 (打乱比较好)
num_workers=2, # 多线程来读数据
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training
print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
show_batch()
输出结果:
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
tensor([10., 9., 8., 7., 6., 5., 4., 3., 2., 1.])
steop:0, batch_x:tensor([10., 1., 3., 7., 6.]), batch_y:tensor([ 1., 10., 8., 4., 5.])
steop:1, batch_x:tensor([8., 5., 4., 9., 2.]), batch_y:tensor([3., 6., 7., 2., 9.])
steop:0, batch_x:tensor([ 9., 3., 10., 1., 5.]), batch_y:tensor([ 2., 8., 1., 10., 6.])
steop:1, batch_x:tensor([2., 6., 8., 4., 7.]), batch_y:tensor([9., 5., 3., 7., 4.])
steop:0, batch_x:tensor([ 2., 10., 9., 6., 1.]), batch_y:tensor([ 9., 1., 2., 5., 10.])
steop:1, batch_x:tensor([8., 3., 4., 7., 5.]), batch_y:tensor([3., 8., 7., 4., 6.])
说明:共有 10 条数据,设置 BATCH_SIZE 为 5 来进行划分,能划分为 2 组(steop 为 0 和 1)。这两组数据互斥。
BATCH_SIZE 不整除数据量:会输出余下所有数据
将上述代码中的 BATCH_SIZE 改为 4 :
"""
批训练,把数据变成一小批一小批数据进行训练。
DataLoader就是用来包装所使用的数据,每次抛出一批数据
"""
import torch
import torch.utils.data as Data
BATCH_SIZE = 4 # 批训练的数据个数
x = torch.linspace(1, 10, 10) # 训练数据
print(x)
y = torch.linspace(10, 1, 10) # 标签
print(y)
# 把数据放在数据库中
torch_dataset = Data.TensorDataset(x, y) # 对给定的 tensor 数据,将他们包装成 dataset
loader = Data.DataLoader(
# 从数据库中每次抽出batch size个样本
dataset=torch_dataset, # torch TensorDataset format
batch_size=BATCH_SIZE, # mini batch size
shuffle=True, # 要不要打乱数据 (打乱比较好)
num_workers=2, # 多线程来读数据
)
def show_batch():
for epoch in range(3):
for step, (batch_x, batch_y) in enumerate(loader):
# training
print("steop:{}, batch_x:{}, batch_y:{}".format(step, batch_x, batch_y))
show_batch()
输出结果:
tensor([ 1., 2., 3., 4., 5., 6., 7., 8., 9., 10.])
tensor([10., 9., 8., 7., 6., 5., 4., 3., 2., 1.])
steop:0, batch_x:tensor([1., 5., 3., 2.]), batch_y:tensor([10., 6., 8., 9.])
steop:1, batch_x:tensor([7., 8., 4., 6.]), batch_y:tensor([4., 3., 7., 5.])
steop:2, batch_x:tensor([10., 9.]), batch_y:tensor([1., 2.])
steop:0, batch_x:tensor([ 7., 10., 5., 2.]), batch_y:tensor([4., 1., 6., 9.])
steop:1, batch_x:tensor([9., 1., 6., 4.]), batch_y:tensor([ 2., 10., 5., 7.])
steop:2, batch_x:tensor([8., 3.]), batch_y:tensor([3., 8.])
steop:0, batch_x:tensor([10., 3., 2., 8.]), batch_y:tensor([1., 8., 9., 3.])
steop:1, batch_x:tensor([1., 7., 5., 9.]), batch_y:tensor([10., 4., 6., 2.])
steop:2, batch_x:tensor([4., 6.]), batch_y:tensor([7., 5.])
说明:共有 10 条数据,设置 BATCH_SIZE 为 4 来进行划分,能划分为 3 组(steop 为 0 、1、2)。分别有 4、4、2 条数据。
参考链接
torch.utils.data.DataLoader使用方法
【Pytorch基础】torch.utils.data.DataLoader方法的使用
来源:https://blog.csdn.net/weixin_44211968/article/details/123744513


猜你喜欢
- 数据分组使用 groupby() 方法进行分组group.size()查看分组后每组的数量group.groups 查看分组情况group.
- 环境Laravel 5.4原理在Laravel中,门面为应用服务容器中绑定的类提供了一个“静态”接口
- 本文实例讲述了Python实现二维有序数组查找的方法。分享给大家供大家参考,具体如下:题目:在一个二维数组中,每一行都按照从左到右递增的顺序
- 接口测试中,上传文件的测试场景非常常见。例如:上传头像(图片)、上传文件、上传视频等。下面以一个上传图片的例子为大家讲解如何通过 pytho
- 在平常的项目中,经常会碰到这样的问题:我需要在一张标中同时更新和查询出来的数据。例如:有如下图一张表数据,现在需要更新操作为:把status
- 自从SQL Server 2005推出后,因为有了更好的性能,所以有很多与SQL Server 2000相关的应用程序需要升级到这个版本。但
- 引言做接口测试的时候,避免不了操作数据库。因为数据校验需要,测试数据初始化需要、一些参数化场景需要等。数据库操作框架设计这里主要操作mysq
- 目录业务需求:方案一:vuex-persistedstate方案二:vuex-persist总结业务需求:在基于vue开发SPA项目时,为了
- 本文实例讲述了python概率计算器实现方法。分享给大家供大家参考。具体实现方法如下:from random import randrang
- Windows Server 2003系统是现在很流行的服务器操作系统,许多网站都用它来做。但是如何保证服务器的相对安全,这个只要进行一些简
- 生成Fiboncci Fn数有Θ(1),Θ(n)甚至指数级的算法,不过有Θ(log n)的吗?告诉你,有。首先,关于Fibonacci数,有
- Vue 能监听数组的情况Vue 监听数组和对象的变化(vue2.x)vue 实际上可以监听数组变化,比如:直接 = 赋值data () {
- 前言在算face_track_id map有感:开始验证data={'state':[1,1,2,2,1,2,2,2],
- input()作用:让用户从控制台输入一串字符,按下回车后结束输入,并返回字符串注意:很多初学者以为它可以返回数字,其实是错的!>&g
- 在将自定义的网络权重加载到网络中时,报错:AttributeError: 'dict' object has no attr
- 首先,在数据库中创建一个表,用于存放图片:CREATE TABLE Images(Id INT PRIMARY KEY AUTO_INCRE
- 本文实例讲述了Python3.5常见内置方法参数用法。分享给大家供大家参考,具体如下:Python的内置方法参数详解网站为:https://
- 前言现在系统的各种业务是如此的复杂,数据都存在数据库中的各种表中,这个主键啊,那个外键啊,而表与表之间就依靠着这些主键和外键联系在一起。而我
- 问:我最近升级了一个应用程序,使其可以在 SQL Server 2005 上运行。我利用了允许行长度超出 8,060 个字节这项功能,以便用
- 制作友好的模板Context你也许已经注意到范例中的出版商列表模板在变量 object_list 里保存所有的书籍。这个方法工作的很好,只是