前言
最近我在看一些开源项目代码的时候经常看见这样一个函数rearrange, 来进行维度转换,而不是使用permute。虽然有些时候permute 可以与rearrange替换, 但是可读性不如后者。
写这篇文章的时候看见这篇博文 非科班的他,凭借什么拿到了DeepMind的offer?, 这个故事告诉我们坚持不懈学习,并且多输出是有作用滴。
转换channel
import torch
from einops import rearrange
# H, W, C
a = torch.randn(2, 2, 3)
# H, W, C -> C, H, W
a_permute = a.permute(2, 0, 1)
print('a_permute.shape: ', a_permute.shape)
a_rearrange = rearrange(a, 'h w c -> c h w')
print('a_rearrange.shape:', a_rearrange.shape)
print('逐元素进行判断是否相等: ', a_permute.equal(a_rearrange))
a_permute.shape: torch.Size([3, 2, 2])
a_rearrange.shape: torch.Size([3, 2, 2])
逐元素进行判断是否相等: True
维度合并
import torch
from einops import rearrange
# B, C, H, W
a = torch.arange(9 * 2 * 2).view(1, 9, 2, 2)
# print(a)
b = rearrange(a, 'b c h w -> b c (h w)')
print(b.shape)
torch.Size([1, 9, 4])
高级用法
这里其实好像还和pixelshuffle结果不太一样,虽然维度是一样的。有空再去倒腾倒腾…(todo)
import torch
from einops import rearrange
# B, C, H, W
a = torch.arange(36).view(1, 9, 2, 2)
# print(a)
# 建议在 torch1.12.x 测试 PixleShuffle这个类
# ps = torch.nn.PixleShuffle(3)
b = rearrange(a, 'b (c h1 w2) h w -> b c (h1 h) (w2 w)', h1=3, w2=3)
# print(b)
# b_ps = ps(a)
# print('b.equal(b_ps): ', b.equal(b_ps))
c = rearrange(b, 'b c (h1 h) (w2 w) -> b (c h1 w2) h w', h1=3, w2=3)
print('a.equal(c):', a.equal(c))
a.equal©: True