Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 55 additions & 1 deletion src/benchflow/providers/litellm_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import subprocess
import sys
import tempfile
from collections.abc import Awaitable, Callable
from collections.abc import Awaitable, Callable, Mapping
from dataclasses import dataclass
from pathlib import Path
from typing import Any, NoReturn, cast
Expand All @@ -41,6 +41,7 @@
LiteLLMRoute,
litellm_proxy_config,
resolve_litellm_route,
strip_provider_prefix,
)
from benchflow.providers.litellm_logging import (
callback_module_source,
Expand Down Expand Up @@ -877,6 +878,53 @@ def _provider_secret_env_names() -> set[str]:
return names


def _provider_model_id(entry: object) -> str | None:
if not isinstance(entry, Mapping):
return None
entry_id = cast("Mapping[str, object]", entry).get("id")
return entry_id if isinstance(entry_id, str) else None


def _provider_models_for_proxy_alias(
*,
raw: str | None,
route: LiteLLMRoute,
) -> str | None:
"""Mirror model metadata onto the LiteLLM alias Pi sees in proxy mode.

Pi resolves ``maxTokens``/``contextWindow`` by looking up the model it is
told to use (the LiteLLM alias) in ``BENCHFLOW_PROVIDER_MODELS``. Without an
alias entry that metadata is lost once traffic is routed through the proxy,
so clone the source entry under the alias id/name.
"""
if not raw:
return None
try:
entries = json.loads(raw)
except json.JSONDecodeError:
return None
if not isinstance(entries, list):
return None
wanted = {
route.requested_model,
strip_provider_prefix(route.requested_model),
route.upstream_model,
strip_provider_prefix(route.upstream_model),
}
for entry in entries:
entry_id = _provider_model_id(entry)
if entry_id not in wanted:
continue
alias_entry = dict(cast("Mapping[str, Any]", entry))
alias_entry["id"] = route.model_alias
alias_entry["name"] = route.model_alias
merged = list(entries)
if not any(_provider_model_id(item) == route.model_alias for item in merged):
merged.append(alias_entry)
return json.dumps(merged)
return None


# Caller-supplied provider endpoints. If any of these survive in the agent env,
# the agent could reach a provider directly and bypass the proxy (the exact way
# an Azure ``LLM_BASE_URL`` leaked past the gateway before this hardening). They
Expand Down Expand Up @@ -1014,6 +1062,12 @@ def _wire_litellm_agent_env(
updated["BENCHFLOW_PROVIDER_API_KEY"] = master_key
updated["BENCHFLOW_PROVIDER_MODEL"] = route.model_alias
updated["BENCHFLOW_PROVIDER_NAME"] = "litellm"
alias_models = _provider_models_for_proxy_alias(
raw=agent_env.get("BENCHFLOW_PROVIDER_MODELS"),
route=route,
)
if alias_models:
updated["BENCHFLOW_PROVIDER_MODELS"] = alias_models
Comment on lines +1067 to +1070

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Silent failure when no entry matches

If _provider_models_for_proxy_alias returns None (e.g., BENCHFLOW_PROVIDER_MODELS is set but none of its entries have an id matching any variant in wanted), updated["BENCHFLOW_PROVIDER_MODELS"] is left unchanged — keeping the original entry list that does not contain the alias ID. Pi-acp would then receive a model list without the proxied alias, causing the same lookup failure this PR aims to fix, with no log to indicate the miss. Adding a debug/warning log on the None return paths in _provider_models_for_proxy_alias would make this much easier to diagnose in production.

return updated

agent_cfg = AGENTS.get(agent)
Expand Down
42 changes: 42 additions & 0 deletions tests/test_litellm_runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,48 @@ async def fake_start(**kwargs):
)


@pytest.mark.asyncio
async def test_pi_acp_proxy_preserves_provider_model_metadata(monkeypatch):
"""Guards PR #803: Pi metadata follows the LiteLLM alias in proxy mode."""

async def fake_start(**kwargs):
return FakeLiteLLMServer("http://172.17.0.1:45678", kwargs["route"])

monkeypatch.setattr(runtime_mod, "_start_host_litellm", fake_start)
provider_models = [
{
"id": "Qwen/Qwen3-4B",
"name": "Qwen/Qwen3-4B",
"reasoning": False,
"input": ["text"],
"contextWindow": 16384,
"maxTokens": 1024,
}
]

updated, provider_runtime = await ensure_litellm_runtime(
agent="pi-acp",
agent_env={
"BENCHFLOW_PROVIDER_BASE_URL": "http://172.17.0.1:8000/v1",
"BENCHFLOW_PROVIDER_API_KEY": "dummy",
"BENCHFLOW_PROVIDER_MODELS": json.dumps(provider_models),
},
model="vllm/Qwen/Qwen3-4B",
runtime=None,
environment="docker",
session_id="run-1",
usage_tracking="required",
)

assert provider_runtime is not None
assert updated["BENCHFLOW_PROVIDER_MODEL"] == "benchflow-vllm-Qwen-Qwen3-4B"
models = json.loads(updated["BENCHFLOW_PROVIDER_MODELS"])
alias = next(m for m in models if m["id"] == "benchflow-vllm-Qwen-Qwen3-4B")
assert alias["name"] == "benchflow-vllm-Qwen-Qwen3-4B"
assert alias["maxTokens"] == 1024
assert alias["contextWindow"] == 16384
Comment on lines +206 to +212

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 name field of alias entry is not asserted

The test confirms maxTokens and contextWindow are preserved on the alias entry, but does not assert alias["name"]. Because setdefault is a no-op when "name" already exists in the source entry, the alias entry currently carries name = "Qwen/Qwen3-4B" rather than the alias "benchflow-vllm-Qwen-Qwen3-4B". Adding assert alias["name"] == "benchflow-vllm-Qwen-Qwen3-4B" would lock in the expected behaviour and catch a regression if the name handling changes.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!



@pytest.mark.asyncio
async def test_runtime_reuse_and_stop(monkeypatch):
created = []
Expand Down
Loading