Giter VIP home page Giter VIP logo

mathpluscode / imgx-diffseg Goto Github PK

View Code? Open in Web Editor NEW
69.0 69.0 8.0 21.25 MB

A JAX-based deep learning framework for image segmentation using diffusion models.

Home Page: https://melba-journal.org/2023:016

License: Apache License 2.0

Dockerfile 0.42% Python 99.54% Makefile 0.03%
abdominal-organ-segmentation brain-tumor-segmentation deep-learning diffusion-models flax jax muscle-ultrasound-segmentation prostate-segmentation segmentation

imgx-diffseg's Introduction

Hi, I'm Yunguan πŸ‘‹

@InstaDeep @BioNTech

Note

We have two papers accepted at NeurIPS 2023 Machine Learning in Structural Biology Workshop!

  • LightMHC: A Light Model for pMHC Structure Prediction with Graph Neural Networks (paper, code)
  • FrameDiPT: SE(3) Diffusion Model for Protein Structure Inpainting (paper, code)

I am a senior research engineer from the Department of BioAI at InstaDeep (now part of BioNTech), currently managing multiple R&D teams for bioinformatics projects.

@UCL

I am also a PhD student and Honorary Research Assistant from the Department of Medical Physics and Biomedical Engineering at University College London, under the supervision of Associate Professor Yipeng Hu. My research interests focus on deep learning in medical imaging. More about my research can be found at google scholar.

imgx-diffseg's People

Contributors

mathpluscode avatar yipenghu avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

imgx-diffseg's Issues

how to run on a new dataset?

As the title, I am not familiar with jap and tf. Looking forward to your advice. By the way, my experiments will be conducted on Linux and a single GPU using PyCharm.

Loss computation question.

Hi, thank you for your great job and sharing your code.
I have a question about the loss computation.

if gd.model_out_type == DiffusionModelOutputType.EPSILON:
mse_loss_scalar = jnp.mean((model_out - noise) ** 2)
scalars["mse_loss"] = mse_loss_scalar
x_start = gd.predict_xstart_from_epsilon_xt(
x_t=x_t, epsilon=model_out, t=t
)
logits = gd.x_to_logits(x_start)
seg_loss_scalar, seg_scalars = segmentation_loss_with_aux(
logits=logits,
mask_true=mask_true,
loss_config=loss_config,
)
scalars = {**scalars, **seg_scalars}

I notice that the x_start computated using the predicted noise is translated to logits using x_to_logits function.

def x_to_logits(self, x: jnp.ndarray) -> jnp.ndarray:
"""Map x to logits.
Args:
x: in the same space as x_start.
Returns:
Logits.
"""
if self.x_space == DiffusionSpace.LOGITS:
return x
if self.x_space == DiffusionSpace.SCALED_PROBS:
probs = (x + 1) / 2
probs = jnp.clip(probs, EPS, 1.0)
return jnp.log(probs)
raise ValueError(f"Unknown x space {self.x_space}.")

I think x_start can be regard as logits between -1 and 1, as the input x_start is in this format. I cannot figure out why it is translated to logits again. Maybe I miss some point or misunderstand something. Could you please help me understand this operation?

Looking forward to your reply.

failed to make build_dataset

when i run make build_dataset terminal just repeting as below

tfds build imgx_datasets/male_pelvic_mr
INFO[build.py]: Loading dataset imgx_datasets/male_pelvic_mr from path: /data/huyao/imgx_diffseg/ImgX-DiffSeg-main/imgx_datasets/male_pelvic_mr/male_pelvic_mr_dataset_builder.py
2023-12-15 22:03:22.352001: W tensorflow/tsl/platform/cloud/google_auth_provider.cc:184] All attempts to get a Google authentication bearer token failed, returning an empty token. Retrieving token from files failed with "NOT_FOUND: Could not locate the credentials file.". Retrieving token from GCE failed with "FAILED_PRECONDITION: Error executing an HTTP request: libcurl code 6 meaning 'Couldn't resolve host name', error details: Could not resolve host: metadata.google.internal".

An error was reported while downloading the dataset.

Extraction completed...: 0 file [1:26:56, ? file/s]1:26:56<00:00, 5216.07s/ url]
Dl Size...:   6%|β–ˆβ–Ž                   | 189/3135 [1:26:56<22:35:04, 27.60s/ MiB]
Dl Completed...: 100%|β–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆβ–ˆ| 1/1 [1:26:56<00:00, 5216.08s/ url]
Traceback (most recent call last):  

  File "/home/mhb/anaconda3/envs/imgx/bin/tfds", line 8, in <module>
    sys.exit(launch_cli())
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/scripts/cli/main.py", line 109, in launch_cli
    app.run(main, flags_parser=_parse_flags)
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/absl/app.py", line 308, in run
    _run_main(main, args)
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/absl/app.py", line 254, in _run_main
    sys.exit(main(argv))
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/scripts/cli/main.py", line 104, in main
    args.subparser_fn(args)
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/scripts/cli/build.py", line 274, in _build_datasets
    _download_and_prepare(args, builder)
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/scripts/cli/build.py", line 510, in _download_and_prepare
    builder.download_and_prepare(
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/core/logging/__init__.py", line 169, in __call__
    return function(*args, **kwargs)
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/core/dataset_builder.py", line 640, in download_and_prepare
    self._download_and_prepare(
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/core/dataset_builder.py", line 1448, in _download_and_prepare
    split_generators = self._split_generators(  # pylint: disable=unexpected-keyword-arg
  File "/media/mhb/jxp01/workspace/xq/ImgX-DiffSeg/imgx/datasets/male_pelvic_mr/male_pelvic_mr_dataset_builder.py", line 114, in _split_generators
    zip_dir = dl_manager.download_and_extract(
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/core/download/download_manager.py", line 686, in download_and_extract
    return _map_promise(self._download_extract, url_or_urls)
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/core/download/download_manager.py", line 829, in _map_promise
    res = tree_utils.map_structure(
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tree/__init__.py", line 435, in map_structure
    [func(*args) for args in zip(*map(flatten, structures))])
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tree/__init__.py", line 435, in <listcomp>
    [func(*args) for args in zip(*map(flatten, structures))])
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/core/download/download_manager.py", line 830, in <lambda>
    lambda p: p.get(), all_promises
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/promise/promise.py", line 512, in get
    return self._target_settled_value(_raise=True)
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/promise/promise.py", line 516, in _target_settled_value
    return self._target()._settled_value(_raise)
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/promise/promise.py", line 226, in _settled_value
    reraise(type(raise_val), raise_val, self._traceback)
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/six.py", line 719, in reraise
    raise value
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/promise/promise.py", line 87, in try_catch
    return (handler(*args, **kwargs), None)
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/core/download/download_manager.py", line 406, in <lambda>
    lambda dl_result: self._register_or_validate_checksums(  # pylint: disable=g-long-lambda
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/core/download/download_manager.py", line 471, in _register_or_validate_checksums
    return self._rename_and_get_final_dl_path(
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/core/download/download_manager.py", line 510, in _rename_and_get_final_dl_path
    resource_lib.write_info_file(
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/core/utils/py_utils.py", line 496, in lock_decorated
    return fn(*args, **kwargs)
  File "/home/mhb/anaconda3/envs/imgx/lib/python3.8/site-packages/tensorflow_datasets/core/download/resource.py", line 267, in write_info_file
    raise ValueError(
ValueError: File info /home/mhb/tensorflow_datasets/downloads/zenodo.org_record_7013610_files_dataW0mCI6aH_V-TdeDbM4TdKelNcJ5ZxbAi5isebqCnMr0.zip.INFO contains a different checksum that the downloaded one: Stored: {'checksum': 'c26ec704090bb4a705e0903a05a3d9ac7d7702284eda0f27f19dc17909bd1ead', 'filename': 'data.zip', 'size': 5880838}; Expected: {'size': 189.95 MiB, 'checksum': 'd9270ea4881c4e1002bc2b181213f16ce3ad5f1e93aad17e160246cc0d717d53', 'filename': 'data.zip'}`

An error was reported while downloading the dataset. It seems that the download has not been completed. "My download here is relatively slow, and I cannot continue downloading after an interruption, so I can only download it again.".
What changes have occurred to the data processed by TFPS, which is convenient for me to download and modify the original dataset, or if it is convenient for you to upload the processed dataset to a specific website

Exception running on one's own dataset

By imitating Amos_ CT wrote the "data_builder. py" and "config" files of its own 2D dataset, and ran the algorithm using 2D methods without modifying the network structure. However, during the training process, the loss was always negative.

The dataset consists of 5 categories of 2D images labeled 0-4, with each image containing only one or two types of data. The labels for each image are not all 0, and loss also indicates this (different categories randomly appear with loss values, not always nan, while background class loss always exists). At first, it was thought that the loss did not pay attention to nan, but it was found that it had been changed to 0 for nan. At the same time, the jax architecture was the first time I came into contact and it was difficult to debug during jit compilation. Therefore, I have been researching for a long time and have no clue. I hereby seek advice.
Here is my "data_builder. py" and "config" file
my_dataset.zip
Perhaps some configuration information in these two files was written incorrectly, because my dataset is not a medical image, it is a 2D semantic segmentation that does not contain z-axis information. However, I have done my best to modify any areas that I believe are different. I also adapted some of the medical image processing libraries you called, mainly due to the lack of a dimension for addition.
Mainly modifiedimgx/datasets/preprocess.py,such as

    image_volume = sitk.ReadImage(str(image_path),sitk.sitkInt8)
    label_volume = sitk.ReadImage(str(label_path),sitk.sitkInt8)
    image_array = sitk.GetArrayFromImage(image_volume)
    if len(image_array.shape) == 2:
        image_array = np.expand_dims(image_array, axis=0)
        image_volume = sitk.GetImageFromArray(image_array)

The program can run normally, but as shown in the figure, there is a large amount of nan in the loss. After running, only the background class loss drops normally, while other classes cannot, which is meaningless.
image
This is an example of annotated images, which only includes four categories and is missing one category.
image
If you could spare some time from your busy schedule to take a look, I would greatly appreciate it. The quality and speed of the code you wrote also impressed me, as compared to Pytorch's diffusion model, it can train very quickly.

About Encoders

In your code, you set a bottom encodder and an image encoder. But in your paper, there is only one encoder? Why?

Add example of using non-TFDS data loader

Description

As mentioned in #18 and #20, the usage of TFDS may not be obvious. While the inference example demos how to use a trained network on custom data, it remains challenging to use non-TFDS data loaders.

Following the release v0.3.2, the data iterator has been moved out of Experiment. It is now easier to use other data sets.

for i in range(max_num_steps):
batch = next(train_iter)
if i == 0:
# TODO support reload checkpoint
train_state, step_offset = run.train_init(batch)
if i + step_offset > max_num_steps:
# stop training
break
train_state, train_metrics = run.train_step(train_state, batch, key_train)

How could I run an inference on a single image?

If I trained a model on a custom dataset, how could I load the model and inference a single image/nii.gz?
It seems to be a bad idea if i modify the test set in dataset builder and then build it everytime.
Let's say we have an new image that needs to be segmented via the trained 'amos_ct' model. What is the simplest way to inference the image?
Thank you

ImportError: cannot import name 'IMAGE' from 'imgx' (/home/hyzhang/ImgX-DiffSeg/imgx/imgx/__init__.py)

I followed "local with conda" instruction to build environment in linux.

conda install -y -n base conda-libmamba-solver
conda config --set solver libmamba
conda env update -f docker/environment.yml
conda activate imgx
make pip

I have gone through all these installation steps without error, but when I build amos dataset, it shows the "ImportError: cannot import name 'IMAGE' from 'imgx' (/home/hyzhang/ImgX-DiffSeg/imgx/imgx/init.py)"

What should I do in this case?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    πŸ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. πŸ“ŠπŸ“ˆπŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❀️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.