chenwydj
3/2/2020 - 10:34 PM

MMD (Maximum Mean Discrepancy)

code

    x = x.view(x.size(0), x.size(2) * x.size(3))
    y = y.view(y.size(0), y.size(2) * y.size(3))

    xx, yy, zz = torch.mm(x,x.t()), torch.mm(y,y.t()), torch.mm(x,y.t())

    rx = (xx.diag().unsqueeze(0).expand_as(xx))
    ry = (yy.diag().unsqueeze(0).expand_as(yy))

    K = torch.exp(- self.alpha * (rx.t() + rx - 2*xx))
    L = torch.exp(- self.alpha * (ry.t() + ry - 2*yy))
    P = torch.exp(- self.alpha * (rx.t() + ry - 2*zz))

    beta = (1./(B*(B-1)))
    gamma = (2./(B*B)) 

    return beta * (torch.sum(K)+torch.sum(L)) - gamma * torch.sum(P)

Definition

KL散度度量的不是距离,而是一种信息损失: https://www.zhihu.com/question/265417875

https://stats.stackexchange.com/questions/276497/maximum-mean-discrepancy-distance-distribution