average_pool = torch.nn.AvgPool3d(kernel_size=(1, 4, 4), stride=(1, 4, 4)).cuda()
a=torch.arange(98,dtype=torch.float32).reshape(1,2,7,7)
img1 = average_pool(a)
print(img1)
b=torch.arange(128,dtype=torch.float32).reshape(1,2,8,8)
img2 = average_pool(b)
print(img2)
c=torch.arange(162,dtype=torch.float32).reshape(1,2,9,9)
img3 = average_pool(c)
print(img3)
输出结果展示: