pytorch 图像中的数据预处理和批标准化实例
作者:xckkcxxck 发布时间:2023-07-16 15:08:12
目前数据预处理最常见的方法就是中心化和标准化。
中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。
标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间
批标准化:BN
在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。
所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。
batch normalization 的实现非常简单,接下来写一下对应的python代码:
import sys
sys.path.append('..')
import torch
def simple_batch_norm_1d(x, gamma, beta):
eps = 1e-5
x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)
测试的时候该使用批标准化吗?
答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替
下面我们实现以下能够区分训练状态和测试状态的批标准化方法
def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
eps = 1e-5
x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
if is_training:
x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
else:
x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
下面我们在卷积网络下试用一下批标准化看看效果
def data_tf(x):
x = np.array(x, dtype='float32') / 255
x = (x - 0.5) / 0.5 # 数据预处理,标准化
x = torch.from_numpy(x)
x = x.unsqueeze(0)
return x
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
def __init__(self):
super(conv_bn_net, self).__init__()
self.stage1 = nn.Sequential(
nn.Conv2d(1, 6, 3, padding=1),
nn.BatchNorm2d(6),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.BatchNorm2d(16),
nn.ReLU(True),
nn.MaxPool2d(2, 2)
)
self.classfy = nn.Linear(400, 10)
def forward(self, x):
x = self.stage1(x)
x = x.view(x.shape[0], -1)
x = self.classfy(x)
return x
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
train(net, train_data, test_data, 5, optimizer, criterion)
来源:https://blog.csdn.net/xckkcxxck/article/details/82348789
猜你喜欢
- 首先来看看这个php字符串替换函数 strtr()的两种用法:strtr(string,from,to) 或者strtr(string,ar
- 前言Go大概2009年面世以来,已经8年了,也算是8年抗战。在这8年中,已经有很多公司开始使用Go语言开发自己的服务,甚至完全转向Go开发,
- 本意是为了和手写jdbc对照,不过不要和原来的手写连接重名。打开cmd,直接输入notepad就打开了记事本。jdk1.5之后不必配置cla
- 《色彩解答》系列之一 色彩层次这次我们将深入进去了解一下众多色彩在一起之后所存在的“比例”关系。我们在使用色彩的时候不可能把所有的色彩都做得
- 在MySQL中,事务就是一个逻辑工作单元的一系列步骤。事务是用来保证数据操作的安全性。事务的特征:1.Atomicity(原子性)2.Con
- k8s容器互联-flannel host-gw原理篇容器系列文章容器系列视频简析host-gw前面分析了flannel vxlan模式进行容
- 自定义数据库自动编号初始值和步进值问题: 如何定义数据库的自动编号字段的初始值和步进值?如何定义自动增加字段的初始值和步进值?如何使删除过数
- 一、媒体管道1.1、媒体管道的特性媒体管道实现了以下特性:避免重新下载最近下载的媒体指定存储位置(文件系统目录,Amazon S3 buck
- 前言数学建模的介绍与作用全国大学生数学建模竞赛:全国大学生数学建模竞赛创办于1992年,每年一届,已成为全国高校规模最大的基础性学科竞赛,也
- 表结构的修改1、表结构修改后,原来表中已存在的数据,就会出现结构混乱,makemigrations更新表的时候就会出错比如第一次建模型,漏了
- 需要使用到的文件wxapp.py, read_file.py, setup.py#!/usr/bin/env python# -*- cod
- 使用keras时,加入keras的lambda层以实现自己定义的操作。但是,发现操作结果的shape信息有问题。我的后端是theano,使用
- 变量的存储在高级语言中,变量是对内存及其地址的抽象。对于python而言,python的一切变量都是对象,变量的存储,采用了引用语义的方式,
- 构建网络ResNet由一系列堆叠的残差块组成,其主要作用是通过无限制地增加网络深度,从而使其更加强大。在建立ResNet模型之前,让我们先定
- 话说土匪老湿在他的大作 《交互设计之回归篇》 里曝光了上次有意思小组竞赛我们小组分享的话题 “瞬间的快感”,但这一极具噱
- Pandas是一个用于数据分析和操作的Python库。在pandas中几乎所有的操作都围绕着DataFrame。Dataframe是一个二维
- 我们在建立一个大型网站的时候会有很多副页面框架模式,甚至一些细节元素都是相同的。但令人困扰的是更新它们却要费些周折,要一遍遍地反复更新每个页
- 前言最近在学习python,对于python的print一直很恼火,老是不按照预期输出。在python2中print是一种输出语句,和if语
- 一、类型数组是值类型,将一个数组赋值给另一个数组时,传递的是一份拷贝。切片是引用类型,切片包装的数组称为该切片的底层数组。我们来看一段代码/
- 这篇文章主要介绍了python3 pathlib库Path类方法总结,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习