PyTorch搭建一维线性回归模型(二)
作者:Liam Coder 发布时间:2023-02-16 19:27:15
PyTorch基础入门二:PyTorch搭建一维线性回归模型
1)一维线性回归模型的理论基础
给定数据集,线性回归希望能够优化出一个好的函数,使得能够和尽可能接近。
如何才能学习到参数和呢?很简单,只需要确定如何衡量与之间的差别,我们一般通过损失函数(Loss Funciton)来衡量:。取平方是因为距离有正有负,我们于是将它们变为全是正的。这就是著名的均方误差。我们要做的事情就是希望能够找到和,使得:
均方差误差非常直观,也有着很好的几何意义,对应了常用的欧式距离。现在要求解这个连续函数的最小值,我们很自然想到的方法就是求它的偏导数,让它的偏导数等于0来估计它的参数,即:
求解以上两式,我们就可以得到最优解。
2)代码实现
首先,我们需要“制造”出一些数据集:
import torch
import matplotlib.pyplot as plt
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()
我们想要拟合的一维回归模型是。上面制造的数据集也是比较接近这个模型的,但是为了达到学习效果,人为地加上了torch.rand()值增加一些干扰。
上面人为制造出来的数据集的分布如下:
有了数据,我们就要开始定义我们的模型,这里定义的是一个输入层和输出层都只有一维的模型,并且使用了“先判断后使用”的基本结构来合理使用GPU加速。
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
def forward(self, x):
out = self.linear(x)
return out
if torch.cuda.is_available():
model = LinearRegression().cuda()
else:
model = LinearRegression()
然后我们定义出损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降进行优化:
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
接下来,开始进行模型的训练。
num_epochs = 1000
for epoch in range(num_epochs):
if torch.cuda.is_available():
inputs = Variable(x).cuda()
target = Variable(y).cuda()
else:
inputs = Variable(x)
target = Variable(y)
# 向前传播
out = model(inputs)
loss = criterion(out, target)
# 向后传播
optimizer.zero_grad() # 注意每次迭代都需要清零
loss.backward()
optimizer.step()
if (epoch+1) %20 == 0:
print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
首先定义了迭代的次数,这里为1000次,先向前传播计算出损失函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要记得将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于看到结果,每隔一段时间输出当前的迭代轮数和损失函数。
接下来,我们通过model.eval()函数将模型变为测试模式,然后将数据放入模型中进行预测。最后,通过画图工具matplotlib看一下我们拟合的结果,代码如下:
model.eval()
if torch.cuda.is_available():
predict = model(Variable(x).cuda())
predict = predict.data.cpu().numpy()
else:
predict = model(Variable(x))
predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()
其拟合结果如下图:
附上完整代码:
# !/usr/bin/python
# coding: utf8
# @Time : 2018-07-28 18:40
# @Author : Liam
# @Email : luyu.real@qq.com
# @Software: PyCharm
# .::::.
# .::::::::.
# :::::::::::
# ..:::::::::::'
# '::::::::::::'
# .::::::::::
# '::::::::::::::..
# ..::::::::::::.
# ``::::::::::::::::
# ::::``:::::::::' .:::.
# ::::' ':::::' .::::::::.
# .::::' :::: .:::::::'::::.
# .:::' ::::: .:::::::::' ':::::.
# .::' :::::.:::::::::' ':::::.
# .::' ::::::::::::::' ``::::.
# ...::: ::::::::::::' ``::.
# ```` ':. ':::::::::' ::::..
# '.:::::' ':'````..
# 美女保佑 永无BUG
import torch
from torch.autograd import Variable
import numpy as np
import random
import matplotlib.pyplot as plt
from torch import nn
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
# 画图
# plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show()
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
def forward(self, x):
out = self.linear(x)
return out
if torch.cuda.is_available():
model = LinearRegression().cuda()
else:
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
num_epochs = 1000
for epoch in range(num_epochs):
if torch.cuda.is_available():
inputs = Variable(x).cuda()
target = Variable(y).cuda()
else:
inputs = Variable(x)
target = Variable(y)
# 向前传播
out = model(inputs)
loss = criterion(out, target)
# 向后传播
optimizer.zero_grad() # 注意每次迭代都需要清零
loss.backward()
optimizer.step()
if (epoch+1) %20 == 0:
print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
model.eval()
if torch.cuda.is_available():
predict = model(Variable(x).cuda())
predict = predict.data.cpu().numpy()
else:
predict = model(Variable(x))
predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()
来源:https://blog.csdn.net/out_of_memory_error/article/details/81262309
猜你喜欢
- 引言使用 python 绘制网络训练过程中的的 loss 曲线以及准确率变化曲线,这里的主要思想就时先把想要的损失值以及准确率值保存下来,保
- Python有自己内置的标准GUI库--Tkinter,只要安装好Python就可以调用。今天学习到了图形界面设计的问题,刚开始就卡住了。为
- 一、数据类型在tf中,数据类型有整型(默认是int32),浮点型(默认是float32),以及布尔型,字符串。二、数据类型信息①.devic
- 问题:将文件夹a下任意命名的10个文件修改为如下图所示文件?代码:#coding:utf-8import ospath = "./
- 本文实例为大家分享了python读取Excel实例的具体代码,供大家参考,具体内容如下1.操作步骤:(1)安装python官方Excel库-
- 在本篇的开始之前,我必须阐明,我们对数组无论是索引还是切片,我是通过编号(或称为序列号)来进行操作,请记住:无论是 0轴(行)还是 1轴(列
- 之前爬美团外卖后台的时候出现的问题,各种方式拖动验证码都无法成功,包括直接控制拉动,模拟人工轨迹的随机拖动都失败了,最后发现只要用chrom
- python3 最常用的三种装饰器语法总结1.简述语法装饰器也叫函数装饰器,主要作用是在不修改原来函数的代码情况下(函数本身不会被修改,执行
- 你是否曾为表单设计感到过沮丧或不知所措呢?接下来三篇文章,希望能彻底改变你的看法,真正爱上Web表单设计。首先感谢Luke Wroblews
- 引用类型(Reference)在许多计算机语言中都被使用,而且是作为一个非常强大而实用的特性存在。它有类似指针(Pointer)的实现,却又
- 本文实例为大家分享了python爬取哈尔滨天气信息的具体代码,供大家参考,具体内容如下环境:windows7python3.4(pip in
- 本文实例为大家分享了python简单贪吃蛇的具体代码,供大家参考,具体内容如下import sysimport randomimport p
- 说明1、模型集成是指将一系列不同模型的预测结果集成在一起,从而获得更好的预测结果。2、对于模型集成来说,模型的多样性非常重要。Diversi
- 对python3下的requests使用并不是很熟练,今天稍微用了下,请求几次下来后发现出现连接超时的异常,上网查了下,找到了一个还算中肯的
- 需求描述最近在写一个图像标注小工具,其中需要用到一个缩略图列表,来查看文件夹内的图片文件。这里整理一个基于QListWidget实现的版本,
- 一、selenium截取验证码import jsonfrom io import BytesIOimport timefrom test.t
- 本文实例讲述了Python实现小数转化为百分数的格式化输出方法。分享给大家供大家参考,具体如下:比如将 0.1234 转化为 12.34%
- Swin TransformerSwin Transformer是一种用于图像处理的深度学习模型,它可以用于各种计算机视觉任务,如图像分类、
- 前言在前几篇博客中,分别就棋子的颜色识别、模板匹配等定位方式进行了介绍和实践,这一篇博客就来验证一下github中最热门的跳一跳 * 中采用的
- 1.迭代器当您创建一个列表时,你可以逐个读取它的项。逐项读取其项称为迭代:mylist是一个可迭代的对象。当你使用列表解析式时,你创建了一个