使用Pytorch训练two-head网络的操作
作者:XJTU-Qidong 发布时间:2023-04-06 14:15:59
标签:Pytorch,训练,two-head,网络
之前有写过一篇如何使用Pytorch实现two-head(多输出)模型
在那篇文章里,基本把two-head网络以及构建讲清楚了(如果不清楚请先移步至那一篇博文)。
但是我后来发现之前的训练方法貌似有些问题。
以前的训练方法:
之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:
f_out_y0, _ = net(x0)
_, f_out_y1 = net(x1)
#实例化损失函数
criterion0 = Loss()
criterion1 = Loss()
loss0 = criterion0(f_y0, f_out_y0, w0)
loss1 = criterion1(f_y1, f_out_y1, w1)
print(loss0.item(), loss1.item())
#对网络参数进行初始化
optimizer.zero_grad()
loss0.backward()
loss1.backward()
#对网络的参数进行更新
optimizer.step()
但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。
修改后的方法
使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。
因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。
代码如下:
f_out_y0, cf_out_y0 = net(x0)
cf_out_y1, f_out_y1 = net(x1)
#按照t=0和t=1的索引拼接向量
y_pred = torch.zeros([len(x), 1])
y_pred[index0] = f_out_y0
y_pred[index1] = f_out_y1
criterion = Loss()
loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
#print(loss.item())
viz.line([float(loss)], [epoch], win='train_loss', update='append')
optimizer.zero_grad()
loss.backward()
#对网络的参数进行更新
optimizer.step()
import torch
from torch.autograd import Variable
#初步认识构建Tensor数据
def one():
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
print(torch.ones((2,3)))#生成一个2X3的全一矩阵
a = torch.randn((2,3))
b = a.numpy()#将一个torch.Tensor转换为numpy
c = torch.from_numpy(b)#将numpy转换为Tensor
print(a)
print(b)
print(c)
#使用Variable自动求导
def two():
# 构建Variable
x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
# 函数等式
y = w * x ** 2 + b
# 使用梯度下降计算各变量的偏导数
y.backward(torch.Tensor([1, 1, 1]))
print(x.grad)
print(w.grad)
print(b.grad)
线性回归例子:
import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression,self).__init__()
self.Linear = nn.Linear(1,1)
def forward(self,x):
return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
inputs = Variable(x)
targets = Variable(y)
outputs = model(inputs)
loss = Loss(outputs,targets)
Opt.zero_grad()
loss.backward()
Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()
来源:https://blog.csdn.net/dong_liuqi/article/details/106461847


猜你喜欢
- 公司里很多部门,每个部门可以发多条信息,但每条信息只对应一个部门部门类:class Dep(models.Model): nam
- 本文介绍了prototype.js常用函数及其使用方法例子说明函数名
- 本文实例讲述了Python机器学习之scikit-learn库中KNN算法的封装与使用方法。分享给大家供大家参考,具体如下:1、工具准备,p
- 一、Selects检索表中的所有行$users = DB::table('users')->get();foreach
- ChatGPT 是 OpenAI 开发的 GPT(Generative Pre-trained Transformer)语言模型的变体。它是
- 打开一个Project在导航区带出多个Project将会影响PyCharm的运行速度,解决这个问题的方式只打开一个即可。有时候打开一个Pro
- 以下是涉及到插入表格的查询的5种改进方法:1)使用LOAD DATA INFILE从文本下载数据这将比使用插入语句快20倍。2)使用带有多个
- SQL1: --1、查看表空间的名称及大小 SELECT t.tablespace_name, round(SUM(bytes / (102
- 前言当我们编写任何程序时,都会遇到一些错误,会让我们有挫败感,所以我有一个解决方案给你。 今天在这篇文章中,我们将讨论错误类型error:
- 由于XML本身的诸多优点,XML技术已被广泛的使用,目前的好多软件技术同XML紧密相关,比如微软的.net 平台对xml提供了强大的支持,提
- 其实很简单,一般的数组去重可以直接用 new Set() 方法即可,但是数组对象的话,比较复杂,不能直接用,我们可以采取间接的方法来去重un
- 简介:type() 函数可以对数据的类型进行判定。isinstance() 与 type() 区别:type() 不会认为子类是一种父类类型
- 本文实例讲述了Laravel框架执行原生SQL语句及使用paginate分页的方法。分享给大家供大家参考,具体如下:1、运行原生sqlpub
- Django中获取text,password名字:<input type="text" name="na
- PHPMailer是一个封装好的PHP邮件发送类,支持发送HTML内容的电子邮件,以及可以添加附件发送,并不像PHP本身mail()函数需要
- 在当今企业环境中,保证数据安全不是可有可无的工作。频繁曝光的入侵和欺骗事件、萨班斯•奥克斯利法案、HIPAA法案规定和爱国
- 问题描述在深度学习相关任务的训练时,需要在训练的每个 epoch 记录当前 epoch 的准确率(如下图所示),那么在 python 中要怎
- 本文实例为大家分享了Django文件上传与下载的具体代码,供大家参考,具体内容如下Django1.4首先是上传:#settings.pyME
- 这个格式是我自创的,经常有人问我为什么,这里做个简单总结:1、分类,一个模块或者同类功能定义为一类定义,每类定义之间用段落隔开。2、分级,每
- 概述在数据库当中,索引就跟树的目录一样用来加快数据的查找速度,对于一个SQL查询操作,根据索引快速过滤掉不符合要求的数据并定位到符合要求的数