diff --git a/miles/backends/experimental/fsdp_utils/checkpoint.py b/miles/backends/experimental/fsdp_utils/checkpoint.py index 8a0b5ff5d3..80eeb64110 100644 --- a/miles/backends/experimental/fsdp_utils/checkpoint.py +++ b/miles/backends/experimental/fsdp_utils/checkpoint.py @@ -12,10 +12,7 @@ from torch.distributed.checkpoint.state_dict import get_state_dict, set_state_dict from torch.distributed.checkpoint.stateful import Stateful -from miles.backends.training_utils.log_utils import ( - init_train_step_counter, - save_train_step_counter, -) +from miles.backends.training_utils.log_utils import init_train_step_counter, save_train_step_counter logger = logging.getLogger(__name__) diff --git a/miles/backends/megatron_utils/actor.py b/miles/backends/megatron_utils/actor.py index d0a0daec75..c138ea8746 100644 --- a/miles/backends/megatron_utils/actor.py +++ b/miles/backends/megatron_utils/actor.py @@ -49,8 +49,6 @@ logger = logging.getLogger(__name__) import math -from typing import Any - def validate_rollout_for_grpo_training_step( @@ -69,7 +67,6 @@ def validate_rollout_for_grpo_training_step( Logs useful diagnostics before raising so NCCL-desync root cause is visible in the first failing rank's log. """ - import math import socket import traceback @@ -218,10 +215,7 @@ def _summarize_vector_list(key, limit=3): shapes.append(type(x).__name__) dtypes.append(type(x).__name__) devices.append("python") - return ( - f"{key}: len={len(xs)} first_shapes={shapes} " - f"first_dtypes={dtypes} first_devices={devices}" - ) + return f"{key}: len={len(xs)} first_shapes={shapes} " f"first_dtypes={dtypes} first_devices={devices}" def _basic_batch_summary(): keys = sorted(list(rollout_data.keys())) @@ -420,7 +414,9 @@ def _basic_batch_summary(): _add_error(f"loss_masks[{i}] has no active tokens, sum={mask_sum}, response_len={resp}") if mask_sum > resp: # Warning-only: float/weighted masks can legitimately have sum > resp. - _add_warning(f"loss_masks[{i}] sum={mask_sum} exceeds response_len={resp} (expected for float/weighted masks)") + _add_warning( + f"loss_masks[{i}] sum={mask_sum} exceeds response_len={resp} (expected for float/weighted masks)" + ) if torch.is_tensor(mask): # Binary check: warning, not fatal, because masks may be float. diff --git a/miles/backends/training_utils/data.py b/miles/backends/training_utils/data.py index 2733c93442..1ac3fd578f 100644 --- a/miles/backends/training_utils/data.py +++ b/miles/backends/training_utils/data.py @@ -127,9 +127,7 @@ def get_batch( # same list (aggregate_train_losses keys positionally on the first microbatch). if has_domains: if not hasattr(data_iterator, "_all_domains_cache"): - data_iterator._all_domains_cache = sorted( - {d for d in data_iterator.rollout_data["domains"] if d} - ) + data_iterator._all_domains_cache = sorted({d for d in data_iterator.rollout_data["domains"] if d}) batch["all_domains"] = data_iterator._all_domains_cache tokens = batch["tokens"] diff --git a/miles/backends/training_utils/log_utils.py b/miles/backends/training_utils/log_utils.py index 2b5e190e34..a20b2d6d8b 100644 --- a/miles/backends/training_utils/log_utils.py +++ b/miles/backends/training_utils/log_utils.py @@ -23,33 +23,33 @@ # Maps bare metric names to their W&B top-level section(s). # Keys appearing in multiple sections (e.g. pg_loss) are emitted under each. _TRAIN_METRIC_GROUPS: dict[str, list[str]] = { - "ppo_kl": ["policy_shift"], - "ois": ["policy_shift"], - "pg_clipfrac": ["policy_shift"], - "pg_loss": ["policy_shift", "optimization"], - "log_probs": ["policy_shift"], # current policy (training forward pass) - "old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout) - "ref_kl": ["policy_shift"], + "ppo_kl": ["policy_shift"], + "ois": ["policy_shift"], + "pg_clipfrac": ["policy_shift"], + "pg_loss": ["policy_shift", "optimization"], + "log_probs": ["policy_shift"], # current policy (training forward pass) + "old_log_probs": ["policy_shift"], # old policy (rollout or FSDP rollout) + "ref_kl": ["policy_shift"], "train_rollout_logprob_abs_diff": ["train_inference_mismatch"], - "train_rollout_logprob_diff": ["train_inference_mismatch"], - "tis": ["train_inference_mismatch"], - "tis_abs": ["train_inference_mismatch"], - "tis_clipfrac": ["train_inference_mismatch"], - "loss": ["optimization"], - "entropy_loss": ["optimization"], - "kl_loss": ["optimization"], - "grad_norm": ["optimization"], + "train_rollout_logprob_diff": ["train_inference_mismatch"], + "tis": ["train_inference_mismatch"], + "tis_abs": ["train_inference_mismatch"], + "tis_clipfrac": ["train_inference_mismatch"], + "loss": ["optimization"], + "entropy_loss": ["optimization"], + "kl_loss": ["optimization"], + "grad_norm": ["optimization"], } # Maps rollout batch field names to their W&B top-level section. _ROLLOUT_DATA_METRIC_GROUPS: dict[str, str] = { - "log_probs": "train_inference_mismatch", # FSDP log probs at rollout time + "log_probs": "train_inference_mismatch", # FSDP log probs at rollout time "rollout_log_probs": "train_inference_mismatch", # inference engine log probs - "ref_log_probs": "policy_shift", # reference model log probs - "rewards": "reward", - "raw_reward": "reward", - "advantages": "reward", - "returns": "reward", + "ref_log_probs": "policy_shift", # reference model log probs + "rewards": "reward", + "raw_reward": "reward", + "advantages": "reward", + "returns": "reward", } # Cumulative train-step counter across all rollouts. The previous formula @@ -570,7 +570,7 @@ def log_train_step( for full_key, val in log_dict_out.items(): if not full_key.startswith(prefix): continue - bare_key = full_key[len(prefix):] + bare_key = full_key[len(prefix) :] # Per-domain keys arrive as "/" — route to "//". metric_name, sep, domain = bare_key.rpartition("/") lookup = metric_name if (sep and metric_name in _TRAIN_METRIC_GROUPS) else bare_key diff --git a/miles/backends/training_utils/loss.py b/miles/backends/training_utils/loss.py index 317286f6e0..bf1eaf75a2 100644 --- a/miles/backends/training_utils/loss.py +++ b/miles/backends/training_utils/loss.py @@ -775,8 +775,12 @@ def policy_loss_function( for dd, lm in zip(batch["domains"], batch["loss_masks"], strict=False) ] reducer = get_sum_of_sample_mean( - total_lengths, response_lengths, masked, - args.calculate_per_token_loss, args.qkv_format, max_seq_lens, + total_lengths, + response_lengths, + masked, + args.calculate_per_token_loss, + args.qkv_format, + max_seq_lens, loss_agg_mode=getattr(args, "loss_agg_mode", None), ) for name, t in per_token.items(): diff --git a/miles/ray/rollout.py b/miles/ray/rollout.py index 7d968f0ed8..86f04e2b0c 100644 --- a/miles/ray/rollout.py +++ b/miles/ray/rollout.py @@ -1,3 +1,4 @@ +import copy import dataclasses import itertools import logging @@ -560,6 +561,7 @@ def _get_rollout_data(self, rollout_id): path = Path(self.args.load_debug_rollout_data.format(rollout_id=rollout_id)) if path.suffix == ".parquet": import pyarrow.parquet as pq + data = [Sample.from_dict(row) for row in pq.read_table(path).to_pylist()] else: data = torch.load(path, weights_only=False)["samples"] @@ -650,6 +652,7 @@ def _save_debug_rollout_data(self, data, rollout_id, evaluation: bool): if save_format == "parquet": import pyarrow as pa import pyarrow.parquet as pq + path = path.with_suffix(".parquet") table = pa.Table.from_pylist(samples) table = table.replace_schema_metadata({b"rollout_id": str(rollout_id).encode()}) @@ -897,16 +900,8 @@ def _stat(xs): rollout_data[key] = data[key] token_lens = [total_lengths[j] for j in partition] - response_lens = ( - [data["response_lengths"][j] for j in partition] - if "response_lengths" in data - else [] - ) - loss_mask_lens = ( - [_safe_len(data["loss_masks"][j]) for j in partition] - if "loss_masks" in data - else [] - ) + response_lens = [data["response_lengths"][j] for j in partition] if "response_lengths" in data else [] + loss_mask_lens = [_safe_len(data["loss_masks"][j]) for j in partition] if "loss_masks" in data else [] payload_bytes = _estimate_payload_bytes(rollout_data) ref = ray.put(rollout_data) @@ -1304,6 +1299,11 @@ def _start_session_server(args): The session server runs as a separate process with its own port and proxies inference requests directly to SGLang worker engines. It is always started as a standalone process regardless of whether ``--use-miles-router`` is active. + + When ``--session-server-workers N`` is > 1, this function spawns N backend + SessionServer processes on consecutive ports and an ASGI front-end on + ``args.session_server_port`` that consistent-hash-routes by ``session_id``. + See ``miles/rollout/session/session_router.py``. """ if not getattr(args, "use_session_server", False): return @@ -1327,14 +1327,220 @@ def _start_session_server(args): ) router_url = f"http://{args.sglang_router_ip}:{args.sglang_router_port}" + worker_count = max(1, int(getattr(args, "session_server_workers", 1) or 1)) from miles.rollout.session.session_server import run_session_server - process = multiprocessing.Process(target=run_session_server, args=(args, router_url)) - process.daemon = True - process.start() - wait_for_server_ready(ip, port, process, timeout=30) - logger.info(f"Session server launched at {ip}:{port}") + if worker_count == 1: + # Preserve exact pre-existing behavior when the flag is default. + process = multiprocessing.Process(target=run_session_server, args=(args, router_url)) + process.daemon = True + process.start() + wait_for_server_ready(ip, port, process, timeout=30) + logger.info(f"Session server launched at {ip}:{port}") + return + + # Multi-process layout: N backends on consecutive ports starting at + # port + 1; ASGI front-end on `port` (the user-facing one). + from miles.rollout.session.session_router import run_session_router + + # Mutable list so the reaper sees children as they're spawned. We + # MUST register the reaper BEFORE the first .start() call: the + # backend ready-window is ~60s per worker, and if the parent gets + # SIGTERM (Ray actor shutdown) mid-startup, otherwise all N children + # leak holding their ports — the next rollout then trips the + # "stale session server" RuntimeError. See audit H2. + tracked_processes: list[multiprocessing.Process] = [] + backend_processes: list[multiprocessing.Process] = [] + backend_urls: list[str] = [] + # Track ports we've already handed out so worker i+1 doesn't race + # worker i: the child process hasn't bound yet when the next + # iteration calls is_port_available(), so without this set both + # children may target the same port and one crashes on bind. See + # PR #31 H2. + chosen_ports: set[int] = set() + _register_session_server_reaper(tracked_processes) + try: + for i in range(worker_count): + backend_port = port + 1 + i + # Find a free port at-or-above the desired one. We start one + # above args.session_server_port so the front-end has port + # itself reserved. Skip ports already chosen for previous + # workers — they may not be bound yet but the spawn is + # already in flight. + while not is_port_available(backend_port) or backend_port in chosen_ports: + backend_port += 1 + if backend_port > 65535: + raise RuntimeError( + f"all ports exhausted while allocating port for " + f"session-server worker {i}/{worker_count}; " + f"started at {port + 1 + i}, walked past 65535. " + f"chosen so far: {sorted(chosen_ports)}" + ) + chosen_ports.add(backend_port) + worker_args = _per_worker_args_copy(args) + worker_args.session_server_port = backend_port + worker_args.session_server_worker_index = i + worker_args.session_server_worker_count = worker_count + # Give each worker a stable, distinguishable instance id for + # log correlation while keeping the shared `args.session_server_instance_id` + # as the cluster-facing one. + worker_args.session_server_instance_id = f"{args.session_server_instance_id}-w{i}" + p = multiprocessing.Process(target=run_session_server, args=(worker_args, router_url)) + p.daemon = True + p.start() + tracked_processes.append(p) + backend_processes.append(p) + backend_urls.append(f"http://{ip}:{backend_port}") + + # Wait for every backend to come up before starting the front-end + # so the router never sees connection-refused races on first call. + for p, url in zip(backend_processes, backend_urls, strict=True): + backend_port = int(url.rsplit(":", 1)[1]) + wait_for_server_ready(ip, backend_port, p, timeout=60) + + router_worker_count = max(1, int(getattr(args, "session_router_workers", 1) or 1)) + if router_worker_count == 1: + # Existing path — single router process. Bit-for-bit identical + # to the pre-existing behavior so default deployments don't + # see any change. + router_process = multiprocessing.Process( + target=run_session_router, args=(args, backend_urls), name="session-router" + ) + router_process.daemon = True + router_process.start() + tracked_processes.append(router_process) + wait_for_server_ready(ip, port, router_process, timeout=30) + else: + # Multi-worker: spawn K independent uvicorn workers, all + # SO_REUSEPORT-bound to the same port. The Linux kernel + # hash-distributes incoming connections across them. The + # router state is per-process (rr counter, httpx pool); no + # cross-worker coordination is needed, and routing decisions + # are pure functions of the URL prefix so any worker can + # answer any request correctly. + router_workers: list[multiprocessing.Process] = [] + for i in range(router_worker_count): + # Each worker gets a distinguishable instance_id for log + # correlation; the cluster-facing id stays on `args`. + rargs = _per_worker_args_copy(args) + rargs.session_server_instance_id = f"{args.session_server_instance_id}-router{i}" + p = multiprocessing.Process( + target=run_session_router, + args=(rargs, backend_urls), + name=f"session-router-w{i}", + ) + p.daemon = True + p.start() + router_workers.append(p) + tracked_processes.append(p) + logger.info("spawned session-router worker %d on :%d (pid=%d)", i, port, p.pid) + # One health-check suffices: SO_REUSEPORT means any worker + # can serve the probe. + wait_for_server_ready(ip, port, router_workers[0], timeout=30) + except Exception: + # Make sure we don't leak orphan backend workers if anything + # above raises (e.g. a backend never becomes ready). The parent + # is daemonized so children would otherwise outlive a failed + # start. + for p in tracked_processes: + if p.is_alive(): + p.terminate() + raise + + logger.info( + "Session server launched at %s:%s with %d workers on ports %s-%s", + ip, + port, + worker_count, + port + 1, + port + worker_count, + ) + + +def _per_worker_args_copy(args): + """Return a deep-isolated copy of ``args`` safe for per-worker mutation. + + We mutate a handful of ``session_server_*`` attributes on the copy + before handing it to ``multiprocessing.Process``. The previous + implementation was ``copy.copy(args)`` (a shallow copy), which is + fine for scalar fields but shares references for any nested mutable + (list / dict / Namespace). Any future field that happens to be a + list/dict would have all N worker copies aliasing the same object — + mutating it in one worker (or in this very function, in a loop) + would silently corrupt the others. + + ``copy.deepcopy`` is the safe default here. The args object is + parsed once at startup and is small (< 1 KB worth of strings and + primitives in practice), so the deepcopy cost is negligible + compared to the multiprocessing fork overhead. + """ + return copy.deepcopy(args) + + +def _register_session_server_reaper(processes): + """Make sure session-server child processes die with the parent. + + Three layers: + + * ``daemon=True`` on each child — Python's stdlib terminates + daemonic children automatically when the parent exits. + * ``atexit`` — runs on a normal Python exit (e.g. clean Ray + actor shutdown). + * ``SIGTERM`` handler — covers the case Ray actor preemption + sends SIGTERM and the parent stays alive briefly (the audit's + previous H3 fix removed this entirely, but PR #31's deep + review showed the resulting atexit-only reaper leaks zombies + that hold the session-server port — next rollout then trips + the "stale session server" RuntimeError). + + The SIGTERM handler is intentionally simple: it does NOT chain to + any previous handler (chaining via ``signal.getsignal`` is racy + with Ray and corrupts the captured ``prev`` if this function is + called twice in one process). It just reaps and lets the default + SIGTERM action run via ``signal.SIG_DFL``. + + After ``terminate()`` we ``join(timeout=10)`` so the child is + actually reaped — without this the child becomes a zombie until + the parent itself exits, which is exactly the leak we're trying + to prevent. If it's still alive after the join, escalate to + ``kill()`` so the port is definitely released. + """ + import atexit + import signal + + def _reap(*_): + for p in processes: + try: + if p.is_alive(): + p.terminate() + except Exception: + pass + for p in processes: + try: + p.join(timeout=10) + if p.is_alive(): + p.kill() + p.join(timeout=2) + except Exception: + pass + + atexit.register(_reap) + + def _sigterm_handler(signum, frame): + # Simple delegation: reap, then re-raise the default SIGTERM + # behavior so the parent dies promptly. No chain semantics. + _reap() + signal.signal(signal.SIGTERM, signal.SIG_DFL) + os.kill(os.getpid(), signal.SIGTERM) + + try: + signal.signal(signal.SIGTERM, _sigterm_handler) + except (ValueError, OSError): + # Not in main thread (e.g. running inside a Ray worker thread): + # signal handlers can only be installed from the main thread. + # The atexit + daemon=True fallbacks still cover us in that case. + pass def _log_eval_rollout_data(rollout_id, args, data, extra_metrics: dict[str, Any] | None = None): @@ -1386,7 +1592,7 @@ def _log_rollout_data(rollout_id, args, samples, rollout_extra_metrics, rollout_ # Mirror reward/* and response_stats/* as top-level wandb panels. for full_key, val in list(log_dict.items()): if full_key.startswith(("rollout/reward/", "rollout/response_stats/")): - log_dict[full_key[len("rollout/"):]] = val + log_dict[full_key[len("rollout/") :]] = val logger.info(f"perf {rollout_id}: {log_dict}") step = compute_rollout_step(args, rollout_id) log_dict["rollout/step"] = step @@ -1600,9 +1806,7 @@ def _compute_grouped_response_metrics(args, group: list[Sample], prefix: str) -> } -def _compute_group_outcome_metrics( - args, all_samples: list[Sample], prefix: str = "reward" -) -> dict: +def _compute_group_outcome_metrics(args, all_samples: list[Sample], prefix: str = "reward") -> dict: """Fraction of prompt groups that are unanimously correct or incorrect. GRPO only.""" if args.advantage_estimator == "ppo": return {} diff --git a/miles/rollout/generate_utils/openai_endpoint_utils.py b/miles/rollout/generate_utils/openai_endpoint_utils.py index 6c328719a1..7ba101ac7c 100644 --- a/miles/rollout/generate_utils/openai_endpoint_utils.py +++ b/miles/rollout/generate_utils/openai_endpoint_utils.py @@ -247,5 +247,5 @@ def _truncate_sample_output(sample: Sample, keep_tokens: int, tokenizer) -> None if sample.loss_mask is not None: sample.loss_mask = sample.loss_mask[:keep_tokens] if sample.rollout_routed_experts is not None: - sample.rollout_routed_experts = sample.rollout_routed_experts[:len(sample.tokens) - 1] + sample.rollout_routed_experts = sample.rollout_routed_experts[: len(sample.tokens) - 1] sample.status = Sample.Status.TRUNCATED diff --git a/miles/rollout/session/linear_trajectory.py b/miles/rollout/session/linear_trajectory.py index 31acbeec3c..766652f3fe 100644 --- a/miles/rollout/session/linear_trajectory.py +++ b/miles/rollout/session/linear_trajectory.py @@ -301,16 +301,43 @@ class SessionRegistry: LinearTrajectory; called by the route handler under session.lock. """ - def __init__(self, args, tokenizer: Any, *, tito_tokenizer: TITOTokenizer): + def __init__( + self, + args, + tokenizer: Any, + *, + tito_tokenizer: TITOTokenizer, + worker_index: int = 0, + worker_count: int = 1, + ): self.sessions: dict[str, LinearTrajectory] = {} self._session_last_access: dict[str, float] = {} self.args = args self.tokenizer = tokenizer self.tito_tokenizer = tito_tokenizer self.comparator = tito_tokenizer.create_comparator() + if worker_count < 1: + raise ValueError(f"worker_count must be >= 1, got {worker_count}") + if not 0 <= worker_index < worker_count: + raise ValueError(f"worker_index must be in [0, {worker_count}), got {worker_index}") + self.worker_index = worker_index + self.worker_count = worker_count def create_session(self) -> str: - session_id = uuid.uuid4().hex + """Generate a session_id that routes to this worker. + + Uses Stripe-style prefix encoding: ``w{worker_index}-{uuid4hex}``. + The front-end router parses the ``w-`` prefix to route + subsequent ``/sessions/{id}/...`` calls back to this worker. + + Single-worker deployments (``worker_count == 1``) keep emitting + bare uuid hex for backwards compatibility with existing tests + and operator tooling. + """ + if self.worker_count == 1: + session_id = uuid.uuid4().hex + else: + session_id = f"w{self.worker_index}-{uuid.uuid4().hex}" self.sessions[session_id] = LinearTrajectory() return session_id @@ -340,10 +367,7 @@ def _evict_stale_sessions(self) -> None: if not self._session_last_access: return now = time.monotonic() - stale = [ - sid for sid, ts in self._session_last_access.items() - if now - ts > self._SESSION_TTL_SECS - ] + stale = [sid for sid, ts in self._session_last_access.items() if now - ts > self._SESSION_TTL_SECS] for sid in stale: self.sessions.pop(sid, None) self._session_last_access.pop(sid, None) diff --git a/miles/rollout/session/session_errors.py b/miles/rollout/session/session_errors.py index 30e6784384..af0bdea8b6 100644 --- a/miles/rollout/session/session_errors.py +++ b/miles/rollout/session/session_errors.py @@ -6,7 +6,8 @@ ├── SessionNotFoundError → 404 session does not exist ├── MessageValidationError → 400 messages structure/content invalid ├── TokenizationError → 500 TITO tokenizer / prefix mismatch -└── UpstreamResponseError → 502 SGLang response invalid or unexpected +├── UpstreamResponseError → 502 SGLang response invalid or unexpected +└── SessionStateConflictError → 409 state changed during unlocked proxy phase """ @@ -49,3 +50,24 @@ class UpstreamResponseError(SessionError): """ status_code: int = 502 + + +class SessionStateConflictError(SessionError): + """Raised when session state changed during the unlocked proxy phase. + + The split-lock chat flow releases ``session.lock`` while proxying to + SGLang. If another writer commits an assistant turn in that window, + this writer cannot safely commit its own response: the trajectory's + accumulated_token_ids would no longer line up with the records list, + causing the cursor-mismatch assertion in + ``compute_samples_from_openai_records`` to fire downstream. + + We return 409 so the caller (litellm/harbor) treats this as a + retryable conflict and does NOT incorporate the dropped turn into its + local trajectory. Evidence: run 1711903 (~/run_analysis/1711903/ + 1711903_errors_rca.md) — 24 ``state changed during proxy`` warnings + led to 2 cursor-mismatch failures when the dropped turns were silently + returned as 200. + """ + + status_code: int = 409 diff --git a/miles/rollout/session/session_router.py b/miles/rollout/session/session_router.py new file mode 100644 index 0000000000..6313176eb2 --- /dev/null +++ b/miles/rollout/session/session_router.py @@ -0,0 +1,331 @@ +"""ASGI front-end for the multi-process session server. + +When ``--session-server-workers N`` is set with N > 1, ``_start_session_server`` +spawns N backend SessionServer processes on consecutive ports and runs this +front-end on ``args.session_server_port``. The front-end: + + * Parses ``session_id`` from the URL path (``/sessions/{id}/...``). + * Routes by parsing the ``w-`` prefix stamped onto the id by + ``SessionRegistry.create_session``. Prefix-encoded ids (Stripe-style) + eliminate the hash-agreement risk between router and backend — there + is no shared algorithm to drift on. + * For the stateless ``POST /sessions`` and ``GET /health`` paths, + routes by a round-robin counter (any worker will do; the chosen + worker stamps its own index on the returned id). + * Streams the response body through verbatim (no JSON re-encoding, + no full-body buffering). + +The router does almost no per-request CPU work (path-parse + str.split + +httpx passthrough), so its GIL does not become the new bottleneck — +all the tokenizer / TITO work happens in the backend workers, each in +its own process. +""" + +import itertools +import logging +import re +import socket + +import httpx +import setproctitle +import uvicorn +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse, Response, StreamingResponse + +logger = logging.getLogger(__name__) + +# Matches /sessions/{id}/... and captures {id}. Bare POST /sessions +# (creating a new session, no id yet) is intentionally excluded. +_SESSION_PATH_RE = re.compile(r"^/sessions/([^/]+)(?:/|$)") + +# Matches the ``w-`` prefix that ``SessionRegistry.create_session`` +# stamps onto every multi-worker session id. +_WORKER_PREFIX_RE = re.compile(r"^w(\d+)-") + + +def parse_worker_index(session_id: str, worker_count: int) -> int: + """Parse the ``w-`` prefix and return the worker index. + + Raises ``ValueError`` if the prefix is missing or the parsed index + is out of range for the current worker_count. + """ + m = _WORKER_PREFIX_RE.match(session_id) + if m is None: + raise ValueError(f"session_id {session_id!r} does not have the expected 'w-' prefix") + idx = int(m.group(1)) + if not 0 <= idx < worker_count: + raise ValueError( + f"session_id {session_id!r} parses to worker index {idx}, " f"out of range for worker_count={worker_count}" + ) + return idx + + +class SessionRouter: + """FastAPI app that hash-routes session requests to backend workers.""" + + def __init__(self, args, backend_urls: list[str]): + if not backend_urls: + raise ValueError("SessionRouter requires at least one backend URL") + self.backend_urls = backend_urls + self.worker_count = len(backend_urls) + # Cluster-facing instance id. ``OpenAIEndpointTracer.create`` + # (miles/rollout/generate_utils/openai_endpoint_utils.py) reads + # this from ``/health`` and stamps it on trial metadata; if it + # falls through to None, downstream assertions in the test suite + # (test_sessions.py, test_multi_turn.py) regress. See PR #31 H1. + self.session_server_instance_id = getattr(args, "session_server_instance_id", None) + self.app = FastAPI() + + timeout = getattr(args, "miles_router_timeout", 600.0) + # Connection pool sized for ~N backends * ~hundreds of in-flight + # requests each. Keepalive MUST be enabled here: every request is + # already explicitly routed by ``session_id``, so there's no + # "pinning to one backend" risk (the analogy to + # ``init_http_client``'s keepalive=0 setting does not apply). + # With keepalive=0 every request did a full TCP handshake and + # then sat in TIME_WAIT, leading to ephemeral port exhaustion + # under sustained load (see PR #31 deep review B1). + self.client = httpx.AsyncClient( + limits=httpx.Limits(max_connections=4096, max_keepalive_connections=1024), + timeout=httpx.Timeout(timeout), + ) + self.app.router.on_shutdown.append(self.client.aclose) + + # Round-robin counter for stateless paths (e.g. POST /sessions). + self._rr_counter = itertools.count() + + self._setup_routes() + + def pick_backend(self, path: str) -> str: + """Pick a backend URL for ``path``. + + Stateful paths (``/sessions/{id}/...``) parse the ``w-`` + prefix stamped onto the id by ``SessionRegistry.create_session`` + and route to the owning backend. Stateless paths round-robin so + we don't hot-spot worker 0. + + Rolling-deploy shrink safety net: if a session_id doesn't carry + the ``w-`` prefix or names a worker outside + ``[0, worker_count)`` (e.g. a trial in-flight from a previous + deploy with a larger worker_count, or a legacy bare-uuid id from + an N=1 run), we fall back to round-robin instead of 404. + The chosen backend's ``get_or_create_session`` will reseed the + session under a freshly-minted prefix; the trial loses state but + recovers in-place rather than dying mid-rollout. See PR #31 M. + """ + m = _SESSION_PATH_RE.match(path) + if m is not None: + session_id = m.group(1) + try: + idx = parse_worker_index(session_id, self.worker_count) + except ValueError as exc: + logger.info( + "[session-router] session_id %r doesn't route cleanly (%s); " + "falling back to round-robin for get_or_create reseed", + session_id, + exc, + ) + # next() on itertools.count() is atomic under the CPython GIL. + idx = next(self._rr_counter) % self.worker_count + else: + idx = next(self._rr_counter) % self.worker_count + return self.backend_urls[idx] + + # Hop-by-hop headers per RFC 7230 §6.1 — these must NOT be forwarded + # because they describe the single hop, not the end-to-end message. + # httpx will recompute content-length / transfer-encoding itself. + _HOP_BY_HOP_HEADERS = frozenset( + { + "connection", + "keep-alive", + "proxy-authenticate", + "proxy-authorization", + "te", + "trailers", + "transfer-encoding", + "upgrade", + "content-length", + } + ) + + async def proxy(self, request: Request) -> Response: + path = request.url.path + # pick_backend no longer raises for malformed session_ids — it + # falls back to round-robin so the backend's get_or_create_session + # can reseed. See PR #31 M. + backend = self.pick_backend(path) + url = f"{backend}{path}" + if request.url.query: + url = f"{url}?{request.url.query}" + + # Strip framing / host headers — httpx will recompute them. + # Stream the request body straight through (no buffering in + # router RAM) so multi-MB SGLang request payloads don't OOM + # the router at high concurrency. + req_headers = { + k: v + for k, v in request.headers.items() + if k.lower() not in self._HOP_BY_HOP_HEADERS and k.lower() != "host" + } + + # Use the streaming client.send path so both request and + # response bodies are streamed end-to-end. We must close the + # upstream response when the client disconnects; StreamingResponse + # does that by raising inside the generator. + upstream_req = self.client.build_request( + request.method, + url, + content=request.stream(), + headers=req_headers, + ) + try: + upstream_resp = await self.client.send(upstream_req, stream=True) + except httpx.TransportError as exc: + logger.warning( + "[session-router] backend transport error: %s %s -> %s: %s", + request.method, + path, + backend, + exc, + ) + # Log the full exception above; return a sanitized message to + # the caller so internal backend hostnames / paths can't leak. + return JSONResponse( + status_code=502, + content={"error": "session-router backend transport error"}, + ) + + # Filter hop-by-hop response headers per RFC 7230 §6.1. Also + # drop "server" (cosmetic — was stripped in the old buffered + # path). Keep content-type as-is so charset hints survive. + resp_headers = { + k: v + for k, v in upstream_resp.headers.items() + if k.lower() not in self._HOP_BY_HOP_HEADERS and k.lower() != "server" + } + + async def _body_stream(): + try: + async for chunk in upstream_resp.aiter_raw(): + yield chunk + finally: + await upstream_resp.aclose() + + return StreamingResponse( + _body_stream(), + status_code=upstream_resp.status_code, + headers=resp_headers, + media_type=upstream_resp.headers.get("content-type"), + ) + + def _setup_routes(self) -> None: + @self.app.get("/health") + async def health(): + # Front-end-local health: do NOT proxy. Operators want to + # tell "is the router itself up?" separately from "is any + # backend up?". Per-worker health is reachable via the + # backend ports directly during debugging. + return { + "status": "ok", + "role": "session-router", + "worker_count": self.worker_count, + "backends": self.backend_urls, + # H1: must be present so OpenAIEndpointTracer.create can + # stamp ``session_server_instance_id`` on trial metadata + # in multi-worker mode (the router is the user-facing + # ``session_url`` then). + "session_server_instance_id": self.session_server_instance_id, + } + + @self.app.api_route( + "/{path:path}", + methods=["GET", "POST", "PUT", "DELETE", "PATCH"], + ) + async def catchall(request: Request, path: str): + return await self.proxy(request) + + +def _make_reuseport_socket(host: str, port: int) -> socket.socket: + """Open a TCP listener bound with SO_REUSEPORT for kernel-level load balancing. + + Multiple processes can bind the same (host, port) when each one sets + SO_REUSEPORT before bind(). The Linux kernel (>= 3.9) then hash-distributes + incoming connections across the listener sockets, giving us K-way + parallelism without a userspace dispatcher. + + Requires Linux. macOS has SO_REUSEPORT but with different semantics + (last-bound wins for unicast); the multi-worker path is intended for + Slurm/Linux production. The single-worker (K=1) path is unchanged. + """ + family = socket.AF_INET6 if host and ":" in host else socket.AF_INET + sock = socket.socket(family=family, type=socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEPORT, 1) + sock.bind((host, port)) + # Match uvicorn's own default backlog. + sock.listen(2048) + sock.set_inheritable(True) + return sock + + +def run_session_router(args, backend_urls: list[str]): + """Entry point for the front-end process started by _start_session_server. + + Honors ``args.session_router_workers`` (default 1): + + * K=1: identical to pre-existing behavior (``uvicorn.run`` binds the + listening socket internally). + * K>1: the caller has already forked K worker processes; this one + opens a SO_REUSEPORT socket and hands it to + ``uvicorn.Server.run(sockets=[sock])``. The Linux kernel + distributes connections across the K workers. + + NOTE on the uvicorn API: as of uvicorn 0.40--0.49 (the versions we + currently ship), ``uvicorn.Config`` does **not** have a ``reuse_port`` + kwarg / attribute. The supported path for pre-bound listeners is + ``Server.run(sockets=[...])``. So we open the SO_REUSEPORT socket + ourselves in each worker and pass it in. See PR description for the + bench evidence motivating this change. + """ + setproctitle.setproctitle("miles-session-router") + router = SessionRouter(args, backend_urls) + router_worker_count = max(1, int(getattr(args, "session_router_workers", 1) or 1)) + if router_worker_count <= 1: + # Preserve exact pre-existing behavior — bit-for-bit identical to + # the previous implementation, log line included. + logger.info( + "[session-router] Starting on %s:%s, routing to %d backends: %s", + args.session_server_ip, + args.session_server_port, + len(backend_urls), + backend_urls, + ) + uvicorn.run( + router.app, + host=args.session_server_ip, + port=args.session_server_port, + log_level="info", + ) + return + + logger.info( + "[session-router] Starting on %s:%s, routing to %d backends: %s (router_workers=%d, SO_REUSEPORT)", + args.session_server_ip, + args.session_server_port, + len(backend_urls), + backend_urls, + router_worker_count, + ) + + # K>1: open our own SO_REUSEPORT socket and hand it to uvicorn.Server. + sock = _make_reuseport_socket(args.session_server_ip, args.session_server_port) + config = uvicorn.Config( + router.app, + host=args.session_server_ip, + port=args.session_server_port, + log_level="info", + loop="asyncio", + access_log=False, + ) + server = uvicorn.Server(config) + server.run(sockets=[sock]) diff --git a/miles/rollout/session/sessions.py b/miles/rollout/session/sessions.py index 0c285bf8bb..ed8840785d 100644 --- a/miles/rollout/session/sessions.py +++ b/miles/rollout/session/sessions.py @@ -1,3 +1,4 @@ +import asyncio import json import logging import time @@ -10,6 +11,7 @@ from miles.rollout.session.session_errors import ( SessionError, SessionNotFoundError, + SessionStateConflictError, TokenizationError, UpstreamResponseError, ) @@ -27,6 +29,8 @@ def setup_session_routes(app, backend, args): return session_server_instance_id = getattr(args, "session_server_instance_id", None) + worker_index = getattr(args, "session_server_worker_index", 0) + worker_count = getattr(args, "session_server_worker_count", 1) tokenizer = load_tokenizer( hf_checkpoint, chat_template_path=getattr(args, "chat_template_path", None), trust_remote_code=True @@ -38,13 +42,24 @@ def setup_session_routes(app, backend, args): allowed_append_roles=getattr(args, "tito_allowed_append_roles", None), ) - registry = SessionRegistry(args, tokenizer, tito_tokenizer=tito_tokenizer) + registry = SessionRegistry( + args, + tokenizer, + tito_tokenizer=tito_tokenizer, + worker_index=worker_index, + worker_count=worker_count, + ) @app.get("/health") async def health(): body = {"status": "ok"} if session_server_instance_id is not None: body["session_server_instance_id"] = session_server_instance_id + # Surface worker identity so operators can correlate logs across + # multi-process deployments. Always present (defaults 0/1) so + # parsers can rely on the schema. + body["worker_index"] = worker_index + body["worker_count"] = worker_count return body # --- DEBUG: track in-flight chat_completions --- @@ -155,7 +170,12 @@ async def chat_completions(request: Request, session_id: str): request_body["no_stop_trim"] = False request_messages = request_body.get("messages", []) - pretokenized = session.prepare_pretokenized( + # Run the sync tito-tokenizer call in a thread so the event + # loop isn't blocked while merge_tokens / chat-template render + # holds the GIL. At 300+ in-flight sessions this shaved ~40% + # off server p99 in microbench. + pretokenized = await asyncio.to_thread( + session.prepare_pretokenized, request_messages, tools=request_body.get("tools"), tito_tokenizer=registry.tito_tokenizer, @@ -239,14 +259,33 @@ async def chat_completions(request: Request, session_id: str): return backend.build_proxy_response(result) if session.num_assistant != expected_num_assistant: + # Another writer committed an assistant turn while we were + # in Phase 2 (unlocked proxy). We cannot commit this + # response: doing so would either (a) corrupt the + # trajectory's accumulated_token_ids prefix invariant, or + # (b) drop the state update and silently return a 200, + # causing the cursor-mismatch assertion in + # compute_samples_from_openai_records to fire downstream. + # Return 409 so the caller treats this turn as a + # retryable conflict and does not record it locally. + # See run 1711903 evidence in + # ~/run_analysis/1711903/1711903_errors_rca.md. logger.warning( f"Session {session_id} state changed during proxy " f"(expected num_assistant={expected_num_assistant}, " - f"got {session.num_assistant}), skipping state update" + f"got {session.num_assistant}), returning 409" + ) + raise SessionStateConflictError( + f"session {session_id} state changed during proxy " + f"(expected num_assistant={expected_num_assistant}, " + f"got {session.num_assistant})" ) - return backend.build_proxy_response(result) - session.update_pretokenized_state( + # Same rationale as the prepare_pretokenized call above — + # offload the sync merge_tokens / state update to a thread + # so concurrent in-flight sessions keep moving. + await asyncio.to_thread( + session.update_pretokenized_state, request_messages, assistant_message, prompt_token_ids=prompt_token_ids, diff --git a/miles/utils/arguments.py b/miles/utils/arguments.py index 1d1ada930a..f65b7cc94b 100644 --- a/miles/utils/arguments.py +++ b/miles/utils/arguments.py @@ -1632,6 +1632,30 @@ def add_session_arguments(parser): default=None, help="Port of the standalone session server. Auto-allocated if not set.", ) + parser.add_argument( + "--session-server-workers", + type=int, + default=1, + help="Number of session-server worker processes. When >1, a " + "lightweight ASGI front-end binds to --session-server-port and " + "consistent-hash-routes requests by session_id across N backend " + "workers, each on its own port. Each worker loads its own " + "tokenizer (N x memory). Default 1 preserves current single-" + "process behavior.", + ) + parser.add_argument( + "--session-router-workers", + type=int, + default=1, + help="Number of uvicorn worker processes for the session-router " + "ASGI front-end. When >1, K independent uvicorn workers bind to " + "--session-server-port via SO_REUSEPORT and the Linux kernel " + "hash-distributes incoming connections across them. The router " + "is stateless (per-process round-robin counter for stateless " + "paths, otherwise pure-function URL-prefix routing), so adding " + "workers is safe. Default 1 preserves the existing single-" + "process behavior. Recommended in prod: ~= backends / 2.", + ) parser.add_argument( "--tito-model", type=str, diff --git a/miles/utils/http_utils.py b/miles/utils/http_utils.py index 0aaf792659..3681152163 100644 --- a/miles/utils/http_utils.py +++ b/miles/utils/http_utils.py @@ -228,8 +228,18 @@ def init_http_client(args): _client_concurrency = args.sglang_server_concurrency * args.rollout_num_gpus // args.rollout_num_gpus_per_engine if _http_client is None: + # max_keepalive_connections=0: defeat connection reuse so each /run + # opens a fresh TCP. A pooled httpx.AsyncClient against a uvicorn + # multi-worker server pins all traffic to the few workers that + # originally accept()-won the pooled connections (uvicorn shares a + # single listen socket across --workers; no SO_REUSEPORT, no + # work-stealing). With keepalive off, every request runs its own + # accept() race and spreads across all workers. _http_client = httpx.AsyncClient( - limits=httpx.Limits(max_connections=_client_concurrency), + limits=httpx.Limits( + max_connections=_client_concurrency, + max_keepalive_connections=0, + ), timeout=httpx.Timeout(None), ) diff --git a/miles/utils/replay_base.py b/miles/utils/replay_base.py index 8e19b1ba67..e4a003a730 100644 --- a/miles/utils/replay_base.py +++ b/miles/utils/replay_base.py @@ -123,9 +123,7 @@ def _get_replay_result(top_indices, scores, topk, *args, **kwargs): _, sorted_free = masked_scores.sort(dim=1, descending=True) # The k-th -1 slot in each row gets sorted_free[row, k]. pad_cumsum = torch.cumsum(padding_mask.long(), dim=1) - 1 - fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to( - top_indices.dtype - ) + fill_values = torch.gather(sorted_free, 1, pad_cumsum.clamp(min=0)).to(top_indices.dtype) top_indices = torch.where(padding_mask, fill_values, top_indices) if return_probs: diff --git a/tests/fast/router/test_multi_worker_startup.py b/tests/fast/router/test_multi_worker_startup.py new file mode 100644 index 0000000000..941c21b8e0 --- /dev/null +++ b/tests/fast/router/test_multi_worker_startup.py @@ -0,0 +1,211 @@ +"""Smoke test for multi-process session-server startup. + +Does NOT load real tokenizers (would download HF weights). Instead patches +``run_session_server`` with a tiny fake uvicorn app and verifies that: + + * N worker processes are spawned on distinct ports. + * The front-end ASGI router comes up on ``session_server_port``. + * The router consistent-hash routes ``/sessions/{id}/...`` to the + correct backend. + +Importing this file is heavy (multiprocessing.Process), so it is grouped +with the rest of router fast tests but skips in CI environments where +binding ports is unreliable. +""" + +import socket +import threading +import time +import uuid +from http.server import BaseHTTPRequestHandler, HTTPServer +from types import SimpleNamespace + +import pytest +import requests + +from miles.utils.http_utils import find_available_port + + +class _BackendHandler(BaseHTTPRequestHandler): + """Echoes back the port it's serving on, so the test can verify routing.""" + + def log_message(self, format, *args): # noqa: A002 - stdlib API + pass # silence stderr in CI + + def do_GET(self): + port = self.server.server_address[1] + body = f'{{"port": {port}, "path": "{self.path}"}}'.encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) + + do_POST = do_GET + do_DELETE = do_GET + + +def _spawn_fake_backend(port: int) -> HTTPServer: + server = HTTPServer(("127.0.0.1", port), _BackendHandler) + threading.Thread(target=server.serve_forever, daemon=True).start() + return server + + +def _wait_port(port: int, timeout: float = 5.0) -> None: + deadline = time.time() + timeout + while time.time() < deadline: + try: + with socket.create_connection(("127.0.0.1", port), timeout=0.5): + return + except OSError: + time.sleep(0.05) + raise RuntimeError(f"port {port} never came up") + + +@pytest.fixture +def fake_backends(): + """Spin up 4 fake HTTP backends on consecutive ports.""" + base = find_available_port(40000) + ports = [] + servers = [] + # Find 4 free consecutive-ish ports (best-effort; doesn't have to be + # contiguous because we control the URL list passed to the router). + p = base + while len(ports) < 4: + try: + servers.append(_spawn_fake_backend(p)) + ports.append(p) + except OSError: + pass + p += 1 + for port in ports: + _wait_port(port) + yield ports + for s in servers: + s.shutdown() + + +def test_session_router_routes_by_session_id(fake_backends): + """End-to-end: SessionRouter started in-process, requests reach the right backend.""" + from miles.rollout.session.session_router import SessionRouter + from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + backend_urls = [f"http://127.0.0.1:{p}" for p in fake_backends] + args = SimpleNamespace(miles_router_timeout=10.0) + router = SessionRouter(args, backend_urls) + + router_port = find_available_port(41000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=router_port) + server.start() + try: + _wait_port(router_port) + worker_count = len(fake_backends) + + # Hit the router with N distinct session_ids carrying the + # ``w-`` prefix; verify each one reached the backend whose + # port matches its parsed worker index. + for _ in range(20): + for worker_index in range(worker_count): + sid = f"w{worker_index}-{uuid.uuid4().hex}" + expected_port = fake_backends[worker_index] + resp = requests.get( + f"http://127.0.0.1:{router_port}/sessions/{sid}/v1/chat/completions", + timeout=5, + ) + assert resp.status_code == 200, resp.text + assert ( + resp.json()["port"] == expected_port + ), f"sid={sid} expected port {expected_port}, got {resp.json()}" + finally: + server.stop() + + +def test_session_router_health_no_proxy(fake_backends): + """/health on the router returns its own status, does not hit any backend.""" + from miles.rollout.session.session_router import SessionRouter + from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + backend_urls = [f"http://127.0.0.1:{p}" for p in fake_backends] + args = SimpleNamespace(miles_router_timeout=10.0) + router = SessionRouter(args, backend_urls) + + router_port = find_available_port(42000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=router_port) + server.start() + try: + _wait_port(router_port) + resp = requests.get(f"http://127.0.0.1:{router_port}/health", timeout=5) + assert resp.status_code == 200 + body = resp.json() + assert body["status"] == "ok" + assert body["role"] == "session-router" + assert body["worker_count"] == len(fake_backends) + assert body["backends"] == backend_urls + finally: + server.stop() + + +def test_session_router_malformed_id_routes_via_round_robin(fake_backends): + """Malformed session_ids reach a backend rather than 404 (PR #31 M). + + Rolling-deploy shrink: an in-flight trial may hold a ``w-uuid`` + minted under a wider fleet. Instead of dying with 404, the router + routes it to a backend; the backend's ``get_or_create_session`` + reseeds the session cleanly under a fresh prefix. + """ + from miles.rollout.session.session_router import SessionRouter + from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + backend_urls = [f"http://127.0.0.1:{p}" for p in fake_backends] + args = SimpleNamespace(miles_router_timeout=10.0) + router = SessionRouter(args, backend_urls) + + router_port = find_available_port(44000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=router_port) + server.start() + try: + _wait_port(router_port) + # No prefix at all. + resp = requests.get( + f"http://127.0.0.1:{router_port}/sessions/legacy-bare-uuid/v1/chat/completions", + timeout=5, + ) + assert resp.status_code == 200, resp.text + assert resp.json()["port"] in fake_backends + # Out-of-range index (id minted under wider fleet). + resp = requests.get( + f"http://127.0.0.1:{router_port}/sessions/w99-{uuid.uuid4().hex}/v1/chat/completions", + timeout=5, + ) + assert resp.status_code == 200, resp.text + assert resp.json()["port"] in fake_backends + finally: + server.stop() + + +def test_session_router_health_exposes_instance_id(fake_backends): + """``/health`` includes session_server_instance_id (PR #31 H1). + + OpenAIEndpointTracer.create reads this field from /health and + stamps it on trial metadata; in multi-worker mode the router is + the user-facing session_url, so the field MUST be present here. + """ + from miles.rollout.session.session_router import SessionRouter + from miles.utils.test_utils.uvicorn_thread_server import UvicornThreadServer + + backend_urls = [f"http://127.0.0.1:{p}" for p in fake_backends] + instance_id = "test-instance-" + uuid.uuid4().hex[:8] + args = SimpleNamespace(miles_router_timeout=10.0, session_server_instance_id=instance_id) + router = SessionRouter(args, backend_urls) + + router_port = find_available_port(43000) + server = UvicornThreadServer(router.app, host="127.0.0.1", port=router_port) + server.start() + try: + _wait_port(router_port) + resp = requests.get(f"http://127.0.0.1:{router_port}/health", timeout=5) + assert resp.status_code == 200 + body = resp.json() + assert body["session_server_instance_id"] == instance_id + finally: + server.stop() diff --git a/tests/fast/router/test_router_multi_worker.py b/tests/fast/router/test_router_multi_worker.py new file mode 100644 index 0000000000..81b33d872e --- /dev/null +++ b/tests/fast/router/test_router_multi_worker.py @@ -0,0 +1,145 @@ +"""Unit tests for the multi-worker SO_REUSEPORT path in ``run_session_router``. + +The router is stateless (per-process round-robin counter, pure-function +URL-prefix routing), so running K parallel uvicorn workers behind the +same port is safe. These tests pin down the launch-time contract: + + * K=1 (default): existing single-process behavior is preserved exactly + (``uvicorn.run`` is called with the same kwargs as before; no + SO_REUSEPORT socket is opened). + * K>1: a SO_REUSEPORT socket is opened and handed to + ``uvicorn.Server.run(sockets=[sock])``. + * Per-router-worker args copy carries a distinguishable + ``session_server_instance_id`` so log lines from different worker + PIDs can be told apart (``-router{i}`` suffix applied in + ``_start_session_server``). + +These tests do NOT spawn real processes or bind real ports — they patch +the uvicorn entrypoints and assert on the call arguments. That makes +them fast and safe to run on macOS CI (SO_REUSEPORT semantics differ +there; the multi-worker path is Linux-only at runtime). +""" + +from types import SimpleNamespace +from unittest.mock import patch + +import pytest + + +# --------------------------------------------------------------------------- +# run_session_router uvicorn invocation +# --------------------------------------------------------------------------- + + +@pytest.fixture +def fake_router_args(): + """Minimal args namespace accepted by SessionRouter + run_session_router.""" + return SimpleNamespace( + session_server_ip="127.0.0.1", + session_server_port=0, # we never actually bind in these tests + session_server_instance_id="test-instance", + miles_router_timeout=10.0, + session_router_workers=1, + ) + + +def test_run_session_router_single_worker_calls_uvicorn_run(fake_router_args): + """K=1 path: must call ``uvicorn.run`` with the legacy kwargs, no + sockets, no SO_REUSEPORT. + + This is the backward-compat guarantee — production deployments that + don't opt in must see literally zero behavior change. + """ + from miles.rollout.session import session_router as sr + + backend_urls = ["http://127.0.0.1:6001", "http://127.0.0.1:6002"] + with patch.object(sr, "uvicorn") as mock_uvicorn, patch.object(sr, "_make_reuseport_socket") as mock_sock: + sr.run_session_router(fake_router_args, backend_urls) + # The legacy code-path is uvicorn.run(...). Server / Config must + # not be touched, and crucially no SO_REUSEPORT socket is opened. + mock_uvicorn.run.assert_called_once() + _args, kwargs = mock_uvicorn.run.call_args + assert kwargs["host"] == "127.0.0.1" + assert kwargs["port"] == 0 + assert kwargs["log_level"] == "info" + mock_uvicorn.Config.assert_not_called() + mock_uvicorn.Server.assert_not_called() + mock_sock.assert_not_called() + + +def test_run_session_router_multi_worker_uses_reuseport_socket(fake_router_args): + """K>1 path: must open a SO_REUSEPORT socket and hand it to + ``uvicorn.Server.run(sockets=[sock])``. ``uvicorn.run`` must NOT + be called (it would bind its own socket and conflict). + """ + from miles.rollout.session import session_router as sr + + fake_router_args.session_router_workers = 4 + backend_urls = ["http://127.0.0.1:6001", "http://127.0.0.1:6002"] + fake_sock = object() + with patch.object(sr, "uvicorn") as mock_uvicorn, patch.object( + sr, "_make_reuseport_socket", return_value=fake_sock + ) as mock_sock: + sr.run_session_router(fake_router_args, backend_urls) + # Legacy uvicorn.run MUST NOT be used in the multi-worker path. + mock_uvicorn.run.assert_not_called() + # SO_REUSEPORT socket must be opened at the configured host:port. + mock_sock.assert_called_once_with("127.0.0.1", 0) + # Server.run must receive the pre-bound socket. + mock_uvicorn.Config.assert_called_once() + cfg_args, cfg_kwargs = mock_uvicorn.Config.call_args + assert cfg_kwargs["host"] == "127.0.0.1" + assert cfg_kwargs["port"] == 0 + assert cfg_kwargs["log_level"] == "info" + # Server() instantiated with the Config; .run(sockets=[sock]) called. + mock_uvicorn.Server.assert_called_once() + server_instance = mock_uvicorn.Server.return_value + server_instance.run.assert_called_once_with(sockets=[fake_sock]) + + +def test_run_session_router_default_treats_missing_attr_as_one(fake_router_args): + """A pre-existing args object built before this PR landed will not have + ``session_router_workers``. Treat that the same as K=1 (legacy path) + so a partial deploy doesn't silently flip behavior. + """ + from miles.rollout.session import session_router as sr + + del fake_router_args.session_router_workers + backend_urls = ["http://127.0.0.1:6001"] + with patch.object(sr, "uvicorn") as mock_uvicorn, patch.object(sr, "_make_reuseport_socket") as mock_sock: + sr.run_session_router(fake_router_args, backend_urls) + mock_uvicorn.run.assert_called_once() + mock_sock.assert_not_called() + + +# --------------------------------------------------------------------------- +# per-worker args copy: -router{i} suffix on session_server_instance_id +# --------------------------------------------------------------------------- + + +def test_per_worker_args_copy_isolated_from_caller(): + """``_per_worker_args_copy`` must return an independent copy so we can + safely overwrite ``session_server_instance_id`` per worker without + aliasing the caller's args. + + The router-worker spawn loop in ``_start_session_server`` stamps + ``-router{i}`` onto the copy; this test pins down that the stamp on + one copy does not bleed into the next copy or the original. + """ + from miles.ray.rollout import _per_worker_args_copy + + args = SimpleNamespace(session_server_instance_id="base-id", some_list=[1, 2, 3]) + c0 = _per_worker_args_copy(args) + c1 = _per_worker_args_copy(args) + c0.session_server_instance_id = f"{args.session_server_instance_id}-router0" + c1.session_server_instance_id = f"{args.session_server_instance_id}-router1" + assert c0.session_server_instance_id == "base-id-router0" + assert c1.session_server_instance_id == "base-id-router1" + # Caller's args is untouched. + assert args.session_server_instance_id == "base-id" + # And the nested mutable is deep-copied — mutating c0 doesn't leak + # into c1 or the original (this is the invariant `_per_worker_args_copy` + # exists to guarantee). + c0.some_list.append(99) + assert args.some_list == [1, 2, 3] + assert c1.some_list == [1, 2, 3] diff --git a/tests/fast/router/test_session_race_conditions.py b/tests/fast/router/test_session_race_conditions.py index 95c6a69cba..3ba2ef862f 100644 --- a/tests/fast/router/test_session_race_conditions.py +++ b/tests/fast/router/test_session_race_conditions.py @@ -100,7 +100,8 @@ def test_same_session_concurrent_requests_reach_backend(self): Phase 2 (proxy) runs without the lock, so concurrent requests are not serialized at the backend level. Phase 3 state updates are still serialized; the stale-update guard ensures only one writer wins per - generation, so no state corruption occurs. + generation — losers receive 409 SessionStateConflictError so the + caller doesn't record a phantom turn (see run 1711903 evidence). """ def process_fn(prompt: str) -> ProcessResult: @@ -128,12 +129,16 @@ def process_fn(prompt: str) -> ProcessResult: futures = [pool.submit(_chat, env.url, session_id, retry_payload) for _ in range(4)] responses = [f.result(timeout=30.0) for f in futures] - # All requests should succeed (200) — no 500s. - assert all(resp.status_code == 200 for resp in responses) + # All requests reach the backend (split-lock allows concurrency). assert len(env.backend.request_log) == 4 - # With split-lock, concurrent backend access is expected (not == 1). assert env.backend.max_concurrent >= 1 + # Exactly one writer wins (200); losers get 409 conflict — no 500s, + # and no silent 200 that would drop the state update. + status_codes = sorted(resp.status_code for resp in responses) + assert all(c in (200, 409) for c in status_codes), f"Unexpected codes: {status_codes}" + assert status_codes.count(200) == 1, f"Expected exactly 1 winner, got {status_codes}" + def test_different_sessions_can_run_in_parallel(self): def process_fn(prompt: str) -> ProcessResult: return ProcessResult(text="parallel-ok", finish_reason="stop") @@ -420,3 +425,122 @@ def lifecycle_cycle(idx: int) -> bool: results = [f.result(timeout=60.0) for f in futures] assert all(results) + + +class TestStateConflictNoCursorMismatch: + """Regression: stale-update guard must not silently drop state. + + Before this fix, when ``session.num_assistant`` changed during the + unlocked proxy phase, the chat handler returned the SGLang response + body with HTTP 200 but skipped both ``update_pretokenized_state`` and + ``append_record``. The caller treated the 200 as a real turn and + appended the assistant message to its local trajectory, so on a later + ``compute_samples_from_openai_records`` the assertion + ``cursor == len(accumulated_token_ids)`` fired with a delta equal to + the dropped turn's token count. + + Evidence: run 1711903 — 24 ``state changed during proxy`` warnings + produced 2 cursor-mismatch failures with deltas of 88 and 102 tokens. + See ``~/run_analysis/1711903/1711903_errors_rca.md``. + + The fix returns 409 so the caller does NOT record the dropped turn. + """ + + def test_state_conflict_returns_409_and_session_records_match_token_ids(self): + """Concurrent same-session writers: exactly one wins; session state + and records remain mutually consistent (cursor invariant holds). + """ + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="conflict-test", finish_reason="stop") + + with _router_env(process_fn, latency=0.2) as env: + session_id = _create_session(env.url) + + # Warm up an assistant checkpoint so retry payloads are valid. + warmup_payload = {"messages": [{"role": "user", "content": "warmup"}]} + warmup_resp = _chat(env.url, session_id, warmup_payload) + assert warmup_resp.status_code == 200 + assistant = warmup_resp.json()["choices"][0]["message"] + + retry_payload = { + "messages": [ + {"role": "user", "content": "warmup"}, + assistant, + {"role": "system", "content": "retry-conflict"}, + ] + } + + with ThreadPoolExecutor(max_workers=4) as pool: + futures = [pool.submit(_chat, env.url, session_id, retry_payload) for _ in range(4)] + responses = [f.result(timeout=30.0) for f in futures] + + status_codes = sorted(r.status_code for r in responses) + # Exactly one writer commits; the other three see the bumped + # num_assistant and get 409. + assert status_codes.count(200) == 1, f"expected 1 winner, got {status_codes}" + assert status_codes.count(409) == 3, f"expected 3 conflicts, got {status_codes}" + + # 409 body carries an "error" field naming the conflict so + # callers can surface a useful retry message. + for resp in responses: + if resp.status_code == 409: + body = resp.json() + assert "error" in body + assert "state changed during proxy" in body["error"] + + # Session state must remain self-consistent: records list (one + # per committed turn) should align with the number of assistant + # checkpoints in trajectory_token_ids, so cursor walking in + # compute_samples_from_openai_records will succeed. + get_resp = requests.get(f"{env.url}/sessions/{session_id}", timeout=5.0) + assert get_resp.status_code == 200 + body = get_resp.json() + # warmup (1) + exactly one conflict-winner (1) = 2 records. + assert len(body["records"]) == 2, f"expected 2 records, got {len(body['records'])}" + # accumulated_token_ids reflects the latest checkpoint; non-empty. + assert body["metadata"]["accumulated_token_ids"], "accumulated_token_ids must be non-empty" + + def test_serial_followup_after_conflict_uses_winner_state(self): + """A serial turn after a conflict-resolved burst must succeed and + build on the winner's checkpoint — no stale-state crash.""" + + def process_fn(prompt: str) -> ProcessResult: + return ProcessResult(text="serial-followup", finish_reason="stop") + + with _router_env(process_fn, latency=0.15) as env: + session_id = _create_session(env.url) + warmup_resp = _chat(env.url, session_id, {"messages": [{"role": "user", "content": "warm"}]}) + assert warmup_resp.status_code == 200 + assistant1 = warmup_resp.json()["choices"][0]["message"] + + retry_payload = { + "messages": [ + {"role": "user", "content": "warm"}, + assistant1, + {"role": "system", "content": "retry-burst"}, + ] + } + with ThreadPoolExecutor(max_workers=3) as pool: + futures = [pool.submit(_chat, env.url, session_id, retry_payload) for _ in range(3)] + burst = [f.result(timeout=30.0) for f in futures] + + winners = [r for r in burst if r.status_code == 200] + assert len(winners) == 1 + assistant2 = winners[0].json()["choices"][0]["message"] + + # Build the next message list on top of the winner. + followup_payload = { + "messages": [ + {"role": "user", "content": "warm"}, + assistant1, + {"role": "system", "content": "retry-burst"}, + assistant2, + {"role": "system", "content": "after-conflict"}, + ] + } + followup_resp = _chat(env.url, session_id, followup_payload) + assert followup_resp.status_code == 200, ( + f"serial followup after conflict-resolved burst failed: {followup_resp.status_code} " + f"{followup_resp.text}" + ) diff --git a/tests/fast/router/test_session_uuid_routing.py b/tests/fast/router/test_session_uuid_routing.py new file mode 100644 index 0000000000..f8472a9089 --- /dev/null +++ b/tests/fast/router/test_session_uuid_routing.py @@ -0,0 +1,188 @@ +"""Unit tests for multi-process session routing. + +These verify the load-bearing invariant of the multi-process session-server +design: every session_id that ``SessionRegistry.create_session`` returns +parses — via the prefix-encoding contract — back to the worker that +created it. If this ever breaks, sticky routing breaks, and the +session-server falls back to the auto-reseed path on every turn (silently +losing state). +""" + +from types import SimpleNamespace +from typing import Any + +import pytest + +from miles.rollout.session.linear_trajectory import SessionRegistry +from miles.rollout.session.session_router import parse_worker_index +from miles.utils.chat_template_utils.tito_tokenizer import TITOTokenizer + + +class _MockTITOTokenizer(TITOTokenizer): + """Stub: no real tokenizer work needed for routing tests.""" + + def create_comparator(self): + return None + + def tokenize_additional_non_assistant( + self, + old_messages: list[dict[str, Any]], + new_messages: list[dict[str, Any]], + tools: list[dict[str, Any]] | None = None, + ) -> list[int]: + return [] + + def merge_tokens( + self, + old_messages: list[dict[str, Any]], + new_messages: list[dict[str, Any]], + pretokenized_token_ids: list[int], + tools: list[dict[str, Any]] | None = None, + ) -> list[int]: + return list(pretokenized_token_ids) + + +def _make_registry(worker_index: int, worker_count: int) -> SessionRegistry: + args = SimpleNamespace() + mock_tito = _MockTITOTokenizer( + tokenizer=None, + assistant_start_str="<|im_start|>assistant", + allowed_append_roles=None, + ) + return SessionRegistry( + args, + tokenizer=None, + tito_tokenizer=mock_tito, + worker_index=worker_index, + worker_count=worker_count, + ) + + +class TestParseWorkerIndex: + def test_parses_well_formed_prefix(self): + assert parse_worker_index("w0-abcdef", 4) == 0 + assert parse_worker_index("w3-abcdef", 4) == 3 + assert parse_worker_index("w7-deadbeef", 8) == 7 + + def test_rejects_missing_prefix(self): + # Bare uuid hex (the single-worker shape) has no w- prefix. + with pytest.raises(ValueError): + parse_worker_index("a" * 32, 4) + + def test_rejects_non_numeric_prefix(self): + with pytest.raises(ValueError): + parse_worker_index("wx-abcdef", 4) + + def test_rejects_out_of_range_index(self): + # Worker minted with worker_count=8 but router is now running + # with worker_count=4 — the parsed index is out of range. + with pytest.raises(ValueError): + parse_worker_index("w5-abcdef", 4) + + def test_rejects_negative_index_via_no_match(self): + # Regex `^w(\d+)-` doesn't accept a minus sign, so this is a + # missing-prefix failure rather than an out-of-range failure. + with pytest.raises(ValueError): + parse_worker_index("w-1-abcdef", 4) + + +class TestSessionRegistryRouting: + @pytest.mark.parametrize("worker_count", [1, 2, 4, 8, 16]) + def test_create_session_id_parses_to_self(self, worker_count: int): + """Every session_id created by worker i must parse back to index i.""" + for worker_index in range(worker_count): + registry = _make_registry(worker_index, worker_count) + for _ in range(50): + sid = registry.create_session() + if worker_count == 1: + # Single-worker deployments keep emitting bare uuid hex + # for back-compat; routing is trivially worker 0. + assert len(sid) == 32 + else: + assert parse_worker_index(sid, worker_count) == worker_index, ( + f"session_id {sid} from worker {worker_index}/{worker_count} " + f"did not parse to its own index" + ) + + def test_default_single_worker_behavior(self): + """worker_count=1 is the existing behavior; bare uuid hex.""" + registry = _make_registry(0, 1) + sid = registry.create_session() + assert len(sid) == 32 # uuid4 hex, no prefix + assert sid in registry.sessions + + def test_multi_worker_ids_carry_prefix(self): + """worker_count>1 ids must carry the Stripe-style w- prefix.""" + registry = _make_registry(worker_index=2, worker_count=8) + sid = registry.create_session() + assert sid.startswith("w2-") + assert sid in registry.sessions + + def test_invalid_worker_index(self): + with pytest.raises(ValueError): + _make_registry(worker_index=5, worker_count=4) + with pytest.raises(ValueError): + _make_registry(worker_index=-1, worker_count=4) + + def test_invalid_worker_count(self): + with pytest.raises(ValueError): + _make_registry(worker_index=0, worker_count=0) + + +class TestRouterAgreement: + """The front-end router and SessionRegistry must agree on the routing + contract — now a prefix parse, not a hash. With prefix encoding there + is no "shared algorithm" to drift on, but the contract still has to + hold end-to-end. + """ + + @pytest.mark.parametrize("worker_count", [2, 3, 4, 7, 8]) + def test_session_router_pick_matches_creator(self, worker_count: int): + from miles.rollout.session.session_router import SessionRouter + + args = SimpleNamespace(miles_router_timeout=1.0) + backends = [f"http://127.0.0.1:{6000 + i}" for i in range(worker_count)] + router = SessionRouter(args, backends) + + for worker_index in range(worker_count): + registry = _make_registry(worker_index, worker_count) + for _ in range(20): + sid = registry.create_session() + # The router's pick_backend on a stateful path must + # match the URL of the worker that created the session. + picked = router.pick_backend(f"/sessions/{sid}/v1/chat/completions") + assert picked == backends[worker_index], ( + f"session_id={sid} created by worker {worker_index} " f"but router routed to {picked}" + ) + + def test_router_unknown_session_id_falls_back_to_round_robin(self): + """Malformed/out-of-range session_ids round-robin to a backend. + + Rolling-deploy safety net: rather than 404, the router falls + back to round-robin so the backend's ``get_or_create_session`` + can reseed. See PR #31 finding M. + """ + from miles.rollout.session.session_router import SessionRouter + + args = SimpleNamespace(miles_router_timeout=1.0) + backends = [f"http://127.0.0.1:{6000 + i}" for i in range(4)] + router = SessionRouter(args, backends) + + # No w- prefix -> round-robin, never raises. + picks_no_prefix = [router.pick_backend("/sessions/badid_no_prefix/v1/chat/completions") for _ in range(40)] + assert set(picks_no_prefix) == set(backends) + + # Out-of-range worker index (id minted under wider fleet) -> round-robin. + picks_oor = [router.pick_backend("/sessions/w9-deadbeef/v1/chat/completions") for _ in range(40)] + assert set(picks_oor) == set(backends) + + def test_router_stateless_paths_round_robin(self): + """POST /sessions (no id) and other unmatched paths should not pin to one backend.""" + from miles.rollout.session.session_router import SessionRouter + + args = SimpleNamespace(miles_router_timeout=1.0) + backends = [f"http://127.0.0.1:{7000 + i}" for i in range(4)] + router = SessionRouter(args, backends) + picks = [router.pick_backend("/sessions") for _ in range(40)] + # Round-robin: every backend must appear at least once. + assert set(picks) == set(backends) diff --git a/tests/fast/utils/chat_template_utils/test_tito_k2v3.py b/tests/fast/utils/chat_template_utils/test_tito_k2v3.py index e3b394db0b..35e5ece3c3 100644 --- a/tests/fast/utils/chat_template_utils/test_tito_k2v3.py +++ b/tests/fast/utils/chat_template_utils/test_tito_k2v3.py @@ -60,15 +60,8 @@ from miles.rollout.session.linear_trajectory import LinearTrajectory from miles.rollout.session.session_errors import TokenizationError -from miles.utils.chat_template_utils import ( - MismatchType, - apply_chat_template, - try_get_fixed_chat_template, -) -from miles.utils.chat_template_utils.tito_tokenizer import ( - TITOTokenizerType, - get_tito_tokenizer, -) +from miles.utils.chat_template_utils import MismatchType, apply_chat_template, try_get_fixed_chat_template +from miles.utils.chat_template_utils.tito_tokenizer import TITOTokenizerType, get_tito_tokenizer from miles.utils.processing_utils import load_tokenizer from miles.utils.test_utils.mock_trajectories import ( LongChainThinkingTrajectory, @@ -80,7 +73,6 @@ SingleToolTrajectory, ) - # --------------------------------------------------------------------------- # Path + fixtures # --------------------------------------------------------------------------- @@ -140,6 +132,7 @@ def tito_tok(tokenizer): # Trajectories — realistic conversation shapes from mock_trajectories # --------------------------------------------------------------------------- + def _with_synthetic_thinking( trajectory_cls: type, reasoning: str = "Let me work through this step by step.", @@ -171,19 +164,18 @@ class _Synthesized: # ... \ncontent<|im_end|>). CONVERSATIONS: list[tuple[str, type]] = [ # Single assistant turn — single tool call. - ("single_tool", SingleToolTrajectory), - ("single_tool_thinking", SingleToolThinkingTrajectory), + ("single_tool", SingleToolTrajectory), + ("single_tool_thinking", SingleToolThinkingTrajectory), # Multiple assistant turns — single tool call per turn. - ("multi_turn", MultiTurnTrajectory), - ("multi_turn_thinking", MultiTurnThinkingTrajectory), + ("multi_turn", MultiTurnTrajectory), + ("multi_turn_thinking", MultiTurnThinkingTrajectory), # Single assistant turn — multiple parallel tool calls. - ("multi_tool_single_turn", MultiToolSingleTurnTrajectory), + ("multi_tool_single_turn", MultiToolSingleTurnTrajectory), # No native thinking variant exists for parallel-tools-single-turn; # synthesize by injecting reasoning_content into the assistant turn. - ("multi_tool_single_turn_thinking", - _with_synthetic_thinking(MultiToolSingleTurnTrajectory)), + ("multi_tool_single_turn_thinking", _with_synthetic_thinking(MultiToolSingleTurnTrajectory)), # Multiple assistant turns AND tool calls (chain shape). - ("multi_tool_multi_turn", LongChainTrajectory), + ("multi_tool_multi_turn", LongChainTrajectory), ("multi_tool_multi_turn_thinking", LongChainThinkingTrajectory), ] @@ -262,25 +254,28 @@ def _realistic_emit_ids( emit_ids = tokenizer.encode(emit_text) """ full_text = _render_text( - request_messages + [assistant_message], tokenizer, tools, + request_messages + [assistant_message], + tokenizer, + tools, add_generation_prompt=False, ) prompt_text = _render_text( - request_messages, tokenizer, tools, + request_messages, + tokenizer, + tools, add_generation_prompt=True, ) assert full_text.startswith(prompt_text), ( "chat template not append-only: prompt-only render is not a prefix " "of full render. TITO's premise breaks here." ) - emit_text = full_text[len(prompt_text):] + emit_text = full_text[len(prompt_text) :] # Strip the trailing newline(s) the jinja whitespace adds after # `<|im_end|>`. The model autoregressively stops at the stop token # without producing them. emit_text_stop = emit_text.rstrip("\n") assert emit_text_stop.endswith("<|im_end|>"), ( - f"unexpected emit_text shape (does not end with <|im_end|>): " - f"{emit_text_stop!r}" + f"unexpected emit_text shape (does not end with <|im_end|>): " f"{emit_text_stop!r}" ) return list(tokenizer.encode(emit_text_stop, add_special_tokens=False)) @@ -305,15 +300,15 @@ def _drive_session_through_trajectory( pre = session.prepare_pretokenized(request_messages, tools, tito_tokenizer=tito_tok) if pre is None: prompt_ids = _render_ids( - request_messages, tito_tok.tokenizer, tools, + request_messages, + tito_tok.tokenizer, + tools, add_generation_prompt=True, ) else: prompt_ids = list(pre["input_ids"]) - emit_ids = _realistic_emit_ids( - request_messages, assistant_message, tools, tito_tok.tokenizer - ) + emit_ids = _realistic_emit_ids(request_messages, assistant_message, tools, tito_tok.tokenizer) session.update_pretokenized_state( request_messages=request_messages, @@ -369,19 +364,19 @@ def test_buffer_matches_canonical_under_realistic_rollout(name, trajectory_cls, # end-of-sequence ``<|im_end|>`` vs ``<|im_end|>\\n`` differences if the # trajectory has only ONE assistant turn). expected_final = _render_ids( - session.messages, tito_tok.tokenizer, tools, + session.messages, + tito_tok.tokenizer, + tools, add_generation_prompt=False, ) actual_final = list(session.token_ids) severe_final = [ - m for m in comparator.compare_sequences(expected_final, actual_final) - if m.type != MismatchType.ASSISTANT_TEXT + m for m in comparator.compare_sequences(expected_final, actual_final) if m.type != MismatchType.ASSISTANT_TEXT ] if severe_final: details = "\n".join( f" {m.type.value} at segment {m.segment_index}: " - f"expected={m.expected_text!r} actual={m.actual_text!r}" - + (f" — {m.detail}" if m.detail else "") + f"expected={m.expected_text!r} actual={m.actual_text!r}" + (f" — {m.detail}" if m.detail else "") for m in severe_final[:5] ) pytest.fail( @@ -407,18 +402,18 @@ def test_buffer_matches_canonical_under_realistic_rollout(name, trajectory_cls, ) merged = list(pre["input_ids"]) expected_next = _render_ids( - extended_messages, tito_tok.tokenizer, tools, + extended_messages, + tito_tok.tokenizer, + tools, add_generation_prompt=True, ) severe_next = [ - m for m in comparator.compare_sequences(expected_next, merged) - if m.type != MismatchType.ASSISTANT_TEXT + m for m in comparator.compare_sequences(expected_next, merged) if m.type != MismatchType.ASSISTANT_TEXT ] if severe_next: details = "\n".join( f" {m.type.value} at segment {m.segment_index}: " - f"expected={m.expected_text!r} actual={m.actual_text!r}" - + (f" — {m.detail}" if m.detail else "") + f"expected={m.expected_text!r} actual={m.actual_text!r}" + (f" — {m.detail}" if m.detail else "") for m in severe_next[:5] ) pytest.fail( @@ -442,6 +437,7 @@ def test_buffer_matches_canonical_under_realistic_rollout(name, trajectory_cls, class _EnvAppendShape: """Generic env append shape — the messages to be appended after the session has been driven through some trajectory.""" + name: str appended_messages: list[dict] required_contents: tuple[str, ...] @@ -458,8 +454,7 @@ class _EnvAppendShape: _EnvAppendShape( name="env_tool", appended_messages=[ - {"role": "tool", "tool_call_id": "call_test_xyz", - "content": "_marker_tool_xyz_42_"}, + {"role": "tool", "tool_call_id": "call_test_xyz", "content": "_marker_tool_xyz_42_"}, ], required_contents=("_marker_tool_xyz_42_",), ), @@ -480,11 +475,9 @@ class _EnvAppendShape: _EnvAppendShape( name="env_alternating_user_tool", appended_messages=[ - {"role": "tool", "tool_call_id": "call_alt_1", - "content": "_marker_alt_tool1_aaa_"}, + {"role": "tool", "tool_call_id": "call_alt_1", "content": "_marker_alt_tool1_aaa_"}, {"role": "user", "content": "_marker_alt_user1_bbb_"}, - {"role": "tool", "tool_call_id": "call_alt_2", - "content": "_marker_alt_tool2_ccc_"}, + {"role": "tool", "tool_call_id": "call_alt_2", "content": "_marker_alt_tool2_ccc_"}, {"role": "user", "content": "_marker_alt_user2_ddd_"}, ], required_contents=( @@ -498,11 +491,14 @@ class _EnvAppendShape: @pytest.mark.parametrize( - "traj_name, traj_cls", CONVERSATIONS, + "traj_name, traj_cls", + CONVERSATIONS, ids=lambda x: x if isinstance(x, str) else None, ) @pytest.mark.parametrize( - "env_shape", _ENV_APPEND_SHAPES, ids=lambda s: s.name, + "env_shape", + _ENV_APPEND_SHAPES, + ids=lambda s: s.name, ) def test_append_via_realistic_buffer(traj_name, traj_cls, env_shape, tito_tok): """Invariants I3+I4 (core): ``merge_tokens`` against a realistic @@ -525,10 +521,7 @@ def test_append_via_realistic_buffer(traj_name, traj_cls, env_shape, tito_tok): _drive_session_through_trajectory(session, tito_tok, messages, tools) pretokenized_buffer = list(session.token_ids) - assert ( - pretokenized_buffer - and pretokenized_buffer[-1] == tito_tok._im_end_id - ), ( + assert pretokenized_buffer and pretokenized_buffer[-1] == tito_tok._im_end_id, ( f"K2V3 [{traj_name} + {env_shape.name}] setup error: pretokenized " f"buffer should end at <|im_end|> after drive, got last token " f"{pretokenized_buffer[-1] if pretokenized_buffer else 'EMPTY'}" @@ -544,20 +537,18 @@ def test_append_via_realistic_buffer(traj_name, traj_cls, env_shape, tito_tok): merged = list(pre["input_ids"]) expected = _render_ids( - extended, tito_tok.tokenizer, tools, + extended, + tito_tok.tokenizer, + tools, add_generation_prompt=True, ) comparator = tito_tok.create_comparator() - severe = [ - m for m in comparator.compare_sequences(expected, merged) - if m.type != MismatchType.ASSISTANT_TEXT - ] + severe = [m for m in comparator.compare_sequences(expected, merged) if m.type != MismatchType.ASSISTANT_TEXT] if severe: details = "\n".join( f" {m.type.value} at segment {m.segment_index}: " - f"expected={m.expected_text!r} actual={m.actual_text!r}" - + (f" — {m.detail}" if m.detail else "") + f"expected={m.expected_text!r} actual={m.actual_text!r}" + (f" — {m.detail}" if m.detail else "") for m in severe[:5] ) pytest.fail( @@ -567,9 +558,7 @@ def test_append_via_realistic_buffer(traj_name, traj_cls, env_shape, tito_tok): ) # required-contents-in-order check on the incremental segment. - incremental_text = tito_tok.tokenizer.decode( - merged[len(pretokenized_buffer):], skip_special_tokens=False - ) + incremental_text = tito_tok.tokenizer.decode(merged[len(pretokenized_buffer) :], skip_special_tokens=False) cursor = 0 for content in env_shape.required_contents: found = incremental_text.find(content, cursor) @@ -623,16 +612,19 @@ def _load_sglang_parsers(): fcp_cls = None try: from sglang.srt.function_call.function_call_parser import FunctionCallParser + fcp_cls = FunctionCallParser except ImportError: pass rp_cls = None try: from sglang.srt.parser.reasoning_parser import ReasoningParser + rp_cls = ReasoningParser except ImportError: try: from sglang.srt.reasoning_parser import ReasoningParser # older SGLang layout + rp_cls = ReasoningParser except ImportError: pass @@ -644,6 +636,7 @@ def _try_json_decode_tool_args(tool_calls: list[dict]) -> list[dict]: Hermes parser returns it as a JSON string. Decode for template compatibility — this mirrors what production agent loops do.""" import json + out = [] for tc in tool_calls: fn = tc.get("function", {}) @@ -658,7 +651,8 @@ def _try_json_decode_tool_args(tool_calls: list[dict]) -> list[dict]: @pytest.mark.parametrize( - "traj_name, traj_cls", CONVERSATIONS, + "traj_name, traj_cls", + CONVERSATIONS, ids=lambda x: x if isinstance(x, str) else None, ) def test_chat_template_round_trip_through_real_sglang_parsers(traj_name, traj_cls, tito_tok): @@ -696,21 +690,23 @@ def test_chat_template_round_trip_through_real_sglang_parsers(traj_name, traj_cl # 1) Render truth_msg via chat_template — that is the raw emit shape. full_text = _render_text( - request_messages + [truth_msg], tokenizer, tools, + request_messages + [truth_msg], + tokenizer, + tools, add_generation_prompt=False, ) prompt_text = _render_text( - request_messages, tokenizer, tools, + request_messages, + tokenizer, + tools, add_generation_prompt=True, ) assert full_text.startswith(prompt_text), ( - f"K2V3 [{traj_name}] chat template not append-only: prompt-only " - f"render is not a prefix of full render." + f"K2V3 [{traj_name}] chat template not append-only: prompt-only " f"render is not a prefix of full render." ) - raw_assistant_emit = full_text[len(prompt_text):].rstrip("\n") + raw_assistant_emit = full_text[len(prompt_text) :].rstrip("\n") assert raw_assistant_emit.endswith("<|im_end|>"), ( - f"K2V3 [{traj_name}] unexpected raw_assistant_emit shape: " - f"{raw_assistant_emit!r}" + f"K2V3 [{traj_name}] unexpected raw_assistant_emit shape: " f"{raw_assistant_emit!r}" ) # 2) Run real ReasoningParser on the raw emit (only if the trajectory's @@ -724,10 +720,7 @@ def test_chat_template_round_trip_through_real_sglang_parsers(traj_name, traj_cl try: rp = RP(model_type=_K2V3_REASONING_PARSER) except Exception as e: - pytest.skip( - f"reasoning parser {_K2V3_REASONING_PARSER!r} unsupported " - f"by this SGLang build: {e}" - ) + pytest.skip(f"reasoning parser {_K2V3_REASONING_PARSER!r} unsupported " f"by this SGLang build: {e}") r_out, n_out = rp.parse_non_stream(raw_assistant_emit) parsed_reasoning = r_out or "" text_after_reasoning = n_out if n_out is not None else "" @@ -741,10 +734,7 @@ def test_chat_template_round_trip_through_real_sglang_parsers(traj_name, traj_cl try: fcp = FCP(tools=sglang_tools, tool_call_parser=_K2V3_TOOL_PARSER) except Exception as e: - pytest.skip( - f"tool parser {_K2V3_TOOL_PARSER!r} unsupported by this SGLang " - f"build: {e}" - ) + pytest.skip(f"tool parser {_K2V3_TOOL_PARSER!r} unsupported by this SGLang " f"build: {e}") normal_text, tool_call_items = fcp.parse_non_stream(text_after_reasoning) parsed_content = normal_text if normal_text is not None else "" parsed_tool_calls = [ @@ -776,7 +766,10 @@ def test_chat_template_round_trip_through_real_sglang_parsers(traj_name, traj_cl # and would create a spurious extra special-token mismatch. emit_ids = list(tokenizer.encode(raw_assistant_emit, add_special_tokens=False)) prompt_ids = _render_ids( - request_messages, tokenizer, tools, add_generation_prompt=True, + request_messages, + tokenizer, + tools, + add_generation_prompt=True, ) session = LinearTrajectory() session.update_pretokenized_state( @@ -791,7 +784,10 @@ def test_chat_template_round_trip_through_real_sglang_parsers(traj_name, traj_cl # against ``apply_chat_template(session.messages)`` canonical (which # re-renders parsed_msg back to text). Severe types only. expected = _render_ids( - session.messages, tokenizer, tools, add_generation_prompt=False, + session.messages, + tokenizer, + tools, + add_generation_prompt=False, ) actual = list(session.token_ids) comparator = tito_tok.create_comparator() @@ -800,8 +796,7 @@ def test_chat_template_round_trip_through_real_sglang_parsers(traj_name, traj_cl if severe: details = "\n".join( f" {m.type.value} at segment {m.segment_index}: " - f"expected={m.expected_text!r} actual={m.actual_text!r}" - + (f" — {m.detail}" if m.detail else "") + f"expected={m.expected_text!r} actual={m.actual_text!r}" + (f" — {m.detail}" if m.detail else "") for m in severe[:8] ) pytest.fail( @@ -852,19 +847,16 @@ class _BossFlow: name="multi_turn_thinking + tool_followup", trajectory_cls=MultiTurnThinkingTrajectory, final_env=[ - {"role": "tool", "tool_call_id": "boss_call_1", - "content": "_boss_tool_followup_xyz_42_"}, + {"role": "tool", "tool_call_id": "boss_call_1", "content": "_boss_tool_followup_xyz_42_"}, ], ), _BossFlow( name="multi_tool_multi_turn_thinking + alternating_user_tool_followup", trajectory_cls=LongChainThinkingTrajectory, final_env=[ - {"role": "tool", "tool_call_id": "boss_call_2a", - "content": "_boss_alt_tool1_aaa_"}, + {"role": "tool", "tool_call_id": "boss_call_2a", "content": "_boss_alt_tool1_aaa_"}, {"role": "user", "content": "_boss_alt_user1_bbb_"}, - {"role": "tool", "tool_call_id": "boss_call_2b", - "content": "_boss_alt_tool2_ccc_"}, + {"role": "tool", "tool_call_id": "boss_call_2b", "content": "_boss_alt_tool2_ccc_"}, {"role": "user", "content": "_boss_alt_user2_ddd_"}, ], ), @@ -872,22 +864,18 @@ class _BossFlow: name="multi_tool_single_turn_thinking + system_inject", trajectory_cls=_MultiToolSingleTurnThinking, final_env=[ - {"role": "system", - "content": "_boss_system_inject_def_77_"}, + {"role": "system", "content": "_boss_system_inject_def_77_"}, ], ), _BossFlow( name="multi_tool_multi_turn_thinking + complex_env_chain", trajectory_cls=LongChainThinkingTrajectory, final_env=[ - {"role": "tool", "tool_call_id": "boss_call_4a", - "content": "_boss_chain_tool1_AAA_"}, + {"role": "tool", "tool_call_id": "boss_call_4a", "content": "_boss_chain_tool1_AAA_"}, {"role": "user", "content": "_boss_chain_user1_BBB_"}, - {"role": "tool", "tool_call_id": "boss_call_4b", - "content": "_boss_chain_tool2_CCC_"}, + {"role": "tool", "tool_call_id": "boss_call_4b", "content": "_boss_chain_tool2_CCC_"}, {"role": "system", "content": "_boss_chain_system_DDD_"}, - {"role": "tool", "tool_call_id": "boss_call_4c", - "content": "_boss_chain_tool3_EEE_"}, + {"role": "tool", "tool_call_id": "boss_call_4c", "content": "_boss_chain_tool3_EEE_"}, ], ), ] @@ -911,10 +899,7 @@ def _run_parsers_on_emit( try: rp = rp_cls(model_type=_K2V3_REASONING_PARSER) except Exception as e: - pytest.skip( - f"reasoning parser {_K2V3_REASONING_PARSER!r} unsupported " - f"by this SGLang build: {e}" - ) + pytest.skip(f"reasoning parser {_K2V3_REASONING_PARSER!r} unsupported " f"by this SGLang build: {e}") r_out, n_out = rp.parse_non_stream(raw_emit) parsed_reasoning = r_out or "" text_after_reasoning = n_out if n_out is not None else "" @@ -927,10 +912,7 @@ def _run_parsers_on_emit( try: fcp = fcp_cls(tools=sglang_tools, tool_call_parser=_K2V3_TOOL_PARSER) except Exception as e: - pytest.skip( - f"tool parser {_K2V3_TOOL_PARSER!r} unsupported by this SGLang " - f"build: {e}" - ) + pytest.skip(f"tool parser {_K2V3_TOOL_PARSER!r} unsupported by this SGLang " f"build: {e}") normal_text, tool_call_items = fcp.parse_non_stream(text_after_reasoning) parsed_content = normal_text if normal_text is not None else "" parsed_tool_calls = [ @@ -963,25 +945,30 @@ def _drive_one_assistant_turn_through_real_parsers( tokenizer = tito_tok.tokenizer full_text = _render_text( - request_messages + [truth_assistant_msg], tokenizer, tools, + request_messages + [truth_assistant_msg], + tokenizer, + tools, add_generation_prompt=False, ) prompt_text = _render_text( - request_messages, tokenizer, tools, + request_messages, + tokenizer, + tools, add_generation_prompt=True, ) assert full_text.startswith(prompt_text), ( - f"chat template not append-only between " - f"render(request_messages) and render(request_messages + [truth_msg])" - ) - raw_emit = full_text[len(prompt_text):].rstrip("\n") - assert raw_emit.endswith("<|im_end|>"), ( - f"unexpected raw_emit shape: {raw_emit!r}" + "chat template not append-only between " "render(request_messages) and render(request_messages + [truth_msg])" ) + raw_emit = full_text[len(prompt_text) :].rstrip("\n") + assert raw_emit.endswith("<|im_end|>"), f"unexpected raw_emit shape: {raw_emit!r}" has_reasoning = bool(truth_assistant_msg.get("reasoning_content")) parsed_content, parsed_tool_calls, parsed_reasoning = _run_parsers_on_emit( - raw_emit, tools, fcp_cls=fcp_cls, rp_cls=rp_cls, has_reasoning=has_reasoning, + raw_emit, + tools, + fcp_cls=fcp_cls, + rp_cls=rp_cls, + has_reasoning=has_reasoning, ) parsed_msg: dict = { @@ -995,7 +982,10 @@ def _drive_one_assistant_turn_through_real_parsers( pre = session.prepare_pretokenized(request_messages, tools, tito_tokenizer=tito_tok) if pre is None: prompt_ids = _render_ids( - request_messages, tokenizer, tools, add_generation_prompt=True, + request_messages, + tokenizer, + tools, + add_generation_prompt=True, ) else: prompt_ids = list(pre["input_ids"]) @@ -1056,8 +1046,10 @@ def test_end_to_end_realistic_rollout_with_real_parsers(flow: _BossFlow, tito_to truth_msg = messages[asst_idx] parsed_msg = _drive_one_assistant_turn_through_real_parsers( - session, tito_tok, - fcp_cls=FCP, rp_cls=RP, + session, + tito_tok, + fcp_cls=FCP, + rp_cls=RP, request_messages=request_messages, truth_assistant_msg=truth_msg, tools=tools, @@ -1077,19 +1069,18 @@ def test_end_to_end_realistic_rollout_with_real_parsers(flow: _BossFlow, tito_to merged = list(pre["input_ids"]) expected = _render_ids( - extended, tito_tok.tokenizer, tools, add_generation_prompt=True, + extended, + tito_tok.tokenizer, + tools, + add_generation_prompt=True, ) comparator = tito_tok.create_comparator() - severe = [ - m for m in comparator.compare_sequences(expected, merged) - if m.type != MismatchType.ASSISTANT_TEXT - ] + severe = [m for m in comparator.compare_sequences(expected, merged) if m.type != MismatchType.ASSISTANT_TEXT] if severe: details = "\n".join( f" {m.type.value} at segment {m.segment_index}: " - f"expected={m.expected_text!r} actual={m.actual_text!r}" - + (f" — {m.detail}" if m.detail else "") + f"expected={m.expected_text!r} actual={m.actual_text!r}" + (f" — {m.detail}" if m.detail else "") for m in severe[:8] ) pytest.fail( @@ -1105,9 +1096,7 @@ def test_end_to_end_realistic_rollout_with_real_parsers(flow: _BossFlow, tito_to # the final env chain's content (which includes user/tool/system # markers) actually flows into the incremental tokens in order. pretokenized_buffer = list(session.token_ids) - incremental_text = tito_tok.tokenizer.decode( - merged[len(pretokenized_buffer):], skip_special_tokens=False - ) + incremental_text = tito_tok.tokenizer.decode(merged[len(pretokenized_buffer) :], skip_special_tokens=False) cursor = 0 for env_msg in flow.final_env: marker = env_msg.get("content", "") @@ -1148,7 +1137,10 @@ def test_production_prefix_check_raises_on_intentional_violation(tito_tok): # Seed: drive a single normal turn so the session has stored token_ids. prompt_ids = _render_ids( - [user_q], tito_tok.tokenizer, tools=None, add_generation_prompt=True, + [user_q], + tito_tok.tokenizer, + tools=None, + add_generation_prompt=True, ) eos = getattr(tito_tok.tokenizer, "eos_token_id", None) completion_ids = list(tito_tok.tokenizer.encode("ok", add_special_tokens=False))