# RGB2YUV
import torch
def rgb2yuv(rgb):
rgb_ = rgb.transpose(0,2) # input is 3*n*n default
A = torch.tensor([[0.299, -0.14714119,0.61497538],
[0.587, -0.28886916, -0.51496512],
[0.114, 0.43601035, -0.10001026]]) # from Wikipedia
yuv = torch.tensordot(rgb_,A,1).transpose(0,2)
return yuv
# YUV2RGB
import torch
def yuv2rgb(yuv):
yuv_ = yuv.transpose(0,2) # input is 3*n*n default
A = torch.tensor([[1., 1.,1.],
[0., -0.39465, 2.03211],
[1.13983, -0.58060, 0]]) # from Wikipedia
rgb = torch.tensordot(yuv_,A,1).transpose(0,2)
return rgb