keras分类模型中的输入数据与标签的维度实例
作者:xytywh 发布时间:2022-01-30 02:12:43
在《python深度学习》这本书中。
一、21页mnist十分类
导入数据集
from keras.datasets import mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
初始数据维度:
>>> train_images.shape
(60000, 28, 28)
>>> len(train_labels)
60000
>>> train_labels
array([5, 0, 4, ..., 5, 6, 8], dtype=uint8)
数据预处理:
train_images = train_images.reshape((60000, 28 * 28))
train_images = train_images.astype('float32') / 255
train_labels = to_categorical(train_labels)
之后:
print(train_images, type(train_images), train_images.shape, train_images.dtype)
print(train_labels, type(train_labels), train_labels.shape, train_labels.dtype)
结果:
[[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]] <class 'numpy.ndarray'> (60000, 784) float32
[[0. 0. 0. ... 0. 0. 0.]
[1. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
...
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 0. 0.]
[0. 0. 0. ... 0. 1. 0.]] <class 'numpy.ndarray'> (60000, 10) float32
二、51页IMDB二分类
导入数据:
from keras.datasets import imdb (train_data, train_labels), (test_data, test_labels) = imdb.load_data(num_words=10000)
参数 num_words=10000 的意思是仅保留训练数据中前 10 000 个最常出现的单词。
train_data和test_data都是numpy.ndarray类型,都是一维的(共25000个元素,相当于25000个list),其中每个list代表一条评论,每个list中的每个元素的值范围在0-9999 ,代表10000个最常见单词的每个单词的索引,每个list长度不一,因为每条评论的长度不一,例如train_data中的list最短的为11,最长的为189。
train_labels和test_labels都是含25000个元素(元素的值要不0或者1,代表两类)的list。
数据预处理:
# 将整数序列编码为二进制矩阵
def vectorize_sequences(sequences, dimension=10000):
# Create an all-zero matrix of shape (len(sequences), dimension)
results = np.zeros((len(sequences), dimension))
for i, sequence in enumerate(sequences):
results[i, sequence] = 1. # set specific indices of results[i] to 1s
return results
x_train = vectorize_sequences(train_data)
x_test = vectorize_sequences(test_data)
第一种方式:shape为(25000,)
y_train = np.asarray(train_labels).astype('float32') #就用这种方式就行了
y_test = np.asarray(test_labels).astype('float32')
第二种方式:shape为(25000,1)
y_train = np.asarray(train_labels).astype('float32').reshape(25000, 1)
y_test = np.asarray(test_labels).astype('float32').reshape(25000, 1)
第三种方式:shape为(25000,2)
y_train = to_categorical(train_labels) #变成one-hot向量
y_test = to_categorical(test_labels)
第三种方式,相当于把二分类看成了多分类,所以网络的结构同时需要更改,
最后输出的维度:1->2
最后的激活函数:sigmoid->softmax
损失函数:binary_crossentropy->categorical_crossentropy
预处理之后,train_data和test_data变成了shape为(25000,10000),dtype为float32的ndarray(one-hot向量),train_labels和test_labels变成了shape为(25000,)的一维ndarray,或者(25000,1)的二维ndarray,或者shape为(25000,2)的one-hot向量。
注:
1.sigmoid对应binary_crossentropy,softmax对应categorical_crossentropy
2.网络的所有输入和目标都必须是浮点数张量
补充知识:keras输入数据的方法:model.fit和model.fit_generator
1.第一种,普通的不用数据增强的
from keras.datasets import mnist,cifar10,cifar100
(X_train, y_train), (X_valid, Y_valid) = cifar10.load_data()
model.fit(X_train, Y_train, batch_size=batch_size, nb_epoch=nb_epoch, shuffle=True,
verbose=1, validation_data=(X_valid, Y_valid), )
2.第二种,带数据增强的 ImageDataGenerator,可以旋转角度、平移等操作。
from keras.preprocessing.image import ImageDataGenerator
(trainX, trainY), (testX, testY) = cifar100.load_data()
trainX = trainX.astype('float32')
testX = testX.astype('float32')
trainX /= 255.
testX /= 255.
Y_train = np_utils.to_categorical(trainY, nb_classes)
Y_test = np_utils.to_categorical(testY, nb_classes)
generator = ImageDataGenerator(rotation_range=15,
width_shift_range=5./32,
height_shift_range=5./32)
generator.fit(trainX, seed=0)
model.fit_generator(generator.flow(trainX, Y_train, batch_size=batch_size),
steps_per_epoch=len(trainX) // batch_size, epochs=nb_epoch,
callbacks=callbacks,
validation_data=(testX, Y_test),
validation_steps=testX.shape[0] // batch_size, verbose=1)
来源:https://blog.csdn.net/xiaohuihui1994/article/details/83536752


猜你喜欢
- git checkout . #本地所有修改的。没有的提交的,都返回到原来的状态git stash #把所有没有提交的修改暂存到stash里
- 本文实例讲述了wxPython窗口的继承机制,分享给大家供大家参考。具体分析如下:示例代码如下:import wx class
- Quoted-printable 可译为“可打印字符引用编码”、“使用可打印字符的编码”,我们收邮件,查看信件原始信息,经常会看到这种类型的
- 方法一、线程池执行的循环代码为自己写的情况定义一个全局变量,默认为T,当QT界面关闭后,将该变量值改为F。线程执行的循环代码内增加一个判断方
- 1 , javascript字符集:javascript采用的是Unicode字符集编码。为什么要采用这个编码呢?原因很简单,16位的Uni
- flagflag 是Go 标准库提供的解析命令行参数的包。使用方式:flag.Type(name, defValue, usage)其中Ty
- 本文实例讲述了PHP实现的redis主从数据库状态检测功能。分享给大家供大家参考,具体如下:实例:<?php/** * 检测多个主从r
- @num=1; 把num类型转成nvarchar类型 cast(@num as nvarchar(10)) @str='123
- HTTP协议简介超文本传输协议(英文:HyperText Transfer Protocol,缩写:HTTP)是一种用于分布式、协作式和超媒
- 哥德巴赫猜想:大于8的偶数之和都可以被两个素数相加范围 8 - 10000思路:首先不要去管需要什么什么东西实现,所以我们如果知道如何去完成
- 本文实例讲述了Python smallseg分词用法。分享给大家供大家参考。具体分析如下:#encoding=utf-8 #import p
- 关于中大型开发b/s开发中的缓存(cache),我的一些看法,有不正确的或者是有笔误的地方,请指正。thanks首先,应该了解基本的,对于缓
- Vision Transformer(VIT)Vision Transformer(ViT)是一种新兴的图像分类模型,它使用了类似于自然语言
- 项目技术:webpack + vue + element + axois (vue-resource) + less-loader+ ...
- 作为一名网站开发WEB前端工程师,对自己开发的网站项目应该尽可能地对其性能进行优化,现在互联网上搜索到的网站性能优化多是翻译转载自 Yaho
- cooper谈到用户的视觉路径一般是:从上到下,从左到右。好的视觉设计路径应该是顺应这样的用户习惯,糟糕的设计会让用户无所适从,焦点到处都是
- 1.在HTML5中使用Geolocation.getCurrentPosition()方法来获取地理位置。语法:navigator.geol
- background-clip 和 background-origin 是 CSS3 中新加的 background module 属性,用
- 具体代码如下所示:import requestsimport jsonfrom pyecharts.charts import Map, G
- 最近在一个python工具中需要实现串口自动触发工作的功能,之前只在winform上面实现,今天使用python试试。这里简单记一下:首先用