Comments (3)
- This project can't convert pointpillars directly, we need to modify:
- change RPN inputs and outputs from dict to tuple (torch.jit can't handle dict for now)
- modify PFN:
class PFNLayerTensorRT(nn.Module):
def __init__(self,
in_channels,
out_channels,
use_norm=True,
last_layer=False):
super().__init__()
self.name = 'PFNLayerTensorRT'
self.last_vfe = last_layer
assert self.last_vfe is True, "tensor rt don't support this."
if not self.last_vfe:
out_channels = out_channels // 2
self.units = out_channels
if use_norm:
BatchNorm2d = change_default_args(
eps=1e-3, momentum=0.01)(nn.BatchNorm2d)
Conv2d = change_default_args(bias=False)(nn.Conv2d)
else:
BatchNorm2d = Empty
Conv2d = change_default_args(bias=True)(nn.Conv2d)
self.linear = Conv2d(in_channels, self.units, 1)
self.norm = BatchNorm2d(self.units)
def forward(self, inputs):
"""inputs: [1, num_features, max_num_points, max_points_per_voxel]
[1, 64, 12000, 100]
"""
x = self.linear(inputs)
x = self.norm(x)
x = F.relu(x)
# x: [N, C, numPoints, numPointPerVoxel]
# x = x.view(1, -1, 60)
# max operation very slow in tensorrt consider using maxpool
x_max = torch.max(x, dim=3, keepdim=True)[0]
# x_max = F.max_pool2d(x, [1, x.shape[-1]], [1, x.shape[-1]])
if self.last_vfe:
return x_max
else:
# may need to use conv to implement repeat in tensorrt.
x_repeat = x_max.repeat(1, inputs.shape[1], 1)
x_concatenated = torch.cat([x, x_repeat], dim=2)
return x_concatenated
@register_vfe
class PillarFeatureNetTensorRT(nn.Module):
def __init__(self,
num_input_features=4,
use_norm=True,
num_filters=(64, ),
with_distance=False,
voxel_size=(0.2, 0.2, 4),
pc_range=(0, -40, -3, 70.4, 40, 1)):
super().__init__()
self.name = 'PillarFeatureNetTensorRT'
assert len(num_filters) > 0
num_input_features += 5
if with_distance:
num_input_features += 1
self._with_distance = with_distance
assert with_distance is False
# Create PillarFeatureNetOld layers
num_filters = [num_input_features] + list(num_filters)
pfn_layers = []
assert len(num_filters) == 2, "tensorrt don't support repeat"
for i in range(len(num_filters) - 1):
in_filters = num_filters[i]
out_filters = num_filters[i + 1]
if i < len(num_filters) - 2:
last_layer = False
else:
last_layer = True
pfn_layers.append(
PFNLayerTensorRT(
in_filters, out_filters, use_norm, last_layer=last_layer))
self.pfn_layers = nn.ModuleList(pfn_layers)
# Need pillar (voxel) size and x/y offset in order to calculate pillar offset
self.vx = voxel_size[0]
self.vy = voxel_size[1]
self.x_offset = self.vx / 2 + pc_range[0]
self.y_offset = self.vy / 2 + pc_range[1]
def forward(self, features, num_voxels, coors, voxel_point_mask):
"""TensorRT PFE must use specified inputs. All tensorrt inputs must be float for now.
features: [1, num_point_features, max_num_points, max_points_per_voxel]
num_voxels: [1, 1, max_num_points, 1]
coors: [1, 4, max_num_points, 1]
voxel_point_mask: [1, 1, max_num_points, max_points_per_voxel]
"""
# features = features.permute(2, 0, 1).contiguous().unsqueeze(0)
# num_voxels = num_voxels.type_as(features).view(1, 1, -1, 1)
# coors = coors.type_as(features).t().contiguous().view(1, 4, -1, 1)
# voxel_point_mask = voxel_point_mask.squeeze().unsqueeze(0).unsqueeze(0)
device = features.device
dtype = features.dtype
# Find distance of x, y, and z from cluster center
points_mean = features[:, :3].sum(
dim=3, keepdim=True) / num_voxels.view(1, 1, -1, 1)
f_cluster = features[:, :3] - points_mean
# Find distance of x, y, and z from pillar center
f_center_x = features[:, 0:1] - (coors[:, 3:4] * self.vx + self.x_offset)
f_center_y = features[:, 1:2] - (coors[:, 2:3] * self.vy + self.y_offset)
# print(f_cluster.squeeze().permute(1, 2, 0).mean())
# Combine together feature decorations
features_ls = [features, f_cluster, f_center_x, f_center_y]
features = torch.cat(features_ls, dim=1)
# The feature decorations were calculated without regard to whether pillar was empty. Need to ensure that
# empty pillars remain set to zeros.
features *= voxel_point_mask
# Forward pass through PFNLayers
for pfn in self.pfn_layers:
features = pfn(features)
# features shape: [1, 64, 12000, 1]
return features # .t()
- write a custom tensorrt plugin for PointPillarsScatter.
- I can't provide docker image because the upload speed is very slow and unstable in my network environment.
from torch2trt.
@traveller59 thanks for your reply, I will try it.
best wishes.
from torch2trt.
@traveller59 Thanks for the PFN implementation. I've made the program running, but when I extracted the PFN, it looks like below:
middle_class_name PointPillarsScatter
num_trainable parameters: 69
PillarFeatureNet(
(pfn_layers): ModuleList(
(0): PFNLayer(
(conv1): Conv2d(9, 64, kernel_size=(1, 1), stride=(1, 1))
(norm): DefaultArgLayer(64, eps=0.001, momentum=0.01, affine=True, track_running_stats=True)
(conv2): Conv2d(64, 64, kernel_size=(1, 34), stride=(1, 1), dilation=(1, 3))
)
)
)
Which seems like before, still an empty Model.
Could you tell me where I goes wrong?
Many thanks
from torch2trt.
Related Issues (9)
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from torch2trt.