ziming-liu
2/23/2020 - 5:41 AM

pytorch index_copy_的用法

import torch
a = torch.randn(3, 5)
c = torch.zeros(3)
a.index_copy_(dim=1, index=torch.tensor([3]), source=c)