This repository contains ambiguous datasets generated using a conditional variational autoencoder (CVAE) approach. The datasets contain images that are ambiguous between a pair of classes. This class interpolation is done by conditional generation through the CVAE with a class-vector and mixing factor that gives ambiguous readout class-probability for the digit pair. Currently, only MNIST and EMNIST are supported, but the codebase is general enough that new datasets could be easily added.
First, install dependencies
# clone project
git clone https://github.com/ABL-Lab/ambiguous-dataset
# install project
cd ambiguous-dataset
pip install -e .
pip install -r requirements.txt
A-MNIST google drive link: https://drive.google.com/file/d/1JlJVoymk-3GAf9GHbgVeTTzQzPtvLUAt/view?usp=sharing
A-EMNIST google drive link: https://drive.google.com/file/d/19SspjkL24DngyIdLH5C0MTIXK0gQ2BhI/view?usp=sharing
Sequential MNIST (addition task): https://drive.google.com/file/d/1NMKNNo2lMp4Pg8UsF_cL5nzHRM5lMNRK/view?usp=sharing
Sequential A-MNIST (MNIST digits adding to ambiguous digit): https://drive.google.com/file/d/1vKvGwH_hvzQQrg8g7AxGxetvFEwfW_fQ/view?usp=sharing
Sequential A-EMNIST (EMNIST characters making a 3-letter word): https://drive.google.com/file/d/1TKN-B36AAvknmhS2L8Oknkr2HpXS62j-/view?usp=sharing
gdrive link: https://drive.google.com/file/d/1a49OW61ShyAVZQos6iSLmBZTrTybvOtf/view?usp=sharing
unzip into a folder then:
from ambiguous.models.vae import MLPVAE
from ambiguous.models.readout import Readout
from ambiguous.models.cvae import Conv_CVAE
device='cuda' if torch.cuda.is_available() else 'cpu'
vae_path, readout_path, ccvae_path = ..., ..., ...
vae = MLPVAE(latent_dim=10, input_img_size=28).to(device)
vae.load_state_dict(torch.load(vae_path))
readout = Readout(latent_dim=10, h=512, n_classes=10).to(device)
readout.load_state_dict(torch.load(readout_path))
ccvae = Conv_CVAE(latent_dim=10,n_cls=10).to(device)
ccvae.load_state_dict(torch.load(ccvae_path))
if using cpu, pass map_location='cpu' when calling torch.load()
This project is setup as a package which means you can easily import any file into any other file like so.
from ambiguous.dataset.dataset import DatasetFromNPY, DatasetTriplet
root = 'path_to_ambiguous_dataset'
# older version
#trainset = DatasetFromNPY(root=root, download=False, train=True, transform=None)
#testset = DatasetFromNPY(root=root, download=False, train=False, transform=None)
# new version
trainset = DatasetTriplet(root=root, train=True, transform=None)
testset = DatasetTriplet(root=root, train=False, transform=None)
trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
(clean1, amb, clean2), label = next(iter(amnist_loader))
Here are some examples of the generated triplets (clean, ambiguous, clean):
@article{YourName,
title={Your Title},
author={Your team},
journal={Location},
year={Year}
}