Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions .github/workflows/lint.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
name: Lint

on:
push:
branches: [main]
pull_request:

jobs:
lint:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v5
- run: uv python install 3.10
- run: uv pip install ruff
- run: uv run ruff check .
- run: uv run ruff format --check .
7 changes: 7 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
hooks:
- id: ruff
args: [--fix]
- id: ruff-format
65 changes: 10 additions & 55 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,58 +1,13 @@
[build-system]
requires = ["setuptools", "wheel"]
build-backend = "setuptools.build_meta"
[tool.ruff]
line-length = 120
extend-exclude = ["*.ipynb"]

[project]
name = "simplefold"
version = "0.1.0"
description = "Folding proteins with SimpleFold."
readme = "README.md"
requires-python = ">=3.10"
license = { file = "LICENSE" }
authors = [
{ name = "Yuyang Wang", email = "yuyangw@apple.com" },
{ name = "Miguel Angel Bautista Martin", email = "mbautistamartin@apple.com" }
]
dependencies = [
"lightning==2.5.2",
"biopython==1.85",
"hydra-colorlog==1.2.0",
"mediapy==1.0.2",
"ipykernel==6.22.0",
"ipyvtklink==0.2.3",
"gemmi==0.7.3",
"biotite==1.2.0",
"rdkit==2025.3.5",
"gemmi==0.7.3",
"p-tqdm==1.4.2",
"einops==0.8.1",
"mashumaro==3.16",
"ihm==2.7",
"modelcif==1.4",
"dm-tree==0.1.9",
"click==8.2.1",
"timm==1.0.19",
"py3Dmol==2.5.2",
"scikit-learn==1.7.1",
"pandas==2.3.1",
"seaborn==0.13.2",
]
[tool.ruff.lint]
select = ["E", "F", "I"]
ignore = ["E501", "E402", "E731", "E722", "E741", "E721", "E701", "E702", "E711", "F841", "F403", "F405", "F821", "F601"]

[project.urls]
Homepage = "https://github.com/apple/ml-simplefold"
[tool.ruff.lint.per-file-ignores]
"__init__.py" = ["F401"]

# Optional: create a CLI command `simplefold` that calls simplefold/cli.py:main()
[project.scripts]
simplefold = "simplefold.cli:main"

# Tell hatchling where your packages live when using src layout:
[tool.hatch.build.targets.wheel]
packages = ["src/simplefold"]

# Tell setuptools to include YAML files from configs
[tool.setuptools.packages.find]
where = ["src"]
include = ["simplefold*"]

[tool.setuptools.package-data]
"simplefold.configs" = ["**/*.yaml"]
[tool.uv]
# Install with: uv pip install -r requirements.txt
14 changes: 8 additions & 6 deletions sample.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@
"outputs": [],
"source": [
"import sys\n",
"import numpy as np\n",
"from io import StringIO\n",
"from math import pow\n",
"import py3Dmol\n",
"from pathlib import Path\n",
"from io import StringIO\n",
"from Bio.PDB import PDBIO\n",
"from Bio.PDB import MMCIFParser, Superimposer\n",
"\n",
"import numpy as np\n",
"import py3Dmol\n",
"from Bio.PDB import PDBIO, MMCIFParser, Superimposer\n",
"\n",
"sys.path.append(str(Path(\"./src/simplefold\").resolve()))"
]
},
Expand Down Expand Up @@ -80,6 +81,7 @@
"source": [
"# set random seed for reproducibility\n",
"import lightning.pytorch as pl\n",
"\n",
"pl.seed_everything(42, workers=True)"
]
},
Expand All @@ -90,7 +92,7 @@
"metadata": {},
"outputs": [],
"source": [
"from src.simplefold.wrapper import ModelWrapper, InferenceWrapper\n",
"from src.simplefold.wrapper import InferenceWrapper, ModelWrapper\n",
"\n",
"# initialize the folding model and pLDDT model\n",
"model_wrapper = ModelWrapper(\n",
Expand Down
3 changes: 1 addition & 2 deletions src/simplefold/boltz_data_pipeline/crop/boltz.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,10 @@
from typing import Optional

import numpy as np
from scipy.spatial.distance import cdist

from boltz_data_pipeline import const
from boltz_data_pipeline.crop.cropper import Cropper
from boltz_data_pipeline.types import Tokenized
from scipy.spatial.distance import cdist


def pick_random_token(
Expand Down
1 change: 0 additions & 1 deletion src/simplefold/boltz_data_pipeline/crop/cropper.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from typing import Optional

import numpy as np

from boltz_data_pipeline.types import Tokenized


Expand Down
2 changes: 0 additions & 2 deletions src/simplefold/boltz_data_pipeline/crop/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@
from typing import Optional

import numpy as np
from scipy.spatial.distance import cdist

from boltz_data_pipeline import const
from boltz_data_pipeline.crop.cropper import Cropper
from boltz_data_pipeline.types import Tokenized
Expand Down
6 changes: 2 additions & 4 deletions src/simplefold/boltz_data_pipeline/feature/featurizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,10 @@
# Started from https://github.com/jwohlwend/boltz,
# licensed under MIT License, Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro.

import math
from typing import Optional

import numpy as np
import torch
from torch import Tensor, from_numpy
from torch.nn.functional import one_hot

from boltz_data_pipeline import const
from boltz_data_pipeline.feature.pad import pad_dim
from boltz_data_pipeline.feature.symmetry import (
Expand All @@ -22,6 +18,8 @@
get_ligand_symmetries,
)
from boltz_data_pipeline.types import Tokenized
from torch import Tensor, from_numpy
from torch.nn.functional import one_hot
from utils.boltz_utils import center_random_augmentation


Expand Down
6 changes: 2 additions & 4 deletions src/simplefold/boltz_data_pipeline/feature/symmetry.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,9 @@

import numpy as np
import torch

from boltz_data_pipeline import const
from boltz_data_pipeline.feature.pad import pad_dim
from utils.boltz_utils import lddt_dist
from utils.boltz_utils import weighted_minimum_rmsd_single
from utils.boltz_utils import lddt_dist, weighted_minimum_rmsd_single


def convert_atom_name(name: str) -> tuple[int, int, int, int]:
Expand Down Expand Up @@ -559,7 +557,7 @@ def get_ligand_symmetries(cropped, symmetries):
# for each molecule, get the symmetries
molecule_symmetries = []
for mol_name, start_mol, mol_id, mol_atom_names in index_mols:
if not mol_name in symmetries:
if mol_name not in symmetries:
continue
else:
swaps = []
Expand Down
2 changes: 1 addition & 1 deletion src/simplefold/boltz_data_pipeline/filter/dynamic/date.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
from datetime import datetime
from typing import Literal

from boltz_data_pipeline.types import Record
from boltz_data_pipeline.filter.dynamic.filter import DynamicFilter
from boltz_data_pipeline.types import Record


class DateFilter(DynamicFilter):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# Started from https://github.com/jwohlwend/boltz,
# licensed under MIT License, Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro.

from boltz_data_pipeline.types import Record
from boltz_data_pipeline.filter.dynamic.filter import DynamicFilter
from boltz_data_pipeline.types import Record


class MaxResiduesFilter(DynamicFilter):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# Started from https://github.com/jwohlwend/boltz,
# licensed under MIT License, Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro.

from boltz_data_pipeline.types import Record
from boltz_data_pipeline.filter.dynamic.filter import DynamicFilter
from boltz_data_pipeline.types import Record


class ResolutionFilter(DynamicFilter):
Expand Down
2 changes: 1 addition & 1 deletion src/simplefold/boltz_data_pipeline/filter/dynamic/size.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
# Started from https://github.com/jwohlwend/boltz,
# licensed under MIT License, Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro.

from boltz_data_pipeline.types import Record
from boltz_data_pipeline.filter.dynamic.filter import DynamicFilter
from boltz_data_pipeline.types import Record


class SizeFilter(DynamicFilter):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@

from pathlib import Path

from boltz_data_pipeline.types import Record
from boltz_data_pipeline.filter.dynamic.filter import DynamicFilter
from boltz_data_pipeline.types import Record


class SubsetFilter(DynamicFilter):
Expand Down
1 change: 0 additions & 1 deletion src/simplefold/boltz_data_pipeline/filter/static/filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from abc import ABC, abstractmethod

import numpy as np

from boltz_data_pipeline.types import Structure


Expand Down
3 changes: 1 addition & 2 deletions src/simplefold/boltz_data_pipeline/filter/static/ligand.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,9 @@
# licensed under MIT License, Copyright (c) 2024 Jeremy Wohlwend, Gabriele Corso, Saro Passaro.

import numpy as np

from boltz_data_pipeline import const
from boltz_data_pipeline.types import Structure
from boltz_data_pipeline.filter.static.filter import StaticFilter
from boltz_data_pipeline.types import Structure

LIGAND_EXCLUSION = {
"144",
Expand Down
3 changes: 1 addition & 2 deletions src/simplefold/boltz_data_pipeline/filter/static/polymer.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,10 @@
from dataclasses import dataclass

import numpy as np
from scipy.spatial.distance import cdist

from boltz_data_pipeline import const
from boltz_data_pipeline.filter.static.filter import StaticFilter
from boltz_data_pipeline.types import Structure
from scipy.spatial.distance import cdist


class MinimumLengthFilter(StaticFilter):
Expand Down
1 change: 0 additions & 1 deletion src/simplefold/boltz_data_pipeline/parse/a3m.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from typing import Optional, TextIO

import numpy as np

from boltz_data_pipeline import const
from boltz_data_pipeline.types import MSA, MSADeletion, MSAResidue, MSASequence

Expand Down
1 change: 0 additions & 1 deletion src/simplefold/boltz_data_pipeline/parse/csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@

import numpy as np
import pandas as pd

from boltz_data_pipeline import const
from boltz_data_pipeline.types import MSA, MSADeletion, MSAResidue, MSASequence

Expand Down
3 changes: 1 addition & 2 deletions src/simplefold/boltz_data_pipeline/parse/fasta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
from pathlib import Path

from Bio import SeqIO
from rdkit.Chem.rdchem import Mol

from boltz_data_pipeline.parse.yaml import parse_boltz_schema
from boltz_data_pipeline.types import Target
from rdkit.Chem.rdchem import Mol


def parse_fasta(path: Path, ccd: Mapping[str, Mol]) -> Target: # noqa: C901
Expand Down
15 changes: 7 additions & 8 deletions src/simplefold/boltz_data_pipeline/parse/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,25 +12,24 @@

import click
import numpy as np
from rdkit import rdBase, Chem
from rdkit.Chem import AllChem
from rdkit.Chem.rdchem import Conformer, Mol

from boltz_data_pipeline import const
from boltz_data_pipeline.types import (
Atom,
Bond,
Chain,
ChainInfo,
Connection,
Interface,
InferenceOptions,
Interface,
Record,
Residue,
Structure,
StructureInfo,
Target,
)
from rdkit import Chem, rdBase
from rdkit.Chem import AllChem
from rdkit.Chem.rdchem import Conformer, Mol

####################################################################################################
# DATACLASSES
Expand Down Expand Up @@ -795,7 +794,7 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912
for constraint in constraints:
if "bond" in constraint:
if "atom1" not in constraint["bond"] or "atom2" not in constraint["bond"]:
msg = f"Bond constraint was not properly specified"
msg = "Bond constraint was not properly specified"
raise ValueError(msg)

c1, r1, a1 = tuple(constraint["bond"]["atom1"])
Expand All @@ -805,15 +804,15 @@ def parse_boltz_schema( # noqa: C901, PLR0915, PLR0912
connections.append((c1, c2, r1, r2, a1, a2))
elif "pocket" in constraint:
if "binder" not in constraint["pocket"] or "contacts" not in constraint["pocket"]:
msg = f"Pocket constraint was not properly specified"
msg = "Pocket constraint was not properly specified"
raise ValueError(msg)

binder = constraint["pocket"]["binder"]
contacts = constraint["pocket"]["contacts"]

if len(pocket_binders) > 0:
if pocket_binders[-1] != chain_to_idx[binder]:
msg = f"Only one pocket binders is supported!"
msg = "Only one pocket binders is supported!"
raise ValueError(msg)
else:
pocket_residues[-1].extend([
Expand Down
3 changes: 1 addition & 2 deletions src/simplefold/boltz_data_pipeline/parse/yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
from pathlib import Path

import yaml
from rdkit.Chem.rdchem import Mol

from boltz_data_pipeline.parse.schema import parse_boltz_schema
from boltz_data_pipeline.types import Target
from rdkit.Chem.rdchem import Mol


def parse_yaml(path: Path, ccd: dict[str, Mol]) -> Target:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
from dataclasses import astuple, dataclass

import numpy as np

from boltz_data_pipeline import const
from boltz_data_pipeline.tokenize.tokenizer import Tokenizer
from boltz_data_pipeline.types import Input, Token, TokenBond, Tokenized
Expand Down
Loading