Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions miles/rollout/session/linear_trajectory.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ class LinearTrajectory:

lock: asyncio.Lock = field(default_factory=asyncio.Lock, repr=False, compare=False)
closing: bool = field(default=False, repr=False, compare=False)
# DELETE preemption channel: the chat-completions handler stashes its
# in-flight proxy task here so a concurrent DELETE can cancel it without
# waiting on the lock. The chat coroutine catches CancelledError and
# returns 410 Gone. Cleared in the chat handler's finally.
current_proxy_task: "asyncio.Task | None" = field(default=None, repr=False, compare=False)
messages: list[dict[str, Any]] = field(default_factory=list)
records: list[SessionRecord] = field(default_factory=list)
trajectory_token_ids: list[list[int]] = field(default_factory=list)
Expand Down
23 changes: 19 additions & 4 deletions miles/rollout/session/session_errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
Hierarchy
---------
SessionError (base)
├── SessionNotFoundError → 404 session does not exist
├── MessageValidationError → 400 messages structure/content invalid
├── TokenizationError → 500 TITO tokenizer / prefix mismatch
└── UpstreamResponseError → 502 SGLang response invalid or unexpected
├── SessionNotFoundError → 404 session does not exist
├── MessageValidationError → 400 messages structure/content invalid
├── SessionStateConflictError → 409 Phase-3 commit guard fired (defense-in-depth)
├── TokenizationError → 500 TITO tokenizer / prefix mismatch
└── UpstreamResponseError → 502 SGLang response invalid or unexpected
"""


Expand All @@ -32,6 +33,20 @@ class MessageValidationError(SessionError):
status_code: int = 400


class SessionStateConflictError(SessionError):
"""Raised as a defensive assert when the Phase-3 commit guard fires.

With the lock-restored chat flow (lock held through Phase 1+2+3), this
branch is unreachable in practice — no other writer can mutate the
session while the proxy is in flight. We keep the guard + 409 surface
as a defense in depth: if a future change reintroduces a split-lock
window, callers see a clear retryable conflict instead of a silently
dropped commit that would corrupt the trajectory's accumulated state.
"""

status_code: int = 409


class TokenizationError(SessionError):
"""Raised when TITO tokenization invariants are violated.

Expand Down
183 changes: 109 additions & 74 deletions miles/rollout/session/sessions.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from miles.rollout.session.session_errors import (
SessionError,
SessionNotFoundError,
SessionStateConflictError,
TokenizationError,
UpstreamResponseError,
)
Expand Down Expand Up @@ -98,10 +99,17 @@ async def get_session(session_id: str):

@app.delete("/sessions/{session_id}")
async def delete_session(session_id: str):
# Lock-restored chat flow holds session.lock through Phase 1+2+3.
# Preserve DELETE preemption by cancelling the in-flight proxy task
# (the cancellation channel stored on the session). The cancelled
# chat returns 410 to its caller, then releases the lock so DELETE
# can acquire it and remove the session.
session = registry.get_session(session_id)
if session.closing:
raise SessionNotFoundError(f"session not found: session_id={session_id}")
session.closing = True
if session.current_proxy_task is not None:
session.current_proxy_task.cancel()
logger.debug(
f"[session-server] DELETE waiting for lock: session={session_id} lock_locked={session.lock.locked()}"
)
Expand All @@ -117,25 +125,36 @@ async def delete_session(session_id: str):
async def chat_completions(request: Request, session_id: str):
"""Proxy a chat completion through SGLang with TITO token tracking.

Flow: prepare pretokenized input_ids (lock held briefly) → inject
SGLang flags → proxy to backend (NO lock) → validate response →
update trajectory checkpoint (lock held briefly) → append session record.

The lock is NOT held during the slow proxy call to avoid blocking
DELETE/other operations when the agent disconnects mid-request.
Flow: ALL three phases run under ``session.lock`` —
Phase 1: prepare pretokenized input_ids + inject SGLang flags
Phase 2: proxy to backend (lock held; cancellation channel via
``session.current_proxy_task`` lets DELETE preempt)
Phase 3: validate response, update trajectory checkpoint, record

Holding the lock through Phase 2 eliminates the cursor-mismatch race
that the split-lock design (lock-Phase-1 → unlock-Phase-2 → relock-
Phase-3) introduced: two same-session writers could both reach
Phase 3, the stale-state guard silently dropped the second commit
while still returning 200, and the caller appended a phantom
assistant record to its trajectory.

DELETE preemption is preserved via the cancellation channel:
``delete_session`` cancels ``session.current_proxy_task``; the chat
coroutine catches ``CancelledError`` and returns 410 Gone, then
releases the lock so DELETE can acquire it.
"""
_inflight_chat["count"] += 1
try:
session = registry.get_or_create_session(session_id)
if session.closing:
raise SessionNotFoundError(f"session not found: session_id={session_id}")

# --- Phase 1: prepare request (lock held briefly) ---
async with session.lock:
# Double-check: session may have been marked closing while waiting for lock.
if session.closing:
raise SessionNotFoundError(f"session not found: session_id={session_id}")

# --- Phase 1: prepare request ---
body = await request.body()
request_body = json.loads(body) if body else {}

Expand Down Expand Up @@ -171,82 +190,99 @@ async def chat_completions(request: Request, session_id: str):

body = json.dumps(request_body).encode()
expected_num_assistant = session.num_assistant
# --- lock released here ---

# --- Phase 2: proxy to SGLang (NO lock held) ---
result = await backend.do_proxy(request, "v1/chat/completions", body=body)

# If SGLang returned a non-200 error (e.g. 400 for context too long),
# pass it through to the agent without recording — the agent can retry
# or handle the error.
if result["status_code"] != 200:
# Rollback failures indicate corrupted prefix-cache state in SGLang.
# Retry once without pretokenized input_ids so SGLang processes the
# request from scratch instead of attempting prefix continuation.
error_body = result.get("response_body") or b""
if isinstance(error_body, bytes):
error_body = error_body.decode("utf-8", errors="replace")
if (
result["status_code"] == 400
and "rollback failed" in error_body.lower()
and "input_ids" in request_body
):
logger.warning(
"SGLang rollback failed for session %s, retrying without prefix continuation",
session_id,
)
request_body.pop("input_ids", None)
retry_body = json.dumps(request_body).encode()
result = await backend.do_proxy(request, "v1/chat/completions", body=retry_body)
if result["status_code"] != 200:
return backend.build_proxy_response(result)
else:
return backend.build_proxy_response(result)

response = json.loads(result["response_body"])
# --- Phase 2: proxy to SGLang (lock held; cancellation channel) ---
session.current_proxy_task = asyncio.create_task(
backend.do_proxy(request, "v1/chat/completions", body=body)
)
try:
result = await session.current_proxy_task
except asyncio.CancelledError:
# DELETE preempted this request. Surface 410 Gone to the caller
# so litellm/harbor treats it as a definite session-closed signal
# (distinct from 404 "never existed" and 409 "retryable conflict").
return JSONResponse(status_code=410, content={"error": "session closing"})
finally:
session.current_proxy_task = None

# If SGLang returned a non-200 error (e.g. 400 for context too long),
# pass it through to the agent without recording — the agent can retry
# or handle the error.
if result["status_code"] != 200:
# Rollback failures indicate corrupted prefix-cache state in SGLang.
# Retry once without pretokenized input_ids so SGLang processes the
# request from scratch instead of attempting prefix continuation.
error_body = result.get("response_body") or b""
if isinstance(error_body, bytes):
error_body = error_body.decode("utf-8", errors="replace")
if (
result["status_code"] == 400
and "rollback failed" in error_body.lower()
and "input_ids" in request_body
):
logger.warning(
"SGLang rollback failed for session %s, retrying without prefix continuation",
session_id,
)
request_body.pop("input_ids", None)
retry_body = json.dumps(request_body).encode()
session.current_proxy_task = asyncio.create_task(
backend.do_proxy(request, "v1/chat/completions", body=retry_body)
)
try:
result = await session.current_proxy_task
except asyncio.CancelledError:
return JSONResponse(status_code=410, content={"error": "session closing"})
finally:
session.current_proxy_task = None
if result["status_code"] != 200:
return backend.build_proxy_response(result)
else:
return backend.build_proxy_response(result)

choice = response.get("choices", [{}])[0]
response = json.loads(result["response_body"])

meta_info = choice.get("meta_info")
if not isinstance(meta_info, dict) or "output_token_logprobs" not in meta_info:
raise UpstreamResponseError(
"meta_info and output_token_logprobs must be in choice (requires logprobs=True)"
)
assistant_message = choice.get("message", {})
if assistant_message.get("content") is None:
raise UpstreamResponseError(
"assistant message content is None, when tool call parser failed SGLang should still return "
"an empty content rather than None. Please check your modified SGLang version."
)
choice = response.get("choices", [{}])[0]

prompt_token_ids = choice.get("prompt_token_ids")
output_token_logprobs = meta_info["output_token_logprobs"]
completion_tokens = meta_info["completion_tokens"]

actual_output_logprobs_len = len(output_token_logprobs)
if actual_output_logprobs_len != completion_tokens:
raise UpstreamResponseError(
"invalid chat completion response: "
f"len(output_token_logprobs)={actual_output_logprobs_len} "
f"!= completion_tokens={completion_tokens}. "
f"Please check whether you use the correct SGLang branch which has fix the tokenizer batch decode issue."
)
meta_info = choice.get("meta_info")
if not isinstance(meta_info, dict) or "output_token_logprobs" not in meta_info:
raise UpstreamResponseError(
"meta_info and output_token_logprobs must be in choice (requires logprobs=True)"
)
assistant_message = choice.get("message", {})
if assistant_message.get("content") is None:
raise UpstreamResponseError(
"assistant message content is None, when tool call parser failed SGLang should still return "
"an empty content rather than None. Please check your modified SGLang version."
)

completion_token_ids = [t[1] for t in output_token_logprobs]
prompt_token_ids = choice.get("prompt_token_ids")
output_token_logprobs = meta_info["output_token_logprobs"]
completion_tokens = meta_info["completion_tokens"]

actual_output_logprobs_len = len(output_token_logprobs)
if actual_output_logprobs_len != completion_tokens:
raise UpstreamResponseError(
"invalid chat completion response: "
f"len(output_token_logprobs)={actual_output_logprobs_len} "
f"!= completion_tokens={completion_tokens}. "
f"Please check whether you use the correct SGLang branch which has fix the tokenizer batch decode issue."
)

# --- Phase 3: update state (lock held briefly) ---
async with session.lock:
if session.closing:
logger.warning(f"Session {session_id} closed during proxy, skipping state update")
return backend.build_proxy_response(result)
completion_token_ids = [t[1] for t in output_token_logprobs]

# --- Phase 3: update state (still under the same lock) ---
# Defensive assert: with the lock held through Phase 1+2+3, no
# other writer can mutate num_assistant. If this fires, some
# future change has reintroduced a split-lock window — surface

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Raising SessionStateConflictError here makes one think that this can still happen, which is impossible because of the new locking protocol. IMO this can be removed safely to avoid confusion, and only the lock related change should stay.

# 409 instead of silently dropping the commit (the
# cursor-mismatch class of incident under the old design).
if session.num_assistant != expected_num_assistant:
logger.warning(
f"Session {session_id} state changed during proxy "
raise SessionStateConflictError(
f"session {session_id} state changed during proxy "
f"(expected num_assistant={expected_num_assistant}, "
f"got {session.num_assistant}), skipping state update"
f"got {session.num_assistant})"
)
return backend.build_proxy_response(result)

await asyncio.to_thread(
session.update_pretokenized_state,
Expand All @@ -266,7 +302,6 @@ async def chat_completions(request: Request, session_id: str):
response=response,
)
session.append_record(record)
# --- lock released here ---

return backend.build_proxy_response(result)
finally:
Expand Down
19 changes: 13 additions & 6 deletions tests/fast/router/test_session_race_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,7 +321,10 @@ def process_fn(prompt: str) -> ProcessResult:
d1 = delete_1.result(timeout=30.0)
d2 = delete_2.result(timeout=30.0)

assert chat_resp.status_code == 200
# Lock-restored chat flow: DELETE preempts the in-flight chat via the
# cancellation channel; the cancelled chat returns 410 Gone. (Under
# the prior split-lock design the chat would complete with 200.)
assert chat_resp.status_code == 410
# One delete succeeds, the other gets 404
codes = sorted([d1.status_code, d2.status_code])
assert codes == [204, 404], f"Expected [204, 404], got {codes}"
Expand Down Expand Up @@ -385,12 +388,16 @@ def process_fn(prompt: str) -> ProcessResult:

assert delete_resp.status_code == 204

# At least one chat must succeed (the one holding the lock when
# delete arrived). Others may get 200 (acquired lock before
# closing) or 404 (saw closing=True). No 500s allowed.
# Lock-restored chat flow: the in-flight chat at DELETE time gets
# cancelled via the proxy-task channel (→ 410). Queued chats then
# acquire the lock, see closing=True, and return 404. (Under the
# prior split-lock design the in-flight chat would have completed
# with 200 before DELETE acquired the lock.) No 500s allowed.
status_codes = [r.status_code for r in results]
assert all(c in (200, 404) for c in status_codes), f"Unexpected status codes: {status_codes}"
assert 200 in status_codes, f"Expected at least one 200, got {status_codes}"
assert all(c in (200, 404, 410) for c in status_codes), f"Unexpected status codes: {status_codes}"
assert (
200 in status_codes or 410 in status_codes
), f"Expected at least one in-flight chat (200 or 410), got {status_codes}"

def test_rapid_create_chat_delete_cycles(self):
"""Rapidly create, chat, and delete sessions to stress the lifecycle.
Expand Down
Loading