【机器学习】基于图注意力网络(GAT)的Cora数据集论文主题预测

1. 引言

1.1. GAT概述

GAT是一种基于图神经网络的深度学习模型,专门用于处理图结构数据。与传统的神经网络不同,GAT能够直接对图结构数据进行学习和推理,通过捕捉和传递节点之间的关系和特征信息,实现对图结构数据的深度分析和挖掘。

1.1.1.核心原理
  1. 注意力机制:GAT的核心是引入了注意力机制,使得模型能够更好地捕捉节点之间的关系和特征信息。通过对邻居节点的特征信息进行加权求和,GAT为同一邻域的节点分配不同的权重,从而扩充了模型的尺度。
  2. 多头注意力:GAT使用多头注意力机制,即采用多个独立的注意力机制来计算节点的特征向量,然后将它们拼接起来作为最终的输出。这种机制有助于模型捕获不同方面的信息,并提高了模型的稳定性和表达能力。
  3. 节点特征变换:在GAT中,每个节点的特征向量首先经过线性变换,然后通过注意力机制计算得到邻居节点的加权特征向量,最后通过非线性激活函数得到节点的新特征向量。
1.1.2.GAT优点
  1. 灵活性:GAT能够处理具有复杂关系的节点数据,并在多个领域取得了显著的效果,如社交网络分析、推荐系统等。
  2. 高效性:由于采用了注意力机制,GAT能够专注于与当前节点最相关的部分,从而提高了计算效率。
  3. 可扩展性:GAT可以处理大规模的图数据,并支持并行计算,使得其在实际应用中具有更强的可扩展性。
1.1.3.GAT缺点
  1. 过平滑严重:GAT在处理高阶特征时可能会出现过平滑现象,即节点的特征向量逐渐变得相似,导致模型性能下降。
  2. 感受野受限:GAT的感受野大小受模型深度影响,较深的模型可能无法捕捉到全局的信息。
  3. 冗余计算:由于领域节点的高度重叠,GAT在计算过程中可能会出现冗余计算,增加了计算成本。
1.1.4.应用领域

GAT已被广泛应用于各种图分析任务中,如节点分类、链接预测、图嵌入等。在社交网络分析中,GAT可以用于识别相似的用户或群组;在推荐系统中,GAT可以基于用户-物品图预测用户的兴趣偏好;在生物信息学中,GAT可以帮助发现蛋白质、基因等生物实体之间的潜在关系。

GAT作为一种基于注意力机制的图神经网络模型,在处理图结构数据方面展现出了强大的能力和广泛的应用前景。

1.2. Cora数据集

Cora数据集是一个广泛用于文献学术论文分类的常用数据集,主要用于机器学习和自然语言处理研究。它包含了大量的科学出版物,这些出版物被分为七个不同的类别,并通过引用关系形成了一个复杂的网络结构。

1.2.1.数据结构
  1. 文献网络:Cora数据集构建了一个文献网络,其中每个节点代表一篇论文,而边则表示论文之间的引用关系。这个网络结构允许研究者在分类任务中考虑文献之间的引用关系。
  2. 节点信息:每个节点(论文)包含以下信息:
  • 论文ID:唯一标识每篇论文的ID。
  • 词袋模型表示:论文标题和摘要的词袋模型,用于表示文本信息。
  • 标签信息:每篇论文的类别标签,如人工智能、数据库、数据挖掘等。
1.2.2.数据集规模
  1. 论文数量:Cora数据集包含2708份科学出版物。
  2. 引文网络:由5429个链接组成,表示论文之间的引用关系。
  3. 词汇表:经过处理后,只剩下1433个独特的单词组成词汇表,用于表示论文的内容。
1.2.3.类别划分

论文根据其主题领域被分为七个不同的类别,这些类别通常是预定义的,反映了论文的研究方向。

1.2.4.数据处理

在使用Cora数据集时,通常需要进行一些预处理工作,例如文本的标记化(Tokenization)、词袋模型的构建、图网络的表示等。研究者可以选择将文本信息和引用网络结合起来,以便在模型训练中充分利用这两方面的信息。

1.2.5.应用

Cora数据集广泛用于研究文本分类、图神经网络(Graph Neural Networks, GNNs)等领域。研究者可以利用该数据集开发算法,探索如何更好地利用文本信息和引用网络结构来进行论文分类。

Cora数据集以其丰富的文献信息和复杂的网络结构,为机器学习和自然语言处理领域的研究者提供了一个理想的实验平台。通过利用这个数据集,研究者可以深入探索文本分类和图神经网络等技术的潜力和应用。

1.3.图节点分类概述

图节点分类是图机器学习中的一个重要任务,旨在根据图中节点的特征、节点之间的关系(边)以及图的整体结构,为图中的每个节点分配一个或多个标签。这一过程可以类比于传统机器学习中的分类问题,但不同的是,图节点分类需要考虑节点之间的连接关系。

1.3.1.应用场景

图节点分类在多个领域都有广泛的应用:

  1. 社交网络分析:在社交网络中,节点代表用户,边代表用户之间的关系(如朋友关系、关注关系等)。图节点分类可以用于识别用户的兴趣、职业、政治倾向等。
  2. 生物信息学:在生物网络中,节点代表蛋白质、基因等生物实体,边代表它们之间的相互作用或调控关系。图节点分类可以用于预测蛋白质的功能、识别疾病的致病基因等。
  3. 推荐系统:在推荐系统中,用户和产品可以分别作为图中的节点,用户和产品之间的交互(如购买、评分等)可以表示为边。图节点分类可以用于预测用户的兴趣偏好,从而实现个性化推荐。
1.3.2.方法概述

图节点分类的方法可以分为传统方法和基于深度学习的方法两类。

  1. 传统方法:传统方法通常基于图的统计特性或节点之间的相似性进行分类。例如,可以计算每个节点的邻居节点的标签分布,然后根据这些分布信息为每个节点分配标签。然而,传统方法通常难以处理大规模的图数据,并且难以捕获图中的非线性关系。
  2. 基于深度学习的方法:近年来,随着深度学习技术的发展,基于深度学习的图节点分类方法逐渐兴起。这些方法通常使用图神经网络(Graph Neural Networks, GNNs)来捕获图中的节点特征、边信息以及图的整体结构。其中,Node2Vec、Graph Convolutional Network(GCN)和Graph Attention Network(GAT)等是较为常见的图神经网络模型。这些方法通过迭代地聚合邻居节点的信息来更新节点的表示,并使用这些表示进行分类。
1.3.3.挑战与未来方向

尽管图节点分类已经取得了显著的进展,但仍面临一些挑战:

  1. 可解释性:图节点分类的结果通常难以解释,因为模型的决策过程涉及多个节点和边的交互。未来研究可以探索如何提高模型的可解释性,以便更好地理解模型的决策过程。
  2. 动态性:现实中的图数据往往是动态的,即节点和边会随着时间的推移而发生变化。未来研究可以关注如何设计能够处理动态图数据的图节点分类方法。
  3. 异质性:许多现实中的图数据是异质的,即节点和边具有不同的类型和属性。未来研究可以探索如何设计能够处理异质图数据的图节点分类方法。

2. GAT实现图节点分类过程

2.1. 导入库文件

# 导入 TensorFlow 库,通常用于深度学习模型的构建和训练  
import tensorflow as tf  
  
# 从 TensorFlow 中导入 Keras,Keras 是一个流行的深度学习库,TensorFlow 2.x 将其集成在内  
from tensorflow import keras  
  
# 从 Keras 中导入所需的层,如 Dense 层等  
from tensorflow.keras import layers  
  
# 导入 NumPy 库,用于数值计算,特别是数组和矩阵操作  
import numpy as np  
  
# 导入 Pandas 库,用于数据处理和分析,如数据框(DataFrame)操作  
import pandas as pd  
  
# 导入 os 库,用于与操作系统交互,比如读取文件等  
import os  
  
# 导入 warnings 库,用于处理 Python 运行时产生的警告信息  
import warnings  
  
# 忽略所有产生的警告信息,使输出更整洁  
warnings.filterwarnings("ignore")  
  
# 设置 Pandas 的显示选项,控制 DataFrame 输出时显示的列数和行数  
# 设置为 6 列和 6 行,以便在控制台中更清晰地查看数据  
pd.set_option("display.max_columns", 6)  
pd.set_option("display.max_rows", 6)  
  
# 设置 NumPy 的随机种子,以便在每次运行时都能得到相同的随机结果  
# 这在需要复现实验结果时非常有用  
np.random.seed(2)  
  
# 注意:此代码片段本身并不包含任何实际的数据处理或模型构建代码,  
# 它仅仅是导入库和设置了一些全局选项。  
  
# 在这里,你可以继续编写加载数据、构建模型、训练模型等代码...

2.2. 数据预处理

2.2.1.获取数据集

Cora数据集的准备过程与图神经网络中的节点分类教程相似。要了解数据集和探索性数据分析的详细信息,请参阅该教程。简而言之,Cora数据集由两个文件组成:cora.cites,它记录了论文之间的有向引用关系;以及cora.content,它包含了每篇论文的特征和对应的七个类别标签之一。

import os  
import tensorflow as tf  
from tensorflow import keras  
import pandas as pd  
  
# 下载并解压数据集  
zip_file = keras.utils.get_file(  
    fname="cora.tgz",  
    origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",  
    extract=True,  
)  
  
# 设置数据目录  
data_dir = os.path.join(os.path.dirname(zip_file), "cora")  
  
# 读取引文数据  
citations = pd.read_csv(  
    os.path.join(data_dir, "cora.cites"),  
    sep="\t",  
    header=None,  
    names=["target", "source"],  
)  
  
# 读取论文数据  
papers = pd.read_csv(  
    os.path.join(data_dir, "cora.content"),  
    sep="\t",  
    header=None,  
    names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],  
)  
  
# 获取唯一的类别值和论文ID,并创建类别和论文ID的索引字典  
class_values = sorted(papers["subject"].unique())  
class_idx = {name: idx for idx, name in enumerate(class_values)}  
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}  
  
# 应用索引字典到数据和引文DataFrame,将字符串转换为整数索引  
papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])  
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])  
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])  
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])  
  
# 打印引文和论文DataFrame以检查结果  
print(citations)  
print(papers)

上述代码主要是用于下载、解压、读取并预处理Cora数据集,这是一个公开的引文网络数据集。

  1. 下载并解压数据集:使用keras.utils.get_file函数下载Cora数据集,该函数会自动处理下载和缓存,并且如果数据集已经下载过,则不会重复下载。参数extract=True表示在下载后自动解压文件。

  2. 设置数据目录: 解压后的数据集被存放在一个特定的目录中。代码通过os.path.dirname(zip_file)获取到解压文件的目录,然后与数据集名称拼接,得到完整的数据目录路径。

  3. 读取引文数据:使用pd.read_csv函数读取引文数据文件cora.cites。该文件是一个制表符分隔的文件,没有表头,因此使用header=None并手动指定列名为["target", "source"]

  4. 读取论文数据:类似地,使用pd.read_csv函数读取论文数据文件cora.content。这个文件也是制表符分隔,没有表头,并且具有大量的特征列(1433个词项)和一个表示主题的列。列名被设置为["paper_id"]加上1433个以term_为前缀的索引名,以及一个["subject"]列。

  5. 创建索引字典:

  • 提取papers DataFrame中的唯一主题值,并对其进行排序,然后创建一个从主题名称到索引的映射字典class_idx
  • papers DataFrame中的paper_id进行排序并提取唯一值,然后创建一个从论文ID到索引的映射字典paper_idx
  1. 应用索引字典进行数据转换:
  • 使用apply方法和lambda函数,将papers DataFrame中的paper_idcitations DataFrame中的sourcetarget列从字符串ID转换为对应的整数索引。
  • 同样地,将papers DataFrame中的subject列也从字符串主题转换为整数索引。
  1. 打印结果:最后,打印转换后的citationspapers DataFrame,以便检查数据是否正确处理。

2.2.2.分割数据集

import numpy as np

# 假设 papers 是一个 Pandas DataFrame,它包含了所有的论文数据

# 生成随机索引
# np.random.permutation 会返回一个随机排列的数组,数组中的元素是 0 到 papers.shape[0] - 1 的整数
# range(papers.shape[0]) 生成一个从 0 到 papers.shape[0] - 1 的整数序列
random_indices = np.random.permutation(range(papers.shape[0]))

# 50/50 分割数据集
# 使用随机索引数组的前半部分作为训练集索引
train_indices = random_indices[: len(random_indices) // 2]
# 使用随机索引数组的后半部分作为测试集索引
test_indices = random_indices[len(random_indices) // 2 :]

# 使用 Pandas 的 iloc 方法根据索引数组选择数据
# train_data 包含了训练集的数据
train_data = papers.iloc[train_indices]
# test_data 包含了测试集的数据
test_data = papers.iloc[test_indices]

# 此时 train_data 和 test_data 分别是原始 papers DataFrame 的一个子集,
# 分别包含了随机选取的 50% 的数据,用于训练和测试

上述代码首先生成了一个随机索引数组random_indices,该数组包含了从 0 到papers.shape[0] - 1的所有整数的随机排列。然后,这个随机索引数组被分为两半,前半部分作为训练集的索引,后半部分作为测试集的索引。最后,使用 Pandas 的iloc方法根据这些索引从papers DataFrame 中选取相应的行,形成训练集和测试集。

2.2.3.准备图数据

import tensorflow as tf

# 从训练集和测试集中获取论文索引,这些索引稍后将用于从图中获取节点状态
train_indices = train_data["paper_id"].to_numpy()
test_indices = test_data["paper_id"].to_numpy()

# 获取每个paper_id对应的真实标签
train_labels = train_data["subject"].map(class_idx).to_numpy()  # 假设class_idx是前面定义的字典
test_labels = test_data["subject"].map(class_idx).to_numpy()    # 转换为整数索引

# 定义图,即边张量和节点特征张量
# 首先,将引文数据转换为张量
edges = tf.convert_to_tensor(citations[["target", "source"]].values, dtype=tf.int32)

# 对论文数据进行排序(根据paper_id),并提取词袋特征作为节点状态
# 注意:这里只提取了特征列,没有包括paper_id和subject列
node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1].values, dtype=tf.float32)

# 打印图的形状
print("边(Edges)的形状:\t\t", edges.shape)
print("节点特征(Node features)的形状:", node_states.shape)

# 注意:如果后续操作需要知道每个节点索引对应的原始paper_id,
# 可以在创建node_states之前先保存一个映射关系,或者将paper_id也作为特征的一部分。
# 但是,对于图神经网络,通常只需要节点特征和边关系,paper_id不是必要的输入。

上述代码主要进行了以下几个步骤的操作,针对的是Cora数据集的处理,特别是为图神经网络(Graph Neural Network, GNN)或类似模型准备数据。下面是对代码的详细解读:

  1. 获取论文索引
  • train_indices = train_data["paper_id"].to_numpy()

  • test_indices = test_data["paper_id"].to_numpy()

    这两行代码从训练集和测试集的DataFrame中提取了paper_id列,并将其转换为NumPy数组。这些索引稍后将用于在图的上下文中引用特定的节点(即论文)。

  1. 获取真实标签

    • train_labels = train_data["subject"].map(class_idx).to_numpy()
    • test_labels = test_data["subject"].map(class_idx).to_numpy()

    这里,代码从训练集和测试集的DataFrame中提取了subject列(即论文的主题标签)。然后,使用map函数和之前定义的class_idx字典(它应该是一个从主题名称到整数索引的映射)将字符串标签转换为整数索引。这些整数索引将作为机器学习任务的目标标签。

  2. 定义图结构

  • 边的张量:edges = tf.convert_to_tensor(citations[["target", "source"]].values, dtype=tf.int32)这里,代码从citations DataFrame中提取了targetsource列,表示论文之间的引文关系(即边)。然后,使用tf.convert_to_tensor函数将这些数据转换为TensorFlow张量,并指定了数据类型为tf.int32
  • 节点状态的张量:node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1].values, dtype=tf.float32)这行代码做了几个操作。首先,它使用sort_values函数按paper_idpapers DataFrame进行排序。然后,使用iloc选择器获取除了第一列(paper_id)和最后一列(subject)之外的所有列,这些列通常包含词袋特征或其他节点特征。最后,使用tf.convert_to_tensor将这些数据转换为TensorFlow张量,并指定数据类型为tf.float32
  1. 打印图的形状
  • 打印边的形状:print("边(Edges)的形状:\t\t", edges.shape)
  • 打印节点特征的形状:print("节点特征(Node features)的形状:", node_states.shape)这两行代码用于验证张量的形状是否符合预期,以便在后续的图神经网络模型中使用。

2.3. 构建模型

图注意力网络(GAT)的输入是一个图,包括边张量和节点特征张量,并输出[更新后的]节点状态。这些节点状态对于每个目标节点而言,是聚合了N跳(N由GAT的层数决定)邻域信息的结果。重要的是,与图卷积网络(GCN)不同,GAT利用注意力机制来从邻近节点(或源节点)中聚合信息。换句话说,GAT不是简单地将源节点(源论文)的节点状态平均或求和到目标节点(目标论文),而是首先对每个源节点状态应用归一化的注意力分数,然后再进行聚合。

2.3.1.定义图注意力层

GAT模型通过实现多头图注意力层来工作。MultiHeadGraphAttention层就是多个图注意力层(GraphAttention)的拼接(或平均),每个图注意力层都有其独立的可学习权重W。GraphAttention层的工作流程如下:

考虑输入节点状态h{l},它们通过W{l}进行线性变换,得到z^{l}。

对于每个目标节点:

  • 计算所有邻居节点j的成对注意力分数a{l}{T}(z{l}_{i}||z{l}{j}),得到e{ij}(对于所有j)。其中||表示拼接,{i}代表目标节点,{j}代表给定的1跳邻居/源节点。
  • 通过softmax对e_{ij}进行归一化,确保目标节点所有传入边的注意力分数之和为1。
  • 将归一化后的注意力分数e_{norm}{ij}应用于z{j},并将其累加到新的目标节点状态h^{l+1}_{i}上,对于所有j。
class GraphAttention(layers.Layer):
    # 初始化图注意力层
    def __init__(self, units, kernel_initializer="glorot_uniform", kernel_regularizer=None, **kwargs):
        super().__init__(**kwargs)  # 调用父类构造函数
        self.units = units  # 单元数
        self.kernel_initializer = keras.initializers.get(kernel_initializer)  # 核初始化器
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)  # 核正则化器

    # 构建层
    def build(self, input_shape):
        self.kernel = self.add_weight(  # 添加权重
            shape=(input_shape[0][-1], self.units),  # 权重形状
            trainable=True,  # 是否可训练
            initializer=self.kernel_initializer,  # 初始化器
            regularizer=self.kernel_regularizer,  # 正则化器
            name="kernel",  # 名称
        )
        self.kernel_attention = self.add_weight(  # 添加注意力权重
            shape=(self.units * 2, 1),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel_attention",
        )
        self.built = True  # 标记为已构建

    # 调用层
    def call(self, inputs):
        node_states, edges = inputs  # 输入的节点状态和边

        # 线性变换节点状态
        node_states_transformed = tf.matmul(node_states, self.kernel)

        # (1) 计算成对的注意力分数
        node_states_expanded = tf.gather(node_states_transformed, edges)
        node_states_expanded = tf.reshape(node_states_expanded, (tf.shape(edges)[0], -1))
        attention_scores = tf.nn.leaky_relu(tf.matmul(node_states_expanded, self.kernel_attention))
        attention_scores = tf.squeeze(attention_scores, -1)

        # (2) 归一化注意力分数
        attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
        attention_scores_sum = tf.math.unsorted_segment_sum(
            data=attention_scores,
            segment_ids=edges[:, 0],
            num_segments=tf.reduce_max(edges[:, 0]) + 1,
        )
        attention_scores_sum = tf.repeat(
            attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))
        )
        attention_scores_norm = attention_scores / attention_scores_sum

        # (3) 收集邻居节点状态,应用注意力分数并聚合
        node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
        out = tf.math.unsorted_segment_sum(
            data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
            segment_ids=edges[:, 0],
            num_segments=tf.shape(node_states)[0],
        )
        return out

class MultiHeadGraphAttention(layers.Layer):
    # 初始化多头图注意力层
    def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads  # 头数
        self.merge_type = merge_type  # 合并类型
        self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]  # 创建多个注意力层

    # 调用多头图注意力层
    def call(self, inputs):
        atom_features, pair_indices = inputs  # 输入的原子特征和配对索引

        # 从每个注意力头获取输出
        outputs = [
            attention_layer([atom_features, pair_indices])
            for attention_layer in self.attention_layers
        ]
        # 根据合并类型,连接或平均每个头的节点状态
        if self.merge_type == "concat":
            outputs = tf.concat(outputs, axis=-1)
        else:
            outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
        # 激活并返回节点状态
        return tf.nn.relu(outputs)

上述代码定义了两个类,GraphAttentionMultiHeadGraphAttention,它们都是用于图神经网络中的注意力机制的层。

  1. GraphAttention 类:
  • 这是一个继承自 layers.Layer 的自定义层,用于实现图注意力机制。
  • __init__ 方法初始化层的参数,包括单元数 units,核初始化器 kernel_initializer 和核正则化器 kernel_regularizer
  • build 方法用于创建层的权重。这里创建了两个权重:kernelkernel_attention,分别用于线性变换节点状态和计算注意力分数。
  • call 方法定义了前向传播的逻辑:
    • 首先,使用 tf.matmul 对节点状态进行线性变换。
    • 然后,通过扩展节点状态并使用 tf.matmultf.nn.leaky_relu 计算成对的注意力分数。
    • 接着,使用 tf.clip_by_valuetf.math.exp 对注意力分数进行归一化。
    • 最后,通过应用注意力分数并使用 tf.math.unsorted_segment_sum 聚合邻居节点的状态,得到最终的输出。
  1. MultiHeadGraphAttention 类:
  • 这个类同样继承自 layers.Layer,用于实现多头图注意力机制。
  • __init__ 方法除了初始化 GraphAttention 类的参数外,还添加了头数 num_heads 和合并类型 merge_type
  • call 方法定义了多头注意力的逻辑:
    • 对于每个头,创建一个 GraphAttention 实例并调用它来获取输出。
    • 根据 merge_type,可以选择将所有头的输出在最后一个维度上进行连接(concatenate)或平均(average)。
    • 最后,使用 tf.nn.relu 激活函数对结果进行非线性变换,并返回激活后的节点状态。

这种图注意力机制可以用于节点分类任务,其中节点的表示是通过考虑其邻居节点的加权和来更新的,权重由注意力分数决定。多头注意力允许模型同时学习多个不同的表示子空间,这有助于捕获更丰富的信息。

2.3.2. 自定义模型训练方法

在使用Keras框架实现GAT(图注意力网络)模型的训练逻辑时,我们可以自定义train_step、test_step和predict_step方法。由于GAT模型在整个训练、验证和测试阶段都操作整个图(即节点状态和边),因此,这些图数据(node_states和edges)将被传递给keras.Model的构造函数,并作为属性使用。不同阶段之间的区别在于如何获取和处理输出(例如,通过索引选择某些输出)。

class GraphAttentionNetwork(keras.Model):
    # 初始化图注意力网络模型
    def __init__(
        self,
        node_states,  # 节点状态特征
        edges,  # 图的边
        hidden_units,  # 隐藏单元数
        num_heads,  # 注意力头数
        num_layers,  # 层数
        output_dim,  # 输出维度
        **kwargs,
    ):
        super().__init__(**kwargs)  # 调用父类构造函数
        self.node_states = node_states  # 节点状态
        self.edges = edges  # 边信息
        # 预处理层,用于节点状态特征的转换
        self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")
        # 定义多层多头图注意力层
        self.attention_layers = [
            MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)
        ]
        # 输出层,用于最终的分类任务
        self.output_layer = layers.Dense(output_dim)

    # 模型的前向传播
    def call(self, inputs):
        node_states, edges = inputs
        # 预处理节点状态
        x = self.preprocess(node_states)
        # 循环调用多头图注意力层,并进行残差连接
        for attention_layer in self.attention_layers:
            x = attention_layer([x, edges]) + x
        # 通过输出层得到最终的输出
        outputs = self.output_layer(x)
        return outputs

    # 训练步骤
    def train_step(self, data):
        indices, labels = data
        with tf.GradientTape() as tape:  # 创建梯度记录器
            # 前向传播
            outputs = self([self.node_states, self.edges])
            # 计算损失
            loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # 计算梯度
        grads = tape.gradient(loss, self.trainable_weights)
        # 应用梯度更新权重
        optimizer.apply_gradients(zip(grads, self.trainable_weights))
        # 更新指标
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))
        # 返回指标结果
        return {m.name: m.result() for m in self.metrics}

    # 预测步骤
    def predict_step(self, data):
        indices = data
        # 前向传播
        outputs = self([self.node_states, self.edges])
        # 计算概率分布
        return tf.nn.softmax(tf.gather(outputs, indices))

    # 测试步骤
    def test_step(self, data):
        indices, labels = data
        # 前向传播
        outputs = self([self.node_states, self.edges])
        # 计算损失
        loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # 更新指标
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))
        # 返回指标结果
        return {m.name: m.result() for m in self.metrics}

上述代码定义了一个名为 GraphAttentionNetwork 的图注意力网络模型,它继承自 keras.Model

  1. 初始化 (__init__ 方法):
  • 接收节点状态特征 (node_states)、边信息 (edges)、隐藏单元数 (hidden_units)、注意力头数 (num_heads)、层数 (num_layers) 和输出维度 (output_dim) 等参数。
  • 初始化节点状态和边信息。
  • 创建一个预处理层 (preprocess),使用密集连接 (Dense) 并激活 ReLU
  • 创建多个 MultiHeadGraphAttention 层,数量由 num_layers 决定,每个层的头数由 num_heads 决定。
  • 创建一个输出层 (output_layer),用于最终的分类任务,使用密集连接。
  1. 前向传播 (call 方法):
  • 输入节点状态和边信息。
  • 通过预处理层转换节点状态。
  • 对每个注意力层进行循环处理,将注意力层的输出与输入进行残差连接。
  • 通过输出层得到最终的输出结果。
  1. 训练步骤 (train_step 方法):
  • 接收数据,包括索引和标签。
  • 使用 tf.GradientTape 记录梯度。
  • 执行模型的前向传播,计算损失。
  • 计算梯度,并使用优化器更新模型权重。
  • 更新并返回模型的指标。
  1. 预测步骤 (predict_step 方法):
  • 接收数据,包括索引。
  • 执行模型的前向传播。
  • 使用 softmax 函数计算输出的概率分布。
  1. 测试步骤 (test_step 方法):
  • 接收数据,包括索引和标签。
  • 执行模型的前向传播,计算损失。
  • 更新并返回模型的指标。

模型的关键在于使用图注意力机制来更新节点状态,这有助于捕获节点之间的复杂关系,特别是在图结构数据上。通过多层多头注意力,模型能够学习到更丰富的特征表示,从而提高分类或其他任务的性能。

2.4. 训练和评估模型

# 定义超参数
HIDDEN_UNITS = 100  # 隐藏单元数
NUM_HEADS = 8  # 注意力头数
NUM_LAYERS = 3  # 层数
OUTPUT_DIM = len(class_values)  # 输出维度,类别数

NUM_EPOCHS = 100  # 训练周期数
BATCH_SIZE = 256  # 批量大小
VALIDATION_SPLIT = 0.1  # 验证集比例
LEARNING_RATE = 3e-1  # 学习率
MOMENTUM = 0.9  # 动量

# 定义损失函数,使用稀疏分类交叉熵,from_logits=True表示模型输出未经softmax的原始logits
loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
# 定义优化器,使用随机梯度下降(SGD)并设置动量
optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)
# 定义评估指标,使用稀疏分类准确率
accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc")
# 定义早停法回调函数,用于在验证集准确率不再提升时停止训练
early_stopping = keras.callbacks.EarlyStopping(
    monitor="val_acc",  # 监控的指标
    min_delta=1e-5,  # 最小提升阈值
    patience=5,  # 等待周期数
    restore_best_weights=True  # 恢复最佳权重
)

# 构建图注意力网络模型
gat_model = GraphAttentionNetwork(
    node_states,  # 节点状态特征
    edges,  # 边信息
    HIDDEN_UNITS,  # 隐藏单元数
    NUM_HEADS,  # 注意力头数
    NUM_LAYERS,  # 层数
    OUTPUT_DIM  # 输出维度
)

# 编译模型,指定损失函数、优化器和评估指标
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])

# 训练模型
# x=train_indices 表示训练数据的索引
# y=train_labels 表示训练数据的标签
# validation_split=VALIDATION_SPLIT 表示从训练数据中划分出验证集的比例
# batch_size=BATCH_SIZE 表示批量大小
# epochs=NUM_EPOCHS 表示训练周期数
# callbacks=[early_stopping] 表示训练过程中使用的回调函数
# verbose=2 表示训练过程中的输出信息级别
gat_model.fit(
    x=train_indices,
    y=train_labels,
    validation_split=VALIDATION_SPLIT,
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCHS,
    callbacks=[early_stopping],
    verbose=2,
)

# 评估模型在测试集上的性能
_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0)

# 打印测试集上的准确率
print("--" * 38 + f"\n测试集准确率:{test_accuracy*100:.1f}%")

上述代码是一个使用图注意力网络(Graph Attention Network, GAT)进行图结构数据分类任务的完整流程,包括模型的构建、编译、训练和评估。

  1. 定义超参数:
  • 设置了网络的隐藏单元数 HIDDEN_UNITS、注意力头数 NUM_HEADS、层数 NUM_LAYERS 和输出维度 OUTPUT_DIM
  • 训练相关的超参数包括训练周期数 NUM_EPOCHS、批量大小 BATCH_SIZE、验证集比例 VALIDATION_SPLIT、学习率 LEARNING_RATE 和动量 MOMENTUM
  1. 配置训练组件:
  • 定义了损失函数 loss_fn,这里使用的是稀疏分类交叉熵,from_logits=True 表示模型输出的是未经softmax处理的原始logits。
  • 定义了优化器 optimizer,这里使用的是随机梯度下降(SGD),并设置了学习率和动量。
  • 定义了评估指标 accuracy_fn,使用的是稀疏分类准确率。
  • 设置了早停法 early_stopping 回调,用于在验证集准确率不再提升时停止训练,以防止过拟合。
  1. 构建模型:
  • 使用 GraphAttentionNetwork 类构建图注意力网络模型 gat_model,传入节点特征 node_states、边信息 edges 和之前定义的超参数。
  1. 编译模型:
  • 使用 compile 方法编译模型,指定了损失函数、优化器和评估指标。
  1. 训练模型:
  • 使用 fit 方法训练模型,传入训练数据的索引 train_indices 和标签 train_labels
  • 设置了验证集比例、批量大小、训练周期数和回调函数。
  • verbose=2 表示在训练过程中提供详细的输出信息。
  1. 评估模型:
  • 使用 evaluate 方法在测试集上评估模型性能,传入测试数据的索引 test_indices 和标签 test_labels
  • 打印测试集上的准确率,格式为百分比。

代码展示了如何使用图注意力网络进行图结构数据的分类任务,包括模型的构建、训练和评估过程。通过早停法等技术,可以有效地防止模型过拟合,提高模型的泛化能力。

2.5. 论文可能性预测

# 使用模型预测测试集的概率分布
test_probs = gat_model.predict(x=test_indices)

# 构建一个从类别索引到类别名称的映射,用于将数字索引转换为可读的类别名称
mapping = {v: k for (k, v) in class_idx.items()}

# 遍历测试集的前10个样本及其预测概率和实际标签
for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])):
    # 打印每个样本的类别名称
    print(f"示例 {i+1}: {mapping[label]}")
    # 遍历该样本预测为各个类别的概率
    for j, c in zip(probs, class_idx.keys()):
        # 打印每个类别的概率(转换为百分比形式)
        print(f"\t类别 {c: <24} 的概率 = {j*100:7.3f}%")
    # 打印分隔线,用于区分不同的样本
    print("---" * 20)

上述代码是一个模型预测和结果展示的流程,具体步骤如下:

  1. 模型预测:
  • test_probs = gat_model.predict(x=test_indices):使用训练好的图注意力网络模型 gat_model 对测试数据集 test_indices 进行预测,获取每个样本属于各个类别的概率分布 test_probs
  1. 类别映射:
  • mapping = {v: k for (k, v) in class_idx.items()}:创建一个从类别索引到类别名称的映射字典 mappingclass_idx 可能是一个索引到类别名称的原始映射,这里通过字典推导式将其反向映射,以便于将数字索引转换为人类可读的类别名称。
  1. 结果展示:
  • 通过一个 for 循环,遍历测试集的前10个样本及其预测概率 probs 和实际标签 label
  • print(f"示例 {i+1}: {mapping[label]}"):打印每个样本的索引(i+1)和实际的类别名称(通过 mapping 转换)。
  • 另一个 for 循环用于遍历每个样本的预测概率:
  • for j, c in zip(probs, class_idx.keys())zip 函数将概率 probs 和类别索引 class_idx.keys() 结合在一起进行遍历。
  • print(f"\t类别 {c: <24} 的概率 = {j*100:7.3f}%"):打印每个类别的名称和对应的预测概率(乘以100并保留三位小数,表示为百分比)。
  • print("---" * 20):在每个样本的结果之后打印一条分隔线,以提高输出的可读性。

代码的目的是展示模型对测试集中前10个样本的分类预测结果,包括每个样本的实际类别和各个类别的预测概率。通过这种方式,可以快速了解模型的预测性能和概率分布情况。

Example 1: Probabilistic_Methods
    Probability of Case_Based               =   0.919%
    Probability of Genetic_Algorithms       =   0.180%
    Probability of Neural_Networks          =  37.896%
    Probability of Probabilistic_Methods    =  59.801%
    Probability of Reinforcement_Learning   =   0.705%
    Probability of Rule_Learning            =   0.044%
    Probability of Theory                   =   0.454%
------------------------------------------------------------
Example 2: Genetic_Algorithms
    Probability of Case_Based               =   0.005%
    Probability of Genetic_Algorithms       =  99.993%
    Probability of Neural_Networks          =   0.001%
    Probability of Probabilistic_Methods    =   0.000%
    Probability of Reinforcement_Learning   =   0.000%
    Probability of Rule_Learning            =   0.000%
    Probability of Theory                   =   0.000%
------------------------------------------------------------
Example 3: Theory
    Probability of Case_Based               =   8.151%
    Probability of Genetic_Algorithms       =   1.021%
    Probability of Neural_Networks          =   0.569%
    Probability of Probabilistic_Methods    =  40.220%
    Probability of Reinforcement_Learning   =   0.792%
    Probability of Rule_Learning            =   6.910%
    Probability of Theory                   =  42.337%
------------------------------------------------------------
Example 4: Neural_Networks
    Probability of Case_Based               =   0.097%
    Probability of Genetic_Algorithms       =   0.026%
    Probability of Neural_Networks          =  93.539%
    Probability of Probabilistic_Methods    =   6.206%
    Probability of Reinforcement_Learning   =   0.028%
    Probability of Rule_Learning            =   0.010%
    Probability of Theory                   =   0.094%
------------------------------------------------------------
Example 5: Theory
    Probability of Case_Based               =  25.259%
    Probability of Genetic_Algorithms       =   4.381%
    Probability of Neural_Networks          =  11.776%
    Probability of Probabilistic_Methods    =  15.053%
    Probability of Reinforcement_Learning   =   1.571%
    Probability of Rule_Learning            =  23.589%
    Probability of Theory                   =  18.370%
------------------------------------------------------------
Example 6: Genetic_Algorithms
    Probability of Case_Based               =   0.000%
    Probability of Genetic_Algorithms       = 100.000%
    Probability of Neural_Networks          =   0.000%
    Probability of Probabilistic_Methods    =   0.000%
    Probability of Reinforcement_Learning   =   0.000%
    Probability of Rule_Learning            =   0.000%
    Probability of Theory                   =   0.000%
------------------------------------------------------------
Example 7: Neural_Networks
    Probability of Case_Based               =   0.296%
    Probability of Genetic_Algorithms       =   0.291%
    Probability of Neural_Networks          =  93.419%
    Probability of Probabilistic_Methods    =   5.696%
    Probability of Reinforcement_Learning   =   0.050%
    Probability of Rule_Learning            =   0.072%
    Probability of Theory                   =   0.177%
------------------------------------------------------------
Example 8: Genetic_Algorithms
    Probability of Case_Based               =   0.000%
    Probability of Genetic_Algorithms       = 100.000%
    Probability of Neural_Networks          =   0.000%
    Probability of Probabilistic_Methods    =   0.000%
    Probability of Reinforcement_Learning   =   0.000%
    Probability of Rule_Learning            =   0.000%
    Probability of Theory                   =   0.000%
------------------------------------------------------------
Example 9: Theory
    Probability of Case_Based               =   4.103%
    Probability of Genetic_Algorithms       =   5.217%
    Probability of Neural_Networks          =  14.532%
    Probability of Probabilistic_Methods    =  66.747%
    Probability of Reinforcement_Learning   =   3.008%
    Probability of Rule_Learning            =   1.782%
    Probability of Theory                   =   4.611%
------------------------------------------------------------
Example 10: Case_Based
    Probability of Case_Based               =  99.566%
    Probability of Genetic_Algorithms       =   0.017%
    Probability of Neural_Networks          =   0.016%
    Probability of Probabilistic_Methods    =   0.155%
    Probability of Reinforcement_Learning   =   0.026%
    Probability of Rule_Learning            =   0.192%
    Probability of Theory                   =   0.028%
------------------------------------------------------------

3. 总结和展望

3.1 内容总结

本文对于基于图注意力网络(GAT)的Cora数据集论文主题预测的讨论,主要涵盖了一下内容:

3.1.1. GAT模型简介
  • 模型原理:GAT是一种基于注意力机制的图神经网络,旨在通过为图中的每个节点分配不同的注意力权重来捕捉节点之间的关联和重要性。
  • 主要特点
    • 能够有效处理有向图和无向图。
    • 通过多头注意力机制学习多个独立的注意力权重,从而捕获节点之间复杂的依赖关系。
    • 可以应用于节点分类、链接预测等多种图相关任务。
3.1.2. GAT模型实现
  • 输入:节点特征矩阵(每个节点都有一个特征向量表示)和邻接矩阵(表示节点之间的连接关系)。
  • 输出:每个节点的低维向量表示,用于后续的节点分类等任务。
  • 核心组件
    • 注意力层(Attention Layer):计算节点之间的注意力分数,并通过softmax操作转化为注意力权重。
    • 聚合层(Aggregation Layer):利用注意力权重对邻居节点特征进行加权求和,得到聚合后的邻居特征表示。
    • 多头注意力机制(Multi-head Attention):允许模型学习多个独立的注意力权重,提高模型的表达能力和稳定性。
3.1.3. 性能与优化
  • 实验结果:根据提供的示例,GAT模型在节点分类任务上取得了约80%的准确率。
  • 优化策略
    • 调整模型参数:如层数、隐藏单元数量、学习率等,以优化模型性能。
    • 添加正则化:如dropout,以防止过拟合。
    • 预处理步骤:调整数据预处理步骤,如特征缩放、归一化等,以更好地适应模型。
    • 自环和边向性:考虑添加自环边或使图无向,以改进模型在处理特定任务时的性能。
3.1.4. 注意事项
  • 并行化:GAT模型的自注意力机制可以跨节点-邻居对进行并行化,从而提高训练效率。
  • 模型评估:在多个数据集上进行实验,以全面评估模型的性能。
  • 结构信息:GAT主要基于节点属性信息进行学习,但也可以结合图的结构信息进行改进。

GAT模型是一种强大的图神经网络架构,通过引入注意力机制,能够有效地处理图数据中的节点分类等任务。通过调整模型参数、添加正则化、优化预处理步骤等方法,可以进一步提高GAT模型的性能。此外,GAT模型还具有高度的灵活性和可扩展性,可以应用于各种图相关任务。

3.2 未来展望

展望未来,随着图神经网络和注意力机制研究的不断深入,GAT(Graph Attention Network)模型有着广阔的应用前景和潜力。我们可以期待以下几个方面的发展和改进:

首先,随着计算能力的提升和算法的优化,GAT模型将能够处理更大规模、更复杂的图数据。这将使得GAT模型在社交网络分析、生物信息学、推荐系统等领域的应用更加广泛和深入。

其次,未来的研究将更加注重GAT模型的可解释性和鲁棒性。通过引入更多的可视化技术和解释性方法,我们可以更好地理解GAT模型的工作原理和决策过程,从而发现潜在的问题并进行改进。同时,提高GAT模型的鲁棒性,使其能够应对噪声数据和异常值,也是未来研究的重要方向。

此外,随着多模态数据融合技术的发展,未来的GAT模型将能够结合文本、图像、音频等多种类型的信息,进行更加全面的图表示学习。这将进一步提高GAT模型在节点分类、链接预测等任务上的性能,并为其在更广泛的实际应用中提供支持。

最后,我们可以期待GAT模型与其他先进技术的结合,如强化学习、生成对抗网络等,以探索新的应用场景和解决方案。例如,通过结合强化学习,GAT模型可以学习如何在动态变化的图结构中进行有效的决策和规划;通过结合生成对抗网络,GAT模型可以生成具有特定属性和结构的图数据,为数据增强和隐私保护等领域提供新的思路和方法。

未来的GAT模型将在处理更大规模、更复杂的图数据、提高可解释性和鲁棒性、融合多模态数据以及与其他先进技术结合等方面取得更大的进展和突破。这将为图神经网络的研究和应用带来更加广阔的前景和机遇。

参考文献

【1】Keras官方示例. (2022). Graph attention network (GAT) for node classification. Retrieved from https://keras.io/examples/graph/gat_node_classification/

附录1 实验代码

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import numpy as np
import pandas as pd
import os
import warnings

warnings.filterwarnings("ignore")
pd.set_option("display.max_columns", 6)
pd.set_option("display.max_rows", 6)
np.random.seed(2)

"""
## Obtain the dataset

The preparation of the [Cora dataset](https://linqs.soe.ucsc.edu/data) follows that of the
[Node classification with Graph Neural Networks](https://keras.io/examples/graph/gnn_citations/)
tutorial. Refer to this tutorial for more details on the dataset and exploratory data analysis.
In brief, the Cora dataset consists of two files: `cora.cites` which contains *directed links* (citations) between
papers; and `cora.content` which contains *features* of the corresponding papers and one
of seven labels (the *subject* of the paper).
"""

zip_file = keras.utils.get_file(
    fname="cora.tgz",
    origin="https://linqs-data.soe.ucsc.edu/public/lbc/cora.tgz",
    extract=True,
)

data_dir = os.path.join(os.path.dirname(zip_file), "cora")

citations = pd.read_csv(
    os.path.join(data_dir, "cora.cites"),
    sep="\t",
    header=None,
    names=["target", "source"],
)

papers = pd.read_csv(
    os.path.join(data_dir, "cora.content"),
    sep="\t",
    header=None,
    names=["paper_id"] + [f"term_{idx}" for idx in range(1433)] + ["subject"],
)

class_values = sorted(papers["subject"].unique())
class_idx = {name: id for id, name in enumerate(class_values)}
paper_idx = {name: idx for idx, name in enumerate(sorted(papers["paper_id"].unique()))}

papers["paper_id"] = papers["paper_id"].apply(lambda name: paper_idx[name])
citations["source"] = citations["source"].apply(lambda name: paper_idx[name])
citations["target"] = citations["target"].apply(lambda name: paper_idx[name])
papers["subject"] = papers["subject"].apply(lambda value: class_idx[value])

print(citations)

print(papers)

"""
### Split the dataset
"""

# Obtain random indices
random_indices = np.random.permutation(range(papers.shape[0]))

# 50/50 split
train_data = papers.iloc[random_indices[: len(random_indices) // 2]]
test_data = papers.iloc[random_indices[len(random_indices) // 2 :]]

"""
### Prepare the graph data
"""

# Obtain paper indices which will be used to gather node states
# from the graph later on when training the model
train_indices = train_data["paper_id"].to_numpy()
test_indices = test_data["paper_id"].to_numpy()

# Obtain ground truth labels corresponding to each paper_id
train_labels = train_data["subject"].to_numpy()
test_labels = test_data["subject"].to_numpy()

# Define graph, namely an edge tensor and a node feature tensor
edges = tf.convert_to_tensor(citations[["target", "source"]])
node_states = tf.convert_to_tensor(papers.sort_values("paper_id").iloc[:, 1:-1])

# Print shapes of the graph
print("Edges shape:\t\t", edges.shape)
print("Node features shape:", node_states.shape)

"""
## Build the model

GAT takes as input a graph (namely an edge tensor and a node feature tensor) and
outputs \[updated\] node states. The node states are, for each target node, neighborhood
aggregated information of *N*-hops (where *N* is decided by the number of layers of the
GAT). Importantly, in contrast to the
[graph convolutional network](https://arxiv.org/abs/1609.02907) (GCN)
the GAT makes use of attention mechanisms
to aggregate information from neighboring nodes (or *source nodes*). In other words, instead of simply
averaging/summing node states from source nodes (*source papers*) to the target node (*target papers*),
GAT first applies normalized attention scores to each source node state and then sums.
"""

"""
### (Multi-head) graph attention layer

The GAT model implements multi-head graph attention layers. The `MultiHeadGraphAttention`
layer is simply a concatenation (or averaging) of multiple graph attention layers
(`GraphAttention`), each with separate learnable weights `W`. The `GraphAttention` layer
does the following:

Consider inputs node states `h^{l}` which are linearly transformed by `W^{l}`, resulting in `z^{l}`.

For each target node:

1. Computes pair-wise attention scores `a^{l}^{T}(z^{l}_{i}||z^{l}_{j})` for all `j`,
resulting in `e_{ij}` (for all `j`).
`||` denotes a concatenation, `_{i}` corresponds to the target node, and `_{j}`
corresponds to a given 1-hop neighbor/source node.
2. Normalizes `e_{ij}` via softmax, so as the sum of incoming edges' attention scores
to the target node (`sum_{k}{e_{norm}_{ik}}`) will add up to 1.
3. Applies attention scores `e_{norm}_{ij}` to `z_{j}`
and adds it to the new target node state `h^{l+1}_{i}`, for all `j`.
"""


class GraphAttention(layers.Layer):
    def __init__(
        self,
        units,
        kernel_initializer="glorot_uniform",
        kernel_regularizer=None,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.units = units
        self.kernel_initializer = keras.initializers.get(kernel_initializer)
        self.kernel_regularizer = keras.regularizers.get(kernel_regularizer)

    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[0][-1], self.units),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel",
        )
        self.kernel_attention = self.add_weight(
            shape=(self.units * 2, 1),
            trainable=True,
            initializer=self.kernel_initializer,
            regularizer=self.kernel_regularizer,
            name="kernel_attention",
        )
        self.built = True

    def call(self, inputs):
        node_states, edges = inputs

        # Linearly transform node states
        node_states_transformed = tf.matmul(node_states, self.kernel)

        # (1) Compute pair-wise attention scores
        node_states_expanded = tf.gather(node_states_transformed, edges)
        node_states_expanded = tf.reshape(
            node_states_expanded, (tf.shape(edges)[0], -1)
        )
        attention_scores = tf.nn.leaky_relu(
            tf.matmul(node_states_expanded, self.kernel_attention)
        )
        attention_scores = tf.squeeze(attention_scores, -1)

        # (2) Normalize attention scores
        attention_scores = tf.math.exp(tf.clip_by_value(attention_scores, -2, 2))
        attention_scores_sum = tf.math.unsorted_segment_sum(
            data=attention_scores,
            segment_ids=edges[:, 0],
            num_segments=tf.reduce_max(edges[:, 0]) + 1,
        )
        attention_scores_sum = tf.repeat(
            attention_scores_sum, tf.math.bincount(tf.cast(edges[:, 0], "int32"))
        )
        attention_scores_norm = attention_scores / attention_scores_sum

        # (3) Gather node states of neighbors, apply attention scores and aggregate
        node_states_neighbors = tf.gather(node_states_transformed, edges[:, 1])
        out = tf.math.unsorted_segment_sum(
            data=node_states_neighbors * attention_scores_norm[:, tf.newaxis],
            segment_ids=edges[:, 0],
            num_segments=tf.shape(node_states)[0],
        )
        return out


class MultiHeadGraphAttention(layers.Layer):
    def __init__(self, units, num_heads=8, merge_type="concat", **kwargs):
        super().__init__(**kwargs)
        self.num_heads = num_heads
        self.merge_type = merge_type
        self.attention_layers = [GraphAttention(units) for _ in range(num_heads)]

    def call(self, inputs):
        atom_features, pair_indices = inputs

        # Obtain outputs from each attention head
        outputs = [
            attention_layer([atom_features, pair_indices])
            for attention_layer in self.attention_layers
        ]
        # Concatenate or average the node states from each head
        if self.merge_type == "concat":
            outputs = tf.concat(outputs, axis=-1)
        else:
            outputs = tf.reduce_mean(tf.stack(outputs, axis=-1), axis=-1)
        # Activate and return node states
        return tf.nn.relu(outputs)


"""
### Implement training logic with custom `train_step`, `test_step`, and `predict_step` methods

Notice, the GAT model operates on the entire graph (namely, `node_states` and
`edges`) in all phases (training, validation and testing). Hence, `node_states` and
`edges` are passed to the constructor of the `keras.Model` and used as attributes.
The difference between the phases are the indices (and labels), which gathers
certain outputs (`tf.gather(outputs, indices)`).

"""


class GraphAttentionNetwork(keras.Model):
    def __init__(
        self,
        node_states,
        edges,
        hidden_units,
        num_heads,
        num_layers,
        output_dim,
        **kwargs,
    ):
        super().__init__(**kwargs)
        self.node_states = node_states
        self.edges = edges
        self.preprocess = layers.Dense(hidden_units * num_heads, activation="relu")
        self.attention_layers = [
            MultiHeadGraphAttention(hidden_units, num_heads) for _ in range(num_layers)
        ]
        self.output_layer = layers.Dense(output_dim)

    def call(self, inputs):
        node_states, edges = inputs
        x = self.preprocess(node_states)
        for attention_layer in self.attention_layers:
            x = attention_layer([x, edges]) + x
        outputs = self.output_layer(x)
        return outputs

    def train_step(self, data):
        indices, labels = data

        with tf.GradientTape() as tape:
            # Forward pass
            outputs = self([self.node_states, self.edges])
            # Compute loss
            loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # Compute gradients
        grads = tape.gradient(loss, self.trainable_weights)
        # Apply gradients (update weights)
        optimizer.apply_gradients(zip(grads, self.trainable_weights))
        # Update metric(s)
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}

    def predict_step(self, data):
        indices = data
        # Forward pass
        outputs = self([self.node_states, self.edges])
        # Compute probabilities
        return tf.nn.softmax(tf.gather(outputs, indices))

    def test_step(self, data):
        indices, labels = data
        # Forward pass
        outputs = self([self.node_states, self.edges])
        # Compute loss
        loss = self.compiled_loss(labels, tf.gather(outputs, indices))
        # Update metric(s)
        self.compiled_metrics.update_state(labels, tf.gather(outputs, indices))

        return {m.name: m.result() for m in self.metrics}


"""
### Train and evaluate
"""

# Define hyper-parameters
HIDDEN_UNITS = 100
NUM_HEADS = 8
NUM_LAYERS = 3
OUTPUT_DIM = len(class_values)

NUM_EPOCHS = 100
BATCH_SIZE = 256
VALIDATION_SPLIT = 0.1
LEARNING_RATE = 3e-1
MOMENTUM = 0.9

loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=True)
optimizer = keras.optimizers.SGD(LEARNING_RATE, momentum=MOMENTUM)
accuracy_fn = keras.metrics.SparseCategoricalAccuracy(name="acc")
early_stopping = keras.callbacks.EarlyStopping(
    monitor="val_acc", min_delta=1e-5, patience=5, restore_best_weights=True
)

# Build model
gat_model = GraphAttentionNetwork(
    node_states, edges, HIDDEN_UNITS, NUM_HEADS, NUM_LAYERS, OUTPUT_DIM
)

# Compile model
gat_model.compile(loss=loss_fn, optimizer=optimizer, metrics=[accuracy_fn])

gat_model.fit(
    x=train_indices,
    y=train_labels,
    validation_split=VALIDATION_SPLIT,
    batch_size=BATCH_SIZE,
    epochs=NUM_EPOCHS,
    callbacks=[early_stopping],
    verbose=2,
)

_, test_accuracy = gat_model.evaluate(x=test_indices, y=test_labels, verbose=0)

print("--" * 38 + f"\nTest Accuracy {test_accuracy*100:.1f}%")

"""
### Predict (probabilities)
"""
test_probs = gat_model.predict(x=test_indices)

mapping = {v: k for (k, v) in class_idx.items()}

for i, (probs, label) in enumerate(zip(test_probs[:10], test_labels[:10])):
    print(f"Example {i+1}: {mapping[label]}")
    for j, c in zip(probs, class_idx.keys()):
        print(f"\tProbability of {c: <24} = {j*100:7.3f}%")
    print("---" * 20)

相关推荐

最近更新

  1. TCP协议是安全的吗?

    2024-06-09 14:34:06       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-06-09 14:34:06       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-06-09 14:34:06       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-06-09 14:34:06       20 阅读

热门阅读

  1. 乘积最大子数组 - LeetCode 热题 88

    2024-06-09 14:34:06       9 阅读
  2. 3.组件间通信-mitt(任意组件间通信)

    2024-06-09 14:34:06       12 阅读
  3. spring boot集成pg

    2024-06-09 14:34:06       10 阅读
  4. !力扣70. 爬楼梯

    2024-06-09 14:34:06       9 阅读
  5. 微信小程序:实现音乐播放器的功能

    2024-06-09 14:34:06       6 阅读
  6. oracle10g的dataguard测试

    2024-06-09 14:34:06       12 阅读
  7. 电商系统中热库和冷库的使用与数据转换

    2024-06-09 14:34:06       8 阅读
  8. Python R用法:深度探索与实用技巧

    2024-06-09 14:34:06       9 阅读
  9. K-means聚类模型

    2024-06-09 14:34:06       10 阅读
  10. RAGFlow 学习笔记

    2024-06-09 14:34:06       10 阅读
  11. tcpdump 抓包

    2024-06-09 14:34:06       9 阅读
  12. TypeScript看这篇就够了

    2024-06-09 14:34:06       12 阅读
  13. 【分享】使用 Reducer 和 Context 拓展你的应用

    2024-06-09 14:34:06       13 阅读