Giter VIP home page Giter VIP logo

slicercbcttoothsegmentation's Introduction

SlicerCBCTToothSegmentation

Slicer extension for automated segmentation of individual teeth in cone-beam CT dental scans using a deep-learning based approach.

Dependencies

This module requires the Pytorch extension which can be installed via the Extension Manager. Python dependencies such as MONAI will be automaically installed by the module if not available. The module internally downloads pretrained model weights to intialize a segmentation model.

Tutorial

  1. Load the CBCT scan to be segmented into slicer and select it as the Input volume.
  2. Create a new Input ROI and adjust the bounding box to surround the tooth of interest that you would like to segment.
  3. Specify an 'Output Segmentation' and click the 'Apply' button to run segmentation inference.
  • Repeat steps 2) and 3) to segment multiple teeth.
  • Edit the automated output segmentation using the Paint, Draw and Erase tools in the Segmentation editor built into the module
  • Directly save the output segmentation to file in the Export to file section.

License

This extension is covered by the Apache License, Version 2.0:

https://www.apache.org/licenses/LICENSE-2.0

slicercbcttoothsegmentation's People

Contributors

jcfr avatar pzaffino avatar sadhana-r avatar sjh26 avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar

slicercbcttoothsegmentation's Issues

A Request for Collaboration and Code Optimization

Good morning, colleagues, please share the documentation on how to install the plugin in Slicer 5.6 and higher, as it cannot be done through the "install from file" method. I also request assistance in adapting the code for Blender, where I am using it for scientific purposes in the segmentation of bones for planning orthognathic surgeries. I have an idea to utilize your project and combine it with another for the segmentation of all teeth. Initially, I use MONAI to apply cephalometric points with the help of an AI model, which also includes the crowns of the teeth. In the next step, I want to use your model in a loop, where the ROI will be determined based on the locational point of the tooth crown (additionally expanded by 50), and thus in a loop, I will perform the segmentation of all teeth, tagging them by names. I ask for help in optimizing the code to exclude Slicer and operate only on SimpleITK, VTK, MONAI. Currently, I have managed to build such a part of the code, but the segmentation results are incorrect. For simplification, I am assuming a constant value of ROI, which will be dynamically assigned in the future.

`def brain_tooth_AI(
inputVolume,
outputSegmentation,
modelPath,
sphere_center,
sphere_radius):
"""
Run the processing algorithm.
Can be used without GUI widget.
:param inputVolume: volume to be Segmented
:param outputVolume: Segmentation result
:param inputROI - To ADD
:param showResult: show output volume in slice viewers
"""

if not inputVolume or not outputSegmentation:
    raise ValueError("Input or output volume is invalid")

if not is_installed("monai", "1.3.0"):
    subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "monai", "-y"])
    subprocess.check_call([sys.executable, "-m", "pip", "install", "monai==1.3.0"])

import time
startTime = time.time()
print('Processing started')

### ROI from Blender #########################################################################################


def load_nii_gz_file(file_path):
    return sitk.ReadImage(file_path)


def sitk_to_numpy(image):
    return sitk.GetArrayFromImage(image)

def adjust_roi_for_simpleitk(input_image, sphere_center, sphere_radius):
    img_size = input_image.GetSize()
    img_center = (img_size[0] / 2, img_size[1] / 2, img_size[2] / 2)

    transformed_center = (
        sphere_center[0] + img_center[0],
        sphere_center[1] + img_center[1],
        sphere_center[2] + img_center[2],
    )

    roi = (
        transformed_center[0] - sphere_radius,  # Początek x
        transformed_center[1] - sphere_radius,  # Początek y
        transformed_center[2] - sphere_radius,  # Początek z
        2 * sphere_radius,  # Szerokość
        2 * sphere_radius,  # Wysokość
        2 * sphere_radius  # Głębokość
    )
    return roi

def crop_image(input_image, roi):
    img_size = input_image.GetSize()
    print(f"Image size: {img_size}")

    x, y, z, width, height, depth = roi
    roi_slice = sitk.RegionOfInterestImageFilter()
    roi_slice.SetSize([int(width), int(height), int(depth)]) 
    roi_slice.SetIndex([int(x), int(y), int(z)])  



    cropped_image = roi_slice.Execute(input_image)
    return cropped_image



input_image = load_nii_gz_file(inputVolume)

#roi = adjust_roi_for_simpleitk(input_image, sphere_center, sphere_radius)
roi = (180,250,150,55,55,100) # temporary


cropped_image = crop_image(input_image, roi)


inputImageArray = sitk_to_numpy(cropped_image)
inputCrop_shape = inputImageArray.shape

print("ROI:", inputCrop_shape)

################################################################################################################

import numpy as np
import torch
from monai.inferers import SlidingWindowInferer

from monai.transforms import (
    Compose,
    EnsureChannelFirst,
    SpatialPad,
    NormalizeIntensity
)
from monai.networks.nets import UNet
from monai.networks.layers.factories import Act
from monai.networks.layers import Norm

print("CUDA count: "+str(torch.cuda.device_count()))

if torch.cuda.is_available():
    device = torch.device("cuda:0")
else:
    device = "cpu"
print("Using ", device, " for compute")

# Define U-Net model
model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=2,
    channels=(16, 32, 64, 128),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    act=Act.RELU,
    norm=Norm.BATCH,
    dropout=0.2).to(device)

# Load model weights
inputModelPath = modelPath
loaded_model = torch.load(inputModelPath, map_location=device)
model.load_state_dict(loaded_model,
                      strict=True)  # Strict is false since U-Net is missing some keys - batch norm related?
model.eval()

inputImageArray = torch.tensor(inputImageArray, dtype=torch.float)

# define pre-transforms
pre_transforms = Compose([
    EnsureChannelFirst(channel_dim='no_channel'),
    NormalizeIntensity(),
    SpatialPad(spatial_size=[144, 144, 144], mode="reflect"),
    EnsureChannelFirst(channel_dim='no_channel')
])

# run inference
inputProcessed = pre_transforms(inputImageArray).to(device)
inferer = SlidingWindowInferer(roi_size=[96, 96, 96])


# process prediction output
output = inferer(inputProcessed, model)
output = torch.softmax(output, axis=1).data.cpu().numpy()
output = np.argmax(output, 1).squeeze().astype(np.uint8)

# Crop the predicion back to original size
lower = [0] * 3
upper = [0] * 3
for i in range(len(inputCrop_shape)):
    dim = inputCrop_shape[i]
    padding = 144 - dim
    if padding > 0:
        lower[i] = int(np.floor(padding / 2))
        upper[i] = -int(np.ceil(padding / 2))
    else:
        lower[i] = 0
        upper[i] = dim

output_reshaped = output[lower[0]:upper[0], lower[1]:upper[1], lower[2]:upper[2]]

# # Keep largest connected component
# largest_comp_transform = KeepLargestConnectedComponent()
# val_comp = largest_comp_transform(val_outputs)

print("Inference done")

# Need to take cropped segmentation back into the space of the original image croppedVolume
data_array = numpy_to_vtk(num_array=output_reshaped.ravel(), deep=True, array_type=vtk.VTK_UNSIGNED_CHAR)

image_data = vtk.vtkImageData()
image_data.SetDimensions(output_reshaped.shape)
image_data.GetPointData().SetScalars(data_array)

contour_filter = vtk.vtkMarchingCubes()
contour_filter.SetInputData(image_data)
contour_filter.SetValue(0, 0.5)  
contour_filter.Update()

stl_writer = vtk.vtkSTLWriter()
stl_writer.SetFileName(outputSegmentation+"/test.stl")
stl_writer.SetInputData(contour_filter.GetOutput())
stl_writer.Write()

stopTime = time.time()
print(f'Processing completed in {stopTime - startTime:.2f} seconds')`

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.