Comments (1)
Hi Xavier, thanks for your message and sorry it took so long to get back to you! I haven't included the experiment code in the repo because I haven't cleaned it up yet. At some point in the future when I have more time I'll probably clean it up properly and include it in the repo
For now, here is the uncleaned code, if anything doesn't work let me know and I'll try to help
import numpy as np
import torch
def metric_model(model, data_loader, batch_size=800, M=500, L=100,
thresh=1.0, num_factors=5, use_cuda=False):
"""Computes disentanglement metric on model using data in data_loader."""
# Sample a batch of data from dataset
indices = np.random.randint(0, len(data_loader.dataset), size=batch_size)
batch = torch.stack([data_loader.dataset[i][0] for i in indices])
# Get the latent factors which generated the data
factors = np.zeros((batch_size, num_factors))
for i in range(batch_size):
factors[i] = latents_from_idx(indices[i])
factors = torch.Tensor(factors)
if use_cuda:
batch = batch.cuda()
factors = factors.cuda()
# Encode data to get latents
latent_dist = model.encode(batch)
# Remove "inactive dimensions" (see Kim et. al paper appendix). These
# inactive dimensions are defined as ones where the KL between the
# posterior and the prior are below a certain threshold
# Calculate KL divergence for continuous variables
mean, logvar = latent_dist['cont']
kl_values = -0.5 * (1 + logvar - mean.pow(2) - logvar.exp())
kl_means = torch.mean(kl_values, dim=0)
# Remove latents with low KL
mask = kl_means > thresh
print("Keep {} dimensions.".format(mask.sum().item()))
latents_cont = latent_dist['cont'][0].detach().clone()[:, mask]
_, latents_disc = latent_dist['disc'][0].max(dim=1)
# Calculate KL divergence for discrete variables
alpha = latent_dist['disc'][0]
disc_dim = int(alpha.size()[-1])
log_dim = torch.Tensor([np.log(disc_dim)])
if use_cuda:
log_dim = log_dim.cuda()
# Calculate negative entropy of each row
neg_entropy = torch.sum(alpha * torch.log(alpha + 1e-12), dim=1)
# Take mean of negative entropy across batch
mean_neg_entropy = torch.mean(neg_entropy, dim=0)
# KL loss of alpha with uniform categorical variable
kl_disc = log_dim + mean_neg_entropy
if kl_disc.item() < thresh:
print("Removing discrete dimension.")
latents = latents_cont.clone()
else:
latents = torch.cat([latents_cont, latents_disc.float().unsqueeze(1)], dim=1)
if use_cuda:
latents = latents.cuda()
# Calculate empirical standard deviations
latents_cont_std = torch.std(latents_cont, dim=0)
latents_disc_std = torch.sqrt(discrete_variance(latents_disc))
if use_cuda:
latents_disc_std = latents_disc_std.cuda()
if kl_disc.item() < thresh:
latents_std = latents_cont_std.clone()
else:
latents_std = torch.cat([latents_cont_std, latents_disc_std.float()], dim=0)
# Initialize variances
_, num_latents = latents.size()
variances = torch.zeros(num_latents, num_factors, device=latents_std.device)
if use_cuda:
variances = variances.cuda()
# Find unique factors and prefix a sample of them
factors_unique = [factor.unique() for factor in factors.cpu().split(1, dim=1)]
fixed_factor_indices = np.random.randint(0, num_factors, size=M)
if use_cuda:
factors = factors.cuda()
factors_unique = [factor.cuda() for factor in factors_unique]
# Iterate and calculate metric
for m in range(M):
fixed_factor_idx = fixed_factor_indices[m]
fixed_factor_vals = factors_unique[fixed_factor_idx]
fixed_factor_val = fixed_factor_vals[np.random.choice(len(fixed_factor_vals))]
# Choose random latents
factor_bool = factors[:, fixed_factor_idx] == fixed_factor_val
if use_cuda:
factor_bool = factor_bool.cuda()
latents_subset = latents[factor_bool]
# Randomly choose L points as in paper
latents_subset = latents_subset[torch.randperm(latents_subset.size(0))][:L]
# Calculate variance of continous dimensions, i.e. the first
# num_latents - 1 indices of the latents
var_cont = torch.var(latents_subset[:, :-1], dim=0)
# Calculate variance of discrete dimensions, i.e. the last latent in
# the latents tensor
var_disc = discrete_variance(latents_subset[:, -1])
if use_cuda:
var_disc = var_disc.cuda()
var_all = torch.cat([var_cont, var_disc], dim=0)
# Total distance (as defined in paper)
dist = var_all / latents_std
d_star = torch.argmin(dist, dim=0)
# Increment "vote"
variances[d_star, fixed_factor_idx] += 1
return torch.max(variances, dim=1)[0].sum().item() / M
def latents_from_idx(idx):
"""Returns the latent variables which generated the data at idx (following)
index order given in initial dataset."""
# Shapes, scale, orientation, posX, posY
n_shape = 3
n_scale = 6
n_orientation = 40
n_posx = 32
n_posy = 32
latent_posy = idx % n_posy
latent_posx = int(idx / n_posy) % n_posx
latent_orientation = int(idx / (n_posy * n_posx)) % n_orientation
latent_scale = int(idx / (n_posy * n_posx * n_orientation)) % n_scale
latent_shape = int(idx / (n_posy * n_posx * n_orientation * n_scale)) % n_scale
return (latent_shape, latent_scale, latent_orientation, latent_posx, latent_posy)
def discrete_variance(disc_latents):
"""Gini's definition empirical variance."""
n = disc_latents.size(0)
dist = 0.
for i in range(n):
for j in range(n):
if disc_latents[i] != disc_latents[j]:
dist += 1.
return torch.Tensor([dist / (2 * n * (n - 1))])
from joint-vae.
Related Issues (3)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from joint-vae.