PyTorch搭建多项式回归模型(三)
作者:Liam Coder 发布时间:2022-09-04 00:43:49
PyTorch基础入门三:PyTorch搭建多项式回归模型
1)理论简介
对于一般的线性回归模型,由于该函数拟合出来的是一条直线,所以精度欠佳,我们可以考虑多项式回归来拟合更多的模型。所谓多项式回归,其本质也是线性回归。也就是说,我们采取的方法是,提高每个属性的次数来增加维度数。比如,请看下面这样的例子:
如果我们想要拟合方程:
对于输入变量和输出值
,我们只需要增加其平方项、三次方项系数即可。所以,我们可以设置如下参数方程:
可以看到,上述方程与线性回归方程并没有本质区别。所以我们可以采用线性回归的方式来进行多项式的拟合。下面请看代码部分。
2)代码实现
当然最先要做的就是导包了,下面需要说明的只有一个:itertools中的count,这个是用来记数用的,其可以记数到无穷,第一个参数是记数的起始值,第二个参数是步长。其内部实现相当于如下代码:
def count(firstval=0, step=1):
x = firstval
while 1:
yield x
x += step
下面是导包部分代码,这里定义了一个常量POLY_DEGREE = 3用来指定多项式最高次数。
from itertools import count
import torch
import torch.autograd
import torch.nn.functional as F
POLY_DEGREE = 3
然后我们需要将数据处理成矩阵的形式:
在PyTorch里面使用torch.cat()函数来实现Tensor的拼接:
def make_features(x):
"""Builds features i.e. a matrix with columns [x, x^2, x^3, x^4]."""
x = x.unsqueeze(1)
return torch.cat([x ** i for i in range(1, POLY_DEGREE+1)], 1)
对于输入的个数据,我们将其扩展成上面矩阵所示的样子。
然后定义出我们需要拟合的多项式,可以随机抽取一个多项式来作为我们的目标多项式。当然,系数和偏置
确定了,多项式也就确定了:
W_target = torch.randn(POLY_DEGREE, 1)
b_target = torch.randn(1)
def f(x):
"""Approximated function."""
return x.mm(W_target) + b_target.item()
这里的权重已经定义好了,x.mm(W_target)表示做矩阵乘法,就是每次输入一个
得到一个
的真实函数。
在训练的时候我们需要采样一些点,可以随机生成一批数据来得到训练集。下面的函数可以让我们每次取batch_size这么多个数据,然后将其转化为矩阵形式,再把这个值通过函数之后的结果也返回作为真实的输出值:
def get_batch(batch_size=32):
"""Builds a batch i.e. (x, f(x)) pair."""
random = torch.randn(batch_size)
x = make_features(random)
y = f(x)
return x, y
接下来我们需要定义模型,这里采用一种简写的方式定义模型,torch.nn.Linear()表示定义一个线性模型,这里定义了是输入值和目标参数的行数一致(和POLY_DEGREE一致,本次实验中为3),输出值为1的模型。
# Define model
fc = torch.nn.Linear(W_target.size(0), 1)
下面开始训练模型,训练的过程让其不断优化,直到随机取出的batch_size个点中计算出来的均方误差小于0.001为止。
for batch_idx in count(1):
# Get data
batch_x, batch_y = get_batch()
# Reset gradients
fc.zero_grad()
# Forward pass
output = F.smooth_l1_loss(fc(batch_x), batch_y)
loss = output.item()
# Backward pass
output.backward()
# Apply gradients
for param in fc.parameters():
param.data.add_(-0.1 * param.grad.data)
# Stop criterion
if loss < 1e-3:
break
这样就已经训练出了我们的多项式回归模型,为了方便观察,定义了如下打印函数来打印出我们拟合的多项式表达式:
def poly_desc(W, b):
"""Creates a string description of a polynomial."""
result = 'y = '
for i, w in enumerate(W):
result += '{:+.2f} x^{} '.format(w, len(W) - i)
result += '{:+.2f}'.format(b[0])
return result
print('Loss: {:.6f} after {} batches'.format(loss, batch_idx))
print('==> Learned function:\t' + poly_desc(fc.weight.view(-1), fc.bias))
print('==> Actual function:\t' + poly_desc(W_target.view(-1), b_target))
程序运行结果如下图所示:
可以看出,真实的多项式表达式和我们拟合的多项式十分接近。现实世界中很多问题都不是简单的线性回归,涉及到很多复杂的非线性模型。但是我们可以在其特征量上进行研究,改变或者增加其特征,从而将非线性问题转化为线性问题来解决,这种处理问题的思路是我们从多项式回归的算法中应该汲取到的。
来源:https://blog.csdn.net/out_of_memory_error/article/details/81266231


猜你喜欢
- 这几天一直在看《Pro JavaScript Techniques》,书中有不少优美、健壮代码,让我不得不惊叹老外对语言这东西的研究程度之深
- 首先下载最新版本的python。www.python.org,目前版本为3.1。 接下来是安装,在windows下python的安装与其他应
- 目录一、列表求并集1. union_by二、列表求交集1. intersection_by三、列表求差集1. difference2. di
- 跟着节奏继续来探索fixtures的灵活性。一、一个测试函数/fixture一次请求多个fixture在测试函数和fixture函数中,每一
- 摘要: 三次握手,四次挥手意思是tcp建立连接时需要三次交互来完成,A发起连接A --- SYN --> BA
- 本文实例为大家分享了python+pyqt5编写md5生成器的具体代码,供大家参考,具体内容如下学了一下pyqt5,写一个小程序来实践一下。
- 看代码吧~def test(): return 1,2a, b = test()1 2a, _ = test()1
- 前言H2数据库是一个开源的关系型数据库。H2采用java语言编写,不受平台的限制,同时支持网络版和嵌入式版本,有比较好的兼容性,支持相当标准
- python实战,用户答题分享给大家。主要包含内容,文件的读取,更改,保存。不同文件夹引入模块。输入,输出操作。随机获取数据操作随机生成算数
- 前言在vue里,组件之间的作用域是独立的,父组件跟子组件之间的通讯可以通过prop属性来传参,但是在兄弟组件之间通讯就比较麻烦了。比如A组件
- 有时表或结果集包含重复的记录。有时它是允许的,但有时它需要停止重复的记录。有时它需要识别重复的记录从表中删除。本章将介绍如何防止发生在一个表
- 引言TypeScript 给 JavaScript 添加了一套类型系统,可以在编译期间检查出类型错误,这增加了代码的健壮性,但也多了一个编译
- 前言:Python 3最重要的新特性之一是对字符串和二进制数据流做了明确的区分。文本总是Unicode,由str类型表示,二进制数据则由by
- 用python实现的一个井字棋游戏,供大家参考,具体内容如下#Tic-Tac-Toe 井字棋游戏#全局常量X="X"O=
- (1)Flush的内容至少要有256字节经过反复的测试,我得出一个结论。就是flush的内容至少要有256字节。也就是只有编译产生了至少25
- 1、给定一个数据集noise-data-1.txt,该数据集中保护大量的缺失值(空格、不完整值等)。利用“全局常量”、“均值或者中位数”来填
- 1. 整体思路首先我们来梳理下整体上的实现思路,首先一点:整体思路和 vhr 一模一样。考虑到有的小伙伴可能已经忘记 vhr 中前端动态菜单
- 我认为在ASP中最好的办法是用编程实现定时刷新Cache,也就是说给Application中储存的设一个过期时间。当然,在ASP中Appli
- 使用ENUM代替字符串类型有时候, 可以通过使用ENUM来代理常规的字符串类型。一个ENUM列能够存储65535个不同的字符串值,MySQL
- 调试的定义:通过一定方法,在程序中找到并减少缺陷的数量,从而使其能正常工作。这里说一些如何调试PHP程序的经验。一、PHP自带的调试功能1、