网络编程
位置:首页>> 网络编程>> Python编程>> PyTorch实现卷积神经网络的搭建详解

PyTorch实现卷积神经网络的搭建详解

作者:Bubbliiiing  发布时间:2021-03-12 16:36:03 

标签:PyTorch,卷积神经网络,神经网络

PyTorch中实现卷积的重要基础函数

1、nn.Conv2d:

nn.Conv2d在pytorch中用于实现卷积。

nn.Conv2d(
   in_channels=32,
   out_channels=64,
   kernel_size=3,
   stride=1,
   padding=1,
)

1、in_channels为输入通道数。

2、out_channels为输出通道数。

3、kernel_size为卷积核大小。

4、stride为步数。

5、padding为padding情况。

6、dilation表示空洞卷积情况。

2、nn.MaxPool2d(kernel_size=2)

nn.MaxPool2d在pytorch中用于实现最大池化。

具体使用方式如下:

MaxPool2d(kernel_size,
stride=None,
padding=0,
dilation=1,
return_indices=False,
ceil_mode=False)

1、kernel_size为池化核的大小

2、stride为步长

3、padding为填充情况

3、nn.ReLU()

nn.ReLU()用来实现Relu函数,实现非线性。

4、x.view()

x.view用于reshape特征层的形状。

全部代码

这是一个简单的CNN模型,用于预测mnist手写体。

import os
import numpy as np
import torch
import torch.nn as nn
import torch.utils.data as Data
import torchvision
import matplotlib.pyplot as plt
# 循环世代
EPOCH = 20
BATCH_SIZE = 50
# 下载mnist数据集
train_data = torchvision.datasets.MNIST(root='./mnist/',train=True,transform=torchvision.transforms.ToTensor(),download=True,)
# (60000, 28, 28)
print(train_data.train_data.size())                
# (60000)
print(train_data.train_labels.size())              
train_loader = Data.DataLoader(dataset=train_data, batch_size=BATCH_SIZE, shuffle=True)
# 测试集
test_data = torchvision.datasets.MNIST(root='./mnist/', train=False)
# (2000, 1, 28, 28)
# 标准化
test_x = torch.unsqueeze(test_data.test_data, dim=1).type(torch.FloatTensor)[:2000]/255.
test_y = test_data.test_labels[:2000]
# 建立pytorch神经网络
class CNN(nn.Module):
   def __init__(self):
       super(CNN, self).__init__()
       #----------------------------#
       #   第一部分卷积
       #----------------------------#
       self.conv1 = nn.Sequential(
           nn.Conv2d(
               in_channels=1,
               out_channels=32,
               kernel_size=5,
               stride=1,
               padding=2,
               dilation=1
           ),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2),
       )
       #----------------------------#
       #   第二部分卷积
       #----------------------------#
       self.conv2 = nn.Sequential(
           nn.Conv2d(
               in_channels=32,
               out_channels=64,
               kernel_size=3,
               stride=1,
               padding=1,
               dilation=1
           ),
           nn.ReLU(),
           nn.MaxPool2d(kernel_size=2),
       )
       #----------------------------#
       #   全连接+池化+全连接
       #----------------------------#
       self.ful1 = nn.Linear(64 * 7 * 7, 512)
       self.drop = nn.Dropout(0.5)
       self.ful2 = nn.Sequential(nn.Linear(512, 10),nn.Softmax())
   #----------------------------#
   #   前向传播
   #----------------------------#  
   def forward(self, x):
       x = self.conv1(x)
       x = self.conv2(x)
       x = x.view(x.size(0), -1)
       x = self.ful1(x)
       x = self.drop(x)
       output = self.ful2(x)
       return output
cnn = CNN()
# 指定优化器
optimizer = torch.optim.Adam(cnn.parameters(), lr=1e-3)
# 指定loss函数
loss_func = nn.CrossEntropyLoss()
for epoch in range(EPOCH):
   for step, (b_x, b_y) in enumerate(train_loader):
       #----------------------------#
       #   计算loss并修正权值
       #----------------------------#  
       output = cnn(b_x)
       loss = loss_func(output, b_y)
       optimizer.zero_grad()
       loss.backward()
       optimizer.step()
       #----------------------------#
       #   打印
       #----------------------------#  
       if step % 50 == 0:
           test_output = cnn(test_x)
           pred_y = torch.max(test_output, 1)[1].data.numpy()
           accuracy = float((pred_y == test_y.data.numpy()).astype(int).sum()) / float(test_y.size(0))
           print('Epoch: %2d'% epoch, ', loss: %.4f' % loss.data.numpy(), ', accuracy: %.4f' % accuracy)

来源:https://blog.csdn.net/weixin_44791964/article/details/103658845

0
投稿

猜你喜欢

  • 创建一个名为templatetags的python module。新建一个名为verbose_name.py的文件。from django
  • 我就废话不多说了,大家还是直接看代码吧~import pymysqlfrom sshtunnel import SSHTunnelForwa
  • 记得很早以前看到过这样的一段介绍:想象你在逛街边的一家书店,如果最终你没有购买任何图书就直接离开了,店长并不会知道你来过。但是如果你买了书,
  • 和朋友讨论时,我提到过一个观点,所有框架层设计中,最核心的是导航设计。最近更看到有国外同行提出“80%的可用性是导航!”因为良好的导航可以保
  • 如果在子类中需要父类的构造方法就需要显式地调用父类的构造方法,或者不重写父类的构造方法。子类不重写 __init__,实例化子类时,会自动调
  • 今天做项目时,有一个这样的需求,需要动态删除的Tab,比如:可以删除某一个,可以删除多个。每一个Tab对应一个iframe。本来我的代码是这
  • 一,PHP脚本与动态页面。 PHP脚本是一种服务器端脚本程序,可通过嵌入等方法与HTML文件混合,也可以类,函数封装等形式,以模板的方式对用
  • 1、控制"纵打"、 横打”和“页面的边距。 (1)<script defer> function SetPr
  • 引言语音端点检测最早应用于电话传输和检测系统当中,用于通信信道的时间分配,提高传输线路的利用效率.端点检测属于语音处理系统的前端操作,在语音
  • 在本人看来,HTML 5是一个妥协方案,虽不激进,但更能推动技术的继续进步。没有命名空间,元素也不要求闭合(当然这并不是优点),浏览器也可以
  • 上周在去杭州betacafe的路上,有幸和绿人网梁宁和饭统网李耀东、千鸟一道,在出租车上聊起了地理和历史,其中有一个共同的观点是说,人们对事
  • 因为工作中需要,需要生成一个带表格的图片例如:直接在html中写一个table标签,然后单独把表格部分保存成图片或者是直接将excel中的内
  • 一直想了解Web编程的技术。PHP是进行Web编程重要的一种语言,书上总是说,PHP是用于服务器端的编程语言。但是,实在不能理解它是怎么用于
  • 安 * oostpython调用C/C++的方法有很多,本文使用boost.python。考虑到后期有好多在boost上的开发工作,所以boo
  •  MySQL是一个真正的多用户、多线程SQL数据库服务器。MySQL是以一个客户机/服务器结构的实现,它由一个服务器守护程序mys
  • 前段时间冷空气突袭的时候,据说郊区密云的雪积得挺厚,但北京城内除了飘了一点小雪粒,毫无动静。应该是气温过高所致,我在慈云寺桥附近拍下的照片可
  • 数据透视表(Pivot Table)是 Excel 中一个非常实用的分析功能,可以用于实现复杂的数据分类汇总和对比分析,是数据分析师和运营人
  • python / 和 % 和 //(地板除)用于对数据进行除法运算。python中 // 和 / 和 %简介python中与除法相关的三个运
  • 一、摘要Python使用被称为异常 的特殊对象来管理程序执行期间发生的错误。每当发生让Python不知所措的错误时,它都会创建一个异常对象。
  • 非常不错,大家可以自己应用下。<% '//数据处理部分 dim Content,Num,I,st
手机版 网络编程 asp之家 www.aspxhome.com