关于tensorflow的几种参数初始化方法小结
作者:liushui94 发布时间:2023-10-15 12:05:26
在tensorflow中,经常会遇到参数初始化问题,比如在训练自己的词向量时,需要对原始的embeddigs矩阵进行初始化,更一般的,在全连接神经网络中,每层的权值w也需要进行初始化。
tensorlfow中应该有一下几种初始化方法
1. tf.constant_initializer() 常数初始化
2. tf.ones_initializer() 全1初始化
3. tf.zeros_initializer() 全0初始化
4. tf.random_uniform_initializer() 均匀分布初始化
5. tf.random_normal_initializer() 正态分布初始化
6. tf.truncated_normal_initializer() 截断正态分布初始化
7. tf.uniform_unit_scaling_initializer() 这种方法输入方差是常数
8. tf.variance_scaling_initializer() 自适应初始化
9. tf.orthogonal_initializer() 生成正交矩阵
具体的
1、tf.constant_initializer(),它的简写是tf.Constant()
#coding:utf-8
import numpy as np
import tensorflow as tf
train_inputs = [[1,2],[1,4],[3,2]]
with tf.variable_scope("embedding-layer"):
val = np.array([[1,2,3,4,5,6,7],[1,3,4,5,2,1,9],[0,12,3,4,5,7,8],[2,3,5,5,6,8,9],[3,1,6,1,2,3,5]])
const_init = tf.constant_initializer(val)
embeddings = tf.get_variable("embed",shape=[5,7],dtype=tf.float32,initializer=const_init)
embed = tf.nn.embedding_lookup(embeddings, train_inputs) #在embedding中查找train_input所对应的表示
print("embed",embed)
sum_embed = tf.reduce_mean(embed,1)
initall = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(initall)
print(sess.run(embed))
print(sess.run(tf.shape(embed)))
print(sess.run(sum_embed))
4、random_uniform_initializer = RandomUniform()
可简写为tf.RandomUniform()
生成均匀分布的随机数,参数有四个(minval=0, maxval=None, seed=None, dtype=dtypes.float32),分别用于指定最小值,最大值,随机数种子和类型。
6、tf.truncated_normal_initializer()
可简写tf.TruncatedNormal()
生成截断正态分布的随机数,这个初始化方法在tf中用得比较多。
它有四个参数(mean=0.0, stddev=1.0, seed=None, dtype=dtypes.float32),分别用于指定均值、标准差、随机数种子和随机数的数据类型,一般只需要设置stddev这一个参数就可以了。
8、tf.variance_scaling_initializer()
可简写为tf.VarianceScaling()
参数为(scale=1.0,mode="fan_in",distribution="normal",seed=None,dtype=dtypes.float32)
scale: 缩放尺度(正浮点数)
mode: "fan_in", "fan_out", "fan_avg"中的一个,用于计算标准差stddev的值。
distribution:分布类型,"normal"或“uniform"中的一个。
当 distribution="normal" 的时候,生成truncated normal distribution(截断正态分布) 的随机数,其中stddev = sqrt(scale / n) ,n的计算与mode参数有关。
如果mode = "fan_in", n为输入单元的结点数;
如果mode = "fan_out",n为输出单元的结点数;
如果mode = "fan_avg",n为输入和输出单元结点数的平均值。
当distribution="uniform”的时候 ,生成均匀分布的随机数,假设分布区间为[-limit, limit],则 limit = sqrt(3 * scale / n)
来源:https://blog.csdn.net/liushui94/article/details/78947956
猜你喜欢
- 本文实例讲述了JS 设计模式之:单例模式定义与实现方法。分享给大家供大家参考,具体如下:良好的设计模式可以显著提高代码的可读性,降低复杂度和
- Python 在 2.2 版本中引入了descriptor(描述符)功能,也正是基于这个功能实现了新式类(new-styel class)的
- 最近在用Pycharm学习Python的时候,总有两个地方感觉不是很舒服,比如调用方法的时候区分大小写(thread就不会出现Thread,
- MySQL之前有一个查询缓存Query Cache,从8.0开始,不再使用这个查询缓存,那么放弃它的原因是什么呢?在这一篇里将为您介绍。My
- 一、背景在kubernetes的世界中,很多组件仅仅需要一个实例在运行,比如controller-manager或第三方的controlle
- 使用profile来分析慢sqlmysql 的 sql 性能分析器主要用途是显示 sql 执行的整个过程中各项资源的使用情况。分析器可以更好
- 本文分享了php结合ajax实现无刷新上传图片的实例代码,分享给大家,希望大家可以和小编一起学习学习,共同进步。1.引入文件<!--图
- Perl利用函数rand()和srand()为随机数(更确切的说是"伪随机数")字符串的生成提供了基本的工具。这些函数不
- 本文将研究 ES6 的 for ... of 循环。旧方法在过去,有两种方法可以遍历 javascript。首先是经典的 for i 循环,
- 概述Python是个非常受欢迎的编程语言,随着近些年机器学习、云计算等技术的发展,Python的职位需求越来越高。下面我收集了10个Pyth
- 我们要生成二维码都需要借助一些类库来实现了,下面我介绍利用PHP QR Code生成二维码吧,生成方法很简单,下面我来介绍一下.利用php类
- 本文实例讲述了python抓取并保存html页面时乱码问题的解决方法。分享给大家供大家参考,具体如下:在用Python抓取html页面并保存
- 前言基于 Pythgo的 Django 框架,编程实现一个 WEB 程序,为用户提供 城市信息查询功能。用户可输入一个城市名,输出其所在省份
- 代码很简单,只需要2行代码就可实现想要的功能,虽然很短,但确实使用,主要使用了requests库。测试2XX, 3XX, 4XX, 5XX都
- 目录1 、一般同步下载2、 使用流式请求,requests.get方法的stream3 、异步下载文件4、 异步拆分下载文件5、注意1 、一
- 使用del和drop方法删除DataFrame中的列,使用drop方法一次删除多列数据准备:import pandas as pd 
- 本文实例讲述了Python创建xml的方法。分享给大家供大家参考。具体实现方法如下:from xml.dom.minidom import
- JavaScript获取Select当前值写法:var value = document.getElementById("sele
- 本文主要介绍的是MySQL慢查询分析方法,前一段日子,我曾经设置了一次记录在MySQL数据库中对慢于1秒钟的SQL语句进行查询。想起来有几个
- 利用python进行求解,求解的要求是不能使用python内部封装好的函数例如:maxway1:def findmax(data,n): i