a-rios / longmbart Goto Github PK
View Code? Open in Web Editor NEWLicense: Apache License 2.0
License: Apache License 2.0
I am trying to fine-tune my own longmbart on text simplification. But I am little stucked. Conversion worked but I got an Error when starting to fine-tune. I would really appreciate any hints on how to fix the problem.
pip install -q -r requirements.txt
python ./scripts/convert_mbart_to_longformerencoderdecoder.py \
--save_model_to ./output/converted-longmbart \
--attention_window 512 \
--cache_dir ./output/mbart-large-cc25 \
--base_model facebook/mbart-large-cc25 \
--tokenizer_name_or_path facebook/mbart-large-cc25\
--add_language_tags de_OR de_SI \
--initialize_tags de_DE de_DE \
--max_pos 1024 \
--verbose 1
python -m longformer.simplification \
--from_pretrained ./output/converted-longmbart \
--tokenizer ./output/converted-longmbart \
--save_dir ./output/longmbart-fine-tuned \
--save_prefix "w512" \
--train_source ./data/train-source.txt \
--train_target ./data/train-target.txt \
--val_source ./data/val-source.txt \
--val_target ./data/val-target.txt \
--test_source ./data/test-source.txt \
--test_target ./data/test-target.txt \
--max_output_len 1024 \
--max_input_len 1024 \
--batch_size 1 \
--grad_accum 60 \
--num_workers 5 \
--gpus 1 \
--seed 222 \
--attention_dropout 0.1 \
--dropout 0.3 \
--attention_mode sliding_chunks \
--attention_window 512 \
--label_smoothing 0.2 \
--lr 0.00003 \
--val_every 1.0 \
--val_percent_check 1.0 \
--test_percent_check 1.0 \
--early_stopping_metric 'rougeL' \
--patience 10 \
--lr_reduce_patience 8 \
--lr_reduce_factor 0.5 \
--grad_ckpt \
--progress_bar_refresh_rate 10 \
--tags_included
This threw the following RuntimeError:
Epoch 0: 0%| | 0/2 [00:00<?, ?it/s]
Traceback (most recent call last):
File "/opt/conda/lib/python3.9/runpy.py", line 197, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/opt/conda/lib/python3.9/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/home/jovyan/git/longmbart/longformer/simplification.py", line 527, in <module>
main(args)
File "/home/jovyan/git/longmbart/longformer/simplification.py", line 518, in main
trainer.fit(model)
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 510, in fit
results = self.accelerator_backend.train()
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 158, in train
results = self.ddp_train(process_idx=self.task_idx, model=model)
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 307, in ddp_train
results = self.train_or_test()
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/accelerators/accelerator.py", line 74, in train_or_test
results = self.trainer.train()
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/trainer.py", line 561, in train
self.train_loop.run_training_epoch()
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 549, in run_training_epoch
batch_output = self.run_training_batch(batch, batch_idx, dataloader_idx)
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 704, in run_training_batch
self.optimizer_step(optimizer, opt_idx, batch_idx, train_step_and_backward_closure)
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 482, in optimizer_step
model_ref.optimizer_step(
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/core/lightning.py", line 1296, in optimizer_step
optimizer.step(closure=optimizer_closure)
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 286, in step
self.__optimizer_step(*args, closure=closure, profiler_name=profiler_name, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/core/optimizer.py", line 140, in __optimizer_step
trainer.precision_connector.backend.optimizer_step(trainer, optimizer, closure)
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/plugins/native_amp.py", line 75, in optimizer_step
closure()
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 694, in train_step_and_backward_closure
result = self.training_step_and_backward(
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 792, in training_step_and_backward
result = self.training_step(split_batch, batch_idx, opt_idx, hiddens)
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/trainer/training_loop.py", line 316, in training_step
training_step_output = self.trainer.accelerator_backend.training_step(args)
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 164, in training_step
return self._step(args)
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/accelerators/ddp_accelerator.py", line 176, in _step
output = self.trainer.model(*args)
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/pytorch_lightning/overrides/data_parallel.py", line 179, in forward
output = self.module.training_step(*inputs[0], **kwargs[0])
File "/home/jovyan/git/longmbart/longformer/simplification.py", line 251, in training_step
output = self.forward(*batch)
File "/home/jovyan/git/longmbart/longformer/simplification.py", line 231, in forward
outputs = self.model(
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py", line 1346, in forward
outputs = self.model(
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py", line 1211, in forward
encoder_outputs = self.encoder(
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py", line 840, in forward
layer_outputs = encoder_layer(
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/opt/conda/lib/python3.9/site-packages/transformers/models/mbart/modeling_mbart.py", line 331, in forward
hidden_states, attn_weights, _ = self.self_attn(
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jovyan/git/longmbart/longformer/longformer_encoder_decoder.py", line 66, in forward
outputs = self.longformer_self_attn(
File "/opt/conda/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
return forward_call(*input, **kwargs)
File "/home/jovyan/git/longmbart/longformer/longformer.py", line 184, in forward
float_mask = float_mask.repeat(1, 1, repeat_size, 1)
RuntimeError: Number of dimensions of repeat dims can not be smaller than number of dimensions of tensor
โ
I have checked float_mask
and its size: torch.Size([1, 1, 1024, 1024, 1, 1])
. Which looks odd to me
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.