tensorflow之自定义神经网络层实例
作者:大雄没有叮当猫 发布时间:2022-05-26 07:05:27
标签:tensorflow,自定义,神经网络层
如下所示:
import tensorflow as tf
tfe = tf.contrib.eager
tf.enable_eager_execution()
大多数情况下,在为机器学习模型编写代码时,您希望在比单个操作和单个变量操作更高的抽象级别上操作。
1.关于图层的一些有用操作
许多机器学习模型可以表达为相对简单的图层的组合和堆叠,TensorFlow提供了一组许多常用图层,以及您从头开始或作为组合创建自己的应用程序特定图层的简单方法。TensorFlow在tf.keras包中包含完整的Keras API,而Keras层在构建自己的模型时非常有用。
#在tf.keras.layers包中,图层是对象。要构造一个图层,只需构造一个对象。大多数层将输出维度/通道的数量作为第一个参数。
layer=tf.keras.layers.Dense(100)
#输入维度的数量通常是不必要的,因为它可以在第一次使用图层时推断出来,但如果您想手动指定它,则可以提供它,这在某些复杂模型中很有用。
layer=tf.keras.layers.Dense(10,input_shape=(None,5))
#调用层
layer(tf.zeros([10,5]))
#图层有许多有用的方法。例如,您可以通过调用layer.variables来检查图层中的所有变量。在这种情况下,完全连接的层将具有权重和偏差的变量。
variable=layer.variables
# variable[0]
layer.kernel.numpy()
layer.bias
2.自定义图层
实现自己的层的最佳方法是扩展tf.keras.Layer类并实现:
__init__,您可以在其中执行所有与输入无关的初始化
build方法,您知道输入张量的形状,并可以进行其余的初始化
call方法,在这里进行正向传播计算
请注意,您不必等到调用build来创建变量,您也可以在__init__中创建它们。但是,在build中创建它们的优点是它可以根据图层将要操作的输入的形状启用后期变量创建。另一方面,在__init__中创建变量意味着需要明确指定创建变量所需的形状。
class MyDenseLayer(tf.keras.layers.Layer):
def __init__(self, num_outputs):
super(MyDenseLayer, self).__init__()
self.num_outputs = num_outputs
def build(self, input_shape):
self.kernel = self.add_variable("kernel",
shape=[input_shape[-1].value,
self.num_outputs])
def call(self, input):
return tf.matmul(input, self.kernel)
layer = MyDenseLayer(10)
print(layer(tf.zeros([10, 5])))
print(layer.variables)
3.搭建网络结构
机器学习模型中许多有趣的图层是通过组合现有层来实现的。例如,resnet中的每个residual块是卷积,批量标准化等的组合。
创建包含其他图层的类似图层的东西时使用的主类是tf.keras.Model。实现一个是通过继承自tf.keras.Model完成的。
class ResnetIdentityBlock(tf.keras.Model):
def __init__(self, kernel_size, filters):
super(ResnetIdentityBlock, self).__init__(name='')
filters1, filters2, filters3 = filters
self.conv2a = tf.keras.layers.Conv2D(filters1, (1, 1))
self.bn2a = tf.keras.layers.BatchNormalization()
self.conv2b = tf.keras.layers.Conv2D(filters2, kernel_size, padding='same')
self.bn2b = tf.keras.layers.BatchNormalization()
self.conv2c = tf.keras.layers.Conv2D(filters3, (1, 1))
self.bn2c = tf.keras.layers.BatchNormalization()
def call(self, input_tensor, training=False):
x = self.conv2a(input_tensor)
x = self.bn2a(x, training=training)
x = tf.nn.relu(x)
x = self.conv2b(x)
x = self.bn2b(x, training=training)
x = tf.nn.relu(x)
x = self.conv2c(x)
x = self.bn2c(x, training=training)
x += input_tensor
return tf.nn.relu(x)
block = ResnetIdentityBlock(1, [1, 2, 3])
print(block(tf.zeros([1, 2, 3, 3])))
print([x.name for x in block.variables])
来源:https://blog.csdn.net/u013230189/article/details/81742306


猜你喜欢
- 转发和重定向:转发:一次请求和响应,请求的地址没有发生变化,如果此时刷新页面,就会出现重做现象。重定向:一次以上的请求和响应,请求地址发生一
- 用于匹配的正则表达式为 :([1-9]\d*\.?\d*)|(0\.\d*[1-9])([1-9] :匹配1~9的数字;\d :匹配数字,包
- 前言在前几天的文章中我们讲解了如何从Word表格中提取指定数据并按照格式保存到Excel中,今天我们将再次以一位读者提出的真实需求来讲解如何
- 一、pyc文件我们开发一个python脚本,文件的后缀为.py。如果运行这个py文件,Python内部会先将源码文件(.py文件)编译成字节
- 目录del:根据索引值删除元素pop():根据索引值删除元素remove():根据元素值进行删除clear():删除列表所有元素在 Pyth
- 函数作用:该函数的作用即按字面意思理解,topk:取数组的前k个元素进行排序。通常该函数返回2个值,第一个值为排序的数组,第二个值为该数组中
- 本文实例讲述python调用Moxa PCOMM Lite通过串口Ymodem协议实现发送文件的方法,该程序采用python 2.7编写。主
- 一、使用logging.config.dictConfig()函数读取配置信息,参数是字典类型with open(file="./
- 无聊的人在无聊的时间做无聊的事打发自己,结果在无聊的事情中发现了IE对内联文字解释的一些疑惑。以下问题在FF2中没发现,而我也只
- <%'该函数作用:按指定参数格式化显示时间。'numformat=1:将时间转化为yyyy-mm-dd h
- 感谢人类方方面面的创新,今天Web开发已经不需要在如何设计网站上面浪费时间了。框架和库帮助web开发者得以专注于真正的开发工作上。下面的这些
- CREATE FUNCTION f_Convert( @str NV
- 这个函数用于储存图片,将数组保存为图像此功能仅在安装了Python Imaging Library(PIL)时可用。版本也比较老了,新的替代
- asp连接sql 第一种写法: 代码如下: MM_conn_STRING = "Driver={SQL Server};serv
- 概述从今天开始, 小白我将带领大家一起来补充一下 数据库的知识.自连接自连接 (Self Join) 是一种特殊的表连接. 自连接指相互连接
- PDOStatement::executePDOStatement::execute — 执行一条预处理语句(PHP 5 >= 5.1
- 前言点击视频讲解更加详细this.$route:当前激活的路由的信息对象。每个对象都是局部的,可以获取当前路由的 path, na
- 前言CRUD代表: 增加(create) ,查询(retrieve) ,更新(update) ,删除(delete) 单词首字母。一、新增数
- 我就废话不多说了,直接上代码!from enum import Enumclass Values(): values={'
- 一、泛型程序设计是一种编程风格或编程范式二、案例:传入的参数类型与返回的类型一样function identify<T>(arg