
Python 中图神经网络的温和介绍
引言
图神经网络 (GNNs) 可以被看作是一类特殊的神经网络模型,其中数据被组织成图——无论是用于训练模型的训练数据,还是用于推理的真实世界数据——而不是固定大小的向量或像图像、序列或表格数据实例这样的网格。
虽然像前馈模型这样的传统神经网络架构在处理结构化、表格数据或图像的分类等预测性问题方面表现出色,但GNNs 被设计用来适应数据实体之间关系复杂且不规则的问题。例如社交网络、分子结构和知识图谱。与任何图一样,GNN 中用于训练和推理的输入数据被表示为一个图,节点代表实体(例如社交网络中的用户),边代表关系(例如用户之间的友谊或关注)。
想通过一个简单的 Python 实际示例更好地理解 GNN 的工作原理吗?请继续阅读。
在 Python 中定义图神经网络
在这个 GNN 构建的入门示例中,我们将考虑一个与社交媒体平台相关的小型图数据集,其中每个节点代表一个人,连接任何两个节点的每条边都是人之间的友谊。此外,每个节点(人)都有相关的特征,例如人的年龄、他们的兴趣等。
我们将构建的 GNN 的目标任务是根据他们是否在社交网络中有超过两个朋友或少于两个朋友来对人们进行分类(二分类),并考虑到
- 人的特征,例如他们的兴趣
- 人的与其他人的联系
因此,GNN 为预测任务增加了额外的复杂性,因为它们不仅根据目标实例的特征进行预测,还考虑了它与其他数据实例的关系,这与经典的分类和回归模型不同。
废话不多说,我们开始编码。我们将使用几个适合构建 GNN 的 PyTorch 组件,所以我们首先安装它们
1 2 3 4 |
pip install torch pip install ogb pip install torch_geometric pip install networkx |
现在是必要的导入
1 2 3 4 5 |
import os import torch import torch.nn.functional as F from torch_geometric.data import Data from torch_geometric.nn import GCNConv |
这是我们的“迷你社交网络”数据集或图
1 2 3 4 5 |
# 定义图数据集 edge_index = torch.tensor([ [0, 1, 0, 2, 0, 4, 2, 4], [1, 0, 2, 0, 4, 0, 4, 2], ], dtype=torch.long) |
基本上,edge_index
是用户之间的边或连接的矩阵。有 5 个用户,编号从 0 到 4。第一个连接是从用户 0 到用户 1,通过查看矩阵每行的第一个元素即可得知。第二个连接是上一个连接的倒数:用户 1 到用户 0。然后是用户 0 到用户 2,以此类推。用户 3 似乎还没有连接到任何人!
现在我们为每个人建模两个数值特征,存储在一个张量 node_features
中:人的年龄,以及他们对运动的兴趣,1 表示感兴趣,0 表示不感兴趣。
1 2 3 4 5 6 7 8 |
# 定义数据特征 node_features = torch.tensor([ [25, 1], # 人 0 (25 岁,喜欢运动) [30, 0], # 人 1 (30 岁,不喜欢运动) [22, 1], # 人 2 (22 岁,喜欢运动) [35, 0], # 人 3 (35 岁,不喜欢运动) [27, 1], # 人 4 (27 岁,喜欢运动) ], dtype=torch.float) |
在 Python 中可视化图神经网络
一种在 Python 中可视化我们的图神经网络的方法是使用 NetworkX 库。它将从边列表中创建一个图,并使用 Matplotlib 进行显示。下面是一个例子。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 |
import networkx as nx import matplotlib.pyplot as plt # 将 edge_index 张量转换为边元组列表 edge_list = edge_index.t().tolist() # 从边列表创建 NetworkX 图 G = nx.Graph() G.add_edges_from(edge_list) # 可选:包含可能孤立的节点(例如,人 3) G.add_nodes_from(range(node_features.size(0))) # 生成节点的布局 pos = nx.spring_layout(G, seed=42) # 固定种子以保证可重现性 # 绘制带有标签的图 plt.figure(figsize=(6, 6)) nx.draw_networkx(G, pos, with_labels=True, node_color='lightblue', edge_color='gray', node_size=800) plt.title("社交网络图可视化") plt.axis('off') plt.show() |

图 1:社交网络图可视化
在 Python 中构建图神经网络模型
现在我们为用户数据集定义标签,即一个人是否受欢迎,根据他们是否拥有超过 2 个朋友来判断。这个过程包括根据邻接矩阵计算每个人的朋友数量(地面真实值)。
1 2 3 |
# 定义数据集标签 num_friends = torch.tensor([3, 1, 2, 0, 3]) labels = (num_friends >= 2).long() |
使用以下掩码,我们将指明前三个人将用作训练数据来构建 GNN,另外两个人稍后将用于推理。最后,我们还将所有内容封装到一个 Data
对象中。
1 2 3 |
# 用于分离训练和测试数据的掩码 train_mask = torch.tensor([1, 1, 1, 0, 0], dtype=torch.bool) data = Data(x=node_features, edge_index=edge_index, y=labels, train_mask=train_mask) |
接下来的代码至关重要。它定义了GNN 架构并实例化了模型。在 PyTorch 中,GNN 模型可以通过使用图卷积层来构建,例如 torch_geometric.nn 中的 GCNConv
类实现的那些。图卷积层会聚合节点邻居的信息,帮助学习不仅捕获节点特征,还捕获图结构关系的表示。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
# 定义模型 class GNN(torch.nn.Module): def __init__(self, input_dim, hidden_dim, output_dim): super(GNN, self).__init__() self.conv1 = GCNConv(input_dim, hidden_dim) self.conv2 = GCNConv(hidden_dim, output_dim) def forward(self, data): x, edge_index = data.x, data.edge_index x = self.conv1(x, edge_index) x = F.relu(x) # 激活函数 x = self.conv2(x, edge_index) return x # 实例化模型 model = GNN(input_dim=2, hidden_dim=4, output_dim=2) |
在 Python 中训练图神经网络
训练模型与在 PyTorch 中训练其他类型的神经网络模型非常相似
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 |
# 定义优化器 optimizer = torch.optim.Adam(model.parameters(), lr=0.01) # 训练模型 for epoch in range(100): model.train() optimizer.zero_grad() out = model(data) loss = F.cross_entropy(out[data.train_mask], data.y[data.train_mask]) loss.backward() optimizer.step() if epoch % 10 == 0: print(f"Epoch {epoch}, Loss: {loss.item():.4f}") |
示例训练输出
1 2 3 4 5 6 7 8 9 10 |
Epoch 0, Loss: 1.0987 Epoch 10, Loss: 0.8563 Epoch 20, Loss: 0.6542 Epoch 30, Loss: 0.5234 Epoch 40, Loss: 0.4231 Epoch 50, Loss: 0.3654 Epoch 60, Loss: 0.3120 Epoch 70, Loss: 0.2871 Epoch 80, Loss: 0.2654 Epoch 90, Loss: 0.2543 |
Python 中的图神经网络推理
一旦 GNN 训练完成,推理过程就很简单了。我们传入整个数据集来计算受欢迎程度预测,包括在训练期间未见过的两名用户,并打印结果。请注意,argmax
函数用于获取每个用户在两个可用类别(二元分类器如逻辑回归器)中概率最高的类别,这就是其本质。
1 2 3 4 5 6 |
# 测试模型 model.eval() with torch.no_grad(): predictions = model(data).argmax(dim=1) print("\n最终预测 (1=受欢迎, 0=不受欢迎):", predictions.tolist()) |
这是最终的预测列表
1 2 |
# 测试数据推理输出 Final Predictions (1=Popular, 0=Not Popular): [1, 1, 1, 0, 1] |
因此,我们可以看到除了用户 3(又称“孤独用户”)之外,所有用户都被认为是受欢迎的。
总结
总而言之,我们构建了一个非常简单的 GNN,它使用数据集的图表示来执行预测,这些预测不仅基于实例(由节点表示)的特征,还基于它们与其他实例的关系或连接。
暂无评论。