diff --git a/gigl/common/types/resource_config.py b/gigl/common/types/resource_config.py index a24a97a47..a2a0ffe85 100644 --- a/gigl/common/types/resource_config.py +++ b/gigl/common/types/resource_config.py @@ -15,3 +15,7 @@ class CommonPipelineComponentConfigs: additional_job_args: dict[GiGLComponents, dict[str, str]] = field( default_factory=dict ) + # Environment variables baked into every GiGL-owned container at compile time. + # Applied uniformly across all SPECED_COMPONENTS plus the GLT eligibility check + # and log_metrics_to_ui tasks. The managed VertexNotificationEmailOp is excluded. + env_vars: dict[str, str] = field(default_factory=dict) diff --git a/gigl/orchestration/kubeflow/kfp_orchestrator.py b/gigl/orchestration/kubeflow/kfp_orchestrator.py index 374347ac1..e3323ce24 100644 --- a/gigl/orchestration/kubeflow/kfp_orchestrator.py +++ b/gigl/orchestration/kubeflow/kfp_orchestrator.py @@ -56,6 +56,7 @@ def compile( dataflow_container_image: str, dst_compiled_pipeline_path: Uri = DEFAULT_KFP_COMPILED_PIPELINE_DEST_PATH, additional_job_args: Optional[dict[GiGLComponents, dict[str, str]]] = None, + env_vars: Optional[dict[str, str]] = None, tag: Optional[str] = None, ) -> Uri: """ @@ -68,6 +69,9 @@ def compile( dst_compiled_pipeline_path (Uri): Destination path for the compiled pipeline YAML file. Defaults to :data:`~gigl.constants.DEFAULT_KFP_COMPILED_PIPELINE_DEST_PATH`. additional_job_args (Optional[dict[GiGLComponents, dict[str, str]]]): Additional arguments to be passed into components, organized by component. + env_vars (Optional[dict[str, str]]): Environment variables baked into every GiGL-owned container at compile time. + Applied uniformly across all SPECED_COMPONENTS plus the GLT eligibility check and ``log_metrics_to_ui`` tasks. + The managed ``VertexNotificationEmailOp`` exit handler is intentionally excluded. tag (Optional[str]): Optional tag to include in the pipeline description. Returns: Uri: The URI of the compiled pipeline. @@ -85,6 +89,7 @@ def compile( cpu_container_image=cpu_container_image, dataflow_container_image=dataflow_container_image, additional_job_args=additional_job_args or {}, + env_vars=env_vars or {}, ) Compiler().compile( diff --git a/gigl/orchestration/kubeflow/kfp_pipeline.py b/gigl/orchestration/kubeflow/kfp_pipeline.py index 524107ce2..7ee844b60 100644 --- a/gigl/orchestration/kubeflow/kfp_pipeline.py +++ b/gigl/orchestration/kubeflow/kfp_pipeline.py @@ -51,6 +51,16 @@ } +def _apply_env_vars(task: PipelineTask, env_vars: dict[str, str]) -> None: + """Apply each entry in ``env_vars`` to ``task`` via ``set_env_variable``. + + The KFP v2 SDK's ``PipelineTask.set_env_variable`` takes a single name/value + pair per call, so we loop instead of passing a dict. + """ + for name, value in env_vars.items(): + task.set_env_variable(name=name, value=value) + + def _generate_component_task( component: GiGLComponents, job_name: str, @@ -124,6 +134,10 @@ def _generate_component_task( task=component_task, common_pipeline_component_configs=common_pipeline_component_configs, ) + _apply_env_vars( + task=component_task, + env_vars=common_pipeline_component_configs.env_vars, + ) return component_task @@ -145,10 +159,15 @@ def _generate_component_tasks( resource_config_uri=resource_config_uri, common_pipeline_component_configs=common_pipeline_component_configs, ) - should_use_glt = check_glt_backend_eligibility_component( + glt_eligibility_task = check_glt_backend_eligibility_component( task_config_uri=template_or_frozen_config_uri, base_image=common_pipeline_component_configs.cpu_container_image, ) + _apply_env_vars( + task=glt_eligibility_task, + env_vars=common_pipeline_component_configs.env_vars, + ) + should_use_glt = glt_eligibility_task.output with kfp.dsl.Condition(start_at == GiGLComponents.ConfigPopulator.value): config_populator_task = _create_config_populator_task_op( @@ -249,6 +268,9 @@ def pipeline( stop_after: Optional[str] = None, notification_emails: Optional[List[str]] = None, ): + # VertexNotificationEmailOp is a Google-managed component, so we + # intentionally do not apply common_pipeline_component_configs.env_vars + # to it; see docs/plans/20260429-add-env-var-injection-to-kfp-runner.md. with kfp.dsl.ExitHandler( VertexNotificationEmailOp(recipients=notification_emails), name="Gigl Alerts", @@ -451,6 +473,10 @@ def _create_trainer_task_op( ) log_metrics_component.set_display_name(name="Log Trainer Eval Metrics") log_metrics_component.after(trainer_task) + _apply_env_vars( + task=log_metrics_component, + env_vars=common_pipeline_component_configs.env_vars, + ) with kfp.dsl.Condition(stop_after != GiGLComponents.Trainer.value): inference_task = _create_inferencer_task_op( @@ -484,4 +510,8 @@ def _create_post_processor_task_op( ) log_metrics_component.set_display_name(name="Log PostProcessor Eval Metrics") log_metrics_component.after(post_processor_task) + _apply_env_vars( + task=log_metrics_component, + env_vars=common_pipeline_component_configs.env_vars, + ) return post_processor_task diff --git a/gigl/orchestration/kubeflow/runner.py b/gigl/orchestration/kubeflow/runner.py index 790b000dc..8f9de90ef 100644 --- a/gigl/orchestration/kubeflow/runner.py +++ b/gigl/orchestration/kubeflow/runner.py @@ -34,6 +34,11 @@ --notification_emails: Emails to send notification to. See https://cloud.google.com/vertex-ai/docs/pipelines/email-notifications for more details. Example: --notification_emails=user@example.com --notification_emails=user2@example.com + --env_vars: Environment variables baked into every GiGL-owned container at compile time + (every invocation of --action=run recompiles, so each run gets a fresh bake). + The value has to be of form: "=". GiGL does not interpret the contents. + This argument can be repeated. + Example: --env_vars=PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python --env_vars=FOO=bar You can alternatively run_no_compile if you have a precompiled pipeline somewhere. python gigl.orchestration.kubeflow.runner --action=run_no_compile ...args @@ -48,6 +53,8 @@ --pipeline_tag --notification_emails --wait + NOTE: --env_vars is rejected for --action=run_no_compile because env vars are baked at + compile time. Recompile via --action=run or --action=compile to change them. COMPILING A PIPELINE: A strict subset of running a pipeline, @@ -68,6 +75,10 @@ --additional_job_args=split_generator.some_other_arg='value' This passes additional_spark35_jar_file_uris="gs://path/to/jar" to subgraph_sampler at compile time and some_other_arg="value" to split_generator at compile time. + --env_vars: Environment variables baked into every GiGL-owned container at compile time. + The value has to be of form: "=". GiGL does not interpret the contents. + This argument can be repeated. + Example: --env_vars=PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python --env_vars=FOO=bar """ from __future__ import annotations @@ -170,6 +181,13 @@ def _assert_required_flags(args: argparse.Namespace) -> None: "Please use the run action to run a pipeline with labels." f"Labels provided: {args.run_labels}" ) + if args.action == Action.RUN_NO_COMPILE and args.env_vars: + raise ValueError( + "--env_vars is not supported for the run_no_compile action because " + "environment variables are baked into the pipeline at compile time. " + "Recompile via --action=run or --action=compile to apply env vars. " + f"env_vars provided: {args.env_vars}" + ) logger = Logger() @@ -225,6 +243,22 @@ def _parse_labels(labels: list[str]) -> dict[str, str]: return result +def _parse_env_vars(env_vars: list[str]) -> dict[str, str]: + """ + Parse environment variables to bake into every GiGL-owned container at compile time. + Args: + env_vars list[str]: Each element is of form: "=". + Example: ["FOO=bar", "BAZ=qux"]. + Returns dict[str, str]: The parsed environment variables. + """ + result: dict[str, str] = {} + for entry in env_vars: + name, value = entry.split("=", 1) + result[name] = value + logger.info(f"Parsed env_vars: {result}") + return result + + def _get_parser() -> argparse.ArgumentParser: """ Get the parser for the runner.py script. @@ -326,6 +360,19 @@ def _get_parser() -> argparse.ArgumentParser: Which will taget the pipeline run with gigl-integration-test=true and user=me. """, ) + parser.add_argument( + "--env_vars", + action="append", + default=[], + help="""Environment variables baked into every GiGL-owned container at compile time, of the form: + --env_vars=KEY=VALUE. GiGL does not interpret the contents; the values flow opaquely to all + SPECED_COMPONENTS plus the GLT eligibility check and log_metrics_to_ui tasks. + Only applicable for run and compile actions; rejected with --action=run_no_compile because envs + are baked at compile time and the flag would silently do nothing in that mode. + KFP itself reserves a small set of names (e.g. KFP_*) and may reject those at runtime. + Example: --env_vars=PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python --env_vars=FOO=bar + """, + ) parser.add_argument( "--notification_emails", action="append", @@ -345,6 +392,7 @@ def _get_parser() -> argparse.ArgumentParser: parsed_additional_job_args = _parse_additional_job_args(args.additional_job_args) parsed_labels = _parse_labels(args.run_labels) + parsed_env_vars = _parse_env_vars(args.env_vars) # Set the default value for compiled_pipeline_path as we cannot set it in argparse as # for compile action this is a required flag so we cannot provide it a default value. @@ -388,6 +436,7 @@ def _get_parser() -> argparse.ArgumentParser: dataflow_container_image=dataflow_container_image, dst_compiled_pipeline_path=compiled_pipeline_path, additional_job_args=parsed_additional_job_args, + env_vars=parsed_env_vars, tag=args.pipeline_tag, ) assert path == compiled_pipeline_path, ( @@ -415,6 +464,7 @@ def _get_parser() -> argparse.ArgumentParser: dataflow_container_image=dataflow_container_image, dst_compiled_pipeline_path=compiled_pipeline_path, additional_job_args=parsed_additional_job_args, + env_vars=parsed_env_vars, tag=args.pipeline_tag, ) logger.info( diff --git a/gigl/orchestration/kubeflow/utils/glt_backend.py b/gigl/orchestration/kubeflow/utils/glt_backend.py index 862021864..a0c709bf1 100644 --- a/gigl/orchestration/kubeflow/utils/glt_backend.py +++ b/gigl/orchestration/kubeflow/utils/glt_backend.py @@ -3,12 +3,18 @@ def check_glt_backend_eligibility_component( task_config_uri: str, base_image: str -) -> bool: +) -> dsl.PipelineTask: + """Construct the KFP task that decides whether to use the GLT backend. + + Returns the underlying ``PipelineTask`` so callers can attach resource + requirements, environment variables, or other task-level configuration + before consuming ``.output`` for downstream conditionals. + """ comp = dsl.component( func=_check_glt_backend_eligibility_component, base_image=base_image ) comp.description = "Check whether to use GLT Backend" - return comp(task_config_uri=task_config_uri).output + return comp(task_config_uri=task_config_uri) def _check_glt_backend_eligibility_component( diff --git a/tests/unit/orchestration/kubeflow/kfp_orchestrator_test.py b/tests/unit/orchestration/kubeflow/kfp_orchestrator_test.py index 34cef81f9..ea4a2dd96 100644 --- a/tests/unit/orchestration/kubeflow/kfp_orchestrator_test.py +++ b/tests/unit/orchestration/kubeflow/kfp_orchestrator_test.py @@ -1,8 +1,11 @@ +import tempfile +from pathlib import Path from unittest.mock import ANY, patch +import yaml from absl.testing import absltest -from gigl.common import GcsUri +from gigl.common import GcsUri, LocalUri from gigl.common.logger import Logger from gigl.orchestration.kubeflow.kfp_orchestrator import KfpOrchestrator from tests.test_assets.test_case import TestCase @@ -29,6 +32,86 @@ def test_compile_uploads_compiled_yaml(self, MockFileLoader): file_uri_src=ANY, file_uri_dst=dst_compiled_pipeline_path ) + def test_compile_bakes_env_vars_into_every_gigl_owned_executor(self): + """env_vars passed to compile() should appear on every GiGL-owned executor's container env. + + The managed VertexNotificationEmailOp exit handler is the documented + carve-out and must not receive the env vars. + """ + env_vars = { + "FOO": "bar", + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python", + } + with tempfile.TemporaryDirectory() as tmp_dir: + dst = LocalUri(str(Path(tmp_dir) / "pipeline.yaml")) + KfpOrchestrator.compile( + cuda_container_image="SOME NONEXISTENT IMAGE 1", + cpu_container_image="SOME NONEXISTENT IMAGE 2", + dataflow_container_image="SOME NONEXISTENT IMAGE 3", + dst_compiled_pipeline_path=dst, + env_vars=env_vars, + ) + + with open(dst.uri, "r") as f: + compiled = yaml.safe_load(f) + + executors = compiled["deploymentSpec"]["executors"] + self.assertGreater(len(executors), 0, "Expected at least one executor in IR.") + + gigl_owned_with_env: list[str] = [] + notification_executors_without_env: list[str] = [] + for executor_id, executor_spec in executors.items(): + container = executor_spec.get("container", {}) + env_list = container.get("env", []) + env_dict = {entry["name"]: entry["value"] for entry in env_list} + is_notification = "notification-email" in executor_id.lower() + if is_notification: + # The managed notification op must not receive our env vars. + for name in env_vars: + self.assertNotIn( + name, + env_dict, + f"Env var {name} unexpectedly applied to managed " + f"notification executor {executor_id}.", + ) + notification_executors_without_env.append(executor_id) + else: + for name, value in env_vars.items(): + self.assertEqual( + env_dict.get(name), + value, + f"Executor {executor_id} missing env var {name}={value}; " + f"actual env: {env_dict}.", + ) + gigl_owned_with_env.append(executor_id) + + self.assertGreater( + len(gigl_owned_with_env), + 0, + "Expected at least one GiGL-owned executor to receive env vars.", + ) + + def test_compile_without_env_vars_does_not_inject_env(self): + """When env_vars is omitted, no GiGL-owned executor should pick up phantom env entries from this code path.""" + with tempfile.TemporaryDirectory() as tmp_dir: + dst = LocalUri(str(Path(tmp_dir) / "pipeline.yaml")) + KfpOrchestrator.compile( + cuda_container_image="SOME NONEXISTENT IMAGE 1", + cpu_container_image="SOME NONEXISTENT IMAGE 2", + dataflow_container_image="SOME NONEXISTENT IMAGE 3", + dst_compiled_pipeline_path=dst, + ) + + with open(dst.uri, "r") as f: + compiled = yaml.safe_load(f) + + # Only assert the default (unset) case adds no FOO key — we don't make + # claims about other env entries that KFP itself may inject. + for executor_spec in compiled["deploymentSpec"]["executors"].values(): + env_list = executor_spec.get("container", {}).get("env", []) + env_names = {entry["name"] for entry in env_list} + self.assertNotIn("FOO", env_names) + if __name__ == "__main__": absltest.main() diff --git a/tests/unit/orchestration/kubeflow/kfp_runner_test.py b/tests/unit/orchestration/kubeflow/kfp_runner_test.py index eade9b3ed..368ed61cb 100644 --- a/tests/unit/orchestration/kubeflow/kfp_runner_test.py +++ b/tests/unit/orchestration/kubeflow/kfp_runner_test.py @@ -5,6 +5,7 @@ _assert_required_flags, _get_parser, _parse_additional_job_args, + _parse_env_vars, _parse_labels, ) from gigl.src.common.constants.components import GiGLComponents @@ -110,6 +111,76 @@ def test_assert_required_flags_success(self): # Should not raise any exception _assert_required_flags(args) + def test_parse_env_vars_single(self): + parsed = _parse_env_vars(["FOO=bar"]) + self.assertEqual(parsed, {"FOO": "bar"}) + + def test_parse_env_vars_multiple(self): + parsed = _parse_env_vars(["FOO=bar", "BAZ=qux"]) + self.assertEqual(parsed, {"FOO": "bar", "BAZ": "qux"}) + + def test_parse_env_vars_value_contains_equals(self): + # split("=", 1) means only the first '=' delimits key/value; the rest is value. + parsed = _parse_env_vars(["URL=https://example.com/?q=1&r=2"]) + self.assertEqual(parsed, {"URL": "https://example.com/?q=1&r=2"}) + + def test_parse_env_vars_empty_value(self): + parsed = _parse_env_vars(["FOO="]) + self.assertEqual(parsed, {"FOO": ""}) + + def test_parse_env_vars_empty_list(self): + self.assertEqual(_parse_env_vars([]), {}) + + def test_parse_env_vars_malformed_raises(self): + # No '=' in the entry — split("=", 1) returns a single-element list and the + # tuple unpack raises ValueError, mirroring _parse_labels semantics. + with self.assertRaises(ValueError): + _parse_env_vars(["NOT_A_VALID_ENTRY"]) + + def test_parse_env_vars_duplicate_keys_last_wins(self): + parsed = _parse_env_vars(["FOO=first", "FOO=second"]) + self.assertEqual(parsed, {"FOO": "second"}) + + def test_assert_required_flags_rejects_env_vars_with_run_no_compile(self): + """--env_vars must not be combined with --action=run_no_compile.""" + parser = _get_parser() + args = parser.parse_args( + [ + "--action=run_no_compile", + "--task_config_uri=gs://bucket/task_config.yaml", + "--resource_config_uri=gs://bucket/resource_config.yaml", + "--compiled_pipeline_path=gs://bucket/pipeline.yaml", + "--env_vars=FOO=bar", + ] + ) + with self.assertRaises(ValueError): + _assert_required_flags(args) + + def test_assert_required_flags_allows_env_vars_with_run(self): + """--env_vars is valid for --action=run.""" + parser = _get_parser() + args = parser.parse_args( + [ + "--action=run", + "--task_config_uri=gs://bucket/task_config.yaml", + "--resource_config_uri=gs://bucket/resource_config.yaml", + "--env_vars=FOO=bar", + ] + ) + _assert_required_flags(args) + + def test_assert_required_flags_allows_env_vars_with_compile(self): + """--env_vars is valid for --action=compile.""" + parser = _get_parser() + args = parser.parse_args( + [ + "--action=compile", + "--compiled_pipeline_path=gs://bucket/pipeline.yaml", + "--env_vars=FOO=bar", + ] + ) + _assert_required_flags(args) + if __name__ == "__main__": absltest.main()