tensorflow中next_batch的具体使用
作者:小妖精Fsky 发布时间:2023-04-21 05:34:02
标签:tensorflow,nextbatch
本文介绍了tensorflow中next_batch的具体使用,分享给大家,具体如下:
此处给出了几种不同的next_batch方法,该文章只是做出代码片段的解释,以备以后查看:
def next_batch(self, batch_size, fake_data=False):
"""Return the next `batch_size` examples from this data set."""
if fake_data:
fake_image = [1] * 784
if self.one_hot:
fake_label = [1] + [0] * 9
else:
fake_label = 0
return [fake_image for _ in xrange(batch_size)], [
fake_label for _ in xrange(batch_size)
]
start = self._index_in_epoch
self._index_in_epoch += batch_size
if self._index_in_epoch > self._num_examples: # epoch中的句子下标是否大于所有语料的个数,如果为True,开始新一轮的遍历
# Finished epoch
self._epochs_completed += 1
# Shuffle the data
perm = numpy.arange(self._num_examples) # arange函数用于创建等差数组
numpy.random.shuffle(perm) # 打乱
self._images = self._images[perm]
self._labels = self._labels[perm]
# Start next epoch
start = 0
self._index_in_epoch = batch_size
assert batch_size <= self._num_examples
end = self._index_in_epoch
return self._images[start:end], self._labels[start:end]
该段代码摘自mnist.py文件,从代码第12行start = self._index_in_epoch开始解释,_index_in_epoch-1是上一次batch个图片中最后一张图片的下边,这次epoch第一张图片的下标是从 _index_in_epoch开始,最后一张图片的下标是_index_in_epoch+batch, 如果 _index_in_epoch 大于语料中图片的个数,表示这个epoch是不合适的,就算是完成了语料的一遍的遍历,所以应该对图片洗牌然后开始新一轮的语料组成batch开始
def ptb_iterator(raw_data, batch_size, num_steps):
"""Iterate on the raw PTB data.
This generates batch_size pointers into the raw PTB data, and allows
minibatch iteration along these pointers.
Args:
raw_data: one of the raw data outputs from ptb_raw_data.
batch_size: int, the batch size.
num_steps: int, the number of unrolls.
Yields:
Pairs of the batched data, each a matrix of shape [batch_size, num_steps].
The second element of the tuple is the same data time-shifted to the
right by one.
Raises:
ValueError: if batch_size or num_steps are too high.
"""
raw_data = np.array(raw_data, dtype=np.int32)
data_len = len(raw_data)
batch_len = data_len // batch_size #有多少个batch
data = np.zeros([batch_size, batch_len], dtype=np.int32) # batch_len 有多少个单词
for i in range(batch_size): # batch_size 有多少个batch
data[i] = raw_data[batch_len * i:batch_len * (i + 1)]
epoch_size = (batch_len - 1) // num_steps # batch_len 是指一个batch中有多少个句子
#epoch_size = ((len(data) // model.batch_size) - 1) // model.num_steps # // 表示整数除法
if epoch_size == 0:
raise ValueError("epoch_size == 0, decrease batch_size or num_steps")
for i in range(epoch_size):
x = data[:, i*num_steps:(i+1)*num_steps]
y = data[:, i*num_steps+1:(i+1)*num_steps+1]
yield (x, y)
第三种方式:
def next(self, batch_size):
""" Return a batch of data. When dataset end is reached, start over.
"""
if self.batch_id == len(self.data):
self.batch_id = 0
batch_data = (self.data[self.batch_id:min(self.batch_id +
batch_size, len(self.data))])
batch_labels = (self.labels[self.batch_id:min(self.batch_id +
batch_size, len(self.data))])
batch_seqlen = (self.seqlen[self.batch_id:min(self.batch_id +
batch_size, len(self.data))])
self.batch_id = min(self.batch_id + batch_size, len(self.data))
return batch_data, batch_labels, batch_seqlen
第四种方式:
def batch_iter(sourceData, batch_size, num_epochs, shuffle=True):
data = np.array(sourceData) # 将sourceData转换为array存储
data_size = len(sourceData)
num_batches_per_epoch = int(len(sourceData) / batch_size) + 1
for epoch in range(num_epochs):
# Shuffle the data at each epoch
if shuffle:
shuffle_indices = np.random.permutation(np.arange(data_size))
shuffled_data = sourceData[shuffle_indices]
else:
shuffled_data = sourceData
for batch_num in range(num_batches_per_epoch):
start_index = batch_num * batch_size
end_index = min((batch_num + 1) * batch_size, data_size)
yield shuffled_data[start_index:end_index]
迭代器的用法,具体学习Python迭代器的用法
另外需要注意的是,前三种方式只是所有语料遍历一次,而最后一种方法是,所有语料遍历了num_epochs次
来源:http://blog.csdn.net/appleml/article/details/57413615


猜你喜欢
- 先上网卡数据采集脚本,这个基本上是最大的坑,因为一些数据的类型不正确会导致no datapoint的错误,真是令人抓狂,注意其中几个key的
- 这篇文章主要介绍了简单了解为什么python函数后有多个括号,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需
- 1.requests库简介requests 是 Python 中比较常用的网页请求库,主要用来发送 HTTP 请求,在使用爬虫或测试服务器响
- ipad的goodreader对JS文件支持不太好,虽然可以读取它但总是无法退出,回不了goodreader的主界面,因此我需要把js文件批
- 一、背景俗话说,工欲善其事,必先利其器。go 作为一个对基础功能封装非常好的语言,对编码体验,如何更高效地写出高性能代码,都是考虑非常好的。
- 本来而言,这个问题网上很多资料,但是网上资料都是复制来复制去,很多话大家其实都不是很明白的,或者拿着官方文档翻译过来的,让人看的非常迷糊。今
- 前言开始几天,我是使用很原始的方法,自己去获取天气预报截图,再手动发送给小姐姐。连续几天之后我一想:不对呀,我怎么说也是一个程序猿,怎么能用
- 前言其实就是个小问题,但是爆出来的时候也很莫名其妙。因为之前都跑得好好的,只是换了不同的文件去跑才出的问题,关键是不同的文件要处理的内容和格
- 配置好virtualenv 和virtualenvwrapper后,使用pycharm创建新项目。之后要面临的问题就来了,之前一直使用的是s
- 我的读者知道我是一个喜欢痛骂Python3 unicode的人。这次也不例外。我将会告诉你用unicode有多痛苦和为什么我不能闭嘴。我花了
- 1.nginx使用哪种网络协议? nginx是应用层 我觉得从下往上的话 传输层用的是tcp/ip 应用层用的是http fastcgi负责
- 为什么能实现在线编辑呢? 首先需要ie 的支持,在 ie 5.5以后就有一个编辑状态,就是利用这个编辑状态,然后用javascript 来控
- Python的版本有很多,很多第三方库也有很多不同的版本,不同的版本也可能是互不兼容的,在本机运行不同的项目,可能需要不同的环境。为了不和本
- 由于个人能力有限,文章中难免会出现错误或遗漏的地方,敬请谅解!同时欢迎你指出,以便我能及时修改,以免误导下一个看官。最后希望本文能给你带来一
- 主题众所周知,django.forms极其强大,不少的框架也借鉴了这个模式,如Scrapy。在表单验证时,django.forms是一绝,也
- 一、http请求1、http请求方式:get和postget一般用于获取/查询资源信息,在浏览器中直接输入url+请求参数点击enter之后
- 方式一:叠加文字水印最简单的一种方式是,在图片上绘制半透明文本来实现水印效果。主要用到Figure.text函数参数类型说明x, yfloa
- 实现逻辑1、Golang 版本 1.32、实现原理:1、主进程建立TCP监听服务,并且初始化一个变量 talkChan := m
- 1.1.1 摘要 Join是关系型数据库系统的重要操作之一,SQL Server中包含的常用Join:内联接、外联接和交叉联接等。如果我们想
- 一、什么要备份数据库 ?在现实IT世界里,我们使用的服务器硬件可能因为使用时间过长,而发生故障;Windows系列服务器有可能蓝屏或者感染病