最近在看一些论文的源码过程中,在循环窗口以为过程中,出现了 torch.roll() 这个函数,roll的英文含义为翻滚的意思,在torch中主要表示在对应维度上进行循环移位,其实类似于数据的扩增方式了,由于平时在其他的代码中不常见,这里记录下改代码的具体操作。
torch.roll()
PyTorch官网API点击前往
这里需要关注的是roll函数中的shifts 和 dims 两者的维度必须是一样的,即 shifts中主要是指定对应维度移动多少个像素,dims指定在哪个维度上进行移动。
实例代码
# -*- coding:utf-8 -*-
import torch
import numpy as np
import matplotlib.pyplot as plt
shift_size = 3
'''构造多维张量'''
x=np.arange(301056).reshape(1,56,56,96)
x=torch.from_numpy(x)
if shift_size > 0:
shifted_x = torch.roll(x, shifts=(-20, -20), dims=(1, 2))
#shifted_x = torch.roll(x, shifts=(-self.shift_size, -self.shift_size), dims=(1, 2))
print("---------经过循环位移了---------")
else:
shifted_x = x
'''可视化部分'''
plt.figure(figsize=(16,8))
plt.subplot(1,2,1)
plt.imshow(x[0,:,:,0])
plt.title("orgin_img")
plt.subplot(1,2,2)
plt.imshow(shifted_x[0,:,:,0])
if torch.equal(shifted_x, x):
plt.title("non_shifted")
else:
plt.title("shifted_img")
plt.show()
plt.pause(5)
plt.close()
这幅图是 shifts = (-20,-20) 在 dim(1,2)即 H,W两个维度上进行循环移动的结果,关注移位图的位置,我们不难发现,当 shifts 值为负时,相当于将图像向上,向左进行移动,即将原来最上方和最左端的像素,移动至底部和右端。大家可以进一步尝试将 shifts 设置为不同的值,进行观察。