Skip to content

Commit fa58c7e

Browse files
authored
perf: short-circuit COO.reshape when -1 resolves to self.shape (#935)
* perf: short-circuit COO.reshape when -1 resolves to self.shape COO.reshape returns self when self.shape equals the requested shape, but only checks before resolving any -1 in the target. sparse.tensordot passes shapes like (-1, N) even for 2D x 2D matmul that doesn't actually change shape, so the short-circuit never fires and a full reshape runs (linear_loc + coord rebuild + new COO allocation). Moving the -1 resolution before the equality check avoids that work for callers that pass a -1 factorization of the current shape. Behavior is a strict subset ("fewer copies"): any reshape that already returned self before still does, and reshape((-1, ...)) that resolves to the current shape now also returns self, matching the documented contract. Measured ~16% speedup on a warm conservative-regrid loop (xarray-regrid ConservativeRegridder.regrid) whose tensordot call sits on the hot path; bit-identical output. All 6050 existing numba-backend tests pass. * perf: use `-1 in shape` for presence check in COO.reshape Replace `any(d == -1 for d in shape)` with `-1 in shape`. The latter is a C-level tuple containment, the former a Python-level generator. On this machine (micro): 221 ns -> 45 ns per check. End-to-end on the PR's repro (median of 3): reshape(a.shape): 0.4 us -> 0.2 us (matches main; erases the regression introduced by running the -1 check unconditionally) reshape((-1, K)): 2.7 us -> 2.3 us (small incremental win) Pure readability / perf nit; semantics are identical since shape is already a tuple of ints by the line above.
1 parent f65a764 commit fa58c7e

1 file changed

Lines changed: 3 additions & 3 deletions

File tree

sparse/numba_backend/_coo/core.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1070,11 +1070,11 @@ def reshape(self, shape, order="C"):
10701070
if order not in {"C", None}:
10711071
raise NotImplementedError("The `order` parameter is not supported")
10721072

1073-
if self.shape == shape:
1074-
return self
1075-
if any(d == -1 for d in shape):
1073+
if -1 in shape:
10761074
extra = int(self.size / np.prod([d for d in shape if d != -1]))
10771075
shape = tuple([d if d != -1 else extra for d in shape])
1076+
if self.shape == shape:
1077+
return self
10781078

10791079
if self.size != reduce(operator.mul, shape, 1):
10801080
raise ValueError(f"cannot reshape array of size {self.size} into shape {shape}")

0 commit comments

Comments
 (0)