Skip to content
137 changes: 126 additions & 11 deletions astrbot/core/provider/sources/gemini_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -227,18 +227,58 @@ async def _prepare_query_config(
"当前 SDK 版本不支持 URL 上下文工具,已忽略该设置,请升级 google-genai 包",
)

# 将自定义工具追加进 tool_list
if tools and (func_desc := tools.get_func_desc_google_genai_style()):
if tool_list is None:
tool_list = []
tool_list.append(
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
types.Tool(function_declarations=func_desc["function_declarations"])
)

if not tool_list:
tool_list = None

if tools and tool_list:
logger.warning("已启用原生工具,函数工具将被忽略")
elif tools and (func_desc := tools.get_func_desc_google_genai_style()):
tool_list = [
types.Tool(function_declarations=func_desc["function_declarations"]),
]
# 1. 检查是否已经存在原生工具(搜索/代码执行/URL上下文)
has_native_before = tool_list and any(
getattr(t, "google_search", None)
or getattr(t, "code_execution", None)
or getattr(t, "url_context", None)
for t in tool_list
)

# 2. 判断是否为 Gemini 3 或更新的模型
is_gemini_3_or_later = any(
model_name.startswith(p) for p in ("gemini-3-", "gemini-3.")
)

# 3. 追加自定义工具的逻辑(全文件仅保留这一处)
if tools and (func_desc := tools.get_func_desc_google_genai_style()):
# 如果是老模型且已开启原生工具,强制忽略自定义插件,保障不崩溃
if not is_gemini_3_or_later and has_native_before:
logger.warning(
"当前模型不支持多工具混合编排。已启用原生工具,自定义函数工具将被忽略"
)
else:
if tool_list is None:
tool_list = []
tool_list.append(
types.Tool(function_declarations=func_desc["function_declarations"])
)

if not tool_list:
tool_list = None

tool_config = None
has_func_decl = tool_list and any(t.function_declarations for t in tool_list)

# 4. 再次确认最终工具链
has_native_tool = tool_list and any(
getattr(t, "google_search", None)
or getattr(t, "code_execution", None)
or getattr(t, "url_context", None)
for t in tool_list
)

if has_func_decl:
tool_config = types.ToolConfig(
function_calling_config=types.FunctionCallingConfig(
Expand All @@ -247,7 +287,11 @@ async def _prepare_query_config(
if tool_choice == "required"
else types.FunctionCallingConfigMode.AUTO
)
)
),
# 仅针对 Gemini 3+ 模型注入混合编排开关
include_server_side_tool_invocations=True
if (has_native_tool and is_gemini_3_or_later)
else None,
)

# oper thinking config
Expand Down Expand Up @@ -357,6 +401,13 @@ def append_or_extend(
self.provider_config.get("gm_native_search", False),
],
)

# 判断当前请求的模型是否为 Gemini 3+
model_name = payloads.get("model", self.get_model())
is_gemini_3_or_later = any(
str(model_name).startswith(p) for p in ("gemini-3-", "gemini-3.")
)

for message in payloads["messages"]:
role, content = message["role"], message.get("content")

Expand Down Expand Up @@ -410,13 +461,40 @@ def append_or_extend(
)
append_or_extend(gemini_contents, parts, types.ModelContent)

elif not native_tool_enabled and "tool_calls" in message:
# 只有 Gemini 3 系列及以后或未开启原生工具时,才允许还原函数调用历史
elif (
is_gemini_3_or_later or not native_tool_enabled
) and "tool_calls" in message:
parts = []
for tool in message["tool_calls"]:
# 兼容历史或异常日志中的非 JSON arguments,避免重放工具历史时报错
raw_args = tool.get("function", {}).get("arguments")
parsed_args = None
if isinstance(raw_args, (dict, list)):
parsed_args = raw_args
else:
try:
parsed_args = (
json.loads(raw_args)
if raw_args is not None
else None
Comment thread
sourcery-ai[bot] marked this conversation as resolved.
Outdated
)
except (TypeError, json.JSONDecodeError):
parsed_args = raw_args

part = types.Part.from_function_call(
name=tool["function"]["name"],
args=json.loads(tool["function"]["arguments"]),
args=parsed_args,
)

# 还原 Assistant 历史消息里工具调用的唯一 ID,并用 hasattr 确保向后兼容性
if (
"id" in tool
and part.function_call
and hasattr(part.function_call, "id")
):
part.function_call.id = tool["id"]

# we should set thought_signature back to part if exists
# for more info about thought_signature, see:
# https://ai.google.dev/gemini-api/docs/thought-signatures
Expand All @@ -439,8 +517,17 @@ def append_or_extend(
parts = [types.Part.from_text(text=" ")]
append_or_extend(gemini_contents, parts, types.ModelContent)

elif role == "tool" and not native_tool_enabled:
func_name = message.get("name", message["tool_call_id"])
# 移除了 and not native_tool_enabled 限制
elif role == "tool" and (is_gemini_3_or_later or not native_tool_enabled):
func_name = message.get("name") or message.get("tool_call_id")
tool_call_id = message.get("tool_call_id")

if not func_name:
logger.warning(
"跳过一条缺失 name 和 tool_call_id 的非法工具响应记录"
)
continue

part = types.Part.from_function_response(
name=func_name,
response={
Expand All @@ -449,6 +536,14 @@ def append_or_extend(
},
)

# 使用 hasattr 检查保护此赋值,以确保向后兼容性
if (
tool_call_id
and part.function_response
and hasattr(part.function_response, "id")
):
part.function_response.id = tool_call_id

parts = [part]
append_or_extend(gemini_contents, parts, types.UserContent)

Expand Down Expand Up @@ -756,6 +851,26 @@ async def _query_stream(
llm_response,
validate_output=False,
)

# 如果在这个 chunk 之前已经有流式文本被累积了,则把它强行塞回消息链的最前端
if (
accumulated_text
and llm_response.result_chain
and hasattr(llm_response.result_chain, "chain")
):
llm_response.result_chain.chain.insert(
0, Comp.Plain(accumulated_text)
)

# 同样,如果之前已经累积了推理(思考)文本,也需要保留,确保对话历史中不会中断丢失
if accumulated_reasoning:
if llm_response.reasoning_content:
llm_response.reasoning_content = (
accumulated_reasoning + llm_response.reasoning_content
)
else:
llm_response.reasoning_content = accumulated_reasoning

llm_response.id = chunk.response_id
if chunk.usage_metadata:
llm_response.usage = self._extract_usage(chunk.usage_metadata)
Expand Down
Loading