浅谈tensorflow中dataset.shuffle和dataset.batch dataset.repeat注意点
作者:青盏 发布时间:2023-12-10 09:54:31
batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch size
dataset.repeat就是俗称epoch,但在tf中与dataset.shuffle的使用顺序可能会导致个epoch的混合
dataset.shuffle就是说维持一个buffer size 大小的 shuffle buffer,图中所需的每个样本从shuffle buffer中获取,取得一个样本后,就从源数据集中加入一个样本到shuffle buffer中。
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(3)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()
with tf.Session() as sess:
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
#源数据集
[[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
[ 0.43758721 0.891773 ]
[ 0.96366276 0.38344152]
[ 0.79172504 0.52889492]
[ 0.56804456 0.92559664]
[ 0.07103606 0.0871293 ]
[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.97861834 0.79915856]]
# 通过shuffle batch后取得的样本
[[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
[ 0.5488135 0.71518937]]
[[ 0.96366276 0.38344152]
[ 0.56804456 0.92559664]
[ 0.0202184 0.83261985]
[ 0.79172504 0.52889492]]
[[ 0.07103606 0.0871293 ]
[ 0.97861834 0.79915856]
[ 0.77815675 0.87001215]] #最后一个batch样本个数为3
[[ 0.60276338 0.54488318]
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
[ 0.79172504 0.52889492]]
[[ 0.4236548 0.64589411]
[ 0.56804456 0.92559664]
[ 0.0202184 0.83261985]
[ 0.07103606 0.0871293 ]]
[[ 0.77815675 0.87001215]
[ 0.96366276 0.38344152]
[ 0.97861834 0.79915856]] #最后一个batch样本个数为3
1、按照shuffle中设置的buffer size,首先从源数据集取得三个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
2、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
batch:
[ 0.4236548 0.64589411]
3、shuffle buffer不足三个样本,从源数据集提取一个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
4、从buffer中取一个样本到batch中得:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
batch:
[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
5、如此反复。这就意味中如果shuffle 的buffer size=1,数据集不打乱。如果shuffle 的buffer size=数据集样本数量,随机打乱整个数据集
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(1)
dataset = dataset.batch(4)
dataset = dataset.repeat(2)
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()
with tf.Session() as sess:
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
[[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
[ 0.43758721 0.891773 ]
[ 0.96366276 0.38344152]
[ 0.79172504 0.52889492]
[ 0.56804456 0.92559664]
[ 0.07103606 0.0871293 ]
[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.97861834 0.79915856]]
[[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
[ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
[ 0.79172504 0.52889492]
[ 0.56804456 0.92559664]
[ 0.07103606 0.0871293 ]]
[[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.97861834 0.79915856]]
[[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
[ 0.43758721 0.891773 ]]
[[ 0.96366276 0.38344152]
[ 0.79172504 0.52889492]
[ 0.56804456 0.92559664]
[ 0.07103606 0.0871293 ]]
[[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.97861834 0.79915856]]
注意如果repeat在shuffle之前使用:
官方说repeat在shuffle之前使用能提高性能,但模糊了数据样本的epoch关系
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ""
import numpy as np
import tensorflow as tf
np.random.seed(0)
x = np.random.sample((11,2))
# make a dataset from a numpy array
print(x)
print()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.repeat(2)
dataset = dataset.shuffle(11)
dataset = dataset.batch(4)
# create the iterator
iter = dataset.make_one_shot_iterator()
el = iter.get_next()
with tf.Session() as sess:
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
print(sess.run(el))
[[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
[ 0.43758721 0.891773 ]
[ 0.96366276 0.38344152]
[ 0.79172504 0.52889492]
[ 0.56804456 0.92559664]
[ 0.07103606 0.0871293 ]
[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.97861834 0.79915856]]
[[ 0.56804456 0.92559664]
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.07103606 0.0871293 ]]
[[ 0.96366276 0.38344152]
[ 0.43758721 0.891773 ]
[ 0.43758721 0.891773 ]
[ 0.77815675 0.87001215]]
[[ 0.79172504 0.52889492] #出现相同样本出现在同一个batch中
[ 0.79172504 0.52889492]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]]
[[ 0.07103606 0.0871293 ]
[ 0.4236548 0.64589411]
[ 0.96366276 0.38344152]
[ 0.5488135 0.71518937]]
[[ 0.97861834 0.79915856]
[ 0.0202184 0.83261985]
[ 0.77815675 0.87001215]
[ 0.56804456 0.92559664]]
[[ 0.0202184 0.83261985]
[ 0.97861834 0.79915856]] #可以看到最后个batch为2,而前面都是4
使用案例:
def input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False):
print('Parsing', filenames)
def decode_libsvm(line):
#columns = tf.decode_csv(value, record_defaults=CSV_COLUMN_DEFAULTS)
#features = dict(zip(CSV_COLUMNS, columns))
#labels = features.pop(LABEL_COLUMN)
columns = tf.string_split([line], ' ')
labels = tf.string_to_number(columns.values[0], out_type=tf.float32)
splits = tf.string_split(columns.values[1:], ':')
id_vals = tf.reshape(splits.values,splits.dense_shape)
feat_ids, feat_vals = tf.split(id_vals,num_or_size_splits=2,axis=1)
feat_ids = tf.string_to_number(feat_ids, out_type=tf.int32)
feat_vals = tf.string_to_number(feat_vals, out_type=tf.float32)
#feat_ids = tf.reshape(feat_ids,shape=[-1,FLAGS.field_size])
#for i in range(splits.dense_shape.eval()[0]):
# feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32))
# feat_vals.append(tf.string_to_number(splits.values[2*i+1]))
#return tf.reshape(feat_ids,shape=[-1,field_size]), tf.reshape(feat_vals,shape=[-1,field_size]), labels
return {"feat_ids": feat_ids, "feat_vals": feat_vals}, labels
# Extract lines from input files using the Dataset API, can pass one filename or filename list
dataset = tf.data.TextLineDataset(filenames).map(decode_libsvm, num_parallel_calls=10).prefetch(500000) # multi-thread pre-process then prefetch
# Randomizes input using a window of 256 elements (read into memory)
if perform_shuffle:
dataset = dataset.shuffle(buffer_size=256)
# epochs from blending together.
dataset = dataset.repeat(num_epochs)
dataset = dataset.batch(batch_size) # Batch size to use
#return dataset.make_one_shot_iterator()
iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()
#return tf.reshape(batch_ids,shape=[-1,field_size]), tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels
return batch_features, batch_labels
来源:https://blog.csdn.net/qq_16234613/article/details/81703228
猜你喜欢
- 分析数字中经常是3个数字一组,之后跟一个逗号,因此规律为:***,***,***正则式[a-z]+,[a-z]?import resen =
- 绘制双变量联合分布图有时我们不仅需要查看单个变量的分 布,同时也需要查看变量之间的联系, 往往还需要进行预测等。这时就需要用到双变量联合分布
- 使用cv2对视频进行切割import cv2def clip_video(source_video, target_video, start
- 1. 概念1.1 基本概念时间,对于我们来说很重要,什么时候做什么?什么时候发生什么?没有时间的概念,生活就乱了。在日常的运维当中,我们更关
- 人口普查人口数量变化图1 第七次人口普查不同省份总人口import pandas as pdfrom collections import
- Django2.0中编写models类下的ForeignKeybook = models.ForeignKey('BookInfo&
- 这是我使用python写的第一个类(也算是学习面向对象语言以来正式写的第一个解耦的类),记录下改进的过程。分析需求最初,因为使用time模块
- 一. 元组元组是Python中的一个内置的数据结构,它是一个不可变的序列,所谓的不可变序列就是不可以进行增删改的操作。1.1 元组的创建元组
- 概述从今天开始我们将开启一段自然语言处理 (NLP) 的旅程. 自然语言处理可以让来处理, 理解, 以及运用人类的语言, 实现机器语言和人类
- 1:为什么每个layout下都有个inlayout?我们将layout的宽/浮动等属性设置好之后,对于layout内的padding和mar
- blur事件在元素失去焦点时触发。在一些jquery的教程、api手册等上面对blur事件,提供了一个错误的例子,就是关于p标签失去焦点的问
- 一、绘制折线图import seaborn as snsimport numpy as npimport pandas as pdimpor
- 在自己的网站主页上增加社会化分享按钮,是有效提高自己网站流量的一种方法。今天我在无争围棋网上增加了社会化按钮,根据我个人的习惯,我选择了豆瓣
- 如何在约定时间显示特定的提示信息?<%Function Greeting()
- 百度有啊2009年情人节logo——大纸袋GG给大纸袋MM送了枝玫瑰花,大纸袋MM奖励了大纸袋GG一个吻,好可爱!淘宝网2009年情人节lo
- Python中的多线程其实并不是真正的多线程,如果想要充分地使用多核CPU的资源,在python中大部分情况需要使用多进程。Python提供
- PHP 向它运行的任何脚本提供了大量的预定义常量。魔术常量准确来说并不能算是常量,常量我们在之前的文章中我们介绍到,常量被定义之后是不能被改
- 在DreamWeaver中编写CSS,这种编写习惯本站(twocity.cn)并不提倡,不过由于"可视化"和操作简便,使
- 目录前言super的用法super的原理Python super()使用注意事项混用super与显式类调用不同种类的参数总结前言Python
- 这篇文章主要介绍了Python加密模块的hashlib,hmac模块使用解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的