介绍
信息传递网络(Message Passing Neural Networks, MPNNs)是由Gilmer等人提出的一种图神经网络通用计算框架。原文以量子化学为例,根据原子的性质(对应节点特征)和分子的结构(对应边特征)预测了13种物理化学性质。查看论文原文请点击这里。
机制
理论
MPNN的前向传播包括两个阶段,第一个阶段称为message passing(信息传递)
阶段,第二个阶段称为readout(读取)
阶段。定义一张图
,其中
是所有节点,
是所有边。
信息传递阶段
message passing
阶段会执行多次信息传递过程。对于一个特定的节点v
,我先给出公式。
其中,在公式
中,
是结点v
在t+1
时间步所接收到的信息,
是结点v
的所有邻结点,
是结点v
在t
时间步的特征向量,
是结点v
和w
的边特征,
是消息函数。该公式的意义是节点v
收到的信息来源于节点v
本身状态(
),周围的节点状态(
)和与之相连的边特征(
)。生成信息后,就需要对结点进行更新。
在公式 中, 是结点更新函数,该函数把原节点状态 和信息 作为输入,得到新的节点状态 。熟悉RNN的同学可能会眼熟这个公式,这个更新函数和RNN里的更新函数是一样的。后面我们也可以看到,我们可以用GRU或LSTM来表示 。
最后再强调一下时间步的概念。计算完一次 和 算一个时间步,因此如果时间步设为 ,上述两个公式会各运行 次,最终得到的结果是 。
读取阶段
readout
阶段使用读取函数
计算基于整张图的特征向量,可以表示为
其中,
是最终的输出向量,
是读取函数,这个函数有两个要求:1、要可以求导。2、要满足置换不变性(结点的输入顺序不改变最终结果,这也是为了保证MPNN对图的同构有不变性)
实际案例
在MPNN的框架下,我们可以自定义消息函数、更新函数和读取函数,下面我举一个实际的案例,也是这篇文章所提及的门控图神经网络(Gated Graph Neural Networks, GG-NN)。这里,信息函数、结点更新函数和读取函数被定义为
消息函数
中,矩阵
决定了图中的结点是如何与其他结点进行相互作用的,一条边对应一个矩阵。但是这个函数描述得有些笼统。GGNN文章中的公式更清晰一些,如下所示
其中,
是结点v
在t
时刻接收到的信息向量,和我们之前定义的
是一样的,只是换了些字母。
表示节点在t-1
个时间步的状态,因此
把每个结点的状态拼接在一个维度上,维度大小为
,
是偏置项,至于
,我们先看下面这张图
这里有一个边的特征矩阵
,矩阵
考虑了边的方向,因此它是由out
和in
两个部分拼接而成,图中的不同字母代表了不同的相互作用类型(也可以视为每条边的特征,注意每一条边的特征维度都是
,而不是我们常见的一维向量,在实际应用中,如果边的初始特征维度不是
,可以进行embedding
或线性变换到
维,再reshape
到
),最终的维度是
,其中
是结点个数。有了矩阵
之后,我们需要针对某一个结点选出“两列”
(并非真正意义上的两列)。以2号结点作为v
结点为例,我们在Outgoing Edges
和Incoming Edges
中分别找到2号结点,再把这两列拼接起来,得到一个维度是
的矩阵
。将该矩阵的转置与所有节点的状态拼接成的列向量相乘,最终得到一个维度为
的信息向量
。而对于无向图而言,只需要考虑一半的情况就行了。
结点更新函数
是GRU
,对GRU
不熟悉的同学可以看一下这方面的知识,在此就不再多做解释了。
读取函数
看起来是较为复杂的,我们可以拆开来看。首先
表示逐元素相乘,
和
分别表示一个全连接神经网络,并且在
的外面又套了一层sigmoid
函数,用符号
表示。对于神经网络
而言,输入是结点的初始状态和最终状态,因此输入维度是2 * in_dim
,而对于神经网络
而言,输入只有结点的最终状态,因此输入维度是in_dim
。但是这两个神经网络的输出维度是一样的,这样才能逐元素相乘。再往深入一点讲,这里包含了self attention
机制,就是在读取阶段要注意
该节点最初的特征。
代码
我分别找到了Pytorch和Tensorflow的实现,以后有时间我会分析一下Pytorch版的实现过程。
Pytorch版
Tensorflow版(原作者)