pytorch自定义初始化权重的方法
作者:goodxin_ie 发布时间:2023-12-25 07:55:06
在常见的pytorch代码中,我们见到的初始化方式都是调用init类对每层所有参数进行初始化。但是,有时我们有些特殊需求,比如用某一层的权重取优化其它层,或者手动指定某些权重的初始值。
核心思想就是构造和该层权重同一尺寸的矩阵去对该层权重赋值。但是,值得注意的是,pytorch中各层权重的数据类型是nn.Parameter,而不是Tensor或者Variable。
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
# 第一一个卷积层,我们可以看到它的权值是随机初始化的
w=torch.nn.Conv2d(2,2,3,padding=1)
print(w.weight)
# 第一种方法
print("1.使用另一个Conv层的权值")
q=torch.nn.Conv2d(2,2,3,padding=1) # 假设q代表一个训练好的卷积层
print(q.weight) # 可以看到q的权重和w是不同的
w.weight=q.weight # 把一个Conv层的权重赋值给另一个Conv层
print(w.weight)
# 第二种方法
print("2.使用来自Tensor的权值")
ones=torch.Tensor(np.ones([2,2,3,3])) # 先创建一个自定义权值的Tensor,这里为了方便将所有权值设为1
w.weight=torch.nn.Parameter(ones) # 把Tensor的值作为权值赋值给Conv层,这里需要先转为torch.nn.Parameter类型,否则将报错
print(w.weight)
附:Variable和Parameter的区别
Parameter 是torch.autograd.Variable的一个字类,常被用于Module的参数。例如权重和偏置。
Parameters和Modules一起使用的时候会有一些特殊的属性。parameters赋值给Module的属性的时候,它会被自动加到Module的参数列表中,即会出现在Parameter()迭代器中。将Varaible赋给Module的时候没有这样的属性。这可以在nn.Module的实现中详细看一下。这样做是为了保存模型的时候只保存权重偏置参数,不保存节点值。所以复写Variable加以区分。
另外一个不同是parameter不能设置volatile,而且require_grad默认设置为true。Varaible默认设置为False.
参数:
parameter.data 得到tensor数据
parameter.requires_grad 默认为True, BP过程中会求导
Parameter一般是在Modules中作为权重和偏置,自动加入参数列表,可以进行保存恢复。和Variable具有相同的运算。
我们可以这样简单区分,在计算图中,数据(包括输入数据和计算过程中产生的feature map等)时variable类型,该类型不会被保存到模型中。 网络的权重是parameter类型,在计算过程中会被更新,将会被保存到模型中。
来源:https://blog.csdn.net/goodxin_ie/article/details/84555805
猜你喜欢
- 说来惭愧,以前在去掉数组的空值是都是强写foreach或者while的,利用这两个语法结构来删除数组中的空元素,简单代码如下:<?ph
- Python作为一种功能强大的编程语言,因其简单易学而受到很多开发者的青睐。那么,Python 的应用领域有哪些呢?概括起来,Python的
- 多元正态分布(多元高斯分布)直接从多元正态分布讲起。多元正态分布公式如下:这就是多元正态分布的定义,均值好理解,就是高斯分布的概率分布值最大
- 本文实例为大家分享了python实现学生信息管理系统的具体代码,供大家参考,具体内容如下学生管理系统的开发步骤:1、显示学生管理系统的功能菜
- 第一步:创建一个表。 create table Test_Table ( ID number(11) primary key, Name v
- 复制是将主数据库的DDL和DML操作通过二进制日志传到从库上,然后再从库重做,从而使得从库和主库保持数据的同步。MySQL可以从一台主库同时
- 很多时候我们需要对数字进行格式化,比如位数不足前面加0补足。用PHP可以很轻易实现,因为PHP自带了相关功能的函数。<?php &nb
- 本文实例讲述了mysql存储过程之创建(CREATE PROCEDURE)和调用(CALL)及变量创建(DECLARE)和赋值(SET)操作
- 前言:Python内置对SMTP的支持,可以发送纯文本邮件、HTML邮件以及带附件的邮件。Python对SMTP支持有smtplib和ema
- 前言在遇到三维数据时,三维图像能给我们对数据带来更加深入地理解。python的matplotlib库就包含了丰富的三维绘图工具。1.创建三维
- 作为程序员,我们经常需要对时间进行处理。在 Go 中,标准库 time 提供了对应的能力。本文将介绍 time 库中一些重要的函数和方法,希
- 参考网址 https://www.jb51.net/article/29551.htmSELECT [StartDate] FROM [db
- 概述用爬虫时,大部分网站都有一定的反爬措施,有些网站会限制每个 IP 的访问速度或访问次数,超出了它的限制你的 IP 就会被封掉。对于访问速
- layer是一款近年来备受青睐的web弹层组件,官网地址是:http://layer.layui.com/可以从官网上下载最新版本.使用la
- 首先centos7 已经不支持mysql,因为收费了你懂得,所以内部集成了mariadb,而安装mysql的话会和mariadb的文件冲突,
- Python 多线程的实例详解一)线程基础1、创建线程:thread模块提供了start_new_thread函数,用以创建线程
- 这些日子,几乎每个人都在谈论XML (Extensible Markup Language),但是很少有人真正理解其含义。XML的推崇者认为
- 问题你有一个数据序列,想利用一些规则从中提取出需要的值或者是缩短序列解决方案最简单的过滤序列元素的方法就是使用列表推导。比如:>>
- 目录前言map 并发操作出现问题sync.Map 解决并发操作问题计算 map 长度计算 sync.Map 长度前言在 Golang 中 m
- 教程使用的版本是2019.1新版本安装激活可以参考此篇教程,通用版!一、go安装1、建议去go语言中文网下载,网址:https://stud