一小时学会TensorFlow2之自定义层
作者:我是小白呀 发布时间:2021-12-22 18:00:11
概述
通过自定义网络, 我们可以自己创建网络并和现有的网络串联起来, 从而实现各种各样的网络结构.
Sequential
Sequential 是 Keras 的一个网络容器. 可以帮助我们将多层网络封装在一起.
通过 Sequential 我们可以把现有的层已经我们自己的层实现结合, 一次前向传播就可以实现数据从第一层到最后一层的计算.
格式:
tf.keras.Sequential(
layers=None, name=None
)
例子:
# 5层网络模型
model = tf.keras.Sequential([
tf.keras.layers.Dense(256, activation=tf.nn.relu),
tf.keras.layers.Dense(128, activation=tf.nn.relu),
tf.keras.layers.Dense(64, activation=tf.nn.relu),
tf.keras.layers.Dense(32, activation=tf.nn.relu),
tf.keras.layers.Dense(10)
])
Model & Layer
通过 Model 和 Layer 的__init__
和call()
我们可以自定义层和模型.
Model:
class My_Model(tf.keras.Model): # 继承Model
def __init__(self):
"""
初始化
"""
super(My_Model, self).__init__()
self.fc1 = My_Dense(784, 256) # 第一层
self.fc2 = My_Dense(256, 128) # 第二层
self.fc3 = My_Dense(128, 64) # 第三层
self.fc4 = My_Dense(64, 32) # 第四层
self.fc5 = My_Dense(32, 10) # 第五层
def call(self, inputs, training=None):
"""
在Model被调用的时候执行
:param inputs: 输入
:param training: 默认为None
:return: 返回输出
"""
x = self.fc1(inputs)
x = tf.nn.relu(x)
x = self.fc2(x)
x = tf.nn.relu(x)
x = self.fc3(x)
x = tf.nn.relu(x)
x = self.fc4(x)
x = tf.nn.relu(x)
x = self.fc5(x)
return x
Layer:
class My_Dense(tf.keras.layers.Layer): # 继承Layer
def __init__(self, input_dim, output_dim):
"""
初始化
:param input_dim:
:param output_dim:
"""
super(My_Dense, self).__init__()
# 添加变量
self.kernel = self.add_variable("w", [input_dim, output_dim]) # 权重
self.bias = self.add_variable("b", [output_dim]) # 偏置
def call(self, inputs, training=None):
"""
在Layer被调用的时候执行, 计算结果
:param inputs: 输入
:param training: 默认为None
:return: 返回计算结果
"""
# y = w * x + b
out = inputs @ self.kernel + self.bias
return out
案例
数据集介绍
CIFAR-10 是由 10 类不同的物品组成的 6 万张彩 * 片的数据集. 其中 5 万张为训练集, 1 万张为测试集.
完整代码
import tensorflow as tf
def pre_process(x, y):
# 转换x
x = 2 * tf.cast(x, dtype=tf.float32) / 255 - 1 # 转换为-1~1的形式
x = tf.reshape(x, [-1, 32 * 32 * 3]) # 把x铺平
# 转换y
y = tf.convert_to_tensor(y) # 转换为0~1的形式
y = tf.one_hot(y, depth=10) # 转成one_hot编码
# 返回x, y
return x, y
def get_data():
"""
获取数据
:return:
"""
# 获取数据
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.cifar10.load_data()
# 调试输出维度
print(X_train.shape) # (50000, 32, 32, 3)
print(y_train.shape) # (50000, 1)
# squeeze
y_train = tf.squeeze(y_train) # (50000, 1) => (50000,)
y_test = tf.squeeze(y_test) # (10000, 1) => (10000,)
# 分割训练集
train_db = tf.data.Dataset.from_tensor_slices((X_train, y_train)).shuffle(10000, seed=0)
train_db = train_db.batch(batch_size).map(pre_process).repeat(iteration_num) # 迭代20次
# 分割测试集
test_db = tf.data.Dataset.from_tensor_slices((X_test, y_test)).shuffle(10000, seed=0)
test_db = test_db.batch(batch_size).map(pre_process)
return train_db, test_db
class My_Dense(tf.keras.layers.Layer): # 继承Layer
def __init__(self, input_dim, output_dim):
"""
初始化
:param input_dim:
:param output_dim:
"""
super(My_Dense, self).__init__()
# 添加变量
self.kernel = self.add_weight("w", [input_dim, output_dim]) # 权重
self.bias = self.add_weight("b", [output_dim]) # 偏置
def call(self, inputs, training=None):
"""
在Layer被调用的时候执行, 计算结果
:param inputs: 输入
:param training: 默认为None
:return: 返回计算结果
"""
# y = w * x + b
out = inputs @ self.kernel + self.bias
return out
class My_Model(tf.keras.Model): # 继承Model
def __init__(self):
"""
初始化
"""
super(My_Model, self).__init__()
self.fc1 = My_Dense(32 * 32 * 3, 256) # 第一层
self.fc2 = My_Dense(256, 128) # 第二层
self.fc3 = My_Dense(128, 64) # 第三层
self.fc4 = My_Dense(64, 32) # 第四层
self.fc5 = My_Dense(32, 10) # 第五层
def call(self, inputs, training=None):
"""
在Model被调用的时候执行
:param inputs: 输入
:param training: 默认为None
:return: 返回输出
"""
x = self.fc1(inputs)
x = tf.nn.relu(x)
x = self.fc2(x)
x = tf.nn.relu(x)
x = self.fc3(x)
x = tf.nn.relu(x)
x = self.fc4(x)
x = tf.nn.relu(x)
x = self.fc5(x)
return x
# 定义超参数
batch_size = 256 # 一次训练的样本数目
learning_rate = 0.001 # 学习率
iteration_num = 20 # 迭代次数
optimizer = tf.keras.optimizers.Adam(learning_rate=learning_rate) # 优化器
loss = tf.losses.CategoricalCrossentropy(from_logits=True) # 损失
network = My_Model() # 实例化网络
# 调试输出summary
network.build(input_shape=[None, 32 * 32 * 3])
print(network.summary())
# 组合
network.compile(optimizer=optimizer,
loss=loss,
metrics=["accuracy"])
if __name__ == "__main__":
# 获取分割的数据集
train_db, test_db = get_data()
# 拟合
network.fit(train_db, epochs=5, validation_data=test_db, validation_freq=1)
输出结果:
Model: "my__model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
my__dense (My_Dense) multiple 786688
_________________________________________________________________
my__dense_1 (My_Dense) multiple 32896
_________________________________________________________________
my__dense_2 (My_Dense) multiple 8256
_________________________________________________________________
my__dense_3 (My_Dense) multiple 2080
_________________________________________________________________
my__dense_4 (My_Dense) multiple 330
=================================================================
Total params: 830,250
Trainable params: 830,250
Non-trainable params: 0
_________________________________________________________________
None
(50000, 32, 32, 3)
(50000, 1)
2021-06-15 14:35:26.600766: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:176] None of the MLIR Optimization Passes are enabled (registered 2)
Epoch 1/5
3920/3920 [==============================] - 39s 10ms/step - loss: 0.9676 - accuracy: 0.6595 - val_loss: 1.8961 - val_accuracy: 0.5220
Epoch 2/5
3920/3920 [==============================] - 41s 10ms/step - loss: 0.3338 - accuracy: 0.8831 - val_loss: 3.3207 - val_accuracy: 0.5141
Epoch 3/5
3920/3920 [==============================] - 41s 10ms/step - loss: 0.1713 - accuracy: 0.9410 - val_loss: 4.2247 - val_accuracy: 0.5122
Epoch 4/5
3920/3920 [==============================] - 41s 10ms/step - loss: 0.1237 - accuracy: 0.9581 - val_loss: 4.9458 - val_accuracy: 0.5050
Epoch 5/5
3920/3920 [==============================] - 42s 11ms/step - loss: 0.1003 - accuracy: 0.9666 - val_loss: 5.2425 - val_accuracy: 0.5097
来源:https://blog.csdn.net/weixin_46274168/article/details/117919277
猜你喜欢
- myPhoneBook2.py#!/usr/bin/python# -*- coding: utf-8 -*-import reclass
- 本文实例讲述了django框架model orM使用字典作为参数,保存数据的方法。分享给大家供大家参考,具体如下:假设有一个字典,里面已经有
- 1.说明redis作为一个缓存数据库,在各方面都有很大作用,Python支持操作redis,如果你使用Django,有一个专为Django搭
- 参考资料:正则表达式语法–菜鸟教程Java正则表达式实现简单批量替换举例:将and 批量替换为&&Python实现impor
- 早上看了一个贴子,是一个哥们推广自己一个智能的数据库备份系统,他总结了数据库备份过程中所有可能出错的情况,可以借鉴。如果你做DBA时间不长,
- 使用Django的时候,我发现一个很神奇的装饰器: @login_required, 这是控制一个view的权限的,比如一个视图必须登录才可
- 导读:如何使用scrapy框架实现爬虫的4步曲?什么是CrawSpider模板?如何设置下载中间件?如何实现Scrapyd远程部署和监控?想
- 函数函数的英文单词是 Function,这个单词还有着功能的意思。在 Go 语言中,函数是实现某一特定功能的代码块。函数代表着某个功能,可以
- 方法一:>>> str1 = '''Le vent se lève, il faut tenter
- 前段时日微软(Microsoft)正式发布了.NET Core 2.0,在很多开发社区中反响不错。但还是有一些开发者发出了疑问,.NET C
- 一、常见模型分类1.1、循环服务器模型循环接收客户端请求,处理请求。同一时刻只能处理一个请求,处理完毕后再处理下一个。优点:实现简单,占用资
- 一、制作思路由于注册的时候常常会用到注册码来防止机器恶意注册,这里我发表一个产生png图片验证码的基本图像,简单的思路分析:1、产生一张pn
- 我们知道,做web开发,在调试时需要反复启动整个工程,那么上一个工程占用的端口,在下一次工程启动时就不能用了,因为占用的端口没有释放,但是手
- pyside2 >>> pip install pyside2 QT Designer>>
- 一. 数据的格式首先我们需要x,y,z三个数据进行画图。从本实验用到的数据集KITTI 00.txt中举例:1.000000e+00 9.0
- '-----------------------------------------------------------
- 本文实例讲述了微信小程序picker组件简单用法。分享给大家供大家参考,具体如下:picker滚动选择器,现支持三种选择器,通过mode来区
- 前言我们都知道时区,标准时区是UTC时区,django默认使用的就是UTC时区,所以我们存储在数据库中的时间是UTC的时间,但是当我们做的网
- 我们一般采用photoshop等做图工具制作电视扫描线效果图片:首先做一个黑白相间的图案,然后用这个图案进行填充,再调整图层的模式或者透明度
- 上回 说到“大屏幕浏览页面的良好体验,本就应该用户自己调整窗口。”根据屏幕不同大小,缩小窗口出横向滚动条在所难免,但理想情况下,页面应该能适