pytorch+sklearn实现数据加载的流程
作者:梁小憨憨 发布时间:2022-05-15 14:44:27
之前在训练网络的时候加载数据都是稀里糊涂的放进去的,也没有理清楚里面的流程,今天整理一下,加深理解,也方便以后查阅。
pytorch+sklearn实现数据加载
epoch & batch_size & iteration
epoch
:1个epoch等于使用训练集中的全部样本训练一次,通俗的讲epoch的值就是整个数据集被轮几次。batch_size
:批大小。在深度学习中,一般采用SGD训练,即每次训练在训练集中取batchsize个样本训练;iteration
:1个iteration等于使用batch_size个样本训练一次;
优化算法——梯度下降
深度学习的优化算法,说白了就是梯度下降。每次的参数更新有两种方式。
Batch gradient descent
第一种,遍历全部数据集
算一次损失函数,然后算函数对各个参数的梯度,更新梯度,这称为批梯度下降(Batch gradient descent)
。
这样做至少有 2 个好处:其一,由全数据集确定的方向能够更好地代表样本总体,从而更准确地朝向极值所在的方向。其二,由于不同权重的梯度值差别巨大,因此选取一个全局的学习率很困难。 Full Batch Learning 可以使用 Rprop 只基于梯度符号并且针对性单独更新各权值。
对于更大的数据集,以上 2 个好处又变成了 2 个坏处:其一,随着数据集的海量增长和内存限制,一次性载入所有的数据进来变得越来越不可行。其二,以 Rprop 的方式迭代,会由于各个 Batch 之间的采样差异性,各次梯度修正值相互抵消,无法修正。这才有了后来 RMSProp 的妥协方案。
Stochastic gradient descent
另一种,每看一个数据就算一下损失函数
,然后求梯度更新参数,这个称为随机梯度下降(Stochastic gradient descent)
。这个方法速度比较快,但是收敛性能不太好,可能在最优点附近晃来晃去,达不到最优点。两次参数的更新也有可能互相抵消掉,造成目标函数震荡的比较剧烈。
Mini-batch gradient decent
为了克服两种方法的缺点,现在一般采用的是一种折中手段,mini-batch gradient decent
,小批的梯度下降,这种方法把数据分为若干个批,按批来更新参数,这样,一个批中的一组数据共同决定了本次梯度的方向,下降起来就不容易跑偏,减少了随机性。另一方面因为批的样本数与整个数据集相比小了很多,计算量也不是很大。
现在用的优化器SGD是stochastic gradient descent的缩写,但不代表是一个样本就更新一回,还是基于mini-batch的。
批量梯度下降:批量大小=训练集的大小
随机梯度下降:批量大小= 1
小批量梯度下降:1 <批量大小<训练集的大小
在小批量梯度下降的情况下,流行的批量大小包括32,64和128个样本。
再谈Batch_Size
在合理范围内,增大 Batch_Size 有何好处?
内存利用率提高了,大矩阵乘法的并行化效率提高。
跑完一次 epoch(全数据集)所需的迭代次数减少,对于相同数据量的处理速度进一步加快。
在一定范围内,一般来说 Batch_Size 越大,其确定的下降方向越准,引起训练震荡越小。
盲目增大 Batch_Size 有何坏处?
内存利用率提高了,但是内存容量可能撑不住了。
跑完一次 epoch(全数据集)所需的迭代次数减少,要想达到相同的精度,其所花费的时间大大增加了,从而对参数的修正也就显得更加缓慢。
Batch_Size 增大到一定程度,其确定的下降方向已经基本不再变化。
深度学习的第一项任务——数据加载
数据加载流程——重要
以BCICIV_2a数据为例
import mne
import numpy as np
import torch
import torch.nn as nn
class LoadData:
def __init__(self,eeg_file_path: str):
self.eeg_file_path = eeg_file_path
def load_raw_data_gdf(self,file_to_load):
self.raw_eeg_subject = mne.io.read_raw_gdf(self.eeg_file_path + '/' + file_to_load)
return self
def load_raw_data_mat(self,file_to_load):
import scipy.io as sio
self.raw_eeg_subject = sio.loadmat(self.eeg_file_path + '/' + file_to_load)
def get_all_files(self,file_path_extension: str =''):
if file_path_extension:
return glob.glob(self.eeg_file_path+'/'+file_path_extension)
return os.listdir(self.eeg_file_path)
class LoadBCIC(LoadData):
'''Subclass of LoadData for loading BCI Competition IV Dataset 2a'''
def __init__(self, file_to_load, *args):
self.stimcodes=('769','770','771','772')
# self.epoched_data={}
self.file_to_load = file_to_load
self.channels_to_remove = ['EOG-left', 'EOG-central', 'EOG-right']
super(LoadBCIC,self).__init__(*args)
def get_epochs(self, tmin=0,tmax=1,baseline=None):
self.load_raw_data_gdf(self.file_to_load)
raw_data = self.raw_eeg_subject
# raw_downsampled = raw_data.copy().resample(sfreq=128)
self.fs = raw_data.info.get('sfreq')
events, event_ids = mne.events_from_annotations(raw_data)
stims =[value for key, value in event_ids.items() if key in self.stimcodes]
epochs = mne.Epochs(raw_data, events, event_id=stims, tmin=tmin, tmax=tmax, event_repeated='drop',
baseline=baseline, preload=True, proj=False, reject_by_annotation=False)
epochs = epochs.drop_channels(self.channels_to_remove)
self.y_labels = epochs.events[:, -1] - min(epochs.events[:, -1])
self.x_data = epochs.get_data()*1e6
eeg_data={'x_data':self.x_data,
'y_labels':self.y_labels,
'fs':self.fs}
return eeg_data
data_path = "/home/pytorch/LiangXiaohan/MI_Dataverse/BCICIV_2a_gdf"
file_to_load = 'A01T.gdf'
'''for BCIC Dataset'''
bcic_data = LoadBCIC(file_to_load, data_path)
eeg_data = bcic_data.get_epochs() # {'x_data':, 'y_labels':, 'fs':}
X = eeg_data.get('x_data')
Y = eeg_data.get('y_labels')
Y.shape
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=0)
X_train.shape
from sklearn.model_selection import StratifiedKFold
train_idx = {}
eval_idx = {}
skf = StratifiedKFold(n_splits=4, shuffle=True)
i = 0
for train_indices, eval_indices in skf.split(X_train, y_train):
train_idx.update({i: train_indices})
eval_idx.update({i: eval_indices})
i += 1
train_idx.get(1).shape
def split_xdata(eeg_data, train_idx, eval_idx):
x_train=np.copy(eeg_data[train_idx,:,:])
x_eval=np.copy(eeg_data[eval_idx,:,:])
x_train = torch.from_numpy(x_train).to(torch.float32)
x_eval = torch.from_numpy(x_eval).to(torch.float32)
return x_train, x_eval
def split_ydata(y_true, train_idx, eval_idx):
y_train = np.copy(y_true[train_idx])
y_eval = np.copy(y_true[eval_idx])
y_train = torch.from_numpy(y_train)
y_eval = torch.from_numpy(y_eval)
return y_train, y_eval
x_train, x_eval = split_xdata(X_train, train_idx.get(1), eval_idx.get(1))
y_train, y_eval = split_ydata(Y_train, train_idx.get(1), eval_idx.get(1))
y_train.shape
from torch.utils.data import Dataset, DataLoader, TensorDataset
from tqdm import tqdm
def BCICDataLoader(x_train, y_train, batch_size=64, num_workers=2, shuffle=True):
data = TensorDataset(x_train, y_train)
train_data = DataLoader(dataset=data, batch_size=batch_size, shuffle=shuffle, num_workers=num_workers)
return train_data
train_data = BCICDataLoader(x_train, y_train, batch_size=32)
for inputs, target in tqdm(train_data):
print(target)
到此数据就读出来了!!!
相关API解释
sklearn.model_selection.train_test_split
https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.train_test_split.html?highlight=train_test_split
sklearn.model_selection.StratifiedKFold
https://scikit-learn.org/stable/modules/generated/sklearn.model_selection.StratifiedKFold.html?highlight=stratifiedkfold#sklearn.model_selection.StratifiedKFold
torch.utils.data.TensorDataset
https://pytorch.org/docs/stable/data.html?highlight=tensordataset#torch.utils.data.TensorDataset
torch.utils.data.DataLoader
https://pytorch.org/docs/stable/data.html?highlight=dataloader#torch.utils.data.DataLoader
参考资料
深度学习中的batch、epoch、iteration的含义
神经网络中Batch和Epoch之间的区别是什么?
谈谈深度学习中的 Batch_Size
来源:https://blog.csdn.net/qq_41990294/article/details/127849876
猜你喜欢
- 前言我们知道在这个互联网时代,评论已经在我们的生活到处可见,评论区里面的信息是一个非常有趣和有争议的地方。我们今天,就来获取某技术平台的评论
- 前言np.argmax是用于取得数组中每一行或者每一列的的最大值。常用于机器学习中获取分类结果、计算精确度等。函数:numpy.argmax
- 【先锋缓存类】Ver2004作者:孙立宇、apollosun、ezhonghua官方网站:http://www.lkstar.com 技术支
- 前言当我们需要对列表(list)、元组(tuple)、字典(dictionary)和集合(set)的元素进行遍历时,其实Python内部都是
- 人脸检测方法有许多,比如opencv自带的人脸Haar特征分类器和dlib人脸检测方法等。对于opencv的人脸检测方法,有点是简单,快速;
- 接触Python时间不长,对有些知识点,掌握的不是很扎实,我个人比较崇尚不管学习什么东西,首先一定回去把基础打的非常扎实了,再往高处走。今天
- 1. 查找图像中出现的人脸代码示例:#导入face_recognition模块import face_recognition#将j
- 字体的处理在网页设计中无论怎么强调也不为过, 毕竟网页使用来传递信息的, 而最经典最直接的信息传递方式就是文字,&nbs
- 本文实例讲述了php替换字符串中间字符为省略号的方法。分享给大家供大家参考。具体分析如下:对于一个长字符串,如果你只希望用户看到头尾的部分内
- 在使用matplotlib模块时画坐标图时,往往需要对坐标轴设置很多参数,这些参数包括横纵坐标轴范围、坐标轴刻度大小、坐标轴名称等 在mat
- 现将几种主要情况进行小结: 一、如何输入NULL值 如果不输入null值,当时间为空时,会默认写入"1900-01-01"
- 下面就是我们的authenticate.asp页面,在这里,将用户的信息收集起来,连同最初的URL一起传到一个识别用户身份的页面中。我们可用
- 这个使用起来很简单,以前需要的时候在网上找的,用了感觉还不错,具体的看演示就明白了。,这个可以保留你文章中的html标记,需要你修改的就是下
- <%dim total(7,1) total(1,0)="中国经营报"
- 准确地讲,Python没有专门处理字节的数据类型。但由于str既是字符串,又可以表示字节,所以,字节数组=str。而在C语言中,我们可以很方
- 思路1.将姓名和单号填入excel表格里面2.读取excel表格,将所有姓名存到ExeclName这个list中,单号存到ExeclId3.
- 最近碰到一个mysql5数据库的问题。就是一个标准的servlet/tomcat网络应用,后台使用mysql数据库。问题是待机一晚上后,第二
- 例:公司员工采取三个轮班制度:凌晨0:00到早上8:00为第一班,早上8:00到下午4:00为第二班,下午4:00到晚上12:00为第三班。
- Python 的datetime模块 其实就是date和time 模块的结合,常见的属性方法都比较常用 比如: datetime.day,d
- 在上一篇Python接口自动化测试系列文章:Python接口自动化浅析logging日志原理及模块操作流程,主要介绍日志相关概念及loggi