pytorch 数据处理:定义自己的数据集合实例
作者:蓝鲸123 发布时间:2021-09-11 06:10:21
标签:pytorch,定义,数据,集合
数据处理
版本1
#数据处理
import os
import torch
from torch.utils import data
from PIL import Image
import numpy as np
#定义自己的数据集合
class DogCat(data.Dataset):
def __init__(self,root):
#所有图片的绝对路径
imgs=os.listdir(root)
self.imgs=[os.path.join(root,k) for k in imgs]
def __getitem__(self, index):
img_path=self.imgs[index]
#dog-> 1 cat ->0
label=1 if 'dog' in img_path.split('/')[-1] else 0
pil_img=Image.open(img_path)
array=np.asarray(pil_img)
data=torch.from_numpy(array)
return data,label
def __len__(self):
return len(self.imgs)
dataSet=DogCat('./data/dogcat')
print(dataSet[0])
输出:
(
( 0 ,.,.) =
215 203 191
206 194 182
211 199 187
⋮
200 191 186
201 192 187
201 192 187
( 1 ,.,.) =
215 203 191
208 196 184
213 201 189
⋮
198 189 184
200 191 186
201 192 187
( 2 ,.,.) =
215 201 188
209 195 182
214 200 187
⋮
200 191 186
202 193 188
204 195 190
…
(399,.,.) =
72 90 32
88 106 48
38 56 0
⋮
158 161 106
87 85 36
105 98 52
[torch.ByteTensor of size 400x300x3]
, 1)
上面的数据处理有下面的问题:
1.返回的样本的形状大小不一致,每一张图片的大小不一样。这对于需要batch训练的神经网络来说很不友好。
2. 返回的数据样本数值很大,没有归一化【-1,1】
对于上面的问题,pytorch torchvision 是一个视觉化的工具包,提供了很多的图像处理的工具,其中transforms模块提供了对PIL image对象和Tensor对象的常用操作。
对PIL Image常见的操作如下;
Resize 调整图片的尺寸,长宽比保持不变
CentorCrop ,RandomCrop,RandomSizeCrop 裁剪图片
Pad 填充
ToTensor 将PIL Image 转换为Tensor,会自动将[0,255] 归一化至[0,1]
对Tensor 的操作如下:
Normalize 标准化,即减均值,除以标准差
ToPILImage 将Tensor转换为 PIL Image对象
版本2
#数据处理
import os
import torch
from torch.utils import data
from PIL import Image
import numpy as np
from torchvision import transforms
transform=transforms.Compose([
transforms.Resize(224), #缩放图片,保持长宽比不变,最短边的长为224像素,
transforms.CenterCrop(224), #从中间切出 224*224的图片
transforms.ToTensor(), #将图片转换为Tensor,归一化至[0,1]
transforms.Normalize(mean=[.5,.5,.5],std=[.5,.5,.5]) #标准化至[-1,1]
])
#定义自己的数据集合
class DogCat(data.Dataset):
def __init__(self,root):
#所有图片的绝对路径
imgs=os.listdir(root)
self.imgs=[os.path.join(root,k) for k in imgs]
self.transforms=transform
def __getitem__(self, index):
img_path=self.imgs[index]
#dog-> 1 cat ->0
label=1 if 'dog' in img_path.split('/')[-1] else 0
pil_img=Image.open(img_path)
if self.transforms:
data=self.transforms(pil_img)
else:
pil_img=np.asarray(pil_img)
data=torch.from_numpy(pil_img)
return data,label
def __len__(self):
return len(self.imgs)
dataSet=DogCat('./data/dogcat')
print(dataSet[0])
输出:
(
( 0 ,.,.) =
-0.1765 -0.2627 -0.1686 … -0.0824 -0.2000 -0.2627
-0.2392 -0.3098 -0.3176 … -0.2863 -0.2078 -0.1765
-0.3176 -0.2392 -0.2784 … -0.2941 -0.1137 -0.0118
… ⋱ …
-0.7569 -0.5922 -0.1529 … -0.8510 -0.8196 -0.8353
-0.8353 -0.7255 -0.3255 … -0.8275 -0.8196 -0.8588
-0.9373 -0.7647 -0.4510 … -0.8196 -0.8353 -0.8824
( 1 ,.,.) =
-0.0431 -0.1373 -0.0431 … 0.0118 -0.0980 -0.1529
-0.0980 -0.1686 -0.1765 … -0.1608 -0.0745 -0.0431
-0.1686 -0.0902 -0.1373 … -0.1451 0.0431 0.1529
… ⋱ …
-0.5529 -0.3804 0.0667 … -0.7961 -0.7725 -0.7961
-0.6314 -0.5137 -0.1137 … -0.7804 -0.7882 -0.8275
-0.7490 -0.5608 -0.2392 … -0.7725 -0.8039 -0.8588
…
[torch.FloatTensor of size 3x224x224]
, 1)
项目的github地址:https://github.com/WebLearning17/CommonTool
来源:https://blog.csdn.net/TH_NUM/article/details/80877196
0
投稿
猜你喜欢
- strip_tags定义和用法strip_tags() 函数剥去字符串中的 HTML、XML 以及 PHP 的标签。注释:该函数始终会剥离
- 本文实例讲述了python实现定时同步本机与北京时间的方法。分享给大家供大家参考。具体如下:这段python代码首先从www.beijing
- 刚开始,根据我的想法,这个很简单嘛,上sql语句delete from zqzrdp where tel in (select min(dp
- 目前广泛使用的图像分类数据集之一是MNIST数据集。如今,MNIST数据集更像是一个健全的检查,而不是一个基准。为了提高难度,我们将在接下来
- 本文主要包括三大方面,大家仔细学习。1、导航栏中的表单导航栏中的表单不是使用 Bootstrap 表单 章节中所讲到的默认的 class,它
- 正在看的ORACLE教程是:Access2000迁移到Oracle9i要点。 &nb
- 谷歌的potobuf不说了,它很牛B,但是对客户端对象不支持,比如JavaScript就读取不了。Jil很牛,比Newtonsoft.Jso
- 本文实例讲述了Python使用Pandas库常见操作。分享给大家供大家参考,具体如下:1、概述Pandas 是Python的核心数据分析支持
- 说起来惭愧,总是犯一些小错误,纠结半天,这不应为一个分号的玩意折腾了好半天! 错误时在执行SQL语句的时候发出的,信息如下: Java代码
- 本文实例讲述了Python微信企业号开发之回调模式接收微信端客户端发送消息及被动返回消息。分享给大家供大家参考,具体如下:说明:此代码用于接
- 英文文档:locals()Update and return a dictionary representing the current l
- 一、什么是匿名函数?在Javascript定义一个函数一般有如下三种方式:函数关键字(function)语句:function f
- Python2的字符串有两种:str和Unicode,Python3的字符串也有两种:str和Bytes。Python2的str相当于Pyt
- Idea 2020 发布之后,官方终于支持了中文语言包但是,我下载后在插件市场无法找到官方的汉化包那要怎么解决这个问题呢?首先,查看你当前I
- 本文实例讲述了python获取从命令行输入数字的方法。分享给大家供大家参考。具体如下:#--------------------------
- console.log,作为一个前端开发者,可能每天都会用它来分析调试,但这个简单函数背后不简单那一面,你未必知道……基础首先,简单科普这个
- 数据库相关错误的解决办法错误一:数据库连接池超过限制SqlAlchemy QueuePool limit overflow造成连接数超过数据
- 1.今天网上下载一个博客项目,发现本地访问,js,css加载不了.我想应该是项目上线的安全措施,但是我想调试项目.找到方法如下在settin
- 网页中使用flash可以增强页面的动态交互效果,特别是用flash来制作广告,效果更好。经常使用flash的人,可能就碰到了flash会遮住
- Altova 公司的 XMLSPY 是个不可多得的好东西,它几乎可以开发所有的 XML 产品。最近用它来做 Schema