使用Pytorch实现two-head(多输出)模型的操作
作者:XJTU-Qidong 发布时间:2023-08-20 07:00:05
标签:Pytorch,two-head,多输出
如何使用Pytorch实现two-head(多输出)模型
1. two-head模型定义
先放一张我要实现的模型结构图:
如上图,就是一个two-head模型,也是一个但输入多输出模型。该模型的特点是输入一个x和一个t,h0和h1中只有一个会输出,所以可能这不算是一个典型的多输出模型。
2.实现所遇到的困难 一开始的想法:
这不是很简单嘛,做一个判断不就完了,t=0时模型为前半段加h0,t=1时模型为前半段加h1。但实现的时候傻眼了,发现在真正前向传播的时候t是一个tensor,有0有1,没法儿进行判断。
灵机一动,又生一法:把这个模型变为三个模型,前半段是一个模型(r),后面的h0和h1分别为另两个模型。把数据集按t=0和1分开,分别训练两个模型:r+h0和r+h1。
但是后来搜如何进行模型串联,发现极为麻烦。
3.解决方案
后来在pytorch的官方社区中看到一个极为简单的方法:
(1) 按照一般的多输出模型进行实现,代码如下:
def forward(self, x):
#三层的表示层
x = F.elu(self.fcR1(x))
x = F.elu(self.fcR2(x))
x = F.elu(self.fcR3(x))
#two-head,两个head分别进行输出
y0 = F.elu(self.fcH01(x))
y0 = F.elu(self.fcH02(y0))
y0 = F.elu(self.fcH03(y0))
y1 = F.elu(self.fcH11(x))
y1 = F.elu(self.fcH12(y1))
y1 = F.elu(self.fcH13(y1))
return y0, y1
这样就相当实现了一个多输出模型,一个x同时输出y0和y1.
训练的时候分别训练,也即分别建loss,代码如下:
f_out_y0, _ = net(x0)
_, f_out_y1 = net(x1)
#实例化损失函数
criterion0 = Loss()
criterion1 = Loss()
loss0 = criterion0(f_y0, f_out_y0, w0)
loss1 = criterion1(f_y1, f_out_y1, w1)
print(loss0.item(), loss1.item())
#对网络参数进行初始化
optimizer.zero_grad()
loss0.backward()
loss1.backward()
#对网络的参数进行更新
optimizer.step()
先把x按t=0和t=1分为x0和x1,然后分别送入进行训练。这样就实现了一个two-head模型。
4.后记
我自以为多输出模型可以分为以下两类:
多个输出不同时获得,如本文情况。
多个输出同时获得。
多输出不同时获得的解决方法上文已说明。多输出同时获得则可以通过把y0和y1拼接起来一起输出来实现。
补充:PyTorch 多输入多输出模型构建
本篇教程基于 PyTorch 1.5版本
直接上代码!
import torch
import torch.nn as nn
from torch.autograd import Variable
import torch.distributed as dist
import torch.utils.data as data_utils
class Net(nn.Module):
def __init__(self, n_input, n_hidden, n_output):
super(Net, self).__init__()
self.hidden1 = nn.Linear(n_input, n_hidden)
self.hidden2 = nn.Linear(n_hidden, n_hidden)
self.predict1 = nn.Linear(n_hidden*2, n_output)
self.predict2 = nn.Linear(n_hidden*2, n_output)
def forward(self, input1, input2): # 多输入!!!
out01 = self.hidden1(input1)
out02 = torch.relu(out01)
out03 = self.hidden2(out02)
out04 = torch.sigmoid(out03)
out11 = self.hidden1(input2)
out12 = torch.relu(out11)
out13 = self.hidden2(out12)
out14 = torch.sigmoid(out13)
out = torch.cat((out04, out14), dim=1) # 模型层拼合!!!当然你的模型中可能不需要~
out1 = self.predict1(out)
out2 = self.predict2(out)
return out1, out2 # 多输出!!!
net = Net(1, 20, 1)
x1 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1) # 请不要关心这里,随便弄一个数据,为了说明问题而已
y1 = x1.pow(3)+0.1*torch.randn(x1.size())
x2 = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y2 = x2.pow(3)+0.1*torch.randn(x2.size())
x1, y1 = (Variable(x1), Variable(y1))
x2, y2 = (Variable(x2), Variable(y2))
optimizer = torch.optim.SGD(net.parameters(), lr=0.1)
loss_func = torch.nn.MSELoss()
for t in range(5000):
prediction1, prediction2 = net(x1, x2)
loss1 = loss_func(prediction1, y1)
loss2 = loss_func(prediction2, y2)
loss = loss1 + loss2 # 重点!
optimizer.zero_grad()
loss.backward()
optimizer.step()
if t % 100 == 0:
print('Loss1 = %.4f' % loss1.data,'Loss2 = %.4f' % loss2.data,)
至此搞定!
来源:https://blog.csdn.net/dong_liuqi/article/details/104850408


猜你喜欢
- 当向 MySQL 数据库插入一条带有中文的数据形如 insert into employee values(null,'张
- BeautifulSoup简介Beautiful Soup是python的一个库,最主要的功能是从网页抓取数据。官方解释如下:Beautif
- 前言此文记录了我在进行 Anaconda 环境变量配置的做法 ,希望可以对有需要的朋友们有所帮助或者启发一、什么是环境变量环境变量一般是指操
- SQL Server四类数据仓库建模的方法主要分为以下四类。第一类是关系数据库的三范式建模,通常我们将三范式建模方法用于建立各种操作型数据库
- 某天,在需要抓取某个网页信息的时候,需要在header中增加一些信息,于是搜索了一下,如何在golang发起的http请求中设置header
- 目录什么是虚拟 dom?为什么需要虚拟dom?虚拟dom是如何转换为真实dom的?模板和虚拟dom的关系注入挂载完整流程总结什么是虚拟 do
- 在编程时你一定碰到过时间触发的事件,在VB中有timer控件,而asp中没有,假如你要不停地查询数据库来等待一个返回结果的话,我想你一定知道
- 本文实例为大家分享了原生js实现tab选项卡切换效果的代码,供大家参考,具体内容如下1.html部分<body> <div
- pytorch中的gather函数pytorch比tensorflow更加编程友好,所以准备用pytorch试着做最近要做的一些实验。立个f
- 在 Python 中,我们可以使用基本的索引操作来获取数组中的元素。然而,有时候我们需要获取一个数组的子数组,也就是只获取数组中的一部分元素
- 在web.config文件中添加<connectionStrings><add name="SQLConnect
- pytorch报错:RuntimeError: Expected object of type Variable[torch.LongTen
- 最终的目标是想这样的,在JavaScript里写一个swing来实现确定取消,来决定是否执行这个功能的,但是在执行的过程中,出现了一点问题,
- 我们先看一下相关数据结构的知识。 在学习线性表的时候,曾有这样一个例题。 已知一个存储整数的顺序表La,试构造顺序表Lb,要求顺序表Lb中只
- 1. 算法描述二分法是一种效率比较高的搜索方法回忆之前做过的猜数字的小游戏,预先给定一个小于100的正整数x,让你猜猜测过程中给予大小判断的
- 在最新版的pycharm中拥有类似jupyter的分段执行代码功能,其使用方法如下:1.在想要分段运行的段前一行(空白行)输入#%%2.选择
- Python文件: #parsexml.py #本例子参考自python联机文档,做了适当改动和添加 import xml.parsers.
- 导语在CSDN学习的过程中,遇到了爆火的文章是关于刮刮卡的!大家猜猜看是谁写的?我看这文章都特别火,我也感觉挺好玩的,那就寻思用 Pytho
- 1、处理包含数据的文件最近利用Python读取txt文件时遇到了一个小问题,就是在计算两个np.narray()类型的数组时,出现了以下错误
- 前言至今,ChatGPT 已经火了很多轮,我在第一轮的时候注册了账号,遗憾的是,没有彻头彻尾好好地体验过一次。最近这一次火爆,ChatGPT