Pytorch结合PyG实现MLP过程详解
作者:实力 发布时间:2022-05-01 21:51:55
导入库和数据
首先,我们需要导入PyTorch和PyG库,然后准备好我们的数据。例如,我们可以使用以下方式生成一个简单的随机数据集:
from torch.utils.data import random_split
from torch_geometric.datasets import TUDataset
dataset = TUDataset(root='/tmp/ENZYMES', name='ENZYMES')
train_dataset, test_dataset = random_split(dataset, [len(dataset) - 1000, 1000])
其中,TUDataset
是PyG提供的图形数据集之一。这里我们选择了 ENZYMES
数据集并存储到 /tmp/ENZYMES
文件夹中。然后我们将该数据集分成训练集和测试集,其中训练集包含所有数据减去最后1000个数据,测试集则为最后1000个数据。
定义模型结构
接下来,我们需要定义MLP模型的结构。通过PyTorch和PyG,我们可以自己定义完整的MLP模型或者利用现有的库函数快速构建模型。在这里,我们将使用 torch.nn.Sequential
函数逐层堆叠多个线性层来实现MLP模型。以下是MLP模型定义的示例代码:
import torch.nn as nn
from torch_geometric.nn import MLP
class Net(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=3):
super(Net, self).__init__()
self.num_layers = num_layers
self.mlp = MLP([in_channels] + [hidden_channels] * (num_layers-1) + [out_channels])
def forward(self, x):
return self.mlp(x)
上述代码中,我们定义了一个 Net
类用于构建MLP网络,接收输入通道数、隐藏层节点数、输出通道数以及MLP层数作为参数。例如,我们可以按照以下方式创建一个拥有 4 层、128 个隐藏节点、并将度为图结构作为输入的MLP模型:
model = Net(in_channels=dataset.num_node_features, hidden_channels=128, out_channels=dataset.num_classes, num_layers=4)
定义训练函数
然后,我们需要定义训练函数来训练我们的MLP神经网络。在这里,我们将使用交叉熵损失和Adam优化器进行训练,并在每一个epoch结束时计算准确率并打印出来。以下是训练函数的示例代码:
import torch.optim as optim
from torch_geometric.data import DataLoader
from tqdm import tqdm
def train(model, loader, optimizer, loss_fn):
model.train()
correct = 0
total_loss = 0
for data in tqdm(loader, desc='Training'):
optimizer.zero_grad()
out = model(data.x)
pred = out.argmax(dim=1)
loss = loss_fn(out, data.y)
loss.backward()
optimizer.step()
total_loss += loss.item() * data.num_graphs
correct += pred.eq(data.y).sum().item()
return total_loss / len(loader.dataset), correct / len(loader.dataset)
在上述代码中,我们遍历加载器中的每个数据批次,并对模型进行培训。对于每个图数据批次,我们计算网络输出、预测和损失,然后通过反向传播来更新权重。最后,我们将总损失和正确率记录下来并返回。
定义测试函数
接下来,我们还需要定义测试函数来测试我们的MLP神经网络性能表现。我们将利用与训练函数相同的输出参数进行测试,并打印出最终的测试准确率。以下是测试函数的示例代码:
def test(model, loader, loss_fn):
model.eval()
correct = 0
total_loss = 0
with torch.no_grad():
for data in tqdm(loader, desc='Testing'):
out = model(data.x)
pred = out.argmax(dim=1)
loss = loss_fn(out, data.y)
total_loss += loss.item() * data.num_graphs
correct += pred.eq(data.y).sum().item()
return total_loss / len(loader.dataset), correct / len(loader.dataset)
在上述代码中,我们对测试数据集中的所有数据进行了循环,并计算网络的输出和预测。我们记录下总损失和正确分类的数据量,并返回损失和准确率之间的比率(我们使用该比率而不是精度来反映测试表现通常较小)。
训练模型并评估训练结果
最后,我们可以使用前面定义过的函数来定义主函数,从而完成MLP神经网络的训练和测试。以下是主函数的示例代码:
if __name__ == '__main__':
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Net(in_channels=dataset.num_node_features, hidden_channels=128, out_channels=dataset.num_classes, num_layers=4).to(device)
loader = DataLoader(train_dataset, batch_size=128, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=128)
optimizer = optim.Adam(model.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()
for epoch in range(1, 201):
train_loss, train_acc = train(model, loader, optimizer, loss_fn)
test_loss, test_acc = test(model, test_loader, loss_fn)
print(f'Epoch {epoch:03d}, Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, '
f'Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}')
通过上述代码,我们就可以完成MLP神经网络的训练和测试。我们使用 DataLoader
函数进行数据加载,设置学习率、损失函数、训练轮数等超参数。最后,我们可以在屏幕上看到每个时代的准确率和损失值,并通过它们评估模型的训练表现。
来源:https://juejin.cn/post/7223995981234061372


猜你喜欢
- 前言密码安全是非常重要的,因此我们在代码中往往需要对密码进行加密,以此保证密码的安全加依赖<!-- jasypt --><
- 最近碰见太多次lambda函数了,那就来详细解释一下该函数。lambda函数我们先对lambda函数进行一个简单的介绍lambda函数是一种
- 在使用django restframework serializer 序列化在django中定义的model时,有时候我们需要额外在seri
- 本文实例讲述了javascript+HTML5 canvas绘制时钟功能。分享给大家供大家参考,具体如下:效果如下:代码:<!DOCT
- 因为有把python程序打包成exe的需求,所以,有了如下的代码import timeclass LoopOver(Exception):
- 本文实例讲述了python3实现的zip格式压缩文件夹操作。分享给大家供大家参考,具体如下:思路:先把第一级目录中的文件进行遍历,如果是文件
- 先给大家介绍下sqlserver给表添加新字段、给表和字段添加备注、更新备注及查询备注,代码如下所示:-- 添加新字段及字段备注的语法USE
- Python3安装第三方爬虫库BeautifulSoup4,供大家参考,具体内容如下在做Python3爬虫练习时,从网上找到了一段代码如下:
- 如果你已经理解了block formatting contexts那么请继续,否则请先看看这篇文章。Overflow能够做一些很牛掰的事情,
- 本文实例讲述了Python只用40行代码编写的计算器。分享给大家供大家参考,具体如下:效果图:代码:from tkinter import
- 本文主要介绍了Python利用numpy实现三层神经网络的示例代码,分享给大家,具体如下:其实神经网络很好实现,稍微有点基础的基本都可以实现
- python3.x已经不支持mysqldb了,支持的是pymysql使用pandas读取MySQL数据时,使用sqlalchemy,出现No
- 本文实例讲述了Python设置默认编码为utf8的方法。分享给大家供大家参考,具体如下:这是Python的编码问题,设置python的默认编
- 随着互联网的快速发展和数据交换的广泛应用,各种数据格式的处理成为软件开发中的关键问题。JSON 作为一种通用的数据交换格式,在各种应用场景中
- 大部分语言,例如c语言,交换两个变量的值需要使用中间变量。例如交换a,b伪代码:tmp = aa = bb = tmppython里面可以实
- PS: 我的检索是在文章模块下 forum/article第一步:先安装需要的包:pip install django-haystackpi
- 以下是IE7中新支持的属性:min-height,max-height,min-width,max-width这个hack还可以使最大高度兼
- 本文实例讲述了python访问mysql数据库的实现方法。分享给大家供大家参考,具体如下:首先安装与Python版本匹配的MySQLdb示例
- 不知道大家在做网站时有没有给目录名或者文件名添加”( )”的习惯,有则改之,无则加勉。因为他有潜在的危险,起码就被我遇到了。要使页面能够使用
- pandas的DataFrame对象,本质上是二维矩阵,跟常规二维矩阵的差别在于前者额外指定了每一行和每一列的名称。这样内部数据抽取既可以用