Giter VIP home page Giter VIP logo

dac's Introduction

Doubly Abductive Counterfactual Inference for Text-based Image Editing

This respository contains the code for the CVPR 2024 paper Doubly Abductive Counterfactual Inference for Text-based Image Editing.

Setup

Dependency Installation

First, clone the repository:

git clone https://github.com/xuesong39/DAC

Then, install the dependencies in a new virtual environment:

cd DAC
git clone https://github.com/huggingface/diffusers -b v0.24.0
cd diffusers
pip install -e .

Finally, cd in the main folder DAC and run:

pip install -r requirements.txt

Data Preparation

The images and annotations we use in the paper can be found here. For the format of data used in the experiments, we provide some examples in the folder DAC/data. For example, for the image DAC/data/cat/train/cat.jpeg, the folder containing source prompt is DAC/data/cat/ while that containing target prompt is DAC/data/cat-cap/.

Usage

Abduction-1

The fine-tuning script for abduction on U is train_text_to_image_lora.sh as follows:

export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
export TRAIN_DIR="ORIGIN_DATA_PATH"

CUDA_VISIBLE_DEVICES=0 accelerate launch train_text_to_image_lora.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --train_data_dir=$TRAIN_DIR --caption_column="text" \
  --resolution=512 \
  --train_batch_size=1 \
  --num_train_epochs=1000 --checkpointing_steps=1000 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 \
  --rank=512 \
  --output_dir="U_PATH" \
  --validation_prompt="xxx" \
  --report_to="wandb" \
  --validation_epochs=500

Please specify TRAIN_DIR (e.g., "./data/cat/"), --output_dir (e.g., "ckpt/cat"), and --validation_prompt (e.g., "A cat.").

Abduction-2

The fine-tuning script for abduction on Δ is train_text_to_image_lora_t.sh as follows:

export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
export TRAIN_DIR="TARGET_DATA_PATH"

CUDA_VISIBLE_DEVICES=0 accelerate launch train_text_to_image_lora_t.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --unet_lora_path="U_PATH" \
  --train_data_dir=$TRAIN_DIR --caption_column="text" \
  --resolution=512 --train_text_encoder \
  --train_batch_size=1 \
  --num_train_epochs=1000 --checkpointing_steps=1000 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 \
  --annealing=0.8 \
  --output_dir="DELTA_PATH" \
  --report_to="wandb" \
  --validation_epochs=500

Please specify TRAIN_DIR (e.g., "./data/cat-cap/"), --unet_lora_path (e.g., "ckpt/cat"), and --output_dir (e.g., "ckpt/cat-cap-annealing0.8"). You can also change --annealing to achieve control on hyperparameter $\eta$.

Action & Prediction

The inference script is inference_t.sh as follows:

CUDA_VISIBLE_DEVICES=0 python inference_t.py \
 --annealing=0.8 \
 --unet_path="U_PATH" \
 --text_path="DELTA_PATH" \
 --target_prompt="xxx" \
 --save_path="./"

Please specify --unet_path (e.g., "ckpt/cat"), --text_path (e.g., "ckpt/cat-cap-annealing0.8"), and --target_prompt (e.g., "A cat wearing a wool cap.").

Optional Usage

This part contains the implementation mentioned in the ablation analysis section in the paper, i.e., ablation on Abduction-1. We could incorporate another exogenous variable T in the Abduction-1 to further improve fidelity.

Abduction-1

The fine-tuning script for abduction on U is the same as the above.

The fine-tuning script for abduction on T is train_text_to_image_lora_t.sh as follows:

export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
export TRAIN_DIR="ORIGIN_DATA_PATH"

CUDA_VISIBLE_DEVICES=0 accelerate launch train_text_to_image_lora_t.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --unet_lora_path="U_PATH" \
  --train_data_dir=$TRAIN_DIR --caption_column="text" \
  --resolution=512 --train_text_encoder \
  --train_batch_size=1 \
  --num_train_epochs=1000 --checkpointing_steps=1000 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 \
  --annealing=0.8 \
  --output_dir="T_PATH" \
  --report_to="wandb" \
  --validation_epochs=500

Please specify TRAIN_DIR (e.g., "./data/cat/"), --unet_lora_path (e.g., "ckpt/cat"), and --output_dir (e.g., "ckpt/cat-annealing0.8")

Abduction-2

The fine-tuning script for abduction on Δ is train_text_to_image_lora_t2.sh as follows:

export MODEL_NAME="stabilityai/stable-diffusion-2-1-base"
export TRAIN_DIR="TARGTE_DATA_PATH"

CUDA_VISIBLE_DEVICES=0 accelerate launch train_text_to_image_lora_t2.py \
  --pretrained_model_name_or_path=$MODEL_NAME \
  --unet_lora_path="U_PATH" \
  --text_lora1_path="T_PATH" \
  --train_data_dir=$TRAIN_DIR --caption_column="text" \
  --resolution=512 --train_text_encoder \
  --train_batch_size=1 \
  --num_train_epochs=1000 --checkpointing_steps=1000 \
  --learning_rate=1e-04 --lr_scheduler="constant" --lr_warmup_steps=0 \
  --seed=42 \
  --annealing=0.8 \
  --output_dir="DELTA_PATH" \
  --report_to="wandb" \
  --validation_epochs=500

Please specify TRAIN_DIR (e.g., "./data/cat-cap/"), --unet_lora_path (e.g., "ckpt/cat"), --text_lora1_path (e.g., "ckpt/cat-annealing0.8"), and --output_dir (e.g., "ckpt/cat-cap-annealing0.8-t2").

Action & Prediction

The inference script is inference_t2.sh as follows:

CUDA_VISIBLE_DEVICES=0 python inference_t2.py \
 --annealing=0.8 \
 --unet_path="U_PATH" \
 --text1_path="T_PATH" \
 --text2_path="DELTA_PATH" \
 --target_prompt="xxx" \
 --save_path="./"

Please specify --unet_path (e.g., "ckpt/cat"), --text1_path (e.g., "ckpt/cat-annealing0.8"), --text2_path (e.g., "ckpt/cat-cap-annealing0.8-t2"), and --target_prompt (e.g., "A cat wearing a wool cap.").

Checkpoints

We provide some checkpoints in the following:

Image Abduction-1 Abduction-2
DAC/data/cat U Δ
DAC/data/glass U Δ
DAC/data/black U Δ
DAC/data/cat U, T Δ
DAC/data/glass U, T Δ
DAC/data/black U, T Δ

Acknowledgments

In this code we refer to the following codebase: Diffusers and PEFT. Great thanks to them!

dac's People

Contributors

xuesong39 avatar

Stargazers

XYY avatar Guoqing Hao avatar ZhangFengda avatar zhongqing Wu avatar Jiahao Cui avatar  avatar  avatar Faych Chen avatar Si.X avatar YangJiao avatar Daniel123 avatar  avatar Jiequan avatar Zhanjie Zhang avatar hrz avatar

Watchers

Kostas Georgiou avatar  avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.