Comments (3)
Hey Michael, could you please explain what's 60,000 in "my kernels can be of size (100, 60000)"? And by kernels, do you mean NTKs? As far as I know, when computing the infinite-width NTKs (as it is the case in the notebook that you shared), the input width for the network doesn't affect the performance (neither computational nor memory) of computing the kernels. For instance, you can see that the kernel_fn
doesn't need parameters of the net as the input.
As far as I know, there's no easy way to reduce the memory footprint of computing infinite-width NTKs, except maybe deriving the exact NTK formulas analytically and computing them directly (as they did in https://github.com/LeoYu/neural-tangent-kernel-UCI for Convolutional-NTKs) as opposed to this repo which does it compositionally (for the sake of generality).
However, I would suggest using empirical NTKs instead of infinite-width NTKs. Particularly about this work that you have suggested, if I understand things correctly, they are treating the network as the fixed object and data as the trainable parameters, as opposed to data as the fixed object and network's weights as the trainable parameters. In this case, I highly suspect that using an empirical NTK with trained weights at the end of the training procedure would produce better results than using the infinite-width NTK, as the generalization of a finite-width network at the end of (proper) training is often better than that of a corresponding infinite-width network.
If you decide to use empirical NTKs, I would again suggest using pseudo-NTK (https://proceedings.mlr.press/v202/mohamadi23a.html), which approximates empirical-NTK almost perfectly at the end of training, and is orders of magnitude cheaper, both computational and memory complexity-wise. It's shown in the paper that you can use pNTK to compute full 50,000 x 50,000 kernels on datasets like CIFAR-10 with ResNet18 network on a reasonable machine available in academia.
Let me know if it helps!
from neural-tangents.
Hey Michael,
Unfortunately I'm not an expert on autograd, and I don't know many tricks in this regard. I just skimmed the code, and it seems like in the loss_acc_fn
they use sp.linalg.solve
to compute the kernel regression predictions. I'm not exactly sure how the gradient for this step is computed, but if it's taking gradient of iterative LU operations, that could require a lot of memory. (also see google/jax#1747)
I'd suggest replacing the np.linalg.solve
in that function with the cholesky solve alternative (see https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cho_solve.html and https://jax.readthedocs.io/en/latest/_autosummary/jax.scipy.linalg.cho_factor.html) for possible improvements both memory-wise and speed-wise.
And yes, nt.batch
is for computing the NTK kenel in batches (see https://neural-tangents.readthedocs.io/en/latest/batching.html).
from neural-tangents.
Hi @mohamad-amin thank you for your feedback. I will definitely look into that but at this moment I have to finalize the project as is.
So looking into the code better, I understand that the limitation isn't in computing k(x,x)
but rather doing backprop. If I understand correctly, ntk.batch
is mainly for kernel computation (forward pass). Is there anything to break up gradient calculation within NTK? If not I assume that is something to be done via JAX.
from neural-tangents.
Related Issues (20)
- The analytical output of GP can not fit the result of NNGP generated by the nt.predict.gp_inference HOT 1
- Question: Relu Kernel Computation HOT 3
- Question: Connection MLE "parametrized" GP in infinite Width Limit vs minimizing MSE "parametrized" Kernel in infinite Width HOT 4
- Question regarding lr in Neural Tangents Cookbook
- eNTK implementation uses deprecated xla attribute HOT 2
- Colab notebooks issue HOT 2
- How to obtain aleatoric uncertainty? HOT 2
- How to compute the empirical after kernel? HOT 1
- pip install issues HOT 2
- Erf function goes beyond [-1,1] HOT 2
- using stax.Cos(a=1.0, b=1.0, c=0.0) to get kernel from conv layer gives error HOT 2
- NTK is not PD
- stax.serial PSDness HOT 1
- How to use batch to gradient_descent_mse_ensemble ? HOT 1
- NTK/NNGP behavior in the infinite regime when weights are drawn from Gaussians with high standard deviation HOT 7
- NKT_mean output Nan, when the number of training sample is increased HOT 3
- Inefficient jacobian computation for embedding layers. HOT 1
- Question regarding the cookbook
- Calling the empirical kernel function with different parameters returns same result
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from neural-tangents.