Tensorflow的DataSet的使用详解
作者:月司 发布时间:2021-03-19 18:18:04
Dataset类是TensorFlow非常流行的存储数据的格式。常用来作为输入输出。data模块主要的用途就是通过这种方法创建Dataset。
Dataset使用过程中的一些心得:
经常将自变量X数据以及target数据以元组的形式包裹,如db_train=tf.data.Dataset.from_tensor_slices((x_train,y_train)),创建Dataset。模型的fit()方法可以自动的解包。
Dataset能够包括比较灵活的类型,比如db_train=tf.data.Dataset.from_tensor_slices(({"features":features_train,"biomass_start":biomass_start_trarin},y_train))。因为数据最外部依然是最外部包裹,所以model的fit()依然可以自动的对x以及target解包。但由于dataset保存component是以原始数据的形式保存的。所以,fit()里的inputs一般是这个样子:
{'features': <tf.Tensor 'my_rnn/Cast_1:0' shape=(None, 5, 4) dtype=float32>, 'biomass_start': <tf.Tensor 'my_rnn/Cast:0' shape=(None, 1) dtype=float32>}
对于字典内部部分,需要手动的自己解包。这样的好处是,给我们自定义模型的结构提供的很大的遍历,输入一部分导入A网络,一部分导入不同的B网络。
Dataset作为模型的输入,需要设定batch()。而不在模型内设定batch。更加方便。然而Dataset作为迭代器,迭代完成后再次迭代数据,生成数据的前后数据是不一样的。需要注意。
batch的drop_remainder=True参数比较重要,只有设定为True,input接下来的层还能正确的识别shape
Dataset的常用属性
Dataset.element_spec
这个属性可以检测每一个元素中的component的类型。返回的是一个tf.TypeSpec对象。这个对象的结构跟元素的结构是一致的。
dataset1 = tf.data.Dataset.from_tensor_slices(tf.random.uniform([4, 10]))
dataset1.element_spec
#TensorSpec(shape=(10,), dtype=tf.float32, name=None)
dataset2 = tf.data.Dataset.from_tensor_slices(
(tf.random.uniform([4]),
tf.random.uniform([4, 100], maxval=100, dtype=tf.int32)))
dataset2.element_spec
# 标量和向量
# (TensorSpec(shape=(), dtype=tf.float32, name=None),
#TensorSpec(shape=(100,), dtype=tf.int32, name=None))
dataset = tf.data.Dataset.from_tensor_slices(([1, 2], [3, 4], [5, 6]))
dataset.element_spec
#(TensorSpec(shape=(), dtype=tf.int32, name=None),
# TensorSpec(shape=(), dtype=tf.int32, name=None),
# TensorSpec(shape=(), dtype=tf.int32, name=None))
# 注意这里是字典类型
dataset = tf.data.Dataset.from_tensor_slices({"a": [1, 2], "b": [3, 4]})
dataset.element_spec
#{'a': TensorSpec(shape=(), dtype=tf.int32, name=None),
# 'b': TensorSpec(shape=(), dtype=tf.int32, name=None)}
Dataset的常用方法
apply方法
对dataset进行转换。
dataset = tf.data.Dataset.range(100)
def dataset_fn(ds):
return ds.filter(lambda x: x < 5)
dataset = dataset.apply(dataset_fn)
list(dataset.as_numpy_iterator())
as_numpy_iterator
dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3])
for element in dataset.as_numpy_iterator():
print(element)
这个在dataset比较常用。就是将dataset变成迭代器,将所有元素都变成numy对象输出
shuffle
shuffle(
buffer_size, seed=None, reshuffle_each_iteration=None, name=None
)
参数:
buffer_size:缓冲区大小
seed:随机种子
reshuffle_each_iteration:bool. 如果为真,表示每次迭代时数据集完成后都应该是进行伪随机重新洗牌的。控制每个epoch的洗牌顺序是否不同。
这个方法用来随机打乱数据集的元素顺序。数据集用buffer_size元素填充一个缓冲区,然后从这个缓冲区随机取样元素,用新元素替换选中的元素。例如,如果您的数据集包含10,000个元素,但是buffer_size被设置为1,000,那么shuffle将首先从缓冲区中的前1,000个元素中选择一个随机元素。一旦一个元素被选中,它在缓冲区中的空间就会被下一个(比如第1001个)元素替换,从而保持这个1,000元素缓冲区。为了实现完美的洗牌,需要一个大于或等于数据集完整大小的缓冲区。
dataset = tf.data.Dataset.range(3)
# 每个每个epoch重新洗牌
dataset = dataset.shuffle(3, reshuffle_each_iteration=True)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 2, 0]
dataset = tf.data.Dataset.range(3)
# 每个每个epoch不重新洗牌
dataset = dataset.shuffle(3, reshuffle_each_iteration=False)
list(dataset.as_numpy_iterator())
# [1, 0, 2]
list(dataset.as_numpy_iterator())
# [1, 0, 2]
batch
batch(
batch_size,
drop_remainder=False,
num_parallel_calls=None,
deterministic=None,
name=None
)
参数:
batch_size: 批处理大小
drop_remainder:是否删除最后一个短batch。==这个比较重要,只有设定为Ture,model才能正确的判断其输入的shape。==这也比较合理,指定为Falsel,因为谁也不知道后面是不是有一个比较短的batch,只有第一维是None,才能提高程序的稳定性。
num_parallel_calls:并行计算的数量。不指定会顺序执行。如果有 tf.data.AUTOTUNE,会自动动态的制定这个值。
deterministic:bool. 指定了num_parallel_calls,才有效。如果设置为False,则允许转换产生无序元素,以牺牲确定性来换取性能。如果不指定,tf.data.Options.deterministic控制这个行为(默认为True)
name: 标识符
这个方法经常使用,将dataset进行批处理化。因为数据集比较大的时候,一下子完全进行训练占用大量的内存。所以用分批处理。输出的元素增加了一个额外的维度,就是batch维,shape是batch的size.
batch支持一个drop_remainder=True关键字,为真意味着,最后一个batch的size如果小于我们指定值,就会被舍弃。
之所以要删掉最后一个短的batch,是因为如果我们的项目依赖这个batch的size,那最后一个batch不等长,可能会出错。
import tensorflow as tf
from tensorflow.python.data import Dataset
dataset = tf.data.Dataset.range(8)
dataset = dataset.batch(3)
print(list(dataset.as_numpy_iterator()))
# 通过这个看到这个elem也已经是分批了
for elem in dataset:
print(elem)
# tf.Tensor([0 1 2], shape=(3,), dtype=int64)
# tf.Tensor([3 4 5], shape=(3,), dtype=int64)
# tf.Tensor([6 7], shape=(2,), dtype=int64)
for elem in dataset.as_numpy_iterator():
print(elem)
# [0 1 2]
# [3 4 5]
# [6 7]
dataset = tf.data.Dataset.range(8)
# drop_remainder舍掉最后一个长度不够的batch
dataset = dataset.batch(3, drop_remainder=True)
list(dataset.as_numpy_iterator())
一般情况下,shuffle跟batch是连续使用的,实现随机读取并批量处理数据:dataset.shuffle(buffer_size).batch(batchsize)
不能对已经batch的dataset进行连续的batch操作,其batchsize不会改变,而是生成了新的异常数据
unbatch
unbatch(
name=None
)
这里是将Batchdataset这样的dataset分割为一个个元素,元素的格式跟定义时的格式是一样的。而且,这里固定的是对第1个维度进行split操作,且生成shape[0]个元素。
reduce方法
reduce(
initial_state, reduce_func, name=None
)
将输入数据集简化为一个元素。 reduce_func作用于dataset中每一个元素,输出其dataset的聚合信息。
参数initial_state代表进行reduce之前的初始状态。reduce_func要接收old_state, input_element两个参数,然后生成新的状态newstate。old_state和new_state的结构要一致。
dataset = tf.data.Dataset.from_tensor_slices([8, 3, 0, 8, 2, 1])
print(dataset.reduce(0, lambda state, value: state + value).numpy())
# 22
dataset不支持tf.split属性,也不能直接把dataset给切分为训练集和测试集。
来源:https://blog.csdn.net/yue81560/article/details/128691866


猜你喜欢
- 给一个例子 :# -*- coding: utf-8 -*-import matplotlib.pyplot as plt im
- 数据持久化vuex-persistedstate使用vuex是在中大型项目中必不可少的状态管理组件,刷新会重新更新状态,但是有时候我们并不希
- requests 是一个非常小巧全面的库,应用它可以很容易写出与服务器进行交互的程序,今天遇到了一个问题,与服务器交互时,url都是http
- 大概在九九年做游戏网站的时候,就对文章的发布感到麻烦,不过那会儿玩ASP不精。只是将就用着。在遇到长文件 10000 字时网页就是一大片长了
- 什么是pyecharts?pyecharts 是一个用于生成 Echarts 图表的类库。echarts是百度开源的一个数据可视化 JS 库
- 这个效果本身难度不大,主要在程序结构和扩展中下了些功夫,务求用起来更方便,能用在更多的地方。程序特点 1,同一个提示框用在多个触发元素时,只
- 起因事情是这样的,项目最近有个需求。服务器有个图片空间,说白了就是个文件夹。文件夹的结构大家都知道,一层一层的。然后需要在前端以树形展示。具
- 本文实例讲述了Python微信企业号文本消息推送功能。分享给大家供大家参考,具体如下:企业号的创建、企业号应用的创建、组、tag、part就
- 这边我是需要得到图片在Vgg的5个block里relu后的Feature Map (其余网络只需要替换就可以了)索引可以这样获得vgg =
- 前言最近有朋友在做投票的项目,里面有用到一个倒计时的组件,还想要个动画效果。cv * 浸染多年的我,首先想到的是直接找个现有的组件。通过一通搜
- 目录0 背景说明0.1 获取AccessToken0.2 数据库查询0.3 文件下载2. 简单的封装3. 简单测试4. 参考文档0 背景说明
- 1.lxml库简介lxml 是 Python 常用的文档解析库,能够高效地解析 HTML/XML 文档,常用于 Python 爬虫。lxml
- 在做网站产品展示页面时,一般会用到缩略图,好处当然是直观醒目让人一目了然。点击进入然后看到大图及具体的介绍。但是缩略图在实现上带来了两个问题
- 1、页签的表达。页签表达很清晰,当前页签突出,且层级包涵关系明确;看下图,一目了然的感觉,不用疑惑我在那部分里。不信?拿当当的对比一下,你感
- 前言在日常开发中,我们往往会将 JSON 解析成对应的结构体,反之也会将结构体转成 JSON。接下来本文会通过 JSON 包的两个函数,来介
- 如果你取相对路径不是在主文件里,可能就会有相对路径问题:"No such file or directory"。因为 p
- 在推广Web标准的今天,那些崇尚Web标准的人经常说XHTML比HTML更加严格,当然从某种意义上说是的,比如它要求所有的标签关闭并且所有的
- 与其他技术相比,Git应该拯救了更多开发人员的饭碗。只要你经常使用Git保存自己的工作,你就一直有机会可以将代码退回到之前的状态,因此就可以
- 本文实例讲述了php防止sql注入中过滤分页参数的方法。分享给大家供大家参考。具体分析如下:就网络安全而言,在网络上不要相信任何输入信息,对
- 1、Session的存储方式。 session其实分为客户端Session和服务器端Session。 当用户首次与Web服务器建立连接的时候