在 PyTorch 中,使用 None
作为张量索引的目的是为张量增加一个新的维度。具体来说,None
索引会创建一个大小为1的新维度,而不会改变原始张量中的任何数据。
假设 hidden_states
是一个形状为 (A, B)
的二维张量。当你使用 hidden_states[None, :]
进行索引时,你将得到一个形状为 (1, A, B)
的三维张量,其中原始数据保持不变,并在最前面增加了一个新维度。
以下是一个具体示例:
import torch
# 创建一个形状为 (3, 4) 的二维张量
hidden_states = torch.randn(3, 4)
print("原始张量:")
print(hidden_states)
print("原始张量形状:", hidden_states.shape)
# 使用 None 索引增加一个维度
new_hidden_states = hidden_states[None, :]
print("增加维度后的张量:")
print(new_hidden_states)
print("增加维度后的张量形状:", new_hidden_states.shape)
输出:
原始张量:
tensor([[ 0.8860, 0.2456, 0.2150, -0.4519],
[ 0.8445, -0.5865, -0.0738, 0.1211],
[-0.4374, 0.3314, -0.0214, 0.4563]])
原始张量形状: torch.Size([3, 4])
增加维度后的张量:
tensor([[[ 0.8860, 0.2456, 0.2150, -0.4519],
[ 0.8445, -0.5865, -0.0738, 0.1211],
[-0.4374, 0.3314, -0.0214, 0.4563]]])
增加维度后的张量形状: torch.Size([1, 3, 4])
如你所见,原始张量的形状从 (3, 4)
变为 (1, 3, 4)
,在最前面增加了一个新维度。这在需要调整张量维度以匹配其他张量形状的情况下非常有用。例如,当你需要将张量输入到期望三维输入的神经网络层时。