Giter VIP home page Giter VIP logo

nonwestlit's People

Contributors

devrimcavusoglu avatar gokcengokceoglu avatar

Stargazers

 avatar  avatar  avatar

Watchers

 avatar  avatar

nonwestlit's Issues

Add evaluation metrics to be logged to experiment tracking

Currently we only have loss computation on evaluation over validation dataset. For classification we can add more metrics which would be good to have an insight on. I think adding classification metrics for the prompt tuning setup could be a little bit tricky.

Add LoRA or QLoRA fine-tuning setup.

Low-Rank Adaptation of Large Language Models (LoRA) is a training method that accelerates the training of large models while consuming less memory. It adds pairs of rank-decomposition weight matrices (called update matrices) to existing weights, and only trains those newly added weights.

paper: https://arxiv.org/pdf/2106.09685.pdf
peft: https://huggingface.co/docs/peft/conceptual_guides/lora
peft-finetuning: https://huggingface.co/docs/trl/main/en/lora_tuning_peft

for quantization refer to bitsandbytes

New evaluation script (post-training)

Currently evaluation (while training) evaluates all chunks individually even if they belong to the same article. At test time, we should normally classify an article and not chunks. In the evaluation some kind of aggregation or pooling should be done either through scores (which would be more appropriate) or with predicted chunk classes.

Either way, we should produce a single class for an article, not multiple class predictions.

Llama-2 OOM problem.

Currently the articles in the datasets are long, and truncation by collators truncates the tokens to the model's max seq. length anyway even if the setting is to pad/truncate w.r.t the 'longest' sequence in the batch. This is probably because articles in the dataset are too long. A solution to this could be to chunkify the articles and process by each chunk.

Apart from that, this is also required as currently models cannot see the whole article during training.

Fancy thought: Another solution rather than chunkifying could be to get random chunks over articles during training, though I presume this require much more training epochs. However, might be a good way for training ๐Ÿ˜… wdyt ? @gokcengokceoglu

ValueError: Target size (torch.Size([4])) must be the same as input size (torch.Size([4, 3]))

Error while training on first-level dataset.

Traceback

Traceback (most recent call last):
  File "<frozen runpy>", line 198, in _run_module_as_main
  File "<frozen runpy>", line 88, in _run_code
  File "/home/devrim/lab/gh/ms/nonwestlit/nonwestlit/__main__.py", line 8, in <module>
    fire.Fire({"train": train, "predict": predict, "evaluate": evaluate})
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/fire/core.py", line 141, in Fire
    component_trace = _Fire(component, args, parsed_flag_args, context, name)
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/fire/core.py", line 475, in _Fire
    component, remaining_args = _CallAndUpdateTrace(
                                ^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/fire/core.py", line 691, in _CallAndUpdateTrace
    component = fn(*varargs, **kwargs)
                ^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/lab/gh/ms/nonwestlit/nonwestlit/training.py", line 318, in train
    return trainer.train()
           ^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/transformers/trainer.py", line 1537, in train
    return inner_training_loop(
           ^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/transformers/trainer.py", line 1854, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/transformers/trainer.py", line 2728, in training_step
    loss = self.compute_loss(model, inputs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/transformers/trainer.py", line 2751, in compute_loss
    outputs = model(**inputs)
              ^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/accelerate/utils/operations.py", line 636, in forward
    return model_forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/accelerate/utils/operations.py", line 624, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
                           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/torch/amp/autocast_mode.py", line 16, in decorate_autocast
    return func(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/peft/peft_model.py", line 886, in forward
    return self.base_model(
           ^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/peft/tuners/tuners_utils.py", line 103, in forward
    return self.model.forward(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/transformers/models/llama/modeling_llama.py", line 1401, in forward
    loss = loss_fct(pooled_logits, labels)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/torch/nn/modules/loss.py", line 725, in forward
    return F.binary_cross_entropy_with_logits(input, target,
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/devrim/miniconda3/envs/nonwestlit/lib/python3.11/site-packages/torch/nn/functional.py", line 3193, in binary_cross_entropy_with_logits
    raise ValueError(f"Target size ({target.size()}) must be the same as input size ({input.size()})")
ValueError: Target size (torch.Size([4])) must be the same as input size (torch.Size([4, 3]))

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.