training gpt on c4 with jax and equinox (and jaxamp).
View sample logs here
When running on the SCC, you may wish to run source scc_setup.sh
first. Then you may want to run wandb init
to setup wandb logging.
Note that sometimes the cuda versions get a little messed up. You can try: (1) do not have any scc cuda modules loaded (module unload cuda
) (2) reinstalling jax after pytorch (follow instructions here)
Then python trainer.py
to train.
The various options are specified in conf/config_gpt2.yaml
. You can override them in the command line like so: python trainer.py model.num_blocks=8 train.wandb_project=projectname