Comments (16)
I didn'y met this problem before. 60G memory usage sounds impossible to me.
Can you try to use offline evaluation to see if the problem still exist ? ex: python3 train.py --load /path/to/ckpt/ --evaluate ...
from 3dunet-tensorflow-brats18.
I changed to offline_evaluate()
by using
model_path='train_log/unet3d/model-5'
pred = OfflinePredictor(PredictConfig(
model=get_model(modelType="inference"),
session_init=get_model_loader(model_path),
input_names=['image'],
output_names=get_model_output_names()))
something goes wrong, the log shows:
[1119 19:40:30 @sessinit.py:117] Restoring checkpoint from train_log/unet3d/model-5 ...
INFO:tensorflow:Restoring parameters from train_log/unet3d/model-5
[1119 19:40:27 @sessinit.py:90] WRN The following variables are in the checkpoint, but not found in the graph: global_step:0, learning_rate:0
from 3dunet-tensorflow-brats18.
The warning is as expected, variable global_step:0, learning_rate:0 are only used in training mode.
Is there other exception ?
from 3dunet-tensorflow-brats18.
I see. Thank you. Finally, I found a solution to avoid OOM:
import threading
thread = threading.Thread(target=self._eval(), name='self._eval')
thread.start()
thread.join() # wait threading to finish to close it to save memory
from 3dunet-tensorflow-brats18.
Nice ! It will be nice if you submit a pull request ! Maybe other people are facing the problem.
from 3dunet-tensorflow-brats18.
Hi @tkuanlun350 , now I'm adapting your code to LiTS (for liver segmentation) challenge. The 3D CT volume is much larger than BRATS, i.e. 512x512xn (n = 100~1000). In this case, I can't do online prediction with EvalCallBack()
even the threading
block as above is applied. The training process takes about 30GB of memory. When it comes to EvalCallBack()
after several epochs (I used 10 epochs, each epoch = 250 steps), the memory soon increases to more than 60 GB and triggers OOM issue. I checked the processes with htop
and some GB memory can be released when killing the pids marked as D
status during EvalCallBack()
. It seems that this is a deadlock problem. However, I can't guess where a deadlock can happen in the code. Could you give any clues to the solution? Thank you very much.
My major revision to your code
- In
def get_eval_dataflow()
, I changedmapdatacomponent()
tomapdata
to adapt to my inputs here. - In
class EvalCallback(callback)
, I wrappedthreading
functions to_eval()
indef _trigger_epoch
.
Below is my revised code:
def get_eval_dataflow(images_path, labels_path):
# #if config.CROSS_VALIDATION:
# imgs = SEG_loader.load_from_file(config.BASEDIR, config.VAL_MODE)
# # no filter for training
files = data_loader.load_files(images_path, labels_path)
files = list(files)
ds = DataFromListOfDict(files, ['id', 'image_data', 'gt', 'preprocessed']) # return yield [files(index)['id'], files(index)['image_data'], file(index)['gt'], file(index)['preprocessed']] (i.e. split-join each dict to a list)
ds.reset_state()
def eval_preprocess(data):
if config.NO_CACHE:
gt, im = data[2], data[1]
volume_list, label, weight, original_shape, bbox = crop_brain_region(im, gt)
batch = sampler3d_whole(volume_list, label, weight, original_shape, bbox, gt)
# logger.info('volume_list[0].shape:{}, original_shape:{}, batch_images shape:{}, batch_original shape:{}, batch_bbox shape:{}'.format(volume_list[0].shape, original_shape, batch['images'].shape, str(batch['original_shape']), str(batch['bbox'])))
for key in batch.keys():
if isinstance(batch[key], np.ndarray):
batch[key] = np.ascontiguousarray(batch[key])
else:
volume_list, label, weight, original_shape, bbox = data[3]
batch = sampler3d_whole(volume_list, label, weight, original_shape, bbox, gt)
del volume_list
del label
del weight
gc.collect()
return [data[0], data[1], data[2], batch]
ds = MapData(ds, eval_preprocess) # should return yield list to pass to PrefetchDataZMQ()?
ds = PrefetchDataZMQ(ds, 1)
del files
return ds
class EvalCallback(Callback):
def __init__(self, images_path, labels_path):
self.chief_only = True # Only run this callback on chief training process.
self.images_path = images_path
self.labels_path = labels_path
def _setup_graph(self):
# ipdb.set_trace()
self.pred = self.trainer.get_predictor(
['image'], get_model_output_names())
self.df = get_eval_dataflow(self.images_path, self.labels_path)
def _eval(self):
logger.info('Evaluate after epoch {}'.format(self.epoch_num))
scores = eval_brats(self.df, lambda img: segment_one_image(img, [self.pred], is_online=True), outdir=config.OUTDIR, epoch_num=self.epoch_num)
fo = open(os.path.join(os.getcwd(),'eval_res.csv'), mode='a+')
wo = csv.writer(fo, delimiter=',')
for k, v in scores.items():
self.trainer.monitors.put_scalar(k, v)
wo.writerow([config.TASK, self.epoch_num, config.STEP_PER_EPOCH, k, v, tinies.datestr()])
fo.flush()
def _trigger_epoch(self):
if self.epoch_num > 0 and self.epoch_num % config.EVAL_EPOCH == 0:
# self._eval()
thread = threading.Thread(target=self._eval(), name='self._eval')
thread.start()
thread.join() # wait threading to finish to close it to save memory
from 3dunet-tensorflow-brats18.
@huangmozhilv Hi, I'm working on online prediction too. did you have solve this problem ? could you please give me some advises on this problem ? Thanks
from 3dunet-tensorflow-brats18.
@mini-Shark Yes. I found the reason resides in QueueInput(get_train_dataflow())
. get_train_dataflow()
, the process from 'loading data from hard disk' to 'get data preprocessed for queue', is on going all the time. During online evaluation, when the queue of the preprocessed train data is full, newly preprocessed train data is stored in memory increasingly, resulting in out of memory problem.
Since we built our pipeline with pytorch from scratch and borrowed some code from this repo, we solved this problem by writing our own QueueInput()
like function using built-in python module multiprocessing
.
from 3dunet-tensorflow-brats18.
@huangmozhilv Saaad...Is there have some methods to avoid this situation ?may be this problem is stupid, but i didn't have time to rewrite whole pipeline : (
from 3dunet-tensorflow-brats18.
I have no idea using tensorpack.
from 3dunet-tensorflow-brats18.
@huangmozhilv Anyway, thanks for your reply
from 3dunet-tensorflow-brats18.
@mini-Shark Your problem is that you cannot do evaluation and training at the same time because of memory bottleneck ? You can try to change config. NO_CACHE = True to online load data.
When config. NO_CACHE = False, the images are all loaded and preprocessed to accelerate the training but it will consume a lot more memory.
If the problem is not solved, I think we can open a new issue for the problem for better discussion.
from 3dunet-tensorflow-brats18.
@tkuanlun350 It's a different problem. I think @mini-Shark should also set 'NO_CACHE = True'. The problem is that if the online evaluation takes long time(e.g. we have half of the BRATS dataset to online evaluation), the queue of training will get full, and preprocessed data from get_train_dataflow()
will temporally be stored in memory instead of in queue.
from 3dunet-tensorflow-brats18.
@huangmozhilv Thanks ! I will try to investigate tensorpack source code to figure out a workaround.
The ugly solution is that you discard the queue input and just use feed_dict.
from 3dunet-tensorflow-brats18.
@tkuanlun350 Thank you.
from 3dunet-tensorflow-brats18.
@huangmozhilv @tkuanlun350 Thanks for you guys help me. now, maybe I found a trade-off solution is that add a additional parameter on 'PrefetchDataZMQ' when define 'get_train_dataflow()'. There 'PrefetchDataZMQ(ds, nr_proc=1, hwm=50)' have a default 'hwm=50' parameter, which control queue size of dataflow. I modify it to 'hwm=2'. And I also have modified 'get_eval_dataflow()' for don't load all validation data one time.
I'm not sure this will work properly, but it didn't raise OOM now(I have 64GB memory).
from 3dunet-tensorflow-brats18.
Related Issues (20)
- AssertionError HOT 6
- How can I get the gif image like in your readme? HOT 5
- dataset structure HOT 1
- unrecognized arguments: --logdir=./train_log/unet3d --gpu 0
- Training error: CUDA_ERROR_OUT_OF_MEMORY HOT 12
- why i run 'python3 train.py --logdir=./train_log/unet3d --gpu 0' doesn't produce 'train_log/unet3d/model-30000'
- Training With Google Colab HOT 2
- Could you please add the testing result for brats 2018?
- TypeError: join() argument must be str or bytes, not 'list'
- ValueError: Layer named InstanceNorm5d is already registered!
- Running Op sync_variables_from_main_tower...
- Low GPU utilization rate
- ImportError: cannot import name 'Iterator' from 'tensorflow.contrib.data'
- cannot import name 'get_tf_version_number' from 'tensorpack.tfutils.common' HOT 2
- To Set up enviroment......
- No idea of how to save a model in python, help :)
- "Dimension 0 in both shapes must be equal,but are 4 and 3"
- Your code had killed or hanging program when finished epoch 1
- At the inference stage, are there some problems in the sampling method of patches in Function batch_segmentation
- can I use it using only cpu?
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from 3dunet-tensorflow-brats18.