python神经网络学习利用PyTorch进行回归运算
作者:Bubbliiiing 发布时间:2023-02-24 13:30:47
学习前言
我发现不仅有很多的Keras模型,还有很多的PyTorch模型,还是学学Pytorch吧,我也想了解以下tensor到底是个啥。
PyTorch中的重要基础函数
1、class Net(torch.nn.Module)神经网络的构建:
PyTorch中神经网络的构建和Tensorflow的不一样,它需要用一个类来进行构建(后面还可以用与Keras类似的Sequential模型构建),当然基础还是用类构建,这个类需要继承PyTorch中的神经网络模型,torch.nn.Module,具体构建方式如下:
# 继承torch.nn.Module模型
class Net(torch.nn.Module):
# 重载初始化函数(我忘了这个是不是叫重载)
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
# Applies a linear transformation to the incoming data: :math:y = xA^T + b
# 全连接层,公式为y = xA^T + b
# 在初始化的同时构建两个全连接层(也就是一个隐含层)
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
# forward函数用于构建前向传递的过程
def forward(self, x):
# 隐含层的输出
hidden_layer = functional.relu(self.hidden(x))
# 实际的输出
output_layer = self.predict(hidden_layer)
return output_layer
该部分构建了一个含有一层隐含层的神经网络,隐含层神经元个数为n_hidden。
在建立了上述的类后,就可以通过如下函数建立神经网络:
net = Net(n_feature=1, n_hidden=10, n_output=1)
2、optimizer优化器
optimizer用于构建模型的优化器,与tensorflow中优化器的意义相同,PyTorch的优化器在前缀为torch.optim的库中。
优化器需要传入net网络的参数。
具体使用方式如下:
# torch.optim是优化器模块
# Adam可以改成其它优化器,如SGD、RMSprop等
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
3、loss损失函数定义
loss用于定义神经网络训练的损失函数,常用的损失函数是均方差损失函数(回归)和交叉熵损失函数(分类)。
具体使用方式如下:
# 均方差lossloss_func = torch.nn.MSELoss()
4、训练过程
训练过程分为三个步骤:
1、利用网络预测结果。
prediction = net(x)
2、利用预测的结果与真实值对比生成loss。
loss = loss_func(prediction, y)
3、进行反向传递(该部分有三步)。
# 均方差loss
# 反向传递步骤
# 1、初始化梯度
optimizer.zero_grad()
# 2、计算梯度
loss.backward()
# 3、进行optimizer优化
optimizer.step()
全部代码
这是一个简单的回归预测模型。
import torch
from torch.autograd import Variable
import torch.nn.functional as functional
import matplotlib.pyplot as plt
import numpy as np
# x的shape为(100,1)
x = torch.from_numpy(np.linspace(-1,1,100).reshape([100,1])).type(torch.FloatTensor)
# y的shape为(100,1)
y = torch.sin(x) + 0.2*torch.rand(x.size())
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
# Applies a linear transformation to the incoming data: :math:y = xA^T + b
# 全连接层,公式为y = xA^T + b
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
def forward(self, x):
# 隐含层的输出
hidden_layer = functional.relu(self.hidden(x))
output_layer = self.predict(hidden_layer)
return output_layer
# 类的建立
net = Net(n_feature=1, n_hidden=10, n_output=1)
# torch.optim是优化器模块
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# 均方差loss
loss_func = torch.nn.MSELoss()
for t in range(1000):
prediction = net(x)
loss = loss_func(prediction, y)
# 反向传递步骤
# 1、初始化梯度
optimizer.zero_grad()
# 2、计算梯度
loss.backward()
# 3、进行optimizer优化
optimizer.step()
if t & 50 == 0:
print("The loss is",loss.data.numpy())
运行结果为:
The loss is 0.27913737
The loss is 0.2773982
The loss is 0.27224126
…………
The loss is 0.0035993527
The loss is 0.0035974088
The loss is 0.0035967692
来源:https://blog.csdn.net/weixin_44791964/article/details/101790418


猜你喜欢
- 本文实例讲述了python检测某个变量是否有定义的方法。分享给大家供大家参考。具体如下:第一种方法使用内置函数locals():'t
- 数据库优化有很多可以讲,按照支撑的数据量来分可以分为两个阶段:单机数据库和分库分表,前者一般可以支撑500W或者10G以内的数据,超过这个值
- 你一定想下载一下感兴趣的网页,以便慢慢欣赏吧!利用FrontPage能够轻松做到这一点,甚至可以下载整个站点,当然这里只能下载静态的页面。启
- 本文实例讲述了python网络编程:socketserver的基本使用方法。分享给大家供大家参考,具体如下:本文内容:socketserve
- 背景描述:Pycharm作为python专业开发工具,要比轻量级的vscode更加稳定,适合个人、团队的项目开发。同时pycharm来创建虚
- 前言当我们使用pandas处理数据的时候,经常会遇到数据重复的问题,如何找出重复数据进而分析重复原因,或者如何直接删除重复的数据是一个关键的
- 前言我们都知道 Node.js 是以单线程的模式运行的,但它使用的是事件驱动来处理并发,这样有助于我们在多核 cpu 的系统上创建多个子进程
- 本文实例讲述了微信小程序MUI导航栏透明渐变功能。分享给大家供大家参考,具体如下:导航栏透明渐变效果实现原理1. 利用position:ab
- 无论是在小得可怜的免费数据库空间或是大型电子商务网站,合理的设计表结构、充分利用空间是十分必要的。这就要求我们对数据库系统的常用数据类型有充
- 这篇文章主要介绍了Python实现序列化及csv文件读取,文中通过示例代码介绍的非常详细,对大家的学习或者工作具有一定的参考学习价值,需要的
- 1 自动微分我们在《数值分析》课程中已经学过许多经典的数值微分方法。许多经典的数值微分算法非常快,因为它们只需要计算差商。然而,他们的主要缺
- 上一一节我们讲了while循环,while循环主要用于重复程序的运行,for循环更加倾向于遍历一个项目,即将特定内容(比如一个列表、一个字符
- 函数可以有0或多个返回值,返回值需要指定数据类型,返回值通过return关键字来指定。return可以有参数,也可以没有参数,这些返回值可以
- 1、python-pptx模块简介使用python操作PPT,需要使用的模块就是python-pptx,下面来对该模块做一个简单的介绍。这里
- 1.首先读取Excel文件数据代表了各个城市店铺的装修和配置费用,要统计出装修和配置项的总费用并进行加和计算;2.pandas实现过程imp
- 背景事情是这样的,在公司内部新开发了一个功能还没有上线,目前部署在测试环境,Node服务会开启一个定时任务,每5分钟会处理好一部分数据写入到
- 问题背景周一上班,首先向同事了解了一下上周的测试情况,被告知在多实例场景下 MySQL Server hang 住,无法测试下去,原生版本不
- 一、使用python内置commands模块执行shellcommands对Python的os.popen()进行了封装,使用SHELL命令
- 目录前言yarn create 做了什么源码解析项目依赖模版配置工具函数copycopyDiremptyDir核心函数命令行交互并创建文件夹
- INSERT、DELETE、UPDATE 三种SQL语句是数据库技术的三大基本语句. 在通常的web开发中对它的处理可以说是无处不在. 如果