# ott/src/ott/neural/flows/genot.py
333 if self.checkpoint_manager is not None:
334 states_to_save = {"state_velocity_field": self.state_velocity_field}
--> 335 if self.state_eta is not None:
336 states_to_save["state_eta"] = self.state_eta
337 if self.state_xi is not None:
AttributeError: 'GENOT' object has no attribute 'state_eta'
iterations=3
# ot_solver = sinkhorn.Sinkhorn()
ot_solver = gromov_wasserstein.GromovWasserstein(epsilon=1e-3)
time_sampler = sample_uniformly
optimizer = optax.adam(learning_rate=1e-4)
options = orbax.checkpoint.CheckpointManagerOptions(max_to_keep=2, create=True)
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
checkpoint_manager = orbax.checkpoint.CheckpointManager(
os.path.join("./ckpts_genots", "genot_test_saving.ckpt"),
orbax_checkpointer,
options
)
genot = GENOT(
velocity_field = neural_vf,
input_dim = source_dim,
output_dim = target_dim,
cond_dim = condition_dim,
iterations = iterations,
valid_freq = iterations-1,
ot_solver = ot_solver,
epsilon = None,
cost_fn = Euclidean(),
scale_cost = 1.0,
optimizer = optimizer,
flow = ConstantNoiseFlow(0.0),
time_sampler = time_sampler,
checkpoint_manager = checkpoint_manager,
k_samples_per_x = 1,
solver_latent_to_data = None,
kwargs_solver_latent_to_data = {},
fused_penalty = 0.,
tau_a = 1.,
tau_b = 1.,
rescaling_a = None,
rescaling_b = None,
unbalanced_kwargs = {},
callback_fn = None,
rng = None,
)
genot(dl, dl) # <-- error occurs here