diff --git a/docs/en/advanced/miles_server_args.md b/docs/en/advanced/miles_server_args.md index e8d70db59f..7190edcc48 100644 --- a/docs/en/advanced/miles_server_args.md +++ b/docs/en/advanced/miles_server_args.md @@ -136,7 +136,7 @@ Arguments for sampling strategies and data filtering during rollout and buffer m | `--partial-rollout` | Enable partial rollout for **dynamic sampling**: cache partially generated (aborted/unfinished) samples and resume generation in later rollout steps, reducing wasted compute for long responses. Cached samples are stored in the rollout buffer and can be prioritized/selected via `--buffer-filter-path` (default FIFO behavior). See [Partial Rollout](../get_started/quick_start.md#partial-rollout). | `False` | bool flag (set to enable) | Miles Native | | `--mask-offpolicy-in-partial-rollout` | When using partial rollout, mask the previously generated (cached) response tokens so they do not contribute to the loss; only tokens generated after resuming are used for training. This helps avoid training on a cached prefix produced by an older policy version. See [Partial Rollout](../get_started/quick_start.md#partial-rollout). | `False` | bool flag (set to enable) | Miles Native | | `--buffer-filter-path` | Path to the function to filter or sort samples in the rollout buffer before training. [Ref](../get_started/customization.md#5-buffer-filter---buffer-filter-path) | `None` | Type: str | Miles Native | -| `--rollout-sample-filter-path` | Path to the function that marks individual samples to be excluded from loss calculation. [Ref](../get_started/customization.md#6-rollout-sample-filter---rollout-sample-filter-path) | `None` | Type: str | Miles Native | +| `--rollout-sample-filter-path` | Path to the function that marks individual samples to be excluded from reward/advantage normalization and loss calculation. [Ref](../get_started/customization.md#6-rollout-sample-filter---rollout-sample-filter-path) | `None` | Type: str | Miles Native | | `--rollout-all-samples-process-path` | Path to the function to process all samples (including filtered ones) after rollout. [Ref](../get_started/customization.md#7-rollout-all-samples-process---rollout-all-samples-process-path) | `None` | Type: str | Miles Native | ## Data Arguments diff --git a/docs/en/get_started/customization.md b/docs/en/get_started/customization.md index 8aa63c23fb..fdc23352c8 100644 --- a/docs/en/get_started/customization.md +++ b/docs/en/get_started/customization.md @@ -13,7 +13,7 @@ Below is a summary of all available customization interfaces and their purposes. | [`--custom-rm-path`](#3-reward-model---custom-rm-path) | Implement custom reward computation logic. | | [`--dynamic-sampling-filter-path`](#4-dynamic-sampling-filter---dynamic-sampling-filter-path) | Filter samples during dynamic sampling (e.g., DAPO). | | [`--buffer-filter-path`](#5-buffer-filter---buffer-filter-path) | Filter samples in the rollout buffer before training. | -| [`--rollout-sample-filter-path`](#6-rollout-sample-filter---rollout-sample-filter-path) | Determine if individual samples participate in loss calculation. | +| [`--rollout-sample-filter-path`](#6-rollout-sample-filter---rollout-sample-filter-path) | Determine if individual samples participate in training. | | [`--rollout-all-samples-process-path`](#7-rollout-all-samples-process---rollout-all-samples-process-path) | Process all samples (including filtered ones) after rollout. | | [`--rollout-data-postprocess-path`](#8-rollout-data-postprocess---rollout-data-postprocess-path) | Post-process rollout data after log probs are computed. | | [`--custom-loss-function-path`](#9-custom-loss-function---custom-loss-function-path) | Implement custom training loss computation. | @@ -161,7 +161,7 @@ def buffer_filter(samples: list[list[Sample]]) -> list[list[Sample]] **Default**: `None` -**Purpose**: Determine whether individual samples participate in loss calculation. +**Purpose**: Determine whether individual samples participate in training. **Signature**: ```python @@ -169,6 +169,7 @@ def filter_function(args, samples: list[Sample]) -> None ``` **Note**: This function should directly modify the `remove_sample` attribute of each `Sample` object. +Removed samples stay available in rollout artifacts, but are excluded from reward/advantage normalization and loss calculation. **Use Cases**: - Filtering samples based on response quality @@ -438,5 +439,3 @@ For detailed explanation of R3 and MilesRouter, see [Miles Router](../advanced/m ```python def custom_model_provider(pre_process: bool, post_process: bool, vp_stage: int | None = None) -> GPTModel ``` - - diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index d0a0daec75..97c105cc47 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -24,6 +24,7 @@ from miles.utils.replay_base import all_replay_managers from miles.utils.timer import Timer, inverse_timer, timer from miles.utils.tracking_utils import init_tracking +from miles.utils.training_semantics import validate_loss_masks_for_removed_samples from miles.utils.types import RolloutBatch from ...utils.profile_utils import TrainProfiler @@ -334,6 +335,13 @@ def _basic_batch_summary(): if got != n: _add_error(f"{key!r} length mismatch: got {got}, expected {n}") + remove_samples = rollout_data.get("remove_samples") + if remove_samples is not None: + if not _is_seq(remove_samples): + _add_error(f"'remove_samples' must be list/tuple, got {type(remove_samples).__name__}") + elif len(remove_samples) != n: + _add_error(f"'remove_samples' length mismatch: got {len(remove_samples)}, expected {n}") + token_key = None if _present("tokens"): token_key = "tokens" @@ -416,8 +424,6 @@ def _basic_batch_summary(): continue mask_sum = _sum_float(mask) - if mask_sum <= 0: - _add_error(f"loss_masks[{i}] has no active tokens, sum={mask_sum}, response_len={resp}") if mask_sum > resp: # Warning-only: float/weighted masks can legitimately have sum > resp. _add_warning(f"loss_masks[{i}] sum={mask_sum} exceeds response_len={resp} (expected for float/weighted masks)") @@ -431,6 +437,11 @@ def _basic_batch_summary(): except Exception as e: _add_warning(f"binary check failed for loss_masks[{i}]: {type(e).__name__}: {e}") + try: + validate_loss_masks_for_removed_samples(rollout_data["loss_masks"], response_lengths, remove_samples) + except ValueError as e: + _add_error(str(e)) + # max_seq_lens if present. if _present("max_seq_lens"): xs = rollout_data["max_seq_lens"] diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index 317286f6e0..9ee860c6ce 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -18,6 +18,7 @@ get_reinforce_plus_plus_baseline_advantages, get_reinforce_plus_plus_returns, ) +from miles.utils.training_semantics import validate_loss_masks_for_removed_samples from miles.utils.types import RolloutBatch from .cp_utils import ( @@ -305,6 +306,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) values: None | list[torch.Tensor] = rollout_data.get("values") response_lengths: list[int] = rollout_data.get("response_lengths") loss_masks: list[torch.Tensor] = rollout_data.get("loss_masks") + remove_samples: list[bool] | None = rollout_data.get("remove_samples") total_lengths: list[int] = rollout_data.get("total_lengths") max_seq_lens: list[int] | None = rollout_data.get("max_seq_lens", None) @@ -312,6 +314,8 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) if log_probs is None and values is None: return + validate_loss_masks_for_removed_samples(loss_masks, response_lengths, remove_samples) + if args.kl_coef == 0 or not log_probs: # when kl_coef is 0, we won't compute ref_log_prob xs = log_probs if log_probs is not None else values @@ -352,6 +356,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) rewards=rewards, kl=kl, loss_masks=loss_masks, + remove_samples=remove_samples, response_lengths=response_lengths, total_lengths=total_lengths, kl_coef=args.kl_coef, @@ -429,7 +434,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) all_masks = torch.cat(mask_chunks) - if all_masks.numel() > 0: + if all_masks.numel() > 0 and all_masks.sum().item() > 0: assert ( all_advs.size() == all_masks.size() ), f"Shape mismatch before whitening: advantages {all_advs.size()}, masks {all_masks.size()}" @@ -444,6 +449,14 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch) chunk_lengths = [chunk.size(0) for chunk in advantages] advantages = list(torch.split(whitened_advs_flat, chunk_lengths)) + if remove_samples is not None: + if len(remove_samples) != len(advantages): + raise ValueError(f"remove_samples length {len(remove_samples)} != advantages length {len(advantages)}") + for i, remove_sample in enumerate(remove_samples): + if remove_sample: + advantages[i] = torch.zeros_like(advantages[i]) + returns[i] = torch.zeros_like(returns[i]) + rollout_data["advantages"] = advantages rollout_data["returns"] = returns diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 7d968f0ed8..09ab9cb4c5 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -48,6 +48,7 @@ from miles.utils.misc import load_function from miles.utils.ray_utils import Box from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions +from miles.utils.training_semantics import post_process_rewards_excluding_removed from miles.utils.tracking_utils import init_tracking from miles.utils.types import Sample @@ -678,28 +679,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]): if self.custom_reward_post_process_func is not None: return self.custom_reward_post_process_func(self.args, samples) - raw_rewards = [sample.get_reward_value(self.args) for sample in samples] - if ( - self.args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"] - and self.args.rewards_normalization - ): - # group norm - rewards = torch.tensor(raw_rewards, dtype=torch.float) - if rewards.shape[-1] == self.args.n_samples_per_prompt * self.args.rollout_batch_size: - rewards = rewards.reshape(-1, self.args.n_samples_per_prompt) - else: - # when samples count are not equal in each group - rewards = rewards.view(-1, rewards.shape[-1]) - mean = rewards.mean(dim=-1, keepdim=True) - rewards = rewards - mean - - if self.args.advantage_estimator in ["grpo", "gspo"] and self.args.grpo_std_normalization: - std = rewards.std(dim=-1, keepdim=True) - rewards = rewards / (std + 1e-6) - - return raw_rewards, rewards.flatten().tolist() - - return raw_rewards, raw_rewards + return post_process_rewards_excluding_removed(self.args, samples) def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sample]]): """ @@ -722,6 +702,7 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl "raw_reward": raw_rewards, "truncated": [1 if sample.status == Sample.Status.TRUNCATED else 0 for sample in samples], "sample_indices": [sample.index for sample in samples], + "remove_samples": [sample.remove_sample for sample in samples], } # loss mask @@ -873,6 +854,7 @@ def _stat(xs): "rewards", "truncated", "loss_masks", + "remove_samples", "round_number", "sample_indices", "rollout_log_probs", diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 1d1ada930a..b52e8e35b8 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1495,10 +1495,10 @@ def add_rollout_buffer_arguments(parser): default=None, help=( "Path to the rollout sample filter function. " - "This function determines whether a sample will participate in loss calculation. " + "This function determines whether a sample will participate in training. " "The function should take args and samples (list[Sample]) as input, and return None. " "Please directly modify the remove_sample attribute of Sample. " - "Note: This attribute does not determine whether the sample participates in advantage normalization." + "Removed samples are excluded from reward/advantage normalization and loss calculation." ), ) parser.add_argument( diff --git a/miles/utils/ppo_utils.py b/miles/utils/ppo_utils.py index 634c9d430d..38e3d53505 100644 --- a/miles/utils/ppo_utils.py +++ b/miles/utils/ppo_utils.py @@ -212,6 +212,7 @@ def get_reinforce_plus_plus_returns( rewards: torch.Tensor, kl: list[torch.Tensor], loss_masks: list[torch.Tensor], + remove_samples: list[bool] | None, response_lengths: list[int], total_lengths: list[int], kl_coef: float, @@ -224,6 +225,7 @@ def get_reinforce_plus_plus_returns( rewards (Tensor): A tensor of scalar rewards for each sequence. kl (List[Tensor]): List of per-token KL divergence tensors for sequence chunks. loss_masks (List[Tensor]): List of response-only loss masks for each full sequence. + remove_samples (List[bool] | None): Whether each sequence is intentionally removed. response_lengths (List[int]): The full length of each response sequence. total_lengths (List[int]): The full length of each sequence (prompt + response). kl_coef (float): Coefficient for the KL penalty. @@ -241,6 +243,7 @@ def get_reinforce_plus_plus_returns( for i in range(len(rewards)): local_kl_chunk = kl[i] total_len, response_len = total_lengths[i], response_lengths[i] + remove_sample = remove_samples is not None and remove_samples[i] if cp_size > 1: # Step 1,2:Gather all chunks and token_offsets from all ranks and reconstruct the full response tensor by splitting and placing each part @@ -251,6 +254,15 @@ def get_reinforce_plus_plus_returns( full_kl_response = local_kl_chunk # Step 3: Compute returns on full response kl tensor. + if remove_sample: + returns_for_seq = torch.zeros_like(full_kl_response) + if cp_size > 1: + from miles.backends.training_utils.cp_utils import slice_log_prob_with_cp + + returns_for_seq = slice_log_prob_with_cp(returns_for_seq, total_len, response_len) + final_returns_chunks.append(returns_for_seq) + continue + token_level_rewards = -kl_coef * full_kl_response full_mask = loss_masks[i] assert full_mask.sum().item() > 0, f"Sequence at index {i} is fully masked." diff --git a/miles/utils/training_semantics.py b/miles/utils/training_semantics.py new file mode 100644 index 0000000000..f8375a9c25 --- /dev/null +++ b/miles/utils/training_semantics.py @@ -0,0 +1,82 @@ +import math +from typing import Any + +import torch + +from miles.utils.types import Sample + + +def _reward_value(args: Any, sample: Sample) -> Any: + return sample.get_reward_value(args) + + +def _finite_reward(value: Any, sample_index: int) -> float: + if not isinstance(value, (int, float)) or not math.isfinite(float(value)): + raise ValueError(f"sample {sample_index} has non-finite reward for training: {value!r}") + return float(value) + + +def _group_ranges(args: Any, sample_count: int) -> list[range]: + samples_per_prompt = args.n_samples_per_prompt + if sample_count == samples_per_prompt * args.rollout_batch_size: + return [range(start, start + samples_per_prompt) for start in range(0, sample_count, samples_per_prompt)] + return [range(0, sample_count)] + + +def post_process_rewards_excluding_removed(args: Any, samples: list[Sample]) -> tuple[list[Any], list[float]]: + raw_rewards = [_reward_value(args, sample) for sample in samples] + remove_samples = [sample.remove_sample for sample in samples] + + processed_rewards = [0.0] * len(samples) + if not ( + args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"] + and args.rewards_normalization + ): + for i, raw_reward in enumerate(raw_rewards): + if not remove_samples[i]: + processed_rewards[i] = _finite_reward(raw_reward, i) + return raw_rewards, processed_rewards + + for group in _group_ranges(args, len(samples)): + kept_indices = [i for i in group if not remove_samples[i]] + if not kept_indices: + continue + + kept_rewards = torch.tensor( + [_finite_reward(raw_rewards[i], i) for i in kept_indices], + dtype=torch.float, + ) + kept_rewards = kept_rewards - kept_rewards.mean() + + if args.advantage_estimator in ["grpo", "gspo"] and args.grpo_std_normalization: + if kept_rewards.numel() > 1: + std = kept_rewards.std() + else: + std = torch.tensor(0.0, dtype=kept_rewards.dtype) + kept_rewards = kept_rewards / (std + 1e-6) + + for i, reward in zip(kept_indices, kept_rewards.tolist(), strict=True): + processed_rewards[i] = reward + + return raw_rewards, processed_rewards + + +def validate_loss_masks_for_removed_samples( + loss_masks: list[Any], + response_lengths: list[int], + remove_samples: list[bool] | None, +) -> None: + if remove_samples is None: + remove_samples = [False] * len(loss_masks) + if len(remove_samples) != len(loss_masks): + raise ValueError(f"remove_samples length {len(remove_samples)} != loss_masks length {len(loss_masks)}") + + for i, (loss_mask, response_length, remove_sample) in enumerate( + zip(loss_masks, response_lengths, remove_samples, strict=True) + ): + mask_sum = loss_mask.sum().item() if torch.is_tensor(loss_mask) else sum(loss_mask) + if mask_sum <= 0 and not remove_sample: + raise ValueError( + f"loss_masks[{i}] has no active tokens, sum={mask_sum}, " + f"response_len={response_length}, remove_samples[{i}] is false" + ) diff --git a/miles/utils/types.py b/miles/utils/types.py index 4d7d6ef9b2..c2034d5f08 100644 --- a/miles/utils/types.py +++ b/miles/utils/types.py @@ -254,7 +254,7 @@ class ParamInfo: # A dict-based batch produced along the rollout -> training path # In Megatron backend, several fields are converted to torch.Tensor lists on GPU # before being consumed by data iterators (see megatron_utils.actor._get_rollout_data). -RolloutBatch = dict[str, list[torch.Tensor] | list[int] | list[float] | list[str]] +RolloutBatch = dict[str, list[torch.Tensor] | list[int] | list[float] | list[str] | list[bool]] @dataclass diff --git a/tests/fast/utils/test_training_semantics.py b/tests/fast/utils/test_training_semantics.py new file mode 100644 index 0000000000..41518b515a --- /dev/null +++ b/tests/fast/utils/test_training_semantics.py @@ -0,0 +1,60 @@ +from argparse import Namespace + +import pytest +import torch + +from miles.utils.training_semantics import ( + post_process_rewards_excluding_removed, + validate_loss_masks_for_removed_samples, +) +from miles.utils.types import Sample + + +def _args(**overrides): + return Namespace( + advantage_estimator="grpo", + rewards_normalization=True, + grpo_std_normalization=False, + n_samples_per_prompt=2, + rollout_batch_size=2, + reward_key=None, + **overrides, + ) + + +def _sample(reward, *, remove_sample=False): + return Sample(reward=reward, remove_sample=remove_sample) + + +def test_removed_samples_are_excluded_from_group_reward_normalization(): + samples = [ + _sample(1.0), + _sample(None, remove_sample=True), + _sample(3.0), + _sample(5.0), + ] + + raw_rewards, rewards = post_process_rewards_excluding_removed(_args(), samples) + + assert raw_rewards == [1.0, None, 3.0, 5.0] + assert rewards == [0.0, 0.0, -1.0, 1.0] + + +def test_non_removed_sample_with_missing_reward_raises(): + with pytest.raises(ValueError, match="non-finite reward"): + post_process_rewards_excluding_removed(_args(), [_sample(None)]) + + +def test_zero_loss_mask_requires_explicit_removed_sample(): + validate_loss_masks_for_removed_samples( + [torch.tensor([0, 0]), torch.tensor([1, 0])], + [2, 2], + [True, False], + ) + + with pytest.raises(ValueError, match="remove_samples\\[0\\] is false"): + validate_loss_masks_for_removed_samples( + [torch.tensor([0, 0])], + [2], + [False], + )