Keras构建神经网络踩坑(解决model.predict预测值全为0.0的问题)
作者:qq_42972774 发布时间:2023-03-28 04:05:22
标签:Keras,model,predict,预测值
终于构建出了第一个神经网络,Keras真的很方便。
之前不知道Keras这么方便,在构建神经网络的过程中绕了很多弯路,最开始学的TensorFlow,后来才知道Keras。
TensorFlow和Keras的关系,就像c语言和python的关系,所以Keras是真的好用。
搞不清楚数据的标准化和归一化的关系,想对原始数据做归一化,却误把数据做了标准化,导致用model.predict预测出来的值全是0.0,在网上搜了好久但是没搜到答案,后来自己又把程序读了一遍,突然灵光一现好像是数据归一化出了问题,于是把数据预处理部分的标准化改成了归一化,修改过来之后才能正常预测出来值,才得到应有的数据趋势。
标准化:
(x-mean(x))/std(x) 这是使用z-score方法规范化
归一化:
(x-min(x))/(max(x)-min(x)) 这是常用的最小最大规范化方法
补充知识:keras加载已经训练好的模型文件,进行预测时却发现预测结果几乎为同一类(本人预测时几乎均为为第0类)**
原因:在进行keras训练时候,使用了keras内置的数据读取方式,但是在进行预测时候,使用了自定义的数据读取方式,本人为图片读取。
解决办法查看如下代码:
##############训练:
train_gen = ImageDataGenerator(rotation_range=10,
width_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
fill_mode='constant',
cval=0)
train_generator = train_gen.flow_from_directory(train_path,
target_size=(224, 224),
batch_size=16,
class_mode='categorical',
save_to_dir=train_g,
save_prefix='man',
save_format='jpg')
#############预测
img = cv2.imread(img_path)
img = cv2.resize(img, (row, col))
img = np.expands(img, axis=0)
out = model.predict(img)
# 上述方法是不行的,仔细查看keras内置读取方式,可以观察到内置了load_img方式
# 因此,我们在预测时候,将读取图片的方式改为
from keras.preprocessing.image import load_img, img_to_array
img = load_img(img_path)
img = img_to_array(img, target_size=(row, col))
img = np.expands(img, axis=0)
out = model.predict(img)
注:本文意在说明 对训练数据和预测数据的读取、预处理方式上应该在某种程度上保持一致,从而避免训练结果和真实预测结果相差过大的情况。
来源:https://blog.csdn.net/qq_42972774/article/details/105101935


猜你喜欢
- 论坛经常有人会问到用CSS如何美化Select标签,其实但凡你看到很酷的都是用javascript来实现的。昨天试着做了一下,基本实现的初级
- 如下所示:# -*- coding: utf-8 -*-from __future__ import unicode_literalsfro
- 自己从工艺品设计到平面设计到网络设计,虽然设计原则不离其宗,但经验下来的心得告诉自己,设计媒介的变化带来很多媒介自身的特殊性,下面总结下网站
- think-queue是ThinkPHP官方提供的一个消息队列服务,是专门支持队列服务的扩展包。think-queue消息队列适用于大并发或
- 今天给大家分享一个简单的python脚本,使用python进行http的接口测试,脚本很简单,逻辑是:读取excel写好的测试用例,然后根据
- 本文实例讲述了Python PyInstaller库基本使用方法。分享给大家供大家参考,具体如下:概述将.py源码转换成无需源代码的可执行文
- 内存管理:概述在Python中,内存管理涉及到一个包含所有Python对象和数据结构的私有堆(heap). 这个私有堆的管理由内部的Pyth
- python实现学生信息管理系统,供大家参考,具体内容如下#!/usr/bin/env python# -*- coding:utf-8 -
- 本篇阅读的代码片段来自于30-seconds-of-python。1. count_bydef count_by(arr, fn=lambd
- 前言最近学习了Fiddler抓包工具的简单使用,通过抓包,我们可以抓取到HTTP请求,并对其进行分析。现在我准备尝试着结合Python来模拟
- 1.概述最近项目需要使用程序实现数学微积分,最初想用java实现,后来发现可用文档太少,实现比较麻烦,后来尝试使用python实现,代码量较
- --新增表字段 ALTER procedure [dbo].[sp_Web_TableFiled_Insert] ( @TableName
- 本文实例讲述了Python使用Flask框架同时上传多个文件的方法,分享给大家供大家参考。具体如下:下面的演示代码带有详细的html页面和p
- 1. 信号与槽(Signals and slots)信号与槽机制是 PyQt 的核心机制,用于对象之间的通信,也就是实现函数之间的自动调用。
- 一、DSE算法背景介绍1. DES的采用1979年,美国银行协会批准使用1980年,美国国家标准局(ANSI)赞同DES作为私人使用的标准,
- 获取每一天的统计数据做项目的时候需要统对项目日志做分析,其中有一个需求是获取某个给定的时间段内,每一天的日志数据,比如说要获取从2018-0
- 中文繁体、简体的差异,在NPL中类似英文中的大小写,但又比大小写更为复杂,比如同样为繁体字,大陆、香港和台湾又不一样。先前写过一篇中文繁简转
- OAuth是一个关于授权(authorization)的开放网络标准,在全世界得到广泛应用,目前的版本是2.0版。本文对OAuth 2.0的
- 在页面中的链接除了常规的方式以外,如果使用javascript,还有很多种方式,下面是一些使用javascript,打开链接的几种方式:1.
- 下文通过图文并茂的方式给大家介绍mssqlserver数据库导出到另外一个数据库的方法,具体详情请看下文。1.准备源数据库,找到想要导出的数