Keras保存模型并载入模型继续训练的实现
作者:凌逆战 发布时间:2021-08-12 23:23:32
标签:Keras,保存模型,加载模型
我们以MNIST手写数字识别为例
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
# 创建模型,输入784个神经元,输出10个神经元
model = Sequential([
Dense(units=10,input_dim=784,bias_initializer='one',activation='softmax')
])
# 定义优化器
sgd = SGD(lr=0.2)
# 定义优化器,loss function,训练过程中计算准确率
model.compile(
optimizer = sgd,
loss = 'mse',
metrics=['accuracy'],
)
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=5)
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print('\ntest loss',loss)
print('accuracy',accuracy)
# 保存模型
model.save('model.h5') # HDF5文件,pip install h5py
载入初次训练的模型,再训练
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from keras.optimizers import SGD
from keras.models import load_model
# 载入数据
(x_train,y_train),(x_test,y_test) = mnist.load_data()
# (60000,28,28)
print('x_shape:',x_train.shape)
# (60000)
print('y_shape:',y_train.shape)
# (60000,28,28)->(60000,784)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
# 换one hot格式
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
# 载入模型
model = load_model('model.h5')
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print('\ntest loss',loss)
print('accuracy',accuracy)
# 训练模型
model.fit(x_train,y_train,batch_size=64,epochs=2)
# 评估模型
loss,accuracy = model.evaluate(x_test,y_test)
print('\ntest loss',loss)
print('accuracy',accuracy)
# 保存参数,载入参数
model.save_weights('my_model_weights.h5')
model.load_weights('my_model_weights.h5')
# 保存网络结构,载入网络结构
from keras.models import model_from_json
json_string = model.to_json()
model = model_from_json(json_string)
print(json_string)
关于compile和load_model()的使用顺序
这一段落主要是为了解决我们fit、evaluate、predict之前还是之后使用compile。想要弄明白,首先我们要清楚compile在程序中是做什么的?都做了什么?
compile做什么?
compile定义了loss function损失函数、optimizer优化器和metrics度量。它与权重无关,也就是说compile并不会影响权重,不会影响之前训练的问题。
如果我们要训练模型或者评估模型evaluate,则需要compile,因为训练要使用损失函数和优化器,评估要使用度量方法;如果我们要预测,则没有必要compile模型。
是否需要多次编译?
除非我们要更改其中之一:损失函数、优化器 / 学习率、度量
又或者我们加载了尚未编译的模型。或者您的加载/保存方法没有考虑以前的编译。
再次compile的后果?
如果再次编译模型,将会丢失优化器状态.
这意味着您的训练在开始时会受到一点影响,直到调整学习率,动量等为止。但是绝对不会对重量造成损害(除非您的初始学习率如此之大,以至于第一次训练步骤疯狂地更改微调的权重)。
来源:https://www.cnblogs.com/LXP-Never/p/11601404.html


猜你喜欢
- 前言 Tensorflow中可以使用tensorboard这个强大的工具对计算图、loss、网络参数等进行可视化。本文并不涉及对tensor
- 写这个的目地,主要是系统理下目前产品设计的流程,提醒自己尽量去避免一些常见的问题,也能让大家系统的了解天极网的产品设计流程。当然,不保证任何
- 写在前面和小伙伴们分享一些Python 网络编程的一些笔记,博文为《Python Cookbook》读书后笔记整理博文涉及内容包括:TCP/
- Python 三元运算符Python 三元运算符用于根据条件选择两个值之一。它是 if-else 语句的一个缩影,它将两个值之一分配给一个变
- 最近服务器升级到了win2008 r2,数据库也从sql2000升级到了sql2005,不过安装后发现sql server找不到服务器名这样
- 如下所示:screen.widthscreen.heightscreen.availHeight //获取去除状态栏后的屏幕高度screen
- 1.python实现对doc文档的读取#读取docx中的文本代码示例import docx#获取文档对象file=docx.Document
- 写在前面最近在更新我服务器上的python以及pip版本的时候,碰见了令人头痛的问题,就是我执行了升级指令之后,升级也正常的Successf
- 本文的完整代码在github.com/hdt3213/godis/redis/client通常 TCP 客户端的通信模式都是阻塞式的: 客户
- 题目描述原题链接 :66. 加一给定一个由 整数 组成的 非空 数组所表示的非负整数,在该数的基础上加一。最高位数字存放在数组的首位, 数组
- SpringSecurity? Spring Security是一个能够为基于Spring的企业应用系统提供声明式的安全访问控制解
- Python 中文编码Python 文件中如果未指定编码,在执行过程会出现报错:Python中默认的编码格式是 ASCII 格式,在没修改编
- 前言本系统是基于fabric.js实现的canvas版图片,文本编辑器,支持对图片的放大,缩小,旋转,镜面翻转,拖动,显示/隐藏图层,删除图
- 本文实例讲述了Python 字符串、列表、元组的截取与切片操作。分享给大家供大家参考,具体如下:demo.py(字符串、列表、元组的截取):
- 本文实例讲述了Python实现的计算器功能。分享给大家供大家参考,具体如下:源码:# -*- coding:utf-8 -*-#! pyth
- Oracle Tips, Tricks & Scripts1. Topic: Compiling Invalid Objects:O
- 如下所示:# -*- coding: utf-8 -*-# @Time :18-8-2 下午3:23import sysreload(sys
- 这篇文章主要介绍了python3 tcp的粘包现象和解决办法解析,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值
- 我从Stephen A. Goss那读到关于了《Python 3正在毁灭Python》。这篇文章有不少精彩的论点,但我却并不认为Python
- 前言本文主要给大家介绍了关于python图片添加半透明水印的相关资料,分享出来供大家参考学习,下面话不多说了,来一起看看详细的介绍吧示例代码