Giter VIP home page Giter VIP logo

sparsetransformer's Introduction

SpTr: PyTorch Spatially Sparse Transformer Library

SparseTransformer (SpTr) provides a fast, memory-efficient, and easy-to-use implementation for sparse transformer with varying token numbers (e.g., window transformer for 3D point cloud).

SpTr has been used by the following works:

  • Spherical Transformer for LiDAR-based 3D Recognition (CVPR 2023): [Paper] [Code]

  • Stratified Transformer for 3D Point Cloud Segmentation (CVPR 2022): [Paper] [Code]

Installation

Install Dependency

pip install torch==1.8.0+cu111 torchvision==0.9.0+cu111 torchaudio==0.8.0 -f https://download.pytorch.org/whl/torch_stable.html
pip install torch_scatter==2.0.9
pip install torch_geometric==1.7.2

Compile sptr

python3 setup.py install

Usage

SpTr can be easily used in most current transformer-based 3D point cloud networks, with only several minor modifications. First, define the attention module sptr.VarLengthMultiheadSA. Then, wrap the input features and indices into sptr.SparseTrTensor, and forward it into the module. That's all. A simple example is as follows. For more complex usage, you can refer to the code of above works (e.g., SphereFormer, StratifiedFormer).

Example

import sptr

# Define module
dim = 48
num_heads = 3
indice_key = 'sptr_0'
window_size = np.array([0.4, 0.4, 0.4])  # can also be integers for voxel-based methods
shift_win = False  # whether to adopt shifted window
self.attn = sptr.VarLengthMultiheadSA(
    dim, 
    num_heads, 
    indice_key, 
    window_size, 
    shift_win
)

# Wrap the input features and indices into SparseTrTensor. Note: indices can be either intergers for voxel-based methods or floats (i.e., xyz) for point-based methods
# feats: [N, C], indices: [N, 4] with batch indices in the 0-th column
input_tensor = sptr.SparseTrTensor(feats, indices, spatial_shape=None, batch_size=None)
output_tensor = self.attn(input_tensor)

# Extract features from output tensor
output_feats = output_tensor.query_feats

Authors

Xin Lai (a Ph.D student at CSE CUHK, [email protected]) - Initial CUDA implementation, maintainance.

Fanbin Lu (a Ph.D student at CSE CUHK) - Improve CUDA implementation, maintainance.

Yukang Chen (a Ph.D student at CSE CUHK) - Maintainance.

Cite

If you find this project useful, please consider citing

@inproceedings{lai2023spherical,
  title={Spherical Transformer for LiDAR-based 3D Recognition},
  author={Lai, Xin and Chen, Yukang and Lu, Fanbin and Liu, Jianhui and Jia, Jiaya},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  year={2023}
}
@inproceedings{lai2022stratified,
  title={Stratified transformer for 3d point cloud segmentation},
  author={Lai, Xin and Liu, Jianhui and Jiang, Li and Wang, Liwei and Zhao, Hengshuang and Liu, Shu and Qi, Xiaojuan and Jia, Jiaya},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={8500--8509},
  year={2022}
}

License

This project is licensed under the Apache license 2.0 License.

sparsetransformer's People

Contributors

x-lai avatar erjanmx avatar yukang2017 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.