Comments (4)
Wow cool, thanks for your reply, I just tested it again and it speedup 5x! thank you!
from lightglue.
Hi @endeleze, on our side batching does yield higher throughput, but early stopping / point pruning is currently not supported. Without seeing your code it is hard to identify your issue, but if you are willing to share your code I can have a look. :)
from lightglue.
Sure, here're my testing codes:
When I use batch=30:
n = 30
config={'width_confidence': 0.99,
'depth_confidence': 0.95}
Model_lg=LightGlue(pretrained='superpoint', **config).eval()
local_feature_matcher = Model_lg.to('cuda')
keys = ['descriptors0', 'keypoints0', 'image_size0', 'keypoint_scores0',
'descriptors1', 'keypoints1', 'image_size1', 'keypoint_scores1']
for key in keys:
pred[key] = torch.cat([pred[key]] * n)
for i in range(10):
last_time=time()
# for i in range(30):
with torch.inference_mode():
pred1 = local_feature_matcher(pred)
current_time=time()
print(current_time-last_time)
The average time of each loop is around 0.53 seconds, and then I use single input:
n = 1
keys = ['descriptors0', 'keypoints0', 'image_size0', 'keypoint_scores0',
'descriptors1', 'keypoints1', 'image_size1', 'keypoint_scores1']
for key in keys:
pred[key] = torch.cat([pred[key]] * n)
for i in range(10):
last_time=time()
for i in range(30):
with torch.inference_mode():
pred1 = local_feature_matcher(pred)
current_time=time()
print(current_time-last_time)
And which each loop the time is around 0.19 seconds.
I have modified the codes in lightglue.py:
def apply_cached_rotary_emb(
freqs: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
t_rot=rotate_half(t).permute(1,0,2,3)
t=t.permute(1,0,2,3)
return ((t * freqs[0]) + (t_rot * freqs[1])).permute(1,0,2,3)
And:
def _forward(self, data: dict) -> dict:
for key in self.required_data_keys:
assert key in data, f'Missing key {key} in data'
kpts0_, kpts1_ = data['keypoints0'], data['keypoints1']
b, m, _ = kpts0_.shape
b, n, _ = kpts1_.shape
kpts0 = normalize_keypoints(
kpts0_, size=data.get('image_size0'), shape=None)
kpts1 = normalize_keypoints(
kpts1_, size=data.get('image_size1'), shape=None)
assert torch.all(kpts0 >= -1) and torch.all(kpts0 <= 1)
assert torch.all(kpts1 >= -1) and torch.all(kpts1 <= 1)
desc0 = data['descriptors0'].detach().permute(0,2,1)
desc1 = data['descriptors1'].detach().permute(0,2,1)
assert(desc0.shape[-1] == self.conf.input_dim)
assert(desc1.shape[-1] == self.conf.input_dim)
if torch.is_autocast_enabled():
desc0 = desc0.half()
desc1 = desc1.half()
desc0 = self.input_proj(desc0)
desc1 = self.input_proj(desc1)
# cache positional embeddings
encoding0 = self.posenc(kpts0)
encoding1 = self.posenc(kpts1)
# GNN + final_proj + assignment
ind0 = torch.arange(0, m).to(device=kpts0.device).expand(b,-1)
ind1 = torch.arange(0, n).to(device=kpts0.device).expand(b,-1)
prune0 = torch.ones_like(ind0) # store layer where pruning is detected
prune1 = torch.ones_like(ind1)
dec, wic = self.conf.depth_confidence, self.conf.width_confidence
token0, token1 = None, None
for i in range(self.conf.n_layers):
# self+cross attention
desc0, desc1 = self.self_attn[i](
desc0, desc1, encoding0, encoding1)
desc0, desc1 = self.cross_attn[i](desc0, desc1)
if i == self.conf.n_layers - 1:
continue # no early stopping or adaptive width at last layer
if dec > 0: # early stopping
token0, token1 = self.token_confidence[i](desc0, desc1)
if self.stop(token0, token1, self.conf_th(i), dec, m+n):
break
if wic > 0: # point pruning
match0, match1 = self.log_assignment[i].scores(desc0, desc1)
mask0 = self.get_mask(token0, match0, self.conf_th(i), 1-wic)
mask1 = self.get_mask(token1, match1, self.conf_th(i), 1-wic)
ind0, ind1 = ind0[mask0][None], ind1[mask1][None]
desc0, desc1 = desc0[mask0][None], desc1[mask1][None]
if desc0.shape[-2] == 0 or desc1.shape[-2] == 0:
break
encoding0 = encoding0[:, :, mask0][:, None]
encoding1 = encoding1[:, :, mask1][:, None]
prune0[:, ind0] += 1
prune1[:, ind1] += 1
if wic > 0: # scatter with indices after pruning
scores_, _ = self.log_assignment[i](desc0, desc1)
dt, dev = scores_.dtype, scores_.device
scores = torch.zeros(b, m+1, n+1, dtype=dt, device=dev)
scores[:, :-1, :-1] = -torch.inf
scores[:, ind0[0], -1] = scores_[:, :-1, -1]
scores[:, -1, ind1[0]] = scores_[:, -1, :-1]
x, y = torch.meshgrid(ind0[0], ind1[0], indexing='ij')
scores[:, x, y] = scores_[:, :-1, :-1]
else:
scores, _ = self.log_assignment[i](desc0, desc1)
m0, m1, mscores0, mscores1 = filter_matches(
scores, self.conf.filter_threshold)
return {
'log_assignment': scores,
'matches0': m0,
'matches1': m1,
'matching_scores0': mscores0,
'matching_scores1': mscores1,
'stop': i+1,
'prune0': prune0,
'prune1': prune1,
}
from lightglue.
Hi, thanks for providing the code!
config={'width_confidence': 0.99,
'depth_confidence': 0.95}
Both these options are not supported with batch inference. I set both to -1 (off).
When rerunning your code with N=30 and 2048 pts/image, I get 0.45s for batched inference (67 FPS, 100% GPU utilization) and 0.61 for non-batched (49 FPS, 84% GPU utilization), which is a speedup of 36%.
Furthermore, if one decreases the number of key points per image to 512, batched inference yields a speedup of > 500%. This is to be expected - the GPU utilization is much smaller with fewer keypoints, and thus the performance gain from batching is higher.
Maybe the problem is indeed with the configs you used. For completeness here is the code I used for this experiment:
from lightglue import LightGlue, SuperPoint
from lightglue.utils import load_image
from time import time
from pathlib import Path
import torch
images = Path('assets')
device=torch.device('cuda')
extractor = SuperPoint(max_num_keypoints=2048, detection_threshold=0.0).eval().to(device) # load the extractor
match_conf = {
'width_confidence': -1, # for point pruning
'depth_confidence': -1, # for early stopping,
'flash': True,
}
matcher = LightGlue(features='superpoint', **match_conf).eval().to(device)
image0, scales0 = load_image(images / 'DSC_0411.JPG')
image1, scales1 = load_image(images / 'DSC_0410.JPG')
# batched
n = 30
feats0 = extractor.extract(image0.to(device))
feats1 = extractor.extract(image1.to(device))
for key in feats0.keys():
feats0[key] = torch.cat([feats0[key]] * n)
feats1[key] = torch.cat([feats1[key]] * n)
for i in range(10):
last_time=time()
with torch.inference_mode():
pred1 = matcher({'image0': feats0, 'image1': feats1})
current_time=time()
print(current_time-last_time)
# non-batched
feats0 = extractor.extract(image0.to(device))
feats1 = extractor.extract(image1.to(device))
for i in range(10):
torch.cuda.synchronize()
last_time=time()
preds = []
for i in range(n):
t = time()
with torch.inference_mode():
preds.append(matcher({'image0': feats0, 'image1': feats1}))
torch.cuda.synchronize()
current_time=time()
print(current_time-last_time)
from lightglue.
Related Issues (20)
- loss HOT 1
- question about rotated images for LightGlue and OpenCV SIFT HOT 3
- why lightglue+superpoint is better than LightGlue+disk? HOT 1
- Estimating the H-Matrix (homography) HOT 3
- PR to kornia HOT 4
- Run in google colab HOT 3
- Regarding estimating Camera matrix K, R and T HOT 4
- Batch mode results worse than non-batch HOT 3
- resize_image() got an unexpected keyword argument 'grayscale' HOT 2
- is it possible to only displaying n number of points ?? HOT 1
- Imge stitching HOT 2
- about scores
- About Inference time HOT 8
- License issue of Superpoint in LightGlue HOT 2
- can you offer train dataset? HOT 3
- Release of training and evaluation code HOT 3
- Keypoints Pixels & Heatmap HOT 1
- matcher needs image as an input HOT 3
- there may exist a code bug
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from lightglue.