pytorch masked_fill

版权声明:我是小仙女 转载要告诉小仙女哦 https://blog.csdn.net/qq_40210472/article/details/88826821
import torch.nn.functional as F
import numpy as np
a = torch.Tensor([1,2,3,4])
a = a.masked_fill(mask = torch.ByteTensor([1,1,0,0]), value=-np.inf)

print(a)
b = F.softmax(a)
print(b)

tensor([-inf, -inf, 3., 4.])
d:/pycharmdaima/star-transformer/ceshi.py:8: UserWarning: Implicit dimension choice for softmax has been deprecated. Change
the call to include dim=X as an argument.
  b = F.softmax(a)
tensor([0.0000, 0.0000, 0.2689, 0.7311])

猜你喜欢

转载自blog.csdn.net/qq_40210472/article/details/88826821