From d1bbb5c859c5d17bb94085c157d084580e89b787 Mon Sep 17 00:00:00 2001 From: Bingran You Date: Sun, 21 Jun 2026 23:08:25 -0700 Subject: [PATCH] Preserve pi-acp model metadata through LiteLLM proxy Mirror BENCHFLOW_PROVIDER_MODELS metadata (maxTokens/contextWindow) onto the LiteLLM proxy alias that pi-acp sees, so model limits survive the mandatory proxy routing introduced in #820. Rebased onto main after #820 restructured litellm_runtime into _apply/_wire_litellm_agent_env. - Add _provider_models_for_proxy_alias + _provider_model_id helpers. - Clone the source model entry under route.model_alias in the pi-acp branch of _wire_litellm_agent_env. - Regression test asserts vLLM/Qwen alias carries maxTokens/contextWindow. Local: tests/test_litellm_runtime.py 23 passed; ruff + ty clean. --- src/benchflow/providers/litellm_runtime.py | 56 +++++++++++++++++++++- tests/test_litellm_runtime.py | 42 ++++++++++++++++ 2 files changed, 97 insertions(+), 1 deletion(-) diff --git a/src/benchflow/providers/litellm_runtime.py b/src/benchflow/providers/litellm_runtime.py index 0e29b83a..31f0047d 100644 --- a/src/benchflow/providers/litellm_runtime.py +++ b/src/benchflow/providers/litellm_runtime.py @@ -15,7 +15,7 @@ import subprocess import sys import tempfile -from collections.abc import Awaitable, Callable +from collections.abc import Awaitable, Callable, Mapping from dataclasses import dataclass from pathlib import Path from typing import Any, NoReturn, cast @@ -41,6 +41,7 @@ LiteLLMRoute, litellm_proxy_config, resolve_litellm_route, + strip_provider_prefix, ) from benchflow.providers.litellm_logging import ( callback_module_source, @@ -877,6 +878,53 @@ def _provider_secret_env_names() -> set[str]: return names +def _provider_model_id(entry: object) -> str | None: + if not isinstance(entry, Mapping): + return None + entry_id = cast("Mapping[str, object]", entry).get("id") + return entry_id if isinstance(entry_id, str) else None + + +def _provider_models_for_proxy_alias( + *, + raw: str | None, + route: LiteLLMRoute, +) -> str | None: + """Mirror model metadata onto the LiteLLM alias Pi sees in proxy mode. + + Pi resolves ``maxTokens``/``contextWindow`` by looking up the model it is + told to use (the LiteLLM alias) in ``BENCHFLOW_PROVIDER_MODELS``. Without an + alias entry that metadata is lost once traffic is routed through the proxy, + so clone the source entry under the alias id/name. + """ + if not raw: + return None + try: + entries = json.loads(raw) + except json.JSONDecodeError: + return None + if not isinstance(entries, list): + return None + wanted = { + route.requested_model, + strip_provider_prefix(route.requested_model), + route.upstream_model, + strip_provider_prefix(route.upstream_model), + } + for entry in entries: + entry_id = _provider_model_id(entry) + if entry_id not in wanted: + continue + alias_entry = dict(cast("Mapping[str, Any]", entry)) + alias_entry["id"] = route.model_alias + alias_entry["name"] = route.model_alias + merged = list(entries) + if not any(_provider_model_id(item) == route.model_alias for item in merged): + merged.append(alias_entry) + return json.dumps(merged) + return None + + # Caller-supplied provider endpoints. If any of these survive in the agent env, # the agent could reach a provider directly and bypass the proxy (the exact way # an Azure ``LLM_BASE_URL`` leaked past the gateway before this hardening). They @@ -1014,6 +1062,12 @@ def _wire_litellm_agent_env( updated["BENCHFLOW_PROVIDER_API_KEY"] = master_key updated["BENCHFLOW_PROVIDER_MODEL"] = route.model_alias updated["BENCHFLOW_PROVIDER_NAME"] = "litellm" + alias_models = _provider_models_for_proxy_alias( + raw=agent_env.get("BENCHFLOW_PROVIDER_MODELS"), + route=route, + ) + if alias_models: + updated["BENCHFLOW_PROVIDER_MODELS"] = alias_models return updated agent_cfg = AGENTS.get(agent) diff --git a/tests/test_litellm_runtime.py b/tests/test_litellm_runtime.py index e357b93c..a0566c7b 100644 --- a/tests/test_litellm_runtime.py +++ b/tests/test_litellm_runtime.py @@ -170,6 +170,48 @@ async def fake_start(**kwargs): ) +@pytest.mark.asyncio +async def test_pi_acp_proxy_preserves_provider_model_metadata(monkeypatch): + """Guards PR #803: Pi metadata follows the LiteLLM alias in proxy mode.""" + + async def fake_start(**kwargs): + return FakeLiteLLMServer("http://172.17.0.1:45678", kwargs["route"]) + + monkeypatch.setattr(runtime_mod, "_start_host_litellm", fake_start) + provider_models = [ + { + "id": "Qwen/Qwen3-4B", + "name": "Qwen/Qwen3-4B", + "reasoning": False, + "input": ["text"], + "contextWindow": 16384, + "maxTokens": 1024, + } + ] + + updated, provider_runtime = await ensure_litellm_runtime( + agent="pi-acp", + agent_env={ + "BENCHFLOW_PROVIDER_BASE_URL": "http://172.17.0.1:8000/v1", + "BENCHFLOW_PROVIDER_API_KEY": "dummy", + "BENCHFLOW_PROVIDER_MODELS": json.dumps(provider_models), + }, + model="vllm/Qwen/Qwen3-4B", + runtime=None, + environment="docker", + session_id="run-1", + usage_tracking="required", + ) + + assert provider_runtime is not None + assert updated["BENCHFLOW_PROVIDER_MODEL"] == "benchflow-vllm-Qwen-Qwen3-4B" + models = json.loads(updated["BENCHFLOW_PROVIDER_MODELS"]) + alias = next(m for m in models if m["id"] == "benchflow-vllm-Qwen-Qwen3-4B") + assert alias["name"] == "benchflow-vllm-Qwen-Qwen3-4B" + assert alias["maxTokens"] == 1024 + assert alias["contextWindow"] == 16384 + + @pytest.mark.asyncio async def test_runtime_reuse_and_stop(monkeypatch): created = []