Giter VIP home page Giter VIP logo

gaitforemer's Introduction

GaitForeMer: Self-Supervised Pre-Training of Transformers via Human Motion Forecasting for Few-Shot Gait Impairment Severity Estimation

GaitForeMer (Gait Forecasting and impairment estimation transforMer) predicts MDS-UPDRS gait impairment severity scores using learned motion features from the pre-training task of human motion forecasting. This work will be presented in MICCAI 2022 later this year!

GaitForeMer architecture figure

Requirements

  • Pytorch>=1.7.
  • Numpy.
  • Tensorboard for pytorch.

Pretext Task

The pre-trained weights are available for download here.

If you would like, you can perform your own pre-training. In order to run your own pre-training, you will have to download the NTURGB+D dataset. Place the nturgbd_skeletons_s001_to_s017.zip and action_labels.txt files into the data/nturgb+d_data directory and unzip the zip file. You will then have to process the data by running python data/NTURGDDataset.py. This will create all the needed files for pre-training the model. Once these steps have been completed, use the following script to do your own pre-training.

python training/transformer_model_fn.py \
  --task=pretext \
  --model_prefix=nturgb+d_out \
  --batch_size=16 \
  --data_path=data/nturgb+d_data \
  --learning_rate=0.0001 \
  --max_epochs=200 \
  --steps_per_epoch=200 \
  --loss_fn=l1 \
  --model_dim=128 \
  --num_encoder_layers=4 \
  --num_decoder_layers=4 \
  --num_heads=4 \
  --dim_ffn=2048 \
  --dropout=0.3 \
  --lr_step_size=400 \
  --learning_rate_fn=step \
  --warmup_epochs=100 \
  --pose_format=None \
  --pose_embedding_type=gcn_enc \
  --dataset=ntu_rgbd \
  --pre_normalization \
  --pad_decoder_inputs \
  --non_autoregressive \
  --pos_enc_alpha=10 \
  --pos_enc_beta=500 \
  --predict_activity \
  --action=all \
  --source_seq_len=40 \
  --target_seq_len=20

Downstream Task

Once you have pre-trained weights, you can run the downstream task of gait impairment severity estimation.

Data

The PD dataset used in the paper is not publically available, but the CASIA Gait Database is available upon request. This repo has the steps for extracting poses and preprocessing the data. We follow this setup in our work.

Training and Evaluating

To train and evaluate GaitForeMer, run this command:

python3 training/transformer_model_fn.py \
  --downstream_strategy=both_then_class \
  --model_prefix=output_models/finetune_both_branches_then_class_branch \
  --batch_size=16 \
  --data_path=<PATH_TO_PD_DATA> \
  --learning_rate=0.0001 \
  --max_epochs=100 \
  --steps_per_epoch=200 \
  --loss_fn=l1 \
  --model_dim=128 \
  --num_encoder_layers=4 \
  --num_decoder_layers=4 \
  --num_heads=4 \
  --dim_ffn=2048 \
  --dropout=0.3 \
  --lr_step_size=400 \
  --learning_rate_fn=step \
  --warmup_epochs=10 \
  --pose_format=None \
  --pose_embedding_type=gcn_enc \
  --dataset=pd_gait \
  --pad_decoder_inputs \
  --non_autoregressive \
  --pos_enc_alpha=10 \
  --pos_enc_beta=500 \
  --predict_activity \
  --action=all \
  --source_seq_len=40 \
  --target_seq_len=20 \
  --finetuning_ckpt=<PATH_TO_PRETRAINED_MODEL_CHECKPOINT e.g. pre-trained_ntu_ckpt_epoch_0099>

You can change the fine-tuning strategy by alterning the downstream_strategy parameter.

gaitforemer's People

Contributors

markendo avatar

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.