keras自动编码器实现系列之卷积自动编码器操作
作者:xuyang2233 发布时间:2023-12-31 18:33:15
标签:keras,卷积,编码器
图片的自动编码很容易就想到用卷积神经网络做为编码-解码器。在实际的操作中,
也经常使用卷积自动编码器去解决图像编码问题,而且非常有效。
下面通过**keras**完成简单的卷积自动编码。 编码器有堆叠的卷积层和池化层(max pooling用于空间降采样)组成。 对应的解码器由卷积层和上采样层组成。
@requires_authorization
# -*- coding:utf-8 -*-
from keras.layers import Input, Dense, Conv2D, MaxPooling2D, UpSampling2D
from keras.models import Model
from keras import backend as K
import os
## 网络结构 ##
input_img = Input(shape=(28,28,1)) # Tensorflow后端, 注意要用channel_last
# 编码器部分
x = Conv2D(16, (3,3), activation='relu', padding='same')(input_img)
x = MaxPooling2D((2,2), padding='same')(x)
x = Conv2D(8,(3,3), activation='relu', padding='same')(x)
x = MaxPooling2D((2,2), padding='same')(x)
x = Conv2D(8, (3,3), activation='relu', padding='same')(x)
encoded = MaxPooling2D((2,2), padding='same')(x)
# 解码器部分
x = Conv2D(8, (3,3), activation='relu', padding='same')(encoded)
x = UpSampling2D((2, 2))(x)
x = Conv2D(8, (3,3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
x = Conv2D(16, (3, 3), activation='relu', padding='same')(x)
x = UpSampling2D((2, 2))(x)
decoded = Conv2D(1, (3, 3), activation='sigmoid', padding='same')(x)
autoencoder = Model(input_img, decoded)
autoencoder.compile(optimizer='adam', loss='binary_crossentropy')
# 得到编码层的输出
encoder_model = Model(inputs=autoencoder.input, outputs=autoencoder.get_layer('encoder_out').output)
## 导入数据, 使用常用的手写识别数据集
def load_mnist(dataset_name):
'''
load the data
'''
data_dir = os.path.join("./data", dataset_name)
f = np.load(os.path.join(data_dir, 'mnist.npz'))
train_data = f['train'].T
trX = train_data.reshape((-1, 28, 28, 1)).astype(np.float32)
trY = f['train_labels'][-1].astype(np.float32)
test_data = f['test'].T
teX = test_data.reshape((-1, 28, 28, 1)).astype(np.float32)
teY = f['test_labels'][-1].astype(np.float32)
# one-hot
# y_vec = np.zeros((len(y), 10), dtype=np.float32)
# for i, label in enumerate(y):
# y_vec[i, y[i]] = 1
# keras.utils里带的有one-hot的函数, 就直接用那个了
return trX / 255., trY, teX/255., teY
# 开始导入数据
x_train, _ , x_test, _= load_mnist('mnist')
# 可视化训练结果, 我们打开终端, 使用tensorboard
# tensorboard --logdir=/tmp/autoencoder # 注意这里是打开一个终端, 在终端里运行
# 训练模型, 并且在callbacks中使用tensorBoard实例, 写入训练日志 http://0.0.0.0:6006
from keras.callbacks import TensorBoard
autoencoder.fit(x_train, x_train,
epochs=50,
batch_size=128,
shuffle=True,
validation_data=(x_test, x_test),
callbacks=[TensorBoard(log_dir='/tmp/autoencoder')])
# 重建图片
import matplotlib.pyplot as plt
decoded_imgs = autoencoder.predict(x_test)
encoded_imgs = encoder_model.predict(x_test)
n = 10
plt.figure(figsize=(20, 4))
for i in range(n):
k = i + 1
# 画原始图片
ax = plt.subplot(2, n, k)
plt.imshow(x_test[k].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
# 画重建图片
ax = plt.subplot(2, n, k + n)
plt.imshow(decoded_imgs[i].reshape(28, 28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
# 编码得到的特征
n = 10
plt.figure(figsize=(20, 8))
for i in range(n):
k = i + 1
ax = plt.subplot(1, n, k)
plt.imshow(encoded[k].reshape(4, 4 * 8).T)
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
补充知识:keras搬砖系列-单层卷积自编码器
考试成绩出来了,竟然有一门出奇的差,只是有点意外。
觉得应该不错的,竟然考差了,它估计写了个随机数吧。
头文件
from keras.layers import Input,Dense
from keras.models import Model
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
导入数据
(X_train,_),(X_test,_) = mnist.load_data()
X_train = X_train.astype('float32')/255.
X_test = X_test.astype('float32')/255.
X_train = X_train.reshape((len(X_train),-1))
X_test = X_test.reshape((len(X_test),-1))
这里的X_train和X_test的维度分别为(60000L,784L),(10000L,784L)
这里进行了归一化,将所有的数值除上255.
设定编码的维数与输入数据的维数
encoding_dim = 32
input_img = Input(shape=(784,))
构建模型
encoded = Dense(encoding_dim,activation='relu')(input_img)
decoded = Dense(784,activation='relu')(encoded)
autoencoder = Model(inputs = input_img,outputs=decoded)
encoder = Model(inputs=input_img,outputs=encoded)
encoded_input = Input(shape=(encoding_dim,))
decoder_layer = autoencoder.layers[-1]
deconder = Model(inputs=encoded_input,outputs = decoder_layer(encoded_input))
模型编译
autoencoder.compile(optimizer='adadelta',loss='binary_crossentropy')
模型训练
autoencoder.fit(X_train,X_train,epochs=50,batch_size=256,shuffle=True,validation_data=(X_test,X_test))
预测
encoded_imgs = encoder.predict(X_test)
decoded_imgs = deconder.predict(encoded_imgs)
数据可视化
n = 10
for i in range(n):
ax = plt.subplot(2,n,i+1)
plt.imshow(X_test[i].reshape(28,28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2,n,i+1+n)
plt.imshow(decoded_imgs[i].reshape(28,28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
完成代码
from keras.layers import Input,Dense
from keras.models import Model
from keras.datasets import mnist
import numpy as np
import matplotlib.pyplot as plt
(X_train,_),(X_test,_) = mnist.load_data()
X_train = X_train.astype('float32')/255.
X_test = X_test.astype('float32')/255.
X_train = X_train.reshape((len(X_train),-1))
X_test = X_test.reshape((len(X_test),-1))
encoding_dim = 32
input_img = Input(shape=(784,))
encoded = Dense(encoding_dim,activation='relu')(input_img)
decoded = Dense(784,activation='relu')(encoded)
autoencoder = Model(inputs = input_img,outputs=decoded)
encoder = Model(inputs=input_img,outputs=encoded)
encoded_input = Input(shape=(encoding_dim,))
decoder_layer = autoencoder.layers[-1]
deconder = Model(inputs=encoded_input,outputs = decoder_layer(encoded_input))
autoencoder.compile(optimizer='adadelta',loss='binary_crossentropy')
autoencoder.fit(X_train,X_train,epochs=50,batch_size=256,shuffle=True,validation_data=(X_test,X_test))
encoded_imgs = encoder.predict(X_test)
decoded_imgs = deconder.predict(encoded_imgs)
##via
n = 10
for i in range(n):
ax = plt.subplot(2,n,i+1)
plt.imshow(X_test[i].reshape(28,28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
ax = plt.subplot(2,n,i+1+n)
plt.imshow(decoded_imgs[i].reshape(28,28))
plt.gray()
ax.get_xaxis().set_visible(False)
ax.get_yaxis().set_visible(False)
plt.show()
来源:https://blog.csdn.net/xuyang508/article/details/74276181
0
投稿
猜你喜欢
- python中同样使用关键字class创建一个类,类名称第一个字母大写,可以带括号也可以不带括号;python中实例化类不需要使用关键字ne
- 人生苦短,我用python。看到这句话的时候,感觉可能确实是很深得人心,不过每每想学学,就又止步,年纪大了,感觉学什么东西都很慢,很难,精神
- SQL Server数据库查询速度慢的原因有很多,常见的有以下几种:1、没有索引或者没有用到索引(这是查询慢最常见的问题,是程序设计的缺陷)
- 本文实例为大家分享了wxPython实现计算器的具体代码,供大家参考,具体内容如下# -*- coding: utf-8 -*-######
- python字符串字符串是 Python 中最常用的数据类型。我们可以使用引号('或")来创建字符串。创建字符串很简单,只
- 一、安装插件要生成html类型的报告,需要使用pytest-html插件,可以在IDE中安装,也可以在命令行中安装。插件安装的位置涉及到不同
- 本文为大家分享了Eclipse开发python脚本的具体方法,供大家参考,具体内容如下一、安装python1.访问网址,可以看到如下图所示界
- 目录1. 理解 * 和 ** 2.Python函数的参数 3. 支持任意参数的函数
- 这些小东西是我在网上看到的就把它记下来了,可能以后会有用的: &nbs
- 1、由于 Python 列表的切片会在内存中创建新对象,因此需要注意的另一个重要函数是itertools.islice。2、通常需要遍历切片
- 一 异常处理1.什么是异常Error(错误)是系统中的错误,程序员是不能改变的和处理的,如系统崩溃,内存空间不足,方法调用栈溢等。遇到这样的
- PyCharm 是我用过的python编辑器中,比较顺手的一个。而且可以跨平台,在macos和windows下面都可以用,这点比较好。是py
- [前言] 我们经常会遇到多重查询问题,而长长的SQL语句往往让人丈二和尚摸不着头脑。特别是客户端部分填入
- 所使用python环境为最新的3.6版本Python中几种对文件的操作方法:将A文件复制到B文件中去(保持原来格式)读取文件中的内容,返回L
- 首先,我们需要着重介绍一些概念,以给你提供一些使这个“奇迹”得以发生的组成部分。太轻易地泄露伏笔对于讲故事来说不是个好的形式,所以那些不愿意
- Supervisor 是基于 Python 的进程管理工具,只能运行在 Unix-Like 的系统上,也就是无法运行在 Windows 上。
- PyTorch: https://github.com/shanglianlm0525/PyTorch-Networksimport tor
- DFA 算法是通过提前构造出一个 树状查找结构,之后根据输入在该树状结构中就可以进行非常高效的查找。设我们有一个敏感词库,词酷中的词汇为:我
- 前言在网上找了很多Python处理Excel的方法和代码,都不是很尽人意,所以自己综合网上各位大佬的方法,自己进行了优化,具体的代码如下。博
- 1.如何统计序列中元素出现的频率并排序?统计序列中元素出现的频率的结果肯定是一个字典,Key 为序列中的元素而 Value 为元素出现的次数