pytorch实现图像识别(实战)
作者:AI?AX?AT 发布时间:2022-10-03 01:19:03
标签:pytorch,实现,图像,识别
1. 代码讲解
1.1 导库
import os.path
from os import listdir
import numpy as np
import pandas as pd
from PIL import Image
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn import AdaptiveAvgPool2d
from torch.utils.data.sampler import SubsetRandomSampler
from torch.utils.data import Dataset
import torchvision.transforms as transforms
from sklearn.model_selection import train_test_split
1.2 标准化、transform、设置GPU
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
normalize = transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
transform = transforms.Compose([transforms.ToTensor(), normalize]) # 转换
1.3 预处理数据
class DogDataset(Dataset):
# 定义变量
def __init__(self, img_paths, img_labels, size_of_images):
self.img_paths = img_paths
self.img_labels = img_labels
self.size_of_images = size_of_images
# 多少长图片
def __len__(self):
return len(self.img_paths)
# 打开每组图片并处理每张图片
def __getitem__(self, index):
PIL_IMAGE = Image.open(self.img_paths[index]).resize(self.size_of_images)
TENSOR_IMAGE = transform(PIL_IMAGE)
label = self.img_labels[index]
return TENSOR_IMAGE, label
print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train')))
print(len(pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv')))
print(len(listdir(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\test')))
train_paths = []
test_paths = []
labels = []
# 训练集图片路径
train_paths_lir = r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\train'
for path in listdir(train_paths_lir):
train_paths.append(os.path.join(train_paths_lir, path))
# 测试集图片路径
labels_data = pd.read_csv(r'C:\Users\AIAXIT\Desktop\DeepLearningProject\Deep_Learning_Data\dog-breed-identification\labels.csv')
labels_data = pd.DataFrame(labels_data)
# 把字符标签离散化,因为数据有120种狗,不离散化后面把数据给模型时会报错:字符标签过多。把字符标签从0-119编号
size_mapping = {}
value = 0
size_mapping = dict(labels_data['breed'].value_counts())
for kay in size_mapping:
size_mapping[kay] = value
value += 1
# print(size_mapping)
labels = labels_data['breed'].map(size_mapping)
labels = list(labels)
# print(labels)
print(len(labels))
# 划分训练集和测试集
X_train, X_test, y_train, y_test = train_test_split(train_paths, labels, test_size=0.2)
train_set = DogDataset(X_train, y_train, (32, 32))
test_set = DogDataset(X_test, y_test, (32, 32))
train_loader = torch.utils.data.DataLoader(train_set, batch_size=64)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=64)
1.4 建立模型
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.features = nn.Sequential(
nn.Conv2d(in_channels=3, out_channels=6, kernel_size=5),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
nn.ReLU(),
nn.AvgPool2d(kernel_size=2, stride=2)
)
self.classifier = nn.Sequential(
nn.Linear(16 * 5 * 5, 120),
nn.ReLU(),
nn.Linear(120, 84),
nn.ReLU(),
nn.Linear(84, 120)
)
def forward(self, x):
batch_size = x.shape[0]
x = self.features(x)
x = x.view(batch_size, -1)
x = self.classifier(x)
return x
model = LeNet().to(device)
criterion = nn.CrossEntropyLoss().to(device)
optimizer = optim.Adam(model.parameters())
TRAIN_LOSS = [] # 损失
TRAIN_ACCURACY = [] # 准确率
1.5 训练模型
def train(epoch):
model.train()
epoch_loss = 0.0 # 损失
correct = 0 # 精确率
for batch_index, (Data, Label) in enumerate(train_loader):
# 扔到GPU中
Data = Data.to(device)
Label = Label.to(device)
output_train = model(Data)
# 计算损失
loss_train = criterion(output_train, Label)
epoch_loss = epoch_loss + loss_train.item()
# 计算精确率
pred = torch.max(output_train, 1)[1]
train_correct = (pred == Label).sum()
correct = correct + train_correct.item()
# 梯度归零、反向传播、更新参数
optimizer.zero_grad()
loss_train.backward()
optimizer.step()
print('Epoch: ', epoch, 'Train_loss: ', epoch_loss / len(train_set), 'Train correct: ', correct / len(train_set))
1.6 测试模型
和训练集差不多。
def test():
model.eval()
correct = 0.0
test_loss = 0.0
with torch.no_grad():
for Data, Label in test_loader:
Data = Data.to(device)
Label = Label.to(device)
test_output = model(Data)
loss = criterion(test_output, Label)
pred = torch.max(test_output, 1)[1]
test_correct = (pred == Label).sum()
correct = correct + test_correct.item()
test_loss = test_loss + loss.item()
print('Test_loss: ', test_loss / len(test_set), 'Test correct: ', correct / len(test_set))
1.7结果
epoch = 10
for n_epoch in range(epoch):
train(n_epoch)
test()
来源:https://blog.csdn.net/weixin_45758642/article/details/119764959


猜你喜欢
- 一、前言很多网站提供视频转GIF的功能,但要么收费要么有广告实际上我们通过python,几行代码就能够实现视频转gif二、教程1. 安装必备
- 本文实例为大家分享了python3.4函数操作mysql数据库的具体代码,供大家参考,具体内容如下#!/usr/bin/env python
- 一、特效预览处理前处理后细节放大后二、程序原理将图片所在的 256 的灰度映射到相应的字符上面也就是 RGB 值转成相应的字符然后再将字符其
- Python中赋值的含义在C++中,变量就是对象本身,对变量赋值就改变了它代表的对象。而在Python中,赋值的含义却是关联变量名字和实际对
- 目录1. python内置方法(read、readline、readlines)2. 内置模块(csv)3. 使用numpy库(loadtx
- 本文实例为大家分享了python多线程同时接受和发的具体代码,供大家参考,具体内容如下'''模仿qq 同时可以发送信
- Python文件遍历os.walk()与os.listdir()在图片处理过程中,样本数据的组织是个常见的问题,样本组织好了,后面数据转换、
- SQL Server定位于中型的数据库应用,操作较Oracle和MySQL等要相对简便,SQL Server在处理海量数据的效率,后台开发的
- 首先去官网下载两个架包链接如下:官网链接第一步:将两个架包解压到同一个database目录下。如截图所示:第二步:打开setup应用程序打开
- pycharm2021激活码是一个可以轻松帮助用户免费激活pycharm2021.1软件的文件,虽然说pycharm现在只是推出了2021.
- 1.思路在网上查找了半天,基本都是提取word中文字的,没有找到可以把word中的图片提取出来的方法。一个巧合的情况下,发现将word的后缀
- 存储过程与编码MySQL 存储过程中, 表和数据的编码与数据库和存储过程默认的编码不同则可能出现 sql 不会使用索引的情况, 因为 MyS
- 声明本文章为个人拙见,仅仅提供参考,不一定正确,各位大佬可以发表自己的意见。题目描述考虑到在虚拟机部署中资源提供商通常希望自己的收益最大化,
- 前言:我目前使用的服务器为centos6.x 系统自带的python的版本为2.6.x,但是目前无论是学习还是使用python,python
- 代码如下所示:表landundertake结构如下所示:表appraiser结构如下所示:access代码:TRANSFORM First(
- 关于ref和$refs的用法及讲解,vue.js中文社区( https://cn.vuejs.org/v2/api/#ref )是这么讲解的
- 在上篇文章给大家介绍过Django 多环境配置详解,感兴趣的朋友可以点击查阅,今天继续给大家介绍django 多环境配置的相关内容,本文重点
- 我们可以用鼠标把Dreamweaver的层在页面内拖动,但要全屏拖动就困难了,下面是一种实现的方法:制作步骤:一、准备图片,取名/file/
- 这次讨论一下关于select元素的一个问题,其实很早以前我就碰到过关于select元素的问题,这次做网站又被问到同样的问题,就是:一般div
- 在Python 3.10发布之前,Python是没有类似于其他语言中switch语句的,要实现类似的功能最简单的方法就是通过if ... e