python神经网络学习使用Keras进行简单分类
作者:Bubbliiiing 发布时间:2023-09-18 04:37:23
标签:python,神经网络,Keras,分类
学习前言
上一篇讲了如何构建回归算法,这一次将怎么进行简单分类。
Keras中分类的重要函数
1、np_utils.to_categorical
np_utils.to_categorical用于将标签转化为形如(nb_samples, nb_classes)的二值序列。
假设num_classes = 10。
如将[1,2,3,……4]转化成:
[[0,1,0,0,0,0,0,0]
[0,0,1,0,0,0,0,0]
[0,0,0,1,0,0,0,0]
……
[0,0,0,0,1,0,0,0]]
这样的形态。
如将Y_train转化为二值序列,可以用如下方式:
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
2、Activation
Activation是激活函数,一般在每一层的输出使用。
当我们使用Sequential模型构建函数的时候,只需要在每一层Dense后面添加Activation就可以了。
Sequential函数也支持直接在参数中完成所有层的构建,使用方法如下。
model = Sequential([
Dense(32,input_dim = 784),
Activation("relu"),
Dense(10),
Activation("softmax")
]
)
其中两次Activation分别使用了relu函数和softmax函数。
3、metrics=[‘accuracy’]
在model.compile中添加metrics=[‘accuracy’]表示需要计算分类精确度,具体使用方式如下:
model.compile(
loss = 'categorical_crossentropy',
optimizer = rmsprop,
metrics=['accuracy']
)
全部代码
这是一个简单的仅含有一个隐含层的神经网络,用于完成手写体识别。在本例中,使用的优化器是RMSprop,具体可以使用的优化器可以参照Keras中文文档。
import numpy as np
from keras.models import Sequential
from keras.layers import Dense,Activation ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import RMSprop
# 获取训练集
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
# 首先进行标准化
X_train = X_train.reshape(X_train.shape[0],-1)/255
X_test = X_test.reshape(X_test.shape[0],-1)/255
# 计算categorical_crossentropy需要对分类结果进行categorical
# 即需要将标签转化为形如(nb_samples, nb_classes)的二值序列
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
# 构建模型
model = Sequential([
Dense(32,input_dim = 784),
Activation("relu"),
Dense(10),
Activation("softmax")
]
)
rmsprop = RMSprop(lr = 0.001,rho = 0.9,epsilon = 1e-08,decay = 0)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = rmsprop,metrics=['accuracy'])
print("\ntraining")
cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 32)
print("\nTest")
cost,accuracy = model.evaluate(X_test,Y_test)
## W,b = model.layers[0].get_weights()
print("accuracy:",accuracy)
实验结果为:
Epoch 1/2
60000/60000 [==============================] - 12s 202us/step - loss: 0.3512 - acc: 0.9022
Epoch 2/2
60000/60000 [==============================] - 11s 183us/step - loss: 0.2037 - acc: 0.9419
Test
10000/10000 [==============================] - 1s 108us/step
accuracy: 0.9464
来源:https://blog.csdn.net/weixin_44791964/article/details/101170430


猜你喜欢
- 本文实例讲述了JS实现TITLE悬停长久显示效果。分享给大家供大家参考,具体如下:<!DOCTYPE html PUBLIC &quo
- 大概在2004年初的时候,我第一次买了一本很厚的书,名字或许叫《Dreamweaver MX从入门到精通》,很认真看着书并实践操作大约三分之
- window环境安装mysql5.7.21,具体内容如下1. 从MySQL官网下载免安装的压缩包mysql-5.7.21-winx64.zi
- 1. 问题虽然scrapy能够完美且快速的抓取静态页面,但是在现实中,目前绝大多数网站的页面都是动态页面,动态页面中的部分内容是
- 前情提要:作为刚入门机器视觉的小伙伴,第一节课学到机器视觉语法时觉得很难理解,很多人家的经验,我发现都千篇一律,功能函数没解析,参数不讲解,
- 阅读作者上一篇文章:段正淳的css笔记(4)css代码的简写CSS未知图片垂直居中的方法:一天大家在团队中讨论“未知图片垂直居中”的问题,突
- 一 创建mappingPUT test{ "mappings": { "
- 不是炒冷饭,我添加了很多新的功能哦演示地址: xwinhtcdemo.htmCSS: global.cssHTC: xwin.htc特点:1
- 1,定义和注册中间件在注册的中间件中使用:from django.http import HttpResponseRedirect'
- 在MySQL中,慢查询的界定时间是由MySQL内置参数变量long_query_time来指定的,其默认值为10(单位:秒),我们可以通过s
- 普通爬虫正常流程:数据来源分析发送请求获取数据解析数据保存数据环境介绍python 3.8pycharm 2021专业版【付费VIP完整版】
- 文章介绍内容:操作MySQL数据库:创建MySQL数据表;向表中插入记录;其他数据库操作。面试题:如何创建MySQL数据表?如何向MySQL
- matlab中创建类似字典的数据结构Matlab中创建struct:d = struct('a','1',&
- 如下所示:BaseException +-- SystemExit +-- KeyboardInterrupt +-- GeneratorE
- 一. 访问WEB数据库的多种方案目前在WINDOWS环境下有多种访问WEB数据库的技术,主要有:1.公共网关接口CGI(Commo
- V5.0之后,我们总结了一些得失。首先要说的是改版的动力。产品设计或产品升级的驱动力只有两个:用户需求和网站目标。之前的我们的多次改版,其驱
- 先不说直接改后缀,直接可以用网快等工具直接下载,其实这样你已经是为入侵者打开了大门。入侵者可以利用asp/asa为后缀的数据库直接得到web
- 有一组4096长度的数据,需要找到一阶导数从正到负的点,和三阶导数从负到正的点,截取了一小段。394.0 388.0 389.0 388.0
- 类的私有属性和方法Python是个开放的语言,默认情况下所有的属性和方法都是公开的 或者叫公有方法,不像C++和 Java中有明确的publ
- 1. 排名函数与PARTITION BY --所有数据 SELECT * FROM dbo.student AS a INNER JOIN