Giter VIP home page Giter VIP logo

ferminet's Introduction

FermiNet: Fermionic Neural Networks

FermiNet is a neural network for learning highly accurate ground state wavefunctions of atoms and molecules using a variational Monte Carlo approach.

This repository contains an implementation of the algorithm and experiments first described in "Ab-Initio Solution of the Many-Electron Schroedinger Equation with Deep Neural Networks", David Pfau, James S. Spencer, Alex G de G Matthews and W.M.C. Foulkes, Phys. Rev. Research 2, 033429 (2020), along with subsequent research and developments.

WARNING: This is a research-level release of a JAX implementation and is under active development. The original TensorFlow implementation can be found in the tf branch.

Installation

pip install -e . will install all required dependencies. This is best done inside a virtual environment.

virtualenv ~/venv/ferminet
source ~/venv/ferminet/bin/activate
pip install -e .

If you have a GPU available (highly recommended for fast training), then you can install JAX with CUDA support, using e.g.:

pip install --upgrade jax jaxlib==0.1.57+cuda110 -f
https://storage.googleapis.com/jax-releases/jax_releases.html

Note that the jaxlib version must correspond to the existing CUDA installation you wish to use. Please see the JAX documentation for more details.

The tests are easiest run using pytest:

pip install -e '.[testing]'
python -m pytest

Usage

ferminet uses the ConfigDict from ml_collections to configure the system. A few example scripts are included under ferminet/configs/. These are mostly for testing so may need additional settings for a production-level calculation.

ferminet --config ferminet/configs/atom.py --config.system.atom Li --config.batch_size 256 --config.pretrain.iterations 100

will train FermiNet to find the ground-state wavefunction of the Li atom using a batch size of 1024 MCMC configurations ("walkers" in variational Monte Carlo language), and 100 iterations of pretraining (the default of 1000 is overkill for such a small system). The system and hyperparameters can be controlled by modifying the config file or (better, for one-off changes) using flags. See the ml_collections' documentation for further details on the flag syntax. Details of all available config settings are in ferminet/base_config.py.

Other systems can easily be set up, by creating a new config file or ferminet, or writing a custom training script. For example, to run on the H2 molecule, you can create a config file containing:

from ferminet import base_config
from ferminet.utils import system

# Settings in a config files are loaded by executing the the get_config
# function.
def get_config():
  # Get default options.
  cfg = base_config.default()
  # Set up molecule
  cfg.system.electrons = (1,1)
  cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

  # Set training hyperparameters
  cfg.batch_size = 256
  cfg.pretrain.iterations = 100

  return cfg

and then run it using

ferminet --config /path/to/h2_config.py

or equivalently write the following script (or execute it interactively):

import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train

# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Define H2 molecule
cfg = base_config.default()
cfg.system.electrons = (1,1)  # (alpha electrons, beta electrons)
cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

# Set training parameters
cfg.batch_size = 256
cfg.pretrain.iterations = 100

train.train(cfg)

Alternatively, you can directly pass in a PySCF 'Molecule'. You can create PySCF Molecules with the following:

from pyscf import gto
mol = gto.Mole()
mol.build(
    atom = 'H  0 0 1; H 0 0 -1',
    basis = 'sto-3g', unit='bohr')

Once you have this molecule, you can pass it directly into the configuration by running

from ferminet import base_config
from ferminet import train

# Add H2 molecule
cfg = base_config.default()
cfg.system.pyscf_mol = mol

# Set training parameters
cfg.batch_size = 256
cfg.pretrain.iterations = 100

train.train(cfg)

Note: to train on larger atoms and molecules with large batch sizes, multi-GPU parallelisation is essential. This is supported via JAX's pmap. Multiple GPUs will be automatically detected and used if available.

Excited States

Excited state properties of systems can be calculated using the Natural Excited States for VMC (NES-VMC) algorithm. To enable the calculation of k states of a system, simply set cfg.system.states=k in the config file.

Output

The results directory contains train_stats.csv which contains the local energy and MCMC acceptance probability for each iteration, and the checkpoints directory, which contains the checkpoints generated during training. When computing observables of excited states or the density matrix for the ground state, .npy files are also saved to the same folder. A single NumPy array is saved for every iteration of optimization into the same file. An example Colab notebook analyzing these outputs is given in notebooks/excited_states_analysis.ipynb. (Open in Colab!)

Pretrained Models

A collection of pretrained models trained with KFAC can be found on Google Cloud here. These are all systems from the original PRResearch paper: carbon and neon atoms, and nitrogen, ethene, methylamine, ethanol and bicyclobutane molecules. Each folder contains samples from the wavefunction in walkers.npy, parameters in parameters.npz and geometries for the molecule in geometry.npz. To load the models and evaluate the local energy, run:

import numpy as np
import jax
from functools import partial
from ferminet import networks, train

with open('params.npz', 'rb') as f:
  params = dict(np.load(f, allow_pickle=True))
  params = params['arr_0'].tolist()

with open('walkers.npy', 'rb') as f:
  data = np.load(f)

with open('geometry.npz', 'rb') as f:
  geometry = dict(np.load(f, allow_pickle=True))

signed_network = partial(networks.fermi_net, envelope_type='isotropic', full_det=False, **geometry)
# networks.fermi_net gives the sign/log of the wavefunction. We only care about the latter.
network = lambda p, x: signed_network(p, x)[1]
batch_network = jax.vmap(network, (None, 0), 0)
loss = train.make_loss(network, batch_network, geometry['atoms'], geometry['charges'], clip_local_energy=5.0)

print(loss(params, data)[0])  # For neon, should give -128.94165

Giving Credit

If you use this code in your work, please cite the associated papers. The initial paper details the architecture and results on a range of systems:

@article{pfau2020ferminet,
  title={Ab-Initio Solution of the Many-Electron Schr{\"o}dinger Equation with Deep Neural Networks},
  author={D. Pfau and J.S. Spencer and A.G. de G. Matthews and W.M.C. Foulkes},
  journal={Phys. Rev. Research},
  year={2020},
  volume={2},
  issue = {3},
  pages={033429},
  doi = {10.1103/PhysRevResearch.2.033429},
  url = {https://link.aps.org/doi/10.1103/PhysRevResearch.2.033429}
}

and a NeurIPS Workshop Machine Learning and Physics paper describes the JAX implementation:

@misc{spencer2020better,
  title={Better, Faster Fermionic Neural Networks},
  author={James S. Spencer and David Pfau and Aleksandar Botev and W. M.C. Foulkes},
  year={2020},
  eprint={2011.07125},
  archivePrefix={arXiv},
  primaryClass={physics.comp-ph},
  url={https://arxiv.org/abs/2011.07125}
}

The PsiFormer architecture is detailed in an ICLR 2023 paper, preprint reference:

@misc{vonglehn2022psiformer,
  title={A Self-Attention Ansatz for Ab-initio Quantum Chemistry},
  author={Ingrid von Glehn and James S Spencer and David Pfau},
  year={2022},
  eprint={2211.13672},
  archivePrefix={arXiv},
  primaryClass={physics.chem-ph},
  url={https://arxiv.org/abs/2211.13672},
}

Periodic boundary conditions were originally introduced in a Physical Review Letters article:

@article{cassella2023discovering,
  title={Discovering quantum phase transitions with fermionic neural networks},
  author={Cassella, Gino and Sutterud, Halvard and Azadi, Sam and Drummond, ND and Pfau, David and Spencer, James S and Foulkes, W Matthew C},
  journal={Physical review letters},
  volume={130},
  number={3},
  pages={036401},
  year={2023},
  publisher={APS}
}

Wasserstein QMC (thanks to Kirill Neklyudov) is described in a NeurIPS 2023 article:

@article{neklyudov2023wasserstein,
  title={Wasserstein Quantum Monte Carlo: A Novel Approach for Solving the Quantum Many-Body Schr{\"o}dinger Equation},
  author={Neklyudov, Kirill and Nys, Jannes and Thiede, Luca and Carrasquilla, Juan and Liu, Qiang and Welling, Max and Makhzani, Alireza},
  journal={NeurIPS},
  year={2023}
}

Natural excited states was introduced in this article, which is also the first paper from our group using pseudopotentials

@article{pfau2023natural,
  title={Natural Quantum Monte Carlo Computation of Excited States},
  author={Pfau, David and Axelrod, Simon and Sutterud, Halvard and von Glehn, Ingrid and Spencer, James S},
  journal={arXiv preprint arXiv:2308.16848},
  year={2023}
}

This repository can be cited using:

@software{ferminet_github,
  author = {James S. Spencer, David Pfau and FermiNet Contributors},
  title = {{FermiNet}},
  url = {http://github.com/deepmind/ferminet},
  year = {2020},
}

Disclaimer

This is not an official Google product.

ferminet's People

Contributors

dpfau avatar gcassella avatar halvarsu avatar hawkinsp avatar james-martens avatar jsspencer avatar n-gao avatar necludov avatar rchen152 avatar saran-t avatar shishaochen avatar weiluoren-bytedance avatar yilei avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

ferminet's Issues

Something went wrong in RepeatedDenseBlock.update_curvature_matrix_estimate

I think current version of update_curvature_matrix_estimate have something problem for ignoring the input name pmap_axis_name, and thus I think the proper function should be given as follow:

  def update_curvature_matrix_estimate(
      self,
      state: kfac_jax.TwoKroneckerFactored.State,
      estimation_data: Mapping[str, Sequence[Array]],
      ema_old: Numeric,
      ema_new: Numeric,
      batch_size: int,
          pmap_axis_name: Optional[str],
          sync: Array | bool = True,
  ) -> kfac_jax.TwoKroneckerFactored.State:
    estimation_data = dict(**estimation_data)
    x, = estimation_data["inputs"]
    dy, = estimation_data["outputs_tangent"]
    assert x.shape[0] == batch_size
    estimation_data["inputs"] = (x.reshape([-1, x.shape[-1]]),)
    estimation_data["outputs_tangent"] = (dy.reshape([-1, dy.shape[-1]]),)
    batch_size = x.size // x.shape[-1]
    return super().update_curvature_matrix_estimate(
        state=state,
        estimation_data=estimation_data,
        ema_old=ema_old,
        ema_new=ema_new,
        batch_size=batch_size,
        pmap_axis_name=pmap_axis_name,
    )

where I added pmap_axis_name.

Ground State Energies

Can you please refer me to the list of Ground State values to which you compared the model performance?

How does training time scale w.r.t. model size?

Hi all, I'm trying to get FermiNet up and running with GPU acceleration and so far I haven't been able to get TensorFlow 1.15 installed at the same time as CUDA 10.0. I'm wondering if this is even worth the effort.

Can anyone tell me how I should expect training time to scale relative to molecule size? Say, a hydrogen atom vs. a benzene ring vs. a caffeine molecule? And how much of an improvement (broadly) should I expect from GPU training support?

Incidentally--is there a Docker image available with CUDA support?

Pretraining with `full_det=True`

Hi,
I noticed that the default parameter for full_det is True. So, I would expect that during the pretraining one also fits the dense Slater-determinant obtained by Hartree Fock. However, it looks like the code only retrieves two blocks from the Slater determinant and fits it to the diagonal blocks of FermiNet while fitting the rest of FermiNet's orbitals to 0.
Is there a good reason to train like this?
Wouldn't a better approach be to fit FermiNet's orbitals to the product of the two (spin-up and spin-down) matrices obtained by Hartree Fock?

An issue regarding multi-node training with TF code

Hi there. So I'm trying to run the TF-version ferminet on multiple nodes in a GPU cluster with a pretty naive idea of replacing MirroredStrategy with the MultiWorkerMirroredStrategy (I had some experience with TF's distributive training with its estimator API, but not with the low-level training loop nor with Sonnet).

Unfortunately, the effort failed with an issue that the strategy trying to place tensors on a device named like /job:worker/replica:0/task:0/device:GPU:0 on a worker node. Here /job:worker part is problematic since all available options are all started with /job:localhost instead (from my understanding TF should create a correctly named device but somehow it didn't). In my setting I didn't use any PS node and I am not sure if PS node is required when using MultiWorkerMirroredStrategy in TF 1.15 or if it's related to this phenomena.

So have you guys tried multi-node training with the TF-version code? If so did you use MultiWorkerMirroredStrategy (or did you run into this issue or something similar?)? Any comments is appreciated, thanks!

By the way, to my knowledge, JAX has not yet supported multi-node training, does it?

ValueError: Using default_file_mode other than 'r' is no longer supported. Pass the mode to h5py.File() instead. JAX branch.

Hi, I am trying to run (with the JAX branch):

ferminet --config ferminet/configs/atom.py --config.system.atom Li --config.batch_size 256 --config.pretrain.iterations 100

I am getting this error:

ValueError: Using default_file_mode other than 'r' is no longer supported. Pass the mode to h5py.File() instead.

Here is the full error:

Traceback (most recent call last):
File "/home/ben/miniconda3/envs/sch_eqn/bin/ferminet", line 7, in
exec(compile(f.read(), file, 'exec'))
File "/home/ben/Documents/sch_eqn/ferminet_jax/bin/ferminet", line 21, in
from ferminet import train
File "/home/ben/Documents/sch_eqn/ferminet_jax/ferminet/train.py", line 25, in
from ferminet import hamiltonian
File "/home/ben/Documents/sch_eqn/ferminet_jax/ferminet/hamiltonian.py", line 17, in
from ferminet import networks
File "/home/ben/Documents/sch_eqn/ferminet_jax/ferminet/networks.py", line 21, in
from ferminet.utils import scf
File "/home/ben/Documents/sch_eqn/ferminet_jax/ferminet/utils/scf.py", line 38, in
import pyscf
File "/home/ben/miniconda3/envs/sch_eqn/lib/python3.7/site-packages/pyscf/init.py", line 71, in
from pyscf import lib
File "/home/ben/miniconda3/envs/sch_eqn/lib/python3.7/site-packages/pyscf/lib/init.py", line 24, in
from pyscf.lib import numpy_helper
File "/home/ben/miniconda3/envs/sch_eqn/lib/python3.7/site-packages/pyscf/lib/numpy_helper.py", line 27, in
from pyscf.lib import misc
File "/home/ben/miniconda3/envs/sch_eqn/lib/python3.7/site-packages/pyscf/lib/misc.py", line 46, in
h5py.get_config().default_file_mode = 'a'
File "h5py/h5.pyx", line 179, in h5py.h5.H5PYConfig.default_file_mode.set
ValueError: Using default_file_mode other than 'r' is no longer supported. Pass the mode to h5py.File() instead.

kfac_jax error when running H2 example script

Hi, I'm trying to run the example script for the H2 molecule on Colab, but I run into an attribute error on the first iteration of training. Below are the commands I'm using to install the relevant packages:

pip install git+https://github.com/deepmind/ferminet@main
pip install numpy==1.26.0 #Error in importing pyscf with version 1.26.1

the script I'm trying to run:

import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train

# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Define H2 molecule
cfg = base_config.default()
cfg.system.electrons = (1,1)  # (alpha electrons, beta electrons)
cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

# Set training parameters
cfg.batch_size = 256
cfg.pretrain.iterations = 100

train.train(cfg)

as well as the full printout of the error message:

INFO:absl:Starting QMC with 1 XLA devices per host across 1 hosts.
converged SCF energy = -1.05642988216974  <S^2> = -4.4408921e-16  2S+1 = 1
INFO:absl:No checkpoint found. Training new model.
INFO:absl:Pretrain iter 00000: 0.0987188
INFO:absl:Pretrain iter 00001: 0.0468968
INFO:absl:Pretrain iter 00002: 0.021173
INFO:absl:Pretrain iter 00003: 0.012926
INFO:absl:Pretrain iter 00004: 0.012574
INFO:absl:Pretrain iter 00005: 0.0137347
INFO:absl:Pretrain iter 00006: 0.0137845
INFO:absl:Pretrain iter 00007: 0.012732
INFO:absl:Pretrain iter 00008: 0.0110899
INFO:absl:Pretrain iter 00009: 0.00934778
INFO:absl:Pretrain iter 00010: 0.00778915
INFO:absl:Pretrain iter 00011: 0.00656673
INFO:absl:Pretrain iter 00012: 0.00565565
INFO:absl:Pretrain iter 00013: 0.00497928
INFO:absl:Pretrain iter 00014: 0.00446487
INFO:absl:Pretrain iter 00015: 0.00406806
INFO:absl:Pretrain iter 00016: 0.00370217
INFO:absl:Pretrain iter 00017: 0.00335341
INFO:absl:Pretrain iter 00018: 0.00301634
INFO:absl:Pretrain iter 00019: 0.00270557
INFO:absl:Pretrain iter 00020: 0.00244722
INFO:absl:Pretrain iter 00021: 0.00228125
INFO:absl:Pretrain iter 00022: 0.00220615
INFO:absl:Pretrain iter 00023: 0.00218246
INFO:absl:Pretrain iter 00024: 0.00216123
INFO:absl:Pretrain iter 00025: 0.00209267
INFO:absl:Pretrain iter 00026: 0.00195859
INFO:absl:Pretrain iter 00027: 0.0017916
INFO:absl:Pretrain iter 00028: 0.00160594
INFO:absl:Pretrain iter 00029: 0.00143077
INFO:absl:Pretrain iter 00030: 0.0012787
INFO:absl:Pretrain iter 00031: 0.00115612
INFO:absl:Pretrain iter 00032: 0.00107149
INFO:absl:Pretrain iter 00033: 0.00101318
INFO:absl:Pretrain iter 00034: 0.00097588
INFO:absl:Pretrain iter 00035: 0.000957538
INFO:absl:Pretrain iter 00036: 0.000952348
INFO:absl:Pretrain iter 00037: 0.000960992
INFO:absl:Pretrain iter 00038: 0.000957644
INFO:absl:Pretrain iter 00039: 0.00094491
INFO:absl:Pretrain iter 00040: 0.000913838
INFO:absl:Pretrain iter 00041: 0.000867143
INFO:absl:Pretrain iter 00042: 0.000815364
INFO:absl:Pretrain iter 00043: 0.000763372
INFO:absl:Pretrain iter 00044: 0.000719524
INFO:absl:Pretrain iter 00045: 0.00068491
INFO:absl:Pretrain iter 00046: 0.000651115
INFO:absl:Pretrain iter 00047: 0.000626584
INFO:absl:Pretrain iter 00048: 0.000609355
INFO:absl:Pretrain iter 00049: 0.000602466
INFO:absl:Pretrain iter 00050: 0.000594917
INFO:absl:Pretrain iter 00051: 0.000595118
INFO:absl:Pretrain iter 00052: 0.000592292
INFO:absl:Pretrain iter 00053: 0.000585884
INFO:absl:Pretrain iter 00054: 0.000574762
INFO:absl:Pretrain iter 00055: 0.000557475
INFO:absl:Pretrain iter 00056: 0.000534383
INFO:absl:Pretrain iter 00057: 0.000514163
INFO:absl:Pretrain iter 00058: 0.000497474
INFO:absl:Pretrain iter 00059: 0.000483247
INFO:absl:Pretrain iter 00060: 0.000474224
INFO:absl:Pretrain iter 00061: 0.000468507
INFO:absl:Pretrain iter 00062: 0.000463785
INFO:absl:Pretrain iter 00063: 0.000460483
INFO:absl:Pretrain iter 00064: 0.000456877
INFO:absl:Pretrain iter 00065: 0.000450266
INFO:absl:Pretrain iter 00066: 0.000445674
INFO:absl:Pretrain iter 00067: 0.00043909
INFO:absl:Pretrain iter 00068: 0.00043164
INFO:absl:Pretrain iter 00069: 0.000425373
INFO:absl:Pretrain iter 00070: 0.000416835
INFO:absl:Pretrain iter 00071: 0.000408702
INFO:absl:Pretrain iter 00072: 0.000404959
INFO:absl:Pretrain iter 00073: 0.000398033
INFO:absl:Pretrain iter 00074: 0.000392218
INFO:absl:Pretrain iter 00075: 0.000388479
INFO:absl:Pretrain iter 00076: 0.000385467
INFO:absl:Pretrain iter 00077: 0.000381409
INFO:absl:Pretrain iter 00078: 0.000377661
INFO:absl:Pretrain iter 00079: 0.00037244
INFO:absl:Pretrain iter 00080: 0.000368109
INFO:absl:Pretrain iter 00081: 0.000364462
INFO:absl:Pretrain iter 00082: 0.000360276
INFO:absl:Pretrain iter 00083: 0.0003561
INFO:absl:Pretrain iter 00084: 0.000352395
INFO:absl:Pretrain iter 00085: 0.000348523
INFO:absl:Pretrain iter 00086: 0.000344737
INFO:absl:Pretrain iter 00087: 0.000342277
INFO:absl:Pretrain iter 00088: 0.000339568
INFO:absl:Pretrain iter 00089: 0.000336274
INFO:absl:Pretrain iter 00090: 0.000333011
INFO:absl:Pretrain iter 00091: 0.000329609
INFO:absl:Pretrain iter 00092: 0.000326805
INFO:absl:Pretrain iter 00093: 0.000322345
INFO:absl:Pretrain iter 00094: 0.000320427
INFO:absl:Pretrain iter 00095: 0.000316732
INFO:absl:Pretrain iter 00096: 0.000314699
INFO:absl:Pretrain iter 00097: 0.000311589
INFO:absl:Pretrain iter 00098: 0.000308642
INFO:absl:Pretrain iter 00099: 0.000305587
INFO:absl:==================================================
INFO:absl:Graph parameter registrations:
INFO:absl:{'envelope': [{'pi': 'Auto[scale_and_shift_tag_1]',
               'sigma': 'Auto[scale_and_shift_tag_0]'},
              {'pi': 'Auto[scale_and_shift_tag_3]',
               'sigma': 'Auto[scale_and_shift_tag_2]'}],
 'layers': {'input': {},
            'streams': [{'double': {'b': 'Auto[repeated_dense_tag_1]',
                                    'w': 'Auto[repeated_dense_tag_1]'},
                         'single': {'b': 'Auto[repeated_dense_tag_0]',
                                    'w': 'Auto[repeated_dense_tag_0]'}},
                        {'double': {'b': 'Auto[repeated_dense_tag_3]',
                                    'w': 'Auto[repeated_dense_tag_3]'},
                         'single': {'b': 'Auto[repeated_dense_tag_2]',
                                    'w': 'Auto[repeated_dense_tag_2]'}},
                        {'double': {'b': 'Auto[repeated_dense_tag_5]',
                                    'w': 'Auto[repeated_dense_tag_5]'},
                         'single': {'b': 'Auto[repeated_dense_tag_4]',
                                    'w': 'Auto[repeated_dense_tag_4]'}},
                        {'single': {'b': 'Auto[repeated_dense_tag_6]',
                                    'w': 'Auto[repeated_dense_tag_6]'}}]},
 'orbital': [{'w': 'Auto[repeated_dense_tag_7]'},
             {'w': 'Auto[repeated_dense_tag_8]'}]}
INFO:absl:==================================================
INFO:absl:Burning in MCMC chain for 100 steps
INFO:absl:Completed burn-in MCMC steps
INFO:absl:Initial energy: -1.1759 E_h
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<ipython-input-1-df9aebf4a1f2> in <cell line: 22>()
     20 cfg.pretrain.iterations = 100
     21 
---> 22 train.train(cfg)

16 frames
/usr/local/lib/python3.10/dist-packages/ferminet/train.py in train(cfg, writer_manager)
    712     for t in range(t_init, cfg.optim.iterations):
    713       sharded_key, subkeys = kfac_jax.utils.p_split(sharded_key)
--> 714       data, params, opt_state, loss, unused_aux_data, pmove = step(
    715           data,
    716           params,

/usr/local/lib/python3.10/dist-packages/ferminet/train.py in step(data, params, state, key, mcmc_width)
    313 
    314     # Optimization step
--> 315     new_params, state, stats = optimizer.step(
    316         params=params,
    317         state=state,

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/optimizer.py in step(self, params, state, rng, data_iterator, batch, func_state, learning_rate, momentum, damping, global_step_int)
   1214       batch = next(data_iterator)
   1215 
-> 1216     return self._step(params, state, rng, batch, func_state,
   1217                       learning_rate, momentum, damping)
   1218 

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/utils/staging.py in decorated(instance, *args)
    246           pmap_funcs[key] = func
    247 
--> 248         outs = func(instance, *args)
    249 
    250       else:

    [... skipping hidden 12 frame]

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/utils/misc.py in wrapped(instance, *args, **kwargs)
    283       method_name = method_name[1:]
    284     with jax.named_scope(f"{class_name}_{method_name}"):
--> 285       return method(instance, *args, **kwargs)
    286 
    287   return wrapped

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/optimizer.py in _step(self, params, state, rng, batch, func_state, learning_rate, momentum, damping)
   1020 
   1021     # Update curvature estimate
-> 1022     state = self._maybe_update_estimator_curvature(
   1023         state,
   1024         func_args,

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/optimizer.py in _maybe_update_estimator_curvature(self, state, func_args, rng, ema_old, ema_new, sync)
    723   ) -> "Optimizer.State":
    724     """Updates the curvature estimates if it is the right iteration."""
--> 725     return self._maybe_update_estimator_state(
    726         state,
    727         self.should_update_estimate_curvature(state),

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/optimizer.py in _maybe_update_estimator_state(self, state, should_update, update_func, **update_func_kwargs)
    678     state = state.copy()
    679 
--> 680     state.estimator_state = lax.cond(
    681         should_update,
    682         functools.partial(update_func, **update_func_kwargs),

    [... skipping hidden 13 frame]

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/optimizer.py in _update_estimator_curvature(self, estimator_state, func_args, rng, ema_old, ema_new, sync)
    696   ) -> curvature_estimator.BlockDiagonalCurvature.State:
    697     """Updates the curvature estimator state."""
--> 698     state = self.estimator.update_curvature_matrix_estimate(
    699         state=estimator_state,
    700         ema_old=ema_old,

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/utils/misc.py in wrapped(instance, *args, **kwargs)
    283       method_name = method_name[1:]
    284     with jax.named_scope(f"{class_name}_{method_name}"):
--> 285       return method(instance, *args, **kwargs)
    286 
    287   return wrapped

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/curvature_estimator.py in update_curvature_matrix_estimate(self, state, ema_old, ema_new, batch_size, rng, func_args, estimation_mode)
   1239 
   1240     # Compute the losses and the VJP function from the function inputs
-> 1241     losses, losses_vjp = self._compute_losses_vjp(func_args)
   1242 
   1243     if "fisher" in estimation_mode:

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/utils/misc.py in wrapped(instance, *args, **kwargs)
    283       method_name = method_name[1:]
    284     with jax.named_scope(f"{class_name}_{method_name}"):
--> 285       return method(instance, *args, **kwargs)
    286 
    287   return wrapped

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/curvature_estimator.py in _compute_losses_vjp(self, func_args)
   1037   def _compute_losses_vjp(self, func_args: utils.FuncArgs):
   1038     """Computes all model statistics needed for estimating the curvature."""
-> 1039     return self._vjp(func_args)
   1040 
   1041   def params_vector_to_blocks_vectors(

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/tracer.py in wrapped_transformation(func_args, return_only_jaxpr, *args)
    379       return jaxpr
    380     else:
--> 381       return f(func_args, *args)
    382 
    383   return wrapped_transformation

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/tracer.py in _layer_tag_vjp(processed_jaxpr, primal_func_args)
    781 
    782   # First compute the primal values for the inputs to all layer tags
--> 783   layer_input_values = forward()
    784   primals_dict = dict(zip(layer_input_vars, layer_input_values))
    785 

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/tracer.py in forward()
    698     for eqn in processed_jaxpr.jaxpr.eqns:
    699 
--> 700       write(eqn.outvars, tgm.eval_jaxpr_eqn(eqn, read(eqn.invars)))
    701 
    702       if isinstance(eqn.primitive, tags.LossTag):

/usr/local/lib/python3.10/dist-packages/kfac_jax/_src/tag_graph_matcher.py in eval_jaxpr_eqn(eqn, in_values)
     63 
     64   if jax_version > (0, 4, 11):
---> 65     user_context = jax_extend.source_info_util.user_context
     66   else:
     67     user_context = jax.core.source_info_util.user_context

AttributeError: module 'jax.extend' has no attribute 'source_info_util'

install eeror

(plm) (ferminet) [liujinde@node02 ferminet-master]$ python -m pytest
=========================================================================== test session starts ============================================================================
platform linux -- Python 3.7.3, pytest-6.2.5, py-1.10.0, pluggy-1.0.0
rootdir: /home/liujinde/Dowload/ferminet-master
collected 135 items

ferminet/tests/hamiltonian_test.py ..s..s..s...s [ 9%]
ferminet/tests/mcmc_test.py ...s [ 12%]
ferminet/tests/networks_test.py ...Fatal Python error: Fatal Python error: Fatal Python error: Fatal Python error: Fatal Python error: Fatal Python error: Fatal Python error: Fatal Python error: Segmentation faultSegmentation faultSegmentation faultSegmentation faultSegmentation faultSegmentation faultSegmentation fault

Segmentation fault (core dumped)

Why local_energy is not jit'ed?

Hi there. Got a quick question on the JAX implementation:

Why local_energy is not jit'ed? To be more specific, I mean the local energy defined in https://github.com/deepmind/ferminet/blob/bf0d06eb05e3a17063551e8573a129568e99beac/ferminet/train.py#L111-L112

Actually, jit does not even show up in the train.py file at all.

Note that in tests, the local_energy is indeed jit'ed before comparison. See https://github.com/deepmind/ferminet/blob/bf0d06eb05e3a17063551e8573a129568e99beac/ferminet/tests/hamiltonian_test.py#L147-L149

This could be a JAX 101 question (Sorry in advance!) though. It would be very helpful if you can share the rationale for not doing jit here and/or some related performance tips. Thanks!

BTW, I am quite interested in how the laplacian could be calculated in JAX, and measured the performance of the local_kinetic_energy. I did find that doing jit can significantly improve the performance (which makes me wonder when you didn't do jit for local_energy). A related implementation-detail question:

  1. why doing a fori_loop to add up the second-order derivatives instead of first calculating Hessian, for instance, by jit(jacfwd(jacrev(fun))) then sum the diagonal and the square of the gradient? Is the concern more on the speed or memory consumption?

Issues with KFAC on multiple GPUs

Hi there. Thanks again for the awesome open-sourcing work of the KFAC optimizer!

However, as I mentioned in #24 (comment) several weeks ago, we hit some issues when running optimization with KFAC.

As suggested by @jsspencer, it may be the issue of some noisy matrix to be inverted. However, we tried a quite large batch size (40960 for Mg) and large damping factor (1 as opposed to the default 0.001), but neither fixes the issue.

Recently, we notice that the cusolver issue only shows up when we optimize ferminet with multiple GPUs (in our case 8 V100 cards). And if we run the same command for Mg on a single GPU (even with small batch size like 256), it does not hit the same cusolver issue.

Unfortunately, we were failed to spot where the cusolver issue really happens. In fact, we tried to debug with the KFAC optimizer's debug option turned on (so that no jit or pmap happens), but KFAC optimizer doesn't work at all with debug turned on and it seems not trivial to really fix it (maybe we didn't try hard enough). We think the problematic inversion might happen at https://github.com/deepmind/deepmind-research/blob/master/kfac_ferminet_alpha/utils.py#L131, but even if we simply replace the to-invert matrix by an identity matrix, the issue persists in the multiple-GPU environment.

Since the optimization works in a single-GPU environment but failed in a multiple-GPU one, we suspect something wrong when pmap meets cusolver, but didn't know how to dig deeper. Thoughts?

BTW, do you guys do all the development and testing on TPU instead GPU? If so, we might also run our experiments on TPUs if KFAC works there. Any gotchas when running JAX on TPU? Thanks!

Issues (or typos?) when running JAX code with multiple GPUs

Hi there. Got two issues when running the JAX code with multiple GPUs:

  1. https://github.com/deepmind/ferminet/blob/b46077f1a4687b3ac03b583d9726f59ab4d914d9/ferminet/train.py#L293-L297
    It would hit too many values to unpack error when num_devices is greater than 1.
    My understanding is that we should do
key, *subkeys = jax.random.split(key, num_devices+1)

instead (note the extra asterisk), in which case the following explicit broadcast is not necessary any more for single GPU case.

  1. https://github.com/deepmind/ferminet/blob/b46077f1a4687b3ac03b583d9726f59ab4d914d9/ferminet/train.py#L372-L373
    constants.pmap gives a tuple of an array instead of just an array in this case when num_devices is greater than 1 (not sure why, probably just JAX's API). This would cause logging to complain. It's easy to fix though.

Let me know if it makes sense. Also if you like, I can submit a tiny PR to fix them

Extension of PBC code to 1D

Hi, I was wondering if there are any plans to extend the code for calculations in periodic boundary conditions to different choices of dimensionality, for example to the one-dimensional case. Thanks!

unable to setup

on running pip install -e . I got:

Obtaining file:///users/andrewho/Documents/ferminet-main
Collecting kfac_jax@ git+https://github.com/deepmind/kfac-jax
  Cloning https://github.com/deepmind/kfac-jax to /tmp/pip-install-qnrbw3gg/kfac-jax_f460d57ceba14857a5582fea1fed6cc3
  Running command git clone -q https://github.com/deepmind/kfac-jax /tmp/pip-install-qnrbw3gg/kfac-jax_f460d57ceba14857a5582fea1fed6cc3
  Resolved https://github.com/deepmind/kfac-jax to commit f8b6405a9da0fbb4b9dc957d1997e8eb24a96c18
Requirement already satisfied: absl-py in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from ferminet==0.2) (1.0.0)
Collecting attrs
  Using cached attrs-21.4.0-py2.py3-none-any.whl (60 kB)
Requirement already satisfied: chex in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from ferminet==0.2) (0.1.3)
Requirement already satisfied: h5py in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from ferminet==0.2) (3.6.0)
Requirement already satisfied: jax in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from ferminet==0.2) (0.3.8)
Requirement already satisfied: jaxlib in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from ferminet==0.2) (0.3.7)
Collecting ml-collections
  Using cached ml_collections-0.1.1-py3-none-any.whl
Collecting optax
  Using cached optax-0.1.2-py3-none-any.whl (140 kB)
Requirement already satisfied: numpy in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from ferminet==0.2) (1.22.3)
Collecting pandas
  Using cached pandas-1.4.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (11.7 MB)
Collecting pyscf
  Using cached pyscf-2.0.1.tar.gz (7.7 MB)
Collecting pyblock
  Using cached pyblock-0.6-py3-none-any.whl
Requirement already satisfied: scipy in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from ferminet==0.2) (1.8.0)
Requirement already satisfied: tables in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from ferminet==0.2) (3.7.0)
Requirement already satisfied: typing_extensions in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from ferminet==0.2) (4.2.0)
Requirement already satisfied: six in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from absl-py->ferminet==0.2) (1.16.0)
Requirement already satisfied: dm-tree>=0.1.5 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from chex->ferminet==0.2) (0.1.7)
Requirement already satisfied: toolz>=0.9.0 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from chex->ferminet==0.2) (0.11.2)
Requirement already satisfied: opt-einsum in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from jax->ferminet==0.2) (3.3.0)
Requirement already satisfied: flatbuffers<3.0,>=1.12 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from jaxlib->ferminet==0.2) (2.0)
Requirement already satisfied: immutabledict>=2.2.1 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from kfac_jax@ git+https://github.com/deepmind/kfac-jax->ferminet==0.2) (2.2.1)
Requirement already satisfied: distrax>=0.1.1 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from kfac_jax@ git+https://github.com/deepmind/kfac-jax->ferminet==0.2) (0.1.2)
Requirement already satisfied: tensorflow-probability>=0.15.0 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from distrax>=0.1.1->kfac_jax@ git+https://github.com/deepmind/kfac-jax->ferminet==0.2) (0.16.0)
Requirement already satisfied: decorator in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.1->kfac_jax@ git+https://github.com/deepmind/kfac-jax->ferminet==0.2) (5.1.1)
Requirement already satisfied: gast>=0.3.2 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.1->kfac_jax@ git+https://github.com/deepmind/kfac-jax->ferminet==0.2) (0.5.3)
Requirement already satisfied: cloudpickle>=1.3 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from tensorflow-probability>=0.15.0->distrax>=0.1.1->kfac_jax@ git+https://github.com/deepmind/kfac-jax->ferminet==0.2) (2.0.0)
Requirement already satisfied: PyYAML in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from ml-collections->ferminet==0.2) (6.0)
Requirement already satisfied: contextlib2 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from ml-collections->ferminet==0.2) (21.6.0)
Requirement already satisfied: python-dateutil>=2.8.1 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from pandas->ferminet==0.2) (2.8.2)
Requirement already satisfied: pytz>=2020.1 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from pandas->ferminet==0.2) (2022.1)
Requirement already satisfied: packaging in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from tables->ferminet==0.2) (21.3)
Requirement already satisfied: numexpr>=2.6.2 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from tables->ferminet==0.2) (2.8.1)
Requirement already satisfied: pyparsing!=3.0.5,>=2.0.2 in /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages (from packaging->tables->ferminet==0.2) (3.0.8)
Building wheels for collected packages: kfac-jax, pyscf
  Building wheel for kfac-jax (setup.py) ... done
  Created wheel for kfac-jax: filename=kfac_jax-0.0.1-py3-none-any.whl size=118136 sha256=3d01eb1af325d295ba97b5d7e623a649182882d556ec4d9da5c53f8db833fd60
  Stored in directory: /tmp/pip-ephem-wheel-cache-7s0_5zoa/wheels/d3/a2/bf/79a38a03091a1a334020b9eff654f78c3d51f59e588f897163
  Building wheel for pyscf (setup.py) ... error
  ERROR: Command errored out with exit status 1:
   command: /users/andrewho/.conda/envs/fmnet/bin/python -u -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-qnrbw3gg/pyscf_591825adafd5410e9a5da0cf1a9e00ab/setup.py'"'"'; __file__='"'"'/tmp/pip-install-qnrbw3gg/pyscf_591825adafd5410e9a5da0cf1a9e00ab/setup.py'"'"';f = getattr(tokenize, '"'"'open'"'"', open)(__file__) if os.path.exists(__file__) else io.StringIO('"'"'from setuptools import setup; setup()'"'"');code = f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' bdist_wheel -d /tmp/pip-wheel-a180qok7
       cwd: /tmp/pip-install-qnrbw3gg/pyscf_591825adafd5410e9a5da0cf1a9e00ab/
  Complete output (8 lines):
  running bdist_wheel
  running build
  running build_ext
  Configuring extensions
  cmake -S/tmp/pip-install-qnrbw3gg/pyscf_591825adafd5410e9a5da0cf1a9e00ab/pyscf/lib -Bbuild/temp.linux-x86_64-3.10
  CMake Error: The source directory "" does not exist.
  Specify --help for usage, or press the help button on the CMake GUI.
  error: command '/bin/cmake' failed with exit code 1
  ----------------------------------------
  ERROR: Failed building wheel for pyscf
  Running setup.py clean for pyscf
Successfully built kfac-jax
Failed to build pyscf
Installing collected packages: pyscf, pyblock, pandas, optax, ml-collections, kfac-jax, attrs, ferminet
    Running setup.py install for pyscf ... error
    ERROR: Command errored out with exit status 1:
     command: /users/andrewho/.conda/envs/fmnet/bin/python -u -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-qnrbw3gg/pyscf_591825adafd5410e9a5da0cf1a9e00ab/setup.py'"'"'; __file__='"'"'/tmp/pip-install-qnrbw3gg/pyscf_591825adafd5410e9a5da0cf1a9e00ab/setup.py'"'"';f = getattr(tokenize, '"'"'open'"'"', open)(__file__) if os.path.exists(__file__) else io.StringIO('"'"'from setuptools import setup; setup()'"'"');code = f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' install --record /tmp/pip-record-edw8qsmk/install-record.txt --single-version-externally-managed --compile --install-headers /users/andrewho/.conda/envs/fmnet/include/python3.10/pyscf
         cwd: /tmp/pip-install-qnrbw3gg/pyscf_591825adafd5410e9a5da0cf1a9e00ab/
    Complete output (10 lines):
    running install
    /users/andrewho/.conda/envs/fmnet/lib/python3.10/site-packages/setuptools/command/install.py:34: SetuptoolsDeprecationWarning: setup.py install is deprecated. Use build and pip and other standards-based tools.
      warnings.warn(
    running build
    running build_ext
    Configuring extensions
    cmake -S/tmp/pip-install-qnrbw3gg/pyscf_591825adafd5410e9a5da0cf1a9e00ab/pyscf/lib -Bbuild/temp.linux-x86_64-3.10
    CMake Error: The source directory "" does not exist.
    Specify --help for usage, or press the help button on the CMake GUI.
    error: command '/bin/cmake' failed with exit code 1
    ----------------------------------------
ERROR: Command errored out with exit status 1: /users/andrewho/.conda/envs/fmnet/bin/python -u -c 'import io, os, sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-install-qnrbw3gg/pyscf_591825adafd5410e9a5da0cf1a9e00ab/setup.py'"'"'; __file__='"'"'/tmp/pip-install-qnrbw3gg/pyscf_591825adafd5410e9a5da0cf1a9e00ab/setup.py'"'"';f = getattr(tokenize, '"'"'open'"'"', open)(__file__) if os.path.exists(__file__) else io.StringIO('"'"'from setuptools import setup; setup()'"'"');code = f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' install --record /tmp/pip-record-edw8qsmk/install-record.txt --single-version-externally-managed --compile --install-headers /users/andrewho/.conda/envs/fmnet/include/python3.10/pyscf Check the logs for full command output.

It seems to be conflict of packages' versions

how to make JAX code run on single GPU instead of TPU?

I'm trying to run this example (JAX branch):

import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train

# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Define H2 molecule
cfg = base_config.default()
cfg.system.electrons = (1,1)  # (alpha electrons, beta electrons)
cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

# Set training parameters
cfg.batch_size = 256
cfg.pretrain.iterations = 100

train.train(cfg)

At train.train(cfg), the code seems to be running on TPU by default, how to change it to run on a single GPU instead?

INFO:absl:Starting the local TPU driver.
INFO:absl:Unable to initialize backend 'tpu_driver': Not found: Unable to find driver in registry given worker: local://
INFO:absl:Unable to initialize backend 'tpu': Invalid argument: TpuPlatform is not available.
INFO:absl:Starting QMC with 1 XLA devices

KFAC Open sourcing

Hi,

I was wondering whether you are planning on releasing the KFAC optimizer used in both papers as well?
I know that the TensorFlow version is available on GitHub. Is the JAX version also going open-source?

Thank you!

problem about multi_gpu flag

Dear authors~ It's really a great package and I enjoy much using it to study molecule properties.
However, I encounter a problem when I try to run this package on multi gpus. I follow the tutorial and install tensorflow-gpu just like the readme said "pip install -e'.[tensorflow-gpu]'. Then I run "ferminet --multi_gpu True" in my terminal. The programs ends up with an error:"ValueError: Variable KFAC/model/det_net/after_det_weights/replica_1/KFAC/ does not exist, or was not created with tf.get_variable(). Did you mean to set reuse=tf.AUTO_REUSE in VarScope?"
It seems the variables have scope problems when run on multi-gpus. But it works well on one gpu. I also contact my colleagues and they all have this problem.
Can you told me how to fix it. Thanks!

Question About load Checkpoint

Hello, I have a question about loading checkpoint function.
To the best of my knowledge, we can save the model by /ferminet/train.py in :

      if time.time() - time_of_last_ckpt > cfg.log.save_frequency * 60:
        checkpoint.save(ckpt_save_path, t, data, params, opt_state, mcmc_width)
        time_of_last_ckpt = time.time()
        sys.exit(0)

and this function is implemented by np.savez.
However, when I attempt to load this chekpoint, it will not pass the check logic in checkpoint.resotre, specifically:

  with open(restore_filename, 'rb') as f:
    ckpt_data = np.load(f, allow_pickle=True)
    # Retrieve data from npz file. Non-array variables need to be converted back
    # to natives types using .tolist().
    t = ckpt_data['t'].tolist() + 1  # Return the iterations completed.
    data = ckpt_data['data']
    params = ckpt_data['params'].tolist()
    opt_state = ckpt_data['opt_state'].tolist()
    mcmc_width = jnp.array(ckpt_data['mcmc_width'].tolist())
    if data.shape[0] != jax.device_count():
      raise ValueError(
          f'Incorrect number of devices found. Expected {data.shape[0]}, found '
          f'{jax.device_count()}.')

I attempt to alleviate this issue, and I found that for checkpoint.save function, the data is FermiNetData class, which contains four array named position, spins, atoms, charges, respectively. However, when I load this numpy checkpoint, the data is an array merely consists of strings [position, spins, atoms, charges]. It seems that this part may have some questions?
I'm wondering wheter this part should be correct? With great appreciate for your time and efforts in reading my issue.

Will log_determinant return inf for singular matrices?

In the last line of following code

`
def log_determinant(x):
with tf.name_scope('determinant'):
with tf.device('/cpu:0'):
s, u, v = tf.linalg.svd(x)

sign = sign_det(x)

def grad(dsign, dlog):
  del dsign
  # d ln|det(X)|/dX = transpose(X^{-1}). This uses pseudo-inverse via SVD
  # and so is not numerically stable...
  adj = tf.matmul(u,
                  tf.matmul(tf.linalg.diag(1.0 / s), v, transpose_b=True))
  return tf.expand_dims(tf.expand_dims(dlog, -1), -1) * adj

return (sign, tf.reduce_sum(tf.log(s), axis=-1)), grad

`

it seems that ,for singular matrices, s will have zero elements and tf.log(s) will return -inf ?
And tf.linalg.diag(1.0 / s) many also generate inf results?

Error in set-up

when i run the setup of ferminet, I got an error below:

File "C:\Users\user\anaconda3\envs\FermiNet\lib\site-packages\setuptools\package_index.py", line 119, in distros_for_location
wheel = Wheel(basename)
File "C:\Users\user\anaconda3\envs\FermiNet\lib\site-packages\setuptools\wheel.py", line 61, in init
raise ValueError('invalid wheel name: %r' % filename)
ValueError: invalid wheel name: 'pyscf-1.4.3-cp34-macosx_10_6_intel.macosx_10_9_intel.macosx_10_9_x86_64.macosx_10_10_intel.macosx_10_10_x86_64.whl'

I am using anaconda in window 64 bits to setup using terminal

Jax install - issue with correct version number

Hi,

Apologies for the basic question, but I am having issue running the test scripts provided and I believe it is an issue with the version of Jax that was installed by default within the Setup script. Below are the versions that were automatically installed:

image

And here is the error message that is causing each test to fail.

image

Currently I am running this off the main branch within a Google Colab notebook with GPU backend enabled.

What version of Jax/Jaxlib should I be using?

Thank you for the help!

Logdet Bug Similar to e9f8c64

Noticed that in psiformer.py, line 448, the logdet calculation for NES is still wrong by multiplying ndets together:

      return batch_logdet_matmul(*orbitals)

A similar issue was fixed for FermiNet in e9f8c64.

KeyError raised after burn-in MCMC steps

Dear authors
Currently I meet some problems when I run ferminet, when burn-in MCMC steps completed, it raises an error occurred in kfac_jax/_src/tag_graph_matcher.py, line 674. It is 'KeyError: a', the detail is in attached file 'error.json', please check it.

error.json

Then we try to print 'env' and 'var', we find the keys in variable 'env' changed, that is the reason we meet the error, the detail is in attached file 'error-print.json', please check it.

error-print.json

Could you please give us any advice to solve this problem?

Here is our environment:

channels:

  • https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge
  • https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/conda-forge/
  • https://mirrors.tuna.tsinghua.edu.cn/anaconda/cloud/pytorch/
  • https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/main/
  • https://mirrors.tuna.tsinghua.edu.cn/anaconda/pkgs/free/
  • defaults
    dependencies:
  • _libgcc_mutex=0.1=conda_forge
  • _openmp_mutex=4.5=2_gnu
  • bzip2=1.0.8=hd590300_5
  • ca-certificates=2023.11.17=hbcca054_0
  • ld_impl_linux-64=2.40=h41732ed_0
  • libexpat=2.5.0=hcb278e6_1
  • libffi=3.4.2=h7f98852_5
  • libgcc-ng=13.2.0=h807b86a_4
  • libgomp=13.2.0=h807b86a_4
  • libnsl=2.0.1=hd590300_0
  • libsqlite=3.44.2=h2797004_0
  • libuuid=2.38.1=h0b41bf4_0
  • libxcrypt=4.4.36=hd590300_1
  • libzlib=1.2.13=hd590300_5
  • ncurses=6.4=h59595ed_2
  • openssl=3.2.0=hd590300_1
  • pip=23.3.2=pyhd8ed1ab_0
  • python=3.11.7=hab00c5b_1_cpython
  • readline=8.2=h8228510_1
  • tk=8.6.13=noxft_h4845f30_101
  • wheel=0.42.0=pyhd8ed1ab_0
  • xz=5.2.6=h166bdaf_0
  • pip:
    • absl-py==1.4.0
    • array-record==0.5.0
    • astunparse==1.6.3
    • attrs==23.2.0
    • blosc2==2.5.1
    • cachetools==5.3.2
    • certifi==2023.11.17
    • charset-normalizer==3.3.2
    • chex==0.1.6
    • click==8.1.7
    • cloudpickle==3.0.0
    • contextlib2==21.6.0
    • decorator==5.1.1
    • distrax==0.0.3
    • dm-haiku==0.0.11
    • dm-tree==0.1.8
    • etils==1.6.0
    • flatbuffers==23.5.26
    • flax==0.8.0
    • folx==0.2.2.post1
    • fsspec==2023.12.2
    • gast==0.5.4
    • google-auth==2.27.0
    • google-auth-oauthlib==1.2.0
    • google-pasta==0.2.0
    • googleapis-common-protos==1.62.0
    • grpcio==1.60.0
    • h5py==3.10.0
    • idna==3.6
    • immutabledict==4.1.0
    • importlib-resources==6.1.1
    • iniconfig==2.0.0
    • jax==0.4.23
    • jaxlib==0.3.25
    • jaxline==0.0.8
    • jaxtyping==0.2.25
    • jmp==0.0.4
    • keras==2.15.0
    • kfac-jax==0.0.2
    • libclang==16.0.6
    • markdown==3.5.2
    • markdown-it-py==3.0.0
    • markupsafe==2.1.4
    • mdurl==0.1.2
    • ml-collections==0.1.1
    • ml-dtypes==0.2.0
    • msgpack==1.0.7
    • ndindex==1.7
    • nest-asyncio==1.6.0
    • numexpr==2.9.0
    • numpy==1.26.3
    • oauthlib==3.2.2
    • opt-einsum==3.3.0
    • optax==0.0.5
    • orbax-checkpoint==0.4.4
    • packaging==23.2
    • pandas==2.2.0
    • pluggy==1.4.0
    • promise==2.3
    • protobuf==3.20.3
    • psutil==5.9.8
    • py-cpuinfo==9.0.0
    • pyasn1==0.5.1
    • pyasn1-modules==0.3.0
    • pyblock==0.6
    • pygments==2.17.2
    • pyscf==2.4.0
    • pytest==8.0.0
    • python-dateutil==2.8.2
    • pytz==2023.3.post1
    • pyyaml==6.0.1
    • requests==2.31.0
    • requests-oauthlib==1.3.1
    • rich==13.7.0
    • rsa==4.9
    • scipy==1.12.0
    • setuptools==69.0.3
    • six==1.16.0
    • tables==3.9.2
    • tabulate==0.9.0
    • tensorboard==2.15.1
    • tensorboard-data-server==0.7.2
    • tensorflow==2.15.0.post1
    • tensorflow-datasets==4.9.4
    • tensorflow-estimator==2.15.0
    • tensorflow-io-gcs-filesystem==0.35.0
    • tensorflow-metadata==1.14.0
    • tensorflow-probability==0.23.0
    • tensorstore==0.1.45
    • termcolor==2.4.0
    • toml==0.10.2
    • toolz==0.12.1
    • tqdm==4.66.1
    • typeguard==2.13.3
    • typing-extensions==4.9.0
    • tzdata==2023.4
    • urllib3==2.1.0
    • werkzeug==3.0.1
    • wrapt==1.14.1
    • zipp==3.17.0

Regards
Qinmeng

Evaluating logprob using batch_network in train

Hi all,

To evaluate the model performance, I need to compute the log probability of a new set of data positions inside the main training loop. Calling batch_network = constants.pmap(batch_network) with input batch_network as the vmapped function as defined in train function, and log_prob = 2.0 * batch_network(params, x, data.spins, data.atoms, data.charges) where x is a jax array having the same dimension as data.positions.

However, I encountered the following issue

log_prob = 2.0 * batch_network(params, data.positions, data.spins, data.atoms, data.charges)

ValueError: pmap got inconsistent sizes for array axes to be mapped:`
  * most axes (6 of them) had size 256, e.g. axis 0 of argument args[0]['layers']['streams'][0]['single']['b'] of type float32[256];
  * some axes (5 of them) had size 32, e.g. axis 0 of argument args[0]['layers']['streams'][0]['double']['b'] of type float32[32];
  * some axes (4 of them) had size 6, e.g. axis 0 of argument args[0]['envelope'][0]['pi'] of type float32[6,256];
  * some axes (4 of them) had size 1024, e.g. axis 0 of argument args[1] of type float32[1024,48];
  * some axes (3 of them) had size 832, e.g. axis 0 of argument args[0]['layers']['streams'][1]['single']['w'] of type float32[832,256];
  * one axis had size 4: axis 0 of argument args[0]['layers']['streams'][0]['double']['w'] of type float32[4,32];
  * one axis had size 80: axis 0 of argument args[0]['layers']['streams'][0]['single']['w'] of type float32[80,256]

I tried to avoid pmapping the params input by setting in_axes=(None, 0, 0, 0, 0) and get

File "/home/baiyu/ferminet/train.py", line 828, in train
    log_prob = 2.0 * batch_network(params, data.positions, data.spins, data.atoms, data.charges)
  File "/home/baiyu/ferminet/train.py", line 558, in <lambda>
    logabs_network = lambda *args, **kwargs: signed_network(*args, **kwargs)[1]
  File "/home/baiyu/ferminet/networks.py", line 1387, in apply
    orbitals = orbitals_apply(params, pos, spins, atoms, charges)
  File "/home/baiyu/ferminet/networks.py", line 1171, in apply
    h_to_orbitals = equivariant_layers_apply(
  File "/home/baiyu/ferminet/networks.py", line 1029, in apply
    h_one, h_two, h_elec_ion = apply_layer(
  File "/home/baiyu/ferminet/networks.py", line 937, in apply_layer
    h_one_in = construct_symmetric_features(
  File "/home/baiyu/ferminet/networks.py", line 552, in construct_symmetric_features
    return jnp.concatenate(features, axis=1)
  File "/home/baiyu/miniconda/envs/newenv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1837, in concatenate
    arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
  File "/home/baiyu/miniconda/envs/newenv/lib/python3.10/site-packages/jax/_src/numpy/lax_numpy.py", line 1837, in <listcomp>

arrays_out = [lax.concatenate(arrays_out[i:i+k], axis)
TypeError: Cannot concatenate arrays with shapes that differ in dimensions other than the one being concatenated: concatenating along dimension 2 for shapes (1024, 16, 1, 256), (1024, 1, 16, 256), (1024, 1, 16, 256), (1024, 16, 1, 32), (1024, 16, 1, 32).

And replacing the x by the default data.positions will give rise to the exact same issue.
Without these two lines of code everything else works fine.

How could I evaluate the value of log probability correctly in the training loop?

Thanks!

AttributeError: module 'jax.core' has no attribute 'extract_call_jaxpr'

I am running with jaxlib==0.3.0 on cuda11 and it starts on my two V100 GPUs but stops with this:

Traceback (most recent call last):
File "/opt/conda/bin/ferminet", line 7, in
exec(compile(f.read(), file, 'exec'))
File "/home/jaxelsen/aisecurity/Ferminet_google/bin/ferminet", line 39, in
app.run(main)
File "/opt/conda/lib/python3.7/site-packages/absl/app.py", line 312, in run
_run_main(main, args)
File "/opt/conda/lib/python3.7/site-packages/absl/app.py", line 258, in _run_main
sys.exit(main(argv))
File "/home/jaxelsen/aisecurity/Ferminet_google/bin/ferminet", line 35, in main
train.train(cfg)
File "/home/jaxelsen/aisecurity/Ferminet_google/ferminet/train.py", line 450, in train
opt_state = optimizer.init(params, subkeys, data)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/optimizer.py", line 498, in init
self.finalize(params, rng, batch, func_state)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/optimizer.py", line 244, in finalize
patterns_to_skip=self.patterns_to_skip)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 468, in auto_register_tags
graph = function_to_jax_graph(func, func_args, params_index=params_index)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 437, in function_to_jax_graph
typed_jaxpr = jax.make_jaxpr(func)(*args)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 256, in merged_func
evaluate_eqn(eqn, jax_util.safe_map(read, eqn.invars), write)
File "/opt/conda/lib/python3.7/site-packages/kfac_ferminet_alpha/tag_graph_matcher.py", line 172, in evaluate_eqn
call_jaxpr, params = jax.core.extract_call_jaxpr(eqn.primitive, eqn.params)
AttributeError: module 'jax.core' has no attribute 'extract_call_jaxpr'

Variance of a converged model

Good day,

I played around with the ferminet implementation a bit and noticed that for small systems like H2 the network very quickly converges with a low variance on the local energy. Though, for the Hydrogen chain with 10 atoms I noticed the variance to converge significantly slower. Could you please share the variance of the converged models for the Hydrogan chain in the original FermiNet paper for reference?

Thanks!

minimize depends on batch size?

Hello,

This is Cunwei and I am interested in ferminet and its applications. When we try to read and run the code, we have a minor problem about the minimize procedure in the code. In qmc.py #264 the code looks like

optimize_step = functools.partial(
            optimizer.minimize,
            features,
            global_step=global_step,
            var_list=self.network.trainable_variables,
            grad_loss=grad_loss_clipped)

Thus, my question is that whether this gradient computation depends on the batch size for Adam optimizer. I checked the tensorflow source code (not sure the version is what required here) and find the adam optimizer tries to dot product grad_loss and the features gradient. Thus, it seems that the gradients is extensive. But if it is not, is this scale dealt with somewhere else?

Thank you a lot for answering this question.

NaN during training

Hi,

I fiddled around with the jax code a bit and noticed that for small systems where any spin has only one electron the network will throw nan after some time.

ferminet --config ferminet/configs/atom.py --config.system.atom H --config.batch_size 4096 --config.pretrain.iterations 0
I0215 05:54:52.148167 139716596184896 train.py:461] Step 00538: -0.4999 E_h, pmove=0.97
I0215 05:54:52.173480 139716596184896 train.py:461] Step 00539: -0.4999 E_h, pmove=0.97
I0215 05:54:52.199377 139716596184896 train.py:461] Step 00540: nan E_h, pmove=0.97
I0215 05:54:52.224862 139716596184896 train.py:461] Step 00541: nan E_h, pmove=0.00
I0215 05:54:52.250287 139716596184896 train.py:461] Step 00542: nan E_h, pmove=0.00

I traced the issue down and found that this happens at the log abs determinant of the Slater determinant (in this case a 1x1 matrix). There is a small probability for a sample to be chosen such that the (1x1) matrix is exactly 0. After that, the code just produces nan.

Different results obtained from the paper for ch3nh2

Hi,
I run ferminet for ch3nh2 with the coordinates of the atom provided in the appendix of the Ferminet paper. The energy is around -91.5 Eh instead of -95.51 Eh.
I guess that might be due to the unit of the coordinates of the atoms provided in the appendix is not Bohr?
In the Appendix I saw provided coordinates of o3, c4h6, c2h4 etc, are they in Bohr or Angstrom? In the paper I cannot find the information about the units of those structures.

Below is the input configs for ch3nh2:

from ferminet import base_config
from ferminet.utils import system
import ml_collections


def get_config() -> ml_collections.ConfigDict:
  """Returns config for running CH3NH2 with FermiNet."""
  cfg = base_config.default()
  # geometry in bohr.
  cfg.system.molecule = [
      system.Atom(symbol='C', coords=(0.0517, 0.7044, 0.0)),
      system.Atom(symbol='N', coords=(0.0517, -0.7596, 0.0)),
      system.Atom(symbol='H', coords=(-0.9417, 1.1762, 0.0)),
      system.Atom(symbol='H', coords=(-0.4582, -1.0994, 0.8124)),
      system.Atom(symbol='H', coords=(-0.4582, -1.0994, -0.8124)),
      system.Atom(symbol='H', coords=(0.5928, 1.0567, 0.8807)),
      system.Atom(symbol='H', coords=(0.5928, 1.0567, -0.8807)),
  ]
  cfg.system.electrons = (9, 9)
  return cfg

Question about exact_cusp function

Hello, there~
I have a question about exact_cusp function in your code, although it's actually turned off.

https://github.com/deepmind/ferminet/blob/c6c53bf96c1425750d6b7e9038eee68ff3de9d81/ferminet/networks.py#L411

 e_cusp = (jnp.sum(1. / (1. + r_ees[0][0])) +
            jnp.sum(1. / (1. + r_ees[1][1])) +
            jnp.sum(1. / (1. + r_ees[0][1])))
return env + a_cusp - 0.5 * e_cusp

If I understand correctly, cusp factor for parallel spin electrons should be 1/4 and 1/2 for anti-parallel spin electrons.
I am afraid it should be

 e_cusp = (jnp.sum(1. / (1. + r_ees[0][0])) / 4 +
            jnp.sum(1. / (1. + r_ees[1][1])) / 4 +
            jnp.sum(1. / (1. + r_ees[0][1])))
return env + a_cusp - 0.5 * e_cusp

0.5 factor in the last line times r_ees[0][0] fixs the double counting in jnp.sum(r_ees[0][0]), however jnp.sum(r_ees[0][0]) still need to be divided by 4, since they are parallel electrons.
As for r_ees[0][1], there are no double-counting problems and the 0.5 factor in the last line produces the correct cusp factor for anti-parallel electrons.

Sorry for bothering~

Installation Error

Hello, I am trying to install ferminet in Windows Subsystem Linux using command pip install -e '.[tensorflow-gpu]' and I am getting this error: ERROR: File "setup.py" not found. Directory cannot be installed in editable mode.

Please help to solve this error!

nan when training with 'adam'

I tried to obtain the ground state of BeH with BeH.py:

import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)
cfg = base_config.default()
cfg.system.electrons = (3,2) # (alpha electrons, beta electrons)
cfg.system.molecule = [system.Atom('Be', (1.3269, 0.0, 0.0)), system.Atom('H', (0.0, 0.0, 0.0))]
cfg.batch_size = 256
cfg.pretrain.iterations = 100
cfg.optim.optimizer = 'adam'
train.train(cfg)

After 3454 steps, I got a nan as below:

INFO:absl:Step 03447: -14.9133 E_h, variance=0.0626 E_h^2, pmove=0.52
INFO:absl:Step 03448: -14.9394 E_h, variance=0.1356 E_h^2, pmove=0.54
INFO:absl:Step 03449: -14.9157 E_h, variance=0.0708 E_h^2, pmove=0.54
INFO:absl:Step 03450: -14.9189 E_h, variance=0.0347 E_h^2, pmove=0.53
INFO:absl:Step 03451: -14.9058 E_h, variance=0.0451 E_h^2, pmove=0.53
INFO:absl:Step 03452: -14.9168 E_h, variance=0.0460 E_h^2, pmove=0.52
INFO:absl:Step 03453: -14.9193 E_h, variance=0.1345 E_h^2, pmove=0.52
INFO:absl:Step 03454: -14.9189 E_h, variance=0.0443 E_h^2, pmove=0.51
INFO:absl:Step 03455: nan E_h, variance=nan E_h^2, pmove=0.51
INFO:absl:Step 03456: nan E_h, variance=nan E_h^2, pmove=0.00
INFO:absl:Step 03457: nan E_h, variance=nan E_h^2, pmove=0.00
INFO:absl:Step 03458: nan E_h, variance=nan E_h^2, pmove=0.00
INFO:absl:Step 03459: nan E_h, variance=nan E_h^2, pmove=0.00
INFO:absl:Step 03460: nan E_h, variance=nan E_h^2, pmove=0.00

Question about pbc ewald part.

Dear contributors,
Glad to see open-source code of electron gas system. When I go through pbc code, I get a problem about ewald summation and I cite it below.https://github.com/deepmind/ferminet/blob/main/ferminet/pbc/hamiltonian.py#L142

phase_prim_ae = phase_ae % 1

If I understand correctly this line is used to calculate the energy between elecrtrons and their images. And the summation should be done following minimal-imag conventions.

Let's assume we are working in a simple cubic and two electrons are seperated by 0.6 L where L is the length of this simple cubic in one dimension. So the minimal imag between these two electrons should be -0.4L, while the code above will give 0.6 L instead which may lead to some problems.

Although this problem can be easily masked if large enough truncated limit is used, it still needs to be corrected for potential problems.

I'm afraid the correct code should be

phase_prim_ae = (phase_ae + 0.5) % 1 - 0.5

Then the minimal imag between 0.6 L electron pairs will be -0.4L, which seems correct.

About configs

Hi all,

I try to see how FermiNet can be run by following some examples in the configs. One of the examples I tried is the following

import sys
from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train
from ferminet.pbc import envelopes
import numpy as np
import jax.numpy as jnp
from jax.config import config


def _sc_lattice_vecs(rs: float, nelec: int) -> np.nd
```array:
  """Returns simple cubic lattice vectors with Wigner-Seitz radius rs."""
  volume = (4 / 3) * np.pi * (rs**3) * nelec
  length = volume**(1 / 3)
  return length * np.eye(3)

logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

cfg = base_config.default()
cfg.system.electrons = (27, 0)
cfg.system.molecule = [system.Atom("X", (0., 0., 0.))]
cfg.pretrain.method = None
lattice = _sc_lattice_vecs(1.0, sum(cfg.system.electrons))
kpoints = envelopes.make_kpoints(lattice, cfg.system.electrons)
cfg.system.make_local_energy_fn = "ferminet.pbc.hamiltonian.local_energy"
cfg.system.make_local_energy_kwargs = {"lattice": lattice, "heg": True}
cfg.network.make_feature_layer_fn = (
    "ferminet.pbc.feature_layer.make_pbc_feature_layer")
cfg.network.make_feature_layer_kwargs = {
    "lattice": lattice,
    "include_r_ae": False
}
cfg.network.make_envelope_fn = (
    "ferminet.pbc.envelopes.make_multiwave_envelope")
cfg.network.make_envelope_kwargs = {"kpoints": kpoints}
cfg.network.full_det = False
cfg.batch_size = 512 
cfg.pretrain.iterations = 0
train.train(cfg)

However, it seems to run this code, one needs to first modify line line396 in train.py "charges" to "charges.shape[0]" to avoid error.

Even though the code can be run after that, it reports energy 28.6515, which seems to be much lower than the number mentioned in arxiv. 2202.05183 (which should be 1.2615x27=34.0605).

In addition, if one tries to run lattice = _sc_lattice_vecs(30, sum(cfg.system.electrons)), it prints out "Initial Energy: 5296.6323 E_h" and then after two steps just becomes nan.

May I know whether the above setup is proper for the related tasks?

Upstream breaking change in `kfac-jax`

The most recent commit to the kfac-jax repo (at the time of writing, f466559d86b07d6a2291cc699ac769c8e0931592) contains a breaking change for the ferminet repository. Last working commit is bacdf8eaf4f5bd1a467b7e9d9703e571ed37c897. Following the installation / usage instructions in README.md will result in a broken installation as a result.

To reproduce, install as per usual instructions and run:

import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train

# Settings in a config files are loaded by executing the the get_config
# function.
def get_config():
  # Get default options.
  cfg = base_config.default()
  # Set up molecule
  cfg.system.electrons = (1,1)
  cfg.system.molecule = [system.Atom('H', (0, 0, -1)), system.Atom('H', (0, 0, 1))]

  # Set training hyperparameters
  cfg.batch_size = 256
  cfg.pretrain.iterations = 100

  return cfg

# Optional, for also printing training progress to STDOUT.
# If running a script, you can also just use the --alsologtostderr flag.
logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

# Define H2 molecule
cfg = get_config()
train.train(cfg)

resulting in stack trace,

Traceback (most recent call last):
  File "/home/ettore/ferminet/test.py", line 6, in <module>
    from ferminet import train
  File "/home/ettore/ferminet/ferminet/train.py", line 24, in <module>
    from ferminet import checkpoint
  File "/home/ettore/ferminet/ferminet/checkpoint.py", line 24, in <module>
    from ferminet import networks
  File "/home/ettore/ferminet/ferminet/networks.py", line 21, in <module>
    from ferminet import envelopes
  File "/home/ettore/ferminet/ferminet/envelopes.py", line 21, in <module>
    from ferminet import curvature_tags_and_blocks
  File "/home/ettore/ferminet/ferminet/curvature_tags_and_blocks.py", line 27, in <module>
    vmap_psd_inv_cholesky = jax.vmap(kfac_jax.utils.psd_inv_cholesky, (0, None), 0)
AttributeError: module 'kfac_jax._src.utils' has no attribute 'psd_inv_cholesky'

Jax error running on A100 GPU (everything is okay on CPU)

Hi,

I got an error on the train.py, line 229 new_params, state, stats = optimizer.step(......)

The error code is shown below:

2022-04-14 12:46:59.552761: E external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2141] Execution of replica 1 failed: INTERNAL: CustomCall failed: jaxlib/cusolver_kernels.cc:44: operation cusolverDnCreate(&handle) failed: cuSolver internal error
2022-04-14 12:47:09.554416: F external/org_tensorflow/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:2288] Replicated computation launch failed, but not all replicas terminated. Aborting process to work around deadlock. Failure message (there may have been multiple failures, see the error log for all failures):

CustomCall failed: jaxlib/cusolver_kernels.cc:44: operation cusolverDnCreate(&handle) failed: cuSolver internal error
Fatal Python error: Aborted

I didn't get any error running on CPU. But on GPU I always get this error.
Could you help me to solve this problem? Thank you.

Questions about the convergence

Hi, there~ Thanks to your kind guidance, I have successfully run ferminet program and get some results about, for example, H4 circle. However the output energy for each iteration still has oscillating behavior roughly about 0.1 Eh even after 10^5 iterations.

So I just want to ask how can I know this network has been well trained? I mean when can I stop updating the network parameters and use it to calculate energy accurately?

I note the default iteration in your program is 10^6, and I wonder whether this oscillating behavior will be cured at that time.

Does Adam Optimization depend on batch size?

Hello,

This is Cunwei and I am interested in ferminet and its applications. When we try to read and run the code, we have a minor problem about the minimize procedure in the code. In qmc.py #264 the code looks like

optimize_step = functools.partial(
            optimizer.minimize,
            features,
            global_step=global_step,
            var_list=self.network.trainable_variables,
            grad_loss=grad_loss_clipped)

Thus, my question is that whether this gradient computation depends on the batch size for Adam Optimizer. I checked the tensorflow source code (not sure the version is what required here) and find the adam optimizer tries to dot product grad_loss and the features gradient. Thus, it seems that the gradients is extensive. But if it is not, is this scale dealt with somewhere else?

Thank you a lot for answering this question.

How to reproduce the results for Neon. JAX.

Hi, I am trying to reproduce the results for Neon, I am running the following code with default base config and only changes to batch_size = 256, pretrain iterations = 100 and optim iterations = 100_000 (for now, will be increased if results not matched):

Training Code
import sys

from absl import logging
from ferminet.utils import system
from ferminet import base_config
from ferminet import train
import numpy as np

logging.get_absl_handler().python_handler.stream = sys.stdout
logging.set_verbosity(logging.INFO)

cfg.system.electrons = (5,5)
cfg.system.molecule = [system.Atom('Ne')]

cfg.batch_size = 256
cfg.pretrain.iterations = 100

train.train(cfg)
Loading Model
with open('ferminet_2021_08_22_16:24:03/qmcjax_ckpt_099929.npz', 'rb') as f:
    params = dict(np.load(f, allow_pickle=True))['params'].tolist()

with open('ferminet_2021_08_22_16:24:03/qmcjax_ckpt_099929.npz', 'rb') as f:
    data = dict(np.load(f, allow_pickle=True))['data']

with open(path+'geometry.npz', 'rb') as f:
    geometry = dict(np.load(f, allow_pickle=True))

foo = partial(networks.fermi_net, envelope_type='isotropic', full_det=False, **geometry)
# networks.fermi_net gives the sign/log of the wavefunction. We only care about the latter.
network = lambda p, x: foo(p, x)[1]
batch_network = jax.vmap(network, (None, 0), 0)
loss = train.make_loss(network, batch_network, geometry['atoms'], geometry['charges'], clip_local_energy=5.0)
ploss = jax.pmap(loss, axis_name='qmc_pmap_axis')  # right now, the code only works if the loss is wrapped by pmap

loss_ = ploss(params, data)  # For neon, should give -128.94165
loss_[0]

At this step loss_ = ploss(params, data), I am getting this error:

ValueError: Incompatible shapes for broadcasting: ((1, 5, 1, 1), (3, 3, 1, 160))

I compared my params with the cloud files given and it seems my pi and sigma envelopes have different shapes.

Any help on how to reproduce the pretrained results would be appreciated.

The proper way to cite FermiNet repo

Thanks a lot for having this great project. Really appreciate it!

Recently we worked on a Diffusion Monte Carlo project based on FermiNet (FermiNet-DMC) and would like to cite this repository directly beside the mentioned PRR and NeurIPS Workshop papers. I can come up with a bibtex section like the one provided in the KFAC repo. One thing I am not sure about is how to fill the author part. Any thought?

which distribution is used in the pretrain phase?

As mentioned in the Ferminet paper, the probability distribution used in the pretrain phase is the average between the HF one and the output from Ferminet output.
Screen Shot 2021-01-26 at 3 01 01 PM

However, the implementation here seems doing something different:

In the master branch, even though we concat the walkers from both HF and Ferminet, but it seems only the ones from HF are used.

Screen Shot 2021-01-26 at 3 05 07 PM

(My understanding is that in the pretrain_hartree_fork function, only the first tf.distribute.get_replica_context(). num_replicas_in_sync walkers are used, which are basically the ones from HF).

On the other hand, in the JAX branch, it seems to me we are only using the ferminet output as the probability distribution when doing the pretrain. (Sorry I'm not that familiar with JAX. It's just that I can't find anything related to HF distribution around the pretrain code)

Is my understanding correct? Or does it mean the distribution used in pretrain is not that critical? Thanks!

how to reconstruct a ferminet?

Dear authors, sorry to ask a naive question again.....

I am quite interest in the network and trying to use it to study interesting physics phenomenon. However, I am a new hand in tensorflow, and I encounter some problems about how to reconstruct this ferminet.

I mean: for example, after I run the H2 example in the Usage. I think this program then restores the trained parameters in the checkpoints. But I don't know how to use the file in the checkpoints to reconstruct the ferminet I have trained. Then I can input the coordinates of the electrons and return the corresponding wave function value.

I think it's somehow an annoying question and sorry to bother.

JAX implementation of the derivatives of log determinents

To my knowledge, the sonnet implementation (master branch) used a hand-rolled gradient function for the logdet_matmul function corresponding to the Appendix C of the Ferminet paper. However, in the JAX implemention, such hand-rolled gradient function seems gone.

Could someone shed some light on the rationale behind the change? Is it because JAX is doing something smarter when differentiating a slogdet function? Thanks!

Issue on running pytest

Based on the guidance, I have already installed the ferminet, but I still have issues when running the pytest.

The error is below:
========== short test summary info===============
ERROR ferminet/pbc/tests/features_test.py - TypeError: unsupported operand type(s) for |: 'ABCMeta' and 'type'
ERROR ferminet/pbc/tests/hamiltonian_test.py - TypeError: unsupported operand type(s) for |: 'ABCMeta' and 'type'
ERROR ferminet/tests/envelopes_test.py - TypeError: unsupported operand type(s) for |: 'ABCMeta' and 'type'
ERROR ferminet/tests/hamiltonian_test.py - TypeError: unsupported operand type(s) for |: 'ABCMeta' and 'type'
ERROR ferminet/tests/networks_test.py - TypeError: unsupported operand type(s) for |: 'ABCMeta' and 'type'
ERROR ferminet/tests/psiformer_test.py - TypeError: unsupported operand type(s) for |: 'ABCMeta' and 'type'
ERROR ferminet/tests/train_test.py - TypeError: unsupported operand type(s) for |: 'ABCMeta' and 'type'

My python version is 3.9.13
And I use Ubuntu 22.04.2 LTS to run the code, is there any reason lead to this problem?

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.