Giter VIP home page Giter VIP logo

ofr's Introduction

COAT: Measuring Object Compositionality in Emergent Representations

The official code repository for "COAT: Measuring Object Compositionality in Emergent Representations"

Sirui Xie, Ari Morcos, Song-Chun Zhu, Ramakrishna Vedantam
Presented at ICML 2022.

[Paper] [Code] [Data]

teaser

Requirements

  • Python >= 3.8
  • PyTorch >= 1.7.1
  • Pytorch Lightning == 1.1.4
  • hydra-core == 1.2.0
  • tqdm
  • CUDA enabled computing device

Usage

This repository contains

  • The generation code for COAT testing corpus modified based on the CLEVR generation code
  • The generation code for Correlated CLEVR with colorful background based on the CLEVR generation code
  • Pytorch implementation of Slot Attention and beta-TC-VAE, modified based on repositories from Untitled-AI and AntixK respectively. The modification on Slot Attention mainly concerns the post-processing of deduplicating slots (controlled by dup_threshold) and removing invisible slots, i.e. slots with close-to-zero mask weight (controlled by rm_invisible).
  • The method validation_epoch_end in method.py for applying the COAT metric to slot-based representations and slot-free representations.

Data

The generated training and test data is available at here. You probably should change the following data paths in the configuration files in ./hydra_cfg/:

data_mix_idx: 1 # the index of data mixture, check data_mix.csv for details
data_mix_csv: /your_data_root/data_mix.csv # the file for different composition of the training set
data_root:  /your_data_root/clevr_corr/ # training data for both iid and correlated CLEVR with colorful background
val_root: /your_data_root/clevr_with_masks/ # evaluation data from original CLEVR for mask ARI metric
test_root: /your_data_root/coat_test/ # testing data for our COAT metric

In the training data we provide, data_mix.csv is a meta file for different composition of training sets with different correlations. Set data_mix_idx=1,2,3,4,5 for i.i.d. dataset; set data_mix_idx=13,14,15,16,17 for the correlated dataset in our paper.

To generate the test or the training data, check ./coat_generation/.

Our COAT measure can be expanded to domains other than CLEVR, the Dataset class CLEVRAlgebraTestset in data.py is reusable. It takes a list of tuples of images as input test_cases: List[List[Optional[str]]]. In train_hydra.py, such a list is loaded from /test_root/obj_test_final/CLEVR_test_cases_hard.csv, which is contained in our released data, with the following code:

if os.path.exists(os.path.join(cfg.test_root, "obj_test_final", "CLEVR_test_cases_hard.csv")):
    with open(os.path.join(cfg.test_root, "obj_test_final", "CLEVR_test_cases_hard.csv"), "r") as f:
        csv_reader = reader(f)
        self.obj_algebra_test_cases = list(csv_reader)
else:
    self.obj_algebra_test_cases = None
    print(os.path.join(cfg.test_root, "obj_test_final", "CLEVR_test_cases_hard.csv")+" does not exist.")

Training

Configuration files for models and training can be found in ./hydra_cfg/, and should be linked to hydra_train.py with

@hydra.main(config_path='hydra_cfg', config_name='cfg_file') 
""" 'cfg_file' can be either of 
    - 'bvae' for beta-tc-vae
    - 'slot-attn' for original slot attention model
    - 'slot-attn-no-dup' for slot attention model with duplicated slots removed
    - 'slot-attn-no-dup-no-inv' for slot attention model with duplicated slots and invisible slots removed. 
    
    Details of different post process on the representations can be found in the paper. 
"""

To train models from scratch with epoch-wise COAT test, run

python hydra_train.py

Logging

We use wandb to output logging. Logs should contain COAT metrics and test visualization.

The COAT metrics include the COAT-l2 and the COAT-acos scores which are normalized and corrected for chance, as well as the empirical probability of P(Loss(A, B, C, D)<Loss(A, B, C, D')), where D' is the hard negative. Here are some example training curves.

curves

The visualization shows how well the models reconstructs the images, as well as how well the slots are matched for Slot Attention.

Here are some examples.

bvae-viz

slot-attn-viz

slot-attn-no-dup-viz

lot-attn-no-dup-no-inv-viz

Citation

@inproceedings{xie2022coat,
  title={COAT: Measuring Object Compositionality in Emergent Representations},
  author={Xie, Sirui and Morcos, Ari S and Zhu, Song-Chun and Vedantam, Ramakrishna},
  booktitle={International Conference on Machine Learning},
  pages={24388--24413},
  year={2022},
  organization={PMLR}
}

ofr's People

Contributors

srxie avatar

Stargazers

Baoxiong Jia avatar Chi Zhang avatar

Watchers

Shanmukha Ramakrishna Vedantam avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.