Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions astrapy/admin/admin.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def fetch_raw_database_info_from_id_token(
dev_ops_api=True,
redacted_header_names=_api_options.redacted_header_names,
event_observers=_api_options.event_observers,
ca_cert_path=_api_options.ca_cert_path,
)

gd_response = dev_ops_commander.request(
Expand Down Expand Up @@ -313,6 +314,7 @@ async def async_fetch_raw_database_info_from_id_token(
dev_ops_api=True,
redacted_header_names=_api_options.redacted_header_names,
event_observers=_api_options.event_observers,
ca_cert_path=_api_options.ca_cert_path,
)

gd_response = await dev_ops_commander.async_request(
Expand Down Expand Up @@ -517,6 +519,7 @@ def _get_dev_ops_api_commander(self) -> APICommander:
redacted_header_names=self.api_options.redacted_header_names,
event_observers=self.api_options.event_observers,
spawner=self,
ca_cert_path=self.api_options.ca_cert_path,
)
return dev_ops_commander

Expand Down Expand Up @@ -547,6 +550,7 @@ def _get_dev_ops_regionlist_api_commander(self) -> APICommander:
redacted_header_names=self.api_options.redacted_header_names,
event_observers=self.api_options.event_observers,
spawner=self,
ca_cert_path=self.api_options.ca_cert_path,
)
return rl_dev_ops_commander

Expand Down Expand Up @@ -2761,6 +2765,7 @@ def _get_api_commander(self) -> APICommander:
redacted_header_names=self.api_options.redacted_header_names,
event_observers=self.api_options.event_observers,
spawner=self,
ca_cert_path=self.api_options.ca_cert_path,
)
return api_commander

Expand Down Expand Up @@ -2789,6 +2794,7 @@ def _get_dev_ops_api_commander(self) -> APICommander:
redacted_header_names=self.api_options.redacted_header_names,
event_observers=self.api_options.event_observers,
spawner=self,
ca_cert_path=self.api_options.ca_cert_path,
)
return dev_ops_commander

Expand Down Expand Up @@ -3976,6 +3982,7 @@ def _get_api_commander(self) -> APICommander:
redacted_header_names=self.api_options.redacted_header_names,
event_observers=self.api_options.event_observers,
spawner=self,
ca_cert_path=self.api_options.ca_cert_path,
)
return api_commander

Expand Down
2 changes: 2 additions & 0 deletions astrapy/data/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,7 @@ def _get_api_commander(self) -> APICommander:
handle_decimals_reads=(
self.api_options.serdes_options.use_decimals_in_collections
),
ca_cert_path=self.api_options.ca_cert_path,
)
return api_commander

Expand Down Expand Up @@ -3314,6 +3315,7 @@ def _get_api_commander(self) -> APICommander:
handle_decimals_reads=(
self.api_options.serdes_options.use_decimals_in_collections
),
ca_cert_path=self.api_options.ca_cert_path,
)
return api_commander

Expand Down
4 changes: 4 additions & 0 deletions astrapy/data/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,7 @@ def _get_api_commander(self, keyspace: str | None) -> APICommander | None:
redacted_header_names=self.api_options.redacted_header_names,
event_observers=self.api_options.event_observers,
spawner=self,
ca_cert_path=self.api_options.ca_cert_path,
)
return api_commander

Expand Down Expand Up @@ -2216,6 +2217,7 @@ def command(
redacted_header_names=self.api_options.redacted_header_names,
event_observers=self.api_options.event_observers,
spawner=self,
ca_cert_path=self.api_options.ca_cert_path,
)

_cmd_desc = ",".join(sorted(body.keys()))
Expand Down Expand Up @@ -2430,6 +2432,7 @@ def _get_api_commander(self, keyspace: str | None) -> APICommander | None:
redacted_header_names=self.api_options.redacted_header_names,
event_observers=self.api_options.event_observers,
spawner=self,
ca_cert_path=self.api_options.ca_cert_path,
)
return api_commander

Expand Down Expand Up @@ -4500,6 +4503,7 @@ async def command(
redacted_header_names=self.api_options.redacted_header_names,
event_observers=self.api_options.event_observers,
spawner=self,
ca_cert_path=self.api_options.ca_cert_path,
)

_cmd_desc = ",".join(sorted(body.keys()))
Expand Down
2 changes: 2 additions & 0 deletions astrapy/data/table.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,7 @@ def _get_api_commander(self) -> APICommander:
spawner=self,
handle_decimals_writes=True,
handle_decimals_reads=True,
ca_cert_path=self.api_options.ca_cert_path,
)
return api_commander

Expand Down Expand Up @@ -3437,6 +3438,7 @@ def _get_api_commander(self) -> APICommander:
spawner=self,
handle_decimals_writes=True,
handle_decimals_reads=True,
ca_cert_path=self.api_options.ca_cert_path,
)
return api_commander

Expand Down
16 changes: 12 additions & 4 deletions astrapy/utils/api_commander.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,21 +143,28 @@ def __init__(
event_observers: dict[str, Observer | None] = {},
handle_decimals_writes: bool = False,
handle_decimals_reads: bool = False,
ca_cert_path: str | None = None,
) -> None:
self.ca_cert_path = ca_cert_path
if ca_cert_path is not None:
ssl_context = ssl.create_default_context(cafile=ca_cert_path)
else:
ssl_context = CLIENT_SSL_CONTEXT

ssl_control_headers: dict[str, str | None]
if disable_ssl_reuse:
self.client = httpx.Client(
limits=no_pooling_limits,
verify=CLIENT_SSL_CONTEXT,
verify=ssl_context,
)
self.async_client = httpx.AsyncClient(
limits=no_pooling_limits,
verify=CLIENT_SSL_CONTEXT,
verify=ssl_context,
)
ssl_control_headers = {"Connection": "close"}
else:
self.client = httpx.Client(verify=CLIENT_SSL_CONTEXT)
self.async_client = httpx.AsyncClient(verify=CLIENT_SSL_CONTEXT)
self.client = httpx.Client(verify=ssl_context)
self.async_client = httpx.AsyncClient(verify=ssl_context)
ssl_control_headers = {}

self.api_endpoint = api_endpoint.rstrip("/")
Expand Down Expand Up @@ -290,6 +297,7 @@ def _copy(
else self.redacted_header_names
),
dev_ops_api=dev_ops_api if dev_ops_api is not None else self.dev_ops_api,
ca_cert_path=self.ca_cert_path,
)

def _compose_request_url(self, additional_path: str | None) -> str:
Expand Down
21 changes: 21 additions & 0 deletions astrapy/utils/api_options.py
Original file line number Diff line number Diff line change
Expand Up @@ -945,6 +945,12 @@ class APIOptions:
dev_ops_api_url_options: an instance of `DevOpsAPIURLOptions` (see) to
customize the URL used to reach the DevOps API (customizing this setting
is rarely needed; relevant only for Astra DB environments).
ca_cert_path: optional path to a custom CA certificate file (PEM format)
to use for SSL verification when connecting to the API. This is useful
for PrivateLink setups or other environments where a custom CA is needed.
When set, a per-instance SSL context is created with the given CA file
instead of using the default certifi-based CA bundle. Defaults to None
(use the bundled certifi CA roots).

Examples:
>>> from astrapy import DataAPIClient
Expand Down Expand Up @@ -1022,6 +1028,7 @@ class APIOptions:
embedding_api_key: EmbeddingHeadersProvider | UnsetType = _UNSET
reranking_api_key: RerankingHeadersProvider | UnsetType = _UNSET
event_observers: dict[str, Observer | None] | UnsetType = _UNSET
ca_cert_path: str | None | UnsetType = _UNSET

timeout_options: TimeoutOptions | UnsetType = _UNSET
serdes_options: SerdesOptions | UnsetType = _UNSET
Expand All @@ -1039,6 +1046,7 @@ def __init__(
embedding_api_key: str | EmbeddingHeadersProvider | UnsetType = _UNSET,
reranking_api_key: str | RerankingHeadersProvider | UnsetType = _UNSET,
event_observers: dict[str, Observer | None] | UnsetType = _UNSET,
ca_cert_path: str | None | UnsetType = _UNSET,
timeout_options: TimeoutOptions | UnsetType = _UNSET,
serdes_options: SerdesOptions | UnsetType = _UNSET,
data_api_url_options: DataAPIURLOptions | UnsetType = _UNSET,
Expand All @@ -1062,6 +1070,7 @@ def __init__(
reranking_api_key,
)
self.event_observers = event_observers
self.ca_cert_path = ca_cert_path
self.timeout_options = timeout_options
self.serdes_options = serdes_options
self.data_api_url_options = data_api_url_options
Expand Down Expand Up @@ -1133,6 +1142,9 @@ def __repr__(self) -> str:
None
if isinstance(self.dev_ops_api_url_options, UnsetType)
else f"dev_ops_api_url_options={self.dev_ops_api_url_options}",
None
if isinstance(self.ca_cert_path, UnsetType)
else f"ca_cert_path={self.ca_cert_path}",
)
if pc is not None
]
Expand Down Expand Up @@ -1219,6 +1231,7 @@ class FullAPIOptions(APIOptions):
embedding_api_key: EmbeddingHeadersProvider
reranking_api_key: RerankingHeadersProvider
event_observers: dict[str, Observer | None]
ca_cert_path: str | None

timeout_options: FullTimeoutOptions
serdes_options: FullSerdesOptions
Expand All @@ -1237,6 +1250,7 @@ def __init__(
embedding_api_key: str | EmbeddingHeadersProvider,
reranking_api_key: str | RerankingHeadersProvider,
event_observers: dict[str, Observer | None],
ca_cert_path: str | None,
timeout_options: FullTimeoutOptions,
serdes_options: FullSerdesOptions,
data_api_url_options: FullDataAPIURLOptions,
Expand All @@ -1252,6 +1266,7 @@ def __init__(
embedding_api_key=embedding_api_key,
reranking_api_key=reranking_api_key,
event_observers=event_observers,
ca_cert_path=ca_cert_path,
timeout_options=timeout_options,
serdes_options=serdes_options,
data_api_url_options=data_api_url_options,
Expand Down Expand Up @@ -1392,6 +1407,11 @@ def with_override(self, other: APIOptions | None | UnsetType) -> FullAPIOptions:
else self.reranking_api_key
),
event_observers=event_observers,
ca_cert_path=(
other.ca_cert_path
if not isinstance(other.ca_cert_path, UnsetType)
else self.ca_cert_path
),
timeout_options=timeout_options,
serdes_options=serdes_options,
data_api_url_options=data_api_url_options,
Expand Down Expand Up @@ -1451,6 +1471,7 @@ def defaultAPIOptions(environment: str) -> FullAPIOptions:
embedding_api_key=EmbeddingAPIKeyHeaderProvider(None),
reranking_api_key=RerankingAPIKeyHeaderProvider(None),
event_observers={},
ca_cert_path=None,
timeout_options=defaultTimeoutOptions,
serdes_options=defaultSerdesOptions,
data_api_url_options=defaultDataAPIURLOptions,
Expand Down
34 changes: 34 additions & 0 deletions tests/base/unit/test_apicommander.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,40 @@ def test_apicommander_conversions(self) -> None:
)
assert cmd1 == cmd1._copy(dev_ops_api=False)._copy(dev_ops_api=True)

@pytest.mark.describe("test of APICommander ca_cert_path SSL context")
def test_apicommander_ca_cert_path(self) -> None:
import ssl

import certifi

from astrapy.utils.api_commander import CLIENT_SSL_CONTEXT

# default: uses the shared global SSL context
cmd_default = APICommander(
api_endpoint="https://example.com",
path="/v1",
spawner=None,
)
assert cmd_default.ca_cert_path is None
assert cmd_default.client._transport._pool._ssl_context is CLIENT_SSL_CONTEXT # type: ignore[attr-defined]

# custom CA: creates a distinct SSL context
ca_path = certifi.where() # reuse certifi's bundle as a known-valid path
cmd_custom = APICommander(
api_endpoint="https://example.com",
path="/v1",
spawner=None,
ca_cert_path=ca_path,
)
assert cmd_custom.ca_cert_path == ca_path
custom_ctx = cmd_custom.client._transport._pool._ssl_context # type: ignore[attr-defined]
assert isinstance(custom_ctx, ssl.SSLContext)
assert custom_ctx is not CLIENT_SSL_CONTEXT

# _copy preserves ca_cert_path
cmd_copied = cmd_custom._copy(path="/v2")
assert cmd_copied.ca_cert_path == ca_path

@pytest.mark.describe("test of APICommander request, sync")
def test_apicommander_request_sync(self, httpserver: HTTPServer) -> None:
base_endpoint = httpserver.url_for("/")
Expand Down
17 changes: 17 additions & 0 deletions tests/base/unit/test_apioptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,3 +85,20 @@ def test_apioptions_eventobservers(self) -> None:

assert opts_1 == opts_2
assert opts_3n == opts_4n

@pytest.mark.describe("test of ca_cert_path inheritance in APIOptions")
def test_apioptions_ca_cert_path(self) -> None:
opts_d = defaultAPIOptions(environment="prod")
assert opts_d.ca_cert_path is None

# override with a path
opts_1 = opts_d.with_override(APIOptions(ca_cert_path="/some/ca.pem"))
assert opts_1.ca_cert_path == "/some/ca.pem"

# second override replaces the first
opts_2 = opts_1.with_override(APIOptions(ca_cert_path="/other/ca.pem"))
assert opts_2.ca_cert_path == "/other/ca.pem"

# unset override (None) does not overwrite
opts_3 = opts_1.with_override(APIOptions())
assert opts_3.ca_cert_path == "/some/ca.pem"
Loading