Giter VIP home page Giter VIP logo

tcm's Introduction

Implementation of Transporting Causal Mechanisms (TCM)

This includes the implementation of our ICCV2021 Oral paper Transporting Causal Mechanisms for Unsupervised Domain Adaptation, where we give a theoretically-grounded solution to Unsupervised Domain Adaptation using the transportability theory, which requires the stratification and representation of the unobserved confounder. Specifically, we provide a practical implementation by identifying the stratification through learning Disentangled Causal Mechanisms (DCMs) and the representation through the proxy theory.

Prerequisites

  • pytorch = 1.0.1
  • torchvision = 0.2.1
  • numpy = 1.17.2
  • python3.6
  • cuda10

Preparation

  1. Download datasets. For ImageCLEF-DA, please download the dataset using this link: https://drive.google.com/file/d/1_BXJlbalvW7I9xzHpMMy9k5SoCtQ3roJ/view?usp=sharing, where we organized the images in the dataset similar to how Office-Home is stored. Alternatively, you can download ImageCLEF-DA from official sources and process the dataset as given below. The other datasets can be downloaded from the official sources.

    imageclef_m
      |-- c (domain name)
        |-- 0 (class id)
        	|-- ...(images)
        |-- 1
        	|-- ...(images)
        ...
      |-- i
      	...
      |-- p
      	...
    
  2. In the data folder, modify the file paths based on where dataset is stored. Note that the numbers following the file path is the class ID, and should not be modified. For example, for ImageCLEF-DA, change the file paths in data/ic/c.txt, data/ic/i.txt, data/ic/p.txt.

  3. Under scripts/cyclegan, change a_root and b_root based on the dataset directory and domain name. Modify the checpoints_dir to where you want to store the trained DCMs.

  4. The configs folder store the configurations of the trained DCMs. Modify these configuration files according to the dataset paths and checkpoint directories on your machine. Specifically, a_root and b_root are the paths to the stored dataset. checkpoints_dir is the saved DCMs networks location. cdm_path is where the cross-domain counterfactual images will be saved (as a pre-processing step for the 2nd stage of TCM).

Training and Testing

Step 1: DCMs training

This corresponds to Section 3.1 of our paper. The script to initialize training is stored in scripts/cyclegan. Run the python file to start DCMs training.

Step 2: Generate Cross-Domain Counterfactual Images (\hat{X})

We pre-save the generated cross-domain counterfactual images for faster TCM training. This is achieved by running generate_cdm.py. However, this step is included as part of the automated scripts for the next step (corresponding to the cdm field in the configuration files in configs folder). So you don't need to worry about it.

Step 3: Learning h_y and Inference

This is achieved by running the python files in scripts/tcm. The python scripts will first generate the counterfactual images, followed by training and testing.

tcm's People

Contributors

yue-zhongqi avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

tcm's Issues

Reproduce the report results.

This is a very creative paper,I want to reproduce your results in your paper,but the result is not very ideal;I want to know whether the default parameters you set can get the actual results in the paper. Thanks! : )

Question about `GVBD` and `GVBG`

    def get_alignment_loss(self, name, logits, x, align_logits):
        discriminator = getattr(self, name + "_discriminator")
        gvbg = getattr(self, name + "_gvbg")
        coeff = calc_coeff(self.iteration)
        bridge = gvbg(x)
        if self.opt.gvbg_weight > 0 and align_logits:
            logits = logits - bridge
        setattr(self, "loss_%s_gvbg" % (name), loss_gvbg)
        setattr(self, "loss_%s_gvbd" % (name), loss_gvbd)

Hi~I'm very confused by these two linear layers. They seem to be a logit-related trick, but I find no description in the paper/supp.
I would appreciate it if any explanation could be given on what they are used for.

核心问题

作者你好,请问因果推断在这套源码中具体什么位置体现的?我没有找到,如果您能告诉我就太好了

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.