Giter VIP home page Giter VIP logo

pytorch-dmanet's Introduction

Deep Multi-Branch Aggregation Network for Real-Time Semantic Segmentation in Street Scenes

PyTorch Lightning Config: Hydra Template
Paper

DMA-Net Architecture

This is an implementation of DMA-Net in Pytorch. The project is for my self exploration with Pytorch Lightning and Hydra tools and enhance my programming skills. DMA-Net is a real-time semantic segmentation network for street scenes in self-driving cars.

Added Features

  1. D-Adaptaion Optmizers Learning rate free learning for SGD, AdaGrad and Adam! by facebookresearch/dadaptation/ Simlply enable by using:

    model.auto_lr=True model.lr=1.0
    
  2. Hyperparameter Search Since its hard to reproduce the result from the original author, I added 2 variables high_level_features and low_level_features to set the feature sizes in the model.

    • high_level_features: its the CBR (upmid_cbr) input size after addition ops between sub-network 3 and sub-network 4 in the upscaling pipeline.

    • low_level_features: its the CBR (uplow_cbr) input size after addition ops between sub-network 2 and upmid_cbr in the upscaling pipeline.

    model.net.low_level_features=128 model.net.high_level_features=128
    

How to run

Install dependencies

# clone project
git clone https://github.com/haritsahm/pytorch-DMANet.git
cd pytorch-DMANet

# [OPTIONAL] create conda environment
conda create -n myenv python=3.10
conda activate myenv

# install pytorch according to instructions
# https://pytorch.org/get-started/

# install requirements
pip install -r requirements.txt

Prepare dataset

Run and follow the notebook to prepare and visualize dataset using Fiftyone Fiftyone Sample

Train Commands

1. Train with default configurations

# train on CPU
python src/train.py trainer=cpu paths.data_dir=data/cityscape_fo_image_segmentation

# train on GPU
python src/train.py trainer=gpu paths.data_dir=data/cityscape_fo_image_segmentation

# train with DDP (Distributed Data Parallel) (4 GPUs)
python src/train.py trainer=ddp trainer.devices=4 paths.data_dir=data/cityscape_fo_image_segmentation

2. Train model with chosen experiment configuration from configs/experiment/

# train using cityscape dataset
python train.py experiment=dmanet_cityscape paths.data_dir=data/cityscape_fo_image_segmentation

# train using camvid dataset
python train.py experiment=dmanet_camvid

3. Override any parameter

python train.py experiment=dmanet_cityscape paths.data_dir=data/cityscape_fo_image_segmentation trainer.max_epochs=20 datamodule.batch_size=64 model.net.low_level_features=128 model.net.high_level_features=256

Read the full documentation on how to use pytorch-lightning + hydra

TODO:

  • Train model using cloud instances
  • Validate and compare model metrics (cityscapes and camvid)

pytorch-dmanet's People

Contributors

haritsahm avatar

Stargazers

 avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

pytorch-dmanet's Issues

512*1024分辨率的训练结果

您好,我用512*1024的分辨率训练模型,结果模型的精度很低,只有40%左右,很困惑,一直没有找到原因,想问一下您知道这是什么原因吗?

Error when running in google colab

!python train.py experiment=camvid trainer.max_epochs=100 datamodule.batch_size=4 logger=neptune

  File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/strategies/strategy.py", line 358, in training_step
    return self.model.training_step(*args, **kwargs)
  File "/content/pytorch-DMANet/src/models/dmanet_module.py", line 82, in training_step
    images, logits, gt_masks = self.step(batch)
  File "/content/pytorch-DMANet/src/models/dmanet_module.py", line 78, in step
    logits = self.forward(images)
  File "/content/pytorch-DMANet/src/models/dmanet_module.py", line 74, in forward
    return self._net(x)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/pytorch-DMANet/src/models/components/dma_net.py", line 156, in forward
    dec_masks = self._decoder(enc_features)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/pytorch-DMANet/src/models/components/dma_net.py", line 83, in forward
    gcb_features = self._gcb(c5)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/content/pytorch-DMANet/src/models/functions/layers.py", line 179, in forward
    x = self._bn(x)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/batchnorm.py", line 179, in forward
    self.eps,
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 2436, in batch_norm
    _verify_batch_size(input.size())
  File "/usr/local/lib/python3.7/dist-packages/torch/nn/functional.py", line 2404, in _verify_batch_size
    raise ValueError("Expected more than 1 value per channel when training, got input size {}".format(size))
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 256, 1, 1])

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.