pytorch深度神经网络入门准备自己的图片数据
作者:denny402 发布时间:2023-12-07 13:55:58
标签:pytorch,图片数据,数据准备,深度神经网络
图片数据一般有两种情况:
1、所有图片放在一个文件夹内,另外有一个txt文件显示标签。
2、不同类别的图片放在不同的文件夹内,文件夹就是图片的类别。
针对这两种不同的情况,数据集的准备也不相同,第一种情况可以自定义一个Dataset,第二种情况直接调用torchvision.datasets.ImageFolder来处理。下面分别进行说明:
一、所有图片放在一个文件夹内
这里以mnist数据集的10000个test为例, 我先把test集的10000个图片保存出来,并生着对应的txt标签文件。
先在当前目录创建一个空文件夹mnist_test, 用于保存10000张图片,接着运行代码:
import torch
import torchvision
import matplotlib.pyplot as plt
from skimage import io
mnist_test= torchvision.datasets.MNIST(
'./mnist', train=False, download=True
)
print('test set:', len(mnist_test))
f=open('mnist_test.txt','w')
for i,(img,label) in enumerate(mnist_test):
img_path="./mnist_test/"+str(i)+".jpg"
io.imsave(img_path,img)
f.write(img_path+' '+str(label)+'\n')
f.close()
经过上面的操作,10000张图片就保存在mnist_test文件夹里了,并在当前目录下生成了一个mnist_test.txt的文件,大致如下:
前期工作就装备好了,接着就进入正题了:
from torchvision import transforms, utils
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
from PIL import Image
def default_loader(path):
return Image.open(path).convert('RGB')
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
train_data=MyDataset(txt='mnist_test.txt', transform=transforms.ToTensor())
data_loader = DataLoader(train_data, batch_size=100,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
grid = utils.make_grid(imgs)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader')
for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(),batch_y.size())
show_batch(batch_x)
plt.axis('off')
plt.show()
自定义了一个MyDataset, 继承自torch.utils.data.Dataset。然后利用torch.utils.data.DataLoader将整个数据集分成多个批次。
二、不同类别的图片放在不同的文件夹内
同样先准备数据,这里以flowers数据集为例
提取 链接: https://pan.baidu.com/s/1dcAsOOZpUfWNYR77JGXPHA?pwd=mwg6
花总共有五类,分别放在5个文件夹下。大致如下图:
我的路径是d:/flowers/.
数据准备好了,就开始准备Dataset吧,这里直接调用torchvision里面的ImageFolder
import torch
import torchvision
from torchvision import transforms, utils
import matplotlib.pyplot as plt
img_data = torchvision.datasets.ImageFolder('D:/bnu/database/flower',
transform=transforms.Compose([
transforms.Scale(256),
transforms.CenterCrop(224),
transforms.ToTensor()])
)
print(len(img_data))
data_loader = torch.utils.data.DataLoader(img_data, batch_size=20,shuffle=True)
print(len(data_loader))
def show_batch(imgs):
grid = utils.make_grid(imgs,nrow=5)
plt.imshow(grid.numpy().transpose((1, 2, 0)))
plt.title('Batch from dataloader')
for i, (batch_x, batch_y) in enumerate(data_loader):
if(i<4):
print(i, batch_x.size(), batch_y.size())
show_batch(batch_x)
plt.axis('off')
plt.show()
来源:https://www.cnblogs.com/denny402/p/7512516.html


猜你喜欢
- Birdseye是一个Python调试器,它在函数调用中记录表达式的值,并让你在函数退出后轻松查看它们,例如:无论你如何运行或编辑代码,都可
- 1.首先检查自己的环境变量是否配置正确点击setting 点击 Python Interpreter点击Add Interpret
- xlwt工具使用,生成excel栏位宽度可自适应内容长度import xlwtresult = [
- 1 监听启动activity 信息命令adb shell logcat | grep START 可以查看apk包名和Activity名字=
- 前言汉诺塔问题是一个经典的问题。汉诺塔(Hanoi Tower),又称河内塔,源于印度一个古老传说。大梵天创造世界的时候做了三根金刚石柱子,
- 从技术上来说,在ASP环境中,读入并管理XML文本的主要方法有三种: 创建MSXML对象,并且将XML文档载入DOM; 使用服务器端嵌入(S
- 尝试了几种方法,感觉过于复杂,于是自己写了一个方法。(1)首先在要绘图的页面传入从数据库中提取的参数,这一步通过views可以实现;(2)然
- 对于变量的访问和设置,我们可以使用get、set方法,如下:class student: def __init__(self,n
- 什么是运行时配置(Runtime Configuration,rc)Matplotlib使用matplotlibrc配置文件来自定义图形的各
- 使用zap接收gin框架默认的日志并配置日志归档我们在基于gin框架开发项目时通常都会选择使用专业的日志库来记录项目中的日志,go语言常用的
- 这只是自己练习的一个记录而已。因为某种原因,不想用yii自带的user表,想用自己建的admin数据库表,修改如下:1. 参考高级模板里里的
- radians()方法把角度转化为弧度角x。语法以下是radians()方法的语法:radians(x)注意:此函数是无法直接访
- 本文为大家分享了Python多线程聊天室,是一个Socket,两个线程,一个是服务器,一个是客户端。 最近公司培训,要写个大富翁的小程序,准
- Mysql存储过程1.创建存储过程语法(格式)DELIMITER $CREATE PROCEDURE 存储过程名A(IN 传入参数名a IN
- 当数据库数据量涨到一定数量时,性能就成为我们不能不关注的问题,如何优化呢? 常用的方式不外乎那么几种:1、分表,即把一个很大的表达数据分到几
- 目录Show Me The Code测试下效果效果PS另一种方法Show Me The CodeHTMLElement.prototype.
- 有关函数HashBytes请参考:http://msdn.microsoft.com/en-us/library/ms174415.aspx
- 我就废话不多说了,还是直接上代码吧! url = "http://%s:%s/api-token-auth/" % (i
- 在python中加背景音乐的方法:1、导入pygame资源包;2、修改音乐的file路径;3、使用init()方法进行初始化;4、使用loa
- 下一代的 web 已经开始上路了,就在这个星期,MySpace 集成了 Google Gears,雅虎发布了新的 BrowserPlus,G