diff --git a/agentplatform/_genai/_agent_engines_utils.py b/agentplatform/_genai/_agent_engines_utils.py index b269ba0833..dfa4f52cf5 100644 --- a/agentplatform/_genai/_agent_engines_utils.py +++ b/agentplatform/_genai/_agent_engines_utils.py @@ -17,6 +17,7 @@ import abc import asyncio import base64 +import dataclasses from importlib import metadata as importlib_metadata import inspect import io @@ -134,6 +135,37 @@ ClientFactory = None TaskIdParams = None TaskQueryParams = None +try: + from autogen.agentchat import chat + + AutogenChatResult = chat.ChatResult +except ImportError: + AutogenChatResult = Any +try: + from autogen.io import run_response + + AutogenRunResponse = run_response.RunResponse +except ImportError: + AutogenRunResponse = Any +try: + from llama_index.core.base.response import schema as llama_index_schema + from llama_index.core.base.llms import types as llama_index_types + + LlamaIndexResponse = llama_index_schema.Response + LlamaIndexBaseModel = llama_index_schema.BaseModel + LlamaIndexChatResponse = llama_index_types.ChatResponse +except ImportError: + LlamaIndexResponse = Any + LlamaIndexBaseModel = Any + LlamaIndexChatResponse = Any +try: + import pydantic + + BaseModel = pydantic.BaseModel +except ImportError: + BaseModel = Any + +JsonDict = Dict[str, Any] _ACTIONS_KEY = "actions" _ACTION_APPEND = "append" @@ -1994,3 +2026,235 @@ def _add_telemetry_enablement_env( return env_vars return env_vars | env_to_add + + +def _dataclass_to_dict_or_raise(obj: Any) -> Dict[str, Any]: + """Converts a dataclass to a JSON dictionary.""" + if not dataclasses.is_dataclass(obj): + raise TypeError(f"Object is not a dataclass: {obj}") + return json.loads(json.dumps(dataclasses.asdict(obj))) + + +def _autogen_run_response_protocol_to_dict( + obj: AutogenRunResponse, +) -> Dict[str, Any]: + """Converts an AutogenRunResponse object into a JSON-serializable dictionary.""" + if hasattr(obj, "process"): + obj.process() + last_speaker = None + if getattr(obj, "last_speaker", None) is not None: + last_speaker = { + "name": getattr(obj.last_speaker, "name", None), + "description": getattr(obj.last_speaker, "description", None), + } + cost = None + if getattr(obj, "cost", None) is not None: + if hasattr(obj.cost, "model_dump_json"): + cost = json.loads(obj.cost.model_dump_json()) + else: + cost = str(obj.cost) + result = { + "summary": getattr(obj, "summary", None), + "messages": list(getattr(obj, "messages", [])), + "context_variables": getattr(obj, "context_variables", None), + "last_speaker": last_speaker, + "cost": cost, + } + return json.loads(json.dumps(result)) + + +def to_json_serializable_autogen_object( + obj: Union[ + AutogenChatResult, + AutogenRunResponse, + ], +) -> Dict[str, Any]: + """Converts an Autogen object to a JSON serializable object.""" + if isinstance(obj, AutogenChatResult): + return _dataclass_to_dict_or_raise(obj) + return _autogen_run_response_protocol_to_dict(obj) + + +def _llama_index_response_to_dict(obj: LlamaIndexResponse) -> Any: + response = {} + if hasattr(obj, "response"): + response["response"] = obj.response + if hasattr(obj, "source_nodes"): + response["source_nodes"] = [node.model_dump_json() for node in obj.source_nodes] + if hasattr(obj, "metadata"): + response["metadata"] = obj.metadata + return json.loads(json.dumps(response)) + + +def _llama_index_chat_response_to_dict(obj: LlamaIndexChatResponse) -> Any: + return json.loads(obj.message.model_dump_json()) + + +def _llama_index_base_model_to_dict(obj: LlamaIndexBaseModel) -> Any: + return json.loads(obj.model_dump_json()) + + +def to_json_serializable_llama_index_object( + obj: Union[ + LlamaIndexResponse, + LlamaIndexBaseModel, + LlamaIndexChatResponse, + Sequence[LlamaIndexBaseModel], + ], +) -> Union[str, Dict[str, Any], Sequence[Union[str, Dict[str, Any]]]]: + """Converts a LlamaIndexResponse to a JSON serializable object.""" + if isinstance(obj, LlamaIndexResponse): + return _llama_index_response_to_dict(obj) + if isinstance(obj, LlamaIndexChatResponse): + return _llama_index_chat_response_to_dict(obj) + if isinstance(obj, Sequence): + seq_result = [] + for item in obj: + if isinstance(item, LlamaIndexBaseModel): + seq_result.append(_llama_index_base_model_to_dict(item)) + continue + seq_result.append(str(item)) + return seq_result + if isinstance(obj, LlamaIndexBaseModel): + return _llama_index_base_model_to_dict(obj) + return str(obj) + + +def is_noop_or_proxy_tracer_provider(tracer_provider) -> bool: + """Returns True if the tracer_provider is Proxy or NoOp.""" + opentelemetry = _import_opentelemetry_or_warn() + if not opentelemetry: + return False + ProxyTracerProvider = opentelemetry.trace.ProxyTracerProvider + NoOpTracerProvider = opentelemetry.trace.NoOpTracerProvider + return isinstance(tracer_provider, (NoOpTracerProvider, ProxyTracerProvider)) + + +def dump_event_for_json(event: BaseModel) -> Dict[str, Any]: + """Dumps an ADK event to a JSON-serializable dictionary.""" + return json.loads(event.model_dump_json(exclude_none=True)) + + +def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]: + """Tries to import the opentelemetry module.""" + try: + import opentelemetry + + return opentelemetry + except ImportError: + logger.warning( + "opentelemetry-api is not installed. Please call " + "'pip install google-cloud-aiplatform[agent_engines]'." + ) + return None + + +def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]: + """Tries to import the opentelemetry.sdk.trace module.""" + try: + import opentelemetry.sdk.trace + + return opentelemetry.sdk.trace + except ImportError: + logger.warning( + "opentelemetry-sdk is not installed. Please call " + "'pip install google-cloud-aiplatform[agent_engines]'." + ) + return None + + +def _import_cloud_trace_v2_or_warn() -> Optional[types.ModuleType]: + """Tries to import the google.cloud.trace_v2 module.""" + try: + import google.cloud.trace_v2 + + return google.cloud.trace_v2 + except ImportError: + logger.warning( + "google-cloud-trace is not installed. Please call " + "'pip install google-cloud-aiplatform[agent_engines]'." + ) + return None + + +def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]: + """Tries to import the opentelemetry.exporter.cloud_trace module.""" + try: + import opentelemetry.exporter.cloud_trace + + return opentelemetry.exporter.cloud_trace + except ImportError: + logger.warning( + "opentelemetry-exporter-gcp-trace is not installed. Please " + "call 'pip install google-cloud-aiplatform[agent_engines]'." + ) + return None + + +def _import_openinference_langchain_or_warn() -> Optional[types.ModuleType]: + """Tries to import the openinference.instrumentation.langchain module.""" + try: + import openinference.instrumentation.langchain + + return openinference.instrumentation.langchain + except ImportError: + logger.warning( + "openinference-instrumentation-langchain is not installed. Please " + "call 'pip install google-cloud-aiplatform[langchain]'." + ) + return None + + +def _import_openinference_autogen_or_warn() -> Optional[types.ModuleType]: + """Tries to import the openinference.instrumentation.autogen module.""" + try: + import openinference.instrumentation.autogen + + return openinference.instrumentation.autogen + except ImportError: + logger.warning( + "openinference-instrumentation-autogen is not installed. Please " + "call 'pip install google-cloud-aiplatform[ag2]'." + ) + return None + + +def _import_openinference_llama_index_or_warn() -> Optional[types.ModuleType]: + """Tries to import the openinference.instrumentation.llama_index module.""" + try: + import openinference.instrumentation.llama_index # noqa:F401 + + return openinference.instrumentation.llama_index + except ImportError: + logger.warning( + "openinference-instrumentation-llama_index is not installed. Please " + "call 'pip install google-cloud-aiplatform[llama_index]'." + ) + return None + + +def _import_nest_asyncio_or_warn() -> Optional[types.ModuleType]: + """Tries to import the nest_asyncio module.""" + try: + import nest_asyncio + + return nest_asyncio + except ImportError: + logger.warning( + "nest_asyncio is not installed. Please call: `pip install nest-asyncio`" + ) + return None + + +def _import_autogen_tools_or_warn() -> Optional[types.ModuleType]: + """Tries to import the autogen.tools module.""" + try: + from autogen import tools + + return tools + except ImportError: + logger.warning( + "autogen.tools is not installed. Please " + "call `pip install google-cloud-aiplatform[ag2]`." + ) + return None diff --git a/agentplatform/agent_engines/_agent_engines.py b/agentplatform/agent_engines/_agent_engines.py index 48d44fa51e..dd2efb73c4 100644 --- a/agentplatform/agent_engines/_agent_engines.py +++ b/agentplatform/agent_engines/_agent_engines.py @@ -45,14 +45,15 @@ from google.cloud.aiplatform import utils as aip_utils from google.cloud.aiplatform_v1 import types as aip_types from google.cloud.aiplatform_v1.types import reasoning_engine_service -from agentplatform.agent_engines import _utils +from agentplatform._genai import _agent_engines_utils import httpx import proto from google.protobuf import field_mask_pb2 -_LOGGER = _utils.LOGGER +_LOGGER = base.Logger("agentplatform.agent_engines") + _SUPPORTED_PYTHON_VERSIONS = ("3.10", "3.11", "3.12", "3.13", "3.14") _DEFAULT_GCS_DIR_NAME = "agent_engine" _BLOB_FILENAME = "agent_engine.pkl" @@ -907,9 +908,9 @@ def delete( operation_future.result() _LOGGER.info(f"Agent Engine deleted. Resource name: {self.resource_name}") - def operation_schemas(self) -> Sequence[_utils.JsonDict]: + def operation_schemas(self) -> Sequence[_agent_engines_utils.JsonDict]: """Returns the (Open)API schemas for the Agent Engine.""" - spec = _utils.to_dict(self._gca_resource.spec) + spec = _agent_engines_utils.to_dict(self._gca_resource.spec) if not hasattr(self, "_operation_schemas") or self._operation_schemas is None: self._operation_schemas = spec.get("class_methods", []) return self._operation_schemas @@ -1159,7 +1160,7 @@ def _validate_requirements_or_raise( logger.info(f"Read the following lines: {requirements}") except IOError as err: raise IOError(f"Failed to read requirements from {requirements=}") from err - requirements = _utils.validate_requirements_or_warn( + requirements = _agent_engines_utils.validate_requirements_or_warn( obj=agent_engine, requirements=requirements, logger=logger, @@ -1175,7 +1176,7 @@ def _validate_extra_packages_or_raise( """Tries to validates the extra packages.""" extra_packages = extra_packages or [] if build_options and _BUILD_OPTIONS_INSTALLATION in build_options: - _utils.validate_installation_scripts_or_raise( + _agent_engines_utils.validate_installation_scripts_or_raise( script_paths=build_options[_BUILD_OPTIONS_INSTALLATION], extra_packages=extra_packages, ) @@ -1195,7 +1196,7 @@ def _get_gcs_bucket( logger: base.Logger = _LOGGER, ) -> storage.Bucket: """Gets or creates the GCS bucket.""" - storage = _utils._import_cloud_storage_or_raise() + storage = _agent_engines_utils._import_cloud_storage_or_raise() storage_client = storage.Client(project=project) staging_bucket = staging_bucket.replace("gs://", "") try: @@ -1216,7 +1217,7 @@ def _upload_agent_engine( logger: base.Logger = _LOGGER, ) -> None: """Uploads the agent engine to GCS.""" - cloudpickle = _utils._import_cloudpickle_or_raise() + cloudpickle = _agent_engines_utils._import_cloudpickle_or_raise() blob = gcs_bucket.blob(f"{gcs_dir_name}/{_BLOB_FILENAME}") with blob.open("wb") as f: try: @@ -1342,7 +1343,7 @@ def _update_deployment_spec_with_env_vars_dict_or_raise( for key, value in env_vars.items(): if isinstance(value, Dict): try: - secret_ref = _utils.to_proto(value, aip_types.SecretRef()) + secret_ref = _agent_engines_utils.to_proto(value, aip_types.SecretRef()) except Exception as e: raise ValueError(f"Failed to convert to secret ref: {value}") from e deployment_spec.secret_env.append( @@ -1546,7 +1547,9 @@ def _generate_update_request_or_raise( ) -def _wrap_query_operation(method_name: str) -> Callable[..., _utils.JsonDict]: +def _wrap_query_operation( + method_name: str, +) -> Callable[..., _agent_engines_utils.JsonDict]: """Wraps an Agent Engine method, creating a callable for `query` API. This function creates a callable object that executes the specified @@ -1562,7 +1565,7 @@ def _wrap_query_operation(method_name: str) -> Callable[..., _utils.JsonDict]: the `query` API. """ - def _method(self, **kwargs) -> _utils.JsonDict: + def _method(self, **kwargs) -> _agent_engines_utils.JsonDict: response = self.execution_api_client.query_reasoning_engine( request=aip_types.QueryReasoningEngineRequest( name=self.resource_name, @@ -1570,7 +1573,7 @@ def _method(self, **kwargs) -> _utils.JsonDict: class_method=method_name, ), ) - output = _utils.to_dict(response) + output = _agent_engines_utils.to_dict(response) return output.get("output", output) return _method @@ -1592,7 +1595,7 @@ def _wrap_async_query_operation(method_name: str) -> Callable[..., Coroutine]: the `query` API. """ - async def _method(self, **kwargs) -> _utils.JsonDict: + async def _method(self, **kwargs) -> _agent_engines_utils.JsonDict: response = await self.execution_async_client.query_reasoning_engine( request=aip_types.QueryReasoningEngineRequest( name=self.resource_name, @@ -1600,7 +1603,7 @@ async def _method(self, **kwargs) -> _utils.JsonDict: class_method=method_name, ), ) - output = _utils.to_dict(response) + output = _agent_engines_utils.to_dict(response) return output.get("output", output) return _method @@ -1631,7 +1634,7 @@ def _method(self, **kwargs) -> Iterable[Any]: ), ) for chunk in response: - for parsed_json in _utils.yield_parsed_json(chunk): + for parsed_json in _agent_engines_utils.yield_parsed_json(chunk): if parsed_json is not None: yield parsed_json @@ -1665,7 +1668,7 @@ async def _method(self, **kwargs) -> AsyncIterable[Any]: ), ) for chunk in response: - for parsed_json in _utils.yield_parsed_json(chunk): + for parsed_json in _agent_engines_utils.yield_parsed_json(chunk): if parsed_json is not None: yield parsed_json @@ -1822,7 +1825,7 @@ async def _method(self, **kwargs) -> Any: def _unregister_api_methods( - obj: "AgentEngine", operation_schemas: Sequence[_utils.JsonDict] + obj: "AgentEngine", operation_schemas: Sequence[_agent_engines_utils.JsonDict] ): """Unregisters Agent Engine API methods based on operation schemas. @@ -1990,12 +1993,14 @@ def _generate_class_methods_spec_or_raise( method = getattr(agent_engine, method_name) try: - schema_dict = _utils.generate_schema(method, schema_name=method_name) + schema_dict = _agent_engines_utils.generate_schema( + method, schema_name=method_name + ) except Exception as e: logger.warning(f"failed to generate schema for {method_name}: {e}") continue - class_method = _utils.to_proto(schema_dict) + class_method = _agent_engines_utils.to_proto(schema_dict) class_method[_MODE_KEY_IN_SCHEMA] = mode # A2A agent card is a special case, when running in A2A mode, if hasattr(agent_engine, "agent_card"): @@ -2013,4 +2018,6 @@ def _class_methods_to_class_methods_spec( class_methods: List[dict[str, Any]], ) -> List[proto.Message]: """Converts a list of class methods to a list of ReasoningEngineSpec.ClassMethod messages.""" - return [_utils.to_proto(class_method) for class_method in class_methods] + return [ + _agent_engines_utils.to_proto(class_method) for class_method in class_methods + ] diff --git a/agentplatform/agent_engines/_utils.py b/agentplatform/agent_engines/_utils.py deleted file mode 100644 index 4b5aeb6daf..0000000000 --- a/agentplatform/agent_engines/_utils.py +++ /dev/null @@ -1,936 +0,0 @@ -# -*- coding: utf-8 -*- -# Copyright 2023 Google LLC -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -import dataclasses -import inspect -import json -import os -import sys -import types -import typing -from typing import ( - Any, - Callable, - Dict, - Iterable, - List, - Mapping, - Optional, - Sequence, - Set, - TypedDict, - Union, -) -from importlib import metadata as importlib_metadata - -import proto - -from google.cloud.aiplatform import base -from google.api import httpbody_pb2 -from google.protobuf import struct_pb2 -from google.protobuf import json_format - -try: - # For LangChain templates, they might not import langchain_core and get - # PydanticUserError: `query` is not fully defined; you should define - # `RunnableConfig`, then call `query.model_rebuild()`. - import langchain_core.runnables.config - - RunnableConfig = langchain_core.runnables.config.RunnableConfig -except ImportError: - RunnableConfig = Any - -try: - import packaging - - SpecifierSet = packaging.specifiers.SpecifierSet -except AttributeError: - SpecifierSet = Any - -try: - _BUILTIN_MODULE_NAMES: Sequence[str] = sys.builtin_module_names -except AttributeError: - _BUILTIN_MODULE_NAMES: Sequence[str] = [] - -try: - # sys.stdlib_module_names is available from Python 3.10 onwards. - _STDLIB_MODULE_NAMES: frozenset = sys.stdlib_module_names -except AttributeError: - _STDLIB_MODULE_NAMES: frozenset = frozenset() - -try: - _PACKAGE_DISTRIBUTIONS: Mapping[str, Sequence[str]] = ( - importlib_metadata.packages_distributions() - ) - -except AttributeError: - _PACKAGE_DISTRIBUTIONS: Mapping[str, Sequence[str]] = {} - -try: - from autogen.agentchat import chat - - AutogenChatResult = chat.ChatResult -except ImportError: - AutogenChatResult = Any - -try: - from autogen.io import run_response - - AutogenRunResponse = run_response.RunResponse -except ImportError: - AutogenRunResponse = Any - -try: - import pydantic - - BaseModel = pydantic.BaseModel -except ImportError: - BaseModel = Any - -JsonDict = Dict[str, Any] - - -class _RequirementsValidationActions(TypedDict): - append: Set[str] - - -class _RequirementsValidationWarnings(TypedDict): - missing: Set[str] - incompatible: Set[str] - - -class _RequirementsValidationResult(TypedDict): - warnings: _RequirementsValidationWarnings - actions: _RequirementsValidationActions - - -LOGGER = base.Logger("agentplatform.agent_engines") - -_BASE_MODULES = set(_BUILTIN_MODULE_NAMES + tuple(_STDLIB_MODULE_NAMES)) -_DEFAULT_REQUIRED_PACKAGES = frozenset(["cloudpickle", "pydantic"]) -_ACTIONS_KEY = "actions" -_ACTION_APPEND = "append" -_WARNINGS_KEY = "warnings" -_WARNING_MISSING = "missing" -_WARNING_INCOMPATIBLE = "incompatible" -_INSTALLATION_SUBDIR = "installation_scripts" - - -def to_proto( - obj: Union[JsonDict, proto.Message], - message: Optional[proto.Message] = None, -) -> proto.Message: - """Parses a JSON-like object into a message. - - If the object is already a message, this will return the object as-is. If - the object is a JSON Dict, this will parse and merge the object into the - message. - - Args: - obj (Union[dict[str, Any], proto.Message]): - Required. The object to convert to a proto message. - message (proto.Message): - Optional. A protocol buffer message to merge the obj into. It - defaults to Struct() if unspecified. - - Returns: - proto.Message: The same message passed as argument. - """ - if message is None: - message = struct_pb2.Struct() - if isinstance(obj, (proto.Message, struct_pb2.Struct)): - return obj - try: - json_format.ParseDict(obj, message._pb) - except AttributeError: - json_format.ParseDict(obj, message) - return message - - -def to_dict(message: proto.Message) -> JsonDict: - """Converts the contents of the protobuf message to JSON format. - - Args: - message (proto.Message): - Required. The proto message to be converted to a JSON dictionary. - - Returns: - dict[str, Any]: A dictionary containing the contents of the proto. - """ - try: - # Best effort attempt to convert the message into a JSON dictionary. - result: JsonDict = json.loads( - json_format.MessageToJson( - message._pb, - preserving_proto_field_name=True, - ) - ) - except AttributeError: - result: JsonDict = json.loads( - json_format.MessageToJson( - message, - preserving_proto_field_name=True, - ) - ) - return result - - -def _dataclass_to_dict_or_raise(obj: Any) -> JsonDict: - """Converts a dataclass to a JSON dictionary. - - Args: - obj (Any): - Required. The dataclass to be converted to a JSON dictionary. - - Returns: - dict[str, Any]: A dictionary containing the contents of the dataclass. - - Raises: - TypeError: If the object is not a dataclass. - """ - if not dataclasses.is_dataclass(obj): - raise TypeError(f"Object is not a dataclass: {obj}") - return json.loads(json.dumps(dataclasses.asdict(obj))) - - -def _autogen_run_response_protocol_to_dict( - obj: AutogenRunResponse, -) -> JsonDict: - """Converts an AutogenRunResponse object into a JSON-serializable dictionary. - - This function takes a `RunResponseProtocol` object and transforms its - relevant attributes into a dictionary format suitable for JSON conversion. - - The `RunResponseProtocol` defines the structure of the response object, - which typically includes: - - * **summary** (`Optional[str]`): - A textual summary of the run. - * **messages** (`Iterable[Message]`): - A sequence of messages exchanged during the run. - Each message is expected to be a JSON-serializable dictionary (`Dict[str, - Any]`). - * **events** (`Iterable[BaseEvent]`): - A sequence of events that occurred during the run. - Note: The `process()` method, if present, is called before conversion, - which typically clears this event queue. - * **context_variables** (`Optional[dict[str, Any]]`): - A dictionary containing contextual variables from the run. - * **last_speaker** (`Optional[Agent]`): - The agent that produced the last message. - The `Agent` object has attributes like `name` (Optional[str]) and - `description` (Optional[str]). - * **cost** (`Optional[Cost]`): - Information about the computational cost of the run. - The `Cost` object inherits from `pydantic.BaseModel` and is converted - to JSON using its `model_dump_json()` method. - * **process** (`Optional[Callable[[], None]]`): - An optional function (like a console event processor) that is called - before the conversion takes place. - Executing this method often clears the `events` queue. - - For a detailed definition of `RunResponseProtocol` and its components, refer - to: https://github.com/ag2ai/ag2/blob/main/autogen/io/run_response.py - - Args: - obj (AutogenRunResponse): The AutogenRunResponse object to convert. This - object must conform to the `RunResponseProtocol`. - - Returns: - JsonDict: A dictionary representation of the AutogenRunResponse, ready - to be serialized into JSON. The dictionary includes keys like - 'summary', 'messages', 'context_variables', 'last_speaker_name', - and 'cost'. - """ - if hasattr(obj, "process"): - obj.process() - - last_speaker = None - if getattr(obj, "last_speaker", None) is not None: - last_speaker = { - "name": getattr(obj.last_speaker, "name", None), - "description": getattr(obj.last_speaker, "description", None), - } - - cost = None - if getattr(obj, "cost", None) is not None: - if hasattr(obj.cost, "model_dump_json"): - cost = json.loads(obj.cost.model_dump_json()) - else: - cost = str(obj.cost) - - result = { - "summary": getattr(obj, "summary", None), - "messages": list(getattr(obj, "messages", [])), - "context_variables": getattr(obj, "context_variables", None), - "last_speaker": last_speaker, - "cost": cost, - } - return json.loads(json.dumps(result)) - - -def to_json_serializable_autogen_object( - obj: Union[ - AutogenChatResult, - AutogenRunResponse, - ], -) -> JsonDict: - """Converts an Autogen object to a JSON serializable object. - - In `ag2<=0.8.4`, `.run()` will return a `ChatResult` object. - In `ag2>=0.8.5`, `.run()` will return a `RunResponse` object. - - Args: - obj (Union[AutogenChatResult, AutogenRunResponse]): - Required. The Autogen object to be converted to a JSON serializable - object. - - Returns: - JsonDict: A JSON serializable object. - """ - if isinstance(obj, AutogenChatResult): - return _dataclass_to_dict_or_raise(obj) - return _autogen_run_response_protocol_to_dict(obj) - - -def yield_parsed_json(body: httpbody_pb2.HttpBody) -> Iterable[Any]: - """Converts the contents of the httpbody message to JSON format. - - Args: - body (httpbody_pb2.HttpBody): - Required. The httpbody body to be converted to a JSON. - - Yields: - Any: A JSON object or the original body if it is not JSON or None. - """ - content_type = getattr(body, "content_type", None) - data = getattr(body, "data", None) - - if content_type is None or data is None or "application/json" not in content_type: - yield body - return - - try: - utf8_data = data.decode("utf-8") - except Exception as e: - LOGGER.warning(f"Failed to decode data: {data}. Exception: {e}") - yield body - return - - if not utf8_data: - yield None - return - - # Handle the case of multiple dictionaries delimited by newlines. - for line in utf8_data.split("\n"): - if line: - try: - line = json.loads(line) - except Exception as e: - LOGGER.warning(f"failed to parse json: {line}. Exception: {e}") - yield line - - -def parse_constraints( - constraints: Sequence[str], -) -> Mapping[str, "SpecifierSet"]: - """Parses a list of constraints into a dict of requirements. - - Args: - constraints (list[str]): - Required. The list of package requirements to parse. This is assumed - to come from the `requirements.txt` file. - - Returns: - dict[str, SpecifierSet]: The specifiers for each package. - """ - requirements = _import_packaging_requirements_or_raise() - result = {} - for constraint in constraints: - try: - if constraint.endswith(".whl"): - constraint = os.path.basename(constraint) - requirement = requirements.Requirement(constraint) - except Exception as e: - LOGGER.warning(f"Failed to parse constraint: {constraint}. Exception: {e}") - continue - result[requirement.name] = requirement.specifier or None - return result - - -def validate_requirements_or_warn( - obj: Any, - requirements: List[str], - logger: base.Logger = LOGGER, -) -> Mapping[str, str]: - """Compiles the requirements into a list of requirements.""" - requirements = requirements.copy() - try: - current_requirements = scan_requirements(obj) - logger.info(f"Identified the following requirements: {current_requirements}") - constraints = parse_constraints(requirements) - missing_requirements = compare_requirements(current_requirements, constraints) - for warning_type, warnings in missing_requirements.get( - _WARNINGS_KEY, {} - ).items(): - if warnings: - logger.warning( - f"The following requirements are {warning_type}: {warnings}" - ) - for action_type, actions in missing_requirements.get(_ACTIONS_KEY, {}).items(): - if actions and action_type == _ACTION_APPEND: - for action in actions: - requirements.append(action) - logger.info(f"The following requirements are appended: {actions}") - except Exception as e: - logger.warning(f"Failed to compile requirements: {e}") - return requirements - - -def compare_requirements( - requirements: Mapping[str, str], - constraints: Union[Sequence[str], Mapping[str, "SpecifierSet"]], - *, - required_packages: Optional[Sequence[str]] = None, -) -> Mapping[str, Mapping[str, Any]]: - """Compares the requirements with the constraints. - - Args: - requirements (Mapping[str, str]): - Required. The packages (and their versions) to compare with the constraints. - This is assumed to be the result of `scan_requirements`. - constraints (Union[Sequence[str], Mapping[str, SpecifierSet]]): - Required. The package constraints to compare against. This is assumed - to be the result of `parse_constraints`. - required_packages (Sequence[str]): - Optional. The set of packages that are required to be in the - constraints. It defaults to the set of packages that are required - for deployment on Agent Engine. - - Returns: - dict[str, dict[str, Any]]: The comparison result as a dictionary containing: - * warnings: - * missing: The set of packages that are not in the constraints. - * incompatible: The set of packages that are in the constraints - but have versions that are not in the constraint specifier. - * actions: - * append: The set of packages that are not in the constraints - but should be appended to the constraints. - """ - packaging_version = _import_packaging_version_or_raise() - if required_packages is None: - required_packages = _DEFAULT_REQUIRED_PACKAGES - result = _RequirementsValidationResult( - warnings=_RequirementsValidationWarnings(missing=set(), incompatible=set()), - actions=_RequirementsValidationActions(append=set()), - ) - if isinstance(constraints, list): - constraints = parse_constraints(constraints) - for package, package_version in requirements.items(): - if package not in constraints: - result[_WARNINGS_KEY][_WARNING_MISSING].add(package) - if package in required_packages: - result[_ACTIONS_KEY][_ACTION_APPEND].add( - f"{package}=={package_version}" - ) - continue - if package_version: - package_specifier = constraints[package] - if not package_specifier: - continue - if packaging_version.Version(package_version) not in package_specifier: - result[_WARNINGS_KEY][_WARNING_INCOMPATIBLE].add( - f"{package}=={package_version} (required: {str(package_specifier)})" - ) - return result - - -def scan_requirements( - obj: Any, - ignore_modules: Optional[Sequence[str]] = None, - package_distributions: Optional[Mapping[str, Sequence[str]]] = None, - inspect_getmembers_kwargs: Optional[Mapping[str, Any]] = None, -) -> Mapping[str, str]: - """Scans the object for modules and returns the requirements discovered. - - This is not a comprehensive scan of the object, and only detects for common - cases based on the members of the object returned by `dir(obj)`. - - Args: - obj (Any): - Required. The object to scan for package requirements. - ignore_modules (Sequence[str]): - Optional. The set of modules to ignore. It defaults to the set of - built-in and stdlib modules. - package_distributions (Mapping[str, Sequence[str]]): - Optional. The mapping of module names to the set of packages that - contain them. It defaults to the set of packages from - `importlib_metadata.packages_distributions()`. - inspect_getmembers_kwargs (Mapping[str, Any]): - Optional. The keyword arguments to pass to `inspect.getmembers`. It - defaults to an empty dictionary. - - Returns: - Sequence[str]: The list of requirements that were discovered. - """ - if ignore_modules is None: - ignore_modules = _BASE_MODULES - if package_distributions is None: - package_distributions = _PACKAGE_DISTRIBUTIONS - modules_found = set(_DEFAULT_REQUIRED_PACKAGES) - inspect_getmembers_kwargs = inspect_getmembers_kwargs or {} - for _, attr in inspect.getmembers(obj, **inspect_getmembers_kwargs): - if not attr or inspect.isbuiltin(attr) or not hasattr(attr, "__module__"): - continue - module_name = (attr.__module__ or "").split(".")[0] - if module_name and module_name not in ignore_modules: - for module in package_distributions.get(module_name, []): - modules_found.add(module) - return {module: importlib_metadata.version(module) for module in modules_found} - - -def _is_pydantic_serializable(param: inspect.Parameter) -> bool: - """Checks if the parameter is pydantic serializable.""" - - if param.annotation == inspect.Parameter.empty: - return True - - if isinstance(param.annotation, str): - return False - pydantic = _import_pydantic_or_raise() - try: - pydantic.TypeAdapter(param.annotation) - return True - except Exception: - return False - - -def generate_schema( - f: Callable[..., Any], - *, - schema_name: Optional[str] = None, - descriptions: Mapping[str, str] = {}, - required: Sequence[str] = [], -) -> JsonDict: - """Generates the OpenAPI Schema for a callable object. - - Only positional and keyword arguments of the function `f` will be supported - in the OpenAPI Schema that is generated. I.e. `*args` and `**kwargs` will - not be present in the OpenAPI schema returned from this function. For those - cases, you can either include it in the docstring for `f`, or modify the - OpenAPI schema returned from this function to include additional arguments. - - Args: - f (Callable): - Required. The function to generate an OpenAPI Schema for. - schema_name (str): - Optional. The name for the OpenAPI schema. If unspecified, the name - of the Callable will be used. - descriptions (Mapping[str, str]): - Optional. A `{name: description}` mapping for annotating input - arguments of the function with user-provided descriptions. It - defaults to an empty dictionary (i.e. there will not be any - description for any of the inputs). - required (Sequence[str]): - Optional. For the user to specify the set of required arguments in - function calls to `f`. If specified, it will be automatically - inferred from `f`. - - Returns: - dict[str, Any]: The OpenAPI Schema for the function `f` in JSON format. - """ - pydantic = _import_pydantic_or_raise() - defaults = dict(inspect.signature(f).parameters) - fields_dict = { - name: ( - # 1. We infer the argument type here: use Any rather than None so - # it will not try to auto-infer the type based on the default value. - (param.annotation if param.annotation != inspect.Parameter.empty else Any), - pydantic.Field( - # 2. We do not support default values for now. - # default=( - # param.default if param.default != inspect.Parameter.empty - # else None - # ), - # 3. We support user-provided descriptions. - description=descriptions.get(name, None), - ), - ) - for name, param in defaults.items() - # We do not support *args or **kwargs - if param.kind - in ( - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - inspect.Parameter.POSITIONAL_ONLY, - ) - and _is_pydantic_serializable(param) - } - parameters = pydantic.create_model(f.__name__, **fields_dict).schema() - # Postprocessing - # 4. Suppress unnecessary title generation: - # * https://github.com/pydantic/pydantic/issues/1051 - # * http://cl/586221780 - parameters.pop("title", "") - for name, function_arg in parameters.get("properties", {}).items(): - function_arg.pop("title", "") - annotation = defaults[name].annotation - # 5. Nullable fields: - # * https://github.com/pydantic/pydantic/issues/1270 - # * https://stackoverflow.com/a/58841311 - # * https://github.com/pydantic/pydantic/discussions/4872 - if typing.get_origin(annotation) is Union and type(None) in typing.get_args( - annotation - ): - # for "typing.Optional" arguments, function_arg might be a - # dictionary like - # - # {'anyOf': [{'type': 'integer'}, {'type': 'null'}] - for schema in function_arg.pop("anyOf", []): - schema_type = schema.get("type") - if schema_type and schema_type != "null": - function_arg["type"] = schema_type - break - function_arg["nullable"] = True - # 6. Annotate required fields. - if required: - # We use the user-provided "required" fields if specified. - parameters["required"] = required - else: - # Otherwise we infer it from the function signature. - parameters["required"] = [ - k - for k in defaults - if ( - defaults[k].default == inspect.Parameter.empty - and defaults[k].kind - in ( - inspect.Parameter.POSITIONAL_OR_KEYWORD, - inspect.Parameter.KEYWORD_ONLY, - inspect.Parameter.POSITIONAL_ONLY, - ) - ) - ] - schema = dict(name=f.__name__, description=f.__doc__, parameters=parameters) - if schema_name: - schema["name"] = schema_name - return schema - - -def is_noop_or_proxy_tracer_provider(tracer_provider) -> bool: - """Returns True if the tracer_provider is Proxy or NoOp.""" - opentelemetry = _import_opentelemetry_or_warn() - ProxyTracerProvider = opentelemetry.trace.ProxyTracerProvider - NoOpTracerProvider = opentelemetry.trace.NoOpTracerProvider - return isinstance(tracer_provider, (NoOpTracerProvider, ProxyTracerProvider)) - - -def dump_event_for_json(event: BaseModel) -> Dict[str, Any]: - """Dumps an ADK event to a JSON-serializable dictionary.""" - return json.loads(event.model_dump_json(exclude_none=True)) - - -def _import_cloud_storage_or_raise() -> types.ModuleType: - """Tries to import the Cloud Storage module.""" - try: - from google.cloud import storage - except ImportError as e: - raise ImportError( - "Cloud Storage is not installed. Please call " - "'pip install google-cloud-aiplatform[agent_engines]'." - ) from e - return storage - - -def _import_cloudpickle_or_raise() -> types.ModuleType: - """Tries to import the cloudpickle module.""" - try: - import cloudpickle # noqa:F401 - except ImportError as e: - raise ImportError( - "cloudpickle is not installed. Please call " - "'pip install google-cloud-aiplatform[agent_engines]'." - ) from e - return cloudpickle - - -def _import_pydantic_or_raise() -> types.ModuleType: - """Tries to import the pydantic module.""" - try: - import pydantic - - _ = pydantic.Field - except AttributeError: - from pydantic import v1 as pydantic - except ImportError as e: - raise ImportError( - "pydantic is not installed. Please call " - "'pip install google-cloud-aiplatform[agent_engines]'." - ) from e - return pydantic - - -def _import_packaging_requirements_or_raise() -> types.ModuleType: - """Tries to import the packaging.requirements module.""" - try: - from packaging import requirements - except ImportError as e: - raise ImportError( - "packaging.requirements is not installed. Please call " - "'pip install google-cloud-aiplatform[agent_engines]'." - ) from e - return requirements - - -def _import_packaging_version_or_raise() -> types.ModuleType: - """Tries to import the packaging.requirements module.""" - try: - from packaging import version - except ImportError as e: - raise ImportError( - "packaging.version is not installed. Please call " - "'pip install google-cloud-aiplatform[agent_engines]'." - ) from e - return version - - -def _import_opentelemetry_or_warn() -> Optional[types.ModuleType]: - """Tries to import the opentelemetry module.""" - try: - import opentelemetry # noqa:F401 - - return opentelemetry - except ImportError: - LOGGER.warning( - "opentelemetry-sdk is not installed. Please call " - "'pip install google-cloud-aiplatform[agent_engines]'." - ) - return None - - -def _import_opentelemetry_sdk_trace_or_warn() -> Optional[types.ModuleType]: - """Tries to import the opentelemetry.sdk.trace module.""" - try: - import opentelemetry.sdk.trace # noqa:F401 - - return opentelemetry.sdk.trace - except ImportError: - LOGGER.warning( - "opentelemetry-sdk is not installed. Please call " - "'pip install google-cloud-aiplatform[agent_engines]'." - ) - return None - - -def _import_cloud_trace_v2_or_warn() -> Optional[types.ModuleType]: - """Tries to import the google.cloud.trace_v2 module.""" - try: - import google.cloud.trace_v2 - - return google.cloud.trace_v2 - except ImportError: - LOGGER.warning( - "google-cloud-trace is not installed. Please call " - "'pip install google-cloud-aiplatform[agent_engines]'." - ) - return None - - -def _import_cloud_trace_exporter_or_warn() -> Optional[types.ModuleType]: - """Tries to import the opentelemetry.exporter.cloud_trace module.""" - try: - import opentelemetry.exporter.cloud_trace # noqa:F401 - - return opentelemetry.exporter.cloud_trace - except ImportError: - LOGGER.warning( - "opentelemetry-exporter-gcp-trace is not installed. Please " - "call 'pip install google-cloud-aiplatform[agent_engines]'." - ) - return None - - -def _import_openinference_langchain_or_warn() -> Optional[types.ModuleType]: - """Tries to import the openinference.instrumentation.langchain module.""" - try: - import openinference.instrumentation.langchain # noqa:F401 - - return openinference.instrumentation.langchain - except ImportError: - LOGGER.warning( - "openinference-instrumentation-langchain is not installed. Please " - "call 'pip install google-cloud-aiplatform[langchain]'." - ) - return None - - -def _import_openinference_autogen_or_warn() -> Optional[types.ModuleType]: - """Tries to import the openinference.instrumentation.autogen module.""" - try: - import openinference.instrumentation.autogen # noqa:F401 - - return openinference.instrumentation.autogen - except ImportError: - LOGGER.warning( - "openinference-instrumentation-autogen is not installed. Please " - "call 'pip install google-cloud-aiplatform[ag2]'." - ) - return None - - -def _import_autogen_tools_or_warn() -> Optional[types.ModuleType]: - """Tries to import the autogen.tools module.""" - try: - from autogen import tools - - return tools - except ImportError: - LOGGER.warning( - "autogen.tools is not installed. Please " - "call `pip install google-cloud-aiplatform[ag2]`." - ) - return None - - -def validate_installation_scripts_or_raise( - script_paths: Sequence[str], - extra_packages: Sequence[str], -): - """Validates the installation scripts' path explicitly provided by the user. - - Args: - script_paths (Sequence[str]): - Required. The paths to the installation scripts. - extra_packages (Sequence[str]): - Required. The extra packages to be updated. - - Raises: - ValueError: If a user-defined script is not under the expected - subdirectory, or not in `extra_packages`, or if an extra package is - in the installation scripts subdirectory, but is not specified as an - installation script. - """ - for script_path in script_paths: - if not script_path.startswith(_INSTALLATION_SUBDIR): - LOGGER.warning( - f"User-defined installation script '{script_path}' is not in " - f"the expected '{_INSTALLATION_SUBDIR}' subdirectory. " - f"Ensure it is placed in '{_INSTALLATION_SUBDIR}' within your " - f"`extra_packages`." - ) - raise ValueError( - f"Required installation script '{script_path}' " - f"is not under '{_INSTALLATION_SUBDIR}'" - ) - - if script_path not in extra_packages: - LOGGER.warning( - f"User-defined installation script '{script_path}' is not in " - f"extra_packages. Ensure it is added to `extra_packages`." - ) - raise ValueError( - f"User-defined installation script '{script_path}' " - f"does not exist in `extra_packages`" - ) - - for extra_package in extra_packages: - if ( - extra_package.startswith(_INSTALLATION_SUBDIR) - and extra_package not in script_paths - ): - LOGGER.warning( - f"Extra package '{extra_package}' is in the installation " - "scripts subdirectory, but is not specified as an installation " - "script in `build_options`. " - "Ensure it is added to installation_scripts for " - "automatic execution." - ) - raise ValueError( - f"Extra package '{extra_package}' is in the installation " - "scripts subdirectory, but is not specified as an installation " - "script in `build_options`." - ) - return - - -def _validate_resource_limits_or_raise(resource_limits: dict[str, str]) -> None: - """Validates the resource limits. - - Checks that the resource limits are a dict with 'cpu' and 'memory' keys. - Checks that the 'cpu' value is one of 1, 2, 4, 6, 8. - Checks that the 'memory' value is a string ending with 'Gi'. - Checks that the memory size is smaller than 32Gi. - Checks that the memory size requires at least the specified number of CPUs. - - Args: - resource_limits: The resource limits to be validated. - - Raises: - TypeError: If the resource limits are not a dict. - KeyError: If the resource limits do not contain 'cpu' and 'memory' keys. - ValueError: If the 'cpu' value is not one of 1, 2, 4, 6, 8. - ValueError: If the 'memory' value is not a string ending with 'Gi'. - ValueError: If the memory size is too large. - ValueError: If the memory size requires more CPUs than the specified - 'cpu' value. - """ - if not isinstance(resource_limits, dict): - raise TypeError(f"resource_limits must be a dict. Got {type(resource_limits)}") - if "cpu" not in resource_limits or "memory" not in resource_limits: - raise KeyError("resource_limits must contain 'cpu' and 'memory' keys.") - - cpu = int(resource_limits["cpu"]) - memory_str = resource_limits["memory"] - - if cpu not in [1, 2, 4, 6, 8]: - raise ValueError( - "resource_limits['cpu'] must be one of 1, 2, 4, 6, 8. Got" f" {cpu}" - ) - - if not isinstance(memory_str, str) or not memory_str.endswith("Gi"): - raise ValueError( - "resource_limits['memory'] must be a string ending with 'Gi'." - f" Got {memory_str}" - ) - - try: - memory_gb = int(memory_str[:-2]) - except ValueError: - raise ValueError( - f"Invalid memory value: {memory_str}. Must be an integer" - " followed by 'Gi'." - ) - - # https://cloud.google.com/run/docs/configuring/memory-limits - if memory_gb > 32: - raise ValueError( - f"Memory size of {memory_str} is too large. Must be smaller than 32Gi." - ) - if memory_gb > 24: - min_cpu = 8 - elif memory_gb > 16: - min_cpu = 6 - elif memory_gb > 8: - min_cpu = 4 - elif memory_gb > 4: - min_cpu = 2 - else: - min_cpu = 1 - - if cpu < min_cpu: - raise ValueError( - f"Memory size of {memory_str} requires at least {min_cpu} CPUs." - f" Got {cpu}" - ) diff --git a/agentplatform/agent_engines/templates/adk.py b/agentplatform/agent_engines/templates/adk.py index 5f66ea60a9..a2e8cb5ab8 100644 --- a/agentplatform/agent_engines/templates/adk.py +++ b/agentplatform/agent_engines/templates/adk.py @@ -236,13 +236,13 @@ def __init__(self, **kwargs): # The session ID. def dump(self) -> Dict[str, Any]: - from agentplatform.agent_engines import _utils + from agentplatform._genai import _agent_engines_utils result = {} if self.events: result["events"] = [] for event in self.events: - event_dict = _utils.dump_event_for_json(event) + event_dict = _agent_engines_utils.dump_event_for_json(event) event_dict["invocation_id"] = event_dict.get("invocation_id", "") result["events"].append(event_dict) if self.artifacts: @@ -463,9 +463,9 @@ def _detect_cloud_resource_id(project_id: str) -> Optional[str]: # Avoids AttributeError: # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no # attribute 'add_span_processor'. - from agentplatform.agent_engines import _utils + from agentplatform._genai import agent_engines_utils - if _utils.is_noop_or_proxy_tracer_provider(tracer_provider): + if _agent_engines_utils.is_noop_or_proxy_tracer_provider(tracer_provider): tracer_provider = opentelemetry.sdk.trace.TracerProvider(resource=resource) opentelemetry.trace.set_tracer_provider(tracer_provider) # Avoids OpenTelemetry client already exists error. @@ -1156,7 +1156,7 @@ async def async_stream_query( a Content object. ValueError: If both session_id and session_events are specified. """ - from agentplatform.agent_engines import _utils + from agentplatform._genai import _agent_engines_utils from google.genai import types if isinstance(message, Dict): @@ -1211,7 +1211,7 @@ async def async_stream_query( try: async for event in events_async: # Yield the event data as a dictionary - yield _utils.dump_event_for_json(event) + yield _agent_engines_utils.dump_event_for_json(event) finally: # Avoid telemetry data loss having to do with CPU throttling on instance turndown _ = await _force_flush_otel( @@ -1261,7 +1261,7 @@ def stream_query( DeprecationWarning, stacklevel=2, ) - from agentplatform.agent_engines import _utils + from agentplatform._genai import _agent_engines_utils from google.genai import types if isinstance(message, Dict): @@ -1288,7 +1288,7 @@ def stream_query( run_config=run_config, **kwargs, ): - yield _utils.dump_event_for_json(event) + yield _agent_engines_utils.dump_event_for_json(event) else: for event in self._tmpl_attrs.get("runner").run( user_id=user_id, @@ -1296,7 +1296,7 @@ def stream_query( new_message=content, **kwargs, ): - yield _utils.dump_event_for_json(event) + yield _agent_engines_utils.dump_event_for_json(event) async def streaming_agent_run_with_events(self, request_json: str): """Streams responses asynchronously from the ADK application. diff --git a/agentplatform/agent_engines/templates/ag2.py b/agentplatform/agent_engines/templates/ag2.py index 84982ecd4c..97db7a497b 100644 --- a/agentplatform/agent_engines/templates/ag2.py +++ b/agentplatform/agent_engines/templates/ag2.py @@ -89,13 +89,13 @@ def _default_runnable_builder( def _default_instrumentor_builder(project_id: str): - from agentplatform.agent_engines import _utils + from agentplatform._genai import _agent_engines_utils - cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn() - cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn() - openinference_autogen = _utils._import_openinference_autogen_or_warn() - opentelemetry = _utils._import_opentelemetry_or_warn() - opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn() + cloud_trace_exporter = _agent_engines_utils._import_cloud_trace_exporter_or_warn() + cloud_trace_v2 = _agent_engines_utils._import_cloud_trace_v2_or_warn() + openinference_autogen = _agent_engines_utils._import_openinference_autogen_or_warn() + opentelemetry = _agent_engines_utils._import_opentelemetry_or_warn() + opentelemetry_sdk_trace = _agent_engines_utils._import_opentelemetry_sdk_trace_or_warn() if all( ( cloud_trace_exporter, @@ -142,7 +142,7 @@ def _default_instrumentor_builder(project_id: str): # Avoids AttributeError: # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no # attribute 'add_span_processor'. - if _utils.is_noop_or_proxy_tracer_provider(tracer_provider): + if _agent_engines_utils.is_noop_or_proxy_tracer_provider(tracer_provider): tracer_provider = opentelemetry_sdk_trace.TracerProvider() opentelemetry.trace.set_tracer_provider(tracer_provider) # Avoids OpenTelemetry client already exists error. @@ -396,9 +396,9 @@ def set_up(self): tools = self._tmpl_attrs.get("tools") ag2_tool_objects = self._tmpl_attrs.get("ag2_tool_objects") if tools and not ag2_tool_objects: - from agentplatform.agent_engines import _utils + from agentplatform._genai import _agent_engines_utils - autogen_tools = _utils._import_autogen_tools_or_warn() + autogen_tools = _agent_engines_utils._import_autogen_tools_or_warn() if autogen_tools: for tool in tools: ag2_tool_objects.append(autogen_tools.Tool(func_or_tool=tool)) @@ -484,6 +484,6 @@ def query( **kwargs, ) - from agentplatform.agent_engines import _utils + from agentplatform._genai import _agent_engines_utils - return _utils.to_json_serializable_autogen_object(response) + return _agent_engines_utils.to_json_serializable_autogen_object(response) diff --git a/agentplatform/agent_engines/templates/langchain.py b/agentplatform/agent_engines/templates/langchain.py index 2ee9039a5f..6722a6b834 100644 --- a/agentplatform/agent_engines/templates/langchain.py +++ b/agentplatform/agent_engines/templates/langchain.py @@ -192,13 +192,13 @@ def _default_runnable_builder( def _default_instrumentor_builder(project_id: str): - from agentplatform.agent_engines import _utils + from agentplatform._genai import _agent_engines_utils - cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn() - cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn() - openinference_langchain = _utils._import_openinference_langchain_or_warn() - opentelemetry = _utils._import_opentelemetry_or_warn() - opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn() + cloud_trace_exporter = _agent_engines_utils._import_cloud_trace_exporter_or_warn() + cloud_trace_v2 = _agent_engines_utils._import_cloud_trace_v2_or_warn() + openinference_langchain = _agent_engines_utils._import_openinference_langchain_or_warn() + opentelemetry = _agent_engines_utils._import_opentelemetry_or_warn() + opentelemetry_sdk_trace = _agent_engines_utils._import_opentelemetry_sdk_trace_or_warn() if all( ( cloud_trace_exporter, @@ -245,7 +245,7 @@ def _default_instrumentor_builder(project_id: str): # Avoids AttributeError: # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no # attribute 'add_span_processor'. - if _utils.is_noop_or_proxy_tracer_provider(tracer_provider): + if _agent_engines_utils.is_noop_or_proxy_tracer_provider(tracer_provider): tracer_provider = opentelemetry_sdk_trace.TracerProvider() opentelemetry.trace.set_tracer_provider(tracer_provider) # Avoids OpenTelemetry client already exists error. diff --git a/agentplatform/agent_engines/templates/langgraph.py b/agentplatform/agent_engines/templates/langgraph.py index cdfb417a8e..6de517e171 100644 --- a/agentplatform/agent_engines/templates/langgraph.py +++ b/agentplatform/agent_engines/templates/langgraph.py @@ -168,13 +168,13 @@ def _default_runnable_builder( def _default_instrumentor_builder(project_id: str): - from agentplatform.agent_engines import _utils + from agentplatform._genai import _agent_engines_utils - cloud_trace_exporter = _utils._import_cloud_trace_exporter_or_warn() - cloud_trace_v2 = _utils._import_cloud_trace_v2_or_warn() - openinference_langchain = _utils._import_openinference_langchain_or_warn() - opentelemetry = _utils._import_opentelemetry_or_warn() - opentelemetry_sdk_trace = _utils._import_opentelemetry_sdk_trace_or_warn() + cloud_trace_exporter = _agent_engines_utils._import_cloud_trace_exporter_or_warn() + cloud_trace_v2 = _agent_engines_utils._import_cloud_trace_v2_or_warn() + openinference_langchain = _agent_engines_utils._import_openinference_langchain_or_warn() + opentelemetry = _agent_engines_utils._import_opentelemetry_or_warn() + opentelemetry_sdk_trace = _agent_engines_utils._import_opentelemetry_sdk_trace_or_warn() if all( ( cloud_trace_exporter, @@ -220,7 +220,7 @@ def _default_instrumentor_builder(project_id: str): # Avoids AttributeError: # 'ProxyTracerProvider' and 'NoOpTracerProvider' objects has no # attribute 'add_span_processor'. - if _utils.is_noop_or_proxy_tracer_provider(tracer_provider): + if _agent_engines_utils.is_noop_or_proxy_tracer_provider(tracer_provider): tracer_provider = opentelemetry_sdk_trace.TracerProvider() opentelemetry.trace.set_tracer_provider(tracer_provider) # Avoids OpenTelemetry client already exists error. diff --git a/noxfile.py b/noxfile.py index da8a2863d7..ab494633c4 100644 --- a/noxfile.py +++ b/noxfile.py @@ -223,6 +223,7 @@ def default(session): "--ignore=tests/unit/architecture", "--ignore=tests/unit/vertexai/genai/replays", "--ignore=tests/unit/agentplatform/genai/replays", + "--ignore=tests/unit/agentplatform/frameworks", os.path.join("tests", "unit"), *session.posargs, ) @@ -301,12 +302,12 @@ def unit_ray(session, ray): def unit_langchain(session): # Install all test dependencies, then install this package in-place. - constraints_path = str(CURRENT_DIRECTORY / "testing" / "constraints-langchain.txt") + constraints_path = str(CURRENT_DIRECTORY / "testing" / "constraints-ag2.txt") standard_deps = UNIT_TEST_STANDARD_DEPENDENCIES + UNIT_TEST_DEPENDENCIES session.install(*standard_deps, "-c", constraints_path) # Install langchain extras - session.install("-e", ".[langchain_testing]", "-c", constraints_path) + session.install("-e", ".[adk_testing]", "-c", constraints_path) # Run py.test against the unit tests. session.run( @@ -318,7 +319,7 @@ def unit_langchain(session): "--cov-config=.coveragerc", "--cov-report=", "--cov-fail-under=0", - os.path.join("tests", "unit", "vertex_langchain"), + os.path.join("tests", "unit", "agentplatform", "frameworks", "test_frameworks_adk.py"), *session.posargs, ) diff --git a/setup.py b/setup.py index 65be53bae9..f049bc7b9d 100644 --- a/setup.py +++ b/setup.py @@ -175,6 +175,14 @@ "aiohttp", # for ADK users to use aiohttp rather than httpx client ] +adk_testing_extra_require = list( + set( + adk_extra_require + + reasoning_engine_extra_require + + ["absl-py", "pytest-xdist"] + ) +) + evaluation_extra_require = [ "pandas >= 1.0.0", "tqdm>=4.23.0", @@ -349,6 +357,7 @@ "ray": ray_extra_require, "ray_testing": ray_testing_extra_require, "adk": adk_extra_require, + "adk_testing": adk_testing_extra_require, "reasoningengine": reasoning_engine_extra_require, "agent_engines": agent_engines_extra_require, "evaluation": evaluation_extra_require, diff --git a/tests/unit/agentplatform/frameworks/test_frameworks_adk.py b/tests/unit/agentplatform/frameworks/test_frameworks_adk.py new file mode 100644 index 0000000000..34e0f80c21 --- /dev/null +++ b/tests/unit/agentplatform/frameworks/test_frameworks_adk.py @@ -0,0 +1,1432 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import base64 +import importlib +import json +import os +import re +import sys +from typing import Optional +from unittest import mock +import uuid + +import cloudpickle +from google import auth +from google.api_core import operation as ga_operation +from google.auth import credentials as auth_credentials +from google.auth.transport import mtls +from google.cloud import storage +from google.cloud import aiplatform +import agentplatform +from google.cloud.aiplatform import base +from google.cloud.aiplatform import initializer +from google.cloud.aiplatform_v1 import types as aip_types +from google.cloud.aiplatform_v1.services import reasoning_engine_service +from agentplatform import agent_engines +from agentplatform.agent_engines import _agent_engines +from agentplatform.agent_engines import _utils +from agentplatform.agent_engines.templates import ( + adk as adk_template, +) +from google.genai import types +import pytest + + +try: + from google.adk.agents import llm_agent + + Agent = llm_agent.Agent +except ImportError: + + class Agent: + def __init__(self, name: str, model: str): + self.name = name + self.model = model + + +_TEST_LOCATION = "us-central1" +_TEST_PROJECT = "test-project" +_TEST_PROJECT_ID = "test-project-id" +_TEST_API_KEY = "test-api-key" +_TEST_MODEL = "gemini-2.0-flash" +_TEST_USER_ID = "test_user_id" +_TEST_AGENT_NAME = "test_agent" +_TEST_AGENT = Agent(name=_TEST_AGENT_NAME, model=_TEST_MODEL) +_TEST_SESSION = { + "id": "ca18c25a-644b-4e13-9b24-78c150ec3eb9", + "app_name": "default_app_name", + "user_id": _TEST_USER_ID, + "events": [ + { + "author": "user", + "content": { + "parts": [{"text": "My cat's name is Garfield"}], + "role": "user", + }, + }, + { + "author": "my_personal_agent", + "content": { + "parts": [{"text": "Okay, good to know!"}], + "role": "model", + }, + }, + ], +} +_TEST_SEARCH_MEMORY_QUERY = "What is my cat's name" +_TEST_RUN_CONFIG = { + "save_input_blobs_as_artifacts": False, + "support_cfc": False, + "streaming_mode": "sse", + "max_llm_calls": 500, +} +_TEST_STAGING_BUCKET = "gs://test-bucket" +_TEST_CREDENTIALS = mock.Mock(spec=auth_credentials.AnonymousCredentials()) +_TEST_PARENT = f"projects/{_TEST_PROJECT}/locations/{_TEST_LOCATION}" +_TEST_RESOURCE_ID = "1028944691210842416" +_TEST_AGENT_ENGINE_RESOURCE_NAME = ( + f"{_TEST_PARENT}/reasoningEngines/{_TEST_RESOURCE_ID}" +) +_TEST_AGENT_ENGINE_DISPLAY_NAME = "Agent Engine Display Name" +_TEST_GCS_DIR_NAME = _agent_engines._DEFAULT_GCS_DIR_NAME +_TEST_BLOB_FILENAME = _agent_engines._BLOB_FILENAME +_TEST_REQUIREMENTS_FILE = _agent_engines._REQUIREMENTS_FILE +_TEST_EXTRA_PACKAGES_FILE = _agent_engines._EXTRA_PACKAGES_FILE +_TEST_AGENT_ENGINE_GCS_URI = "{}/{}/{}".format( + _TEST_STAGING_BUCKET, + _TEST_GCS_DIR_NAME, + _TEST_BLOB_FILENAME, +) +_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI = "{}/{}/{}".format( + _TEST_STAGING_BUCKET, + _TEST_GCS_DIR_NAME, + _TEST_EXTRA_PACKAGES_FILE, +) +_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI = "{}/{}/{}".format( + _TEST_STAGING_BUCKET, + _TEST_GCS_DIR_NAME, + _TEST_REQUIREMENTS_FILE, +) +_TEST_AGENT_ENGINE_PACKAGE_SPEC = aip_types.ReasoningEngineSpec.PackageSpec( + python_version=f"{sys.version_info.major}.{sys.version_info.minor}", + pickle_object_gcs_uri=_TEST_AGENT_ENGINE_GCS_URI, + dependency_files_gcs_uri=_TEST_AGENT_ENGINE_DEPENDENCY_FILES_GCS_URI, + requirements_gcs_uri=_TEST_AGENT_ENGINE_REQUIREMENTS_GCS_URI, +) +_ADK_AGENT_FRAMEWORK = adk_template.AdkApp.agent_framework +_TEST_AGENT_ENGINE_OBJ = aip_types.ReasoningEngine( + name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + display_name=_TEST_AGENT_ENGINE_DISPLAY_NAME, + spec=aip_types.ReasoningEngineSpec( + package_spec=_TEST_AGENT_ENGINE_PACKAGE_SPEC, + agent_framework=_ADK_AGENT_FRAMEWORK, + ), +) + +GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY = ( + "GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY" +) + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as google_auth_mock: + credentials_mock = mock.Mock() + credentials_mock.with_quota_project.return_value = None + google_auth_mock.return_value = ( + credentials_mock, + _TEST_PROJECT, + ) + yield google_auth_mock + + +@pytest.fixture +def agentplatform_init_mock(): + with mock.patch.object(agentplatform, "init") as agentplatform_init_mock: + yield agentplatform_init_mock + + +@pytest.fixture +def otlp_span_exporter_mock(): + with mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" + ) as otlp_span_exporter_mock: + yield otlp_span_exporter_mock + + +@pytest.fixture +def tracer_provider_mock(): + import opentelemetry.sdk.trace + + with mock.patch.object( + opentelemetry.sdk.trace, "TracerProvider" + ) as tracer_provider_mock: + yield tracer_provider_mock + + +@pytest.fixture +def trace_provider_force_flush_mock(): + import opentelemetry.trace + import opentelemetry.sdk.trace + + with mock.patch.object( + opentelemetry.trace, "get_tracer_provider" + ) as get_tracer_provider_mock: + get_tracer_provider_mock.return_value = mock.Mock( + spec=opentelemetry.sdk.trace.TracerProvider() + ) + yield get_tracer_provider_mock.return_value.force_flush + + +@pytest.fixture +def logger_provider_force_flush_mock(): + import opentelemetry._logs + import opentelemetry.sdk._logs + + with mock.patch.object( + opentelemetry._logs, "get_logger_provider" + ) as get_logger_provider_mock: + get_logger_provider_mock.return_value = mock.Mock( + spec=opentelemetry.sdk._logs.LoggerProvider() + ) + yield get_logger_provider_mock.return_value.force_flush + + +@pytest.fixture +def default_instrumentor_builder_mock(): + with mock.patch.object( + adk_template, "_default_instrumentor_builder" + ) as default_instrumentor_builder_mock: + yield default_instrumentor_builder_mock + + +@pytest.fixture +def simple_span_processor_mock(): + with mock.patch( + "opentelemetry.sdk.trace.export.SimpleSpanProcessor" + ) as simple_span_processor_mock: + yield simple_span_processor_mock + + +@pytest.fixture +def adk_version_mock(): + with mock.patch.object(adk_template, "get_adk_version") as adk_version_mock: + yield adk_version_mock + + +@pytest.fixture +def is_version_sufficient_mock(): + with mock.patch.object( + adk_template, "is_version_sufficient" + ) as is_version_sufficient_mock: + is_version_sufficient_mock.return_value = True + yield is_version_sufficient_mock + + +@pytest.fixture +def get_project_id_mock(): + from google.cloud.aiplatform.utils import resource_manager_utils + + with mock.patch.object( + resource_manager_utils, "get_project_id" + ) as get_project_id_mock: + get_project_id_mock.return_value = _TEST_PROJECT_ID + yield get_project_id_mock + + +@pytest.fixture +def warn_if_telemetry_api_disabled_mock(): + with mock.patch.object( + adk_template, "_warn_if_telemetry_api_disabled" + ) as warn_if_telemetry_api_disabled_mock: + yield warn_if_telemetry_api_disabled_mock + + +class _MockRunner: + def run(self, *args, **kwargs): + from google.adk.events import event + + yield event.Event( + **{ + "author": "currency_exchange_agent", + "content": { + "parts": [ + { + "thought_signature": b"test_signature", + "function_call": { + "args": { + "currency_date": "2025-04-03", + "currency_from": "USD", + "currency_to": "SEK", + }, + "id": "af-c5a57692-9177-4091-a3df-098f834ee849", + "name": "get_exchange_rate", + }, + } + ], + "role": "model", + }, + "id": "9aaItGK9", + "invocation_id": "e-6543c213-6417-484b-9551-b67915d1d5f7", + } + ) + + async def run_async(self, *args, **kwargs): + from google.adk.events import event + + yield event.Event( + **{ + "author": "currency_exchange_agent", + "content": { + "parts": [ + { + "thought_signature": b"test_signature", + "function_call": { + "args": { + "currency_date": "2025-04-03", + "currency_from": "USD", + "currency_to": "SEK", + }, + "id": "af-c5a57692-9177-4091-a3df-098f834ee849", + "name": "get_exchange_rate", + }, + } + ], + "role": "model", + }, + "id": "9aaItGK9", + "invocation_id": "e-6543c213-6417-484b-9551-b67915d1d5f7", + } + ) + + +@pytest.mark.usefixtures("google_auth_mock") +class TestAdkApp: + def test_adk_version(self): + with mock.patch.object( + adk_template, + "get_adk_version", + return_value="0.5.0", + ): + with pytest.raises( + ValueError, + match=( + "Unsupported google-adk version: 0.5.0, please use" + " google-adk>=1.5.0 for AdkApp deployment on Agent Engine." + ), + ): + agent_engines.AdkApp(agent=_TEST_AGENT) + + def setup_method(self): + importlib.reload(initializer) + importlib.reload(agentplatform) + agentplatform.init(project=_TEST_PROJECT, location=_TEST_LOCATION) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_initialization(self): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("project") == _TEST_PROJECT + assert app._tmpl_attrs.get("location") == _TEST_LOCATION + assert app._tmpl_attrs.get("runner") is None + + def test_set_up( + self, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + assert app._tmpl_attrs.get("runner") is not None + + def test_clone( + self, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + app.set_up() + assert app._tmpl_attrs.get("runner") is not None + app_clone = app.clone() + assert app._tmpl_attrs.get("runner") is not None + assert app_clone._tmpl_attrs.get("runner") is None + app_clone.set_up() + assert app_clone._tmpl_attrs.get("runner") is not None + + def test_register_operations(self): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + for operations in app.register_operations().values(): + for operation in operations: + assert operation in dir(app) + + def test_stream_query( + self, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + events = list( + app.stream_query( + user_id=_TEST_USER_ID, + message="test message", + ) + ) + assert len(events) == 1 + + def test_stream_query_with_content( + self, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + events = list( + app.stream_query( + user_id=_TEST_USER_ID, + message=types.Content( + role="user", + parts=[ + types.Part( + text="test message with content", + ) + ], + ).model_dump(), + ) + ) + assert len(events) == 1 + + @pytest.mark.asyncio + async def test_async_stream_query( + self, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + events = [] + async for event in app.async_stream_query( + user_id=_TEST_USER_ID, + message="test message", + ): + events.append(event) + assert len(events) == 1 + + @pytest.mark.asyncio + @mock.patch.dict( + os.environ, + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}, + ) + async def test_async_stream_query_force_flush_otel( + self, + trace_provider_force_flush_mock: mock.Mock, + logger_provider_force_flush_mock: mock.Mock, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + async for _ in app.async_stream_query( + user_id=_TEST_USER_ID, + message="test message", + ): + pass + + trace_provider_force_flush_mock.assert_called_once() + logger_provider_force_flush_mock.assert_called_once() + + @pytest.mark.asyncio + async def test_async_stream_query_with_content( + self, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("runner") is None + app.set_up() + app._tmpl_attrs["runner"] = _MockRunner() + events = [] + async for event in app.async_stream_query( + user_id=_TEST_USER_ID, + message=types.Content( + role="user", + parts=[ + types.Part( + text="test message with content", + ) + ], + ).model_dump(), + ): + events.append(event) + assert len(events) == 1 + + @pytest.mark.asyncio + async def test_streaming_agent_run_with_events( + self, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + app.set_up() + app._tmpl_attrs["in_memory_runner"] = _MockRunner() + request_json = json.dumps( + { + "authorizations": { + "test_user_id1": {"access_token": "test_access_token"}, + "test_user_id2": {"accessToken": "test-access-token"}, + }, + "user_id": _TEST_USER_ID, + "message": { + "parts": [{"text": "What is the exchange rate from USD to SEK?"}], + "role": "user", + }, + } + ) + events = [] + async for event in app.streaming_agent_run_with_events( + request_json=request_json, + ): + events.append(event) + assert len(events) == 1 + + @pytest.mark.asyncio + @mock.patch.dict( + os.environ, + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}, + ) + async def test_streaming_agent_run_with_events_force_flush_otel( + self, + trace_provider_force_flush_mock: mock.Mock, + logger_provider_force_flush_mock: mock.Mock, + default_instrumentor_builder_mock: mock.Mock, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + app.set_up() + app._tmpl_attrs["in_memory_runner"] = _MockRunner() + request_json = json.dumps( + { + "authorizations": { + "test_user_id1": {"access_token": "test_access_token"}, + "test_user_id2": {"accessToken": "test-access-token"}, + }, + "user_id": _TEST_USER_ID, + "message": { + "parts": [{"text": "What is the exchange rate from USD to SEK?"}], + "role": "user", + }, + } + ) + async for _ in app.streaming_agent_run_with_events( + request_json=request_json, + ): + pass + + trace_provider_force_flush_mock.assert_called_once() + logger_provider_force_flush_mock.assert_called_once() + + @pytest.mark.asyncio + async def test_async_create_session(self, get_project_id_mock: mock.Mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + session1 = await app.async_create_session(user_id=_TEST_USER_ID) + assert session1["user_id"] == _TEST_USER_ID + session2 = await app.async_create_session( + user_id=_TEST_USER_ID, session_id="test_session_id" + ) + assert session2["user_id"] == _TEST_USER_ID + assert session2["id"] == "test_session_id" + + @pytest.mark.asyncio + async def test_async_get_session(self, get_project_id_mock: mock.Mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + session1 = await app.async_create_session(user_id=_TEST_USER_ID) + session2 = await app.async_get_session( + user_id=_TEST_USER_ID, + session_id=session1["id"], + ) + assert session2.user_id == _TEST_USER_ID + assert session1["id"] == session2.id + + @pytest.mark.asyncio + async def test_async_list_sessions(self, get_project_id_mock: mock.Mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + response0 = await app.async_list_sessions(user_id=_TEST_USER_ID) + assert not response0.sessions + session = await app.async_create_session(user_id=_TEST_USER_ID) + response1 = await app.async_list_sessions(user_id=_TEST_USER_ID) + assert len(response1.sessions) == 1 + assert response1.sessions[0].id == session["id"] + session2 = await app.async_create_session(user_id=_TEST_USER_ID) + response2 = await app.async_list_sessions(user_id=_TEST_USER_ID) + assert len(response2.sessions) == 2 + assert response2.sessions[0].id == session["id"] + assert response2.sessions[1].id == session2["id"] + + @pytest.mark.asyncio + async def test_async_delete_session(self, get_project_id_mock: mock.Mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + response = await app.async_delete_session( + user_id=_TEST_USER_ID, + session_id="", + ) + assert not response + session = await app.async_create_session(user_id=_TEST_USER_ID) + response1 = await app.async_list_sessions(user_id=_TEST_USER_ID) + assert len(response1.sessions) == 1 + await app.async_delete_session( + user_id=_TEST_USER_ID, + session_id=session["id"], + ) + response0 = await app.async_list_sessions(user_id=_TEST_USER_ID) + assert not response0.sessions + + def test_create_session(self, get_project_id_mock: mock.Mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + session1 = app.create_session(user_id=_TEST_USER_ID) + assert session1["user_id"] == _TEST_USER_ID + session2 = app.create_session( + user_id=_TEST_USER_ID, session_id="test_session_id" + ) + assert session2["user_id"] == _TEST_USER_ID + assert session2["id"] == "test_session_id" + + def test_get_session(self, get_project_id_mock: mock.Mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + session1 = app.create_session(user_id=_TEST_USER_ID) + session2 = app.get_session( + user_id=_TEST_USER_ID, + session_id=session1["id"], + ) + assert session2.user_id == _TEST_USER_ID + assert session1["id"] == session2.id + + def test_list_sessions(self, get_project_id_mock: mock.Mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + response0 = app.list_sessions(user_id=_TEST_USER_ID) + assert not response0.sessions + session = app.create_session(user_id=_TEST_USER_ID) + response1 = app.list_sessions(user_id=_TEST_USER_ID) + assert len(response1.sessions) == 1 + assert response1.sessions[0].id == session["id"] + session2 = app.create_session(user_id=_TEST_USER_ID) + response2 = app.list_sessions(user_id=_TEST_USER_ID) + assert len(response2.sessions) == 2 + assert response2.sessions[0].id == session["id"] + assert response2.sessions[1].id == session2["id"] + + def test_delete_session(self, get_project_id_mock: mock.Mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + response = app.delete_session(user_id=_TEST_USER_ID, session_id="") + assert not response + session = app.create_session(user_id=_TEST_USER_ID) + response1 = app.list_sessions(user_id=_TEST_USER_ID) + assert len(response1.sessions) == 1 + app.delete_session(user_id=_TEST_USER_ID, session_id=session["id"]) + response0 = app.list_sessions(user_id=_TEST_USER_ID) + assert not response0.sessions + + @pytest.mark.asyncio + async def test_async_add_session_to_memory_dict( + self, + get_project_id_mock: mock.Mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + response = await app.async_search_memory( + user_id=_TEST_USER_ID, + query=_TEST_SEARCH_MEMORY_QUERY, + ) + assert not response.memories + await app.async_add_session_to_memory(session=_TEST_SESSION) + response = await app.async_search_memory( + user_id=_TEST_USER_ID, + query=_TEST_SEARCH_MEMORY_QUERY, + ) + assert len(response.memories) >= 1 + + @pytest.mark.asyncio + async def test_async_search_memory(self, get_project_id_mock: mock.Mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + response = await app.async_search_memory( + user_id=_TEST_USER_ID, + query=_TEST_SEARCH_MEMORY_QUERY, + ) + assert not response.memories + await app.async_add_session_to_memory(session=_TEST_SESSION) + response = await app.async_search_memory( + user_id=_TEST_USER_ID, + query=_TEST_SEARCH_MEMORY_QUERY, + ) + assert len(response.memories) >= 1 + + @pytest.mark.parametrize( + "adk_version,enable_tracing,enable_telemetry,want_tracing_setup,want_logging_setup", + [ + ("1.16.0", False, False, False, False), + ("1.16.0", False, True, False, True), + ("1.16.0", False, None, False, False), + ("1.16.0", True, False, False, False), + ("1.16.0", True, True, True, True), + ("1.16.0", True, None, True, False), + ("1.16.0", None, False, False, False), + ("1.16.0", None, True, False, True), + ("1.16.0", None, None, False, False), + ("1.16.0", None, "unspecified", False, False), + ("1.16.0", False, "unspecified", False, False), + ("1.16.0", True, "unspecified", True, False), + ("1.17.0", False, False, False, False), + ("1.17.0", False, True, False, True), + ("1.17.0", False, None, False, False), + ("1.17.0", True, False, False, False), + ("1.17.0", True, True, True, True), + ("1.17.0", True, None, True, False), + ("1.17.0", None, False, False, False), + ("1.17.0", None, True, True, True), + ("1.17.0", None, None, False, False), + ("1.17.0", None, "unspecified", False, False), + ("1.17.0", False, "unspecified", False, False), + ("1.17.0", True, "unspecified", True, False), + ], + ) + @mock.patch.dict(os.environ) + def test_default_instrumentor_enablement( + self, + adk_version: str, + enable_tracing: Optional[bool], + enable_telemetry: Optional[bool], + want_tracing_setup: bool, + want_logging_setup: bool, + default_instrumentor_builder_mock: mock.Mock, + warn_if_telemetry_api_disabled_mock: mock.Mock, + get_project_id_mock: mock.Mock, + adk_version_mock: mock.Mock, + ): + # Arrange + adk_version_mock.return_value = adk_version + if enable_telemetry is not None: + os.environ["GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"] = str( + enable_telemetry + ) + + app = agent_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=enable_tracing) + + # Act + app.set_up() + + # Assert + default_instrumentor_builder_mock.assert_called_once_with( + _TEST_PROJECT_ID, + enable_tracing=want_tracing_setup, + enable_logging=want_logging_setup, + ) + + @pytest.mark.parametrize( + "adk_version,enable_tracing,enable_telemetry,want_custom_instrumentor_called", + [ + ("1.16.0", False, False, False), + ("1.16.0", False, True, False), + ("1.16.0", False, None, False), + ("1.16.0", True, False, False), + ("1.16.0", True, True, True), + ("1.16.0", True, None, True), + ("1.16.0", None, False, False), + ("1.16.0", None, True, False), + ("1.16.0", None, None, False), + ("1.16.0", None, "unspecified", False), + ("1.16.0", False, "unspecified", False), + ("1.16.0", True, "unspecified", True), + ("1.17.0", False, False, False), + ("1.17.0", False, True, False), + ("1.17.0", False, None, False), + ("1.17.0", True, False, False), + ("1.17.0", True, True, True), + ("1.17.0", True, None, True), + ("1.17.0", None, False, False), + ("1.17.0", None, True, True), + ("1.17.0", None, None, False), + ("1.17.0", None, "unspecified", False), + ("1.17.0", False, "unspecified", False), + ("1.17.0", True, "unspecified", True), + ], + ) + @mock.patch.dict(os.environ) + def test_custom_instrumentor_enablement( + self, + adk_version: str, + enable_tracing: Optional[bool], + enable_telemetry: Optional[bool], + want_custom_instrumentor_called: bool, + get_project_id_mock: mock.Mock, + warn_if_telemetry_api_disabled_mock: mock.Mock, + adk_version_mock: mock.Mock, + ): + # Arrange + adk_version_mock.return_value = adk_version + if enable_telemetry is not None: + os.environ["GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY"] = str( + enable_telemetry + ) + custom_instrumentor = mock.Mock() + app = agent_engines.AdkApp( + agent=_TEST_AGENT, + enable_tracing=enable_tracing, + instrumentor_builder=custom_instrumentor, + ) + + # Act + app.set_up() + + # Assert + if want_custom_instrumentor_called: + custom_instrumentor.assert_called_once_with(_TEST_PROJECT_ID) + else: + custom_instrumentor.assert_not_called() + + @mock.patch.dict( + os.environ, + { + "GOOGLE_CLOUD_AGENT_ENGINE_ID": "test_agent_id", + "OTEL_RESOURCE_ATTRIBUTES": "some-attribute=some-value", + }, + ) + def test_tracing_setup( + self, + monkeypatch, + tracer_provider_mock: mock.Mock, + otlp_span_exporter_mock: mock.Mock, + get_project_id_mock: mock.Mock, + warn_if_telemetry_api_disabled_mock: mock.Mock, + ): + monkeypatch.setattr( + "uuid.uuid4", lambda: uuid.UUID("12345678123456781234567812345678") + ) + monkeypatch.setattr("os.getpid", lambda: 123123123) + with mock.patch.object(initializer.global_config, "_project", _TEST_PROJECT): + app = agent_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True) + app.set_up() + + otlp_span_exporter_mock.assert_called_once_with( + session=mock.ANY, + endpoint="https://telemetry.googleapis.com/v1/traces", + headers=mock.ANY, + ) + + get_project_id_mock.assert_called_with(_TEST_PROJECT) + + user_agent = otlp_span_exporter_mock.call_args.kwargs["headers"]["User-Agent"] + assert ( + re.fullmatch( + r"Vertex-Agent-Engine\/[\d\.]+ OTel-OTLP-Exporter-Python\/[\d\.]+", + user_agent, + ) + is not None + ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing( + self, + caplog, + tracer_provider_mock, + simple_span_processor_mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert agent._tmpl_attrs.get("instrumentor") is not None + # assert ( + # "enable_tracing=True but proceeding with tracing disabled" + # not in caplog.text + # ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing_warning(self, caplog): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + assert app._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # app.set_up() + # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text + + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # @pytest.mark.parametrize( + # "enable_tracing,want_warning", + # [ + # (True, False), + # (False, True), + # (None, False), + # ], + # ) + # @pytest.mark.usefixtures("caplog") + # def test_tracing_disabled_warning(self, enable_tracing, want_warning, caplog): + # _ = agent_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=enable_tracing) + # assert ( + # "[WARNING] Your 'enable_tracing=False' setting" in caplog.text + # ) == want_warning + + @mock.patch.dict(os.environ) + def test_span_content_capture_disabled_by_default(self, get_project_id_mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + app.set_up() + assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "false" + + @mock.patch.dict( + os.environ, {"OTEL_INSTRUMENTATION_GENAI_CAPTURE_MESSAGE_CONTENT": "true"} + ) + def test_span_content_capture_disabled_with_env_var(self, get_project_id_mock): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + app.set_up() + assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "false" + + @mock.patch.dict(os.environ) + def test_span_content_capture_enabled_with_tracing( + self, + get_project_id_mock, + warn_if_telemetry_api_disabled_mock, + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT, enable_tracing=True) + app.set_up() + assert os.environ["ADK_CAPTURE_MESSAGE_CONTENT_IN_SPANS"] == "true" + + +def test_dump_event_for_json(): + from google.adk.events import event + + raw_signature = b"test_signature" + # Create an event with both a ThoughtPart and a FunctionCallPart + test_event = event.Event( + **{ + "author": _TEST_AGENT_NAME, + "content": { + "parts": [ + { + "thought_signature": raw_signature, + "text": "This is a test", + }, + ], + "role": "model", + }, + "id": "test_id", + "invocation_id": "test_invocation_id", + } + ) + dumped_event = _utils.dump_event_for_json(test_event) + + part = dumped_event["content"]["parts"][0] + assert "text" in part + assert part["text"] == "This is a test" + assert "thought_signature" in part + assert isinstance(part["thought_signature"], str) + assert base64.b64decode(part["thought_signature"]) == raw_signature + + +# def test_adk_app_initialization_with_api_key(): +# importlib.reload(initializer) +# importlib.reload(agentplatform) +# try: +# agentplatform.init(api_key=_TEST_API_KEY) +# app = agent_engines.AdkApp(agent=_TEST_AGENT) +# assert app._tmpl_attrs.get("express_mode_api_key") == _TEST_API_KEY +# assert app._tmpl_attrs.get("runner") is None +# app.set_up() +# assert app._tmpl_attrs.get("runner") is not None +# assert os.environ.get("GOOGLE_API_KEY") == _TEST_API_KEY +# assert "GOOGLE_CLOUD_LOCATION" not in os.environ +# assert "GOOGLE_CLOUD_PROJECT" not in os.environ +# finally: +# initializer.global_pool.shutdown(wait=True) + + +# def test_adk_app_initialization_with_env_api_key(): +# try: +# os.environ["GOOGLE_API_KEY"] == _TEST_API_KEY +# app = agent_engines.AdkApp(agent=_TEST_AGENT) +# assert app._tmpl_attrs.get("express_mode_api_key") == _TEST_API_KEY +# assert app._tmpl_attrs.get("runner") is None +# app.set_up() +# assert app._tmpl_attrs.get("runner") is not None +# assert "GOOGLE_CLOUD_LOCATION" not in os.environ +# assert "GOOGLE_CLOUD_PROJECT" not in os.environ +# finally: +# initializer.global_pool.shutdown(wait=True) + + +@pytest.mark.usefixtures("is_version_sufficient_mock") +class TestAdkAppErrors: + @pytest.mark.asyncio + async def test_raise_get_session_not_found_error(self, get_project_id_mock): + with pytest.raises( + RuntimeError, + match=r"Session not found. Please create it using .create_session()", + ): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + await app.async_get_session( + user_id="non_existent_user", + session_id="test_session_id", + ) + + @pytest.mark.asyncio + async def test_async_stream_query_invalid_message_type(self): + app = agent_engines.AdkApp(agent=_TEST_AGENT) + with pytest.raises( + TypeError, + match="message must be a string or a dictionary representing a Content object.", + ): + async for _ in app.async_stream_query(user_id=_TEST_USER_ID, message=123): + pass + + +@pytest.fixture(scope="module") +def create_agent_engine_mock(): + with mock.patch.object( + reasoning_engine_service.ReasoningEngineServiceClient, + "create_reasoning_engine", + ) as create_agent_engine_mock: + create_agent_engine_lro_mock = mock.Mock(ga_operation.Operation) + create_agent_engine_lro_mock.result.return_value = _TEST_AGENT_ENGINE_OBJ + create_agent_engine_mock.return_value = create_agent_engine_lro_mock + yield create_agent_engine_mock + + +@pytest.fixture(scope="module") +def get_agent_engine_mock(): + with mock.patch.object( + reasoning_engine_service.ReasoningEngineServiceClient, + "get_reasoning_engine", + ) as get_agent_engine_mock: + api_client_mock = mock.Mock() + api_client_mock.get_reasoning_engine.return_value = _TEST_AGENT_ENGINE_OBJ + get_agent_engine_mock.return_value = api_client_mock + yield get_agent_engine_mock + + +@pytest.fixture(scope="module") +def cloud_storage_create_bucket_mock(): + with mock.patch.object(storage, "Client") as cloud_storage_mock: + bucket_mock = mock.Mock(spec=storage.Bucket) + bucket_mock.blob.return_value.open.return_value = "blob_file" + bucket_mock.blob.return_value.upload_from_filename.return_value = None + bucket_mock.blob.return_value.upload_from_string.return_value = None + + cloud_storage_mock.get_bucket = mock.Mock( + side_effect=ValueError("bucket not found") + ) + cloud_storage_mock.bucket.return_value = bucket_mock + cloud_storage_mock.create_bucket.return_value = bucket_mock + + yield cloud_storage_mock + + +@pytest.fixture(scope="module") +def cloudpickle_dump_mock(): + with mock.patch.object(cloudpickle, "dump") as cloudpickle_dump_mock: + yield cloudpickle_dump_mock + + +@pytest.fixture(scope="module") +def cloudpickle_load_mock(): + with mock.patch.object(cloudpickle, "load") as cloudpickle_load_mock: + yield cloudpickle_load_mock + + +@pytest.fixture(scope="function") +def get_gca_resource_mock(): + with mock.patch.object( + base.VertexAiResourceNoun, + "_get_gca_resource", + ) as get_gca_resource_mock: + get_gca_resource_mock.return_value = _TEST_AGENT_ENGINE_OBJ + yield get_gca_resource_mock + + +# Function scope is required for the pytest parameterized tests. +@pytest.fixture(scope="function") +def update_agent_engine_mock(): + with mock.patch.object( + reasoning_engine_service.ReasoningEngineServiceClient, + "update_reasoning_engine", + ) as update_agent_engine_mock: + yield update_agent_engine_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestAgentEngines: + + def setup_method(self): + importlib.reload(initializer) + importlib.reload(aiplatform) + aiplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + credentials=_TEST_CREDENTIALS, + staging_bucket=_TEST_STAGING_BUCKET, + ) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + @pytest.mark.parametrize( + "env_vars,expected_env_vars", + [ + ({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "unspecified"}), + (None, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "unspecified"}), + ( + {"some_env": "some_val"}, + { + "some_env": "some_val", + GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "unspecified", + }, + ), + ( + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}, + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}, + ), + ( + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, + ), + ], + ) + def test_create_default_telemetry_enablement( + self, + create_agent_engine_mock: mock.Mock, + cloud_storage_create_bucket_mock: mock.Mock, + cloudpickle_dump_mock: mock.Mock, + cloudpickle_load_mock: mock.Mock, + get_gca_resource_mock: mock.Mock, + env_vars: dict[str, str], + expected_env_vars: dict[str, str], + ): + agent_engines.create( + agent_engine=agent_engines.AdkApp(agent=_TEST_AGENT), + env_vars=env_vars, + ) + deployment_spec = create_agent_engine_mock.call_args.kwargs[ + "reasoning_engine" + ].spec.deployment_spec + assert _utils.to_dict(deployment_spec)["env"] == [ + {"name": key, "value": value} for key, value in expected_env_vars.items() + ] + + @pytest.mark.parametrize( + "env_vars,expected_env_vars", + [ + ({}, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "unspecified"}), + (None, {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "unspecified"}), + ( + {"some_env": "some_val"}, + { + "some_env": "some_val", + GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "unspecified", + }, + ), + ( + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}, + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "true"}, + ), + ( + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, + {GOOGLE_CLOUD_AGENT_ENGINE_ENABLE_TELEMETRY: "false"}, + ), + ], + ) + def test_update_default_telemetry_enablement( + self, + update_agent_engine_mock: mock.Mock, + cloud_storage_create_bucket_mock: mock.Mock, + cloudpickle_dump_mock: mock.Mock, + cloudpickle_load_mock: mock.Mock, + get_gca_resource_mock: mock.Mock, + get_agent_engine_mock: mock.Mock, + env_vars: dict[str, str], + expected_env_vars: dict[str, str], + ): + agent_engines.update( + resource_name=_TEST_AGENT_ENGINE_RESOURCE_NAME, + description="foobar", # avoid "At least one of ... must be specified" errors. + env_vars=env_vars, + ) + update_agent_engine_mock.assert_called_once() + deployment_spec = update_agent_engine_mock.call_args.kwargs[ + "request" + ].reasoning_engine.spec.deployment_spec + assert _utils.to_dict(deployment_spec)["env"] == [ + {"name": key, "value": value} for key, value in expected_env_vars.items() + ] + + +class TestAdkAppMtls: + """Test cases for mTLS functionality in AdkApp.""" + + def test_use_client_cert_effective_with_should_use_client_cert(self): + """Verifies that it respects the google-auth mTLS enablement check.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + return_value=True, + create=True, + ): + assert adk_template._use_client_cert_effective() is True + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "true"}) + def test_use_client_cert_effective_with_env_var_true(self): + """Verifies that it falls back to the environment variable if google-auth check fails.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + side_effect=AttributeError, + create=True, + ): + assert adk_template._use_client_cert_effective() is True + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "false"}) + def test_use_client_cert_effective_with_env_var_false(self): + """Verifies that it respects the environment variable being set to false.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + side_effect=AttributeError, + create=True, + ): + assert adk_template._use_client_cert_effective() is False + + def test_get_api_endpoint_default(self): + """Verifies the default telemetry endpoint is returned when no mTLS is configured.""" + assert ( + adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "always"}) + def test_get_api_endpoint_always_with_cert(self): + """Verifies the mTLS endpoint is used when forced and a certificate is available.""" + assert ( + adk_template._get_api_endpoint(client_cert_source=b"cert") + == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "auto"}) + def test_get_api_endpoint_auto_no_cert(self): + """Verifies it falls back to regular endpoint even if forced if no certificate is provided.""" + assert ( + adk_template._get_api_endpoint() == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "never"}) + def test_get_api_endpoint_never(self): + """Verifies the regular endpoint is used when mTLS is explicitly disabled.""" + assert ( + adk_template._get_api_endpoint(client_cert_source=b"cert") + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + @mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" + ) + def test_default_instrumentor_builder_with_mtls( + self, + mock_exporter, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the instrumentor builder with mTLS enabled.""" + # Mocking to enable mTLS + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=True + ): + with mock.patch.object( + mtls, "has_default_client_cert_source", return_value=True + ): + with mock.patch.object( + mtls, + "default_client_cert_source", + return_value=lambda: b"cert", + ): + adk_template._default_instrumentor_builder( + _TEST_PROJECT_ID, enable_tracing=True + ) + + # Verify the session was configured for mTLS + mock_session_cls.return_value.configure_mtls_channel.assert_called_once() + # Verify the exporter was initialized with the mTLS endpoint + mock_exporter.assert_called_once() + assert ( + mock_exporter.call_args.kwargs["endpoint"] + == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + def test_warn_if_telemetry_api_disabled_with_mtls( + self, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the telemetry API check with mTLS enabled.""" + mock_session = mock_session_cls.return_value + mock_session.post.return_value = mock.Mock(text="") + + # Mocking to enable mTLS + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=True + ): + with mock.patch.object( + mtls, "has_default_client_cert_source", return_value=True + ): + with mock.patch.object( + mtls, + "default_client_cert_source", + return_value=lambda: b"cert", + ): + adk_template._warn_if_telemetry_api_disabled() + + # Verify mTLS channel was configured for the check request + mock_session.configure_mtls_channel.assert_called_once() + # Verify the check was performed against the mTLS endpoint + mock_session.post.assert_called_once_with( + adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT, data=None + ) + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_MTLS_ENDPOINT": "invalid_value"}) + def test_get_api_endpoint_invalid_env(self): + """Verifies it defaults to AUTO and warns on invalid env var.""" + with mock.patch.object(adk_template, "_warn") as mock_warn: + assert ( + adk_template._get_api_endpoint() + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + mock_warn.assert_called_once() + + @mock.patch.dict(os.environ, {"GOOGLE_API_USE_CLIENT_CERTIFICATE": "not_a_bool"}) + def test_use_client_cert_effective_invalid_env(self): + """Verifies it warns on invalid boolean env var.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + side_effect=AttributeError, + create=True, + ): + with mock.patch.object(adk_template, "_warn") as mock_warn: + assert adk_template._use_client_cert_effective() is False + mock_warn.assert_called_once() + + def test_use_client_cert_effective_with_should_use_client_cert_false(self): + """Verifies that it respects google-auth returning False for mTLS.""" + with mock.patch.object( + mtls, + "should_use_client_cert", + return_value=False, + create=True, + ): + assert adk_template._use_client_cert_effective() is False + + def test_get_api_endpoint_auto_with_cert(self): + """Verifies the mTLS endpoint is used in AUTO mode when a cert is available.""" + # AUTO is the default, so we just pass a cert + assert ( + adk_template._get_api_endpoint(client_cert_source=b"cert") + == adk_template._DEFAULT_MTLS_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + @mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" + ) + def test_default_instrumentor_builder_no_mtls( + self, + mock_exporter, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the instrumentor builder with mTLS disabled.""" + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=False + ): + adk_template._default_instrumentor_builder( + _TEST_PROJECT_ID, enable_tracing=True + ) + + # Verify mTLS channel was NOT configured + mock_session_cls.return_value.configure_mtls_channel.assert_not_called() + # Verify the exporter was initialized with the regular endpoint + mock_exporter.assert_called_once() + assert ( + mock_exporter.call_args.kwargs["endpoint"] + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + def test_warn_if_telemetry_api_disabled_no_mtls( + self, + mock_session_cls, + mock_auth_default, + ): + """Integration test for the telemetry API check with mTLS disabled.""" + mock_session = mock_session_cls.return_value + mock_session.post.return_value = mock.Mock(text="") + + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=False + ): + adk_template._warn_if_telemetry_api_disabled() + + # Verify mTLS channel was NOT configured + mock_session.configure_mtls_channel.assert_not_called() + # Verify the check was performed against the regular endpoint + mock_session.post.assert_called_once_with( + adk_template._DEFAULT_TELEMETRY_ENDPOINT, data=None + ) + + @mock.patch("google.auth.default", return_value=(mock.Mock(), _TEST_PROJECT)) + @mock.patch.object(adk_template.requests_auth, "AuthorizedSession") + @mock.patch( + "opentelemetry.exporter.otlp.proto.http.trace_exporter.OTLPSpanExporter" + ) + def test_default_instrumentor_builder_mtls_no_cert_source( + self, + mock_exporter, + mock_session_cls, + mock_auth_default, + ): + """Tests that it falls back to regular endpoint if mTLS is on but no cert is found.""" + with mock.patch.object( + adk_template, "_use_client_cert_effective", return_value=True + ): + with mock.patch.object( + mtls, + "has_default_client_cert_source", + return_value=False, + ): + adk_template._default_instrumentor_builder( + _TEST_PROJECT_ID, enable_tracing=True + ) + + # Channel is configured, but endpoint remains default due to missing cert source + mock_session_cls.return_value.configure_mtls_channel.assert_called_once() + assert ( + mock_exporter.call_args.kwargs["endpoint"] + == adk_template._DEFAULT_TELEMETRY_ENDPOINT + ) diff --git a/tests/unit/agentplatform/frameworks/test_frameworks_ag2.py b/tests/unit/agentplatform/frameworks/test_frameworks_ag2.py new file mode 100644 index 0000000000..7a02cefacf --- /dev/null +++ b/tests/unit/agentplatform/frameworks/test_frameworks_ag2.py @@ -0,0 +1,389 @@ +# Copyright 2025 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import dataclasses +import importlib +import json +from typing import Optional +from unittest import mock + +from google import auth +import agentplatform +from google.cloud.aiplatform import initializer +from agentplatform import agent_engines +from agentplatform.agent_engines import _utils +import pytest + + +_DEFAULT_PLACE_TOOL_ACTIVITY = "museums" +_DEFAULT_PLACE_TOOL_PAGE_SIZE = 3 +_DEFAULT_PLACE_PHOTO_MAXWIDTH = 400 +_TEST_LOCATION = "us-central1" +_TEST_PROJECT = "test-project" +_TEST_MODEL = "gemini-1.0-pro" +_TEST_RUNNABLE_NAME = "test-runnable" +_TEST_SYSTEM_INSTRUCTION = "You are a helpful bot." + + +def place_tool_query( + city: str, + activity: str = _DEFAULT_PLACE_TOOL_ACTIVITY, + page_size: int = _DEFAULT_PLACE_TOOL_PAGE_SIZE, +): + """Searches the city for recommendations on the activity.""" + return {"city": city, "activity": activity, "page_size": page_size} + + +def place_photo_query( + photo_reference: str, + maxwidth: int = _DEFAULT_PLACE_PHOTO_MAXWIDTH, + maxheight: Optional[int] = None, +): + """Returns the photo for a given reference.""" + result = {"photo_reference": photo_reference, "maxwidth": maxwidth} + if maxheight: + result["maxheight"] = maxheight + return result + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as google_auth_mock: + credentials_mock = mock.Mock() + credentials_mock.with_quota_project.return_value = None + google_auth_mock.return_value = ( + credentials_mock, + _TEST_PROJECT, + ) + yield google_auth_mock + + +@pytest.fixture +def agentplatform_init_mock(): + with mock.patch.object(agentplatform, "init") as agentplatform_init_mock: + yield agentplatform_init_mock + + +@pytest.fixture +def dataclasses_asdict_mock(): + with mock.patch.object(dataclasses, "asdict") as dataclasses_asdict_mock: + dataclasses_asdict_mock.return_value = {} + yield dataclasses_asdict_mock + + +@pytest.fixture +def dataclasses_is_dataclass_mock(): + with mock.patch.object( + dataclasses, "is_dataclass" + ) as dataclasses_is_dataclass_mock: + dataclasses_is_dataclass_mock.return_value = True + yield dataclasses_is_dataclass_mock + + +@pytest.fixture +def to_json_serializable_autogen_object_mock(): + with mock.patch.object( + _utils, + "to_json_serializable_autogen_object", + ) as to_json_serializable_autogen_object_mock: + to_json_serializable_autogen_object_mock.return_value = {} + yield to_json_serializable_autogen_object_mock + + +@pytest.fixture +def cloud_trace_exporter_mock(): + with mock.patch.object( + _utils, + "_import_cloud_trace_exporter_or_warn", + ) as cloud_trace_exporter_mock: + yield cloud_trace_exporter_mock + + +@pytest.fixture +def tracer_provider_mock(): + with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock: + yield tracer_provider_mock + + +@pytest.fixture +def simple_span_processor_mock(): + with mock.patch( + "opentelemetry.sdk.trace.export.SimpleSpanProcessor" + ) as simple_span_processor_mock: + yield simple_span_processor_mock + + +@pytest.fixture +def autogen_instrumentor_mock(): + with mock.patch.object( + _utils, + "_import_openinference_autogen_or_warn", + ) as autogen_instrumentor_mock: + yield autogen_instrumentor_mock + + +@pytest.fixture +def autogen_instrumentor_none_mock(): + with mock.patch.object( + _utils, + "_import_openinference_autogen_or_warn", + ) as autogen_instrumentor_mock: + autogen_instrumentor_mock.return_value = None + yield autogen_instrumentor_mock + + +@pytest.fixture +def autogen_tools_mock(): + with mock.patch.object( + _utils, + "_import_autogen_tools_or_warn", + ) as autogen_tools_mock: + autogen_tools_mock.return_value = mock.MagicMock() + yield autogen_tools_mock + + +class MockAgent: + def __init__(self, name=None, description=None): + self.name = name + self.description = description + + +class MockCost: + def __init__(self, total_cost=0.0): + self.total_cost = total_cost + + def model_dump_json(self): + return json.dumps({"total_cost": self.total_cost}) + + +@pytest.mark.usefixtures("google_auth_mock") +class TestAG2Agent: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(agentplatform) + agentplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_initialization(self): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, runnable_name=_TEST_RUNNABLE_NAME + ) + assert agent._tmpl_attrs.get("model_name") == _TEST_MODEL + assert agent._tmpl_attrs.get("runnable_name") == _TEST_RUNNABLE_NAME + assert agent._tmpl_attrs.get("project") == _TEST_PROJECT + assert agent._tmpl_attrs.get("location") == _TEST_LOCATION + assert agent._tmpl_attrs.get("runnable") is None + + def test_initialization_with_tools(self, autogen_tools_mock): + tools = [ + place_tool_query, + place_photo_query, + ] + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + system_instruction=_TEST_SYSTEM_INSTRUCTION, + tools=tools, + runnable_builder=lambda **kwargs: kwargs, + ) + assert agent._tmpl_attrs.get("runnable") is None + assert agent._tmpl_attrs.get("tools") + assert not agent._tmpl_attrs.get("ag2_tool_objects") + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + assert agent._tmpl_attrs.get("ag2_tool_objects") + + def test_set_up(self): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + runnable_builder=lambda **kwargs: kwargs, + ) + assert agent._tmpl_attrs.get("runnable") is None + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + + def test_clone(self): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + runnable_builder=lambda **kwargs: kwargs, + ) + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + agent_clone = agent.clone() + assert agent._tmpl_attrs.get("runnable") is not None + assert agent_clone._tmpl_attrs.get("runnable") is None + agent_clone.set_up() + assert agent_clone._tmpl_attrs.get("runnable") is not None + + def test_query(self, to_json_serializable_autogen_object_mock): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + ) + agent._tmpl_attrs["runnable"] = mock.Mock() + mocks = mock.Mock() + mocks.attach_mock(mock=agent._tmpl_attrs["runnable"], attribute="run") + agent.query(input="test query") + mocks.assert_has_calls( + [ + mock.call.run.run( + message={"content": "test query"}, + user_input=False, + tools=[], + max_turns=None, + ) + ] + ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing( + self, + caplog, + cloud_trace_exporter_mock, + tracer_provider_mock, + simple_span_processor_mock, + autogen_instrumentor_mock, + ): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + enable_tracing=True, + ) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert agent._tmpl_attrs.get("instrumentor") is not None + # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing_warning(self, caplog, autogen_instrumentor_none_mock): + agent = agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + enable_tracing=True, + ) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text + + +def _return_input_no_typing(input_): + """Returns input back to user.""" + return input_ + + +class TestConvertToolsOrRaiseErrors: + def test_raise_untyped_input_args(self, agentplatform_init_mock): + with pytest.raises(TypeError, match=r"has untyped input_arg"): + agent_engines.AG2Agent( + model=_TEST_MODEL, + runnable_name=_TEST_RUNNABLE_NAME, + tools=[_return_input_no_typing], + ) + + +class TestToJsonSerializableAutoGenObject: + """Tests for `_utils.to_json_serializable_autogen_object`.""" + + def test_autogen_chat_result( + self, + dataclasses_asdict_mock, + dataclasses_is_dataclass_mock, + ): + mock_chat_result: _utils.AutogenChatResult = mock.Mock( + spec=_utils.AutogenChatResult + ) + _utils.to_json_serializable_autogen_object(mock_chat_result) + dataclasses_is_dataclass_mock.assert_called_once_with(mock_chat_result) + dataclasses_asdict_mock.assert_called_once_with(mock_chat_result) + + def test_autogen_run_response(self): + mock_response: _utils.AutogenRunResponse = mock.Mock( + spec=_utils.AutogenRunResponse + ) + mock_agent = MockAgent( + name="TestAgent", + description="Agent Description", + ) + mock_cost = MockCost(total_cost=5.5) + mock_response.summary = "summary" + mock_response.messages = [{"role": "user", "content": "Hello"}] + mock_response.context_variables = {"var1": "value1"} + mock_response.last_speaker = mock_agent + mock_response.cost = mock_cost + mock_response.process = mock.MagicMock() + + want = { + "summary": "summary", + "messages": [{"role": "user", "content": "Hello"}], + "context_variables": {"var1": "value1"}, + "last_speaker": { + "name": "TestAgent", + "description": "Agent Description", + }, + "cost": {"total_cost": 5.5}, + } + got = _utils.to_json_serializable_autogen_object(mock_response) + mock_response.process.assert_called_once() + assert got == want + + def test_autogen_empty_run_response(self): + mock_response: _utils.AutogenRunResponse = mock.Mock( + spec=_utils.AutogenRunResponse + ) + mock_response.summary = None + mock_response.messages = [] + mock_response.context_variables = None + mock_response.last_speaker = None + mock_response.cost = None + want = { + "summary": None, + "messages": [], + "context_variables": None, + "last_speaker": None, + "cost": None, + } + got = _utils.to_json_serializable_autogen_object(mock_response) + assert got == want + + +class TestDataClassToJsonSerializable: + """Tests for `_utils._dataclass_to_dict_or_raise`.""" + + def test_valid_dataclass(self): + @dataclasses.dataclass + class SimpleDataClass: + field1: str + field2: int + + instance = SimpleDataClass(field1="value1", field2=123) + want = {"field1": "value1", "field2": 123} + got = _utils._dataclass_to_dict_or_raise(instance) + assert got == want + + def test_not_a_dataclass_raises_type_error(self): + class NotADataclass: + pass + + instance = NotADataclass() + with pytest.raises(TypeError, match="Object is not a dataclass"): + _utils._dataclass_to_dict_or_raise(instance) diff --git a/tests/unit/agentplatform/frameworks/test_frameworks_langchain.py b/tests/unit/agentplatform/frameworks/test_frameworks_langchain.py new file mode 100644 index 0000000000..54c7f5834d --- /dev/null +++ b/tests/unit/agentplatform/frameworks/test_frameworks_langchain.py @@ -0,0 +1,300 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +from typing import Optional +from unittest import mock + +from google import auth +import agentplatform +from google.cloud.aiplatform import initializer +from agentplatform import agent_engines + +from agentplatform.agent_engines import _utils +import pytest + + +from langchain_core import prompts +from langchain_core.load import dump as langchain_load_dump +from langchain_classic.agents.format_scratchpad import ( + format_to_openai_function_messages, +) +from langchain_core.tools import StructuredTool + + +_DEFAULT_PLACE_TOOL_ACTIVITY = "museums" +_DEFAULT_PLACE_TOOL_PAGE_SIZE = 3 +_DEFAULT_PLACE_PHOTO_MAXWIDTH = 400 +_TEST_LOCATION = "us-central1" +_TEST_PROJECT = "test-project" +_TEST_MODEL = "gemini-1.0-pro" +_TEST_SYSTEM_INSTRUCTION = "You are a helpful bot." + + +def place_tool_query( + city: str, + activity: str = _DEFAULT_PLACE_TOOL_ACTIVITY, + page_size: int = _DEFAULT_PLACE_TOOL_PAGE_SIZE, +): + """Searches the city for recommendations on the activity.""" + return {"city": city, "activity": activity, "page_size": page_size} + + +def place_photo_query( + photo_reference: str, + maxwidth: int = _DEFAULT_PLACE_PHOTO_MAXWIDTH, + maxheight: Optional[int] = None, +): + """Returns the photo for a given reference.""" + result = {"photo_reference": photo_reference, "maxwidth": maxwidth} + if maxheight: + result["maxheight"] = maxheight + return result + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as google_auth_mock: + credentials_mock = mock.Mock() + credentials_mock.with_quota_project.return_value = None + google_auth_mock.return_value = ( + credentials_mock, + _TEST_PROJECT, + ) + yield google_auth_mock + + +@pytest.fixture +def agentplatform_init_mock(): + with mock.patch.object(agentplatform, "init") as agentplatform_init_mock: + yield agentplatform_init_mock + + +@pytest.fixture +def langchain_dump_mock(): + with mock.patch.object(langchain_load_dump, "dumpd") as langchain_dump_mock: + yield langchain_dump_mock + + +@pytest.fixture +def cloud_trace_exporter_mock(): + with mock.patch.object( + _utils, + "_import_cloud_trace_exporter_or_warn", + ) as cloud_trace_exporter_mock: + yield cloud_trace_exporter_mock + + +@pytest.fixture +def tracer_provider_mock(): + with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock: + yield tracer_provider_mock + + +@pytest.fixture +def simple_span_processor_mock(): + with mock.patch( + "opentelemetry.sdk.trace.export.SimpleSpanProcessor" + ) as simple_span_processor_mock: + yield simple_span_processor_mock + + +@pytest.fixture +def langchain_instrumentor_mock(): + with mock.patch.object( + _utils, + "_import_openinference_langchain_or_warn", + ) as langchain_instrumentor_mock: + yield langchain_instrumentor_mock + + +@pytest.fixture +def langchain_instrumentor_none_mock(): + with mock.patch.object( + _utils, + "_import_openinference_langchain_or_warn", + ) as langchain_instrumentor_mock: + langchain_instrumentor_mock.return_value = None + yield langchain_instrumentor_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestLangchainAgent: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(agentplatform) + agentplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + self.prompt = { + "input": lambda x: x["input"], + "agent_scratchpad": ( + lambda x: format_to_openai_function_messages(x["intermediate_steps"]) + ), + } | prompts.ChatPromptTemplate.from_messages( + [ + ("user", "{input}"), + prompts.MessagesPlaceholder(variable_name="agent_scratchpad"), + ] + ) + self.output_parser = mock.Mock() + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_initialization(self): + agent = agent_engines.LangchainAgent(model=_TEST_MODEL) + assert agent._tmpl_attrs.get("model_name") == _TEST_MODEL + assert agent._tmpl_attrs.get("project") == _TEST_PROJECT + assert agent._tmpl_attrs.get("location") == _TEST_LOCATION + assert agent._tmpl_attrs.get("runnable") is None + + def test_initialization_with_tools(self): + tools = [ + place_tool_query, + StructuredTool.from_function(place_photo_query), + ] + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + system_instruction=_TEST_SYSTEM_INSTRUCTION, + tools=tools, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + for tool, agent_tool in zip(tools, agent._tmpl_attrs.get("tools")): + assert isinstance(agent_tool, type(tool)) + assert agent._tmpl_attrs.get("runnable") is None + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + + def test_set_up(self): + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + assert agent._tmpl_attrs.get("runnable") is None + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + + def test_clone(self): + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + agent_clone = agent.clone() + assert agent._tmpl_attrs.get("runnable") is not None + assert agent_clone._tmpl_attrs.get("runnable") is None + agent_clone.set_up() + assert agent_clone._tmpl_attrs.get("runnable") is not None + + def test_query(self, langchain_dump_mock): + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + ) + agent._tmpl_attrs["runnable"] = mock.Mock() + mocks = mock.Mock() + mocks.attach_mock(mock=agent._tmpl_attrs["runnable"], attribute="invoke") + agent.query(input="test query") + mocks.assert_has_calls( + [mock.call.invoke.invoke(input={"input": "test query"}, config=None)] + ) + + def test_stream_query(self, langchain_dump_mock): + agent = agent_engines.LangchainAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].stream.return_value = [] + list(agent.stream_query(input="test stream query")) + agent._tmpl_attrs["runnable"].stream.assert_called_once_with( + input={"input": "test stream query"}, + config=None, + ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing( + self, + caplog, + cloud_trace_exporter_mock, + tracer_provider_mock, + simple_span_processor_mock, + langchain_instrumentor_mock, + ): + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + enable_tracing=True, + ) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert agent._tmpl_attrs.get("instrumentor") is not None + # assert ( + # "enable_tracing=True but proceeding with tracing disabled" + # not in caplog.text + # ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock): + agent = agent_engines.LangchainAgent( + model=_TEST_MODEL, + prompt=self.prompt, + output_parser=self.output_parser, + enable_tracing=True, + ) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text + + +def _return_input_no_typing(input_): + """Returns input back to user.""" + return input_ + + +class TestConvertToolsOrRaiseErrors: + def test_raise_untyped_input_args(self, agentplatform_init_mock): + with pytest.raises(TypeError, match=r"has untyped input_arg"): + agent_engines.LangchainAgent( + model=_TEST_MODEL, + tools=[_return_input_no_typing], + ) + + +class TestSystemInstructionAndPromptRaisesErrors: + def test_raise_both_system_instruction_and_prompt_error(self, agentplatform_init_mock): + with pytest.raises( + ValueError, + match=r"Only one of `prompt` or `system_instruction` should be specified.", + ): + agent_engines.LangchainAgent( + model=_TEST_MODEL, + system_instruction=_TEST_SYSTEM_INSTRUCTION, + prompt=prompts.ChatPromptTemplate.from_messages( + [ + ("user", "{input}"), + ] + ), + ) diff --git a/tests/unit/agentplatform/frameworks/test_frameworks_langgraph.py b/tests/unit/agentplatform/frameworks/test_frameworks_langgraph.py new file mode 100644 index 0000000000..09318ae514 --- /dev/null +++ b/tests/unit/agentplatform/frameworks/test_frameworks_langgraph.py @@ -0,0 +1,365 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import importlib +from typing import Any, Dict, List, Optional +from unittest import mock + +from google import auth +import agentplatform +from google.cloud.aiplatform import initializer +from agentplatform import agent_engines +from agentplatform.agent_engines import _utils +import pytest + +from langchain_core import runnables +from langchain_core.load import dump as langchain_load_dump +from langchain_core.tools import StructuredTool + + +_DEFAULT_PLACE_TOOL_ACTIVITY = "museums" +_DEFAULT_PLACE_TOOL_PAGE_SIZE = 3 +_DEFAULT_PLACE_PHOTO_MAXWIDTH = 400 +_TEST_LOCATION = "us-central1" +_TEST_PROJECT = "test-project" +_TEST_MODEL = "gemini-1.0-pro" +_TEST_CONFIG = runnables.RunnableConfig(configurable={"thread_id": "thread-values"}) + + +def place_tool_query( + city: str, + activity: str = _DEFAULT_PLACE_TOOL_ACTIVITY, + page_size: int = _DEFAULT_PLACE_TOOL_PAGE_SIZE, +): + """Searches the city for recommendations on the activity.""" + return {"city": city, "activity": activity, "page_size": page_size} + + +def place_photo_query( + photo_reference: str, + maxwidth: int = _DEFAULT_PLACE_PHOTO_MAXWIDTH, + maxheight: Optional[int] = None, +): + """Returns the photo for a given reference.""" + result = {"photo_reference": photo_reference, "maxwidth": maxwidth} + if maxheight: + result["maxheight"] = maxheight + return result + + +def _checkpointer_builder(**unused_kwargs): + try: + from langgraph.checkpoint import memory + except ImportError: + from langgraph_checkpoint.checkpoint import memory + + return memory.MemorySaver() + + +def _get_state_messages(state: Dict[str, Any]) -> List[str]: + messages = [] + for message in state.get("values").get("messages"): + messages.append(message.content) + return messages + + +@pytest.fixture(scope="module") +def google_auth_mock(): + with mock.patch.object(auth, "default") as google_auth_mock: + credentials_mock = mock.Mock() + credentials_mock.with_quota_project.return_value = None + google_auth_mock.return_value = ( + credentials_mock, + _TEST_PROJECT, + ) + yield google_auth_mock + + +@pytest.fixture +def agentplatform_init_mock(): + with mock.patch.object(agentplatform, "init") as agentplatform_init_mock: + yield agentplatform_init_mock + + +@pytest.fixture +def langchain_dump_mock(): + with mock.patch.object(langchain_load_dump, "dumpd") as langchain_dump_mock: + yield langchain_dump_mock + + +@pytest.fixture +def cloud_trace_exporter_mock(): + with mock.patch.object( + _utils, + "_import_cloud_trace_exporter_or_warn", + ) as cloud_trace_exporter_mock: + yield cloud_trace_exporter_mock + + +@pytest.fixture +def tracer_provider_mock(): + with mock.patch("opentelemetry.sdk.trace.TracerProvider") as tracer_provider_mock: + yield tracer_provider_mock + + +@pytest.fixture +def simple_span_processor_mock(): + with mock.patch( + "opentelemetry.sdk.trace.export.SimpleSpanProcessor" + ) as simple_span_processor_mock: + yield simple_span_processor_mock + + +@pytest.fixture +def langchain_instrumentor_mock(): + with mock.patch.object( + _utils, + "_import_openinference_langchain_or_warn", + ) as langchain_instrumentor_mock: + yield langchain_instrumentor_mock + + +@pytest.fixture +def langchain_instrumentor_none_mock(): + with mock.patch.object( + _utils, + "_import_openinference_langchain_or_warn", + ) as langchain_instrumentor_mock: + langchain_instrumentor_mock.return_value = None + yield langchain_instrumentor_mock + + +@pytest.mark.usefixtures("google_auth_mock") +class TestLanggraphAgent: + def setup_method(self): + importlib.reload(initializer) + importlib.reload(agentplatform) + agentplatform.init( + project=_TEST_PROJECT, + location=_TEST_LOCATION, + ) + + def teardown_method(self): + initializer.global_pool.shutdown(wait=True) + + def test_initialization(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + assert agent._tmpl_attrs.get("model_name") == _TEST_MODEL + assert agent._tmpl_attrs.get("project") == _TEST_PROJECT + assert agent._tmpl_attrs.get("location") == _TEST_LOCATION + assert agent._tmpl_attrs.get("runnable") is None + + def test_initialization_with_tools(self): + tools = [ + place_tool_query, + StructuredTool.from_function(place_photo_query), + ] + agent = agent_engines.LanggraphAgent( + model=_TEST_MODEL, + tools=tools, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + for tool, agent_tool in zip(tools, agent._tmpl_attrs.get("tools")): + assert isinstance(agent_tool, type(tool)) + assert agent._tmpl_attrs.get("runnable") is None + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + + def test_set_up(self): + agent = agent_engines.LanggraphAgent( + model=_TEST_MODEL, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + assert agent._tmpl_attrs.get("runnable") is None + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + + def test_clone(self): + agent = agent_engines.LanggraphAgent( + model=_TEST_MODEL, + model_builder=lambda **kwargs: kwargs, + runnable_builder=lambda **kwargs: kwargs, + ) + agent.set_up() + assert agent._tmpl_attrs.get("runnable") is not None + agent_clone = agent.clone() + assert agent._tmpl_attrs.get("runnable") is not None + assert agent_clone._tmpl_attrs.get("runnable") is None + agent_clone.set_up() + assert agent_clone._tmpl_attrs.get("runnable") is not None + + def test_query(self, langchain_dump_mock): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + mocks = mock.Mock() + mocks.attach_mock(mock=agent._tmpl_attrs.get("runnable"), attribute="invoke") + agent.query(input="test query") + mocks.assert_has_calls( + [ + mock.call.invoke.invoke( + input={"input": "test query", "messages": [("user", "test query")]}, + config=None, + ) + ] + ) + + def test_stream_query(self, langchain_dump_mock): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].stream.return_value = [] + list(agent.stream_query(input="test stream query")) + agent._tmpl_attrs["runnable"].stream.assert_called_once_with( + input={ + "input": "test stream query", + "messages": [("user", "test stream query")], + }, + config=None, + ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing( + self, + caplog, + cloud_trace_exporter_mock, + tracer_provider_mock, + simple_span_processor_mock, + langchain_instrumentor_mock, + ): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL, enable_tracing=True) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/384730642): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert agent._instrumentor is not None + # assert ( + # "enable_tracing=True but proceeding with tracing disabled" + # not in caplog.text + # ) + + @pytest.mark.usefixtures("caplog") + def test_enable_tracing_warning(self, caplog, langchain_instrumentor_none_mock): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL, enable_tracing=True) + assert agent._tmpl_attrs.get("instrumentor") is None + # TODO(b/383923584): Re-enable this test once the parent issue is fixed. + # agent.set_up() + # assert "enable_tracing=True but proceeding with tracing disabled" in caplog.text + + def test_get_state_history_empty(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].get_state_history.return_value = [] + history = list(agent.get_state_history()) + assert history == [] + + def test_get_state_history(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].get_state_history.return_value = [ + mock.Mock(), + mock.Mock(), + ] + agent._tmpl_attrs["runnable"].get_state_history.return_value[ + 0 + ]._asdict.return_value = {"test_key_1": "test_value_1"} + agent._tmpl_attrs["runnable"].get_state_history.return_value[ + 1 + ]._asdict.return_value = {"test_key_2": "test_value_2"} + history = list(agent.get_state_history()) + assert history == [ + {"test_key_1": "test_value_1"}, + {"test_key_2": "test_value_2"}, + ] + + def test_get_state_history_with_config(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].get_state_history.return_value = [ + mock.Mock(), + mock.Mock(), + ] + agent._tmpl_attrs["runnable"].get_state_history.return_value[ + 0 + ]._asdict.return_value = {"test_key_1": "test_value_1"} + agent._tmpl_attrs["runnable"].get_state_history.return_value[ + 1 + ]._asdict.return_value = {"test_key_2": "test_value_2"} + history = list(agent.get_state_history(config=_TEST_CONFIG)) + assert history == [ + {"test_key_1": "test_value_1"}, + {"test_key_2": "test_value_2"}, + ] + + def test_get_state(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].get_state.return_value = mock.Mock() + agent._tmpl_attrs["runnable"].get_state.return_value._asdict.return_value = { + "test_key": "test_value" + } + state = agent.get_state() + assert state == {"test_key": "test_value"} + + def test_get_state_with_config(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent._tmpl_attrs["runnable"].get_state.return_value = mock.Mock() + agent._tmpl_attrs["runnable"].get_state.return_value._asdict.return_value = { + "test_key": "test_value" + } + state = agent.get_state(config=_TEST_CONFIG) + assert state == {"test_key": "test_value"} + + def test_update_state(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent.update_state() + agent._tmpl_attrs["runnable"].update_state.assert_called_once() + + def test_update_state_with_config(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent.update_state(config=_TEST_CONFIG) + agent._tmpl_attrs["runnable"].update_state.assert_called_once_with( + config=_TEST_CONFIG + ) + + def test_update_state_with_config_and_kwargs(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + agent._tmpl_attrs["runnable"] = mock.Mock() + agent.update_state(config=_TEST_CONFIG, test_key="test_value") + agent._tmpl_attrs["runnable"].update_state.assert_called_once_with( + config=_TEST_CONFIG, test_key="test_value" + ) + + def test_register_operations(self): + agent = agent_engines.LanggraphAgent(model=_TEST_MODEL) + expected_operations = { + "": ["query", "get_state", "update_state"], + "stream": ["stream_query", "get_state_history"], + } + assert agent.register_operations() == expected_operations + + +def _return_input_no_typing(input_): + """Returns input back to user.""" + return input_ + + +class TestConvertToolsOrRaiseErrors: + def test_raise_untyped_input_args(self, agentplatform_init_mock): + with pytest.raises(TypeError, match=r"has untyped input_arg"): + agent_engines.LanggraphAgent( + model=_TEST_MODEL, tools=[_return_input_no_typing] + )