Python 中图神经网络的温和介绍

A Gentle Introduction to Graph Neural Networks in Python

Python 中图神经网络的温和介绍

引言

图神经网络 (GNNs) 可以被看作是一类特殊的神经网络模型,其中数据被组织成图——无论是用于训练模型的训练数据,还是用于推理的真实世界数据——而不是固定大小的向量或像图像、序列或表格数据实例这样的网格。

虽然像前馈模型这样的传统神经网络架构在处理结构化、表格数据或图像的分类等预测性问题方面表现出色,但GNNs 被设计用来适应数据实体之间关系复杂且不规则的问题。例如社交网络、分子结构和知识图谱。与任何图一样,GNN 中用于训练和推理的输入数据被表示为一个图,节点代表实体(例如社交网络中的用户),边代表关系(例如用户之间的友谊或关注)。

想通过一个简单的 Python 实际示例更好地理解 GNN 的工作原理吗?请继续阅读。

在 Python 中定义图神经网络

在这个 GNN 构建的入门示例中,我们将考虑一个与社交媒体平台相关的小型图数据集,其中每个节点代表一个人,连接任何两个节点的每条边都是人之间的友谊。此外,每个节点(人)都有相关的特征,例如人的年龄、他们的兴趣等。

我们将构建的 GNN 的目标任务是根据他们是否在社交网络中有超过两个朋友或少于两个朋友来对人们进行分类(二分类),并考虑到

  1. 人的特征,例如他们的兴趣
  2. 人的与其他人的联系

因此,GNN 为预测任务增加了额外的复杂性,因为它们不仅根据目标实例的特征进行预测,还考虑了它与其他数据实例的关系,这与经典的分类和回归模型不同。

废话不多说,我们开始编码。我们将使用几个适合构建 GNN 的 PyTorch 组件,所以我们首先安装它们

现在是必要的导入

这是我们的“迷你社交网络”数据集或图

基本上,edge_index 是用户之间的边或连接的矩阵。有 5 个用户,编号从 0 到 4。第一个连接是从用户 0 到用户 1,通过查看矩阵每行的第一个元素即可得知。第二个连接是上一个连接的倒数:用户 1 到用户 0。然后是用户 0 到用户 2,以此类推。用户 3 似乎还没有连接到任何人!

现在我们为每个人建模两个数值特征,存储在一个张量 node_features 中:人的年龄,以及他们对运动的兴趣,1 表示感兴趣,0 表示不感兴趣。

在 Python 中可视化图神经网络

一种在 Python 中可视化我们的图神经网络的方法是使用 NetworkX 库。它将从边列表中创建一个图,并使用 Matplotlib 进行显示。下面是一个例子。

Visualization of the social network graph

图 1:社交网络图可视化

在 Python 中构建图神经网络模型

现在我们为用户数据集定义标签,即一个人是否受欢迎,根据他们是否拥有超过 2 个朋友来判断。这个过程包括根据邻接矩阵计算每个人的朋友数量(地面真实值)。

使用以下掩码,我们将指明前三个人将用作训练数据来构建 GNN,另外两个人稍后将用于推理。最后,我们还将所有内容封装到一个 Data 对象中。

接下来的代码至关重要。它定义了GNN 架构并实例化了模型。在 PyTorch 中,GNN 模型可以通过使用图卷积层来构建,例如 torch_geometric.nn 中的 GCNConv 类实现的那些。图卷积层会聚合节点邻居的信息,帮助学习不仅捕获节点特征,还捕获图结构关系的表示。

在 Python 中训练图神经网络

训练模型与在 PyTorch 中训练其他类型的神经网络模型非常相似

示例训练输出

Python 中的图神经网络推理

一旦 GNN 训练完成,推理过程就很简单了。我们传入整个数据集来计算受欢迎程度预测,包括在训练期间未见过的两名用户,并打印结果。请注意,argmax 函数用于获取每个用户在两个可用类别(二元分类器如逻辑回归器)中概率最高的类别,这就是其本质。

这是最终的预测列表

因此,我们可以看到除了用户 3(又称“孤独用户”)之外,所有用户都被认为是受欢迎的。

总结

总而言之,我们构建了一个非常简单的 GNN,它使用数据集的图表示来执行预测,这些预测不仅基于实例(由节点表示)的特征,还基于它们与其他实例的关系或连接。

暂无评论。

留下回复

Machine Learning Mastery 是 Guiding Tech Media 的一部分,Guiding Tech Media 是一家领先的数字媒体出版商,专注于帮助人们了解技术。访问我们的公司网站以了解更多关于我们的使命和团队的信息。