google-deepmind / chex Goto Github PK
View Code? Open in Web Editor NEWHome Page: https://chex.readthedocs.io
License: Apache License 2.0
Home Page: https://chex.readthedocs.io
License: Apache License 2.0
Hi,
When I try to import chex
, I got the following error:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/__init__.py", line 17, in <module>
from chex._src.asserts import assert_axis_dimension
File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/_src/asserts.py", line 26, in <module>
from chex._src import asserts_internal as _ai
File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/_src/asserts_internal.py", line 32, in <module>
from chex._src import pytypes
File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/_src/pytypes.py", line 36, in <module>
PRNGKey = jax.random.KeyArray
AttributeError: module 'jax.random' has no attribute 'KeyArray'
I am using jax version 0.2.14 and jaxlib 0.1.68. Did they cause the error?
Hello!
I'm interested in using pydantic's recursive constructor / asdict functionality, but jax.jit-ed functions give the following error:
Argument '_Pydantic_OptimConfig_93971134241088(.. SOMETHING HERE...)' of type <class 'pydantic.dataclasses._Pydantic_OptimConfig_93971134241088'> is not a valid JAX type.
I'm trying to inherit from a dataclass with an optional id in the parent, however, I am seeing TypeError: non-default argument 'value' follows default argument
. This is the minimal code to reproduce.
from typing import Optional
import chex
@chex.dataclass
class Base:
idx: Optional[int] = None
@chex.dataclass
class Derived(Base):
value: int
This is an issue with the dataclasses
library rather than chex.dataclasses
, however, according to https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses#answer-69822584, it is possible to fix by setting kw_only=True
.
In most JAX-based implementations, jit
is almost always included. Basically, if there is no reason not to use it, people will try to take advantage of its speedup.
I noticed that @chex.variants(with_jit=True, without_jit=True)
is a great way to assert the same behavior for both execution paths, as long as the variant is derived from a non-jitted function.
In the following example, I would expect to see "Tracing fn" four times total: Three times for the non-jitted variants and once for the initial jit compiliation. In reality, test_variant_pre_jitted()
is executed twice with the jitted fn
, resulting in two tracer outputs.
@chex.variants(with_jit=True, without_jit=Truue)
def test_variant_pre_jitted(self):
@jit
def fn(x, y):
print("Tracing fn")
return x + y
var_fn = self.variant(fn)
self.assertEqual(var_fn(1, 2), 3)
self.assertEqual(var_fn(3, 4), 7)
self.assertEqual(var_fn(5, 6), 11)
Of course, omitting @jit
will lead to the expected behavior. However, when more complex implementations already make use of jit
, variants do not make sense anymore, sadly.
My case is the latter and I only see the option of implementing a model-wide use_jit
flag so that I can derive variants from non-jitted code. However, this makes the whole idea of variants rather obsolete altogether.
I'm aware this could well be a limitation of JAX and jit
itself rather than chex. In that case, I think an error when jitted code is passed to variant()
would make this more transparent.
Hello,
Thanks for the useful package. I am hitting an error when using the latest chex. See reproduction instructions below.
pip install chex==0.1.5 && python -c "import chex; print(isinstance(None, (int, float, chex.Array)))"
False
pip install chex==0.1.7 && python -c "import chex; print(isinstance(None, (int, float, chex.Array)))"
raceback (most recent call last):
File "<string>", line 1, in <module>
File "/usr/lib/python3.8/typing.py", line 769, in __instancecheck__
return self.__subclasscheck__(type(obj))
File "/usr/lib/python3.8/typing.py", line 777, in __subclasscheck__
raise TypeError("Subscripted generics cannot be used with"
TypeError: Subscripted generics cannot be used with class and instance checks
This causes an error for optax when trying to use inject_hyperparams
, which essentially uses import chex; print(isinstance(None, (int, float, chex.Array)))
When inheriting one dataclass from another, Chex's dataclass does not allow a super() call to be made. This is something you can do in Python's base dataclass module.
A minimum working example is
from chex import dataclass as dataclass
@dataclass
class ChexBase:
a : int
def __post_init__(self):
self.b = self.a + 1
@dataclass
class ChexSub(ChexBase):
a: int
def __post_init__(self):
super().__post_init__()
self.c = self.a + 2
temp = ChexSub(a = 1)
temp.b
Importing dataclass
from dataclasses runs without error and returns 2, as expected.
Environment
Hi,
First of all thank you for this very useful library !
I have a project in Jax in which I already implemented my tests using pytest. However the possibility that chex.variants
offers are too nice to ignore. Simultaneously I would like not to rewrite all my test.
Is there a way to reconcile pytest and chex ?
Thank you again for all the work!
Best,
@hbq1 The previous release version was 0.1.8 and the current one is 0.1.81, not 0.1.9 as one would expect. Why is there a gap between versions? Might it be a typo/mistake by any chance?
Hi chex
team,
Is there any potential for something like flax.struct.dataclass
in chex
?
Basically a kind of dataclass
that can mark static arguments.
Two other variations include jax_dataclasses
and simple-pytree
though a chex
official one would be cool for all its benefits of being part of chex
.
Thank you!
Can you update chex to work with the newest version of jax.
From jax.random: PRNGKeyArray, KeyArray,
default_prng_impl, threefry_2x32, threefry2x32_key, threefry2x32_p, rbg_key, and unsafe_rbg_key.
This breaks the completely breaks the import of chex and makes many packages unusable.
Static type checkers like pyright, mypy, etc. will think chex.dataclass
-decorated dataclass has a constructor with no parameters.
Example:
@chex.dataclass(frozen=True)
class Foo:
a: int
b: int
However, Foo()
is not a valid call: A legitimate call would be something like Foo(a=1, b=2)
. This does not agree with static type checker's analysis.
Compare the behavior with built-in dataclass:
The assert max traces decorator still raises an assertion error when chex is configured to disable assertions with chex.disable_assertions().
Code to reproduce:
import jax
import jax.numpy as jnp
import chex
chex.disable_asserts()
@jax.jit
@chex.assert_max_traces(n=1)
def f(x):
return x
chex.assert_equal_shape(jnp.zeros((1)), jnp.zeros((2,))) # correctly ignored
f(jnp.zeros((1,)))
f(jnp.zeros((2,))) # AssertionError: [Chex] Function 'f' is traced > 1 times!
Applying the chex.dataclass
wrapper to a class yields the following error:
import chex
@chex.dataclass(frozen=True)
class Class:
x: int
Class(4)
$ pyright --version
pyright 1.1.337
$ python3 --version
Python 3.11.6
$ python3 -c "import chex; print(chex.__version__)"
0.1.85
$ pyright test.py
/Users/user/Desktop/test.py
/Users/user/Desktop/test.py:7:1 - error: Expected no arguments to "Class" constructor (reportGeneralTypeIssues)
1 error, 0 warnings, 0 informations
microsoft/pyright#6536 (comment)
This is a bug in the
chex
library. Thechex.dataclass
decorator has no type annotations despite the fact that the package contains a "py.typed" marker file.
microsoft/pyright#6536 (comment)
I recommend looking at the stdlib
dataclass
class in the typesheddataclass.pyi
stub.
https://github.com/python/typeshed/blob/main/stdlib/dataclasses.pyi
I spend quite some time figuring out why code in a large codebase was so slow, only to find out that jit
was disabled throughout the entire project. This was because the main
function was called as follows:
with chex.fake_pmap_and_jit(FLAGS.debug):
main()
While on first sight it appears as if this indeed disables both pmap and jit if flag debug
is set, this in fact only disables pmap
and always disables jit!
The reason is that fake_pmap_and_jit
take two positional arguments that disable respectively pmap
and jit
, and they are both True
by default. The names of these arguments are somewhat cryptic to me as well: enable_pmap_patching
and enable_jit_patching
, which actually disable these JAX transformations.
Given these observations, I think the situation would improve if the signature would be:
def fake_pmap_and_jit(*, disable_pmap: bool = True, disable_jit: bool = True)
Then my code above would then look like this:
with chex.fake_pmap_and_jit(disable_pmap=FLAGS.debug):
main()
Which shows clearly we are not setting disable_jit
, so we would rewrite this to:
with chex.fake_pmap_and_jit(disable_pmap=FLAGS.debug, disable_jit=FLAGS.debug):
main()
Chex dataclasses can be pickled with pickle
, but not with dill
:
import dill
import pickle
import chex
@chex.dataclass
class Point:
x: float
y: float
# Works fine.
pickle.dumps(Point(x=1.0, y=2.0))
# Generates `RecursionError: maximum recursion depth exceeded`
dill.dumps(Point(x=1.0, y=2.0))
To start with, thanks for open sourcing your work on Chex, it's a great tooling library for building robust Jax applications!
As I was upgrading to the latest release 0.0.3, I noticed quite a few of my tests breaking. It happens that the default option mappable_dataclass=True
in chex.dataclass
is breaking the usual interface of dataclasses (which is clearly expected reading the code documentation!)
I guess probably from the perspective of Deepmind usage, it makes sense to default this option. But from an external user point of view, it is rather surprising to have a dataclass decorator not behaving like a dataclass. I think it would be great to make it clear in the library readme that this option needs to be turned off to get the full dataclass behaviour (or turned it off by default).
def print_pytree(pytree: chex.ArrayTree):
print(pytree)
def main(argv):
# does not work
# tree_x = (-1.0, {"k1": 1.0, "k2": (1.0, 1.0)}, 1.0)
# works
tree_x = [-1.0, {"k1": 1.0, "k2": (1.0, 1.0)}, 1.0]
print_pytree(tree_x)
On installing the latest version of Chex directly from GitHub, the following error pops up when importing chex:
File ".../anaconda3/lib/python3.8/site-packages/chex/__init__.py", line 17, in <module>
from chex._src.asserts import assert_axis_dimension
File "...anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py", line 26, in <module>
from chex._src import asserts_internal as _ai
File ".../anaconda3/lib/python3.8/site-packages/chex/_src/asserts_internal.py", line 34, in <module>
from chex._src import pytypes
File ".../lib/python3.8/site-packages/chex/_src/pytypes.py", line 19, in <module>
import jax.numpy as jnp
ModuleNotFoundError: No module named 'jax.numpy'
Jax: v0.3.34
jaxlib: v0.1.69
Hi,
Thanks for sharing this nice package! I especially like all the assertions for pytrees.
However, I came across the following inconsistency in the documentation:
Current situation
According to the docs, the shape_prefix
argument of assert_tree_shape_prefix
is of type Sequence[int]
.
However, when I pass sequence such as list (instead of a tuple)
import chex
import jax.numpy as jnp
mytree = {'a': jnp.array([[[1], [2]]])}
chex.assert_tree_shape_prefix(mytree, shape_prefix=[1, 2]) # AssertionError!
The assertion raises an exception:
AssertionError: [Chex] Assertion assert_tree_shape_prefix failed: Tree leaf 'a' has a shape prefix different from expected: (1, 2) != [1, 2].
The error can simply be fixed by using a tuple instead:
chex.assert_tree_shape_prefix(mytree, shape_prefix=(1, 2)) # OK!
I think this is not the only inconsistent function, but I did not check for others.
Desired situation
Ideally, I would like the function to behave like in the docs, so that I can also pass a list. Why? I think being able to choose square brackets (i.e., list) after closing parenthesis helps readability.
I would be interested to hear your opinion and I am happy to contribute a pull request.
Keep up the good work!
Hylke
Hi,
I would like to flag that the citation link is broken. It is also broken on the general DeepMind Jax ecosystem page.
According to the docs, by default a class wrapped with chex.dataclass
can be indexed, because the dataclass becomes compatible with collections.abc.Mapping
(because mappable_dataclass=True
).
However, mypy doesn't seem to understand this. For example:
import chex
@chex.dataclass
class Container:
foo: float
c = Container(foo=1.)
d = c.foo # OK.
e = c['foo'] # error: Value of type "Container" is not indexable [index]
Looking at the code, it seems that this is related to methods such as __getitem__
that are added dynamically with setattr
which mypy doesn't recognise.
Any ideas how to go around this apart for (i) explicitly silencing the error (ii) using a different method of accessing the variables?
Keep up the good work!
Hylke
Apologies if this is the wrong place to post this.
The problem is that the dependencies for the chex package on conda-forge is incorrect: there is a typo in the jax version required.
conda search chex==0.1.7 --channel conda-forge --info
returns
chex 0.1.7 pyhd8ed1ab_0
-----------------------
file name : chex-0.1.7-pyhd8ed1ab_0.conda
name : chex
version : 0.1.7
build : pyhd8ed1ab_0
build number: 0
size : 70 KB
license : Apache-2.0
subdir : noarch
url : https://conda.anaconda.org/conda-forge/noarch/chex-0.1.7-pyhd8ed1ab_0.conda
md5 : 7d643a09cac375aab18872f92db3b78c
timestamp : 2023-03-27 14:01:47 UTC
dependencies:
- absl-py >=0.9.0
- dm-tree >=0.1.5
- jax >=0.1.55
- jaxlib >=0.1.37
- numpy >=1.18.0
- python >=3.6
- toolz >=0.9.0
- typing_extensions >=4.2.0
But the jax version should be >=0.4.6
from jax import jit
from jax.lax import scan
from tjax import IntegralNumeric, RealNumeric
from tjax.dataclasses import dataclass, field
import chex
def f(carry, _):
return carry + 1.0, None
@jit
def do_scan(c):
final, _ = scan(f, c.x, None, c.y)
return final
@dataclass
class C:
x: RealNumeric
y: IntegralNumeric = field(static=True)
print(do_scan(C(1.0, 10))) # works
@chex.dataclass
class D:
x: RealNumeric
y: IntegralNumeric
print(do_scan(D(x=1.0, y=10))) # fails
ImportError Traceback (most recent call last)
Cell In[27], line 20
18 # Import model definition from big_vision
19 from big_vision.models.proj.paligemma import paligemma
---> 20 from big_vision.trainers.proj.paligemma import predict_fns
22 # Import big vision utilities
23 import big_vision.datasets.jsonl
File /kaggle/working/big_vision_repo/big_vision/trainers/proj/paligemma/predict_fns.py:20
17 import functools
19 from big_vision.pp import registry
---> 20 import big_vision.utils as u
21 import einops
22 import jax
File /kaggle/working/big_vision_repo/big_vision/utils.py:38
36 import flax.jax_utils as flax_utils
37 import jax
---> 38 from jax.experimental.array_serialization import serialization as array_serial
39 import jax.numpy as jnp
40 import ml_collections as mlc
File /opt/conda/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py:36
34 from jax._src import sharding
35 from jax._src import sharding_impls
---> 36 from jax._src.layout import Layout, DeviceLocalLayout as DLL
37 from jax._src import typing
38 from jax._src import util
ImportError: cannot import name 'DeviceLocalLayout' from 'jax._src.layout' (/opt/conda/lib/python3.10/site-packages/jax/_src/layout.py)
Hi,
Thanks for making this awesome library!
Is it possible to specify fields in the chex.dataclass definitions to not include certain fields? This is a feature supported in flax https://flax.readthedocs.io/en/latest/_modules/flax/struct.html#dataclass
which I found to be quite useful when defining data classes with fields (such as JAX functions) that shouldn't be mapped over with dm-tree or jax.tree_map. I am not sure if this is supported out of the box by chex at the moment but is something that I hope would be part of chex.
trying to import optax and getting an error AttributeError: module 'jax' has no attribute '_src'
for jax versions > 0.3.17
optax version == 0.1.3
chex version == 0.1.3
In [1]: import optax
/home/penn/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/pytypes.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.
PyTreeDef = type(jax.tree_structure(None))
---------------------------------------------------------------------------
AttributeError Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 import optax
File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/__init__.py:17, in <module>
1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
13 # limitations under the License.
14 # ==============================================================================
15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import experimental
18 from optax._src.alias import adabelief
19 from optax._src.alias import adafactor
File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/experimental/__init__.py:20, in <module>
1 # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
13 # limitations under the License.
14 # ==============================================================================
15 """Experimental features in Optax.
16
17 Features may be removed or modified at any time.
18 """
---> 20 from optax._src.experimental.complex_valued import split_real_and_imaginary
21 from optax._src.experimental.complex_valued import SplitRealAndImaginaryState
File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/_src/experimental/complex_valued.py:32, in <module>
15 """Complex-valued optimization.
16
17 When using `split_real_and_imaginary` to wrap an optimizer, we split the complex
(...)
27 See details at https://github.com/deepmind/optax/issues/196
28 """
30 from typing import NamedTuple, Union
---> 32 import chex
33 import jax
34 import jax.numpy as jnp
File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/__init__.py:17, in <module>
1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
2 #
3 # Licensed under the Apache License, Version 2.0 (the "License");
(...)
13 # limitations under the License.
14 # ==============================================================================
15 """Chex: Testing made fun, in JAX!"""
---> 17 from chex._src.asserts import assert_axis_dimension
18 from chex._src.asserts import assert_axis_dimension_comparator
19 from chex._src.asserts import assert_axis_dimension_gt
File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/asserts.py:26, in <module>
23 import unittest
24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
27 from chex._src import pytypes
28 import jax
File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/asserts_internal.py:32, in <module>
29 from typing import Any, Sequence, Union, Callable, Optional, Set, Tuple, Type
31 from absl import logging
---> 32 from chex._src import pytypes
33 import jax
34 import jax.numpy as jnp
File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/pytypes.py:44, in <module>
40 Device = jax.lib.xla_extension.Device
42 ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
---> 44 ArrayDType = jax._src.numpy.lax_numpy._ScalarMeta
AttributeError: module 'jax' has no attribute '_src'
is_leaf
is handled correctly with NamedTuple
class NT(NamedTuple):
a: Any
b: Any
class Histogram(NamedTuple):
hist: jnp.ndarray
bins: jnp.ndarray
def is_leaf(n):
print(n)
return isinstance(n, Histogram)
jax.tree_map(lambda x: x, NT(a=Histogram(1, 2), b=Histogram(4, 3)), is_leaf=is_leaf)
Output (as expected)
NT(a=Histogram(hist=1, bins=2), b=Histogram(hist=4, bins=3))
Histogram(hist=1, bins=2)
Histogram(hist=4, bins=3)
Does not work with chex.dataclass
@chex.dataclass
class DC:
a: Any
b: Any
jax.tree_map(identity, DC(a=Histogram(1, 2), b=Histogram(4, 3)), is_leaf=is_leaf)
Actual output
DC(a=Histogram(hist=1, bins=2), b=Histogram(hist=4, bins=3))
1
2
4
3
I guess either the setup.py
should be updated to reflect that, or a version check can be added to keep supporting older versions of Jax.
Hello, I have a dilemma with chexify - consider the following code:
# If this is not commented out, the second test will fail
# If this is commented out, the first test will fail
@chex.chexify
@jax.jit
def log_safe(x: jnp.array) -> jnp.array:
chex.assert_trees_all_equal(x > 0, jnp.ones_like(x, dtype=bool))
return jnp.log(x)
@chex.chexify
@jax.jit
def combo_safe(x: jnp.array) -> jnp.array:
chex.assert_trees_all_equal(x != 1, jnp.ones_like(x, dtype=bool))
return log_safe(x) / (x - 1)
def test_log_safe() -> None:
x = jnp.array([1.0, 2.0, 3.0, -1.0])
with pytest.raises(Exception):
log_safe(x)
log_safe.wait_checks()
x = jnp.array([1.0, 2.0, 3.0, 4.0])
assert jnp.array_equal(log_safe(x), jnp.log(x))
log_safe.wait_checks()
def test_combo_safe() -> None:
x = jnp.array([1.0, 2.0, 3.0, 4.0])
with pytest.raises(Exception):
combo_safe(x)
combo_safe.wait_checks()
x = jnp.array([2.0, 3.0, 4.0, 5.0])
assert jnp.array_equal(combo_safe(x), jnp.log(x) / (x - 1))
combo_safe.wait_checks()
If I comment out the first chexify the test_log_safe
test will fail with RuntimeError: Value assertions can only be called from functions wrapped with @chex.chexify. See the docs.
which makes sense to me. However, once I add the decorator back in, the second test fails with RuntimeError: Nested @chexify wrapping is disallowed. Make sure that you only wrap the function at the outermost level.
A hack in this simple scenario would be to make two versions of the function, a log_safe
without the chexify decorator and a log_safe_test = chex.chexify(log_safe)
and only call the log_safe_test
version during my test. However, that solution is pretty clumsy, especially if I have a lot of these scenarios. In a codebase that is fully end-to-end jax, that would mean all but the outermost function require this hack. Would it be possible to allow for nested chex.chexify where subsequent applications of the macro simply do nothing, or just raise a warning?
Example:
chex.chexify(functools.partial(fn, foo='bar'))
Error:
AttributeError: 'functools.partial' object has no attribute '__name__'
WARNING:absl:[Chex] Some of chexify assetion statuses were not inspected due to async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html). Consider calling `chex.block_until_chexify_assertions_complete()` at the end of computations that rely on jitted chex assetions.
I'm seeing the following test failure when running the test suite:
============================= test session starts ==============================
platform linux -- Python 3.10.7, pytest-7.1.3, pluggy-1.0.0
rootdir: /build/source
collected 548 items
chex/chex_test.py . [ 0%]
chex/_src/asserts_chexify_test.py ......F..... [ 2%]
chex/_src/asserts_internal_test.py .s.s......... [ 4%]
chex/_src/asserts_test.py ..s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s. [ 13%]
s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s..................... [ 26%]
........................................................................ [ 39%]
........................................................................ [ 52%]
................................. [ 58%]
chex/_src/dataclass_test.py ........................................... [ 66%]
chex/_src/dimensions_test.py ................. [ 69%]
chex/_src/fake_set_n_cpu_devices_test.py s [ 69%]
chex/_src/fake_test.py ................................ [ 75%]
chex/_src/restrict_backends_test.py ssssssssss [ 77%]
chex/_src/variants_test.py .....................s....s............s....s [ 85%]
..........................................................ssssssssssssss [ 98%]
sssssss [100%]
=================================== FAILURES ===================================
__________________ AssertsChexifyTest.test_uninspected_checks __________________
self = <chex._src.asserts_chexify_test.AssertsChexifyTest testMethod=test_uninspected_checks>
def test_uninspected_checks(self):
@jax.jit
def _pos_sum(x):
chex_value_assert_positive(x, custom_message='err_label')
return x.sum()
invalid_x = -jnp.ones(3)
chexify_async(_pos_sum)(invalid_x) # async error
> with self.assertRaisesRegex(AssertionError, 'err_label'):
E AssertionError: AssertionError not raised
chex/_src/asserts_chexify_test.py:179: AssertionError
------------------------------ Captured log call -------------------------------
WARNING absl:asserts_chexify.py:57 [Chex] Some of chexify assetion statuses were not inspected due to async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html). Consider calling `chex.block_until_chexify_assertions_complete()` at the end of computations that rely on jitted chex assetions.
=============================== warnings summary ===============================
chex/_src/asserts_chexify_test.py: 12 warnings
/build/source/chex/_src/asserts_chexify_test.py:58: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
return jnp.all(jnp.array([(x > 0).all() for x in jax.tree_leaves(tree)]))
chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_static_assertion__with_jit
chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_static_assertion__without_jit
/build/source/chex/_src/asserts_chexify_test.py:86: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
return sum(x.sum() for x in jax.tree_leaves(tree))
chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
/nix/store/4y9j6xdkgqwkdx5ki508l175smcjgs9l-python3.10-pytest-7.1.3/lib/python3.10/site-packages/_pytest/unraisableexception.py:78: PytestUnraisableExceptionWarning: Exception ignored in atexit callback: <function _check_if_hanging_assertions at 0x7ffddfe66d40>
Traceback (most recent call last):
File "/build/source/chex/_src/asserts_chexify.py", line 32, in _check_error
checkify.check_error(err)
File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 476, in check_error
return assert_p.bind(err, code, payload, msgs=error.msgs)
File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 328, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 331, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 698, in process_primitive
return primitive.impl(*tracers, **params)
File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 483, in assert_impl
raise_error(Error(err, code, msgs, payload))
File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 123, in raise_error
raise ValueError(err)
ValueError: [Chex] chexify assertion failed [err_label] [failed at /build/source/chex/_src/asserts_chexify_test.py:173] (check failed at /build/source/chex/_src/asserts_internal.py:229 (_chex_assert_fn))
During handling of the above exception, another exception occurred:
Traceback (most recent call last):
File "/build/source/chex/_src/asserts_chexify.py", line 62, in _check_if_hanging_assertions
block_until_chexify_assertions_complete()
File "/build/source/chex/_src/asserts_chexify.py", line 51, in block_until_chexify_assertions_complete
wait_fn()
File "/build/source/chex/_src/asserts_chexify.py", line 180, in _wait_checks
_check_error(async_check_futures.popleft().result(async_timeout))
File "/build/source/chex/_src/asserts_chexify.py", line 40, in _check_error
raise AssertionError(msg) # pylint:disable=raise-missing-from
AssertionError: [Chex] chexify assertion failed [err_label] [failed at /build/source/chex/_src/asserts_chexify_test.py:173]
warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
/build/source/chex/_src/asserts_chexify_test.py:52: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
if not all((x > 0).all() for x in jax.tree_leaves(tree)):
-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
====== 1 failed, 461 passed, 86 skipped, 20 warnings in 84.47s (0:01:24) =======
error: builder for '/nix/store/f9icjsb9pbz4p8qpsyhp9gq1fvjvwwhz-python3.10-chex-0.1.5.drv' failed with exit code 1;
last 10 log lines:
> chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
> chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
> chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
> /build/source/chex/_src/asserts_chexify_test.py:52: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
> if not all((x > 0).all() for x in jax.tree_leaves(tree)):
>
> -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
> =========================== short test summary info ============================
> FAILED chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
> ====== 1 failed, 461 passed, 86 skipped, 20 warnings in 84.47s (0:01:24) =======
I'm using
I would like to propose an API enhancement that allow the use of chex.Dimensions inside function annotations. If there is interest I'd like to contribute. Example below:
dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
def foo(arr: chex.Array):
chex.assert_shape(arr, dims['BTE'])
# fn logic
### turns into ###
def foo(arr: chex.Array(dims['BTE'])): # behind the scenes assert on function call
# fn logic
This is particularly useful for dataclasses e.g.
dims = chex.Dimensions(B=batch_size, T=rollout_len)
# asserts are run on instantiation
class TimeStep:
q_values: chex.Array(dims['BT'])
discounts: chex.Array(dims['BT'])
rewards: chex.Array(dims['BT'])
Pros:
Cons:
Test files are not filtered properly. The issue is that setuptools.find_package
finds packages not modules while tests are organized as a separate modules. In order to mitigate the issue, one should filter test files manually as follows. This patch are created and tested on chex
v0.1.86.
--- a/setup.py 2024-03-19 12:58:11.000000000 +0300
+++ b/setup.py 2024-04-17 15:19:33.593293889 +0300
@@ -15,8 +15,11 @@
"""Install script for setuptools."""
import os
-from setuptools import find_packages
-from setuptools import setup
+from pathlib import Path
+
+from setuptools import find_packages, setup
+from setuptools.command.build_py import build_py as _build_py
+
_CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
@@ -40,6 +43,15 @@
]
+class build_py(_build_py):
+
+ def find_package_modules(self, package, package_dir):
+ modules = super().find_package_modules(package, package_dir)
+ return [(pkg, mod, file)
+ for pkg, mod, file in modules
+ if not Path(file).match('**/*_test.py')]
+
+
setup(
name='chex',
version=_get_version(),
@@ -51,7 +63,7 @@
long_description_content_type='text/markdown',
author_email='[email protected]',
keywords='jax testing debugging python machine learning',
- packages=find_packages(exclude=['*_test.py']),
+ packages=find_packages(),
install_requires=_parse_requirements(
os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')),
tests_require=_parse_requirements(
@@ -73,4 +85,5 @@
'Topic :: Software Development :: Testing :: Unit',
'Topic :: Software Development :: Libraries :: Python Modules',
],
+ cmdclass={'build_py': build_py},
)
/home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/toolz/compatibility.py:2: DeprecationWarning: The toolz.compatibility module is no longer needed in Python 3 and has been deprecated. Please import these utilities directly from the standard library. This module will be removed in a future release.
warnings.warn("The toolz.compatibility module is no longer "
$ python --version
Python 3.9.1
In [1]: import chex
In [2]: @chex.dataclass
...: class Parameters:
...: x: chex.ArrayDevice
...: y: chex.ArrayDevice
...:
...: parameters = Parameters(
...: x=jnp.ones((2, 2)),
...: y=jnp.ones((1, 2)),
...: )
...:
...: # Dataclasses can be treated as JAX pytrees
...: jax.tree_map(lambda x: 2.0 * x, parameters)
...:
...: # and as mappings by dm-tree
...: tree.flatten(parameters)
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-3-3461a2700932> in <module>
1 @chex.dataclass
----> 2 class Parameters:
3 x: chex.ArrayDevice
4 y: chex.ArrayDevice
5
~/.virtualenvs/science/lib/python3.9/site-packages/chex/_src/dataclass.py in dataclass(cls, init, repr, eq, order, unsafe_hash, frozen, mappable_dataclass, restricted_inheritance)
104 if cls is None:
105 return dcls
--> 106 return dcls(cls)
107
108
~/.virtualenvs/science/lib/python3.9/site-packages/chex/_src/dataclass.py in __call__(self, cls)
147
148 if self.mappable_dataclass:
--> 149 dcls = mappable_dataclass(dcls, self.restricted_inheritance)
150
151 def _from_tuple(args):
~/.virtualenvs/science/lib/python3.9/site-packages/chex/_src/dataclass.py in mappable_dataclass(cls, restricted_inheritance)
81 if cls.__bases__ == (object,):
82 # `collections.Mapping` is incompatible with `object`
---> 83 cls.__bases__ = (collections.Mapping,)
84 else:
85 cls.__bases__ += (collections.Mapping,)
TypeError: __bases__ assignment: 'Mapping' deallocator differs from 'object'
Hi,
I'm the lead developer of NetKet, an established machine learning / quantum physics package.
We have recently finished rewriting our core to be based on Jax (and flax), and recently released a beta version.
Since many physicists seem to use anaconda, we would also like to update our conda recipe.
However, since we depend on optax (and therefore on Chex), we would need Chex to have a Conda recipe.
Is that something you'd consider? I am willing to volunteer some work to help you.
I tried creating a recipe starting from your pypi source distribution, but that is problematic because you don't bundle your requirements.txt
file, which is required to run setup.py
.
I could create a recipe from the tag tarballs on GitHub, but that sometimes prevent the conda packages from auto-updating the recipe for later releases.
In my experience, when chex reports max traces exceeded, it's usually because of me passing parameters to the function with different shapes or data types. Is it possible for chex to report such inconsistency?
e.g.,
AssertionError: [Chex] Function '_wrapper' is traced > 1 times!
Difference in input shapes. Last time variable `x` traced with shape "(10, 1)", this time traced with shape "(9, 1)".
Hi,
The latest version (chex==0.1.7
) is missing the typing-extension
dependency.
Problem
For example, the following python code
from chex import dataclass
raises the error
File "<string>", line 1, in <module>
File "/usr/local/lib/python3.11/dist-packages/chex/__init__.py", line 72, in <module>
from chex._src.dataclass import dataclass
File "/usr/local/lib/python3.11/dist-packages/chex/_src/dataclass.py", line 23, in <module>
from typing_extensions import dataclass_transform # pytype: disable=not-supported-yet
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'typing_extensions'
on ubuntu 23.04 (with Python 3.11)
Solution
I found that installing typing-extension
resolves the error:
pip3 install typing-extension
Reproduce
The following code reproduces the error:
docker run -it ubuntu:23.04 \
bash -c \
"apt update && apt install --assume-yes python3-pip && pip3 install --break-system-packages chex && python3 -c 'from chex import dataclass'"
Hi there!
Please tell -- how to get version of chex installed in my ubuntu(18.0)
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.82 requires numpy>=1.25.0, but you have numpy 1.24.3 which is incompatible.
Successfully installed numpy-1.24.3
Trying to get Diffusers setup - https://huggingface.co/docs/transformers/installation
Hello,
Seems like the newest version of jax (0.3.7) removed some classes that are used here in chex. Should chex upper bound the jax version? I see this conflicting code is not currently on the main branch -- alternatively, maybe a new release can be made?
See below:
import jax
import chex
chex.assert_trees_all_equal(jax.random.key(0), jax.random.key(0))
> RuntimeWarning: divide by zero encountered in equal val = comparison(x, y)
# Runs fine
chex.assert_trees_all_equal(jax.random.PRNGKey(0), jax.random.PRNGKey(0))
The problem is not jax
per se, since key comparison works:
jax.random.key(0) == jax.random.key(1)
> Array(False, dtype=bool)
jax.random.key(0) == jax.random.key(0)
> Array(True, dtype=bool)
The _with_pmap
function accepts static_argnums
as a parameter, but not static_argnames
. This is inconsistent with other variants, such as with_jit
and with_device
. Crucially, this prevents to test methods that require to pass arguments by name (e.g., Distrax's Distribution.sample())
More generally, it would be best if all variants accepted the same parameters where possible (i.e., where not specific to a single variant) and I would suggest to check all keys in **unused_kwargs
against a list of allowed parameters (i.e., the union of the parameters of all variant functions) to prevent silent errors due to e.g., misspells.
chex
v0.1.83 is failing with the latest jax
(0.4.19) as jax.core.Shape
.
This issue has been fixed on the master
branch of chex
in this commit.
Could you please make a release so as to ship this fix ?
Thank you very much !
A declarative, efficient, and flexible JavaScript library for building user interfaces.
๐ Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.
TypeScript is a superset of JavaScript that compiles to clean JavaScript output.
An Open Source Machine Learning Framework for Everyone
The Web framework for perfectionists with deadlines.
A PHP framework for web artisans
Bring data to life with SVG, Canvas and HTML. ๐๐๐
JavaScript (JS) is a lightweight interpreted programming language with first-class functions.
Some thing interesting about web. New door for the world.
A server is a program made to process requests and deliver data to clients.
Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.
Some thing interesting about visualization, use data art
Some thing interesting about game, make everyone happy.
We are working to build community through open source technology. NB: members must have two-factor auth.
Open source projects and samples from Microsoft.
Google โค๏ธ Open Source for everyone.
Alibaba Open Source for everyone
Data-Driven Documents codes.
China tencent open source team.