Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion models/rf3/src/rf3/callbacks/metrics_logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def _load_and_concatenate_csvs(self, epoch: int) -> pd.DataFrame:
files = list(self.save_dir.glob(pattern))

# Track which example_id + dataset combinations we've already seen
seen_examples = set()
seen_examples: set[str] = set()
final_dataframes = []

for f in files:
Expand Down
2 changes: 1 addition & 1 deletion models/rf3/src/rf3/data/extra_xforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def _patched(atom_array, *args, **kwargs):
mol.AddConformer(conf, assignId=True)
return result

_patched._input_coord_fallback_patched = True
setattr(_patched, "_input_coord_fallback_patched", True)
_rdkit_utils.sample_rdkit_conformer_for_atom_array = _patched
# af3_reference_molecule imports the function directly, so patch that reference too
_af3_ref.sample_rdkit_conformer_for_atom_array = _patched
Expand Down
4 changes: 3 additions & 1 deletion models/rf3/src/rf3/utils/predict_and_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,9 @@ def predict_and_score_with_rf3(
annotate_b_factor_with_plddt=annotate_b_factor_with_plddt,
)

# Extract results for this example
# Extract results for this example. run() returns a dict (never None)
# when out_dir is None, per its documented contract.
assert inference_results is not None
result = inference_results[example_id]

# Check for early stopping
Expand Down
34 changes: 33 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ typeCheckingMode = "off"
# Per-module strictness is ratcheted up via [tool.mypy.overrides] as annotations land.
[tool.mypy]
python_version = "3.12"
files = ["src/foundry", "src/foundry_cli", "models/rfd3/src/rfd3"]
files = ["src/foundry", "src/foundry_cli", "models/rfd3/src/rfd3", "models/rf3/src/rf3"]
ignore_missing_imports = true
warn_unused_ignores = true
warn_redundant_casts = true
Expand Down Expand Up @@ -248,6 +248,38 @@ module = [
]
ignore_errors = true

# rf3 enablement ratchet. `models/rf3` was brought into mypy's scope in 0014;
# these modules had pre-existing type errors at that point and are exempted until
# annotated. Fix the errors and remove the entry to enable type-checking for that
# module (same playbook as the rfd3 ratchet above). Do NOT add modules.
[[tool.mypy.overrides]]
module = [
"rf3.callbacks.dump_validation_structures",
"rf3.data.ground_truth_template",
"rf3.data.paired_msa",
"rf3.data.pipeline_utils",
"rf3.data.pipelines",
"rf3.inference",
"rf3.inference_engines.rf3",
"rf3.metrics.chiral",
"rf3.metrics.clashing_chains",
"rf3.metrics.distogram",
"rf3.metrics.lddt",
"rf3.metrics.metadata",
"rf3.metrics.predicted_error",
"rf3.metrics.rasa",
"rf3.metrics.selected_distances",
"rf3.model.RF3",
"rf3.model.layers.af3_diffusion_transformer",
"rf3.symmetry.resolve",
"rf3.trainers.rf3",
"rf3.utils.inference",
"rf3.utils.io",
"rf3.utils.loss",
"rf3.utils.predicted_error",
]
ignore_errors = true

# Testing ----------------------------------------------------------------------------
[tool.pytest.ini_options]
testpaths = ["tests"]
Expand Down
81 changes: 81 additions & 0 deletions tests/rf3/test_frames.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
"""Unit tests for rf3.utils.frames.

These RF2AA-derived helpers build the rigid frames that anchor RF3's structural
losses. `rigid_from_3_points` constructs a per-residue orientation from the
backbone N/Ca/C atoms via Gram-Schmidt, then applies an idealization rotation
that nudges the frame toward canonical backbone geometry; the contract that is
not obvious from the body is that the returned `R` is still a proper rotation
(orthonormal, det +1) and the returned origin is exactly Ca. The module ships
with a `# TODO: ... HOPEFULLY TESTS` note and no tests, so the contracts are
pinned here on small CPU inputs. `is_atom` splits the sequence alphabet at
NNAPROTAAS (atom tokens are strictly above it).
"""

import torch
from rf3.chemical import NNAPROTAAS
from rf3.utils.frames import is_atom, rigid_from_3_points


def _is_proper_rotation(R: torch.Tensor, atol: float = 1e-3) -> bool:
"""True iff every R[..., 3, 3] is orthonormal and has det +1 (no reflection).

The idealization step leaves det within ~2e-4 of 1, so det is checked with a
looser tolerance than orthonormality.
"""
eye = torch.eye(3).expand_as(R)
orthonormal = torch.allclose(R @ R.transpose(-1, -2), eye, atol=atol)
det = torch.linalg.det(R)
return (
orthonormal
and bool((det > 0).all())
and torch.allclose(det, torch.ones_like(det), atol=1e-2)
)


# --- rigid_from_3_points ------------------------------------------------------


def test_rigid_from_3_points_returns_proper_rotation():
N = torch.tensor([[0.0, 1.0, 0.0]])
Ca = torch.tensor([[0.0, 0.0, 0.0]])
C = torch.tensor([[1.0, 0.0, 0.0]])
R, _ = rigid_from_3_points(N, Ca, C)
assert R.shape == (1, 3, 3)
assert _is_proper_rotation(R)


def test_rigid_from_3_points_origin_is_ca():
N = torch.tensor([[0.0, 1.0, 0.0]])
Ca = torch.tensor([[2.0, -1.0, 3.0]])
C = torch.tensor([[1.0, 0.0, 0.0]])
_, t = rigid_from_3_points(N, Ca, C)
assert torch.equal(t, Ca)


def test_rigid_from_3_points_preserves_batch_dims():
torch.manual_seed(0)
N, Ca, C = (torch.randn(2, 4, 3) for _ in range(3))
R, t = rigid_from_3_points(N, Ca, C)
assert R.shape == (2, 4, 3, 3)
assert t.shape == (2, 4, 3)
assert _is_proper_rotation(R)


def test_rigid_from_3_points_na_path_is_proper_and_differs():
# The is_na flag swaps the idealization target angle (costgt -> costgtNA), so
# the nucleic-acid frame is a different proper rotation than the protein one.
N = torch.tensor([[0.0, 1.0, 0.0]])
Ca = torch.tensor([[0.0, 0.0, 0.0]])
C = torch.tensor([[1.0, 0.0, 0.0]])
R_protein, _ = rigid_from_3_points(N, Ca, C)
R_na, _ = rigid_from_3_points(N, Ca, C, is_na=torch.tensor([True]))
assert _is_proper_rotation(R_na)
assert not torch.allclose(R_protein, R_na, atol=1e-3)


# --- is_atom ------------------------------------------------------------------


def test_is_atom_splits_strictly_above_nnaprotaas():
seq = torch.tensor([0, NNAPROTAAS - 1, NNAPROTAAS, NNAPROTAAS + 1])
assert is_atom(seq).tolist() == [False, False, False, True]
104 changes: 104 additions & 0 deletions tests/rf3/test_metric_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
"""Unit tests for rf3.metrics.metric_utils.

Pure helpers behind RF3's confidence metrics (pLDDT / PAE / PDE). The non-obvious
contracts pinned here: `find_bin_midpoints` returns the `num_bins` centres of
equal bins spanning `[0, max_distance]`; `unbin_logits` takes the softmax
expectation over those midpoints, so a one-hot logit recovers its bin's
midpoint; the chainwise / interface mask builders turn a per-residue chain-label
array into intra-chain and cross-chain boolean masks; and the subsampled
mean / min reduce a batched matrix over a boolean pair mask, with the min
explicitly excluding unscored entries.
"""

import numpy as np
import torch
from rf3.metrics.metric_utils import (
compute_mean_over_subsampled_pairs,
compute_min_over_subsampled_pairs,
create_chainwise_masks_1d,
create_chainwise_masks_2d,
create_interface_masks_2d,
find_bin_midpoints,
spread_batch_into_dictionary,
unbin_logits,
)

# --- find_bin_midpoints -------------------------------------------------------


def test_find_bin_midpoints_are_equal_bin_centres():
# 5 bins over [0, 10] -> centres at 1, 3, 5, 7, 9.
mp = find_bin_midpoints(10.0, 5)
assert torch.allclose(mp, torch.tensor([1.0, 3.0, 5.0, 7.0, 9.0]))


def test_find_bin_midpoints_count_matches_num_bins():
assert find_bin_midpoints(32.0, 64).shape == (64,)


# --- unbin_logits -------------------------------------------------------------


def test_unbin_logits_recovers_bin_midpoint():
# A near-one-hot distribution on bin index 2 unbins to that bin's midpoint (5.0).
num_bins = 5
logits = torch.full((1, num_bins, 2, 2), -50.0)
logits[:, 2] = 50.0
out = unbin_logits(logits, 10.0, num_bins)
assert out.shape == (1, 2, 2)
assert torch.allclose(out, torch.full((1, 2, 2), 5.0), atol=1e-3)


# --- chainwise / interface masks ----------------------------------------------


def test_create_chainwise_masks_1d():
masks = create_chainwise_masks_1d(np.array(["A", "A", "B"]))
assert masks["A"].tolist() == [True, True, False]
assert masks["B"].tolist() == [False, False, True]


def test_create_chainwise_masks_2d_is_intra_chain_outer_product():
masks = create_chainwise_masks_2d(np.array(["A", "A", "B"]))
assert masks["A"].int().tolist() == [[1, 1, 0], [1, 1, 0], [0, 0, 0]]
assert masks["B"].int().tolist() == [[0, 0, 0], [0, 0, 0], [0, 0, 1]]


def test_create_interface_masks_2d_is_symmetric_cross_chain():
masks = create_interface_masks_2d(np.array(["A", "A", "B"]))
assert list(masks.keys()) == [("A", "B")]
assert masks[("A", "B")].int().tolist() == [[0, 0, 1], [0, 0, 1], [1, 1, 0]]


# --- subsampled reductions ----------------------------------------------------


def test_compute_mean_over_subsampled_pairs():
mat = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]])
pairs = torch.tensor([[True, False], [False, True]])
# mean over the two scored (diagonal) entries: (1 + 4) / 2
assert torch.allclose(
compute_mean_over_subsampled_pairs(mat, pairs), torch.tensor([2.5]), atol=1e-4
)


def test_compute_min_over_subsampled_pairs():
mat = torch.tensor([[[1.0, 2.0], [3.0, 4.0]]])
pairs = torch.tensor([[True, False], [False, True]])
# min over the two scored (diagonal) entries: min(1, 4)
assert compute_min_over_subsampled_pairs(mat, pairs).tolist() == [1.0]


def test_compute_min_excludes_unscored_entries():
# The unscored off-diagonal holds the global min (0.0); masking must exclude it
# so the result is the smallest *scored* entry (5.0), not 0.0.
mat = torch.tensor([[[5.0, 0.0], [3.0, 9.0]]])
pairs = torch.tensor([[True, False], [False, True]])
assert compute_min_over_subsampled_pairs(mat, pairs).tolist() == [5.0]


# --- spread_batch_into_dictionary ---------------------------------------------


def test_spread_batch_into_dictionary():
assert spread_batch_into_dictionary(torch.tensor([1.0, 2.0])) == {0: 1.0, 1: 2.0}
44 changes: 44 additions & 0 deletions tests/rfd3/test_symmetry_contigs.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
"""Unit tests for rfd3.inference.symmetry.contigs.

The contig helpers expand compact motif specifications into explicit per-residue
labels. `expand_contig_to_resid_from_string` reads a single-character chain id
followed by an inclusive `start-end` residue range (e.g. "A1-5" -> A1..A5);
`expand_contig_unsym_motif` expands the range entries in a mixed list while
keeping the plain (dash-free) names. Both are pure string logic, pinned here.
"""

from rfd3.inference.symmetry.contigs import (
expand_contig_to_resid_from_string,
expand_contig_unsym_motif,
)

# --- expand_contig_to_resid_from_string ---------------------------------------


def test_expand_contig_basic_range():
assert expand_contig_to_resid_from_string("A1-5") == ["A1", "A2", "A3", "A4", "A5"]


def test_expand_contig_is_inclusive_of_endpoints():
assert expand_contig_to_resid_from_string("B10-12") == ["B10", "B11", "B12"]


def test_expand_contig_single_residue_range():
assert expand_contig_to_resid_from_string("C7-7") == ["C7"]


# --- expand_contig_unsym_motif ------------------------------------------------


def test_expand_unsym_motif_expands_ranges_and_keeps_plain_names():
# plain (dash-free) names are kept first, expanded ranges appended after.
result = expand_contig_unsym_motif(["A1-3", "LIG"])
assert result == ["LIG", "A1", "A2", "A3"]


def test_expand_unsym_motif_without_ranges_is_unchanged():
assert expand_contig_unsym_motif(["LIG", "GLY"]) == ["LIG", "GLY"]


def test_expand_unsym_motif_only_ranges():
assert expand_contig_unsym_motif(["A1-2"]) == ["A1", "A2"]
Loading
Loading