Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
339 changes: 337 additions & 2 deletions lib/crewai/src/crewai/llms/providers/openai_compatible/completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
from __future__ import annotations

from dataclasses import dataclass, field
import json
import logging
import os
from typing import Any
from typing import TYPE_CHECKING, Any

from pydantic import model_validator
from pydantic import BaseModel, model_validator

from crewai.llms.providers.openai.completion import OpenAICompletion
from crewai.utilities.types import LLMMessage


if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool


@dataclass(frozen=True)
Expand All @@ -32,6 +39,7 @@ class ProviderConfig:
default_headers: HTTP headers to include in all requests.
api_key_required: Whether an API key is required for this provider.
default_api_key: Default API key to use if none is provided and not required.
supports_json_schema: Whether the provider supports json_schema response_format type.
"""

base_url: str
Expand All @@ -40,6 +48,7 @@ class ProviderConfig:
default_headers: dict[str, str] = field(default_factory=dict)
api_key_required: bool = True
default_api_key: str | None = None
supports_json_schema: bool = True


OPENAI_COMPATIBLE_PROVIDERS: dict[str, ProviderConfig] = {
Expand All @@ -55,6 +64,7 @@ class ProviderConfig:
api_key_env="DEEPSEEK_API_KEY",
base_url_env="DEEPSEEK_BASE_URL",
api_key_required=True,
supports_json_schema=False,
),
"ollama": ProviderConfig(
base_url="http://localhost:11434/v1",
Expand Down Expand Up @@ -250,6 +260,331 @@ def _resolve_headers(

return merged if merged else None

@property
def _provider_supports_json_schema(self) -> bool:
"""Check if the current provider supports json_schema response_format."""
config = OPENAI_COMPATIBLE_PROVIDERS.get(self.provider)
if config is None:
return True
return config.supports_json_schema

def _prepare_completion_params(
self,
messages: list[LLMMessage],
tools: list[dict[str, BaseTool]] | None = None,
) -> dict[str, Any]:
"""Prepare params, stripping json_schema response_format if unsupported."""
params = super()._prepare_completion_params(messages, tools)

if not self._provider_supports_json_schema:
rf = params.get("response_format")
if isinstance(rf, dict) and rf.get("type") == "json_schema":
schema_info = rf.get("json_schema", {})
schema = schema_info.get("schema", schema_info)
self._inject_schema_instructions(params, schema)
del params["response_format"]

return params

def _inject_schema_instructions(
self,
params: dict[str, Any],
schema: dict[str, Any],
) -> None:
"""Inject JSON schema instructions into the system message."""
schema_str = json.dumps(schema, indent=2)
instruction = (
"\nYou must respond with a valid JSON object that conforms to this JSON schema:\n"
f"```json\n{schema_str}\n```\n"
"Respond ONLY with valid JSON, no additional text or markdown."
)
msgs = params.get("messages", [])
for msg in msgs:
if msg.get("role") == "system":
msg["content"] = (msg.get("content") or "") + instruction
return
params["messages"] = [
{"role": "system", "content": instruction.lstrip()},
*msgs,
]

@staticmethod
def _extract_json_from_text(text: str) -> str:
"""Extract JSON from text that may be wrapped in markdown code blocks."""
stripped = text.strip()
if stripped.startswith("```"):
lines = stripped.split("\n")
json_lines: list[str] = []
in_block = False
for line in lines:
if line.strip().startswith("```"):
if in_block:
break
in_block = True
continue
if in_block:
json_lines.append(line)
return "\n".join(json_lines).strip()
return stripped

def _handle_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle completion, falling back for providers without json_schema."""
if response_model and not self._provider_supports_json_schema:
return self._handle_completion_fallback(
params, available_functions, from_task, from_agent, response_model
)
return super()._handle_completion(
params, available_functions, from_task, from_agent, response_model
)

def _handle_completion_fallback(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle structured output via prompt injection instead of json_schema."""
from crewai.events.types.llm_events import LLMCallType

schema_dict = response_model.model_json_schema() if response_model else {}
modified_params = dict(params)
modified_params.pop("response_format", None)

self._inject_schema_instructions(modified_params, schema_dict)

response = self._get_sync_client().chat.completions.create(**modified_params)

usage = self._extract_openai_token_usage(response)
self._track_token_usage_internal(usage)

message = response.choices[0].message

if message.tool_calls and not available_functions:
self._emit_call_completed_event(
response=list(message.tool_calls),
call_type=LLMCallType.TOOL_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return list(message.tool_calls)

content = message.content or ""
if response_model:
try:
json_content = self._extract_json_from_text(content)
parsed = response_model.model_validate_json(json_content)
self._emit_call_completed_event(
response=parsed.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return parsed
except Exception as e:
logging.warning(
f"Structured output parsing failed, returning raw content: {e}"
)

content = self._apply_stop_words(content)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return content

async def _ahandle_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle async completion, falling back for providers without json_schema."""
if response_model and not self._provider_supports_json_schema:
return await self._ahandle_completion_fallback(
params, available_functions, from_task, from_agent, response_model
)
return await super()._ahandle_completion(
params, available_functions, from_task, from_agent, response_model
)

async def _ahandle_completion_fallback(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle async structured output via prompt injection instead of json_schema."""
from crewai.events.types.llm_events import LLMCallType

schema_dict = response_model.model_json_schema() if response_model else {}
modified_params = dict(params)
modified_params.pop("response_format", None)

self._inject_schema_instructions(modified_params, schema_dict)

response = await self._get_async_client().chat.completions.create(
**modified_params
)

usage = self._extract_openai_token_usage(response)
self._track_token_usage_internal(usage)

message = response.choices[0].message

if message.tool_calls and not available_functions:
self._emit_call_completed_event(
response=list(message.tool_calls),
call_type=LLMCallType.TOOL_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return list(message.tool_calls)

content = message.content or ""
if response_model:
try:
json_content = self._extract_json_from_text(content)
parsed = response_model.model_validate_json(json_content)
self._emit_call_completed_event(
response=parsed.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return parsed
except Exception as e:
logging.warning(
f"Structured output parsing failed, returning raw content: {e}"
)

content = self._apply_stop_words(content)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage,
)
return content

def _handle_streaming_completion(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | list[dict[str, Any]] | Any:
"""Handle streaming completion, falling back for providers without json_schema."""
if response_model and not self._provider_supports_json_schema:
return self._handle_streaming_completion_fallback(
params, available_functions, from_task, from_agent, response_model
)
return super()._handle_streaming_completion(
params, available_functions, from_task, from_agent, response_model
)

def _handle_streaming_completion_fallback(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle streaming structured output via prompt injection."""
from crewai.events.types.llm_events import LLMCallType

schema_dict = response_model.model_json_schema() if response_model else {}
modified_params = dict(params)
modified_params.pop("response_format", None)

self._inject_schema_instructions(modified_params, schema_dict)

full_response = ""
usage_data: dict[str, Any] | None = None

completion_stream = self._get_sync_client().chat.completions.create(
**modified_params
)

for chunk in completion_stream:
response_id_stream = chunk.id if hasattr(chunk, "id") else None

if hasattr(chunk, "usage") and chunk.usage:
usage_data = self._extract_openai_token_usage(chunk)
continue

if not chunk.choices:
continue

choice = chunk.choices[0]
delta = choice.delta

if delta.content:
full_response += delta.content
self._emit_stream_chunk_event(
chunk=delta.content,
from_task=from_task,
from_agent=from_agent,
response_id=response_id_stream,
)

if usage_data:
self._track_token_usage_internal(usage_data)

if response_model:
try:
json_content = self._extract_json_from_text(full_response)
parsed = response_model.model_validate_json(json_content)
self._emit_call_completed_event(
response=parsed.model_dump_json(),
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage_data,
)
return parsed
except Exception as e:
logging.warning(f"Structured output parsing failed in stream: {e}")

self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=modified_params["messages"],
usage=usage_data,
)
return full_response

def supports_function_calling(self) -> bool:
"""Check if the provider supports function calling.

Expand Down
Loading
Loading