Giter VIP home page Giter VIP logo

gml16 / rl-medical Goto Github PK

View Code? Open in Web Editor NEW
76.0 3.0 21.0 194.16 MB

Communicative Multiagent Deep Reinforcement Learning for Anatomical Landmark Detection using PyTorch.

Home Page: https://arxiv.org/abs/2008.08055

License: Apache License 2.0

Python 100.00%
reinforcement-learning deep-reinforcement-learning deep-learning landmark-detection machine-learning healthcare multiagent-reinforcement-learning

rl-medical's Introduction

RL-Medical

Multiagent Deep Reinforcement Learning for Anatomical Landmark Detection using PyTorch. This is the code for the paper Communicative Reinforcement Learning Agents for Landmark Detection in Brain Images.

Introduction

Accurate detection of anatomical landmarks is an essential step in several medical imaging tasks. This repository implements a novel communicative multi-agent reinforcement learning (C-MARL) system to automatically detect landmarks in 3D medical images. C-MARL enables the agents to learn explicit communication channels, as well as implicit communication signals by sharing certain weights of the architecture among all the agents.

In addition to C-MARL, the code also supports single agents and multi-agents with no communication channel (named Network3d). This code is originally a fork from Amir Alansary's repository.

10 brain MRI scans each with 20 landmarks annotated from the ADNI dataset are included in the data folder for convenience.

Results

Here are a few examples of the learned agents on unseen data:

  • An example of our proposed C-MARL system consisting of 5 agents. These agents are looking for 5 different landmarks in a brain MRI scan. Each agent’s ROI is represented by a yellow box and centered around a blue point, while the red point is the target landmark. ROI is sampled with 3mm spacing at the beginning of every episode. The length of the circumference of red disks denotes the distance between the current and target landmarks in z-axis.

  • Similarly, 5 C-MARL agents in fetal ultrasounds scans.

  • Detecting the apex point in short-axis cardiac MRI (HQ video)

  • Detecting the anterior commissure (AC) point in adult brain MRI (HQ video)

  • Detecting the cavum septum pellucidum (CSP) point in fetal head ultrasound (HQ video)

Running the code

The main file is src/DQN.py and offers two modes of use, training and evaluation, that are described below. For convenience, a Conda environment has been provided (note: on my machine the environment takes 3.7GB, mostly because of PyTorch and the CUDA toolkit). There's no need to use it if the code already runs for you.

conda env create -f environment.yml
conda activate rl-medical

All other commands are run from the src folder.

cd src

Train

Example to train 5 C-MARL agents (named CommNet in the code)

python DQN.py --task train --files data/filenames/image_files.txt data/filenames/landmark_files.txt --model_name CommNet --file_type brain --landmarks 13 14 0 1 2 --multiscale --viz 0 --train_freq 50 --write

The command above is the one used to train the models presented in the paper. The default value for the replay buffer size is very large. Consider setting a lower value to the flags --memory_size and --init_memory_size to reduce the memory used. With the --write flag, training will produce logs and a Tensorboard in the --logDir directory (runs by default).

The --landmarks flag specifies the number of agents and their target landmarks. For example, --landmarks 0 1 1 means there are 3 agents. One agent looks for landmark 0 while two agents look for the same landmark number 1. All 3 agents communicate with each other.

Evaluate

  • 8 C-MARL agents
python DQN.py --task eval --load 'data/models/BrainMRI/CommNet8agents.pt' --files 'data/filenames/image_files.txt' 'data/filenames/landmark_files.txt' --file_type brain --landmarks 13 14 0 1 2 3 4 5 --model_name "CommNet"
  • 5 C-MARL agents
python DQN.py --task eval --load 'data/models/BrainMRI/CommNet5agents.pt' --files 'data/filenames/image_files.txt' 'data/filenames/landmark_files.txt' --file_type brain --landmarks 13 14 0 1 2 --model_name "CommNet"
  • 8 Network3d agents
python DQN.py --task eval --load 'data/models/BrainMRI/Network3d8agents.pt' --files 'data/filenames/image_files.txt' 'data/filenames/landmark_files.txt' --file_type brain --landmarks 13 14 0 1 2 3 4 5 --model_name "Network3d"
  • Single agent
python DQN.py --task eval --load 'data/models/BrainMRI/SingleAgent.pt' --files 'data/filenames/image_files.txt' 'data/filenames/landmark_files.txt' --file_type brain --landmarks 13 --model_name "Network3d"

Inference without ground truth

The argument --task play can be used to run inference without any ground truth landmarks. In which case, the argument --files should only have image_files.txt, and no landmark_files.txt should be passed. For example, in the case of 8 C-MARL agents:

python DQN.py --task play --load 'data/models/BrainMRI/CommNet8agents.pt' --files 'data/filenames/image_files.txt' --file_type brain --landmarks 13 14 0 1 2 3 4 5 --model_name "CommNet"

Please note that when using task mode play, the evaluation file and logs have "N/A" instead of the landmark xyz positions and distance, since there is no ground truth available.

Usage

usage: DQN.py [-h] [--load LOAD] [--task {play,eval,train}]
              [--file_type {brain,cardiac,fetal}] [--files FILES [FILES ...]]
              [--val_files VAL_FILES [VAL_FILES ...]] [--saveGif]
              [--saveVideo] [--logDir LOGDIR]
              [--landmarks [LANDMARKS [LANDMARKS ...]]]
              [--model_name {CommNet,Network3d}] [--batch_size BATCH_SIZE]
              [--memory_size MEMORY_SIZE]
              [--init_memory_size INIT_MEMORY_SIZE]
              [--max_episodes MAX_EPISODES]
              [--steps_per_episode STEPS_PER_EPISODE]
              [--target_update_freq TARGET_UPDATE_FREQ]
              [--save_freq SAVE_FREQ] [--delta DELTA] [--viz VIZ]
              [--multiscale] [--write] [--train_freq TRAIN_FREQ] [--seed SEED]

optional arguments:
  -h, --help            show this help message and exit
  --load LOAD           Path to the model to load (default: None)
  --task {play,eval,train}
                        task to perform, must load a pretrained model if task
                        is "play" or "eval" (default: train)
  --file_type {brain,cardiac,fetal}
                        Type of the training and validation files (default:
                        train)
  --files FILES [FILES ...]
                        Filepath to the text file that contains list of
                        images. Each line of this file is a full path to an
                        image scan. For (task == train or eval) there should
                        be two input files ['images', 'landmarks'] (default:
                        None)
  --val_files VAL_FILES [VAL_FILES ...]
                        Filepath to the text file that contains list of
                        validation images. Each line of this file is a full
                        path to an image scan. For (task == train or eval)
                        there should be two input files ['images',
                        'landmarks'] (default: None)
  --saveGif             Save gif image of the game (default: False)
  --saveVideo           Save video of the game (default: False)
  --logDir LOGDIR       Store logs in this directory during training (default:
                        runs)
  --landmarks [LANDMARKS [LANDMARKS ...]]
                        Landmarks to use in the images (default: [1])
  --model_name {CommNet,Network3d}
                        Models implemented are: Network3d, CommNet (default:
                        CommNet)
  --batch_size BATCH_SIZE
                        Size of each batch (default: 64)
  --memory_size MEMORY_SIZE
                        Number of transitions stored in exp replay buffer. If
                        too much is allocated training may abruptly stop.
                        (default: 100000.0)
  --init_memory_size INIT_MEMORY_SIZE
                        Number of transitions stored in exp replay before
                        training (default: 30000.0)
  --max_episodes MAX_EPISODES
                        "Number of episodes to train for" (default: 100000.0)
  --steps_per_episode STEPS_PER_EPISODE
                        Maximum steps per episode (default: 200)
  --target_update_freq TARGET_UPDATE_FREQ
                        Number of epochs between each target network update
                        (default: 10)
  --save_freq SAVE_FREQ
                        Saves network every save_freq steps (default: 1000)
  --delta DELTA         Amount to decreases epsilon each episode, for the
                        epsilon-greedy policy (default: 0.0001)
  --viz VIZ             Size of the window, None for no visualisation
                        (default: 0.01)
  --multiscale          Reduces size of voxel around the agent when it
                        oscillates (default: False)
  --write               Saves the training logs (default: False)
  --train_freq TRAIN_FREQ
                        Number of agent steps between each training step on
                        one mini-batch (default: 1)
  --seed SEED           Random seed for both training and evaluating. If none
                        is provided, no seed will be set (default: None)

Visualiser

To help debug/visualise the images with their landmarks, you can use the visualise.py script. It will show each image with the agents positioned at the landmarks' locations. Pressing Enter will go to the next image in the dataset.

Example usage:

python visualiser.py --files data/filenames/image_files.txt data/filenames/landmark_files.txt --file_type brain --landmarks 0 1 2 3

Contributing

Issues and pull requests are very welcomed.

Citation

If you use this code in your research, please cite this paper:

@article{leroy2020communicative,
  title={Communicative Reinforcement Learning Agents for Landmark Detection in Brain Images},
  author={Leroy, Guy and Rueckert, Daniel and Alansary, Amir},
  journal={arXiv preprint arXiv:2008.08055},
  year={2020}
}

Resources

More information on this project:

rl-medical's People

Contributors

amiralansary avatar brdav avatar crypdick avatar ghisvail avatar gml16 avatar nikolaosbouas avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar

rl-medical's Issues

Landmark annotation

May I ask what tool doctors use to manually annotate your dataset?
I tried Dolphin and itk snap, but their effects were not very good.
This question is crucial to me, thank you

set seed & reproducible result

Hi,
May I ask how do you generate reproducible results as the environment randomly resets the state in each game? since I can't find how the random seeds can be manually set for each experiment.
Thank you in advance!

RuntimeError: Unable to find a valid cuDNN algorithm to run convolution

Hi @gml16 , I'm trying to learn sth from this project.
I just cloned it to my PC and used the instruction of Evaluating with 8 C-MARL agents.
However, I met an error saying that'RuntimeError: Unable to find a valid cuDNN algorithm to run convolution'.
I'm sure that I'm using the same environment as the file 'environment.yml' offered.
Btw, I reinstalled cuDNN 7.6.5 but it didn't work.
Also, I've looked for several solutions saying that decreasing the batch size could be helpful.
I changed the batch size from 64 to 32 but the same problem occurred.
So where could the problem be at?
Maybe it's a very basic problem since this is my first time to apply a project on my own.I'm really desperate for the training.
Looking forward to your reply.

Image Testing

Hi @gml16 , may I know what is the criteria to use the 'play' task during testing. I mean can you show the example how to run the 'play' command for testing

dicom size

hi, thanks for this code, it's amazing,
I'm having trouble training the model with my own data, i have some ct scanner with different size between them, for example, 512512250, 512256128...
So it seems that it is not possible to use this type of data? or what I should do?

thanks for your help

Oscar

about ADNI dataset

After I applied for ADNI data set, I did not find the anatomical landmark label of corresponding MRI. Could you please tell me whether the dataset you used for training was marked by yourself ?

Data Augmentation

Hey @gml16, I am working on adding data augmentation. Currently, I augment the data and place it inside the images folder. Is there any place that I could simply add the data augmentation logic within the code, so that it will be taken care of during runtime? Thanks

What is the recommended python version?

I am trying to train a model. Due to some version mismatch I run into some import errors. I am using python 3.8.2. Is there a specific python version that I should be using?

Traceback (most recent call last):
  File "DQN.py", line 8, in <module>
    from logger import Logger
  File "/scratch/sshanmug/rl-medical/src/logger.py", line 6, in <module>
    from torch.utils.tensorboard import SummaryWriter
  File "/scratch/sshanmug/rl-medical/env/lib/python3.8/site-packages/torch/utils/tensorboard/__init__.py", line 12, in <module>
    from .writer import FileWriter, SummaryWriter  # noqa: F401
  File "/scratch/sshanmug/rl-medical/env/lib/python3.8/site-packages/torch/utils/tensorboard/writer.py", line 9, in <module>
    from tensorboard.compat.proto.event_pb2 import SessionLog
  File "/scratch/sshanmug/rl-medical/env/lib/python3.8/site-packages/tensorboard/compat/proto/event_pb2.py", line 6, in <module>
    from google.protobuf import descriptor as _descriptor
  File "/scratch/sshanmug/rl-medical/env/lib/python3.8/site-packages/google/protobuf/descriptor.py", line 51, in <module>
    from google.protobuf.pyext import _message
ImportError: /scratch/sshanmug/rl-medical/env/lib/python3.8/site-packages/google/protobuf/pyext/_message.cpython-38-x86_64-linux-gnu.so: undefined symbol: _ZN4absl12lts_2023012512log_internal9kCharNullE

Large dataset training issue.

I'm sorry to disturb you.
I want to use the project for a large datset. I have a large dataset which contain more than 3000 data for training. The training log shows that distance is not decrease but fluctuate a lot.
I notice that the dataset you used in your article has fewer samples.
Do I need to do some special modification to the code to make it suitable for large dataset training, like using longer episodes or something?
Looking forward to your reply.

Hyperparameter “seed”

Hi!
Thanks for sharing your code! RL it's a super interesting topic in which I'm still new, so my apologies in advance if these questions are too obvious.

When running the program, there is a hyperparameter “seed” (defalt: None). I found that when it is not set, the result of each evaluation will be different. How does the seed parameter affect the results of the evaluation? How should I set it when training and testing on an unseen dataset?

Thanks for the help!

about the evaluate command

When I run your code:“python DQN.py --task eval --load 'data/models/BrainMRI/Network3d8agents.pt' --files 'data/filenames/image_files.txt' 'data/filenames/landmark_files.txt' --file_type brain --landmarks 13 14 0 1 2 3 4 5 --model_name "Network3d"

The result of my error is as follows:
Traceback (most recent call last):
File "DQN.py", line 234, in
evaluator.play_n_episodes(fixed_spawn=args.fixed_spawn)
File "C:\Users\86159\Desktop\rl-medical-master\src\evaluator.py", line 44, in play_n_episodes
score, start_dists, q_values, info = self.play_one_episode(render, fixed_spawn=fixed_spawn[j])
TypeError: 'NoneType' object is not subscriptable

Do you know why? This question may be very basic, because I am a novice.

How to generate landmarks to txt files?

could you please tell me how to generate the landmarks to txt files?
For example, in ADNI_002_S_0816_MR_MPR__GradWarp__B1_Correction__N3__Scaled_Br_20070217005829488_S18402_I40731.txt, how to explain the first line 86,86,84? Is this the corresponding to the physical coordinate in the CT image?

Infer without landmarks

Hi, thank you for your great contribution.
I want to use the model to predict landmarks for images without landmark ground truth file. However, it seems both the train and eval model require the landmark file.
What should I do?
Looking forward to your reply.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.