Giter VIP home page Giter VIP logo

controlnet_plus_plus's Introduction

arXiv  huggingface demo 

🕹️ Environments

git clone https://github.com/liming-ai/ControlNet_Plus_Plus.git
pip3 install -r requirements.txt
pip3 install -U openmim
mim install mmengine
mim install "mmcv==2.1.0"
pip3 install "mmsegmentation>=1.0.0"
pip3 install mmdet

🕹️ Data Preperation

All the organized data has been put on Huggingface and will be automatically downloaded during training or evaluation. You can preview it in advance to check the data samples and disk space occupied with following links.

Task Training Data 🤗 Evaluation Data 🤗
LineArt, Hed, Canny Data, 1.14 TB Data, 2.25GB
Depth Data, 1.22 TB Data, 2.17GB
Segmentation ADE20K Data, 7.04 GB Same Path as Training Data
Segmentation COCOStuff Data, 61.9 GB Same Path as Training Data

🕹️ Training

By default, our training is based on 8 A100-80G GPUs. If your computational resources are insufficient for training, you may need to reduce the batch size and increase gradient accumulation at the same time, and we have not observed any performance degradation. Reducing the training resolution will result in performance degradation.

For segmentation task

ControlNet V1.1 Seg is trained on both ADE20K and COCOStuff, and these two datasets have different masks. To this end, we first perform normal model fine-tuning on each dataset, and then perform reward fine-tuning.

# Please refer to the reward script for details
bash train/reward_ade20k.sh
bash train/reward_cocostuff.sh

For other tasks

We can directly perform reward fine-tuning.

bash train/reward_canny.sh
bash train/reward_depth.sh
bash train/reward_hed.sh
bash train/reward_linedrawing.sh

Core Code

Please refer to the core code here, in summary:

Step 1: Predict the single-step denoised RGB image with noise sampler:

# Predict the single-step denoised latents
pred_original_sample = [
    noise_scheduler.step(noise, t, noisy_latent).pred_original_sample.to(weight_dtype) \
        for (noise, t, noisy_latent) in zip(model_pred, timesteps, noisy_latents)
]
pred_original_sample = torch.stack(pred_original_sample)

# Map the denoised latents into RGB images
pred_original_sample = 1 / vae.config.scaling_factor * pred_original_sample
image = vae.decode(pred_original_sample.to(weight_dtype)).sample
image = (image / 2 + 0.5).clamp(0, 1)

Step 2: Normalize the single-step denoised images according to different reward models

# The normalization depends on different reward models.
if args.task_name == 'depth':
    image = torchvision.transforms.functional.resize(image, (384, 384))
    image = normalize(image, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
elif args.task_name in ['canny', 'lineart', 'hed']:
    pass
else:
    image = normalize(image, (0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

Step 3: Apply both diffusion training loss and reward loss:

# reward model inference
if args.task_name == 'canny':
    outputs = reward_model(image.to(accelerator.device), low_threshold, high_threshold)
else:
    outputs = reward_model(image.to(accelerator.device))

# Determine which samples in the current batch need to calculate reward loss
timestep_mask = (args.min_timestep_rewarding <= timesteps.reshape(-1, 1)) & (timesteps.reshape(-1, 1) <= args.max_timestep_rewarding)

# Calculate reward loss
reward_loss = get_reward_loss(outputs, labels, args.task_name, reduction='none')

# Calculate final loss
reward_loss = reward_loss.reshape_as(timestep_mask)
reward_loss = (timestep_mask * reward_loss).sum() / (timestep_mask.sum() + 1e-10)
loss = pretrain_loss + reward_loss * args.grad_scale

🕹️ Evaluation

Please download the model weights and put them into each subset of checkpoints:

model HF weights🤗
LineArt model
Depth model
Hed (SoftEdge) model
Canny model
Segmentation (ADE20K) UperNet-R50, FCN-R101
Segmentation (COCOStuff) model

Please make sure the folder directory is consistent with the test script, then you can eval each model by:

bash eval/eval_ade20k.sh
bash eval/eval_cocostuff.sh
bash eval/eval_canny.sh
bash eval/eval_depth.sh
bash eval/eval_hed.sh
bash eval/eval_linedrawing.sh

The segmentation mIoU results of ControlNet and ControlNet++ in the arXiv v1 version of the paper were tested using images and labels saved in .jpg format, which resulted in errors.

We retested and reported the results using images and labels saved in .png format, please refer to our latest arXiv and ECCV Camera Ready releases.

Other comparison methods (Gligen/T2I-Adapter/UniControl/UniControlNet) and other evaluation metrics (FID/CLIP-score) were not affected by this error.

🕹️ Inference

Please refer to the Inference Branch or try our online Huggingface demo

License

This project is licensed under the Apache License 2.0 - see the LICENSE file for details.

🙏 Acknowledgements

We sincerely thank the Huggingface, ControlNet and ImageReward communities for their open source code and contributions. Our project would not be possible without these amazing works.

Citation

If our work assists your research, feel free to give us a star ⭐ or cite us using:

@inproceedings{controlnet_plus_plus,
    author    = {Ming Li, Taojiannan Yang, Huafeng Kuang, Jie Wu, Zhaoning Wang, Xuefeng Xiao, Chen Chen},
    title     = {ControlNet++: Improving Conditional Controls with Efficient Consistency Feedback},
    booktitle = {European Conference on Computer Vision (ECCV)},
    year      = {2024},
}

controlnet_plus_plus's People

Contributors

liming-ai avatar limingcv avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.