Giter VIP home page Giter VIP logo

Comments (7)

zhangbububu avatar zhangbububu commented on August 17, 2024

hi, how can i enabling float64 precision ?

from neural-tangents.

romanngg avatar romanngg commented on August 17, 2024

Sorry for the late reply!

@zhangbububu see https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html#double-64bit-precision

@tengandreaxu could you try using Relu(do_stabilize=True)? https://neural-tangents.readthedocs.io/en/latest/_autosummary/neural_tangents.stax.Relu.html This parameter triggers a way of calculating the nonlinearity kernel in a way that helps prevent numerical overflow.

from neural-tangents.

tengandreaxu avatar tengandreaxu commented on August 17, 2024

Thank you so much, Roman. It's no problem at all!

import numpy as np
from neural_tangents import stax
from jax import jit

W_stds = list(range(1, 17))
# W_stds.reverse()
layer_fn = []
for i in range(len(W_stds) - 1):
    layer_fn.append(stax.Dense(1, W_std=W_stds[i]))
    layer_fn.append(stax.Relu(do_stabilize=True))

layer_fn.append(stax.Dense(1, 1.0, 0.0))
_, _, kernel_fn = stax.serial(*layer_fn)

kernel_fn = jit(kernel_fn, static_argnames="get")

x = np.random.rand(100, 100)

print(kernel_fn(x, x, "ntk"))

results in

[[2.61008562e+20 1.12163820e+20 1.23732785e+20 ... 1.08229372e+20
  1.05533967e+20 1.10687273e+20]
 [1.12163820e+20 2.92078984e+20 1.31143308e+20 ... 1.16449180e+20
  1.15616286e+20 1.19062657e+20]
 [1.23732785e+20 1.31143308e+20 3.36093753e+20 ... 1.28641726e+20
  1.19473708e+20 1.28997387e+20]
 ...
 [1.08229363e+20 1.16449180e+20 1.28641726e+20 ... 2.74442324e+20
  1.07858132e+20 1.20695995e+20]
 [1.05533967e+20 1.15616286e+20 1.19473708e+20 ... 1.07858132e+20
  2.69344883e+20 1.11830439e+20]
 [1.10687273e+20 1.19062657e+20 1.28997387e+20 ... 1.20695995e+20
  1.11830439e+20 2.83645061e+20]]

Do you think that there is no sense in having weights drawn from a higher standard deviation as we go deeper into the neural net in the infinite width?

from neural-tangents.

zhangbububu avatar zhangbububu commented on August 17, 2024

@romanngg @tengandreaxu

hi, i meet a confuse problem

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(1, W_std=1.5, b_std=0.05)
)

s = 10
l = jnp.pi * -s
r = jnp.pi * s 
N_tr = 100
N_te = 5
train_xs = jnp.linspace(l, r , N_tr).reshape(-1, 1).astype(jnp.float64)
train_ys = jnp.sin(train_xs) + jnp.sin(2*train_xs).astype(jnp.float64)
test_xs = jnp.linspace(l, r, N_te).reshape(-1, 1).astype(jnp.float64)

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_xs,
                                                      train_ys, diag_reg=1e-4)
nkt_mean, nkt_covariance = predict_fn(x_test=test_xs, get='ntk',
                                        compute_cov=True)
print(nkt_mean)


if i increate the number of training samples (N_tr), i will get a all NaN nkt_mean

from neural-tangents.

zhangbububu avatar zhangbububu commented on August 17, 2024

@romanngg @tengandreaxu

hi, i meet a confuse problem

init_fn, apply_fn, kernel_fn = stax.serial(
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(512, W_std=1.5, b_std=0.05), stax.Relu(do_stabilize=True),
    stax.Dense(1, W_std=1.5, b_std=0.05)
)

s = 10
l = jnp.pi * -s
r = jnp.pi * s 
N_tr = 100
N_te = 5
train_xs = jnp.linspace(l, r , N_tr).reshape(-1, 1).astype(jnp.float64)
train_ys = jnp.sin(train_xs) + jnp.sin(2*train_xs).astype(jnp.float64)
test_xs = jnp.linspace(l, r, N_te).reshape(-1, 1).astype(jnp.float64)

predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_xs,
                                                      train_ys, diag_reg=1e-4)
nkt_mean, nkt_covariance = predict_fn(x_test=test_xs, get='ntk',
                                        compute_cov=True)
print(nkt_mean)

if i increate the number of training samples (N_tr), i will get a all NaN nkt_mean

image image

from neural-tangents.

romanngg avatar romanngg commented on August 17, 2024

@tengandreaxu

Do you think that there is no sense in having weights drawn from a higher standard deviation as we go deeper into the neural net in the infinite width?

I think so, ideally you would want the mean and variance of your network outputs to match the mean and variance of your training labels, as a sensible prior. But even if your training labels have a large variance, it's common practice to just standardize them (together with test labels) to have mean 0 and variance 1 for best numerical stability.

Then in a Relu network, to have mean zero / variance one outputs (given mean zero, variance one inputs), you would want to set W_std=2**0.5 for all intermediate layers preceding Relus, and W_std=1 for the top layer.

@zhangbububu replied in your separate thread, let's continue there.

from neural-tangents.

tengandreaxu avatar tengandreaxu commented on August 17, 2024

Thank you for your prompt help Roman!

from neural-tangents.

Related Issues (20)

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.