Pytorch 使用CNN图像分类的实现
作者:NULL 发布时间:2023-04-01 03:24:21
标签:Pytorch,CNN,图像分类
需求
在4*4的图片中,比较外围黑色像素点和内圈黑色像素点个数的大小将图片分类
如上图图片外围黑色像素点5个大于内圈黑色像素点1个分为0类反之1类
想法
通过numpy、PIL构造4*4的图像数据集
构造自己的数据集类
读取数据集对数据集选取减少偏斜
cnn设计因为特征少,直接1*1卷积层
或者在4*4外围添加padding成6*6,设计2*2的卷积核得出3*3再接上全连接层
代码
import torch
import torchvision
import torchvision.transforms as transforms
import numpy as np
from PIL import Image
构造数据集
import csv
import collections
import os
import shutil
def buildDataset(root,dataType,dataSize):
"""构造数据集
构造的图片存到root/{dataType}Data
图片地址和标签的csv文件存到 root/{dataType}DataInfo.csv
Args:
root:str
项目目录
dataType:str
'train'或者‘test'
dataNum:int
数据大小
Returns:
"""
dataInfo = []
dataPath = f'{root}/{dataType}Data'
if not os.path.exists(dataPath):
os.makedirs(dataPath)
else:
shutil.rmtree(dataPath)
os.mkdir(dataPath)
for i in range(dataSize):
# 创建0,1 数组
imageArray=np.random.randint(0,2,(4,4))
# 计算0,1数量得到标签
allBlackNum = collections.Counter(imageArray.flatten())[0]
innerBlackNum = collections.Counter(imageArray[1:3,1:3].flatten())[0]
label = 0 if (allBlackNum-innerBlackNum)>innerBlackNum else 1
# 将图片保存
path = f'{dataPath}/{i}.jpg'
dataInfo.append([path,label])
im = Image.fromarray(np.uint8(imageArray*255))
im = im.convert('1')
im.save(path)
# 将图片地址和标签存入csv文件
filePath = f'{root}/{dataType}DataInfo.csv'
with open(filePath, 'w') as f:
writer = csv.writer(f)
writer.writerows(dataInfo)
root=r'/Users/null/Documents/PythonProject/Classifier'
构造训练数据集
buildDataset(root,'train',20000)
构造测试数据集
buildDataset(root,'test',10000)
读取数据集
class MyDataset(torch.utils.data.Dataset):
def __init__(self, root, datacsv, transform=None):
super(MyDataset, self).__init__()
with open(f'{root}/{datacsv}', 'r') as f:
imgs = []
# 读取csv信息到imgs列表
for path,label in map(lambda line:line.rstrip().split(','),f):
imgs.append((path, int(label)))
self.imgs = imgs
self.transform = transform if transform is not None else lambda x:x
def __getitem__(self, index):
path, label = self.imgs[index]
img = self.transform(Image.open(path).convert('1'))
return img, label
def __len__(self):
return len(self.imgs)
trainData=MyDataset(root = root,datacsv='trainDataInfo.csv', transform=transforms.ToTensor())
testData=MyDataset(root = root,datacsv='testDataInfo.csv', transform=transforms.ToTensor())
处理数据集使得数据集不偏斜
import itertools
def chooseData(dataset,scale):
# 将类别为1的排序到前面
dataset.imgs.sort(key=lambda x:x[1],reverse=True)
# 获取类别1的数目 ,取scale倍的数组,得数据不那么偏斜
trueNum =collections.Counter(itertools.chain.from_iterable(dataset.imgs))[1]
end = min(trueNum*scale,len(dataset))
dataset.imgs=dataset.imgs[:end]
scale = 4
chooseData(trainData,scale)
chooseData(testData,scale)
len(trainData),len(testData)
(2250, 1122)
import torch.utils.data as Data
# 超参数
batchSize = 50
lr = 0.1
numEpochs = 20
trainIter = Data.DataLoader(dataset=trainData, batch_size=batchSize, shuffle=True)
testIter = Data.DataLoader(dataset=testData, batch_size=batchSize)
定义模型
from torch import nn
from torch.autograd import Variable
from torch.nn import Module,Linear,Sequential,Conv2d,ReLU,ConstantPad2d
import torch.nn.functional as F
class Net(Module):
def __init__(self):
super(Net, self).__init__()
self.cnnLayers = Sequential(
# padding添加1层常数1,设定卷积核为2*2
ConstantPad2d(1, 1),
Conv2d(1, 1, kernel_size=2, stride=2,bias=True)
)
self.linearLayers = Sequential(
Linear(9, 2)
)
def forward(self, x):
x = self.cnnLayers(x)
x = x.view(x.shape[0], -1)
x = self.linearLayers(x)
return x
class Net2(Module):
def __init__(self):
super(Net2, self).__init__()
self.cnnLayers = Sequential(
Conv2d(1, 1, kernel_size=1, stride=1,bias=True)
)
self.linearLayers = Sequential(
ReLU(),
Linear(16, 2)
)
def forward(self, x):
x = self.cnnLayers(x)
x = x.view(x.shape[0], -1)
x = self.linearLayers(x)
return x
定义损失函数
# 交叉熵损失函数
loss = nn.CrossEntropyLoss()
loss2 = nn.CrossEntropyLoss()
定义优化算法
net = Net()
optimizer = torch.optim.SGD(net.parameters(),lr = lr)
net2 = Net2()
optimizer2 = torch.optim.SGD(net2.parameters(),lr = lr)
训练模型
# 计算准确率
def evaluateAccuracy(dataIter, net):
accSum, n = 0.0, 0
with torch.no_grad():
for X, y in dataIter:
accSum += (net(X).argmax(dim=1) == y).float().sum().item()
n += y.shape[0]
return accSum / n
def train(net, trainIter, testIter, loss, numEpochs, batchSize,
optimizer):
for epoch in range(numEpochs):
trainLossSum, trainAccSum, n = 0.0, 0.0, 0
for X,y in trainIter:
yHat = net(X)
l = loss(yHat,y).sum()
optimizer.zero_grad()
l.backward()
optimizer.step()
# 计算训练准确度和loss
trainLossSum += l.item()
trainAccSum += (yHat.argmax(dim=1) == y).sum().item()
n += y.shape[0]
# 评估测试准确度
testAcc = evaluateAccuracy(testIter, net)
print('epoch {:d}, loss {:.4f}, train acc {:.3f}, test acc {:.3f}'.format(epoch + 1, trainLossSum / n, trainAccSum / n, testAcc))
Net模型训练
train(net, trainIter, testIter, loss, numEpochs, batchSize,optimizer)
epoch 1, loss 0.0128, train acc 0.667, test acc 0.667
epoch 2, loss 0.0118, train acc 0.683, test acc 0.760
epoch 3, loss 0.0104, train acc 0.742, test acc 0.807
epoch 4, loss 0.0093, train acc 0.769, test acc 0.772
epoch 5, loss 0.0085, train acc 0.797, test acc 0.745
epoch 6, loss 0.0084, train acc 0.798, test acc 0.807
epoch 7, loss 0.0082, train acc 0.804, test acc 0.816
epoch 8, loss 0.0078, train acc 0.816, test acc 0.812
epoch 9, loss 0.0077, train acc 0.818, test acc 0.817
epoch 10, loss 0.0074, train acc 0.824, test acc 0.826
epoch 11, loss 0.0072, train acc 0.836, test acc 0.819
epoch 12, loss 0.0075, train acc 0.823, test acc 0.829
epoch 13, loss 0.0071, train acc 0.839, test acc 0.797
epoch 14, loss 0.0067, train acc 0.849, test acc 0.824
epoch 15, loss 0.0069, train acc 0.848, test acc 0.843
epoch 16, loss 0.0064, train acc 0.864, test acc 0.851
epoch 17, loss 0.0062, train acc 0.867, test acc 0.780
epoch 18, loss 0.0060, train acc 0.871, test acc 0.864
epoch 19, loss 0.0057, train acc 0.881, test acc 0.890
epoch 20, loss 0.0055, train acc 0.885, test acc 0.897
Net2模型训练
# batchSize = 50
# lr = 0.1
# numEpochs = 15 下得出的结果
train(net2, trainIter, testIter, loss2, numEpochs, batchSize,optimizer2)
epoch 1, loss 0.0119, train acc 0.638, test acc 0.676
epoch 2, loss 0.0079, train acc 0.823, test acc 0.986
epoch 3, loss 0.0046, train acc 0.987, test acc 0.977
epoch 4, loss 0.0030, train acc 0.983, test acc 0.973
epoch 5, loss 0.0023, train acc 0.981, test acc 0.976
epoch 6, loss 0.0019, train acc 0.980, test acc 0.988
epoch 7, loss 0.0016, train acc 0.984, test acc 0.984
epoch 8, loss 0.0014, train acc 0.985, test acc 0.986
epoch 9, loss 0.0013, train acc 0.987, test acc 0.992
epoch 10, loss 0.0011, train acc 0.989, test acc 0.993
epoch 11, loss 0.0010, train acc 0.989, test acc 0.996
epoch 12, loss 0.0010, train acc 0.992, test acc 0.994
epoch 13, loss 0.0009, train acc 0.993, test acc 0.994
epoch 14, loss 0.0008, train acc 0.995, test acc 0.996
epoch 15, loss 0.0008, train acc 0.994, test acc 0.998
测试
test = torch.Tensor([[[[0,0,0,0],[0,1,1,0],[0,1,1,0],[0,0,0,0]]],
[[[1,1,1,1],[1,0,0,1],[1,0,0,1],[1,1,1,1]]],
[[[0,1,0,1],[1,0,0,1],[1,0,0,1],[0,0,0,1]]],
[[[0,1,1,1],[1,0,0,1],[1,0,0,1],[0,0,0,1]]],
[[[0,0,1,1],[1,0,0,1],[1,0,0,1],[1,0,1,0]]],
[[[0,0,1,0],[0,1,0,1],[0,0,1,1],[1,0,1,0]]],
[[[1,1,1,0],[1,0,0,1],[1,0,1,1],[1,0,1,1]]]
])
target=torch.Tensor([0,1,0,1,1,0,1])
test
tensor([[[[0., 0., 0., 0.],
[0., 1., 1., 0.],
[0., 1., 1., 0.],
[0., 0., 0., 0.]]],
[[[1., 1., 1., 1.],
[1., 0., 0., 1.],
[1., 0., 0., 1.],
[1., 1., 1., 1.]]],
[[[0., 1., 0., 1.],
[1., 0., 0., 1.],
[1., 0., 0., 1.],
[0., 0., 0., 1.]]],
[[[0., 1., 1., 1.],
[1., 0., 0., 1.],
[1., 0., 0., 1.],
[0., 0., 0., 1.]]],
[[[0., 0., 1., 1.],
[1., 0., 0., 1.],
[1., 0., 0., 1.],
[1., 0., 1., 0.]]],
[[[0., 0., 1., 0.],
[0., 1., 0., 1.],
[0., 0., 1., 1.],
[1., 0., 1., 0.]]],
[[[1., 1., 1., 0.],
[1., 0., 0., 1.],
[1., 0., 1., 1.],
[1., 0., 1., 1.]]]])
with torch.no_grad():
output = net(test)
output2 = net2(test)
predictions =output.argmax(dim=1)
predictions2 =output2.argmax(dim=1)
# 比较结果
print(f'Net测试结果{predictions.eq(target)}')
print(f'Net2测试结果{predictions2.eq(target)}')
Net测试结果tensor([ True, True, False, True, True, True, True])
Net2测试结果tensor([False, True, False, True, True, False, True])
来源:https://segmentfault.com/a/1190000022939415


猜你喜欢
- 1、字典中的键存在时,可以通过字典名+下标的方式访问字典中改键对应的值,若键不存在则会抛出异常。如果想直接向字典中添加元素可以直接用字典名+
- 在 InnoDB中更加快速的全表扫描 一般来讲,大多数应用查询的时候都会用索引,查找很少的几行数据(主键查找或百行内的
- #!/usr/bin/env python# name IsOpen.pyimport osimport socketdef IsOpen(
- YEAR() 函数返回一个整数值,它表示指定日期的年份,一般使用为:Year(时间),如:YEAR('2023-03-14
- 代码如下: EXEC sp_rename '表名.[原列名]', '新列名', 'column
- 下次用python画图的时候选色选点都可以直接参考这边,牛逼!分享给大家,也给自己留个笔记。参考网址:http://stackoverflo
- 这几天在做一个数据集,由于不是很熟悉Linux下的命令,所以特地用了强大的python来做。我之前有一个数据集但是我只要里面名称带有comp
- Python中实现socket通信的服务端比较复杂,而客户端非常简单,所以客户端基本上都是用sockct模块实现,而服务 端用有很多模块可以
- 考虑下述Python代码片段。对文件中的数据进行某些操作,然后将结果保存回文件中:with open(filename) as f:&nbs
- mysql win10 解压缩下载 解压Mysql :版本 5.7.13 下载链接 (通过这个官网zip链接可以直接下载,不用再注册 Ora
- 1.ROOT_URLCONF = '总路由所在路径(比如untitled.urls)'<===默认情况是这样根路由的路
- TF-IDFTF-IDF(Term Frequencey-Inverse Document Frequency)指词频-逆文档频率,它属于数
- 本文实例为大家分享了python实现学生成绩测评系统的具体代码,供大家参考,具体内容如下1、问题描述(功能要求): 根据实验指导书
- jquery基本入门 第一天:选择器相关 1.html()与.text() .html()取得第一个匹配元素的html内容。会带有标签,.t
- 1、基本原理访问网站扫码登录页,网站给浏览器返回一个二维码和一个唯一标志KEY浏览器开启定时轮询服务器,确认KEY对应的扫码结果用户使用ap
- 限流器是服务中非常重要的一个组件,在网关设计、微服务、以及普通的后台应用中都比较常见。它可以限制访问服务的频次和速率,防止服务过载,被刷爆。
- 介绍pandas数据聚合和重组的相关知识,仅供参考。1GroupBy技术1.1简介简介:根据一个或多个键进行分组,每一组应用函数,再进行合并
- 前言最近在学习python,发现了解线程信号量的基础知识,对深入理解python的线程会大有帮助。所以本文将给大家介绍Python3.X线程
- 了解如何 在sublime编辑器中安装python软件包,以 实现自动完成等功能,并在sublime编辑器本身中运行build。安装Subl
- asyncio在Python 2的时代,高性能的网络编程主要是使用Twisted、Tornado和Gevent这三个库,但是它们的异步代码相