pytorch 图像中的数据预处理和批标准化实例
作者:xckkcxxck 发布时间:2023-07-16 15:08:12
目前数据预处理最常见的方法就是中心化和标准化。
中心化相当于修正数据的中心位置,实现方法非常简单,就是在每个特征维度上减去对应的均值,最后得到 0 均值的特征。
标准化也非常简单,在数据变成 0 均值之后,为了使得不同的特征维度有着相同的规模,可以除以标准差近似为一个标准正态分布,也可以依据最大值和最小值将其转化为 -1 ~ 1 之间
批标准化:BN
在数据预处理的时候,我们尽量输入特征不相关且满足一个标准的正态分布,这样模型的表现一般也较好。但是对于很深的网路结构,网路的非线性层会使得输出的结果变得相关,且不再满足一个标准的 N(0, 1) 的分布,甚至输出的中心已经发生了偏移,这对于模型的训练,特别是深层的模型训练非常的困难。
所以在 2015 年一篇论文提出了这个方法,批标准化,简而言之,就是对于每一层网络的输出,对其做一个归一化,使其服从标准的正态分布,这样后一层网络的输入也是一个标准的正态分布,所以能够比较好的进行训练,加快收敛速度。
batch normalization 的实现非常简单,接下来写一下对应的python代码:
import sys
sys.path.append('..')
import torch
def simple_batch_norm_1d(x, gamma, beta):
eps = 1e-5
x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
x = torch.arange(15).view(5, 3)
gamma = torch.ones(x.shape[1])
beta = torch.zeros(x.shape[1])
print('before bn: ')
print(x)
y = simple_batch_norm_1d(x, gamma, beta)
print('after bn: ')
print(y)
测试的时候该使用批标准化吗?
答案是肯定的,因为训练的时候使用了,而测试的时候不使用肯定会导致结果出现偏差,但是测试的时候如果只有一个数据集,那么均值不就是这个值,方差为 0 吗?这显然是随机的,所以测试的时候不能用测试的数据集去算均值和方差,而是用训练的时候算出的移动平均均值和方差去代替
下面我们实现以下能够区分训练状态和测试状态的批标准化方法
def batch_norm_1d(x, gamma, beta, is_training, moving_mean, moving_var, moving_momentum=0.1):
eps = 1e-5
x_mean = torch.mean(x, dim=0, keepdim=True) # 保留维度进行 broadcast
x_var = torch.mean((x - x_mean) ** 2, dim=0, keepdim=True)
if is_training:
x_hat = (x - x_mean) / torch.sqrt(x_var + eps)
moving_mean[:] = moving_momentum * moving_mean + (1. - moving_momentum) * x_mean
moving_var[:] = moving_momentum * moving_var + (1. - moving_momentum) * x_var
else:
x_hat = (x - moving_mean) / torch.sqrt(moving_var + eps)
return gamma.view_as(x_mean) * x_hat + beta.view_as(x_mean)
下面我们在卷积网络下试用一下批标准化看看效果
def data_tf(x):
x = np.array(x, dtype='float32') / 255
x = (x - 0.5) / 0.5 # 数据预处理,标准化
x = torch.from_numpy(x)
x = x.unsqueeze(0)
return x
train_set = mnist.MNIST('./data', train=True, transform=data_tf, download=True) # 重新载入数据集,申明定义的数据变换
test_set = mnist.MNIST('./data', train=False, transform=data_tf, download=True)
train_data = DataLoader(train_set, batch_size=64, shuffle=True)
test_data = DataLoader(test_set, batch_size=128, shuffle=False)
# 使用批标准化
class conv_bn_net(nn.Module):
def __init__(self):
super(conv_bn_net, self).__init__()
self.stage1 = nn.Sequential(
nn.Conv2d(1, 6, 3, padding=1),
nn.BatchNorm2d(6),
nn.ReLU(True),
nn.MaxPool2d(2, 2),
nn.Conv2d(6, 16, 5),
nn.BatchNorm2d(16),
nn.ReLU(True),
nn.MaxPool2d(2, 2)
)
self.classfy = nn.Linear(400, 10)
def forward(self, x):
x = self.stage1(x)
x = x.view(x.shape[0], -1)
x = self.classfy(x)
return x
net = conv_bn_net()
optimizer = torch.optim.SGD(net.parameters(), 1e-1) # 使用随机梯度下降,学习率 0.1
train(net, train_data, test_data, 5, optimizer, criterion)
来源:https://blog.csdn.net/xckkcxxck/article/details/82348789
![](https://www.aspxhome.com/images/zang.png)
![](https://www.aspxhome.com/images/jiucuo.png)
猜你喜欢
- 一看,C盘只有不到2M可用空间,一查原因,sqlserver安装路径下的log目录文件占了好大,5G多, 于是上网搜了下,解决了: 把与sq
- 这个话题是应腾讯ISD同仁之邀在WebReBuild三周年交流会上做的主题分享。由于临场等原因有些问题当时没有讲明白,回来后按原有思路形成了
- 概述在使用keras中的keras.backend.batch_dot和tf.matmul实现功能其实是一样的智能矩阵乘法,比如A,B,C,
- 本文实例讲述了python实现获取序列中最小的几个元素。分享给大家供大家参考。具体方法如下:import heapq import rand
- 背景我们经常调侃程序员每天都在写bug,这确实是事实,没有测出bug不代表程序就真的不存在问题。传统的代码review、静态分析、人工测试和
- Tornado 4.0 已经发布了很长一段时间了, 新版本广泛的应用了协程(Future)特性. 我们目前已经将 Tornado 升级到最新
- 前言.net core来势已不可阻挡。既然挡不了,那我们就顺应它。了解它并学习它。今天我们就来看看和之前.net版本的配置文件读取方式有何异
- javascript版 俄罗斯方块(Russian box)小游戏,喜欢的朋友可以玩玩。对源代码感兴趣的朋友也可以研究一下。玩法介绍:可以输
- PDO::queryPDO::query — 执行 SQL 语句,返回PDOStatement对象,可以理解为结果集(PHP 5 >=
- 本文实例为大家分享了JS实现拖动模糊框特效的具体代码,供大家参考,具体内容如下需求:在图片上拖动按钮,图片蒙层慢慢覆盖,当蒙层边缘碰到左右下
- PHP的header函数 可以很少代码就能实现HTML代码中META 标签这里只说用 header函数来做页面的跳转1. HTML代码中页面
- 从开始认识CSS(DW4)那时起,我就知道了CSS的强大,但从未用CSS排版过,因为我曾经尝试过学习,但感觉太难了而且用DW的表格,所见及所
- javascript实现翻页效果:<html> <head> <title>上下翻页看 - aspxho
- 疫情数据程序源码// An highlighted blockimport requestsimport jsonclass epidemi
- 假设有2个有序列表l1、l2,如何效率比较高的将2个list合并并保持有序状态,这里默认排序是正序。思路是比较简单的,无非是依次比较l1和l
- 本文实例讲述了python函数enumerate,operator和Counter使用技巧。分享给大家供大家参考,具体如下:最近看人家的代码
- 一、装饰器decorator decorator设计模式允许动态地对现有的对象或函数包装以至于修改现有的职责和行为,简单地讲用来动态地扩展现
- 关于权限管理的思考最近用laravel设计后台,后台需要有个权限管理。权限管理实质上分为两个部分,首先是认证,然后是权限。认证部分非常好做,
- 在PHP中,有两种包含外部文件的方式,分别是include和require。他们之间有什么不同呢?如果文件不存在或发生了错误,require
- WordPress 的插件机制实际上只的就是这个 Hook 了,它中文被翻译成钩子,允许你参与 WordPress 核心的运行,是一个非常棒