Giter VIP home page Giter VIP logo

samplenet's Introduction

SampleNet: Differentiable Point Cloud Sampling

Created by Itai Lang, Asaf Manor and Shai Avidan from Tel-Aviv University.

teaser

Introduction

This work is based on our arXiv tech report. Please read it for more information.

There is a growing number of tasks that work directly on point clouds. As the size of the point cloud grows, so do the computational demands of these tasks. A possible solution is to sample the point cloud first. Classic sampling approaches, such as farthest point sampling (FPS), do not consider the downstream task. A recent work showed that learning a task-specific sampling can improve results significantly. However, the proposed technique did not deal with the non-differentiability of the sampling operation and offered a workaround instead.

We introduce a novel differentiable relaxation for point cloud sampling. Our approach employs a soft projection operation that approximates sampled points as a mixture of points in the primary input cloud. The approximation is controlled by a temperature parameter and converges to regular sampling when the temperature goes to zero. During training, we use a projection loss that encourages the temperature to drop, thereby driving every sample point to be close to one of the input points.

This approximation scheme leads to consistently good results on various applications such as classification, retrieval, and geometric reconstruction. We also show that the proposed sampling network can be used as a front to a point cloud registration network. This is a challenging task since sampling must be consistent across two different point clouds. In all cases, our method works better than existing non-learned and learned sampling alternatives.

Citation

If you find our work useful in your research, please consider citing:

@article{lang2019samplenet,
  title={SampleNet: Differentiable Point Cloud Sampling},
  author={Lang, Itai and Manor, Asaf and Avidan, Shai},
  journal={arXiv preprint arXiv:1912.03663},
  year={2019}
}

Installation and usage

This project contains three sub-directories, each one is a stand-alone project with it's own instructions. Please see the README.md in classification, registration, and reconstruction directories.

Usage of SampleNet in another project

Classification and reconstruction projects are implemented in TensorFlow. Registration is implemented with PyTorch library. We provide here an example code snippet, for the usage of PyTorch implementation of SampleNet with any task. The source files of SampleNet for this example are in registration/src/ folder.

import torch
from src import SampleNet, sputils

"""This code bit assumes a defined and pretrained task_model(),
data, and an optimizer."""

"""Get SampleNet parsing options and add your own."""
parser = sputils.get_parser()
args = parser.parse_args()

"""Create a data loader."""
trainloader = torch.utils.data.DataLoader(
    DATA, batch_size=32, shuffle=True
)

"""Create a SampleNet sampler instance."""
sampler = SampleNet(
    num_out_points=args.num_out_points,
    bottleneck_size=args.bottleneck_size,
    group_size=args.projection_group_size,
    initial_temperature=1.0,
    input_shape="bnc",
    output_shape="bnc",
)

"""For inference time behavior, set sampler.training = False."""
sampler.training = True

"""Training routine."""
for epoch in EPOCHS:
    for pc in trainloader:
        # Sample and predict
        simp_pc, proj_pc = sampler(pc)
        pred = task_model(proj_pc)

        # Compute losses
        simplification_loss = sampler.get_simplification_loss(
                pc, simp_pc, args.num_out_points
        )
        projection_loss = sampler.get_projection_loss()
        samplenet_loss = args.alpha * simplification_loss + args.lmbda * projection_loss

        task_loss = task_model.loss(pred)

        # Equation (1) in SampleNet paper
        loss = task_loss + samplenet_loss

        # Backward + Optimize
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

License

This project is licensed under the terms of the MIT license (see LICENSE for details).

samplenet's People

Contributors

asafmanor avatar itailang avatar

Watchers

 avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.