pytorch 准备、训练和测试自己的图片数据的方法
作者:denny402 发布时间:2021-02-27 13:54:35
标签:pytorch,准备,训练,测试
大部分的pytorch入门教程,都是使用torchvision里面的数据进行训练和测试。如果我们是自己的图片数据,又该怎么做呢?
一、我的数据
我在学习的时候,使用的是fashion-mnist。这个数据比较小,我的电脑没有GPU,还能吃得消。关于fashion-mnist数据,可以百度,也可以点此 了解一下,数据就像这个样子:
下载地址:https://github.com/zalandoresearch/fashion-mnist
但是下载下来是一种二进制文件,并不是图片,因此我先转换成了图片。
我先解压gz文件到e:/fashion_mnist/文件夹
然后运行代码:
import os
from skimage import io
import torchvision.datasets.mnist as mnist
root="E:/fashion_mnist/"
train_set = (
mnist.read_image_file(os.path.join(root, 'train-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root, 'train-labels-idx1-ubyte'))
)
test_set = (
mnist.read_image_file(os.path.join(root, 't10k-images-idx3-ubyte')),
mnist.read_label_file(os.path.join(root, 't10k-labels-idx1-ubyte'))
)
print("training set :",train_set[0].size())
print("test set :",test_set[0].size())
def convert_to_img(train=True):
if(train):
f=open(root+'train.txt','w')
data_path=root+'/train/'
if(not os.path.exists(data_path)):
os.makedirs(data_path)
for i, (img,label) in enumerate(zip(train_set[0],train_set[1])):
img_path=data_path+str(i)+'.jpg'
io.imsave(img_path,img.numpy())
f.write(img_path+' '+str(label)+'\n')
f.close()
else:
f = open(root + 'test.txt', 'w')
data_path = root + '/test/'
if (not os.path.exists(data_path)):
os.makedirs(data_path)
for i, (img,label) in enumerate(zip(test_set[0],test_set[1])):
img_path = data_path+ str(i) + '.jpg'
io.imsave(img_path, img.numpy())
f.write(img_path + ' ' + str(label) + '\n')
f.close()
convert_to_img(True)
convert_to_img(False)
这样就会在e:/fashion_mnist/目录下分别生成train和test文件夹,用于存放图片。还在该目录下生成了标签文件train.txt和test.txt.
二、进行CNN分类训练和测试
先要将图片读取出来,准备成torch专用的dataset格式,再通过Dataloader进行分批次训练。
代码如下:
import torch
from torch.autograd import Variable
from torchvision import transforms
from torch.utils.data import Dataset, DataLoader
from PIL import Image
root="E:/fashion_mnist/"
# -----------------ready the dataset--------------------------
def default_loader(path):
return Image.open(path).convert('RGB')
class MyDataset(Dataset):
def __init__(self, txt, transform=None, target_transform=None, loader=default_loader):
fh = open(txt, 'r')
imgs = []
for line in fh:
line = line.strip('\n')
line = line.rstrip()
words = line.split()
imgs.append((words[0],int(words[1])))
self.imgs = imgs
self.transform = transform
self.target_transform = target_transform
self.loader = loader
def __getitem__(self, index):
fn, label = self.imgs[index]
img = self.loader(fn)
if self.transform is not None:
img = self.transform(img)
return img,label
def __len__(self):
return len(self.imgs)
train_data=MyDataset(txt=root+'train.txt', transform=transforms.ToTensor())
test_data=MyDataset(txt=root+'test.txt', transform=transforms.ToTensor())
train_loader = DataLoader(dataset=train_data, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_data, batch_size=64)
#-----------------create the Net and training------------------------
class Net(torch.nn.Module):
def __init__(self):
super(Net, self).__init__()
self.conv1 = torch.nn.Sequential(
torch.nn.Conv2d(3, 32, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2))
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(32, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.conv3 = torch.nn.Sequential(
torch.nn.Conv2d(64, 64, 3, 1, 1),
torch.nn.ReLU(),
torch.nn.MaxPool2d(2)
)
self.dense = torch.nn.Sequential(
torch.nn.Linear(64 * 3 * 3, 128),
torch.nn.ReLU(),
torch.nn.Linear(128, 10)
)
def forward(self, x):
conv1_out = self.conv1(x)
conv2_out = self.conv2(conv1_out)
conv3_out = self.conv3(conv2_out)
res = conv3_out.view(conv3_out.size(0), -1)
out = self.dense(res)
return out
model = Net()
print(model)
optimizer = torch.optim.Adam(model.parameters())
loss_func = torch.nn.CrossEntropyLoss()
for epoch in range(10):
print('epoch {}'.format(epoch + 1))
# training-----------------------------
train_loss = 0.
train_acc = 0.
for batch_x, batch_y in train_loader:
batch_x, batch_y = Variable(batch_x), Variable(batch_y)
out = model(batch_x)
loss = loss_func(out, batch_y)
train_loss += loss.data[0]
pred = torch.max(out, 1)[1]
train_correct = (pred == batch_y).sum()
train_acc += train_correct.data[0]
optimizer.zero_grad()
loss.backward()
optimizer.step()
print('Train Loss: {:.6f}, Acc: {:.6f}'.format(train_loss / (len(
train_data)), train_acc / (len(train_data))))
# evaluation--------------------------------
model.eval()
eval_loss = 0.
eval_acc = 0.
for batch_x, batch_y in test_loader:
batch_x, batch_y = Variable(batch_x, volatile=True), Variable(batch_y, volatile=True)
out = model(batch_x)
loss = loss_func(out, batch_y)
eval_loss += loss.data[0]
pred = torch.max(out, 1)[1]
num_correct = (pred == batch_y).sum()
eval_acc += num_correct.data[0]
print('Test Loss: {:.6f}, Acc: {:.6f}'.format(eval_loss / (len(
test_data)), eval_acc / (len(test_data))))
打印出来的网络模型:
训练和测试结果:
来源:https://www.cnblogs.com/denny402/p/7520063.html


猜你喜欢
- Mysql查询以某"字符串"开头的查询查询不以某个或者某些字符串为开头的字符串1、使用left()函数select *
- Python实现八大排序算法,具体内容如下1、插入排序描述插入排序的基本操作就是将一个数据插入到已经排好序的有序数据中,从而得到一个新的、个
- 前言Exception类是常用的异常类,该类包括StandardError,StopIteration, GeneratorExit, Wa
- 一、注意你的Python版本Python官方网站为http://www.python.org/,当前最新稳定版本为3.6.5,在3.0版本时
- vue-i18n在单文件js中使用示例import Vue from 'vue'import VueI18n from
- 单一数据读取方式:第一种:slice_input_producer()# 返回值可以直接通过 Session.run([images, la
- 前言官方手册:https://dev.mysql.com/doc/refman/5.7/en/server-logs.html不管是哪个数据
- 模块的的作用主要是用于字符串和文本处理,查找,搜索,替换等复习一下基本的正则表达式吧 .:匹配除了换行符以为的任意单个字符&nbs
- 本文实例讲述了Python实现合并同一个文件夹下所有txt文件的方法。分享给大家供大家参考,具体如下:一、需求分析合并一个文件夹下所有txt
- 很多网站注册时都会要求输入电子邮箱,其应用场景是比较广的,例如注册账号接收验证码、注册成功通知、登录通知、找回密码验证通知等。本文将介绍如何
- 指令和程序计算机的硬件系统通常由五大部件构成,包括:运算器、控制器、存储器、输入设备和输出设备。其中,运算器和控制器放在一起就是我们通常所说
- 1. 从官网下载 mysql-5.7.13-linux-glibc2.5-x86_64.tar.gz经测试, 本文还适用于如下版本:MySQ
- 前言如果你急需一个简单的Web Server,但你又不想去下载并安装那些复杂的HTTP服务程序,比如:Apache,ISS等。那么, Pyt
- 一:车辆识别成果展示二:车辆识别超详细步骤解析步骤一:灰度化处理灰度处理目的 RGB三通道转灰度单通道 压缩到原图片三分之一大小效果展示:【
- em 和 strong 的区别,可以从三个层次上来谈。首先看 HTML 4.01 中的说明:EM: Indicates emphasis.S
- asp压缩access数据库(带密码)方法:以下是一个类文件,下面的注解是调用类的方法 注意:如果系统不支持建立Scripting
- 本文实例讲述了从ThinkPHP3.2.3过渡到ThinkPHP5.0学习笔记。分享给大家供大家参考,具体如下:用tp3.2.3做了不少项目
- 前言首先,一个常见的问题是,ECMAScript 和 JavaScript 到底是什么关系?ECMAScript是一个国际通过的标准化脚本语
- 有时候会需要通过从保存下来的ckpt文件来观察其保存下来的训练完成的变量值。ckpt文件名列表:(一般是三个文件)xxxxx.ckpt.da
- mysql数据通过data文件恢复mysql磁盘文件被损坏,无法启动,能看到data文件,在没有备份的话如何复原?情景1:知道数据库中的表结