Giter VIP home page Giter VIP logo

segvit's Introduction

Official Pytorch Implementation of SegViT [ckpt]

SegViT: Semantic Segmentation with Plain Vision Transformers

Zhang, Bowen and Tian, Zhi and Tang, Quan and Chu, Xiangxiang and Wei, Xiaolin and Shen, Chunhua and Liu, Yifan.

NeurIPS 2022. [paper]

SegViTv2: Exploring Efficient and Continual Semantic Segmentation with Plain Vision Transformers

Bowen Zhang, Liyang Liu, Minh Hieu Phan, Zhi Tian, Chunhua Shen and Yifan Liu.

IJCV 2023. [paper] [we are refactoring code for release ...]

This repository contains the official Pytorch implementation of training & evaluation code and the pretrained models for SegViT and the extended version SegViT v2.

Highlights

  • Simple Decoder: The Attention-to-Mask (ATM) decoder provides a simple segmentation head for Plain Vision Transformer, which is easy to extend to other downstream tasks.
  • Light Structure: We proposed Shrunk structure that can save up to 40% computational cost in a structure with ViT backbone.
  • Stronger performance: We got state-of-the-art performance mIoU 55.2% on ADE20K, mIoU 50.3% on COCOStuff10K, and mIoU 65.3% on PASCAL-Context datasets with the least amount of computational cost among counterparts using ViT backbone.
  • Scaleability SegViT v2 employed more powerful backbones (BEiT-V2) obtained state-of-the-art performance mIoU 58.2% (MS) on ADE20K, mIoU 53.5% (MS) on COCOStuff10K, and mIoU 67.14% (MS) on PASCAL-Context datasets, showcasing strong scalability.
  • Continuals Learning We propose to adapt SegViT v2 for continual semantic segmentation, demonstrating nearly zero forgetting of previously learned knowledge.

As shown in the following figure, the similarity between the class query and the image features is transfered to the segmentation mask.

Getting started

  1. Install the mmsegmentation library and some required packages.
pip install mmcv-full==1.4.4 mmsegmentation==0.24.0
pip install scipy timm

Training

python tools/dist_train.sh  configs/segvit/segvit_vit-l_jax_640x640_160k_ade20k.py 

Evaluation

python tools/dist_test.sh configs/segvit/segvit_vit-l_jax_640x640_160k_ade20k.py   {path_to_ckpt}

Datasets

Please follow the instructions of mmsegmentation data preparation

Results

Model backbone datasets mIoU mIoU (ms) GFlops ckpt
Vit-Base ADE20k 51.3 53.0 120.9 model
Vit-Large (Shrunk) ADE20k 53.9 55.1 373.5 model
Vit-Large ADE20k 54.6 55.2 637.9 model
Vit-Large (Shrunk) COCOStuff10K 49.1 49.4 224.8 model
Vit-Large COCOStuff10K 49.9 50.3 383.9 model
Vit-Large (Shrunk) PASCAL-Context (59cls) 62.3 63.7 186.9 model
Vit-Large PASCAL-Context (59cls) 64.1 65.3 321.6 model

License

For academic use, this project is licensed under the 2-clause BSD License - see the LICENSE file for details. For commercial use, please contact the authors.

Citation

@article{zhang2022segvit,
  title={SegViT: Semantic Segmentation with Plain Vision Transformers},
  author={Zhang, Bowen and Tian, Zhi and Tang, Quan and Chu, Xiangxiang and Wei, Xiaolin and Shen, Chunhua and Liu, Yifan},
  journal={NeurIPS},
  year={2022}
}

@article{zhang2023segvitv2,
  title={SegViTv2: Exploring Efficient and Continual Semantic Segmentation with Plain Vision Transformers},
  author={Zhang, Bowen and Liu, Liyang and Phan, Minh Hieu and Tian, Zhi and Shen, Chunhua and Liu, Yifan},
  journal={IJCV},
  year={2023}
}

segvit's People

Contributors

akideliu avatar cshen avatar irfanicmll avatar zbwxp avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar

segvit's Issues

use checkpointing to save memory

Did you try checkpointing? It will cause the error: Expected to mark a variable ready only once. Without checkpointing, I can't even train the large model with 8 RTX3090.

Additional loss term for continual learning

Hello,
I have some questions for SegViT v2.
When you used SegViT v2 for continual learning, is there any additional loss term to handle forgetting or background shift problem?

> We trained for around 8 hours for SegViT Large for the Pascal-context dataset with a node of 4 A100.

          > We trained for around 8 hours for SegViT Large for the Pascal-context dataset with a node of 4 A100.

Hi, I have a naive question, is the pacscal context configuration defaulted to 60 categories? If I want to change the pascal context to include 59 classes, it is model=[...decode_head=dict(...num_classes=59...) in configs/segvit/segvit...480..80k_pc.py ..] and loss_decode=dict(...num_classes=59...)? Thanks again!
QQ图片20230513144906

Originally posted by @1787648106 in #7 (comment)

Doubt About the Paper: Segmentation Mask

If the Decoder has 8 heads (at it is stated in the config) how do you convert this to NxL? which is later reshaped the the segmentation mask. If you have 8 heads, We will have 8 posible masks right?

模型的权重已无法下载

CloudStor 已于 2023 年 12 月 15 日星期五中午 12 点(澳大利亚东部时间)停用。
FileSender 将继续作为独立服务。
方便作者更新一下权重的下载路径吗

random seed

Could you provide the seed you use for reproducing the results?

Pre-trained weight files missing

Hi, I think the server hosting the pre-trained weight file was decommissioned recently. Is it possible to host the models somewhere else?

Set semseg as logit or probability?

def semantic_inference(self, mask_cls, mask_pred):

semseg = mask_cls.softmax(-1) @ mask_pred.sigmoid()
which will result in: 1>= semseg >= 0
When we compute loss_func, we set semseg as a logits, and compute prob by semseg.sigmoid()
If 1>= semseg >= 0,the probs must be >= 0.5, but we need the prob >=0.
Maybe I am wrong or I miss some calculation process,please feel free to advise me。

Continual Learning part

Nice work! I am mostly interested in your recent work SegViTv2, which can be easily extended to semantic segmentation under the continual learning setting. Where is the corresponding code? Look forward to your reply, thanks very much.

GPU Configuration for Training?

Thank you so much for your great work.
Could you please provide information about the number and type of GPUs you used for training the models?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.