基于PyTorch实现一个简单的CNN图像分类器
作者:SY_curry 发布时间:2021-08-16 21:52:31
目录
一. 加载数据
1. 继承Dataset类并重写关键方法
2. 使用Dataloader加载数据
二. 模型设计
三. 训练
四. 测试
结语
pytorch中文网:https://www.pytorchtutorial.com/
pytorch官方文档:https://pytorch.org/docs/stable/index.html
一. 加载数据
Pytorch的数据加载一般是用torch.utils.data.Dataset与torch.utils.data.Dataloader两个类联合进行。我们需要继承Dataset来定义自己的数据集类,然后在训练时用Dataloader加载自定义的数据集类。
1. 继承Dataset类并重写关键方法
pytorch的dataset类有两种:Map-style datasets和Iterable-style datasets。前者是我们常用的结构,而后者是当数据集难以(或不可能)进行随机读取时使用。在这里我们实现Map-style dataset。
继承torch.utils.data.Dataset后,需要重写的方法有:__len__与__getitem__方法,其中__len__方法需要返回所有数据的数量,而__getitem__则是要依照给出的数据索引获取对应的tensor类型的Sample,除了这两个方法以外,一般还需要实现__init__方法来初始化一些变量。话不多说,直接上代码。
'''
包括了各种数据集的读取处理,以及图像相关处理方法
'''
from torch.utils.data import Dataset
import torch
import os
import cv2
from Config import mycfg
import random
import numpy as np
class ImageClassifyDataset(Dataset):
def __init__(self, imagedir, labelfile, classify_num, train=True):
'''
这里进行一些初始化操作。
'''
self.imagedir = imagedir
self.labelfile = labelfile
self.classify_num = classify_num
self.img_list = []
# 读取标签
with open(self.labelfile, 'r') as fp:
lines = fp.readlines()
for line in lines:
filepath = os.path.join(self.imagedir, line.split(";")[0].replace('\\', '/'))
label = line.split(";")[1].strip('\n')
self.img_list.append((filepath, label))
if not train:
self.img_list = random.sample(self.img_list, 50)
def __len__(self):
return len(self.img_list)
def __getitem__(self, item):
'''
这个函数是关键,通过item(索引)来取数据集中的数据,
一般来说在这里才将图像数据加载入内存,之前存的是图像的保存路径
'''
_int_label = int(self.img_list[item][1])# label直接用0,1,2,3,4...表示不同类别
label = torch.tensor(_int_label,dtype=torch.long)
img = self.ProcessImgResize(self.img_list[item][0])
return img, label
def ProcessImgResize(self, filename):
'''
对图像进行一些预处理
'''
_img = cv2.imread(filename)
_img = cv2.resize(_img, (mycfg.IMG_WIDTH, mycfg.IMG_HEIGHT), interpolation=cv2.INTER_CUBIC)
_img = _img.transpose((2, 0, 1))
_img = _img / 255
_img = torch.from_numpy(_img)
_img = _img.to(torch.float32)
return _img
有一些的数据集类一般还会传入一个transforms函数来构造一个图像预处理序列,传入transforms函数的一个好处是作为参数传入的话可以对一些非本地数据集中的数据进行操作(比如直接通过torchvision获取的一些预存数据集CIFAR10等等),除此之外就是torchvision.transforms里面有一些预定义的图像操作函数,可以直接像拼积木一样拼成一个图像处理序列,很方便。我这里因为是用我自己下载到本地的数据集,而且比较简单就直接用自己的函数来操作了。
2. 使用Dataloader加载数据
实例化自定义的数据集类ImageClassifyDataset后,将其传给DataLoader作为参数,得到一个可遍历的数据加载器。可以通过参数batch_size控制批处理大小,shuffle控制是否乱序读取,num_workers控制用于读取数据的线程数量。
from torch.utils.data import DataLoader
from MyDataset import ImageClassifyDataset
dataset = ImageClassifyDataset(imagedir, labelfile, 10)
dataloader = DataLoader(dataset, batch_size=5, shuffle=True,num_workers=5)
for index, data in enumerate(dataloader):
print(index)# batch索引
print(data)# 一个batch的{img,label}
二. 模型设计
在这里只讨论深度学习模型的设计,pytorch中的网络结构是一层一层叠出来的,pytorch中预定义了许多可以通过参数控制的网络层结构,比如Linear、CNN、RNN、Transformer等等具体可以查阅官方文档中的torch.nn部分。
设计自己的模型结构需要继承torch.nn.Module这个类,然后实现其中的forward方法,一般在__init__中设定好网络模型的一些组件,然后在forward方法中依据输入输出顺序拼装组件。
'''
包括了各种模型、自定义的loss计算方法、optimizer
'''
import torch.nn as nn
class Simple_CNN(nn.Module):
def __init__(self, class_num):
super(Simple_CNN, self).__init__()
self.class_num = class_num
self.conv1 = nn.Sequential(
nn.Conv2d(# input: 3,400,600
in_channels=3,
out_channels=8,
kernel_size=5,
stride=1,
padding=2
),
nn.Conv2d(
in_channels=8,
out_channels=16,
kernel_size=5,
stride=1,
padding=2
),
nn.AvgPool2d(2), # 16,400,600 --> 16,200,300
nn.BatchNorm2d(16),
nn.LeakyReLU(),
nn.Conv2d(
in_channels=16,
out_channels=16,
kernel_size=5,
stride=1,
padding=2
),
nn.Conv2d(
in_channels=16,
out_channels=8,
kernel_size=5,
stride=1,
padding=2
),
nn.AvgPool2d(2), # 8,200,300 --> 8,100,150
nn.BatchNorm2d(8),
nn.LeakyReLU(),
nn.Conv2d(
in_channels=8,
out_channels=8,
kernel_size=3,
stride=1,
padding=1
),
nn.Conv2d(
in_channels=8,
out_channels=1,
kernel_size=3,
stride=1,
padding=1
),
nn.AvgPool2d(2), # 1,100,150 --> 1,50,75
nn.BatchNorm2d(1),
nn.LeakyReLU()
)
self.line = nn.Sequential(
nn.Linear(
in_features=50 * 75,
out_features=self.class_num
),
nn.Softmax()
)
def forward(self, x):
x = self.conv1(x)
x = x.view(-1, 50 * 75)
y = self.line(x)
return y
上面我定义的模型中包括卷积组件conv1和全连接组件line,卷积组件中包括了一些卷积层,一般是按照{卷积层、池化层、激活函数}的顺序拼接,其中我还在激活函数之前添加了一个BatchNorm2d层对上层的输出进行正则化以免传入激活函数的值过小(梯度消失)或过大(梯度 * )。
在拼接组件时,由于我全连接层的输入是一个一维向量,所以需要将卷积组件中最后的50 × 75 50\times 7550×75大小的矩阵展平成一维的再传入全连接层(x.view(-1,50*75))
三. 训练
实例化模型后,网络模型的训练需要定义损失函数与优化器,损失函数定义了网络输出与标签的差距,依据不同的任务需要定义不同的合适的损失函数,而优化器则定义了神经网络中的参数如何基于损失来更新,目前神经网络最常用的优化器就是SGD(随机梯度下降算法) 及其变种。
在我这个简单的分类器模型中,直接用的多分类任务最常用的损失函数CrossEntropyLoss()以及优化器SGD。
self.cnnmodel = Simple_CNN(mycfg.CLASS_NUM)
self.criterion = nn.CrossEntropyLoss()# 交叉熵,标签应该是0,1,2,3...的形式而不是独热的
self.optimizer = optim.SGD(self.cnnmodel.parameters(), lr=mycfg.LEARNING_RATE, momentum=0.9)
训练过程其实很简单,使用dataloader依照batch读出数据后,将input放入网络模型中计算得到网络的输出,然后基于标签通过损失函数计算Loss,并将Loss反向传播回神经网络(在此之前需要清理上一次循环时的梯度),最后通过优化器更新权重。训练部分代码如下:
for each_epoch in range(mycfg.MAX_EPOCH):
running_loss = 0.0
self.cnnmodel.train()
for index, data in enumerate(self.dataloader):
inputs, labels = data
outputs = self.cnnmodel(inputs)
loss = self.criterion(outputs, labels)
self.optimizer.zero_grad()# 清理上一次循环的梯度
loss.backward()# 反向传播
self.optimizer.step()# 更新参数
running_loss += loss.item()
if index % 200 == 199:
print("[{}] loss: {:.4f}".format(each_epoch, running_loss/200))
running_loss = 0.0
# 保存每一轮的模型
model_name = 'classify-{}-{}.pth'.format(each_epoch,round(all_loss/all_index,3))
torch.save(self.cnnmodel,model_name)# 保存全部模型
四. 测试
测试和训练的步骤差不多,也就是读取模型后通过dataloader获取数据然后将其输入网络获得输出,但是不需要进行反向传播的等操作了。比较值得注意的可能就是准确率计算方面有一些小技巧。
acc = 0.0
count = 0
self.cnnmodel = torch.load('mymodel.pth')
self.cnnmodel.eval()
for index, data in enumerate(dataloader_eval):
inputs, labels = data # 5,3,400,600 5,10
count += len(labels)
outputs = cnnmodel(inputs)
_,predict = torch.max(outputs, 1)
acc += (labels == predict).sum().item()
print("[{}] accurancy: {:.4f}".format(each_epoch, acc / count))
我这里采用的是保存全部模型并加载全部模型的方法,这种方法的好处是在使用模型时可以完全将其看作一个黑盒,但是在模型比较大时这种方法会很费事。此时可以采用只保存参数不保存网络结构的方法,在每一次使用模型时需要读取参数赋值给已经实例化的模型:
torch.save(cnnmodel.state_dict(), "my_resnet.pth")
cnnmodel = Simple_CNN()
cnnmodel.load_state_dict(torch.load("my_resnet.pth"))
结语
至此整个流程就说完了,是一个小白级的图像分类任务流程,因为前段时间一直在做android方面的事,所以有点生疏了,就写了这篇博客记录一下,之后应该还会写一下seq2seq以及image caption任务方面的模型构造与训练过程,完整代码之后也会统一放到github上给大家做参考。
来源:https://blog.csdn.net/qq_34392457/article/details/113748534
猜你喜欢
- 【名称】Abs【类别】数学函数【原形】Abs(number)【参数】必选的。Number参数是一个任何有效的数值型表达式【返回值】同numb
- 1、程序执行代码:#Author by Andy#_*_ coding:utf-8 _*_import os,sys,timeBase_di
- 我们知道,一般的关系数据库(如SQL Server、Oracle、Access等)中的查询操作是支持集合操作的,例如可以用“Update A
- IE的有条件注释是一种专有的(因此是非标准的)、对常规(X)HTML注释的Miscrosoft扩展。顾名思义,有条件注释使你能够根据条件(比
- 目录准备数据集导入所需的软件包将数据从文件加载到Python变量拆分数据进行训练和测试标记化并准备词汇预处理输出标签/类建立Keras模型并
- 1. 将Oracle 10g client安装包copy到本地才能安装:2. 双击setup 的到:3. 稍后进入安装界面:4. 选择下一步
- 本文实例为大家分享了python+logging+yaml实现日志分割的具体代码,供大家参考,具体内容如下1、建立log.yaml文件ver
- 我们知道map() 会根据提供的函数对指定序列做映射。 第一个参数 function 以参数序列中的每一个元素调用 function函数,返
- Numpy是什么很简单,Numpy是Python的一个科学计算的库,提供了矩阵运算的功能,其一般与Scipy、matplotlib一起使用。
- 采集中 或者 在线添加文章中 都可以用到此功能俺自己在baidu上搜索的保存远程图片到本地的代码 感觉比较难用点 而且没有现成的比较全的代码
- 一、概念1、模块化代码可以使代码易于维护和调试,并且提高代码的重用性;2、函数可以用来减少冗余的代码并提高代码的可重用性。函数也可以用来模块
- 一 描述561. 数组拆分 I - 力扣(LeetCode) (leetcode-cn.com)给定长度为 2n 的整数
- php从5.2.x升级到5.3.2.出来问题了。有些原来能用的程序报错了。报错内容是Deprecated: Function session
- WebSocket的作用WebSock其实在平常使用,我们是时常见到的,用于实时通讯,例如我们常用的实时聊天、服务端向客户端消息推送、也可以
- python一行输入n个数据有时会碰到一行输入多个数据,这是可以先用str类型存一组数据,然后再迭代的将每个数据追加到新的列表中。方法一先输
- PHP mysqli_sqlstate() 函数返回最后一个 MySQL 操作的 SQLSTATE 错误代码:<?php// 假定数据
- 阅读是在网站中的一个很重要的部分,可以说是网站的核心。网站最终要呈现给用户的就是内容。尤其是文本内容。豆瓣豆瓣前段时间小改了一下,页面拉宽,
- 去掉html中的table代码 Function OutTable(str) dim a,re&nb
- 一、简单介绍正则表达式是一种小型的、高度专业化的编程语言,并不是python * 有的,是许多编程语言中基础而又重要的一部分。在python中
- Python时间格式化的时候,去掉前导0的:dt = datetime.now() print dt.strftime('%-H