markusschmitt / vmc_jax Goto Github PK
View Code? Open in Web Editor NEWImpementation of Variational Monte Carlo (VMC) for quantum many-body dynamics using JAX.
License: MIT License
Impementation of Variational Monte Carlo (VMC) for quantum many-body dynamics using JAX.
License: MIT License
Since the latest change in jVMC/util/util.py examples 5 and 6 raise "KeyError: batch_size". Likewise test povm_t.py does no longer run without errors.
Allow custom definition of branching operators.
From the examples ex5_dissipative_Lindblad.py is not working because it relies on jax.ops.index_add() in vmc_jax/jVMC/operator/povm.py l.86, which was used to update elements of a jax-array. However it is deprecated (see https://jax.readthedocs.io/en/latest/jax.ops.html#module-jax.ops) and should be replaced with probs.at[idx].set( value ).
Pass MC sampler seed as integer instead of PRNGKey
Jax 0.2.18 works fine.
0.2.21 has seen changes e.g. in partial(), causing crashes.
Adding operator strings (here just Sz(i)) to the hamiltonian without prefactor (i.e. not using jvmc.operator.scal_opstr) can lead to an indexing error when a certrain number of them are added (here L>=50 to repdrocude the error). Replacing in the following example the line h.add((jvmc.operator.Sz(i),))
by h.add(jvmc.operator.scal_opstr(1., (jvmc.operator.Sz(i),)))
solves the issue for me.
It seems that if L is large enough, compilation is distributed via MPI (see branch_free.py line 224 and following), which then leads to the error because the prefactors (if that is the same as the actual factors provided in scal_opstr) are considered and the pure operator string does not have one(?). A possible solution would be to point out in the documentation to always use the scal_opstr method (even if the prefactor is 1) or to automatically apply scal_opstr with a factor of 1 if not done explicitly.
A minimal example to reproduce the error:
import jax.numpy as jnp
import jax
jax.config.update("jax_enable_x64", True)
import jVMC as jvmc
# define transverse field Ising Hamiltonian
L = 50 # works with L <= 49
h = jvmc.operator.BranchFreeOperator()
for i in range(L):
h.add((jvmc.operator.Sz(i),))
h.add(jvmc.operator.scal_opstr(1., (jvmc.operator.Sx(i),)))
s = jnp.ones(shape=(1, 12, L), dtype=jnp.int32)
primes = h.get_s_primes(s)
The error message:
IndexError Traceback (most recent call last)
Cell In[1], line 16
13 h.add(jvmc.operator.scal_opstr(1., (jvmc.operator.Sx(i),)))
15 s = jnp.ones(shape=(1, 12, L), dtype=jnp.int32)
---> 16 primes = h.get_s_primes(s)
File [~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/base.py:141](http://localhost:8888/lab/tree/notebooks/~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/base.py#line=140), in Operator.get_s_primes(self, s, *args)
139 if type(fun) is tuple:
140 self.arg_fun = fun[1]
--> 141 args = self.arg_fun(*args)
142 fun = fun[0]
143 else:
File [~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/branch_free.py:240](http://localhost:8888/lab/tree/notebooks/~/Programming/jvmc_playground/jvmc_cpu/lib/python3.12/site-packages/jVMC/operator/branch_free.py#line=239), in BranchFreeOperator.compile.<locals>.arg_fun(prefactor, init, *args)
238 res = init[myStart:myEnd]
239 for i,f in prefactor[myStart:myEnd]:
--> 240 res[i-myStart] = f(*args)
242 res = np.concatenate(comm.allgather(res), axis=0)
244 return (jnp.array(res), )
IndexError: index 51 is out of bounds for axis 0 with size 50
flax have changed their API for RNN cells. We need to adjust our network definitions accordingly.
Flax recently changed their concept for network definition, see https://flax.readthedocs.io/en/latest/flax.linen.html. Need to adapt it.
We introduced too frequent recompilation with v1.2.0, by moving the jit_my_stuff()
to the global_defs
module. This severely impedes performance.
The diagonalization on the GPU device in TDVP.transform_to_eigenbasis done with jax.numpy.linalg.eigh can sometimes raise a ValueError. In this case the calculations cancels and needs to be restarted with the diagonalizeOnDevice parameter set to False, which makes the function fall back to the numpy CPU version of eigh.
I suggest to change this behaviour so that if diagonalizeOnDevice is True it first tries to use jax.numpy.linalg.eigh and if it fails it falls back to the numpy version automatically without the need to restart the entire calculation.
Error message:
ValueError: INTERNAL: CustomCall failed: jaxlib/cusolver_kernels.cc:444: operation cusolverDnSsyevd(handle.get(), jobz, d.uplo, d.n, a, d.n, w, static_cast<float*>(work), d.lwork, info) failed: cuSolver execution failed
Remove astype(np.int32) in lines 189 and 185 of sampler.py and enforce integer-type in nets directly.
Also see https://jax.readthedocs.io/en/latest/_autosummary/jax.nn.one_hot.html, as one_hot returns float64 as new standard.
The get_O_loc_batched
function of Operator
class calls the _alloc_Oloc_pmapd
function which always allocates zeros with dtype global_defs.tCpx
. In the POVM case the POVMOperator
class returns real matrix elements when get_s_primes
is called.
Inserting real valued numbers into a complex jax array with jax.lax.dynamic_update_slice
will result in the following error:
TypeError: lax.dynamic_update_slice requires arguments to have the same dtypes, got complex128, float64.
The method get_network_checkpoint of the OutputManager class has the default value -1 for the time parameter, but raises an Exception for negative time values.
pmap devices are not properly assigned in the stats
class.
The operator base class should be modified such that explicit dependence on external parameters is possible.
ex5_dissipative_Lindblad.py and ex6_dissipative_Lindblad_2D.py raise the error message: "AttributeError: 'NQS' object has no attribute '_param_unflatten_cpx'".
The '_param_unflatten_cpx' method was removed from the NQS class in commit "Cleanup of the NQS class.".
Dear Devs,
I have noticed that when one creates a Hamiltonian based on a local Hilbert space that has dimension other than 2 one has to set the lDim
variable in op.BranchFreeOperator
to pad the operator strings with identities.
If one forgets to do that or does it incorrectly like so
import jVMC.operator as op
hamiltonian = op.BranchFreeOperator(lDim=1)
hamiltonian.add(op.scal_opstr( 1., ( op.Sz(0), op.Sz(0) ) ) )
hamiltonian.add(op.scal_opstr( 1., ( op.Sx(0),) ) )
hamiltonian.compile()
the error message is somewhat cryptic. Thus, I propose to change line 220 in branch_free.py
as follows:
try:
self.mapC = jnp.array(self.map, dtype=np.int32)
except Exception as e:
raise ValueError("Check that you have set <local_hilbert_dinension> in op.BranchFreeOperator(lDim=<local_hilbert_dinension>) correctly.") from e
Best, Jonas
We should set up tests that specifically check the functioning of our MPI wrappers using multiple processes.
The function jVMC.util.util.init_net raises the error:
AttributeError: module 'jVMC.nets' has no attribute 'CpxRNN'
Since the class CpxRNN has been removed but is still referred to in this function.
The TDVP
class tries to get Eloc[:, 0::2]
in its __call__
method if crossValidation
is True
, but the SampledObs
object is not subscriptable.
Also the call of solve
still uses the old parameters.
I have come across an issue where complex weights in network checkpoints are not being saved properly. Currently, only the real parts of the complex numbers are saved and the imaginary parts are discarded. This behavior was encountered when using the write_network_checkpoint
method of the OutputManager
class.
class should be initialized with the name of the network to use and the symmetries to average.
If we define the class
class MatrixMultiplication_NonHolomorphic(nn.Module):
holo: bool = False
@nn.compact
def __call__(self, s):
layer1 = nn.Dense(1, use_bias=False, **init_fn_args(dtype=global_defs.tCpx))
out = layer1(2 * s.ravel() - 1)
if not self.holo:
out = out + 1e-1 * jnp.real(out)
return jnp.sum(out)
and let holo=True
, the gradients computed by psi.gradients
for the input s=[0, 0, 0, 0]
are
[-1.+0.j -1.+0.j -1.+0.j -1.+0.j -0.-1.j -0.-1.j -0.-1.j -0.-1.j]
as expected.
However, if we let holo=False
, the returned gradients are [-2.1+0.j -2.1+0.j -2.1+0.j -2.1+0.j]
. This means, that if we for example did time evolution with this setup, considering only the imaginary part of the S-matrix we would get all-zeroes.
If we add the gradient function
def flat_gradient_cpx_nonholo(fun, params, arg):
gr = grad(lambda p, y: jnp.real(fun.apply(p, y)))(params, arg)["params"]
gi = grad(lambda p, y: jnp.imag(fun.apply(p, y)))(params, arg)["params"]
g = tree_flatten(tree_map(lambda x, y: [x.ravel(), -y.ravel()], gr, gi))[0]
return jnp.concatenate(g)
the returned gradients are [-1.1+0.j -1.1+0.j -1.1+0.j -1.1+0.j -0. -1.j -0. -1.j -0. -1.j -0. -1.j]
, which is in line with the above case where holo=True
.
When a parameters
keyword argument is passed to the sample()
function of a sampler object, samples w.r.t. to the corresponding distribution are supposed to be returned. Also, the coefficients psi(s)
that are returned should correspond to the given parameters
. This is not the case.
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.