keras的ImageDataGenerator和flow()的用法说明
作者:o0程卓0o 发布时间:2021-12-12 08:54:57
标签:keras,ImageDataGenerator,flow
ImageDataGenerator的参数自己看文档
from keras.preprocessing import image
import numpy as np
X_train=np.ones((3,123,123,1))
Y_train=np.array([[1],[2],[2]])
generator=image.ImageDataGenerator(featurewise_center=False,
samplewise_center=False,
featurewise_std_normalization=False,
samplewise_std_normalization=False,
zca_whitening=False,
zca_epsilon=1e-6,
rotation_range=180,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0,
zoom_range=0.001,
channel_shift_range=0,
fill_mode='nearest',
cval=0.,
horizontal_flip=True,
vertical_flip=True,
rescale=None,
preprocessing_function=None,
data_format='channels_last')
a=generator.flow(X_train,Y_train,batch_size=20)#生成的是一个迭代器,可直接用于for循环
'''
batch_size如果小于X的第一维m,next生成的多维矩阵的第一维是为batch_size,输出是从输入中随机选取batch_size个数据
batch_size如果大于X的第一维m,next生成的多维矩阵的第一维是m,输出是m个数据,不过顺序随机
,输出的X,Y是一一对对应的
如果要直接用于tf.placeholder(),要求生成的矩阵和要与tf.placeholder相匹配
'''
X,Y=next(a)
print(Y)
X,Y=next(a)
print(Y)
X,Y=next(a)
print(Y)
X,Y=next(a)
输出
[[2]
[1]
[2]]
[[2]
[2]
[1]]
[[2]
[2]
[1]]
[[2]
[2]
[1]]
补充知识:tensorflow 与keras 混用之坑
在使用tensorflow与keras混用是model.save 是正常的但是在load_model的时候报错了在这里mark 一下
其中错误为:TypeError: tuple indices must be integers, not list
再一一番百度后无结果,上谷歌后找到了类似的问题。但是是一对鸟文不知道什么东西(翻译后发现是俄文)。后来谷歌翻译了一下找到了解决方法。故将原始问题文章贴上来警示一下
原训练代码
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
#Каталог с данными для обучения
train_dir = 'train'
# Каталог с данными для проверки
val_dir = 'val'
# Каталог с данными для тестирования
test_dir = 'val'
# Размеры изображения
img_width, img_height = 800, 800
# Размерность тензора на основе изображения для входных данных в нейронную сеть
# backend Tensorflow, channels_last
input_shape = (img_width, img_height, 3)
# Количество эпох
epochs = 1
# Размер мини-выборки
batch_size = 4
# Количество изображений для обучения
nb_train_samples = 300
# Количество изображений для проверки
nb_validation_samples = 25
# Количество изображений для тестирования
nb_test_samples = 25
model = Sequential()
model.add(Conv2D(32, (7, 7), padding="same", input_shape=input_shape))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
model.add(Conv2D(64, (5, 5), padding="same"))
model.add(BatchNormalization())
model.add(Activation('tanh'))
model.add(MaxPooling2D(pool_size=(10, 10)))
model.add(Flatten())
model.add(Dense(512))
model.add(Activation('relu'))
model.add(Dropout(0.5))
model.add(Dense(10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer="Nadam",
metrics=['accuracy'])
print(model.summary())
datagen = ImageDataGenerator(rescale=1. / 255)
train_generator = datagen.flow_from_directory(
train_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical')
val_generator = datagen.flow_from_directory(
val_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical')
test_generator = datagen.flow_from_directory(
test_dir,
target_size=(img_width, img_height),
batch_size=batch_size,
class_mode='categorical')
model.fit_generator(
train_generator,
steps_per_epoch=nb_train_samples // batch_size,
epochs=epochs,
validation_data=val_generator,
validation_steps=nb_validation_samples // batch_size)
print('Сохраняем сеть')
model.save("grib.h5")
print("Сохранение завершено!")
模型载入
from tensorflow.python.keras.preprocessing.image import ImageDataGenerator
from tensorflow.python.keras.models import Sequential
from tensorflow.python.keras.layers import Conv2D, MaxPooling2D, BatchNormalization
from tensorflow.python.keras.layers import Activation, Dropout, Flatten, Dense
from keras.models import load_model
print("Загрузка сети")
model = load_model("grib.h5")
print("Загрузка завершена!")
报错
/usr/bin/python3.5 /home/disk2/py/neroset/do.py
/home/mama/.local/lib/python3.5/site-packages/h5py/__init__.py:36: FutureWarning: Conversion of the second argument of issubdtype from `float` to `np.floating` is deprecated. In future, it will be treated as `np.float64 == np.dtype(float).type`.
from ._conv import register_converters as _register_converters
Using TensorFlow backend.
Загрузка сети
Traceback (most recent call last):
File "/home/disk2/py/neroset/do.py", line 13, in <module>
model = load_model("grib.h5")
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 243, in load_model
model = model_from_config(model_config, custom_objects=custom_objects)
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 317, in model_from_config
return layer_module.deserialize(config, custom_objects=custom_objects)
File "/usr/local/lib/python3.5/dist-packages/keras/layers/__init__.py", line 55, in deserialize
printable_module_name='layer')
File "/usr/local/lib/python3.5/dist-packages/keras/utils/generic_utils.py", line 144, in deserialize_keras_object
list(custom_objects.items())))
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 1350, in from_config
model.add(layer)
File "/usr/local/lib/python3.5/dist-packages/keras/models.py", line 492, in add
output_tensor = layer(self.outputs[0])
File "/usr/local/lib/python3.5/dist-packages/keras/engine/topology.py", line 590, in __call__
self.build(input_shapes[0])
File "/usr/local/lib/python3.5/dist-packages/keras/layers/normalization.py", line 92, in build
dim = input_shape[self.axis]
TypeError: tuple indices must be integers or slices, not list
Process finished with exit code 1
战斗种族解释
убераю BatchNormalization всё работает хорошо. Не подскажите в чём ошибка?Выяснил что сохранение keras и нормализация tensorflow не работают вместе нужно просто изменить строку импорта.(译文:整理BatchNormalization一切正常。 不要告诉我错误是什么?我发现保存keras和规范化tensorflow不能一起工作;只需更改导入字符串即可。)
强调文本 强调文本
keras.preprocessing.image import ImageDataGenerator
keras.models import Sequential
keras.layers import Conv2D, MaxPooling2D, BatchNormalization
keras.layers import Activation, Dropout, Flatten, Dense
##完美解决
##附上原文链接
https://qa-help.ru/questions/keras-batchnormalization
来源:https://blog.csdn.net/CZ505632696/article/details/79515782
0
投稿
猜你喜欢
- Python字符编码目前计算机内存的字符编码都是Unicode,目前国内的windows操作系统采用的是gbk。python2默认的字符编码
- 本文实例讲述了JavaScript实现的伸展收缩型菜单代码。分享给大家供大家参考。具体如下:运行效果截图如下:具体代码如下:<html
- CKeditor是目前最优秀的可见即可得网页编辑器之一,它采用JavaScript编写。具备功能强大、配置容易、跨浏览器、支持多种编程语言、
- 从2004年开始,我开始进入雅虎的异常表现小组。我们是一个很小的队伍,专门针对雅虎的产品进行质量检测和改进,我作为一个后端工程师,现在却开始
- 好久没有弄JS了,因为我烦里边的大小写。其实和vbs差不多的,不过我看vbs毕竟应用面不广了,呵呵。var w=WScript.create
- 1. 主要内容从本节开始介绍windows开发实现记事本程序的逻辑实现部分。本节的主要内容有以下3点:1. 主窗口定义 —— 主要介绍记事本
- 一.JavaScript基本介绍js诞生于1995年,是Javascript的缩写,其与java语言没有关系,当时的主要目的是验证表单的数据
- 项目意义如果你想在支付宝蚂蚁森林收集很多能量种树,为环境绿化出一份力量,又或者是想每天称霸微信运动排行榜装逼,却不想出门走路,那么该pyth
- 目前两个客户端扩展库连接超时可以设置选项来操作,比如mysqli: <?php //创建对象 $mysqli = mysqli_ini
- Cookie 对象是一种以文件(Cookie文件)的形式保存在客户端硬盘的Cookies文件夹中的数据信息(Cookie数据)。Cookie
- 八卦是种优良品质,特别是用在技术上时。来看几个Reset CSS的八卦问题吧:你知道世界上第一份reset.css在哪么?*&nb
- 前言大家好,说起动态条形图,之前推荐过两个 Python 库,比如Bar Chart Race、Pandas_Alive,都可以实现。今天就
- golang时间格式化科普 CST 含义CST: 中部标准时间 (Central Standard Time) 同时表示下面4个时区CST
- 看到有人在有汉字的字符串 前加一个 ‘ 或是任意半角符号,让bug将其除掉,不过这样做太麻烦了。最后呢,找来一个模拟fgetcsv功能的函数
- 你的主页或者你管理的网站有各种密码需要保护,把密码直接放在数据库或者文件中存在不少安全隐患,所以密码加密后存储是最常见的做法。在ASP.NE
- 事先说明哦,这不是一篇关于Python异常的全面介绍的文章,这只是在学习Python异常后的一篇笔记式的记录和小结性质的文章。什么?你还不知
- 基于python3基础课程,编写名片管理系统训练,有利于熟悉python基础代码的使用。cards_main.py#! /usr/bin/p
- 0、前言在python2.7及以上的版本,str.format()的方式为格式化提供了非常大的便利。与之前的%型格式化字符串相比,他显得更为
- 翻译说明:这是Solid State Group网站上的一篇很友好的文章,解决了我在设计中遇到的很多问题,故在此我翻译其文,并对原作者表示非
- 在1943年,沃伦麦卡洛可与沃尔特皮茨提出了第一个脑神经元的抽象模型,简称麦卡洛可-皮茨神经元(McCullock-Pitts neuron