python神经网络使用Keras进行模型的保存与读取
作者:Bubbliiiing 发布时间:2023-04-15 06:53:34
标签:python,神经网络,Keras,模型,保存读取
学习前言
开始做项目的话,有些时候会用到别人训练好的模型,这个时候要学会load噢。
Keras中保存与读取的重要函数
1、model.save
model.save用于保存模型,在保存模型前,首先要利用pip install安装h5py的模块,这个模块在Keras的模型保存与读取中常常被使用,用于定义保存格式。
pip install h5py
完成安装后,可以通过如下函数保存模型。
model.save("./model.hdf5")
其中,model是已经训练完成的模型,save函数传入的参数就是保存后的位置+名字。
2、load_model
load_model用于载入模型。
具体使用方式如下:
model = load_model("./model.hdf5")
其中,load_model函数传入的参数就是已经完成保存的模型的位置+名字。./表示保存在当前目录。
全部代码
这是一个简单的手写体识别例子,在之前也讲解过如何构建
python神经网络学习使用Keras进行简单分类,在最后我添加上了模型的保存与读取函数。
import numpy as np
from keras.models import Sequential,load_model,save_model
from keras.layers import Dense,Activation ## 全连接层
from keras.datasets import mnist
from keras.utils import np_utils
from keras.optimizers import RMSprop
# 获取训练集
(X_train,Y_train),(X_test,Y_test) = mnist.load_data()
# 首先进行标准化
X_train = X_train.reshape(X_train.shape[0],-1)/255
X_test = X_test.reshape(X_test.shape[0],-1)/255
# 计算categorical_crossentropy需要对分类结果进行categorical
# 即需要将标签转化为形如(nb_samples, nb_classes)的二值序列
Y_train = np_utils.to_categorical(Y_train,num_classes= 10)
Y_test = np_utils.to_categorical(Y_test,num_classes= 10)
# 构建模型
model = Sequential([
Dense(32,input_dim = 784),
Activation("relu"),
Dense(10),
Activation("softmax")
]
)
rmsprop = RMSprop(lr = 0.001,rho = 0.9,epsilon = 1e-08,decay = 0)
## compile
model.compile(loss = 'categorical_crossentropy',optimizer = rmsprop,metrics=['accuracy'])
print("\ntraining")
cost = model.fit(X_train,Y_train,nb_epoch = 2,batch_size = 100)
print("\nTest")
# 测试
cost,accuracy = model.evaluate(X_test,Y_test)
print("accuracy:",accuracy)
# 保存模型
model.save("./model.hdf5")
# 删除现有模型
del model
print("model had been del")
# 再次载入模型
model = load_model("./model.hdf5")
# 预测
cost,accuracy = model.evaluate(X_test,Y_test)
print("accuracy:",accuracy)
实验结果为:
Epoch 1/2
60000/60000 [==============================] - 6s 104us/step - loss: 0.4217 - acc: 0.8888
Epoch 2/2
60000/60000 [==============================] - 6s 99us/step - loss: 0.2240 - acc: 0.9366
Test
10000/10000 [==============================] - 1s 149us/step
accuracy: 0.9419
model had been del
10000/10000 [==============================] - 1s 117us/step
accuracy: 0.9419
来源:https://blog.csdn.net/weixin_44791964/article/details/101613118


猜你喜欢
- sys模块:全称system,指的是解释器。常用操作,用于接收系统操作系统调用解释器传入的参数1、 sys.argv获取脚本传递的所有参数,
- 实际上在python中用列表就可以实现动态变量名的管理,python中的列表中可以存储任何类型的元素:listA = [0,"st
- 本文为大家分享了python tkinter图形界面代码统计工具,供大家参考,具体内容如下#encoding=utf-8import os,
- django中form表单设置action后,点提交按钮是跳转到action页面的,比如设置action为login,网址为192.168.
- 关于带权随机数为了帮助理解,先来看三类随机问题的对比:1.已有n条记录,从中选取m条记录,选取出来的记录前后顺序不管。实现思路:按行遍历所有
- np.newaxisnp.newaxis 的功能是增加新的维度,但是要注意 np.newaxis 放的位置不同,产生的矩阵形状也不同。通常按
- 本文实例讲述了JS实现不规则TAB选项卡效果代码。分享给大家供大家参考。具体如下:这是一款不规则TAB选项卡效果,将中规中矩的方角换成了不规
- 目录前言环境依赖代码前言本文主要分享一个可以将图片或者视频模糊化的工具代码。技术路线主要是使用ffmpeg滤镜。环境依赖ffmpeg环境部署
- python中列表的常见操作列表元组的简单操作前面我们已经学过了关于len()函数、赋值运算符及身份运算符的使用,下面简单回顾一下这些在列表
- 一 什么是读写分离虽然知道处理大数据量时,数据库为什么要做读写分离,原因很简单:读写分离是MySQL优化的一方面,它可以提高性能,缓解数据库
- 前言:之前,我写笔记的工具一直都是 notion,而且没有写博客的习惯。但是一是由于 notion 的服务器在
- 问题描述时间在我们日常的代码编写中会是经常出现的筛选或排序条件,尤其是一些特殊时间节点的时间显得尤为突出,例如昨天,当前日期,当前月份,当前
- python的线程有一个类叫Timer可以,用来创建定时任务,但是它的问题是只能运行一次,如果要重复执行,则只能在任务中再调用一次timer
- 要用django的orm表达sql的exists子查询,是个比较麻烦的事情,需要做两部来完成from django.db.models im
- 想必每个DBA都喜欢挑战数据导入时间,用时越短工作效率越高,也充分的能够证明自己的实力。实际工作中有时候需要把大量数据导入数据库,然后用于各
- 对于二维数组,img_mask[[ 0 0 0 ..., 7 7 7] [ 0 0 0 ..., 7 7 7] [ 0 0 0 ..., 7
- 开发工具Python版本:3.6.4相关模块:cv2模块;以及一些Python自带的模块。环境搭建安装Python并添加到环境变量,pip安
- python实现两个文本合并employee文件中记录了工号和姓名cat employee.txt:100 Jason Smith200 J
- 经常使用python检测服务器是否能ping通, 程序是否正常运行(检测对应的端口是否正常)以前使用shell脚本的写法如下:PINGRET
- 使用自带的函数就可以实现:lineEdit.setEchoMode(QLineEdit.Password)import structfrom