Giter VIP home page Giter VIP logo

Comments (3)

mohamad-amin avatar mohamad-amin commented on July 17, 2024 1

Hey Michael, could you please explain what's 60,000 in "my kernels can be of size (100, 60000)"? And by kernels, do you mean NTKs? As far as I know, when computing the infinite-width NTKs (as it is the case in the notebook that you shared), the input width for the network doesn't affect the performance (neither computational nor memory) of computing the kernels. For instance, you can see that the kernel_fn doesn't need parameters of the net as the input.

As far as I know, there's no easy way to reduce the memory footprint of computing infinite-width NTKs, except maybe deriving the exact NTK formulas analytically and computing them directly (as they did in https://github.com/LeoYu/neural-tangent-kernel-UCI for Convolutional-NTKs) as opposed to this repo which does it compositionally (for the sake of generality).

However, I would suggest using empirical NTKs instead of infinite-width NTKs. Particularly about this work that you have suggested, if I understand things correctly, they are treating the network as the fixed object and data as the trainable parameters, as opposed to data as the fixed object and network's weights as the trainable parameters. In this case, I highly suspect that using an empirical NTK with trained weights at the end of the training procedure would produce better results than using the infinite-width NTK, as the generalization of a finite-width network at the end of (proper) training is often better than that of a corresponding infinite-width network.

If you decide to use empirical NTKs, I would again suggest using pseudo-NTK (https://proceedings.mlr.press/v202/mohamadi23a.html), which approximates empirical-NTK almost perfectly at the end of training, and is orders of magnitude cheaper, both computational and memory complexity-wise. It's shown in the paper that you can use pNTK to compute full 50,000 x 50,000 kernels on datasets like CIFAR-10 with ResNet18 network on a reasonable machine available in academia.

Let me know if it helps!

from neural-tangents.

mohamad-amin avatar mohamad-amin commented on July 17, 2024 1

Hey Michael,

Unfortunately I'm not an expert on autograd, and I don't know many tricks in this regard. I just skimmed the code, and it seems like in the loss_acc_fn they use sp.linalg.solve to compute the kernel regression predictions. I'm not exactly sure how the gradient for this step is computed, but if it's taking gradient of iterative LU operations, that could require a lot of memory. (also see google/jax#1747)
I'd suggest replacing the np.linalg.solve in that function with the cholesky solve alternative (see https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cho_solve.html and https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cho_factor.html) for possible improvements both memory-wise and speed-wise.

And yes, nt.batch is for computing the NTK kenel in batches (see https://neural-tangents.readthedocs.io/en/latest/batching.html).

from neural-tangents.

MichaelMMeskhi avatar MichaelMMeskhi commented on July 17, 2024

Hi @mohamad-amin thank you for your feedback. I will definitely look into that but at this moment I have to finalize the project as is.

So looking into the code better, I understand that the limitation isn't in computing k(x,x) but rather doing backprop. If I understand correctly, ntk.batch is mainly for kernel computation (forward pass). Is there anything to break up gradient calculation within NTK? If not I assume that is something to be done via JAX.

from neural-tangents.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.