Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/en/advanced/miles_server_args.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Arguments for sampling strategies and data filtering during rollout and buffer m
| `--partial-rollout` | Enable partial rollout for **dynamic sampling**: cache partially generated (aborted/unfinished) samples and resume generation in later rollout steps, reducing wasted compute for long responses. Cached samples are stored in the rollout buffer and can be prioritized/selected via `--buffer-filter-path` (default FIFO behavior). See [Partial Rollout](../get_started/quick_start.md#partial-rollout). | `False` | bool flag (set to enable) | Miles Native |
| `--mask-offpolicy-in-partial-rollout` | When using partial rollout, mask the previously generated (cached) response tokens so they do not contribute to the loss; only tokens generated after resuming are used for training. This helps avoid training on a cached prefix produced by an older policy version. See [Partial Rollout](../get_started/quick_start.md#partial-rollout). | `False` | bool flag (set to enable) | Miles Native |
| `--buffer-filter-path` | Path to the function to filter or sort samples in the rollout buffer before training. [Ref](../get_started/customization.md#5-buffer-filter---buffer-filter-path) | `None` | Type: str | Miles Native |
| `--rollout-sample-filter-path` | Path to the function that marks individual samples to be excluded from loss calculation. [Ref](../get_started/customization.md#6-rollout-sample-filter---rollout-sample-filter-path) | `None` | Type: str | Miles Native |
| `--rollout-sample-filter-path` | Path to the function that marks individual samples to be excluded from reward/advantage normalization and loss calculation. [Ref](../get_started/customization.md#6-rollout-sample-filter---rollout-sample-filter-path) | `None` | Type: str | Miles Native |
| `--rollout-all-samples-process-path` | Path to the function to process all samples (including filtered ones) after rollout. [Ref](../get_started/customization.md#7-rollout-all-samples-process---rollout-all-samples-process-path) | `None` | Type: str | Miles Native |

## Data Arguments
Expand Down
7 changes: 3 additions & 4 deletions docs/en/get_started/customization.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ Below is a summary of all available customization interfaces and their purposes.
| [`--custom-rm-path`](#3-reward-model---custom-rm-path) | Implement custom reward computation logic. |
| [`--dynamic-sampling-filter-path`](#4-dynamic-sampling-filter---dynamic-sampling-filter-path) | Filter samples during dynamic sampling (e.g., DAPO). |
| [`--buffer-filter-path`](#5-buffer-filter---buffer-filter-path) | Filter samples in the rollout buffer before training. |
| [`--rollout-sample-filter-path`](#6-rollout-sample-filter---rollout-sample-filter-path) | Determine if individual samples participate in loss calculation. |
| [`--rollout-sample-filter-path`](#6-rollout-sample-filter---rollout-sample-filter-path) | Determine if individual samples participate in training. |
| [`--rollout-all-samples-process-path`](#7-rollout-all-samples-process---rollout-all-samples-process-path) | Process all samples (including filtered ones) after rollout. |
| [`--rollout-data-postprocess-path`](#8-rollout-data-postprocess---rollout-data-postprocess-path) | Post-process rollout data after log probs are computed. |
| [`--custom-loss-function-path`](#9-custom-loss-function---custom-loss-function-path) | Implement custom training loss computation. |
Expand Down Expand Up @@ -161,14 +161,15 @@ def buffer_filter(samples: list[list[Sample]]) -> list[list[Sample]]

**Default**: `None`

**Purpose**: Determine whether individual samples participate in loss calculation.
**Purpose**: Determine whether individual samples participate in training.

**Signature**:
```python
def filter_function(args, samples: list[Sample]) -> None
```

**Note**: This function should directly modify the `remove_sample` attribute of each `Sample` object.
Removed samples stay available in rollout artifacts, but are excluded from reward/advantage normalization and loss calculation.

**Use Cases**:
- Filtering samples based on response quality
Expand Down Expand Up @@ -438,5 +439,3 @@ For detailed explanation of R3 and MilesRouter, see [Miles Router](../advanced/m
```python
def custom_model_provider(pre_process: bool, post_process: bool, vp_stage: int | None = None) -> GPTModel
```


15 changes: 13 additions & 2 deletions miles/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from miles.utils.replay_base import all_replay_managers
from miles.utils.timer import Timer, inverse_timer, timer
from miles.utils.tracking_utils import init_tracking
from miles.utils.training_semantics import validate_loss_masks_for_removed_samples
from miles.utils.types import RolloutBatch

from ...utils.profile_utils import TrainProfiler
Expand Down Expand Up @@ -334,6 +335,13 @@ def _basic_batch_summary():
if got != n:
_add_error(f"{key!r} length mismatch: got {got}, expected {n}")

remove_samples = rollout_data.get("remove_samples")
if remove_samples is not None:
if not _is_seq(remove_samples):
_add_error(f"'remove_samples' must be list/tuple, got {type(remove_samples).__name__}")
elif len(remove_samples) != n:
_add_error(f"'remove_samples' length mismatch: got {len(remove_samples)}, expected {n}")

token_key = None
if _present("tokens"):
token_key = "tokens"
Expand Down Expand Up @@ -416,8 +424,6 @@ def _basic_batch_summary():
continue

mask_sum = _sum_float(mask)
if mask_sum <= 0:
_add_error(f"loss_masks[{i}] has no active tokens, sum={mask_sum}, response_len={resp}")
if mask_sum > resp:
# Warning-only: float/weighted masks can legitimately have sum > resp.
_add_warning(f"loss_masks[{i}] sum={mask_sum} exceeds response_len={resp} (expected for float/weighted masks)")
Expand All @@ -431,6 +437,11 @@ def _basic_batch_summary():
except Exception as e:
_add_warning(f"binary check failed for loss_masks[{i}]: {type(e).__name__}: {e}")

try:
validate_loss_masks_for_removed_samples(rollout_data["loss_masks"], response_lengths, remove_samples)
except ValueError as e:
_add_error(str(e))

# max_seq_lens if present.
if _present("max_seq_lens"):
xs = rollout_data["max_seq_lens"]
Expand Down
15 changes: 14 additions & 1 deletion miles/backends/training_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
get_reinforce_plus_plus_baseline_advantages,
get_reinforce_plus_plus_returns,
)
from miles.utils.training_semantics import validate_loss_masks_for_removed_samples
from miles.utils.types import RolloutBatch

from .cp_utils import (
Expand Down Expand Up @@ -305,13 +306,16 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
values: None | list[torch.Tensor] = rollout_data.get("values")
response_lengths: list[int] = rollout_data.get("response_lengths")
loss_masks: list[torch.Tensor] = rollout_data.get("loss_masks")
remove_samples: list[bool] | None = rollout_data.get("remove_samples")
total_lengths: list[int] = rollout_data.get("total_lengths")
max_seq_lens: list[int] | None = rollout_data.get("max_seq_lens", None)

# return when not the last pp stage.
if log_probs is None and values is None:
return

validate_loss_masks_for_removed_samples(loss_masks, response_lengths, remove_samples)

if args.kl_coef == 0 or not log_probs:
# when kl_coef is 0, we won't compute ref_log_prob
xs = log_probs if log_probs is not None else values
Expand Down Expand Up @@ -352,6 +356,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
rewards=rewards,
kl=kl,
loss_masks=loss_masks,
remove_samples=remove_samples,
response_lengths=response_lengths,
total_lengths=total_lengths,
kl_coef=args.kl_coef,
Expand Down Expand Up @@ -429,7 +434,7 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)

all_masks = torch.cat(mask_chunks)

if all_masks.numel() > 0:
if all_masks.numel() > 0 and all_masks.sum().item() > 0:
assert (
all_advs.size() == all_masks.size()
), f"Shape mismatch before whitening: advantages {all_advs.size()}, masks {all_masks.size()}"
Expand All @@ -444,6 +449,14 @@ def compute_advantages_and_returns(args: Namespace, rollout_data: RolloutBatch)
chunk_lengths = [chunk.size(0) for chunk in advantages]
advantages = list(torch.split(whitened_advs_flat, chunk_lengths))

if remove_samples is not None:
if len(remove_samples) != len(advantages):
raise ValueError(f"remove_samples length {len(remove_samples)} != advantages length {len(advantages)}")
for i, remove_sample in enumerate(remove_samples):
if remove_sample:
advantages[i] = torch.zeros_like(advantages[i])
returns[i] = torch.zeros_like(returns[i])

rollout_data["advantages"] = advantages
rollout_data["returns"] = returns

Expand Down
26 changes: 4 additions & 22 deletions miles/ray/rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
from miles.utils.misc import load_function
from miles.utils.ray_utils import Box
from miles.utils.seqlen_balancing import get_seqlen_balanced_partitions
from miles.utils.training_semantics import post_process_rewards_excluding_removed
from miles.utils.tracking_utils import init_tracking
from miles.utils.types import Sample

Expand Down Expand Up @@ -678,28 +679,7 @@ def _post_process_rewards(self, samples: list[Sample] | list[list[Sample]]):
if self.custom_reward_post_process_func is not None:
return self.custom_reward_post_process_func(self.args, samples)

raw_rewards = [sample.get_reward_value(self.args) for sample in samples]
if (
self.args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"]
and self.args.rewards_normalization
):
# group norm
rewards = torch.tensor(raw_rewards, dtype=torch.float)
if rewards.shape[-1] == self.args.n_samples_per_prompt * self.args.rollout_batch_size:
rewards = rewards.reshape(-1, self.args.n_samples_per_prompt)
else:
# when samples count are not equal in each group
rewards = rewards.view(-1, rewards.shape[-1])
mean = rewards.mean(dim=-1, keepdim=True)
rewards = rewards - mean

if self.args.advantage_estimator in ["grpo", "gspo"] and self.args.grpo_std_normalization:
std = rewards.std(dim=-1, keepdim=True)
rewards = rewards / (std + 1e-6)

return raw_rewards, rewards.flatten().tolist()

return raw_rewards, raw_rewards
return post_process_rewards_excluding_removed(self.args, samples)

def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sample]]):
"""
Expand All @@ -722,6 +702,7 @@ def _convert_samples_to_train_data(self, samples: list[Sample] | list[list[Sampl
"raw_reward": raw_rewards,
"truncated": [1 if sample.status == Sample.Status.TRUNCATED else 0 for sample in samples],
"sample_indices": [sample.index for sample in samples],
"remove_samples": [sample.remove_sample for sample in samples],
}

# loss mask
Expand Down Expand Up @@ -873,6 +854,7 @@ def _stat(xs):
"rewards",
"truncated",
"loss_masks",
"remove_samples",
"round_number",
"sample_indices",
"rollout_log_probs",
Expand Down
4 changes: 2 additions & 2 deletions miles/utils/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,10 +1495,10 @@ def add_rollout_buffer_arguments(parser):
default=None,
help=(
"Path to the rollout sample filter function. "
"This function determines whether a sample will participate in loss calculation. "
"This function determines whether a sample will participate in training. "
"The function should take args and samples (list[Sample]) as input, and return None. "
"Please directly modify the remove_sample attribute of Sample. "
"Note: This attribute does not determine whether the sample participates in advantage normalization."
"Removed samples are excluded from reward/advantage normalization and loss calculation."
),
)
parser.add_argument(
Expand Down
12 changes: 12 additions & 0 deletions miles/utils/ppo_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def get_reinforce_plus_plus_returns(
rewards: torch.Tensor,
kl: list[torch.Tensor],
loss_masks: list[torch.Tensor],
remove_samples: list[bool] | None,
response_lengths: list[int],
total_lengths: list[int],
kl_coef: float,
Expand All @@ -224,6 +225,7 @@ def get_reinforce_plus_plus_returns(
rewards (Tensor): A tensor of scalar rewards for each sequence.
kl (List[Tensor]): List of per-token KL divergence tensors for sequence chunks.
loss_masks (List[Tensor]): List of response-only loss masks for each full sequence.
remove_samples (List[bool] | None): Whether each sequence is intentionally removed.
response_lengths (List[int]): The full length of each response sequence.
total_lengths (List[int]): The full length of each sequence (prompt + response).
kl_coef (float): Coefficient for the KL penalty.
Expand All @@ -241,6 +243,7 @@ def get_reinforce_plus_plus_returns(
for i in range(len(rewards)):
local_kl_chunk = kl[i]
total_len, response_len = total_lengths[i], response_lengths[i]
remove_sample = remove_samples is not None and remove_samples[i]

if cp_size > 1:
# Step 1,2:Gather all chunks and token_offsets from all ranks and reconstruct the full response tensor by splitting and placing each part
Expand All @@ -251,6 +254,15 @@ def get_reinforce_plus_plus_returns(
full_kl_response = local_kl_chunk

# Step 3: Compute returns on full response kl tensor.
if remove_sample:
returns_for_seq = torch.zeros_like(full_kl_response)
if cp_size > 1:
from miles.backends.training_utils.cp_utils import slice_log_prob_with_cp

returns_for_seq = slice_log_prob_with_cp(returns_for_seq, total_len, response_len)
final_returns_chunks.append(returns_for_seq)
continue

token_level_rewards = -kl_coef * full_kl_response
full_mask = loss_masks[i]
assert full_mask.sum().item() > 0, f"Sequence at index {i} is fully masked."
Expand Down
82 changes: 82 additions & 0 deletions miles/utils/training_semantics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import math
from typing import Any

import torch

from miles.utils.types import Sample


def _reward_value(args: Any, sample: Sample) -> Any:
return sample.get_reward_value(args)


def _finite_reward(value: Any, sample_index: int) -> float:
if not isinstance(value, (int, float)) or not math.isfinite(float(value)):
raise ValueError(f"sample {sample_index} has non-finite reward for training: {value!r}")
return float(value)


def _group_ranges(args: Any, sample_count: int) -> list[range]:
samples_per_prompt = args.n_samples_per_prompt
if sample_count == samples_per_prompt * args.rollout_batch_size:
return [range(start, start + samples_per_prompt) for start in range(0, sample_count, samples_per_prompt)]
return [range(0, sample_count)]


def post_process_rewards_excluding_removed(args: Any, samples: list[Sample]) -> tuple[list[Any], list[float]]:
raw_rewards = [_reward_value(args, sample) for sample in samples]
remove_samples = [sample.remove_sample for sample in samples]

processed_rewards = [0.0] * len(samples)
if not (
args.advantage_estimator in ["grpo", "gspo", "reinforce_plus_plus_baseline"]
and args.rewards_normalization
):
for i, raw_reward in enumerate(raw_rewards):
if not remove_samples[i]:
processed_rewards[i] = _finite_reward(raw_reward, i)
return raw_rewards, processed_rewards

for group in _group_ranges(args, len(samples)):
kept_indices = [i for i in group if not remove_samples[i]]
if not kept_indices:
continue

kept_rewards = torch.tensor(
[_finite_reward(raw_rewards[i], i) for i in kept_indices],
dtype=torch.float,
)
kept_rewards = kept_rewards - kept_rewards.mean()

if args.advantage_estimator in ["grpo", "gspo"] and args.grpo_std_normalization:
if kept_rewards.numel() > 1:
std = kept_rewards.std()
else:
std = torch.tensor(0.0, dtype=kept_rewards.dtype)
kept_rewards = kept_rewards / (std + 1e-6)

for i, reward in zip(kept_indices, kept_rewards.tolist(), strict=True):
processed_rewards[i] = reward

return raw_rewards, processed_rewards


def validate_loss_masks_for_removed_samples(
loss_masks: list[Any],
response_lengths: list[int],
remove_samples: list[bool] | None,
) -> None:
if remove_samples is None:
remove_samples = [False] * len(loss_masks)
if len(remove_samples) != len(loss_masks):
raise ValueError(f"remove_samples length {len(remove_samples)} != loss_masks length {len(loss_masks)}")

for i, (loss_mask, response_length, remove_sample) in enumerate(
zip(loss_masks, response_lengths, remove_samples, strict=True)
):
mask_sum = loss_mask.sum().item() if torch.is_tensor(loss_mask) else sum(loss_mask)
if mask_sum <= 0 and not remove_sample:
raise ValueError(
f"loss_masks[{i}] has no active tokens, sum={mask_sum}, "
f"response_len={response_length}, remove_samples[{i}] is false"
)
2 changes: 1 addition & 1 deletion miles/utils/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ class ParamInfo:
# A dict-based batch produced along the rollout -> training path
# In Megatron backend, several fields are converted to torch.Tensor lists on GPU
# before being consumed by data iterators (see megatron_utils.actor._get_rollout_data).
RolloutBatch = dict[str, list[torch.Tensor] | list[int] | list[float] | list[str]]
RolloutBatch = dict[str, list[torch.Tensor] | list[int] | list[float] | list[str] | list[bool]]


@dataclass
Expand Down
Loading