Tensorflow之MNIST CNN实现并保存、加载模型
作者:uflswe 发布时间:2023-10-16 06:21:33
标签:Tensorflow,MNIST,CNN,模型
本文实例为大家分享了Tensorflow之MNIST CNN实现并保存、加载模型的具体代码,供大家参考,具体内容如下
废话不说,直接上代码
# TensorFlow and tf.keras
import tensorflow as tf
from tensorflow import keras
# Helper libraries
import numpy as np
import matplotlib.pyplot as plt
import os
#download the data
mnist = keras.datasets.mnist
(train_images, train_labels), (test_images, test_labels) = mnist.load_data()
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
train_images = train_images / 255.0
test_images = test_images / 255.0
def create_model():
# It's necessary to give the input_shape,or it will fail when you load the model
# The error will be like : You are trying to load the 4 layer models to the 0 layer
model = keras.Sequential([
keras.layers.Conv2D(32,[5,5], activation=tf.nn.relu,input_shape = (28,28,1)),
keras.layers.MaxPool2D(),
keras.layers.Conv2D(64,[7,7], activation=tf.nn.relu),
keras.layers.MaxPool2D(),
keras.layers.Flatten(),
keras.layers.Dense(576, activation=tf.nn.relu),
keras.layers.Dense(10, activation=tf.nn.softmax)
])
model.compile(optimizer=tf.train.AdamOptimizer(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
return model
#reshape the shape before using it, for that the input of cnn is 4 dimensions
train_images = np.reshape(train_images,[-1,28,28,1])
test_images = np.reshape(test_images,[-1,28,28,1])
#train
model = create_model()
model.fit(train_images, train_labels, epochs=4)
#save the model
model.save('my_model.h5')
#Evaluate
test_loss, test_acc = model.evaluate(test_images, test_labels,verbose = 0)
print('Test accuracy:', test_acc)
模型保存后,自己手写了几张图片,放在文件夹C:\pythonp\testdir2下,开始测试
#Load the model
new_model = keras.models.load_model('my_model.h5')
new_model.compile(optimizer=tf.train.AdamOptimizer(),
loss='sparse_categorical_crossentropy',
metrics=['accuracy'])
new_model.summary()
#Evaluate
# test_loss, test_acc = new_model.evaluate(test_images, test_labels)
# print('Test accuracy:', test_acc)
#Predicte
mypath = 'C:\\pythonp\\testdir2'
def getimg(mypath):
listdir = os.listdir(mypath)
imgs = []
for p in listdir:
img = plt.imread(mypath+'\\'+p)
# I save the picture that I draw myself under Windows, but the saved picture's
# encode style is just opposite with the experiment data, so I transfer it with
# this line.
img = np.abs(img/255-1)
imgs.append(img[:,:,0])
return np.array(imgs),len(imgs)
imgs = getimg(mypath)
test_images = np.reshape(imgs[0],[-1,28,28,1])
predictions = new_model.predict(test_images)
plt.figure()
for i in range(imgs[1]):
c = np.argmax(predictions[i])
plt.subplot(3,3,i+1)
plt.xticks([])
plt.yticks([])
plt.imshow(test_images[i,:,:,0])
plt.title(class_names[c])
plt.show()
测试结果
自己手写的图片截的时候要注意,空白部分尽量不要太大,否则测试结果就呵呵了


猜你喜欢
- 如果机房马上要关门了,或者你急着要和MM约会,请直接跳到第四个自然段。以下叙述的脚本包括服务器端脚本和客户端的脚本,服务器端脚本指在服务器上
- 1.视图的概述 视图其实就是一条查询sql语句,用于显示一个或多个表或其他视图中的相关数据。视图将一个查询的结果作为一个表来使用,因此视图可
- 前言本篇文章分享一下我在实际开发小程序时遇到的需要获取用户当前位置的问题,在小程序开发过程中经常使用到的获取定位功能。uniapp官方也提供
- php输出全部gb2312编码内的汉字,$area表示分区,$pos表示分区内所在位置。<?php$fp = fopen('t
- 有了Selenium,还可以轻松操作Cookies,比如获取、添加、删除Cookies。具体代码如下:from selenium impor
- 众所周知,随着数据库体积的日益庞大,其备份文件的大小也水涨船高。虽然说通过差异备份与完全备份配套策略,可以大大的减小SQL Server数据
- 比如说在1-3000之内生成随机永不重复数,点击运行代码的时候请注意,此代码比较占用资源,如果硬件配置比较菜请把count改小。俺的电脑配置
- 一、config.ini 配置文件[DATABASE]host = 192.1.1.1username = rootpassword = r
- Python 调用JS文件中的函数方法如下1、安装PyExecJS第三方库2、导入库:import execjs3、调用JS文件中的方法Pa
- 1、 string的定义Golang中的string的定义在reflect包下的value.go中,定义如下:StringHeader 是字
- 想要根据django中的模型和配置生成SQL语句,需要先进行一定的设置:首先需要在你的app文件夹中进入setting.py文件,里面有一个
- 作为互联网产品设计师,在和前端开发人员沟通时你是否常常会听到这样的声音: —— “大姐,给点专业精神好不好,这个表格是自适应的,你
- MAC 中mysql密码忘记解决办法最近项目用到MySQL,之前装过一个,可是忘记了当时设置的密码,然后走上了修改密码的坎坷道路。在百度,G
- 在Web 开发中,JavaScript的一个很重要的作用就是对DOM进行操作,可你知道么?对DOM的操作是非常昂贵的,因为这会导致浏览器执行
- 这篇文章主要介绍了Python如何计算语句执行时间,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的朋友可
- 先上代码:import tensorflow as tfx = tf.ones(shape=[100, 200], dtype=tf.int
- 上文:成为一个顶级设计师的第二准则英文原文成为一个顶级设计师的第三准则:对比,对比,对比在设计里面,好的对比和你对颜色选择是密切相关的。对比
- 一、包说明分析context包:这个包分析的是1.15context包定义了一个Context类型(接口类型),通过这个Context接口类
- asp之家注:为什么要防止访客频繁刷新页面呢?也许你会说他想刷新就让他刷新吧,没什么关系,而且还增加了网页的PV,呵呵。但是有的页面我们可能
- Python数据类型分为值类型和引用类型, 下面我们看下它们的区别:值类型:对象本身不允许修改,数值的修改实际上是让变量指向了一个新的对象包