RBF函数展开代码实现
1. 应用场景介绍
论文《Atomistic Line Graph Neural Network for improved materials property predictions》是一篇利用图神经网络预测晶体和分子性质的文章。在论文中,每个晶体图被表示成Node features、Edge features和triplet features。其中Node features根据其原子种类分配9个输入节点特征,这个类似CGCNN模型;Edge features是原子间键的距离,论文中使用径向基函数(RBF)展开为40维向量,晶体支持在0到8A之间,分子支持0-5A;triplet features是键角余弦的RBF展开,展开为80维向量。如下图所示:
2. RBF函数展开代码
RBF展开R是一个1维张量,包含在区间[0,8]上均匀间隔的40个点,通过pytorch广播机制,通过RBF展开函数计算之后,每一个 e i j e_{ij} eij将变成一个具有40个特征的一维张量。
from typing import Optional
import numpy as np
import torch
from torch import nn
class RBFExpansion(nn.Module):
"""Expand interatomic distances with radial basis functions."""
def __init__(
self,
vmin: float = 0,
vmax: float = 8,
bins: int = 40,
lengthscale: Optional[float] = None,
):
"""Register torch parameters for RBF expansion."""
super().__init__()
self.vmin = vmin
self.vmax = vmax
self.bins = bins
self.register_buffer(
"centers", torch.linspace(self.vmin, self.vmax, self.bins)
)
if lengthscale is None:
# SchNet-style
# set lengthscales relative to granularity of RBF expansion
self.lengthscale = np.diff(self.centers).mean() # np.diff()就是后一个元素减前一个元素的差值
self.gamma = 1 / self.lengthscale
else:
self.lengthscale = lengthscale
self.gamma = 1 / (lengthscale ** 2)
def forward(self, distance: torch.Tensor) -> torch.Tensor:
"""Apply RBF expansion to interatomic distance tensor."""
return torch.exp(
-self.gamma * (distance.unsqueeze(1) - self.centers) ** 2 # self.centers torch.Size([40])
)
if __name__ == "__main__":
distances = torch.rand(3)
print(distances)
rbf = RBFExpansion()
rbf_vec = rbf.forward(distances)
print(rbf_vec) # torch.Size([3, 40])
参考
[1] alignn代码