Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 82 additions & 0 deletions astrbot/core/provider/fallbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
from __future__ import annotations

from astrbot.core import logger


def _get_provider_source_types(config: dict) -> dict[str, str]:
provider_source_types: dict[str, str] = {}
for provider_source in config.get("provider_sources", []):
if not isinstance(provider_source, dict):
continue
provider_source_id = provider_source.get("id")
if not isinstance(provider_source_id, str) or not provider_source_id:
continue
provider_source_types[provider_source_id] = (
provider_source.get("provider_type") or "chat_completion"
)
return provider_source_types


def get_enabled_chat_provider_ids(config: dict) -> set[str]:
"""Return provider IDs that can be used as chat fallback providers."""
provider_ids: set[str] = set()
provider_source_types = _get_provider_source_types(config)
for provider in config.get("provider", []):
if not isinstance(provider, dict):
continue
provider_id = provider.get("id")
if not isinstance(provider_id, str) or not provider_id:
continue
if provider.get("enable") is False:
continue
provider_type = provider.get("provider_type")
provider_source_id = provider.get("provider_source_id")
if not provider_type:
if provider_source_id:
if not isinstance(provider_source_id, str):
continue
provider_type = provider_source_types.get(provider_source_id)
else:
provider_type = "chat_completion"
if provider_type != "chat_completion":
continue
Comment thread
Clhikari marked this conversation as resolved.
provider_ids.add(provider_id)
return provider_ids


def prune_fallback_chat_models(config: dict) -> list[str]:
"""Drop stale or disabled provider IDs from provider_settings.fallback_chat_models."""
provider_settings = config.get("provider_settings")
if not isinstance(provider_settings, dict):
return []

fallback_ids = provider_settings.get("fallback_chat_models")
if not isinstance(fallback_ids, list):
return []

valid_provider_ids = get_enabled_chat_provider_ids(config)
seen: set[str] = set()
pruned_fallback_ids: list[str] = []
removed_fallback_ids: list[str] = []

for fallback_id in fallback_ids:
if not isinstance(fallback_id, str) or not fallback_id:
removed_fallback_ids.append(str(fallback_id))
continue
if fallback_id in seen:
removed_fallback_ids.append(fallback_id)
continue
seen.add(fallback_id)
if fallback_id not in valid_provider_ids:
removed_fallback_ids.append(fallback_id)
continue
pruned_fallback_ids.append(fallback_id)

if pruned_fallback_ids != fallback_ids:
provider_settings["fallback_chat_models"] = pruned_fallback_ids
logger.info(
"Removed stale fallback chat providers from config: %s",
removed_fallback_ids,
)

return removed_fallback_ids
4 changes: 4 additions & 0 deletions astrbot/core/provider/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

from ..persona_mgr import PersonaManager
from .entities import ProviderType
from .fallbacks import prune_fallback_chat_models
from .provider import (
EmbeddingProvider,
Provider,
Expand Down Expand Up @@ -828,6 +829,7 @@ async def delete_provider(
config["provider"] = [
prov for prov in config["provider"] if prov.get("id") != tpid
]
prune_fallback_chat_models(config)
config.save_config()
logger.info(f"Provider {target_prov_ids} 已从配置中删除。")

Expand All @@ -851,6 +853,7 @@ async def update_provider(self, origin_provider_id: str, new_config: dict) -> No
break
else:
raise ValueError(f"Provider ID {origin_provider_id} not found")
prune_fallback_chat_models(config)
config.save_config()
# reload instance
await self.reload(new_config)
Expand All @@ -867,6 +870,7 @@ async def create_provider(self, new_config: dict) -> None:
raise ValueError(f"Provider ID {npid} already exists")
# add to config
config["provider"].append(new_config)
prune_fallback_chat_models(config)
config.save_config()
# load instance
await self.load_provider(new_config)
Expand Down
29 changes: 29 additions & 0 deletions astrbot/dashboard/routes/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from astrbot.core.core_lifecycle import AstrBotCoreLifecycle
from astrbot.core.platform.register import platform_cls_map, platform_registry
from astrbot.core.provider import Provider
from astrbot.core.provider.fallbacks import prune_fallback_chat_models
from astrbot.core.provider.register import provider_registry
from astrbot.core.star.star import StarMetadata, star_registry
from astrbot.core.utils.astrbot_path import (
Expand Down Expand Up @@ -327,9 +328,31 @@ def save_config(
if errors:
raise ValueError(f"格式校验未通过: {errors}")

if is_core:
prune_fallback_chat_models(post_config)

config.save_config(post_config)


def _provider_config_enabled(provider_config: dict) -> bool:
return provider_config.get("enable") is not False


def _provider_config_selectable(
provider_config: dict,
provider_type_ls: list[str],
provider_inst_map: dict,
) -> bool:
if not _provider_config_enabled(provider_config):
return False
provider_id = provider_config.get("id")
if not isinstance(provider_id, str):
return False
if provider_config.get("provider_type") == "agent_runner":
return True
return provider_id in provider_inst_map


class ConfigRoute(Route):
def __init__(
self,
Expand Down Expand Up @@ -783,6 +806,12 @@ async def get_provider_config_list(self):
for psrc in self.core_lifecycle.provider_manager.provider_sources_config
}
for provider in ps:
if not _provider_config_selectable(
provider,
provider_type_ls,
self.core_lifecycle.provider_manager.inst_map,
):
continue
ps_id = provider.get("provider_source_id", None)
if (
ps_id
Expand Down
56 changes: 56 additions & 0 deletions tests/unit/test_config_route_provider_list.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from astrbot.dashboard.routes.config import (
_provider_config_enabled,
_provider_config_selectable,
)


def test_provider_config_enabled_excludes_only_explicitly_disabled_provider():
assert _provider_config_enabled({"id": "enabled-provider", "enable": True})
assert _provider_config_enabled({"id": "legacy-provider-without-enable"})
assert not _provider_config_enabled({"id": "disabled-provider", "enable": False})


def test_provider_config_selectable_requires_loaded_non_agent_provider():
inst_map = {"loaded-provider": object()}

assert _provider_config_selectable(
{"id": "loaded-provider", "enable": True},
["chat_completion"],
inst_map,
)
assert not _provider_config_selectable(
{"enable": True},
["chat_completion"],
inst_map,
)
assert not _provider_config_selectable(
{"id": 123, "enable": True},
["chat_completion"],
inst_map,
)
assert not _provider_config_selectable(
Comment thread
Clhikari marked this conversation as resolved.
{"id": "unloaded-provider", "enable": True},
["chat_completion"],
inst_map,
)
assert not _provider_config_selectable(
{"id": "disabled-provider", "enable": False},
["chat_completion"],
inst_map,
)


def test_provider_config_selectable_keeps_agent_runner_configs():
assert _provider_config_selectable(
{"id": "agent-runner", "enable": True, "provider_type": "agent_runner"},
["agent_runner"],
{},
)
Comment thread
Clhikari marked this conversation as resolved.


def test_provider_config_selectable_does_not_bypass_loaded_check_for_mixed_types():
assert not _provider_config_selectable(
{"id": "unloaded-provider", "enable": True, "provider_type": "chat_completion"},
["agent_runner", "chat_completion"],
{},
)
118 changes: 118 additions & 0 deletions tests/unit/test_provider_fallbacks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
from astrbot.core.provider.fallbacks import prune_fallback_chat_models


def test_prune_fallback_chat_models_removes_disabled_missing_and_duplicate_ids():
config = {
"provider": [
{
"id": "enabled-chat",
"provider_type": "chat_completion",
"enable": True,
},
{
"id": "disabled-chat",
"provider_type": "chat_completion",
"enable": False,
},
{
"id": "tts-provider",
"provider_type": "text_to_speech",
"enable": True,
},
],
"provider_settings": {
"fallback_chat_models": [
"enabled-chat",
"disabled-chat",
"missing-chat",
"tts-provider",
"enabled-chat",
"",
],
},
}

removed_ids = prune_fallback_chat_models(config)

assert config["provider_settings"]["fallback_chat_models"] == ["enabled-chat"]
assert removed_ids == [
"disabled-chat",
"missing-chat",
"tts-provider",
"enabled-chat",
"",
]


def test_prune_fallback_chat_models_ignores_non_list_setting():
config = {
"provider": [{"id": "enabled-chat", "enable": True}],
"provider_settings": {"fallback_chat_models": "enabled-chat"},
}

removed_ids = prune_fallback_chat_models(config)

assert config["provider_settings"]["fallback_chat_models"] == "enabled-chat"
assert removed_ids == []


def test_prune_fallback_chat_models_uses_provider_source_type():
config = {
"provider_sources": [
{"id": "chat-source", "provider_type": "chat_completion"},
{"id": "legacy-chat-source"},
{"id": "empty-chat-source", "provider_type": ""},
{"id": "embedding-source", "provider_type": "embedding"},
],
"provider": [
{
"id": "chat-model",
"provider_source_id": "chat-source",
"enable": True,
},
{
"id": "legacy-chat-model",
"provider_source_id": "legacy-chat-source",
"enable": True,
},
{
"id": "empty-chat-model",
"provider_source_id": "empty-chat-source",
"enable": True,
},
{
"id": "inline-legacy-chat",
"enable": True,
},
{
"id": "embedding-model",
"provider_source_id": "embedding-source",
"enable": True,
},
{
"id": "missing-source-model",
"provider_source_id": "missing-source",
"enable": True,
},
],
"provider_settings": {
"fallback_chat_models": [
"chat-model",
"legacy-chat-model",
"empty-chat-model",
"inline-legacy-chat",
"embedding-model",
"missing-source-model",
],
},
}

removed_ids = prune_fallback_chat_models(config)

assert config["provider_settings"]["fallback_chat_models"] == [
"chat-model",
"legacy-chat-model",
"empty-chat-model",
"inline-legacy-chat",
]
assert removed_ids == ["embedding-model", "missing-source-model"]
Loading