Pytorch 实现自定义参数层的例子
作者:青盏 发布时间:2023-01-27 22:00:06
标签:Pytorch,自定义,参数层
注意,一般官方接口都带有可导功能,如果你实现的层不具有可导功能,就需要自己实现梯度的反向传递。
官方Linear层:
class Linear(Module):
def __init__(self, in_features, out_features, bias=True):
super(Linear, self).__init__()
self.in_features = in_features
self.out_features = out_features
self.weight = Parameter(torch.Tensor(out_features, in_features))
if bias:
self.bias = Parameter(torch.Tensor(out_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input):
return F.linear(input, self.weight, self.bias)
def extra_repr(self):
return 'in_features={}, out_features={}, bias={}'.format(
self.in_features, self.out_features, self.bias is not None
)
实现view层
class Reshape(nn.Module):
def __init__(self, *args):
super(Reshape, self).__init__()
self.shape = args
def forward(self, x):
return x.view((x.size(0),)+self.shape)
实现LinearWise层
class LinearWise(nn.Module):
def __init__(self, in_features, bias=True):
super(LinearWise, self).__init__()
self.in_features = in_features
self.weight = nn.Parameter(torch.Tensor(self.in_features))
if bias:
self.bias = nn.Parameter(torch.Tensor(self.in_features))
else:
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
stdv = 1. / math.sqrt(self.weight.size(0))
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, input):
x = input * self.weight
if self.bias is not None:
x = x + self.bias
return x
来源:https://blog.csdn.net/qq_16234613/article/details/81604081


猜你喜欢
- 在开发Windows应用程序时,我们经常需要进行测试来确保程序的质量和稳定性。手动测试是一种常见的方法,但是它非常耗时和繁琐,特别是对于大型
- 网上找的协程安全的map都是用互斥锁或者读写锁实现的,这里用单个协程来实现下,即所有的增删查改操作都集成到一个goroutine中,这样肯定
- TensorFlow从txt文件中读取数据的方法很多有种,我比较常用的是下面两种:【1】np.loadtxtimport numpy as
- DBA_2PC_PENDING Oracle会自动处理分布事务,保证分布事务的一致性,所有站点全部提交或全部回滚。一般情况下,处理过程在很短
- 见下表:序号列类型需要的存储量1TINYINT1 字节2SMALLINT2 个字节3MEDIUMINT3 个字节4INT4 个字节5INTE
- 项目总览创建虚拟环境mkvirtualenv meiduo_malls创建项目django-admin startproject meidu
- 很多朋友问过我absolute与relative怎么区分,怎么用?我们都知道absolute是绝对定位,relative是相对定位,但是这个
- blob对象介绍一个 Blob对象表示一个不可变的, 原始数据的类似文件对象。Blob表示的数据不一定是一个JavaScript原生格式 b
- 我遇到的情况是:把数据按一定的时间段提出。比如提出每天6:00-8:00的每个数据,可以这样做:# -*-coding: utf-8 -*-
- 数组/对象数组删除其中某一项由于日常工作中经常需要对数组进行操作,最经常使用到的就是对数组进行的删除操作对于我们前端来说,数组有两种区别1、
- 1.figure语法及操作(1)figure语法说明figure(num=None, figsize=None, dpi=None, fac
- 有时需要根据项目的实际需求向spider传递参数以控制spider的行为,比如说,根据用户提交的url来控制spider爬取的网站。在这种情
- 一、线程基础以及守护进程线程是CPU调度的最小单位全局解释器锁全局解释器锁GIL(global interpreter lock)全局解释器
- 看代码:Vue提供了强大的前端开发架构,很多时候我们需要判断数据对象是否为空,使用typeof判断是个不错选择,具体代码见图。补充知识:vu
- 一、条件简化我们编写的查询语句的搜索条件本质上是一个表达式,这些表达式可能比较繁杂,或者不能高效的执行,MySQL的查询优化器会为我们简化这
- Mysql的分页的两个参数select * from user limit 1,21表示从第几条数据开始查(默认索引是0,如果写1,从第二条
- 在自动化测试过程中,有时后会遇到元素定位方式没有问题,但是依旧抛出无法找到元素的异常的问题,通常情况下,如果元素定位没有问题,但还是无法找到
- 本文是基于上一篇(python项目:学生信息管理系统(初版) )进行了完善,并添加了新的功能。主要包括有:完善部分:输入错误;无数据查询等异
- 一个简单的JS显示日期代码,可以显示星期几<script type="text/javascript">fu
- 用Python编写关于计算图形面积的代码实现,供大家参考,具体内容如下#寒假打卡28天第7天import mathclass Round()