图神经网络通用框架信息传递网络(MPNNs)

图神经网络通用框架信息传递网络(MPNNs)

介绍

信息传递网络(Message Passing Neural Networks, MPNNs)是由Gilmer等人提出的一种图神经网络通用计算框架。原文以量子化学为例,根据原子的性质(对应节点特征)和分子的结构(对应边特征)预测了13种物理化学性质。查看论文原文请点击这里

机制

理论

MPNN的前向传播包括两个阶段,第一个阶段称为message passing(信息传递)阶段,第二个阶段称为readout(读取)阶段。定义一张图 G = ( V , E ) G=(V,E) ,其中 V V 是所有节点, E E 是所有边。

信息传递阶段

message passing阶段会执行多次信息传递过程。对于一个特定的节点v,我先给出公式。
m v t + 1 = w N ( v ) M t ( h v t , h w t , e v w ) (1) m_v^{t+1}=\sum_{w\in N(v)}M_t\left( h_v^{t},h_w^{t},e_{vw} \right)\tag{1} h v t + 1 = U t ( h v t , m v t + 1 ) (2) h_v^{t+1}=U_t\left(h_v^{t},m_v^{t+1}\right)\tag{2}
其中,在公式 ( 1 ) (1) 中, m v t + 1 m_v^{t+1} 是结点vt+1时间步所接收到的信息, N ( v ) N(v) 是结点v的所有邻结点, h v t h_v^{t} 是结点vt时间步的特征向量, e v w e_{vw} 是结点vw的边特征, M t M_t 是消息函数。该公式的意义是节点v收到的信息来源于节点v本身状态( h v t h_v^{t} ),周围的节点状态( h w t h_w^{t} )和与之相连的边特征( e v w e_{vw} )。生成信息后,就需要对结点进行更新。

在公式 ( 2 ) (2) 中, U t U_t 是结点更新函数,该函数把原节点状态 h v t h_v^{t} 和信息 m v t + 1 m_v^{t+1} 作为输入,得到新的节点状态 h v t + 1 h_v^{t+1} 。熟悉RNN的同学可能会眼熟这个公式,这个更新函数和RNN里的更新函数是一样的。后面我们也可以看到,我们可以用GRU或LSTM来表示 U t U_t

最后再强调一下时间步的概念。计算完一次 ( 1 ) (1) ( 2 ) (2) 算一个时间步,因此如果时间步设为 T T ,上述两个公式会各运行 T T 次,最终得到的结果是 h v T h_v^{T}

读取阶段

readout阶段使用读取函数 R R 计算基于整张图的特征向量,可以表示为
y ^ = R ( { h v T v G } ) (3) \hat{y}=R\left(\{h_v^T|v \in G \} \right)\tag{3}
其中, y ^ \hat{y} 是最终的输出向量, R R 是读取函数,这个函数有两个要求:1、要可以求导。2、要满足置换不变性(结点的输入顺序不改变最终结果,这也是为了保证MPNN对图的同构有不变性)

实际案例

在MPNN的框架下,我们可以自定义消息函数、更新函数和读取函数,下面我举一个实际的案例,也是这篇文章所提及的门控图神经网络(Gated Graph Neural Networks, GG-NN)。这里,信息函数、结点更新函数和读取函数被定义为
M t ( h v t , h w t , e v w ) = A e v w h w t (4) M_t\left( h_v^{t},h_w^{t},e_{vw} \right)=A_{e_{vw}}h_w^t\tag{4} U t ( h v t , m v t + 1 ) = G R U ( h v t , m v t + 1 ) (5) U_t\left(h_v^{t},m_v^{t+1}\right)=GRU\left(h_v^{t},m_v^{t+1}\right)\tag{5} R = v V σ ( i ( h v ( T ) , h v 0 ) ) ( j ( h v ( T ) ) ) (6) R=\sum_{v\in V}\sigma\left(i\left(h_v^{(T)},h_v^0\right)\right)\odot \left(j\left(h_v^{(T)}\right)\right)\tag{6}

消息函数 ( 4 ) (4) 中,矩阵 A e v w A_{e_{vw}} 决定了图中的结点是如何与其他结点进行相互作用的,一条边对应一个矩阵。但是这个函数描述得有些笼统。GGNN文章中的公式更清晰一些,如下所示
a v ( t ) = A v : T [ h 1 ( t 1 ) T , h 2 ( t 1 ) T , . . . , h V ( t 1 ) T ] T + b (7) a_v^{\left(t\right)}=A_{v:}^T\left[h_1^{(t-1)^T},h_2^{(t-1)^T},...,h_{|V|}^{(t-1)^T}\right]^T+b\tag{7}
其中, a v ( t ) a_v^{\left(t\right)} 是结点vt时刻接收到的信息向量,和我们之前定义的 m v t + 1 m_v^{t+1} 是一样的,只是换了些字母。 h ( t 1 ) h^{(t-1)} 表示节点在t-1个时间步的状态,因此 [ h 1 ( t 1 ) T , h 2 ( t 1 ) T , . . . , h V ( t 1 ) T ] T \left[h_1^{(t-1)^T},h_2^{(t-1)^T},...,h_{|V|}^{(t-1)^T}\right]^T 把每个结点的状态拼接在一个维度上,维度大小为 D V D|V| b b 是偏置项,至于 A v : A_{v:} ,我们先看下面这张图
在这里插入图片描述
这里有一个边的特征矩阵 A A ,矩阵 A A 考虑了边的方向,因此它是由outin两个部分拼接而成,图中的不同字母代表了不同的相互作用类型(也可以视为每条边的特征,注意每一条边的特征维度都是 ( D , D ) (D, D) ,而不是我们常见的一维向量,在实际应用中,如果边的初始特征维度不是 D D ,可以进行embedding或线性变换到 D × D D\times D 维,再reshape ( D , D ) (D, D) ),最终的维度是 ( D V , 2 D V ) (D|V|,2D|V|) ,其中 V |V| 是结点个数。有了矩阵 A A 之后,我们需要针对某一个结点选出“两列”(并非真正意义上的两列)。以2号结点作为v结点为例,我们在Outgoing EdgesIncoming Edges中分别找到2号结点,再把这两列拼接起来,得到一个维度是 ( D V , 2 D ) (D|V|,2D) 的矩阵 A v : A_{v:} 。将该矩阵的转置与所有节点的状态拼接成的列向量相乘,最终得到一个维度为 2 D 2D 的信息向量 a v ( t ) a_v^{\left(t\right)} 。而对于无向图而言,只需要考虑一半的情况就行了。

结点更新函数 ( 5 ) (5) GRU,对GRU不熟悉的同学可以看一下这方面的知识,在此就不再多做解释了。

读取函数 ( 6 ) (6) 看起来是较为复杂的,我们可以拆开来看。首先 \odot 表示逐元素相乘, i i j j 分别表示一个全连接神经网络,并且在 i i 的外面又套了一层sigmoid函数,用符号 σ \sigma 表示。对于神经网络 i i 而言,输入是结点的初始状态和最终状态,因此输入维度是2 * in_dim,而对于神经网络 j j 而言,输入只有结点的最终状态,因此输入维度是in_dim。但是这两个神经网络的输出维度是一样的,这样才能逐元素相乘。再往深入一点讲,这里包含了self attention机制,就是在读取阶段要注意该节点最初的特征。

代码

我分别找到了Pytorch和Tensorflow的实现,以后有时间我会分析一下Pytorch版的实现过程。
Pytorch版
Tensorflow版(原作者)

第三方库

torch-geometric

发布了19 篇原创文章 · 获赞 29 · 访问量 3万+

猜你喜欢

转载自blog.csdn.net/qq_41987033/article/details/103532624