keras处理欠拟合和过拟合的实例讲解
作者:Lzj000lzj 发布时间:2022-06-23 05:14:38
标签:keras,欠拟合,过拟合
baseline
import tensorflow.keras.layers as layers
baseline_model = keras.Sequential(
[
layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
layers.Dense(16, activation='relu'),
layers.Dense(1, activation='sigmoid')
]
)
baseline_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
baseline_model.summary()
baseline_history = baseline_model.fit(train_data, train_labels,
epochs=20, batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
小模型
small_model = keras.Sequential(
[
layers.Dense(4, activation='relu', input_shape=(NUM_WORDS,)),
layers.Dense(4, activation='relu'),
layers.Dense(1, activation='sigmoid')
]
)
small_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
small_model.summary()
small_history = small_model.fit(train_data, train_labels,
epochs=20, batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
大模型
big_model = keras.Sequential(
[
layers.Dense(512, activation='relu', input_shape=(NUM_WORDS,)),
layers.Dense(512, activation='relu'),
layers.Dense(1, activation='sigmoid')
]
)
big_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
big_model.summary()
big_history = big_model.fit(train_data, train_labels,
epochs=20, batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
绘图比较上述三个模型
def plot_history(histories, key='binary_crossentropy'):
plt.figure(figsize=(16,10))
for name, history in histories:
val = plt.plot(history.epoch, history.history['val_'+key],
'--', label=name.title()+' Val')
plt.plot(history.epoch, history.history[key], color=val[0].get_color(),
label=name.title()+' Train')
plt.xlabel('Epochs')
plt.ylabel(key.replace('_',' ').title())
plt.legend()
plt.xlim([0,max(history.epoch)])
plot_history([('baseline', baseline_history),
('small', small_history),
('big', big_history)])
三个模型在迭代过程中在训练集的表现都会越来越好,并且都会出现过拟合的现象
大模型在训练集上表现更好,过拟合的速度更快
l2正则减少过拟合
l2_model = keras.Sequential(
[
layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
activation='relu', input_shape=(NUM_WORDS,)),
layers.Dense(16, kernel_regularizer=keras.regularizers.l2(0.001),
activation='relu'),
layers.Dense(1, activation='sigmoid')
]
)
l2_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
l2_model.summary()
l2_history = l2_model.fit(train_data, train_labels,
epochs=20, batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
plot_history([('baseline', baseline_history),
('l2', l2_history)])
可以发现正则化之后的模型在验证集上的过拟合程度减少
添加dropout减少过拟合
dpt_model = keras.Sequential(
[
layers.Dense(16, activation='relu', input_shape=(NUM_WORDS,)),
layers.Dropout(0.5),
layers.Dense(16, activation='relu'),
layers.Dropout(0.5),
layers.Dense(1, activation='sigmoid')
]
)
dpt_model.compile(optimizer='adam',
loss='binary_crossentropy',
metrics=['accuracy', 'binary_crossentropy'])
dpt_model.summary()
dpt_history = dpt_model.fit(train_data, train_labels,
epochs=20, batch_size=512,
validation_data=(test_data, test_labels),
verbose=2)
plot_history([('baseline', baseline_history),
('dropout', dpt_history)])
批正则化
model = keras.Sequential([
layers.Dense(64, activation='relu', input_shape=(784,)),
layers.BatchNormalization(),
layers.Dense(64, activation='relu'),
layers.BatchNormalization(),
layers.Dense(64, activation='relu'),
layers.BatchNormalization(),
layers.Dense(10, activation='softmax')
])
model.compile(optimizer=keras.optimizers.SGD(),
loss=keras.losses.SparseCategoricalCrossentropy(),
metrics=['accuracy'])
model.summary()
history = model.fit(x_train, y_train, batch_size=256, epochs=100, validation_split=0.3, verbose=0)
plt.plot(history.history['accuracy'])
plt.plot(history.history['val_accuracy'])
plt.legend(['training', 'validation'], loc='upper left')
plt.show()
总结
防止神经网络中过度拟合的最常用方法:
获取更多训练数据。
减少网络容量。
添加权重正规化。
添加dropout。
来源:https://blog.csdn.net/Lzj000lzj/article/details/94132842
0
投稿
猜你喜欢
- 1.视频分解图片我们使用cv2.VideoCapture来读取视频import cv2cap = cv2.VideoCapture('
- 运算符重载意味着赋予超出其预定义的操作含义的扩展含义。例如运算符 + 用于添加两个整数以及连接两个字符串和合并两个列表。这是可以实现的,因为
- 这几天转了几个内容包含日语的贴,结果发现搜索数据库时出现“内存溢出”错误。上网搜索寻求答案未果。最后才发现这就是传说中的“日文 26 个片假
- 准备工作本文用到的表格内容如下:先来看一下原始情形:import pandas as pddf = pd.read_excel(r'
- 简要pyinstaller模块主要用于python代码打包成exe程序直接使用,这样在其它电脑上即使没有python环境也是可以运行的。用法
- 方法1:自定义异常# -*- coding:utf-8 -*-"""功能:python跳出循环"&q
- 1 引言如果你想对图像进行校准,那么透视变换是非常有效的变换手段。透视变换的定义为将图像投影到一个新的视平面,通常也被称之为投影映射。2 公
- 属性在运行时的动态替换,叫做猴子补丁(Monkey Patch)。为什么叫猴子补丁属性的运行时替换和猴子也没什么关系,关于猴子补丁的由来网上
- python2.7中 集成了json的处理(simplejson),但在实际应用中,从mysql查询出来的数据,通常有日期格式,这时候,会报
- 1、汉语分词的由来使用ASP开发的中小企业的网站,对于站内搜索,往往只是简单的通过SQL语句匹配数据库。对于比较短的词语搜索,这个方法是有效
- 将Python数据类型转换为其他代码格式叫做(序列化),而json就是在各个代码实现转换的中间件。序列化要求:1. 只能有int,str,b
- 8是典型的七段数码管的例子,因为刚好七段都有经过,这里我写的代码是从1开始右转。这是看Mooc视频写的一个关于用七段数码管显示当前时间# -
- 参考Tensorflow Machine Leanrning Cookbooktf.ConfigProto()主要的作用是配置tf.Sess
- 下面示例代码是防止用网页刷新过快,如果多个页面使用,最好将<%...%>代码存为一个asp文件,在需要的页面最前面include
- 问题:在安装SP4补丁的时候,老是报验证密码错误。上网查了一下资料,发现是一个小bug。按照一下操作,安装正常。SQL Server补丁安装
- Python 中的运算符什么是运算符?举个简单的例子 4 +5 = 9 。 例子中,4 和 5 被称为操作数,"+" 称
- 本文实例讲述了Python实现的爬取豆瓣电影信息功能。分享给大家供大家参考,具体如下:本案例的任务为,爬取豆瓣电影top250的电影信息(包
- 什么是自省?在日常生活中,自省(introspection)是一种自我检查行为。在计算机编程中,自省是指这种能力:检查某些事物以确定它是什么
- 内容摘要:图片随机显示是一个应用非常广泛的技巧。比如随机banner的显示,当你进入一个网站时它的banner总是不同的,或者总有内容不同的
- scikit-learn是python的第三方机器学习库,里面集成了大量机器学习的常用方法。例如:贝叶斯,svm,knn等。scikit-l