【图神经网络——消息传递】

画图先:

导包:

import networkx as nx
import matplotlib.pyplot as plt
import torch
from torch_geometric.nn import MessagePassing

画图:

# 创建有向图
G = nx.DiGraph()

# 添加四个节点
nodes = [0,1,2,3]
G.add_nodes_from(nodes)

# 添加每个节点的属性
node_attributes = {0:'[1,2]', 1:'[2,3]', 2:'[8,3]', 3:'[2,4]'}
nx.set_node_attributes(G,node_attributes,'embeddings')

# 添加边(使用edge_index)
edge_index = [(0,0),(0,1),(1,2),(2,1),(2,3),(3,2)]
G.add_edges_from(edge_index)

# 获取节点标签
node_labels = nx.get_node_attributes(G,'embeddings')

pos = nx.spring_layout(G)

# 绘制图
nx.draw(G,pos,with_labels=False,node_size=900,node_color ='skyblue',font_size=15,font_color='black')

# 在节点旁边添加节点属性
nx.draw_networkx_labels(G,pos,font_color='black',labels={k:f'{k}:{v} ' for k,v in node_labels.items()})

在这里插入图片描述

实现消息传递:

创建与上面所创建的图一致的数据


x = torch.tensor([[1,2],[2,3],[8,3],[2,4]])
edge_index = torch.tensor([
    [0,0,1,2,2,3],
    [0,1,2,1,3,2]
])

举两个不同的消息传递例子,方便理解。

例子一:

class MessagePassingLayer(MessagePassing):
    def __init__(self):
        super(MessagePassingLayer,self).__init__(aggr='max')
    def forward(self,x,edge_index):
        return self.propagate(edge_index=edge_index,x=x)
    def message(self,x_i, x_j):
        # 中心节点特征,也就是向量
        print(x_i)
        # 邻居节点特征
        print(x_j)
        return x_j
  
messagePassingLayer = MessagePassingLayer()
output = messagePassingLayer(x,edge_index)
print(output)
plt.show()

输出如下:
tensor([[1, 2], 这个是中心节点的特征
[2, 3],
[8, 3],
[2, 3],
[2, 4],
[8, 3]])
tensor([[1, 2], 这个是邻居节点的特征
[1, 2],
[2, 3],
[8, 3],
[8, 3],
[2, 4]])
tensor([[1, 2], 这个是进行消息传递后的中心节点的特征。
[8, 3],
[2, 4],
[8, 3]])

例子二:

class MessagePassingLayer(MessagePassing):
    def __init__(self):
        super(MessagePassingLayer,self).__init__(aggr='add')
    def forward(self,x,edge_index):
        return self.propagate(edge_index=edge_index,x=x)
    def message(self,x_i, x_j):
        # 中心节点特征,也就是向量
        print(x_i)
        # 邻居节点特征
        print(x_j)
        return (x_i+x_j)


messagePassingLayer = MessagePassingLayer()
output = messagePassingLayer(x,edge_index)
print(output)
plt.show()

输出如下:
tensor([[1, 2],
[2, 3],
[8, 3],
[2, 3],
[2, 4],
[8, 3]])
tensor([[1, 2],
[1, 2],
[2, 3],
[8, 3],
[8, 3],
[2, 4]])
tensor([[ 2, 4],
[13, 11],
[20, 13],
[10, 7]])

相关推荐

  1. 神经网络 | Pytorch神经网络ST-GNN

    2024-05-16 01:12:17       11 阅读

最近更新

  1. TCP协议是安全的吗?

    2024-05-16 01:12:17       18 阅读
  2. 阿里云服务器执行yum,一直下载docker-ce-stable失败

    2024-05-16 01:12:17       19 阅读
  3. 【Python教程】压缩PDF文件大小

    2024-05-16 01:12:17       18 阅读
  4. 通过文章id递归查询所有评论(xml)

    2024-05-16 01:12:17       20 阅读

热门阅读

  1. 第十一周学习笔记DAY.1-MySQL

    2024-05-16 01:12:17       9 阅读
  2. mysql 索引失效的原因

    2024-05-16 01:12:17       12 阅读
  3. 设计模式:备忘录模式

    2024-05-16 01:12:17       12 阅读
  4. 数据特征降维 | 主成分分析(PCA)附Python代码

    2024-05-16 01:12:17       12 阅读
  5. sophgo sdk v23.03.01

    2024-05-16 01:12:17       10 阅读
  6. js遇到需要正则匹配来修改img标签+清除行内样式

    2024-05-16 01:12:17       13 阅读