论文笔记:详解GraphSAGE


对节点嵌入不明白的可以先看这篇: 论文笔记:DeepWalk与Node2vec

背景

  还是之前笔记里提到过的直推式(Transductive)学习与归纳(Inductive)学习:

Inductive learning,翻译成中文可以叫做 “归纳式学习”,就是从已有数据中归纳出模式来,应用于新的数据和任务。在图学习的训练过程中,看不到没有标注的节点,训练时只需要图的局部,不必一次性输入整张图,由于这个特性,归纳式学习是可以迁移的。即,在这个图上训练好的模型,可以迁移到另外一个图中使用。
Transductive learning,翻译成中文可以叫做 “直推式学习”,指的是由当前学习的知识直接推广到给定的数据上。就是训练期间可以看到没有标注的节点(训练需要整个图里面所有的节点参与),那么换一个图就需要重新训练。

  在GraphSAGE之前,生成图嵌入主流的方法都是Transductive learning的,无法处理模型没见过的节点,代表就是GCN,基于图的拉普拉斯矩阵特征值分解,虽有有方法可以将它修改为Inductive learning[1],但是也需要昂贵的计算代价。
  GraphSAGE提出了一个用于Inductive learning的归纳式生成节点嵌入的框架,利用节点特征(例如文本属性、节点概要信息、节点度)来学习一个嵌入函数,该函数可以推广到没见过的节点。通过在学习算法中加入节点特征,可以同时学习每个节点邻域的拓扑结构以及节点特征在邻域中的分布。GraphSAGE本身是基于特征的方法,但同时也可以学习到结构信息。可以通过设计loss(让邻居节点之间编码相似)来无监督训练图。
  与GCN不同,GraphSAGE没有为每个节点训练一个不同的嵌入向量,而是训练了一组聚合器函数,这些函数学习从节点的局部邻域聚合特征信息。每个聚合器函数从远离给定节点的不同跳数或搜索深度聚合信息。在推理时,通过应用学习到的聚合函数为完全不可见的节点生成嵌入。请添加图片描述
  图中是一个生成节点嵌入的例子,GraphSAGE首先对一跳和二跳邻居进行采样,然后用训练好的不同的聚合函数对信息进行聚合,生成红色节点的嵌入表示。

相关工作

1、基于随机游走和矩阵分解学习节点的embedding,代表有GCN、Node2vec、DeepWalk,但是这些方法大都是Transductive learning的。
2、Yang等人[2]引入的Planetoid-I算法,这是一种基于归纳嵌入的半监督学习方法,然而,Planetoid-I在推理过程中没有使用任何图形结构信息;相反,它在训练期间使用图结构作为一种正则化形式。

模型推导

前向传播

  首先介绍节点嵌入生成(前向传播)算法,假设模型已经经过训练并且参数是固定的,我们学习到了K个聚合器的参数,记为 A G G R E G A T E k   ∀ k ∈ { 1 , . . . , K } AGGREGATE_k\ \forall k\in \{1,...,K\} AGGREGATEk k{ 1,...,K},这些函数从邻接节点中聚合信息,以及一组参数矩阵 W k W^k Wk    ∀ k ∈ { 1 , . . . , K } \ \ \forall k\in \{1,...,K\}   k{ 1,...,K},用于在模型不同的layers或者不同的搜索深度之间传播信息。算法的伪代码:
请添加图片描述
  算法的输入包括:图 G ( V , E ) G(V,E) G(V,E),节点的特征 { x v   v ∈ V } \{x_v\ v\in V\} { xv vV},我们要进行GraphSAGE的深度K,通过遍历K层的GraphSAGE分别对各层的邻居信息进行聚合,最终生成深度为K的最终表示: z v = h v K , ∀ v ∈ V z_v=h_v^K, \forall v\in V zv=hvK,vV。用于聚合邻居信息的聚合器有很多种设置,会在下面分别介绍。
  在上面的算法中,每一层聚合所有邻居节点的信息,但是这样有很明显的缺陷:进行每次计算的时间复杂度高(最高0(|V|)),并且运行时间是不可预测的。GraphSAGE采用了这样一种设置:每次采样固定长度的邻居聚合信息,那么就要解决两个问题:邻居节点数比固定长度少怎么办?比它多怎么办?
  对于第一种情况,GraphSAGE会取所有邻居,再从所有邻居中随机选择补齐差的节点数量;对于第二种情况,则直接随机选择相应数量的节点。论文提出, K = 2 K=2 K=2 S 1 ⋅ S 2 < = 500 S_1\cdot S_2<=500 S1S2<=500 这样的设置就可以取得一个比较好的效果,其中 S 1 , S 2 S_1,S_2 S1,S2分别是一跳、二跳邻居节点的采样数量。

扩展GraphSAGE算法框架到minibatch

  和GCN使用全图方式方式不同,GraphSAGE采用聚合邻居节点,并且进行了采样。这样在minbatch下,可以不使用全图信息,使得在大规模图上训练变得可行。算法:
请添加图片描述  简单地说,就是从当前求embedding的节点一层一层采样推出计算这个embedding都需要哪些节点,再用这个得到的节点集合进行GraphSAGE。

模型训练

  在完全无监督的条件下,损失函数的思路为相近的节点它们的表示应该相似,强制不同节点的表示高度不同(不知道什么意思),损失函数为:
J G ( z u ) = − l o g ( σ ( z u T z v ) ) − Q ⋅ E v n ∼ P n ( v ) l o g ( σ ( − z u T z u ) ) J_\mathfrak G(z_u)=-log(\sigma(z_u^Tz_v))-Q\cdot \mathbb E_{v_n\sim P_{n(v)}}log(\sigma(-z_u^Tz_u)) JG(zu)=log(σ(zuTzv))QEvnPn(v)log(σ(zuTzu))  这个式子里 z u z_u zu就是经过GraphSAGE后得到的节点embedding,本着相邻节点表示相似的原则,越相似两个节点embedding的乘积应该越大,然后通过这个损失函数训练反向传播优化GraphSAGE的参数矩阵 W W W,这里就是GraphSAGE与其他方法不同的地方,它不直接优化出embedding,而是优化 W W W,优化完成后对一个没见过的节点就可以快速求出它的embedding。
  如果是有监督的任务,可以简单直接地把这个损失函数换成交叉熵等。

聚合器的设置

  聚合器的作用是把周围的邻居节点的特征向量聚合成一个向量,由于图的性质,我们的聚合器需要满足排列不变性,即生成的向量周围节点的顺序是无关的,因为图节点本身没有顺序。文章介绍了三种可用的聚合器:Mean Aggregator、LSTM Aggregator、Pooling Aggregator。

  1. Mean Aggregator 就是简单的把k-1层的邻居节点的表示求均值,如果要使用Mean Aggregator,就把算法中的4、5行改为 h v k ← σ ( W ⋅ M E A N ( { h v k − 1 } ∪ { h u k − 1 , ∀ u ∈ N ( v ) ) ) h_v^k\leftarrow\sigma(W\cdot MEAN(\{h_v^{k-1}\}\cup\{h_u^{k-1},\forall u\in N(v))) hvkσ(WMEAN({ hvk1}{ huk1,uN(v)))使用Mean Aggregator就不需要一个拼接的过程,比较简单,与GCN几乎等价。
  2. LSTM Aggregator LSTM本身是有序的,通过随机打乱节点顺序的方式就可以让它应用于无序的集合,LSTM模型本身容量更大,相比于Mean Aggregator,拥有更强的表达能力。
  3. Pooling Aggregator 这种聚合器对每一个节点的向量都经过一层全连接神经网络进行学习,转换后对每个位置上取最大值(最大池化),它的公式:
    A G G R E G A T E k p o o l = m a x ( { W p o o l h u i k + b ) , ∀ u i ∈ N ( v ) } ) AGGREGATE_k^{pool}=max(\{W_{pool}h_{u_i}^k+b),\forall u_i\in N(v)\}) AGGREGATEkpool=max({ Wpoolhuik+b),uiN(v)})需要注意的是,原则上,任何对称向量函数都可以用来代替max运算符(例如平均值)。在实验中,作者发现max-pooling和mean-pooling之间没有显著的差异,因此在剩下的实验中使用max-pooling。

实验

  使用的三个数据集:

  1. 使用科学网络引用数据集将学术文章分类成不同的主题(Citation)
  2. 将Reddit推文分类成不同的社区(Reddit)
  3. 将蛋白质功能分类成不同的生物蛋白质蛋白质交互图(PPI)

  在预测阶段对训练时不可见的结点进行预测,并且在PPI数据集上,在完全不可见的图上进行测试。
  对比的baseline:

  1. 随机分类器(a random classifer)
  2. 基于特征的逻辑回归,忽略图结构(a logistic regression feature-based classifier)
  3. DeepWalk
  4. 加入特征的DeepWalk
  5. GraphSAGE的四个变体(四种聚合器)

  本文设计实验的目标是:(i)验证GraphSAGE对于baseline(即原始特征和DeepWalk)上的改进;(ii)提供不同GraphSAGE聚合器架构的严格比较。为了提供公平的比较,所有模型共享其小批量迭代器、损失函数和邻域采样器(如果适用)的相同实现。实验结果:
请添加图片描述
  本文进行了Reddit数据上的定时实验,训练批次大小为512,在完整测试集(79,534个节点)上进行推理,这些方法的训练时间是相当的(GraphSAGE-LSTM是最慢的)。然而,由于DeepWalk需要对新的随机游走进行采样,并运行新一轮的SGD来嵌入看不见的节点,这使得DeepWalk在测试时的速度慢了100-500倍。
请添加图片描述  对于GraphSAGE变体,文章发现与K = 1相比,设置K = 2可以提供平均约10-15%的准确性提升;然而,将K增加到2以上会在性能上获得边际回报(0-5%),而运行时间则会增加10-100倍。
  本文还进行了模型对采样邻域大小的性能比较,其中“邻域样本量”是指当K = 2且S1 = S2时,在每个深度采样的邻域数量(引用数据使用GraphSAGE-mean),图B中可以看出采样数量增大时的收益递减,因此,尽管子采样邻域引起了较高的方差,但GraphSAGE仍然能够保持较强的预测精度,同时显著提高了效率。
请添加图片描述
  在本文使用的数据集中,Citation and Reddit data是不断发展的信息图,会不断出现新的unseen的数据,本文的实验表明,GraphSAGE的性能明显优于所有baseline,相比于GCN,可训练的神经网络聚合器提供了显著的收益,有趣的是,基于LSTM的聚合器表现出很强的性能,尽管它是为顺序数据而不是无序数据集设计的。最后,我们看到无监督GraphSAGE的性能与完全监督版本的性能相当,这表明我们的框架可以在没有特定于任务的微调的情况下实现强大的性能。
  ppi数据集上执行的是图泛化任务,这需要学习节点角色而不是社区结构。因为 PPI 数据集中的节点表示蛋白质分子,这些蛋白质分子在不同的生物学上下文中可能会扮演不同的角色,例如参与不同的生物过程或反应。在实验中,我们再次看到GraphSAGE显著优于所有的baseline,基于LSTM和池的聚合器比基于均值和gcn的聚合器效果更好,但是LSTM的时间成本更大。

对GraphSAGE表达能力的理论分析(讨论其如何学习图结构)

  GraphSAGE本质上是基于特征的,但是文章认为它仍能学习图结构,作为案例研究,考虑GraphSAGE是否能够学习预测节点的聚类系数,参考图论中的聚类系数(Clustering coefficient),聚类系数是衡量节点局部邻域聚类程度的常用方法,它是许多更复杂的结构基元的构建块。文章证明GraphSAGE能够将聚类系数近似到任意精度:
请添加图片描述
  定理1指出,对于任何图,存在一个算法1的参数设置,如果每个节点的特征是不同的(并且足够高维),那么它可以将该图中的聚类系数近似为任意精度。作为定理1的一个推论,GraphSAGE可以学习局部图结构,即使节点特征输入是从绝对连续的随机分布中采样。证明背后的基本思想是,如果每个节点都有唯一的特征表示,那么我们可以学习将节点映射到指示向量并识别节点邻域。定理1的证明依赖于池聚合器的一些属性,这也提供了graphsag -pool优于GCN和基于均值的聚合器的原因。

参考论文

[1]B. Perozzi, R. Al-Rfou, and S. Skiena. Deepwalk: Online learning of social representations. In
KDD, 2014.
[2]Z. Yang, W. Cohen, and R. Salakhutdinov. Revisiting semi-supervised learning with graph
embeddings. In ICML, 2016.

论文地址:GraphSAGE

欢迎点赞 关注 留言私信交流 📝 如有错误敬请指正!

相关推荐

  1. <span style='color:red;'>GraphSAGE</span>

    GraphSAGE

    2023-12-15 11:00:03      35 阅读
  2. 论文阅读笔记】清单

    2023-12-15 11:00:03       51 阅读

最近更新

  1. TCP协议是安全的吗?

    2023-12-15 11:00:03       16 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2023-12-15 11:00:03       16 阅读
  3. 【Python教程】压缩PDF文件大小

    2023-12-15 11:00:03       15 阅读
  4. 通过文章id递归查询所有评论(xml)

    2023-12-15 11:00:03       18 阅读

热门阅读

  1. gitk查看被删除的单个文件的所有历史记录

    2023-12-15 11:00:03       44 阅读
  2. Vue学习笔记-Vue3中的shallowReactive和shallowRef

    2023-12-15 11:00:03       44 阅读
  3. 英语六级作文好句

    2023-12-15 11:00:03       30 阅读
  4. 【antd】 Pagination.onChange获取不到pageSize值的原因

    2023-12-15 11:00:03       37 阅读
  5. Windows下ping IP+端口的方法

    2023-12-15 11:00:03       46 阅读
  6. 飞天使-docker知识点6-容器dockerfile各项名词解释

    2023-12-15 11:00:03       31 阅读
  7. 力扣labuladong——一刷day74

    2023-12-15 11:00:03       36 阅读
  8. filecmp --- 文件及目录的比较

    2023-12-15 11:00:03       36 阅读
  9. mysql binlog_ignore_db参数的效果详解

    2023-12-15 11:00:03       33 阅读
  10. 9月7日算法学习笔记(栈)

    2023-12-15 11:00:03       33 阅读
  11. 力扣面试150题 |有效的括号

    2023-12-15 11:00:03       49 阅读