Giter VIP home page Giter VIP logo

vmc_jax's People

Contributors

emergentspacetime avatar jonasrigo avatar lagrange2art avatar laurinbrunner avatar markusschmitt avatar rehmoritz avatar tszoldra avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar

vmc_jax's Issues

Examples 5 and 6 raise "KeyError"

Since the latest change in jVMC/util/util.py examples 5 and 6 raise "KeyError: batch_size". Likewise test povm_t.py does no longer run without errors.

adding operator_string to branchfree_operator can lead to indexing errors

Adding operator strings (here just Sz(i)) to the hamiltonian without prefactor (i.e. not using jvmc.operator.scal_opstr) can lead to an indexing error when a certrain number of them are added (here L>=50 to repdrocude the error). Replacing in the following example the line h.add((jvmc.operator.Sz(i),)) by h.add(jvmc.operator.scal_opstr(1., (jvmc.operator.Sz(i),))) solves the issue for me.

It seems that if L is large enough, compilation is distributed via MPI (see branch_free.py line 224 and following), which then leads to the error because the prefactors (if that is the same as the actual factors provided in scal_opstr) are considered and the pure operator string does not have one(?). A possible solution would be to point out in the documentation to always use the scal_opstr method (even if the prefactor is 1) or to automatically apply scal_opstr with a factor of 1 if not done explicitly.

A minimal example to reproduce the error:

import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import jVMC as jvmc

# define transverse field Ising Hamiltonian
L = 50  # works with L <= 49

h = jvmc.operator.BranchFreeOperator()

for i in range(L):
    h.add((jvmc.operator.Sz(i),))
    h.add(jvmc.operator.scal_opstr(1., (jvmc.operator.Sx(i),)))

s = jnp.ones(shape=(1, 12, L), dtype=jnp.int32)
primes = h.get_s_primes(s)

The error message:

IndexError                                Traceback (most recent call last)
Cell In[1], line 16
     13     h.add(jvmc.operator.scal_opstr(1., (jvmc.operator.Sx(i),)))
     15 s = jnp.ones(shape=(1, 12, L), dtype=jnp.int32)
---> 16 primes = h.get_s_primes(s)

File [~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/base.py:141](http://localhost:8888/lab/tree/notebooks/~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/base.py#line=140), in Operator.get_s_primes(self, s, *args)
    139 if type(fun) is tuple:
    140     self.arg_fun = fun[1]
--> 141     args = self.arg_fun(*args)
    142     fun = fun[0]
    143 else:

File [~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/branch_free.py:240](http://localhost:8888/lab/tree/notebooks/~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/branch_free.py#line=239), in BranchFreeOperator.compile.<locals>.arg_fun(prefactor, init, *args)
    238     res = init[myStart:myEnd]
    239     for i,f in prefactor[myStart:myEnd]:
--> 240         res[i-myStart] = f(*args)
    242     res = np.concatenate(comm.allgather(res), axis=0)
    244 return (jnp.array(res), )

IndexError: index 51 is out of bounds for axis 0 with size 50

Frequent recompilation in v1.2.0

We introduced too frequent recompilation with v1.2.0, by moving the jit_my_stuff() to the global_defs module. This severely impedes performance.

Diagonalization on device can raise unhandeled exception

The diagonalization on the GPU device in TDVP.transform_to_eigenbasis done with jax.numpy.linalg.eigh can sometimes raise a ValueError. In this case the calculations cancels and needs to be restarted with the diagonalizeOnDevice parameter set to False, which makes the function fall back to the numpy CPU version of eigh.

I suggest to change this behaviour so that if diagonalizeOnDevice is True it first tries to use jax.numpy.linalg.eigh and if it fails it falls back to the numpy version automatically without the need to restart the entire calculation.

Error message:
ValueError: INTERNAL: CustomCall failed: jaxlib/cusolver_kernels.cc:444: operation cusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, static_cast<float*>(work), d.lwork, info) failed: cuSolver execution failed

batched Oloc always allocates complex zeros

The get_O_loc_batched function of Operator class calls the _alloc_Oloc_pmapd function which always allocates zeros with dtype global_defs.tCpx. In the POVM case the POVMOperator class returns real matrix elements when get_s_primes is called.
Inserting real valued numbers into a complex jax array with jax.lax.dynamic_update_slice will result in the following error:

TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got complex128, float64.

Please improve error for omission of op.BranchFreeOperator(lDim=<local_hilbert_dinension>)

Dear Devs,
I have noticed that when one creates a Hamiltonian based on a local Hilbert space that has dimension other than 2 one has to set the lDim variable in op.BranchFreeOperator to pad the operator strings with identities.
If one forgets to do that or does it incorrectly like so

import jVMC.operator as op
hamiltonian = op.BranchFreeOperator(lDim=1)
hamiltonian.add(op.scal_opstr( 1., ( op.Sz(0), op.Sz(0) ) ) )
hamiltonian.add(op.scal_opstr( 1., ( op.Sx(0),) ) )
hamiltonian.compile()

the error message is somewhat cryptic. Thus, I propose to change line 220 in branch_free.py as follows:

try:
    self.mapC = jnp.array(self.map, dtype=np.int32)
except Exception as e:
    raise ValueError("Check that you have set <local_hilbert_dinension> in op.BranchFreeOperator(lDim=<local_hilbert_dinension>) correctly.") from e

Best, Jonas

jVMC.util.util.init_net raises AttributeError

The function jVMC.util.util.init_net raises the error:

AttributeError: module 'jVMC.nets' has no attribute 'CpxRNN'

Since the class CpxRNN has been removed but is still referred to in this function.

CrossValidation cannot select subset of Eloc

The TDVP class tries to get Eloc[:, 0::2] in its __call__ method if crossValidation is True, but the SampledObs object is not subscriptable.
Also the call of solve still uses the old parameters.

Complex Weights Not Properly Saved in Network Checkpoints

I have come across an issue where complex weights in network checkpoints are not being saved properly. Currently, only the real parts of the complex numbers are saved and the imaginary parts are discarded. This behavior was encountered when using the write_network_checkpoint method of the OutputManager class.

Non-holomorphic networks with complex parameters not treated correctly

If we define the class

class MatrixMultiplication_NonHolomorphic(nn.Module):
    holo: bool = False

    @nn.compact
    def __call__(self, s):
        layer1 = nn.Dense(1, use_bias=False, **init_fn_args(dtype=global_defs.tCpx))
        out = layer1(2 * s.ravel() - 1)
        if not self.holo:
            out = out + 1e-1 * jnp.real(out)
        return jnp.sum(out)

and let holo=True, the gradients computed by psi.gradients for the input s=[0, 0, 0, 0] are
[-1.+0.j -1.+0.j -1.+0.j -1.+0.j -0.-1.j -0.-1.j -0.-1.j -0.-1.j] as expected.

However, if we let holo=False, the returned gradients are [-2.1+0.j -2.1+0.j -2.1+0.j -2.1+0.j]. This means, that if we for example did time evolution with this setup, considering only the imaginary part of the S-matrix we would get all-zeroes.

If we add the gradient function

def flat_gradient_cpx_nonholo(fun, params, arg):
    gr = grad(lambda p, y: jnp.real(fun.apply(p, y)))(params, arg)["params"]
    gi = grad(lambda p, y: jnp.imag(fun.apply(p, y)))(params, arg)["params"]
    g = tree_flatten(tree_map(lambda x, y: [x.ravel(), -y.ravel()], gr, gi))[0]
    return jnp.concatenate(g)

the returned gradients are [-1.1+0.j -1.1+0.j -1.1+0.j -1.1+0.j -0. -1.j -0. -1.j -0. -1.j -0. -1.j], which is in line with the above case where holo=True.

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.