keras的ImageDataGenerator和flow()的用法说明
作者:o0程卓0o 发布时间:2021-12-12 08:54:57
标签:keras,ImageDataGenerator,flow
ImageDataGenerator的参数自己看文档
from keras.preprocessing import image
import numpy as np
X_train=np.ones((3,123,123,1))
Y_train=np.array([[1],[2],[2]])
generator=image.ImageDataGenerator(featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-6,
rotation_range=180,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0,
zoom_range=0.001,
channel_shift_range=0,
fill_mode='nearest',
cval=0.,
horizontal_flip=True,
vertical_flip=True,
rescale=None,
preprocessing_function=None,
data_format='channels_last')
a=generator.flow(X_train,Y_train,batch_size=20)#生成的是一个迭代器,可直接用于for循环
'''
batch_size如果小于X的第一维m,next生成的多维矩阵的第一维是为batch_size,输出是从输入中随机选取batch_size个数据
batch_size如果大于X的第一维m,next生成的多维矩阵的第一维是m,输出是m个数据,不过顺序随机
,输出的X,Y是一一对对应的
如果要直接用于tf.placeholder(),要求生成的矩阵和要与tf.placeholder相匹配
'''
X,Y=next(a)
print(Y)
X,Y=next(a)
print(Y)
X,Y=next(a)
print(Y)
X,Y=next(a)
输出
[[2]
[1]
[2]]
[[2]
[2]
[1]]
[[2]
[2]
[1]]
[[2]
[2]
[1]]
补充知识:tensorflow 与keras 混用之坑
在使用tensorflow与keras混用是model.save 是正常的但是在load_model的时候报错了在这里mark 一下
其中错误为:TypeError: tuple indices must be integers, not list
再一一番百度后无结果,上谷歌后找到了类似的问题。但是是一对鸟文不知道什么东西(翻译后发现是俄文)。后来谷歌翻译了一下找到了解决方法。故将原始问题文章贴上来警示一下
原训练代码
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
#Каталог с данными для обучения
train_dir = 'train'
# Каталог с данными для проверки
val_dir = 'val'
# Каталог с данными для тестирования
test_dir = 'val'
# Размеры изображения
img_width, img_height = 800, 800
# Размерность тензора на основе изображения для входных данных в нейронную сеть
# backend Tensorflow, channels_last
input_shape = (img_width, img_height, 3)
# Количество эпох
epochs = 1
# Размер мини-выборки
batch_size = 4
# Количество изображений для обучения
nb_train_samples = 300
# Количество изображений для проверки
nb_validation_samples = 25
# Количество изображений для тестирования
nb_test_samples = 25
model = Sequential()
model.add(Conv2D(32, (7, 7), padding="same", input_shape=input_shape))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
model.add(Conv2D(64, (5, 5), padding="same"))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer="Nadam",
metrics=['accuracy'])
print(model.summary())
datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = datagen.flow_from_directory(
train_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical')
val_generator = datagen.flow_from_directory(
val_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical')
test_generator = datagen.flow_from_directory(
test_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical')
model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=epochs,
validation_data=val_generator,
validation_steps=nb_validation_samples // batch_size)
print('Сохраняем сеть')
model.save("grib.h5")
print("Сохранение завершено!")
模型载入
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
from keras.models import load_model
print("Загрузка сети")
model = load_model("grib.h5")
print("Загрузка завершена!")
报错
/usr/bin/python3.5 /home/disk2/py/neroset/do.py
/home/mama/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Загрузка сети
Traceback (most recent call last):
File "/home/disk2/py/neroset/do.py", line 13, in <module>
model = load_model("grib.h5")
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 243, in load_model
model = model_from_config(model_config, custom_objects=custom_objects)
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 317, in model_from_config
return layer_module.deserialize(config, custom_objects=custom_objects)
File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
list(custom_objects.items())))
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 1350, in from_config
model.add(layer)
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 492, in add
output_tensor = layer(self.outputs[0])
File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 590, in __call__
self.build(input_shapes[0])
File "/usr/local/lib/python3.5/dist-packages/keras/layers/normalization.py", line 92, in build
dim = input_shape[self.axis]
TypeError: tuple indices must be integers or slices, not list
Process finished with exit code 1
战斗种族解释
убераю BatchNormalization всё работает хорошо. Не подскажите в чём ошибка?Выяснил что сохранение keras и нормализация tensorflow не работают вместе нужно просто изменить строку импорта.(译文:整理BatchNormalization一切正常。 不要告诉我错误是什么?我发现保存keras和规范化tensorflow不能一起工作;只需更改导入字符串即可。)
强调文本 强调文本
keras.preprocessing.image import ImageDataGenerator
keras.models import Sequential
keras.layers import Conv2D, MaxPooling2D, BatchNormalization
keras.layers import Activation, Dropout, Flatten, Dense
##完美解决
##附上原文链接
https://qa-help.ru/questions/keras-batchnormalization
来源:https://blog.csdn.net/CZ505632696/article/details/79515782
0
投稿
猜你喜欢
- 我有个需求就是抓取一些简单的书籍信息存储到mysql数据库,例如,封面图片,书名,类型,作者,简历,出版社,语种。我比较之后,决定在亚马逊来
- 一般调试程序的时候都比较倾向print,利用直接打印的方法作出判断,但是print只能打印出结果,对类型无法作出判断。例如:复制代码a =
- 现在网页的设计都讲究整体统一风格,无论是网页的文字、图像,还是浏览器的滚动条都要求颜色和风
- ctrl+Enter:重建ctrl+0:相当于点击当前行左方的加号或减号ctrl+E:打开新窗口预览ctrl+T:替换\t为两个空格tab:
- 本文实例讲述了Python事务操作实现方法。分享给大家供大家参考,具体如下:#coding=utf-8import sysimport My
- 语法: ROW_NUMBER() OVER([ <partition_by_clause>] <order_by_clau
- 用matplotlib.pyplot画的图,显示和保存的图片周围都会有白边,可以去掉。为了显示的更清楚,给图片加了红色的框代码“` impo
- 在Perfection kills上看到他去年写的一篇文章,关于HTML优化的,讲的很详细,姑且记录之,尽管里面有些东西并不能在目前的环境里
- 千图成像也就是用N张图片组成一张图片的效果。制作方法有很多的,最常见的如用ps、懒人图云、foto-mosaik-edda这些制作。千图成像
- 创建项目首先打开Pycharm勾选I confirm that I have read and accept the terms of th
- Python heapq 详解Python有一个内置的模块,heapq标准的封装了最小堆的算法实现。下面看两个不错的应用。小顶堆
- model:class Profile(models.Model): user = models.OneToOneField(User, o
- 网页链接:https://www.huya.com/g/4079 这里的主要步骤其实还是和我们之前分析的一样,如下图所示:这里再简单带大家看
- 简单的一个例子,是以前用Dephi写的,前不久刚实现了一个在Python中使用Delphi控件来编写界面程序,于是趁热写一个类似的的查询方案
- 本文实例为大家分享了opencv实现答题卡识别的具体代码,供大家参考,具体内容如下"""识别答题卡"
- OUTLINE 常见的时间字符串与timestamp之间的转换日期与timestamp之间的转换常见的时间字符串与timesta
- 1、新建DLL打开VB6-->文件-->新建工程-->选择ActiveX DLL-->确定2、将默认工程、类重命名工
- 在sql语句后使用 SCOPE_IDENTITY() 当然您也可以使用 SELECT @@IDENTITY 但是使用 SELECT @@ID
- ltp是哈工大出品的自然语言处理工具箱, pyltp是python下对ltp(c++)的封装.在linux下我们很容易的安装pyltp, 因
- 哪的资料都不如官方资料权威。今天总算从MSDN中择出了ASP编码问题的解决方案。下面是MSDN中的一段话。Setting @CODEPAGE