every blog every motto: You can do more than you think.
https://blog.csdn.net/weixin_39190382?type=blog
0. 前言
总结torch常用的函数。
包括如下:
torch.FloatTensor torch.max torch.numel torch.sort torch.sort torch.clamp torch.nonzero torch.cat torch.statck
说明: 后续持续更新中…
1. 正文
1.1 torch.FloatTensor()
- 当传入的是数组形式时
里面的数据类型会转换成浮点型
a = torch.FloatTensor([1,2,3,4])
2. 当传入的是数字时
生成数字对应维度的数组(数组内数值为0)
b = torch.FloatTensor(4)
注意: 不同版本的torch生成的数字可能不同,有可能生成一个很大的数。
1.2 torch.max()
- 当只传入Tensor时,返回里面最大的数
a = torch.Tensor([5,6,9])
torch.max(a)
2. 当传入Tensor,和维度时(在哪一维比较大小)
返回最大的数和索引
a = torch.Tensor([[5,6,9],
[2,3,1]])
torch.max(a,dim=1)
如下图所示,返回的时最大值的9和3,当然这是在dim=1这个维度进行比较的,即,每一行里面比较选择一个最大的。
第一行的数据为 【5,6,9】,最大的9,索引为2
第二行数据为【2,3,1】,最大的为3,索引为1
可对数字和索引进行拆分,如下图所示:
1.3 torch.numel()
返回tensor中有多少个数
a = torch.Tensor([2,3,4,5,1,2,3])
a.numel()
a = torch.Tensor([[5,6,9],[2,3,1]])
a.numel()
1.4 torch.sort(0, descending=True)
对指定维度进行排序,True表示降序,False表示升序
返回排序后的tensor和数字在原tensor中对应的索引
a = torch.Tensor([2,3,4,5,1,2,3])
a.sort(0,True)
和上面的max类似,可进行拆分,如下图所示
1. 5 torch.clamp()
将tensor数值进行修改
- 当数值小于min时,数值改为min
- 当数值大于max时,数值改为max
a = torch.Tensor([2,3,4,5])
a.clamp(min=3)
a = torch.Tensor([2,3,4,5])
a.clamp(max=3)
1.6 torch.nonzero()
返回非零元素的索引
一维:
a = torch.Tensor([0, 1, 2, 3, 0, 5])
a.nonzero()
返回tensor中非零元素的索引:
二维:
a = torch.Tensor([
[0, 1, 2, 3, 0, 5],
[0, 0, 0, 0, 0, 10] ])
a.nonzero()
非零元素的索引:
如,a[0,1],a[0,2] …等,其值时非零的。
1.7 torch.cat()
沿着指定维度进行合并,默认维度dim=0
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.cat((a, b))
各tensor的shape:
1.8 torch.stack()
沿着指定维度合并,但是会多出一维
a = torch.rand((2, 3))
b = torch.rand((2, 3))
c = torch.stack((a, b))