Giter VIP home page Giter VIP logo

robustsam's Introduction

RobustSAM: Segment Anything Robustly on Degraded Images (CVPR 2024 Highlight)

Official repository for RobustSAM: Segment Anything Robustly on Degraded Images

✨ Training code and dataset will be released before Aug!

Project Page | Paper | Video | Dataset

Updates

  • July 2024: ✨ Checkpoints for different ViT backbones are released!
  • June 2024: ✨ Inference code has been released!
  • Feb 2024: ✨ RobustSAM was accepted into CVPR 2024!

Introduction

Segment Anything Model (SAM) has emerged as a transformative approach in image segmentation, acclaimed for its robust zero-shot segmentation capabilities and flexible prompting system. Nonetheless, its performance is challenged by images with degraded quality. Addressing this limitation, we propose the Robust Segment Anything Model (RobustSAM), which enhances SAM's performance on low-quality images while preserving its promptability and zero-shot generalization.

Our method leverages the pre-trained SAM model with only marginal parameter increments and computational requirements. The additional parameters of RobustSAM can be optimized within 30 hours on eight GPUs, demonstrating its feasibility and practicality for typical research laboratories. We also introduce the Robust-Seg dataset, a collection of 688K image-mask pairs with different degradations designed to train and evaluate our model optimally. Extensive experiments across various segmentation tasks and datasets confirm RobustSAM's superior performance, especially under zero-shot conditions, underscoring its potential for extensive real-world application. Additionally, our method has been shown to effectively improve the performance of SAM-based downstream tasks such as single image dehazing and deblurring.

image

Setup

  1. Create a conda environment and activate it.
conda create --name robustsam python=3.10 -y
conda activate robustsam
  1. Clone and enter into repo directory.
git clone https://github.com/robustsam/RobustSAM
cd RobustSAM
  1. Use command below to check your CUDA version.
nvidia-smi
  1. Replace the CUDA version with yours in command below.
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu[$YOUR_CUDA_VERSION]
# For example: pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu117 # cu117 = CUDA_version_11.7
  1. Install remaining dependencies
pip install -r requirements.txt
  1. Download pretrained RobustSAM checkpoints of different sizes and place them into current directory.

Demo

We have prepared some images im demo_images folder for demo purpose. Besides, two prompting modes are available (box prompts and point prompts).

  • For box prompt:
python eval.py --bbox --model_size l
  • For point prompt:
python eval.py --model_size l

In default, demo results will be saved to demo_result/[$PROMPT_TYPE].

Comparison of computational requirements

image

Visual Comparison

image

Quantitative Comparison

Seen dataset with synthetic degradation

image

Unseen dataset with synthetic degradation

image

Unseen dataset with real degradation

image

Reference

If you find this work useful, please consider citing us!

@inproceedings{chen2024robustsam,
  title={RobustSAM: Segment Anything Robustly on Degraded Images},
  author={Chen, Wei-Ting and Vong, Yu-Jiet and Kuo, Sy-Yen and Ma, Sizhou and Wang, Jian},
  journal={CVPR},
  year={2024}
}

Acknowledgements

We thank the authors of SAM from which our repo is based off of.

robustsam's People

Contributors

robustsam avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar

robustsam's Issues

AutomaticMaskGenerator not working

Here is the driver code:

sam = sam_model_registry["vit_l"](opt=None, checkpoint="robustsam_checkpoint.pth").to(device=device)
mask_generator = SamAutomaticMaskGenerator(
    model=sam,
    points_per_side=32,
    pred_iou_thresh=0.86,
    stability_score_thresh=0.5,
    crop_n_layers=1,
    crop_n_points_downscale_factor=2,
    min_mask_region_area=100,
)

masks = mask_generator.generate(image) # image is a numpy array with shape (2048,2048,3)

Error:

  0%|          | 0/1 [00:01<?, ?it/s]
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/tmp/ipykernel_617/4150095890.py in <module>
     33 
     34 for (x1,y1,x2,y2) in tqdm.tqdm(calculate_slice_bboxes(image.shape[0], image.shape[1], 2048, 2048)):
---> 35     reconstructed_image[y1:y2, x1:x2] = process_window(image[y1:y2, x1:x2])

/tmp/ipykernel_617/1729186085.py in process_window(image)
      2     # generate masks
      3     final_mask = image.copy()
----> 4     masks = mask_generator.generate(image)
      5     masks = [x for x in masks if x['area'] > 1000]
      6     # masks = [x for x in masks if x['area'] > 100]

/opt/conda/lib/python3.8/site-packages/torch/utils/_contextlib.py in decorate_context(*args, **kwargs)
    113     def decorate_context(*args, **kwargs):
    114         with ctx_factory():
--> 115             return func(*args, **kwargs)
    116 
    117     return decorate_context

/RobustSAM/robust_segment_anything/automatic_mask_generator.py in generate(self, image)
    161 
    162         # Generate masks
--> 163         mask_data = self._generate_masks(image)
    164 
    165         # Filter small disconnected regions and holes in masks

/RobustSAM/robust_segment_anything/automatic_mask_generator.py in _generate_masks(self, image)
    205         data = MaskData()
    206         for crop_box, layer_idx in zip(crop_boxes, layer_idxs):
--> 207             crop_data = self._process_crop(image, crop_box, layer_idx, orig_size)
    208             data.cat(crop_data)
    209 

/RobustSAM/robust_segment_anything/automatic_mask_generator.py in _process_crop(self, image, crop_box, crop_layer_idx, orig_size)
    246         count = 0
    247         for (points,) in batch_iterator(self.points_per_batch, points_for_image):
--> 248             batch_data = self._process_batch(points, cropped_im_size, crop_box, orig_size)
    249             # print('Second: ', mask_logits.shape)
    250 

/RobustSAM/robust_segment_anything/automatic_mask_generator.py in _process_batch(self, points, im_size, crop_box, orig_size)
    285         in_points = torch.as_tensor(transformed_points, device=self.predictor.device)
    286         in_labels = torch.ones(in_points.shape[0], dtype=torch.int, device=in_points.device)
--> 287         masks, iou_preds, _ = self.predictor.predict_torch(
    288             in_points[:, None, :],
    289             in_labels[:, None],

/RobustSAM/robust_segment_anything/predictor.py in predict_torch(self, point_coords, point_labels, boxes, mask_input, multimask_output, return_logits)
    228 
    229         # Predict masks
--> 230         low_res_masks, iou_predictions = self.model.mask_decoder(
    231             image_embeddings=self.features,
    232             image_pe=self.model.prompt_encoder.get_dense_pe(),

/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *args, **kwargs)
   1499                 or _global_backward_pre_hooks or _global_backward_hooks
   1500                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1501             return forward_call(*args, **kwargs)
   1502         # Do not call functions when jit is used
   1503         full_backward_hooks, non_full_backward_hooks = [], []

TypeError: forward() missing 1 required positional argument: 'encoder_features'

On further investigation, I found this https://github.com/robustsam/RobustSAM/blob/main/robust_segment_anything/predictor.py#L230-L236 which is not passing the encoder_features which is a required param for mask decoder here: https://github.com/robustsam/RobustSAM/blob/main/robust_segment_anything/modeling/mask_decoder.py#L98C9-L117 Interestingly #TODO is mentioned for that param aswell.

where is the checkpoints

I don't see robustsam_checkpoint.pt file anywhere. Am I missing something? or the trained checkpoints are not published yet

about train code?

Thank you for your great work!
Thank you very much for your excellent work on the sam segmentation.
I was wondering if you plan to release your train code anytime soon.

Data

Will the training data be released, thanks

Do you have any vit_h checkpoint?

I click on the link to download Checkpoint, but there is only one file, when I use this file for vit_h, I get an error.
Do you have any vit_h checkpoint?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.