I'm attempting to test MD runs with the AMBER protein force fields and a solvated system as a basis for some future free energy experiments. For very small systems I've generally had good luck getting things working but when I try to test a larger solvated system (# atoms = 37266), I'm having trouble with neighbor list generation causing an out of memory error.
import sys
import time
import jax
import jax.numpy as jnp
import numpy as np
import openmm.app as app
import openmm.unit as unit
from dmff import Hamiltonian, NeighborList
from jax_md import space, smap, energy, minimize, quantity, simulate, quantity
from jax.config import config
config.update("jax_enable_x64", True)
prmtop = app.AmberPrmtopFile('../RAMP1_ion.prmtop')
inpcrd = app.AmberInpcrdFile('../RAMP1_ion.inpcrd')
ff = Hamiltonian("amber14/protein.ff14SB.xml", "amber14/tip3p.xml")
def hhbond(bond):
if bond[0].residue.name == 'HOH':
if bond[0].element._symbol == 'H' and bond[1].element._symbol == 'H':
return True
return False
#remove extra H-H bonds found in AMBER format
prmtop.topology._bonds = [bond for bond in prmtop.topology._bonds if not hhbond(bond)]
potentials = ff.createPotential(prmtop.topology, nonbondedMethod=app.PME, nonbondedCutoff=8*unit.angstrom, prm=prmtop)
params = ff.getParameters()
positions = jnp.array(inpcrd.getPositions(asNumpy=True).value_in_unit(unit.nanometer))
positions = positions - jnp.min(positions, axis=0)
#positions range from 0 to ~9.6 in any given direction
box = jnp.array([
[10.0, 0.0, 0.0],
[0.0, 10.0, 0.0],
[0.0, 0.0, 10.0]
])
#8 angstrom cutoff
nbList = NeighborList(box, .8, potentials.meta["cov_map"])
nbList.allocate(positions)
Traceback (most recent call last):
File "/mnt/ufs18/home-094/betanc18/DMFF/examples/classical/forces_bench_ramp/jax/jaxrampdebug.py", line 85, in <module>
nbList.allocate(positions)
File "/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py", line 44, in allocate
self.nblist = self.neighborlist_fn.allocate(positions)
File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax_md/partition.py", line 816, in allocate_fn
return neighbor_list_fn(position, extra_capacity=extra_capacity, **kwargs)
File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax_md/partition.py", line 803, in neighbor_list_fn
return neighbor_fn((position, False))
File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax_md/partition.py", line 772, in neighbor_fn
idx, occupancy = prune_neighbor_list_sparse(position, idx, **kwargs)
File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/_src/api.py", line 528, in cache_miss
out_flat = xla.xla_call(
File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/core.py", line 1963, in bind
return call_bind(self, fun, *args, **params)
File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/core.py", line 1979, in call_bind
outs = top_trace.process_call(primitive, fun_, tracers, params)
File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/core.py", line 689, in process_call
return primitive.impl(f, *tracers, **params)
File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/_src/dispatch.py", line 236, in _xla_call_impl
return compiled_fun(*args)
File "/mnt/home/betanc18/anaconda3/envs/dmff/lib/python3.9/site-packages/jax/_src/dispatch.py", line 837, in _execute_compiled
out_flat = compiled.execute(in_flat)
jax._src.traceback_util.UnfilteredStackTrace: jaxlib.xla_extension.XlaRuntimeError: RESOURCE_EXHAUSTED: Failed to allocate request for 93.12GiB (99990344464B) on device ordinal 0
BufferAssignment OOM Debugging.
BufferAssignment stats:
parameter allocation: 10.35GiB
constant allocation: 144B
maybe_live_out allocation: 10.35GiB
preallocated temp allocation: 93.12GiB
preallocated temp fragmentation: 64B (0.00%)
total allocation: 113.82GiB
total fragmentation: 10.35GiB (9.09%)
Peak buffers:
Buffer 1:
Size: 31.04GiB
Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/vmap(jit(_einsum))/dot_general[dimension_numbers=(((1,), (1,)), ((), ())) precision=None preferred_element_type=None]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
XLA Label: custom-call
Shape: f64[3,1388754756]
==========================
Buffer 2:
Size: 31.04GiB
Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1, 3) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
XLA Label: fusion
Shape: f64[1388754756,3]
==========================
Buffer 3:
Size: 31.04GiB
Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/gather[dimension_numbers=GatherDimensionNumbers(offset_dims=(1,), collapsed_slice_dims=(0,), start_index_map=(0,)) slice_sizes=(1, 3) unique_indices=False indices_are_sorted=False mode=GatherScatterMode.PROMISE_IN_BOUNDS fill_value=None]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
XLA Label: fusion
Shape: f64[1388754756,3]
==========================
Buffer 4:
Size: 10.35GiB
Entry Parameter Subshape: s64[37266,37266]
==========================
Buffer 5:
Size: 10.35GiB
Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/concatenate[dimension=0]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
XLA Label: fusion
Shape: s32[2,1388754756]
==========================
Buffer 6:
Size: 873.4KiB
Entry Parameter Subshape: f64[37266,3]
==========================
Buffer 7:
Size: 72B
XLA Label: constant
Shape: f64[3,3]
==========================
Buffer 8:
Size: 72B
XLA Label: constant
Shape: f64[3,3]
==========================
Buffer 9:
Size: 16B
Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 1, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
XLA Label: fusion
Shape: (s64[10595], s64[10595])
==========================
Buffer 10:
Size: 16B
Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 0, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
XLA Label: fusion
Shape: (s64[694377378], s64[694377378])
==========================
Buffer 11:
Size: 16B
Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 0, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
XLA Label: fusion
Shape: (s64[173594344], s64[173594344])
==========================
Buffer 12:
Size: 16B
Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 0, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
XLA Label: fusion
Shape: (s64[43398586], s64[43398586])
==========================
Buffer 13:
Size: 16B
Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 0, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
XLA Label: fusion
Shape: (s64[10849646], s64[10849646])
==========================
Buffer 14:
Size: 16B
Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 1, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
XLA Label: fusion
Shape: (s64[2712411], s64[2712411])
==========================
Buffer 15:
Size: 16B
Operator: op_name="jit(prune_neighbor_list_sparse)/jit(main)/jit(_cumulative_reduction)/pad[padding_config=((1, 0, 1),)]" source_file="/mnt/home/betanc18/.local/lib/python3.9/site-packages/dmff/common/nblist.py" source_line=44
XLA Label: fusion
Shape: (s64[678102], s64[678102])
==========================
Running the above snippet of Python code with the input files below causes this issue. From the above error, it looks like some of the buffers allocated are essentially n^2 in one dimension. I don't understand the neighbor/cell list generation code in JAX MD well enough to figure out why this is happening but my understanding is that a cell neighbor list should avoid these issues with an appropriately set cutoff.
The 2 files are the protein system building examples from the AMBER tutorials of the RAMP1 protein solvated in a water box.
RAMP1_ion.zip