pytorch 带batch的tensor类型图像显示操作
作者:Xavier Jiezou 发布时间:2023-06-02 08:47:26
项目场景
pytorch训练时我们一般把数据集放到数据加载器里,然后分批拿出来训练。训练前我们一般还要看一下训练数据长啥样,也就是训练数据集可视化。
那么如何显示dataloader里面带batch的tensor类型的图像呢?
显示图像
绘图最常用的库就是matplotlib:
pip install matplotlib
显示图像会用到matplotlib.pyplot.imshow方法。查阅官方文档可知,该方法接收的图像的通道数要放到后面:
数据加载器中数据的维度是[B, C, H, W],我们每次只拿一个数据出来就是[C, H, W],而matplotlib.pyplot.imshow要求的输入维度是[H, W, C],所以我们需要交换一下数据维度,把通道数放到最后面,这里用到pytorch里面的permute方法(transpose方法也行,不过要交换两次,没这个方便,numpy中的transpose方法倒是可以一次交换完成)
用法示例如下:
>>> x = torch.randn(2, 3, 5)
>>> x.size()
torch.Size([2, 3, 5])
>>> x.permute(1, 2, 0).size()
torch.Size([3, 5, 2])
代码示例
#%% 导入模块
import torch
import matplotlib.pyplot as plt
from torchvision.utils import make_grid
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
#%% 下载数据集
train_file = datasets.MNIST(
root='./dataset/',
train=True,
transform=transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,))
]),
download=True
)
#%% 制作数据加载器
train_loader = DataLoader(
dataset=train_file,
batch_size=9,
shuffle=True
)
#%% 训练数据可视化
images, labels = next(iter(train_loader))
print(images.size()) # torch.Size([9, 1, 28, 28])
plt.figure(figsize=(9, 9))
for i in range(9):
plt.subplot(3, 3, i+1)
plt.title(labels[i].item())
plt.imshow(images[i].permute(1, 2, 0), cmap='gray')
plt.axis('off')
plt.show()
这里以mnist数据集为例,演示一下显示效果。我这个代码其实还有一点小问题。数据增强的时候我不是进行标准化了嘛,就是在第7行代码:Normalize((0.1307,), (0.3081,))。
所以,如果你想查看训练集的原始图像,还得反标准化。
标准化:image = (image-mean)/std
反标准化:image = image*std+mean
我拿imagenet中的一个蚂蚁和蜜蜂的子集做了一下实验,标准化前后的区别还是很明显的:
最终效果
补充:PIL,plt显示tensor类型的图像
该方法针对显示Dataloader读取的图像
PIL 与plt中对应操作不同,但原理是一样的,我试过用下方代码Image的方法在plt上show失败了,原因暂且不知。
# 方法1:Image.show()
# transforms.ToPILImage()中有一句
# npimg = np.transpose(pic.numpy(), (1, 2, 0))
# 因此pic只能是3-D Tensor,所以要用image[0]消去batch那一维
img = transforms.ToPILImage(image[0])
img.show()
# 方法2:plt.imshow(ndarray)
img = image[0] # plt.imshow()只能接受3-D Tensor,所以也要用image[0]消去batch那一维
img = img.numpy() # FloatTensor转为ndarray
img = np.transpose(img, (1,2,0)) # 把channel那一维放到最后
# 显示图片
plt.imshow(img)
plt.show()
cnt += 1
来源:https://blog.csdn.net/qq_42951560/article/details/109962828


猜你喜欢
- 0x01春节闲着没事(是有多闲),就写了个简单的程序,来爬点笑话看,顺带记录下写程序的过程。第一次接触爬虫是看了这么一个帖子,一个逗逼,爬取
- 先来说说实现方式: 1、我们来假定Table中有一个已经建立了索引的主键字段ID(整数型),我们将按照这个字段来取数据进行分页。 2、页的大
- 直接上代码图片就使用我家爽妹子的吧如果没有安装pil模块的话先cmd安装下输入:pip install pillow# -*- coding
- python正则模块re中findall和finditer两者相似,但却有很大区别。 两者都可以获取所有的匹配结果,这和searc
- MySQL报错:错误代码: 1293 Incorrect table definition; there can be only one T
- SQL Server导出表到EXCEL文件的存储过程:*--数据导出EXCEL导出表中的数据到Excel,包含字段名,文件为真正的Excel
- 一、logging日志模块等级常见log级别从高到低:CRITICAL 》ERROR 》WARNING 》INFO 》DEBUG,默认等级为
- 前言在laravel项目开发中,经常使用到公共函数,那如何在laravel配置全局公共函数呢??下面话不多说了,来一起看看详细的介绍吧方法如
- 本文实例讲述了Python实现查询某个目录下修改时间最新的文件。分享给大家供大家参考,具体如下:通过Python脚本,查询出某个目录下修改时
- 引言使用python进行接口测试时常常需要接口用例测试数据、断言接口功能、验证接口响应状态等,如果大量的接口测试用例脚本都将接口测试用例数据
- 一、SQL模式SQL的模式匹配允许你使用“_”匹配任何单个字符,而“%”匹配任意数目字符(包括零个字符)。在 MySQL中,SQL的模式缺省
- 1. 准备工作后台服务接口,对书本的增删改查操作2. 弹出窗口进入ElementUi官网, 找到Dialog对话框,可以参考&ldq
- Python中try块可以捕获测试代码块中的错误。except块可以处理错误。finally块可以执行代码,而不管try-和except块的
- 腐蚀在一些图像中,会有一些异常的部分,比如这样的毛刺:对于这样的情况,我们就可以应用复式操作了。需要注意的是,腐蚀操作只能处理二值图像,即像
- 准备工作首先是准备工作,导入需要使用的库,读取并创建数据表取名为loandata。import numpy as npimport pand
- 如果你有两条音频合成为一条音频(叠加,不是拼接)的需求,以下代码可以直接使用,需要修改的地方我已经标出来了,有三处需要修改你的本地音频的地址
- 环境:Pycharm ;其他环境:安装Anaconda最近在做一个小型项目练手,涉及到大量的IP和相关数据处理,所以选用了Python来处理
- 写出能用的代码很简单,写出好用的代码很难。好用的代码,也都会遵循一此原则,这就是设计原则,它们分别是:单一职责原则 (SRP)开闭原则 (O
- $str=preg_replace("/\s+/", " ", $str); //过滤多余回车 $s
- 我就废话不多说了,直接上代码吧!a=[[1,2,3],[4,5][6,7]["a","b""