python机器学习GCN图卷积神经网络原理解析
作者:Mr.琛 发布时间:2022-09-29 06:36:14
1. 图信号处理知识
图卷积神经网络涉及到图信号处理的相关知识,也是由图信号处理领域的知识推导发展而来,了解图信号处理的知识是理解图卷积神经网络的基础。
1.1 图的拉普拉斯矩阵
拉普拉斯矩阵是体现图结构关联的一种重要矩阵,是图卷积神经网络的一个重要部分。
1.1.1 拉普拉斯矩阵的定义及示例
实例:
按照上述计算式子,可以得到拉普拉斯矩阵为:
1.1.2 正则化拉普拉斯矩阵
1.1.3 拉普拉斯矩阵的性质
1.2 图上的傅里叶变换
傅里叶变换是一种分析信号的方法,它可分析信号的成分,也可用这些成分合成信号。它将信号从时域转换到频域,从频域视角给出了信号处理的另一种解法。(1)对于图结构,可以定义图上的傅里叶变换(GFT),对于任意一个在图G上的信号x,其傅里叶变换表示为:
从线代角度,可以清晰的看出:v1,…, vn构成了N维特征空间中的一组完备基向量,G中任意一个图信号都可表示为这些基向量的线性加权求和,系数为图信号对应傅里叶基上的傅里叶系数。
回到之前提到的拉普拉斯矩阵刻画平滑度的总变差:
可以看成:刻画图平滑度的总变差是图中所有节点特征值的线性组合,权值为傅里叶系数的平方。总变差取最小值的条件是图信号与最小的特征值所对应的特征向量完全重合,结合其描述图信号整体平滑度的意义,可将特征值等价成频率:特征值越低,频率越低,对应的傅里叶基变化缓慢,即相近节点的信号值趋于一致。
把图信号所有的傅里叶系数结合称为频谱(spectrum),频域的视角从全局视角既考虑信号本身,也考虑到图的结构性质。
1.3 图信号滤波器
图滤波器(Graph Filter)为对图中的频率分量进行增强或衰减,图滤波算子核心为其频率响应矩阵,为滤波器带来不同的滤波效果。
故图滤波器根据滤波效果可分为低通,高通和带通。
低通滤波器:保留低频部分,关注信号的平滑部分;
高通滤波器:保留高频部分,关注信号的剧烈变化部分;
带通滤波器:保留特定频段部分;
而拉普拉斯矩阵多项式扩展可形成图滤波器H:
2. 图卷积神经网络
2.1 数学定义
图卷积运算的数学定义为:
上述公式存在一个较大问题:学习参数为N,这涉及到整个图的所有节点,对于大规模数据极易发生过拟合。
进一步的化简推导:将之前说到的拉普拉斯矩阵的多项式展开代替上述可训练参数矩阵。
此结构内容即定义为图卷积层(GCN layer),有图卷积层堆叠得到的网络模型即为图卷积网络GCN。
2.2 GCN的理解及时间复杂度
图卷积层是对频率响应矩阵的极大化简,将本要训练的图滤波器直接退化为重归一化拉普拉斯矩阵
2.3 GCN的优缺点
优点:GCN作为近年图神经网络的基础之作,对处理图数据非常有效,其对图结构的结构信息和节点的属性信息同时学习,共同得到最终的节点特征表示,考虑到了节点之间的结构关联性,这在图操作中是非常重要的。
缺点:过平滑问题(多层叠加之后,节点的表示向量趋向一致,节点难以区分),由于GCN具有一个低通滤波器的作用(j聚合特征时使得节点特征不断融合),多次迭代后特征会趋于相同。
3. Pytorch代码解析
GCN层的pytorch实现:
class GraphConvolutionLayer(nn.Module):
'''
图卷积层:Lsym*X*W
其中 Lsym表示正则化图拉普拉斯矩阵, X为输入特征, W为权重矩阵, X'表示输出特征;
*表示矩阵乘法
'''
def __init__(self, input_dim, output_dim, use_bias=True):
#初始化, parameters: input_dim-->输入维度, output_dim-->输出维度, use_bias-->是否使用偏置项, boolean
super(GraphConvolutionLayer,self).__init__()
self.input_dim=input_dim
self.output_dim=output_dim
self.use_bias=use_bias #是否加入偏置, 默认为True
self.weight=nn.Parameter(torch.Tensor(input_dim, output_dim))#权重矩阵为可训练参数
if self.use_bias==True: #加入偏置
self.bias=nn.Parameter(torch.Tensor(output_dim))
else: #设置偏置为空
self.register_parameter('bias', None)
self.reset_parameters()
def reset_parameters(self):
#初始化参数
stdv = 1. / math.sqrt(self.weight.size(1))
self.weight.data.uniform_(-stdv, stdv)#使用均匀分布U(-stdv,stdv)初始化权重Tensor
if self.bias is not None:
self.bias.data.uniform_(-stdv, stdv)
def forward(self, adj, input_feature):
#前向传播, parameters: adj-->邻接矩阵(输入为正则化拉普拉斯矩阵), input_future-->输入特征矩阵
temp=torch.mm(input_feature, self.weight)#矩阵乘法, 得到X*W
output_feature=torch.sparse.mm(adj, temp)#由于邻接矩阵adj为稀疏矩阵, 采用稀疏矩阵乘法提高计算效率, 得到Lsym*temp=Lsym*X*W
if self.use_bias==True: #若设置了偏置, 加入偏置项
output_feature+=self.bias
return output_feature
定义两层的GCN网络模型:
class GCN(nn.Module):
'''
定义两层GCN网络模型
'''
def __init__(self, input_dim, hidden_dim, output_dim):
#初始化, parameters: input_dim-->输入维度, hidden_dim-->隐藏层维度, output_dim-->输出维度
super.__init__(GCN, self).__init__()
#定义两层图卷积层
self.gcn1=GraphConvolutionLayer(input_dim, hidden_dim)
self.gcn2=GraphConvolutionLayer(hidden_dim, output_dim)
def forward(self, adj, feature):
#前向传播, parameters: adj-->邻接矩阵, feature-->输入特征
x=F.relu(self.gcn1(adj, feature))
x=self.gcn2(adj, x)
return F.log_softmax(x, dim=1)
来源:https://blog.csdn.net/weixin_44756457/article/details/107855072
猜你喜欢
- 2017年底,Tensorflow 推出Lite版本,可实现移动端的快速运行,其中,一个很关键的问题,如何把现有分类模型(.pb) 转换为(
- 静态页面由于其稳定性快速性,的确给SE、用户及站长带来了方便。但有时,需要记住用户的信息,如用户留下评论后,下一次再来,就要记住该用户的信息
- 多线程-共享全局变量#coding=utf-8from threading import Threadimport timeg_num =
- 一、前言在Python中,除了可以自定义模块外,还可以引用其他模块,主要包括使用标准库和第三方模块。下面分别进行介绍。二、导入和使用标准模块
- 1.SQL Server 2005中的存储过程并发问题问:我在SQL Server2005中遇到了并发问题。我持有车票的公共汽车上有一些空闲
- 官方文档:https://elasticsearch-py.readthedocs.io/en/master/1、介绍python提供了操作
- 最常用的遍历方式为for语句(也有递归、while方式)。当我们遍历一个数组的时候,我们一般会这么做:var arr = [1,2,3,4,
- django-mdeditorGithub地址:https://github.com/pylixm/django-mdeditor 欢迎试用
- 字符串打印打印函数echo: 打印值,用于单值print_r(): 人类可读方式打印,用于数组var_dump():打印结构和类型,一般用于
- 功能: 1、 允许/限制对表的修改 2、 自动生成派生列,比如自增字段 3、 强制数据一致性 4、 提供审计和日志记录 5、 防止无效的事务
- 1. 模块(Module)在计算机程序的开发过程中,随着程序代码越写越多,在一个文件里代码就会越来越长,越来越不容易维护。为了编写可维护的代
- pygame实现代码雨动画如视频所示 利用pygame库实现了一个代码呈雨状下落的视觉效果部分代码如下import sysimport ra
- 写在前面这次的爬虫是关于房价信息的抓取,目的在于练习10万以上的数据处理及整站式抓取。数据量的提升最直观的感觉便是对函数逻辑要求的提高,针对
- 购物车程序要求如下图代码# --*--coding:utf-8--*--# Author: 村雨import pprintproductLi
- JSONJSON 起源JSON 全称 JavaScript Object Notation 。是处理对象文字语法的 JavaScript 编
- 程序员的时间很宝贵,Python这门语言虽然足够简单、优雅,但并不是说你使用Python编程,效率就一定会高。要想节省时间、提高效率,还是需
- 场景针对园区停车信息,需要对各个公司提供的停车数据进行整合并录入自家公司的大数据平台数据的录入无外乎就是对数据的增删改查下面上一个常规的写法
- 第一招、mysql服务的启动和停止net stop mysqlnet start mysql第二招、登陆mysql语法如下: mysql -
- 使用 NetBox 可以方便的将 asp 应用编译成为独立运行的执行程序,完全摆脱 iis 的束缚,在几乎所有的 Windows 版本上面直
- 前言在上一篇文中,我们介绍了关于Python正则表达式的基础,那么在这一篇文章里,我们将总结一下正则表达式关于捕获的用法。下面话不多说,来看