PyTorch学习笔记之回归实战
作者:manong_wxd 发布时间:2023-09-17 10:26:19
本文主要是用PyTorch来实现一个简单的回归任务。
编辑器:spyder
1.引入相应的包及生成伪数据
import torch
import torch.nn.functional as F # 主要实现激活函数
import matplotlib.pyplot as plt # 绘图的工具
from torch.autograd import Variable
# 生成伪数据
x = torch.unsqueeze(torch.linspace(-1, 1, 100), dim = 1)
y = x.pow(2) + 0.2 * torch.rand(x.size())
# 变为Variable
x, y = Variable(x), Variable(y)
其中torch.linspace
是为了生成连续间断的数据,第一个参数表示起点,第二个参数表示终点,第三个参数表示将这个区间分成平均几份,即生成几个数据。因为torch只能处理二维的数据,所以我们用torch.unsqueeze
给伪数据添加一个维度,dim表示添加在第几维。torch.rand
返回的是[0,1)之间的均匀分布。
2.绘制数据图像
在上述代码后面加下面的代码,然后运行可得伪数据的图形化表示:
# 绘制数据图像
plt.scatter(x.data.numpy(), y.data.numpy())
plt.show()
3.建立神经网络
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
self.hidden = torch.nn.Linear(n_feature, n_hidden) # hidden layer
self.predict = torch.nn.Linear(n_hidden, n_output) # output layer
def forward(self, x):
x = F.relu(self.hidden(x)) # activation function for hidden layer
x = self.predict(x) # linear output
return x
net = Net(n_feature=1, n_hidden=10, n_output=1) # define the network
print(net) # net architecture
一般神经网络的类都继承自torch.nn.Module
,__init__()和forward()
两个函数是自定义类的主要函数。在__init__()
中都要添加一句super(Net, self).__init__(),
这是固定的标准写法,用于继承父类的初始化函数。__init__()
中只是对神经网络的模块进行了声明,真正的搭建是在forwad()
中实现。自定义类中的成员都通过self指针来进行访问,所以参数列表中都包含了self。
如果想查看网络结构,可以用print()
函数直接打印网络。本文的网络结构输出如下:
Net (
(hidden): Linear (1 -> 10)
(predict): Linear (10 -> 1)
)
4.训练网络
# 训练100次
for t in range(100):
prediction = net(x) # input x and predict based on x
loss = loss_func(prediction, y) # 一定要是输出在前,标签在后 (1. nn output, 2. target)
optimizer.zero_grad() # clear gradients for next train
loss.backward() # backpropagation, compute gradients
optimizer.step() # apply gradients
训练网络之前我们需要先定义优化器和损失函数。torch.optim
包中包括了各种优化器,这里我们选用最常见的SGD作为优化器。因为我们要对网络的参数进行优化,所以我们要把网络的参数net.parameters()
传入优化器中,并设置学习率(一般小于1)。
由于这里是回归任务,我们选择torch.nn.MSELoss()
作为损失函数。
由于优化器是基于梯度来优化参数的,并且梯度会保存在其中。所以在每次优化前要通过optimizer.zero_grad()
把梯度置零,然后再后向传播及更新。
5.可视化训练过程
plt.ion() # something about plotting
for t in range(100):
...
if t % 5 == 0:
# plot and show learning process
plt.cla()
plt.scatter(x.data.numpy(), y.data.numpy())
plt.plot(x.data.numpy(), prediction.data.numpy(), 'r-', lw=5)
plt.text(0.5, 0, 'Loss=%.4f' % loss.data[0], fontdict={'size': 20, 'color': 'red'})
plt.pause(0.1)
plt.ioff()
plt.show()
6.运行结果
来源:https://blog.csdn.net/manong_wxd/article/details/78585371


猜你喜欢
- 一 简单介绍wxpy基于itchat,使用了 Web 微信的通讯协议,,通过大量接口优化提升了模块的易用性,并进行丰富的功能扩展。实现了微信
- 自己试过很好用function zero_fill_hex(num, digits) { var s = num.toString(16);
- 为了区分选择与未选择区域,,将已选择区域的文本背景色设置为浅蓝色是个很做法。设置的路径在 Editor > Color Scheme
- Plotly Express是对 Plotly.py 的高级封装,内置了大量实用、现代的绘图模板,用户只需调用简单的API函数,即可快速生成
- 1、其中再语义分割比较常用的上采样:其实现方法为:def upconv2x2(in_channels, out_channels, mode
- 公布到网页上的Email经常会被一些工具自动提取,一些非法用户就会利用所提取的Email大肆发送垃圾邮件。这些工具大多都是查找链接中“mai
- 前言在业务迭代中,随着数据量的上升,会出现慢SQL情况,但是当我们去分析单条SQL的时候,发现其执行速度并没有那么慢,原因是什么呢,那么就可
- 之前呢,我一直对GUI不是很感兴趣,但是呢,最近由于某些特殊原因,导致不得不用tkinter,需要实现一个渐变色,但是当我翻阅文档的时候,却
- 条件语句主要有三种形式:分别为if语句、if...else语句和if...elif...else 语句1.if语句条件语句中常用的比较运算符
- python函数的闭包问题(内嵌函数)>>> def func1():... print ('fun
- 借助 org.springframework.ui.Model 对象或 Map 对象将信息传到 springmvc 的页面中需要:jstl
- 最近心血来潮加上有点闲情,动手写了第一个JavaScript版的俄罗斯方块Easy Tetris.先上Easy Tetris俄罗斯方块游戏截
- 前言 绝大多数的Oracle数据库性能问题都是由于数据库设计不合理造成的,只有少部分问题根植于Database Buffer、Share P
- 前言总是记不住字符串拼接,每次都要百度去搜索,所以在这里记录一下,好方便后续的查找,如有错误和问题可以提出,谢谢。字符串拼接分为几种方式,在
- 前言本文主要给大家介绍的是关于python对配置文件.ini增删改查操作的相关内容,分享出来供大家参考学习,下面话不多说了,来一起看看详细的
- django model的json字段的编码器不能有效编码诸如uuid,datetime等数据类型,当直接存储此类型的对象到json字段中为
- 前言CORS 即 Cross Origin Resource Sharing 跨域资源共享.跨域请求分两种:简单请求、复杂请求.简单请求简单
- 摘要:Oracle和微软都是数据库方面的大厂商,采用两家的产品的企业也不少。今天这篇文章为大家对比Oracle和SQLServer的镜像。标
- 本章节将一些Python3基础语法整理成手册,方便各位在日常使用和学习是查阅,包含了编码、标识符、保留字、注释、缩进、字符串等常用内容。编码
- 插值对于一些时间序列的问题可能比较有用。Show the code directly:import numpy as npfrom matp