Giter VIP home page Giter VIP logo

ambiguous-dataset's Introduction

Ambiguous Datasets

Description

This repository contains ambiguous datasets generated using a conditional variational autoencoder (CVAE) approach. The datasets contain images that are ambiguous between a pair of classes. This class interpolation is done by conditional generation through the CVAE with a class-vector and mixing factor that gives ambiguous readout class-probability for the digit pair. Currently, only MNIST and EMNIST are supported, but the codebase is general enough that new datasets could be easily added.

How to run

First, install dependencies

# clone project   
git clone https://github.com/ABL-Lab/ambiguous-dataset

# install project   
cd ambiguous-dataset 
pip install -e .   
pip install -r requirements.txt

download ambiguous datasets (A-MNIST and A-EMNIST)

A-MNIST google drive link: https://drive.google.com/file/d/1JlJVoymk-3GAf9GHbgVeTTzQzPtvLUAt/view?usp=sharing

A-EMNIST google drive link: https://drive.google.com/file/d/19SspjkL24DngyIdLH5C0MTIXK0gQ2BhI/view?usp=sharing

Sequential MNIST (addition task): https://drive.google.com/file/d/1NMKNNo2lMp4Pg8UsF_cL5nzHRM5lMNRK/view?usp=sharing

Sequential A-MNIST (MNIST digits adding to ambiguous digit): https://drive.google.com/file/d/1vKvGwH_hvzQQrg8g7AxGxetvFEwfW_fQ/view?usp=sharing

Sequential A-EMNIST (EMNIST characters making a 3-letter word): https://drive.google.com/file/d/1TKN-B36AAvknmhS2L8Oknkr2HpXS62j-/view?usp=sharing

download trained model weights for experiments

gdrive link: https://drive.google.com/file/d/1a49OW61ShyAVZQos6iSLmBZTrTybvOtf/view?usp=sharing

unzip into a folder then:

from ambiguous.models.vae import MLPVAE
from ambiguous.models.readout import Readout
from ambiguous.models.cvae import Conv_CVAE

device='cuda' if torch.cuda.is_available() else 'cpu'

vae_path, readout_path, ccvae_path = ..., ..., ...
vae = MLPVAE(latent_dim=10, input_img_size=28).to(device)
vae.load_state_dict(torch.load(vae_path))

readout = Readout(latent_dim=10, h=512, n_classes=10).to(device)
readout.load_state_dict(torch.load(readout_path))

ccvae = Conv_CVAE(latent_dim=10,n_cls=10).to(device)
ccvae.load_state_dict(torch.load(ccvae_path))

if using cpu, pass map_location='cpu' when calling torch.load()

Importing Ambiguous Dataset to your own project

This project is setup as a package which means you can easily import any file into any other file like so.

from ambiguous.dataset.dataset import DatasetFromNPY, DatasetTriplet


root = 'path_to_ambiguous_dataset'
# older version
#trainset = DatasetFromNPY(root=root, download=False, train=True, transform=None)
#testset = DatasetFromNPY(root=root, download=False, train=False, transform=None)

# new version
trainset = DatasetTriplet(root=root, train=True, transform=None)
testset = DatasetTriplet(root=root, train=False, transform=None)

trainloader = DataLoader(trainset, batch_size=64, shuffle=True)
testloader = DataLoader(testset, batch_size=64, shuffle=False)
(clean1, amb, clean2), label = next(iter(amnist_loader))

Here are some examples of the generated triplets (clean, ambiguous, clean):

plot plot

Citation

@article{YourName,
  title={Your Title},
  author={Your team},
  journal={Location},
  year={Year}
}

ambiguous-dataset's People

Contributors

nizarislah avatar etterguillaume avatar dependabot[bot] avatar markovg avatar masht18 avatar

Stargazers

VictorLetzelter avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.