网络编程
位置:首页>> 网络编程>> Python编程>> 关于Pytorch的MLP模块实现方式

关于Pytorch的MLP模块实现方式

作者:黄鑫huangxin  发布时间:2021-12-19 03:28:22 

标签:Pytorch,MLP模块

MLP分类效果一般好于线性分类器,即将特征输入MLP中再经过softmax来进行分类。

具体实现为将原先线性分类模块:


self.classifier = nn.Linear(config.hidden_size, num_labels)

替换为:


self.classifier = MLP(config.hidden_size, num_labels)

并且添加MLP模块:


 class MLP(nn.Module):
   def __init__(self, input_size, common_size):
     super(MLP, self).__init__()
     self.linear = nn.Sequential(
       nn.Linear(input_size, input_size // 2),
       nn.ReLU(inplace=True),
       nn.Linear(input_size // 2, input_size // 4),
       nn.ReLU(inplace=True),
       nn.Linear(input_size // 4, common_size)
     )

def forward(self, x):
     out = self.linear(x)
     return out

看一下模块结构:


mlp = MLP(1000,3)
print(mlp)

关于Pytorch的MLP模块实现方式

来源:https://blog.csdn.net/qq_33373858/article/details/88108153

0
投稿

猜你喜欢

手机版 网络编程 asp之家 www.aspxhome.com