diff --git a/astrbot/core/provider/fallbacks.py b/astrbot/core/provider/fallbacks.py new file mode 100644 index 0000000000..eaaa6f0318 --- /dev/null +++ b/astrbot/core/provider/fallbacks.py @@ -0,0 +1,82 @@ +from __future__ import annotations + +from astrbot.core import logger + + +def _get_provider_source_types(config: dict) -> dict[str, str]: + provider_source_types: dict[str, str] = {} + for provider_source in config.get("provider_sources", []): + if not isinstance(provider_source, dict): + continue + provider_source_id = provider_source.get("id") + if not isinstance(provider_source_id, str) or not provider_source_id: + continue + provider_source_types[provider_source_id] = ( + provider_source.get("provider_type") or "chat_completion" + ) + return provider_source_types + + +def get_enabled_chat_provider_ids(config: dict) -> set[str]: + """Return provider IDs that can be used as chat fallback providers.""" + provider_ids: set[str] = set() + provider_source_types = _get_provider_source_types(config) + for provider in config.get("provider", []): + if not isinstance(provider, dict): + continue + provider_id = provider.get("id") + if not isinstance(provider_id, str) or not provider_id: + continue + if provider.get("enable") is False: + continue + provider_type = provider.get("provider_type") + provider_source_id = provider.get("provider_source_id") + if not provider_type: + if provider_source_id: + if not isinstance(provider_source_id, str): + continue + provider_type = provider_source_types.get(provider_source_id) + else: + provider_type = "chat_completion" + if provider_type != "chat_completion": + continue + provider_ids.add(provider_id) + return provider_ids + + +def prune_fallback_chat_models(config: dict) -> list[str]: + """Drop stale or disabled provider IDs from provider_settings.fallback_chat_models.""" + provider_settings = config.get("provider_settings") + if not isinstance(provider_settings, dict): + return [] + + fallback_ids = provider_settings.get("fallback_chat_models") + if not isinstance(fallback_ids, list): + return [] + + valid_provider_ids = get_enabled_chat_provider_ids(config) + seen: set[str] = set() + pruned_fallback_ids: list[str] = [] + removed_fallback_ids: list[str] = [] + + for fallback_id in fallback_ids: + if not isinstance(fallback_id, str) or not fallback_id: + removed_fallback_ids.append(str(fallback_id)) + continue + if fallback_id in seen: + removed_fallback_ids.append(fallback_id) + continue + seen.add(fallback_id) + if fallback_id not in valid_provider_ids: + removed_fallback_ids.append(fallback_id) + continue + pruned_fallback_ids.append(fallback_id) + + if pruned_fallback_ids != fallback_ids: + provider_settings["fallback_chat_models"] = pruned_fallback_ids + logger.info( + "Removed stale fallback chat providers from config: %s", + removed_fallback_ids, + ) + + return removed_fallback_ids diff --git a/astrbot/core/provider/manager.py b/astrbot/core/provider/manager.py index 2544736e72..f28e7e90a2 100644 --- a/astrbot/core/provider/manager.py +++ b/astrbot/core/provider/manager.py @@ -12,6 +12,7 @@ from ..persona_mgr import PersonaManager from .entities import ProviderType +from .fallbacks import prune_fallback_chat_models from .provider import ( EmbeddingProvider, Provider, @@ -828,6 +829,7 @@ async def delete_provider( config["provider"] = [ prov for prov in config["provider"] if prov.get("id") != tpid ] + prune_fallback_chat_models(config) config.save_config() logger.info(f"Provider {target_prov_ids} 已从配置中删除。") @@ -851,6 +853,7 @@ async def update_provider(self, origin_provider_id: str, new_config: dict) -> No break else: raise ValueError(f"Provider ID {origin_provider_id} not found") + prune_fallback_chat_models(config) config.save_config() # reload instance await self.reload(new_config) @@ -867,6 +870,7 @@ async def create_provider(self, new_config: dict) -> None: raise ValueError(f"Provider ID {npid} already exists") # add to config config["provider"].append(new_config) + prune_fallback_chat_models(config) config.save_config() # load instance await self.load_provider(new_config) diff --git a/astrbot/dashboard/routes/config.py b/astrbot/dashboard/routes/config.py index 3c37a2b788..c80f1bb3ed 100644 --- a/astrbot/dashboard/routes/config.py +++ b/astrbot/dashboard/routes/config.py @@ -21,6 +21,7 @@ from astrbot.core.core_lifecycle import AstrBotCoreLifecycle from astrbot.core.platform.register import platform_cls_map, platform_registry from astrbot.core.provider import Provider +from astrbot.core.provider.fallbacks import prune_fallback_chat_models from astrbot.core.provider.register import provider_registry from astrbot.core.star.star import StarMetadata, star_registry from astrbot.core.utils.astrbot_path import ( @@ -327,9 +328,31 @@ def save_config( if errors: raise ValueError(f"格式校验未通过: {errors}") + if is_core: + prune_fallback_chat_models(post_config) + config.save_config(post_config) +def _provider_config_enabled(provider_config: dict) -> bool: + return provider_config.get("enable") is not False + + +def _provider_config_selectable( + provider_config: dict, + provider_type_ls: list[str], + provider_inst_map: dict, +) -> bool: + if not _provider_config_enabled(provider_config): + return False + provider_id = provider_config.get("id") + if not isinstance(provider_id, str): + return False + if provider_config.get("provider_type") == "agent_runner": + return True + return provider_id in provider_inst_map + + class ConfigRoute(Route): def __init__( self, @@ -783,6 +806,12 @@ async def get_provider_config_list(self): for psrc in self.core_lifecycle.provider_manager.provider_sources_config } for provider in ps: + if not _provider_config_selectable( + provider, + provider_type_ls, + self.core_lifecycle.provider_manager.inst_map, + ): + continue ps_id = provider.get("provider_source_id", None) if ( ps_id diff --git a/tests/unit/test_config_route_provider_list.py b/tests/unit/test_config_route_provider_list.py new file mode 100644 index 0000000000..91ad0ee2b7 --- /dev/null +++ b/tests/unit/test_config_route_provider_list.py @@ -0,0 +1,56 @@ +from astrbot.dashboard.routes.config import ( + _provider_config_enabled, + _provider_config_selectable, +) + + +def test_provider_config_enabled_excludes_only_explicitly_disabled_provider(): + assert _provider_config_enabled({"id": "enabled-provider", "enable": True}) + assert _provider_config_enabled({"id": "legacy-provider-without-enable"}) + assert not _provider_config_enabled({"id": "disabled-provider", "enable": False}) + + +def test_provider_config_selectable_requires_loaded_non_agent_provider(): + inst_map = {"loaded-provider": object()} + + assert _provider_config_selectable( + {"id": "loaded-provider", "enable": True}, + ["chat_completion"], + inst_map, + ) + assert not _provider_config_selectable( + {"enable": True}, + ["chat_completion"], + inst_map, + ) + assert not _provider_config_selectable( + {"id": 123, "enable": True}, + ["chat_completion"], + inst_map, + ) + assert not _provider_config_selectable( + {"id": "unloaded-provider", "enable": True}, + ["chat_completion"], + inst_map, + ) + assert not _provider_config_selectable( + {"id": "disabled-provider", "enable": False}, + ["chat_completion"], + inst_map, + ) + + +def test_provider_config_selectable_keeps_agent_runner_configs(): + assert _provider_config_selectable( + {"id": "agent-runner", "enable": True, "provider_type": "agent_runner"}, + ["agent_runner"], + {}, + ) + + +def test_provider_config_selectable_does_not_bypass_loaded_check_for_mixed_types(): + assert not _provider_config_selectable( + {"id": "unloaded-provider", "enable": True, "provider_type": "chat_completion"}, + ["agent_runner", "chat_completion"], + {}, + ) diff --git a/tests/unit/test_provider_fallbacks.py b/tests/unit/test_provider_fallbacks.py new file mode 100644 index 0000000000..43e096d1d9 --- /dev/null +++ b/tests/unit/test_provider_fallbacks.py @@ -0,0 +1,118 @@ +from astrbot.core.provider.fallbacks import prune_fallback_chat_models + + +def test_prune_fallback_chat_models_removes_disabled_missing_and_duplicate_ids(): + config = { + "provider": [ + { + "id": "enabled-chat", + "provider_type": "chat_completion", + "enable": True, + }, + { + "id": "disabled-chat", + "provider_type": "chat_completion", + "enable": False, + }, + { + "id": "tts-provider", + "provider_type": "text_to_speech", + "enable": True, + }, + ], + "provider_settings": { + "fallback_chat_models": [ + "enabled-chat", + "disabled-chat", + "missing-chat", + "tts-provider", + "enabled-chat", + "", + ], + }, + } + + removed_ids = prune_fallback_chat_models(config) + + assert config["provider_settings"]["fallback_chat_models"] == ["enabled-chat"] + assert removed_ids == [ + "disabled-chat", + "missing-chat", + "tts-provider", + "enabled-chat", + "", + ] + + +def test_prune_fallback_chat_models_ignores_non_list_setting(): + config = { + "provider": [{"id": "enabled-chat", "enable": True}], + "provider_settings": {"fallback_chat_models": "enabled-chat"}, + } + + removed_ids = prune_fallback_chat_models(config) + + assert config["provider_settings"]["fallback_chat_models"] == "enabled-chat" + assert removed_ids == [] + + +def test_prune_fallback_chat_models_uses_provider_source_type(): + config = { + "provider_sources": [ + {"id": "chat-source", "provider_type": "chat_completion"}, + {"id": "legacy-chat-source"}, + {"id": "empty-chat-source", "provider_type": ""}, + {"id": "embedding-source", "provider_type": "embedding"}, + ], + "provider": [ + { + "id": "chat-model", + "provider_source_id": "chat-source", + "enable": True, + }, + { + "id": "legacy-chat-model", + "provider_source_id": "legacy-chat-source", + "enable": True, + }, + { + "id": "empty-chat-model", + "provider_source_id": "empty-chat-source", + "enable": True, + }, + { + "id": "inline-legacy-chat", + "enable": True, + }, + { + "id": "embedding-model", + "provider_source_id": "embedding-source", + "enable": True, + }, + { + "id": "missing-source-model", + "provider_source_id": "missing-source", + "enable": True, + }, + ], + "provider_settings": { + "fallback_chat_models": [ + "chat-model", + "legacy-chat-model", + "empty-chat-model", + "inline-legacy-chat", + "embedding-model", + "missing-source-model", + ], + }, + } + + removed_ids = prune_fallback_chat_models(config) + + assert config["provider_settings"]["fallback_chat_models"] == [ + "chat-model", + "legacy-chat-model", + "empty-chat-model", + "inline-legacy-chat", + ] + assert removed_ids == ["embedding-model", "missing-source-model"]