使用pytorch完成kaggle猫狗图像识别方式
作者:charsenhz 发布时间:2023-04-10 08:39:01
kaggle是一个为开发商和数据科学家提供举办机器学习竞赛、托管数据库、编写和分享代码的平台,在这上面有非常多的好项目、好资源可供机器学习、深度学习爱好者学习之用。
碰巧最近入门了一门非常的深度学习框架:pytorch,所以今天我和大家一起用pytorch实现一个图像识别领域的入门项目:猫狗图像识别。
深度学习的基础就是数据,咱们先从数据谈起。此次使用的猫狗分类图像一共25000张,猫狗分别有12500张,我们先来简单的瞅瞅都是一些什么图片。
我们从下载文件里可以看到有两个文件夹:train和test,分别用于训练和测试。以train为例,打开文件夹可以看到非常多的小猫图片,图片名字从0.jpg一直编码到9999.jpg,一共有10000张图片用于训练。
而test中的小猫只有2500张。仔细看小猫,可以发现它们姿态不一,有的站着,有的眯着眼睛,有的甚至和其他可识别物体比如桶、人混在一起。
同时,小猫们的图片尺寸也不一致,有的是竖放的长方形,有的是横放的长方形,但我们最终需要是合理尺寸的正方形。小狗的图片也类似,在这里就不重复了。
紧接着我们了解一下特别适用于图像识别领域的神经网络:卷积神经网络。学习过神经网络的同学可能或多或少地听说过卷积神经网络。这是一种典型的多层神经网络,擅长处理图像特别是大图像的相关机器学习问题。
卷积神经网络通过一系列的方法,成功地将大数据量的图像识别问题不断降维,最终使其能够被训练。CNN最早由Yann LeCun提出并应用在手写体识别上。
一个典型的CNN网络架构如下:
这是一个典型的CNN架构,由卷基层、池化层、全连接层组合而成。其中卷基层与池化层配合,组成多个卷积组,逐层提取特征,最终完成分类。
听到上述一连串的术语如果你有点蒙了,也别怕,因为这些复杂、抽象的技术都已经在pytorch中一一实现,我们要做的不过是正确的调用相关函数,
我在粘贴代码后都会做更详细、易懂的解释。
import os
import shutil
import torch
import collections
from torchvision import transforms,datasets
from __future__ import print_function, division
import os
import torch
import pylab
import pandas as pd
import torch.nn as nn
import torch.nn.functional as F
from torch.autograd import Variable
from skimage import io, transform
import numpy as np
import matplotlib.pyplot as plt
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms, utils
# Ignore warnings
import warnings
warnings.filterwarnings("ignore")
plt.ion() # interactive mode
一个正常的CNN项目所需要的库还是蛮多的。
import math
from PIL import Image
class Resize(object):
"""Resize the input PIL Image to the given size.
Args:
size (sequence or int): Desired output size. If size is a sequence like
(h, w), output size will be matched to this. If size is an int,
smaller edge of the image will be matched to this number.
i.e, if height > width, then image will be rescaled to
(size * height / width, size)
interpolation (int, optional): Desired interpolation. Default is
``PIL.Image.BILINEAR``
"""
def __init__(self, size, interpolation=Image.BILINEAR):
# assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2)
self.size = size
self.interpolation = interpolation
def __call__(self, img):
w,h = img.size
min_edge = min(img.size)
rate = min_edge / self.size
new_w = math.ceil(w / rate)
new_h = math.ceil(h / rate)
return img.resize((new_w,new_h))
这个称为Resize的库用于给图像进行缩放操作,本来是不需要亲自定义的,因为transforms.Resize已经实现这个功能了,但是由于目前还未知的原因,我的库里没有提供这个函数,所以我需要亲自实现用来代替transforms.Resize。
如果你的torch里面已经有了这个Resize函数就不用像我这样了。
data_transform = transforms.Compose([
Resize(84),
transforms.CenterCrop(84),
transforms.ToTensor(),
transforms.Normalize(mean = [0.5,0.5,0.5],std = [0.5,0.5,0.5])
])
train_dataset = datasets.ImageFolder(root = 'train/',transform = data_transform)
train_loader = torch.utils.data.DataLoader(train_dataset,batch_size = 4,shuffle = True,num_workers = 4)
test_dataset = datasets.ImageFolder(root = 'test/',transform = data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset,batch_size = 4,shuffle = True,num_workers = 4)
transforms是一个提供针对数据(这里指的是图像)进行转化的操作库,Resize就是上上段代码提供的那个类,主要用于把一张图片缩放到某个尺寸,在这里我们把需求暂定为要把图像缩放到84 x 84这个级别,这个就是可供调整的参数,大家为部署好项目以后可以试着修改这个参数,比如改成200 x 200,你就发现你可以去玩一盘游戏了~_~。
CenterCrop用于从中心裁剪图片,目标是一个长宽都为84的正方形,方便后续的计算。
ToTenser()就比较重要了,这个函数的目的就是读取图片像素并且转化为0-1的数字。
Normalize作为垫底的一步也很关键,主要用于把图片数据集的数值转化为标准差和均值都为0.5的数据集,这样数据值就从原来的0到1转变为-1到1。
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(3,6,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(6,16,5)
self.fc1 = nn.Linear(16 * 18 * 18,800)
self.fc2 = nn.Linear(800,120)
self.fc3 = nn.Linear(120,2)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = x.view(-1,16 * 18 * 18)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = self.fc3(x)
return x
net = Net()
来源:https://blog.csdn.net/xcszjjh1991/article/details/79256576


猜你喜欢
- 0、背景今天看到了一个比较诡异的写法,for后直接跟了else语句,起初还以为是没有缩进好,查询后发现果然有这种语法,特此分享。之前写过c+
- 前言本文将深入研究 preg_replace /e 模式下的代码执行问题,其中包括 preg_replace 函数的执行过程分析、正则表达式
- 它为什么是有用的? 作为一名JavaScript开发者,你可能经常发现自己处于代码覆盖可能有用的情景。例如:对测试套件的质量感兴趣? 重构一
- 一、Sql Server中的日期与时间函数 1. 当前系统日期、时间 select getdate() 2. dateadd 在向指定日期加
- 本实例的实现逻辑是,应用selenium UI自动化登录百度盘,读取存储百度分享地址和提取码的txt文档,打开百度盘分享地址,填入提取码,然
- MYSQL的事务处理主要有两种方法。 1、用begin,rollback,commit来实现 begin 开始一个事务 rollback 事
- Python快捷键相关设置,具体内容如下1、主题毫无疑问Pycharm是一个具有强大快捷键系统的IDE,这就意味着你在Pycharm中的任何
- 前言在使用传统物理机或云服务器上部署项目都会存在一些痛点比如:项目部署速度慢、资源浪费、迁移难且扩展低而使用 Docker 部署项目的优势包
- 引言软件开发经历了许多阶段,如需求收集和分析、设计、软件开发、测试和发布。测试是 SDLC 不可或缺的一部分,单元测试是一种可靠的测试类型。
- 前言Python语言的turtle库是一个直观有趣的图形绘制函数库,是python语言标准库之一。turtle库也叫海龟库,是turtle绘
- 浏览器的出现互联网的出现是人类信息交流方式的一次划时代的革命,在这场革命中有两个技术对互联网的发展起到了决定性的作用:一个技术带来的人类信息
- php var_dump 函数作用是判断一个变量的类型与长度,并输出变量的数值,如果变量有值输的是变量的值并回返数据类型.来看看var_du
- 备份还原数据库备份数据库企业管理器--或用SQL语句(完全备份):backup database 数据库 to
- 前言最近使用PyTorch感觉妙不可言,有种当初使用Keras的快感,而且速度还不慢。各种设计直接简洁,方便研究,比tensorflow的臃
- 大部分新手刚学Django开发的时候默认用的都是SQLite数据库,上线部署的时候,大多用的却是Mysql。那么我们应该如何把数据库从SQL
- 1.前言我在进行DEM数据的裁剪时,发现各个省的数据量非常大,比如说四川省的30m的DEM数据的大小为2G。考虑到有限的电脑磁盘空间,我对T
- python自定义异常实例详解 本文通过两种
- 一 按时间创建文件源码# 截图方式二# coding=utf-8import osimport time# 当前年月日时分秒时间 2020-
- 对于Python而言代码缩进是一种语法,Python没有像其他语言一样采用{}或者begin...end分隔代码块,而是采用代码缩进和冒号来
- 1. 表单框类型<!DOCTYPE html><html lang="en"><head&