PyTorch实现AlexNet示例
作者:mingo_敏 发布时间:2021-08-31 20:15:44
标签:PyTorch,AlexNet
PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks
import torch
import torch.nn as nn
import torchvision
class AlexNet(nn.Module):
def __init__(self,num_classes=1000):
super(AlexNet,self).__init__()
self.feature_extraction = nn.Sequential(
nn.Conv2d(in_channels=3,out_channels=96,kernel_size=11,stride=4,padding=2,bias=False),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
nn.Conv2d(in_channels=96,out_channels=192,kernel_size=5,stride=1,padding=2,bias=False),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3,stride=2,padding=0),
nn.Conv2d(in_channels=192,out_channels=384,kernel_size=3,stride=1,padding=1,bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=384,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=256,out_channels=256,kernel_size=3,stride=1,padding=1,bias=False),
nn.ReLU(inplace=True),
nn.MaxPool2d(kernel_size=3, stride=2, padding=0),
)
self.classifier = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(in_features=256*6*6,out_features=4096),
nn.ReLU(inplace=True),
nn.Dropout(p=0.5),
nn.Linear(in_features=4096, out_features=4096),
nn.ReLU(inplace=True),
nn.Linear(in_features=4096, out_features=num_classes),
)
def forward(self,x):
x = self.feature_extraction(x)
x = x.view(x.size(0),256*6*6)
x = self.classifier(x)
return x
if __name__ =='__main__':
# model = torchvision.models.AlexNet()
model = AlexNet()
print(model)
input = torch.randn(8,3,224,224)
out = model(input)
print(out.shape)
来源:https://blog.csdn.net/shanglianlm/article/details/86424857
0
投稿
猜你喜欢
- 1.批量处理所谓的批处理就是批量处理cmd里面的命令。python要想实现批处理功能需要导入os库,然后利用批处理的命令为os.system
- 一、记事本源码#python简易记事本from tkinter import *from tkinter import messagebox
- 概述Python是个非常受欢迎的编程语言,随着近些年机器学习、云计算等技术的发展,Python的职位需求越来越高。下面我收集了10个Pyth
- 一、读写excel数据利用pandas可以很方便的读写excel数据1.1 读:data_in = pd.read_excel('M
- 键盘事件废话不多说直接上包from selenium.webdriver.common.keys import Keys1、删除键 BACK
- 前言:文章里用的Python环境是Anaconda3 2019.7这里测试的程序是找出所有1000以内的勾股数。a∈[1,
- 实例化对象名._类名__私有属性名 class Flylove:price = 123 def __init__(self):s
- 线性逻辑回归本文用代码实现怎么利用sklearn来进行线性逻辑回归的计算,下面先来看看用到的数据。这是有两行特征的数据,然后第三行是数据的标
- import numpy as npimport pandas as pdimport matplotlib.pylab as pltif
- 本文实例讲述了Python序列对象与String类型内置方法。分享给大家供大家参考,具体如下:前言在Python数据结构篇中介绍了Pytho
- 今天 大白 问了一个关于CSS权重的问题:关于选择器权重的问题 。class的权重是10 标签权重是 1 。比如说 p span{} 权重是
- 触发器权限和所有权CREATE TRIGGER 权限默认授予定义触发器的表所有者、sysadmin 固定服务器角色成员以及 db_owner
- 使用Matplotlib绘制的图表的默认坐标轴是在左下角的,这样对于一些函数的显示不是非常方便,要改变坐标轴的默认显示方式主要要使用gca(
- 如下所示:#coding=utf8import csv import logginglogging.basicConfig(level=lo
- 定义流的作用是使用统一的方式处理文件、网络和数据压缩等共用同一套函数和用法的操作。简单而言,流是具有流式行为的资源对象。因此,流可以线性读写
- Python下有许多款不同的 Web 框架。Django是重量级选手中最有代表性的一位。许多成功的网站和APP都基于Django。Djang
- Demo里的三种方法:方法1是两层div,兼容FF3.1a+, Safari 3+, Chrome, IE6/7方法2是两层div,除了IE
- 最近发现session的知识有点脱节了,默认设置愣是搞半天,看来忘了不少。今天把一些通用设置贴上来,以备随时回顾。配置文件中设置默认操作(通
- scipy.interpolate插值方法1 一维插值from scipy.interpolate import interp1d1维插值算
- 例如:我们在百度中搜索 词典网,则网址后面的参数就是http://www.baidu.com/s?cl=3&wd=%B4%CA%B5