pytorch神经网络之卷积层与全连接层参数的设置方法
作者:嘿芝麻 发布时间:2023-07-11 14:21:27
标签:pytorch,卷积层,全连接层,参数
当使用pytorch写网络结构的时候,本人发现在卷积层与第一个全连接层的全连接层的input_features不知道该写多少?一开始本人的做法是对着pytorch官网的公式推,但是总是算错。
后来发现,写完卷积层后可以根据模拟神经网络的前向传播得出这个。
全连接层的input_features是多少。首先来看一下这个简单的网络。这个卷积的Sequential本人就不再啰嗦了,现在看nn.Linear(???, 4096)这个全连接层的第一个参数该为多少呢?
请看下文详解。
class AlexNet(nn.Module):
def __init__(self):
super(AlexNet, self).__init__()
self.conv = nn.Sequential(
nn.Conv2d(3, 96, kernel_size=11, stride=4),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(96, 256, kernel_size=5, padding=2),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2),
nn.Conv2d(256, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 384, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.Conv2d(384, 256, kernel_size=3, padding=1),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2)
)
self.fc = nn.Sequential(
nn.Linear(???, 4096)
......
......
)
首先,我们先把forward写一下:
def forward(self, x):
x = self.conv(x)
print x.size()
就写到这里就可以了。其次,我们初始化一下网络,随机一个输入:
import torch
from Alexnet.AlexNet import *
from torch.autograd import Variable
if __name__ == '__main__':
net = AlexNet()
data_input = Variable(torch.randn([1, 3, 96, 96])) # 这里假设输入图片是96x96
print data_input.size()
net(data_input)
结果如下:
(1L, 3L, 96L, 96L)
(1L, 256L, 1L, 1L)
显而易见,咱们这个全连接层的input_features为256。
来源:https://blog.csdn.net/zw__chen/article/details/82839061


猜你喜欢
- 本文实例讲述了Python中super关键字用法。分享给大家供大家参考。具体分析如下:在Python类的方法(method)中,要调用父类的
- 对表误删或执行缺少条件的修改 SQL 导致修改了表内其他数据时,我们需要想办法将数据恢复回来。先创建两个测试表 table_1CREATE
- 当逐渐在用python开发项目或者日常使用时,一般需要大量使用别人提供的包,这些包能高效的帮助我们快速高效的完成指定任务或者需求,不过有时也
- APSchedulerAPScheduler 四个组件分别为:调度器(scheduler)、触发器(trigger),作业存储(job st
- 最近笔者学会了用FrontPage XP做网页,心理特高兴,非常想把我在制作主页过程中的一些经验和大家交流交流、切磋切磋,我们一起来看看吧。
- 我们一般使用爬虫看到的都是最后的数据结果,对于整个的获取过程没有过多了解过。对于初学python的小伙伴们来说,不光是代码的练习,还是原理的
- 0. 引言有如上一张图片,在以往的图像旋转处理中,往往得到如图所示的图片。然而,在进行一些其他图像处理或者图像展示时,黑边带来了一些不便。本
- 一、前言容器使用沙箱机制,互相隔离,优势在于让各个部署在容器的里的应用互不影响,独立运行,提供更高的安全性。本文主要介绍python应用(d
- 5.1.5 表单验证 表单作为 HTML 最重要的一个组成部分,几乎在每个网页上
- JS怎样知道Flash广告条被网友点击过? 1、Flash广告条不是我做的,它的链接是写在里面的。 2、我想统计这个Flash被网友点击了多
- 首先让我们来看看有关 Perl 面向对象编程的三个基本定义:1. 一个“对象”是指一个“有办法知道它是属于哪个类”的简单引用。(
- 如何利用pandas读取csv数据并绘图导包,常用的numpy和pandas,绘图模块matplotlib,import matplotli
- 1 基本信息- 模块主页:[github]- 类型:#第三方库2 安装方法pip install pythonping3 一般使用from
- Git 基本操作Git 的工作就是创建和保存你项目的快照及与之后的快照进行对比。本章将对有关创建与提交你的项目快照的命令作介绍。获取与创建项
- 总的来说:1、数据库设计和表创建时就要考虑性能2、sql的编写需要注意优化3、分区、分表、分库设计表的时候:1、字段避免null值出现,nu
- MySQL字符集出错的解决方法:错误案例: Illegal mix of collations (gbk_chinese_ci,I
- 异常的捕获与处理什么是错误简而言之:还没运行,在语法解析的时候,就发现语法存在问题,这个时候就是错误。什么是异常简而言之:代码写好之后,无明
- 本文实例讲述了python判断字符串是否包含子字符串的方法。分享给大家供大家参考。具体如下:python的string对象没有contain
- JS在firefox中的兼容性问题,自己也经常遇到.此文是网上资料,不过时间较久不记得原址了...1. document.form.item
- 函数的概念,函数是将具有独立功能的代码块组织成为一个整体,使其具有特殊功能的代码集函数的作用,使用函数可以加强代码的复用性,提高程序编写的效