PyTorch 解决Dataset和Dataloader遇到的问题
作者:xgbm_k 发布时间:2023-10-14 04:37:50
标签:PyTorch,Dataset,Dataloader
今天在使用PyTorch中Dataset遇到了一个问题。先看代码
class psDataset(Dataset):
def __init__(self, x, y, transforms = None):
super(Dataset, self).__init__()
self.x = x
self.y = y
if transforms == None:
self.transforms = Compose([Resize((224, 224)), ToTensor()])
else:
self.transforms = transforms
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
img = Image.open(self.x[idx])
img = self.transforms(img)
return img, torch.tensor([[self.y[idx]]])
结果运行时报错:RuntimeError: invalid argument 0: Sizes of tensors must match except in dimension 0. Got 3 and 1 in dimension 1 at /opt/conda/conda-bld/pytorch_1522182087074/work/torch/lib/TH/generic/THTensorMath.c:2897
Google了一下发现是这样的:读入的图片有些是灰度图(1个通道),绝大多数是RGB图片(3通道),也有些是带透明度的(4通道)
。这导致在读入后最后一个维度(通道数)不一致(可能是1、3或者4)。
Dataloader在制作batch data时,tensor的shape必须一样,就报了这个错误。解决的方法是:img = img.convert(“RGB”)。完
整代码如下:
class psDataset(Dataset):
def __init__(self, x, y, transforms = None):
super(Dataset, self).__init__()
self.x = x
self.y = y
if transforms == None:
self.transforms = Compose([Resize((224, 224)), ToTensor()])
else:
self.transforms = transforms
def __len__(self):
return len(self.x)
def __getitem__(self, idx):
img = Image.open(self.x[idx])
img = img.convert("RGB")
img = self.transforms(img)
return img, torch.tensor([[self.y[idx]]])
来源:https://blog.csdn.net/xgbm_k/article/details/84067245


猜你喜欢
- 在接口测试学习过程中,遇到了利用requests库进行文件下载和上传的问题。同样,在真正的测试过程中,我们不可避免的会遇到上传和下载的测试。
- 本文实例为大家分享了python批量读取文件名并写入txt中的具体代码,供大家参考,具体内容如下先说下脚本使用的环境吧,在做项目的过程中需要
- 目录一,python介绍二.python的安装程序三、变量python基础部分学习一,python介绍python的创始人为吉多·范罗苏姆(
- asp之家注:长文章分页算是asp编程中一个比较经典单位问题,怎么分页,什么时候分页.方法挺多,有的是人为的加入分页标志,有的是程序自动加分
- 抢票脚本,python +splinter自动刷新抢票,可以成功抢到(依赖自己的网络环境太厉害,还有机器的好坏),但是感觉不是很完美。有大神
- 本文实例为大家分享了python比特币初始配置的具体代码,供大家参考,具体内容如下# -*- coding: utf-8 -*- "
- 我们知道一般的应用系统,读写比例在10:1左右,而且插入操作和一般的更新操作很少出现性能问题,遇到最多的,也是最容易出问题的,还是一些复杂的
- 前言注释可以起到一个备注的作用,团队合作的时候,个人编写的代码经常会被多人调用,为了让别人能更容易理解代码的通途,使用注释是非常有效的。Py
- 前两天看见有人问静态网页加密问题,就写了这个代码稍微有些长,解释一下思路:加密时:先把用户的密钥A用md5加密为B,然后用B异或源文件S0得
- import sysfrom PyQt5 import QtWidgetsfrom PyQt5.QtWidgets import QMain
- Blog Posts的提交让我们从简单的开始。首页上必须有一张用户提交新的post的表单。首先我们定义一个单域表单对象(fileapp/fo
- DreamWeaver 4的到来让大家兴奋吧?但是大家一定为DreamWeaver4里面的字体、文字大
- 实现一个不规则窗体这里我们实现一个圆形窗体,实现其他形状的窗体与这个方法类似。首先,把窗口的高度(height)和宽度(width)值修改为
- 做图像识别的时候需要在图片中画出特定大小和角度的矩形框,自己写了一个函数,给定的输入是图片名称,矩形框的位置坐标,长宽和角度,直接输出画好矩
- #sidebar div#live_chat a { background: url("scroll/live_chat1.jpg
- 检测这些圆,先找轮廓后通过轮廓点拟合椭圆import cv2import numpy as npimport matplotlib.pypl
- 1. 时间的表示Go 语言中时间的表示方式是通过 time.Time 结构体来表示的。time.Time 类型代表了一个时刻,它包含了年月日
- VSCode 必须安装以下插件:首先你必须安装 Golang 插件,然后再给 Go 安装工具包。在 VS Code 中,使用快捷键: com
- 在一个大型的项目中,不可避免会出现操作时间的业务,比如时间的格式化,比如时间的加减,我们一般会直接使用moment.js库来做,毕竟稳定可靠
- 相对于numpy、TensorFlow、pandas这些已经经过多年维护、迭代,对于大多数Python开发者耳熟能详的库不同。今天要给大家介