Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 26 additions & 22 deletions chex/_src/asserts.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def disable_asserts() -> None:

Use wisely.
"""
_ai.DISABLE_ASSERTIONS = True
_ai.DISABLE_ASSERTIONS = True # pyrefly: ignore[bad-assignment]


def enable_asserts() -> None:
Expand Down Expand Up @@ -416,23 +416,24 @@ def assert_size(
"""
# Ensure inputs and expected sizes are sequences.
if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]
inputs = [inputs] # pyrefly: ignore[bad-assignment]

if isinstance(expected_sizes, int):
expected_sizes = [expected_sizes] * len(inputs)
expected_sizes = [expected_sizes] * len(inputs) # pyrefly: ignore[bad-argument-type]

if not isinstance(expected_sizes, (list, tuple)):
raise AssertionError(
"Error in size compatibility check: expected sizes should be an int, "
f"list, or tuple of ints, got {expected_sizes}.")

if len(inputs) != len(expected_sizes):
if len(inputs) != len(expected_sizes): # pyrefly: ignore[bad-argument-type]
raise AssertionError(
# pyrefly: ignore[bad-argument-type]
"Length of `inputs` and `expected_sizes` must match: "
f"{len(inputs)} is not equal to {len(expected_sizes)}.")

errors = []
for idx, (x, expected) in enumerate(zip(inputs, expected_sizes)):
for idx, (x, expected) in enumerate(zip(inputs, expected_sizes)): # pyrefly: ignore[bad-argument-type]
size = getattr(x, "size", 1) # scalars have size 1 by definition.
# Allow any size for the ellipsis case and allow handling of integer
# expected sizes or collection of acceptable expected sizes.
Expand Down Expand Up @@ -575,20 +576,20 @@ def _shape_matches(actual_shape: Sequence[int],
# If there is no ellipsis, just compare to the full `actual_shape`.
if expected_suffix is None:
assert len(expected_prefix) == len(expected_shape)
return _unelided_shape_matches(actual_shape, expected_prefix)
return _unelided_shape_matches(actual_shape, expected_prefix) # pyrefly: ignore[bad-argument-type]

# Checks that the actual rank is least the number of non-elided dimensions.
if len(actual_shape) < len(expected_prefix) + len(expected_suffix):
return False

if expected_prefix:
actual_prefix = actual_shape[:len(expected_prefix)]
if not _unelided_shape_matches(actual_prefix, expected_prefix):
if not _unelided_shape_matches(actual_prefix, expected_prefix): # pyrefly: ignore[bad-argument-type]
return False

if expected_suffix:
actual_suffix = actual_shape[-len(expected_suffix):]
if not _unelided_shape_matches(actual_suffix, expected_suffix):
if not _unelided_shape_matches(actual_suffix, expected_suffix): # pyrefly: ignore[bad-argument-type]
return False

return True
Expand Down Expand Up @@ -632,18 +633,19 @@ def assert_shape(

# Ensure inputs and expected shapes are sequences.
if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]
inputs = [inputs] # pyrefly: ignore[bad-assignment]

# Shapes are always lists or tuples, not scalars.
if (not expected_shapes or not isinstance(expected_shapes[0], (list, tuple))):
expected_shapes = [expected_shapes] * len(inputs)
if len(inputs) != len(expected_shapes):
expected_shapes = [expected_shapes] * len(inputs) # pyrefly: ignore[bad-argument-type, bad-assignment]
if len(inputs) != len(expected_shapes): # pyrefly: ignore[bad-argument-type]
raise AssertionError(
# pyrefly: ignore[bad-argument-type]
"Length of `inputs` and `expected_shapes` must match: "
f"{len(inputs)} is not equal to {len(expected_shapes)}.")

errors = []
for idx, (x, expected) in enumerate(zip(inputs, expected_shapes)):
for idx, (x, expected) in enumerate(zip(inputs, expected_shapes)): # pyrefly: ignore[bad-argument-type]
shape = getattr(x, "shape", ()) # scalars have shape () by definition.
if not _shape_matches(shape, expected):
errors.append((idx, shape, _ai.format_shape_matcher(expected)))
Expand Down Expand Up @@ -743,17 +745,18 @@ def assert_rank(

# Ensure inputs and expected ranks are sequences.
if not isinstance(inputs, collections.abc.Sequence):
inputs = [inputs]
inputs = [inputs] # pyrefly: ignore[bad-assignment]
if (not isinstance(expected_ranks, collections.abc.Sequence) or
isinstance(expected_ranks, collections.abc.Set)):
expected_ranks = [expected_ranks] * len(inputs)
if len(inputs) != len(expected_ranks):
expected_ranks = [expected_ranks] * len(inputs) # pyrefly: ignore[bad-argument-type]
if len(inputs) != len(expected_ranks): # pyrefly: ignore[bad-argument-type]
raise AssertionError(
# pyrefly: ignore[bad-argument-type]
"Length of inputs and expected_ranks must match: inputs has length "
f"{len(inputs)}, expected_ranks has length {len(expected_ranks)}.")

errors = []
for idx, (x, expected) in enumerate(zip(inputs, expected_ranks)):
for idx, (x, expected) in enumerate(zip(inputs, expected_ranks)): # pyrefly: ignore[bad-argument-type]
if hasattr(x, "shape"):
shape = x.shape
else:
Expand Down Expand Up @@ -819,17 +822,18 @@ def assert_type(
if the types of inputs do not match the expected types.
"""
if not isinstance(inputs, (list, tuple)):
inputs = [inputs]
inputs = [inputs] # pyrefly: ignore[bad-assignment]
if not isinstance(expected_types, (list, tuple)):
expected_types = [expected_types] * len(inputs)
expected_types = [expected_types] * len(inputs) # pyrefly: ignore[bad-argument-type]

errors = []
if len(inputs) != len(expected_types):
if len(inputs) != len(expected_types): # pyrefly: ignore[bad-argument-type]
raise AssertionError(
# pyrefly: ignore[bad-argument-type]
"Length of `inputs` and `expected_types` must match, "
f"got {len(inputs)} != {len(expected_types)}."
)
for idx, (x, expected) in enumerate(zip(inputs, expected_types)):
for idx, (x, expected) in enumerate(zip(inputs, expected_types)): # pyrefly: ignore[bad-argument-type]
dtype = x.dtype if hasattr(x, "dtype") else np.result_type(x)
if expected in {float, jnp.floating}:
if not jnp.issubdtype(dtype, jnp.floating):
Expand Down Expand Up @@ -1630,7 +1634,7 @@ def _assert_trees_all_equal_jittable(
err_msg_template = "Values not exactly equal: {arr_1} != {arr_2}."
cmp_fn = lambda x, y: jnp.array_equal(x, y, equal_nan=True)
return _ai.assert_trees_all_eq_comparator_jittable(
cmp_fn, err_msg_template, *trees
cmp_fn, err_msg_template, *trees # pyrefly: ignore[bad-argument-type]
)


Expand Down Expand Up @@ -1712,7 +1716,7 @@ def _assert_trees_all_close_jittable(
)
cmp_fn = lambda x, y: jnp.isclose(x, y, rtol=rtol, atol=atol).all()
return _ai.assert_trees_all_eq_comparator_jittable(
cmp_fn, err_msg_template, *trees
cmp_fn, err_msg_template, *trees # pyrefly: ignore[bad-argument-type]
)


Expand Down
2 changes: 1 addition & 1 deletion chex/_src/asserts_chexify.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _wait_checks():

# Add the callback to the chexified funtion's properties.
if not hasattr(_chexified_fn, 'wait_checks'):
_chexified_fn.wait_checks = _wait_checks
_chexified_fn.wait_checks = _wait_checks # pyrefly: ignore[missing-attribute]
else:
logging.warning(
"Function %s already defines 'wait_checks' method; "
Expand Down
4 changes: 2 additions & 2 deletions chex/_src/asserts_chexify_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ def run_test_suite(self, make_test_fn, all_valid_args, all_invalid_args,
if run_pure:
# Test all versions return the same outputs.
asserts.assert_trees_all_equal(
fn_no_assert(*valid_args), fn_static_assert(*valid_args))
fn_no_assert(*valid_args), fn_static_assert(*valid_args)) # pyrefly: ignore[unbound-name]
asserts.assert_trees_all_equal(
fn_no_assert(*valid_args), fn_value_assert(*valid_args))

Expand All @@ -384,7 +384,7 @@ def run_test_suite(self, make_test_fn, all_valid_args, all_invalid_args,
if run_pure:
# Static assertion fails on incorrect inputs (without transformations).
with self.assertRaisesRegex(AssertionError, re.escape(label)):
fn_static_assert(*invalid_args)
fn_static_assert(*invalid_args) # pyrefly: ignore[unbound-name]

# Value assertion fails on incorrect inputs (with transformations).
err_regex = get_chexify_err_regex('assert_tree_positive_test', label)
Expand Down
6 changes: 3 additions & 3 deletions chex/_src/asserts_internal.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def _check(pred, msg, *fmt_args, **fmt_kwargs):
# In particular, this check fails when `pred` is False and no
# `checkify.check` calls took place in `jittable_assert_fn`, which would
# be a bug in the assertion's implementation.
checkify.check(pred, "assertion failed!")
checkify.check(pred, "assertion failed!") # pyrefly: ignore[bad-argument-type]
else:
try:
host_assertion_fn(
Expand Down Expand Up @@ -316,7 +316,7 @@ def num_devices_available(devtype: str, backend: Optional[str] = None) -> int:

def get_tracers(tree: pytypes.ArrayTree) -> Tuple[jax.core.Tracer]:
"""Returns a tuple with tracers from a tree."""
return tuple(
return tuple( # pyrefly: ignore[bad-return]
x for x in jax.tree_util.tree_leaves(tree)
if isinstance(x, jax.core.Tracer))

Expand Down Expand Up @@ -476,4 +476,4 @@ def _convert_key_fn(key: JaxKeyType) -> Union[int, str, Hashable]:
raise ValueError(f"Jax tree key '{key}' of type '{type(key)}' not valid.")
# pytype:enable=attribute-error

return tuple(_convert_key_fn(key) for key in jax_tree_path)
return tuple(_convert_key_fn(key) for key in jax_tree_path) # pyrefly: ignore[bad-return]
4 changes: 2 additions & 2 deletions chex/_src/asserts_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -685,7 +685,7 @@ def test_type_should_fail_scalar(self, scalars, wrong_type):
('float32_array', [1, 2], jnp.float32, np.integer),
)
def test_type_should_fail_array(self, array, dtype, wrong_type):
array = self.variant(emplace)(array, dtype)
array = self.variant(emplace)(array, dtype) # pyrefly: ignore[missing-attribute]
with self.assertRaisesRegex(
AssertionError, _get_err_regex('input .+ has type .+ but expected .+')):
asserts.assert_type(array, wrong_type)
Expand All @@ -708,7 +708,7 @@ def test_type_should_pass_scalar(self, array, expected_type):
('one_bool_array', [True], bool, bool),
)
def test_type_should_pass_array(self, array, dtype, expected_type):
array = self.variant(emplace)(array, dtype)
array = self.variant(emplace)(array, dtype) # pyrefly: ignore[missing-attribute]
asserts.assert_type(array, expected_type)

def test_type_should_fail_mixed(self):
Expand Down
6 changes: 3 additions & 3 deletions chex/_src/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ def __call__(self, cls):
)
# pytype: enable=wrong-keyword-args

fields_names = set(f.name for f in dataclasses.fields(dcls))
fields_names = set(f.name for f in dataclasses.fields(dcls)) # pyrefly: ignore[bad-argument-type]
invalid_fields = fields_names.intersection(_RESERVED_DCLS_FIELD_NAMES)
if invalid_fields:
raise ValueError(f"The following dataclass fields are disallowed: "
Expand All @@ -203,7 +203,7 @@ def __call__(self, cls):
dcls = mappable_dataclass(dcls)

def _from_tuple(args):
return dcls(zip(dcls.__dataclass_fields__.keys(), args))
return dcls(zip(dcls.__dataclass_fields__.keys(), args)) # pyrefly: ignore[bad-argument-type, missing-attribute]

def _to_tuple(self):
return tuple(getattr(self, k) for k in self.__dataclass_fields__.keys())
Expand Down Expand Up @@ -247,7 +247,7 @@ def _setstate(self, state):
@functools.wraps(orig_init)
def _init(self, *args, **kwargs):
register_dataclass_type_with_jax_tree_util(dcls)
return orig_init(self, *args, **kwargs)
return orig_init(self, *args, **kwargs) # pyrefly: ignore[bad-argument-count]

setattr(dcls, "from_tuple", _from_tuple)
setattr(dcls, "to_tuple", _to_tuple)
Expand Down
2 changes: 1 addition & 1 deletion chex/_src/dimensions.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def size(self, key: str) -> int:
f"cannot take product of shape '{key}' = {shape}, "
'because it contains non-positive sized dimensions'
)
return math.prod(shape)
return math.prod(shape) # pyrefly: ignore[no-matching-overload]

def __getitem__(self, key: str) -> Shape:
self._validate_key(key)
Expand Down
6 changes: 3 additions & 3 deletions chex/_src/fake.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@ def wrapped_fn(*args, **kwargs):
else:
vmap_in_axes = list(in_axes)
for argnum in static_broadcasted_argnums:
vmap_in_axes[argnum] = None
vmap_in_axes[argnum] = None # pyrefly: ignore[unsupported-operation]

# To protect the arguments from `static_broadcasted_argnums`,
# from turning into tracers (because of vmap), we capture the original
Expand Down Expand Up @@ -398,14 +398,14 @@ def __init__(self, fn_transformation: str, callback_fn: Callable[..., Any]):
"""
self._fn_transformation = fn_transformation
self._callback_fn = callback_fn
self._patch: mock._patch[Callable[[Any], Any]] = None # pylint: disable=unsubscriptable-object
self._patch: mock._patch[Callable[[Any], Any]] = None # pylint: disable=unsubscriptable-object # pyrefly: ignore[bad-assignment]
self._original_fn_transformation = None

def __enter__(self):

def _new_fn_transformation(fn, *args, **kwargs):
"""Returns a transformed version of the given function."""
transformed_fn = self._original_fn_transformation(fn, *args, **kwargs)
transformed_fn = self._original_fn_transformation(fn, *args, **kwargs) # pyrefly: ignore[not-callable]

@functools.wraps(transformed_fn)
def _new_transformed_fn(*args, **kwargs):
Expand Down
8 changes: 4 additions & 4 deletions chex/_src/variants.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@ def generate():
if named:
name = "_".join(t[0] for t in combination)
args_tuples = (t[1:] for t in combination)
args = sum(args_tuples, ())
args = sum(args_tuples, ()) # pyrefly: ignore[no-matching-overload]
yield (name, *args)
else:
yield sum(combination, ())
yield sum(combination, ()) # pyrefly: ignore[no-matching-overload]

return list(generate())

Expand Down Expand Up @@ -557,7 +557,7 @@ def bcast_fn(x):
return x

if broadcast_args_to_devices:
args = [
args = [ # pyrefly: ignore[bad-assignment]
tree_map(bcast_fn, arg) if idx not in static_argnums else arg
for idx, arg in enumerate(args)
]
Expand Down Expand Up @@ -596,7 +596,7 @@ def bcast_fn(x):
pmapped_fn = jax.pmap(fn, **pmap_kwargs)

res = pmapped_fn(*args, **kwargs)
return reduce_fn(res)
return reduce_fn(res) # pyrefly: ignore[not-callable]

return wrapper

Expand Down
Loading
Loading