图神经网络GNN算法基本原理详解
作者:Cyril_KI 发布时间:2023-08-08 23:53:53
前言
本文结合一个具体的无向图来对最简单的一种GNN进行推导。本文第一部分是数据介绍,第二部分为推导过程中需要用的变量的定义,第三部分是GNN的具体推导过程,最后一部分为自己对GNN的一些看法与总结。
1. 数据
利用networkx简单生成一个无向图:
# -*- coding: utf-8 -*-
"""
@Time : 2021/12/21 11:23
@Author :KI
@File :gnn_basic.py
@Motto:Hungry And Humble
"""
import networkx as nx
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
G = nx.Graph()
node_features = [[2, 3], [4, 7], [3, 7], [4, 5], [5, 5]]
edges = [(1, 2), (1, 3), (2, 4), (2, 5), (1, 3), (3, 5), (3, 4)]
edge_features = [[1, 3], [4, 1], [1, 5], [5, 3], [5, 6], [5, 4], [4, 3]]
colors = []
edge_colors = []
# add nodes
for i in range(1, len(node_features) + 1):
G.add_node(i, feature=str(i) + ':(' + str(node_features[i-1][0]) + ',' + str(node_features[i-1][1]) + ')')
colors.append('#DCBB8A')
# add edges
for i in range(1, len(edge_features) + 1):
G.add_edge(edges[i-1][0], edges[i-1][1], feature='(' + str(edge_features[i-1][0]) + ',' + str(edge_features[i-1][1]) + ')')
edge_colors.append('#3CA9C4')
# draw
fig, ax = plt.subplots()
pos = nx.spring_layout(G)
nx.draw(G, pos=pos, node_size=2000, node_color=colors, edge_color='black')
node_labels = nx.get_node_attributes(G, 'feature')
nx.draw_networkx_labels(G, pos=pos, labels=node_labels, node_size=2000, node_color=colors, font_color='r', font_size=14)
edge_labels = nx.get_edge_attributes(G, 'feature')
nx.draw_networkx_edge_labels(G, pos, edge_labels=edge_labels, font_size=14, font_color='#7E8877')
ax.set_facecolor('deepskyblue')
ax.axis('off')
fig.set_facecolor('deepskyblue')
plt.show()
如下所示:
其中,每一个节点都有自己的一些特征,比如在社交网络中,每个节点(用户)有性别以及年龄等特征。
5个节点的特征向量依次为:
[[2, 3], [4, 7], [3, 7], [4, 5], [5, 5]]
同样,6条边的特征向量为:
[[1, 3], [4, 1], [1, 5], [5, 3], [5, 6], [5, 4], [4, 3]]
2. 变量定义
特征向量实际上也就是节点或者边的标签,这个是图本身的属性,一直保持不变。
3. GNN算法
GNN算法的完整描述如下:Forward向前计算状态,Backward向后计算梯度,主函数通过向前和向后迭代调用来最小化损失。
主函数中:
上述描述只是一个总体的概述,可以略过先不看。
3.1 Forward
早期的GNN都是RecGNN,即循环GNN。这种类型的GNN基于信息传播机制: GNN通过不断交换邻域信息来更新节点状态,直到达到稳定均衡。节点的状态向量 x 由以下 f w 函数来进行周期性更新:
解析上述公式:对于节点 n ,假设为节点1,更新其状态需要以下数据参与:
这里的fw只是形式化的定义,不同的GNN有不同的定义,如随机稳态嵌入(SSE)中定义如下:
由更新公式可知,当所有节点的状态都趋于稳定状态时,此时所有节点的状态向量都包含了其邻居节点和相连边的信息。
这与图嵌入有些类似:如果是节点嵌入,我们最终得到的是一个节点的向量表示,而这些向量是根据随机游走序列得到的,随机游走序列中又包括了节点的邻居信息, 因此节点的向量表示中包含了连接信息。
证明上述更新过程能够收敛需要用到不动点理论,这里简单描述下:
如果我们有以下更新公式:
GNN的Foward描述如下:
解释:
3.2 Backward
在节点嵌入中,我们最终得到了每个节点的表征向量,此时我们就能利用这些向量来进行聚类、节点分类、链接预测等等。
GNN中类似,得到这些节点状态向量的最终形式不是我们的目的,我们的目的是利用这些节点状态向量来做一些实际的应用,比如节点标签预测。
因此,如果想要预测的话,我们就需要一个输出函数来对节点状态进行变换,得到我们要想要的东西:
最容易想到的就是将节点状态向量经过一个前馈神经网络得到输出,也就是说 g w g_w gw可以是一个FNN,同样的, f w f_w fw也可以是一个FNN:
我们利用 g w g_w gw函数对节点 n n n收敛后的状态向量 x n x_n xn以及其特征向量 l n l_n ln进行变换,就能得到我们想要的输出,比如某一类别,某一具体的数值等等。
在BP算法中,我们有了输出后,就能算出损失,然后利用损失反向传播算出梯度,最后再利用梯度下降法对神经网络的参数进行更新。
对于某一节点的损失(比如回归)我们可以简单定义如下:
有了z(t)后,我们就能求导了:
z(t)的求解方法在Backward中有描述:
因此,在Backward中需要计算以下导数:
4.总结与展望
本文所讲的GNN是最原始的GNN,此时的GNN存在着不少的问题,比如对不动点隐藏状态的更新比较低效。
由于CNN在CV领域的成功,许多重新定义图形数据卷积概念的方法被提了出来,图卷积神经网络ConvGNN也被提了出来,ConvGNN被分为两大类:频域方法(spectral-based method )和空间域方法(spatial-based method)。2009年,Micheli在继承了来自RecGNN的消息传递思想的同时,在架构上复合非递归层,首次解决了图的相互依赖问题。在过去的几年里还开发了许多替代GNN,包括GAE和STGNN。这些学习框架可以建立在RecGNN、ConvGNN或其他用于图形建模的神经架构上。
GNN是用于图数据的深度学习架构,它将端到端学习与归纳推理相结合,业界普遍认为其有望解决深度学习无法处理的因果推理、可解释性等一系列瓶颈问题,是未来3到5年的重点方向。
因此,不仅仅是GNN,图领域的相关研究都是比较有前景的,这方面的应用也十分广泛,比如推荐系统、计算机视觉、物理/化学(生命科学)、药物发现等等。
来源:https://blog.csdn.net/Cyril_KI/article/details/122058881
猜你喜欢
- 如何制作K线图?也不难,代码和说明见下:<%@ Language=VBScript %><%Respo
- “表情包”是现在非常流行的交流方式,通过一张图片就能把文字不能表达或不便于表达的情感给表示出来,表情包一经诞生,就统治了中国人的社交圈,尤其
- 一、join函数(一)参数使用说明描述Python join() 方法用于将序列中的元素以指定的字符连接生成一个新的字符串。语法join()
- 前言在Python中,import操作应该算是最为频繁和常见的,但同时也应该是最核心需要搞清楚其工作原理的地方,比如,python是如何找到
- 在Python中,对列表进行排序有两种方法。一种是调用 sort() 方法,该方法没有返回值,对列表本身进行升序排序。c
- 前段时间,接到一个需求,要求下载某一个网站的视频,然后自己从网上查阅了相关的资料,在这里做一个总结。1. m3u8文件m3u8是苹果公司推出
- 在web2.0的站中用户互动性是很强的,例如用户留言我们可能放开img标签,允许用户外链其他站点的图片,那么我们就需要解决图片尺寸过大所带来
- 英文的文档在这里,详细全面,本文仅为自己的学习笔记,只是试图通过转述加深自己的学习,不详细不全面。由于浏览器之间的差异,所以在JS中监听事件
- 1.安装插件,在非虚拟环境conda install nb_condaconda install ipykernel2、安装ipykerne
- 排序算法是《数据结构与算法》中最基本的算法之一,也是面试必背题,为方便技术交流,文末创建技术交流群。排序算法可以分为内部排序和外部排序,内部
- 1. 前言python除了丰富的第三方库外,本身也提供了一些内在的方法和底层的一些属性,大家比较常用的如dict、list、set、min、
- 在使用Django过程中需要开发一些API给其他系统使用,为了安全把Token等验证信息放在header头中。如何获取:使用request.
- 下面这个例子描述的是在Godaddy-Linux托管帐户上使用JSP连接到某个MySQL数据库。 <%@ page
- 如何用ASP发送HTML格式的邮件?HTML格式的邮件可以把网页上的所有元素,包括文字和图片集成保存在一个文件中,阅读和链接非常便捷,请问在
- 最近在看python脚本语言,脚本语言是一种解释性的语言,不需要编译,可以直接用,由解释器来负责解释。python语言很强大,而且写起来很简
- 问题:在安装SP4补丁的时候,老是报验证密码错误。上网查了一下资料,发现是一个小bug。按照一下操作,安装正常。SQL Server补丁安装
- 前言在golang当中,defer代码块会在函数调用链表中增加一个函数调用。这个函数调用不是普通的函数调用,而是会在函数正常返回,也就是re
- islower()方法判断检查字符串的所有的字符(字母)是否为小写。语法以下是islower()方法的语法:str.islowe
- 不知道工商银行帐号是否是这样的格式, 如果错了请大家见谅!<script language="javascript"
- 阅读上一章:Chapter 13 为文字指定样式Chapter 14 图片替换随着更多设计师与开发者开始使用标准(特别是CSS),每天都会有