Giter VIP home page Giter VIP logo

Comments (16)

tkuanlun350 avatar tkuanlun350 commented on June 10, 2024

I didn'y met this problem before. 60G memory usage sounds impossible to me.

Can you try to use offline evaluation to see if the problem still exist ? ex: python3 train.py --load /path/to/ckpt/ --evaluate ...

from 3dunet-tensorflow-brats18.

huangmozhilv avatar huangmozhilv commented on June 10, 2024

I changed to offline_evaluate() by using

model_path='train_log/unet3d/model-5'
pred = OfflinePredictor(PredictConfig(
                    model=get_model(modelType="inference"),
                    session_init=get_model_loader(model_path),
                    input_names=['image'],
                    output_names=get_model_output_names()))

something goes wrong, the log shows:

[1119 19:40:30 @sessinit.py:117] Restoring checkpoint from train_log/unet3d/model-5 ...
INFO:tensorflow:Restoring parameters from train_log/unet3d/model-5
[1119 19:40:27 @sessinit.py:90] WRN The following variables are in the checkpoint, but not found in the graph: global_step:0, learning_rate:0

from 3dunet-tensorflow-brats18.

tkuanlun350 avatar tkuanlun350 commented on June 10, 2024

The warning is as expected, variable global_step:0, learning_rate:0 are only used in training mode.
Is there other exception ?

from 3dunet-tensorflow-brats18.

huangmozhilv avatar huangmozhilv commented on June 10, 2024

I see. Thank you. Finally, I found a solution to avoid OOM:

import threading
thread = threading.Thread(target=self._eval(), name='self._eval')
thread.start()
thread.join() # wait threading to finish to close it to save memory

from 3dunet-tensorflow-brats18.

tkuanlun350 avatar tkuanlun350 commented on June 10, 2024

Nice ! It will be nice if you submit a pull request ! Maybe other people are facing the problem.

from 3dunet-tensorflow-brats18.

huangmozhilv avatar huangmozhilv commented on June 10, 2024

Hi @tkuanlun350 , now I'm adapting your code to LiTS (for liver segmentation) challenge. The 3D CT volume is much larger than BRATS, i.e. 512x512xn (n = 100~1000). In this case, I can't do online prediction with EvalCallBack() even the threading block as above is applied. The training process takes about 30GB of memory. When it comes to EvalCallBack() after several epochs (I used 10 epochs, each epoch = 250 steps), the memory soon increases to more than 60 GB and triggers OOM issue. I checked the processes with htop and some GB memory can be released when killing the pids marked as D status during EvalCallBack(). It seems that this is a deadlock problem. However, I can't guess where a deadlock can happen in the code. Could you give any clues to the solution? Thank you very much.

My major revision to your code

  1. In def get_eval_dataflow(), I changed mapdatacomponent() to mapdata to adapt to my inputs here.
  2. In class EvalCallback(callback), I wrapped threading functions to _eval() in def _trigger_epoch.

Below is my revised code:


def get_eval_dataflow(images_path, labels_path):
    # #if config.CROSS_VALIDATION:
    # imgs = SEG_loader.load_from_file(config.BASEDIR, config.VAL_MODE)
    # # no filter for training
    files = data_loader.load_files(images_path, labels_path)
    files = list(files)

    ds = DataFromListOfDict(files, ['id', 'image_data', 'gt', 'preprocessed']) # return yield [files(index)['id'], files(index)['image_data'], file(index)['gt'], file(index)['preprocessed']] (i.e. split-join each dict to a list)
    ds.reset_state()
    def eval_preprocess(data):
        if config.NO_CACHE:
            gt, im = data[2], data[1]
            volume_list, label, weight, original_shape, bbox = crop_brain_region(im, gt)
            batch = sampler3d_whole(volume_list, label, weight, original_shape, bbox, gt)
            # logger.info('volume_list[0].shape:{}, original_shape:{}, batch_images shape:{}, batch_original shape:{}, batch_bbox shape:{}'.format(volume_list[0].shape, original_shape, batch['images'].shape, str(batch['original_shape']), str(batch['bbox'])))
            for key in batch.keys():
                if isinstance(batch[key], np.ndarray):
                    batch[key] = np.ascontiguousarray(batch[key])
        else:
            volume_list, label, weight, original_shape, bbox = data[3]
            batch = sampler3d_whole(volume_list, label, weight, original_shape, bbox, gt)
         
        del volume_list
        del label
        del weight
        gc.collect()
        return [data[0], data[1], data[2], batch]

    ds = MapData(ds, eval_preprocess) # should return yield list to pass to PrefetchDataZMQ()?
    ds = PrefetchDataZMQ(ds, 1)

    del files
    return ds


class EvalCallback(Callback):
    def __init__(self, images_path, labels_path):
        self.chief_only = True # Only run this callback on chief training process.
        self.images_path = images_path
        self.labels_path = labels_path
        
    def _setup_graph(self):
        # ipdb.set_trace()
        self.pred = self.trainer.get_predictor(
            ['image'], get_model_output_names())
        self.df = get_eval_dataflow(self.images_path, self.labels_path)
    
    def _eval(self):
        logger.info('Evaluate after epoch {}'.format(self.epoch_num))
        scores = eval_brats(self.df, lambda img: segment_one_image(img, [self.pred], is_online=True), outdir=config.OUTDIR, epoch_num=self.epoch_num)
        fo = open(os.path.join(os.getcwd(),'eval_res.csv'), mode='a+')
        wo = csv.writer(fo, delimiter=',')
        for k, v in scores.items():
            self.trainer.monitors.put_scalar(k, v)
            wo.writerow([config.TASK, self.epoch_num, config.STEP_PER_EPOCH, k, v, tinies.datestr()])
        fo.flush()

    def _trigger_epoch(self):
        if self.epoch_num > 0 and self.epoch_num % config.EVAL_EPOCH == 0:
            # self._eval()
            thread = threading.Thread(target=self._eval(), name='self._eval')
            thread.start()
            thread.join() # wait threading to finish to close it to save memory

from 3dunet-tensorflow-brats18.

mini-Shark avatar mini-Shark commented on June 10, 2024

@huangmozhilv Hi, I'm working on online prediction too. did you have solve this problem ? could you please give me some advises on this problem ? Thanks

from 3dunet-tensorflow-brats18.

huangmozhilv avatar huangmozhilv commented on June 10, 2024

@mini-Shark Yes. I found the reason resides in QueueInput(get_train_dataflow()). get_train_dataflow(), the process from 'loading data from hard disk' to 'get data preprocessed for queue', is on going all the time. During online evaluation, when the queue of the preprocessed train data is full, newly preprocessed train data is stored in memory increasingly, resulting in out of memory problem.
Since we built our pipeline with pytorch from scratch and borrowed some code from this repo, we solved this problem by writing our own QueueInput() like function using built-in python module multiprocessing.

from 3dunet-tensorflow-brats18.

mini-Shark avatar mini-Shark commented on June 10, 2024

@huangmozhilv Saaad...Is there have some methods to avoid this situation ?may be this problem is stupid, but i didn't have time to rewrite whole pipeline : (

from 3dunet-tensorflow-brats18.

huangmozhilv avatar huangmozhilv commented on June 10, 2024

I have no idea using tensorpack.

from 3dunet-tensorflow-brats18.

mini-Shark avatar mini-Shark commented on June 10, 2024

@huangmozhilv Anyway, thanks for your reply

from 3dunet-tensorflow-brats18.

tkuanlun350 avatar tkuanlun350 commented on June 10, 2024

@mini-Shark Your problem is that you cannot do evaluation and training at the same time because of memory bottleneck ? You can try to change config. NO_CACHE = True to online load data.
When config. NO_CACHE = False, the images are all loaded and preprocessed to accelerate the training but it will consume a lot more memory.

If the problem is not solved, I think we can open a new issue for the problem for better discussion.

from 3dunet-tensorflow-brats18.

huangmozhilv avatar huangmozhilv commented on June 10, 2024

@tkuanlun350 It's a different problem. I think @mini-Shark should also set 'NO_CACHE = True'. The problem is that if the online evaluation takes long time(e.g. we have half of the BRATS dataset to online evaluation), the queue of training will get full, and preprocessed data from get_train_dataflow() will temporally be stored in memory instead of in queue.

from 3dunet-tensorflow-brats18.

tkuanlun350 avatar tkuanlun350 commented on June 10, 2024

@huangmozhilv Thanks ! I will try to investigate tensorpack source code to figure out a workaround.
The ugly solution is that you discard the queue input and just use feed_dict.

from 3dunet-tensorflow-brats18.

huangmozhilv avatar huangmozhilv commented on June 10, 2024

@tkuanlun350 Thank you.

from 3dunet-tensorflow-brats18.

mini-Shark avatar mini-Shark commented on June 10, 2024

@huangmozhilv @tkuanlun350 Thanks for you guys help me. now, maybe I found a trade-off solution is that add a additional parameter on 'PrefetchDataZMQ' when define 'get_train_dataflow()'. There 'PrefetchDataZMQ(ds, nr_proc=1, hwm=50)' have a default 'hwm=50' parameter, which control queue size of dataflow. I modify it to 'hwm=2'. And I also have modified 'get_eval_dataflow()' for don't load all validation data one time.

I'm not sure this will work properly, but it didn't raise OOM now(I have 64GB memory).

from 3dunet-tensorflow-brats18.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.