From a1ac726760c60bf668421809f1f0680bd9d56ee7 Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Wed, 3 Jun 2026 20:21:32 +0000 Subject: [PATCH 1/3] test(rfd3): unit tests for symmetry frame and contig helpers Fixture-backed CPU unit tests for the pure symmetry geometry in rfd3.inference.symmetry: cyclic/dihedral frame generation, the framecoord <-> (R, t) round-trip, pack/unpack, Kabsch _align/_rms recovery, symmetry-id parsing, and contig-string expansion. New tests/rfd3/ dir, auto-collected by testpaths and kept separate from the cluster-only models/rfd3/tests/. No source or config changes. Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- tests/rfd3/test_symmetry_contigs.py | 44 ++++++ tests/rfd3/test_symmetry_frames.py | 216 ++++++++++++++++++++++++++++ 2 files changed, 260 insertions(+) create mode 100644 tests/rfd3/test_symmetry_contigs.py create mode 100644 tests/rfd3/test_symmetry_frames.py diff --git a/tests/rfd3/test_symmetry_contigs.py b/tests/rfd3/test_symmetry_contigs.py new file mode 100644 index 00000000..058edbfe --- /dev/null +++ b/tests/rfd3/test_symmetry_contigs.py @@ -0,0 +1,44 @@ +"""Unit tests for rfd3.inference.symmetry.contigs. + +The contig helpers expand compact motif specifications into explicit per-residue +labels. `expand_contig_to_resid_from_string` reads a single-character chain id +followed by an inclusive `start-end` residue range (e.g. "A1-5" -> A1..A5); +`expand_contig_unsym_motif` expands the range entries in a mixed list while +keeping the plain (dash-free) names. Both are pure string logic, pinned here. +""" + +from rfd3.inference.symmetry.contigs import ( + expand_contig_to_resid_from_string, + expand_contig_unsym_motif, +) + +# --- expand_contig_to_resid_from_string --------------------------------------- + + +def test_expand_contig_basic_range(): + assert expand_contig_to_resid_from_string("A1-5") == ["A1", "A2", "A3", "A4", "A5"] + + +def test_expand_contig_is_inclusive_of_endpoints(): + assert expand_contig_to_resid_from_string("B10-12") == ["B10", "B11", "B12"] + + +def test_expand_contig_single_residue_range(): + assert expand_contig_to_resid_from_string("C7-7") == ["C7"] + + +# --- expand_contig_unsym_motif ------------------------------------------------ + + +def test_expand_unsym_motif_expands_ranges_and_keeps_plain_names(): + # plain (dash-free) names are kept first, expanded ranges appended after. + result = expand_contig_unsym_motif(["A1-3", "LIG"]) + assert result == ["LIG", "A1", "A2", "A3"] + + +def test_expand_unsym_motif_without_ranges_is_unchanged(): + assert expand_contig_unsym_motif(["LIG", "GLY"]) == ["LIG", "GLY"] + + +def test_expand_unsym_motif_only_ranges(): + assert expand_contig_unsym_motif(["A1-2"]) == ["A1", "A2"] diff --git a/tests/rfd3/test_symmetry_frames.py b/tests/rfd3/test_symmetry_frames.py new file mode 100644 index 00000000..c9b25663 --- /dev/null +++ b/tests/rfd3/test_symmetry_frames.py @@ -0,0 +1,216 @@ +"""Unit tests for rfd3.inference.symmetry.frames. + +These pure functions build and manipulate the symmetry frames that drive RFD3's +symmetric-assembly generation: the cyclic / dihedral rotation sets, the +frame <-> (rotation, translation) conversions used in the symmetry loss, and the +Kabsch alignment that recovers a transform from two coordinate sets. Their +contracts — a frame is an `(R, t)` pair; `Cn` is `n` proper rotations about z; +`Dn` is `2n`; the framecoord conversion round-trips; `_align` recovers an exact +rigid transform — are not obvious from the signatures, so the tests pin them on +small CPU inputs. + +One sharp edge is pinned deliberately: `is_valid_rotation_matrix` checks only +orthogonality (`R @ R.T == I`), not `det(R) == +1`, so it accepts reflections +(see the roadmap finding on tightening it). +""" + +import numpy as np +import pytest +import torch +from rfd3.inference.symmetry.frames import ( + RTs_to_framecoords, + _align, + _rms, + decompose_symmetry_frame, + framecoords_to_RTs, + get_cyclic_frames, + get_dihedral_frames, + get_symmetry_frames_from_symmetry_id, + is_valid_rotation_matrix, + pack_vector, + unpack_vector, +) + +# --- is_valid_rotation_matrix ------------------------------------------------- + + +def test_identity_is_valid_rotation(): + assert is_valid_rotation_matrix(np.eye(3)) + + +def test_proper_rotation_is_valid(): + R = get_cyclic_frames(4)[1][0] # 90 deg about z + assert is_valid_rotation_matrix(R) + + +def test_non_orthogonal_matrix_is_invalid(): + assert not is_valid_rotation_matrix(2 * np.eye(3)) + + +def test_reflection_passes_orthogonality_only_check(): + """`is_valid_rotation_matrix` constrains orthogonality, not determinant. + + A reflection (det -1) is orthogonal, so it is accepted even though it is not + a proper rotation. Pinned to document the actual contract; see the roadmap + finding on tightening this to also require det == +1. + """ + reflection = np.diag([1.0, 1.0, -1.0]) + assert np.isclose(np.linalg.det(reflection), -1.0) + assert is_valid_rotation_matrix(reflection) + + +# --- get_cyclic_frames -------------------------------------------------------- + + +def test_cyclic_frame_count_and_zero_translation(): + frames = get_cyclic_frames(3) + assert len(frames) == 3 + for _, t in frames: + assert np.array_equal(t, np.zeros(3)) + + +def test_cyclic_first_frame_is_identity(): + R, _ = get_cyclic_frames(6)[0] + assert np.allclose(R, np.eye(3)) + + +def test_cyclic_frames_are_proper_rotations(): + for R, _ in get_cyclic_frames(5): + assert is_valid_rotation_matrix(R) + assert np.isclose(np.linalg.det(R), 1.0) + + +def test_cyclic_frame_rotates_about_z_by_expected_angle(): + # order 4, index 1 -> 90 deg CCW about z: e_x -> e_y, z fixed. + R, _ = get_cyclic_frames(4)[1] + assert np.allclose(R @ np.array([1.0, 0.0, 0.0]), [0.0, 1.0, 0.0], atol=1e-12) + assert np.allclose(R @ np.array([0.0, 0.0, 1.0]), [0.0, 0.0, 1.0]) + + +def test_cyclic_generator_has_order_n(): + # applying the unit rotation `order` times returns to identity. + order = 7 + R = get_cyclic_frames(order)[1][0] + assert np.allclose(np.linalg.matrix_power(R, order), np.eye(3), atol=1e-9) + + +# --- get_dihedral_frames ------------------------------------------------------ + + +def test_dihedral_frame_count_is_double_order(): + assert len(get_dihedral_frames(3)) == 6 + + +def test_dihedral_frames_are_proper_rotations(): + # both the rotation frames and the flipped frames are proper rotations. + for R, t in get_dihedral_frames(4): + assert np.array_equal(t, np.zeros(3)) + assert is_valid_rotation_matrix(R) + assert np.isclose(np.linalg.det(R), 1.0) + + +def test_dihedral_even_frames_match_cyclic(): + order = 3 + dihedral = get_dihedral_frames(order) + cyclic = get_cyclic_frames(order) + for i in range(order): + assert np.allclose(dihedral[2 * i][0], cyclic[i][0]) + + +# --- get_symmetry_frames_from_symmetry_id ------------------------------------- + + +def test_symmetry_id_cyclic(): + frames = get_symmetry_frames_from_symmetry_id("C2") + assert len(frames) == 2 + assert all(is_valid_rotation_matrix(R) for R, _ in frames) + + +def test_symmetry_id_dihedral(): + assert len(get_symmetry_frames_from_symmetry_id("D2")) == 4 + + +def test_symmetry_id_is_case_insensitive(): + assert len(get_symmetry_frames_from_symmetry_id("c3")) == 3 + assert len(get_symmetry_frames_from_symmetry_id("d3")) == 6 + + +def test_symmetry_id_unsupported_raises(): + with pytest.raises(ValueError, match="not supported"): + get_symmetry_frames_from_symmetry_id("X9") + + +# --- RTs_to_framecoords <-> framecoords_to_RTs -------------------------------- + + +def test_framecoord_roundtrip_recovers_rotation_and_translation(): + R = torch.tensor(get_cyclic_frames(5)[1][0], dtype=torch.float64) + t = torch.tensor([3.0, -2.0, 5.0], dtype=torch.float64) + Ori, X, Y = RTs_to_framecoords(R, t, sig=1.0) + R_rec, T_rec = framecoords_to_RTs(Ori, X, Y) + assert torch.allclose(R_rec, R, atol=1e-5) + assert torch.allclose(T_rec, t, atol=1e-5) + + +def test_RTs_to_framecoords_accepts_numpy_and_returns_torch(): + R = get_cyclic_frames(4)[1][0] # numpy + t = np.array([1.0, 2.0, 3.0]) + Ori, X, Y = RTs_to_framecoords(R, t) + assert isinstance(Ori, torch.Tensor) + assert isinstance(X, torch.Tensor) + # Ori is the translation; X/Y sit one unit along the first two rotation rows. + assert torch.allclose(Ori, torch.from_numpy(t)) + + +# --- pack_vector / unpack_vector ---------------------------------------------- + + +def test_pack_unpack_roundtrip_preserves_values_and_dtype(): + v = np.array([1.5, -2.0, 3.25], dtype=np.float64) + packed = pack_vector(v) + assert packed.shape == (1,) + unpacked = unpack_vector(packed) + assert unpacked.shape == (1, 3) + assert np.array_equal(unpacked[0], v) + assert unpacked.dtype == v.dtype + + +def test_pack_vector_preserves_integer_dtype(): + v = np.array([1, 2, 3], dtype=np.int32) + assert unpack_vector(pack_vector(v)).dtype == np.int32 + + +# --- _align / _rms (Kabsch) --------------------------------------------------- + + +def test_align_recovers_known_rigid_transform(): + rng = np.random.default_rng(0) + X_moving = rng.normal(size=(8, 3)) + R_true = get_cyclic_frames(4)[1][0] # 90 deg about z + centroid = np.array([10.0, -3.0, 2.0]) + X_fixed = (X_moving - X_moving.mean(axis=0)) @ R_true.T + centroid + + u_moving, R, u_fixed = _align(X_fixed, X_moving) + assert is_valid_rotation_matrix(R) + assert np.allclose(R, R_true, atol=1e-6) + assert np.allclose(u_fixed, centroid, atol=1e-6) + # the recovered transform aligns moving onto fixed with ~zero RMSD. + assert _rms(X_fixed, X_moving, u_moving, R, u_fixed) < 1e-6 + + +def test_align_identical_point_sets_is_identity(): + rng = np.random.default_rng(1) + X = rng.normal(size=(6, 3)) + _, R, _ = _align(X, X) + assert np.allclose(R, np.eye(3), atol=1e-6) + + +# --- decompose_symmetry_frame ------------------------------------------------- + + +def test_decompose_symmetry_frame_origin_is_translation(): + R = get_cyclic_frames(4)[1][0] + T = np.array([1.0, 2.0, 3.0]) + Ori, _X, _Y = decompose_symmetry_frame((R, T)) + # each returned value is a packed (1,) structured array; the origin is T. + assert np.allclose(unpack_vector(Ori)[0], T, atol=1e-6) From 6c5fbecbddbb80c633c49a51b4df55600c2c70ed Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Wed, 3 Jun 2026 20:43:46 +0000 Subject: [PATCH 2/3] chore(mypy): bring models/rf3 into scope behind an ignore_errors ratchet Add models/rf3/src/rf3 to [tool.mypy].files and seed a fresh per-module ignore_errors ratchet listing the 26 rf3 modules with pre-existing errors (158 total), mirroring the rfd3 bootstrap. mypy now type-checks the 28 already-clean rf3 modules; the 26 are cleared slice-by-slice in follow-ups. Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- pyproject.toml | 37 ++++++++++++++++++++++++++++++++++++- 1 file changed, 36 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 07b0ef34..de1f72f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -208,7 +208,7 @@ typeCheckingMode = "off" # Per-module strictness is ratcheted up via [tool.mypy.overrides] as annotations land. [tool.mypy] python_version = "3.12" -files = ["src/foundry", "src/foundry_cli", "models/rfd3/src/rfd3"] +files = ["src/foundry", "src/foundry_cli", "models/rfd3/src/rfd3", "models/rf3/src/rf3"] ignore_missing_imports = true warn_unused_ignores = true warn_redundant_casts = true @@ -248,6 +248,41 @@ module = [ ] ignore_errors = true +# rf3 enablement ratchet. `models/rf3` was brought into mypy's scope in 0014; +# these modules had pre-existing type errors at that point and are exempted until +# annotated. Fix the errors and remove the entry to enable type-checking for that +# module (same playbook as the rfd3 ratchet above). Do NOT add modules. +[[tool.mypy.overrides]] +module = [ + "rf3.callbacks.dump_validation_structures", + "rf3.callbacks.metrics_logging", + "rf3.data.extra_xforms", + "rf3.data.ground_truth_template", + "rf3.data.paired_msa", + "rf3.data.pipeline_utils", + "rf3.data.pipelines", + "rf3.inference", + "rf3.inference_engines.rf3", + "rf3.metrics.chiral", + "rf3.metrics.clashing_chains", + "rf3.metrics.distogram", + "rf3.metrics.lddt", + "rf3.metrics.metadata", + "rf3.metrics.predicted_error", + "rf3.metrics.rasa", + "rf3.metrics.selected_distances", + "rf3.model.RF3", + "rf3.model.layers.af3_diffusion_transformer", + "rf3.symmetry.resolve", + "rf3.trainers.rf3", + "rf3.utils.inference", + "rf3.utils.io", + "rf3.utils.loss", + "rf3.utils.predict_and_score", + "rf3.utils.predicted_error", +] +ignore_errors = true + # Testing ---------------------------------------------------------------------------- [tool.pytest.ini_options] testpaths = ["tests"] From a7100f6d2b5c3fd83292da29ebb4d26c0e655d14 Mon Sep 17 00:00:00 2001 From: Sergey Lyskov Date: Wed, 3 Jun 2026 21:48:12 +0000 Subject: [PATCH 3/3] test(rf3): clear easy-tier mypy modules and add first unit tests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit mypy: fix the 3 one-error rf3 modules (extra_xforms setattr, metrics_logging set[str], predict_and_score assert-narrow on run()'s dict|None return) and drop them from the ignore_errors ratchet (26 -> 23). Annotation/type-honesty only, no behavior change. tests: add tests/rf3/ first pass — 15 fixture-backed CPU unit tests pinning utils/frames.py (rigid_from_3_points proper-rotation/origin/batch /NA-branch, is_atom) and metrics/metric_utils.py (bin midpoints, unbin expectation, chain/interface masks, subsampled mean/min). Co-authored-by: lyskov-ai <277346777+lyskov-ai@users.noreply.github.com> --- .../rf3/src/rf3/callbacks/metrics_logging.py | 2 +- models/rf3/src/rf3/data/extra_xforms.py | 2 +- models/rf3/src/rf3/utils/predict_and_score.py | 4 +- pyproject.toml | 3 - tests/rf3/test_frames.py | 81 ++++++++++++++ tests/rf3/test_metric_utils.py | 104 ++++++++++++++++++ 6 files changed, 190 insertions(+), 6 deletions(-) create mode 100644 tests/rf3/test_frames.py create mode 100644 tests/rf3/test_metric_utils.py diff --git a/models/rf3/src/rf3/callbacks/metrics_logging.py b/models/rf3/src/rf3/callbacks/metrics_logging.py index c3ad2460..19b9577e 100755 --- a/models/rf3/src/rf3/callbacks/metrics_logging.py +++ b/models/rf3/src/rf3/callbacks/metrics_logging.py @@ -171,7 +171,7 @@ def _load_and_concatenate_csvs(self, epoch: int) -> pd.DataFrame: files = list(self.save_dir.glob(pattern)) # Track which example_id + dataset combinations we've already seen - seen_examples = set() + seen_examples: set[str] = set() final_dataframes = [] for f in files: diff --git a/models/rf3/src/rf3/data/extra_xforms.py b/models/rf3/src/rf3/data/extra_xforms.py index c65c71da..18d2e99a 100644 --- a/models/rf3/src/rf3/data/extra_xforms.py +++ b/models/rf3/src/rf3/data/extra_xforms.py @@ -52,7 +52,7 @@ def _patched(atom_array, *args, **kwargs): mol.AddConformer(conf, assignId=True) return result - _patched._input_coord_fallback_patched = True + setattr(_patched, "_input_coord_fallback_patched", True) _rdkit_utils.sample_rdkit_conformer_for_atom_array = _patched # af3_reference_molecule imports the function directly, so patch that reference too _af3_ref.sample_rdkit_conformer_for_atom_array = _patched diff --git a/models/rf3/src/rf3/utils/predict_and_score.py b/models/rf3/src/rf3/utils/predict_and_score.py index c775e7e1..166fc423 100644 --- a/models/rf3/src/rf3/utils/predict_and_score.py +++ b/models/rf3/src/rf3/utils/predict_and_score.py @@ -140,7 +140,9 @@ def predict_and_score_with_rf3( annotate_b_factor_with_plddt=annotate_b_factor_with_plddt, ) - # Extract results for this example + # Extract results for this example. run() returns a dict (never None) + # when out_dir is None, per its documented contract. + assert inference_results is not None result = inference_results[example_id] # Check for early stopping diff --git a/pyproject.toml b/pyproject.toml index de1f72f2..e0108c02 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -255,8 +255,6 @@ ignore_errors = true [[tool.mypy.overrides]] module = [ "rf3.callbacks.dump_validation_structures", - "rf3.callbacks.metrics_logging", - "rf3.data.extra_xforms", "rf3.data.ground_truth_template", "rf3.data.paired_msa", "rf3.data.pipeline_utils", @@ -278,7 +276,6 @@ module = [ "rf3.utils.inference", "rf3.utils.io", "rf3.utils.loss", - "rf3.utils.predict_and_score", "rf3.utils.predicted_error", ] ignore_errors = true diff --git a/tests/rf3/test_frames.py b/tests/rf3/test_frames.py new file mode 100644 index 00000000..b3a9503c --- /dev/null +++ b/tests/rf3/test_frames.py @@ -0,0 +1,81 @@ +"""Unit tests for rf3.utils.frames. + +These RF2AA-derived helpers build the rigid frames that anchor RF3's structural +losses. `rigid_from_3_points` constructs a per-residue orientation from the +backbone N/Ca/C atoms via Gram-Schmidt, then applies an idealization rotation +that nudges the frame toward canonical backbone geometry; the contract that is +not obvious from the body is that the returned `R` is still a proper rotation +(orthonormal, det +1) and the returned origin is exactly Ca. The module ships +with a `# TODO: ... HOPEFULLY TESTS` note and no tests, so the contracts are +pinned here on small CPU inputs. `is_atom` splits the sequence alphabet at +NNAPROTAAS (atom tokens are strictly above it). +""" + +import torch +from rf3.chemical import NNAPROTAAS +from rf3.utils.frames import is_atom, rigid_from_3_points + + +def _is_proper_rotation(R: torch.Tensor, atol: float = 1e-3) -> bool: + """True iff every R[..., 3, 3] is orthonormal and has det +1 (no reflection). + + The idealization step leaves det within ~2e-4 of 1, so det is checked with a + looser tolerance than orthonormality. + """ + eye = torch.eye(3).expand_as(R) + orthonormal = torch.allclose(R @ R.transpose(-1, -2), eye, atol=atol) + det = torch.linalg.det(R) + return ( + orthonormal + and bool((det > 0).all()) + and torch.allclose(det, torch.ones_like(det), atol=1e-2) + ) + + +# --- rigid_from_3_points ------------------------------------------------------ + + +def test_rigid_from_3_points_returns_proper_rotation(): + N = torch.tensor([[0.0, 1.0, 0.0]]) + Ca = torch.tensor([[0.0, 0.0, 0.0]]) + C = torch.tensor([[1.0, 0.0, 0.0]]) + R, _ = rigid_from_3_points(N, Ca, C) + assert R.shape == (1, 3, 3) + assert _is_proper_rotation(R) + + +def test_rigid_from_3_points_origin_is_ca(): + N = torch.tensor([[0.0, 1.0, 0.0]]) + Ca = torch.tensor([[2.0, -1.0, 3.0]]) + C = torch.tensor([[1.0, 0.0, 0.0]]) + _, t = rigid_from_3_points(N, Ca, C) + assert torch.equal(t, Ca) + + +def test_rigid_from_3_points_preserves_batch_dims(): + torch.manual_seed(0) + N, Ca, C = (torch.randn(2, 4, 3) for _ in range(3)) + R, t = rigid_from_3_points(N, Ca, C) + assert R.shape == (2, 4, 3, 3) + assert t.shape == (2, 4, 3) + assert _is_proper_rotation(R) + + +def test_rigid_from_3_points_na_path_is_proper_and_differs(): + # The is_na flag swaps the idealization target angle (costgt -> costgtNA), so + # the nucleic-acid frame is a different proper rotation than the protein one. + N = torch.tensor([[0.0, 1.0, 0.0]]) + Ca = torch.tensor([[0.0, 0.0, 0.0]]) + C = torch.tensor([[1.0, 0.0, 0.0]]) + R_protein, _ = rigid_from_3_points(N, Ca, C) + R_na, _ = rigid_from_3_points(N, Ca, C, is_na=torch.tensor([True])) + assert _is_proper_rotation(R_na) + assert not torch.allclose(R_protein, R_na, atol=1e-3) + + +# --- is_atom ------------------------------------------------------------------ + + +def test_is_atom_splits_strictly_above_nnaprotaas(): + seq = torch.tensor([0, NNAPROTAAS - 1, NNAPROTAAS, NNAPROTAAS + 1]) + assert is_atom(seq).tolist() == [False, False, False, True] diff --git a/tests/rf3/test_metric_utils.py b/tests/rf3/test_metric_utils.py new file mode 100644 index 00000000..f6c78412 --- /dev/null +++ b/tests/rf3/test_metric_utils.py @@ -0,0 +1,104 @@ +"""Unit tests for rf3.metrics.metric_utils. + +Pure helpers behind RF3's confidence metrics (pLDDT / PAE / PDE). The non-obvious +contracts pinned here: `find_bin_midpoints` returns the `num_bins` centres of +equal bins spanning `[0, max_distance]`; `unbin_logits` takes the softmax +expectation over those midpoints, so a one-hot logit recovers its bin's +midpoint; the chainwise / interface mask builders turn a per-residue chain-label +array into intra-chain and cross-chain boolean masks; and the subsampled +mean / min reduce a batched matrix over a boolean pair mask, with the min +explicitly excluding unscored entries. +""" + +import numpy as np +import torch +from rf3.metrics.metric_utils import ( + compute_mean_over_subsampled_pairs, + compute_min_over_subsampled_pairs, + create_chainwise_masks_1d, + create_chainwise_masks_2d, + create_interface_masks_2d, + find_bin_midpoints, + spread_batch_into_dictionary, + unbin_logits, +) + +# --- find_bin_midpoints ------------------------------------------------------- + + +def test_find_bin_midpoints_are_equal_bin_centres(): + # 5 bins over [0, 10] -> centres at 1, 3, 5, 7, 9. + mp = find_bin_midpoints(10.0, 5) + assert torch.allclose(mp, torch.tensor([1.0, 3.0, 5.0, 7.0, 9.0])) + + +def test_find_bin_midpoints_count_matches_num_bins(): + assert find_bin_midpoints(32.0, 64).shape == (64,) + + +# --- unbin_logits ------------------------------------------------------------- + + +def test_unbin_logits_recovers_bin_midpoint(): + # A near-one-hot distribution on bin index 2 unbins to that bin's midpoint (5.0). + num_bins = 5 + logits = torch.full((1, num_bins, 2, 2), -50.0) + logits[:, 2] = 50.0 + out = unbin_logits(logits, 10.0, num_bins) + assert out.shape == (1, 2, 2) + assert torch.allclose(out, torch.full((1, 2, 2), 5.0), atol=1e-3) + + +# --- chainwise / interface masks ---------------------------------------------- + + +def test_create_chainwise_masks_1d(): + masks = create_chainwise_masks_1d(np.array(["A", "A", "B"])) + assert masks["A"].tolist() == [True, True, False] + assert masks["B"].tolist() == [False, False, True] + + +def test_create_chainwise_masks_2d_is_intra_chain_outer_product(): + masks = create_chainwise_masks_2d(np.array(["A", "A", "B"])) + assert masks["A"].int().tolist() == [[1, 1, 0], [1, 1, 0], [0, 0, 0]] + assert masks["B"].int().tolist() == [[0, 0, 0], [0, 0, 0], [0, 0, 1]] + + +def test_create_interface_masks_2d_is_symmetric_cross_chain(): + masks = create_interface_masks_2d(np.array(["A", "A", "B"])) + assert list(masks.keys()) == [("A", "B")] + assert masks[("A", "B")].int().tolist() == [[0, 0, 1], [0, 0, 1], [1, 1, 0]] + + +# --- subsampled reductions ---------------------------------------------------- + + +def test_compute_mean_over_subsampled_pairs(): + mat = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]]) + pairs = torch.tensor([[True, False], [False, True]]) + # mean over the two scored (diagonal) entries: (1 + 4) / 2 + assert torch.allclose( + compute_mean_over_subsampled_pairs(mat, pairs), torch.tensor([2.5]), atol=1e-4 + ) + + +def test_compute_min_over_subsampled_pairs(): + mat = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]]) + pairs = torch.tensor([[True, False], [False, True]]) + # min over the two scored (diagonal) entries: min(1, 4) + assert compute_min_over_subsampled_pairs(mat, pairs).tolist() == [1.0] + + +def test_compute_min_excludes_unscored_entries(): + # The unscored off-diagonal holds the global min (0.0); masking must exclude it + # so the result is the smallest *scored* entry (5.0), not 0.0. + mat = torch.tensor([[[5.0, 0.0], [3.0, 9.0]]]) + pairs = torch.tensor([[True, False], [False, True]]) + assert compute_min_over_subsampled_pairs(mat, pairs).tolist() == [5.0] + + +# --- spread_batch_into_dictionary --------------------------------------------- + + +def test_spread_batch_into_dictionary(): + assert spread_batch_into_dictionary(torch.tensor([1.0, 2.0])) == {0: 1.0, 1: 2.0}