keras topN显示,自编写代码案例
作者:姚贤贤 发布时间:2021-03-19 03:15:13
标签:keras,topN显示,自编
对于使用已经训练好的模型,比如VGG,RESNET等,keras都自带了一个keras.applications.imagenet_utils.decode_predictions的方法,有很多限制:
def decode_predictions(preds, top=5):
"""Decodes the prediction of an ImageNet model.
# Arguments
preds: Numpy tensor encoding a batch of predictions.
top: Integer, how many top-guesses to return.
# Returns
A list of lists of top class prediction tuples
`(class_name, class_description, score)`.
One list of tuples per sample in batch input.
# Raises
ValueError: In case of invalid shape of the `pred` array
(must be 2D).
"""
global CLASS_INDEX
if len(preds.shape) != 2 or preds.shape[1] != 1000:
raise ValueError('`decode_predictions` expects '
'a batch of predictions '
'(i.e. a 2D array of shape (samples, 1000)). '
'Found array with shape: ' + str(preds.shape))
if CLASS_INDEX is None:
fpath = get_file('imagenet_class_index.json',
CLASS_INDEX_PATH,
cache_subdir='models',
file_hash='c2c37ea517e94d9795004a39431a14cb')
with open(fpath) as f:
CLASS_INDEX = json.load(f)
results = []
for pred in preds:
top_indices = pred.argsort()[-top:][::-1]
result = [tuple(CLASS_INDEX[str(i)]) + (pred[i],) for i in top_indices]
result.sort(key=lambda x: x[2], reverse=True)
results.append(result)
return results
把重要的东西挖出来,然后自己敲,这样就OK了,下例以MNIST数据集为例:
import keras
from keras.models import Sequential
from keras.layers import Dense
import numpy as np
import tflearn
import tflearn.datasets.mnist as mnist
def decode_predictions_custom(preds, top=5):
CLASS_CUSTOM = ["0","1","2","3","4","5","6","7","8","9"]
results = []
for pred in preds:
top_indices = pred.argsort()[-top:][::-1]
result = [tuple(CLASS_CUSTOM[i]) + (pred[i]*100,) for i in top_indices]
results.append(result)
return results
x_train, y_train, x_test, y_test = mnist.load_data(one_hot=True)
model = Sequential()
model.add(Dense(units=64, activation='relu', input_dim=784))
model.add(Dense(units=10, activation='softmax'))
model.compile(loss='categorical_crossentropy',
optimizer='sgd',
metrics=['accuracy'])
model.fit(x_train, y_train, epochs=10, batch_size=128)
# score = model.evaluate(x_test, y_test, batch_size=128)
# print(score)
preds = model.predict(x_test[0:1,:])
p = decode_predictions_custom(preds)
for (i,(label,prob)) in enumerate(p[0]):
print("{}. {}: {:.2f}%".format(i+1, label,prob))
# 1. 7: 99.43%
# 2. 9: 0.24%
# 3. 3: 0.23%
# 4. 0: 0.05%
# 5. 2: 0.03%
补充知识:keras简单的去噪自编码器代码和各种类型自编码器代码
我就废话不多说了,大家还是直接看代码吧~
start = time()
from keras.models import Sequential
from keras.layers import Dense, Dropout,Input
from keras.layers import Embedding
from keras.layers import Conv1D, GlobalAveragePooling1D, MaxPooling1D
from keras import layers
from keras.models import Model
# Parameters for denoising autoencoder
nb_visible = 120
nb_hidden = 64
batch_size = 16
# Build autoencoder model
input_img = Input(shape=(nb_visible,))
encoded = Dense(nb_hidden, activation='relu')(input_img)
decoded = Dense(nb_visible, activation='sigmoid')(encoded)
autoencoder = Model(input=input_img, output=decoded)
autoencoder.compile(loss='mean_squared_error',optimizer='adam',metrics=['mae'])
autoencoder.summary()
# Train
### 加一个early_stooping
import keras
early_stopping = keras.callbacks.EarlyStopping(
monitor='val_loss',
min_delta=0.0001,
patience=5,
verbose=0,
mode='auto'
)
autoencoder.fit(X_train_np, y_train_np, nb_epoch=50, batch_size=batch_size , shuffle=True,
callbacks = [early_stopping],verbose = 1,validation_data=(X_test_np, y_test_np))
# Evaluate
evaluation = autoencoder.evaluate(X_test_np, y_test_np, batch_size=batch_size , verbose=1)
print('val_loss: %.6f, val_mean_absolute_error: %.6f' % (evaluation[0], evaluation[1]))
end = time()
print('耗时:'+str((end-start)/60))
keras各种自编码代码
来源:https://blog.csdn.net/u011311291/article/details/79991716
0
投稿
猜你喜欢
- 爬取网站为:http://xiaohua.zol.com.cn/youmo/查看网页机构,爬取笑话内容时存在如下问题:1、每页需要进入“查看
- 需 求 分 析 1、读取指定目录下的所有文件2、读取指定文件,输出文件内容3、创建一个文件并保存到指定目录实 现 过 程Python写代码简
- 原则一:注意WHERE子句中的连接顺序: ORACLE采用自下而上的顺序解析WHERE子句,根据这个原理,表之间的连接必须写在其他WHERE
- 很久之前曾经总结过一篇博客“MySQL如何找出未提交事务信息”,现在看来,这篇文章中不少知识点或观点都略显肤浅,或者说不够深入,甚至部分结论
- Python学习第一篇。把之前学习的Python基础知识总结一下。一、认识Python首先我们得清楚这个:Python这个名字是从Monty
- 如下所示:# coding=gbkfrom PIL import Imageimport numpy as np# import scipy
- MongoDB简介MongoDB 是由C++语言编写的,是一个基于分布式文件存储的开源数据库系统。在高负载的情况下,添加更多的节点,可以保证
- 网上找了半天 不是dataframe转化成array的就是array转化dataframe,所以这里给汇总一下,相互转换的python代如下
- 一.values()1.values()结果是什么?官方文档说明:https://docs.djangoproject.com/en/2.1
- django在使用外键ForeignKey的时候,会自动给当前字段后面添加一个后缀_id。正常来说这样并不会影响使用。除非你要写原生sql,
- MySQL UNION 操作符本教程为大家介绍 MySQL UNION 操作符的语法和实例。描述MySQL UNION 操作符用于连接两个以
- Doug Bowman,Google的Visual Design Lead离职了,一封带有感 * 彩的离职信惹发了大家不少的讨论。甚至还有人用
- Go反射的实现和 interface 和 unsafe.Pointer 密切相关。如果对golang的 interface 底层实现还没有理
- 有时引用其它js时,其js却使用了window.onload事件,这样的话,引入的页面的onload事件就有可能执行不了,怎样才能两个都运行
- 本文实例讲述了Python实现利用最大公约数求三个正整数的最小公倍数。分享给大家供大家参考,具体如下:在求解两个数的小公倍数的方法时,假设两
- CacheControl 属性设置是否可缓存由 ASP 生成的输出。默认地,代理服务器不会保持缓存副本。语法:response.CacheC
- 最近切换到了Ubuntu的系统作为工作环境, 在使用Pycharm的时候, 出现了个奇怪的问题中文是无法正常输入的, 然后找遍了网上的解决办
- 只能是一些限定的东西运行代码框ENTER键可以让光标移到下一个输入框 <input onkeydown="if(event.
- 本文实例讲述了php中加密解密DES类的简单使用方法。分享给大家供大家参考,具体如下:在平时的开发工作中,我们经常会对关键字符进行加密,可能
- Python 使用 selenium 进行自动化测试 或者协助日常工作,内容如下所示:1、基础准备需要准备 Python 环境需要安装 se