Giter VIP home page Giter VIP logo

Comments (4)

endeleze avatar endeleze commented on June 4, 2024 1

Wow cool, thanks for your reply, I just tested it again and it speedup 5x! thank you!

from lightglue.

Phil26AT avatar Phil26AT commented on June 4, 2024

Hi @endeleze, on our side batching does yield higher throughput, but early stopping / point pruning is currently not supported. Without seeing your code it is hard to identify your issue, but if you are willing to share your code I can have a look. :)

from lightglue.

endeleze avatar endeleze commented on June 4, 2024

Sure, here're my testing codes:

When I use batch=30:

      n = 30 

      config={'width_confidence': 0.99,
              'depth_confidence': 0.95}
      
      Model_lg=LightGlue(pretrained='superpoint', **config).eval()
      local_feature_matcher = Model_lg.to('cuda')

      keys = ['descriptors0', 'keypoints0', 'image_size0', 'keypoint_scores0',
              'descriptors1', 'keypoints1', 'image_size1', 'keypoint_scores1']
      
      for key in keys:
          pred[key] = torch.cat([pred[key]] * n)
      
      
      for i in range(10):
          last_time=time()
          # for i in range(30):
          with torch.inference_mode():
              pred1 = local_feature_matcher(pred)
          current_time=time()
          print(current_time-last_time)

The average time of each loop is around 0.53 seconds, and then I use single input:

        n = 1
        
        keys = ['descriptors0', 'keypoints0', 'image_size0', 'keypoint_scores0',
                'descriptors1', 'keypoints1', 'image_size1', 'keypoint_scores1']
        
        for key in keys:
            pred[key] = torch.cat([pred[key]] * n)
        
        
        for i in range(10):
            last_time=time()
            for i in range(30):
                with torch.inference_mode():
                    pred1 = local_feature_matcher(pred)
            current_time=time()
            print(current_time-last_time)

And which each loop the time is around 0.19 seconds.

I have modified the codes in lightglue.py:

      def apply_cached_rotary_emb(
              freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
          t_rot=rotate_half(t).permute(1,0,2,3)
          t=t.permute(1,0,2,3)
          return ((t * freqs[0]) + (t_rot * freqs[1])).permute(1,0,2,3)

And:

      def _forward(self, data: dict) -> dict:
          for key in self.required_data_keys:
              assert key in data, f'Missing key {key} in data'
          kpts0_, kpts1_ = data['keypoints0'], data['keypoints1']
          b, m, _ = kpts0_.shape
          b, n, _ = kpts1_.shape
  
          kpts0 = normalize_keypoints(
              kpts0_, size=data.get('image_size0'), shape=None)
          kpts1 = normalize_keypoints(
              kpts1_, size=data.get('image_size1'), shape=None)
  
          assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
          assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
  
          desc0 = data['descriptors0'].detach().permute(0,2,1)
          desc1 = data['descriptors1'].detach().permute(0,2,1)
  
          assert(desc0.shape[-1] == self.conf.input_dim)
          assert(desc1.shape[-1] == self.conf.input_dim)
  
          if torch.is_autocast_enabled():
              desc0 = desc0.half()
              desc1 = desc1.half()
  
          desc0 = self.input_proj(desc0)
          desc1 = self.input_proj(desc1)
  
          # cache positional embeddings
          encoding0 = self.posenc(kpts0)
          encoding1 = self.posenc(kpts1)
  
          # GNN + final_proj + assignment
          ind0 = torch.arange(0, m).to(device=kpts0.device).expand(b,-1)
          ind1 = torch.arange(0, n).to(device=kpts0.device).expand(b,-1)
          prune0 = torch.ones_like(ind0)  # store layer where pruning is detected
          prune1 = torch.ones_like(ind1)
          dec, wic = self.conf.depth_confidence, self.conf.width_confidence
          token0, token1 = None, None
          for i in range(self.conf.n_layers):
              # self+cross attention
              desc0, desc1 = self.self_attn[i](
                  desc0, desc1, encoding0, encoding1)
              desc0, desc1 = self.cross_attn[i](desc0, desc1)
              if i == self.conf.n_layers - 1:
                  continue  # no early stopping or adaptive width at last layer
              if dec > 0:  # early stopping
                  token0, token1 = self.token_confidence[i](desc0, desc1)
                  if self.stop(token0, token1, self.conf_th(i), dec, m+n):
                      break
              if wic > 0:  # point pruning
                  match0, match1 = self.log_assignment[i].scores(desc0, desc1)
                  mask0 = self.get_mask(token0, match0, self.conf_th(i), 1-wic)
                  mask1 = self.get_mask(token1, match1, self.conf_th(i), 1-wic)
                  ind0, ind1 = ind0[mask0][None], ind1[mask1][None]
                  desc0, desc1 = desc0[mask0][None], desc1[mask1][None]
                  if desc0.shape[-2] == 0 or desc1.shape[-2] == 0:
                      break
                  encoding0 = encoding0[:, :, mask0][:, None]
                  encoding1 = encoding1[:, :, mask1][:, None]
              prune0[:, ind0] += 1
              prune1[:, ind1] += 1
  
          if wic > 0:  # scatter with indices after pruning
              scores_, _ = self.log_assignment[i](desc0, desc1)
              dt, dev = scores_.dtype, scores_.device
              scores = torch.zeros(b, m+1, n+1, dtype=dt, device=dev)
              scores[:, :-1, :-1] = -torch.inf
              scores[:, ind0[0], -1] = scores_[:, :-1, -1]
              scores[:, -1, ind1[0]] = scores_[:, -1, :-1]
              x, y = torch.meshgrid(ind0[0], ind1[0], indexing='ij')
              scores[:, x, y] = scores_[:, :-1, :-1]
          else:
              scores, _ = self.log_assignment[i](desc0, desc1)
  
          m0, m1, mscores0, mscores1 = filter_matches(
              scores, self.conf.filter_threshold)
  
          return {
              'log_assignment': scores,
              'matches0': m0,
              'matches1': m1,
              'matching_scores0': mscores0,
              'matching_scores1': mscores1,
              'stop': i+1,
              'prune0': prune0,
              'prune1': prune1,
          }

from lightglue.

Phil26AT avatar Phil26AT commented on June 4, 2024

Hi, thanks for providing the code!

      config={'width_confidence': 0.99,
                    'depth_confidence': 0.95}

Both these options are not supported with batch inference. I set both to -1 (off).

When rerunning your code with N=30 and 2048 pts/image, I get 0.45s for batched inference (67 FPS, 100% GPU utilization) and 0.61 for non-batched (49 FPS, 84% GPU utilization), which is a speedup of 36%.

Furthermore, if one decreases the number of key points per image to 512, batched inference yields a speedup of > 500%. This is to be expected - the GPU utilization is much smaller with fewer keypoints, and thus the performance gain from batching is higher.

Maybe the problem is indeed with the configs you used. For completeness here is the code I used for this experiment:

from lightglue import LightGlue, SuperPoint
from lightglue.utils import load_image
from time import time
from pathlib import Path
import torch
images = Path('assets')
device=torch.device('cuda')

extractor = SuperPoint(max_num_keypoints=2048, detection_threshold=0.0).eval().to(device)  # load the extractor
match_conf = {
    'width_confidence': -1,  # for point pruning
    'depth_confidence': -1,  # for early stopping,
    'flash': True,
}
matcher = LightGlue(features='superpoint', **match_conf).eval().to(device)

image0, scales0 = load_image(images / 'DSC_0411.JPG')
image1, scales1 = load_image(images / 'DSC_0410.JPG')

# batched
n = 30
feats0 = extractor.extract(image0.to(device))
feats1 = extractor.extract(image1.to(device))
for key in feats0.keys():
    feats0[key] = torch.cat([feats0[key]] * n)
    feats1[key] = torch.cat([feats1[key]] * n)
for i in range(10):
    last_time=time()
    with torch.inference_mode():
        pred1 = matcher({'image0': feats0, 'image1': feats1})
    current_time=time()
    print(current_time-last_time)


# non-batched
feats0 = extractor.extract(image0.to(device))
feats1 = extractor.extract(image1.to(device))
for i in range(10):
    torch.cuda.synchronize()
    last_time=time()
    preds = []
    for i in range(n):
        t = time()
        with torch.inference_mode():
            preds.append(matcher({'image0': feats0, 'image1': feats1}))
    torch.cuda.synchronize()
    current_time=time()
    print(current_time-last_time)

from lightglue.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.