GraphSAGE项目练手

# 导包
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import NeighborLoader
from torch_geometric.utils import to_networkx
import numpy as np
import networkx as nx
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch_geometric.nn import SAGEConv
# 导入PubMed数据集
dataset = Planetoid(root='',name='Pubmed')
data = dataset[0]

# 邻居采样
# 使用NeighborLoader 来完成这一任务。
# 保留目的节点的10个邻居 和 其邻居的10个邻居, 对60个目的节点进行分组,每16个目的节点为一组

# 进行采样
train_loader = NeighborLoader(
    data,# 数据源
    num_neighbors=[5,10], # 每一层采样的邻居采样量,第一层5,第二层10
    batch_size=16,
    input_nodes=data.train_mask # 60个训练目的节点
)
# 遍历数据检验
# for i,subgraph in enumerate(train_loader):
#     print(f'Subgraph{i}:{subgraph}')

# 子图可视化
# fig = plt.figure(figsize=(16,16))
# for idx,(subdata,pos) in enumerate(zip(train_loader,[221,222,223,224])):
#     G = to_networkx(subdata,to_undirected=True)
#     ax = fig.add_subplot(pos)
#     ax.set_title(f'Subgraph{idx},fonts=24')
#     plt.axis('off')
#     nx.draw_networkx(G,pos=nx.spring_layout(G),with_labels=False,node_color=subdata.y)
# plt.show()

# 实现准确率评估模型
def  accuracy(pre_y,y):
    return ((pre_y==y).sum() / len(y)).item()

# 定义GraphSAGE
class GraphSAGE(torch.nn.Module):
    def __init__(self,dim_in,dim_h,dim_out):
        super().__init__()
        self.sage1 = SAGEConv(dim_in,dim_h)
        self.sage2= SAGEConv(dim_h,dim_out)

    def forward(self,x,edge_index):
        h = self.sage1(x,edge_index)
        h = torch.relu(h)
        h = F.dropout(h,p=0.5,training=self.training)
        h = self.sage2(h,edge_index)
        return h
# 使用小批量训练,Fit函数要修改为先循环epoch次,然后循环批数据,以在每个批数据上训练epoch次
    def fit(self,loader,epochs):
        criterion = torch.nn.CrossEntropyLoss()
        optimizer = torch.optim.Adam(self.parameters(),lr=0.01)
        self.train()
        for epoch in range(epochs+1):
            total_loss = 0
            acc = 0
            val_loss = 0
            val_acc = 0
            for batch in loader:
                optimizer.zero_grad()
                out = self(batch.x, batch.edge_index)
                loss = criterion(out[batch.train_mask], batch.y[batch.train_mask])
                total_loss += loss.item()
                acc += accuracy(out[batch.train_mask].argmax(dim=1), batch.y[batch.train_mask])
                loss.backward()
                optimizer.step()

                # Validation
                val_loss += criterion(out[batch.val_mask], batch.y[batch.val_mask])
                val_acc += accuracy(out[batch.val_mask].argmax(dim=1), batch.y[batch.val_mask])

                if epoch % 20 == 0:
                    print(f'Epoch {epoch:>3} | Train Loss: {loss/len(loader):.3f} | Train Acc: {acc/len(loader)*100:>6.2f}% | Val Loss: {val_loss/len(train_loader):.2f} | Val Acc: {val_acc/len(train_loader)*100:.2f}%')
@torch.no_grad()
def test(self, data):
    self.eval()
    out = self(data.x, data.edge_index)
    acc = accuracy(out.argmax(dim=1)[data.test_mask], data.y[data.test_mask])
    return acc

# Create GraphSAGE
graphsage = GraphSAGE(dataset.num_features, 64, dataset.num_classes)
print(graphsage)

# Train
graphsage.fit(train_loader, 200)

相关推荐

  1. GraphSAGE项目

    2024-07-11 19:00:03       11 阅读
  2. Rust 项目:猜数游戏

    2024-07-11 19:00:03       23 阅读
  3. <span style='color:red;'>GraphSAGE</span>

    GraphSAGE

    2024-07-11 19:00:03      40 阅读

最近更新

  1. docker php8.1+nginx base 镜像 dockerfile 配置

    2024-07-11 19:00:03       9 阅读
  2. Could not load dynamic library ‘cudart64_100.dll‘

    2024-07-11 19:00:03       8 阅读
  3. 在Django里面运行非项目文件

    2024-07-11 19:00:03       8 阅读
  4. Python语言-面向对象

    2024-07-11 19:00:03       11 阅读

热门阅读

  1. el-dialog弹框里面的组件第二次打开create不生效

    2024-07-11 19:00:03       6 阅读
  2. 测试类型介绍-功能测试入门指南

    2024-07-11 19:00:03       10 阅读
  3. 【ARMv8/v9 GIC 系列 1.8 -- PE 中断处理的前期评估】

    2024-07-11 19:00:03       7 阅读
  4. VUE与React的生命周期对比

    2024-07-11 19:00:03       9 阅读
  5. 设计模式:建造者模式

    2024-07-11 19:00:03       8 阅读
  6. Puppeteer 生成图片 生成 PDF

    2024-07-11 19:00:03       6 阅读
  7. iOS开发新手教程:Swift语言与Xcode工具链

    2024-07-11 19:00:03       7 阅读
  8. 详解Redis:什么是Redis?

    2024-07-11 19:00:03       10 阅读
  9. 设计模式六大原则

    2024-07-11 19:00:03       7 阅读
  10. PG延迟模拟和查看

    2024-07-11 19:00:03       9 阅读