Giter VIP home page Giter VIP logo

Comments (7)

renzhonglu11 avatar renzhonglu11 commented on June 16, 2024 1

+1. I have the same issue when I used the forward() function and loss() function seperately in the trainer.

from pykeen.

mberr avatar mberr commented on June 16, 2024 1

Hi @ferzcam (and @renzhonglu11 ),

this likely comes DistMult using a regularizer by default, cf.

regularizer: HintOrType[Regularizer] = LpRegularizer,
regularizer_kwargs: OptionalKwargs = None,
that needs to accumulate the regularization term across micro-batches. PyKEEN's pipelines call collect_regularization_term which not only collects the regularization terms across different places, but also releases their references. When this is not called, the term accumulates indefinetely, and holds references to buffers -> over time you will run out of memory.

To fix it, either

  • configure the model without regularizer, e.g., for DistMult pass regularizer=None, or
  • if you want to make use of the regularizer, make sure to call collect_regularization_term on the models and include this into your loss terms

from pykeen.

renzhonglu11 avatar renzhonglu11 commented on June 16, 2024 1

@mberr your solution works. Thanks a lot!😁

from pykeen.

mberr avatar mberr commented on June 16, 2024 1

You can see with

from pykeen.datasets import get_dataset
from pykeen.models import DistMult

dataset = get_dataset(dataset="nations")
model = DistMult(triples_factory=dataset.training)
print(model)

how the resulting structure looks like:

DistMult(
  (loss): MarginRankingLoss(
    (margin_activation): ReLU()
  )
  (interaction): DistMultInteraction()
  (entity_representations): ModuleList(
    (0): Embedding(
      (_embeddings): Embedding(14, 50)
    )
  )
  (relation_representations): ModuleList(
    (0): Embedding(
      (regularizer): LpRegularizer()
      (_embeddings): Embedding(55, 50)
    )
  )
  (weight_regularizers): ModuleList()
)

Notice how the relation_representations as a regularizer. compare this to

from pykeen.datasets import get_dataset
from pykeen.models import DistMult

dataset = get_dataset(dataset="nations")
model = DistMult(triples_factory=dataset.training, regularizer=None)
print(model)

resulting in

DistMult(
  ...
  (relation_representations): ModuleList(
    (0): Embedding(
      (_embeddings): Embedding(55, 50)
    )
  )
  ...
)

The regularization term of the relation embedding is updated here

# regularize *after* repeating
if self.regularizer is not None:
self.regularizer.update(x)

i.e., in the Embedding's forward call.

from pykeen.

renzhonglu11 avatar renzhonglu11 commented on June 16, 2024 1

Nice. Now it makes more sense. And I also found out where Pykeen calls the collect_regularization_term(). Thanks for your explanation.

from pykeen.

renzhonglu11 avatar renzhonglu11 commented on June 16, 2024

Thanks a lot. But I am still quite not sure if it is really the reason. I took a look at Pykeen's source code. It seems Pykeen calculates the loss just with a function in the trainer. (

batch_loss = self._forward_pass(
)
What I did is that I first call forward() of a model(like Distmult) to calculate the prediction value and then call loss to calculate the loss value. (similar to what @ferzcam did) Then the memory will increase per epoch during training despite using GPU or CPU.
However, when I put forward and loss calculation together in one function just like Pykeen, and call this function. The memory does not increase anymore. So confused about it. 🧐


In my KGE Model

class PykeenKGE:
    def training_step(self, batch):
        x_batch, y_batch = batch
        yhat_batch = self.forward(x_batch)
        loss_batch = self.loss(yhat_batch, y_batch)
        return loss_batch + self.model.collect_regularization_term()

In Trainer:

 batch_loss = self.model.training_step(batch)

from pykeen.

ferzcam avatar ferzcam commented on June 16, 2024

Hi @mberr. Thanks for the explanation. I was able to make it work now!

from pykeen.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.