I tried converting my complex sympy expression to jax, and got the following error.
I wrote a minimum working example. The I
is sympy's variable for a complex number. 1j
is Python's version, and they are both treated the same.
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:213, in _sympy_to_node(expr, memodict, func_lookup, make_array)
212 try:
--> 213 return memodict[expr]
214 except KeyError:
KeyError: I*x
During handling of the above exception, another exception occurred:
KeyError Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:213, in _sympy_to_node(expr, memodict, func_lookup, make_array)
212 try:
--> 213 return memodict[expr]
214 except KeyError:
KeyError: I
During handling of the above exception, another exception occurred:
KeyError Traceback (most recent call last)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:180, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
179 try:
--> 180 self._func = func_lookup[expr.func]
181 except KeyError as e:
KeyError: <class 'sympy.core.numbers.ImaginaryUnit'>
The above exception was the direct cause of the following exception:
KeyError Traceback (most recent call last)
/Users/thomas/Documents/vilde.ipynb Cell 6 in <cell line: 8>()
[4](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=3) x = symbols("x")
[6](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=5) expr = x*I # or x*1j
----> [8](vscode-notebook-cell:/Users/thomas/Documents/vilde.ipynb#Y103sZmlsZQ%3D%3D?line=7) sympy2jax.SymbolicModule(expr)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
129 object.__setattr__(self, "__class__", initable_cls)
130 try:
--> 131 cls.__init__(self, *args, **kwargs)
132 finally:
133 object.__setattr__(self, "__class__", cls)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:257, in SymbolicModule.__init__(self, expressions, extra_funcs, make_array, **kwargs)
250 self.has_extra_funcs = True
251 _convert = ft.partial(
252 _sympy_to_node,
253 memodict=dict(),
254 func_lookup=lookup,
255 make_array=make_array,
256 )
--> 257 self.nodes = jax.tree_map(_convert, expressions)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/jax/_src/tree_util.py:205, in tree_map(f, tree, is_leaf, *rest)
203 leaves, treedef = tree_flatten(tree, is_leaf)
204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/jax/_src/tree_util.py:205, in <genexpr>(.0)
203 leaves, treedef = tree_flatten(tree, is_leaf)
204 all_leaves = [leaves] + [treedef.flatten_up_to(r) for r in rest]
--> 205 return treedef.unflatten(f(*xs) for xs in zip(*all_leaves))
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:224, in _sympy_to_node(expr, memodict, func_lookup, make_array)
222 out = _Rational(expr, make_array)
223 else:
--> 224 out = _Func(expr, memodict, func_lookup, make_array)
225 memodict[expr] = out
226 return out
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
129 object.__setattr__(self, "__class__", initable_cls)
130 try:
--> 131 cls.__init__(self, *args, **kwargs)
132 finally:
133 object.__setattr__(self, "__class__", cls)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:183, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
181 except KeyError as e:
182 raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
--> 183 self._args = [
184 _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
185 ]
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:184, in <listcomp>(.0)
181 except KeyError as e:
182 raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
183 self._args = [
--> 184 _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
185 ]
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:224, in _sympy_to_node(expr, memodict, func_lookup, make_array)
222 out = _Rational(expr, make_array)
223 else:
--> 224 out = _Func(expr, memodict, func_lookup, make_array)
225 memodict[expr] = out
226 return out
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/equinox/module.py:131, in _ModuleMeta.__call__(cls, *args, **kwargs)
129 object.__setattr__(self, "__class__", initable_cls)
130 try:
--> 131 cls.__init__(self, *args, **kwargs)
132 finally:
133 object.__setattr__(self, "__class__", cls)
File ~/mambaforge/envs/kaggle/lib/python3.10/site-packages/sympy2jax/sympy_module.py:182, in _Func.__init__(self, expr, memodict, func_lookup, make_array)
180 self._func = func_lookup[expr.func]
181 except KeyError as e:
--> 182 raise KeyError(f"Unsupported Sympy type {type(expr)}") from e
183 self._args = [
184 _sympy_to_node(arg, memodict, func_lookup, make_array) for arg in expr.args
185 ]
KeyError: "Unsupported Sympy type <class 'sympy.core.numbers.ImaginaryUnit'>"