diff --git a/docs/design/step_ca_ephemeral_admin_certs.md b/docs/design/step_ca_ephemeral_admin_certs.md new file mode 100644 index 0000000000..33a0ca2732 --- /dev/null +++ b/docs/design/step_ca_ephemeral_admin_certs.md @@ -0,0 +1,154 @@ +# Ephemeral Admin Certificates with step-ca + +## Goal + +Allow an admin to authenticate with OIDC and receive a short-lived FLARE admin +certificate without adding OIDC handling to the FLARE server or clients. + +```text +admin CLI -> step CLI -> step-ca -> OIDC provider +admin CLI <- short-lived admin certificate/key +admin CLI -> existing FLARE mTLS login and job signing +``` + +The built-in `step_ca` provider delegates OIDC discovery, browser login, token +validation, claim mapping, and certificate issuance to step-ca. + +## Trust Model + +step-ca signs admin certificates with an intermediate CA rooted in the FLARE +project root. Existing servers and clients validate the resulting chain with +`rootCA.pem`. The FLARE server cannot mint an admin certificate or job +signature unless it controls step-ca, its signing key, or an admin private key. + +The issued leaf certificate must contain the fields FLARE already consumes: + +- `commonName`: authenticated admin identity +- `organizationName`: FLARE organization +- `unstructuredName`: `project_admin`, `org_admin`, `lead`, or `member` + +The admin private key is generated on the admin machine by `step ca +certificate` and is not sent to the OIDC provider or FLARE server. + +## Runtime Behavior + +An ephemeral admin startup kit contains `ephemeral_admin_cert` instead of +static `client.crt` and `client.key` files. The admin client: + +1. Loads a valid cached credential or invokes its configured provider. +2. Validates the certificate chain, validity, identity fields, allowed role, + and certificate/private-key match. +3. Uses the existing certificate challenge, mTLS, authorization, and job-signing + paths. +4. Reacquires credentials when the certificate enters its renewal window. + +The certificate-chain support is implemented separately. This feature adds no +OIDC token format, server login mode, or server-signed job manifest to FLARE. + +## Provisioning + +Static and ephemeral admins can coexist in one `project.yml`: + +```yaml +participants: + - name: static-admin@example.com + type: admin + org: example_org + role: project_admin + + - name: sso-admin-kit + type: admin + ephemeral_admin_cert: + provider: step_ca + renewal_window: 60 + provider_config: + ca_url: https://step-ca.example.com + provisioner: nvflare-admin-oidc + cert_ttl: 24h + command_timeout: 300 +``` + +The ephemeral participant omits `org` and `role`; both come from the issued +certificate. Its name identifies a generic startup kit, not a user. The kit +contains `rootCA.pem` and provider configuration but no static admin +certificate or private key. Server and site startup kits are unchanged. + +## Provider and Cache + +`ephemeral_admin_cert.provider` is a built-in provider name or a +`module:function` path. A provider receives its configuration and the project +root certificate, and returns local certificate/key paths. FLARE applies the +same validation to built-in and custom provider results. + +Valid credentials are cached per OS user under +`~/.nvflare/ephemeral_admin_certs`. The cache entry is bound to the provider +configuration and project root. Files are private to the OS user, concurrent +CLI processes serialize acquisition, and immutable credential directories +prevent cert/key replacement races. + +The cache is required because each `nvflare` command starts a new process; +without it every command would repeat browser login. Users must not share an OS +account because that also shares its cached credential. Deleting the cache +forces a fresh OIDC login. + +## step-ca Requirements + +Operators configure step-ca, not FLARE, with: + +- an intermediate CA signed by the FLARE project root +- an OIDC provisioner and matching IdP loopback redirect URI +- the OIDC client credentials and scopes needed by the X.509 template +- an X.509 template that writes the required FLARE identity fields +- a short maximum/default certificate duration, normally 24 hours +- renewal disabled so extending access requires another OIDC login + +The root CA private key returns to offline storage after signing the +intermediate. The intermediate key remains with step-ca and may be protected by +an HSM/KMS. Neither private key is distributed to FLARE servers, sites, or +admins. + +### Organization and Role Mapping + +The step-ca template must map an exact, allowlisted IdP role to one +`(organization, FLARE role)` pair. Organization and role must not be accepted as +independent user-controlled claims. + +Example mappings for one project and organization: + +```text +nvflare-demo-example-project_admin -> (example, project_admin) +nvflare-demo-example-org_admin -> (example, org_admin) +nvflare-demo-example-lead -> (example, lead) +nvflare-demo-example-member -> (example, member) +``` + +When several mapped roles for the same organization are present, the template +selects the highest privilege in this order: + +```text +project_admin > org_admin > lead > member +``` + +Mappings that produce more than one organization are ambiguous and must fail +closed. Separate provisioners/templates per organization are the simplest +deployment model. The FLARE server does not map or rewrite certificate +organization or role values. + +## Lifetime and Clone Behavior + +The built-in provider requests a 24-hour certificate by default. FLARE does not +perform revocation checks, so disabling a user prevents new issuance but does +not invalidate an existing certificate. The lifetime must cover expected queue +and deployment delays because clients verify that the signing certificate is +still valid when the job is deployed. + +`clone_job` copies the original submitter signature without contacting the +admin client. A clone therefore becomes unusable after the original certificate +expires. FLARE reports this before cloning when it can inspect the stored +certificate. A future client-assisted clone could re-sign and resubmit the job. + +## References + +- [step-ca](https://smallstep.com/docs/step-ca/) +- [`step ca certificate`](https://smallstep.com/docs/step-cli/reference/ca/certificate/) +- [step-ca provisioners](https://smallstep.com/docs/step-ca/provisioners/) diff --git a/docs/programming_guide/provisioning_system.rst b/docs/programming_guide/provisioning_system.rst index f96b060ec1..0d344a49b0 100644 --- a/docs/programming_guide/provisioning_system.rst +++ b/docs/programming_guide/provisioning_system.rst @@ -452,7 +452,8 @@ Edit the project.yml configuration file to meet your project requirements: - "api_version" should be set to 3 or 4. Version 4 adds support for multi-study configuration (see :ref:`multi_study_guide`) - "name" is used to identify this project. - "participants" describes the different parties in the FL system, distinguished by type. For all participants, "name" - should be unique, and "org" should be defined in AuthPolicyBuilder. The "name" of the server should + should be unique. ``org`` is required except for ephemeral admin kit entries, whose organization comes from the + issued certificate. The "name" of the server should be in the format of a fully qualified domain name. It is possible to use a unique hostname rather than FQDN, with the IP mapped to the hostname by having it added to ``/etc/hosts``: @@ -461,10 +462,78 @@ Edit the project.yml configuration file to meet your project requirements: - "fed_learn_port" is the port number for communication between the FL server and FL clients - "admin_port" is the port number for communication between the FL server and FL administration client - Type "client" describes the FL clients, with one "org" and "name" for each client as well as "enable_byoc" settings. - - Type "admin" describes the admin clients with the name being a unique email. The role must be one of "project_admin", "org_admin", "lead" and "member". + - Type "admin" describes the admin clients. For traditional static + admin certificates, the name must be a unique email. For ephemeral + admin certificate startup kits, the name may be a unique kit name + such as ``sso-admin-kit`` because the real admin identity comes from + the short-lived certificate issued after login. Static admins must + define ``org`` and a role of "project_admin", "org_admin", "lead" or + "member". Ephemeral admin kit entries omit ``org`` and ``role``; + those values come from the issued certificate. - "builders" contains all of the builders and the args to be passed into each. See the details in docstrings of the :ref:`bundled_builders`. - "studies" (optional, requires ``api_version: 4``): defines named studies with per-study site enrollment and admin role mappings. See :ref:`multi_study_guide` for the full schema and examples. +Ephemeral admin certificate configuration +========================================= + +Use ephemeral admin certificates when admin users should authenticate through an +external certificate provider instead of receiving long-lived private keys in +their startup kits. Server and client startup kits are unchanged. + +The built-in ``step_ca`` provider delegates OIDC login and short-lived +certificate issuance to step-ca. FLARE only stores the provider configuration in +the generated admin startup kit and then validates the returned certificate +before using it. The admin machine must have the ``step`` CLI installed. + +Example configuration: + +.. code-block:: yaml + + participants: + - name: static-admin@example.com + type: admin + org: nvidia + role: project_admin + + - name: sso-admin-kit + type: admin + ephemeral_admin_cert: + provider: step_ca + renewal_window: 60 + provider_config: + ca_url: https://step-ca.example.com + provisioner: nvflare-admin-oidc + cert_ttl: 24h + command_timeout: 300 + +Only admin participants with ``ephemeral_admin_cert`` receive SSO-backed startup +kits. The generated ``sso-admin-kit`` startup kit contains +``ephemeral_admin_cert`` in ``fed_admin.json`` and omits static admin +``client.crt`` and ``client.key``. Traditional admin participants still receive +static admin certificate material. The admin client invokes the configured +provider when the cached certificate is missing, invalid, expired, or close to +expiry. Cached certificate material is stored under +``~/.nvflare/ephemeral_admin_certs`` and can be removed manually if a fresh SSO +login is required. The returned certificate must chain to ``rootCA.pem``, match +its private key, contain a valid FLARE organization and admin role, and be valid +for the current time. If ``cert_ttl`` is omitted, the built-in ``step_ca`` +provider requests ``24h``. + +The certificate provider must map authenticated IdP claims to one allowed +``(organization, role)`` pair. For example, an IdP role such as +``nvflare-demo-example-project_admin`` can map to organization ``example`` and +role ``project_admin``. Perform this mapping in the step-ca X.509 template (or +in the IdP), using exact allowlisted values. Do not accept organization and role +as independent user-controlled claims. If several allowed roles for the same +organization are present, select the highest privilege; fail closed if the +organization is ambiguous. + +Custom certificate providers can be configured with +``provider: module:function``. The function receives ``provider_config`` and +``root_ca_file`` and returns paths for the admin certificate and key. +FLARE performs the same certificate validation for custom providers as it does +for ``step_ca``. + .. _project_yml: Default project.yml file diff --git a/docs/user_guide/admin_guide/deployment/overview.rst b/docs/user_guide/admin_guide/deployment/overview.rst index f022cb7c74..96f7540fa6 100644 --- a/docs/user_guide/admin_guide/deployment/overview.rst +++ b/docs/user_guide/admin_guide/deployment/overview.rst @@ -260,6 +260,14 @@ you will need to modify the corresponding script. The same applies to the other The email to participate this FL project is embedded in the CN field of client certificate, which uniquely identifies the participant. As such, please safeguard its private key, client.key. +Some projects use ephemeral admin certificates. In that case, the admin startup +kit contains ``ephemeral_admin_cert`` in ``fed_admin.json`` instead of static +``client.crt`` and ``client.key`` files. The admin client obtains a short-lived +admin certificate and private key from the configured provider when connecting +to the server, then uses the same certificate login and job-signing flow as a +static admin kit. The startup kit name can be a generic name such as +``sso-admin-kit``; the issued certificate contains the real admin identity. + .. attention:: You will need write access in the directory containing the "startup" folder because the "transfer" directory for diff --git a/docs/user_guide/admin_guide/security/identity_security.rst b/docs/user_guide/admin_guide/security/identity_security.rst index f10180339d..9d8ef10ca1 100644 --- a/docs/user_guide/admin_guide/security/identity_security.rst +++ b/docs/user_guide/admin_guide/security/identity_security.rst @@ -40,6 +40,38 @@ The security of the system comes from the PKI credentials in the Startup Kits. A :ref:`NVFlare Dashboard ` is a website that supports user and site registration. Users will be able to download their Startup Kits (and other artifacts) from the website. +Ephemeral Admin Certificates +---------------------------- +For admin users, NVFLARE can also provision startup kits that do not contain a +static admin certificate or private key. In this mode, the admin startup kit +contains an ``ephemeral_admin_cert`` provider configuration. When the admin +client starts, it asks that provider for a short-lived admin certificate and +private key, validates the returned certificate against the project +``rootCA.pem``, and then uses the normal certificate login and job-signing path. +Valid ephemeral admin cert/key material is cached under +``~/.nvflare/ephemeral_admin_certs`` so repeated CLI commands do not require a +new browser login until the certificate is invalid, expired, or close to +expiry. The cache is private to the OS user, so administrators should not share +an OS account. The startup kit can use a generic name such as +``sso-admin-kit``; the actual admin identity comes from the certificate issued +after SSO login. + +The built-in provider is ``step_ca``. With this provider, step-ca owns OIDC +login, role claim handling, and certificate issuance. The issued certificate +must contain the same FLARE identity fields that the existing PKI path consumes: +``commonName`` for the admin identity, ``organizationName`` for the FLARE org, +and ``unstructuredName`` for one FLARE role: ``project_admin``, ``org_admin``, +``lead``, or ``member``. + +The step-ca template must map an exact, allowlisted IdP role to both the FLARE +organization and role. This binds the authorization tuple before the +certificate reaches the FLARE server; the server does not derive or rewrite +either value. + +This mode reduces the distribution risk of long-lived admin private keys while +preserving the existing server and client trust model. Server and FL client +startup kits still use their normal PKI credentials. + .. _federated_authorization: diff --git a/nvflare/apis/job_def.py b/nvflare/apis/job_def.py index 880c49b999..2414704b28 100644 --- a/nvflare/apis/job_def.py +++ b/nvflare/apis/job_def.py @@ -52,6 +52,7 @@ class JobMetaKey(str, Enum): SUBMITTER_NAME = "submitter_name" SUBMITTER_ORG = "submitter_org" SUBMITTER_ROLE = "submitter_role" + SUBMITTER_CERT_VALIDITY = "submitter_cert_validity" STATUS = "status" DATA_STORAGE_FORMAT = "data_storage_format" DEPLOY_MAP = "deploy_map" diff --git a/nvflare/apis/utils/format_check.py b/nvflare/apis/utils/format_check.py index e2f8afa0ae..dbb06ed3a7 100644 --- a/nvflare/apis/utils/format_check.py +++ b/nvflare/apis/utils/format_check.py @@ -23,6 +23,7 @@ "job_name": r"^[A-Za-z0-9][A-Za-z0-9._-]*$", "relay": r"^[A-Za-z0-9-_]+$", "admin": r"^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}$", + "admin_kit": r"^[A-Za-z0-9-_]+$", "email": r"^[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,}$", "org": r"^[A-Za-z0-9_]+$", "simple_name": r"^[A-Za-z0-9_]+$", diff --git a/nvflare/fuel/hci/client/api.py b/nvflare/fuel/hci/client/api.py index 350075acb9..d904190cc8 100644 --- a/nvflare/fuel/hci/client/api.py +++ b/nvflare/fuel/hci/client/api.py @@ -58,6 +58,11 @@ from nvflare.fuel.hci.reg import CommandEntry, CommandModule, CommandRegister from nvflare.fuel.hci.table import Table from nvflare.fuel.sec.authn import set_add_auth_headers_filters +from nvflare.fuel.sec.ephemeral_admin_cert import ( + get_ephemeral_admin_cert_renewal_window, + obtain_ephemeral_admin_cert_files, + validate_ephemeral_admin_cert_config, +) from nvflare.fuel.utils.admin_name_utils import new_admin_client_name from nvflare.fuel.utils.log_utils import get_obj_logger from nvflare.private.aux_runner import AuxMsgTarget, AuxRunner @@ -297,6 +302,18 @@ def __init__( self.ca_cert = admin_config.get(AdminConfigKey.CA_CERT) self.client_cert = admin_config.get(AdminConfigKey.CLIENT_CERT) self.client_key = admin_config.get(AdminConfigKey.CLIENT_KEY) + self.ephemeral_admin_cert_files = None + self.ephemeral_admin_cert_config = admin_config.get(AdminConfigKey.EPHEMERAL_ADMIN_CERT) + try: + if self.ephemeral_admin_cert_config: + self.ephemeral_admin_cert_config = validate_ephemeral_admin_cert_config( + self.ephemeral_admin_cert_config + ) + self.ephemeral_admin_cert_renewal_window = get_ephemeral_admin_cert_renewal_window( + self.ephemeral_admin_cert_config or {} + ) + except ValueError as ex: + raise ConfigError(str(ex)) from ex self.uid_source = admin_config.get(AdminConfigKey.UID_SOURCE, UidSource.USER_INPUT) self.host = admin_config.get(AdminConfigKey.HOST, "localhost") self.port = admin_config.get(AdminConfigKey.PORT, 8002) @@ -309,6 +326,12 @@ def __init__( if not self.ca_cert: raise ConfigError("missing CA Cert file name") + if self.ephemeral_admin_cert_config: + if self.client_cert or self.client_key: + raise ConfigError( + "client_cert and client_key must both be omitted when ephemeral_admin_cert is configured" + ) + self.ensure_client_cert_valid() if not self.client_cert: raise ConfigError("missing Client Cert file name") if not self.client_key: @@ -367,6 +390,50 @@ def __init__( ) self.file_download_waiters = {} # tx_id => Threading.Event + def ensure_client_cert_valid(self): + if not getattr(self, "ephemeral_admin_cert_config", None): + return False + if self.ephemeral_admin_cert_files and not self.ephemeral_admin_cert_files.needs_renewal( + renewal_window=self.ephemeral_admin_cert_renewal_window + ): + return False + + renewing = self.ephemeral_admin_cert_files is not None + try: + new_files = obtain_ephemeral_admin_cert_files( + config=self.ephemeral_admin_cert_config, + root_ca_file=self.ca_cert, + ) + except Exception as ex: + raise ConfigError(f"failed to obtain ephemeral admin certificate: {secure_format_exception(ex)}") from ex + + self.ephemeral_admin_cert_files = new_files + self.client_cert = new_files.client_cert + self.client_key = new_files.client_key + if self.uid_source == UidSource.CERT: + cert = load_cert_file(self.client_cert) + self.user_name = get_cn_from_cert(cert) + if renewing: + self.fl_ctx_mgr.identity_name = self.user_name + if renewing: + self._reset_cell() + return True + + def _reset_cell(self): + self.server_sess_active = False + self.token = None + self.login_result = None + try: + self.shutdown_streamer() + finally: + try: + if self.cell: + self.cell.stop() + finally: + self.cell = None + self.aux_runner = None + self.object_streamer = None + def new_context(self): return self.fl_ctx_mgr.new_context() @@ -385,6 +452,7 @@ def connect(self, timeout=None): self._print_hci("Connecting to FLARE ...") if self.cell: return + self.ensure_client_cert_valid() my_fqcn = new_admin_client_name() credentials = { @@ -405,71 +473,75 @@ def connect(self, timeout=None): self.debug(f"Creating cell: {my_fqcn=} {root_url=} {secure_conn=} {credentials=}") - self.cell = Cell( - fqcn=my_fqcn, - root_url=root_url, - secure=secure_conn, - credentials=credentials, - create_internal_listener=False, - parent_url=None, - auth_identity_map={FQCN.ROOT_SERVER: self.server_identity}, - ) - - self.cell.register_request_cb( - channel=CellChannel.HCI, - topic="SESSION_EXPIRED", - cb=self._handle_session_expired, - ) + try: + self.cell = Cell( + fqcn=my_fqcn, + root_url=root_url, + secure=secure_conn, + credentials=credentials, + create_internal_listener=False, + parent_url=None, + auth_identity_map={FQCN.ROOT_SERVER: self.server_identity}, + ) - NetAgent(self.cell) - self.cell.start() + self.cell.register_request_cb( + channel=CellChannel.HCI, + topic="SESSION_EXPIRED", + cb=self._handle_session_expired, + ) - # authenticate - authenticator = Authenticator( - cell=self.cell, - project_name=self.project_name, - client_name=self.user_name, - client_type=ClientType.ADMIN, - expected_sp_identity=self.server_identity, - secure_mode=True, # always True to authenticate the cell endpoint! - root_cert_file=self.ca_cert, - private_key_file=self.client_key, - cert_file=self.client_cert, - msg_timeout=self.authenticate_msg_timeout, - retry_interval=1.0, - timeout=timeout, - ) + NetAgent(self.cell) + self.cell.start() + + # authenticate + authenticator = Authenticator( + cell=self.cell, + project_name=self.project_name, + client_name=self.user_name, + client_type=ClientType.ADMIN, + expected_sp_identity=self.server_identity, + secure_mode=True, # always True to authenticate the cell endpoint! + root_cert_file=self.ca_cert, + private_key_file=self.client_key, + cert_file=self.client_cert, + msg_timeout=self.authenticate_msg_timeout, + retry_interval=1.0, + timeout=timeout, + ) - abort_signal = Signal() - shared_fl_ctx = FLContext() - shared_fl_ctx.set_public_props({ReservedKey.IDENTITY_NAME: self.user_name}) - token, token_signature, ssid, token_verifier = authenticator.authenticate( - shared_fl_ctx=shared_fl_ctx, - abort_signal=abort_signal, - ) + abort_signal = Signal() + shared_fl_ctx = FLContext() + shared_fl_ctx.set_public_props({ReservedKey.IDENTITY_NAME: self.user_name}) + token, token_signature, ssid, token_verifier = authenticator.authenticate( + shared_fl_ctx=shared_fl_ctx, + abort_signal=abort_signal, + ) - if not isinstance(token_verifier, TokenVerifier): - raise RuntimeError(f"expect token_verifier to be TokenVerifier but got {type(token_verifier)}") + if not isinstance(token_verifier, TokenVerifier): + raise RuntimeError(f"expect token_verifier to be TokenVerifier but got {type(token_verifier)}") - set_add_auth_headers_filters(self.cell, self.user_name, token, token_signature, ssid) + set_add_auth_headers_filters(self.cell, self.user_name, token, token_signature, ssid) - self.cell.core_cell.add_incoming_filter( - channel="*", - topic="*", - cb=validate_auth_headers, - token_verifier=token_verifier, - logger=self.logger, - ) - self.debug(f"Successfully authenticated to {self.server_identity}: {token=} {ssid=}") + self.cell.core_cell.add_incoming_filter( + channel="*", + topic="*", + cb=validate_auth_headers, + token_verifier=token_verifier, + logger=self.logger, + ) + self.debug(f"Successfully authenticated to {self.server_identity}: {token=} {ssid=}") - self.aux_runner = AuxRunner(self) - self.object_streamer = ObjectStreamer(self.aux_runner) + self.aux_runner = AuxRunner(self) + self.object_streamer = ObjectStreamer(self.aux_runner) - self.cell.register_request_cb( - channel=CellChannel.AUX_COMMUNICATION, - topic="*", - cb=self._handle_aux_message, - ) + self.cell.register_request_cb( + channel=CellChannel.AUX_COMMUNICATION, + topic="*", + cb=self._handle_aux_message, + ) + except Exception: + self._reset_cell() + raise def _handle_aux_message(self, request: CellMessage) -> CellMessage: assert isinstance(request, CellMessage), "request must be CellMessage but got {}".format(type(request)) @@ -671,8 +743,10 @@ def _user_login(self): Returns: A dict of login status and details """ + self.ensure_client_cert_valid() + if not self.cell: + self.connect() command = f"{InternalCommands.CERT_LOGIN} {self.user_name}" - id_asserter = IdentityAsserter(private_key_file=self.client_key, cert_file=self.client_cert) cn_signature = id_asserter.sign_common_name(nonce="") diff --git a/nvflare/fuel/hci/client/api_spec.py b/nvflare/fuel/hci/client/api_spec.py index 9fc3fab5c1..9fa6dad401 100644 --- a/nvflare/fuel/hci/client/api_spec.py +++ b/nvflare/fuel/hci/client/api_spec.py @@ -218,6 +218,7 @@ class AdminConfigKey: CONNECTION_SECURITY = "connection_security" CLIENT_KEY = "client_key" CLIENT_CERT = "client_cert" + EPHEMERAL_ADMIN_CERT = "ephemeral_admin_cert" CA_CERT = "ca_cert" HOST = "host" PORT = "port" diff --git a/nvflare/fuel/hci/client/file_transfer.py b/nvflare/fuel/hci/client/file_transfer.py index 8100802522..2cdf65fbee 100644 --- a/nvflare/fuel/hci/client/file_transfer.py +++ b/nvflare/fuel/hci/client/file_transfer.py @@ -309,6 +309,23 @@ def push_folder(self, args, ctx: CommandContext): # sign folders and files (skip gracefully when key is absent — e.g. simulator) api = ctx.get_api() + ensure_client_cert_valid = getattr(api, "ensure_client_cert_valid", None) + if ensure_client_cert_valid: + try: + ensure_client_cert_valid() + except Exception as e: + return {"status": APIStatus.ERROR_RUNTIME, "details": f"Failed to refresh admin certificate: {e}"} + if getattr(api, "cell", None) is None: + try: + api.connect() + login_result = api.login() + except Exception as e: + return { + "status": APIStatus.ERROR_RUNTIME, + "details": f"Failed to reconnect with refreshed admin certificate: {e}", + } + if login_result.get("status") != APIStatus.SUCCESS: + return login_result client_key_file_path = api.client_key if client_key_file_path and os.path.exists(client_key_file_path) and api.client_cert: try: @@ -325,6 +342,8 @@ def push_folder(self, args, ctx: CommandContext): folder_name = split_path(full_path)[1] parts = [cmd_entry.full_command_name(), folder_name] + if getattr(api, "ephemeral_admin_cert_config", None): + parts.append("--ephemeral-admin-cert") parts.extend(submit_args) command = join_args(parts) sender = _FileSender(out_file) diff --git a/nvflare/fuel/sec/ephemeral_admin_cert.py b/nvflare/fuel/sec/ephemeral_admin_cert.py new file mode 100644 index 0000000000..3ad9d51027 --- /dev/null +++ b/nvflare/fuel/sec/ephemeral_admin_cert.py @@ -0,0 +1,321 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import hashlib +import importlib +import json +import os +import shutil +import tempfile +import time +from contextlib import contextmanager +from dataclasses import dataclass, field +from pathlib import Path +from typing import Mapping, Optional + +from cryptography.hazmat.primitives import serialization + +from nvflare.fuel.sec.admin_cert import validate_admin_leaf_cert +from nvflare.lighter.utils import load_crt, load_crt_chain, load_private_key_file, verify_cert_chain + +EPHEMERAL_ADMIN_CERT_PROVIDER_CONFIG_KEY = "provider_config" +EPHEMERAL_ADMIN_CERT_PROVIDER_KEY = "provider" +EPHEMERAL_ADMIN_CERT_CACHE_DIR = "ephemeral_admin_certs" +EPHEMERAL_ADMIN_CERT_CLIENT_CERT = "client.crt" +EPHEMERAL_ADMIN_CERT_CLIENT_KEY = "client.key" +EPHEMERAL_ADMIN_CERT_CACHE_LOCK = ".lock" +BUILTIN_EPHEMERAL_ADMIN_CERT_PROVIDERS = { + "step_ca": "nvflare.fuel.sec.step_ca_admin_cert:obtain_step_ca_admin_cert_files", +} + + +class EphemeralAdminCertError(ValueError): + """Raised when short-lived admin certificate acquisition fails.""" + + +@dataclass +class EphemeralAdminCertFiles: + client_key: str + client_cert: str + expires_at: float = 0.0 + temp_dir: Optional[tempfile.TemporaryDirectory] = field(default=None, repr=False) + + def needs_renewal(self, renewal_window: float = 60.0, now: Optional[float] = None) -> bool: + if not self.expires_at: + return True + now = time.time() if now is None else now + return self.expires_at - now <= max(0.0, renewal_window) + + def cleanup(self): + if self.temp_dir: + self.temp_dir.cleanup() + self.temp_dir = None + + +def obtain_ephemeral_admin_cert_files(config: Mapping, root_ca_file: str) -> EphemeralAdminCertFiles: + config = validate_ephemeral_admin_cert_config(config) + if not root_ca_file: + raise EphemeralAdminCertError("root_ca_file is required") + + provider = config.get(EPHEMERAL_ADMIN_CERT_PROVIDER_KEY) + provider_name = provider + provider_config = config[EPHEMERAL_ADMIN_CERT_PROVIDER_CONFIG_KEY] + renewal_window = get_ephemeral_admin_cert_renewal_window(config) + cache_dir = _cache_base_dir() / _cache_key( + provider=provider_name, provider_config=provider_config, root_ca_file=root_ca_file + ) + with _cache_lock(cache_dir): + cached_files = _load_cached_ephemeral_admin_cert_files(cache_dir, root_ca_file, renewal_window) + if cached_files: + return cached_files + + provider_func = _load_provider(provider_name) + files = provider_func(config=provider_config, root_ca_file=root_ca_file) + if not isinstance(files, EphemeralAdminCertFiles): + raise EphemeralAdminCertError(f"ephemeral admin cert provider returned {type(files)}") + + try: + cert = validate_ephemeral_admin_cert_files(files.client_cert, files.client_key, root_ca_file) + except Exception: + files.cleanup() + raise + files.expires_at = cert_time(cert, "not_valid_after").timestamp() + return _store_ephemeral_admin_cert_files(files, cache_dir, root_ca_file) + + +def validate_ephemeral_admin_cert_files( + cert_path: str, + key_path: str, + root_ca_file: str, +): + cert_chain = load_crt_chain(cert_path) + cert = cert_chain[0] + root_ca_cert = load_crt(root_ca_file) + try: + verify_cert_chain(leaf_cert=cert, intermediate_certs=cert_chain[1:], root_ca_cert=root_ca_cert) + except Exception as ex: + raise EphemeralAdminCertError("ephemeral admin certificate does not chain to the configured CA") from ex + + private_key = load_private_key_file(key_path) + if _public_key_pem(cert.public_key()) != _public_key_pem(private_key.public_key()): + raise EphemeralAdminCertError("ephemeral admin certificate is for a different private key") + + try: + validate_admin_leaf_cert(cert) + except Exception as ex: + raise EphemeralAdminCertError(str(ex)) from ex + return cert + + +def cert_time(cert, field_name: str) -> datetime.datetime: + value = getattr(cert, f"{field_name}_utc", None) + if value is not None: + return value + return getattr(cert, field_name).replace(tzinfo=datetime.timezone.utc) + + +def _load_cached_ephemeral_admin_cert_files( + cache_dir: Path, + root_ca_file: str, + renewal_window: float, +) -> Optional[EphemeralAdminCertFiles]: + issuance_dirs = sorted( + (path for path in cache_dir.iterdir() if path.is_dir() and not path.name.startswith(".")), + key=lambda path: path.name, + reverse=True, + ) + for issuance_dir in issuance_dirs: + cert_path = issuance_dir / EPHEMERAL_ADMIN_CERT_CLIENT_CERT + key_path = issuance_dir / EPHEMERAL_ADMIN_CERT_CLIENT_KEY + if not cert_path.is_file() or not key_path.is_file(): + continue + try: + cert = validate_ephemeral_admin_cert_files(str(cert_path), str(key_path), root_ca_file) + files = EphemeralAdminCertFiles( + client_key=str(key_path), + client_cert=str(cert_path), + expires_at=cert_time(cert, "not_valid_after").timestamp(), + ) + if not files.needs_renewal(renewal_window=renewal_window): + return files + except Exception: + continue + return None + + +def _store_ephemeral_admin_cert_files( + files: EphemeralAdminCertFiles, + cache_dir: Path, + root_ca_file: str, +) -> EphemeralAdminCertFiles: + _ensure_private_dir(cache_dir) + temp_dir = Path(tempfile.mkdtemp(prefix=".new-", dir=cache_dir)) + issuance_dir = cache_dir / str(time.time_ns()) + cert_path = temp_dir / EPHEMERAL_ADMIN_CERT_CLIENT_CERT + key_path = temp_dir / EPHEMERAL_ADMIN_CERT_CLIENT_KEY + + try: + _copy_file_private(files.client_cert, cert_path) + _copy_file_private(files.client_key, key_path) + os.replace(temp_dir, issuance_dir) + except Exception: + shutil.rmtree(temp_dir, ignore_errors=True) + files.cleanup() + raise + + files.cleanup() + _remove_stale_cache_entries(cache_dir, root_ca_file, keep=issuance_dir) + return EphemeralAdminCertFiles( + client_key=str(issuance_dir / EPHEMERAL_ADMIN_CERT_CLIENT_KEY), + client_cert=str(issuance_dir / EPHEMERAL_ADMIN_CERT_CLIENT_CERT), + expires_at=files.expires_at, + ) + + +def _remove_stale_cache_entries(cache_dir: Path, root_ca_file: str, keep: Path): + now = time.time() + for issuance_dir in cache_dir.iterdir(): + if not issuance_dir.is_dir() or issuance_dir == keep or issuance_dir.name.startswith("."): + continue + cert_path = issuance_dir / EPHEMERAL_ADMIN_CERT_CLIENT_CERT + key_path = issuance_dir / EPHEMERAL_ADMIN_CERT_CLIENT_KEY + try: + cert = validate_ephemeral_admin_cert_files(str(cert_path), str(key_path), root_ca_file) + expired = cert_time(cert, "not_valid_after").timestamp() <= now + except Exception: + expired = True + if expired: + shutil.rmtree(issuance_dir, ignore_errors=True) + + +@contextmanager +def _cache_lock(cache_dir: Path): + try: + import fcntl + except ImportError as ex: + raise EphemeralAdminCertError("ephemeral admin certificate caching requires POSIX file locking") from ex + + _ensure_private_dir(cache_dir) + lock_path = cache_dir / EPHEMERAL_ADMIN_CERT_CACHE_LOCK + with open(lock_path, "a+b") as lock_file: + os.chmod(lock_path, 0o600) + fcntl.flock(lock_file, fcntl.LOCK_EX) + try: + yield + finally: + fcntl.flock(lock_file, fcntl.LOCK_UN) + + +def _cache_base_dir() -> Path: + cache_dir = Path.home() / ".nvflare" / EPHEMERAL_ADMIN_CERT_CACHE_DIR + _ensure_private_dir(cache_dir) + return cache_dir + + +def _ensure_private_dir(path: Path): + path.mkdir(mode=0o700, parents=True, exist_ok=True) + os.chmod(path, 0o700) + + +def _copy_file_private(src: str, dst: Path): + fd = os.open(dst, os.O_WRONLY | os.O_CREAT | os.O_EXCL, 0o600) + try: + with open(src, "rb") as in_file: + with os.fdopen(fd, "wb") as out_file: + fd = None + shutil.copyfileobj(in_file, out_file) + finally: + if fd is not None: + os.close(fd) + + +def _cache_key(provider: str, provider_config: Mapping, root_ca_file: str) -> str: + with open(root_ca_file, "rb") as f: + root_ca_hash = hashlib.sha256(f.read()).hexdigest() + + cache_material = { + "version": 1, + "root_ca_sha256": root_ca_hash, + "provider": provider, + "provider_config": provider_config, + } + encoded = json.dumps(cache_material, sort_keys=True, separators=(",", ":")).encode("utf-8") + return hashlib.sha256(encoded).hexdigest() + + +def get_ephemeral_admin_cert_renewal_window(config: Mapping) -> float: + renewal_window = config.get("renewal_window", 60.0) + try: + renewal_window = float(renewal_window) + except (TypeError, ValueError) as ex: + raise EphemeralAdminCertError("ephemeral_admin_cert.renewal_window must be a number") from ex + if renewal_window <= 0.0: + raise EphemeralAdminCertError("ephemeral_admin_cert.renewal_window must be greater than zero") + return renewal_window + + +def validate_ephemeral_admin_cert_config(config: Mapping) -> dict: + if not isinstance(config, Mapping): + raise EphemeralAdminCertError(f"ephemeral_admin_cert must be a mapping but got {type(config)}") + + result = dict(config) + provider = result.get(EPHEMERAL_ADMIN_CERT_PROVIDER_KEY) + if not isinstance(provider, str) or not provider: + raise EphemeralAdminCertError(f"ephemeral_admin_cert.{EPHEMERAL_ADMIN_CERT_PROVIDER_KEY} is required") + _validate_provider_name(provider) + + provider_config = result.get(EPHEMERAL_ADMIN_CERT_PROVIDER_CONFIG_KEY) or {} + if not isinstance(provider_config, Mapping): + raise EphemeralAdminCertError( + f"ephemeral_admin_cert.{EPHEMERAL_ADMIN_CERT_PROVIDER_CONFIG_KEY} must be a mapping" + ) + result[EPHEMERAL_ADMIN_CERT_PROVIDER_CONFIG_KEY] = dict(provider_config) + get_ephemeral_admin_cert_renewal_window(result) + return result + + +def _validate_provider_name(provider: str): + provider_path = BUILTIN_EPHEMERAL_ADMIN_CERT_PROVIDERS.get(provider, provider) + if ":" not in provider_path: + raise EphemeralAdminCertError( + f"ephemeral admin cert provider '{provider}' must be a built-in provider name or module:function path" + ) + module_name, func_name = provider_path.split(":", 1) + if not module_name or not func_name: + raise EphemeralAdminCertError( + f"ephemeral admin cert provider '{provider}' must be a built-in provider name or module:function path" + ) + + +def _load_provider(provider: str): + provider_path = BUILTIN_EPHEMERAL_ADMIN_CERT_PROVIDERS.get(provider, provider) + _validate_provider_name(provider) + module_name, func_name = provider_path.split(":", 1) + + try: + module = importlib.import_module(module_name) + provider_func = getattr(module, func_name) + except Exception as ex: + raise EphemeralAdminCertError(f"cannot load ephemeral admin cert provider '{provider}': {ex}") from ex + if not callable(provider_func): + raise EphemeralAdminCertError(f"ephemeral admin cert provider '{provider}' is not callable") + return provider_func + + +def _public_key_pem(public_key) -> bytes: + return public_key.public_bytes( + encoding=serialization.Encoding.PEM, + format=serialization.PublicFormat.SubjectPublicKeyInfo, + ) diff --git a/nvflare/fuel/sec/step_ca_admin_cert.py b/nvflare/fuel/sec/step_ca_admin_cert.py new file mode 100644 index 0000000000..d9cdd4f07d --- /dev/null +++ b/nvflare/fuel/sec/step_ca_admin_cert.py @@ -0,0 +1,125 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import subprocess +import tempfile +from typing import Mapping, Sequence +from urllib.parse import urlparse + +from nvflare.fuel.sec.admin_cert import ADMIN_CERT_PLACEHOLDER_CN +from nvflare.fuel.sec.ephemeral_admin_cert import EphemeralAdminCertError, EphemeralAdminCertFiles + +DEFAULT_STEP_CA_CERT_TTL = "24h" +DEFAULT_STEP_CA_REQUEST_NAME = ADMIN_CERT_PLACEHOLDER_CN +DEFAULT_STEP_CA_COMMAND_TIMEOUT = 300.0 + + +def obtain_step_ca_admin_cert_files(config: Mapping, root_ca_file: str) -> EphemeralAdminCertFiles: + temp_dir = tempfile.TemporaryDirectory(prefix="nvflare-step-ca-admin-") + cert_path = os.path.join(temp_dir.name, "client.crt") + key_path = os.path.join(temp_dir.name, "client.key") + command = _build_step_ca_command( + config=config, + root_ca_file=root_ca_file, + cert_path=cert_path, + key_path=key_path, + ) + command_timeout = float(config.get("command_timeout") or DEFAULT_STEP_CA_COMMAND_TIMEOUT) + + try: + _run_step(command, timeout=command_timeout) + except Exception: + temp_dir.cleanup() + raise + + return EphemeralAdminCertFiles( + client_key=key_path, + client_cert=cert_path, + temp_dir=temp_dir, + ) + + +def validate_step_ca_admin_cert_config(config: Mapping) -> dict: + if not isinstance(config, Mapping): + raise EphemeralAdminCertError(f"step_ca provider_config must be a mapping but got {type(config)}") + result = dict(config) + ca_url = result.get("ca_url") + if not ca_url: + raise EphemeralAdminCertError("step_ca provider_config.ca_url is required") + _validate_step_ca_url(str(ca_url)) + if not result.get("provisioner"): + raise EphemeralAdminCertError("step_ca provider_config.provisioner is required") + _command_timeout(result) + return result + + +def _build_step_ca_command( + config: Mapping, + root_ca_file: str, + cert_path: str, + key_path: str, +) -> Sequence[str]: + config = validate_step_ca_admin_cert_config(config) + step_bin = str(config.get("step_bin") or "step") + ca_url = str(config.get("ca_url")) + provisioner = str(config.get("provisioner")) + command = [step_bin, "ca", "certificate", "--ca-url", ca_url, "--root", root_ca_file] + command.extend(["--provisioner", provisioner]) + command.extend(["--not-after", str(config.get("cert_ttl") or DEFAULT_STEP_CA_CERT_TTL)]) + command.extend(["--kty", "RSA", "--size", "2048"]) + command.extend([DEFAULT_STEP_CA_REQUEST_NAME, cert_path, key_path]) + return command + + +def _command_timeout(config: Mapping) -> float: + command_timeout_config = config.get("command_timeout") + if command_timeout_config is None or command_timeout_config == "": + return DEFAULT_STEP_CA_COMMAND_TIMEOUT + try: + command_timeout = float(command_timeout_config) + except (TypeError, ValueError) as ex: + raise EphemeralAdminCertError("step_ca provider_config.command_timeout must be a number") from ex + if command_timeout <= 0.0: + raise EphemeralAdminCertError("step_ca provider_config.command_timeout must be greater than zero") + return command_timeout + + +def _run_step( + command: Sequence[str], + timeout: float = DEFAULT_STEP_CA_COMMAND_TIMEOUT, +): + try: + return subprocess.run( + command, + check=True, + timeout=timeout, + ) + except FileNotFoundError as ex: + raise EphemeralAdminCertError( + "step-ca admin certs require the 'step' CLI in PATH or step_ca provider_config.step_bin" + ) from ex + except subprocess.TimeoutExpired as ex: + raise EphemeralAdminCertError(f"step ca certificate timed out after {timeout} seconds") from ex + except subprocess.CalledProcessError as ex: + raise EphemeralAdminCertError(f"step ca certificate failed with exit code {ex.returncode}") from ex + + +def _validate_step_ca_url(url: str): + parsed = urlparse(url) + if parsed.scheme == "https" and parsed.netloc: + return + if parsed.scheme == "http" and parsed.hostname in {"127.0.0.1", "::1", "localhost"}: + return + raise EphemeralAdminCertError("step_ca provider_config.ca_url must use https; http is only allowed for localhost") diff --git a/nvflare/lighter/constants.py b/nvflare/lighter/constants.py index d264826912..2a2fe89ea9 100644 --- a/nvflare/lighter/constants.py +++ b/nvflare/lighter/constants.py @@ -59,6 +59,7 @@ class PropKey: # ever shipped with the previous name. ALLOW_LOG_STREAMING = "allow_log_streaming" CONN_SECURITY = "connection_security" + EPHEMERAL_ADMIN_CERT = "ephemeral_admin_cert" AUTH_IDENTITY = "auth_identity" CUSTOM_CA_CERT = "custom_ca_cert" SCHEME = "scheme" diff --git a/nvflare/lighter/entity.py b/nvflare/lighter/entity.py index 8e32fb972e..7960f84791 100644 --- a/nvflare/lighter/entity.py +++ b/nvflare/lighter/entity.py @@ -238,7 +238,7 @@ def __repr__(self): class Participant(Entity): - def __init__(self, type: str, name: str, org: str, props: Optional[dict] = None, project: Entity = None): + def __init__(self, type: str, name: str, org: Optional[str], props: Optional[dict] = None, project: Entity = None): """Class to represent a participant. Each participant communicates to other participant. Therefore, each participant has its @@ -256,8 +256,11 @@ def __init__(self, type: str, name: str, org: str, props: Optional[dict] = None, """ Entity.__init__(self, f"{type}::{name}", name, props, parent=project) + ephemeral_admin = type == ParticipantType.ADMIN and bool(props and props.get(PropKey.EPHEMERAL_ADMIN_CERT)) if type in DEFINED_PARTICIPANT_TYPES: err, reason = name_check(name, type) + if err and ephemeral_admin: + err, reason = name_check(name, "admin_kit") if err: raise ValueError(reason) else: @@ -266,24 +269,32 @@ def __init__(self, type: str, name: str, org: str, props: Optional[dict] = None, raise ValueError(reason) print(f"Warning: participant type '{type}' of {name} is not a defined type {DEFINED_PARTICIPANT_TYPES}") - err, reason = name_check(org, "org") - if err: - raise ValueError(reason) + if ephemeral_admin and org: + raise ValueError(f"ephemeral admin '{name}' must not define org; org comes from issued cert") + if org: + err, reason = name_check(org, "org") + if err: + raise ValueError(reason) + elif not ephemeral_admin: + raise ValueError(f"missing participant {PropKey.ORG}") if type == ParticipantType.ADMIN: - if not props: - raise ValueError(f"missing role for admin '{name}'") - - role = props.get(PropKey.ROLE) - if not role: + if ephemeral_admin: + if props.get(PropKey.ROLE): + raise ValueError(f"ephemeral admin '{name}' must not define role; role comes from issued cert") + elif not props: raise ValueError(f"missing role for admin '{name}'") + else: + role = props.get(PropKey.ROLE) + if not role: + raise ValueError(f"missing role for admin '{name}'") - err, reason = name_check(role, "simple_name") - if err: - raise ValueError(f"bad role value '{role}' for admin '{name}': {reason}") + err, reason = name_check(role, "simple_name") + if err: + raise ValueError(f"bad role value '{role}' for admin '{name}': {reason}") - if role not in DEFINED_ROLES: - print(f"Warning: '{role}' of admin '{name}' is not a defined role {DEFINED_ROLES}") + if role not in DEFINED_ROLES: + print(f"Warning: '{role}' of admin '{name}' is not a defined role {DEFINED_ROLES}") self.type = type self.org = org @@ -374,7 +385,7 @@ def participant_from_dict(participant_def: dict) -> Participant: name = _must_get(participant_def, PropKey.NAME) t = _must_get(participant_def, PropKey.TYPE) - org = _must_get(participant_def, PropKey.ORG) + org = participant_def.pop(PropKey.ORG, None) return Participant(type=t, name=name, org=org, props=participant_def) diff --git a/nvflare/lighter/ephemeral_admin.py b/nvflare/lighter/ephemeral_admin.py new file mode 100644 index 0000000000..ab62abc020 --- /dev/null +++ b/nvflare/lighter/ephemeral_admin.py @@ -0,0 +1,29 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Optional + +from nvflare.fuel.sec.ephemeral_admin_cert import validate_ephemeral_admin_cert_config +from nvflare.lighter.constants import PropKey + + +def get_admin_ephemeral_cert_config(admin) -> Optional[dict]: + config = admin.get_prop(PropKey.EPHEMERAL_ADMIN_CERT) + if not config: + return None + scope = f"admin {admin.name}.{PropKey.EPHEMERAL_ADMIN_CERT}" + try: + return validate_ephemeral_admin_cert_config(config) + except ValueError as ex: + raise ValueError(f"invalid {scope}: {ex}") from ex diff --git a/nvflare/lighter/impl/cert.py b/nvflare/lighter/impl/cert.py index 1f26cc5cd1..c2f3078a2b 100644 --- a/nvflare/lighter/impl/cert.py +++ b/nvflare/lighter/impl/cert.py @@ -24,6 +24,7 @@ from nvflare.lighter.constants import CertFileBasename, CtxKey, ParticipantType, PropKey from nvflare.lighter.ctx import ProvisionContext from nvflare.lighter.entity import Participant, Project +from nvflare.lighter.ephemeral_admin import get_admin_ephemeral_cert_config from nvflare.lighter.spec import Builder from nvflare.lighter.utils import Identity, generate_cert, generate_keys, serialize_cert, serialize_pri_key @@ -265,6 +266,10 @@ def _build_write_cert_pair(self, participant: Participant, base_name, ctx: Provi if participant.type in [ParticipantType.CLIENT, ParticipantType.RELAY]: self._build_internal_listener_cert(participant, ctx) + self._write_root_ca(participant, ctx) + + def _write_root_ca(self, participant: Participant, ctx: ProvisionContext): + dest_dir = ctx.get_kit_dir(participant) with open(os.path.join(dest_dir, "rootCA.pem"), "wb") as f: f.write(self.serialized_cert) @@ -333,7 +338,10 @@ def build(self, project: Project, ctx: ProvisionContext): self._build_write_cert_pair(relay, CertFileBasename.CLIENT, ctx) for admin in project.get_admins(): - self._build_write_cert_pair(admin, CertFileBasename.CLIENT, ctx) + if get_admin_ephemeral_cert_config(admin): + self._write_root_ca(admin, ctx) + else: + self._build_write_cert_pair(admin, CertFileBasename.CLIENT, ctx) def get_pri_key_cert(self, participant: Participant): pri_key, pub_key = generate_keys() diff --git a/nvflare/lighter/impl/static_file.py b/nvflare/lighter/impl/static_file.py index dfe562da59..3873c80326 100644 --- a/nvflare/lighter/impl/static_file.py +++ b/nvflare/lighter/impl/static_file.py @@ -30,6 +30,7 @@ TemplateSectionKey, ) from nvflare.lighter.entity import Participant +from nvflare.lighter.ephemeral_admin import get_admin_ephemeral_cert_config from nvflare.lighter.spec import Builder, Project, ProvisionContext _logger = logging.getLogger(__name__) @@ -633,7 +634,8 @@ def prepare_admin_config(self, admin: Participant, ctx: ProvisionContext): if not conn_sec: conn_sec = ConnSecurity.MTLS - uid_source = "user_input" + ephemeral_admin_cert = get_admin_ephemeral_cert_config(admin) + uid_source = "cert" if ephemeral_admin_cert else "user_input" provision_mode = ctx.get_provision_mode() if provision_mode == ProvisionMode.POC: uid_source = "cert" @@ -644,7 +646,7 @@ def prepare_admin_config(self, admin: Participant, ctx: ProvisionContext): replacement_dict = { "project_name": project.name, - "username": "" if provision_mode == ProvisionMode.POC else admin.name, + "username": "" if provision_mode == ProvisionMode.POC or ephemeral_admin_cert else admin.name, "server_identity": self._get_auth_identity(server), "scheme": self.scheme, "conn_sec": conn_sec, @@ -657,6 +659,8 @@ def prepare_admin_config(self, admin: Participant, ctx: ProvisionContext): temp_section=TemplateSectionKey.FED_ADMIN, file_name=ProvFileName.FED_ADMIN_JSON, replacement=replacement_dict, + content_modify_cb=_modify_fed_admin_config, + ephemeral_admin_cert=ephemeral_admin_cert, ) # create default resources in local @@ -1035,6 +1039,20 @@ def _remove_undefined_port(section: str) -> str: return section +def _modify_fed_admin_config(section: str, ephemeral_admin_cert=None) -> str: + if not ephemeral_admin_cert: + return section + + admin_config = json.loads(section) + admin = admin_config.get("admin", {}) + admin.pop("client_key", None) + admin.pop("client_cert", None) + admin["username"] = "" + admin["uid_source"] = "cert" + admin[PropKey.EPHEMERAL_ADMIN_CERT] = dict(ephemeral_admin_cert) + return json.dumps(admin_config, indent=2) + + def check_parent(c: Participant, path: list): if c.name in path: return f"circular parent ref {c.name}" diff --git a/nvflare/private/fed/server/job_cmds.py b/nvflare/private/fed/server/job_cmds.py index 8e21fd92d0..e52f38ae2a 100644 --- a/nvflare/private/fed/server/job_cmds.py +++ b/nvflare/private/fed/server/job_cmds.py @@ -16,6 +16,7 @@ import io import json import os +import posixpath import shutil import threading import uuid @@ -66,8 +67,11 @@ from nvflare.fuel.hci.server.authz import PreAuthzReturnCode from nvflare.fuel.hci.server.binary_transfer import BinaryTransfer from nvflare.fuel.hci.server.constants import ConnProps +from nvflare.fuel.sec.ephemeral_admin_cert import cert_time from nvflare.fuel.utils.argument_utils import SafeArgumentParser from nvflare.fuel.utils.log_utils import get_obj_logger +from nvflare.lighter.tool_consts import NVFLARE_SUBMITTER_CRT_FILE +from nvflare.lighter.utils import load_crt_chain_bytes from nvflare.private.defs import RequestHeader, TrainingTopic from nvflare.private.fed.server.admin import new_message from nvflare.private.fed.server.job_meta_validator import JobMetaValidator @@ -98,12 +102,67 @@ def __init__(self, record: dict): JobMetaKey.MANDATORY_CLIENTS.value, JobMetaKey.DATA_STORAGE_FORMAT.value, JobMetaKey.STUDY.value, + JobMetaKey.SUBMITTER_CERT_VALIDITY.value, AppValidationKey.BYOC, } JSON_LOG_FILE_NAME = "log.json" +def _submitter_cert_validity(zip_file_name: str) -> Optional[dict]: + zip_source = io.BytesIO(zip_file_name) if isinstance(zip_file_name, bytes) else zip_file_name + try: + zip_file = ZipFile(zip_source) + except (BadZipFile, OSError): + return None + + validity = None + with zip_file: + for info in zip_file.infolist(): + if info.is_dir() or posixpath.basename(info.filename) != NVFLARE_SUBMITTER_CRT_FILE: + continue + + try: + cert_chain = load_crt_chain_bytes(zip_file.read(info)) + except Exception: + return {} + cert = cert_chain[0] + cert_validity = { + "not_before": cert_time(cert, "not_valid_before").timestamp(), + "not_after": cert_time(cert, "not_valid_after").timestamp(), + } + if validity is None: + validity = cert_validity + else: + validity["not_before"] = max(validity["not_before"], cert_validity["not_before"]) + validity["not_after"] = min(validity["not_after"], cert_validity["not_after"]) + return validity + + +def _clone_signature_error(job: Job) -> str: + validity = job.meta.get(JobMetaKey.SUBMITTER_CERT_VALIDITY.value) + if validity is None: + return "" + if not isinstance(validity, dict): + validity = {} + try: + not_before = float(validity["not_before"]) + not_after = float(validity["not_after"]) + except (KeyError, TypeError, ValueError): + return ( + "Cannot clone this job because the stored submitter certificate cannot be inspected. " + "Download and submit the job again so it is signed with a current admin certificate." + ) + + now = datetime.datetime.now(datetime.timezone.utc).timestamp() + if now < not_before or now >= not_after: + return ( + "Cannot clone this job because the stored submitter certificate is no longer valid. " + "Download and submit the job again so it is signed with a current admin certificate." + ) + return "" + + def _active_study_from_conn(conn: Connection) -> str: return conn.get_prop(ConnProps.ACTIVE_STUDY, DEFAULT_STUDY) or DEFAULT_STUDY @@ -132,6 +191,7 @@ def _create_submit_job_cmd_parser(): parser = SafeArgumentParser(prog=AdminCommandNames.SUBMIT_JOB) parser.add_argument("folder_name", help="Uploaded job folder name") parser.add_argument("--submit-token", dest="submit_token", help="retry-safe submit token") + parser.add_argument("--ephemeral-admin-cert", action="store_true") return parser @@ -1074,6 +1134,11 @@ def clone_job(self, conn: Connection, args: List[str]): f"job_def_manager in engine is not of type JobDefManagerSpec, but got {type(job_def_manager)}" ) with engine.new_context() as fl_ctx: + clone_error = _clone_signature_error(job) + if clone_error: + conn.append_error(clone_error, meta=make_meta(MetaStatusValue.INVALID_JOB_DEFINITION, clone_error)) + return + job_meta = {str(k): job.meta[k] for k in job.meta.keys() & CLONED_META_KEYS} # set the submitter info for the new job @@ -1548,6 +1613,7 @@ def submit_job(self, conn: Connection, args: List[str]): parsed_args = parser.parse_args(args[1:]) folder_name = parsed_args.folder_name submit_token = validate_submit_token(parsed_args.submit_token) + ephemeral_admin_cert = parsed_args.ephemeral_admin_cert except ValueError as e: conn.append_error(str(e), meta=make_meta(MetaStatusValue.SYNTAX_ERROR, str(e))) return @@ -1572,6 +1638,7 @@ def submit_job(self, conn: Connection, args: List[str]): meta.pop(JobMetaKey.FROM_HUB_SITE.value, None) # Submit-token is server-owned submission metadata. User job metadata must not expose it. meta.pop(SubmitRecordKey.SUBMIT_TOKEN.value, None) + meta.pop(JobMetaKey.SUBMITTER_CERT_VALIDITY.value, None) job_def_manager = engine.job_def_manager if not isinstance(job_def_manager, JobDefManagerSpec): @@ -1616,6 +1683,11 @@ def submit_job(self, conn: Connection, args: List[str]): ) return + if ephemeral_admin_cert: + submitter_cert_validity = _submitter_cert_validity(zip_file_name) + if submitter_cert_validity is not None: + meta[JobMetaKey.SUBMITTER_CERT_VALIDITY.value] = submitter_cert_validity + # set submitter info submitter = self._submitter_from_conn(conn) meta[JobMetaKey.SUBMITTER_NAME.value] = submitter["name"] diff --git a/nvflare/tool/kit/kit_config.py b/nvflare/tool/kit/kit_config.py index 2d17cc9b2b..abf7022e02 100644 --- a/nvflare/tool/kit/kit_config.py +++ b/nvflare/tool/kit/kit_config.py @@ -45,9 +45,9 @@ ADMIN_STARTUP_KIT_REQUIRED_FILES = ( os.path.join("startup", "fed_admin.json"), - os.path.join("startup", "client.crt"), os.path.join("startup", "rootCA.pem"), ) +ADMIN_STATIC_CERT_FILE = os.path.join("startup", "client.crt") SITE_STARTUP_KIT_REQUIRED_FILES = (os.path.join("startup", "fed_client.json"),) SERVER_STARTUP_KIT_REQUIRED_FILES = (os.path.join("startup", "fed_server.json"),) STARTUP_KIT_KIND_ADMIN = "admin" @@ -236,13 +236,30 @@ def _has_required_files(startup_kit_dir: Path, required_files) -> bool: return all((startup_kit_dir / rel_path).is_file() for rel_path in required_files) +def _is_admin_startup_kit(startup_kit_dir: Path) -> bool: + if not _has_required_files(startup_kit_dir, ADMIN_STARTUP_KIT_REQUIRED_FILES): + return False + if (startup_kit_dir / ADMIN_STATIC_CERT_FILE).is_file(): + return True + + try: + config = json.loads((startup_kit_dir / "startup" / "fed_admin.json").read_text()) + except Exception: + return False + + admin_config = config.get("admin") if isinstance(config, dict) else None + return isinstance(admin_config, dict) and bool(admin_config.get("ephemeral_admin_cert")) + + def classify_startup_kit(path: str) -> Tuple[str, str]: """Return (kind, normalized participant dir) for a generated startup kit.""" startup_path = _as_existing_dir(path) startup_kit_dir = startup_path.parent if startup_path.name == "startup" else startup_path + if _is_admin_startup_kit(startup_kit_dir): + return STARTUP_KIT_KIND_ADMIN, str(startup_kit_dir.resolve()) + for kind, required_files in ( - (STARTUP_KIT_KIND_ADMIN, ADMIN_STARTUP_KIT_REQUIRED_FILES), (STARTUP_KIT_KIND_SITE, SITE_STARTUP_KIT_REQUIRED_FILES), (STARTUP_KIT_KIND_SERVER, SERVER_STARTUP_KIT_REQUIRED_FILES), ): @@ -465,9 +482,13 @@ def _certificate_expiration_metadata(cert, cert_path: str) -> Tuple[Dict, list]: ) -def _inspect_admin_cert_metadata(startup_dir: str, metadata: Dict) -> None: +def _inspect_admin_cert_metadata(startup_dir: str, metadata: Dict, has_ephemeral_admin_cert: bool = False) -> None: cert_path = os.path.join(startup_dir, "client.crt") if not os.path.isfile(cert_path): + if has_ephemeral_admin_cert: + metadata["certificate"] = {"status": "runtime_issued"} + metadata["credential_source"] = "ephemeral_admin_cert" + return metadata["findings"].append( _finding( "STARTUP_KIT_CERT_MISSING", @@ -526,15 +547,18 @@ def inspect_startup_kit_metadata(path: str) -> Dict: startup_dir = os.path.join(startup_kit_dir, "startup") if kind == STARTUP_KIT_KIND_ADMIN: + has_ephemeral_admin_cert = False try: fed_admin_config = ConfigFactory.load_config("fed_admin.json", [startup_dir]) if fed_admin_config: config_dict = fed_admin_config.to_dict() - metadata["identity"] = config_dict.get("admin", {}).get("username") + admin_config = config_dict.get("admin", {}) + metadata["identity"] = admin_config.get("username") + has_ephemeral_admin_cert = bool(admin_config.get("ephemeral_admin_cert")) except Exception: pass - _inspect_admin_cert_metadata(startup_dir, metadata) + _inspect_admin_cert_metadata(startup_dir, metadata, has_ephemeral_admin_cert=has_ephemeral_admin_cert) if not metadata["identity"]: metadata["identity"] = os.path.basename(startup_kit_dir) diff --git a/nvflare/tool/package_checker/check_rule.py b/nvflare/tool/package_checker/check_rule.py index 5d297087a8..13a72c60d0 100644 --- a/nvflare/tool/package_checker/check_rule.py +++ b/nvflare/tool/package_checker/check_rule.py @@ -135,6 +135,7 @@ def __call__(self, package_path, data): host = admin["host"] port = admin["port"] scheme = admin.get("scheme", "grpc") + uses_ephemeral_admin_cert = bool(admin.get("ephemeral_admin_cert")) else: # For client/server, the FL server endpoint is in servers[0].service.target servers = fed_config.get("servers", []) @@ -152,25 +153,30 @@ def __call__(self, package_path, data): ) host, port = target.split(":")[0], int(target.split(":")[1]) scheme = get_communication_scheme(package_path, nvf_config, default_scheme="grpc") + uses_ephemeral_admin_cert = False - # Check connectivity based on the communication scheme - if scheme in ["grpc", "agrpc"]: - if not check_grpc_server_running(startup=startup, host=host, port=int(port)): - return CheckResult( - f"Can't connect to {scheme} server ({host}:{port})", - "Please check if server is up.", - ) - elif scheme in ["http", "https", "tcp", "stcp"]: - # HTTP/HTTPS use WebSocket, TCP/STCP use raw sockets - both checked via socket connection - if not check_socket_server_running(startup=startup, host=host, port=int(port), scheme=scheme): - return CheckResult( - f"Can't connect to {scheme} server ({host}:{port})", - "Please check if server is up.", - ) - else: + supported_schemes = {"grpc", "agrpc", "http", "https", "tcp", "stcp"} + if scheme not in supported_schemes: return CheckResult( f"Unsupported communication scheme: {scheme}", f"Scheme '{scheme}' is not supported for connectivity check.", ) + # Check connectivity based on the communication scheme + if uses_ephemeral_admin_cert: + # Preflight must not trigger interactive SSO. A TCP connection proves + # endpoint reachability; the real command validates mTLS after login. + server_running = check_socket_server_running(startup=startup, host=host, port=int(port), scheme="tcp") + elif scheme in ["grpc", "agrpc"]: + server_running = check_grpc_server_running(startup=startup, host=host, port=int(port)) + else: + server_running = check_socket_server_running(startup=startup, host=host, port=int(port), scheme=scheme) + + if not server_running: + probe = "TCP reachability" if uses_ephemeral_admin_cert else scheme + return CheckResult( + f"Can't connect to {scheme} server ({host}:{port}) using {probe}", + "Please check if server is up.", + ) + return CheckResult(CHECK_PASSED, "N/A") diff --git a/nvflare/tool/package_checker/nvflare_console_package_checker.py b/nvflare/tool/package_checker/nvflare_console_package_checker.py index 03e5ea8581..5a9da11fa2 100644 --- a/nvflare/tool/package_checker/nvflare_console_package_checker.py +++ b/nvflare/tool/package_checker/nvflare_console_package_checker.py @@ -12,9 +12,15 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os +import shutil + +from nvflare.fuel.sec.ephemeral_admin_cert import validate_ephemeral_admin_cert_config +from nvflare.fuel.sec.step_ca_admin_cert import validate_step_ca_admin_cert_config from .client_package_checker import ClientPackageChecker +from .package_checker import CheckStatus from .utils import NVFlareConfig, NVFlareRole @@ -27,3 +33,38 @@ def get_dry_run_command(self) -> str: def get_dry_run_inputs(self): return os.path.basename(os.path.normpath(self.package_path)) + + def check_dry_run(self) -> CheckStatus: + startup = os.path.join(self.package_path, "startup") + try: + with open(os.path.join(startup, NVFlareConfig.ADMIN), "r") as f: + config = json.load(f) + except (OSError, json.JSONDecodeError): + return super().check_dry_run() + ephemeral_config = config.get("admin", {}).get("ephemeral_admin_cert") + if ephemeral_config: + try: + ephemeral_config = validate_ephemeral_admin_cert_config(ephemeral_config) + root_ca_file = os.path.join(startup, "rootCA.pem") + if not os.path.isfile(root_ca_file): + raise ValueError(f"missing project root certificate: {root_ca_file}") + if ephemeral_config["provider"] == "step_ca": + provider_config = validate_step_ca_admin_cert_config(ephemeral_config["provider_config"]) + step_bin = str(provider_config.get("step_bin") or "step") + if not shutil.which(step_bin): + raise ValueError(f"step CLI is not available: {step_bin}") + except ValueError as ex: + self.add_report( + "Check ephemeral admin certificate", + str(ex), + "Correct the startup kit configuration and install the configured certificate provider.", + ) + return CheckStatus.FAIL + + self.add_report( + "Check dry run", + "SKIPPED", + "Certificate acquisition requires interactive login; run an NVFlare command to test it.", + ) + return CheckStatus.PASS + return super().check_dry_run() diff --git a/tests/unit_test/fuel/hci/client/test_push_folder_key_guard.py b/tests/unit_test/fuel/hci/client/test_push_folder_key_guard.py index de5aabce85..fcd9cc926c 100644 --- a/tests/unit_test/fuel/hci/client/test_push_folder_key_guard.py +++ b/tests/unit_test/fuel/hci/client/test_push_folder_key_guard.py @@ -45,6 +45,7 @@ def _make_push_folder_args_and_ctx(key_path, cert_path, folder_name="test_job"): api = MagicMock() api.client_key = key_path api.client_cert = cert_path + api.ephemeral_admin_cert_config = None ctx = MagicMock() ctx.get_command_entry.return_value = _make_cmd_entry() @@ -194,3 +195,61 @@ def test_push_folder_preserves_submit_args_after_folder(tmp_path): server_execute.assert_called_once() command = server_execute.call_args.args[0] assert command == "admin.push_folder test_job --submit-token retry-1" + + +def test_push_folder_refreshes_ephemeral_cert_before_signing(tmp_path): + upload_dir = str(tmp_path / "upload") + download_dir = str(tmp_path / "dl") + folder_name = "test_job" + os.makedirs(os.path.join(upload_dir, folder_name), exist_ok=True) + os.makedirs(download_dir, exist_ok=True) + + key_file = tmp_path / "test.key" + key_file.write_text("fake key content") + module = FileTransferModule(upload_dir=upload_dir, download_dir=download_dir) + args, ctx = _make_push_folder_args_and_ctx(str(key_file), "/path/to/cert.crt", folder_name) + api = ctx.get_api.return_value + api.ensure_client_cert_valid = MagicMock() + api.ephemeral_admin_cert_config = {"provider": "step_ca"} + + with ( + patch("nvflare.fuel.hci.client.file_transfer.load_private_key_file", return_value=MagicMock()), + patch("nvflare.fuel.hci.client.file_transfer.sign_folders"), + patch("nvflare.fuel.hci.client.file_transfer.zip_directory_to_file"), + patch.object(api, "server_execute", return_value={}) as server_execute, + ): + module.push_folder(args, ctx) + + api.ensure_client_cert_valid.assert_called_once_with() + assert server_execute.call_args.args[0] == "admin.push_folder test_job --ephemeral-admin-cert" + + +def test_push_folder_reconnects_when_cell_is_missing_after_failed_renewal_reconnect(tmp_path): + from nvflare.fuel.hci.client.api_status import APIStatus + + upload_dir = str(tmp_path / "upload") + download_dir = str(tmp_path / "dl") + folder_name = "test_job" + os.makedirs(os.path.join(upload_dir, folder_name), exist_ok=True) + os.makedirs(download_dir, exist_ok=True) + + key_file = tmp_path / "test.key" + key_file.write_text("fake key content") + module = FileTransferModule(upload_dir=upload_dir, download_dir=download_dir) + args, ctx = _make_push_folder_args_and_ctx(str(key_file), "/path/to/cert.crt", folder_name) + api = ctx.get_api.return_value + api.ensure_client_cert_valid = MagicMock(return_value=False) + api.cell = None + api.login.return_value = {"status": APIStatus.SUCCESS} + + with ( + patch("nvflare.fuel.hci.client.file_transfer.load_private_key_file", return_value=MagicMock()), + patch("nvflare.fuel.hci.client.file_transfer.sign_folders"), + patch("nvflare.fuel.hci.client.file_transfer.zip_directory_to_file"), + patch.object(api, "server_execute", return_value={}), + ): + result = module.push_folder(args, ctx) + + assert result == {} + api.connect.assert_called_once_with() + api.login.assert_called_once_with() diff --git a/tests/unit_test/fuel/hci/client_api_props_test.py b/tests/unit_test/fuel/hci/client_api_props_test.py index 3be9c33730..a663104b38 100644 --- a/tests/unit_test/fuel/hci/client_api_props_test.py +++ b/tests/unit_test/fuel/hci/client_api_props_test.py @@ -51,6 +51,7 @@ def test_user_login_sends_study_header(monkeypatch): api.user_name = "admin@nvidia.com" api.study = "cancer-research" api.login_result = None + api.cell = object() captured = {} class _FakeIdentityAsserter: @@ -88,6 +89,7 @@ def test_user_login_defaults_study_header(monkeypatch): api.user_name = "admin@nvidia.com" api.study = DEFAULT_STUDY api.login_result = None + api.cell = object() captured = {} class _FakeIdentityAsserter: @@ -121,6 +123,7 @@ def test_user_login_parses_structured_reject_code(monkeypatch): api.user_name = "admin@nvidia.com" api.study = "cancer-research" api.login_result = None + api.cell = object() class _FakeIdentityAsserter: cert_data = "cert-data" diff --git a/tests/unit_test/fuel/hci/ephemeral_admin_cert_api_test.py b/tests/unit_test/fuel/hci/ephemeral_admin_cert_api_test.py new file mode 100644 index 0000000000..19fcf58b5e --- /dev/null +++ b/tests/unit_test/fuel/hci/ephemeral_admin_cert_api_test.py @@ -0,0 +1,280 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nvflare.fuel.common.excepts import ConfigError +from nvflare.fuel.hci.client.api import AdminAPI +from nvflare.fuel.hci.client.api_spec import AdminConfigKey +from nvflare.fuel.hci.proto import InternalCommands +from nvflare.fuel.sec.ephemeral_admin_cert import EphemeralAdminCertFiles + + +def test_admin_api_hydrates_missing_cert_pair_from_step_ca(monkeypatch, tmp_path): + key_file = tmp_path / "client.key" + cert_file = tmp_path / "client.crt" + key_file.write_text("key", encoding="utf-8") + cert_file.write_text("cert", encoding="utf-8") + resolved_files = EphemeralAdminCertFiles(client_key=str(key_file), client_cert=str(cert_file)) + captured = {} + + def _fake_obtain(config, root_ca_file): + captured["config"] = config + captured["root_ca_file"] = root_ca_file + return resolved_files + + monkeypatch.setattr("nvflare.fuel.hci.client.api.obtain_ephemeral_admin_cert_files", _fake_obtain) + + api = AdminAPI( + user_name="alice@nvidia.com", + admin_config={ + AdminConfigKey.PROJECT_NAME: "project", + AdminConfigKey.CA_CERT: "rootCA.pem", + AdminConfigKey.EPHEMERAL_ADMIN_CERT: { + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + }, + }, + }, + cmd_modules=[], + ) + + assert api.client_key == str(key_file) + assert api.client_cert == str(cert_file) + assert api.ephemeral_admin_cert_files is resolved_files + assert "subject" not in captured["config"] + assert captured["root_ca_file"] == "rootCA.pem" + + +def test_admin_api_reports_invalid_ephemeral_renewal_window_as_config_error(): + with pytest.raises(ConfigError, match="renewal_window must be a number"): + AdminAPI( + user_name="alice@nvidia.com", + admin_config={ + AdminConfigKey.PROJECT_NAME: "project", + AdminConfigKey.CA_CERT: "rootCA.pem", + AdminConfigKey.EPHEMERAL_ADMIN_CERT: { + "provider": "step_ca", + "renewal_window": "one minute", + "provider_config": {}, + }, + }, + cmd_modules=[], + ) + + +def test_admin_api_renews_expiring_ephemeral_cert_and_resets_connection(monkeypatch, tmp_path): + old_key_file = tmp_path / "old.key" + old_cert_file = tmp_path / "old.crt" + new_key_file = tmp_path / "new.key" + new_cert_file = tmp_path / "new.crt" + for path in (old_key_file, old_cert_file, new_key_file, new_cert_file): + path.write_text(path.name, encoding="utf-8") + + class _FakeCell: + stopped = False + + def stop(self): + self.stopped = True + + old_files = EphemeralAdminCertFiles( + client_key=str(old_key_file), + client_cert=str(old_cert_file), + expires_at=1.0, + ) + new_files = EphemeralAdminCertFiles( + client_key=str(new_key_file), + client_cert=str(new_cert_file), + expires_at=9999999999.0, + ) + issued_files = [old_files, new_files] + + def _fake_obtain(config, root_ca_file): + return issued_files.pop(0) + + monkeypatch.setattr("nvflare.fuel.hci.client.api.obtain_ephemeral_admin_cert_files", _fake_obtain) + monkeypatch.setattr("nvflare.fuel.hci.client.api.load_cert_file", lambda path: path) + monkeypatch.setattr( + "nvflare.fuel.hci.client.api.get_cn_from_cert", + lambda path: "alice@nvidia.com" if path == str(old_cert_file) else "bob@nvidia.com", + ) + + api = AdminAPI( + user_name="alice@nvidia.com", + admin_config={ + AdminConfigKey.PROJECT_NAME: "project", + AdminConfigKey.CA_CERT: "rootCA.pem", + AdminConfigKey.UID_SOURCE: "cert", + AdminConfigKey.EPHEMERAL_ADMIN_CERT: { + "provider": "step_ca", + "renewal_window": 60.0, + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + }, + }, + }, + cmd_modules=[], + ) + + assert api.client_key == str(old_key_file) + assert api.user_name == "alice@nvidia.com" + assert api.fl_ctx_mgr.identity_name == "alice@nvidia.com" + old_cell = _FakeCell() + api.cell = old_cell + api.aux_runner = object() + api.server_sess_active = True + api.token = "old-token" + api.login_result = "old-result" + + assert api.ensure_client_cert_valid() + + assert api.client_key == str(new_key_file) + assert api.client_cert == str(new_cert_file) + assert api.user_name == "bob@nvidia.com" + assert api.fl_ctx_mgr.identity_name == "bob@nvidia.com" + assert old_cell.stopped + assert api.cell is None + assert api.aux_runner is None + assert not api.server_sess_active + assert api.token is None + assert api.login_result is None + + +def test_user_login_builds_command_after_identity_renewal(monkeypatch): + class _FakeIdentityAsserter: + cert_data = b"cert" + + def __init__(self, private_key_file, cert_file): + pass + + def sign_common_name(self, nonce): + return b"signature" + + api = object.__new__(AdminAPI) + api.user_name = "alice@nvidia.com" + api.client_key = "client.key" + api.client_cert = "client.crt" + api.study = "default" + api.cell = object() + api.login_result = None + captured = {} + + def _renew(): + api.user_name = "bob@nvidia.com" + api.cell = None + return True + + def _connect(): + api.cell = object() + + def _server_execute(command, reply_processor, headers): + captured["command"] = command + captured["headers"] = headers + api.login_result = "REJECT" + + api.ensure_client_cert_valid = _renew + api.connect = _connect + api.server_execute = _server_execute + monkeypatch.setattr("nvflare.fuel.hci.client.api.IdentityAsserter", _FakeIdentityAsserter) + + api._user_login() + + assert captured["command"] == f"{InternalCommands.CERT_LOGIN} bob@nvidia.com" + assert captured["headers"]["user_name"] == "bob@nvidia.com" + + +def test_user_login_reconnects_when_cell_is_missing_without_another_renewal(monkeypatch): + class _FakeIdentityAsserter: + cert_data = b"cert" + + def __init__(self, private_key_file, cert_file): + pass + + def sign_common_name(self, nonce): + return b"signature" + + api = object.__new__(AdminAPI) + api.user_name = "alice@nvidia.com" + api.client_key = "client.key" + api.client_cert = "client.crt" + api.study = "default" + api.cell = None + api.login_result = None + connected = [] + + api.ensure_client_cert_valid = lambda: False + + def _connect(): + connected.append(True) + api.cell = object() + + def _server_execute(command, reply_processor, headers): + api.login_result = "REJECT" + + api.connect = _connect + api.server_execute = _server_execute + monkeypatch.setattr("nvflare.fuel.hci.client.api.IdentityAsserter", _FakeIdentityAsserter) + + api._user_login() + + assert connected == [True] + + +def test_connect_cleans_up_partial_cell_after_authentication_failure(monkeypatch): + class _FakeCell: + stopped = False + + def __init__(self, **kwargs): + pass + + def register_request_cb(self, **kwargs): + pass + + def start(self): + pass + + def stop(self): + self.stopped = True + + class _FailingAuthenticator: + def __init__(self, **kwargs): + pass + + def authenticate(self, **kwargs): + raise RuntimeError("authentication failed") + + monkeypatch.setattr("nvflare.fuel.hci.client.api.Cell", _FakeCell) + monkeypatch.setattr("nvflare.fuel.hci.client.api.NetAgent", lambda cell: None) + monkeypatch.setattr("nvflare.fuel.hci.client.api.Authenticator", _FailingAuthenticator) + + api = AdminAPI( + user_name="alice@nvidia.com", + admin_config={ + AdminConfigKey.PROJECT_NAME: "project", + AdminConfigKey.CA_CERT: "rootCA.pem", + AdminConfigKey.CLIENT_CERT: "client.crt", + AdminConfigKey.CLIENT_KEY: "client.key", + }, + cmd_modules=[], + ) + + with pytest.raises(RuntimeError, match="authentication failed"): + api.connect() + + assert api.cell is None + assert api.aux_runner is None + assert api.object_streamer is None diff --git a/tests/unit_test/fuel/sec/ephemeral_admin_cert_test.py b/tests/unit_test/fuel/sec/ephemeral_admin_cert_test.py new file mode 100644 index 0000000000..b79837c8a2 --- /dev/null +++ b/tests/unit_test/fuel/sec/ephemeral_admin_cert_test.py @@ -0,0 +1,525 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import datetime +import json +import os +import threading +import time +from concurrent.futures import ThreadPoolExecutor + +import pytest +from cryptography.x509.oid import NameOID + +from nvflare.fuel.sec.ephemeral_admin_cert import ( + EphemeralAdminCertError, + EphemeralAdminCertFiles, + obtain_ephemeral_admin_cert_files, + validate_ephemeral_admin_cert_files, +) +from nvflare.fuel.sec.step_ca_admin_cert import ( + DEFAULT_STEP_CA_REQUEST_NAME, + obtain_step_ca_admin_cert_files, + validate_step_ca_admin_cert_config, +) +from nvflare.lighter.tool_consts import NVFLARE_SUBMITTER_CRT_FILE +from nvflare.lighter.utils import ( + Identity, + generate_cert, + generate_keys, + load_crt_chain, + load_private_key_file, + serialize_cert, + serialize_pri_key, + sign_folders, + verify_folder_signature, +) +from nvflare.private.fed.utils.identity_utils import IdentityAsserter, IdentityVerifier, load_cert_file + + +@pytest.fixture(autouse=True) +def _use_test_home(monkeypatch, tmp_path): + home_dir = tmp_path / "home" + home_dir.mkdir() + monkeypatch.setenv("HOME", str(home_dir)) + + +def _make_root_ca(tmp_path): + ca_key, ca_pub_key = generate_keys() + ca_cert = generate_cert( + subject=Identity("root", "nvidia"), + issuer=Identity("root", "nvidia"), + signing_pri_key=ca_key, + subject_pub_key=ca_pub_key, + ca=True, + ) + root_ca_path = tmp_path / "rootCA.pem" + root_ca_path.write_bytes(serialize_cert(ca_cert)) + return ca_key, ca_cert, root_ca_path + + +def _write_key_cert(key_path, cert_path, private_key, certs): + with open(key_path, "wb") as f: + f.write(serialize_pri_key(private_key)) + with open(cert_path, "wb") as f: + for cert in certs: + f.write(serialize_cert(cert)) + + +def _make_admin_cert_files( + tmp_path, + signing_key, + issuer_identity, + chain=(), + issued_subject="alice@nvidia.com", + role="lead", + ca=False, + extra_extensions=None, +): + admin_key, admin_pub_key = generate_keys() + now = datetime.datetime.now(datetime.timezone.utc) + cert = generate_cert( + subject=Identity(issued_subject, "nvidia", role), + issuer=issuer_identity, + signing_pri_key=signing_key, + subject_pub_key=admin_pub_key, + not_valid_before=now - datetime.timedelta(seconds=1), + not_valid_after=now + datetime.timedelta(hours=1), + ca=ca, + extra_extensions=extra_extensions, + ) + cert_src = tmp_path / "issued.crt" + key_src = tmp_path / "issued.key" + _write_key_cert(key_src, cert_src, admin_key, [cert, *chain]) + return cert_src, key_src + + +def _custom_ephemeral_provider(config, root_ca_file): + return EphemeralAdminCertFiles( + client_key=config["key_path"], + client_cert=config["cert_path"], + expires_at=config.get("expires_at", 0.0), + ) + + +def _fake_step( + monkeypatch, + tmp_path, + cert_src=None, + key_src=None, + command_log=None, + exit_code=None, + sleep=None, +): + fake_step = tmp_path / "step" + fake_step.write_text( + "\n".join( + [ + "#!/usr/bin/env python3", + "import json, os, shutil, sys, time", + "if os.environ.get('NVFLARE_TEST_STEP_SLEEP'):", + " time.sleep(float(os.environ['NVFLARE_TEST_STEP_SLEEP']))", + "if os.environ.get('NVFLARE_TEST_STEP_EXIT'):", + " raise SystemExit(int(os.environ['NVFLARE_TEST_STEP_EXIT']))", + "if os.environ.get('NVFLARE_TEST_STEP_COMMAND_LOG'):", + " with open(os.environ['NVFLARE_TEST_STEP_COMMAND_LOG'], 'a') as f:", + " f.write(json.dumps(sys.argv[1:]) + '\\n')", + "if sys.argv[1:3] == ['ca', 'certificate']:", + " shutil.copyfile(os.environ['NVFLARE_TEST_STEP_CERT'], sys.argv[-2])", + " shutil.copyfile(os.environ['NVFLARE_TEST_STEP_KEY'], sys.argv[-1])", + "else:", + " raise SystemExit(3)", + ] + ), + encoding="utf-8", + ) + fake_step.chmod(0o755) + if cert_src: + monkeypatch.setenv("NVFLARE_TEST_STEP_CERT", str(cert_src)) + if key_src: + monkeypatch.setenv("NVFLARE_TEST_STEP_KEY", str(key_src)) + if command_log: + monkeypatch.setenv("NVFLARE_TEST_STEP_COMMAND_LOG", str(command_log)) + if exit_code is not None: + monkeypatch.setenv("NVFLARE_TEST_STEP_EXIT", str(exit_code)) + if sleep is not None: + monkeypatch.setenv("NVFLARE_TEST_STEP_SLEEP", str(sleep)) + return fake_step + + +def _read_command_log(command_log): + return [json.loads(line) for line in command_log.read_text(encoding="utf-8").splitlines()] + + +def test_step_ca_source_invokes_step_and_cert_works_with_existing_flare_paths(monkeypatch, tmp_path): + ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + cert_src, key_src = _make_admin_cert_files(tmp_path, signing_key=ca_key, issuer_identity=Identity("root", "nvidia")) + command_log = tmp_path / "commands.jsonl" + fake_step = _fake_step(monkeypatch, tmp_path, cert_src=cert_src, key_src=key_src, command_log=command_log) + + files = obtain_ephemeral_admin_cert_files( + config={ + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + "cert_ttl": "1h", + "step_bin": str(fake_step), + }, + }, + root_ca_file=str(root_ca_path), + ) + + try: + command = _read_command_log(command_log)[0] + assert command[:2] == ["ca", "certificate"] + assert "--ca-url" in command + assert "https://step-ca.example.com" in command + assert "--provisioner" in command + assert "nvflare-admin-oidc" in command + assert "--token" not in command + assert "--kty" in command + assert "RSA" in command + assert command[-3] == DEFAULT_STEP_CA_REQUEST_NAME + assert os.path.isfile(files.client_key) + assert os.path.isfile(files.client_cert) + assert files.expires_at > time.time() + + asserter = IdentityAsserter(private_key_file=files.client_key, cert_file=files.client_cert) + signature = asserter.sign_common_name(nonce="") + cert_chain = load_crt_chain(files.client_cert) + assert IdentityVerifier(root_cert_file=str(root_ca_path)).verify_common_name( + asserted_cn="alice@nvidia.com", + nonce="", + asserter_cert=asserter.cert, + signature=signature, + intermediate_certs=cert_chain[1:], + ) + + cert = load_cert_file(files.client_cert) + assert cert.subject.get_attributes_for_oid(NameOID.ORGANIZATION_NAME)[0].value == "nvidia" + assert cert.subject.get_attributes_for_oid(NameOID.UNSTRUCTURED_NAME)[0].value == "lead" + + job_dir = tmp_path / "job" + job_dir.mkdir() + (job_dir / "config.json").write_text(json.dumps({"name": "job"}), encoding="utf-8") + sign_folders(str(job_dir), asserter.pri_key, files.client_cert) + assert (job_dir / NVFLARE_SUBMITTER_CRT_FILE).is_file() + assert verify_folder_signature(str(job_dir), str(root_ca_path)) + finally: + files.cleanup() + + +def test_step_ca_source_accepts_intermediate_chain(monkeypatch, tmp_path): + root_key, _root_cert, root_ca_path = _make_root_ca(tmp_path) + intermediate_key, intermediate_pub_key = generate_keys() + intermediate_cert = generate_cert( + subject=Identity("step-ca-intermediate", "nvidia"), + issuer=Identity("root", "nvidia"), + signing_pri_key=root_key, + subject_pub_key=intermediate_pub_key, + ca=True, + ) + cert_src, key_src = _make_admin_cert_files( + tmp_path, + signing_key=intermediate_key, + issuer_identity=Identity("step-ca-intermediate", "nvidia"), + chain=(intermediate_cert,), + ) + command_log = tmp_path / "commands.jsonl" + fake_step = _fake_step(monkeypatch, tmp_path, cert_src=cert_src, key_src=key_src, command_log=command_log) + + files = obtain_ephemeral_admin_cert_files( + config={ + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + "step_bin": str(fake_step), + }, + }, + root_ca_file=str(root_ca_path), + ) + + try: + command = _read_command_log(command_log)[0] + assert "24h" in command + assert len(load_crt_chain(files.client_cert)) == 2 + job_dir = tmp_path / "job" + job_dir.mkdir() + (job_dir / "config.json").write_text(json.dumps({"name": "job"}), encoding="utf-8") + sign_folders(str(job_dir), load_private_key_file(files.client_key), files.client_cert) + assert verify_folder_signature(str(job_dir), str(root_ca_path)) + finally: + files.cleanup() + + +def test_custom_ephemeral_admin_cert_provider_path_is_supported(tmp_path): + ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + cert_path, key_path = _make_admin_cert_files( + tmp_path, + signing_key=ca_key, + issuer_identity=Identity("root", "nvidia"), + ) + + files = obtain_ephemeral_admin_cert_files( + config={ + "provider": "tests.unit_test.fuel.sec.ephemeral_admin_cert_test:_custom_ephemeral_provider", + "provider_config": { + "cert_path": str(cert_path), + "key_path": str(key_path), + }, + }, + root_ca_file=str(root_ca_path), + ) + + assert os.path.isfile(files.client_cert) + assert os.path.isfile(files.client_key) + assert files.client_cert != str(cert_path) + assert files.client_key != str(key_path) + assert files.expires_at > time.time() + + +def test_ephemeral_admin_cert_uses_validated_certificate_expiry(tmp_path): + ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + cert_path, key_path = _make_admin_cert_files( + tmp_path, + signing_key=ca_key, + issuer_identity=Identity("root", "nvidia"), + ) + + files = obtain_ephemeral_admin_cert_files( + config={ + "provider": "tests.unit_test.fuel.sec.ephemeral_admin_cert_test:_custom_ephemeral_provider", + "provider_config": { + "cert_path": str(cert_path), + "key_path": str(key_path), + "expires_at": time.time() + 86400, + }, + }, + root_ca_file=str(root_ca_path), + ) + + cert = load_cert_file(files.client_cert) + assert files.expires_at == cert.not_valid_after_utc.timestamp() + + +def test_ephemeral_admin_cert_cache_reuses_valid_cert(monkeypatch, tmp_path): + ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + cert_src, key_src = _make_admin_cert_files(tmp_path, signing_key=ca_key, issuer_identity=Identity("root", "nvidia")) + command_log = tmp_path / "commands.jsonl" + fake_step = _fake_step(monkeypatch, tmp_path, cert_src=cert_src, key_src=key_src, command_log=command_log) + config = { + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + "step_bin": str(fake_step), + }, + } + + first = obtain_ephemeral_admin_cert_files(config=config, root_ca_file=str(root_ca_path)) + second = obtain_ephemeral_admin_cert_files(config=config, root_ca_file=str(root_ca_path)) + + assert len(_read_command_log(command_log)) == 1 + assert first.client_cert == second.client_cert + assert first.client_key == second.client_key + + +def test_ephemeral_admin_cert_cache_refreshes_inside_renewal_window(monkeypatch, tmp_path): + ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + cert_src, key_src = _make_admin_cert_files(tmp_path, signing_key=ca_key, issuer_identity=Identity("root", "nvidia")) + command_log = tmp_path / "commands.jsonl" + fake_step = _fake_step(monkeypatch, tmp_path, cert_src=cert_src, key_src=key_src, command_log=command_log) + config = { + "provider": "step_ca", + "renewal_window": 7200, + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + "step_bin": str(fake_step), + }, + } + + first = obtain_ephemeral_admin_cert_files(config=config, root_ca_file=str(root_ca_path)) + second = obtain_ephemeral_admin_cert_files(config=config, root_ca_file=str(root_ca_path)) + + assert len(_read_command_log(command_log)) == 2 + assert first.client_cert != second.client_cert + assert os.path.isfile(first.client_cert) + assert os.path.isfile(first.client_key) + + +def test_ephemeral_admin_cert_cache_refreshes_invalid_cache(monkeypatch, tmp_path): + ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + cert_src, key_src = _make_admin_cert_files(tmp_path, signing_key=ca_key, issuer_identity=Identity("root", "nvidia")) + command_log = tmp_path / "commands.jsonl" + fake_step = _fake_step(monkeypatch, tmp_path, cert_src=cert_src, key_src=key_src, command_log=command_log) + config = { + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + "step_bin": str(fake_step), + }, + } + + files = obtain_ephemeral_admin_cert_files(config=config, root_ca_file=str(root_ca_path)) + with open(files.client_key, "w", encoding="utf-8") as f: + f.write("not a private key") + + refreshed_files = obtain_ephemeral_admin_cert_files(config=config, root_ca_file=str(root_ca_path)) + + assert len(_read_command_log(command_log)) == 2 + assert os.path.isfile(refreshed_files.client_cert) + assert os.path.isfile(refreshed_files.client_key) + + +def test_ephemeral_admin_cert_cache_serializes_concurrent_acquisition(monkeypatch, tmp_path): + ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + cert_src, key_src = _make_admin_cert_files(tmp_path, signing_key=ca_key, issuer_identity=Identity("root", "nvidia")) + provider_started = threading.Event() + provider_release = threading.Event() + provider_calls = 0 + + def _provider(config, root_ca_file): + nonlocal provider_calls + provider_calls += 1 + provider_started.set() + assert provider_release.wait(timeout=5) + return EphemeralAdminCertFiles(client_key=str(key_src), client_cert=str(cert_src)) + + monkeypatch.setattr("nvflare.fuel.sec.ephemeral_admin_cert._load_provider", lambda _provider_name: _provider) + config = {"provider": "test.provider:obtain_certificate", "provider_config": {}} + + with ThreadPoolExecutor(max_workers=2) as executor: + first_future = executor.submit(obtain_ephemeral_admin_cert_files, config, str(root_ca_path)) + assert provider_started.wait(timeout=5) + second_future = executor.submit(obtain_ephemeral_admin_cert_files, config, str(root_ca_path)) + provider_release.set() + first = first_future.result(timeout=5) + second = second_future.result(timeout=5) + + assert provider_calls == 1 + assert first.client_cert == second.client_cert + assert first.client_key == second.client_key + + +def test_validate_ephemeral_admin_cert_files_rejects_missing_role(tmp_path): + ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + admin_key, admin_pub_key = generate_keys() + cert = generate_cert( + subject=Identity("alice@nvidia.com", "nvidia"), + issuer=Identity("root", "nvidia"), + signing_pri_key=ca_key, + subject_pub_key=admin_pub_key, + ) + key_path = tmp_path / "client.key" + cert_path = tmp_path / "client.crt" + _write_key_cert(key_path, cert_path, admin_key, [cert]) + + with pytest.raises(EphemeralAdminCertError, match="unstructuredName"): + validate_ephemeral_admin_cert_files(str(cert_path), str(key_path), str(root_ca_path)) + + +def test_step_ca_source_reports_step_failure(monkeypatch, tmp_path): + _ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + fake_step = _fake_step(monkeypatch, tmp_path, exit_code=1) + + with pytest.raises(EphemeralAdminCertError, match="step ca certificate failed"): + obtain_ephemeral_admin_cert_files( + config={ + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + "step_bin": str(fake_step), + }, + }, + root_ca_file=str(root_ca_path), + ) + + +def test_step_ca_source_requires_explicit_provisioner(monkeypatch, tmp_path): + ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + cert_src, key_src = _make_admin_cert_files(tmp_path, signing_key=ca_key, issuer_identity=Identity("root", "nvidia")) + fake_step = _fake_step(monkeypatch, tmp_path, cert_src=cert_src, key_src=key_src) + + with pytest.raises(EphemeralAdminCertError, match="provisioner"): + obtain_step_ca_admin_cert_files( + config={"ca_url": "https://step-ca.example.com", "step_bin": str(fake_step)}, + root_ca_file=str(root_ca_path), + ) + + +def test_step_ca_source_requires_ca_url(): + with pytest.raises(EphemeralAdminCertError, match="ca_url is required"): + validate_step_ca_admin_cert_config({"provisioner": "nvflare-admin-oidc"}) + + +def test_step_ca_source_accepts_ipv6_loopback_http(): + config = {"ca_url": "http://[::1]:9000", "provisioner": "nvflare-admin-oidc"} + + assert validate_step_ca_admin_cert_config(config) == config + + +def test_step_ca_source_times_out_step_command(monkeypatch, tmp_path): + _ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + fake_step = _fake_step(monkeypatch, tmp_path, sleep=2) + + with pytest.raises(EphemeralAdminCertError, match="timed out"): + obtain_step_ca_admin_cert_files( + config={ + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + "step_bin": str(fake_step), + "command_timeout": 0.01, + }, + root_ca_file=str(root_ca_path), + ) + + +def test_step_ca_source_rejects_placeholder_common_name(monkeypatch, tmp_path): + ca_key, _ca_cert, root_ca_path = _make_root_ca(tmp_path) + cert_src, key_src = _make_admin_cert_files( + tmp_path, + signing_key=ca_key, + issuer_identity=Identity("root", "nvidia"), + issued_subject=DEFAULT_STEP_CA_REQUEST_NAME, + ) + fake_step = _fake_step(monkeypatch, tmp_path, cert_src=cert_src, key_src=key_src) + + with pytest.raises(EphemeralAdminCertError, match="commonName"): + obtain_ephemeral_admin_cert_files( + config={ + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + "step_bin": str(fake_step), + }, + }, + root_ca_file=str(root_ca_path), + ) + + +def test_ephemeral_cert_files_need_renewal_inside_window(): + cert_files = EphemeralAdminCertFiles( + client_key="client.key", + client_cert="client.crt", + expires_at=time.time() + 30.0, + ) + + assert cert_files.needs_renewal(renewal_window=60.0) + assert not cert_files.needs_renewal(renewal_window=10.0) diff --git a/tests/unit_test/lighter/cert_builder_ephemeral_admin_test.py b/tests/unit_test/lighter/cert_builder_ephemeral_admin_test.py new file mode 100644 index 0000000000..bae63a194d --- /dev/null +++ b/tests/unit_test/lighter/cert_builder_ephemeral_admin_test.py @@ -0,0 +1,49 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from nvflare.lighter.constants import ParticipantType +from nvflare.lighter.ctx import ProvisionContext +from nvflare.lighter.entity import Participant, Project +from nvflare.lighter.impl.cert import CertBuilder + + +def test_cert_builder_omits_static_admin_cert_for_ephemeral_cert(tmp_path): + server = Participant(type=ParticipantType.SERVER, name="server", org="nvidia") + admin = Participant( + type=ParticipantType.ADMIN, + name="sso-admin-kit", + org=None, + props={ + "ephemeral_admin_cert": { + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + }, + }, + }, + ) + project = Project(name="project", description="desc", participants=[server, admin]) + ctx = ProvisionContext(workspace_root_dir=str(tmp_path), project=project) + for participant in (server, admin): + (tmp_path / "wip" / participant.name / "startup").mkdir(parents=True) + builder = CertBuilder() + + builder.initialize(project, ctx) + builder.build(project, ctx) + + admin_startup = tmp_path / "wip" / "sso-admin-kit" / "startup" + assert (admin_startup / "rootCA.pem").is_file() + assert not (admin_startup / "client.key").exists() + assert not (admin_startup / "client.crt").exists() diff --git a/tests/unit_test/lighter/ephemeral_admin_test.py b/tests/unit_test/lighter/ephemeral_admin_test.py new file mode 100644 index 0000000000..6f27feedd3 --- /dev/null +++ b/tests/unit_test/lighter/ephemeral_admin_test.py @@ -0,0 +1,148 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest + +from nvflare.lighter.constants import ParticipantType +from nvflare.lighter.entity import Participant, Project +from nvflare.lighter.ephemeral_admin import get_admin_ephemeral_cert_config + + +def _project(props=None, admin_props=None): + server = Participant(type=ParticipantType.SERVER, name="server", org="nvidia") + admin_props = admin_props or {} + if "ephemeral_admin_cert" in admin_props: + participant_props = admin_props + org = None + else: + participant_props = {"role": "project_admin", **admin_props} + org = "nvidia" + admin = Participant( + type=ParticipantType.ADMIN, + name="admin@example.com", + org=org, + props=participant_props, + ) + project = Project(name="project", description="desc", participants=[server, admin], props=props or {}) + return project, admin + + +def _ephemeral_cert_config(ca_url="https://step-ca.example.com", provisioner="nvflare-admin-oidc"): + provider_config = {"ca_url": ca_url} + if provisioner: + provider_config["provisioner"] = provisioner + return { + "provider": "step_ca", + "renewal_window": 60, + "provider_config": provider_config, + } + + +def test_admin_without_ephemeral_cert_config_has_no_ephemeral_cert_config(): + project, admin = _project() + + assert get_admin_ephemeral_cert_config(admin) is None + + +def test_per_admin_ephemeral_cert_config_supplies_admin_config(): + _project_obj, admin = _project( + admin_props={ + "ephemeral_admin_cert": { + "provider": "step_ca", + "renewal_window": 60, + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + "cert_ttl": "1h", + }, + } + }, + ) + + assert get_admin_ephemeral_cert_config(admin)["provider_config"]["ca_url"] == "https://step-ca.example.com" + assert get_admin_ephemeral_cert_config(admin)["provider_config"]["cert_ttl"] == "1h" + assert "subject" not in get_admin_ephemeral_cert_config(admin) + + +def test_per_admin_ephemeral_cert_config_allows_admin_kit_name(): + admin = Participant( + type=ParticipantType.ADMIN, + name="sso-admin-kit", + org=None, + props={"ephemeral_admin_cert": _ephemeral_cert_config()}, + ) + + assert admin.name == "sso-admin-kit" + assert get_admin_ephemeral_cert_config(admin)["provider"] == "step_ca" + + +def test_per_admin_ephemeral_cert_config_is_used_directly(): + _project_obj, admin = _project( + admin_props={ + "ephemeral_admin_cert": _ephemeral_cert_config(ca_url="https://admin-step-ca.example.com"), + }, + ) + + assert get_admin_ephemeral_cert_config(admin)["provider_config"]["ca_url"] == "https://admin-step-ca.example.com" + + +def test_ephemeral_cert_config_requires_provider(): + project, admin = _project( + admin_props={"ephemeral_admin_cert": {"provider_config": {"ca_url": "https://step-ca.example.com"}}} + ) + + with pytest.raises(ValueError, match="provider"): + get_admin_ephemeral_cert_config(admin) + + +def test_provider_specific_config_is_not_validated_by_generic_provisioning_helper(): + project, admin = _project( + admin_props={ + "ephemeral_admin_cert": { + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + }, + } + } + ) + + assert get_admin_ephemeral_cert_config(admin)["provider_config"]["ca_url"] == "https://step-ca.example.com" + + +def test_ephemeral_cert_config_rejects_unknown_provider_name_shape(): + _project_obj, admin = _project( + admin_props={ + "ephemeral_admin_cert": { + "provider": "step-ca", + "provider_config": {}, + } + } + ) + + with pytest.raises(ValueError, match="built-in provider name or module:function path"): + get_admin_ephemeral_cert_config(admin) + + +def test_ephemeral_cert_config_accepts_custom_provider_path_without_importing_it(): + _project_obj, admin = _project( + admin_props={ + "ephemeral_admin_cert": { + "provider": "customer.cert_provider:obtain_certificate", + "provider_config": {}, + } + } + ) + + assert get_admin_ephemeral_cert_config(admin)["provider"] == "customer.cert_provider:obtain_certificate" diff --git a/tests/unit_test/lighter/participant_test.py b/tests/unit_test/lighter/participant_test.py index be1b69e0b9..8aa9ab48ad 100644 --- a/tests/unit_test/lighter/participant_test.py +++ b/tests/unit_test/lighter/participant_test.py @@ -34,6 +34,40 @@ def test_invalid_name(self, type, invalid_name): with pytest.raises(ValueError): _ = Participant(name=invalid_name, org="org", type=type) + def test_ephemeral_admin_allows_kit_name(self): + participant = Participant( + name="sso-admin-kit", + org=None, + type="admin", + props={ + "ephemeral_admin_cert": { + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + }, + }, + }, + ) + + assert participant.name == "sso-admin-kit" + assert not participant.org + + @pytest.mark.parametrize( + "props,org,match", + [ + ({"role": "project_admin", "ephemeral_admin_cert": {"provider": "step_ca"}}, None, "must not define role"), + ({"ephemeral_admin_cert": {"provider": "step_ca"}}, "org", "must not define org"), + ], + ) + def test_ephemeral_admin_rejects_project_time_identity(self, props, org, match): + with pytest.raises(ValueError, match=match): + _ = Participant(name="sso-admin-kit", org=org, type="admin", props=props) + + def test_static_admin_rejects_kit_name(self): + with pytest.raises(ValueError): + _ = Participant(name="sso-admin-kit", org="org", type="admin", props={"role": "project_admin"}) + @pytest.mark.parametrize( "invalid_org", [("org-"), ("org@"), ("org!"), ("org~")], diff --git a/tests/unit_test/lighter/provision_test.py b/tests/unit_test/lighter/provision_test.py index 7f26690269..a00876c3f3 100644 --- a/tests/unit_test/lighter/provision_test.py +++ b/tests/unit_test/lighter/provision_test.py @@ -61,6 +61,35 @@ def test_prepare_project_accepts_api_version_4(self): assert [p.name for p in project.get_clients()] == ["client1"] assert [p.name for p in project.get_admins()] == ["admin1@org.com"] + def test_prepare_project_accepts_ephemeral_admin_without_org_or_role(self): + project_config = { + "api_version": 3, + "name": "mytest", + "description": "test", + "participants": [ + {"type": "server", "name": "server1", "org": "org"}, + { + "type": "admin", + "name": "sso-admin-kit", + "ephemeral_admin_cert": { + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + }, + }, + }, + ], + } + + project = prepare_project(project_dict=project_config) + admin = project.get_admins()[0] + + assert admin.name == "sso-admin-kit" + assert not admin.org + assert admin.get_prop("role") is None + assert admin.get_prop("ephemeral_admin_cert")["provider"] == "step_ca" + def test_prepare_project_requires_api_version_4_for_studies(self): project_config = self._base_project( api_version=3, studies={"study-a": {"site_orgs": {"org": ["client1"]}, "admins": []}} diff --git a/tests/unit_test/lighter/static_file_builder_test.py b/tests/unit_test/lighter/static_file_builder_test.py index 99dd10ecef..63dce6172b 100644 --- a/tests/unit_test/lighter/static_file_builder_test.py +++ b/tests/unit_test/lighter/static_file_builder_test.py @@ -21,7 +21,7 @@ from nvflare.lighter.constants import CtxKey from nvflare.lighter.ctx import ProvisionContext from nvflare.lighter.entity import Participant, Project -from nvflare.lighter.impl.static_file import StaticFileBuilder +from nvflare.lighter.impl.static_file import StaticFileBuilder, _modify_fed_admin_config class _FakeCtx: @@ -100,6 +100,43 @@ def test_scheme(self, scheme): builder = StaticFileBuilder(scheme=scheme) assert builder.scheme == scheme + def test_ephemeral_admin_cert_config_omits_static_client_cert_material(self): + config = _modify_fed_admin_config( + json.dumps( + { + "admin": { + "username": "admin@example.com", + "client_key": "client.key", + "client_cert": "client.crt", + "ca_cert": "rootCA.pem", + } + } + ), + ephemeral_admin_cert={ + "provider": "step_ca", + "renewal_window": 60, + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + }, + }, + ) + + admin = json.loads(config)["admin"] + assert "client_key" not in admin + assert "client_cert" not in admin + assert admin["username"] == "" + assert admin["uid_source"] == "cert" + assert admin["ca_cert"] == "rootCA.pem" + assert admin["ephemeral_admin_cert"] == { + "provider": "step_ca", + "renewal_window": 60, + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + }, + } + def test_build_server_emits_study_registry_when_studies_exist(self, tmp_path): server = Participant(type="server", name="server1", org="org") project = Project( diff --git a/tests/unit_test/private/fed/server/job_cmds_test.py b/tests/unit_test/private/fed/server/job_cmds_test.py index 631337ad5a..75ea4efbe7 100644 --- a/tests/unit_test/private/fed/server/job_cmds_test.py +++ b/tests/unit_test/private/fed/server/job_cmds_test.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import datetime import gc import io from argparse import Namespace @@ -154,6 +155,7 @@ class _FakeJobDefManager: def __init__(self): self.created_meta = None self.cloned_meta = None + self.content = None def create(self, meta, uploaded_content, fl_ctx): self.created_meta = dict(meta) @@ -167,6 +169,9 @@ def clone(self, from_jid, meta, fl_ctx): result[JobMetaKey.JOB_ID.value] = "cloned-job-id" return result + def get_content(self, meta, fl_ctx): + return self.content + class _FakeSubmitTokenJobDefManager: def __init__(self): @@ -544,6 +549,63 @@ def test_submit_job_strips_user_supplied_from_hub_site(monkeypatch): assert JobMetaKey.FROM_HUB_SITE.value not in engine.job_def_manager.created_meta +@pytest.mark.parametrize("ephemeral_admin_cert", [False, True]) +def test_submit_job_records_only_ephemeral_submitter_cert_validity(monkeypatch, tmp_path, ephemeral_admin_cert): + now = datetime.datetime.now(datetime.timezone.utc) + + class _Cert: + not_valid_before_utc = now - datetime.timedelta(minutes=1) + not_valid_after_utc = now + datetime.timedelta(hours=1) + + class _Validator: + def validate(self, folder_name, zip_file_name): + return True, "", {} + + zip_path = tmp_path / "job.zip" + with ZipFile(zip_path, "w") as zip_file: + zip_file.writestr(f"source/app/{NVFLARE_SUBMITTER_CRT_FILE}", b"cert") + + monkeypatch.setattr(job_cmds_module, "JobMetaValidator", _Validator) + monkeypatch.setattr(job_cmds_module, "JobDefManagerSpec", object) + monkeypatch.setattr(job_cmds_module, "load_crt_chain_bytes", lambda _data: [_Cert()]) + + engine = _FakeEngine() + conn = _submit_conn(engine, str(zip_path)) + + args = ["submit_job", "job_folder"] + if ephemeral_admin_cert: + args.append("--ephemeral-admin-cert") + JobCommandModule().submit_job(conn, args) + + assert conn.errors == [] + if ephemeral_admin_cert: + assert engine.job_def_manager.created_meta[JobMetaKey.SUBMITTER_CERT_VALIDITY.value] == { + "not_before": _Cert.not_valid_before_utc.timestamp(), + "not_after": _Cert.not_valid_after_utc.timestamp(), + } + else: + assert JobMetaKey.SUBMITTER_CERT_VALIDITY.value not in engine.job_def_manager.created_meta + + +def test_submit_job_strips_forged_submitter_cert_validity(monkeypatch): + monkeypatch.setattr(job_cmds_module, "JobDefManagerSpec", object) + monkeypatch.setattr( + job_cmds_module, + "JobMetaValidator", + lambda: _FakeJobMetaValidatorWithMeta( + {JobMetaKey.SUBMITTER_CERT_VALIDITY.value: {"not_before": 0, "not_after": 9999999999}} + ), + ) + + engine = _FakeEngine() + conn = _submit_conn(engine, "job.zip") + + JobCommandModule().submit_job(conn, ["submit_job", "job_folder"]) + + assert conn.errors == [] + assert JobMetaKey.SUBMITTER_CERT_VALIDITY.value not in engine.job_def_manager.created_meta + + def test_submit_job_defaults_study_when_cmd_props_missing(monkeypatch): monkeypatch.setattr(job_cmds_module, "JobMetaValidator", _FakeJobMetaValidator) monkeypatch.setattr(job_cmds_module, "JobDefManagerSpec", object) @@ -1153,6 +1215,46 @@ def test_clone_job_preserves_byoc_flag(monkeypatch): assert engine.job_def_manager.cloned_meta[AppValidationKey.BYOC] is True +def test_clone_job_rejects_expired_stored_submitter_cert(monkeypatch): + monkeypatch.setattr(job_cmds_module, "ServerEngine", object) + monkeypatch.setattr(job_cmds_module, "JobDefManagerSpec", object) + + now = datetime.datetime.now(datetime.timezone.utc) + + source_job = _FakeListedJob( + { + JobMetaKey.JOB_ID.value: "source-job", + JobMetaKey.JOB_NAME.value: "source", + JobMetaKey.SUBMITTER_CERT_VALIDITY.value: { + "not_before": (now - datetime.timedelta(days=2)).timestamp(), + "not_after": (now - datetime.timedelta(days=1)).timestamp(), + }, + } + ) + engine = _FakeEngine() + engine.job_def_manager.get_content = MagicMock(side_effect=AssertionError("clone must not load job content")) + conn = _MockConnection( + app_ctx=engine, + props={ + JobCommandModule.JOB: source_job, + JobCommandModule.JOB_ID: "source-job", + ConnProps.USER_NAME: "submitter", + ConnProps.USER_ORG: "org", + ConnProps.USER_ROLE: "role", + }, + ) + + JobCommandModule().clone_job(conn, ["clone_job", "source-job"]) + + assert conn.successes == [] + assert engine.job_def_manager.cloned_meta is None + assert conn.errors + msg, meta = conn.errors[0] + assert "stored submitter certificate" in msg + assert meta[MetaKey.STATUS] == MetaStatusValue.INVALID_JOB_DEFINITION + engine.job_def_manager.get_content.assert_not_called() + + def test_list_jobs_filters_legacy_jobs_into_default_study(monkeypatch): monkeypatch.setattr(job_cmds_module, "JobDefManagerSpec", object) jobs = [ diff --git a/tests/unit_test/tool/kit/kit_config_test.py b/tests/unit_test/tool/kit/kit_config_test.py index 4501f8c476..750ac68e4b 100644 --- a/tests/unit_test/tool/kit/kit_config_test.py +++ b/tests/unit_test/tool/kit/kit_config_test.py @@ -45,6 +45,19 @@ def _make_poc_admin_startup_kit(parent: Path, name: str = "admin@nvidia.com") -> return kit_dir +def _make_ephemeral_cert_admin_startup_kit(parent: Path, name: str = "sso-admin-kit") -> Path: + kit_dir = parent / name + startup_dir = kit_dir / "startup" + startup_dir.mkdir(parents=True) + (startup_dir / "fed_admin.json").write_text( + '{"admin": {"username": "", "uid_source": "cert", "ephemeral_admin_cert": ' + '{"provider": "step_ca", "provider_config": ' + '{"ca_url": "https://step-ca.example.com", "provisioner": "nvflare-admin-oidc"}}}}\n' + ) + (startup_dir / "rootCA.pem").write_text("dummy root ca\n") + return kit_dir + + def _make_site_startup_kit(parent: Path, name: str = "site-1") -> Path: kit_dir = parent / name startup_dir = kit_dir / "startup" @@ -197,6 +210,21 @@ def test_set_active_validates_registered_path(self, tmp_path): assert updated.get("startup_kits.active") == "project_admin" + def test_ephemeral_cert_admin_startup_kit_does_not_require_static_client_cert(self, tmp_path): + from nvflare.tool.kit import kit_config + + kit_dir = _make_ephemeral_cert_admin_startup_kit(tmp_path, "sso-admin-kit") + config = CF.parse_string("version = 2") + + updated = kit_config.add_startup_kit_entry(config, "ephemeral_cert_admin", str(kit_dir)) + + assert _entry_path(updated, "ephemeral_cert_admin") == kit_dir.resolve() + assert kit_config.validate_admin_startup_kit(str(kit_dir / "startup")) == str(kit_dir.resolve()) + metadata = kit_config.inspect_startup_kit_metadata(str(kit_dir)) + assert metadata["certificate"]["status"] == "runtime_issued" + assert metadata["credential_source"] == "ephemeral_admin_cert" + assert metadata["findings"] == [] + def test_set_active_rejects_unknown_id_without_mutating_active(self, tmp_path): from nvflare.tool.kit import kit_config diff --git a/tests/unit_test/tool/package_checker/ephemeral_admin_test.py b/tests/unit_test/tool/package_checker/ephemeral_admin_test.py new file mode 100644 index 0000000000..8b16e91194 --- /dev/null +++ b/tests/unit_test/tool/package_checker/ephemeral_admin_test.py @@ -0,0 +1,141 @@ +# Copyright (c) 2026, NVIDIA CORPORATION. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from unittest.mock import patch + +from nvflare.tool.package_checker.check_rule import CHECK_PASSED, CheckServerAvailable +from nvflare.tool.package_checker.nvflare_console_package_checker import NVFlareConsolePackageChecker +from nvflare.tool.package_checker.package_checker import CheckStatus, PackageChecker +from nvflare.tool.package_checker.utils import NVFlareRole + + +def _write_ephemeral_admin_config(package_dir): + startup = package_dir / "startup" + startup.mkdir() + (startup / "fed_admin.json").write_text( + json.dumps( + { + "admin": { + "host": "localhost", + "port": 8003, + "scheme": "grpc", + "ephemeral_admin_cert": { + "provider": "step_ca", + "provider_config": { + "ca_url": "https://step-ca.example.com", + "provisioner": "nvflare-admin-oidc", + }, + }, + } + } + ), + encoding="utf-8", + ) + (startup / "rootCA.pem").write_text("root", encoding="utf-8") + + +def test_server_check_uses_non_interactive_reachability_for_ephemeral_admin(tmp_path): + _write_ephemeral_admin_config(tmp_path) + rule = CheckServerAvailable(name="Check server available", role=NVFlareRole.ADMIN) + + with patch( + "nvflare.tool.package_checker.check_rule.check_socket_server_running", return_value=True + ) as socket_check: + with patch("nvflare.tool.package_checker.check_rule.check_grpc_server_running") as grpc_check: + result = rule(str(tmp_path), data=None) + + assert result.problem == CHECK_PASSED + socket_check.assert_called_once_with(startup=str(tmp_path / "startup"), host="localhost", port=8003, scheme="tcp") + grpc_check.assert_not_called() + + +def test_console_check_skips_interactive_dry_run_for_ephemeral_admin(tmp_path): + _write_ephemeral_admin_config(tmp_path) + checker = NVFlareConsolePackageChecker() + checker.init(str(tmp_path)) + + with ( + patch.object(PackageChecker, "check_dry_run") as inherited_check, + patch( + "nvflare.tool.package_checker.nvflare_console_package_checker.shutil.which", return_value="/usr/bin/step" + ), + ): + status = checker.check_dry_run() + + assert status == CheckStatus.PASS + assert checker.report[str(tmp_path.resolve())][-1][0:2] == ("Check dry run", "SKIPPED") + inherited_check.assert_not_called() + + +def test_console_check_rejects_invalid_ephemeral_config(tmp_path): + _write_ephemeral_admin_config(tmp_path) + config_path = tmp_path / "startup" / "fed_admin.json" + config = json.loads(config_path.read_text(encoding="utf-8")) + config["admin"]["ephemeral_admin_cert"] = True + config_path.write_text(json.dumps(config), encoding="utf-8") + checker = NVFlareConsolePackageChecker() + checker.init(str(tmp_path)) + + assert checker.check_dry_run() == CheckStatus.FAIL + assert "must be a mapping" in checker.report[str(tmp_path.resolve())][-1][1] + + +def test_console_check_rejects_missing_root_ca(tmp_path): + _write_ephemeral_admin_config(tmp_path) + (tmp_path / "startup" / "rootCA.pem").unlink() + checker = NVFlareConsolePackageChecker() + checker.init(str(tmp_path)) + + assert checker.check_dry_run() == CheckStatus.FAIL + assert "missing project root certificate" in checker.report[str(tmp_path.resolve())][-1][1] + + +def test_console_check_rejects_missing_step_cli(tmp_path): + _write_ephemeral_admin_config(tmp_path) + checker = NVFlareConsolePackageChecker() + checker.init(str(tmp_path)) + + with patch("nvflare.tool.package_checker.nvflare_console_package_checker.shutil.which", return_value=None): + assert checker.check_dry_run() == CheckStatus.FAIL + assert "step CLI is not available" in checker.report[str(tmp_path.resolve())][-1][1] + + +def test_server_check_rejects_unsupported_scheme_before_ephemeral_probe(tmp_path): + _write_ephemeral_admin_config(tmp_path) + config_path = tmp_path / "startup" / "fed_admin.json" + config = json.loads(config_path.read_text(encoding="utf-8")) + config["admin"]["scheme"] = "ws" + config_path.write_text(json.dumps(config), encoding="utf-8") + rule = CheckServerAvailable(name="Check server available", role=NVFlareRole.ADMIN) + + with patch("nvflare.tool.package_checker.check_rule.check_socket_server_running") as socket_check: + result = rule(str(tmp_path), data=None) + + assert result.problem == "Unsupported communication scheme: ws" + socket_check.assert_not_called() + + +def test_console_check_falls_back_to_normal_dry_run_for_invalid_config(tmp_path): + startup = tmp_path / "startup" + startup.mkdir() + (startup / "fed_admin.json").write_text("{", encoding="utf-8") + checker = NVFlareConsolePackageChecker() + checker.init(str(tmp_path)) + + with patch.object(PackageChecker, "check_dry_run", return_value=CheckStatus.FAIL) as inherited_check: + status = checker.check_dry_run() + + assert status == CheckStatus.FAIL + inherited_check.assert_called_once_with()