Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions bergson/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from .query.attributor import Attributor
from .query.faiss_index import FaissConfig
from .score.scorer import Scorer
from .sharding import ShardedMemmap, shard_status
from .utils.gradcheck import FiniteDiff
from .utils.load_from_optimizer import load_from_optimizer

Expand Down Expand Up @@ -54,5 +55,7 @@
"Scorer",
"ScoreConfig",
"QueryConfig",
"ShardedMemmap",
"shard_status",
"mix_autocorrelation_matrices",
]
2 changes: 2 additions & 0 deletions bergson/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
Query,
Reduce,
Score,
Status,
Test_Model_Configuration,
Trackstar,
Train,
Expand All @@ -38,6 +39,7 @@ class Main:
Query,
Reduce,
Score,
Status,
Trackstar,
Train,
Test_Model_Configuration,
Expand Down
12 changes: 11 additions & 1 deletion bergson/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
)
from bergson.utils.worker_utils import (
create_processor,
publish_shard,
setup_data_pipeline,
setup_model_and_peft,
)
Expand Down Expand Up @@ -156,6 +157,12 @@ def build(
if index_cfg.debug:
setup_reproducibility()

if index_cfg.sharded and preprocess_cfg.aggregation != "none":
raise ValueError(
"Sharded runs do not support gradient aggregation; per-shard "
"aggregates would be concatenated instead of summed."
)

index_cfg.partial_run_path.mkdir(parents=True, exist_ok=True)

ds, _ = setup_data_pipeline(index_cfg)
Expand All @@ -175,7 +182,10 @@ def build(
)

if dist_cfg.rank == 0:
shutil.move(index_cfg.partial_run_path, index_cfg.run_path)
if index_cfg.sharded:
publish_shard(index_cfg, num_items=len(ds))
else:
shutil.move(index_cfg.partial_run_path, index_cfg.run_path)

if dist_cfg.world_size < index_cfg.distributed.world_size:
parent_barrier(index_cfg.distributed)
78 changes: 74 additions & 4 deletions bergson/cli/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from dataclasses import dataclass

from simple_parsing import Serializable
from simple_parsing import Serializable, field

from ..build import build
from ..config.config import (
Expand All @@ -31,7 +31,7 @@
from ..process_grads import mix_autocorrelation_matrices
from ..query.query_index import query
from ..score.score import score_dataset
from ..utils.worker_utils import validate_run_path
from ..utils.worker_utils import prepare_shard, validate_run_path


@dataclass
Expand All @@ -52,6 +52,12 @@ class ApproxUnrolling(Serializable):
def execute(self):
from ..approx_unrolling.pipeline import approx_unrolling_pipeline

if self.index_cfg.sharded:
raise ValueError(
"approx_unrolling does not support sharded runs yet; "
"shard the build and score steps individually."
)

save_run_config(self, self.index_cfg.run_path)
approx_unrolling_pipeline(
self.index_cfg,
Expand Down Expand Up @@ -92,7 +98,17 @@ def execute(self):
f"{self.hessian_cfg.method}."
)

validate_run_path(self.index_cfg)
if self.index_cfg.sharded:
if self.hessian_cfg is not None:
raise ValueError(
"Sharded builds do not support simultaneous Hessian "
"estimation; Hessian factors cannot be merged across "
"independent shards yet. Run `bergson hessian` separately."
)
if prepare_shard(self, self.index_cfg):
return
else:
validate_run_path(self.index_cfg)

save_run_config(self, self.index_cfg.partial_run_path)
build(self.index_cfg, self.preprocess_cfg, self.hessian_cfg)
Expand All @@ -115,6 +131,12 @@ class Ekfac(Serializable):
def execute(self):
from ..hessians.pipeline import hessian_pipeline

if self.index_cfg.sharded:
raise ValueError(
"ekfac does not support sharded runs yet; "
"shard the build and score steps individually."
)

save_run_config(self, self.index_cfg.run_path)
hessian_pipeline(
self.index_cfg,
Expand All @@ -134,6 +156,11 @@ class Hessian(Serializable):

def execute(self):
"""Compute Hessian approximation."""
if self.index_cfg.sharded:
raise ValueError(
"hessian does not support sharded runs; Hessian factors "
"cannot be merged across independent shards yet."
)

validate_run_path(self.index_cfg)

Expand Down Expand Up @@ -197,6 +224,12 @@ class Reduce(Serializable):

def execute(self):
"""Reduce a gradient index."""
if self.index_cfg.sharded:
raise ValueError(
"reduce does not support sharded runs; per-shard aggregates "
"would be concatenated instead of summed."
)

if self.index_cfg.projection_dim != 0:
print(f"Using a projection dimension of {self.index_cfg.projection_dim}. ")

Expand All @@ -222,7 +255,12 @@ def execute(self):
if self.index_cfg.projection_dim != 0:
print(f"Using a projection dimension of {self.index_cfg.projection_dim}. ")

validate_run_path(self.index_cfg)
if self.index_cfg.sharded:
if prepare_shard(self, self.index_cfg):
return
else:
validate_run_path(self.index_cfg)

save_run_config(self, self.index_cfg.partial_run_path)
score_dataset(self.index_cfg, self.score_cfg, self.preprocess_cfg)

Expand All @@ -238,6 +276,12 @@ class Trackstar(Serializable):
def execute(self):
from .trackstar import trackstar

if self.index_cfg.sharded:
raise ValueError(
"trackstar does not support sharded runs yet; "
"shard the build and score steps individually."
)

save_run_config(self, self.index_cfg.run_path)
trackstar(self.index_cfg, self.trackstar_cfg)

Expand All @@ -251,6 +295,32 @@ def execute(self):
run_magic(self)


@dataclass
class Status(Serializable):
"""Report the progress of a sharded run: which shards are published,
in progress, or missing."""

run_path: str = field(positional=True)

def execute(self):
"""Print the shard inventory of a run path."""
from ..sharding import shard_status

published, partial, num_shards = shard_status(self.run_path)
if num_shards is None:
print(f"{self.run_path} is not a sharded run.")
return

missing = sorted(set(range(num_shards)) - published.keys() - partial.keys())
print(f"{self.run_path}: {len(published)}/{num_shards} shards published")
if partial:
print(f" in progress or crashed: {sorted(partial)}")
if missing:
print(f" not started: {missing}")
if not partial and not missing:
print(" run complete")


@dataclass
class Test_Model_Configuration:
"""Test gradient consistency across padding and batch composition.
Expand Down
70 changes: 70 additions & 0 deletions bergson/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@
import torch
from simple_parsing import Serializable, field

from bergson.sharding import SHARDS_DIRNAME, shard_dir_name


@dataclass
class DataConfig(Serializable):
Expand Down Expand Up @@ -451,9 +453,77 @@ class IndexConfig(AttributionConfig, Serializable):
modules: list[str] = field(default_factory=list)
"""Modules to use for the query. If empty, all modules will be used."""

num_shards: int = 1
"""Split the dataset into this many contiguous shards, processing only
`shard_id`'s slice. Each shard is an independent single-node run that
publishes into `run_path/shards/`; readers present the published shards
as one index. Incompatible with `nnode` > 1."""

shard_id: int | None = None
"""Which shard to process when `num_shards` > 1. If unset, inferred from
SLURM_ARRAY_TASK_ID or SLURM_PROCID."""

def __post_init__(self):
super().__post_init__()

if self.num_shards < 1:
raise ValueError(f"num_shards must be >= 1, got {self.num_shards}")
if self.num_shards > 1 and self.distributed.nnode > 1:
raise ValueError(
"num_shards launches independent single-node runs and cannot "
"be combined with nnode > 1. Use nnode for coordinated "
"multi-node runs, or num_shards for embarrassingly parallel "
"ones, not both."
)
if self.shard_id is not None and not self.sharded:
raise ValueError("shard_id requires num_shards > 1")
if self.shard_id is not None and not 0 <= self.shard_id < self.num_shards:
raise ValueError(
f"shard_id must be in [0, {self.num_shards}), got {self.shard_id}"
)

@property
def sharded(self) -> bool:
"""Whether this run builds one shard of a sharded index."""
return self.num_shards > 1

@property
def resolved_shard_id(self) -> int:
"""The shard to process, from config or SLURM environment variables."""
if self.shard_id is not None:
return self.shard_id

for var in ("SLURM_ARRAY_TASK_ID", "SLURM_PROCID"):
if var in os.environ:
shard_id = int(os.environ[var])
if not 0 <= shard_id < self.num_shards:
raise ValueError(
f"{var}={shard_id} is out of range for "
f"num_shards={self.num_shards}"
)
return shard_id

raise ValueError(
"num_shards > 1 but no shard id found. Set it with --shard_id, "
"or via SLURM_ARRAY_TASK_ID/SLURM_PROCID."
)

@property
def final_run_path(self) -> Path:
"""Where this run's finished artifacts are published."""
if self.sharded:
name = shard_dir_name(self.resolved_shard_id, self.num_shards)
return Path(self.run_path) / SHARDS_DIRNAME / name

return Path(self.run_path)

@property
def partial_run_path(self) -> Path:
"""Temporary path to use while writing build artifacts."""
if self.sharded:
final = self.final_run_path
return final.with_name(final.name + ".part")

return Path(self.run_path + ".part")


Expand Down
64 changes: 64 additions & 0 deletions bergson/config/config_io.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import subprocess
from datetime import datetime, timezone
from importlib.metadata import PackageNotFoundError
Expand All @@ -12,6 +13,11 @@

CONFIG_FILENAME = "config.yaml"

# Per-invocation identity, not run configuration: these index_cfg fields may
# differ between the shards of one sharded run and are stripped from the
# canonical config.yaml shared by all shards.
EPHEMERAL_INDEX_FIELDS = ("shard_id", "overwrite")


def _resolve(path: str | Path) -> Path:
"""Return the path to a ``config.yaml``, accepting either a dir or a file."""
Expand Down Expand Up @@ -78,6 +84,64 @@ def save_run_config(command: Any, run_path: str | Path):
_write([step], Path(run_path) / CONFIG_FILENAME)


def canonical_steps(command: Any) -> list[dict[str, Any]]:
"""One-step ``steps`` list with per-invocation identity stripped.

All shards of a sharded run must produce the same canonical steps, so
fields that legitimately differ between shards (``shard_id``,
``overwrite``, ``distributed.node_rank``) are removed.
"""
# Round-trip through YAML so the comparison in publish_canonical_config
# sees the same plain types a reread of the file would produce.
steps = yaml.safe_load(
yaml.safe_dump([{(type(command).__name__).lower(): command.to_dict()}])
)

for parsed_step in steps:
for cmd_dict in parsed_step.values():
index_cfg = (cmd_dict or {}).get("index_cfg")
if not index_cfg:
continue
for field in EPHEMERAL_INDEX_FIELDS:
index_cfg.pop(field, None)
(index_cfg.get("distributed") or {}).pop("node_rank", None)

return steps


def publish_canonical_config(command: Any, run_path: str | Path):
"""Write the ``config.yaml`` shared by all shards of a sharded run.

The first shard to arrive writes it atomically; every later shard
verifies its own canonical config matches and errors out otherwise,
so shards built from different configurations can never mix in one
run_path.
"""
path = Path(run_path) / CONFIG_FILENAME
steps = canonical_steps(command)

if path.exists():
existing = read_config(path)
if existing["steps"] != steps:
raise ValueError(
f"{path} was written by a run with a different configuration. "
f"Refusing to add shards to it; use a fresh run_path or "
f"rerun with the original configuration."
)
return

doc: dict[str, Any] = {"steps": steps, "metadata": make_metadata()}
path.parent.mkdir(parents=True, exist_ok=True)

# Concurrent shards may race to create the file; writing a temp file and
# renaming it into place is atomic, and every racer writes identical
# steps, so last-writer-wins is safe.
tmp_path = path.with_name(f".{CONFIG_FILENAME}.{os.getpid()}.tmp")
with tmp_path.open("w") as f:
yaml.safe_dump(doc, f, sort_keys=False)
os.rename(tmp_path, path)


def save_pipeline_config(steps: list[tuple[str, Any]], run_path: str | Path | None):
"""Write a multi-step ``config.yaml`` to ``run_path``.

Expand Down
Loading
Loading