浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点
作者:青盏 发布时间:2023-12-10 09:54:31
batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch size
dataset.repeat就是俗称epoch,但在tf中与dataset.shuffle的使用顺序可能会导致个epoch的混合
dataset.shuffle就是说维持一个buffer size 大小的 shuffle buffer,图中所需的每个样本从shuffle buffer中获取,取得一个样本后,就从源数据集中加入一个样本到shuffle buffer中。
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(3)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()
with tf.Session() as sess:
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
#源数据集
[[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
[ 0.43758721 0.891773 ]
[ 0.96366276 0.38344152]
[ 0.79172504 0.52889492]
[ 0.56804456 0.92559664]
[ 0.07103606 0.0871293 ]
[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.97861834 0.79915856]]
# 通过shuffle batch后取得的样本
[[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
[ 0.5488135 0.71518937]]
[[ 0.96366276 0.38344152]
[ 0.56804456 0.92559664]
[ 0.0202184 0.83261985]
[ 0.79172504 0.52889492]]
[[ 0.07103606 0.0871293 ]
[ 0.97861834 0.79915856]
[ 0.77815675 0.87001215]] #最后一个batch样本个数为3
[[ 0.60276338 0.54488318]
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
[ 0.79172504 0.52889492]]
[[ 0.4236548 0.64589411]
[ 0.56804456 0.92559664]
[ 0.0202184 0.83261985]
[ 0.07103606 0.0871293 ]]
[[ 0.77815675 0.87001215]
[ 0.96366276 0.38344152]
[ 0.97861834 0.79915856]] #最后一个batch样本个数为3
1、按照shuffle中设置的buffer size,首先从源数据集取得三个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
2、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
batch:
[ 0.4236548 0.64589411]
3、shuffle buffer不足三个样本,从源数据集提取一个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
4、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
batch:
[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
5、如此反复。这就意味中如果shuffle 的buffer size=1,数据集不打乱。如果shuffle 的buffer size=数据集样本数量,随机打乱整个数据集
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(1)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()
with tf.Session() as sess:
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
[[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
[ 0.43758721 0.891773 ]
[ 0.96366276 0.38344152]
[ 0.79172504 0.52889492]
[ 0.56804456 0.92559664]
[ 0.07103606 0.0871293 ]
[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.97861834 0.79915856]]
[[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
[ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
[ 0.79172504 0.52889492]
[ 0.56804456 0.92559664]
[ 0.07103606 0.0871293 ]]
[[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.97861834 0.79915856]]
[[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
[ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
[ 0.79172504 0.52889492]
[ 0.56804456 0.92559664]
[ 0.07103606 0.0871293 ]]
[[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.97861834 0.79915856]]
注意如果repeat在shuffle之前使用:
官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.repeat(2)
dataset = dataset.shuffle(11)
dataset = dataset.batch(4)
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()
with tf.Session() as sess:
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
[[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
[ 0.43758721 0.891773 ]
[ 0.96366276 0.38344152]
[ 0.79172504 0.52889492]
[ 0.56804456 0.92559664]
[ 0.07103606 0.0871293 ]
[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.97861834 0.79915856]]
[[ 0.56804456 0.92559664]
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.07103606 0.0871293 ]]
[[ 0.96366276 0.38344152]
[ 0.43758721 0.891773 ]
[ 0.43758721 0.891773 ]
[ 0.77815675 0.87001215]]
[[ 0.79172504 0.52889492] #出现相同样本出现在同一个batch中
[ 0.79172504 0.52889492]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]]
[[ 0.07103606 0.0871293 ]
[ 0.4236548 0.64589411]
[ 0.96366276 0.38344152]
[ 0.5488135 0.71518937]]
[[ 0.97861834 0.79915856]
[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.56804456 0.92559664]]
[[ 0.0202184 0.83261985]
[ 0.97861834 0.79915856]] #可以看到最后个batch为2,而前面都是4
使用案例:
def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
print('Parsing', filenames)
def decode_libsvm(line):
#columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS)
#features = dict(zip(CSV_COLUMNS, columns))
#labels = features.pop(LABEL_COLUMN)
columns = tf.string_split([line], ' ')
labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
splits = tf.string_split(columns.values[1:], ':')
id_vals = tf.reshape(splits.values,splits.dense_shape)
feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1)
feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
#feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])
#for i in range(splits.dense_shape.eval()[0]):
# feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32))
# feat_vals.append(tf.string_to_number(splits.values[2*i+1]))
#return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels
return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels
# Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000) # multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory)
if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256)
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size) # Batch size to use
#return dataset.make_one_shot_iterator()
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
#return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
return batch_features, batch_labels
来源:https://blog.csdn.net/qq_16234613/article/details/81703228


猜你喜欢
- 就像标题呈现的一样,SQL Server 2008中的MERGE语句能做很多事情,它的功能是根据源表对目标表执行插入、更新或删除操作。最典型
- 一、基础表单 <form > <div class="form-group"> <labe
- 交代背景作为一名合格的 Python 程序员,在工作中必然会用到二维码相关操作,那如何快速的用 Python 实现呢?别着急,咱们这篇博客就
- 02条件语句和while循环三目运算a = 6#原判断语句if a > 5:print(True)else:print(False)#
- 工作时同事间几mb小文件的传输,一般使用QQ或者微信就足够了,但当传输文件几百MB或者几十G时,这种方法的效率就显得不足了。本篇就是简单说明
- 本文提供一种方法,通过将字符串编码成Unicode格式,保证数据在展示和传输过程中万无一失。无论客户端浏览器如何改变编码,页面上的编码都不会
- 简介pandas中的DF数据类型可以像数据库表格一样进行groupby操作。通常来说groupby操作可以分为三部分:分割数据,应用变换和和
- 本文主要讲解如何使用python绘制三维的柱形图,如下图源代码如下:import numpy as npimport matplotlib.
- Python下一切皆对象,每个对象都有多个属性(attribute),Python对属性有一套统一的管理方案。__dict__与dir()的
- 在线音乐播放器,使用python的Tkinter库做了一个界面,感觉这个库使用起来还是挺方便的,音乐的数据来自网易云音乐的一个接口,通过ur
- asp 中处理文件上传以及删除时常用的自定义函数:删除文件,建立目录的程序,根据原文件名生成新的随机文件名,CMS替换函数,将所有开始,结束
- 一、python numpy + matplotlib 画股票k线图# -- coding: utf-8 --import requests
- 前言:加班原因是上线,解决线上数据库存在重复数据的问题,发现了程序的bug,很好解决,有点问题的是,修正线上的重复数据。线上库有6个表存在重
- 一、概述在一般的sql操作中,sql语句基本上都是固定的,如: SELECT t.empno,t.ename FROM scott
- 一、eval()函数是什么?Python的一个内置函数;返回传入字符串的表达式结果(官方)二、eval()函数语法解析三、eval()函数应
- 本文实例讲述了Python实现向服务器请求压缩数据及解压缩数据的方法。分享给大家供大家参考,具体如下:向服务器请求压缩数据格式,并解压缩数据
- 出现的问题: 在 vue-cli 创建的项目中,创建文件并命名后,会报 “Compone
- 本文实例讲述了Django自定义过滤器定义与用法。分享给大家供大家参考,具体如下:一、自定义过滤器的介绍前面我们就介绍过过滤器其实就是一个函
- string操作在编程中具有极高的频率,那么string中有哪些有用的方法呢?使用strings直接操作Comparefunc Compar
- ScrapyScrapy是纯python实现的一个为了爬取网站数据、提取结构性数据而编写的应用框架。Scrapy使用了Twisted异步网络