PyTorch上的常用数据类型如下
Data type | dtype | CPU tensor | GPU tensor | Size/bytes |
---|---|---|---|---|
32-bit floating | torch.float32 or torch.float |
torch.FloatTensor |
torch.cuda.FloatTensor |
4 |
64-bit floating | torch.float64 or torch.double |
torch.DoubleTensor |
torch.cuda.DoubleTensor |
8 |
16-bit floating | torch.float16 or torch.half |
torch.HalfTensor |
torch.cuda.HalfTensor |
- |
8-bit integer (unsigned) | torch.uint8 |
torch.ByteTensor |
torch.cuda.ByteTensor |
1 |
8-bit integer (signed) | torch.int8 |
torch.CharTensor |
torch.cuda.CharTensor |
- |
16-bit integer (signed) | torch.int16 or torch.short |
torch.ShortTensor |
torch.cuda.ShortTensor |
2 |
32-bit integer (signed) | torch.int32 or torch.int |
torch.IntTensor |
torch.cuda.IntTensor |
4 |
64-bit integer (signed) | torch.int64 or torch.long |
torch.LongTensor |
torch.cuda.LongTensor |
8 |
以上PyTorch中的数据类型和numpy中的相对应,占用字节大小也是一样的
数据类型转换
Reference
[1] https://pytorch.org/docs/stable/tensors.html
[2] https://blog.csdn.net/u010099080/article/details/53411703