tensorflow2.0保存和恢复模型3种方法
作者:李宜君 发布时间:2023-03-07 01:06:03
方法1:只保存模型的权重和偏置
这种方法不会保存整个网络的结构,只是保存模型的权重和偏置,所以在后期恢复模型之前,必须手动创建和之前模型一模一样的模型,以保证权重和偏置的维度和保存之前的相同。
tf.keras.model类中的save_weights方法和load_weights方法,参数解释我就直接搬运官网的内容了。
save_weights(
filepath,
overwrite=True,
save_format=None
)
Arguments:
filepath: String, path to the file to save the weights to. When saving in TensorFlow format, this is the prefix used for checkpoint files (multiple files are generated). Note that the '.h5' suffix causes weights to be saved in HDF5 format.
overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.
save_format: Either 'tf' or 'h5'. A filepath ending in '.h5' or '.keras' will default to HDF5 if save_format is None. Otherwise None defaults to 'tf'.
load_weights(
filepath,
by_name=False
)
实例1:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# step2 创建模型
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
y=y_train,
epochs=1,
)
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
# step6 保存模型的权重和偏置
model.save_weights('./save_weights/my_save_weights')
# step7 删除模型
del model
# step8 重新创建模型
model = create_model()
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# step9 恢复权重
model.load_weights('./save_weights/my_save_weights')
# step10 测试模型
loss, acc = model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))
train model, accuracy:96.55%
Restored model, accuracy:96.55%
可以看到在模型的权重和偏置恢复之后,在测试集合上同样达到了训练之前相同的准确率。
方法2:直接保存整个模型
这种方法会将网络的结构,权重和优化器的状态等参数全部保存下来,后期恢复的时候就没必要创建新的网络了。
tf.keras.model类中的save方法和load_model方法
save(
filepath,
overwrite=True,
include_optimizer=True,
save_format=None
)
Arguments:
filepath: String, path to SavedModel or H5 file to save the model.
overwrite: Whether to silently overwrite any existing file at the target location, or provide the user with a manual prompt.
include_optimizer: If True, save optimizer's state together.
save_format: Either 'tf' or 'h5', indicating whether to save the model to Tensorflow SavedModel or HDF5. The default is currently 'h5', but will switch to 'tf' in TensorFlow 2.0. The 'tf' option is currently disabled (use tf.keras.experimental.export_saved_model instead).
实例2:
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers
# step1 加载训练集和测试集合
mnist = tf.keras.datasets.mnist
(x_train, y_train),(x_test, y_test) = mnist.load_data()
x_train, x_test = x_train / 255.0, x_test / 255.0
# step2 创建模型
def create_model():
return tf.keras.models.Sequential([
tf.keras.layers.Flatten(input_shape=(28, 28)),
tf.keras.layers.Dense(512, activation='relu'),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation='softmax')
])
model = create_model()
# step3 编译模型 主要是确定优化方法,损失函数等
model.compile(optimizer='adam',
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
# step4 模型训练 训练一个epochs
model.fit(x=x_train,
y=y_train,
epochs=1,
)
# step5 模型测试
loss, acc = model.evaluate(x_test, y_test)
print("train model, accuracy:{:5.2f}%".format(100 * acc))
# step6 保存模型的权重和偏置
model.save('my_model.h5') # creates a HDF5 file 'my_model.h5'
# step7 删除模型
del model # deletes the existing model
# step8 恢复模型
# returns a compiled model
# identical to the previous one
restored_model = tf.keras.models.load_model('my_model.h5')
# step9 测试模型
loss, acc = restored_model.evaluate(x_test, y_test)
print("Restored model, accuracy:{:5.2f}%".format(100 * acc))
train model, accuracy:96.94%
Restored model, accuracy:96.94%
方法3:使用tf.keras.callbacks.ModelCheckpoint方法在训练过程中保存模型
该方法继承自tf.keras.callbacks类,一般配合mode.fit函数使用
来源:https://blog.csdn.net/wwwlyj123321/article/details/94291992


猜你喜欢
- 本文实例讲述了jQuery实现简单复制json对象和json对象集合操作。分享给大家供大家参考,具体如下:<!DOCTYPE html
- 我就废话不多说了,直接上代码吧!import pandas as pdimport numpy as npimport matplotlib
- 关于php的引用(就是在变量或者函数、对象等前面加上&符号)的作用,我们先看下面这个程序。<?php
- 本文实例为大家分享了pytorch实现手写数字图片识别的具体代码,供大家参考,具体内容如下数据集:MNIST数据集,代码中会自动下载,不用自
- 1 概述C/C++和Java(以及大多数的主流编程语言)都有自己成熟的单元测试框架,前者如Check,后者如JUnit,但这些编程框架本质上
- CSS重设就是由于各种浏览器解释CSS样式的初始值有所不同,导致设计师在没有定义某个CSS属性时,不同的浏览器会按照自己的默认值来为没有定义
- Dreamweaver中一直变色的超级链接,css+javascript实现超级链接变色,当鼠标移动到链接上时,链接的颜色不停闪烁变色。&l
- 一、去除空格strip()" xyz ".strip() &n
- 基于pytorch来讲MSELoss()多用于回归问题,也可以用于one_hotted编码形式,CrossEntropyLoss()名字为交
- 什么是下载?首先客户端会问服务器,有没有一个xxx的文件啊?服务器开始寻找,找到后对客户端说有,然后客户端在本地新建一个文件,客户端从服务器
- 带你了解CGO编程大学时最开始学的语言莫过于C/C++,C/C++经过几十年的发展,已经积累了庞大的软件资产,它们很多久经考验而且性能已经足
- 如何使用模板系统让我们深入研究模板系统,你将会明白它是如何工作的。但我们暂不打算将它与先前创建的视图结合在一起,因为我们现在的目的是了解它是
- 当设计一个产品,其中很多地方要把日期类型保存到数据库中,如果产品有兼容不同数据库产品的需求,那么,应当怎样设计呢?当然,首先想到的是,使用数
- 官方给出Vue.filters(id , [definition])//id {string}//definition {function}
- vue中代码的复用, 为我们提供了 mixnis. 模板的复用, 为我们提供了 插槽( slot )插槽的分类默认插槽具名插槽作用域插槽当我
- 最近在着手支付宝个人版改版的项目,正好在一些国内知名的SNS网站上分别注册了帐户进行体验。显然一点,国内的SNS都带有Facebook的影子
- 最近JETBRAINS发布了目前最受欢迎的python-web开发框架,可以看到最受欢迎的还是Django和Flask,那么本文就对上榜的1
- 这篇文章主要介绍了python基于event实现线程间通信控制,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,
- 前言Python是面向对象的程序设计(Object Oriented Programming)。面向对象的程序设计的一条基本原则是:计算机程
- 这10个asp处理网页编码转换的函数,不知何时收藏在我的电脑中,今天刚好看到了,拿出来与大家分享,这里各种常见的网页编码问题已经