class invPixelShuffle(nn.Module):
def __init__(self, ratio=2):
super(invPixelShuffle, self).__init__()
self.ratio = ratio
def forward(self, tensor):
ratio = self.ratio
b = tensor.size(0)
ch = tensor.size(1)
y = tensor.size(2)
x = tensor.size(3)
assert x % ratio ==0 and y % ratio ==0, 'x, y, ratio : {}, {}, {}'.format(x, y, ratio)
return tensor.view(b, ch, y//ratio, ratio, x//ratio, ratio).permute(0,1,3,5,2,4).contiguous().view(b, -1, y//ratio, x//ratio)