Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
5e7d738
fix(http_utils): disable httpx keepalive to spread load across uvicor…
rmfan May 29, 2026
03fdffb
feat(session): SessionRegistry uuid pinning for multi-process routing
rmfan Jun 4, 2026
59f5794
feat(session): add ASGI front-end SessionRouter for multi-process layout
rmfan Jun 4, 2026
2c721fb
feat(session): --session-server-workers N spawns N processes + router
rmfan Jun 4, 2026
908fdae
test(session): hash agreement + multi-worker startup smoke tests
rmfan Jun 4, 2026
27851f8
feat(session-router): replace hash+rejection with prefix-encoded sess…
rmfan Jun 4, 2026
fa89d49
fix(session-router): stream request/response bodies end-to-end (audit…
rmfan Jun 4, 2026
39efe46
fix(session-server): register reaper BEFORE first .start() (audit H2)
rmfan Jun 4, 2026
7ed52a7
fix(session-server): drop SIGTERM handler chain (audit H3)
rmfan Jun 4, 2026
b0bb0e1
fix(session-server): deep-copy per-worker args to avoid shared-ref bu…
rmfan Jun 4, 2026
355c939
style: black/isort on PR files (pre-commit autofix)
rmfan Jun 4, 2026
e45b90e
fix(session-router): enable HTTP keepalive in router->backend pool
rmfan Jun 4, 2026
50295d5
fix(session-router): expose session_server_instance_id on /health
rmfan Jun 4, 2026
438807a
fix(session-server): prevent port collisions + reap children on SIGTERM
rmfan Jun 4, 2026
fecb048
fix(session-router): route unknown session_ids via round-robin (was 404)
rmfan Jun 4, 2026
0353c73
docs: drop references to design docs not in the repo
rmfan Jun 4, 2026
11cca4a
fix(session-server): explicit error when port walk exhausts 65535
rmfan Jun 4, 2026
aa95c59
perf(session): asyncio.to_thread wrap sync tito-tokenizer calls
rmfan Jun 4, 2026
eee1684
style: pre-commit autofix on prod files
rmfan Jun 4, 2026
1fa0673
fix(session-router): don't leak exception details in 502 response body
rmfan Jun 4, 2026
e6748c5
perf(session-router): K-way SO_REUSEPORT uvicorn workers (opt-in)
rmfan Jun 4, 2026
66f27f6
fix(session-router): return 409 on stale-update guard to prevent curs…
rmfan Jun 5, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 1 addition & 4 deletions miles/backends/experimental/fsdp_utils/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,7 @@
from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict
from torch.distributed.checkpoint.stateful import Stateful

from miles.backends.training_utils.log_utils import (
init_train_step_counter,
save_train_step_counter,
)
from miles.backends.training_utils.log_utils import init_train_step_counter, save_train_step_counter

logger = logging.getLogger(__name__)

Expand Down
12 changes: 4 additions & 8 deletions miles/backends/megatron_utils/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,6 @@
logger = logging.getLogger(__name__)

import math
from typing import Any



def validate_rollout_for_grpo_training_step(
Expand All @@ -69,7 +67,6 @@ def validate_rollout_for_grpo_training_step(
Logs useful diagnostics before raising so NCCL-desync root cause is visible
in the first failing rank's log.
"""
import math
import socket
import traceback

Expand Down Expand Up @@ -218,10 +215,7 @@ def _summarize_vector_list(key, limit=3):
shapes.append(type(x).__name__)
dtypes.append(type(x).__name__)
devices.append("python")
return (
f"{key}: len={len(xs)} first_shapes={shapes} "
f"first_dtypes={dtypes} first_devices={devices}"
)
return f"{key}: len={len(xs)} first_shapes={shapes} " f"first_dtypes={dtypes} first_devices={devices}"

def _basic_batch_summary():
keys = sorted(list(rollout_data.keys()))
Expand Down Expand Up @@ -420,7 +414,9 @@ def _basic_batch_summary():
_add_error(f"loss_masks[{i}] has no active tokens, sum={mask_sum}, response_len={resp}")
if mask_sum > resp:
# Warning-only: float/weighted masks can legitimately have sum > resp.
_add_warning(f"loss_masks[{i}] sum={mask_sum} exceeds response_len={resp} (expected for float/weighted masks)")
_add_warning(
f"loss_masks[{i}] sum={mask_sum} exceeds response_len={resp} (expected for float/weighted masks)"
)

if torch.is_tensor(mask):
# Binary check: warning, not fatal, because masks may be float.
Expand Down
4 changes: 1 addition & 3 deletions miles/backends/training_utils/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,9 +127,7 @@ def get_batch(
# same list (aggregate_train_losses keys positionally on the first microbatch).
if has_domains:
if not hasattr(data_iterator, "_all_domains_cache"):
data_iterator._all_domains_cache = sorted(
{d for d in data_iterator.rollout_data["domains"] if d}
)
data_iterator._all_domains_cache = sorted({d for d in data_iterator.rollout_data["domains"] if d})
batch["all_domains"] = data_iterator._all_domains_cache

tokens = batch["tokens"]
Expand Down
44 changes: 22 additions & 22 deletions miles/backends/training_utils/log_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,33 +23,33 @@
# Maps bare metric names to their W&B top-level section(s).
# Keys appearing in multiple sections (e.g. pg_loss) are emitted under each.
_TRAIN_METRIC_GROUPS: dict[str, list[str]] = {
"ppo_kl": ["policy_shift"],
"ois": ["policy_shift"],
"pg_clipfrac": ["policy_shift"],
"pg_loss": ["policy_shift", "optimization"],
"log_probs": ["policy_shift"], # current policy (training forward pass)
"old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout)
"ref_kl": ["policy_shift"],
"ppo_kl": ["policy_shift"],
"ois": ["policy_shift"],
"pg_clipfrac": ["policy_shift"],
"pg_loss": ["policy_shift", "optimization"],
"log_probs": ["policy_shift"], # current policy (training forward pass)
"old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout)
"ref_kl": ["policy_shift"],
"train_rollout_logprob_abs_diff": ["train_inference_mismatch"],
"train_rollout_logprob_diff": ["train_inference_mismatch"],
"tis": ["train_inference_mismatch"],
"tis_abs": ["train_inference_mismatch"],
"tis_clipfrac": ["train_inference_mismatch"],
"loss": ["optimization"],
"entropy_loss": ["optimization"],
"kl_loss": ["optimization"],
"grad_norm": ["optimization"],
"train_rollout_logprob_diff": ["train_inference_mismatch"],
"tis": ["train_inference_mismatch"],
"tis_abs": ["train_inference_mismatch"],
"tis_clipfrac": ["train_inference_mismatch"],
"loss": ["optimization"],
"entropy_loss": ["optimization"],
"kl_loss": ["optimization"],
"grad_norm": ["optimization"],
}

# Maps rollout batch field names to their W&B top-level section.
_ROLLOUT_DATA_METRIC_GROUPS: dict[str, str] = {
"log_probs": "train_inference_mismatch", # FSDP log probs at rollout time
"log_probs": "train_inference_mismatch", # FSDP log probs at rollout time
"rollout_log_probs": "train_inference_mismatch", # inference engine log probs
"ref_log_probs": "policy_shift", # reference model log probs
"rewards": "reward",
"raw_reward": "reward",
"advantages": "reward",
"returns": "reward",
"ref_log_probs": "policy_shift", # reference model log probs
"rewards": "reward",
"raw_reward": "reward",
"advantages": "reward",
"returns": "reward",
}

# Cumulative train-step counter across all rollouts. The previous formula
Expand Down Expand Up @@ -570,7 +570,7 @@ def log_train_step(
for full_key, val in log_dict_out.items():
if not full_key.startswith(prefix):
continue
bare_key = full_key[len(prefix):]
bare_key = full_key[len(prefix) :]
# Per-domain keys arrive as "<metric>/<domain>" — route to "<group>/<domain>/<metric>".
metric_name, sep, domain = bare_key.rpartition("/")
lookup = metric_name if (sep and metric_name in _TRAIN_METRIC_GROUPS) else bare_key
Expand Down
8 changes: 6 additions & 2 deletions miles/backends/training_utils/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -775,8 +775,12 @@ def policy_loss_function(
for dd, lm in zip(batch["domains"], batch["loss_masks"], strict=False)
]
reducer = get_sum_of_sample_mean(
total_lengths, response_lengths, masked,
args.calculate_per_token_loss, args.qkv_format, max_seq_lens,
total_lengths,
response_lengths,
masked,
args.calculate_per_token_loss,
args.qkv_format,
max_seq_lens,
loss_agg_mode=getattr(args, "loss_agg_mode", None),
)
for name, t in per_token.items():
Expand Down
Loading
Loading