网络编程
位置:首页>> 网络编程>> 网络编程>> Pytorch中如何调用forward()函数

Pytorch中如何调用forward()函数

作者:good  发布时间:2023-06-14 21:00:24 

标签:Pytorch,forward,函数

Pytorch调用forward()函数

Module类是nn模块里提供的一个模型构造类,是所有神经网络模块的基类,我们可以继承它来定义我们想要的模型。

下面继承Module类构造本节开头提到的多层感知机。

这里定义的MLP类重载了Module类的__init__函数和forward函数。

它们分别用于创建模型参数和定义前向计算。

前向计算也即正向传播。

import torch
from torch import nn
 
class MLP(nn.Module):
    # 声明带有模型参数的层,这里声明了两个全连接层
    def __init__(self, **kwargs):
        # 调用MLP父类Module的构造函数来进行必要的初始化。这样在构造实例时还可以指定其他函数
        # 参数,如“模型参数的访问、初始化和共享”一节将介绍的模型参数params
        super(MLP, self).__init__(**kwargs)
        self.hidden = nn.Linear(784, 256) # 隐藏层
        self.act = nn.ReLU()
        self.output = nn.Linear(256, 10)  # 输出层
 
 
    # 定义模型的前向计算,即如何根据输入x计算返回所需要的模型输出
    def forward(self, x):
        a = self.act(self.hidden(x))
        return self.output(a)
  
X = torch.rand(2, 784)
net = MLP()
print(net)
net(X)

输出:

MLP( (hidden): Linear(in_features=784, out_features=256, bias=True) (act): ReLU() (output): Linear(in_features=256, out_features=10, bias=True) ) tensor([[-0.1798, -0.2253, 0.0206, -0.1067, -0.0889, 0.1818, -0.1474, 0.1845, -0.1870, 0.1970], [-0.1843, -0.1562, -0.0090, 0.0351, -0.1538, 0.0992, -0.0883, 0.0911, -0.2293, 0.2360]], grad_fn=<ThAddmmBackward>)

为什么会调用forward()呢,是因为Module中定义了__call__()函数,该函数调用了forward()函数,当执行net(x)的时候,会自动调用__call__()函数

Pytorch函数调用的问题和源码解读

最近用到 softmax 函数,但是发现 softmax 的写法五花八门,记录如下

# torch._C._VariableFunctions
torch.softmax(x, dim=-1)
# class
softmax = torch.nn.Softmax(dim=-1)
x=softmax(x)
# function
x = torch.nn.functional.softmax(x, dim=-1)

简单测试了一下,用 torch.nn.Softmax 类是最慢的,另外两个差不多

torch.nn.Softmax 源码如下,可以看到这是个类,而他这里的 return F.softmax(input, self.dim, _stacklevel=5) 调用的是 torch.nn.functional.softmax

class Softmax(Module):
    r"""Applies the Softmax function to an n-dimensional input Tensor
    rescaling them so that the elements of the n-dimensional output Tensor
    lie in the range [0,1] and sum to 1.

    Softmax is defined as:

    .. math::
        \text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}

    When the input Tensor is a sparse tensor then the unspecifed
    values are treated as ``-inf``.

    Shape:
        - Input: :math:`(*)` where `*` means, any number of additional
          dimensions
        - Output: :math:`(*)`, same shape as the input

    Returns:
        a Tensor of the same dimension and shape as the input with
        values in the range [0, 1]

    Args:
        dim (int): A dimension along which Softmax will be computed (so every slice
            along dim will sum to 1).

    .. note::
        This module doesn't work directly with NLLLoss,
        which expects the Log to be computed between the Softmax and itself.
        Use `LogSoftmax` instead (it's faster and has better numerical properties).

    Examples::

        >>> m = nn.Softmax(dim=1)
        >>> input = torch.randn(2, 3)
        >>> output = m(input)

    """
    __constants__ = ['dim']
    dim: Optional[int]

    def __init__(self, dim: Optional[int] = None) -> None:
        super(Softmax, self).__init__()
        self.dim = dim

    def __setstate__(self, state):
        self.__dict__.update(state)
        if not hasattr(self, 'dim'):
            self.dim = None

    def forward(self, input: Tensor) -> Tensor:
        return F.softmax(input, self.dim, _stacklevel=5)

    def extra_repr(self) -> str:
        return 'dim={dim}'.format(dim=self.dim)

torch.nn.functional.softmax 函数源码如下,可以看到 ret = input.softmax(dim) 实际上调用了 torch._C._VariableFunctions 中的 softmax 函数

def softmax(input: Tensor, dim: Optional[int] = None, _stacklevel: int = 3, dtype: Optional[DType] = None) -> Tensor:
    r"""Applies a softmax function.

    Softmax is defined as:

    :math:`\text{Softmax}(x_{i}) = \frac{\exp(x_i)}{\sum_j \exp(x_j)}`

    It is applied to all slices along dim, and will re-scale them so that the elements
    lie in the range `[0, 1]` and sum to 1.

    See :class:`~torch.nn.Softmax` for more details.

    Args:
        input (Tensor): input
        dim (int): A dimension along which softmax will be computed.
        dtype (:class:`torch.dtype`, optional): the desired data type of returned tensor.
          If specified, the input tensor is casted to :attr:`dtype` before the operation
          is performed. This is useful for preventing data type overflows. Default: None.

    .. note::
        This function doesn't work directly with NLLLoss,
        which expects the Log to be computed between the Softmax and itself.
        Use log_softmax instead (it's faster and has better numerical properties).

    """
    if has_torch_function_unary(input):
        return handle_torch_function(softmax, (input,), input, dim=dim, _stacklevel=_stacklevel, dtype=dtype)
    if dim is None:
        dim = _get_softmax_dim("softmax", input.dim(), _stacklevel)
    if dtype is None:
        ret = input.softmax(dim)
    else:
        ret = input.softmax(dim, dtype=dtype)
    return ret

那么不如直接调用 built-in C 的函数?

但是有个博客 A selective excursion into the internals of PyTorch 里说

Note: That bilinear is exported as torch.bilinear is somewhat accidental. Do use the documented interfaces, here torch.nn.functional.bilinear whenever you can!

意思是说 built-in C 能被 torch.xxx 直接调用是意外的,强烈建议使用 torch.nn.functional.xxx 这样的接口

看到最新的 transformer 官方代码里也用的是 torch.nn.functional.softmax,还是和他们一致更好(虽然他们之前用的是类。。。)

来源:https://blog.csdn.net/weixin_39454351/article/details/106419293

0
投稿

猜你喜欢

  • 一、问题描述 SQL Plus WorkSheet是一个窗口图形界面的SQL语句编辑器,对于那些喜欢窗口界面而不喜欢字符界面的用户,该工具相
  • 在做一个在线交流的网站时,有个问题很令我头疼,就是关于实时统计在线用户的问题,客户要求:统计当前在线人数、游客人数、会员人数、在线用户列表,
  • /* 小弟刚刚接触ORACLE存储过程,有一个问题向各位同行求教,小弟写了一个存储过程,其目的是接收一个参数作为表名,然后查询该表中的全部记
  • 作者:Lachlan Hunt概要网络是不断的进化的. 新的和有创意的网站每天都在出现, 从各方面都在冲击着HTML的边界. HTML 4来
  • 您可以将SQL Server 数据库引擎升级到 SQL Server 2008。SQL Server 安装程序只需最少的用户干预就可升级 S
  • Mysql默认是不可以通过远程机器访问的,通过下面的配置可以开启远程访问.我的Mysql环境是ubuntu+mysql51.修改/etc/m
  • 内容摘要:图片切换效果在网页制作中经常被使用,好的切换效果不仅增加了网站的实用行也提升了网站的趣味性。而图片切换方法有的使用flash来实现
  • 应原书编辑要求,先在文章顶部给出链接:《Everything You Know About CSS Is Wrong》http://www.
  • 在许多情况下,当迁移至SQL Server 2008之前必须了解那些反对和放弃功能的具体情况。下文是几个主要功能在兼容性上的问题列表:1.S
  • String Types(字符串类型)字符串类型Mysql支持多种字符串类型的变体。 这些数据类型在4.1和5.0版本中有较大的变化, 这使
  • 阅读作者的上一篇相关文章:段正淳的css笔记(3)标题右侧“更多”的实现 段正淳的css笔记(4)1、css代码的简写css缩写的语法,对新
  • 对于更完整的代码可以参考,这个是支持数据库的版本。经过测试Asp+Ajax仿google搜索提示效果 数据库版google搜索提示.rar
  • 导言(Introduction)这个提案描述了如何在jQuery的核心库中增加模板支持。更为特别是,这个提案描述了一个新的jQuery方法-
  • SQL Server TEXT、NTEXT字段拆分的问题引用的内容:SET NOCOUNT ON CREATE 
  • 有过Web经验的人喜欢使用:<meta http-equiv="refresh" content="1;
  • 阅读上一篇:交互设计模式(二)-Pagination(分页,标记页数) Tagging(标签)问题摘要用户往往想通过流行或最详尽的主题来浏览
  • 计算分页,嘿嘿一次搞定不用判断intNumPage = Abs(Int(-(intNumRecord/intPerPage)))  
  • 把下面SQL里的SELECT单独执行,没有问题,但是用来CREATE VIEW 就报错了.CREATE OR REPLA
  • 前端代码要做到简洁易读、高效,还要考虑后端嵌套的方便性。前段时间做了一个导航,把整个制作过程重现,希望对大家有帮助。看到这样的导航,你会怎么
  •     大家在打开带有图片的网页时,有时会看到这样的情况:当鼠标指向图片的不同部位时,可以打开不同的超链接,这
手机版 网络编程 asp之家 www.aspxhome.com