keras分类模型中的输入数据与标签的维度实例
作者:xytywh 发布时间:2022-01-30 02:12:43
在《python深度学习》这本书中。
一、21页mnist十分类
导入数据集
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
初始数据维度:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
数据预处理:
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
之后:
print(train_images, type(train_images), train_images.shape, train_images.dtype)
print(train_labels, type(train_labels), train_labels.shape, train_labels.dtype)
结果:
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]] <class 'numpy.ndarray'> (60000, 784) float32
[[0. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 1. 0.]] <class 'numpy.ndarray'> (60000, 10) float32
二、51页IMDB二分类
导入数据:
from keras.datasets import imdb (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)
参数 num_words=10000 的意思是仅保留训练数据中前 10 000 个最常出现的单词。
train_data和test_data都是numpy.ndarray类型,都是一维的(共25000个元素,相当于25000个list),其中每个list代表一条评论,每个list中的每个元素的值范围在0-9999 ,代表10000个最常见单词的每个单词的索引,每个list长度不一,因为每条评论的长度不一,例如train_data中的list最短的为11,最长的为189。
train_labels和test_labels都是含25000个元素(元素的值要不0或者1,代表两类)的list。
数据预处理:
# 将整数序列编码为二进制矩阵
def vectorize_sequences(sequences, dimension=10000):
# Create an all-zero matrix of shape (len(sequences), dimension)
results = np.zeros((len(sequences), dimension))
for i, sequence in enumerate(sequences):
results[i, sequence] = 1. # set specific indices of results[i] to 1s
return results
x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)
第一种方式:shape为(25000,)
y_train = np.asarray(train_labels).astype('float32') #就用这种方式就行了
y_test = np.asarray(test_labels).astype('float32')
第二种方式:shape为(25000,1)
y_train = np.asarray(train_labels).astype('float32').reshape(25000, 1)
y_test = np.asarray(test_labels).astype('float32').reshape(25000, 1)
第三种方式:shape为(25000,2)
y_train = to_categorical(train_labels) #变成one-hot向量
y_test = to_categorical(test_labels)
第三种方式,相当于把二分类看成了多分类,所以网络的结构同时需要更改,
最后输出的维度:1->2
最后的激活函数:sigmoid->softmax
损失函数:binary_crossentropy->categorical_crossentropy
预处理之后,train_data和test_data变成了shape为(25000,10000),dtype为float32的ndarray(one-hot向量),train_labels和test_labels变成了shape为(25000,)的一维ndarray,或者(25000,1)的二维ndarray,或者shape为(25000,2)的one-hot向量。
注:
1.sigmoid对应binary_crossentropy,softmax对应categorical_crossentropy
2.网络的所有输入和目标都必须是浮点数张量
补充知识:keras输入数据的方法:model.fit和model.fit_generator
1.第一种,普通的不用数据增强的
from keras.datasets import mnist,cifar10,cifar100
(X_train, y_train), (X_valid, Y_valid) = cifar10.load_data()
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, shuffle=True,
verbose=1, validation_data=(X_valid, Y_valid), )
2.第二种,带数据增强的 ImageDataGenerator,可以旋转角度、平移等操作。
from keras.preprocessing.image import ImageDataGenerator
(trainX, trainY), (testX, testY) = cifar100.load_data()
trainX = trainX.astype('float32')
testX = testX.astype('float32')
trainX /= 255.
testX /= 255.
Y_train = np_utils.to_categorical(trainY, nb_classes)
Y_test = np_utils.to_categorical(testY, nb_classes)
generator = ImageDataGenerator(rotation_range=15,
width_shift_range=5./32,
height_shift_range=5./32)
generator.fit(trainX, seed=0)
model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size),
steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch,
callbacks=callbacks,
validation_data=(testX, Y_test),
validation_steps=testX.shape[0] // batch_size, verbose=1)
来源:https://blog.csdn.net/xiaohuihui1994/article/details/83536752
猜你喜欢
- 所有文件都包含在各个不同的目录下,不过Python也能轻松处理。os模块有许多方法能帮你创建,删除和更改目录。mkdir()方法可以使用os
- 什么是性能分析?性能分析是衡量应用程序在代码级别的相对性能。性能分析将捕捉的事件包括:CPU的使用,内存的使用,函数的调用时长和次数,以及调
- uuid str int 之间的转换import uudi#str 转 uuiduuid.UUID('123456781234567
- 前言本文主要给大家介绍了关于Django自定义过滤器的相关内容,分享出来供大家参考学习,下面话不多说了,来一起看看详细的介绍:过滤器与函数d
- 目录项目地址:1) 启动方法2) web查看方法3) 功能说明:4) 展示:代码项目地址:https://github.com/guodon
- 正文开始if name == "main":可以看成是python程序的入口,就像java中的main()方法,但不完全
- php cookie中不能使用点号(句号),实际上不是很严格,应该说可以使用点号的cookie名,但会被转换,你命名一个cookie:$_C
- 如何使用Iframe实现本页提交?例:chunfeng.html< html>< head>&n
- 《色彩解答》系列之一 色彩层次《色彩解答》系列之二 色彩比例我们知道在设计中有很多对比,大小的对比,形状的对比,长短的对比,多少的对比,这些
- 一、桌面应用软件桌面应用软件是基于GUI(Graphical User Interface,图形用户界面)交互式程序,需要实现GUI库实现前
- 1.使用jobsName.ini文件保存要创建job的名字jobs1jobs2jobs32.使用Jenkins创建job时自动生成的conf
- 在不久前的一天,当我为了解决一个语法问题来翻阅VBscript文档时,偶然间发现在了下面的一句话: &nb
- 公网与私有网络的判断其实十分简单,只要记住私有网络的三个网段。不过,对于记性不好的人或者学识不是很高的机器来说,有一种判断方法还是有必要的。
- 如何正确理解和使用Command、Connection和 Recordset三个对象?我知道它们都是连接数据库的“好手”,但在编程的具体应用
- 直接使用Python来实现向量的相加# -*-coding:utf-8-*-#向量相加def pythonsum(n): a = range
- 一个可能你似曾相识的场景阅读内容包含大量英文的 PPT、Word、Excel 或者记事本时,由于英语不熟悉,为了流利地阅读,需要打开浏览器进
- 一、TCP1、tcp服务器创建#创建服务器from socket import *from time import ctime #导入cti
- 前言哈希 又称作 “散列”,它接收任何一组任意长度的输入信息,通过 哈希 算法变换成固定长度的数据指
- 本文实例讲述了Python打包文件夹的方法。分享给大家供大家参考,具体如下:一、zipimport os, zipfile#打包目录为zip
- 最近做的一个B/S项目,在打印时采用了在IE中嵌入.net winform控件和XML结合的方式(参见http://www.yesky.co