Skip to content
119 changes: 84 additions & 35 deletions astrbot/core/provider/sources/gemini_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,14 @@ def _init_safety_settings(self) -> None:
and threshold_str in self.THRESHOLD_MAPPING
]

def _supports_multi_tool(self, model_name: str) -> bool:
"""检查模型是否支持多工具混合编排 (内置工具与自定义函数并存)"""
# 针对已知的历史老版本 (gemini-1.x, gemini-2.x) 返回 False
if "gemini-1" in model_name or "gemini-2" in model_name:
return False
# 默认支持 Gemini 3.0 以及所有未来更新的模型
return True

async def _handle_api_error(self, e: APIError, keys: list[str]) -> bool:
"""处理API错误,返回是否需要重试"""
if e.message is None:
Expand Down Expand Up @@ -227,28 +235,46 @@ async def _prepare_query_config(
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包",
)

supports_multi_tool = self._supports_multi_tool(model_name)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

issue (complexity): Consider flattening the multi-tool handling logic, making tool_config construction more explicit, and simplifying native_tool_enabled computation to reduce branching and mental overhead.

You can keep the new behavior but simplify some of the branching to reduce mental overhead.

1. Flatten multi-tool + plugin tool list construction

You can separate capability checks from list construction to avoid nested if/else and repeated tool_list fiddling:

supports_multi_tool = self._supports_multi_tool(model_name)
func_desc = tools.get_func_desc_google_genai_style() if tools else None

has_native_tools = bool(tool_list)
has_plugin_tools = bool(func_desc)

if has_plugin_tools:
    if has_native_tools and not supports_multi_tool:
        logger.warning(
            f"模型 {model_name} 不支持多工具混合编排。已启用原生工具,函数工具(本地插件)将被忽略。"
        )
    else:
        if tool_list is None:
            tool_list = []
        tool_list.append(
            types.Tool(function_declarations=func_desc["function_declarations"])
        )

This keeps all behavior but flattens the conditions and makes the combinations “native vs plugin vs capability” easier to read.

2. Make tool_config construction explicit

The kwargs_tool_config dict is only used to add one optional field. You can keep the same behavior with clearer, positional construction:

has_func_decl = tool_list and any(t.function_declarations for t in tool_list)
tool_config = None

if has_func_decl:
    has_builtin_tools = any(
        getattr(t, "google_search", None)
        or getattr(t, "code_execution", None)
        or getattr(t, "url_context", None)
        for t in tool_list
    )

    fc_config = types.FunctionCallingConfig(
        mode=(
            types.FunctionCallingConfigMode.ANY
            if tool_choice == "required"
            else types.FunctionCallingConfigMode.AUTO
        )
    )

    if supports_multi_tool and has_builtin_tools:
        tool_config = types.ToolConfig(
            function_calling_config=fc_config,
            include_server_side_tool_invocations=True,
        )
    else:
        tool_config = types.ToolConfig(function_calling_config=fc_config)

This removes the indirection of kwargs_tool_config while keeping all the new flags.

3. Simplify native_tool_enabled derivation

The relationship between supports_multi_tool and native_tool_enabled can be expressed as a single expression:

model_name = cast(str, payloads.get("model", self.get_model()))
supports_multi_tool = self._supports_multi_tool(model_name)

native_tool_enabled = (
    not supports_multi_tool
    and (
        self.provider_config.get("gm_native_coderunner", False)
        or self.provider_config.get("gm_native_search", False)
    )
)

This avoids the temporary False assignment and the subsequent if block, while preserving behavior and making later if native_tool_enabled checks easier to reason about.


if tools and (func_desc := tools.get_func_desc_google_genai_style()):
if tool_list and not supports_multi_tool:
logger.warning(
f"模型 {model_name} 不支持多工具混合编排。已启用原生工具,函数工具(本地插件)将被忽略。"
)
else:
if tool_list is None:
tool_list = []
tool_list.append(
types.Tool(
function_declarations=func_desc["function_declarations"]
),
)

if not tool_list:
tool_list = None

if tools and tool_list:
logger.warning("已启用原生工具,函数工具将被忽略")
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
tool_list = [
types.Tool(function_declarations=func_desc["function_declarations"]),
]

tool_config = None
has_func_decl = tool_list and any(t.function_declarations for t in tool_list)
if has_func_decl:
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
has_builtin_tools = tool_list and any(
getattr(t, "google_search", None)
or getattr(t, "code_execution", None)
or getattr(t, "url_context", None)
for t in tool_list
)
kwargs_tool_config = {
"function_calling_config": types.FunctionCallingConfig(
mode=(
types.FunctionCallingConfigMode.ANY
if tool_choice == "required"
else types.FunctionCallingConfigMode.AUTO
)
)
)
}
if supports_multi_tool and has_builtin_tools:
kwargs_tool_config["include_server_side_tool_invocations"] = True
tool_config = types.ToolConfig(**kwargs_tool_config)

# oper thinking config
thinking_config = None
Expand Down Expand Up @@ -351,12 +377,16 @@ def append_or_extend(
contents.append(content_cls(parts=part))

gemini_contents: list[types.Content] = []
native_tool_enabled = any(
[
self.provider_config.get("gm_native_coderunner", False),
self.provider_config.get("gm_native_search", False),
],
)
model_name = cast(str, payloads.get("model", self.get_model()))
supports_multi_tool = self._supports_multi_tool(model_name)
native_tool_enabled = False
if not supports_multi_tool:
native_tool_enabled = any(
[
self.provider_config.get("gm_native_coderunner", False),
self.provider_config.get("gm_native_search", False),
],
)
for message in payloads["messages"]:
role, content = message["role"], message.get("content")

Expand All @@ -379,11 +409,10 @@ def append_or_extend(
append_or_extend(gemini_contents, parts, types.UserContent)

elif role == "assistant":
if isinstance(content, str):
parts = [types.Part.from_text(text=content)]
append_or_extend(gemini_contents, parts, types.ModelContent)
parts = []
if isinstance(content, str) and content:
parts.append(types.Part.from_text(text=content))
elif isinstance(content, list):
parts = []
thinking_signature = None
text = ""
for part in content:
Expand All @@ -408,14 +437,24 @@ def append_or_extend(
thought_signature=thinking_signature,
)
)
append_or_extend(gemini_contents, parts, types.ModelContent)

elif not native_tool_enabled and "tool_calls" in message:
parts = []
if (
not native_tool_enabled
and "tool_calls" in message
and message["tool_calls"]
):
for tool in message["tool_calls"]:
part = types.Part.from_function_call(
name=tool["function"]["name"],
args=json.loads(tool["function"]["arguments"]),
func_name = tool["function"]["name"]
tool_id = tool.get("id")
# 仅当 ID 不是本地伪造的函数名本身时,才进行传递
fc_id = tool_id if tool_id and tool_id != func_name else None

part = types.Part(
function_call=types.FunctionCall(
name=func_name,
args=json.loads(tool["function"]["arguments"]),
id=fc_id,
)
)
# we should set thought_signature back to part if exists
# for more info about thought_signature, see:
Expand All @@ -429,24 +468,34 @@ def append_or_extend(
if ts_bs64:
part.thought_signature = base64.b64decode(ts_bs64)
parts.append(part)
append_or_extend(gemini_contents, parts, types.ModelContent)
else:

if not parts:
logger.warning("assistant 角色的消息内容为空,已添加空格占位")
if native_tool_enabled and "tool_calls" in message:
logger.warning(
"检测到启用Gemini原生工具,且上下文中存在函数调用,建议使用 /reset 重置上下文",
)
parts = [types.Part.from_text(text=" ")]
append_or_extend(gemini_contents, parts, types.ModelContent)

append_or_extend(gemini_contents, parts, types.ModelContent)

elif role == "tool" and not native_tool_enabled:
func_name = message.get("name", message["tool_call_id"])
part = types.Part.from_function_response(
name=func_name,
response={
"name": func_name,
"content": message["content"],
},
tool_call_id = message.get("tool_call_id")
# 仅当 ID 不是本地伪造的函数名本身时,才进行传递
fr_id = (
tool_call_id if tool_call_id and tool_call_id != func_name else None
)

part = types.Part(
function_response=types.FunctionResponse(
name=func_name,
response={
"name": func_name,
"content": message["content"],
},
id=fr_id,
)
)

parts = [part]
Expand Down
Loading