RBF函数展开代码实现

1. 应用场景介绍

  论文《Atomistic Line Graph Neural Network for improved materials property predictions》是一篇利用图神经网络预测晶体和分子性质的文章。在论文中,每个晶体图被表示成Node features、Edge features和triplet features。其中Node features根据其原子种类分配9个输入节点特征,这个类似CGCNN模型;Edge features是原子间键的距离,论文中使用径向基函数(RBF)展开为40维向量,晶体支持在0到8A之间,分子支持0-5A;triplet features是键角余弦的RBF展开,展开为80维向量。如下图所示:
在这里插入图片描述

2. RBF函数展开代码

   RBF展开R是一个1维张量,包含在区间[0,8]上均匀间隔的40个点,通过pytorch广播机制,通过RBF展开函数计算之后,每一个 e i j e_{ij} eij将变成一个具有40个特征的一维张量。

from typing import Optional

import numpy as np
import torch
from torch import nn


class RBFExpansion(nn.Module):
    """Expand interatomic distances with radial basis functions."""

    def __init__(
        self,
        vmin: float = 0,
        vmax: float = 8,
        bins: int = 40,
        lengthscale: Optional[float] = None,
    ):
        """Register torch parameters for RBF expansion."""
        super().__init__()
        self.vmin = vmin
        self.vmax = vmax
        self.bins = bins
        self.register_buffer(
            "centers", torch.linspace(self.vmin, self.vmax, self.bins)
        )

        if lengthscale is None:
            # SchNet-style
            # set lengthscales relative to granularity of RBF expansion
            self.lengthscale = np.diff(self.centers).mean()  # np.diff()就是后一个元素减前一个元素的差值
            self.gamma = 1 / self.lengthscale

        else:
            self.lengthscale = lengthscale
            self.gamma = 1 / (lengthscale ** 2)

    def forward(self, distance: torch.Tensor) -> torch.Tensor:
        """Apply RBF expansion to interatomic distance tensor."""
        return torch.exp(
            -self.gamma * (distance.unsqueeze(1) - self.centers) ** 2  # self.centers torch.Size([40])
        )

if __name__ == "__main__":
    distances = torch.rand(3)
    print(distances)
    rbf = RBFExpansion()
    rbf_vec = rbf.forward(distances)
    print(rbf_vec)  # torch.Size([3, 40])

参考

[1] alignn代码

猜你喜欢

转载自blog.csdn.net/qq_49323609/article/details/126424980
rbf