PyTorch实现卷积神经网络的搭建详解
作者:Bubbliiiing 发布时间:2021-03-12 16:36:03
标签:PyTorch,卷积神经网络,神经网络
PyTorch中实现卷积的重要基础函数
1、nn.Conv2d:
nn.Conv2d在pytorch中用于实现卷积。
nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
)
1、in_channels为输入通道数。
2、out_channels为输出通道数。
3、kernel_size为卷积核大小。
4、stride为步数。
5、padding为padding情况。
6、dilation表示空洞卷积情况。
2、nn.MaxPool2d(kernel_size=2)
nn.MaxPool2d在pytorch中用于实现最大池化。
具体使用方式如下:
MaxPool2d(kernel_size,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode=False)
1、kernel_size为池化核的大小
2、stride为步长
3、padding为填充情况
3、nn.ReLU()
nn.ReLU()用来实现Relu函数,实现非线性。
4、x.view()
x.view用于reshape特征层的形状。
全部代码
这是一个简单的CNN模型,用于预测mnist手写体。
import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
# 循环世代
EPOCH = 20
BATCH_SIZE = 50
# 下载mnist数据集
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=True,)
# (60000, 28, 28)
print(train_data.train_data.size())
# (60000)
print(train_data.train_labels.size())
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 测试集
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# (2000, 1, 28, 28)
# 标准化
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
# 建立pytorch神经网络
class CNN(nn.Module):
def __init__(self):
super(CNN, self).__init__()
#----------------------------#
# 第一部分卷积
#----------------------------#
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels=1,
out_channels=32,
kernel_size=5,
stride=1,
padding=2,
dilation=1
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
#----------------------------#
# 第二部分卷积
#----------------------------#
self.conv2 = nn.Sequential(
nn.Conv2d(
in_channels=32,
out_channels=64,
kernel_size=3,
stride=1,
padding=1,
dilation=1
),
nn.ReLU(),
nn.MaxPool2d(kernel_size=2),
)
#----------------------------#
# 全连接+池化+全连接
#----------------------------#
self.ful1 = nn.Linear(64 * 7 * 7, 512)
self.drop = nn.Dropout(0.5)
self.ful2 = nn.Sequential(nn.Linear(512, 10),nn.Softmax())
#----------------------------#
# 前向传播
#----------------------------#
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
x = self.ful1(x)
x = self.drop(x)
output = self.ful2(x)
return output
cnn = CNN()
# 指定优化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
# 指定loss函数
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
for step, (b_x, b_y) in enumerate(train_loader):
#----------------------------#
# 计算loss并修正权值
#----------------------------#
output = cnn(b_x)
loss = loss_func(output, b_y)
optimizer.zero_grad()
loss.backward()
optimizer.step()
#----------------------------#
# 打印
#----------------------------#
if step % 50 == 0:
test_output = cnn(test_x)
pred_y = torch.max(test_output, 1)[1].data.numpy()
accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
print('Epoch: %2d'% epoch, ', loss: %.4f' % loss.data.numpy(), ', accuracy: %.4f' % accuracy)
来源:https://blog.csdn.net/weixin_44791964/article/details/103658845


猜你喜欢
- 一、python中对文件、文件夹操作时经常用到的os模块和shutil模块常用方法。1.得到当前工作目录,即当前Python脚本工作的目录路
- DRF中的Request在Django REST Framework中内置的Request类扩展了Django中的Request类,实现了很
- 目录前言什么是 websocketwebsocket 通信原理和机制websocket 的特点构建实时日志跟踪的小例子前言websocket
- Javascript中标签(label)是一个标识符。标签可以与变量重名,它是一个独立的语法元素(既不是变量,也不是类型),其作用是标识”标
- 1 打开cmd,不改变运行的目录:输入python 空格 调试好的python文件路径或者python 空格 将py
- CREATE OR REPLACE VIEW BLOG_V_ADMIN (ID,NICK
- 今天在测试php程序的时候,出现了一个错误提示:Cannot use a scalar value as an array,这个错误提示前几
- 题目描述给定n个字符串,请对n个字符串按照字典序排列。输入描述:输入第一行为一个正整数n(1≤n≤1000),下面n行为n个字符串(字符串长
- SQL的扩展的删除与恢复 删除 代码如下:use master exec spdropextendedproc “xpcmdshell“ e
- 1.前言版本:Python3.6.1 + PyQt5写一个程序的时候需要用到画板/手写板,只需要最简单的那种。原以为网上到处都是,结果找了好
- 实际使用Pool 是用于存放临时对象的集合,这些对象是为了后续的使用,以达到复用对象的效果。其目的是缓解频繁创建对象造成的gc压力。在许多开
- 前言FlashText 算法是由 Vikash Singh 于2017年发表的大规模关键词替换算法,这个算法的时间复杂度仅由文本长度(N)决
- 例如,有一个字典如下:>>> dic = {"name": "botoo",&qu
- 格式化是通过格式操作使任意类型的数据转换成一个字符串。例如下面这样<script>console.log(chopper.for
- 介绍 append()语法list.append( element )参数element:任何类型的元素列表「末尾」添加元素nam
- 迭代器即可以遍历诸如列表,字典及字符串等序列对象甚至自定义对象的对象,其本质就是记录迭代对象中每个元素的位置。迭代过程从第一个元素至最后一个
- 设计与开发之间本有一线界限,但当时代步入又一个十年,这个线变得更加模糊甚至感觉不到它的存在。使用PS设计网页版面,足矣?或许五年前是吧!现在
- 前言:最近在探索用Go来读取文件,读取文本时发现,对于单行超长的文本,我的Go代码无法处理。经过查阅才发现,Go提供的Scanner无法读取
- 前言:问题分析:在进行数据库查询的时候,我们都知道索引可以加快数据查询的效率。但是在实际的业务场景下,经常会遇到即使在表中增加了索引,但是同
- @StartIndex为当前页起始序号,@EndIndex为当前页结束记录序号,可以直接作为参数输入,也可以通过输入PageSize和Pag