Giter VIP home page Giter VIP logo

birefnet's Introduction

Bilateral Reference for High-Resolution Dichotomous Image Segmentation

DIS-Sample_1 DIS-Sample_2

This repo is the official implementation of "Bilateral Reference for High-Resolution Dichotomous Image Segmentation" (arXiv 2024).

Authors: Peng Zheng, Dehong Gao, Deng-Ping Fan, Li Liu, Jorma Laaksonen, Wanli Ouyang, & Nicu Sebe.

[arXiv] [code] [stuff] [中文版]

Our BiRefNet has achieved SOTA on many similar HR tasks:

DIS: PWC PWC PWC PWC PWC

Figure of Comparison on Papers with Codes (by the time of this work):


COD:PWC PWC PWC PWC

Figure of Comparison on Papers with Codes (by the time of this work):


HRSOD: PWC PWC PWC PWC PWC

Figure of Comparison on Papers with Codes (by the time of this work):


Try our online demos for inference:

  • Inference and evaluation of your given weights: Open In Colab
  • Online Inference with GUI with adjustable resolutions: Hugging Face Spaces
  • Online Single Image Inference on Colab: Open In Colab

Model Zoo

For more general use of our BiRefNet, I managed to extend the original adademic one to more general ones for better application in real life.

Datasets and datasets are suggested to download from official pages. But you can also download the packaged ones: DIS, HRSOD, COD, Backbones.

Find performances (almost all metrics) of all models in the exp-TASK_SETTINGS folders in [stuff].

Models in the original paper, for comparison on benchmarks:

Task Training Sets Backbone Download
DIS DIS5K-TR swin_v1_large google-drive
COD COD10K-TR, CAMO-TR swin_v1_large google-drive
HRSOD DUTS-TR swin_v1_large google-drive
HRSOD HRSOD-TR swin_v1_large google-drive
HRSOD UHRSD-TR swin_v1_large google-drive
HRSOD DUTS-TR, HRSOD-TR swin_v1_large google-drive
HRSOD DUTS-TR, UHRSD-TR swin_v1_large google-drive
HRSOD HRSOD-TR, UHRSD-TR swin_v1_large google-drive
HRSOD DUTS-TR, HRSOD-TR, UHRSD-TR swin_v1_large google-drive
Models trained with customed data (massive, portrait), for general use in practical application:
Task Training Sets Backbone Test Set Metric (S, wF[, HCE]) Download
general use DIS5K-TR,DIS-TEs, DUTS-TR_TE,HRSOD-TR_TE,UHRSD-TR_TE, HRS10K-TR_TE swin_v1_large DIS-VD 0.889, 0.840, 1152 google-drive
general use DIS5K-TR,DIS-TEs, DUTS-TR_TE,HRSOD-TR_TE,UHRSD-TR_TE, HRS10K-TR_TE swin_v1_tiny DIS-VD 0.867, 0.809, 1182 Google-drive
general use DIS5K-TR, DIS-TEs swin_v1_large DIS-VD 0.907, 0.865, 1059 google-drive
portrait segmentation P3M-10k swin_v1_large P3M-500-P 0.982, 0.990 google-drive
Segmentation with box guidance:

In progress...

Model efficiency:

Screenshot from the original paper. All tests are conducted on a single A100 GPU.

Third-Party Creations

Concerning edge devices with less computing power, we provide a lightweight version with swin_v1_tiny as the backbone, which is x4+ faster and x5+ smaller. The details can be found in this issue and links there.

We found there've been some 3rd party applications based on our BiRefNet. Many thanks for their contribution to the community!
Choose the one you like to try with clicks instead of codes:

  1. Applications:

  2. More Visual Comparisons

    video-from_twitter_toyxyz3_2.mp4
    video-from_twitter_toyxyz3_1.mp4

Usage

Environment Setup

# PyTorch==2.0.1 is used for faster training with compilation.
conda create -n dis python=3.9 -y && conda activate dis
pip install -r requirements.txt

Dataset Preparation

Download combined training / test sets I have organized well from: DIS--COD--HRSOD or the single official ones in the single_ones folder, or their official pages. You can also find the same ones on my BaiduDisk: DIS--COD--HRSOD.

Weights Preparation

Download backbone weights from my google-drive folder or their official pages.

Run

# Train & Test & Evaluation
./train_test.sh RUN_NAME GPU_NUMBERS_FOR_TRAINING GPU_NUMBERS_FOR_TEST
# See train.sh / test.sh for only training / test-evaluation.
# After the evluation, run `gen_best_ep.py` to select the best ckpt from a specific metric (you choose it from Sm, wFm, HCE (DIS only)).

Well-trained weights:

Download the BiRefNet-{TASK}-{EPOCH}.pth from [stuff]. Info of the corresponding (predicted_maps/performance/training_log) weights can be also found in folders like exp-BiRefNet-{TASK_SETTINGS} in the same directory.

You can also download the weights from the release of this repo.

The results might be a bit different from those in the original paper, you can see them in the eval_results-BiRefNet-{TASK_SETTINGS} folder in each exp-xx, we will update them in the following days. Due to the very high cost I used (A100-80G x 8) which many people cannot afford to (including myself....), I re-trained BiRefNet on a single A100-40G only and achieve the performance on the same level (even better). It means you can directly train the model on a single GPU with 36.5G+ memory. BTW, 5.5G GPU memory is needed for inference in 1024x1024. (I personally paid a lot for renting an A100-40G to re-train BiRefNet on the three tasks... T_T. Hope it can help you.)

But if you have more and more powerful GPUs, you can set GPU IDs and increase the batch size in config.py to accelerate the training. We have made all this kind of things adaptive in scripts to seamlessly switch between single-card training and multi-card training. Enjoy it :)

Some of my messages:

This project was originally built for DIS only. But after the updates one by one, I made it larger and larger with many functions embedded together. Finally, you can use it for any binary image segmentation tasks, such as DIS/COD/SOD, medical image segmentation, anomaly segmentation, etc. You can eaily open/close below things (usually in config.py):

  • Multi-GPU training: open/close with one variable.
  • Backbone choices: Swin_v1, PVT_v2, ConvNets, ...
  • Weighted losses: BCE, IoU, SSIM, MAE, Reg, ...
  • Adversarial loss for binary segmentation (proposed in my previous work MCCL).
  • Training tricks: multi-scale supervision, freezing backbone, multi-scale input...
  • Data collator: loading all in memory, smooth combination of different datasets for combined training and test.
  • ... I really hope you enjoy this project and use it in more works to achieve new SOTAs.

Quantitative Results

Qualitative Results

Citation

@article{zheng2024birefnet,
  title={Bilateral Reference for High-Resolution Dichotomous Image Segmentation},
  author={Zheng, Peng and Gao, Dehong and Fan, Deng-Ping and Liu, Li and Laaksonen, Jorma and Ouyang, Wanli and Sebe, Nicu},
  journal={arXiv},
  year={2024}
}

Contact

Any question, discussion or even complaint, feel free to leave issues here or send me e-mails ([email protected]).

birefnet's People

Contributors

alive1024 avatar dengpingfan avatar zhengpeng7 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

birefnet's Issues

Guiding what to segment

Hi !
Firstly amazing work !
Not really an issue but more a question.

By reading the white paper, I am not sure if you can choose what to segment rather than just the foreground, (a bit like Segment Anything Model).
As I am just an artist, I am not understanding everything.
Could you confirm or deny it ?

Best regards,

inference speed extremely slow

Hello,

The inference speed is extremely slow.
I am doing the inference with GPU, but its the same i am doing with u2net and ths speed there is 12x faster.

Is there anything i can do to speed up things?

I have also tried to export to onnx but get error

import torch
import torch.onnx
from models.birefnet import BiRefNet
from utils import check_state_dict
from torch.onnx import register_custom_op_symbolic

Register custom symbolic function for deform_conv2d

def deform_conv2d_symbolic(g, input, weight, offset, bias, stride, padding, dilation, groups, deformable_groups, use_mask=False, mask=None):
return g.op("DeformConv2d", input, weight, offset, bias,
stride_i=stride, padding_i=padding, dilation_i=dilation,
groups_i=groups, deformable_groups_i=deformable_groups)

register_custom_op_symbolic('torchvision::deform_conv2d', deform_conv2d_symbolic, 11)

Load the model

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = BiRefNet(bb_pretrained=False).to(device)
state_dict = torch.load("/root/BiRefNet-massive-epoch_240.pth", map_location=device)
state_dict = check_state_dict(state_dict)
model.load_state_dict(state_dict)
model.eval()

Dummy input to trace the model

dummy_input = torch.randn(1, 3, 1024, 1024).to(device)

Ensure to handle tensor-to-Python type conversions in your model

Example modifications:

if W % self.patch_size[1] != 0:

replace with

if (W % self.patch_size[1]).item() != 0:

Export the model

onnx_model_path = "/root/BiRefNet.onnx"
torch.onnx.export(
model, # model being run
dummy_input, # model input (or a tuple for multiple inputs)
onnx_model_path, # where to save the model (can be a file or file-like object)
export_params=True, # store the trained parameter weights inside the model file
opset_version=11, # the ONNX version to export the model to
do_constant_folding=True, # whether to execute constant folding for optimization
input_names=['input'], # the model's input names
output_names=['output'], # the model's output names
dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # variable length axes
)

print(f"Model has been converted to ONNX and saved at {onnx_model_path}")

Missing supplementary material in the paper

I was not able to figure where are the supplementary material present in the paper . The pink color indicates that it is clickable but thats not the case . The information is also neither present at the end of the paper nor in the git repo

RuntimeError: Error(s) in loading state_dict

Unexpected key(s) in state_dict: "squeeze_module.0.dec_att.aspp1.bn.weight", "squeeze_module.0.dec_att.aspp1.bn.bias", "squeeze_module.0.dec_att.aspp1.bn.running_mean", "squeeze_module.0.dec_att.aspp1.bn.running_var", "squeeze_module.0.dec_att.aspp1.bn.num_batches_tracked", "squeeze_module.0.dec_att.aspp_deforms.0.bn.weight", "squeeze_module.0.dec_att.aspp_deforms.0.bn.bias", "squeeze_module.0.dec_att.aspp_deforms.0.bn.running_mean", "squeeze_module.0.dec_att.aspp_deforms.0.bn.running_var", "squeeze_module.0.dec_att.aspp_deforms.0.bn.num_batches_tracked", "squeeze_module.0.dec_att.aspp_deforms.1.bn.weight", "squeeze_module.0.dec_att.aspp_deforms.1.bn.bias", "squeeze_module.0.dec_att.aspp_deforms.1.bn.running_mean", "squeeze_module.0.dec_att.aspp_deforms.1.bn.running_var", "squeeze_module.0.dec_att.aspp_deforms.1.bn.num_batches_tracked", "squeeze_module.0.dec_att.aspp_deforms.2.bn.weight", "squeeze_module.0.dec_att.aspp_deforms.2.bn.bias", "squeeze_module.0.dec_att.aspp_deforms.2.bn.running_mean", "squeeze_module.0.dec_att.aspp_deforms.2.bn.running_var", "squeeze_module.0.dec_att.aspp_deforms.2.bn.num_batches_tracked", "squeeze_module.0.dec_att.global_avg_pool.2.weight", "squeeze_module.0.dec_att.global_avg_pool.2.bias", "squeeze_module.0.dec_att.global_avg_pool.2.running_mean", "squeeze_module.0.dec_att.global_avg_pool.2.running_var", "squeeze_module.0.dec_att.global_avg_pool.2.num_batches_tracked", "squeeze_module.0.dec_att.bn1.weight", "squeeze_module.0.dec_att.bn1.bias", "squeeze_module.0.dec_att.bn1.running_mean", "squeeze_module.0.dec_att.bn1.running_var", "squeeze_module.0.dec_att.bn1.num_batches_tracked", "squeeze_module.0.bn_in.weight", "squeeze_module.0.bn_in.bias", "squeeze_module.0.bn_in.running_mean", "squeeze_module.0.bn_in.running_var", "squeeze_module.0.bn_in.num_batches_tracked", "squeeze_module.0.bn_out.weight", "squeeze_module.0.bn_out.bias", "squeeze_module.0.bn_out.running_mean", "squeeze_module.0.bn_out.running_var", "squeeze_module.0.bn_out.num_batches_tracked", "decoder.decoder_block4.dec_att.aspp1.bn.weight", "decoder.decoder_block4.dec_att.aspp1.bn.bias", "decoder.decoder_block4.dec_att.aspp1.bn.running_mean", "decoder.decoder_block4.dec_att.aspp1.bn.running_var", "decoder.decoder_block4.dec_att.aspp1.bn.num_batches_tracked", "decoder.decoder_block4.dec_att.aspp_deforms.0.bn.weight", "decoder.decoder_block4.dec_att.aspp_deforms.0.bn.bias", "decoder.decoder_block4.dec_att.aspp_deforms.0.bn.running_mean", "decoder.decoder_block4.dec_att.aspp_deforms.0.bn.running_var", "decoder.decoder_block4.dec_att.aspp_deforms.0.bn.num_batches_tracked", "decoder.decoder_block4.dec_att.aspp_deforms.1.bn.weight", "decoder.decoder_block4.dec_att.aspp_deforms.1.bn.bias", "decoder.decoder_block4.dec_att.aspp_deforms.1.bn.running_mean", "decoder.decoder_block4.dec_att.aspp_deforms.1.bn.running_var", "decoder.decoder_block4.dec_att.aspp_deforms.1.bn.num_batches_tracked", "decoder.decoder_block4.dec_att.aspp_deforms.2.bn.weight", "decoder.decoder_block4.dec_att.aspp_deforms.2.bn.bias", "decoder.decoder_block4.dec_att.aspp_deforms.2.bn.running_mean", "decoder.decoder_block4.dec_att.aspp_deforms.2.bn.running_var", "decoder.decoder_block4.dec_att.aspp_deforms.2.bn.num_batches_tracked", "decoder.decoder_block4.dec_att.global_avg_pool.2.weight", "decoder.decoder_block4.dec_att.global_avg_pool.2.bias", "decoder.decoder_block4.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block4.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block4.dec_att.global_avg_pool.2.num_batches_tracked", "decoder.decoder_block4.dec_att.bn1.weight", "decoder.decoder_block4.dec_att.bn1.bias", "decoder.decoder_block4.dec_att.bn1.running_mean", "decoder.decoder_block4.dec_att.bn1.running_var", "decoder.decoder_block4.dec_att.bn1.num_batches_tracked", "decoder.decoder_block4.bn_in.weight", "decoder.decoder_block4.bn_in.bias", "decoder.decoder_block4.bn_in.running_mean", "decoder.decoder_block4.bn_in.running_var", "decoder.decoder_block4.bn_in.num_batches_tracked", "decoder.decoder_block4.bn_out.weight", "decoder.decoder_block4.bn_out.bias", "decoder.decoder_block4.bn_out.running_mean", "decoder.decoder_block4.bn_out.running_var", "decoder.decoder_block4.bn_out.num_batches_tracked", "decoder.decoder_block3.dec_att.aspp1.bn.weight", "decoder.decoder_block3.dec_att.aspp1.bn.bias", "decoder.decoder_block3.dec_att.aspp1.bn.running_mean", "decoder.decoder_block3.dec_att.aspp1.bn.running_var", "decoder.decoder_block3.dec_att.aspp1.bn.num_batches_tracked", "decoder.decoder_block3.dec_att.aspp_deforms.0.bn.weight", "decoder.decoder_block3.dec_att.aspp_deforms.0.bn.bias", "decoder.decoder_block3.dec_att.aspp_deforms.0.bn.running_mean", "decoder.decoder_block3.dec_att.aspp_deforms.0.bn.running_var", "decoder.decoder_block3.dec_att.aspp_deforms.0.bn.num_batches_tracked", "decoder.decoder_block3.dec_att.aspp_deforms.1.bn.weight", "decoder.decoder_block3.dec_att.aspp_deforms.1.bn.bias", "decoder.decoder_block3.dec_att.aspp_deforms.1.bn.running_mean", "decoder.decoder_block3.dec_att.aspp_deforms.1.bn.running_var", "decoder.decoder_block3.dec_att.aspp_deforms.1.bn.num_batches_tracked", "decoder.decoder_block3.dec_att.aspp_deforms.2.bn.weight", "decoder.decoder_block3.dec_att.aspp_deforms.2.bn.bias", "decoder.decoder_block3.dec_att.aspp_deforms.2.bn.running_mean", "decoder.decoder_block3.dec_att.aspp_deforms.2.bn.running_var", "decoder.decoder_block3.dec_att.aspp_deforms.2.bn.num_batches_tracked", "decoder.decoder_block3.dec_att.global_avg_pool.2.weight", "decoder.decoder_block3.dec_att.global_avg_pool.2.bias", "decoder.decoder_block3.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block3.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block3.dec_att.global_avg_pool.2.num_batches_tracked", "decoder.decoder_block3.dec_att.bn1.weight", "decoder.decoder_block3.dec_att.bn1.bias", "decoder.decoder_block3.dec_att.bn1.running_mean", "decoder.decoder_block3.dec_att.bn1.running_var", "decoder.decoder_block3.dec_att.bn1.num_batches_tracked", "decoder.decoder_block3.bn_in.weight", "decoder.decoder_block3.bn_in.bias", "decoder.decoder_block3.bn_in.running_mean", "decoder.decoder_block3.bn_in.running_var", "decoder.decoder_block3.bn_in.num_batches_tracked", "decoder.decoder_block3.bn_out.weight", "decoder.decoder_block3.bn_out.bias", "decoder.decoder_block3.bn_out.running_mean", "decoder.decoder_block3.bn_out.running_var", "decoder.decoder_block3.bn_out.num_batches_tracked", "decoder.decoder_block2.dec_att.aspp1.bn.weight", "decoder.decoder_block2.dec_att.aspp1.bn.bias", "decoder.decoder_block2.dec_att.aspp1.bn.running_mean", "decoder.decoder_block2.dec_att.aspp1.bn.running_var", "decoder.decoder_block2.dec_att.aspp1.bn.num_batches_tracked", "decoder.decoder_block2.dec_att.aspp_deforms.0.bn.weight", "decoder.decoder_block2.dec_att.aspp_deforms.0.bn.bias", "decoder.decoder_block2.dec_att.aspp_deforms.0.bn.running_mean", "decoder.decoder_block2.dec_att.aspp_deforms.0.bn.running_var", "decoder.decoder_block2.dec_att.aspp_deforms.0.bn.num_batches_tracked", "decoder.decoder_block2.dec_att.aspp_deforms.1.bn.weight", "decoder.decoder_block2.dec_att.aspp_deforms.1.bn.bias", "decoder.decoder_block2.dec_att.aspp_deforms.1.bn.running_mean", "decoder.decoder_block2.dec_att.aspp_deforms.1.bn.running_var", "decoder.decoder_block2.dec_att.aspp_deforms.1.bn.num_batches_tracked", "decoder.decoder_block2.dec_att.aspp_deforms.2.bn.weight", "decoder.decoder_block2.dec_att.aspp_deforms.2.bn.bias", "decoder.decoder_block2.dec_att.aspp_deforms.2.bn.running_mean", "decoder.decoder_block2.dec_att.aspp_deforms.2.bn.running_var", "decoder.decoder_block2.dec_att.aspp_deforms.2.bn.num_batches_tracked", "decoder.decoder_block2.dec_att.global_avg_pool.2.weight", "decoder.decoder_block2.dec_att.global_avg_pool.2.bias", "decoder.decoder_block2.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block2.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block2.dec_att.global_avg_pool.2.num_batches_tracked", "decoder.decoder_block2.dec_att.bn1.weight", "decoder.decoder_block2.dec_att.bn1.bias", "decoder.decoder_block2.dec_att.bn1.running_mean", "decoder.decoder_block2.dec_att.bn1.running_var", "decoder.decoder_block2.dec_att.bn1.num_batches_tracked", "decoder.decoder_block2.bn_in.weight", "decoder.decoder_block2.bn_in.bias", "decoder.decoder_block2.bn_in.running_mean", "decoder.decoder_block2.bn_in.running_var", "decoder.decoder_block2.bn_in.num_batches_tracked", "decoder.decoder_block2.bn_out.weight", "decoder.decoder_block2.bn_out.bias", "decoder.decoder_block2.bn_out.running_mean", "decoder.decoder_block2.bn_out.running_var", "decoder.decoder_block2.bn_out.num_batches_tracked", "decoder.decoder_block1.dec_att.aspp1.bn.weight", "decoder.decoder_block1.dec_att.aspp1.bn.bias", "decoder.decoder_block1.dec_att.aspp1.bn.running_mean", "de
coder.decoder_block1.dec_att.aspp1.bn.running_var", "decoder.decoder_block1.dec_att.aspp1.bn.num_batches_tracked", "decoder.decoder_block1.dec_att.aspp_deforms.0.bn.weight", "decoder.decoder_block1.dec_att.aspp_
deforms.0.bn.bias", "decoder.decoder_block1.dec_att.aspp_deforms.0.bn.running_mean", "decoder.decoder_block1.dec_att.aspp_deforms.0.bn.running_var", "decoder.decoder_block1.dec_att.aspp_deforms.0.bn.num_batches_
tracked", "decoder.decoder_block1.dec_att.aspp_deforms.1.bn.weight", "decoder.decoder_block1.dec_att.aspp_deforms.1.bn.bias", "decoder.decoder_block1.dec_att.aspp_deforms.1.bn.running_mean", "decoder.decoder_blo
ck1.dec_att.aspp_deforms.1.bn.running_var", "decoder.decoder_block1.dec_att.aspp_deforms.1.bn.num_batches_tracked", "decoder.decoder_block1.dec_att.aspp_deforms.2.bn.weight", "decoder.decoder_block1.dec_att.aspp
_deforms.2.bn.bias", "decoder.decoder_block1.dec_att.aspp_deforms.2.bn.running_mean", "decoder.decoder_block1.dec_att.aspp_deforms.2.bn.running_var", "decoder.decoder_block1.dec_att.aspp_deforms.2.bn.num_batches
_tracked", "decoder.decoder_block1.dec_att.global_avg_pool.2.weight", "decoder.decoder_block1.dec_att.global_avg_pool.2.bias", "decoder.decoder_block1.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_bl
ock1.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block1.dec_att.global_avg_pool.2.num_batches_tracked", "decoder.decoder_block1.dec_att.bn1.weight", "decoder.decoder_block1.dec_att.bn1.bias", "decod
er.decoder_block1.dec_att.bn1.running_mean", "decoder.decoder_block1.dec_att.bn1.running_var", "decoder.decoder_block1.dec_att.bn1.num_batches_tracked", "decoder.decoder_block1.bn_in.weight", "decoder.decoder_bl
ock1.bn_in.bias", "decoder.decoder_block1.bn_in.running_mean", "decoder.decoder_block1.bn_in.running_var", "decoder.decoder_block1.bn_in.num_batches_tracked", "decoder.decoder_block1.bn_out.weight", "decoder.dec
oder_block1.bn_out.bias", "decoder.decoder_block1.bn_out.running_mean", "decoder.decoder_block1.bn_out.running_var", "decoder.decoder_block1.bn_out.num_batches_tracked", "decoder.gdt_convs_4.1.weight", "decoder.
gdt_convs_4.1.bias", "decoder.gdt_convs_4.1.running_mean", "decoder.gdt_convs_4.1.running_var", "decoder.gdt_convs_4.1.num_batches_tracked", "decoder.gdt_convs_3.1.weight", "decoder.gdt_convs_3.1.bias", "decoder
.gdt_convs_3.1.running_mean", "decoder.gdt_convs_3.1.running_var", "decoder.gdt_convs_3.1.num_batches_tracked", "decoder.gdt_convs_2.1.weight", "decoder.gdt_convs_2.1.bias", "decoder.gdt_convs_2.1.running_mean",
"decoder.gdt_convs_2.1.running_var", "decoder.gdt_convs_2.1.num_batches_tracked".

Training with customdataset

Hello
I am amazed at the performance of your created model.
So I want to training with custom data, but I'm having some issues.

  1. When I resumed training after the interruption, the training loss increased significantly. (Is this because the model weights are saved but the optimizer information is not?)

  2. In the init_models_optimizers function in train.py, there is a variable epoch_st. I think epoch_st should be a global variable, but is there a reason why you have it set up like this?

  3. I currently have a custom dataset of about 9000 images. Due to the small number of data, I am adding the DIS dataset and HRSOD to run training. Is it okay to train them like this? Or should I just train it with custom data? (I use BiRefNet_ep580.pth)

I look forward to your response, thank you.

How to remove black bacground of output image

After running the model and removal background, the output result has black background. My question is how to get output image without black, i mean save output image (image_preds[0]) as transparent image.

The code:

def predict(self, image) :
        images = [image]
        image_shapes = [image.shape[:2] for image in images]
        images = [array_to_pil_image(image, self.resolution) for image in images]
        image_preprocessor = ImagePreprocessor(resolution=self.resolution)
        images_proc = []
        for image in images:
            images_proc.append(image_preprocessor.proc(image))
        images_proc = torch.cat([image_proc.unsqueeze(0) for image_proc in images_proc])

        with torch.no_grad():
            scaled_preds_tensor = self.model(images_proc.to(self.device))[-1].sigmoid()  
        preds = []
        for image_shape, pred_tensor in zip(image_shapes, scaled_preds_tensor):
            if self.device == 'cuda':
                pred_tensor = pred_tensor.cpu()
            preds.append(torch.nn.functional.interpolate(pred_tensor.unsqueeze(0), size=image_shape, mode='bilinear', align_corners=True).squeeze().numpy())
        image_preds = []
        for image, pred in zip(images, preds):
            image = image.resize(pred.shape[::-1])
            pred = np.repeat(np.expand_dims(pred, axis=-1), 3, axis=-1)
            image_preds.append((pred * image).astype(np.uint8))

        return image, image_preds[0]

RuntimeError: Error(s) in loading state_dict for BiRefNet

File "inference.py", line 82, in main
model.load_state_dict(state_dict)
File "/home/js/AI_run/BiRef_env/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1671, in load_state_dict
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
RuntimeError: Error(s) in loading state_dict for BiRefNet:
Missing key(s) in state_dict: "squeeze_module.0.dec_att.global_avg_pool.2.weight", "squeeze_module.0.dec_att.global_avg_pool.2.bias", "squeeze_module.0.dec_att.global_avg_pool.2.running_mean", "squeeze_module.0.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block4.dec_att.global_avg_pool.2.weight", "decoder.decoder_block4.dec_att.global_avg_pool.2.bias", "decoder.decoder_block4.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block4.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block3.dec_att.global_avg_pool.2.weight", "decoder.decoder_block3.dec_att.global_avg_pool.2.bias", "decoder.decoder_block3.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block3.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block2.dec_att.global_avg_pool.2.weight", "decoder.decoder_block2.dec_att.global_avg_pool.2.bias", "decoder.decoder_block2.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block2.dec_att.global_avg_pool.2.running_var", "decoder.decoder_block1.dec_att.global_avg_pool.2.weight", "decoder.decoder_block1.dec_att.global_avg_pool.2.bias", "decoder.decoder_block1.dec_att.global_avg_pool.2.running_mean", "decoder.decoder_block1.dec_att.global_avg_pool.2.running_var".

When attempting to perform inference after training, several variables cannot be loaded. The environment used for training was torch 1.12.1+cu113, and for inference, torch 1.13.1+cu116 was used. Could this difference be causing the issue?

If I use model.load_state_dict(state_dict, strict=False) for loading the model, will there be a significant difference in performance?

Performance of ASPP module and ONNX conversion feasibility

I'm truly grateful and amazed by your model. However, I have a question. As mentioned in this issue (#26), the deform_conv2d is not working well.

According to the ONNX documentation, deform_conv2d seems to be supported from version 22 onwards.

So, my question is: In your code, there is a line that says 'config.dec_att == 'ASPP''. Do you know how well this ASPP module performs? I couldn't find any ablation or comparison experiments in the paper where something other than Deconv was used.
If the performance drop with ASPP is not significant, I would like to try converting the model to ONNX.

the training stalls

the training stalls, always in the first iteration of the first epoch,my gpu is 16G,my dataset just 39 pictures。
Uploading 2222.png…

Can you provide a script to inference a single image?

It's too hard for me to find the right weight and the right backbone & decoder & other things to load your provided pretrained weights. I just wanted to inference some images of mine T.T
So can you please provide a script to inference a single image? Thanks a lot!

No pth checkpoint for HRSOD

Hello, thank you for your work, can not find pretrained weights for Sailent Object Detection, the gdrive contains only checkpoints for DIS and COD. Could you please provide well-trained weights on HRSOD too.

model load mismatch

In inference.py, how to load different models other then BiRefNet-massive-epoch_240.pth?
I keep getting RuntimeError: Error(s) in loading state_dict for BiRefNet, but do not know where to set the models weights correctly

  for weights in weights_lst:
      print(weights.strip('.pth').split('epoch_')[-1])
      if int(weights.strip('.pth').split('epoch_')[-1]) % 1 != 0:
          continue
      print('\tInferencing {}...'.format(weights))
      state_dict = torch.load(weights, map_location='cpu')
      state_dict = check_state_dict(state_dict)
      model.load_state_dict(state_dict)
      model = model.to(device)

Remove background for comic drawings

Hello, thank you for the great model. The model works great most of the time, but it seems to have trouble with some "easy images" like the one below. The it and pants are getting erased. Is that a known issue? Thank you!
ComfyUI_temp_hslot_00001_
Output_temp_kiyby_00001_

Not getting the same quality as Hugging face demo

Dear Author,

Thank you for your amazing work.

I tried deploying and running your code in my local with the same requirements as mentioned in the hugging face repo and using the same model that is BiRefNet-massive-epoch_240.pth but still we get a different quality .

We are using cuda : 12.3 and python 3.10.

Can you please guide me what I am doing wrong so I can get the same quality as hugging face demo?

Thanks in advance.

After the two commits Yesterday the mask the accuracy has dipped

image

This was working fine till yesterday but some how with the new changes it doesn't work

I also did a hard reset but it asking for such file or directory: '/root/autodl-tmp/weights/swin_large_patch4_window12_384_22kto1k.pth'

i am downloading and trying it out but idk what happened.Looked at commits cannot see anything major

Finetuning with small dataset

First of all, thank you for your great project. I want to ask you for some recommendation about finetuning with a small dataset (Around 400 images), my problem is main car segmentation (Only segment one car even if the image contains multiple cars, the main car is the biggest one and in the middle of the image)

  • Which layers should i freeze?
  • What learning rate should i start with?
  • Because the car segmentation is pretty simple, should i turn off any loss components?
  • Do you have any idea of how many images is pretty enough for training?
    Due to the limit of resources so i can't try all these thing, your suggestion will be really meaningful to me.

issues with the training Code

Hey was starting my training ,

I just noticed some issues with the code quality .

like having self.sys_home_dir = '/root/autodl-tmp' or self.data_root_dir = os.path.join(self.sys_home_dir, 'datasets/dis')

these kind of things make it hard to quickly start the training and one had to just dig into the code find the issues

Convolution kernal size mismatch

When running images with tensor size :-torch.Size([3, 1365, 1024])

/content/BiRefNet/models/baseline.py in get_patches_batch(self, x, p)
227 for column_x in columns_x:
228 patches_x += [p.unsqueeze(0) for p in torch.split(column_x, split_size_or_sections=_size_h, dim=-2)]
--> 229 patch_sample = torch.cat(patches_x, dim=1)
230 patches_batch.append(patch_sample)
231 return torch.cat(patches_batch, dim=0)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 86 but got size 75 for tensor number 15 in the list.

I am trying to add padding but facing issues .

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.