Giter VIP home page Giter VIP logo

Comments (10)

songw-zju avatar songw-zju commented on May 26, 2024 1

Hi @drprojects , sorry for the late reply. I have successfully run your training and inference pipeline with your latest committed code and torchsparse 1.1.0. Thanks for your great work!

from deepviewagg.

drprojects avatar drprojects commented on May 26, 2024

Hi, thanks for using this repo !

You are unfortunately running into an issue that arises on KITTI-360 that I am still working on solving. It seems some types of multimdoal batches trigger CUDA errors. My guess is that it is related to batches with no images. This happens more often than we would think on KITTI-360 with perspective images. It is possible that having set batch_size=2 increases the chances of occurrence of this error. In my experiments, this error arose randomly and some whole trainings even went through without encountering it.

Have you, by any chance, identified the window and cylinder indices making up the items of the faulty batch ? These can be found in mm_data.data.idx_center, mm_data.data.idx_window. If you do find a set of problematic cylinders, please let me know, this may greatly help investigations.

As a temporary workaround to this problem, I would simply resume training. To do so, you need to do the following:

  • add the path of your first training's checkpoint file to the model config conf/models/segmentation/multimdoal/sparseconv3d.yaml:
Res16UNet34-PointPyramid-early-cityscapes-interpolate:
    path_pretrained: /path/to/your/project/torch-points3d/outputs/yyyy-mm-dd/hh-mm-ss/Res16UNet34-PointPyramid-early-cityscapes-interpolate.pt
    class: sparseconv3d.APIModel
    ...
  • change your number of epochs to account for already-passed epochs in your previous training
  • change your learning rate scheduler in conf/lr_scheduler/multi_step_kitti360 to account for the epoch offset. You can modify it in the file of create a new lr_scheduler config, in which case you must adapt train_kitti360.sh accordingly.

Important

  • don't forget to remove path_pretrained from your conf/models/segmentation/multimdoal/sparseconv3d.yaml afterwards, otherwise the model will always be instantiated using these partially-trained weights in the future
  • please git pull the project again, I just modified an error in train_kitti360.sh that did not point to any lr_scheduler.
  • the resume procedure I just gave you is a bit dirty, torch_points3d has a much better way of doing it, but it has issues reloading the optimizer for multimodal models for now, so I did not suggest it

from deepviewagg.

songw-zju avatar songw-zju commented on May 26, 2024

Thanks for your detailed advice. I am getting the error on the first iteration of training while mm_data.data.idx_center=[76930, 63780] and mm_data.data.idx_window=[0, 0]. To make the code work I modified torch_points3d/modules/multimodal/modules.py L194-L198:

in_coords = x_3d.cmaps[stride_in]
in_coords[:, :3] = ((in_coords[:, :3].float() / stride_out[0]).floor() * stride_out[0]).int()
out_coords = x_3d.cmaps[stride_out]
idx = sphashquery(sphash(in_coords), sphash(out_coords))

This will cause idx to contain -1.

from deepviewagg.

drprojects avatar drprojects commented on May 26, 2024

Thanks for proposing this solution. Did this fix your problem ? From what I understand, you found that idx contained -1 values but changing x_3d.coord_maps to x_3d.cmaps solved this problem ? I have not seen this before..

I have torchsparse==1.1.0 on my machine, but I assume you are using I assume you have torchsparse==1.4.0. In my version SparseTensor.cmaps does not exist. This is my mistake, install.sh makes it so but I did not check torchsparse backward compatibility.

I am not sure if this was a backward compatibility error or if you were running into the no-image error I mentionned above. Please let me know if you manage to go through a full KITTI-360 training without any error and obtain satisfying performance. If that is the case, I will integrate your change in the code. In the meantime, I will modify the install.sh to explicitly setup v1.1.0 instead.

from deepviewagg.

songw-zju avatar songw-zju commented on May 26, 2024

My modification was mainly due to a mismatch with the version of torchsparse v1.4 and v1.1. After changing modules.py , I can run this code but it produces the above error. I'll test your code with torchsparse v1.1 as soon as possible.

from deepviewagg.

drprojects avatar drprojects commented on May 26, 2024

I see. I am a bit concerned that the change of torchsparse version might have deeper implications than this. I take note of your suggestion for when I work on integrating 1.4.0, hopefully I will have time to look into that in the near future.

In any case, I don't think changing to torchsparse==1.1.0 will change much regarding your initial issue. I will try to look into this CUDA error soon.

from deepviewagg.

drprojects avatar drprojects commented on May 26, 2024

Hi songw-zju, just to let you know I have not forgotten about you ! I will start working on this issue tomorrow and will get back to you when I find the cause.

from deepviewagg.

drprojects avatar drprojects commented on May 26, 2024

Hi songw-zju, I have been trying to reproduce this issue using the problematic cylinders you pointed out but this does not cause any error on my end. Can you please check if the following works for you ?

# Select you GPU
I_GPU = 0

# Uncomment to use autoreload
# %load_ext autoreload
# %autoreload 2

import os
import os.path as osp
import sys
import torch
import numpy as np
from time import time
from omegaconf import OmegaConf
start = time()
import warnings
warnings.filterwarnings('ignore')

torch.cuda.set_device(I_GPU)
DIR = os.path.dirname(os.getcwd())
ROOT = os.path.join(DIR, "..")
sys.path.insert(0, ROOT)
sys.path.insert(0, DIR)

from torch_points3d.utils.config import hydra_read
from torch_geometric.data import Data
from torch_points3d.core.multimodal.data import MMData, MMBatch
from torch_points3d.core.multimodal.image import SameSettingImageData, ImageData
from torch_points3d.datasets.segmentation.multimodal.kitti360 import KITTI360DatasetMM
from torch_points3d.core.multimodal.image import ImageData
from torch_points3d.models.model_factory import instantiate_model

# Set your dataset root directory, where the data was/will be downloaded
DATA_ROOT = '/path/to/your/dataset/root/directory'
mini = False                                                          # set to True to only load and play with a small portion of the KITTI-360 dataset
train_is_trainval = False                                             # set to True if you want to the Train set to be Train+Val
sample_per_epoch = 12000                                              # number of cylinders sampled in the Train set. Corrects class imbalance. Set to 0 for regularly-sampled cylinders

dataset_config = 'segmentation/multimodal/kitti360-sparse'   
models_config = 'segmentation/multimodal/sparseconv3d'                # model family
model_name = 'Res16UNet34-PointPyramid-early-cityscapes-interpolate'  # specific model

overrides = [
    'task=segmentation',
    f'data={dataset_config}',
    f'data.mini={mini}',
    f'models={models_config}',
    f'model_name={model_name}',
    f'data.dataroot={DATA_ROOT}',
    f'+train_is_trainval={train_is_trainval}',
    f'data.sample_per_epoch={sample_per_epoch}',
]

cfg = hydra_read(overrides)
# print(OmegaConf.to_yaml(cfg)

# Dataset instantiation
start = time()
dataset = KITTI360DatasetMM(cfg.data)
# print(dataset)
print(f"Time = {time() - start:0.1f} sec.")

# Create the model
print(f"Creating model: {cfg.model_name}")
model = instantiate_model(cfg, dataset)
# print(model)

# Prepare the model for inference
model = model.eval().cuda()
print('Model loaded')

# Specify the (idx_window, idx_center) of your problematic samples
problematic_cylinders = [(0, 76930), (0, 63780)]

# Convert (idx_window, idx_center) into global indices
cum_sizes = torch.cat((torch.LongTensor([0]), dataset.train_dataset.sampling_sizes.cumsum(0)))
problematic_cylinders_global = [cum_sizes[idx_window].item() + idx_center for idx_window, idx_center in problematic_cylinders]

# Create a batch with the samples
batch = MMBatch.from_mm_data_list([dataset.train_dataset[idx_global] for idx_global in problematic_cylinders_global])
print(batch)

# # Optionally, to see if number of points matters
# # Only keep points within k * voxel size of the firt point, to see if 
# # single voxel subsampling matters 
# k = 4
# pos = batch.data.pos
# closeby = torch.where(((pos[0] - pos)**2).sum(dim=1) < k * batch.data.grid_size.max())[0]
# batch = batch[closeby]
# print(batch)

# # Optionally, to see if number of views matters
# # Remove all images
# batch.modalities['image'] = batch.modalities['image'][None]
# print(batch)

# Run forward and backward
model = model.train().cuda()
model.set_input(batch, model.device)
model(batch)
model.backward()

This assumes that you used the exact same dataset configuration as the one provided in the repo. If you changed resolution_3d or train_sample_res, our window idx and center index will differ.

from deepviewagg.

songw-zju avatar songw-zju commented on May 26, 2024

Hi, @drprojects , sorry for the late reply. I tested the code you provided and got the following error:

Time = 14.7 sec.
Creating model: Res16UNet34-PointPyramid-early-cityscapes-interpolate
Model loaded
MMBatch(
    data = DataBatch(mapping_index=[75395], planarity=[75395], y=[75395], origin_id=[75395], linearity=[75395], num_raw_points=[2], norm=[75395, 3], scattering=[75395], pos=[75395, 3], grid_size=[2], idx_window=[2], idx_center=[2], x=[75395, 1], coords=[75395, 3], batch=[75395], ptr=[3])
    image = ImageBatch(num_settings=4, num_views=13, num_points=75395, device=cpu)
)
None
Traceback (most recent call last):
  File "/home/ws/project/lcfusion/DeepViewAgg/test.py", line 100, in <module>
    model(batch)
  File "/home/public/ws/semantic_kittenv/anaconda3/envs/segcontrast/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ws/project/lcfusion/DeepViewAgg/torch_points3d/models/segmentation/sparseconv3d.py", line 43, in forward
    features = self.backbone(self.input).x
  File "/home/public/ws/semantic_kittenv/anaconda3/envs/segcontrast/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ws/project/lcfusion/DeepViewAgg/torch_points3d/applications/sparseconv3d.py", line 228, in forward
    data = self.down_modules[i](data)
  File "/home/public/ws/semantic_kittenv/anaconda3/envs/segcontrast/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ws/project/lcfusion/DeepViewAgg/torch_points3d/modules/multimodal/modules.py", line 85, in forward
    mm_data_dict, self.block_1)
  File "/home/ws/project/lcfusion/DeepViewAgg/torch_points3d/modules/multimodal/modules.py", line 221, in forward_3d_block_down
    mm_data_dict['modalities'][m].select_points(idx, mode=mode)
  File "/home/ws/project/lcfusion/DeepViewAgg/torch_points3d/core/multimodal/image.py", line 1493, in select_points
    print(idx)
  File "/home/public/ws/semantic_kittenv/anaconda3/envs/segcontrast/lib/python3.7/site-packages/torch/_tensor.py", line 249, in __repr__
    return torch._tensor_str._str(self)
  File "/home/public/ws/semantic_kittenv/anaconda3/envs/segcontrast/lib/python3.7/site-packages/torch/_tensor_str.py", line 415, in _str
    return _str_intern(self)
  File "/home/public/ws/semantic_kittenv/anaconda3/envs/segcontrast/lib/python3.7/site-packages/torch/_tensor_str.py", line 390, in _str_intern
    tensor_str = _tensor_str(self, indent)
  File "/home/public/ws/semantic_kittenv/anaconda3/envs/segcontrast/lib/python3.7/site-packages/torch/_tensor_str.py", line 251, in _tensor_str
    formatter = _Formatter(get_summarized_data(self) if summarize else self)
  File "/home/public/ws/semantic_kittenv/anaconda3/envs/segcontrast/lib/python3.7/site-packages/torch/_tensor_str.py", line 86, in __init__
    value_str = '{}'.format(value)
  File "/home/public/ws/semantic_kittenv/anaconda3/envs/segcontrast/lib/python3.7/site-packages/torch/_tensor.py", line 571, in __format__
    return self.item().__format__(format_spec)
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

from deepviewagg.

drprojects avatar drprojects commented on May 26, 2024

Hi songw-zju, thanks for your feedback.

I have been investigating this issue and I think I have found the source of the error. The assertion error occurs in select_points in DeepViewAgg/torch_points3d/core/multimodal/image.py. The assertion is triggered by as security check I put there assert (torch.arange(idx.max() + 1, device=self.device) == torch.unique(idx)).all(). This statement ensures that no voxels are forgotten when aggregating (mode merge) mappings after a voxel subsampling.

However, as you noticed, this assertion will crash if idx contains -1 values. This should normally never happen in our use case idx = sphashquery(sphash(in_coords), sphash(out_coords)) in DeepViewAgg/torch_points3d/modules/multimodal/modules.py. However, I found that sphashquery sometimes produces -1 values in this case for rare occasions when running on GPU, while CPU computation works fine. This is a very rare event on my end and I have not been able to identify the exact reason for this behavior. So I setup a workaround to make sure the idx is always non-negative in the code, even if it means breaking GPU-CPU asynchronicity.

The latest commit integrates these changes, and it works on my end.

This is what I suggest you do:

  • pull the latest commit
  • install torchsparse 1.1.0, I cannot garantee 1.4.0 support for now, because changing it has larger implications for the project which I cannot work on right on (see torch-points3d/torch-points3d#635 for example)

Please let me know if this solves the issue on your end.

from deepviewagg.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.