Giter VIP home page Giter VIP logo

Comments (1)

EmilienDupont avatar EmilienDupont commented on July 16, 2024

Hi Xavier, thanks for your message and sorry it took so long to get back to you! I haven't included the experiment code in the repo because I haven't cleaned it up yet. At some point in the future when I have more time I'll probably clean it up properly and include it in the repo 😃
For now, here is the uncleaned code, if anything doesn't work let me know and I'll try to help 😁 Hope that helps!

import numpy as np
import torch


def metric_model(model, data_loader, batch_size=800, M=500, L=100,
                 thresh=1.0, num_factors=5, use_cuda=False):
    """Computes disentanglement metric on model using data in data_loader."""
    # Sample a batch of data from dataset
    indices = np.random.randint(0, len(data_loader.dataset), size=batch_size)
    batch = torch.stack([data_loader.dataset[i][0] for i in indices])
    # Get the latent factors which generated the data
    factors = np.zeros((batch_size, num_factors))
    for i in range(batch_size):
        factors[i] = latents_from_idx(indices[i])
    factors = torch.Tensor(factors)
    if use_cuda:
        batch = batch.cuda()
        factors = factors.cuda()

    # Encode data to get latents
    latent_dist = model.encode(batch)

    # Remove "inactive dimensions" (see Kim et. al paper appendix). These
    # inactive dimensions are defined as ones where the KL between the
    # posterior and the prior are below a certain threshold

    # Calculate KL divergence for continuous variables
    mean, logvar = latent_dist['cont']
    kl_values = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp())
    kl_means = torch.mean(kl_values, dim=0)
    # Remove latents with low KL
    mask = kl_means > thresh
    print("Keep {} dimensions.".format(mask.sum().item()))
    latents_cont = latent_dist['cont'][0].detach().clone()[:, mask]
    _, latents_disc = latent_dist['disc'][0].max(dim=1)

    # Calculate KL divergence for discrete variables
    alpha = latent_dist['disc'][0]
    disc_dim = int(alpha.size()[-1])
    log_dim = torch.Tensor([np.log(disc_dim)])
    if use_cuda:
        log_dim = log_dim.cuda()
    # Calculate negative entropy of each row
    neg_entropy = torch.sum(alpha * torch.log(alpha + 1e-12), dim=1)
    # Take mean of negative entropy across batch
    mean_neg_entropy = torch.mean(neg_entropy, dim=0)
    # KL loss of alpha with uniform categorical variable
    kl_disc = log_dim + mean_neg_entropy
    if kl_disc.item() < thresh:
        print("Removing discrete dimension.")
        latents = latents_cont.clone()
    else:
        latents = torch.cat([latents_cont, latents_disc.float().unsqueeze(1)], dim=1)
    if use_cuda:
        latents = latents.cuda()

    # Calculate empirical standard deviations
    latents_cont_std = torch.std(latents_cont, dim=0)
    latents_disc_std = torch.sqrt(discrete_variance(latents_disc))
    if use_cuda:
        latents_disc_std = latents_disc_std.cuda()
    if kl_disc.item() < thresh:
        latents_std = latents_cont_std.clone()
    else:
        latents_std = torch.cat([latents_cont_std, latents_disc_std.float()], dim=0)

    # Initialize variances
    _, num_latents = latents.size()
    variances = torch.zeros(num_latents, num_factors, device=latents_std.device)
    if use_cuda:
        variances = variances.cuda()
    # Find unique factors and prefix a sample of them
    factors_unique = [factor.unique() for factor in factors.cpu().split(1, dim=1)]
    fixed_factor_indices = np.random.randint(0, num_factors, size=M)
    if use_cuda:
        factors = factors.cuda()
        factors_unique = [factor.cuda() for factor in factors_unique]
    # Iterate and calculate metric
    for m in range(M):
        fixed_factor_idx = fixed_factor_indices[m]
        fixed_factor_vals = factors_unique[fixed_factor_idx]
        fixed_factor_val = fixed_factor_vals[np.random.choice(len(fixed_factor_vals))]
        # Choose random latents
        factor_bool = factors[:, fixed_factor_idx] == fixed_factor_val
        if use_cuda:
            factor_bool = factor_bool.cuda()
        latents_subset = latents[factor_bool]
        # Randomly choose L points as in paper
        latents_subset = latents_subset[torch.randperm(latents_subset.size(0))][:L]
        # Calculate variance of continous dimensions, i.e. the first
        # num_latents - 1 indices of the latents
        var_cont = torch.var(latents_subset[:, :-1], dim=0)
        # Calculate variance of discrete dimensions, i.e. the last latent in
        # the latents tensor
        var_disc = discrete_variance(latents_subset[:, -1])
        if use_cuda:
            var_disc = var_disc.cuda()
        var_all = torch.cat([var_cont, var_disc], dim=0)
        # Total distance (as defined in paper)
        dist = var_all / latents_std
        d_star = torch.argmin(dist, dim=0)
        # Increment "vote"
        variances[d_star, fixed_factor_idx] += 1
    return torch.max(variances, dim=1)[0].sum().item() / M


def latents_from_idx(idx):
    """Returns the latent variables which generated the data at idx (following)
    index order given in initial dataset."""
    # Shapes, scale, orientation, posX, posY
    n_shape = 3
    n_scale = 6
    n_orientation = 40
    n_posx = 32
    n_posy = 32

    latent_posy = idx % n_posy
    latent_posx = int(idx / n_posy) % n_posx
    latent_orientation = int(idx / (n_posy * n_posx)) % n_orientation
    latent_scale = int(idx / (n_posy * n_posx * n_orientation)) % n_scale
    latent_shape = int(idx / (n_posy * n_posx * n_orientation * n_scale)) % n_scale

    return (latent_shape, latent_scale, latent_orientation, latent_posx, latent_posy)


def discrete_variance(disc_latents):
    """Gini's definition empirical variance."""
    n = disc_latents.size(0)
    dist = 0.
    for i in range(n):
        for j in range(n):
            if disc_latents[i] != disc_latents[j]:
                dist += 1.
    return torch.Tensor([dist / (2 * n * (n - 1))])

from joint-vae.

Related Issues (3)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.