diff --git a/astrapy/admin/admin.py b/astrapy/admin/admin.py index f43255b5..16399679 100644 --- a/astrapy/admin/admin.py +++ b/astrapy/admin/admin.py @@ -215,6 +215,7 @@ def fetch_raw_database_info_from_id_token( dev_ops_api=True, redacted_header_names=_api_options.redacted_header_names, event_observers=_api_options.event_observers, + ca_cert_path=_api_options.ca_cert_path, ) gd_response = dev_ops_commander.request( @@ -313,6 +314,7 @@ async def async_fetch_raw_database_info_from_id_token( dev_ops_api=True, redacted_header_names=_api_options.redacted_header_names, event_observers=_api_options.event_observers, + ca_cert_path=_api_options.ca_cert_path, ) gd_response = await dev_ops_commander.async_request( @@ -517,6 +519,7 @@ def _get_dev_ops_api_commander(self) -> APICommander: redacted_header_names=self.api_options.redacted_header_names, event_observers=self.api_options.event_observers, spawner=self, + ca_cert_path=self.api_options.ca_cert_path, ) return dev_ops_commander @@ -547,6 +550,7 @@ def _get_dev_ops_regionlist_api_commander(self) -> APICommander: redacted_header_names=self.api_options.redacted_header_names, event_observers=self.api_options.event_observers, spawner=self, + ca_cert_path=self.api_options.ca_cert_path, ) return rl_dev_ops_commander @@ -2761,6 +2765,7 @@ def _get_api_commander(self) -> APICommander: redacted_header_names=self.api_options.redacted_header_names, event_observers=self.api_options.event_observers, spawner=self, + ca_cert_path=self.api_options.ca_cert_path, ) return api_commander @@ -2789,6 +2794,7 @@ def _get_dev_ops_api_commander(self) -> APICommander: redacted_header_names=self.api_options.redacted_header_names, event_observers=self.api_options.event_observers, spawner=self, + ca_cert_path=self.api_options.ca_cert_path, ) return dev_ops_commander @@ -3976,6 +3982,7 @@ def _get_api_commander(self) -> APICommander: redacted_header_names=self.api_options.redacted_header_names, event_observers=self.api_options.event_observers, spawner=self, + ca_cert_path=self.api_options.ca_cert_path, ) return api_commander diff --git a/astrapy/data/collection.py b/astrapy/data/collection.py index b2ca5056..81062784 100644 --- a/astrapy/data/collection.py +++ b/astrapy/data/collection.py @@ -343,6 +343,7 @@ def _get_api_commander(self) -> APICommander: handle_decimals_reads=( self.api_options.serdes_options.use_decimals_in_collections ), + ca_cert_path=self.api_options.ca_cert_path, ) return api_commander @@ -3314,6 +3315,7 @@ def _get_api_commander(self) -> APICommander: handle_decimals_reads=( self.api_options.serdes_options.use_decimals_in_collections ), + ca_cert_path=self.api_options.ca_cert_path, ) return api_commander diff --git a/astrapy/data/database.py b/astrapy/data/database.py index c0295017..914d6876 100644 --- a/astrapy/data/database.py +++ b/astrapy/data/database.py @@ -210,6 +210,7 @@ def _get_api_commander(self, keyspace: str | None) -> APICommander | None: redacted_header_names=self.api_options.redacted_header_names, event_observers=self.api_options.event_observers, spawner=self, + ca_cert_path=self.api_options.ca_cert_path, ) return api_commander @@ -2216,6 +2217,7 @@ def command( redacted_header_names=self.api_options.redacted_header_names, event_observers=self.api_options.event_observers, spawner=self, + ca_cert_path=self.api_options.ca_cert_path, ) _cmd_desc = ",".join(sorted(body.keys())) @@ -2430,6 +2432,7 @@ def _get_api_commander(self, keyspace: str | None) -> APICommander | None: redacted_header_names=self.api_options.redacted_header_names, event_observers=self.api_options.event_observers, spawner=self, + ca_cert_path=self.api_options.ca_cert_path, ) return api_commander @@ -4500,6 +4503,7 @@ async def command( redacted_header_names=self.api_options.redacted_header_names, event_observers=self.api_options.event_observers, spawner=self, + ca_cert_path=self.api_options.ca_cert_path, ) _cmd_desc = ",".join(sorted(body.keys())) diff --git a/astrapy/data/table.py b/astrapy/data/table.py index 9823a399..1d9d284e 100644 --- a/astrapy/data/table.py +++ b/astrapy/data/table.py @@ -386,6 +386,7 @@ def _get_api_commander(self) -> APICommander: spawner=self, handle_decimals_writes=True, handle_decimals_reads=True, + ca_cert_path=self.api_options.ca_cert_path, ) return api_commander @@ -3437,6 +3438,7 @@ def _get_api_commander(self) -> APICommander: spawner=self, handle_decimals_writes=True, handle_decimals_reads=True, + ca_cert_path=self.api_options.ca_cert_path, ) return api_commander diff --git a/astrapy/utils/api_commander.py b/astrapy/utils/api_commander.py index b43b9036..6f877125 100644 --- a/astrapy/utils/api_commander.py +++ b/astrapy/utils/api_commander.py @@ -143,21 +143,28 @@ def __init__( event_observers: dict[str, Observer | None] = {}, handle_decimals_writes: bool = False, handle_decimals_reads: bool = False, + ca_cert_path: str | None = None, ) -> None: + self.ca_cert_path = ca_cert_path + if ca_cert_path is not None: + ssl_context = ssl.create_default_context(cafile=ca_cert_path) + else: + ssl_context = CLIENT_SSL_CONTEXT + ssl_control_headers: dict[str, str | None] if disable_ssl_reuse: self.client = httpx.Client( limits=no_pooling_limits, - verify=CLIENT_SSL_CONTEXT, + verify=ssl_context, ) self.async_client = httpx.AsyncClient( limits=no_pooling_limits, - verify=CLIENT_SSL_CONTEXT, + verify=ssl_context, ) ssl_control_headers = {"Connection": "close"} else: - self.client = httpx.Client(verify=CLIENT_SSL_CONTEXT) - self.async_client = httpx.AsyncClient(verify=CLIENT_SSL_CONTEXT) + self.client = httpx.Client(verify=ssl_context) + self.async_client = httpx.AsyncClient(verify=ssl_context) ssl_control_headers = {} self.api_endpoint = api_endpoint.rstrip("/") @@ -290,6 +297,7 @@ def _copy( else self.redacted_header_names ), dev_ops_api=dev_ops_api if dev_ops_api is not None else self.dev_ops_api, + ca_cert_path=self.ca_cert_path, ) def _compose_request_url(self, additional_path: str | None) -> str: diff --git a/astrapy/utils/api_options.py b/astrapy/utils/api_options.py index 8226c968..4e357abc 100644 --- a/astrapy/utils/api_options.py +++ b/astrapy/utils/api_options.py @@ -945,6 +945,12 @@ class APIOptions: dev_ops_api_url_options: an instance of `DevOpsAPIURLOptions` (see) to customize the URL used to reach the DevOps API (customizing this setting is rarely needed; relevant only for Astra DB environments). + ca_cert_path: optional path to a custom CA certificate file (PEM format) + to use for SSL verification when connecting to the API. This is useful + for PrivateLink setups or other environments where a custom CA is needed. + When set, a per-instance SSL context is created with the given CA file + instead of using the default certifi-based CA bundle. Defaults to None + (use the bundled certifi CA roots). Examples: >>> from astrapy import DataAPIClient @@ -1022,6 +1028,7 @@ class APIOptions: embedding_api_key: EmbeddingHeadersProvider | UnsetType = _UNSET reranking_api_key: RerankingHeadersProvider | UnsetType = _UNSET event_observers: dict[str, Observer | None] | UnsetType = _UNSET + ca_cert_path: str | None | UnsetType = _UNSET timeout_options: TimeoutOptions | UnsetType = _UNSET serdes_options: SerdesOptions | UnsetType = _UNSET @@ -1039,6 +1046,7 @@ def __init__( embedding_api_key: str | EmbeddingHeadersProvider | UnsetType = _UNSET, reranking_api_key: str | RerankingHeadersProvider | UnsetType = _UNSET, event_observers: dict[str, Observer | None] | UnsetType = _UNSET, + ca_cert_path: str | None | UnsetType = _UNSET, timeout_options: TimeoutOptions | UnsetType = _UNSET, serdes_options: SerdesOptions | UnsetType = _UNSET, data_api_url_options: DataAPIURLOptions | UnsetType = _UNSET, @@ -1062,6 +1070,7 @@ def __init__( reranking_api_key, ) self.event_observers = event_observers + self.ca_cert_path = ca_cert_path self.timeout_options = timeout_options self.serdes_options = serdes_options self.data_api_url_options = data_api_url_options @@ -1133,6 +1142,9 @@ def __repr__(self) -> str: None if isinstance(self.dev_ops_api_url_options, UnsetType) else f"dev_ops_api_url_options={self.dev_ops_api_url_options}", + None + if isinstance(self.ca_cert_path, UnsetType) + else f"ca_cert_path={self.ca_cert_path}", ) if pc is not None ] @@ -1219,6 +1231,7 @@ class FullAPIOptions(APIOptions): embedding_api_key: EmbeddingHeadersProvider reranking_api_key: RerankingHeadersProvider event_observers: dict[str, Observer | None] + ca_cert_path: str | None timeout_options: FullTimeoutOptions serdes_options: FullSerdesOptions @@ -1237,6 +1250,7 @@ def __init__( embedding_api_key: str | EmbeddingHeadersProvider, reranking_api_key: str | RerankingHeadersProvider, event_observers: dict[str, Observer | None], + ca_cert_path: str | None, timeout_options: FullTimeoutOptions, serdes_options: FullSerdesOptions, data_api_url_options: FullDataAPIURLOptions, @@ -1252,6 +1266,7 @@ def __init__( embedding_api_key=embedding_api_key, reranking_api_key=reranking_api_key, event_observers=event_observers, + ca_cert_path=ca_cert_path, timeout_options=timeout_options, serdes_options=serdes_options, data_api_url_options=data_api_url_options, @@ -1392,6 +1407,11 @@ def with_override(self, other: APIOptions | None | UnsetType) -> FullAPIOptions: else self.reranking_api_key ), event_observers=event_observers, + ca_cert_path=( + other.ca_cert_path + if not isinstance(other.ca_cert_path, UnsetType) + else self.ca_cert_path + ), timeout_options=timeout_options, serdes_options=serdes_options, data_api_url_options=data_api_url_options, @@ -1451,6 +1471,7 @@ def defaultAPIOptions(environment: str) -> FullAPIOptions: embedding_api_key=EmbeddingAPIKeyHeaderProvider(None), reranking_api_key=RerankingAPIKeyHeaderProvider(None), event_observers={}, + ca_cert_path=None, timeout_options=defaultTimeoutOptions, serdes_options=defaultSerdesOptions, data_api_url_options=defaultDataAPIURLOptions, diff --git a/tests/base/unit/test_apicommander.py b/tests/base/unit/test_apicommander.py index 99a58030..3c860434 100644 --- a/tests/base/unit/test_apicommander.py +++ b/tests/base/unit/test_apicommander.py @@ -71,6 +71,40 @@ def test_apicommander_conversions(self) -> None: ) assert cmd1 == cmd1._copy(dev_ops_api=False)._copy(dev_ops_api=True) + @pytest.mark.describe("test of APICommander ca_cert_path SSL context") + def test_apicommander_ca_cert_path(self) -> None: + import ssl + + import certifi + + from astrapy.utils.api_commander import CLIENT_SSL_CONTEXT + + # default: uses the shared global SSL context + cmd_default = APICommander( + api_endpoint="https://example.com", + path="/v1", + spawner=None, + ) + assert cmd_default.ca_cert_path is None + assert cmd_default.client._transport._pool._ssl_context is CLIENT_SSL_CONTEXT # type: ignore[attr-defined] + + # custom CA: creates a distinct SSL context + ca_path = certifi.where() # reuse certifi's bundle as a known-valid path + cmd_custom = APICommander( + api_endpoint="https://example.com", + path="/v1", + spawner=None, + ca_cert_path=ca_path, + ) + assert cmd_custom.ca_cert_path == ca_path + custom_ctx = cmd_custom.client._transport._pool._ssl_context # type: ignore[attr-defined] + assert isinstance(custom_ctx, ssl.SSLContext) + assert custom_ctx is not CLIENT_SSL_CONTEXT + + # _copy preserves ca_cert_path + cmd_copied = cmd_custom._copy(path="/v2") + assert cmd_copied.ca_cert_path == ca_path + @pytest.mark.describe("test of APICommander request, sync") def test_apicommander_request_sync(self, httpserver: HTTPServer) -> None: base_endpoint = httpserver.url_for("/") diff --git a/tests/base/unit/test_apioptions.py b/tests/base/unit/test_apioptions.py index 0ec2c439..e2d1182c 100644 --- a/tests/base/unit/test_apioptions.py +++ b/tests/base/unit/test_apioptions.py @@ -85,3 +85,20 @@ def test_apioptions_eventobservers(self) -> None: assert opts_1 == opts_2 assert opts_3n == opts_4n + + @pytest.mark.describe("test of ca_cert_path inheritance in APIOptions") + def test_apioptions_ca_cert_path(self) -> None: + opts_d = defaultAPIOptions(environment="prod") + assert opts_d.ca_cert_path is None + + # override with a path + opts_1 = opts_d.with_override(APIOptions(ca_cert_path="/some/ca.pem")) + assert opts_1.ca_cert_path == "/some/ca.pem" + + # second override replaces the first + opts_2 = opts_1.with_override(APIOptions(ca_cert_path="/other/ca.pem")) + assert opts_2.ca_cert_path == "/other/ca.pem" + + # unset override (None) does not overwrite + opts_3 = opts_1.with_override(APIOptions()) + assert opts_3.ca_cert_path == "/some/ca.pem"