Giter VIP home page Giter VIP logo

Comments (4)

yihongL1U avatar yihongL1U commented on August 15, 2024 1

New update:

By fixing this, the training time is reduced! The new estimated training time for different model sizes is as follows:

  • ofa-768: 4.73 hours / 10K updates
  • ofa-400: 4.28 hours / 10K updates
  • ofa-200: 3.94 hours / 10K updates
  • ofa-100: 3.78 hours / 10K updates

from ofa.

fdschmidt93 avatar fdschmidt93 commented on August 15, 2024 1

Kudos, that was fast @yihongL1U :) Nice runtime improvements!

However, I would like to point out that the forward pass (basically matrices multiplication) should be quite efficient in Pytorch and the major time consumed in the training is backward pass (gradient computation + updates) as far as I know.

I held off on replying since, while I agree that this is true, I was inclined to reply to say "let's wait for the numbers" :) 62BN too many inner products is still a lot.

As I suspected, the difference in hours / 10K updates to the original table is quite staggering (c.-9 hours for ofa-{400,768}), and significantly diminishes the differences between variants {100, ..., 768}.

That said, the runtimes as originally Table 3 of course remain much more realistic for causal multilingual language models.

Good luck with your submission :)

from ofa.

yihongL1U avatar yihongL1U commented on August 15, 2024 1

Indeed, I didn't expect the difference would be this large. We appreciate your input on this problem that we didn't notice! Credit to you Fabian!

Thanks a lot!

from ofa.

yihongL1U avatar yihongL1U commented on August 15, 2024

Hi Fabian,

Thank you very much for your interest and your comment!

You are right that the original huggingface implementation of the masked language modeling pipeline is not as efficient as I thought... since it computes many unnecessary tokens! We didn't notice it ( HugginFace should consider this :( ). Thanks for pointing this out, and here is your suggested fix (it should be put into the forward pass of, e.g., XLMRobertaAssembledForMaskedLM):

original:

sequence_output = outputs[0]
prediction_scores = self.lm_head(sequence_output)

fix:

sequence_output = outputs[0]

# select ouputs and labels
valid_tokens = labels != -100
valid_tokens = valid_tokens.unsqueeze(-1).expand_as(sequence_output)
filtered_output = torch.masked_select(sequence_output, valid_tokens).view(-1, sequence_output.size(-1))
filtered_labels = torch.masked_select(labels, labels != -100)
sequence_output = filtered_output
labels = filtered_labels

prediction_scores = self.lm_head(sequence_output)

However, I would like to point out that the forward pass (basically matrices multiplication) should be quite efficient in Pytorch and the major time consumed in the training is backward pass (gradient computation + updates) as far as I know. Therefore lower-dimensional embeddings will be naturally more efficient. Nevertheless, I will re-estimate the training time based on the updated codes. Thank you very much for your input!

from ofa.

Related Issues (1)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.