torch.utils.data.DataLoader与迭代器转换操作
作者:Orion's?Blog 发布时间:2021-01-18 11:02:34
标签:torch.utils.data.DataLoader,迭代器转
在做实验时,我们常常会使用用开源的数据集进行测试。而Pytorch中内置了许多数据集,这些数据集我们常常使用DataLoader
类进行加载。
如下面这个我们使用DataLoader
类加载torch.vision
中的FashionMNIST
数据集。
from torch.utils.data import DataLoader
from torchvision import datasets
from torchvision.transforms import ToTensor
import matplotlib.pyplot as plt
training_data = datasets.FashionMNIST(
? ? root="data",
? ? train=True,
? ? download=True,
? ? transform=ToTensor()
)
test_data = datasets.FashionMNIST(
? ? root="data",
? ? train=False,
? ? download=True,
? ? transform=ToTensor()
)
我们接下来定义Dataloader对象用于加载这两个数据集:
train_dataloader = DataLoader(training_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)
那么这个train_dataloader
究竟是什么类型呢?
print(type(train_dataloader)) ?# <class 'torch.utils.data.dataloader.DataLoader'>
我们可以将先其转换为迭代器类型。
print(type(iter(train_dataloader)))# <class 'torch.utils.data.dataloader._SingleProcessDataLoaderIter'>
然后再使用next(iter(train_dataloader))
从迭代器里取数据,如下所示:
train_features, train_labels = next(iter(train_dataloader))
print(f"Feature batch shape: {train_features.size()}")
print(f"Labels batch shape: {train_labels.size()}")
img = train_features[0].squeeze()
label = train_labels[0]
plt.imshow(img, cmap="gray")
plt.show()
print(f"Label: {label}")
可以看到我们成功获取了数据集中第一张图片的信息,控制台打印:
Feature batch shape: torch.Size([64, 1, 28, 28])
Labels batch shape: torch.Size([64])
Label: 2
图片可视化显示如下:
不过有读者可能就会产生疑问,很多时候我们并没有将DataLoader类型强制转换成迭代器类型呀,大多数时候我们会写如下代码:
for train_features, train_labels in train_dataloader:?
? ? print(train_features.shape) # torch.Size([64, 1, 28, 28])
? ? print(train_features[0].shape) # torch.Size([1, 28, 28])
? ? print(train_features[0].squeeze().shape) # torch.Size([28, 28])
? ??
? ? img = train_features[0].squeeze()
? ? label = train_labels[0]
? ? plt.imshow(img, cmap="gray")
? ? plt.show()
? ? print(f"Label: {label}")
可以看到,该代码也能够正常迭代训练数据,前三个样本的控制台打印输出为:
torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 7
torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 4
torch.Size([64, 1, 28, 28])
torch.Size([1, 28, 28])
torch.Size([28, 28])
Label: 1
那么为什么我们这里没有显式将Dataloader
转换为迭代器类型呢,其实是Python语言for循环的一种机制,一旦我们用for ... in ...句式来迭代一个对象,那么Python
解释器就会偷偷地自动帮我们创建好迭代器,也就是说
for train_features, train_labels in train_dataloader:
实际上等同于
for train_features, train_labels in iter(train_dataloader):
更进一步,这实际上等同于
train_iterator = iter(train_dataloader)
try:
? ? while True:
? ? ? ? train_features, train_labels = next(train_iterator)
except StopIteration:
? ? pass
推而广之,我们在用Python迭代直接迭代列表时:
for x in [1, 2, 3, 4]:
其实Python解释器已经为我们隐式转换为迭代器了:
list_iterator = iter([1, 2, 3, 4])
try:
? ? while True:
? ? ? ? x = next(list_iterator)
except StopIteration:
? ? pass
来源:https://www.cnblogs.com/orion-orion/p/15651037.html


猜你喜欢
- tensorflow里面给出了一个函数用来读取图像,不过得到的结果是最原始的图像,是咩有经过解码的图像,这个函数为tf.gfile.Fast
- 当你执行大型程序的时候,突然出现exception,会让程序直接停止,这种对服务器自动程序很不友好,而python有着较好的异常捕获机制,不
- (一).确认删除用法: 1. BtnDel.Attributes.Add("onclick","return
- 如果你使用的正是mysql数据库,那么你把密码或者其他敏感重要信息保存在应用程序里的机会就很大。保护这些数据免受黑客或者窥探者的获取是一个令
- 在数据库开发过程中,当你检索的数据只是一条记录时,你所编写的事务语句代码往往使用SELECT INSERT 语句。但是我们常常会遇到这样情况
- 众所周知tensorflow造势虽大却很难用,因此推荐使用Keras,它缺省是基于tensorflow的,但通过修改keras.json也可
- 一般来说,pytorch 的Parameter是一个tensor,但是跟通常意义上的tensor有些不一样1) 通常意义上的tensor 仅
- 现在大部分网站都使用asp+access构建,这样的话通过下载access数据库简单就可以对网站进行破坏! 而很多的网站都不太重
- 在使用ros的时候经常会用到rosbag来录制或者回放算法,是个非常有用的工具。rosbag 命令列表命令作用record录制一个包,并且指
- 这两天在做小程序调取地图的时候遇到一个问题,如果用户第一次拒绝了位置权限请求。那么就不会再次唤起授权弹出。需要我们引导用户去开启。具体做法如
- 对于字典,通过“键”获得“值”非常简单,但通过“值”获得“键”则需绕些弯子。一、通用:自行定义函数方式假设:输入:一个字典(dic)+要找的
- 最近找遍了python的各个函数发现无法直接生成随机的二维数组,其中包括random()相关的各种方法,都没有得到想要的结果。最后在一篇博客
- 前言使用python实现设计模式中的单例模式。单例模式是一种比较常用的设计模式,其实现和使用场景判定都是相对容易的。本文将简要介绍一下pyt
- 引言接口测试就是数据的测试,在测试之前,需要准备好测试数据,而测试数据可以用数据库、excel、txt和csv方式,当然还有一种方式,那就是
- 目录1. 常用的编码2.补充:计算机表示的单位:3.ASCII编码2.GBK和GB2312编码4.Unicode5.UTF-8编码6.编码和
- 本文实例讲述了JavaScript字符串对象(string)基本用法。分享给大家供大家参考,具体如下:1.获取字符串的长度:var s =
- 问题描述:ImportError: No module named ‘XXXX'解决方式一: 将XXXX包放在python的site
- 本文实例为大家分享了python实现图像边缘检测的具体代码,供大家参考,具体内容如下任务描述背景边缘检测是数字图像处理领域的一个常用技术,被
- 在 Python 整型对象所存储的位置是不同的, 有一些是一直存储在某个存储里面, 而其它的, 则在使用时开辟出空间.说这句话的理由, 可以
- 本文实例为大家分享了Python获取指定网页源码的具体代码,供大家参考,具体内容如下1、任务简介前段时间一直在学习Python基础知识,故未