python神经网络学习利用PyTorch进行回归运算
作者:Bubbliiiing 发布时间:2023-02-24 13:30:47
学习前言
我发现不仅有很多的Keras模型,还有很多的PyTorch模型,还是学学Pytorch吧,我也想了解以下tensor到底是个啥。
PyTorch中的重要基础函数
1、class Net(torch.nn.Module)神经网络的构建:
PyTorch中神经网络的构建和Tensorflow的不一样,它需要用一个类来进行构建(后面还可以用与Keras类似的Sequential模型构建),当然基础还是用类构建,这个类需要继承PyTorch中的神经网络模型,torch.nn.Module,具体构建方式如下:
# 继承torch.nn.Module模型
class Net(torch.nn.Module):
# 重载初始化函数(我忘了这个是不是叫重载)
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
# Applies a linear transformation to the incoming data: :math:y = xA^T + b
# 全连接层,公式为y = xA^T + b
# 在初始化的同时构建两个全连接层(也就是一个隐含层)
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
# forward函数用于构建前向传递的过程
def forward(self, x):
# 隐含层的输出
hidden_layer = functional.relu(self.hidden(x))
# 实际的输出
output_layer = self.predict(hidden_layer)
return output_layer
该部分构建了一个含有一层隐含层的神经网络,隐含层神经元个数为n_hidden。
在建立了上述的类后,就可以通过如下函数建立神经网络:
net = Net(n_feature=1, n_hidden=10, n_output=1)
2、optimizer优化器
optimizer用于构建模型的优化器,与tensorflow中优化器的意义相同,PyTorch的优化器在前缀为torch.optim的库中。
优化器需要传入net网络的参数。
具体使用方式如下:
# torch.optim是优化器模块
# Adam可以改成其它优化器,如SGD、RMSprop等
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
3、loss损失函数定义
loss用于定义神经网络训练的损失函数,常用的损失函数是均方差损失函数(回归)和交叉熵损失函数(分类)。
具体使用方式如下:
# 均方差lossloss_func = torch.nn.MSELoss()
4、训练过程
训练过程分为三个步骤:
1、利用网络预测结果。
prediction = net(x)
2、利用预测的结果与真实值对比生成loss。
loss = loss_func(prediction, y)
3、进行反向传递(该部分有三步)。
# 均方差loss
# 反向传递步骤
# 1、初始化梯度
optimizer.zero_grad()
# 2、计算梯度
loss.backward()
# 3、进行optimizer优化
optimizer.step()
全部代码
这是一个简单的回归预测模型。
import torch
from torch.autograd import Variable
import torch.nn.functional as functional
import matplotlib.pyplot as plt
import numpy as np
# x的shape为(100,1)
x = torch.from_numpy(np.linspace(-1,1,100).reshape([100,1])).type(torch.FloatTensor)
# y的shape为(100,1)
y = torch.sin(x) + 0.2*torch.rand(x.size())
class Net(torch.nn.Module):
def __init__(self, n_feature, n_hidden, n_output):
super(Net, self).__init__()
# Applies a linear transformation to the incoming data: :math:y = xA^T + b
# 全连接层,公式为y = xA^T + b
self.hidden = torch.nn.Linear(n_feature, n_hidden)
self.predict = torch.nn.Linear(n_hidden, n_output)
def forward(self, x):
# 隐含层的输出
hidden_layer = functional.relu(self.hidden(x))
output_layer = self.predict(hidden_layer)
return output_layer
# 类的建立
net = Net(n_feature=1, n_hidden=10, n_output=1)
# torch.optim是优化器模块
optimizer = torch.optim.Adam(net.parameters(), lr=1e-3)
# 均方差loss
loss_func = torch.nn.MSELoss()
for t in range(1000):
prediction = net(x)
loss = loss_func(prediction, y)
# 反向传递步骤
# 1、初始化梯度
optimizer.zero_grad()
# 2、计算梯度
loss.backward()
# 3、进行optimizer优化
optimizer.step()
if t & 50 == 0:
print("The loss is",loss.data.numpy())
运行结果为:
The loss is 0.27913737
The loss is 0.2773982
The loss is 0.27224126
…………
The loss is 0.0035993527
The loss is 0.0035974088
The loss is 0.0035967692
来源:https://blog.csdn.net/weixin_44791964/article/details/101790418
猜你喜欢
- 本文实例讲述了PHP基于cookie与session统计网站访问量并输出显示的方法。分享给大家供大家参考,具体如下:<?php$f_o
- Some readers have asked to me what
- 由于数据文件平时在数据库运行的时候处于使用状态,故当数据库处于打开状态时,管理员是无法重命名数据文件名字的。那么一定要更改这个数据文件的名字
- 我的Windows 8.1 环境1.下载安装Python 2.7.6在Python官方网站中下载Python2.7.6的Windows安装包
- 什么是存储过程呢?定义:将常用的或很复杂的工作,预先用SQL语句写好并用一个指定的名称存储起来, 那么以后要叫数据库提供与已定义好的存储过程
- W3C发布了WCAG 2.0提案(Web Content Accessibility Guidelines 网页内容无障碍指南),大概为了实
- <% pagenum=55'指定打印行数 %> <HTML> <HEAD> <
- 第一次写技术博客,有不尽如人意的地方,还请见谅和指正。为什么想整理这方面的类容,我觉得就像油画家要了解他的颜料和画布、雕塑家要了解他的石材一
- 从XML中读取数据到内存的实例: public clsSi
- GoLang之使goroutine停止的5种方法1.goroutine停止介绍goroutine是Go语言实现并发编程的利器,简单的一个指令
- 一个更易读的网站意味着网站使用性的改良以及提供愉悦的阅读体验。我们希望浏览者们能或者这些好处不是吗?这篇文章我们将介绍5个简单的方法让你能提
- 本文实例为大家分享了python环境路径设置方法,以及命令行运行python脚本,供大家参考,具体内容如下找Python安装目录,设置环境路
- this指针是面向对象程序设计中的一项重要概念,它表示当前运行的对象。在实现对象的方法时,可以使用this指针来获得该对象自身的引用。和其他
- 1、为图片加入水印功能 Dim Jpeg Set Jpeg = Server.Create
- 制作网页可说是易学难精,因此,不断吸收经验可弥补不足,以下列出的50个制作主页的独门招数可帮助你尽快成为高手,哈哈!1、让读者有理由逗留。要
- 1、并双击新建工程窗口中ActiveX DLL图标,VB将自动为项目添加一个类模块,并将该项目类型设置为ActiveX DLL。2、在属性窗
- 本文实例讲述了PHP编程实现多维数组按照某个键值排序的方法。分享给大家供大家参考,具体如下:实现对多维数组按照某个键值排序的两种解决方法(a
- 方法 bindParam() 和 bindValue() 非常相似。 唯一的区别就是前者使用一个PHP变量绑定参数,而后者使用一个值。 所以
- 这个javascript农历日历,万年历代码网上看到的,很不错,功能齐全,值得收藏!功能介绍:动态显示当前世界各国各时区时间,显示当前农历,
- 在现在的项目里,不管是电商项目还是别的项目,在管理端都会有导出的功能,比方说订单表导出,用户表导出,业绩表导出。这些都需要提前生成excel