使用Keras画神经网络准确性图教程
作者:ZJE_ANDY 发布时间:2021-02-17 03:20:28
标签:Keras,神经网络,准确性图
1.在搭建网络开始时,会调用到 keras.models的Sequential()方法,返回一个model参数表示模型
2.model参数里面有个fit()方法,用于把训练集传进网络。fit()返回一个参数,该参数包含训练集和验证集的准确性acc和错误值loss,用这些数据画成图表即可。
如:
history=model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.25) #获取数据
#########画图
acc = history.history['acc'] #获取训练集准确性数据
val_acc = history.history['val_acc'] #获取验证集准确性数据
loss = history.history['loss'] #获取训练集错误值数据
val_loss = history.history['val_loss'] #获取验证集错误值数据
epochs = range(1,len(acc)+1)
plt.plot(epochs,acc,'bo',label='Trainning acc') #以epochs为横坐标,以训练集准确性为纵坐标
plt.plot(epochs,val_acc,'b',label='Vaildation acc') #以epochs为横坐标,以验证集准确性为纵坐标
plt.legend() #绘制图例,即标明图中的线段代表何种含义
plt.figure() #创建一个新的图表
plt.plot(epochs,loss,'bo',label='Trainning loss')
plt.plot(epochs,val_loss,'b',label='Vaildation loss')
plt.legend() ##绘制图例,即标明图中的线段代表何种含义
plt.show() #显示所有图表
得到效果:
完整代码:
import keras
from keras.datasets import mnist
from keras.layers import Conv2D, MaxPool2D, Dense, Flatten,Dropout
from keras.models import Sequential
import matplotlib.pyplot as plt
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, 28, 28, 1)
x_test = x_test.reshape(-1, 28, 28, 1)
x_train = x_train / 255.
x_test = x_test / 255.
y_train = keras.utils.to_categorical(y_train)
y_test = keras.utils.to_categorical(y_test)
model = Sequential()
model.add(Conv2D(20,(5,5),strides=(1,1),input_shape=(28,28,1),padding='valid',activation='relu',kernel_initializer='uniform'))
model.add(MaxPool2D(pool_size=(2,2),strides=(2,2)))
model.add(Conv2D(64,(5,5),strides=(1,1),padding='valid',activation='relu',kernel_initializer='uniform'))
model.add(MaxPool2D(pool_size=(2,2),strides=(2,2)))
model.add(Flatten())
model.add(Dense(500,activation='relu'))
model.add(Dropout(0.2))
model.add(Dense(10,activation='softmax'))
model.compile('sgd', loss='categorical_crossentropy', metrics=['accuracy']) #随机梯度下降
history=model.fit(x_train, y_train, batch_size=32, epochs=5, validation_split=0.25) #获取数据
#########画图
acc = history.history['acc'] #获取训练集准确性数据
val_acc = history.history['val_acc'] #获取验证集准确性数据
loss = history.history['loss'] #获取训练集错误值数据
val_loss = history.history['val_loss'] #获取验证集错误值数据
epochs = range(1,len(acc)+1)
plt.plot(epochs,acc,'bo',label='Trainning acc') #以epochs为横坐标,以训练集准确性为纵坐标
plt.plot(epochs,val_acc,'b',label='Vaildation acc') #以epochs为横坐标,以验证集准确性为纵坐标
plt.legend() #绘制图例,即标明图中的线段代表何种含义
plt.figure() #创建一个新的图表
plt.plot(epochs,loss,'bo',label='Trainning loss')
plt.plot(epochs,val_loss,'b',label='Vaildation loss')
plt.legend() ##绘制图例,即标明图中的线段代表何种含义
来源:https://blog.csdn.net/u014453898/article/details/89222503


猜你喜欢
- 小编最近由于工作原因要用到python,一门新的知识需要接触,对于我来说难度还是很大的。python工程目录结构每次创建一个python工程
- 完整系列教程详见:http://golang.iswbm.com在 Golang 中用于执行命令的库是 os/exec,exec.Comma
- MNIST是一个非常有名的手写体数字识别数据集,TensorFlow对MNIST数据集做了封装,可以直接调用。MNIST数据集包含了6000
- 前言:转眼距离上篇JS组件系列——又一款MVVM组件:Vue(一:30分钟搞定前端增删改查)已有好几个月了,今天打算将它捡起来,发现好久不用
- <?php $curDomain = $_SERVER['HTTP_HOST']; $strHTML = file_g
- 事件对象asyncio.Event是基于threading.Event来实现的。事件可以一个信号触发多个协程同步工作,例子如下:import
- 或许马上,或许几年之后,但是有迹象显示IE浏览器占统治地位的时代即将结束。在数据分析公司Net Applications的排名中,IE的市场
- %有哪几种含义?查找手册翻看《The Python Libary Reference》python库指南中附录index部分(P1899):
- 从照片里面获取GPS信息。可交换图像文件常被简称为EXIF(Exchangeable
- 在Python类中规定,函数的第一个参数是实例对象本身,并且约定俗成,把其名字写为self。其作用相当于java中的this,表示当前类的对
- 这篇文章主要介绍了python基于celery实现异步任务周期任务定时任务,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参
- 一、使用selenium前?1.安装seleniumpip install Selenium2.安装浏览器驱动Chrome驱动文件下载:点击
- 这篇文章主要介绍了Python如何计算语句执行时间,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可
- <input type=button value=刷新 onclick="window.location.reload()&
- 疫情肆虐,憋在家实在无聊,索性写点东西,于是就有了这个极极极极极简的音乐播放器。这个极极极简的音乐播放器类似于“阅后即焚”的软件,播放器可以
- <?php/*======================================事务处理==================
- shift:删除原数组第一项,并返回删除元素的值;如果数组为空则返回undefined var a = [1,2,3,4,5]; var b
- 遇到mysql ERROR 1045 这个问题搞了很久,自己记下来。方法是百度的,亲测有效。ERROR 1045 (28000): Acce
- 安装(fastcgi模式)的时候,常常有这样一句命令:/usr/local/webserver/php/bin/phpize一、phpize
- CAS 全称集中式认证服务(Central Authentication Service),是实现单点登录(SSO)的一中手段。CAS 的通