Giter VIP home page Giter VIP logo

datum's Introduction

Overview

PWC PWC

Adapting a segmentation model from a labeled source domain to a target domain, where a single unlabeled datum is available, is one of the most challenging problems in domain adaptation and is otherwise known as one-shot unsupervised domain adaptation (OSUDA).

Most of the prior works have addressed the problem by relying on style transfer techniques, where the source images are stylized to have the appearance of the target domain. Departing from the common notion of transferring only the target “texture” information, we leverage text-to-image diffusion models (e.g., Stable Diffusion) to generate a synthetic target dataset with photo-realistic images that not only faithfully depict the style of the target domain, but are also characterized by novel scenes in diverse contexts.

The text interface in our method Data AugmenTation with diffUsion Models (DATUM) endows us with the possibility of guiding the generation of images towards desired semantic concepts while respecting the original spatial context of a single training image, which is not possible in existing OSUDA methods. Extensive experiments on standard benchmarks show that our DATUM surpasses the state-of-the-art OSUDA methods by up to +7.1%.

teaser

For more information on DATUM, please check our [Paper].

If you find this project useful in your research, please consider citing:

@inproceedings{benigmim2023one,
  title={One-shot Unsupervised Domain Adaptation with Personalized Diffusion Models},
  author={Benigmim, Yasser and Roy, Subhankar and Essid, Slim and Kalogeiton, Vicky and Lathuili{\`e}re, St{\'e}phane},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={698--708},
  year={2023}
}

Setup Environment

For this project, we used python 3.8.5. We recommend setting up a new virtual environment:

python -m venv ~/venv/datum
source ~/venv/datum/bin/activate

In that environment, the requirements can be installed with:

pip install -r requirements.txt -f https://download.pytorch.org/whl/torch_stable.html
pip install mmcv-full==1.3.7  # requires the other packages to be installed first

Further, please download the MiT weights . If problems occur with the automatic download, please follow the instructions for a manual download within the script.

sh tools/download_checkpoints.sh

Setup Datasets

Cityscapes: Please, download leftImg8bit_trainvaltest.zip and gtFine_trainvaltest.zip from here and extract them to data/cityscapes.

GTA: Please, download all image and label packages from here and extract them to data/gta.

Synthia : Please, download SYNTHIA-RAND-CITYSCAPES from here and extract it to data/synthia.

One shot image : Please, copy/paste any image you want from data/cityscapes/train to data/one_shot_image.

The final folder structure should look like this:

DATUM
├── ...
├── data
│   ├── cityscapes
│   │   ├── leftImg8bit
│   │   │   ├── train
│   │   │   ├── val
│   │   ├── gtFine
│   │   │   ├── train
│   │   │   ├── val
│   ├── gta
│   │   ├── images
│   │   ├── labels
│   ├── synthia
│   │   ├── RGB
│   │   ├── GT
│   │   │   ├── LABELS
│   ├── one_shot_image
│   │   ├── x_leftImg8bit.png
├── ...

Data Preprocessing: Finally, please run the following scripts to convert the label IDs to the train IDs and to generate the class index for RCS:

python tools/convert_datasets/gta.py data/gta --nproc 8
python tools/convert_datasets/cityscapes.py data/cityscapes --nproc 8
python tools/convert_datasets/synthia.py data/synthia/ --nproc 8

Training

Personalization stage

To train a Stable Diffusion using Dreambooth method, first clone the diffusers (0.12.1) library.

Then copy/paste the 3 python scripts and my_utils folder contained in DATUM/dreambooth into diffusers/examples/dreambooth of diffusers and use the following command to finetune Stable Diffusion using Dreambooth method :

python train_dreambooth.py --instance_data_dir data/one_shot_image --output_dir NAME_OF_EXPERIMENT

the checkpoints will be stored in logs/checkpoints/NAME_OF_EXPERIMENT

Data generation stage

To convert all your trained checkpoints to inference pipelines :

python convert_dreambooth.py --filepath NAME_OF_EXPERIMENT

To generate the dataset using a specific checkpoint:

python generate_dreambooth.py --filepath NAME_OF_EXPERIMENT --ckpt NUM_STEPS

the generated dataset will be stored in logs/images/NAME_OF_EXPERIMENT

Domain segmentation stage

Create a symlink in data which points to the generated dataset stored in logs/images/NAME_OF_EXPERIMENT.

To train DAFormer+DATUM on GTA→Cityscapes with the MiT-B5 encoder, please use the following command :

python run_experiments.py --exp 1 --data-root logs/images/NAME_OF_EXPERIMENT

For DAFormer+DATUM on GTA→Cityscapes with ResNet-101 encoder :

python run_experiments.py --exp 2 --data-root logs/images/NAME_OF_EXPERIMENT

To train DAFormer+DATUM on SYNTHIA→Cityscapes with MiT-B5 encoder :

python run_experiments.py --exp 3 --data-root logs/images/NAME_OF_EXPERIMENT

To train DAFormer+DATUM on SYNTHIA→Cityscapes with ResNet-101 encoder :

python run_experiments.py --exp 4 --data-root logs/images/NAME_OF_EXPERIMENT

The generated configs will be stored in configs/generated/ and the checkpoints will be stored in work_dirs/ folder

Testing & Predictions

The trained models can be tested with the following command:

sh test.sh work_dirs/CHECKPOINT_DIRECTORY

The segmentation maps will be stored in work_dirs/CHECKPOINT_DIRECTORY/preds.

When evaluating a model trained on Synthia→Cityscapes, please note that the evaluation script calculates the mIoU for all 19 Cityscapes classes. However, Synthia contains only labels for 16 of these classes. Therefore, it is a common practice in UDA to report the mIoU for Synthia→Cityscapes only on these 16 classes. As the Iou for the 3 missing classes is 0, you can do the conversion mIoU16 = mIoU19 * 19 / 16.

Checkpoints

Below, we provide checkpoints of DAFormer+DATUM. Since the results in the paper are provided as the mean over three random one shot images, we provide the checkpoint with the median validation performance here:

We also provide the checkpoints of HRDA+DATUM :

The checkpoints come with the training logs. Please note that:

  • The logs provide the mIoU for 19 classes. For Synthia→Cityscapes, it is necessary to convert the mIoU to the 16 valid classes. Please, read the section above for converting the mIoU.
  • The logs provide the mIoU on the validation set.

Acknowledgements

This project (README as well) is heavily based on DAFormer. We thank their authors for making the source code publicly available.

License

This project is released under the Apache License 2.0, while some specific features in this repository are with other licenses. Please refer to LICENSES.md for the careful check, if you are using our code for commercial matters.

datum's People

Contributors

yasserben avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar

datum's Issues

Access denied while downloading the MiT weights.

When i run the scripit $ sh tools/download_checkpoints.sh,it return following error:

"""
Access denied with the following error:

        Cannot retrieve the public link of the file. You may need to change
        the permission to 'Anyone with the link', or have had many accesses. 

You may still be able to access the file from the browser:

         https://drive.google.com/uc?id=1d3wU8KNjPL4EqMCIEO_rO-O3-REpG82T 
"""

And i try for manual download but got:"Unable to access content.The organization that owns this content doesn't allow you to access it."when i access to the URL "https://drive.google.com/drive/folders/1b7bwrInTW4VLEm27YawHOAMSMikga2Ia?usp=sharing".

Please update the url which can get the checkpoint for this project.

HRDA+DATUM experiment

Do you have the code of HRDA+DATUM experiment? I can't seem to find it, thank you very much for your reply!

Image generation size

Hi,

Appreciate your brilliant work!

Just a quick question. If the training size and generating size can be done without crop works?

About generating images

Hello, thank you for your excellent work.
Now I am trying to reproduce your results.
When I generate target images, the images are 512 x 512 pixels. I wonder if it is correct because the images on the paper look like rectangles.

Thank you in advance,
Best

Small feedback

Hi, while using your super project, I noted some improvement points :

  • In the setup datasets part, gt_trainvaltest.zip should be gtFine_trainvaltest.zip
  • train_dreambooth.py saves checkpoints and others in DATUM/diffusers/examples/NAME_OF_EXPERIMENT and convert_dreambooth.py looks for files in DATUM/diffusers/examples/dreambooth/logs
  • Import error (DAFormer.experiments instead of experiments in run_experiments.py)
  • The generated dataset is saved in DATUM/diffusers/examples/dreambooth/logs/images/NAME_OF_EXPERIMENT/NAME_OF_EXPERIMENT_eXXX whereas run_experiments.py looks for files in the data folder (solved by moving the generated dataset to the data folder)

Problem about downloading Mit weights

Hello, author! When I use manually download MIT weights, visit [https://drive.google.com/drive/folders/1b7bwrInTW4VLEm27YawHOAMSMikga2Ia?usp=sharing] website shows a 404 error, can't find the resources

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.