ai-med / triplettraining Goto Github PK
View Code? Open in Web Editor NEWOfficial PyTorch Implementation for From Barlow Twins to Triplet Training: Differentiating Dementia with Limited Data - MIDL 2024
License: GNU General Public License v3.0
Official PyTorch Implementation for From Barlow Twins to Triplet Training: Differentiating Dementia with Limited Data - MIDL 2024
License: GNU General Public License v3.0
Hello,
I would like to produce h5 files for UKB pre-training as done in your method. However, I encounter memory problems while generating it.
Would you have advices on a specific process you followed to integrate your N = 39, 560 MRI samples (with a specific split I guess to have train_data.h5 and valid_data.h5 ?
I am working on the dna nexus platform, with an instance of type mem2_ssd1_gpu1_x32 with 129 GB total memory, 837 GB total storage and 32 cores. I use T2-FLAIR MRI with .nii.gz format, and each of them is 2.3 MB.
The detailed usage is as follows:
My python script uses multi-processing to generate temporary h5 files and combining them in a final data_train.h5:
import h5py
import pandas as pd
import numpy as np
import os
import nibabel as nib
from tqdm import tqdm
from concurrent.futures import ProcessPoolExecutor
import multiprocessing
import sys
import time
import logging
# Set up logging
logging.basicConfig(filename='process.log', level=logging.INFO, format='%(asctime)s %(message)s')
def load_image(image_path):
img = nib.load(image_path).get_fdata()
return img
def process_batches_to_h5(batch_eids, image_folder, temp_file_path, progress_position, progress_queue):
no_image_found = []
with h5py.File(temp_file_path, 'w', libver='latest') as file:
for eid in batch_eids:
start_time = time.time()
image_path = os.path.join(image_folder, f"{eid}/{eid}_T2_FLAIR_brain_to_MNI.nii.gz")
if os.path.exists(image_path):
img = load_image(image_path)
group = file.create_group(str(eid))
group.create_dataset('MRI/T2/data', data=img.astype(np.float32))
else:
no_image_found.append(eid)
continue
end_time = time.time()
process_time = end_time - start_time
log_message = f"Patient {eid} processed in {process_time:.2f} seconds"
progress_queue.put((progress_position, log_message, process_time))
logging.info(log_message)
return no_image_found
def combine_hdf5_files(output_file_path, temp_file_paths):
with h5py.File(output_file_path, 'w', libver='latest') as output_file:
for temp_file_path in temp_file_paths:
with h5py.File(temp_file_path, 'r') as temp_file:
for eid in temp_file.keys():
temp_file.copy(eid, output_file)
os.remove(temp_file_path)
def create_hdf5_from_eids(csv_file_path, image_folder, output_file_path, batch_size=1000, num_workers=32):
df = pd.read_csv(csv_file_path)
eids = df['eid'].values
num_batches = len(eids) // batch_size + 1
batched_eids = np.array_split(eids, num_batches)
manager = multiprocessing.Manager()
progress_queue = manager.Queue()
with ProcessPoolExecutor(max_workers=num_workers) as executor:
futures = []
temp_file_paths = []
for i, batch_eids in enumerate(batched_eids):
temp_file_path = f'temp_{i}.h5'
temp_file_paths.append(temp_file_path)
future = executor.submit(process_batches_to_h5, batch_eids, image_folder, temp_file_path, i, progress_queue)
futures.append(future)
total_patients = len(eids)
progress_bar = tqdm(total=total_patients, desc="Overall Progress", leave=True, file=sys.stdout)
while any(future.running() for future in futures):
try:
while True:
batch_position, message, process_time = progress_queue.get_nowait()
progress_bar.update(1)
tqdm.write(message)
except multiprocessing.queues.Empty:
pass
progress_bar.close()
no_image_found_all = []
for future in futures:
no_image_found_all.extend(future.result())
combine_hdf5_files(output_file_path, temp_file_paths)
if no_image_found_all:
logging.info("Eids for which no image was found:")
logging.info(no_image_found_all)
logging.info(f"HDF5 file {output_file_path} created successfully.")
# Define paths
csv_file_path_train = "mri_eids_UKB_train.csv"
csv_file_path_val = "mri_eids_UKB_val.csv"
image_folder = '/mnt/project/Data/brain_MRI/T2_lesion_T1seg/'
train_output_file_path = 'train_data.h5'
val_output_file_path = 'valid_data.h5'
# Create HDF5 files for training and validation sets
create_hdf5_from_eids(csv_file_path_train, image_folder, train_output_file_path)
create_hdf5_from_eids(csv_file_path_val, image_folder, val_output_file_path)
print(f"Training data saved to {train_output_file_path}")
print(f"Validation data saved to {val_output_file_path}")
I have to face a trade-off of time and space because the time to generate the files is longer when adding this line to lower down the space taken by each temporary h5 file:
group.create_dataset('MRI/T2/data', data=img.astype(np.float32), chunks=True, compression="gzip")
but removing it is generating the following error:
RuntimeError: Dirty entry flush destroy failed (file write failed: time = Wed May 22 14:59:05 2024, filename = 'temp_0.h5', file descriptor = 16, errno = 28, error message = 'No space left on device', buf = 0x56328dc1aeb0, total write size = 1891, bytes this sub-write = 1891, bytes actually written = 18446744073709551615, offset = 0)
For now, I have 32 temporary files (num processors = max num_workers = 32), each of them with target size of 1000 patients. Currently (i.e. after the script terminated), they are all 12 GB in size. On dna nexus platform, I load the data from the project stored on the cloud and I write these files to the local temporary environment from which I can transfer final files onto the cloud after creation.
My script was at this point when I got out of storage:
Overall Progress: 49%|█████████████████████████████▊ | 15615/31913 [3:10:47<3:19:07, 1.36it/s]
My T2-FLAIR MRIs are registered to MNI atlas and are (182, 218, 182) in shape.
I would love to have advices/inputs from your side on faster/more efficient storage gestion for the file creation!
Best and thank you for your answer,
A declarative, efficient, and flexible JavaScript library for building user interfaces.
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google ❤️ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.