Python LeNet网络详解及pytorch实现
作者:Serins 发布时间:2021-11-15 01:19:17
1.LeNet介绍
LeNet神经网络由深度学习三巨头之一的Yan LeCun提出,他同时也是卷积神经网络 (CNN,Convolutional Neural Networks)之父。LeNet主要用来进行手写字符的识别与分类,并在美国的银行中投入了使用。LeNet的实现确立了CNN的结构,现在神经网络中的许多内容在LeNet的网络结构中都能看到,例如卷积层,Pooling层,ReLU层。虽然LeNet早在20世纪90年代就已经提出了,但由于当时缺乏大规模的训练数据,计算机硬件的性能也较低,因此LeNet神经网络在处理复杂问题时效果并不理想。虽然LeNet网络结构比较简单,但是刚好适合神经网络的入门学习。
2.LetNet网络模型
LeNet网络模型一般指LeNet-5,相信大家学习这个模型的时候一定都见过这张图片吧
这张图也是原论文中的一张模型图,这样子看可能会觉得有点不习惯,下面这张图是本人在drawio软件上制作的网络模型图,如下:
纠正一下,上图中第二个Conv2d层后面的计算结果应该为10,写成了5
相信学习了卷积神经网络基础的朋友们应该能很清晰的看懂这张图吧,对于右边的计算在图的左上角也给出了公式,上图中每一层的输入形状以及输出形状我都详细的为大家写出来了,对于计算公式和模型大致的结构,看下面这张图也可以(建议对应上下图一起看更容易理解)
LeNet-5网络模型简单的就包含了卷积层,最大池化层,全连接层以及relu,softmax激活函数,模型中的输入图片大小以及每一层的卷积核个数,步长都是模型制定好的,一般不要随意修改,能改的是最后的输出结果,即分类数量(num_classes)。flatten操作也叫扁平化操作,我们都知道输入到全连接层中的是一个个的特征,及一维向量,但是卷积网络特征提取出来的特征矩阵并非一维,要送入全连接层,所以需要flatten操作将它展平成一维。
3.pytorch实现LeNet
python代码如下
from torch import nn
import torch
import torch.nn.functional as F
'''
说明:
1.LeNet是5层网络
2.nn.ReLU(inplace=True) 参数为True是为了从上层网络Conv2d中传递下来的tensor直接进行修改,这样能够节省运算内存,不用多存储其他变量
3.本模型的维度注释均省略了N(batch_size)的大小,即input(3, 32, 32)-->input(N, 3, 32, 32)
4.nn.init.xavier_uniform_(m.weight)
用一个均匀分布生成值,填充输入的张量或变量,结果张量中的值采样自U(-a, a),
其中a = gain * sqrt( 2/(fan_in + fan_out))* sqrt(3),
gain是可选的缩放因子,默认为1
'fan_in'保留前向传播时权值方差的量级,'fan_out'保留反向传播时的量级
5.nn.init.constant_(m.bias, 0)
为所有维度tensor填充一个常量0
'''
class LeNet(nn.Module):
def __init__(self, num_classes=10, init_weights=False):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1)
self.maxpool1 = nn.MaxPool2d(kernel_size=2, stride=2)
self.relu = nn.ReLU(True)
self.conv2 = nn.Conv2d(in_channels=16, out_channels=32, kernel_size=5, stride=1)
self.maxpool2 = nn.MaxPool2d(kernel_size=2, stride=2)
self.fc1 = nn.Linear(32 * 5 * 5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, num_classes)
if init_weights:
self._initialize_weights()
def forward(self, x):
x = self.conv1(x) # input(3, 32, 32) output(16, 28, 28)
x = self.relu(x) # 激活函数
x = self.maxpool1(x) # output(16, 14, 14)
x = self.conv2(x) # output(32, 10, 10)
x = self.relu(x) # 激活函数
x = self.maxpool2(x) # output(32, 5, 5)
x = torch.flatten(x, start_dim=1) # output(32*5*5) N代表batch_size
x = self.fc1(x) # output(120)
x = self.relu(x) # 激活函数
x = self.fc2(x) # output(84)
x = self.relu(x) # 激活函数
x = self.fc3(x) # output(num_classes)
return x
def _initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
nn.init.constant_(m.bias, 0)
或者
下面这一种没有自己初始化权重和偏置,就会使用默认的初始化方式
import torch.nn as nn
import torch.nn.functional as F
import torch
class LeNet(nn.Module):
def __init__(self):
super(LeNet, self).__init__()
self.conv1 = nn.Conv2d(3, 16, 5)
self.pool1 = nn.MaxPool2d(2, 2)
self.conv2 = nn.Conv2d(16, 32, 5)
self.pool2 = nn.MaxPool2d(2, 2)
self.fc1 = nn.Linear(32*5*5, 120)
self.fc2 = nn.Linear(120, 84)
self.fc3 = nn.Linear(84, 10)
def forward(self, x):
x = F.relu(self.conv1(x)) # output(16, 28, 28)
x = self.pool1(x) # output(16, 14, 14)
x = F.relu(self.conv2(x)) # output(32, 10, 10)
x = self.pool2(x) # output(32, 5, 5)
x = x.view(x.size(0), -1) # output(32*5*5)
x = F.relu(self.fc1(x)) # output(120)
x = F.relu(self.fc2(x)) # output(84)
x = self.fc3(x) # output(10)
return x
nn.Linear就是全连接层,除了最后一个全连接层,其它均需要relu激活,默认无padding操作
nn.Conv2d对应的参数顺序一定要记住:
1.in_channels:输入的通道数或者深度
2.out_channels:输出的通道数或者深度
3.kernel_size:卷积核的大小
4.stride:步长大小,默认1
5.padding:padding的大小,默认0
6.dilation:膨胀大小,默认1,暂时用不到
7.group:分组组数,默认1
8.bias:默认True,布尔值,是否用偏置值
9.padding_mode:默认用0填充
记不住参数顺序也没关系,但需要记住参数名称
参考文章:pytorch实现LeNet网络模型的训练及预测
来源:https://blog.csdn.net/Serins/article/details/121470862


猜你喜欢
- 事实上,当我们向文件导入某个模块时,导入的是
- 服务器现在同时输出json和xml两种数据,取决于服务程序和页面之间的约定。在程序遇到问题的时候会返回错误信息,也按照相同的约定会返回jso
- 近期有一个业务需求,多台机器需要同时从Mysql一个表里查询数据并做后续业务逻辑,为了防止多台机器同时拿到一样的数据,每台机器需要在获取时锁
- 前言随着 Kotlin 1.4 正式发布,关于 SAM 转换的一些问题就可以盖棺定论了。因为这里要讲的都是些旧的东西,所以这是一篇灌水文。K
- 内容简介展示如何给图像叠加不同等级的椒盐噪声和高斯噪声的代码,相应的叠加噪声的已编为对应的类,可实例化使用。以下主要展示自己编写的:加噪声的
- keras模型可视化:model:model = Sequential()# input: 100x100 images with 3 ch
- 1.获取所有数据库名: SELECT Name FROM Master..SysDatabases ORDER BY Name2.获取所有表
- 我们在学习Python的时候,除了用pip安装一些模块之外,有时候会从网站下载安装包下来安装,我也想要把我自己编写的模块做成这样的安装包,该
- <?php $url='test.php?1=1'; $contents="fjka;fjsa;#page#
- 一、前言Go程序像C/C++一样,如果开发编码考虑不当,会出现cpu负载过高的性能问题。如果程序是线上环境或者特定场景下出现负载过高,问题不
- Mint UI 是饿了么开源的,基于 Vue.js 的移动端组件库。关于Mint UI,有文档不够准确详尽,组件略显粗糙,功能不够完善等问题
- 有时候需要比较大的计算量,这个时候Python的效率就很让人捉急了,此时可以考虑使用numba 进行加速,效果提升明显~(numba 安装貌
- 我在网上找到了一篇文章,简直堪称神器。刚开始用brew search mysql ...能找到,按照提示一步一步安装,结果到最后就是启动不起
- 本文介绍了四种asp导出excel数据的方法:1.使用OWC ,2.用Excel的Application组件,3.直接在IE中打开,4.导出
- 最近,有读者微信上私聊我,想让我写一篇视频批量转换成音频的文章,我答应了,周末宅家里把这个小工具做出来了。 这样,对于有些视频学习
- 第一中方法:比较详细以下的文章主要介绍的是MySQL 数据库开启远程连接的时机操作流程,其实开启MySQL 数据库远程连接的实际操作步骤并不
- 从Web查询数据库:Web数据库架构的工作原理 一个用户的浏览器发出一个HTTP请求,请求特定的Web页面,在该页面中出发form表单提交到
- 最近正好在寻求一种Python的数据库ORM (Object Relational Mapper),SQLAlchemy (项目主页)这个开
- 创建项目django-admin startproject meiduo_mall添加工程完整结构包启动前端python -m http.s
- 一、背景主流被使用的地理坐标系并不统一,导致我们从不同平台下载的数据由于坐标系的差异往往对不齐。这个现象在多源数据处理的时候往往很常见,因此