使用Pytorch训练two-head网络的操作
作者:XJTU-Qidong 发布时间:2023-04-06 14:15:59
标签:Pytorch,训练,two-head,网络
之前有写过一篇如何使用Pytorch实现two-head(多输出)模型
在那篇文章里,基本把two-head网络以及构建讲清楚了(如果不清楚请先移步至那一篇博文)。
但是我后来发现之前的训练方法貌似有些问题。
以前的训练方法:
之前是把两个head分开进行训练的,因此每一轮训练先要对一个batch的数据进行划分,然后再分别训练两个头。代码如下:
f_out_y0, _ = net(x0)
_, f_out_y1 = net(x1)
#实例化损失函数
criterion0 = Loss()
criterion1 = Loss()
loss0 = criterion0(f_y0, f_out_y0, w0)
loss1 = criterion1(f_y1, f_out_y1, w1)
print(loss0.item(), loss1.item())
#对网络参数进行初始化
optimizer.zero_grad()
loss0.backward()
loss1.backward()
#对网络的参数进行更新
optimizer.step()
但是在实际操作中想到那这样的话岂不是每次都先使用t=0的数据训练公共的表示层,再使用t=1的数据去训练。这样会不会使表示层产生bias呢?且这样两步训练也很麻烦。
修改后的方法
使用之前训练方法其实还是对神经网络的训练的机理不清楚。事实上,在计算loss的时候每个数据点的梯度都是单独计算的。
因此完全可以把网络前向传播得到结果按之前的顺序拼接起来后再进行梯度的反向传播,这样就可以只进行一步训练,且不会出现训练先后的偏差。
代码如下:
f_out_y0, cf_out_y0 = net(x0)
cf_out_y1, f_out_y1 = net(x1)
#按照t=0和t=1的索引拼接向量
y_pred = torch.zeros([len(x), 1])
y_pred[index0] = f_out_y0
y_pred[index1] = f_out_y1
criterion = Loss()
loss = criterion(f_y, y_pred, w) + 0.01 * (l2_regularization0 + l2_regularization1)
#print(loss.item())
viz.line([float(loss)], [epoch], win='train_loss', update='append')
optimizer.zero_grad()
loss.backward()
#对网络的参数进行更新
optimizer.step()
import torch
from torch.autograd import Variable
#初步认识构建Tensor数据
def one():
print(torch.tensor([1,2,3],dtype=torch.float))#将一个列表强制转换为torch.Tensor类型
print(torch.randn(5,3))#生成torch.Tensor类型的5X3的随机数
print(torch.zeros((2,3)))#生成一个2X3的全零矩阵
print(torch.ones((2,3)))#生成一个2X3的全一矩阵
a = torch.randn((2,3))
b = a.numpy()#将一个torch.Tensor转换为numpy
c = torch.from_numpy(b)#将numpy转换为Tensor
print(a)
print(b)
print(c)
#使用Variable自动求导
def two():
# 构建Variable
x = Variable(torch.Tensor([1, 2, 3]), requires_grad=True)
w = Variable(torch.Tensor([4, 5, 6]), requires_grad=True)
b = Variable(torch.Tensor([7, 8, 9]), requires_grad=True)
# 函数等式
y = w * x ** 2 + b
# 使用梯度下降计算各变量的偏导数
y.backward(torch.Tensor([1, 1, 1]))
print(x.grad)
print(w.grad)
print(b.grad)
线性回归例子:
import torch
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
from torch import nn
x = torch.unsqueeze(torch.linspace(-1,1,100),dim=1)
y = 3*x+10+torch.rand(x.size())
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression,self).__init__()
self.Linear = nn.Linear(1,1)
def forward(self,x):
return self.Linear(x)
model = LinearRegression()
Loss = nn.MSELoss()
Opt = torch.optim.SGD(model.parameters(),lr=0.01)
for i in range(1000):
inputs = Variable(x)
targets = Variable(y)
outputs = model(inputs)
loss = Loss(outputs,targets)
Opt.zero_grad()
loss.backward()
Opt.step()
model.eval()
predict = model(Variable(x))
plt.plot(x.numpy(),y.numpy(),'ro')
plt.plot(x.numpy(),predict.data.numpy())
plt.show()
来源:https://blog.csdn.net/dong_liuqi/article/details/106461847
0
投稿
猜你喜欢
- 前言图片是Word的一种特殊内容,这篇文章主要介绍了关于Python操作word文档,向里面插入图片和表格的相关内容,下面话不多说了,来一起
- modelform是model衍生出来的form .modelform的用法非常死.首先在models.py里创建模型表.所有的form组件
- 有时候我们需要判断某一个IP地址是否属于一个网段,以决定该用户能否访问系统.比如用户登录的IP是218.6.7.7,而我们的程序必须判断他是
- 如何在生产上部署Django?Django的部署可以有很多方式,采用nginx+uwsgi的方式是其中比较常见的一种方式。uwsgi介绍uW
- 获取一组radio被选中项的值var item = $(’input[@name=items][@checke
- 官方文档:https://elasticsearch-py.readthedocs.io/en/master/1、介绍python提供了操作
- 首先,与其他语言不同,JS的效率很大程度是取决于JS engine的效率。除了引擎实现的优劣外,引擎自己也会为一些特殊的代码模式采取一些优化
- Python自带一个轻量级的关系型数据库SQLite。这一数据库使用SQL语言。SQLite作为后端数据库,可以搭配Python建网站,或者
- 在平常的一些的小规模的数据的过滤、清洗过程中使用最多的就是正则表达式,但是随着数据规模的增大,正则表达式就显得有些心有余力不足了。正则表达式
- 随着网络的发展,人们通过各种方式使用它。今天,网络购物,跟朋友或者不认识的人聊天,管理银行账户,以及一些日常应用,共享照片或视频,等等。事实
- 参考服务器安装的是Centos 系统。uwsgi是使用pip安装的。nginx是使用yum install nginx安装。python 2
- 调度和锁定在很多客户一起查询数据表时,如果使客户能最快地查询到数据就是调度和锁定做的工作了。在MySQL中,我们把select操作叫做读,把
- 持续更新一些常用的Tensor操作,比如List,Numpy,Tensor之间的转换,Tensor的拼接,维度的变换等操作。其它Tensor
- Juan Pablo De Gregorio 的 原文很多人都问我如何为一本杂志、一份报纸、一张海报、一份简报或是一份出版物选择
- 运行以下代码: Dim com As ADODB.Command Dim rst
- Oblog4.6 ACCESS版转换为UCenterHome1.5的全过程1、 说明:
- 一、引言Server端的脚本运行环境,它简单易用,不需要编译和连接,脚本可以在 Server端直接运行,并且它支持多用户、多线程,因为 AS
- Python中的闭包前几天又有人留言,关于其中一个闭包和re.sub的使用不太清楚。我在脚本之家搜索了下,发现没有写过闭包相关的东西,所以决
- 如何做一个自己的QQ?这不是什么新鲜的东西,看看代码:refresh.htm<HTML><HEAD><titl
- 最近要做个从 pdf 文件中抽取文本内容的工具,大概查了一下 python 里可以使用 pdfminer 来实现。下面就看看怎样使用吧。PD