Giter VIP home page Giter VIP logo

chex's People

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

chex's Issues

No attribute 'KeyArray' when importing chex

Hi,
When I try to import chex, I got the following error:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/__init__.py", line 17, in <module>
    from chex._src.asserts import assert_axis_dimension
  File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/_src/asserts.py", line 26, in <module>
    from chex._src import asserts_internal as _ai
  File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/_src/asserts_internal.py", line 32, in <module>
    from chex._src import pytypes
  File "/home/chuaraym/cvpr/lib/python3.8/site-packages/chex/_src/pytypes.py", line 36, in <module>
    PRNGKey = jax.random.KeyArray
AttributeError: module 'jax.random' has no attribute 'KeyArray'

I am using jax version 0.2.14 and jaxlib 0.1.68. Did they cause the error?

Error with Pydantic

Hello!
I'm interested in using pydantic's recursive constructor / asdict functionality, but jax.jit-ed functions give the following error:

Argument '_Pydantic_OptimConfig_93971134241088(.. SOMETHING HERE...)' of type <class 'pydantic.dataclasses._Pydantic_OptimConfig_93971134241088'> is not a valid JAX type.

TypeError: non-default argument 'value' follows default argument

I'm trying to inherit from a dataclass with an optional id in the parent, however, I am seeing TypeError: non-default argument 'value' follows default argument. This is the minimal code to reproduce.

from typing import Optional
import chex

@chex.dataclass
class Base:
    idx: Optional[int] = None

@chex.dataclass
class Derived(Base):
    value: int

This is an issue with the dataclasses library rather than chex.dataclasses, however, according to https://stackoverflow.com/questions/51575931/class-inheritance-in-python-3-7-dataclasses#answer-69822584, it is possible to fix by setting kw_only=True.

without_jit=True for already jitted functions

In most JAX-based implementations, jit is almost always included. Basically, if there is no reason not to use it, people will try to take advantage of its speedup.

I noticed that @chex.variants(with_jit=True, without_jit=True) is a great way to assert the same behavior for both execution paths, as long as the variant is derived from a non-jitted function.

In the following example, I would expect to see "Tracing fn" four times total: Three times for the non-jitted variants and once for the initial jit compiliation. In reality, test_variant_pre_jitted() is executed twice with the jitted fn, resulting in two tracer outputs.

@chex.variants(with_jit=True, without_jit=Truue)
def test_variant_pre_jitted(self):
  @jit
  def fn(x, y):
    print("Tracing fn")
    return x + y

  var_fn = self.variant(fn)
  self.assertEqual(var_fn(1, 2), 3)
  self.assertEqual(var_fn(3, 4), 7)
  self.assertEqual(var_fn(5, 6), 11)

Of course, omitting @jit will lead to the expected behavior. However, when more complex implementations already make use of jit, variants do not make sense anymore, sadly.

My case is the latter and I only see the option of implementing a model-wide use_jit flag so that I can derive variants from non-jitted code. However, this makes the whole idea of variants rather obsolete altogether.

I'm aware this could well be a limitation of JAX and jit itself rather than chex. In that case, I think an error when jitted code is passed to variant() would make this more transparent.

`isinstance(None, (int, float, chex.Array))` raises error since `chex==0.1.7`

Hello,

Thanks for the useful package. I am hitting an error when using the latest chex. See reproduction instructions below.

pip install chex==0.1.5 && python -c "import chex; print(isinstance(None, (int, float, chex.Array)))"
False
pip install chex==0.1.7 && python -c "import chex; print(isinstance(None, (int, float, chex.Array)))"
raceback (most recent call last):
  File "<string>", line 1, in <module>
  File "/usr/lib/python3.8/typing.py", line 769, in __instancecheck__
    return self.__subclasscheck__(type(obj))
  File "/usr/lib/python3.8/typing.py", line 777, in __subclasscheck__
    raise TypeError("Subscripted generics cannot be used with"
TypeError: Subscripted generics cannot be used with class and instance checks

This causes an error for optax when trying to use inject_hyperparams, which essentially uses import chex; print(isinstance(None, (int, float, chex.Array)))

https://github.com/deepmind/optax/blob/04768d252911d6af4d4d36361930ccd0a54f9160/optax/_src/schedule.py#L589

post_init error in inherited dataclass

When inheriting one dataclass from another, Chex's dataclass does not allow a super() call to be made. This is something you can do in Python's base dataclass module.

A minimum working example is

from chex import dataclass as dataclass

@dataclass
class ChexBase:
    a : int 

    def __post_init__(self):
        self.b = self.a + 1

@dataclass
class ChexSub(ChexBase):
    a: int 

    def __post_init__(self):
        super().__post_init__()
        self.c = self.a + 2

temp = ChexSub(a = 1)
temp.b

Importing dataclass from dataclasses runs without error and returns 2, as expected.

Environment

  • Chex version 0.1.5
  • Ubuntu 20.04
  • Python 3.9

Using variants with pytest

Hi,

First of all thank you for this very useful library !

I have a project in Jax in which I already implemented my tests using pytest. However the possibility that chex.variants offers are too nice to ignore. Simultaneously I would like not to rewrite all my test.

Is there a way to reconcile pytest and chex ?

Thank you again for all the work!
Best,

Strange gap in version

@hbq1 The previous release version was 0.1.8 and the current one is 0.1.81, not 0.1.9 as one would expect. Why is there a gap between versions? Might it be a typo/mistake by any chance?

Analog to `flax.struct.dataclass`

Hi chex team,

Is there any potential for something like flax.struct.dataclass in chex?

Basically a kind of dataclass that can mark static arguments.

Two other variations include jax_dataclasses and simple-pytree though a chex official one would be cool for all its benefits of being part of chex.

Thank you!

Breaking for jax 0.4.24

Can you update chex to work with the newest version of jax.

From jax.random: PRNGKeyArray, KeyArray,
default_prng_impl, threefry_2x32, threefry2x32_key, threefry2x32_p, rbg_key, and unsafe_rbg_key.

This breaks the completely breaks the import of chex and makes many packages unusable.

Typing issue with chex.dataclass

Static type checkers like pyright, mypy, etc. will think chex.dataclass-decorated dataclass has a constructor with no parameters.

Example:

@chex.dataclass(frozen=True)
class Foo:
    a: int
    b: int

image

However, Foo() is not a valid call: A legitimate call would be something like Foo(a=1, b=2). This does not agree with static type checker's analysis.

Compare the behavior with built-in dataclass:

image

chex.disable_asserts() is ignored by chex.assert_max_traces

The assert max traces decorator still raises an assertion error when chex is configured to disable assertions with chex.disable_assertions().

Code to reproduce:

import jax
import jax.numpy as jnp
import chex
chex.disable_asserts()

@jax.jit
@chex.assert_max_traces(n=1)
def f(x):
    return x

chex.assert_equal_shape(jnp.zeros((1)), jnp.zeros((2,))) # correctly ignored

f(jnp.zeros((1,)))
f(jnp.zeros((2,))) # AssertionError: [Chex] Function 'f' is traced > 1 times!

`chex.dataclass` wrapper causes type error: Expected no arguments to dataclass constructor

Applying the chex.dataclass wrapper to a class yields the following error:

import chex

@chex.dataclass(frozen=True)
class Class:
    x: int

Class(4)
$ pyright --version
pyright 1.1.337
$ python3 --version
Python 3.11.6
$ python3 -c "import chex; print(chex.__version__)"
0.1.85
$ pyright test.py
/Users/user/Desktop/test.py
  /Users/user/Desktop/test.py:7:1 - error: Expected no arguments to "Class" constructor (reportGeneralTypeIssues)
1 error, 0 warnings, 0 informations 

microsoft/pyright#6536 (comment)

This is a bug in the chex library. The chex.dataclass decorator has no type annotations despite the fact that the package contains a "py.typed" marker file.

microsoft/pyright#6536 (comment)

I recommend looking at the stdlib dataclass class in the typeshed dataclass.pyi stub.

https://github.com/python/typeshed/blob/main/stdlib/dataclasses.pyi

fake_pmap_and_jit has a confusing interface

I spend quite some time figuring out why code in a large codebase was so slow, only to find out that jit was disabled throughout the entire project. This was because the main function was called as follows:

with chex.fake_pmap_and_jit(FLAGS.debug):
  main()

While on first sight it appears as if this indeed disables both pmap and jit if flag debug is set, this in fact only disables pmap and always disables jit!

The reason is that fake_pmap_and_jit take two positional arguments that disable respectively pmap and jit, and they are both True by default. The names of these arguments are somewhat cryptic to me as well: enable_pmap_patching and enable_jit_patching, which actually disable these JAX transformations.

Given these observations, I think the situation would improve if the signature would be:

def fake_pmap_and_jit(*, disable_pmap: bool = True, disable_jit: bool = True)

Then my code above would then look like this:

with chex.fake_pmap_and_jit(disable_pmap=FLAGS.debug):
  main()

Which shows clearly we are not setting disable_jit, so we would rewrite this to:

with chex.fake_pmap_and_jit(disable_pmap=FLAGS.debug, disable_jit=FLAGS.debug):
  main()

Dill pickling chex.dataclass blows the stack

Chex dataclasses can be pickled with pickle, but not with dill:

import dill
import pickle

import chex


@chex.dataclass
class Point:
  x: float
  y: float


# Works fine.
pickle.dumps(Point(x=1.0, y=2.0))

# Generates `RecursionError: maximum recursion depth exceeded`
dill.dumps(Point(x=1.0, y=2.0))

Chex dataclass defaulting mappable_dataclass=True

To start with, thanks for open sourcing your work on Chex, it's a great tooling library for building robust Jax applications!

As I was upgrading to the latest release 0.0.3, I noticed quite a few of my tests breaking. It happens that the default option mappable_dataclass=True in chex.dataclass is breaking the usual interface of dataclasses (which is clearly expected reading the code documentation!)

I guess probably from the perspective of Deepmind usage, it makes sense to default this option. But from an external user point of view, it is rather surprising to have a dataclass decorator not behaving like a dataclass. I think it would be great to make it clear in the library readme that this option needs to be turned off to get the full dataclass behaviour (or turned it off by default).

Tuple not recognized as a valid chex.ArrayTree

def print_pytree(pytree: chex.ArrayTree):
  print(pytree)


def main(argv):
  # does not work
  # tree_x = (-1.0, {"k1": 1.0, "k2": (1.0, 1.0)}, 1.0)

  # works
  tree_x = [-1.0, {"k1": 1.0, "k2": (1.0, 1.0)}, 1.0]
  print_pytree(tree_x)

ModuleNotFoundError: No module named 'jax.numpy'

On installing the latest version of Chex directly from GitHub, the following error pops up when importing chex:

File ".../anaconda3/lib/python3.8/site-packages/chex/__init__.py", line 17, in <module>
    from chex._src.asserts import assert_axis_dimension
  File "...anaconda3/lib/python3.8/site-packages/chex/_src/asserts.py", line 26, in <module>
    from chex._src import asserts_internal as _ai
  File ".../anaconda3/lib/python3.8/site-packages/chex/_src/asserts_internal.py", line 34, in <module>
    from chex._src import pytypes
  File ".../lib/python3.8/site-packages/chex/_src/pytypes.py", line 19, in <module>
    import jax.numpy as jnp
ModuleNotFoundError: No module named 'jax.numpy'

Jax: v0.3.34
jaxlib: v0.1.69

`assert_tree_shape_prefix` requires tuple instead of `Sequence`

Hi,

Thanks for sharing this nice package! I especially like all the assertions for pytrees.
However, I came across the following inconsistency in the documentation:
Current situation
According to the docs, the shape_prefix argument of assert_tree_shape_prefix is of type Sequence[int].
However, when I pass sequence such as list (instead of a tuple)

import chex
import jax.numpy as jnp

mytree = {'a': jnp.array([[[1], [2]]])}
chex.assert_tree_shape_prefix(mytree, shape_prefix=[1, 2])  # AssertionError!

The assertion raises an exception:

AssertionError: [Chex] Assertion assert_tree_shape_prefix failed: Tree leaf 'a' has a shape prefix different from expected: (1, 2) != [1, 2].

The error can simply be fixed by using a tuple instead:

chex.assert_tree_shape_prefix(mytree, shape_prefix=(1, 2))  # OK!

I think this is not the only inconsistent function, but I did not check for others.

Desired situation
Ideally, I would like the function to behave like in the docs, so that I can also pass a list. Why? I think being able to choose square brackets (i.e., list) after closing parenthesis helps readability.

I would be interested to hear your opinion and I am happy to contribute a pull request.

Keep up the good work!

Hylke

Mypy index type error with `chex.dataclass`

According to the docs, by default a class wrapped with chex.dataclass can be indexed, because the dataclass becomes compatible with collections.abc.Mapping (because mappable_dataclass=True).
However, mypy doesn't seem to understand this. For example:

import chex

@chex.dataclass
class Container:
    foo: float

c = Container(foo=1.)
d = c.foo  # OK.
e = c['foo']  # error: Value of type "Container" is not indexable  [index]

Looking at the code, it seems that this is related to methods such as __getitem__ that are added dynamically with setattr which mypy doesn't recognise.
Any ideas how to go around this apart for (i) explicitly silencing the error (ii) using a different method of accessing the variables?

Keep up the good work!
Hylke

conda package dependencies

Apologies if this is the wrong place to post this.

The problem is that the dependencies for the chex package on conda-forge is incorrect: there is a typo in the jax version required.

conda search chex==0.1.7 --channel conda-forge --info

returns

chex 0.1.7 pyhd8ed1ab_0
-----------------------
file name   : chex-0.1.7-pyhd8ed1ab_0.conda
name        : chex
version     : 0.1.7
build       : pyhd8ed1ab_0
build number: 0
size        : 70 KB
license     : Apache-2.0
subdir      : noarch
url         : https://conda.anaconda.org/conda-forge/noarch/chex-0.1.7-pyhd8ed1ab_0.conda
md5         : 7d643a09cac375aab18872f92db3b78c
timestamp   : 2023-03-27 14:01:47 UTC
dependencies: 
  - absl-py >=0.9.0
  - dm-tree >=0.1.5
  - jax >=0.1.55
  - jaxlib >=0.1.37
  - numpy >=1.18.0
  - python >=3.6
  - toolz >=0.9.0
  - typing_extensions >=4.2.0

But the jax version should be >=0.4.6

Consider supporting static attributes in chex.dataclass

from jax import jit
from jax.lax import scan
from tjax import IntegralNumeric, RealNumeric
from tjax.dataclasses import dataclass, field
import chex

def f(carry, _):
  return carry + 1.0, None

@jit
def do_scan(c):
  final, _ = scan(f, c.x, None, c.y)
  return final

@dataclass
class C:
  x: RealNumeric
  y: IntegralNumeric = field(static=True)

print(do_scan(C(1.0, 10)))  # works

@chex.dataclass
class D:
  x: RealNumeric
  y: IntegralNumeric

print(do_scan(D(x=1.0, y=10)))  # fails

ImportError: cannot import name 'DeviceLocalLayout' from 'jax._src.layout' (/opt/conda/lib/python3.10/site-packages/jax/_src/layout.py)

/tmp/ipykernel_34/2874194604.py:15: DeprecationWarning: Importing display from IPython.core.display is deprecated since IPython 7.14, please import from IPython display
from IPython.core.display import display, HTML

ImportError Traceback (most recent call last)
Cell In[27], line 20
18 # Import model definition from big_vision
19 from big_vision.models.proj.paligemma import paligemma
---> 20 from big_vision.trainers.proj.paligemma import predict_fns
22 # Import big vision utilities
23 import big_vision.datasets.jsonl

File /kaggle/working/big_vision_repo/big_vision/trainers/proj/paligemma/predict_fns.py:20
17 import functools
19 from big_vision.pp import registry
---> 20 import big_vision.utils as u
21 import einops
22 import jax

File /kaggle/working/big_vision_repo/big_vision/utils.py:38
36 import flax.jax_utils as flax_utils
37 import jax
---> 38 from jax.experimental.array_serialization import serialization as array_serial
39 import jax.numpy as jnp
40 import ml_collections as mlc

File /opt/conda/lib/python3.10/site-packages/jax/experimental/array_serialization/serialization.py:36
34 from jax._src import sharding
35 from jax._src import sharding_impls
---> 36 from jax._src.layout import Layout, DeviceLocalLayout as DLL
37 from jax._src import typing
38 from jax._src import util

ImportError: cannot import name 'DeviceLocalLayout' from 'jax._src.layout' (/opt/conda/lib/python3.10/site-packages/jax/_src/layout.py)

Specify non-pytree node dataclass fields

Hi,

Thanks for making this awesome library!

Is it possible to specify fields in the chex.dataclass definitions to not include certain fields? This is a feature supported in flax https://flax.readthedocs.io/en/latest/_modules/flax/struct.html#dataclass
which I found to be quite useful when defining data classes with fields (such as JAX functions) that shouldn't be mapped over with dm-tree or jax.tree_map. I am not sure if this is supported out of the box by chex at the moment but is something that I hope would be part of chex.

AttributeError: module 'jax' has no attribute '_src'

trying to import optax and getting an error AttributeError: module 'jax' has no attribute '_src' for jax versions > 0.3.17

optax version == 0.1.3
chex version == 0.1.3

In [1]: import optax
/home/penn/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/pytypes.py:37: FutureWarning: jax.tree_structure is deprecated, and will be removed in a future release. Use jax.tree_util.tree_structure instead.
  PyTreeDef = type(jax.tree_structure(None))
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Input In [1], in <cell line: 1>()
----> 1 import optax

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/__init__.py:17, in <module>
      1 # Copyright 2019 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Optax: composable gradient processing and optimization, in JAX."""
---> 17 from optax import experimental
     18 from optax._src.alias import adabelief
     19 from optax._src.alias import adafactor

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/experimental/__init__.py:20, in <module>
      1 # Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Experimental features in Optax.
     16 
     17 Features may be removed or modified at any time.
     18 """
---> 20 from optax._src.experimental.complex_valued import split_real_and_imaginary
     21 from optax._src.experimental.complex_valued import SplitRealAndImaginaryState

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/optax/_src/experimental/complex_valued.py:32, in <module>
     15 """Complex-valued optimization.
     16 
     17 When using `split_real_and_imaginary` to wrap an optimizer, we split the complex
   (...)
     27 See details at https://github.com/deepmind/optax/issues/196
     28 """
     30 from typing import NamedTuple, Union
---> 32 import chex
     33 import jax
     34 import jax.numpy as jnp

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/__init__.py:17, in <module>
      1 # Copyright 2020 DeepMind Technologies Limited. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
   (...)
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Chex: Testing made fun, in JAX!"""
---> 17 from chex._src.asserts import assert_axis_dimension
     18 from chex._src.asserts import assert_axis_dimension_comparator
     19 from chex._src.asserts import assert_axis_dimension_gt

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/asserts.py:26, in <module>
     23 import unittest
     24 from unittest import mock
---> 26 from chex._src import asserts_internal as _ai
     27 from chex._src import pytypes
     28 import jax

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/asserts_internal.py:32, in <module>
     29 from typing import Any, Sequence, Union, Callable, Optional, Set, Tuple, Type
     31 from absl import logging
---> 32 from chex._src import pytypes
     33 import jax
     34 import jax.numpy as jnp

File ~/anaconda3/envs/jax/lib/python3.10/site-packages/chex/_src/pytypes.py:44, in <module>
     40 Device = jax.lib.xla_extension.Device
     42 ArrayTree = Union[Array, Iterable['ArrayTree'], Mapping[Any, 'ArrayTree']]
---> 44 ArrayDType = jax._src.numpy.lax_numpy._ScalarMeta

AttributeError: module 'jax' has no attribute '_src'

Dataclass breaks is_leaf function for jax.tree_map

is_leaf is handled correctly with NamedTuple

class NT(NamedTuple):
  a: Any
  b: Any

class Histogram(NamedTuple):
  hist: jnp.ndarray
  bins: jnp.ndarray

def is_leaf(n):
  print(n)
  return isinstance(n, Histogram)

jax.tree_map(lambda x: x, NT(a=Histogram(1, 2), b=Histogram(4, 3)), is_leaf=is_leaf)

Output (as expected)

NT(a=Histogram(hist=1, bins=2), b=Histogram(hist=4, bins=3))
Histogram(hist=1, bins=2)
Histogram(hist=4, bins=3)

Does not work with chex.dataclass

@chex.dataclass
class DC:
  a: Any
  b: Any

jax.tree_map(identity, DC(a=Histogram(1, 2), b=Histogram(4, 3)), is_leaf=is_leaf)

Actual output

DC(a=Histogram(hist=1, bins=2), b=Histogram(hist=4, bins=3))
1
2
4
3

Allow for nested chex.chexify

Hello, I have a dilemma with chexify - consider the following code:

# If this is not commented out, the second test will fail
# If this is commented out, the first test will fail
@chex.chexify
@jax.jit
def log_safe(x: jnp.array) -> jnp.array:
    chex.assert_trees_all_equal(x > 0, jnp.ones_like(x, dtype=bool))
    return jnp.log(x)

@chex.chexify
@jax.jit
def combo_safe(x: jnp.array) -> jnp.array:
    chex.assert_trees_all_equal(x != 1, jnp.ones_like(x, dtype=bool))
    return log_safe(x) / (x - 1)


def test_log_safe() -> None:
    x = jnp.array([1.0, 2.0, 3.0, -1.0])
    with pytest.raises(Exception):
        log_safe(x)
        log_safe.wait_checks()

    x = jnp.array([1.0, 2.0, 3.0, 4.0])
    assert jnp.array_equal(log_safe(x), jnp.log(x))
    log_safe.wait_checks()

def test_combo_safe() -> None:
    x = jnp.array([1.0, 2.0, 3.0, 4.0])
    with pytest.raises(Exception):
        combo_safe(x)
        combo_safe.wait_checks()

    x = jnp.array([2.0, 3.0, 4.0, 5.0])
    assert jnp.array_equal(combo_safe(x), jnp.log(x) / (x - 1))
    combo_safe.wait_checks()

If I comment out the first chexify the test_log_safe test will fail with RuntimeError: Value assertions can only be called from functions wrapped with @chex.chexify. See the docs. which makes sense to me. However, once I add the decorator back in, the second test fails with RuntimeError: Nested @chexify wrapping is disallowed. Make sure that you only wrap the function at the outermost level.

A hack in this simple scenario would be to make two versions of the function, a log_safe without the chexify decorator and a log_safe_test = chex.chexify(log_safe) and only call the log_safe_test version during my test. However, that solution is pretty clumsy, especially if I have a lot of these scenarios. In a codebase that is fully end-to-end jax, that would mean all but the outermost function require this hack. Would it be possible to allow for nested chex.chexify where subsequent applications of the macro simply do nothing, or just raise a warning?

Support wrapping functools.partial() objects

Example:

chex.chexify(functools.partial(fn, foo='bar'))

Error:

AttributeError: 'functools.partial' object has no attribute '__name__'
WARNING:absl:[Chex] Some of chexify assetion statuses were not inspected due to async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html). Consider calling `chex.block_until_chexify_assertions_complete()` at the end of computations that rely on jitted chex assetions.

`AssertsChexifyTest.test_uninspected_checks` test failure

I'm seeing the following test failure when running the test suite:

============================= test session starts ==============================
platform linux -- Python 3.10.7, pytest-7.1.3, pluggy-1.0.0
rootdir: /build/source
collected 548 items                                                            

chex/chex_test.py .                                                      [  0%]
chex/_src/asserts_chexify_test.py ......F.....                           [  2%]
chex/_src/asserts_internal_test.py .s.s.........                         [  4%]
chex/_src/asserts_test.py ..s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s. [ 13%]
s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s.s..................... [ 26%]
........................................................................ [ 39%]
........................................................................ [ 52%]
.................................                                        [ 58%]
chex/_src/dataclass_test.py ...........................................  [ 66%]
chex/_src/dimensions_test.py .................                           [ 69%]
chex/_src/fake_set_n_cpu_devices_test.py s                               [ 69%]
chex/_src/fake_test.py ................................                  [ 75%]
chex/_src/restrict_backends_test.py ssssssssss                           [ 77%]
chex/_src/variants_test.py .....................s....s............s....s [ 85%]
..........................................................ssssssssssssss [ 98%]
sssssss                                                                  [100%]

=================================== FAILURES ===================================
__________________ AssertsChexifyTest.test_uninspected_checks __________________

self = <chex._src.asserts_chexify_test.AssertsChexifyTest testMethod=test_uninspected_checks>

    def test_uninspected_checks(self):
    
      @jax.jit
      def _pos_sum(x):
        chex_value_assert_positive(x, custom_message='err_label')
        return x.sum()
    
      invalid_x = -jnp.ones(3)
      chexify_async(_pos_sum)(invalid_x)  # async error
    
>     with self.assertRaisesRegex(AssertionError, 'err_label'):
E     AssertionError: AssertionError not raised

chex/_src/asserts_chexify_test.py:179: AssertionError
------------------------------ Captured log call -------------------------------
WARNING  absl:asserts_chexify.py:57 [Chex] Some of chexify assetion statuses were not inspected due to async exec (https://jax.readthedocs.io/en/latest/async_dispatch.html). Consider calling `chex.block_until_chexify_assertions_complete()` at the end of computations that rely on jitted chex assetions.
=============================== warnings summary ===============================
chex/_src/asserts_chexify_test.py: 12 warnings
  /build/source/chex/_src/asserts_chexify_test.py:58: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
    return jnp.all(jnp.array([(x > 0).all() for x in jax.tree_leaves(tree)]))

chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_static_assertion__with_jit
chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_static_assertion__without_jit
  /build/source/chex/_src/asserts_chexify_test.py:86: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
    return sum(x.sum() for x in jax.tree_leaves(tree))

chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
  /nix/store/4y9j6xdkgqwkdx5ki508l175smcjgs9l-python3.10-pytest-7.1.3/lib/python3.10/site-packages/_pytest/unraisableexception.py:78: PytestUnraisableExceptionWarning: Exception ignored in atexit callback: <function _check_if_hanging_assertions at 0x7ffddfe66d40>
  
  Traceback (most recent call last):
    File "/build/source/chex/_src/asserts_chexify.py", line 32, in _check_error
      checkify.check_error(err)
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 476, in check_error
      return assert_p.bind(err, code, payload, msgs=error.msgs)
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 328, in bind
      return self.bind_with_trace(find_top_trace(args), args, params)
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 331, in bind_with_trace
      out = trace.process_primitive(self, map(trace.full_raise, args), params)
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/core.py", line 698, in process_primitive
      return primitive.impl(*tracers, **params)
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 483, in assert_impl
      raise_error(Error(err, code, msgs, payload))
    File "/nix/store/2fcsbc07baqm4mfmibzr1qlh8bfvb6mc-python3.10-jax-0.3.23/lib/python3.10/site-packages/jax/_src/checkify.py", line 123, in raise_error
      raise ValueError(err)
  ValueError: [Chex] chexify assertion failed [err_label] [failed at /build/source/chex/_src/asserts_chexify_test.py:173] (check failed at /build/source/chex/_src/asserts_internal.py:229 (_chex_assert_fn))
  
  During handling of the above exception, another exception occurred:
  
  Traceback (most recent call last):
    File "/build/source/chex/_src/asserts_chexify.py", line 62, in _check_if_hanging_assertions
      block_until_chexify_assertions_complete()
    File "/build/source/chex/_src/asserts_chexify.py", line 51, in block_until_chexify_assertions_complete
      wait_fn()
    File "/build/source/chex/_src/asserts_chexify.py", line 180, in _wait_checks
      _check_error(async_check_futures.popleft().result(async_timeout))
    File "/build/source/chex/_src/asserts_chexify.py", line 40, in _check_error
      raise AssertionError(msg)  # pylint:disable=raise-missing-from
  AssertionError: [Chex] chexify assertion failed [err_label] [failed at /build/source/chex/_src/asserts_chexify_test.py:173] 
  
    warnings.warn(pytest.PytestUnraisableExceptionWarning(msg))

chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
  /build/source/chex/_src/asserts_chexify_test.py:52: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
    if not all((x > 0).all() for x in jax.tree_leaves(tree)):

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
=========================== short test summary info ============================
FAILED chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
====== 1 failed, 461 passed, 86 skipped, 20 warnings in 84.47s (0:01:24) =======
error: builder for '/nix/store/f9icjsb9pbz4p8qpsyhp9gq1fvjvwwhz-python3.10-chex-0.1.5.drv' failed with exit code 1;
       last 10 log lines:
       > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted
       > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
       > chex/_src/asserts_chexify_test.py::AssertsChexifyTestSuite::test_log_abs_fn_jitted_vmapped
       >   /build/source/chex/_src/asserts_chexify_test.py:52: FutureWarning: jax.tree_leaves is deprecated, and will be removed in a future release. Use jax.tree_util.tree_leaves instead.
       >     if not all((x > 0).all() for x in jax.tree_leaves(tree)):
       >
       > -- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
       > =========================== short test summary info ============================
       > FAILED chex/_src/asserts_chexify_test.py::AssertsChexifyTest::test_uninspected_checks
       > ====== 1 failed, 461 passed, 86 skipped, 20 warnings in 84.47s (0:01:24) =======

I'm using

chex.Dimensions API enhancement

I would like to propose an API enhancement that allow the use of chex.Dimensions inside function annotations. If there is interest I'd like to contribute. Example below:

dims = chex.Dimensions(B=batch_size, T=sequence_len, E=embedding_dim)
...
def foo(arr: chex.Array):
     chex.assert_shape(arr, dims['BTE'])
     # fn logic

### turns into ###

def foo(arr: chex.Array(dims['BTE'])): # behind the scenes assert on function call
     # fn logic

This is particularly useful for dataclasses e.g.

dims = chex.Dimensions(B=batch_size, T=rollout_len)

# asserts are run on instantiation
class TimeStep:
     q_values: chex.Array(dims['BT']) 
     discounts: chex.Array(dims['BT']) 
     rewards: chex.Array(dims['BT']) 

Pros:

  • reduces clutter that asserts can add
  • allows user to view the shape expected by function or class in editor (not sure what you call the VScode popup)
    • example: using RLax, in order to know what shape is expected for each arg in a loss fn you need to either look at source code or wait for fn call to raise an assert

Cons:

  • increased API complexity
  • ...?

Test files are in binary package distribution

Test files are not filtered properly. The issue is that setuptools.find_package finds packages not modules while tests are organized as a separate modules. In order to mitigate the issue, one should filter test files manually as follows. This patch are created and tested on chex v0.1.86.

--- a/setup.py	2024-03-19 12:58:11.000000000 +0300
+++ b/setup.py	2024-04-17 15:19:33.593293889 +0300
@@ -15,8 +15,11 @@
 """Install script for setuptools."""
 
 import os
-from setuptools import find_packages
-from setuptools import setup
+from pathlib import Path
+
+from setuptools import find_packages, setup
+from setuptools.command.build_py import build_py as _build_py
+
 
 _CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
 
@@ -40,6 +43,15 @@
     ]
 
 
+class build_py(_build_py):
+
+    def find_package_modules(self, package, package_dir):
+        modules = super().find_package_modules(package, package_dir)
+        return [(pkg, mod, file)
+                for pkg, mod, file in modules
+                if not Path(file).match('**/*_test.py')]
+
+
 setup(
     name='chex',
     version=_get_version(),
@@ -51,7 +63,7 @@
     long_description_content_type='text/markdown',
     author_email='[email protected]',
     keywords='jax testing debugging python machine learning',
-    packages=find_packages(exclude=['*_test.py']),
+    packages=find_packages(),
     install_requires=_parse_requirements(
         os.path.join(_CURRENT_DIR, 'requirements', 'requirements.txt')),
     tests_require=_parse_requirements(
@@ -73,4 +85,5 @@
         'Topic :: Software Development :: Testing :: Unit',
         'Topic :: Software Development :: Libraries :: Python Modules',
     ],
+    cmdclass={'build_py': build_py},
 )

Fake contexts by calling .start() not working

Hi, I tried using both the latest github version and the latest pypi version but in neither using fake contexts by calling .start() works (it did work as a context manager!).
Here some pictures with my problem for fake_pmap and fake_jit:

Captura de pantalla 2020-09-28 a la(s) 18 27 11
Captura de pantalla 2020-09-28 a la(s) 18 27 20

What could be the cause of my problem? Thanks

DeprecationWarning for importing toolz

  /home/neil/.pyenv/versions/3.8.5/lib/python3.8/site-packages/toolz/compatibility.py:2: DeprecationWarning: The toolz.compatibility module is no longer needed in Python 3 and has been deprecated. Please import these utilities directly from the standard library. This module will be removed in a future release.
    warnings.warn("The toolz.compatibility module is no longer "

Chex dataclass throws an exception in Python 3.9

$ python --version
Python 3.9.1
In [1]: import chex

In [2]: @chex.dataclass
   ...: class Parameters:
   ...:   x: chex.ArrayDevice
   ...:   y: chex.ArrayDevice
   ...:
   ...: parameters = Parameters(
   ...:     x=jnp.ones((2, 2)),
   ...:     y=jnp.ones((1, 2)),
   ...: )
   ...:
   ...: # Dataclasses can be treated as JAX pytrees
   ...: jax.tree_map(lambda x: 2.0 * x, parameters)
   ...:
   ...: # and as mappings by dm-tree
   ...: tree.flatten(parameters)
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-3-3461a2700932> in <module>
      1 @chex.dataclass
----> 2 class Parameters:
      3   x: chex.ArrayDevice
      4   y: chex.ArrayDevice
      5

~/.virtualenvs/science/lib/python3.9/site-packages/chex/_src/dataclass.py in dataclass(cls, init, repr, eq, order, unsafe_hash, frozen, mappable_dataclass, restricted_inheritance)
    104   if cls is None:
    105     return dcls
--> 106   return dcls(cls)
    107
    108

~/.virtualenvs/science/lib/python3.9/site-packages/chex/_src/dataclass.py in __call__(self, cls)
    147
    148     if self.mappable_dataclass:
--> 149       dcls = mappable_dataclass(dcls, self.restricted_inheritance)
    150
    151     def _from_tuple(args):

~/.virtualenvs/science/lib/python3.9/site-packages/chex/_src/dataclass.py in mappable_dataclass(cls, restricted_inheritance)
     81   if cls.__bases__ == (object,):
     82     # `collections.Mapping` is incompatible with `object`
---> 83     cls.__bases__ = (collections.Mapping,)
     84   else:
     85     cls.__bases__ += (collections.Mapping,)

TypeError: __bases__ assignment: 'Mapping' deallocator differs from 'object'

[REQ] Conda recipe

Hi,
I'm the lead developer of NetKet, an established machine learning / quantum physics package.

We have recently finished rewriting our core to be based on Jax (and flax), and recently released a beta version.
Since many physicists seem to use anaconda, we would also like to update our conda recipe.
However, since we depend on optax (and therefore on Chex), we would need Chex to have a Conda recipe.

Is that something you'd consider? I am willing to volunteer some work to help you.

I tried creating a recipe starting from your pypi source distribution, but that is problematic because you don't bundle your requirements.txt file, which is required to run setup.py.
I could create a recipe from the tag tarballs on GitHub, but that sometimes prevent the conda packages from auto-updating the recipe for later releases.

Better error report for max traces exceeded

In my experience, when chex reports max traces exceeded, it's usually because of me passing parameters to the function with different shapes or data types. Is it possible for chex to report such inconsistency?

e.g.,

AssertionError: [Chex] Function '_wrapper' is traced > 1 times!
Difference in input shapes. Last time variable `x` traced with shape "(10, 1)", this time traced with shape "(9, 1)".

Missing package dependency typing-extension

Hi,

The latest version (chex==0.1.7) is missing the typing-extension dependency.
Problem
For example, the following python code

from chex import dataclass

raises the error

  File "<string>", line 1, in <module>
  File "/usr/local/lib/python3.11/dist-packages/chex/__init__.py", line 72, in <module>
    from chex._src.dataclass import dataclass
  File "/usr/local/lib/python3.11/dist-packages/chex/_src/dataclass.py", line 23, in <module>
    from typing_extensions import dataclass_transform  # pytype: disable=not-supported-yet
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
ModuleNotFoundError: No module named 'typing_extensions'

on ubuntu 23.04 (with Python 3.11)

Solution
I found that installing typing-extension resolves the error:

pip3 install typing-extension

Reproduce
The following code reproduces the error:

docker run -it ubuntu:23.04 \
  bash -c \
  "apt update && apt install --assume-yes python3-pip && pip3 install --break-system-packages chex && python3 -c 'from chex import dataclass'"

How to version?

Hi there!
Please tell -- how to get version of chex installed in my ubuntu(18.0)

Numpy Conflict

ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
chex 0.1.82 requires numpy>=1.25.0, but you have numpy 1.24.3 which is incompatible.
Successfully installed numpy-1.24.3

Trying to get Diffusers setup - https://huggingface.co/docs/transformers/installation

CpuDevice no longer in jax

Hello,

Seems like the newest version of jax (0.3.7) removed some classes that are used here in chex. Should chex upper bound the jax version? I see this conflicting code is not currently on the main branch -- alternatively, maybe a new release can be made?

google/jax#10326

`jax.random.key` tree comparison results in a `ZeroDivision` warning

See below:

import jax
import chex

chex.assert_trees_all_equal(jax.random.key(0), jax.random.key(0))
> RuntimeWarning: divide by zero encountered in equal val = comparison(x, y)

# Runs fine
chex.assert_trees_all_equal(jax.random.PRNGKey(0), jax.random.PRNGKey(0))

The problem is not jax per se, since key comparison works:

jax.random.key(0) == jax.random.key(1)
> Array(False, dtype=bool)

jax.random.key(0) == jax.random.key(0)
> Array(True, dtype=bool)

chex.variants(with_pmap=True) ignores `static_argnames`

The _with_pmap function accepts static_argnums as a parameter, but not static_argnames. This is inconsistent with other variants, such as with_jit and with_device. Crucially, this prevents to test methods that require to pass arguments by name (e.g., Distrax's Distribution.sample())

More generally, it would be best if all variants accepted the same parameters where possible (i.e., where not specific to a single variant) and I would suggest to check all keys in **unused_kwargs against a list of allowed parameters (i.e., the union of the parameters of all variant functions) to prevent silent errors due to e.g., misspells.

Next release

chex v0.1.83 is failing with the latest jax (0.4.19) as jax.core.Shape.
This issue has been fixed on the master branch of chex in this commit.

Could you please make a release so as to ship this fix ?

Thank you very much !

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    ๐Ÿ–– Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. ๐Ÿ“Š๐Ÿ“ˆ๐ŸŽ‰

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google โค๏ธ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.