diff --git a/paimon-python/pypaimon/read/scanner/bucket_select_converter.py b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py index da3f0e0f47dc..7b2e9e104ffe 100644 --- a/paimon-python/pypaimon/read/scanner/bucket_select_converter.py +++ b/paimon-python/pypaimon/read/scanner/bucket_select_converter.py @@ -51,38 +51,26 @@ * Total cartesian product capped at MAX_VALUES (1000), again matching Java; above that, fall back to a full scan. -Returns a callable ``selector(bucket: int, total_buckets: int) -> bool``. -The callable is cached per ``total_buckets`` to handle the rare case -where bucket count varies across snapshots (rescale). - -TODO: per-partition predicate pre-evaluation. - - Predicates of the form ``(part='a' AND bk IN (1,2)) OR (part='b' AND bk - IN (3,4))`` currently fall through to "no pruning" because the top-level - OR mixes partition and bucket-key constraints. Java simplifies the - predicate per concrete partition value first (replacing partition - leaves with literal true/false and folding AND/OR), so each partition - gets a tighter bucket-key predicate and the corresponding bucket set. - - Implementing this here would need three pieces: - - * a Predicate-replace walker that substitutes a partition's actual - values into partition-column leaves (mirrors Java's - ``paimon-common/.../predicate/PartitionValuePredicateVisitor.java``). - * lifting ``_Selector`` to key its cache by - ``(partition, total_buckets)`` instead of just ``total_buckets``. - * threading the partition value into the early manifest filter - ``FileScanner._build_early_bucket_filter`` (currently sees only - ``(bucket, total_buckets)``). +Returns a callable ``selector(partition, bucket: int, total_buckets: int) +-> bool``. The callable is cached per ``(partition, total_buckets)`` to +handle (a) bucket count variation across snapshots (rescale) and (b) +per-partition predicate specialisation: predicates of the form +``(part='a' AND bk IN (1,2)) OR (part='b' AND bk IN (3,4))`` are +simplified per concrete partition value before bucket selection, so each +partition gets its own tight bucket set. + +When ``partition`` is ``None`` (early manifest filter that has not yet +deserialised the entry), the selector falls back to a partition-agnostic +result — sound but possibly wider than the per-partition tight set. """ from itertools import product -from typing import Any, Callable, Dict, FrozenSet, List, Optional +from typing import Any, Callable, Dict, FrozenSet, List, Optional, Set, Tuple, Union from pypaimon.common.predicate import Predicate from pypaimon.schema.data_types import DataField from pypaimon.table.row.generic_row import GenericRow, GenericRowSerializer -from pypaimon.table.row.internal_row import RowKind +from pypaimon.table.row.internal_row import InternalRow, RowKind from pypaimon.write.row_key_extractor import (_bucket_from_hash, _hash_bytes_by_words) @@ -177,78 +165,17 @@ def _extract_or_clause(or_pred: Predicate, return None if slot is None else [slot, values] -class _Selector: - """Callable bucket filter, lazy + cached per ``total_buckets``.""" - - __slots__ = ('_combinations', '_bucket_key_fields', '_cache') - - def __init__(self, combinations: List[List[Any]], - bucket_key_fields: List[DataField]): - self._combinations = combinations - self._bucket_key_fields = bucket_key_fields - self._cache: Dict[int, FrozenSet[int]] = {} - - def __call__(self, bucket: int, total_buckets: int) -> bool: - # ``total_buckets <= 0`` shows up for postpone / legacy / special - # entries and must NOT be pruned: returning False here would drop - # rows the writer hashed under a different convention. Fail open. - if total_buckets <= 0: - return True - try: - return bucket in self._compute(total_buckets) - except Exception: - # Fail open on any hashing/serialization error (e.g. a literal - # type that doesn't match the bucket-key column's atomic type: - # ``pb.equal('id_bigint', 'foo')`` — GenericRowSerializer raises - # struct.error trying to pack the string as int64). Crashing - # the entire scan here would be worse than skipping pruning; - # the soundness contract still forbids false-negatives. - return True - - def _compute(self, total_buckets: int) -> FrozenSet[int]: - cached = self._cache.get(total_buckets) - if cached is not None: - return cached - result = set() - for combo in self._combinations: - row = GenericRow(list(combo), self._bucket_key_fields, - RowKind.INSERT) - serialized = GenericRowSerializer.to_bytes(row) - # Skip the 4-byte length prefix — matches the writer's hash - # input exactly (see RowKeyExtractor._binary_row_hash_code). - h = _hash_bytes_by_words(serialized[4:]) - result.add(_bucket_from_hash(h, total_buckets)) - frozen = frozenset(result) - self._cache[total_buckets] = frozen - return frozen - - @property - def bucket_combinations(self) -> int: - """Number of (bucket-key) combinations used to compute the filter. - Exposed for tests / observability.""" - return len(self._combinations) - - -def create_bucket_selector( - predicate: Optional[Predicate], - bucket_key_fields: List[DataField]) -> Optional[Callable[[int, int], bool]]: - """Try to derive a bucket selector from ``predicate`` constrained to - ``bucket_key_fields``. +def _build_combinations( + predicate: Predicate, + bucket_key_fields: List[DataField]) -> Optional[List[List[Any]]]: + """Walk ``predicate`` for top-level AND clauses constraining bucket-key + columns by Equal/In, intersect repeated constraints, and return the + cartesian product of literal values (one row per combination). - Returns: - A callable ``(bucket, total_buckets) -> bool`` if the predicate - pins down all bucket keys to a finite Equal/In set; otherwise None - (caller must NOT prune by bucket). + Returns None when the predicate doesn't pin down every bucket-key + column or when the cartesian product exceeds ``MAX_VALUES`` — the + caller treats that as "no pruning, all buckets accept". """ - if predicate is None or not bucket_key_fields: - return None - - # See ``_UNSAFE_BUCKET_KEY_TYPES``: refuse pruning when the bucket-key - # column types are prone to writer/reader byte-level disagreement on - # equal logical values. Fail open rather than risk false-negatives. - if _has_unsafe_bucket_key_type(bucket_key_fields): - return None - bk_name_to_slot: Dict[str, int] = { f.name: i for i, f in enumerate(bucket_key_fields) } @@ -294,5 +221,274 @@ def create_bucket_selector( if total > MAX_VALUES: return None - combinations = [list(combo) for combo in product(*slot_values)] - return _Selector(combinations, bucket_key_fields) + return [list(combo) for combo in product(*slot_values)] + + +def _hash_combinations(combinations: List[List[Any]], + bucket_key_fields: List[DataField], + total_buckets: int) -> FrozenSet[int]: + result = set() + for combo in combinations: + row = GenericRow(list(combo), bucket_key_fields, RowKind.INSERT) + serialized = GenericRowSerializer.to_bytes(row) + # Skip the 4-byte length prefix — matches the writer's hash + # input exactly (see RowKeyExtractor._binary_row_hash_code). + h = _hash_bytes_by_words(serialized[4:]) + result.add(_bucket_from_hash(h, total_buckets)) + return frozenset(result) + + +def _predicate_touches_partition(predicate: Predicate, + partition_field_names: Set[str]) -> bool: + """True if ``predicate`` references any partition column directly or + inside an AND/OR/NOT subtree.""" + if predicate.method in ('and', 'or', 'not'): + return any(_predicate_touches_partition(c, partition_field_names) + for c in (predicate.literals or [])) + return predicate.field is not None and predicate.field in partition_field_names + + +def _evaluate_partition_leaf(predicate: Predicate, + partition_values: Dict[str, Any]) -> Optional[bool]: + """Evaluate ``predicate`` (a leaf on a partition column) against the + concrete partition values. Returns True / False, or ``None`` if the + leaf isn't safely evaluable here (caller should keep the leaf + unchanged — bucket selection stays sound as long as we don't fold + away an evaluable False). + """ + field_value = partition_values.get(predicate.field) + tester = Predicate.testers.get(predicate.method) + if tester is None: + return None + try: + return tester.test_by_value(field_value, predicate.literals) + except Exception: + return None + + +_AlwaysFalse = False # sentinel: predicate always evaluates to False +_AlwaysTrue = None # sentinel: predicate cleared (always True) + + +def replace_partition_predicate( + predicate: Predicate, + partition_field_names: Set[str], + partition_values: Dict[str, Any]) -> Optional[Union[bool, Predicate]]: + """Substitute partition-column leaves with their concrete values and + fold away always-true / always-false sub-expressions. + + Three-way return: + + * ``None`` — predicate is unconditionally True after substitution + (no constraint left for this partition). + * ``False`` — predicate is unconditionally False (this partition + cannot contain matching rows). + * ``Predicate`` — the simplified predicate; partition leaves are + gone. The caller continues bucket-key extraction on this. + """ + if predicate.method == 'and': + new_children: List[Predicate] = [] + for child in (predicate.literals or []): + simplified = replace_partition_predicate( + child, partition_field_names, partition_values) + if simplified is _AlwaysFalse: + return _AlwaysFalse + if simplified is _AlwaysTrue: + continue + new_children.append(simplified) + if not new_children: + return _AlwaysTrue + if len(new_children) == 1: + return new_children[0] + return Predicate(method='and', index=None, field=None, + literals=new_children) + + if predicate.method == 'or': + new_children = [] + for child in (predicate.literals or []): + simplified = replace_partition_predicate( + child, partition_field_names, partition_values) + if simplified is _AlwaysTrue: + return _AlwaysTrue + if simplified is _AlwaysFalse: + continue + new_children.append(simplified) + if not new_children: + return _AlwaysFalse + if len(new_children) == 1: + return new_children[0] + return Predicate(method='or', index=None, field=None, + literals=new_children) + + # Leaf predicate. + if predicate.field is not None and predicate.field in partition_field_names: + truth = _evaluate_partition_leaf(predicate, partition_values) + if truth is True: + return _AlwaysTrue + if truth is False: + return _AlwaysFalse + # Couldn't safely evaluate — keep the leaf. Bucket selection + # stays sound: the leaf still gets ANDed in, just doesn't help + # narrow buckets for this partition. + return predicate + + # Non-partition leaf: keep as-is. + return predicate + + +def _partition_to_dict(partition: Optional[InternalRow], + partition_fields: List[DataField]) -> Dict[str, Any]: + """Pull each partition column's value out of ``partition`` keyed by + field name. Returns an empty dict when ``partition`` is None.""" + if partition is None: + return {} + out: Dict[str, Any] = {} + for i, field in enumerate(partition_fields): + try: + out[field.name] = partition.get_field(i) + except Exception: + out[field.name] = None + return out + + +def _partition_to_cache_key(partition: Optional[InternalRow], + partition_fields: List[DataField] + ) -> Optional[Tuple[Any, ...]]: + if partition is None or not partition_fields: + return None + try: + return tuple(partition.get_field(i) for i in range(len(partition_fields))) + except Exception: + return None + + +class _Selector: + """Callable bucket filter, lazy + cached per ``(partition, total_buckets)``.""" + + __slots__ = ('_predicate', '_bucket_key_fields', '_partition_fields', + '_cache') + + def __init__(self, predicate: Predicate, + bucket_key_fields: List[DataField], + partition_fields: Optional[List[DataField]] = None): + self._predicate = predicate + self._bucket_key_fields = bucket_key_fields + self._partition_fields = list(partition_fields or []) + self._cache: Dict[Tuple[Optional[Tuple[Any, ...]], int], FrozenSet[int]] = {} + + def __call__(self, *args) -> bool: + # Accept ``(bucket, total_buckets)`` (early manifest filter that + # hasn't deserialised the entry yet — partition unknown) or + # ``(partition, bucket, total_buckets)`` (late filter on a fully + # decoded ``ManifestEntry``). The two-arg form is partition- + # agnostic; partition substitution is skipped. + if len(args) == 2: + partition = None + bucket, total_buckets = args + elif len(args) == 3: + partition, bucket, total_buckets = args + else: + raise TypeError( + "_Selector expects 2 or 3 positional args, got %d" % len(args)) + # ``total_buckets <= 0`` shows up for postpone / legacy / special + # entries and must NOT be pruned: returning False here would drop + # rows the writer hashed under a different convention. Fail open. + if total_buckets <= 0: + return True + try: + return bucket in self._compute(partition, total_buckets) + except Exception: + # Fail open on any hashing / serialization / specialisation + # error (e.g. a literal type that doesn't match the bucket-key + # column's atomic type). Crashing the entire scan here would + # be worse than skipping pruning; the soundness contract still + # forbids false-negatives. + return True + + def _compute(self, partition, total_buckets: int) -> FrozenSet[int]: + cache_key = (_partition_to_cache_key(partition, self._partition_fields), + total_buckets) + cached = self._cache.get(cache_key) + if cached is not None: + return cached + + effective_predicate: Optional[Union[bool, Predicate]] = self._predicate + if partition is not None and self._partition_fields: + partition_values = _partition_to_dict(partition, self._partition_fields) + partition_field_names = {f.name for f in self._partition_fields} + effective_predicate = replace_partition_predicate( + self._predicate, partition_field_names, partition_values) + + if effective_predicate is _AlwaysFalse: + # No row in this partition can match — empty bucket set. + frozen: FrozenSet[int] = frozenset() + self._cache[cache_key] = frozen + return frozen + + if effective_predicate is _AlwaysTrue: + # Predicate cleared after partition substitution — accept all + # buckets for this partition. + frozen = frozenset(range(total_buckets)) + self._cache[cache_key] = frozen + return frozen + + combinations = _build_combinations(effective_predicate, + self._bucket_key_fields) + if combinations is None: + # Couldn't pin down all bucket keys (or above MAX_VALUES) — + # fall back to "all buckets accept" for soundness. + frozen = frozenset(range(total_buckets)) + self._cache[cache_key] = frozen + return frozen + + frozen = _hash_combinations(combinations, self._bucket_key_fields, + total_buckets) + self._cache[cache_key] = frozen + return frozen + + +def create_bucket_selector( + predicate: Optional[Predicate], + bucket_key_fields: List[DataField], + partition_fields: Optional[List[DataField]] = None, +) -> Optional[Callable[[Any, int, int], bool]]: + """Try to derive a bucket selector from ``predicate`` constrained to + ``bucket_key_fields``. + + Returns: + A callable ``(partition, bucket, total_buckets) -> bool``. When + ``partition_fields`` is given and the predicate references those + partition columns, the selector specialises the predicate per + partition value before hashing — this catches mixed forms like + ``(part='a' AND bk IN (1,2)) OR (part='b' AND bk IN (3,4))`` that + would otherwise be unprunable. ``partition=None`` callsites + (early manifest filter that hasn't deserialised the entry yet) + simply get the partition-agnostic result. + + Returns None when the predicate carries no usable bucket-key + constraint at all (caller must NOT prune by bucket). + """ + if predicate is None or not bucket_key_fields: + return None + + # See ``_UNSAFE_BUCKET_KEY_TYPES``: refuse pruning when the bucket-key + # column types are prone to writer/reader byte-level disagreement on + # equal logical values. Fail open rather than risk false-negatives. + if _has_unsafe_bucket_key_type(bucket_key_fields): + return None + + # Sanity gate: if the predicate without any partition substitution + # already fails to pin down bucket keys AND it doesn't touch any + # partition columns, there's no point handing the caller a selector + # that always returns "all buckets" — preserve the original "return + # None for unprunable" contract so the caller can skip the wrap. + partition_names = {f.name for f in (partition_fields or [])} + touches_partition = ( + bool(partition_names) + and _predicate_touches_partition(predicate, partition_names) + ) + if not touches_partition: + if _build_combinations(predicate, bucket_key_fields) is None: + return None + + return _Selector(predicate, bucket_key_fields, partition_fields) diff --git a/paimon-python/pypaimon/read/scanner/file_scanner.py b/paimon-python/pypaimon/read/scanner/file_scanner.py index 70cfa6c978f2..ea0e8219ee4e 100755 --- a/paimon-python/pypaimon/read/scanner/file_scanner.py +++ b/paimon-python/pypaimon/read/scanner/file_scanner.py @@ -30,6 +30,7 @@ from pypaimon.manifest.schema.manifest_entry import ManifestEntry from pypaimon.manifest.schema.manifest_file_meta import ManifestFileMeta from pypaimon.manifest.simple_stats_evolutions import SimpleStatsEvolutions +from pypaimon.schema.data_types import DataField from pypaimon.read.plan import Plan from pypaimon.read.push_down_utils import (_get_all_fields, remove_row_id_filter, @@ -356,11 +357,12 @@ def _build_early_bucket_filter(self): """Compose the (bucket, total_buckets) -> bool used by the manifest reader to drop entries before deserialising ``_FILE`` / partition. - Mirrors the BucketFilter applied at Java's InternalRow stage in - ``ManifestEntryCache``. The signature is intentionally minimal: - per-partition predicate pre-evaluation would also need - ``(partition, bucket, total_buckets)``, but the current selector - is partition-agnostic. + The selector is partition-aware now, but at this early stage the + partition field has not been deserialised yet, so callers stick + with the two-arg form. The selector internally falls back to a + partition-agnostic over-approximation; per-partition tightening + still happens later in ``_filter_manifest_entry`` once the entry + is fully decoded. """ only_real = self.only_read_real_buckets selector = self._bucket_selector @@ -479,7 +481,21 @@ def _init_bucket_selector(self): return None if not bucket_key_fields: return None - return create_bucket_selector(self.predicate, bucket_key_fields) + # Partition fields are passed so the selector can specialise + # the predicate per partition value at the late filter stage, + # turning ``(part='a' AND bk=1) OR (part='b' AND bk=2)`` into a + # precise bucket pick per partition instead of an over-scan. + partition_fields: Optional[List[DataField]] = None + if self.table.partition_keys: + partition_fields = [ + self.table.field_dict[name] + for name in self.table.partition_keys + if name in self.table.field_dict + ] + return create_bucket_selector( + self.predicate, bucket_key_fields, + partition_fields=partition_fields, + ) def _filter_manifest_entry(self, entry: ManifestEntry) -> bool: # Redundant safety net: the early filter in the manifest reader @@ -489,7 +505,8 @@ def _filter_manifest_entry(self, entry: ManifestEntry) -> bool: return False if (self._bucket_selector is not None and entry.bucket >= 0 - and not self._bucket_selector(entry.bucket, entry.total_buckets)): + and not self._bucket_selector( + entry.partition, entry.bucket, entry.total_buckets)): return False if self.partition_key_predicate and not self.partition_key_predicate.test(entry.partition): return False diff --git a/paimon-python/pypaimon/tests/pushdown_bucket_test.py b/paimon-python/pypaimon/tests/pushdown_bucket_test.py index b83283200e8c..80d3065ca312 100644 --- a/paimon-python/pypaimon/tests/pushdown_bucket_test.py +++ b/paimon-python/pypaimon/tests/pushdown_bucket_test.py @@ -345,6 +345,217 @@ def test_type_mismatched_literal_fails_open_not_crash(self): "not crash (bucket={}, total={})".format(b, total)) +class PartitionAwareBucketSelectorUnitTest(unittest.TestCase): + """Unit tests for the per-partition predicate specialisation path. + + Covers ``replace_partition_predicate`` (the AND/OR fold walker) and + the partition-aware ``_Selector.__call__(partition, bucket, + total_buckets)`` 3-arg form that ``FileScanner._filter_manifest_entry`` + will use after wiring.""" + + @classmethod + def setUpClass(cls): + cls.id_field = _bigint_field(0, 'id') + cls.part_field = DataField(2, 'part', AtomicType('STRING')) + cls.pb = PredicateBuilder([cls.id_field, cls.part_field]) + + # ----- replace_partition_predicate -------------------------------- + + def test_replace_partition_leaf_to_true_drops_constraint(self): + from pypaimon.read.scanner.bucket_select_converter import \ + replace_partition_predicate + # ``part = 'a' AND id = 1`` against partition {part: 'a'} → + # part leaf becomes True → AND fold removes it → only ``id = 1`` + pred = PredicateBuilder.and_predicates([ + self.pb.equal('part', 'a'), + self.pb.equal('id', 1), + ]) + result = replace_partition_predicate( + pred, {'part'}, {'part': 'a'}) + self.assertTrue(isinstance(result, type(pred)), + "AND should fold to a remaining single leaf") + self.assertEqual(result.method, 'equal') + self.assertEqual(result.field, 'id') + + def test_replace_partition_leaf_to_false_collapses_and(self): + from pypaimon.read.scanner.bucket_select_converter import \ + replace_partition_predicate + # ``part = 'a' AND id = 1`` against partition {part: 'b'} → + # part leaf becomes False → AND collapses to AlwaysFalse (False). + pred = PredicateBuilder.and_predicates([ + self.pb.equal('part', 'a'), + self.pb.equal('id', 1), + ]) + result = replace_partition_predicate( + pred, {'part'}, {'part': 'b'}) + self.assertIs(result, False) + + def test_replace_partition_leaf_in_or_keeps_other_branch(self): + from pypaimon.read.scanner.bucket_select_converter import \ + replace_partition_predicate + # ``(part='a' AND id=1) OR (part='b' AND id=2)`` against + # partition {part: 'a'} → first OR child becomes ``id=1``, second + # collapses to AlwaysFalse and is dropped. Result is just ``id=1``. + pred = PredicateBuilder.or_predicates([ + PredicateBuilder.and_predicates([ + self.pb.equal('part', 'a'), + self.pb.equal('id', 1), + ]), + PredicateBuilder.and_predicates([ + self.pb.equal('part', 'b'), + self.pb.equal('id', 2), + ]), + ]) + result = replace_partition_predicate( + pred, {'part'}, {'part': 'a'}) + # OR with a single surviving child unwraps to that child. + self.assertEqual(result.method, 'equal') + self.assertEqual(result.field, 'id') + self.assertEqual(result.literals, [1]) + + def test_replace_partition_leaf_in_or_other_partition(self): + from pypaimon.read.scanner.bucket_select_converter import \ + replace_partition_predicate + # Same predicate, partition {part: 'b'} → second branch survives + # as ``id=2``. + pred = PredicateBuilder.or_predicates([ + PredicateBuilder.and_predicates([ + self.pb.equal('part', 'a'), + self.pb.equal('id', 1), + ]), + PredicateBuilder.and_predicates([ + self.pb.equal('part', 'b'), + self.pb.equal('id', 2), + ]), + ]) + result = replace_partition_predicate( + pred, {'part'}, {'part': 'b'}) + self.assertEqual(result.method, 'equal') + self.assertEqual(result.field, 'id') + self.assertEqual(result.literals, [2]) + + def test_replace_partition_leaf_unrelated_predicate_unchanged(self): + from pypaimon.read.scanner.bucket_select_converter import \ + replace_partition_predicate + # No partition leaf → predicate returned as-is. + pred = self.pb.equal('id', 42) + result = replace_partition_predicate( + pred, {'part'}, {'part': 'a'}) + self.assertIs(result, pred) + + # ----- _Selector partition-aware path ----------------------------- + + def test_selector_3arg_specialises_per_partition(self): + # ``(part='a' AND id=1) OR (part='b' AND id=2)`` should hit + # bucket(1) only when partition='a' and bucket(2) only when + # partition='b'. Master without this fix returns "all buckets". + pred = PredicateBuilder.or_predicates([ + PredicateBuilder.and_predicates([ + self.pb.equal('part', 'a'), + self.pb.equal('id', 1), + ]), + PredicateBuilder.and_predicates([ + self.pb.equal('part', 'b'), + self.pb.equal('id', 2), + ]), + ]) + sel = create_bucket_selector( + pred, [self.id_field], partition_fields=[self.part_field]) + self.assertIsNotNone(sel) + bucket_for_1 = _hash_bucket([1], [self.id_field], total=8) + bucket_for_2 = _hash_bucket([2], [self.id_field], total=8) + + part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT) + part_b = GenericRow(['b'], [self.part_field], RowKind.INSERT) + + for b in range(8): + self.assertEqual(sel(part_a, b, 8), b == bucket_for_1, + "partition a only keeps bucket %d" % bucket_for_1) + self.assertEqual(sel(part_b, b, 8), b == bucket_for_2, + "partition b only keeps bucket %d" % bucket_for_2) + + def test_selector_falls_through_when_partition_unknown(self): + """Early manifest filter passes ``partition=None`` (or uses the + 2-arg form) — no specialisation runs, bucket set falls back to a + sound over-approximation: all buckets accept.""" + pred = PredicateBuilder.or_predicates([ + PredicateBuilder.and_predicates([ + self.pb.equal('part', 'a'), + self.pb.equal('id', 1), + ]), + PredicateBuilder.and_predicates([ + self.pb.equal('part', 'b'), + self.pb.equal('id', 2), + ]), + ]) + sel = create_bucket_selector( + pred, [self.id_field], partition_fields=[self.part_field]) + self.assertIsNotNone(sel) + # 2-arg form (legacy callsite) — partition unknown, all buckets keep. + for b in range(8): + self.assertTrue(sel(b, 8), + "partition-unknown call must accept all buckets") + # 3-arg form with partition=None has the same semantics. + for b in range(8): + self.assertTrue(sel(None, b, 8)) + + def test_selector_partition_not_matching_returns_empty_bucket_set(self): + # ``part = 'a' AND id = 1`` on partition {part: 'c'} simplifies to + # AlwaysFalse — the selector returns False for every bucket since + # no row in this partition can possibly match. Sound: dropping a + # partition that *can't* contain matches doesn't lose data. + pred = PredicateBuilder.and_predicates([ + self.pb.equal('part', 'a'), + self.pb.equal('id', 1), + ]) + sel = create_bucket_selector( + pred, [self.id_field], partition_fields=[self.part_field]) + self.assertIsNotNone(sel) + part_c = GenericRow(['c'], [self.part_field], RowKind.INSERT) + for b in range(8): + self.assertFalse(sel(part_c, b, 8), + "partition c can't satisfy part='a', " + "drop every bucket (b=%d)" % b) + + def test_selector_partition_only_constraint_drops_partition(self): + # ``part='a' AND id IN (1,2)`` — same partition value 'a' + # specialises ``part='a'`` to True, leaving ``id IN (1,2)``. + pred = PredicateBuilder.and_predicates([ + self.pb.equal('part', 'a'), + self.pb.is_in('id', [1, 2]), + ]) + sel = create_bucket_selector( + pred, [self.id_field], partition_fields=[self.part_field]) + self.assertIsNotNone(sel) + part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT) + expected = {_hash_bucket([v], [self.id_field], 8) for v in (1, 2)} + for b in range(8): + self.assertEqual(sel(part_a, b, 8), b in expected) + + def test_selector_caches_per_partition(self): + pred = PredicateBuilder.or_predicates([ + PredicateBuilder.and_predicates([ + self.pb.equal('part', 'a'), + self.pb.equal('id', 1), + ]), + PredicateBuilder.and_predicates([ + self.pb.equal('part', 'b'), + self.pb.equal('id', 2), + ]), + ]) + sel = create_bucket_selector( + pred, [self.id_field], partition_fields=[self.part_field]) + part_a = GenericRow(['a'], [self.part_field], RowKind.INSERT) + part_b = GenericRow(['b'], [self.part_field], RowKind.INSERT) + # Drive the cache. + for _ in range(5): + sel(part_a, 0, 8) + sel(part_b, 0, 8) + # Cache keyed by (partition tuple, total_buckets); two distinct + # partitions × one total → exactly two entries. + self.assertEqual(len(sel._cache), 2) + + # --------------------------------------------------------------------------- # Layer 2 — Integration: real tables, public API, assert correctness AND # that pruning actually fired (otherwise we're not testing the optimisation, @@ -606,6 +817,66 @@ def test_bucket_key_option_overrides_pk_for_pruning(self): self.assertEqual(self._split_buckets(splits), self._expected_buckets(table, [17])) + def test_per_partition_pruning_with_mixed_or(self): + """``(part='a' AND id=1) OR (part='b' AND id=2)``: each partition + sees only the bucket for its own ``id`` literal. Without + per-partition predicate specialisation this query falls through + to "all buckets in both partitions".""" + opts = {'bucket': '4', 'file.format': 'parquet'} + pa_schema = pa.schema([ + pa.field('part', pa.string(), nullable=False), + pa.field('id', pa.int64(), nullable=False), + ('val', pa.int64()), + ]) + schema = Schema.from_pyarrow_schema( + pa_schema, primary_keys=['part', 'id'], + partition_keys=['part'], options=opts) + identifier = 'default.per_part_mixed_or' + self.catalog.create_table(identifier, schema, False) + table = self.catalog.get_table(identifier) + # Two partitions × three id values each → up to 6 (part, bucket) + # combinations after the writer hashes. + rows = [] + for p in ('a', 'b'): + for i in (1, 2, 3): + rows.append({'part': p, 'id': i, 'val': i * 7}) + wb = table.new_batch_write_builder() + w = wb.new_write() + c = wb.new_commit() + try: + w.write_arrow(pa.Table.from_pylist(rows, schema=pa_schema)) + c.commit(w.prepare_commit()) + finally: + w.close() + c.close() + + pb = table.new_read_builder().new_predicate_builder() + from pypaimon.common.predicate_builder import PredicateBuilder + mixed = PredicateBuilder.or_predicates([ + PredicateBuilder.and_predicates([ + pb.equal('part', 'a'), + pb.equal('id', 1), + ]), + PredicateBuilder.and_predicates([ + pb.equal('part', 'b'), + pb.equal('id', 2), + ]), + ]) + got, splits = self._read_with(table, mixed) + # Correctness: only the two matching rows. + got_sorted = sorted(got, key=lambda r: (r['part'], r['id'])) + self.assertEqual( + got_sorted, + [{'part': 'a', 'id': 1, 'val': 7}, + {'part': 'b', 'id': 2, 'val': 14}]) + # Pruning effectiveness: across both partitions we should see at + # most two distinct (partition, bucket) splits — one per branch. + # Without per-partition pruning we'd see every (partition, bucket) + # combo that exists on disk for the predicate's id literals. + self.assertLessEqual(len(splits), 2, + "per-partition pruning should keep ≤ 2 splits, " + "got %d" % len(splits)) + # --------------------------------------------------------------------------- # Layer 3 — Property: random PK tables, random Equal/In predicates,