Pytorch搭建简单的卷积神经网络(CNN)实现MNIST数据集分类任务
作者:无知的吱屋 发布时间:2021-04-24 02:25:16
标签:Pytorch,神经网络,MNIST,数据集,分类
关于一些代码里的解释,可以看我上一篇发布的文章,里面有很详细的介绍!!!
可以依次把下面的代码段合在一起运行,也可以通过jupyter notebook分次运行
第一步:基本库的导入
import numpy as np
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
import time
np.random.seed(1234)
第二步:引用MNIST数据集,这里采用的是torchvision自带的MNIST数据集
#这里用的是torchvision已经封装好的MINST数据集
trainset=torchvision.datasets.MNIST(
root='MNIST', #root是下载MNIST数据集保存的路径,可以自行修改
train=True,
transform=torchvision.transforms.ToTensor(),
download=True
)
testset=torchvision.datasets.MNIST(
root='MNIST',
train=False,
transform=torchvision.transforms.ToTensor(),
download=True
)
trainloader = DataLoader(dataset=trainset, batch_size=100, shuffle=True) #DataLoader是一个很好地能够帮助整理数据集的类,可以用来分批次,打乱以及多线程等操作
testloader = DataLoader(dataset=testset, batch_size=100, shuffle=True)
下载之后利用DataLoader实例化为适合遍历的训练集和测试集,我们把其中的某一批数据进行可视化,下面是可视化的代码,其实就是利用subplot画了子图。
#可视化某一批数据
train_img,train_label=next(iter(trainloader)) #iter迭代器,可以用来便利trainloader里面每一个数据,这里只迭代一次来进行可视化
fig, axes = plt.subplots(10, 10, figsize=(10, 10))
axes_list = []
#输入到网络的图像
for i in range(axes.shape[0]):
for j in range(axes.shape[1]):
axes[i, j].imshow(train_img[i*10+j,0,:,:],cmap="gray") #这里画出来的就是我们想输入到网络里训练的图像,与之对应的标签用来进行最后分类结果损失函数的计算
axes[i, j].axis("off")
#对应的标签
print(train_label)
第三步:用pytorch搭建简单的卷积神经网络(CNN)
这里把卷积模块单独拿出来作为一个类,看上去会舒服一点。
#卷积模块,由卷积核和激活函数组成
class conv_block(nn.Module):
def __init__(self,ks,ch_in,ch_out):
super(conv_block,self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(ch_in, ch_out, kernel_size=ks,stride=1,padding=1,bias=True), #二维卷积核,用于提取局部的图像信息
nn.ReLU(inplace=True), #这里用ReLU作为激活函数
nn.Conv2d(ch_out, ch_out, kernel_size=ks,stride=1,padding=1,bias=True),
nn.ReLU(inplace=True),
)
def forward(self,x):
return self.conv(x)
下面是CNN主体部分,由上面的卷积模块和全连接分类器组合而成。这里只用了简单的几个卷积块进行堆叠,没有采用池化以及dropout的操作。主要目的是给大家简单搭建一下以便学习。
#常规CNN模块(由几个卷积模块堆叠而成)
class CNN(nn.Module):
def __init__(self,kernel_size,in_ch,out_ch):
super(CNN, self).__init__()
feature_list = [16,32,64,128,256] #代表每一层网络的特征数,扩大特征空间有助于挖掘更多的局部信息
self.conv1 = conv_block(kernel_size,in_ch,feature_list[0])
self.conv2 = conv_block(kernel_size,feature_list[0],feature_list[1])
self.conv3 = conv_block(kernel_size,feature_list[1],feature_list[2])
self.conv4 = conv_block(kernel_size,feature_list[2],feature_list[3])
self.conv5 = conv_block(kernel_size,feature_list[3],feature_list[4])
self.fc = nn.Sequential( #全连接层主要用来进行分类,整合采集的局部信息以及全局信息
nn.Linear(feature_list[4] * 28 * 28, 1024), #此处28为MINST一张图片的维度
nn.ReLU(),
nn.Linear(1024, 512),
nn.ReLU(),
nn.Linear(512, 10)
)
def forward(self,x):
device = x.device
x1 = self.conv1(x )
x2 = self.conv2(x1)
x3 = self.conv3(x2)
x4 = self.conv4(x3)
x5 = self.conv5(x4)
x5 = x5.view(x5.size()[0], -1) #全连接层相当于做了矩阵乘法,所以这里需要将维度降维来实现矩阵的运算
out = self.fc(x5)
return out
第四步:训练以及模型保存
先是一些网络参数的定义,包括优化器,迭代轮数,学习率,运行硬件等等的确定。
#网络参数定义
device = torch.device("cuda:4") #此处根据电脑配置进行选择,如果没有cuda就用cpu
#device = torch.device("cpu")
net = CNN(3,1,1).to(device = device,dtype = torch.float32)
epochs = 50 #训练轮次
optimizer = torch.optim.Adam(net.parameters(), lr=1e-4, weight_decay=1e-8) #使用Adam优化器
criterion = nn.CrossEntropyLoss() #分类任务常用的交叉熵损失函数
train_loss = []
然后是每一轮训练的主体:
# Begin training
MinTrainLoss = 999
for epoch in range(1,epochs+1):
total_train_loss = []
net.train()
start = time.time()
for input_img,label in trainloader:
input_img = input_img.to(device = device,dtype=torch.float32) #我们同样地,需要将我们取出来的训练集数据进行torch能够运算的格式转换
label = label.to(device = device,dtype=torch.float32) #输入和输出的格式都保持一致才能进行运算
optimizer.zero_grad() #每一次算loss前需要将之前的梯度清零,这样才不会影响后面的更新
pred_img = net(input_img)
loss = criterion(pred_img,label.long())
loss.backward()
optimizer.step()
total_train_loss.append(loss.item())
train_loss.append(np.mean(total_train_loss)) #将一个minibatch里面的损失取平均作为这一轮的loss
end = time.time()
#打印当前的loss
print("epochs[%3d/%3d] current loss: %.5f, time: %.3f"%(epoch,epochs,train_loss[-1],(end-start))) #打印每一轮训练的结果
if train_loss[-1]<MinTrainLoss:
torch.save(net.state_dict(), "./model_min_train.pth") #保存loss最小的模型
MinTrainLoss = train_loss[-1]
以下是迭代过程:
第五步:导入网络模型,输入某一批测试数据,查看结果
我们先来看某一批测试数据
#测试机某一批数据
test_img,test_label=next(iter(testloader))
fig, axes = plt.subplots(10, 10, figsize=(10, 10))
axes_list = []
#输入到网络的图像
for i in range(axes.shape[0]):
for j in range(axes.shape[1]):
axes[i, j].imshow(test_img[i*10+j,0,:,:],cmap="gray")
axes[i, j].axis("off")
然后将其输入到训练好的模型进行预测
#预测我拿出来的那一批数据进行展示
cnn = CNN(3,1,1).to(device = device,dtype = torch.float32)
cnn.load_state_dict(torch.load("./model_min_train.pth", map_location=device)) #导入我们之前已经训练好的模型
cnn.eval() #评估模式
test_img = test_img.to(device = device,dtype = torch.float32)
test_label = test_label.to(device = device,dtype = torch.float32)
pred_test = cnn(test_img) #记住,输出的结果是一个长度为10的tensor
test_pred = np.argmax(pred_test.cpu().data.numpy(), axis=1) #所以我们需要对其进行最大值对应索引的处理,从而得到我们想要的预测结果
#预测结果以及标签
print("预测结果")
print(test_pred)
print("标签")
print(test_label.cpu().data.numpy())
从预测的结果我们可以看到,整体上这么一个简单的CNN搭配全连接分类器对MNIST这一批数据分类的效果还不错。当然,我这里只用了交叉熵损失函数,并且没有计算准确率,仅供大家对于CNN学习和参考。
来源:https://blog.csdn.net/qq_43397591/article/details/128960528


猜你喜欢
- 和之前C++执行Linux Bash命令的方法 一样,Python依然支持system调用和popen()函数来执行linux bash命令
- 一,写在前面的话最近公司需要按天,按小时查看数据,可以直观的看到时间段的数据峰值。接到需求,就开始疯狂百度搜索,但是搜索到的资料有很多都不清
- 前言Go语言作为一个由Google开发,号称互联网的C语言的语言,自然也对JSON格式支持很好。下面这篇文章主要介绍了关于golang自定义
- 切片:方便截取list、tuple、字符串部分索引的内容正序切片语法:dlist = doList[0:3]表示,从索引0开始取,直到索引3
- 我就废话不多说了,大家还是直接看代码吧!cmd.py# -*- coding: utf-8 -*-from PySide import Qt
- 旁站查询来源:http://dns.aizhan.comhttp://s.tool.chinaz.com/samehttp://i.link
- 马上就要过节了,想把自己的项目搞得酷炫一些,对整个网站的按钮添加图标、飘花效果、首屏大图展示、顶部导航背景图,于是就写了这一遍文字,如有兴趣
- 本文实例讲述了Python基于jieba库进行简单分词及词云功能实现方法。分享给大家供大家参考,具体如下:目标:1.导入一个文本文件2.使用
- <%'asp事务处理。'测试数据库为sql server,服务器为本机,数据库名为test,表名为a,两个字段id(i
- 步骤很简单,直接进入主题。第一步:创建一个python项目。解析器什么的自己选择,环境目录默认就好。第二步:下载scrapy,步骤file-
- flask之模板继承为什么要用模板继承?原因很简单,因为模板继承能让我们在实现效果的前提下少些很多代码!咱废话不多说,先来看个小例子,看完我
- Rect(rectangle)指的是矩形,或者长方形,在 Pygame 中我们使用 Rect() 方法来创建一个指定位置,大小的矩形区域。函
- 把 Oracle 数据库从 RAC 集群迁移到单机环境一、系统环境1、源数据库db_name:hisdb SID:hisdb1、
- 之前的文章介绍了python抓取网页数据并将数据保存到本地excel文件,后续可以将数据保存到数据库(SqlServer、mysql等)中,
- 摘要:本篇博客介绍了本教程的目标、适用人群、YOLOv5简介和车牌识别的意义和应用场景。为后续章节打下基础,帮助读者了解YOLOv5和车牌识
- 版本:平台:ubuntu 14 / I5 / 4G内存python版本:python2.7opencv版本:2.13.4依赖:如果系统没有p
- 序言小学妹说要毕业了,学了一学期Python等于没学,现在要做毕设做不出来,让我帮帮她,晚上去她家吃夜宵。当时我心想,这不是分分钟的事情,还
- python np.dot(a,b)运算规则解析首先我们知道dot运算时不满 * 换律的,np.dot(a, b)与np.dot(b, a)是
- 1、首先在本机安装ssh在cmd输入ssh,出现下面信息代表安装成功2、vscode安装 Remote - SSH 插件3、连接远程主机vs
- 1. 标签{% 标签 %}1.1 for循环标签<ul><!-- 可迭代对象都可以用循环 --><!-- 循环