diff --git a/sparse/numba_backend/_coo/core.py b/sparse/numba_backend/_coo/core.py index f8d98b41..ef980fbc 100644 --- a/sparse/numba_backend/_coo/core.py +++ b/sparse/numba_backend/_coo/core.py @@ -1070,11 +1070,11 @@ def reshape(self, shape, order="C"): if order not in {"C", None}: raise NotImplementedError("The `order` parameter is not supported") - if self.shape == shape: - return self - if any(d == -1 for d in shape): + if -1 in shape: extra = int(self.size / np.prod([d for d in shape if d != -1])) shape = tuple([d if d != -1 else extra for d in shape]) + if self.shape == shape: + return self if self.size != reduce(operator.mul, shape, 1): raise ValueError(f"cannot reshape array of size {self.size} into shape {shape}")