Comments (4)
Hi @DarrenZhang01, I'm happy to take this one. The idea here is that during computation of the kernel, we would like to know the shape of the intermediate (pre-)activations in the finite version of the computation. To do this, we use JAX's abstract evaluation machinery to infer the shapes using the init_fn
without actually instantiating any parameters. Here akey
is an abstract version of key that retains only shape and dtype information. If you want to know more, you might want to check out one of the JAX talks (there is a great one by Skye that's recorded somewhere) where they explain how tracing works.
from neural-tangents.
Thanks very much Sam @sschoenholz ! I see that akey
is an abstract level (ShapedArray
) of key representation. If I understand correctly, when akey
serves as the input for abstract_eval_fun
in _propagate_shape
, it is only used as a PartialVal
object in generating the Jaxpr
?
from neural-tangents.
I think that's basically correct. As a technical point, I believe JAX doesn't instantiate the jaxpr explicitly, but evaluates the shape while tracing the jaxpr.
from neural-tangents.
I see. Thanks, Sam!
from neural-tangents.
Related Issues (20)
- The analytical output of GP can not fit the result of NNGP generated by the nt.predict.gp_inference HOT 1
- Question: Relu Kernel Computation HOT 3
- Question: Connection MLE "parametrized" GP in infinite Width Limit vs minimizing MSE "parametrized" Kernel in infinite Width HOT 4
- Question regarding OOM issues HOT 3
- Question regarding lr in Neural Tangents Cookbook
- eNTK implementation uses deprecated xla attribute HOT 2
- Colab notebooks issue HOT 2
- How to obtain aleatoric uncertainty? HOT 2
- How to compute the empirical after kernel? HOT 1
- pip install issues HOT 2
- Erf function goes beyond [-1,1] HOT 2
- using stax.Cos(a=1.0, b=1.0, c=0.0) to get kernel from conv layer gives error HOT 2
- NTK is not PD
- stax.serial PSDness HOT 1
- How to use batch to gradient_descent_mse_ensemble ? HOT 1
- NTK/NNGP behavior in the infinite regime when weights are drawn from Gaussians with high standard deviation HOT 7
- NKT_mean output Nan, when the number of training sample is increased HOT 3
- Inefficient jacobian computation for embedding layers. HOT 1
- Question regarding the cookbook
- Calling the empirical kernel function with different parameters returns same result
Recommend Projects
-
React
A declarative, efficient, and flexible JavaScript library for building user interfaces.
-
Vue.js
🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
-
Typescript
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
-
TensorFlow
An Open Source Machine Learning Framework for Everyone
-
Django
The Web framework for perfectionists with deadlines.
-
Laravel
A PHP framework for web artisans
-
D3
Bring data to life with SVG, Canvas and HTML. 📊📈🎉
-
Recommend Topics
-
javascript
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
-
web
Some thing interesting about web. New door for the world.
-
server
A server is a program made to process requests and deliver data to clients.
-
Machine learning
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
-
Visualization
Some thing interesting about visualization, use data art
-
Game
Some thing interesting about game, make everyone happy.
Recommend Org
-
Facebook
We are working to build community through open source technology. NB: members must have two-factor auth.
-
Microsoft
Open source projects and samples from Microsoft.
-
Google
Google ❤️ Open Source for everyone.
-
Alibaba
Alibaba Open Source for everyone
-
D3
Data-Driven Documents codes.
-
Tencent
China tencent open source team.
from neural-tangents.