详解使用Pytorch Geometric实现GraphSAGE模型
作者:实力 发布时间:2021-09-30 21:30:18
GraphSAGE是一种用于图神经网络中的节点嵌入学习方法。它通过聚合节点邻居的信息来生成节点的低维表示,使节点表示能够更好地应用于各种下游任务,如节点分类、链路预测等。
图构建
在使用GraphSAGE对节点进行嵌入学习之前,我们需要先将原始数据转换为图结构,并将其存储为Pytorch Tensor格式。例如,我们可以使用networkx库来构建一个简单的图:
import networkx as nx
G = nx.karate_club_graph()
然后,我们可以使用Pytorch Geometric库将NetworkX图转换为Pytorch Tensor格式。首先,我们需要安装Pytorch Geometric并导入所需的类:
!pip install torch-geometric
from torch_geometric.datasets import Planetoid
from torch_geometric.transforms import NormalizeFeatures
from torch_geometric.utils.convert import from_networkx
接着,我们可以使用from_networkx
函数将NetworkX图转换为Pytorch Tensor格式:
data = from_networkx(G)
此时,data
对象包含了关于节点、边及其属性的信息,例如:
data.edge_index: 2x(#edges)的长整型张量,表示边的起点和终点
data.x
: n×dn \times dn×d 的浮点型张量,表示每个节点的特征向量(其中nnn是节点数量,ddd是特征维度)
注意,此时的data
对象并未包含邻居信息。接下来,我们将介绍如何使用Sampler方法采样节点邻居。
Sampler方法
GraphSAGE使用Sampler方法来聚合邻居信息。在Pytorch Geometric中,可以使用Various Sampling方法来实现Sampler。例如,使用ClusterData方法将图分成多个子图,然后对每个子图进行采样操作。
以下是ClusterData
的使用示例:
from torch_geometric.utils import degree, to_undirected
from torch_geometric.transforms import ClusterData
# Convert the graph to an undirected graph, so we can aggregate neighbors in both directions.
G = to_undirected(G)
# Compute the degree of each node.
deg = degree(data.edge_index[0], num_nodes=data.num_nodes)
# Use METIS algorithm to partition the graph into multiple subgraphs.
cluster_data = ClusterData(data, num_parts=2, recursive=False, transform=NormalizeFeatures(),
degree=deg)
这里我们将原始图分成两个子图,并对每个子图进行规范化特征转换。注意,在使用ClusterData方法之前,需要将原始图转换为无向图。
另一个常用的Sampler方法是在随机游动时对邻居进行采样,这种方法被称为随机游走采样(Random Walk Sampling)。以下是随机游走采样的示例代码:
from torch_geometric.utils import random_walk
# Perform random walk sampling to obtain node neighbor samples.
walk_length = 20 # The length of random walk trail.
num_steps = 4 # The number of nodes to sample from each step.
data.batch = None
data.edge_index = to_undirected(data.edge_index) # Use undirected edge for random walk.
rw_data = random_walk(data.edge_index, walk_length=walk_length, num_steps=num_steps)
这里我们将使用一个长度为20、每个步骤采样4个邻居的随机游走方法。注意,在使用随机游走方法进行采样之前,需要使用无向边。
GraphSAGE模型定义
GraphSAGE模型包含3个部分:1)图卷积层;2)聚合器(Aggregator);3)输出层。我们将在本节中介绍如何使用Pytorch实现这些组件。
首先,让我们定义一个图卷积层。图卷积层的输入是节点特征矩阵、邻接矩阵和聚合器,输出是新的节点特征矩阵。以下是图卷积层的代码实现:
import torch.nn.functional as F
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.nn import global_mean_pool
class GraphSageConv(MessagePassing):
def __init__(self, in_channels, out_channels, aggr='mean'):
super(GraphSageConv, self).__init__(aggr=aggr)
self.lin = nn.Linear(in_channels, out_channels)
def forward(self, x, edge_index):
return self.propagate(edge_index, x=x)
def message(self, x_j):
return x_j
def update(self, aggr_out, x):
return F.relu(self.lin(torch.cat([x, aggr_out], dim=1)))
这里我们继承了MessagePassing
类,并在__init__
函数中定义了一个全连接层,用于将输入特征矩阵x
从 dind_{in}din 维映射到 doutd_{out}dout 维。在forward
函数中,我们使用propagate
方法来实现消息传递操作;在message
函数中,我们仅向下游节点发送原始特征数据;在update
函数中,我们首先对聚合结果进行ReLU非线性变换,然后再通过全连接层进行节点特征的更新。
接下来,让我们定义一个聚合器。聚合器的输入是采样得到的邻居特征矩阵,输出是新的节点嵌入向量。以下是聚合器的代码实现:
class MeanAggregator(nn.Module):
def __init__(self, input_dim, output_dim):
super(MeanAggregator, self).__init__()
self.input_dim = input_dim
self.output_dim = output_dim
self.lin = nn.Linear(input_dim, output_dim)
def forward(self, neigh_mean):
out = F.relu(self.lin(neigh_mean))
return out
这里我们定义了一个简单的均值聚合器,其将邻居特征矩阵中每列的均值作为节点嵌入向量,并使用全连接层进行维度变换。
最后,让我们定义整个GraphSage模型。GraphSage模型包含2个图卷积层和1个输出层。以下是模型的代码实现:
class GraphSAGE(nn.Module):
def __init__(self, in_channels, hidden_channels, out_channels, num_layers=2):
super(GraphSAGE, self).__init__()
self.conv1 = GraphSageConv(in_channels, hidden_channels)
self.aggreg1 = MeanAggregator(hidden_channels, hidden_channels)
self.conv2 = GraphSageConv(hidden_channels, out_channels)
def forward(self, x, edge_index):
x = self.conv1(x, edge_index)
x = global_mean_pool(x, edge_index) # Compute global mean over nodes.
x = self.aggreg1(x)
x = self.conv2(x, edge_index)
return x
这里我们定义了一个包含2层GraphSAGE Conv层的神经网络。在最后一层GraphSAGE Conv层之后,我们使用global_mean_pool
函数来计算节点嵌入的全局平均值。注意,在本示例中,我们仅保留了一个输出节点,因此输出矩阵的大小为1。如果需要输出多个节点,则需要设置global_mean_pool
函数中的参数。
模型训练与测试
在定义好模型后,我们可以使用Pytorch进行模型训练和测试。首先,让我们定义一个损失函数和优化器:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
这里我们使用交叉熵作为损失函数,并使用Adam优化器来更新模型参数。
接着,我们可以开始训练模型。以下是训练过程的代码实现:
num_epochs = 100
for epoch in range(num_epochs):
model.train()
optimizer.zero_grad()
out = model(data.x, data.edge_index)
loss = criterion(out[data.train_mask], data.y[data.train_mask])
loss.backward()
optimizer.step()
print('Epoch {:03d}, Loss: {:.4f}'.format(epoch, loss.item()))
这里我们遍历所有数据样本,计算预测结果和真实标签之间的交叉熵损失,并使用反向传播来更新权重。我们在每个epoch结束后打印出当前损失值。
最后,我们可以对模型进行测试。以下是测试过程的代码实现:
model.eval()
with torch.no_grad():
pred = model(data.x, data.edge_index)
pred = pred.argmax(dim=1)
acc = (pred[data.test_mask] == data.y[data.test_mask]).sum().item() / data.test_mask.sum().item()
print('Test accuracy: {:.4f}'.format(acc))
这里我们使用测试集来计算模型的准确率。注意,在执行model.eval()
后,我们需要使用torch.no_grad()
包装代码块,以禁止梯度计算。
来源:https://juejin.cn/post/7225053568152518714


猜你喜欢
- 为了得到更加清晰的图像我们需要通过技术对图像进行处理,比如使用对比度增强的方法来处理图像,对比度增强就是对图像输出的灰度级放大到指定的程度,
- 利用 vue+canvas 实现拼图小游戏,供大家参考,具体内容如下思路步骤一个拼图拼盘和一个原图参照对原图的切割以及随机排序通过W/A/D
- python下os模块强大的重命名方法renames详解 在python中有很多强大的模块,其中我们经常要使用的就是OS模块,OS
- 现在很多以内容为核心的网站上都在文章底部添加了社会化分享按钮,能让浏览用户在发现一篇有价值的文章时,可以通过社会化网络快速分享给自己的好友,
- 描述exp() 方法返回x的指数,ex。语法以下是 exp() 方法的语法:import mathmath.exp( x )注意:exp()
- 动态生成的IFRAME,设置SRC时的,不同位置带来的影响。以下所说的是在IE7下运行的。IE6下也是同样。在这个blog中,直接点击运行代
- Prometheus是什么Prometheus是一套开源监控系统和告警为一体,由go语言(golang)开发,是监控+报警+时间序列数据库的
- 一般说到组件,我首先想到的是弹窗,其他就大脑空白了。因为觉得这个是在项目中最常用的功能,提取出来方便复用的才是组件~然而我才发现这个想法是有
- 我就废话不多说了,还是直接上代码吧! url = "http://%s:%s/api-token-auth/" % (i
- 排序这个词,我的第一感觉是几乎所有App都有排序的地方,淘宝商品有按照购买时间的排序、B站的评论有按照热度排序的...对于MySQL,一说到
- 问题:在Jupyter Notebook中使用args传递参数时出现错误:原始代码:args = parser.parse_args()us
- 前言PC Server发展到今天,在性能方面有着长足的进步。64位的CPU在数年前都已经进入到寻常的家用PC之中,更别说是更高端的PC Se
- 方法一:def dict_to_numpy_method1(dict): dict_sorted=sorted(dict.iteritems
- 用最新版本(2.1.0)的pyshp解析shp文件的records时:records = sf.records()如果records里面含有
- 前言使用的pyecharts是v1.0这里需要注意,pyecharts0.5的版本和v1.0以上的版本完全不一样,可以说是两个包该包能够方便
- 前言在ORACLE数据库应用调优中,一个SQL的执行次数/频率也是常常需要关注的,因为某个SQL执行太频繁,要么是由于应用设计有缺陷,需要在
- 一、相同点dump 和 dumps 都实现了序列化load 和 loads 都实现反序列化变量从内存中变成可存储或传输的过程称之为序列化序列
- 本文实例讲述了Python基于hashlib模块的文件MD5一致性加密验证。分享给大家供大家参考,具体如下:使用hashlib模块,可对文件
- 平时工作过程中,git在push代码的时候有时会遇到如下的错误错误原因文件冲突,本地的代码和远程Repository中的文件个数不一致(即远
- 环境系统: Mac 工具: Alfred, git, homebrew, pngpaste. 语言: perl 其他: Gitee工具下载g