def entropy(self, logits, targets):
log_q = F.log_softmax(logits, dim=-1)
return -torch.mean(torch.sum(targets * log_q, dim=-1))
def entropy(self, logits, targets):
log_q = torch.log(logits)
return -torch.mean(torch.sum(targets * log_q, dim=-1))
def get_encoder_loss(self, id_, prob_, classes_, cat_lambda, kl_lambda, encoder_type):
cat_target = self.indices_to_one_hot(id_, classes_)
if (encoder_type == 'gst' or encoder_type == 'x-vector') and cat_lambda != 0.0:
loss = cat_lambda*(self.entropy(prob_, cat_target) - np.log(0.1))
elif (encoder_type == 'vae' or encoder_type == 'gst_vae') and (cat_lambda != 0.0 or kl_lambda !=0.0):
loss = cat_lambda*(self.entropy(prob_[2], cat_target) - np.log(0.1)) + kl_lambda*self.KL_loss(prob_[0], prob_[1])
elif encoder_type == 'gmvae' and (cat_lambda != 0.0 or kl_lambda !=0.0) :
loss = self.gaussian_loss(prob_[0], prob_[1], prob_[2], prob_[3], prob_[4])*kl_lambda + (self.entropy(prob_[5], cat_target) - np.log(0.1))*cat_lambda
else:
loss = 0.0
return loss