Thanks for the great blogpost + code. I've tried to run the model training with numpyro==0.7.2. Preprocessing runs, but model training fails with "NotImplementedError: This ELBO objective does not support mutable state.". The line that fails is in the training notebook train_handler.fit. The error seems to originate from numpyro introducing mutable states (I think from here line 57, loss_with_mutable_state method). Maybe this commit is related?
---------------------------------------------------------------------------
NotImplementedError Traceback (most recent call last)
<ipython-input-8-1349482c40ae> in <module>
----> 1 train_handler.fit(X_train, n_epochs=5_000, log_freq=1_000, lr=0.1)
~/opt/anaconda3/envs/bhm-at-scale/lib/python3.8/site-packages/bhm_at_scale-1.1.post0.dev3+gea40dd8.dirty-py3.8.egg/bhm_at_scale/handler.py in fit(self, X, n_epochs, log_freq, lr, **kwargs)
130 self._fit(X, n_epochs)
131 else:
--> 132 loss = self.svi.evaluate(self.svi_state, X) / X.shape[0]
133
134 curr_epoch = 0
~/opt/anaconda3/envs/bhm-at-scale/lib/python3.8/site-packages/numpyro/infer/svi.py in evaluate(self, svi_state, *args, **kwargs)
363 _, rng_key_eval = random.split(svi_state.rng_key)
364 params = self.get_params(svi_state)
--> 365 return self.loss.loss(
366 rng_key_eval,
367 params,
~/opt/anaconda3/envs/bhm-at-scale/lib/python3.8/site-packages/numpyro/infer/elbo.py in loss(self, rng_key, param_map, model, guide, *args, **kwargs)
44 :return: negative of the Evidence Lower Bound (ELBO) to be minimized.
45 """
---> 46 return self.loss_with_mutable_state(
47 rng_key, param_map, model, guide, *args, **kwargs
48 )["loss"]
~/opt/anaconda3/envs/bhm-at-scale/lib/python3.8/site-packages/numpyro/infer/elbo.py in loss_with_mutable_state(self, rng_key, param_map, model, guide, *args, **kwargs)
66 :return: a tuple of ELBO loss and the mutable state
67 """
---> 68 raise NotImplementedError("This ELBO objective does not support mutable state.")
69
70
NotImplementedError: This ELBO objective does not support mutable state.