Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
660 changes: 660 additions & 0 deletions nemo_rl/algorithms/single_controller.py

Large diffs are not rendered by default.

201 changes: 201 additions & 0 deletions nemo_rl/algorithms/staleness_sampler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Prompt-group batch selection strategies for SingleController metadata."""

from __future__ import annotations

from dataclasses import dataclass
from typing import Optional

from nemo_rl.data_plane import KVBatchMeta


@dataclass(frozen=True)
class PromptGroup:
"""Indices and scheduling metadata for one prompt group."""

group_id: str
indices: list[int]
weight_version: int | None
committed: bool
expected_num_samples: int

@property
def is_complete(self) -> bool:
return len(self.indices) == self.expected_num_samples


class StalenessSampler:
"""Select complete prompt groups inside a version staleness window."""

def __init__(self, max_staleness_versions: int):
self.max_staleness_versions = max_staleness_versions

def select_indices(
self,
meta: KVBatchMeta,
*,
trainer_version: int,
min_prompt_groups: int,
generations_per_prompt: int,
) -> Optional[list[int]]:
eligible: list[tuple[int, int, PromptGroup]] = []
for group in _prompt_groups(meta, generations_per_prompt):
if not group.committed or not group.is_complete:
continue
if group.weight_version is None or group.weight_version > trainer_version:
continue
lag = trainer_version - group.weight_version
if lag > self.max_staleness_versions:
continue
eligible.append((lag, group.indices[0], group))

if len(eligible) < min_prompt_groups:
return None

eligible.sort(key=lambda item: (item[0], item[1]))
groups = [item[2] for item in eligible[:min_prompt_groups]]
return _flatten_group_indices(groups)

def select_one_group(
self,
meta: KVBatchMeta,
*,
trainer_version: int,
generations_per_prompt: int,
) -> Optional[list[int]]:
eligible: list[tuple[int, int, PromptGroup]] = []
for group in _prompt_groups(meta, generations_per_prompt):
if not group.committed or not group.is_complete:
continue
if group.weight_version is None or group.weight_version > trainer_version:
continue
lag = trainer_version - group.weight_version
if lag > self.max_staleness_versions:
continue
eligible.append((lag, group.indices[0], group))

if not eligible:
return None

eligible.sort(key=lambda item: (item[0], item[1]))
return _flatten_group_indices([eligible[0][2]])

def evictable_indices(
self,
meta: KVBatchMeta,
*,
trainer_version: int,
generations_per_prompt: int,
) -> list[int]:
groups = []
for group in _prompt_groups(meta, generations_per_prompt):
if group.weight_version is None or not group.is_complete:
continue
lag = trainer_version - group.weight_version
if lag > self.max_staleness_versions:
groups.append(group)
return _flatten_group_indices(groups)


def count_prompt_groups(
meta: KVBatchMeta,
*,
generations_per_prompt: int,
) -> int:
"""Count complete prompt groups represented by ``meta``."""
return sum(
1 for group in _prompt_groups(meta, generations_per_prompt) if group.is_complete
)


def min_weight_version(meta: KVBatchMeta) -> int | None:
"""Smallest ``weight_version`` across per-sample tags, or None if absent."""
versions = [
v for v in (_weight_version(tag) for tag in meta.tags or []) if v is not None
]
return min(versions) if versions else None


def _prompt_groups(
meta: KVBatchMeta,
generations_per_prompt: int,
) -> list[PromptGroup]:
tags = meta.tags or [{} for _ in meta.sample_ids]
grouped: dict[str, list[int]] = {}
first_tag: dict[str, dict] = {}

for idx, sample_id in enumerate(meta.sample_ids):
tag = tags[idx] if idx < len(tags) else {}
group_id = str(tag.get("group_id") or _group_id_from_sample_id(sample_id))
grouped.setdefault(group_id, []).append(idx)
first_tag.setdefault(group_id, tag)

groups: list[PromptGroup] = []
for group_id, indices in grouped.items():
tag = first_tag[group_id]
expected = _as_int(
tag.get(
"expected_num_samples",
tag.get(
"expected_num_keys",
tag.get("generations_per_prompt", generations_per_prompt),
),
)
)
groups.append(
PromptGroup(
group_id=group_id,
indices=indices,
weight_version=_weight_version(tag),
committed=_as_bool(tag.get("committed", True)),
expected_num_samples=expected or generations_per_prompt,
)
)
groups.sort(key=lambda group: group.indices[0])
return groups


def _flatten_group_indices(groups: list[PromptGroup]) -> list[int]:
return [idx for group in groups for idx in group.indices]


def _group_id_from_sample_id(sample_id: str) -> str:
prefix, sep, suffix = sample_id.rpartition("_g")
if sep and suffix.isdigit():
return prefix
return sample_id


def _weight_version(tag: dict) -> int | None:
value = tag.get("weight_version", tag.get("version"))
return _as_int(value)


def _as_int(value) -> int | None:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None


def _as_bool(value) -> bool:
if isinstance(value, bool):
return value
if isinstance(value, str):
return value.lower() in {"1", "true", "yes"}
return bool(value)
25 changes: 25 additions & 0 deletions nemo_rl/data_plane/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,31 @@ def concat(self, *others: "KVBatchMeta") -> "KVBatchMeta":
sample_ids=sample_ids, sequence_lengths=seq_lens, tags=tags
)

def drop(self, indices: "Sequence[int]") -> "KVBatchMeta | None":
"""Complement of :meth:`subset`. Returns ``None`` when all rows are dropped."""
dropped = set(indices)
keep = [i for i in range(self.size) if i not in dropped]
if not keep:
return None
return self.subset(keep)

def with_fields(self, field_names: "Sequence[str]") -> "KVBatchMeta":
"""Return a copy with ``field_names`` merged into ``fields`` (deduped, order-preserving)."""
merged = list(dict.fromkeys([*(self.fields or []), *field_names]))
return KVBatchMeta(
partition_id=self.partition_id,
task_name=self.task_name,
sample_ids=list(self.sample_ids),
fields=merged,
sequence_lengths=(
list(self.sequence_lengths)
if self.sequence_lengths is not None
else None
),
extra_info=dict(self.extra_info or {}),
tags=[dict(tag) for tag in self.tags] if self.tags is not None else None,
)


class DataPlaneClient(ABC):
"""Stable, swappable data-plane boundary.
Expand Down
79 changes: 79 additions & 0 deletions nemo_rl/data_plane/worker_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -505,3 +505,82 @@ def get_reference_policy_logprobs_presharded(
tq_field="reference_policy_logprobs",
)
del result

# ── split-API entrypoints (SC async path) ──────────────────────────────
#
# The split path lets SingleController drive forward/backward per
# microbatch (or per pipeline-batch on Megatron) without stepping the
# optimizer until a full logical batch has accumulated. Backend
# methods (``begin_train_step``, ``train_microbatch``,
# ``finish_train_step``, ``abort_train_step``) own the train-step
# state machine; this mixin just gates them on TQ-presharded data.

@wrap_with_nvtx_name("policy_worker/begin_train_step_presharded")
def begin_train_step_presharded(
self,
step_id: str,
loss_fn: Any,
gbs: Optional[int] = None,
mbs: Optional[int] = None,
) -> None:
"""Open a logical train step. No fetch — pure lifecycle.

The backend stores ``step_id`` / ``loss_fn`` / ``gbs`` / ``mbs``,
clears gradients, and initialises accumulators for
``local_valid_seqs`` / ``local_valid_toks`` and any per-microbatch
metrics. Optimizer state is untouched here.
"""
self.begin_train_step( # type: ignore[attr-defined]
step_id=step_id,
loss_fn=loss_fn,
gbs=gbs,
mbs=mbs,
)

@wrap_with_nvtx_name("policy_worker/train_microbatch_presharded")
def train_microbatch_presharded(
self,
step_id: str,
meta: "KVBatchMeta",
) -> dict[str, Any]:
"""Per-rank microbatch entrypoint. Fetch → packing prep → forward+backward.

Gradients accumulate into ``.grad`` across calls; no
``optimizer.step`` here. Returns per-microbatch metrics (loss,
local_valid_*); the backend folds them into the step accumulator
and the caller may surface them for diagnostics.
"""
data = self._fetch(meta)
data = self._attach_or_repack_pack_metadata(data, meta)
return self.train_microbatch( # type: ignore[attr-defined]
step_id=step_id,
data=data,
)

@wrap_with_nvtx_name("policy_worker/finish_train_step_presharded")
def finish_train_step_presharded(
self,
step_id: str,
) -> dict[str, Any]:
"""Close a logical train step. No fetch — pure lifecycle.

Backend all-reduces accumulated ``local_valid_seqs/toks``,
rescales gradients to the final global normalization, runs grad
clip, steps the optimizer + scheduler, then zeros gradients.
Returns the aggregated step result (``loss``, ``grad_norm``,
``all_mb_metrics``, …).
"""
return self.finish_train_step(step_id=step_id) # type: ignore[attr-defined]

@wrap_with_nvtx_name("policy_worker/abort_train_step_presharded")
def abort_train_step_presharded(
self,
step_id: str,
) -> None:
"""Discard partial train-step state without stepping the optimizer.

Used when SC decides the logical batch will not complete (e.g.
weight-sync triggered mid-step). Backend drops accumulators and
zeros gradients.
"""
self.abort_train_step(step_id=step_id) # type: ignore[attr-defined]
Loading
Loading