Giter VIP home page Giter VIP logo

lightning-hydra-template's Introduction

Hi, I'm Lukas :octocat:

โœจ I like doing open-source and curiosity-driven research.


ashleve stats

lightning-hydra-template's People

Contributors

adizx12 avatar amorehead avatar ashleve avatar atong01 avatar binlee52 avatar caplett avatar cauliyang avatar charlesbmi avatar charlesgaydon avatar colobas avatar dependabot[bot] avatar dreaquil avatar elisim avatar eungbean avatar gscriva avatar gxinhu avatar hotthoughts avatar johnnynunez avatar luciennnnnnn avatar nils-werner avatar phimos avatar sirtris avatar steve-tod avatar tbazin avatar tesfaldet avatar yipliu avatar yongtae723 avatar yu-xiang-wang avatar yucao16 avatar zhengyu-yang avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

lightning-hydra-template's Issues

TF2.0 template

Hi!
Thanks for an awesome work
I wanted to ask, if there is anything similar to this template but for TF2.0 + Hydra bundle?
Didn't manage to find anything relevant on the GitHub

Best regards

[Bug] Top level config has to be OmegaConf DictConfig, plain dict, or a Structured Config class or instance

๐Ÿ› Bug

Description

I am creating a Multilingual Text Classifier using PyTorch Lightning and Hydra. I am using this template as reference for the project. However when I try to run the run.py file I keep getting an error.

Relevant Files

run.py

import dotenv
import hydra
from hydra.core.hydra_config import HydraConfig
from omegaconf import DictConfig
from pathlib import Path

dotenv.load_dotenv(override = True)

import sys
sys.path.append("/workspace/data/multilingual-text-classifier/src")

@hydra.main(config_path="/workspace/data/multilingual-text-classifier/configs/", config_name="config.yaml")
def main(config: DictConfig):

    # Imports should be nested inside @hydra.main to optimize tab completion
    # Read more here: https://github.com/facebookresearch/hydra/issues/934
    
    hydra_dir = Path(HydraConfig.get().run.dir)
    
    from train import train

    # Train model
    return train(config)

if __name__ == "__main__":
    main()

train.py

import hydra
from omegaconf import DictConfig
from pytorch_lightning import (
                        LightningDataModule,
                        LightningModule,
                        Trainer,
                        seed_everything
                    )

from datamodules.datamodule import JigsawDataModule

# Lightning
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping
from pytorch_lightning.loggers import TensorBoardLogger

from models.model import BERTModel

from termcolor import colored
import time

from typing import Optional

def train(config: DictConfig) -> Optional[float]:
       
    model: LightningModule =  hydra.utils.instantiate(config.model)

    logger = TensorBoardLogger("lightning_logs", name="toxic-comments")

    early_stopping_callback = EarlyStopping(monitor='val_loss', patience=2)

    trainer = pl.Trainer(
                logger=logger,
                callbacks=[early_stopping_callback],
                max_epochs=2,
                gpus=[0],
                progress_bar_refresh_rate=30,
                precision= 32
            )

    trainer.fit(model, datamodule = JigsawDataModule)

    trainer.save_checkpoint("/workspace/data/multilingual-text-classifier/model.ckpt")

config.yaml

model: model.yaml

model.yaml

_target_: src.models.model.BERTModel

TRAIN_BATCH_SIZE: 32
VALID_BATCH_SIZE: 64
EPOCHS: 2
LEARNING_RATE: 0.5 * 1e-5

Error

Error executing job with overrides: []
Traceback (most recent call last):
  File "run.py", line 41, in <module>
    main()
  File "/opt/conda/lib/python3.6/site-packages/hydra/main.py", line 53, in decorated_main
    config_name=config_name,
  File "/opt/conda/lib/python3.6/site-packages/hydra/_internal/utils.py", line 368, in _run_hydra
    lambda: hydra.run(
  File "/opt/conda/lib/python3.6/site-packages/hydra/_internal/utils.py", line 214, in run_and_report
    raise ex
  File "/opt/conda/lib/python3.6/site-packages/hydra/_internal/utils.py", line 211, in run_and_report
    return func()
  File "/opt/conda/lib/python3.6/site-packages/hydra/_internal/utils.py", line 371, in <lambda>
    overrides=args.overrides,
  File "/opt/conda/lib/python3.6/site-packages/hydra/_internal/hydra.py", line 110, in run
    _ = ret.return_value
  File "/opt/conda/lib/python3.6/site-packages/hydra/core/utils.py", line 233, in return_value
    raise self._return_value
  File "/opt/conda/lib/python3.6/site-packages/hydra/core/utils.py", line 160, in run_job
    ret.return_value = task_function(task_cfg)
  File "run.py", line 38, in main
    return train(config)
  File "/workspace/data/multilingual-text-classifier/src/train.py", line 41, in train
    model: LightningModule =  hydra.utils.instantiate(config.model)
  File "/opt/conda/lib/python3.6/site-packages/hydra/_internal/instantiate/_instantiate2.py", line 183, in instantiate
    "Top level config has to be OmegaConf DictConfig, plain dict, or a Structured Config class or instance"
hydra.errors.InstantiationException: Top level config has to be OmegaConf DictConfig, plain dict, or a Structured Config class or instance

System information

  • Hydra Version : 1.1.0
  • Python version : 3.6.10
  • Virtual environment type and version : conda 4.8.4
  • Operating system : Windows

This repo is awesome!

Hey @hobogalaxy , this template/boilerplate is spectacular! It's exactly what I've been looking for for the past couple of years. No issues to report, just wanted to say keep up the good work :) Thanks so much for your useful contribution.

How to load configurations of the previous experiment?

Say I have run multiple experiments and each saves config.yaml, hydra.yaml, and overrides.yaml. Now I want to rerun the best one. It seems that the config path is fixed in run.py, so I wonder there is an easy way to reload the specified configurations without need to modify run.py?

Remove WANDB_START_METHOD: thread

In a related issue of hydra I found out that setting WANDB_START_METHOD: thread is no longer required for hydra sweeps and it also affects hydra logging.
If you can confirm that the environment variable has become unnecessary I propose it would be an enhancement to remove it from the configuration. Also thank you for the great template. ๐Ÿ‘

Multi-GPU training on DDP mode will log multiple times

I notice when I use +trainer.accelerator="ddp", the log file have repeated logs(The number of repeats is equal to the number of GPUs), refer to official docs, "Lightning implementation of DDP calls your script under the hood multiple times with the correct environment variables:".
So we need a solution to handle this problem since DDP mode is much faster than DDP_spawn.

How to disable hydra logging if run with debug=True?

Not sure how to achieve this -- in the utils.py file, I see you are modifying the config object, but I am not sure how we can control the hydra configuration if run with debug flag since it is deleted from the config by hydra automatically.

access to neptune Run obj error

I find Neptune Run obj has more logging method can be used, like log pandas obj, so I modify the neptune.yaml:

neptune:
#  _target_: pytorch_lightning.loggers.neptune.NeptuneLogger
  _target_: neptune.new.integrations.pytorch_lightning.NeptuneLogger
  api_key: 
  project:

And in your train.py, the logger is a list:

    # Init lightning loggers
    logger: List[LightningLoggerBase] = []
    if "logger" in config:
        for _, lg_conf in config.logger.items():
            if "_target_" in lg_conf:
                log.info(f"Instantiating logger <{lg_conf._target_}>")
                logger.append(hydra.utils.instantiate(lg_conf)

In my LightningModule obj, I access the Run obj by self.logger.experiment[''], but find error:

TypeError: list indices must be integers or slices, not str

Because there is only one logger obj in logger, I fix the error by:

self.logger.experiment[0]['']

Although I fix the error, other logging methods in the logger list my not work, how to solve this problem? Thanks.

val logs are interspersed in training logs

2021/08/16 14:12:09 ย 
2021/08/16 14:12:09 Validation sanity check: 0it [00:00, ?it/s]
2021/08/16 14:12:09 Validation sanity check: 0%| | 0/2 [00:00<?, ?it/s]
2021/08/16 14:12:09 Training: -1it [00:00, ?it/s]
2021/08/16 14:12:09 Training: 0%| | 0/74 [00:00<00:00, 49344.75it/s]
2021/08/16 14:12:10 Epoch 0: 0%| | 0/74 [00:00<00:00, 3792.32it/s]
2021/08/16 14:12:10 Epoch 0: 7%|##3 | 5/74 [00:00<00:10, 6.76it/s]
2021/08/16 14:12:11 Epoch 0: 7%|8 | 5/74 [00:00<00:10, 6.76it/s, loss=1.11, v_num=N-35]
2021/08/16 14:12:11 Epoch 0: 14%|#4 | 10/74 [00:01<00:11, 5.51it/s, loss=1.11, v_num=N-35]
2021/08/16 14:12:12 Epoch 0: 14%|#4 | 10/74 [00:01<00:11, 5.50it/s, loss=1.06, v_num=N-35]
2021/08/16 14:12:12 Epoch 0: 20%|##2 | 15/74 [00:02<00:10, 5.51it/s, loss=1.06, v_num=N-35]
2021/08/16 14:12:13 Epoch 0: 20%|##2 | 15/74 [00:02<00:10, 5.50it/s, loss=1.09, v_num=N-35]
2021/08/16 14:12:13 Epoch 0: 27%|##9 | 20/74 [00:03<00:10, 5.28it/s, loss=1.09, v_num=N-35]
2021/08/16 14:12:14 Epoch 0: 27%|##9 | 20/74 [00:03<00:10, 5.28it/s, loss=1.05, v_num=N-35]
2021/08/16 14:12:14 Epoch 0: 34%|###7 | 25/74 [00:04<00:09, 5.36it/s, loss=1.05, v_num=N-35]
2021/08/16 14:12:15 Epoch 0: 34%|###3 | 25/74 [00:04<00:09, 5.36it/s, loss=0.968, v_num=N-35]
2021/08/16 14:12:15 Epoch 0: 41%|#### | 30/74 [00:05<00:08, 5.25it/s, loss=0.968, v_num=N-35]
2021/08/16 14:12:16 Epoch 0: 41%|#### | 30/74 [00:05<00:08, 5.25it/s, loss=0.948, v_num=N-35]
2021/08/16 14:12:16 Epoch 0: 47%|####7 | 35/74 [00:06<00:07, 5.29it/s, loss=0.948, v_num=N-35]
2021/08/16 14:12:17 Epoch 0: 47%|####7 | 35/74 [00:06<00:07, 5.29it/s, loss=0.851, v_num=N-35]
2021/08/16 14:12:17 Epoch 0: 54%|#####4 | 40/74 [00:07<00:06, 5.26it/s, loss=0.851, v_num=N-35]
2021/08/16 14:12:18 Epoch 0: 54%|#####4 | 40/74 [00:07<00:06, 5.25it/s, loss=0.892, v_num=N-35]
2021/08/16 14:12:18 Epoch 0: 61%|###### | 45/74 [00:08<00:05, 5.22it/s, loss=0.892, v_num=N-35]
2021/08/16 14:12:19 Epoch 0: 61%|###### | 45/74 [00:08<00:05, 5.22it/s, loss=0.879, v_num=N-35]
2021/08/16 14:12:19 Epoch 0: 68%|######7 | 50/74 [00:09<00:04, 5.23it/s, loss=0.879, v_num=N-35]
2021/08/16 14:12:20 Epoch 0: 68%|#######4 | 50/74 [00:09<00:04, 5.23it/s, loss=0.83, v_num=N-35]
2021/08/16 14:12:20 Epoch 0: 74%|########1 | 55/74 [00:10<00:03, 5.19it/s, loss=0.83, v_num=N-35]
2021/08/16 14:12:20 Epoch 0: 74%|#######4 | 55/74 [00:10<00:03, 5.19it/s, loss=0.847, v_num=N-35]
2021/08/16 14:12:20 Epoch 0: 81%|########1 | 60/74 [00:11<00:02, 5.31it/s, loss=0.847, v_num=N-35]
2021/08/16 14:12:21 Validating: 0it [00:00, ?it/s]
2021/08/16 14:12:21 Validating: 0%| | 0/15 [00:00<?, ?it/s]
2021/08/16 14:12:21 Validating: 33%|##########6 | 5/15 [00:00<00:00, 36.64it/s]
2021/08/16 14:12:21 Epoch 0: 88%|########7 | 65/74 [00:11<00:01, 5.68it/s, loss=0.847, v_num=N-35]
2021/08/16 14:12:21 Validating: 67%|####################6 | 10/15 [00:00<00:00, 37.37it/s]
2021/08/16 14:12:21 Epoch 0: 95%|#########4| 70/74 [00:11<00:00, 6.05it/s, loss=0.847, v_num=N-35]
2021/08/16 14:12:23 Validating: 100%|###############################| 15/15 [00:00<00:00, 37.81it/s]
2021/08/16 14:12:23 Epoch 0: 100%|#| 74/74 [00:14<00:00, 5.24it/s, loss=0.827, v_num=N-35, val/loss
2021/08/16 14:12:24 Epoch 0: 0%| | 0/74 [00:00<00:00, 10565.00it/s, loss=0.827, v_num=N-35, val/lo
2021/08/16 14:12:25 Epoch 1: 0%| | 0/74 [00:00<00:00, 474.79it/s, loss=0.827, v_num=N-35, val/loss
2021/08/16 14:12:25 Epoch 1: 7%| | 5/74 [00:00<00:11, 6.04it/s, loss=0.827, v_num=N-35, val/loss=
2021/08/16 14:12:26 Epoch 1: 7%| | 5/74 [00:00<00:11, 6.04it/s, loss=0.828, v_num=N-35, val/loss=
2021/08/16 14:12:26 Epoch 1: 14%|1| 10/74 [00:02<00:12, 5.13it/s, loss=0.828, v_num=N-35, val/loss
2021/08/16 14:12:27 Epoch 1: 14%|1| 10/74 [00:02<00:12, 5.13it/s, loss=0.843, v_num=N-35, val/loss
2021/08/16 14:12:27 Epoch 1: 20%|2| 15/74 [00:03<00:11, 5.26it/s, loss=0.843, v_num=N-35, val/loss
2021/08/16 14:12:28 Epoch 1: 20%|2| 15/74 [00:03<00:11, 5.25it/s, loss=0.885, v_num=N-35, val/loss

get_wandb_logger fails to retrieve WandbLogger when debug=True

Firstly, I love this repo - really great job!!
Maybe this not really a bug, but more of a "watch out for this"-kinda thing, which can be documented somewhere in the README.

Error:
Exception: You are using wandb related callback, but WandbLogger was not found for some reason...
This error occures when calling the function src.callbacks.wandb_callbacks.get_wandb_logger().

How to reproduce:

  1. Clone repo
  2. python run.py debug=True logger=wandb callbacks=wandb

Relevant information
pytorch_lightning==1.4.0

The problem
When setting debug=True, the Trainer is passed fast_dev_run=True.
In this mode, pytorch_lightning disables all logger(s), which causes get_wandb_logger() to fail.
https://pytorch-lightning.readthedocs.io/en/latest/common/debugging.html
Upon inspection of the trainer object. The trainer.logger is of type pytorch_lightning.loggers.base.DummyLogger, which I assume is the placeholder logger that lightning uses with fast_dev_run=True.
https://pytorch-lightning.readthedocs.io/en/latest/api/pytorch_lightning.loggers.base.html#pytorch_lightning.loggers.base.DummyLogger

I assume this is one of the reasons you have made a debug.yaml which uses fast_dev_run=False. My suggestion is to mention this as an option in the debugging section of the README.

Custom Callbacks

hi,
I am trying to write a custom callback.
Where and how does the callback object actually get instantiated?
Or more specifically how can I pass arguments to the callback?

My callback looks like this:

class ImagePredictionLogger(Callback):
    def __init__(self, val_samples, num_samples=32):
        super().__init__()
        self.num_samples = num_samples
        self.val_imgs, self.val_labels = val_samples

    def on_validation_epoch_end(self, trainer, pl_module):
        ...

When I run the code I get:
TypeError: Error instantiating 'src.callbacks.wandb_callbacks.ImageLogger' : __init__() missing 1 required positional argument: 'val_samples'

Which I guess makes sense but how can I pass the argument to the callback?
Thanks

Resubmit on slurm environment

I run experiments on a slurm environment with a time limit. I want to re-submit my job if it is not completed. Pytorch Lightning handles this by reload the weight. I test the code in this repo but find that resubmit on slurm is not supported.

I wonder how to enable resubmit in slurm environment. This is a great repo and I hope it can be further improved.

Thanks in advance.

adding a linter

Hi,

Thank you for the great work.
Do you consider adding pylint to this template?

It might be really useful.

Thank you

Conda setup script failes for CUDA11.1

Installing cudatoolkit=11.1 requires adding '-c nvidia'. See https://pytorch.org/get-started/locally/

I suggest the following addition:
`

Install pytorch

if [ "$cuda_version" == "none" ]; then
conda install -y pytorch=$pytorch_version torchvision torchaudio cpuonly -c pytorch
elif [ "$cuda_version" == "10.2" ]; then
conda install -y pytorch=$pytorch_version torchvision torchaudio cudatoolkit=$cuda_version -c pytorch
else
conda install -y pytorch=$pytorch_version torchvision torchaudio cudatoolkit=$cuda_version -c pytorch -c nvidia
`

BR Christian

ps. The template is awesome! ;)

Horovod error: TypeError: __init__() missing 2 required positional arguments: 'named_parameters' and 'compression'

Reproduce command

mpirun -np 2 python run.py +trainer.accelerator=horovod

the error is raised during trainer.test().

the full error info is below

Error executing job with overrides: ['+trainer.accelerator=horovod']
Traceback (most recent call last):
  File "run.py", line 31, in main
    return train(config)
  File "/home/user/code/lightning-hydra-template/src/train.py", line 82, in train
    trainer.test()
  File "/home/user/.conda/envs/torch18/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 579, in test
    results = self._run(model)
  File "/home/user/.conda/envs/torch18/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 753, in _run
    self.pre_dispatch()
  File "/home/user/.conda/envs/torch18/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 778, in pre_dispatch
    self.accelerator.pre_dispatch(self)
  File "/home/user/.conda/envs/torch18/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 108, in pre_dispatch
    self.training_type_plugin.pre_dispatch()
  File "/home/user/.conda/envs/torch18/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/horovod.py", line 93, in pre_dispatch
    optimizers = [
  File "/home/user/.conda/envs/torch18/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/horovod.py", line 94, in <listcomp>
    hvd.DistributedOptimizer(
  File "/home/user/.conda/envs/torch18/lib/python3.8/site-packages/horovod/torch/optimizer.py", line 585, in DistributedOptimizer
    return cls(optimizer.param_groups, named_parameters, compression, backward_passes_per_step, op,
  File "/home/user/.conda/envs/torch18/lib/python3.8/site-packages/horovod/torch/optimizer.py", line 41, in __init__
    super(self.__class__, self).__init__(params)
TypeError: __init__() missing 2 required positional arguments: 'named_parameters' and 'compression'

Multi-GPU bugs, AttributeError: Can't pickle local object 'log_hyperparameters.<locals>.<lambda>'

When I use Multi-GPU with 4 3090, I ran into AttributeError: Can't pickle local object 'log_hyperparameters..', it seems due to trainer.logger.log_hyperparams = lambda params: None trick in log_hyperparameters.

Traceback (most recent call last):
File "/ghome/luoxin/projects/liif-lightning-hydra/run.py", line 34, in main
return train(config)
File "/ghome/luoxin/projects/liif-lightning-hydra/src/train.py", line 78, in train
trainer.fit(model=model, datamodule=datamodule)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 499, in fit
self.dispatch()
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 546, in dispatch
self.accelerator.start_training(self)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/accelerators/accelerator.py", line 73, in start_training
self.training_type_plugin.start_training(trainer)
File "/opt/conda/lib/python3.8/site-packages/pytorch_lightning/plugins/training_type/ddp_spawn.py", line 108, in start_training
mp.spawn(self.new_process, **self.mp_spawn_kwargs)
File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
File "/opt/conda/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 179, in start_processes
process.start()
File "/opt/conda/lib/python3.8/multiprocessing/process.py", line 121, in start
self._popen = self._Popen(self)
File "/opt/conda/lib/python3.8/multiprocessing/context.py", line 284, in _Popen
return Popen(process_obj)
File "/opt/conda/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 32, in init
super().init(process_obj)
File "/opt/conda/lib/python3.8/multiprocessing/popen_fork.py", line 19, in init
self._launch(process_obj)
File "/opt/conda/lib/python3.8/multiprocessing/popen_spawn_posix.py", line 47, in _launch
reduction.dump(process_obj, fp)
File "/opt/conda/lib/python3.8/multiprocessing/reduction.py", line 60, in dump
ForkingPickler(file, protocol).dump(obj)
AttributeError: Can't pickle local object 'log_hyperparameters..'

is it possible to loop through wandb sweep?

Hi @ashleve ,

I'm wondering if it's possible to loop through sweeps.

For example, I have 2 datasets: A, and B.
according to ur solution here
I must sweep these 2 dataset individually by launching 2 wandb agents.

Is there any workaround where I can just call wandb sweep <config.yaml> once, and it will sweep these 2 datasets sequentially?

Thanks for your help!

"Release" versions for this template?

Great work on the updates! I'm wondering if you'd consider tagging stable releases. With all the commits recently, I'm not sure what features are still in development and what's ready to be used. Thanks!

trainer `ddp.yaml` doesn't override `default.yaml`?

Hi there,

This template is awesome, and I'm running some examples with it right now.

In configs/config.yaml, I changed

...
# specify here default training configuration
defaults:
  - trainer: default.yaml
...

but when I call python run.py trainer.gpus=4 trainer=ddp it only runs ddp.yaml configs without overriding the default.yaml.

Is this behavior expected? because when I'm reading lightning-transformer , by passing trainer=sharded, it actually overwrites the default.yaml like

trainer:
-  gpus: null
+  gpus: 1
   auto_select_gpus: false
   tpu_cores: null
   log_gpu_memory: null
   ...
   log_every_n_steps: 50
-  accelerator: null
+  accelerator: ddp
   sync_batchnorm: false
-  precision: 32
+  precision: 16
   weights_summary: top
   ....
   terminate_on_nan: false
   auto_scale_batch_size: false
   prepare_data_per_node: true
-  plugins: null
+  plugins:
+    _target_: pytorch_lightning.plugins.DDPShardedPlugin
   amp_backend: native
   amp_level: O2
   move_metrics_to_cpu: false
...

Add ability to resume training from latest checkpoint without specifying path

Add some kind of method to recursively go over everything in logs/, and find the latest saved checkpoint (find by date saved).
Add config flag for resuming training from the latest checkpoint:

resume_latest: True

Useful when we want to quickly resume our latest run without specifying ckpt path.

Should be added as an enhancement to utils.extras().

Could also automatically override the whole config with the correct one from .hydra folder.

Thoughts on making run.py more generic for calling other scripts?

Hi, thanks for this awesome template, its super clean!

I was thinking about what might be a good way to add more tasks to the config other than just training a pytorch lightning model, yet, aim to keep things clean, (more on that below). Some tasks could be:

  • evaluating a trained model
  • training a tokenizer
  • building a dataset (downloading, or doing preprocessing , etc)

About keeping things clean:

In the project https://github.com/Erlemar/pytorch_tempest, the author uses multiple scripts for different tasks, e.g., train.py. train_ner.py, predict.py... but I don't like the idea of placing many scripts in the top level directory. Also, placing them all in a scripts/ directory would make it awkward to work with code in src/ without pip installing that code. I like that in your template everything can be run without packaging the code in src, and how there's only a single script in the top level directory of the project. So, I'm thinking I'd like to try and extend run.py to support more tasks.

I'm wondering if you have any thoughts on how to structure the config in a way such that run.py could be generalized for performing more tasks? Example usage might be something like:

python run.py task=train_pl_model <TAB>

and hydra would suggest all the valid options that apply to the "train_pl_model" task.

Would this mean all the existing config files would need to be grouped together in a new sub directory corresponding to that task?

I'm new to hydra and learning as I go, so just thought I'd ask for your thoughts to see if it makes sense or if there is any plan to take this template in a similar direction in the future already. Thanks!

Optuna sweep stop hparams search study after some trails

My Optuna sweep always stops hparams search study after some trials, can you help me?
I run my code in 5 fold cross-validation style:

def train():
  acc = []
  for fold in range(fold_num):
    # init model, logger...
    train_model()
    test_model()
    acc.append(val/acc)

  return mean(acc)

And the sweep YAML file is following:

defaults:
  - override /hydra/sweeper: optuna
  
 optimized_metric: "val/mean_acc"

hydra:
  sweeper:
      _target_: hydra_plugins.hydra_optuna_sweeper.optuna_sweeper.OptunaSweeper
      storage: null
      study_name: null
      n_jobs: 1
  direction: maximize
  n_trials: 500

  sampler:
        _target_: optuna.samplers.TPESampler
        seed: 2021
        consider_prior: true
        prior_weight: 1.0
        consider_magic_clip: true
        consider_endpoints: false
        n_startup_trials: 10
        n_ei_candidates: 24
        multivariate: false
        warn_independent_sampling: true

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.