PyTorch搭建一维线性回归模型(二)
作者:Liam Coder 发布时间:2023-02-16 19:27:15
PyTorch基础入门二:PyTorch搭建一维线性回归模型
1)一维线性回归模型的理论基础
给定数据集,线性回归希望能够优化出一个好的函数
,使得
能够和
尽可能接近。
如何才能学习到参数和
呢?很简单,只需要确定如何衡量
与
之间的差别,我们一般通过损失函数(Loss Funciton)来衡量:
。取平方是因为距离有正有负,我们于是将它们变为全是正的。这就是著名的均方误差。我们要做的事情就是希望能够找到
和
,使得:
均方差误差非常直观,也有着很好的几何意义,对应了常用的欧式距离。现在要求解这个连续函数的最小值,我们很自然想到的方法就是求它的偏导数,让它的偏导数等于0来估计它的参数,即:
求解以上两式,我们就可以得到最优解。
2)代码实现
首先,我们需要“制造”出一些数据集:
import torch
import matplotlib.pyplot as plt
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
# 画图
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()
我们想要拟合的一维回归模型是。上面制造的数据集也是比较接近这个模型的,但是为了达到学习效果,人为地加上了torch.rand()值增加一些干扰。
上面人为制造出来的数据集的分布如下:
有了数据,我们就要开始定义我们的模型,这里定义的是一个输入层和输出层都只有一维的模型,并且使用了“先判断后使用”的基本结构来合理使用GPU加速。
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
def forward(self, x):
out = self.linear(x)
return out
if torch.cuda.is_available():
model = LinearRegression().cuda()
else:
model = LinearRegression()
然后我们定义出损失函数和优化函数,这里使用均方误差作为损失函数,使用梯度下降进行优化:
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
接下来,开始进行模型的训练。
num_epochs = 1000
for epoch in range(num_epochs):
if torch.cuda.is_available():
inputs = Variable(x).cuda()
target = Variable(y).cuda()
else:
inputs = Variable(x)
target = Variable(y)
# 向前传播
out = model(inputs)
loss = criterion(out, target)
# 向后传播
optimizer.zero_grad() # 注意每次迭代都需要清零
loss.backward()
optimizer.step()
if (epoch+1) %20 == 0:
print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
首先定义了迭代的次数,这里为1000次,先向前传播计算出损失函数,然后向后传播计算梯度,这里需要注意的是,每次计算梯度前都要记得将梯度归零,不然梯度会累加到一起造成结果不收敛。为了便于看到结果,每隔一段时间输出当前的迭代轮数和损失函数。
接下来,我们通过model.eval()函数将模型变为测试模式,然后将数据放入模型中进行预测。最后,通过画图工具matplotlib看一下我们拟合的结果,代码如下:
model.eval()
if torch.cuda.is_available():
predict = model(Variable(x).cuda())
predict = predict.data.cpu().numpy()
else:
predict = model(Variable(x))
predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()
其拟合结果如下图:
附上完整代码:
# !/usr/bin/python
# coding: utf8
# @Time : 2018-07-28 18:40
# @Author : Liam
# @Email : luyu.real@qq.com
# @Software: PyCharm
# .::::.
# .::::::::.
# :::::::::::
# ..:::::::::::'
# '::::::::::::'
# .::::::::::
# '::::::::::::::..
# ..::::::::::::.
# ``::::::::::::::::
# ::::``:::::::::' .:::.
# ::::' ':::::' .::::::::.
# .::::' :::: .:::::::'::::.
# .:::' ::::: .:::::::::' ':::::.
# .::' :::::.:::::::::' ':::::.
# .::' ::::::::::::::' ``::::.
# ...::: ::::::::::::' ``::.
# ```` ':. ':::::::::' ::::..
# '.:::::' ':'````..
# 美女保佑 永无BUG
import torch
from torch.autograd import Variable
import numpy as np
import random
import matplotlib.pyplot as plt
from torch import nn
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)
y = 3*x + 10 + torch.rand(x.size())
# 上面这行代码是制造出接近y=3x+10的数据集,后面加上torch.rand()函数制造噪音
# 画图
# plt.scatter(x.data.numpy(), y.data.numpy())
# plt.show()
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(1, 1) # 输入和输出的维度都是1
def forward(self, x):
out = self.linear(x)
return out
if torch.cuda.is_available():
model = LinearRegression().cuda()
else:
model = LinearRegression()
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=1e-2)
num_epochs = 1000
for epoch in range(num_epochs):
if torch.cuda.is_available():
inputs = Variable(x).cuda()
target = Variable(y).cuda()
else:
inputs = Variable(x)
target = Variable(y)
# 向前传播
out = model(inputs)
loss = criterion(out, target)
# 向后传播
optimizer.zero_grad() # 注意每次迭代都需要清零
loss.backward()
optimizer.step()
if (epoch+1) %20 == 0:
print('Epoch[{}/{}], loss:{:.6f}'.format(epoch+1, num_epochs, loss.data[0]))
model.eval()
if torch.cuda.is_available():
predict = model(Variable(x).cuda())
predict = predict.data.cpu().numpy()
else:
predict = model(Variable(x))
predict = predict.data.numpy()
plt.plot(x.numpy(), y.numpy(), 'ro', label='Original Data')
plt.plot(x.numpy(), predict, label='Fitting Line')
plt.show()
来源:https://blog.csdn.net/out_of_memory_error/article/details/81262309


猜你喜欢
- 实战场景初学 Python 爬虫,十之八九大家采集的目标是网页,因此快速定位到网页内容,就成为我们面临的第一道障碍,本篇博客就为你详细说明最
- 前言Go 数组的长度不可改变,在特定场景中这样的集合就不太适用,Go中提供了一种灵活,功能强悍的内置类型切片("动态数组"
- 效果图实例代码今天我们要用微信小程序实现2048小游戏,效果图如上面所示。游戏的规则很简单,你需要控制所有方块向同一个方向运动,两个相同数字
- 基本思路就是,使用MIMEMultipart来标示这个邮件是多个部分组成的,然后attach各个部分。如果是附件,则add_header加入
- 今天真的被编码问题一直困扰着,午休都没进行。也真的见识到了各种编码。例如:gbk,unicode、utf-8、ansi、gb2312等。如果
- Insert 和 Update假设现在你要把下面的数据插入到数据库中.ID = 3TheDate=mktime(0,0,0,8,31,200
- 如何调用多个不同的ip接口灵感来源:项目的登录登出权限是调A的ip下面的接口,其他的功能调的接口是B的ip下面的接口思路:其实就是多写几个r
- 在设计中保持一致性(uniformity)是网页设计中一个重要的组成部分,它能使你的设计有效地传达信息而不会导致用户迷惑或焦虑。保证一致性的
- 我们都知道tensorflow框架可以使用tensorboard这一高级的可视化的工具,为了使用tensorboard这一套完美的可视化工具
- 在整个安装的过程中也遇到了很多的坑,故此做个记录,争取下次不再犯!我的整个基本配置如下:电脑环境如下:win10(64位)+CPU:E5-2
- SQLSRV驱动程序允许您创建一个结果集,其中包含可以根据游标类型以任何顺序访问的行。本主题将讨论客户端(缓冲)和服务器端(非缓冲)游标及其
- 装饰器作用decorator是当今最流行的设计模式之一,很多使用它的人并不知道它是一种设计模式。这种模式有什么特别之处? 有兴趣可以看看Py
- 我是一个初入互联网的视觉设计师,和以往做设计感受最大的不同就是:一个设计的最终定稿会受到多方面的挑战,有来自产品经理的,来自开发的,来自测试
- 1. 概述对于社区,没有一个明确的定义,有很多对社区的定义,如社区是指在一个网络中,有一组节点,它们彼此都相似,而组内的节点与网络中的其他节
- 前言MySQL 的权限表在数据库启动的时候就载入内存,当用户通过身份认证后,就在内存中进行相应权限的存取,这样,此用户就可以在数据库中做权限
- 前端版本更新检查,实现页面自动刷新使用vite对项目进行打包,对 js 和 css 文件使用了 chunkhash 进行了文件缓存控制,但是
- 1、公式推导 对幂律分布公式:对公式两边同时取以10为底的对数:所以对于幂律公式,对X,Y取对数后,在坐标轴上为线性方程。2、可视化 从图形
- 看到别人用td和table标签模拟的办法: 设置table的上、左padding
- 今天无意在坛子里看到这样一个求救帖(这里),看了一下,感觉问题比较好解决。但是问题背后的问题却引起了我的反思。把他的页面整理一下看看(为了便
- #!/usr/bin/env python#-*- coding: utf-8 -*-#==========================