pytorch 三维one-hot tensor的制作

import torch

batch_size = 2
sequence_len = 3
hidden_dim = 5
x = torch.zeros(batch_size, sequence_len, hidden_dim).scatter_(dim=-1,
                               index=torch.LongTensor([[[2],[2],[1]],[[1],[2],[3]]]),
                               value=1)

print(x)

x = torch.zeros(batch_size, sequence_len, hidden_dim).scatter_(dim=-1,
                               index=torch.LongTensor([[[2],[2],[0]],[[1],[2],[3]]]),
                               value=2)

print(x)

x = torch.zeros(batch_size, sequence_len, hidden_dim).scatter_(dim=-1,
                               index=torch.LongTensor([[[2],[3],[1]],[[1],[2],[3]]]),
                               value=2)

print(x)

print结果:

tensor([[[0., 0., 1., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 1., 0., 0., 0.]],

        [[0., 1., 0., 0., 0.],
         [0., 0., 1., 0., 0.],
         [0., 0., 0., 1., 0.]]])
         
         
tensor([[[0., 0., 2., 0., 0.],
         [0., 0., 2., 0., 0.],
         [2., 0., 0., 0., 0.]],

        [[0., 2., 0., 0., 0.],
         [0., 0., 2., 0., 0.],
         [0., 0., 0., 2., 0.]]])


tensor([[[0., 0., 2., 0., 0.],
         [0., 0., 0., 2., 0.],
         [0., 2., 0., 0., 0.]],

        [[0., 2., 0., 0., 0.],
         [0., 0., 2., 0., 0.],
         [0., 0., 0., 2., 0.]]])
发布了1142 篇原创文章 · 获赞 196 · 访问量 260万+

猜你喜欢

转载自blog.csdn.net/guotong1988/article/details/102541546