diff --git a/.github/actions/assert-is-collaborator/action.yml b/.github/actions/assert-is-collaborator/action.yml index cd1e6feb8..2551f5bea 100644 --- a/.github/actions/assert-is-collaborator/action.yml +++ b/.github/actions/assert-is-collaborator/action.yml @@ -5,7 +5,7 @@ inputs: description: The GitHub username to check required: true initiating-pr-number: - description: The PR number that the check may be associated with, if provided will comment on the PR incase of failures + description: The PR number that the check may be associated with, if provided will comment on the PR incase of failures required: false runs: using: "composite" @@ -22,7 +22,7 @@ runs: script: | try { const username = "${{ inputs.username }}"; - const result = await github.rest.repos.checkCollaborator({ + const result = await github.rest.repos.checkCollaborator({ owner: context.repo.owner, repo: context.repo.repo, username: username @@ -41,7 +41,7 @@ runs: console.log(`Error checking collaborator status: ${error.message}`); } } - + - name: Comment workflow permissions if: ${{ failure() && steps.assert-is-collaborator.conclusion == 'failure' && inputs.initiating-pr-number != '' }} uses: snapchat/gigl/.github/actions/comment-on-pr@main @@ -49,4 +49,3 @@ runs: pr_number: ${{ inputs.initiating-pr-number }} message: | πŸ”’ User ${{ inputs.username }} does not have permissions to run this workflow - \ No newline at end of file diff --git a/.github/actions/comment-on-pr/action.yml b/.github/actions/comment-on-pr/action.yml index 1f7bdf894..9151c2c5d 100644 --- a/.github/actions/comment-on-pr/action.yml +++ b/.github/actions/comment-on-pr/action.yml @@ -16,7 +16,7 @@ outputs: comment_id: description: 'The ID of the created or updated comment' value: ${{steps.comment.outputs.result}} - + runs: using: 'composite' @@ -67,4 +67,4 @@ runs: }); return response.data.id; - } \ No newline at end of file + } diff --git a/.github/actions/get-pr-src-branch/action.yml b/.github/actions/get-pr-src-branch/action.yml index 56c83ae1b..17cf87821 100644 --- a/.github/actions/get-pr-src-branch/action.yml +++ b/.github/actions/get-pr-src-branch/action.yml @@ -30,4 +30,4 @@ runs: }); const branch_name = pr.data.head.ref; console.log("Branch name is:", branch_name); - return branch_name; \ No newline at end of file + return branch_name; diff --git a/docs/plans/20260429-add-env-var-injection-to-kfp-runner.md b/docs/plans/20260429-add-env-var-injection-to-kfp-runner.md new file mode 100644 index 000000000..73f97bc83 --- /dev/null +++ b/docs/plans/20260429-add-env-var-injection-to-kfp-runner.md @@ -0,0 +1,118 @@ +# Plan: Add env-var injection to the KFP runner + +## 1. Background + +Today, every container launched by the GiGL Kubeflow pipeline (config_validator, config_populator, data_preprocessor, subgraph_sampler, split_generator, trainer, inferencer, post_processor β€” see `SPECED_COMPONENTS` at `gigl/orchestration/kubeflow/kfp_pipeline.py:32-41`) inherits its environment from the image. There is no first-class way for a caller of `runner.py` to set arbitrary `ENV` values on those containers at compile time. + +Wiring each downstream consumer in as a typed flag (e.g. one flag per image URI, one per Python-interop toggle, one per protobuf-implementation switch) does not scale and leaks consumer-specific concepts into the GiGL surface area. A single generic `--env KEY=VALUE` flag β€” threaded into `KfpOrchestrator.compile(...)` and applied per-task via the KFP v2 SDK's [`PipelineTask.set_env_variable(name, value)`](https://kubeflow-pipelines.readthedocs.io/en/stable/source/dsl.html#kfp.dsl.PipelineTask.set_env_variable) β€” is the cleanest carrier: GiGL knows nothing about the contents, callers own the meaning. GiGL's own code paths that need env config will continue to use the `GIGL_*` prefix; everything else is opaque transport. + +## 2. API surface + +**New CLI flag on `runner.py`** β€” repeatable, mirrors `--run_labels` exactly: + +``` +--env KEY=VALUE +``` + +Argparse spec β€” added next to `--run_labels` in `_get_parser` (`runner.py:319-328`): + +- `action="append"`, `default=[]`, `help` describes the format and that values flow to every component at compile time. + +**Parser helper** β€” added next to `_parse_labels` (`runner.py:212-225`); identical signature shape: + +``` +def _parse_env_vars(env_vars: list[str]) -> dict[str, str] +``` + +The template to mirror is the existing `_parse_labels` body verbatim β€” `split("=", 1)` on each entry, populate `dict[str, str]`, log the parsed result. Same error semantics: a malformed entry (no `=`) raises `ValueError` from `str.split` unpacking, exactly like `_parse_labels` does today. + +**Plumbing path:** + +1. `runner.py:__main__` calls `parsed_env_vars = _parse_env_vars(args.env)` alongside `parsed_additional_job_args` / `parsed_labels` (`runner.py:346-347`). +2. Both `KfpOrchestrator.compile` call sites (`runner.py:385` for the RUN action, `runner.py:412` for the COMPILE action) gain `env_vars=parsed_env_vars`. +3. `KfpOrchestrator.compile` (`kfp_orchestrator.py:51-109`) gains `env_vars: Optional[dict[str, str]] = None`, packs it into `CommonPipelineComponentConfigs` (a new field β€” see Β§3), and passes it through `generate_pipeline(...)`. +4. `kfp_pipeline.py:_generate_component_task` (`kfp_pipeline.py:54-128`) β€” after `add_task_resource_requirements(...)` is called and before `return component_task`, loop the dict and call `component_task.set_env_variable(name=k, value=v)` per entry. + +Per the [KFP v2 API](https://kubeflow-pipelines.readthedocs.io/en/stable/source/dsl.html#kfp.dsl.PipelineTask.set_env_variable), `set_env_variable` takes one `name`/`value` pair per call and returns the task for chaining; we must loop, not pass a dict. + +## 3. Where the env vars get applied + +**All eight components in `SPECED_COMPONENTS`.** Justification: + +- The whole point of a generic carrier is uniformity β€” callers expect that if they say `--env FOO=bar`, *every* container they're paying for sees `FOO=bar`. +- `_generate_component_task` is the single funnel through which every `PipelineTask` in `SPECED_COMPONENTS` is constructed (`kfp_pipeline.py:54-128`). Adding the loop there covers config_validator, config_populator, data_preprocessor, subgraph_sampler, split_generator, trainer, inferencer, post_processor in one place. +- The two non-`SPECED_COMPONENTS` tasks in the pipeline are `check_glt_backend_eligibility_component` (`utils/glt_backend.py`) and `log_metrics_to_ui` (`utils/log_metrics.py`). Decision: include them too, for a complete answer to "every container my pipeline launches sees these vars." Concretely, both are constructed inside `kfp_pipeline.py` (`_generate_component_tasks` for the GLT check, `_create_trainer_task_op` / `_create_post_processor_task_op` for log metrics). Apply the same loop right after each is built. +- Implementation choice that keeps the callsite count low: introduce a private helper `_apply_env_vars(task, env_vars)` inside `kfp_pipeline.py` (parallel in spirit to `add_task_resource_requirements`) and call it from `_generate_component_task` plus the two non-SPECED sites. One source of truth, no risk of drift. + +The `env_vars` dict rides on `CommonPipelineComponentConfigs` (`gigl/common/types/resource_config.py:8-17`), added as `env_vars: dict[str, str] = field(default_factory=dict)` β€” same shape as the existing `additional_job_args` field. + +## 4. Compile-time vs run-time + +**Compile-time bake-in.** `set_env_variable` records the name/value into the compiled pipeline IR (the YAML at `dst_compiled_pipeline_path`); the values are baked at compile time, not resolved per run. That is the right choice here: + +- For `--action=run`, `runner.py:384-395` *recompiles* on every invocation before calling `orchestrator.run`, so each run gets a fresh bake of whatever `--env` values were passed on that invocation. No UX loss. +- For `--action=compile`, `runner.py:411-422` produces a static artifact β€” the user explicitly wanted those values frozen in. +- For `--action=run_no_compile`, the user is opting into a pre-compiled pipeline; whatever envs were baked at compile time is what runs. This is consistent with how `additional_job_args` already behaves. + +A run-time-resolved alternative would require declaring a KFP pipeline parameter (e.g. `env_vars: dict`) on the `@kfp.dsl.pipeline` function and passing it through every component op. KFP v2 has poor ergonomics for `dict` pipeline parameters and the immediate use case (callers injecting a config value per compile) does not benefit from late binding. We do *not* take that path; if a future caller needs per-run env, that's a separate proposal. + +## 5. Validation / failure modes + +Mirror `_parse_labels` exactly so behavior is predictable across all three flags: + +- Malformed entry (no `=`): `str.split("=", 1)` returns one element, the unpack to `(name, value)` raises `ValueError`. Same as `_parse_labels` today β€” propagate, do not catch. +- Empty value (`--env FOO=`): valid, `value=""`. Mirrors `_parse_labels` β€” that flag also accepts empty values, so we stay consistent. +- Empty key (`--env =bar`): not validated by `_parse_labels` either; KFP's own `set_env_variable` will reject it downstream. Consistent failure surface; do not pre-empt. +- Duplicate keys across multiple `--env` invocations: last one wins (dict overwrite), same as `_parse_labels`. +- Reserved names: do *not* enforce a denylist inside GiGL. KFP itself reserves a small set (e.g. `KFP_*`); leaving validation to KFP keeps GiGL generic. Document the caveat in the help text. +- Cross-flag conflict: `_assert_required_flags` (`runner.py:134-172`) does not need a new check. `--env` is valid for all three actions (`run`, `run_no_compile`, `compile`), unlike `--run_labels` which is run-only. Document this difference. + +## 6. Step-by-step implementation plan + +Each bullet is one commit-sized unit: + +1. **Add `env_vars` field to `CommonPipelineComponentConfigs`** at `gigl/common/types/resource_config.py:8-17`. Default-empty dict. +2. **Add `_apply_env_vars(task, env_vars)` helper in `kfp_pipeline.py`** β€” single-responsibility, loops the dict and calls `task.set_env_variable(name=k, value=v)`. +3. **Wire the helper into `_generate_component_task`** at `kfp_pipeline.py:54-128`, right after `add_task_resource_requirements(...)`. Also call it on `check_glt_backend_eligibility_component` and `log_metrics_to_ui` task results. +4. **Thread `env_vars: Optional[dict[str, str]] = None`** through `KfpOrchestrator.compile` (`kfp_orchestrator.py:51-109`) into `CommonPipelineComponentConfigs(...)` (existing site at `kfp_orchestrator.py:83-88`). +5. **Add `--env` arg to `_get_parser`** (`runner.py:_get_parser`) β€” `action="append"`, `default=[]`, help text describing the format and noting GiGL does not interpret values. +6. **Add `_parse_env_vars` helper** next to `_parse_labels` (`runner.py:212-225`). +7. **Plumb `parsed_env_vars` into both `KfpOrchestrator.compile(...)` call sites** in `runner.py` (`runner.py:385-395` and `runner.py:411-422`). +8. **Update the `runner.py` module docstring** (lines 1-71): document `--env` under both `RUN` and `COMPILE` action sections; note compile-time bake-in semantics. +9. **Unit test: parser** in `tests/unit/orchestration/kubeflow/runner_test.py` (new file, mirror layout of `kfp_orchestrator_test.py`). Cover: single var, multiple vars, value containing `=`, malformed entry raises `ValueError`, empty list returns empty dict. +10. **Unit test: compile-time injection** extending `kfp_orchestrator_test.py:KfpOrchestratorTest` β€” call `KfpOrchestrator.compile(..., env_vars={"FOO": "bar"})` writing to a tmp path, parse the resulting YAML, assert that every component spec under the compiled pipeline IR has an `env` entry with `name: FOO`, `value: bar`. The IR shape is stable for KFP v2 (`spec.executors..container.env`). + +## 7. Test strategy + +- **Parser unit test** (step 9 above): pure function, no I/O. Exhaustive on malformed inputs since this is user-facing CLI. +- **Compile integration test** (step 10): writes the compiled pipeline to a temp file, loads with `yaml.safe_load`, walks every executor in `pipelineSpec.deploymentSpec.executors`, asserts the env list contains all expected pairs. This proves the loop ran on every component, not just one. Run via `make unit_test_py PY_TEST_FILES="kfp_orchestrator_test.py"`. +- **Smoke test, manual**: from a downstream `Makefile`, add `--env=KEY1=value1 --env=KEY2=value2` to one `compile_*_kubeflow_pipeline` target, compile, and `grep -A2 "name: KEY1" build/gigl_pipeline_gnn.yaml` to confirm presence on every executor. + +## 8. Rollout + +GiGL-side callers of `gigl.orchestration.kubeflow.runner` to inventory: + +- `Makefile:258 compile_gigl_kubeflow_pipeline` +- `Makefile:283 run_dev_gnn_kubeflow_pipeline` +- `Makefile:307 compile_simple_gigl_kubeflow_pipeline` +- `Makefile:328-ish run_dev_simple_kubeflow_pipeline` (the simple-GiGL run target β€” verify exact line) +- `Makefile:352, 360, 368, 376, 384, 392, 400, 407` β€” the e2e test targets that depend on `compile_gigl_kubeflow_pipeline`. They inherit whatever the compile target sets; no per-target change unless they need different envs. + +GiGL repo: no callers beyond this `runner.py` module need changes. The flag is opt-in, default empty, fully backward-compatible. + +Downstream repos that vendor GiGL as a submodule: a separate follow-up commit (out of scope of the GiGL PR) wires `--env=...` into whichever Makefile targets need it. That commit lives in the consumer repo, not GiGL. + +## 9. Open questions + +1. **Pipeline-parameter env**: do we expect any caller to want envs that vary *per run* of the same compiled pipeline? Not for the immediate use case driving this work; flag if that changes. +2. **VertexNotificationEmailOp env**: `_generate_component_tasks` wraps everything in an `ExitHandler(VertexNotificationEmailOp(...))` (`kfp_pipeline.py:252-256`). The notification op is a Google-managed component; should env vars be applied to it? Default answer: no β€” it's a managed op outside the user's control surface, applying envs is at best inert and at worst rejected. Confirm during implementation by checking whether `set_env_variable` on it raises. +3. **Naming β€” `--env` vs `--env_var` vs `--env_vars`**: `--env` is shortest and matches `docker run --env`; `--env_var` is more discoverable in `--help`; `--env_vars` would be most consistent with the existing plural `--run_labels`. Repeatable flags in this file use both styles (`--run_labels` plural, `--notification_emails` plural, `--additional_job_args` plural). Suggest `--env_vars` and document it as repeatable. Confirm with reviewer. +4. **Should the helper land on `CommonPipelineComponentConfigs` or be a separate parameter to `generate_pipeline`?** Going via `CommonPipelineComponentConfigs` is consistent with `additional_job_args` and minimizes signature churn. No open question; calling out the design choice. + +### Critical Files for Implementation + +- `gigl/orchestration/kubeflow/runner.py` +- `gigl/orchestration/kubeflow/kfp_orchestrator.py` +- `gigl/orchestration/kubeflow/kfp_pipeline.py` +- `gigl/common/types/resource_config.py` +- `tests/unit/orchestration/kubeflow/kfp_orchestrator_test.py` diff --git a/gigl/common/omegaconf_resolvers.py b/gigl/common/omegaconf_resolvers.py index 6653620e5..0ba8c119e 100644 --- a/gigl/common/omegaconf_resolvers.py +++ b/gigl/common/omegaconf_resolvers.py @@ -5,6 +5,7 @@ """ import subprocess +from collections.abc import Mapping from datetime import datetime, timedelta from omegaconf import OmegaConf @@ -16,6 +17,12 @@ _SUPPORTED_UNITS = ("weeks", "days", "seconds", "minutes", "hours") +# Module-level value dict consumed by the ``gigl`` resolver. Populated +# by ``set_gigl_resolver_values`` (typically called by ``launch_custom`` +# right before re-resolving a ``CustomResourceConfig``'s ``command`` / +# ``args`` strings). +_GIGL_RESOLVER_VALUES: dict[str, str] = {} + def now_resolver(*args: str) -> str: """Resolver that creates a string representing the current time (with optional offset) using strftime. @@ -157,6 +164,53 @@ def git_hash_resolver() -> str: return "" +def gigl_resolver(key: str) -> str: + """Resolve ``${gigl:}`` from a module-level value dict. + + Registered resolvers are invoked with colon syntax in OmegaConf + (``${gigl:foo}``), not dotted (``${gigl.foo}``). The fallback string + therefore mirrors that form so the placeholder can round-trip through + a first-pass YAML load and be re-resolved later once runtime values + are populated via ``set_gigl_resolver_values``. + + Args: + key: The runtime-value key being looked up + (e.g. ``cuda_docker_image``). + + Returns: + The runtime value if ``set_gigl_resolver_values`` has populated + ``key``; otherwise the literal placeholder string + ``"${gigl:}"`` so the early YAML pass + (``ProtoUtils.read_proto_from_yaml``) is lossless. + + Example: + In a YAML loaded by ``ProtoUtils.read_proto_from_yaml``: + + ```yaml + custom_trainer_config: + command: "${gigl:task_config_uri}" + ``` + + ``launch_custom`` later sets ``task_config_uri`` via + ``set_gigl_resolver_values`` and re-resolves the field. + """ + return _GIGL_RESOLVER_VALUES.get(key, f"${{gigl:{key}}}") + + +def set_gigl_resolver_values(values: Mapping[str, str]) -> None: + """Replace the module-level dict the ``gigl`` resolver reads from. + + The dict is cleared before the new values are written so stale + runtime values from a prior ``launch_custom`` call cannot leak into + a subsequent invocation. + + Args: + values: New keyβ†’value mapping. + """ + _GIGL_RESOLVER_VALUES.clear() + _GIGL_RESOLVER_VALUES.update(values) + + def register_resolvers() -> None: """Register all custom OmegaConf resolvers. @@ -178,3 +232,11 @@ def register_resolvers() -> None: logger.debug( "OmegaConf resolver 'git_hash' already registered, skipping registration" ) + + if not OmegaConf.has_resolver("gigl"): + logger.info("Registering OmegaConf resolver 'gigl'") + OmegaConf.register_new_resolver("gigl", gigl_resolver) + else: + logger.debug( + "OmegaConf resolver 'gigl' already registered, skipping registration" + ) diff --git a/gigl/common/types/resource_config.py b/gigl/common/types/resource_config.py index a24a97a47..a2a0ffe85 100644 --- a/gigl/common/types/resource_config.py +++ b/gigl/common/types/resource_config.py @@ -15,3 +15,7 @@ class CommonPipelineComponentConfigs: additional_job_args: dict[GiGLComponents, dict[str, str]] = field( default_factory=dict ) + # Environment variables baked into every GiGL-owned container at compile time. + # Applied uniformly across all SPECED_COMPONENTS plus the GLT eligibility check + # and log_metrics_to_ui tasks. The managed VertexNotificationEmailOp is excluded. + env_vars: dict[str, str] = field(default_factory=dict) diff --git a/gigl/env/runtime.py b/gigl/env/runtime.py new file mode 100644 index 000000000..82fb31aa6 --- /dev/null +++ b/gigl/env/runtime.py @@ -0,0 +1,69 @@ +"""Detect the execution environment for distributed training processes.""" + +import os +from enum import Enum + + +class RuntimeEnv(str, Enum): + """Supported execution environments. + + RAY: Running inside a Ray cluster. + VERTEX_AI: Running inside a Vertex AI custom job container. + UNKNOWN: No signal matched; caller must fall back to defaults. + """ + + RAY = "ray" + VERTEX_AI = "vertex_ai" + UNKNOWN = "unknown" + + +def is_ray_runtime() -> bool: + """True when the current process is running inside a Ray job. + + **Authoritative signal**: ``GIGL_RAY_RUNTIME=1``. A custom launcher + integrating GiGL with a Ray platform is expected to set this env var on + the head spec so the entrypoint can branch reliably. This is the only + signal GiGL fully controls, so it's checked first. + + **Fallbacks for callers that don't set the authoritative var**: KubeRay + injects ``RAY_DASHBOARD_ADDRESS`` into the submitter pod (not + ``RAY_ADDRESS`` β€” verified against Ray 2.54 docs: + https://docs.ray.io/en/releases-2.54.0/cluster/kubernetes/getting-started/rayjob-quick-start.html). + ``RAY_ADDRESS`` is checked last and may or may not be present depending + on how Ray was started. ``ray.is_initialized()`` returns ``False`` before + ``ray.init()``, so it only helps if Ray has already been bootstrapped. + + **Callable-ordering constraint**: this is a snapshot of process env, not + a poll. It MUST be called after the pod's env is populated (always true + for the entrypoint process, which inherits pod-level env from the + container start). Repeat callers get the same answer; don't rely on it + to "become true" mid-process after ``ray.init()`` finishes. + """ + if os.environ.get("GIGL_RAY_RUNTIME") == "1": + return True + if os.environ.get("RAY_DASHBOARD_ADDRESS"): + return True + if os.environ.get("RAY_ADDRESS"): + return True + try: + import ray # type: ignore[import-not-found] + + return ray.is_initialized() + except ImportError: + return False + + +def get_runtime_env() -> RuntimeEnv: + """Classify the current process's execution environment. + + Returns: + ``RuntimeEnv.RAY`` when :func:`is_ray_runtime` is True. + ``RuntimeEnv.VERTEX_AI`` when Vertex AI env vars + (``CLOUD_ML_JOB_ID`` or ``AIP_MODEL_DIR``) are set. + ``RuntimeEnv.UNKNOWN`` otherwise. + """ + if is_ray_runtime(): + return RuntimeEnv.RAY + if os.environ.get("CLOUD_ML_JOB_ID") or os.environ.get("AIP_MODEL_DIR"): + return RuntimeEnv.VERTEX_AI + return RuntimeEnv.UNKNOWN diff --git a/gigl/orchestration/kubeflow/kfp_orchestrator.py b/gigl/orchestration/kubeflow/kfp_orchestrator.py index 374347ac1..e3323ce24 100644 --- a/gigl/orchestration/kubeflow/kfp_orchestrator.py +++ b/gigl/orchestration/kubeflow/kfp_orchestrator.py @@ -56,6 +56,7 @@ def compile( dataflow_container_image: str, dst_compiled_pipeline_path: Uri = DEFAULT_KFP_COMPILED_PIPELINE_DEST_PATH, additional_job_args: Optional[dict[GiGLComponents, dict[str, str]]] = None, + env_vars: Optional[dict[str, str]] = None, tag: Optional[str] = None, ) -> Uri: """ @@ -68,6 +69,9 @@ def compile( dst_compiled_pipeline_path (Uri): Destination path for the compiled pipeline YAML file. Defaults to :data:`~gigl.constants.DEFAULT_KFP_COMPILED_PIPELINE_DEST_PATH`. additional_job_args (Optional[dict[GiGLComponents, dict[str, str]]]): Additional arguments to be passed into components, organized by component. + env_vars (Optional[dict[str, str]]): Environment variables baked into every GiGL-owned container at compile time. + Applied uniformly across all SPECED_COMPONENTS plus the GLT eligibility check and ``log_metrics_to_ui`` tasks. + The managed ``VertexNotificationEmailOp`` exit handler is intentionally excluded. tag (Optional[str]): Optional tag to include in the pipeline description. Returns: Uri: The URI of the compiled pipeline. @@ -85,6 +89,7 @@ def compile( cpu_container_image=cpu_container_image, dataflow_container_image=dataflow_container_image, additional_job_args=additional_job_args or {}, + env_vars=env_vars or {}, ) Compiler().compile( diff --git a/gigl/orchestration/kubeflow/kfp_pipeline.py b/gigl/orchestration/kubeflow/kfp_pipeline.py index 524107ce2..7ee844b60 100644 --- a/gigl/orchestration/kubeflow/kfp_pipeline.py +++ b/gigl/orchestration/kubeflow/kfp_pipeline.py @@ -51,6 +51,16 @@ } +def _apply_env_vars(task: PipelineTask, env_vars: dict[str, str]) -> None: + """Apply each entry in ``env_vars`` to ``task`` via ``set_env_variable``. + + The KFP v2 SDK's ``PipelineTask.set_env_variable`` takes a single name/value + pair per call, so we loop instead of passing a dict. + """ + for name, value in env_vars.items(): + task.set_env_variable(name=name, value=value) + + def _generate_component_task( component: GiGLComponents, job_name: str, @@ -124,6 +134,10 @@ def _generate_component_task( task=component_task, common_pipeline_component_configs=common_pipeline_component_configs, ) + _apply_env_vars( + task=component_task, + env_vars=common_pipeline_component_configs.env_vars, + ) return component_task @@ -145,10 +159,15 @@ def _generate_component_tasks( resource_config_uri=resource_config_uri, common_pipeline_component_configs=common_pipeline_component_configs, ) - should_use_glt = check_glt_backend_eligibility_component( + glt_eligibility_task = check_glt_backend_eligibility_component( task_config_uri=template_or_frozen_config_uri, base_image=common_pipeline_component_configs.cpu_container_image, ) + _apply_env_vars( + task=glt_eligibility_task, + env_vars=common_pipeline_component_configs.env_vars, + ) + should_use_glt = glt_eligibility_task.output with kfp.dsl.Condition(start_at == GiGLComponents.ConfigPopulator.value): config_populator_task = _create_config_populator_task_op( @@ -249,6 +268,9 @@ def pipeline( stop_after: Optional[str] = None, notification_emails: Optional[List[str]] = None, ): + # VertexNotificationEmailOp is a Google-managed component, so we + # intentionally do not apply common_pipeline_component_configs.env_vars + # to it; see docs/plans/20260429-add-env-var-injection-to-kfp-runner.md. with kfp.dsl.ExitHandler( VertexNotificationEmailOp(recipients=notification_emails), name="Gigl Alerts", @@ -451,6 +473,10 @@ def _create_trainer_task_op( ) log_metrics_component.set_display_name(name="Log Trainer Eval Metrics") log_metrics_component.after(trainer_task) + _apply_env_vars( + task=log_metrics_component, + env_vars=common_pipeline_component_configs.env_vars, + ) with kfp.dsl.Condition(stop_after != GiGLComponents.Trainer.value): inference_task = _create_inferencer_task_op( @@ -484,4 +510,8 @@ def _create_post_processor_task_op( ) log_metrics_component.set_display_name(name="Log PostProcessor Eval Metrics") log_metrics_component.after(post_processor_task) + _apply_env_vars( + task=log_metrics_component, + env_vars=common_pipeline_component_configs.env_vars, + ) return post_processor_task diff --git a/gigl/orchestration/kubeflow/runner.py b/gigl/orchestration/kubeflow/runner.py index 790b000dc..8f9de90ef 100644 --- a/gigl/orchestration/kubeflow/runner.py +++ b/gigl/orchestration/kubeflow/runner.py @@ -34,6 +34,11 @@ --notification_emails: Emails to send notification to. See https://cloud.google.com/vertex-ai/docs/pipelines/email-notifications for more details. Example: --notification_emails=user@example.com --notification_emails=user2@example.com + --env_vars: Environment variables baked into every GiGL-owned container at compile time + (every invocation of --action=run recompiles, so each run gets a fresh bake). + The value has to be of form: "=". GiGL does not interpret the contents. + This argument can be repeated. + Example: --env_vars=PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python --env_vars=FOO=bar You can alternatively run_no_compile if you have a precompiled pipeline somewhere. python gigl.orchestration.kubeflow.runner --action=run_no_compile ...args @@ -48,6 +53,8 @@ --pipeline_tag --notification_emails --wait + NOTE: --env_vars is rejected for --action=run_no_compile because env vars are baked at + compile time. Recompile via --action=run or --action=compile to change them. COMPILING A PIPELINE: A strict subset of running a pipeline, @@ -68,6 +75,10 @@ --additional_job_args=split_generator.some_other_arg='value' This passes additional_spark35_jar_file_uris="gs://path/to/jar" to subgraph_sampler at compile time and some_other_arg="value" to split_generator at compile time. + --env_vars: Environment variables baked into every GiGL-owned container at compile time. + The value has to be of form: "=". GiGL does not interpret the contents. + This argument can be repeated. + Example: --env_vars=PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python --env_vars=FOO=bar """ from __future__ import annotations @@ -170,6 +181,13 @@ def _assert_required_flags(args: argparse.Namespace) -> None: "Please use the run action to run a pipeline with labels." f"Labels provided: {args.run_labels}" ) + if args.action == Action.RUN_NO_COMPILE and args.env_vars: + raise ValueError( + "--env_vars is not supported for the run_no_compile action because " + "environment variables are baked into the pipeline at compile time. " + "Recompile via --action=run or --action=compile to apply env vars. " + f"env_vars provided: {args.env_vars}" + ) logger = Logger() @@ -225,6 +243,22 @@ def _parse_labels(labels: list[str]) -> dict[str, str]: return result +def _parse_env_vars(env_vars: list[str]) -> dict[str, str]: + """ + Parse environment variables to bake into every GiGL-owned container at compile time. + Args: + env_vars list[str]: Each element is of form: "=". + Example: ["FOO=bar", "BAZ=qux"]. + Returns dict[str, str]: The parsed environment variables. + """ + result: dict[str, str] = {} + for entry in env_vars: + name, value = entry.split("=", 1) + result[name] = value + logger.info(f"Parsed env_vars: {result}") + return result + + def _get_parser() -> argparse.ArgumentParser: """ Get the parser for the runner.py script. @@ -326,6 +360,19 @@ def _get_parser() -> argparse.ArgumentParser: Which will taget the pipeline run with gigl-integration-test=true and user=me. """, ) + parser.add_argument( + "--env_vars", + action="append", + default=[], + help="""Environment variables baked into every GiGL-owned container at compile time, of the form: + --env_vars=KEY=VALUE. GiGL does not interpret the contents; the values flow opaquely to all + SPECED_COMPONENTS plus the GLT eligibility check and log_metrics_to_ui tasks. + Only applicable for run and compile actions; rejected with --action=run_no_compile because envs + are baked at compile time and the flag would silently do nothing in that mode. + KFP itself reserves a small set of names (e.g. KFP_*) and may reject those at runtime. + Example: --env_vars=PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python --env_vars=FOO=bar + """, + ) parser.add_argument( "--notification_emails", action="append", @@ -345,6 +392,7 @@ def _get_parser() -> argparse.ArgumentParser: parsed_additional_job_args = _parse_additional_job_args(args.additional_job_args) parsed_labels = _parse_labels(args.run_labels) + parsed_env_vars = _parse_env_vars(args.env_vars) # Set the default value for compiled_pipeline_path as we cannot set it in argparse as # for compile action this is a required flag so we cannot provide it a default value. @@ -388,6 +436,7 @@ def _get_parser() -> argparse.ArgumentParser: dataflow_container_image=dataflow_container_image, dst_compiled_pipeline_path=compiled_pipeline_path, additional_job_args=parsed_additional_job_args, + env_vars=parsed_env_vars, tag=args.pipeline_tag, ) assert path == compiled_pipeline_path, ( @@ -415,6 +464,7 @@ def _get_parser() -> argparse.ArgumentParser: dataflow_container_image=dataflow_container_image, dst_compiled_pipeline_path=compiled_pipeline_path, additional_job_args=parsed_additional_job_args, + env_vars=parsed_env_vars, tag=args.pipeline_tag, ) logger.info( diff --git a/gigl/orchestration/kubeflow/utils/glt_backend.py b/gigl/orchestration/kubeflow/utils/glt_backend.py index 862021864..a0c709bf1 100644 --- a/gigl/orchestration/kubeflow/utils/glt_backend.py +++ b/gigl/orchestration/kubeflow/utils/glt_backend.py @@ -3,12 +3,18 @@ def check_glt_backend_eligibility_component( task_config_uri: str, base_image: str -) -> bool: +) -> dsl.PipelineTask: + """Construct the KFP task that decides whether to use the GLT backend. + + Returns the underlying ``PipelineTask`` so callers can attach resource + requirements, environment variables, or other task-level configuration + before consuming ``.output`` for downstream conditionals. + """ comp = dsl.component( func=_check_glt_backend_eligibility_component, base_image=base_image ) comp.description = "Check whether to use GLT Backend" - return comp(task_config_uri=task_config_uri).output + return comp(task_config_uri=task_config_uri) def _check_glt_backend_eligibility_component( diff --git a/gigl/src/common/custom_launcher.py b/gigl/src/common/custom_launcher.py new file mode 100644 index 000000000..9a99f8b23 --- /dev/null +++ b/gigl/src/common/custom_launcher.py @@ -0,0 +1,143 @@ +"""Subprocess dispatch for ``CustomResourceConfig``-backed launchers. + +Resolves ``${gigl:*}`` placeholders in ``CustomResourceConfig.command`` / +``CustomResourceConfig.args`` against a runtime context (task config URI, +applied task identifier, component, …), then shells out via +``subprocess.run(shell_line, shell=True)``. The shell-style invocation +honors leading ``KEY=VALUE`` env-var assignments in ``command`` so +callers can self-document required env without forcing the dispatcher to +parse env separately. + +The receiving subprocess has no special protocol β€” it is expected to be +a plain CLI that argparses whatever flags the YAML wires up via +``args[]``. When more context is needed than ``CustomResourceConfig`` can +carry directly, the YAML embeds ``${gigl:}`` placeholders; this +module populates the values just before exec via the ``gigl`` resolver +(see ``gigl.common.omegaconf_resolvers``). +""" + +import shlex +import subprocess +from collections.abc import Mapping +from typing import Optional + +from omegaconf import OmegaConf + +from gigl.common import Uri +from gigl.common.logger import Logger +from gigl.common.omegaconf_resolvers import ( + register_resolvers, + set_gigl_resolver_values, +) +from gigl.src.common.constants.components import GiGLComponents +from snapchat.research.gbml.gigl_resource_config_pb2 import CustomResourceConfig + +logger = Logger() + +_LAUNCHABLE_COMPONENTS: frozenset[GiGLComponents] = frozenset( + {GiGLComponents.Trainer, GiGLComponents.Inferencer} +) + + +def launch_custom( + custom_resource_config: CustomResourceConfig, + applied_task_identifier: str, + task_config_uri: Uri, + resource_config_uri: Uri, + process_command: str, + process_runtime_args: Mapping[str, str], + cpu_docker_uri: Optional[str], + cuda_docker_uri: Optional[str], + component: GiGLComponents, + is_dry_run: bool = False, +) -> None: + """Resolve ``custom_resource_config`` and shell out to the configured command. + + Populates the ``gigl`` OmegaConf resolver from the runtime kwargs, + re-resolves ``custom_resource_config.command`` / ``.args`` so any + ``${gigl:*}`` placeholders bind to runtime values, then invokes the + composed shell line via ``subprocess.run(shell=True, check=True)``. + + ``process_command`` and ``process_runtime_args`` are accepted for + back-compat with the existing GLT trainer / inferencer call sites + but are intentionally NOT plumbed through to the subprocess β€” + consumers re-derive them from the ``--gbml_uri`` (or equivalent) + they receive. + + Args: + custom_resource_config: Proto whose ``command`` is the shell + snippet to execute and whose ``args`` are positional + arguments. Both fields support ``${gigl:}`` + interpolation. + applied_task_identifier: Stable identifier for the job; exposed + as ``${gigl:applied_task_identifier}``. + task_config_uri: URI of the GbmlConfig serialized as YAML; + exposed as ``${gigl:task_config_uri}``. + resource_config_uri: URI of the GiglResourceConfig serialized as + YAML; exposed as ``${gigl:resource_config_uri}``. + process_command: Accepted for back-compat; ignored. + process_runtime_args: Accepted for back-compat; ignored. + cpu_docker_uri: Optional CPU Docker image URI; exposed as + ``${gigl:cpu_docker_image}`` (empty string when ``None``). + cuda_docker_uri: Optional CUDA Docker image URI; exposed as + ``${gigl:cuda_docker_image}`` (empty string when ``None``). + component: Which GiGL component is being launched. Must be in + ``_LAUNCHABLE_COMPONENTS``. Exposed as ``${gigl:component}`` + (Title-case ``.name``, e.g. ``"Trainer"``). + is_dry_run: If True, the resolved shell line is logged and the + function returns without spawning a subprocess. Exposed as + ``${gigl:is_dry_run}`` (string ``"1"`` or ``"0"``). + + Raises: + ValueError: If ``component`` is not Trainer or Inferencer, or if + ``custom_resource_config.command`` is empty. + subprocess.CalledProcessError: If the spawned subprocess exits + non-zero. + """ + if component not in _LAUNCHABLE_COMPONENTS: + raise ValueError(f"Invalid component: {component}") + if not custom_resource_config.command: + raise ValueError("CustomResourceConfig.command must be set") + + # Defensive registration β€” direct callers (tests, scripts) may build + # a CustomResourceConfig programmatically without going through + # ProtoUtils, which is the usual registration path. + register_resolvers() + + set_gigl_resolver_values( + { + "applied_task_identifier": applied_task_identifier, + "task_config_uri": str(task_config_uri), + "resource_config_uri": str(resource_config_uri), + # Title-case component name so the receiving CLI's argparse + # ``choices=["Trainer", "Inferencer"]`` accepts it. ``.value`` + # is lowercase and would mismatch. + "component": component.name, + "cpu_docker_image": cpu_docker_uri or "", + "cuda_docker_image": cuda_docker_uri or "", + "is_dry_run": "1" if is_dry_run else "0", + } + ) + + # Re-resolve via OmegaConf so any ${gigl:*} placeholder in the + # proto's command/args strings binds to the now-populated runtime + # value. + container = OmegaConf.create( + { + "command": custom_resource_config.command, + "args": list(custom_resource_config.args), + } + ) + resolved_command: str = container.command # type: ignore[assignment] + resolved_args: list[str] = list(container.args) # type: ignore[arg-type] + + shell_line = " ".join( + [resolved_command, *(shlex.quote(a) for a in resolved_args)] + ) + logger.info( + f"Launching {component.name} via subprocess: {shell_line!r} " + f"dry_run={is_dry_run}" + ) + if is_dry_run: + return + subprocess.run(shell_line, shell=True, check=True) diff --git a/gigl/src/common/types/pb_wrappers/gigl_resource_config.py b/gigl/src/common/types/pb_wrappers/gigl_resource_config.py index 93eb95b8e..d174d8b94 100644 --- a/gigl/src/common/types/pb_wrappers/gigl_resource_config.py +++ b/gigl/src/common/types/pb_wrappers/gigl_resource_config.py @@ -8,6 +8,7 @@ from gigl.common.logger import Logger from gigl.src.common.constants.components import GiGLComponents from snapchat.research.gbml.gigl_resource_config_pb2 import ( + CustomResourceConfig, DataflowResourceConfig, DataPreprocessorConfig, DistributedTrainerConfig, @@ -37,12 +38,14 @@ _KFP_TRAINER_CONFIG = "kfp_trainer_config" _LOCAL_TRAINER_CONFIG = "local_trainer_config" _VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG = "vertex_ai_graph_store_trainer_config" +_CUSTOM_TRAINER_CONFIG = "custom_trainer_config" _INFERENCER_CONFIG_FIELD = "inferencer_config" _VERTEX_AI_INFERENCER_CONFIG = "vertex_ai_inferencer_config" _DATAFLOW_INFERENCER_CONFIG = "dataflow_inferencer_config" _LOCAL_INFERENCER_CONFIG = "local_inferencer_config" _VERTEX_AI_GRAPH_STORE_INFERENCER_CONFIG = "vertex_ai_graph_store_inferencer_config" +_CUSTOM_INFERENCER_CONFIG = "custom_inferencer_config" @dataclass @@ -55,6 +58,7 @@ class GiglResourceConfigWrapper: KFPResourceConfig, LocalResourceConfig, VertexAiGraphStoreConfig, + CustomResourceConfig, ] ] = None _inference_config: Optional[ @@ -63,6 +67,7 @@ class GiglResourceConfigWrapper: VertexAiResourceConfig, LocalResourceConfig, VertexAiGraphStoreConfig, + CustomResourceConfig, ] ] = None @@ -283,9 +288,10 @@ def trainer_config( KFPResourceConfig, LocalResourceConfig, VertexAiGraphStoreConfig, + CustomResourceConfig, ]: """ - Returns the trainer config specified in the resource config. (e.g. Vertex AI, KFP, Local) + Returns the trainer config specified in the resource config. (e.g. Vertex AI, KFP, Local, Custom) """ if not self._trainer_config: @@ -305,6 +311,7 @@ def trainer_config( KFPResourceConfig, LocalResourceConfig, VertexAiGraphStoreConfig, + CustomResourceConfig, ] if ( deprecated_config.WhichOneof(_TRAINER_CONFIG_FIELD) # type: ignore[arg-type] @@ -365,6 +372,11 @@ def trainer_config( == _VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG ): _trainer_config = config.vertex_ai_graph_store_trainer_config + elif ( + config.WhichOneof(_TRAINER_CONFIG_FIELD) # type: ignore[arg-type] + == _CUSTOM_TRAINER_CONFIG + ): + _trainer_config = config.custom_trainer_config else: raise ValueError(f"Invalid trainer_config type: {config}") else: @@ -383,9 +395,10 @@ def inferencer_config( VertexAiResourceConfig, LocalResourceConfig, VertexAiGraphStoreConfig, + CustomResourceConfig, ]: """ - Returns the inferencer config specified in the resource config. (Dataflow) + Returns the inferencer config specified in the resource config. (e.g. Dataflow, Vertex AI, Local, Custom) """ if self._inference_config is None: # TODO: (svij) Marked for deprecation @@ -421,6 +434,11 @@ def inferencer_config( self._inference_config = ( config.vertex_ai_graph_store_inferencer_config ) + elif ( + config.WhichOneof(_INFERENCER_CONFIG_FIELD) # type: ignore[arg-type] + == _CUSTOM_INFERENCER_CONFIG + ): + self._inference_config = config.custom_inferencer_config else: raise ValueError("Invalid inferencer_config type") else: diff --git a/gigl/src/data_preprocessor/data_preprocessor.py b/gigl/src/data_preprocessor/data_preprocessor.py index 95d063aaf..49309c01c 100644 --- a/gigl/src/data_preprocessor/data_preprocessor.py +++ b/gigl/src/data_preprocessor/data_preprocessor.py @@ -352,6 +352,12 @@ def __build_data_reference_str(references: Iterable[DataReference]) -> str: edge_ref_to_preprocessing_spec ) + if num_dataflow_jobs == 0: + logger.info("No data references to preprocess; skipping Dataflow.") + return PreprocessedMetadataReferences( + node_data=node_refs_and_results, edge_data=edge_refs_and_results + ) + with concurrent.futures.ThreadPoolExecutor( max_workers=num_dataflow_jobs ) as executor: diff --git a/gigl/src/data_preprocessor/lib/enumerate/utils.py b/gigl/src/data_preprocessor/lib/enumerate/utils.py index b606d7edb..930854db9 100644 --- a/gigl/src/data_preprocessor/lib/enumerate/utils.py +++ b/gigl/src/data_preprocessor/lib/enumerate/utils.py @@ -247,6 +247,10 @@ def __enumerate_all_node_references( ) -> list[EnumeratorNodeTypeMetadata]: results: list[EnumeratorNodeTypeMetadata] = [] + if not node_data_references: + logger.info("No node references to enumerate; skipping.") + return results + logger.info( f"Launch {len(node_data_references)} node enumeration jobs in parallel." ) @@ -274,6 +278,10 @@ def __enumerate_all_edge_references( ) -> list[EnumeratorEdgeTypeMetadata]: results: list[EnumeratorEdgeTypeMetadata] = [] + if not edge_data_references: + logger.info("No edge references to enumerate; skipping.") + return results + logger.info( f"Launch {len(edge_data_references)} edge enumeration jobs in parallel." ) diff --git a/gigl/src/inference/v2/glt_inferencer.py b/gigl/src/inference/v2/glt_inferencer.py index 1587828da..f9271f242 100644 --- a/gigl/src/inference/v2/glt_inferencer.py +++ b/gigl/src/inference/v2/glt_inferencer.py @@ -5,6 +5,7 @@ from gigl.common.logger import Logger from gigl.env.pipelines_config import get_resource_config from gigl.src.common.constants.components import GiGLComponents +from gigl.src.common.custom_launcher import launch_custom from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( @@ -16,6 +17,7 @@ launch_single_pool_job, ) from snapchat.research.gbml.gigl_resource_config_pb2 import ( + CustomResourceConfig, LocalResourceConfig, VertexAiGraphStoreConfig, VertexAiResourceConfig, @@ -90,6 +92,21 @@ def __execute_VAI_inference( cuda_docker_uri=cuda_docker_uri, component=GiGLComponents.Inferencer, ) + elif isinstance( + resource_config_wrapper.inferencer_config, CustomResourceConfig + ): + launch_custom( + custom_resource_config=resource_config_wrapper.inferencer_config, + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + process_command=inference_process_command, + process_runtime_args=inference_process_runtime_args, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + component=GiGLComponents.Inferencer, + is_dry_run=False, + ) else: raise NotImplementedError( f"Unsupported resource config for glt inference: {type(resource_config_wrapper.inferencer_config).__name__}" @@ -112,10 +129,16 @@ def run( raise NotImplementedError( f"Local GLT Inferencer is not yet supported, please specify a {VertexAiResourceConfig.__name__} or {VertexAiGraphStoreConfig.__name__} resource config field." ) - elif isinstance( - resource_config_wrapper.inferencer_config, VertexAiResourceConfig - ) or isinstance( - resource_config_wrapper.inferencer_config, VertexAiGraphStoreConfig + elif ( + isinstance( + resource_config_wrapper.inferencer_config, VertexAiResourceConfig + ) + or isinstance( + resource_config_wrapper.inferencer_config, VertexAiGraphStoreConfig + ) + or isinstance( + resource_config_wrapper.inferencer_config, CustomResourceConfig + ) ): self.__execute_VAI_inference( applied_task_identifier=applied_task_identifier, diff --git a/gigl/src/training/v2/glt_trainer.py b/gigl/src/training/v2/glt_trainer.py index 2f8ecbbbe..d0047be25 100644 --- a/gigl/src/training/v2/glt_trainer.py +++ b/gigl/src/training/v2/glt_trainer.py @@ -5,6 +5,7 @@ from gigl.common.logger import Logger from gigl.env.pipelines_config import get_resource_config from gigl.src.common.constants.components import GiGLComponents +from gigl.src.common.custom_launcher import launch_custom from gigl.src.common.types import AppliedTaskIdentifier from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( @@ -16,6 +17,7 @@ launch_single_pool_job, ) from snapchat.research.gbml.gigl_resource_config_pb2 import ( + CustomResourceConfig, LocalResourceConfig, VertexAiGraphStoreConfig, VertexAiResourceConfig, @@ -86,6 +88,19 @@ def __execute_VAI_training( cuda_docker_uri=cuda_docker_uri, component=GiGLComponents.Trainer, ) + elif isinstance(resource_config.trainer_config, CustomResourceConfig): + launch_custom( + custom_resource_config=resource_config.trainer_config, + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + process_command=training_process_command, + process_runtime_args=training_process_runtime_args, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + component=GiGLComponents.Trainer, + is_dry_run=False, + ) else: raise NotImplementedError( f"Unsupported resource config for glt training: {type(resource_config.trainer_config).__name__}" @@ -110,8 +125,10 @@ def run( raise NotImplementedError( f"Local GLT Trainer is not yet supported, please specify a {VertexAiResourceConfig.__name__} or {VertexAiGraphStoreConfig.__name__} resource config field." ) - elif isinstance(trainer_config, VertexAiResourceConfig) or isinstance( - trainer_config, VertexAiGraphStoreConfig + elif ( + isinstance(trainer_config, VertexAiResourceConfig) + or isinstance(trainer_config, VertexAiGraphStoreConfig) + or isinstance(trainer_config, CustomResourceConfig) ): self.__execute_VAI_training( applied_task_identifier=applied_task_identifier, diff --git a/gigl/src/validation_check/config_validator.py b/gigl/src/validation_check/config_validator.py index ec0ca4caf..f1fda69d8 100644 --- a/gigl/src/validation_check/config_validator.py +++ b/gigl/src/validation_check/config_validator.py @@ -16,6 +16,7 @@ assert_trained_model_exists, ) from gigl.src.validation_check.libs.gbml_and_resource_config_compatibility_checks import ( + check_custom_resource_config_requires_glt_backend, check_inferencer_graph_store_compatibility, check_trainer_graph_store_compatibility, ) @@ -23,6 +24,7 @@ check_if_kfp_pipeline_job_name_valid, ) from gigl.src.validation_check.libs.resource_config_checks import ( + check_if_custom_resource_config_dry_run_valid, check_if_inferencer_resource_config_valid, check_if_preprocessor_resource_config_valid, check_if_shared_resource_config_valid, @@ -202,25 +204,31 @@ GiGLComponents.ConfigPopulator.value: [ check_trainer_graph_store_compatibility, check_inferencer_graph_store_compatibility, + check_custom_resource_config_requires_glt_backend, ], GiGLComponents.DataPreprocessor.value: [ check_trainer_graph_store_compatibility, check_inferencer_graph_store_compatibility, + check_custom_resource_config_requires_glt_backend, ], GiGLComponents.SubgraphSampler.value: [ check_trainer_graph_store_compatibility, check_inferencer_graph_store_compatibility, + check_custom_resource_config_requires_glt_backend, ], GiGLComponents.SplitGenerator.value: [ check_trainer_graph_store_compatibility, check_inferencer_graph_store_compatibility, + check_custom_resource_config_requires_glt_backend, ], GiGLComponents.Trainer.value: [ check_trainer_graph_store_compatibility, check_inferencer_graph_store_compatibility, + check_custom_resource_config_requires_glt_backend, ], GiGLComponents.Inferencer.value: [ check_inferencer_graph_store_compatibility, + check_custom_resource_config_requires_glt_backend, ], # PostProcessor doesn't need graph store compatibility checks } @@ -275,6 +283,9 @@ def kfp_validation_checks( start_at: str, resource_config_uri: Uri, stop_after: Optional[str] = None, + check_custom_launcher_dry_run: bool = False, + cpu_docker_uri: Optional[str] = None, + cuda_docker_uri: Optional[str] = None, ) -> None: # check if job_name is valid check_if_kfp_pipeline_job_name_valid(job_name=job_name) @@ -347,6 +358,22 @@ def kfp_validation_checks( resource_config_wrapper=resource_config_wrapper, ) + # Optional: invoke any CustomResourceConfig launchers in dry-run mode so + # they can validate their inputs without spawning remote jobs. Default is + # False because the dry-run submission may require LCA credentials that + # CI does not have; users opt in with --check_custom_launcher_dry_run. + if check_custom_launcher_dry_run: + for component in (GiGLComponents.Trainer, GiGLComponents.Inferencer): + check_if_custom_resource_config_dry_run_valid( + resource_config_pb=resource_config_pb, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + applied_task_identifier=job_name, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + component=component, + ) + # check if trained model file exist when skipping training if gbml_config_pb.shared_config.should_skip_training == True: assert_trained_model_exists(gbml_config_pb=gbml_config_pb) @@ -383,6 +410,29 @@ def kfp_validation_checks( type=str, help="Runtime argument for resource and env specifications of each component", ) + parser.add_argument( + "--check_custom_launcher_dry_run", + action="store_true", + default=False, + help=( + "If set, invoke any CustomResourceConfig-backed trainer/inferencer " + "launcher with is_dry_run=True for early validation. Off by " + "default because the dry-run call may require LCA credentials " + "that CI does not have; authors run this locally with creds." + ), + ) + parser.add_argument( + "--cpu_docker_uri", + type=str, + default=None, + help="Optional CPU Docker image URI forwarded to the custom launcher dry-run.", + ) + parser.add_argument( + "--cuda_docker_uri", + type=str, + default=None, + help="Optional CUDA Docker image URI forwarded to the custom launcher dry-run.", + ) args = parser.parse_args() kfp_validation_checks( @@ -391,4 +441,7 @@ def kfp_validation_checks( start_at=args.start_at, resource_config_uri=UriFactory.create_uri(args.resource_config_uri), stop_after=args.stop_after, + check_custom_launcher_dry_run=args.check_custom_launcher_dry_run, + cpu_docker_uri=args.cpu_docker_uri, + cuda_docker_uri=args.cuda_docker_uri, ) diff --git a/gigl/src/validation_check/libs/gbml_and_resource_config_compatibility_checks.py b/gigl/src/validation_check/libs/gbml_and_resource_config_compatibility_checks.py index fc12d1939..dcf22c4c9 100644 --- a/gigl/src/validation_check/libs/gbml_and_resource_config_compatibility_checks.py +++ b/gigl/src/validation_check/libs/gbml_and_resource_config_compatibility_checks.py @@ -135,3 +135,63 @@ def check_inferencer_graph_store_compatibility( raise AssertionError( f"If one of GbmlConfig.inferencer_config.graph_store_storage_config or GiglResourceConfig.inferencer_resource_config is set, the other must also be set. GbmlConfig.inferencer_config.graph_store_storage_config is set: {gbml_has_graph_store}, GiglResourceConfig.inferencer_resource_config is set: {resource_has_graph_store}." ) + + +def check_custom_resource_config_requires_glt_backend( + gbml_config_pb_wrapper: GbmlConfigPbWrapper, + resource_config_wrapper: GiglResourceConfigWrapper, +) -> None: + """Enforce that ``CustomResourceConfig`` is only used with the GLT (v2) backend. + + The v1 trainer/inferencer dispatchers never consult the + ``custom_trainer_config`` / ``custom_inferencer_config`` oneof, so pairing + a ``CustomResourceConfig`` with a task config that has + ``should_use_glt_backend=False`` would silently fall through the v1 path + and fail at runtime. Catch it up-front here so the failure is loud and + actionable at validation time. + + Note on naming: the wrapper exposes ``should_use_glt_backend`` (bool) but + the raw YAML key users set is ``feature_flags.should_run_glt_backend``. + The wrapper translates one into the other; this check always reads the + wrapper property and never the raw map. + + Args: + gbml_config_pb_wrapper: The GbmlConfig wrapper (template config). + resource_config_wrapper: The GiglResourceConfig wrapper (resource config). + + Raises: + ValueError: If either the trainer or inferencer resource config is a + ``CustomResourceConfig`` and ``should_use_glt_backend`` is False. + """ + logger.info( + "Config validation check: CustomResourceConfig requires GLT (v2) backend." + ) + trainer_is_custom = isinstance( + resource_config_wrapper.trainer_config, + gigl_resource_config_pb2.CustomResourceConfig, + ) + inferencer_is_custom = isinstance( + resource_config_wrapper.inferencer_config, + gigl_resource_config_pb2.CustomResourceConfig, + ) + if not (trainer_is_custom or inferencer_is_custom): + return + + if not gbml_config_pb_wrapper.should_use_glt_backend: + offending: list[str] = [] + if trainer_is_custom: + offending.append("trainer_resource_config.custom_trainer_config") + if inferencer_is_custom: + offending.append("inferencer_resource_config.custom_inferencer_config") + raise ValueError( + "CustomResourceConfig is only wired into the GLT (v2) dispatchers " + "(glt_trainer.py / glt_inferencer.py); the v1 trainer/inferencer " + "never consult the custom oneof and would fall through to an " + "'Unsupported resource config' error at runtime. The following " + f"custom resource configs were set: {offending}, but the task " + "config has should_use_glt_backend=False (raw YAML key: " + "feature_flags.should_run_glt_backend). Either set " + "feature_flags.should_run_glt_backend='True' in the task config, " + "or replace the CustomResourceConfig with a built-in resource " + "config." + ) diff --git a/gigl/src/validation_check/libs/resource_config_checks.py b/gigl/src/validation_check/libs/resource_config_checks.py index 98a12a360..86d517ecb 100644 --- a/gigl/src/validation_check/libs/resource_config_checks.py +++ b/gigl/src/validation_check/libs/resource_config_checks.py @@ -1,8 +1,10 @@ -from typing import Final, Union +from typing import Final, Optional, Union from google.cloud.aiplatform_v1.types.accelerator_type import AcceleratorType +from gigl.common import Uri from gigl.common.logger import Logger +from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( GiglResourceConfigWrapper, @@ -149,7 +151,15 @@ def check_if_trainer_resource_config_valid( gigl_resource_config_pb2.VertexAiResourceConfig, gigl_resource_config_pb2.KFPResourceConfig, gigl_resource_config_pb2.VertexAiGraphStoreConfig, + gigl_resource_config_pb2.CustomResourceConfig, ] = wrapper.trainer_config + if isinstance(trainer_config, gigl_resource_config_pb2.CustomResourceConfig): + logger.info( + "Skipping trainer machine-shape validation: trainer_config is a " + "CustomResourceConfig (launcher-pluggable; no concrete machine " + "spec to validate)." + ) + return _validate_machine_config(config=trainer_config) @@ -163,6 +173,13 @@ def check_if_inferencer_resource_config_valid( resource_config=resource_config_pb ) inferencer_config = resource_config_wrapper.inferencer_config + if isinstance(inferencer_config, gigl_resource_config_pb2.CustomResourceConfig): + logger.info( + "Skipping inferencer machine-shape validation: inferencer_config " + "is a CustomResourceConfig (launcher-pluggable; no concrete " + "machine spec to validate)." + ) + return _validate_machine_config(config=inferencer_config) @@ -297,6 +314,99 @@ def _validate_machine_config( ) +def check_if_custom_resource_config_dry_run_valid( + resource_config_pb: gigl_resource_config_pb2.GiglResourceConfig, + task_config_uri: Uri, + resource_config_uri: Uri, + applied_task_identifier: str, + cpu_docker_uri: Optional[str], + cuda_docker_uri: Optional[str], + component: GiGLComponents, +) -> None: + """Invoke the custom launcher with ``is_dry_run=True`` for early validation. + + Resolves the component's resource config through the wrapper; if it is not + a ``CustomResourceConfig`` this helper is a no-op. Otherwise it dispatches + through ``launch_custom(..., is_dry_run=True)`` so the user-supplied + launcher can validate its inputs without actually spawning remote jobs. + + The import of ``launch_custom`` is intentionally lazy: the dry-run hook is + only reachable when the caller opts in via + ``--check_custom_launcher_dry_run``, and keeping the import inside the + function ensures ``assert_yaml_configs_parse`` (and other static config + validators) do not transitively pull in launcher-side dependencies (which + may be cluster-management clients such as a Ray platform SDK). + + Auth note: dry-run submission may call out to managed services that the + launcher integrates with; the submitter must have whatever credentials + those services require. See the custom launcher's own documentation for + specifics. + + Args: + resource_config_pb: The resource config to inspect. The trainer or + inferencer oneof (depending on ``component``) is pulled out of the + wrapper and, if it resolves to ``CustomResourceConfig``, dispatched + to the launcher. + task_config_uri: URI of the GbmlConfig YAML. + resource_config_uri: URI of the GiglResourceConfig YAML. + applied_task_identifier: Stable identifier for the job. + cpu_docker_uri: Optional CPU Docker image URI forwarded to the launcher. + cuda_docker_uri: Optional CUDA Docker image URI forwarded to the launcher. + component: Which GiGL component to dry-run. Must be Trainer or + Inferencer; other components never carry a ``CustomResourceConfig``. + + Raises: + ValueError: If ``component`` is not Trainer or Inferencer. + """ + # Lazy import β€” assert_yaml_configs_parse must stay import-free of + # launcher-side deps (the resolved launcher may pull in a cluster SDK). + from gigl.src.common.custom_launcher import launch_custom + + if component not in {GiGLComponents.Trainer, GiGLComponents.Inferencer}: + raise ValueError( + f"check_if_custom_resource_config_dry_run_valid only supports " + f"Trainer and Inferencer components; got {component}." + ) + + wrapper = GiglResourceConfigWrapper(resource_config=resource_config_pb) + component_config: Union[ + gigl_resource_config_pb2.LocalResourceConfig, + gigl_resource_config_pb2.VertexAiResourceConfig, + gigl_resource_config_pb2.KFPResourceConfig, + gigl_resource_config_pb2.VertexAiGraphStoreConfig, + gigl_resource_config_pb2.DataflowResourceConfig, + gigl_resource_config_pb2.CustomResourceConfig, + ] + if component == GiGLComponents.Trainer: + component_config = wrapper.trainer_config + else: + component_config = wrapper.inferencer_config + + if not isinstance(component_config, gigl_resource_config_pb2.CustomResourceConfig): + logger.info( + f"Skipping custom-launcher dry-run for {component.value}: " + f"{type(component_config).__name__} is not a CustomResourceConfig." + ) + return + + logger.info( + f"Invoking custom launcher dry-run for {component.value} via " + f"{component_config.command!r}." + ) + launch_custom( + custom_resource_config=component_config, + applied_task_identifier=applied_task_identifier, + task_config_uri=task_config_uri, + resource_config_uri=resource_config_uri, + process_command="", + process_runtime_args={}, + cpu_docker_uri=cpu_docker_uri, + cuda_docker_uri=cuda_docker_uri, + component=component, + is_dry_run=True, + ) + + def check_if_trainer_graph_store_storage_command_valid( gbml_config_pb_wrapper: GbmlConfigPbWrapper, ) -> None: diff --git a/proto/snapchat/research/gbml/gigl_resource_config.proto b/proto/snapchat/research/gbml/gigl_resource_config.proto index 0d930949b..fd59d9753 100644 --- a/proto/snapchat/research/gbml/gigl_resource_config.proto +++ b/proto/snapchat/research/gbml/gigl_resource_config.proto @@ -166,6 +166,25 @@ message VertexAiResourceConfig { // If unset, and no accelerators are available, will use 1. int32 compute_cluster_local_world_size = 3; } + +// Lets user-defined launchers be piped in. +// The launcher dispatcher invokes `command` (interpreted by /bin/sh -c so +// leading "KEY=VALUE" assignments parse as inline env vars) with `args` +// appended as positional arguments. String fields support OmegaConf +// `${gigl:}` substitutions, which the dispatcher resolves at exec +// time from the runtime context (task_config_uri, applied_task_identifier, +// component, etc.). +message CustomResourceConfig { + // Shell snippet invoked via /bin/sh -c. Leading "KEY=VALUE" assignments + // are honored by the shell, so callers can inline env vars (e.g. + // "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python -m my.cli"). + string command = 1; + // Positional arguments appended after the command. Each element is + // shell-quoted by the dispatcher so values containing spaces/quotes + // survive the shell pass. + repeated string args = 2; +} + // (deprecated) // Configuration for distributed training resources message DistributedTrainerConfig { @@ -183,6 +202,7 @@ message TrainerResourceConfig { KFPResourceConfig kfp_trainer_config = 2; LocalResourceConfig local_trainer_config = 3; VertexAiGraphStoreConfig vertex_ai_graph_store_trainer_config = 4; + CustomResourceConfig custom_trainer_config = 5; } } @@ -193,6 +213,7 @@ message InferencerResourceConfig { DataflowResourceConfig dataflow_inferencer_config = 2; LocalResourceConfig local_inferencer_config = 3; VertexAiGraphStoreConfig vertex_ai_graph_store_inferencer_config = 4; + CustomResourceConfig custom_inferencer_config = 5; } } diff --git a/proto/snapchat/research/gbml/postprocessed_metadata.proto b/proto/snapchat/research/gbml/postprocessed_metadata.proto index c03b7caed..a093ecbb3 100644 --- a/proto/snapchat/research/gbml/postprocessed_metadata.proto +++ b/proto/snapchat/research/gbml/postprocessed_metadata.proto @@ -5,4 +5,4 @@ package snapchat.research.gbml; message PostProcessedMetadata{ // The path to the post processor evaluation results string post_processor_log_metrics_uri = 1; -} \ No newline at end of file +} diff --git a/proto/snapchat/research/gbml/subgraph_sampling_strategy.proto b/proto/snapchat/research/gbml/subgraph_sampling_strategy.proto index 6f55e457b..19fbc5665 100644 --- a/proto/snapchat/research/gbml/subgraph_sampling_strategy.proto +++ b/proto/snapchat/research/gbml/subgraph_sampling_strategy.proto @@ -4,7 +4,7 @@ package snapchat.research.gbml; import "snapchat/research/gbml/graph_schema.proto"; -message RandomUniform { // Randomly sample nodes from the neighborhood without replacement. +message RandomUniform { // Randomly sample nodes from the neighborhood without replacement. int32 num_nodes_to_sample = 1; } diff --git a/scala/.scalafix.conf b/scala/.scalafix.conf index 195ccd431..9e45edd17 100644 --- a/scala/.scalafix.conf +++ b/scala/.scalafix.conf @@ -1,7 +1,7 @@ rules = [ ExplicitResultTypes NoValInForComprehension, - OrganizeImports, + OrganizeImports, ProcedureSyntax, RedundantSyntax, ] diff --git a/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadata.scala b/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadata.scala index 0cafea1ca..afe385a0d 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadata.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadata.scala @@ -38,7 +38,7 @@ final case class DatasetMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { outputMetadata.supervisedNodeClassificationDataset.foreach { __v => @@ -165,7 +165,7 @@ object DatasetMetadata extends scalapb.GeneratedMessageCompanion[snapchat.resear override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class SupervisedNodeClassificationDataset(value: snapchat.research.gbml.dataset_metadata.SupervisedNodeClassificationDataset) extends snapchat.research.gbml.dataset_metadata.DatasetMetadata.OutputMetadata { type ValueType = snapchat.research.gbml.dataset_metadata.SupervisedNodeClassificationDataset diff --git a/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadataProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadataProto.scala index 7f2bfe943..5393b2b91 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadataProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadataProto.scala @@ -60,4 +60,4 @@ object DatasetMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/NodeAnchorBasedLinkPredictionDataset.scala b/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/NodeAnchorBasedLinkPredictionDataset.scala index 63cd8933f..44fd8509f 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/NodeAnchorBasedLinkPredictionDataset.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/NodeAnchorBasedLinkPredictionDataset.scala @@ -21,21 +21,21 @@ final case class NodeAnchorBasedLinkPredictionDataset( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = trainMainDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = testMainDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = valMainDataUri if (!__value.isEmpty) { @@ -64,7 +64,7 @@ final case class NodeAnchorBasedLinkPredictionDataset( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -250,14 +250,14 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -274,7 +274,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -318,7 +318,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp def companion: snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry.type = snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry]) } - + object TrainNodeTypeToRandomNegativeDataUriEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry = { @@ -383,7 +383,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry]) } - + @SerialVersionUID(0L) final case class ValNodeTypeToRandomNegativeDataUriEntry( key: _root_.scala.Predef.String = "", @@ -394,14 +394,14 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -418,7 +418,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -462,7 +462,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp def companion: snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry.type = snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry]) } - + object ValNodeTypeToRandomNegativeDataUriEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry = { @@ -527,7 +527,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry]) } - + @SerialVersionUID(0L) final case class TestNodeTypeToRandomNegativeDataUriEntry( key: _root_.scala.Predef.String = "", @@ -538,14 +538,14 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -562,7 +562,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -606,7 +606,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp def companion: snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry.type = snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry]) } - + object TestNodeTypeToRandomNegativeDataUriEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry = { @@ -671,7 +671,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry]) } - + implicit class NodeAnchorBasedLinkPredictionDatasetLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset](_l) { def trainMainDataUri: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.trainMainDataUri)((c_, f_) => c_.copy(trainMainDataUri = f_)) def testMainDataUri: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.testMainDataUri)((c_, f_) => c_.copy(testMainDataUri = f_)) diff --git a/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedLinkBasedTaskSplitDataset.scala b/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedLinkBasedTaskSplitDataset.scala index 2aca6bc33..453be5b92 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedLinkBasedTaskSplitDataset.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedLinkBasedTaskSplitDataset.scala @@ -18,21 +18,21 @@ final case class SupervisedLinkBasedTaskSplitDataset( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = trainDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = testDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = valDataUri if (!__value.isEmpty) { @@ -49,7 +49,7 @@ final case class SupervisedLinkBasedTaskSplitDataset( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedNodeClassificationDataset.scala b/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedNodeClassificationDataset.scala index ac5f71b11..2411d36d2 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedNodeClassificationDataset.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedNodeClassificationDataset.scala @@ -18,21 +18,21 @@ final case class SupervisedNodeClassificationDataset( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = trainDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = testDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = valDataUri if (!__value.isEmpty) { @@ -49,7 +49,7 @@ final case class SupervisedNodeClassificationDataset( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadata.scala b/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadata.scala index 0395d2bd1..6f8ab4d22 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadata.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadata.scala @@ -38,7 +38,7 @@ final case class FlattenedGraphMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { outputMetadata.supervisedNodeClassificationOutput.foreach { __v => @@ -165,7 +165,7 @@ object FlattenedGraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class SupervisedNodeClassificationOutput(value: snapchat.research.gbml.flattened_graph_metadata.SupervisedNodeClassificationOutput) extends snapchat.research.gbml.flattened_graph_metadata.FlattenedGraphMetadata.OutputMetadata { type ValueType = snapchat.research.gbml.flattened_graph_metadata.SupervisedNodeClassificationOutput diff --git a/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadataProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadataProto.scala index b0102ad23..1018c5787 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadataProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadataProto.scala @@ -50,4 +50,4 @@ object FlattenedGraphMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/NodeAnchorBasedLinkPredictionOutput.scala b/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/NodeAnchorBasedLinkPredictionOutput.scala index c31af112a..375313e97 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/NodeAnchorBasedLinkPredictionOutput.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/NodeAnchorBasedLinkPredictionOutput.scala @@ -22,7 +22,7 @@ final case class NodeAnchorBasedLinkPredictionOutput( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = tfrecordUriPrefix if (!__value.isEmpty) { @@ -43,7 +43,7 @@ final case class NodeAnchorBasedLinkPredictionOutput( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -153,14 +153,14 @@ object NodeAnchorBasedLinkPredictionOutput extends scalapb.GeneratedMessageCompa private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -177,7 +177,7 @@ object NodeAnchorBasedLinkPredictionOutput extends scalapb.GeneratedMessageCompa __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -221,7 +221,7 @@ object NodeAnchorBasedLinkPredictionOutput extends scalapb.GeneratedMessageCompa def companion: snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry.type = snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry]) } - + object NodeTypeToRandomNegativeTfrecordUriPrefixEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry = { @@ -286,7 +286,7 @@ object NodeAnchorBasedLinkPredictionOutput extends scalapb.GeneratedMessageCompa ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry]) } - + implicit class NodeAnchorBasedLinkPredictionOutputLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput](_l) { def tfrecordUriPrefix: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.tfrecordUriPrefix)((c_, f_) => c_.copy(tfrecordUriPrefix = f_)) def nodeTypeToRandomNegativeTfrecordUriPrefix: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = field(_.nodeTypeToRandomNegativeTfrecordUriPrefix)((c_, f_) => c_.copy(nodeTypeToRandomNegativeTfrecordUriPrefix = f_)) diff --git a/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedLinkBasedTaskOutput.scala b/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedLinkBasedTaskOutput.scala index 8cfd94948..7af22ab07 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedLinkBasedTaskOutput.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedLinkBasedTaskOutput.scala @@ -20,14 +20,14 @@ final case class SupervisedLinkBasedTaskOutput( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = labeledTfrecordUriPrefix if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = unlabeledTfrecordUriPrefix if (!__value.isEmpty) { @@ -44,7 +44,7 @@ final case class SupervisedLinkBasedTaskOutput( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedNodeClassificationOutput.scala b/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedNodeClassificationOutput.scala index cb3ed0bb3..538828f18 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedNodeClassificationOutput.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedNodeClassificationOutput.scala @@ -20,14 +20,14 @@ final case class SupervisedNodeClassificationOutput( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = labeledTfrecordUriPrefix if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = unlabeledTfrecordUriPrefix if (!__value.isEmpty) { @@ -44,7 +44,7 @@ final case class SupervisedNodeClassificationOutput( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/Component.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/Component.scala index 871d66432..849972498 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/Component.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/Component.scala @@ -24,63 +24,63 @@ sealed abstract class Component(val value: _root_.scala.Int) extends _root_.scal object Component extends _root_.scalapb.GeneratedEnumCompanion[Component] { sealed trait Recognized extends Component implicit def enumCompanion: _root_.scalapb.GeneratedEnumCompanion[Component] = this - + @SerialVersionUID(0L) case object Component_Unknown extends Component(0) with Component.Recognized { val index = 0 val name = "Component_Unknown" override def isComponentUnknown: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Config_Validator extends Component(1) with Component.Recognized { val index = 1 val name = "Component_Config_Validator" override def isComponentConfigValidator: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Config_Populator extends Component(2) with Component.Recognized { val index = 2 val name = "Component_Config_Populator" override def isComponentConfigPopulator: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Data_Preprocessor extends Component(3) with Component.Recognized { val index = 3 val name = "Component_Data_Preprocessor" override def isComponentDataPreprocessor: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Subgraph_Sampler extends Component(4) with Component.Recognized { val index = 4 val name = "Component_Subgraph_Sampler" override def isComponentSubgraphSampler: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Split_Generator extends Component(5) with Component.Recognized { val index = 5 val name = "Component_Split_Generator" override def isComponentSplitGenerator: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Trainer extends Component(6) with Component.Recognized { val index = 6 val name = "Component_Trainer" override def isComponentTrainer: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Inferencer extends Component(7) with Component.Recognized { val index = 7 val name = "Component_Inferencer" override def isComponentInferencer: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) final case class Unrecognized(unrecognizedValue: _root_.scala.Int) extends Component(unrecognizedValue) with _root_.scalapb.UnrecognizedEnum lazy val values = scala.collection.immutable.Seq(Component_Unknown, Component_Config_Validator, Component_Config_Populator, Component_Data_Preprocessor, Component_Subgraph_Sampler, Component_Split_Generator, Component_Trainer, Component_Inferencer) @@ -97,4 +97,4 @@ object Component extends _root_.scalapb.GeneratedEnumCompanion[Component] { } def javaDescriptor: _root_.com.google.protobuf.Descriptors.EnumDescriptor = GiglResourceConfigProto.javaDescriptor.getEnumTypes().get(0) def scalaDescriptor: _root_.scalapb.descriptors.EnumDescriptor = GiglResourceConfigProto.scalaDescriptor.enums(0) -} \ No newline at end of file +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/CustomResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/CustomResourceConfig.scala new file mode 100644 index 000000000..33ca7b857 --- /dev/null +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/CustomResourceConfig.scala @@ -0,0 +1,159 @@ +// Generated by the Scala Plugin for the Protocol Buffer Compiler. +// Do not edit! +// +// Protofile syntax: PROTO3 + +package snapchat.research.gbml.gigl_resource_config + +/** Lets user-defined launchers be piped in. + * The launcher dispatcher invokes `command` (interpreted by /bin/sh -c so + * leading "KEY=VALUE" assignments parse as inline env vars) with `args` + * appended as positional arguments. String fields support OmegaConf + * `${gigl:<key>}` substitutions, which the dispatcher resolves at exec + * time from the runtime context (task_config_uri, applied_task_identifier, + * component, etc.). + * + * @param command + * Shell snippet invoked via /bin/sh -c. Leading "KEY=VALUE" assignments + * are honored by the shell, so callers can inline env vars (e.g. + * "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python -m my.cli"). + * @param args + * Positional arguments appended after the command. Each element is + * shell-quoted by the dispatcher so values containing spaces/quotes + * survive the shell pass. + */ +@SerialVersionUID(0L) +final case class CustomResourceConfig( + command: _root_.scala.Predef.String = "", + args: _root_.scala.Seq[_root_.scala.Predef.String] = _root_.scala.Seq.empty, + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[CustomResourceConfig] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + + { + val __value = command + if (!__value.isEmpty) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) + } + }; + args.foreach { __item => + val __value = __item + __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) + } + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + { + val __v = command + if (!__v.isEmpty) { + _output__.writeString(1, __v) + } + }; + args.foreach { __v => + val __m = __v + _output__.writeString(2, __m) + }; + unknownFields.writeTo(_output__) + } + def withCommand(__v: _root_.scala.Predef.String): CustomResourceConfig = copy(command = __v) + def clearArgs = copy(args = _root_.scala.Seq.empty) + def addArgs(__vs: _root_.scala.Predef.String *): CustomResourceConfig = addAllArgs(__vs) + def addAllArgs(__vs: Iterable[_root_.scala.Predef.String]): CustomResourceConfig = copy(args = args ++ __vs) + def withArgs(__v: _root_.scala.Seq[_root_.scala.Predef.String]): CustomResourceConfig = copy(args = __v) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => { + val __t = command + if (__t != "") __t else null + } + case 2 => args + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => _root_.scalapb.descriptors.PString(command) + case 2 => _root_.scalapb.descriptors.PRepeated(args.iterator.map(_root_.scalapb.descriptors.PString(_)).toVector) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig.type = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.CustomResourceConfig]) +} + +object CustomResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.CustomResourceConfig = { + var __command: _root_.scala.Predef.String = "" + val __args: _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String] = new _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String] + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __command = _input__.readStringRequireUtf8() + case 18 => + __args += _input__.readStringRequireUtf8() + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gigl_resource_config.CustomResourceConfig( + command = __command, + args = __args.result(), + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gigl_resource_config.CustomResourceConfig( + command = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Predef.String]).getOrElse(""), + args = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).map(_.as[_root_.scala.Seq[_root_.scala.Predef.String]]).getOrElse(_root_.scala.Seq.empty) + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(11) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(11) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = throw new MatchError(__number) + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig( + command = "", + args = _root_.scala.Seq.empty + ) + implicit class CustomResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.CustomResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.CustomResourceConfig](_l) { + def command: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.command)((c_, f_) => c_.copy(command = f_)) + def args: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[_root_.scala.Predef.String]] = field(_.args)((c_, f_) => c_.copy(args = f_)) + } + final val COMMAND_FIELD_NUMBER = 1 + final val ARGS_FIELD_NUMBER = 2 + def of( + command: _root_.scala.Predef.String, + args: _root_.scala.Seq[_root_.scala.Predef.String] + ): _root_.snapchat.research.gbml.gigl_resource_config.CustomResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.CustomResourceConfig( + command, + args + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.CustomResourceConfig]) +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DataPreprocessorConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DataPreprocessorConfig.scala index 75c2e54af..7d9951c80 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DataPreprocessorConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DataPreprocessorConfig.scala @@ -35,7 +35,7 @@ final case class DataPreprocessorConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { edgePreprocessorConfig.foreach { __v => diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedInferencerConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedInferencerConfig.scala index 8363bdb1f..2198a2eb5 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedInferencerConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedInferencerConfig.scala @@ -38,7 +38,7 @@ final case class DistributedInferencerConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { trainerConfig.vertexAiInferencerConfig.foreach { __v => @@ -165,7 +165,7 @@ object DistributedInferencerConfig extends scalapb.GeneratedMessageCompanion[sna override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class VertexAiInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig) extends snapchat.research.gbml.gigl_resource_config.DistributedInferencerConfig.TrainerConfig { type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala index 676b61794..60313b1cc 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala @@ -131,8 +131,8 @@ object DistributedTrainerConfig extends scalapb.GeneratedMessageCompanion[snapch ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(11) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(11) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(12) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(12) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala index d88d363e9..16ff1d6a6 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala @@ -275,8 +275,8 @@ object GiglResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.res ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(15) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(15) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(16) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(16) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala index a086f6113..603a940e4 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala @@ -20,6 +20,7 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { snapchat.research.gbml.gigl_resource_config.KFPResourceConfig, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig, + snapchat.research.gbml.gigl_resource_config.CustomResourceConfig, snapchat.research.gbml.gigl_resource_config.DistributedTrainerConfig, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig, @@ -65,65 +66,70 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { AEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0IT4j8QEg5ncmFwaFN0b3JlUG9vbFIOZ 3JhcGhTdG9yZVBvb2wSYwoMY29tcHV0ZV9wb29sGAIgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291c mNlQ29uZmlnQhDiPw0SC2NvbXB1dGVQb29sUgtjb21wdXRlUG9vbBJpCiBjb21wdXRlX2NsdXN0ZXJfbG9jYWxfd29ybGRfc2l6Z - RgDIAEoBUIh4j8eEhxjb21wdXRlQ2x1c3RlckxvY2FsV29ybGRTaXplUhxjb21wdXRlQ2x1c3RlckxvY2FsV29ybGRTaXplIp0DC - hhEaXN0cmlidXRlZFRyYWluZXJDb25maWcShAEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzItLnNuYXBjaGF0LnJlc - 2VhcmNoLmdibWwuVmVydGV4QWlUcmFpbmVyQ29uZmlnQhriPxcSFXZlcnRleEFpVHJhaW5lckNvbmZpZ0gAUhV2ZXJ0ZXhBaVRyY - WluZXJDb25maWcSbwoSa2ZwX3RyYWluZXJfY29uZmlnGAIgASgLMiguc25hcGNoYXQucmVzZWFyY2guZ2JtbC5LRlBUcmFpbmVyQ - 29uZmlnQhXiPxISEGtmcFRyYWluZXJDb25maWdIAFIQa2ZwVHJhaW5lckNvbmZpZxJ3ChRsb2NhbF90cmFpbmVyX2NvbmZpZxgDI - AEoCzIqLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxUcmFpbmVyQ29uZmlnQhfiPxQSEmxvY2FsVHJhaW5lckNvbmZpZ0gAU - hJsb2NhbFRyYWluZXJDb25maWdCEAoOdHJhaW5lcl9jb25maWcixwQKFVRyYWluZXJSZXNvdXJjZUNvbmZpZxKFAQoYdmVydGV4X - 2FpX3RyYWluZXJfY29uZmlnGAEgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQhriP - xcSFXZlcnRleEFpVHJhaW5lckNvbmZpZ0gAUhV2ZXJ0ZXhBaVRyYWluZXJDb25maWcScAoSa2ZwX3RyYWluZXJfY29uZmlnGAIgA - SgLMikuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5LRlBSZXNvdXJjZUNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ29uZmlnSABSEGtmc - FRyYWluZXJDb25maWcSeAoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsU - mVzb3VyY2VDb25maWdCF+I/FBISbG9jYWxUcmFpbmVyQ29uZmlnSABSEmxvY2FsVHJhaW5lckNvbmZpZxKnAQokdmVydGV4X2FpX - 2dyYXBoX3N0b3JlX3RyYWluZXJfY29uZmlnGAQgASgLMjAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaUdyYXBoU3Rvc - mVDb25maWdCJOI/IRIfdmVydGV4QWlHcmFwaFN0b3JlVHJhaW5lckNvbmZpZ0gAUh92ZXJ0ZXhBaUdyYXBoU3RvcmVUcmFpbmVyQ - 29uZmlnQhAKDnRyYWluZXJfY29uZmlnIocFChhJbmZlcmVuY2VyUmVzb3VyY2VDb25maWcSjgEKG3ZlcnRleF9haV9pbmZlcmVuY - 2VyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Id4j8aEhh2ZXJ0Z - XhBaUluZmVyZW5jZXJDb25maWdIAFIYdmVydGV4QWlJbmZlcmVuY2VyQ29uZmlnEo0BChpkYXRhZmxvd19pbmZlcmVuY2VyX2Nvb - mZpZxgCIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YWZsb3dSZXNvdXJjZUNvbmZpZ0Id4j8aEhhkYXRhZmxvd0luZ - mVyZW5jZXJDb25maWdIAFIYZGF0YWZsb3dJbmZlcmVuY2VyQ29uZmlnEoEBChdsb2NhbF9pbmZlcmVuY2VyX2NvbmZpZxgDIAEoC - zIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxSZXNvdXJjZUNvbmZpZ0Ia4j8XEhVsb2NhbEluZmVyZW5jZXJDb25maWdIA - FIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnErABCid2ZXJ0ZXhfYWlfZ3JhcGhfc3RvcmVfaW5mZXJlbmNlcl9jb25maWcYBCABKAsyM - C5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpR3JhcGhTdG9yZUNvbmZpZ0In4j8kEiJ2ZXJ0ZXhBaUdyYXBoU3RvcmVJb - mZlcmVuY2VyQ29uZmlnSABSInZlcnRleEFpR3JhcGhTdG9yZUluZmVyZW5jZXJDb25maWdCEwoRaW5mZXJlbmNlcl9jb25maWcil - wgKFFNoYXJlZFJlc291cmNlQ29uZmlnEn4KD3Jlc291cmNlX2xhYmVscxgBIAMoCzJALnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU - 2hhcmVkUmVzb3VyY2VDb25maWcuUmVzb3VyY2VMYWJlbHNFbnRyeUIT4j8QEg5yZXNvdXJjZUxhYmVsc1IOcmVzb3VyY2VMYWJlb - HMSjgEKFWNvbW1vbl9jb21wdXRlX2NvbmZpZxgCIAEoCzJALnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb - 25maWcuQ29tbW9uQ29tcHV0ZUNvbmZpZ0IY4j8VEhNjb21tb25Db21wdXRlQ29uZmlnUhNjb21tb25Db21wdXRlQ29uZmlnGpQFC - hNDb21tb25Db21wdXRlQ29uZmlnEiYKB3Byb2plY3QYASABKAlCDOI/CRIHcHJvamVjdFIHcHJvamVjdBIjCgZyZWdpb24YAiABK - AlCC+I/CBIGcmVnaW9uUgZyZWdpb24SQwoSdGVtcF9hc3NldHNfYnVja2V0GAMgASgJQhXiPxISEHRlbXBBc3NldHNCdWNrZXRSE - HRlbXBBc3NldHNCdWNrZXQSXAobdGVtcF9yZWdpb25hbF9hc3NldHNfYnVja2V0GAQgASgJQh3iPxoSGHRlbXBSZWdpb25hbEFzc - 2V0c0J1Y2tldFIYdGVtcFJlZ2lvbmFsQXNzZXRzQnVja2V0EkMKEnBlcm1fYXNzZXRzX2J1Y2tldBgFIAEoCUIV4j8SEhBwZXJtQ - XNzZXRzQnVja2V0UhBwZXJtQXNzZXRzQnVja2V0EloKG3RlbXBfYXNzZXRzX2JxX2RhdGFzZXRfbmFtZRgGIAEoCUIc4j8ZEhd0Z - W1wQXNzZXRzQnFEYXRhc2V0TmFtZVIXdGVtcEFzc2V0c0JxRGF0YXNldE5hbWUSVgoZZW1iZWRkaW5nX2JxX2RhdGFzZXRfbmFtZ - RgHIAEoCUIb4j8YEhZlbWJlZGRpbmdCcURhdGFzZXROYW1lUhZlbWJlZGRpbmdCcURhdGFzZXROYW1lElYKGWdjcF9zZXJ2aWNlX - 2FjY291bnRfZW1haWwYCCABKAlCG+I/GBIWZ2NwU2VydmljZUFjY291bnRFbWFpbFIWZ2NwU2VydmljZUFjY291bnRFbWFpbBI8C - g9kYXRhZmxvd19ydW5uZXIYCyABKAlCE+I/EBIOZGF0YWZsb3dSdW5uZXJSDmRhdGFmbG93UnVubmVyGlcKE1Jlc291cmNlTGFiZ - WxzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEi9 - wgKEkdpZ2xSZXNvdXJjZUNvbmZpZxJbChpzaGFyZWRfcmVzb3VyY2VfY29uZmlnX3VyaRgBIAEoCUIc4j8ZEhdzaGFyZWRSZXNvd - XJjZUNvbmZpZ1VyaUgAUhdzaGFyZWRSZXNvdXJjZUNvbmZpZ1VyaRJ/ChZzaGFyZWRfcmVzb3VyY2VfY29uZmlnGAIgASgLMiwuc - 25hcGNoYXQucmVzZWFyY2guZ2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZ0IZ4j8WEhRzaGFyZWRSZXNvdXJjZUNvbmZpZ0gAUhRza - GFyZWRSZXNvdXJjZUNvbmZpZxJ4ChNwcmVwcm9jZXNzb3JfY29uZmlnGAwgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EY - XRhUHJlcHJvY2Vzc29yQ29uZmlnQhfiPxQSEnByZXByb2Nlc3NvckNvbmZpZ1IScHJlcHJvY2Vzc29yQ29uZmlnEn8KF3N1YmdyY - XBoX3NhbXBsZXJfY29uZmlnGA0gASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TcGFya1Jlc291cmNlQ29uZmlnQhriPxcSF - XN1YmdyYXBoU2FtcGxlckNvbmZpZ1IVc3ViZ3JhcGhTYW1wbGVyQ29uZmlnEnwKFnNwbGl0X2dlbmVyYXRvcl9jb25maWcYDiABK - AsyKy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlNwYXJrUmVzb3VyY2VDb25maWdCGeI/FhIUc3BsaXRHZW5lcmF0b3JDb25maWdSF - HNwbGl0R2VuZXJhdG9yQ29uZmlnEm0KDnRyYWluZXJfY29uZmlnGA8gASgLMjAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EaXN0c - mlidXRlZFRyYWluZXJDb25maWdCFBgB4j8PEg10cmFpbmVyQ29uZmlnUg10cmFpbmVyQ29uZmlnEnQKEWluZmVyZW5jZXJfY29uZ - mlnGBAgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EYXRhZmxvd1Jlc291cmNlQ29uZmlnQhcYAeI/EhIQaW5mZXJlbmNlc - kNvbmZpZ1IQaW5mZXJlbmNlckNvbmZpZxKBAQoXdHJhaW5lcl9yZXNvdXJjZV9jb25maWcYESABKAsyLS5zbmFwY2hhdC5yZXNlY - XJjaC5nYm1sLlRyYWluZXJSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV0cmFpbmVyUmVzb3VyY2VDb25maWdSFXRyYWluZXJSZXNvdXJjZ - UNvbmZpZxKNAQoaaW5mZXJlbmNlcl9yZXNvdXJjZV9jb25maWcYEiABKAsyMC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkluZmVyZ - W5jZXJSZXNvdXJjZUNvbmZpZ0Id4j8aEhhpbmZlcmVuY2VyUmVzb3VyY2VDb25maWdSGGluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZ - 0IRCg9zaGFyZWRfcmVzb3VyY2Uq4wMKCUNvbXBvbmVudBItChFDb21wb25lbnRfVW5rbm93bhAAGhbiPxMSEUNvbXBvbmVudF9Vb - mtub3duEj8KGkNvbXBvbmVudF9Db25maWdfVmFsaWRhdG9yEAEaH+I/HBIaQ29tcG9uZW50X0NvbmZpZ19WYWxpZGF0b3ISPwoaQ - 29tcG9uZW50X0NvbmZpZ19Qb3B1bGF0b3IQAhof4j8cEhpDb21wb25lbnRfQ29uZmlnX1BvcHVsYXRvchJBChtDb21wb25lbnRfR - GF0YV9QcmVwcm9jZXNzb3IQAxog4j8dEhtDb21wb25lbnRfRGF0YV9QcmVwcm9jZXNzb3ISPwoaQ29tcG9uZW50X1N1YmdyYXBoX - 1NhbXBsZXIQBBof4j8cEhpDb21wb25lbnRfU3ViZ3JhcGhfU2FtcGxlchI9ChlDb21wb25lbnRfU3BsaXRfR2VuZXJhdG9yEAUaH - uI/GxIZQ29tcG9uZW50X1NwbGl0X0dlbmVyYXRvchItChFDb21wb25lbnRfVHJhaW5lchAGGhbiPxMSEUNvbXBvbmVudF9UcmFpb - mVyEjMKFENvbXBvbmVudF9JbmZlcmVuY2VyEAcaGeI/FhIUQ29tcG9uZW50X0luZmVyZW5jZXJiBnByb3RvMw==""" + RgDIAEoBUIh4j8eEhxjb21wdXRlQ2x1c3RlckxvY2FsV29ybGRTaXplUhxjb21wdXRlQ2x1c3RlckxvY2FsV29ybGRTaXplIl0KF + EN1c3RvbVJlc291cmNlQ29uZmlnEiYKB2NvbW1hbmQYASABKAlCDOI/CRIHY29tbWFuZFIHY29tbWFuZBIdCgRhcmdzGAIgAygJQ + gniPwYSBGFyZ3NSBGFyZ3MinQMKGERpc3RyaWJ1dGVkVHJhaW5lckNvbmZpZxKEAQoYdmVydGV4X2FpX3RyYWluZXJfY29uZmlnG + AEgASgLMi0uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVRyYWluZXJDb25maWdCGuI/FxIVdmVydGV4QWlUcmFpbmVyQ + 29uZmlnSABSFXZlcnRleEFpVHJhaW5lckNvbmZpZxJvChJrZnBfdHJhaW5lcl9jb25maWcYAiABKAsyKC5zbmFwY2hhdC5yZXNlY + XJjaC5nYm1sLktGUFRyYWluZXJDb25maWdCFeI/EhIQa2ZwVHJhaW5lckNvbmZpZ0gAUhBrZnBUcmFpbmVyQ29uZmlnEncKFGxvY + 2FsX3RyYWluZXJfY29uZmlnGAMgASgLMiouc25hcGNoYXQucmVzZWFyY2guZ2JtbC5Mb2NhbFRyYWluZXJDb25maWdCF+I/FBISb + G9jYWxUcmFpbmVyQ29uZmlnSABSEmxvY2FsVHJhaW5lckNvbmZpZ0IQCg50cmFpbmVyX2NvbmZpZyLFBQoVVHJhaW5lclJlc291c + mNlQ29uZmlnEoUBChh2ZXJ0ZXhfYWlfdHJhaW5lcl9jb25maWcYASABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRle + EFpUmVzb3VyY2VDb25maWdCGuI/FxIVdmVydGV4QWlUcmFpbmVyQ29uZmlnSABSFXZlcnRleEFpVHJhaW5lckNvbmZpZxJwChJrZ + nBfdHJhaW5lcl9jb25maWcYAiABKAsyKS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLktGUFJlc291cmNlQ29uZmlnQhXiPxISEGtmc + FRyYWluZXJDb25maWdIAFIQa2ZwVHJhaW5lckNvbmZpZxJ4ChRsb2NhbF90cmFpbmVyX2NvbmZpZxgDIAEoCzIrLnNuYXBjaGF0L + nJlc2VhcmNoLmdibWwuTG9jYWxSZXNvdXJjZUNvbmZpZ0IX4j8UEhJsb2NhbFRyYWluZXJDb25maWdIAFISbG9jYWxUcmFpbmVyQ + 29uZmlnEqcBCiR2ZXJ0ZXhfYWlfZ3JhcGhfc3RvcmVfdHJhaW5lcl9jb25maWcYBCABKAsyMC5zbmFwY2hhdC5yZXNlYXJjaC5nY + m1sLlZlcnRleEFpR3JhcGhTdG9yZUNvbmZpZ0Ik4j8hEh92ZXJ0ZXhBaUdyYXBoU3RvcmVUcmFpbmVyQ29uZmlnSABSH3ZlcnRle + EFpR3JhcGhTdG9yZVRyYWluZXJDb25maWcSfAoVY3VzdG9tX3RyYWluZXJfY29uZmlnGAUgASgLMiwuc25hcGNoYXQucmVzZWFyY + 2guZ2JtbC5DdXN0b21SZXNvdXJjZUNvbmZpZ0IY4j8VEhNjdXN0b21UcmFpbmVyQ29uZmlnSABSE2N1c3RvbVRyYWluZXJDb25ma + WdCEAoOdHJhaW5lcl9jb25maWcijwYKGEluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZxKOAQobdmVydGV4X2FpX2luZmVyZW5jZXJfY + 29uZmlnGAEgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQh3iPxoSGHZlcnRleEFpS + W5mZXJlbmNlckNvbmZpZ0gAUhh2ZXJ0ZXhBaUluZmVyZW5jZXJDb25maWcSjQEKGmRhdGFmbG93X2luZmVyZW5jZXJfY29uZmlnG + AIgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EYXRhZmxvd1Jlc291cmNlQ29uZmlnQh3iPxoSGGRhdGFmbG93SW5mZXJlb + mNlckNvbmZpZ0gAUhhkYXRhZmxvd0luZmVyZW5jZXJDb25maWcSgQEKF2xvY2FsX2luZmVyZW5jZXJfY29uZmlnGAMgASgLMisuc + 25hcGNoYXQucmVzZWFyY2guZ2JtbC5Mb2NhbFJlc291cmNlQ29uZmlnQhriPxcSFWxvY2FsSW5mZXJlbmNlckNvbmZpZ0gAUhVsb + 2NhbEluZmVyZW5jZXJDb25maWcSsAEKJ3ZlcnRleF9haV9ncmFwaF9zdG9yZV9pbmZlcmVuY2VyX2NvbmZpZxgEIAEoCzIwLnNuY + XBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlHcmFwaFN0b3JlQ29uZmlnQifiPyQSInZlcnRleEFpR3JhcGhTdG9yZUluZmVyZ + W5jZXJDb25maWdIAFIidmVydGV4QWlHcmFwaFN0b3JlSW5mZXJlbmNlckNvbmZpZxKFAQoYY3VzdG9tX2luZmVyZW5jZXJfY29uZ + mlnGAUgASgLMiwuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5DdXN0b21SZXNvdXJjZUNvbmZpZ0Ib4j8YEhZjdXN0b21JbmZlcmVuY + 2VyQ29uZmlnSABSFmN1c3RvbUluZmVyZW5jZXJDb25maWdCEwoRaW5mZXJlbmNlcl9jb25maWcilwgKFFNoYXJlZFJlc291cmNlQ + 29uZmlnEn4KD3Jlc291cmNlX2xhYmVscxgBIAMoCzJALnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25ma + WcuUmVzb3VyY2VMYWJlbHNFbnRyeUIT4j8QEg5yZXNvdXJjZUxhYmVsc1IOcmVzb3VyY2VMYWJlbHMSjgEKFWNvbW1vbl9jb21wd + XRlX2NvbmZpZxgCIAEoCzJALnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25maWcuQ29tbW9uQ29tcHV0Z + UNvbmZpZ0IY4j8VEhNjb21tb25Db21wdXRlQ29uZmlnUhNjb21tb25Db21wdXRlQ29uZmlnGpQFChNDb21tb25Db21wdXRlQ29uZ + mlnEiYKB3Byb2plY3QYASABKAlCDOI/CRIHcHJvamVjdFIHcHJvamVjdBIjCgZyZWdpb24YAiABKAlCC+I/CBIGcmVnaW9uUgZyZ + Wdpb24SQwoSdGVtcF9hc3NldHNfYnVja2V0GAMgASgJQhXiPxISEHRlbXBBc3NldHNCdWNrZXRSEHRlbXBBc3NldHNCdWNrZXQSX + AobdGVtcF9yZWdpb25hbF9hc3NldHNfYnVja2V0GAQgASgJQh3iPxoSGHRlbXBSZWdpb25hbEFzc2V0c0J1Y2tldFIYdGVtcFJlZ + 2lvbmFsQXNzZXRzQnVja2V0EkMKEnBlcm1fYXNzZXRzX2J1Y2tldBgFIAEoCUIV4j8SEhBwZXJtQXNzZXRzQnVja2V0UhBwZXJtQ + XNzZXRzQnVja2V0EloKG3RlbXBfYXNzZXRzX2JxX2RhdGFzZXRfbmFtZRgGIAEoCUIc4j8ZEhd0ZW1wQXNzZXRzQnFEYXRhc2V0T + mFtZVIXdGVtcEFzc2V0c0JxRGF0YXNldE5hbWUSVgoZZW1iZWRkaW5nX2JxX2RhdGFzZXRfbmFtZRgHIAEoCUIb4j8YEhZlbWJlZ + GRpbmdCcURhdGFzZXROYW1lUhZlbWJlZGRpbmdCcURhdGFzZXROYW1lElYKGWdjcF9zZXJ2aWNlX2FjY291bnRfZW1haWwYCCABK + AlCG+I/GBIWZ2NwU2VydmljZUFjY291bnRFbWFpbFIWZ2NwU2VydmljZUFjY291bnRFbWFpbBI8Cg9kYXRhZmxvd19ydW5uZXIYC + yABKAlCE+I/EBIOZGF0YWZsb3dSdW5uZXJSDmRhdGFmbG93UnVubmVyGlcKE1Jlc291cmNlTGFiZWxzRW50cnkSGgoDa2V5GAEgA + SgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEi9wgKEkdpZ2xSZXNvdXJjZUNvb + mZpZxJbChpzaGFyZWRfcmVzb3VyY2VfY29uZmlnX3VyaRgBIAEoCUIc4j8ZEhdzaGFyZWRSZXNvdXJjZUNvbmZpZ1VyaUgAUhdza + GFyZWRSZXNvdXJjZUNvbmZpZ1VyaRJ/ChZzaGFyZWRfcmVzb3VyY2VfY29uZmlnGAIgASgLMiwuc25hcGNoYXQucmVzZWFyY2guZ + 2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZ0IZ4j8WEhRzaGFyZWRSZXNvdXJjZUNvbmZpZ0gAUhRzaGFyZWRSZXNvdXJjZUNvbmZpZ + xJ4ChNwcmVwcm9jZXNzb3JfY29uZmlnGAwgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EYXRhUHJlcHJvY2Vzc29yQ29uZ + mlnQhfiPxQSEnByZXByb2Nlc3NvckNvbmZpZ1IScHJlcHJvY2Vzc29yQ29uZmlnEn8KF3N1YmdyYXBoX3NhbXBsZXJfY29uZmlnG + A0gASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TcGFya1Jlc291cmNlQ29uZmlnQhriPxcSFXN1YmdyYXBoU2FtcGxlckNvb + mZpZ1IVc3ViZ3JhcGhTYW1wbGVyQ29uZmlnEnwKFnNwbGl0X2dlbmVyYXRvcl9jb25maWcYDiABKAsyKy5zbmFwY2hhdC5yZXNlY + XJjaC5nYm1sLlNwYXJrUmVzb3VyY2VDb25maWdCGeI/FhIUc3BsaXRHZW5lcmF0b3JDb25maWdSFHNwbGl0R2VuZXJhdG9yQ29uZ + mlnEm0KDnRyYWluZXJfY29uZmlnGA8gASgLMjAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EaXN0cmlidXRlZFRyYWluZXJDb25ma + WdCFBgB4j8PEg10cmFpbmVyQ29uZmlnUg10cmFpbmVyQ29uZmlnEnQKEWluZmVyZW5jZXJfY29uZmlnGBAgASgLMi4uc25hcGNoY + XQucmVzZWFyY2guZ2JtbC5EYXRhZmxvd1Jlc291cmNlQ29uZmlnQhcYAeI/EhIQaW5mZXJlbmNlckNvbmZpZ1IQaW5mZXJlbmNlc + kNvbmZpZxKBAQoXdHJhaW5lcl9yZXNvdXJjZV9jb25maWcYESABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlRyYWluZXJSZ + XNvdXJjZUNvbmZpZ0Ia4j8XEhV0cmFpbmVyUmVzb3VyY2VDb25maWdSFXRyYWluZXJSZXNvdXJjZUNvbmZpZxKNAQoaaW5mZXJlb + mNlcl9yZXNvdXJjZV9jb25maWcYEiABKAsyMC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZ + 0Id4j8aEhhpbmZlcmVuY2VyUmVzb3VyY2VDb25maWdSGGluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZ0IRCg9zaGFyZWRfcmVzb3VyY + 2Uq4wMKCUNvbXBvbmVudBItChFDb21wb25lbnRfVW5rbm93bhAAGhbiPxMSEUNvbXBvbmVudF9Vbmtub3duEj8KGkNvbXBvbmVud + F9Db25maWdfVmFsaWRhdG9yEAEaH+I/HBIaQ29tcG9uZW50X0NvbmZpZ19WYWxpZGF0b3ISPwoaQ29tcG9uZW50X0NvbmZpZ19Qb + 3B1bGF0b3IQAhof4j8cEhpDb21wb25lbnRfQ29uZmlnX1BvcHVsYXRvchJBChtDb21wb25lbnRfRGF0YV9QcmVwcm9jZXNzb3IQA + xog4j8dEhtDb21wb25lbnRfRGF0YV9QcmVwcm9jZXNzb3ISPwoaQ29tcG9uZW50X1N1YmdyYXBoX1NhbXBsZXIQBBof4j8cEhpDb + 21wb25lbnRfU3ViZ3JhcGhfU2FtcGxlchI9ChlDb21wb25lbnRfU3BsaXRfR2VuZXJhdG9yEAUaHuI/GxIZQ29tcG9uZW50X1Nwb + Gl0X0dlbmVyYXRvchItChFDb21wb25lbnRfVHJhaW5lchAGGhbiPxMSEUNvbXBvbmVudF9UcmFpbmVyEjMKFENvbXBvbmVudF9Jb + mZlcmVuY2VyEAcaGeI/FhIUQ29tcG9uZW50X0luZmVyZW5jZXJiBnByb3RvMw==""" ).mkString) lazy val scalaDescriptor: _root_.scalapb.descriptors.FileDescriptor = { val scalaProto = com.google.protobuf.descriptor.FileDescriptorProto.parseFrom(ProtoBytes) diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala index 77a949c19..dd637b565 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala @@ -32,6 +32,10 @@ final case class InferencerResourceConfig( val __value = inferencerConfig.vertexAiGraphStoreInferencerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; + if (inferencerConfig.customInferencerConfig.isDefined) { + val __value = inferencerConfig.customInferencerConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -69,6 +73,12 @@ final case class InferencerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; + inferencerConfig.customInferencerConfig.foreach { __v => + val __m = __v + _output__.writeTag(5, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; unknownFields.writeTo(_output__) } def getVertexAiInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = inferencerConfig.vertexAiInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) @@ -79,6 +89,8 @@ final case class InferencerResourceConfig( def withLocalInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__v)) def getVertexAiGraphStoreInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = inferencerConfig.vertexAiGraphStoreInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig.defaultInstance) def withVertexAiGraphStoreInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(__v)) + def getCustomInferencerConfig: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig = inferencerConfig.customInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.CustomResourceConfig.defaultInstance) + def withCustomInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.CustomInferencerConfig(__v)) def clearInferencerConfig: InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) def withInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig): InferencerResourceConfig = copy(inferencerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -89,6 +101,7 @@ final case class InferencerResourceConfig( case 2 => inferencerConfig.dataflowInferencerConfig.orNull case 3 => inferencerConfig.localInferencerConfig.orNull case 4 => inferencerConfig.vertexAiGraphStoreInferencerConfig.orNull + case 5 => inferencerConfig.customInferencerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -98,6 +111,7 @@ final case class InferencerResourceConfig( case 2 => inferencerConfig.dataflowInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => inferencerConfig.localInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 4 => inferencerConfig.vertexAiGraphStoreInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 5 => inferencerConfig.customInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -123,6 +137,8 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__inferencerConfig.localInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 34 => __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(__inferencerConfig.vertexAiGraphStoreInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 42 => + __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.CustomInferencerConfig(__inferencerConfig.customInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -143,12 +159,13 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(5).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.CustomInferencerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(13) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(13) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(14) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(14) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -156,6 +173,7 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch case 2 => __out = snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig + case 5 => __out = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig } __out } @@ -171,10 +189,12 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch def isDataflowInferencerConfig: _root_.scala.Boolean = false def isLocalInferencerConfig: _root_.scala.Boolean = false def isVertexAiGraphStoreInferencerConfig: _root_.scala.Boolean = false + def isCustomInferencerConfig: _root_.scala.Boolean = false def vertexAiInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def dataflowInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = _root_.scala.None def localInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None def vertexAiGraphStoreInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = _root_.scala.None + def customInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = _root_.scala.None } object InferencerConfig { @SerialVersionUID(0L) @@ -214,18 +234,27 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch override def vertexAiGraphStoreInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = Some(value) override def number: _root_.scala.Int = 4 } + @SerialVersionUID(0L) + final case class CustomInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig) extends snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig + override def isCustomInferencerConfig: _root_.scala.Boolean = true + override def customInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = Some(value) + override def number: _root_.scala.Int = 5 + } } implicit class InferencerResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig](_l) { def vertexAiInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiInferencerConfig(f_))) def dataflowInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = field(_.getDataflowInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(f_))) def localInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(f_))) def vertexAiGraphStoreInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = field(_.getVertexAiGraphStoreInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(f_))) + def customInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = field(_.getCustomInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.CustomInferencerConfig(f_))) def inferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig] = field(_.inferencerConfig)((c_, f_) => c_.copy(inferencerConfig = f_)) } final val VERTEX_AI_INFERENCER_CONFIG_FIELD_NUMBER = 1 final val DATAFLOW_INFERENCER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_INFERENCER_CONFIG_FIELD_NUMBER = 3 final val VERTEX_AI_GRAPH_STORE_INFERENCER_CONFIG_FIELD_NUMBER = 4 + final val CUSTOM_INFERENCER_CONFIG_FIELD_NUMBER = 5 def of( inferencerConfig: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig( diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/KFPTrainerConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/KFPTrainerConfig.scala index 909ec979b..1225ba210 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/KFPTrainerConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/KFPTrainerConfig.scala @@ -32,35 +32,35 @@ final case class KFPTrainerConfig( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = cpuRequest if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = memoryRequest if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = gpuType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(3, __value) } }; - + { val __value = gpuLimit if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(4, __value) } }; - + { val __value = numReplicas if (__value != 0) { @@ -77,7 +77,7 @@ final case class KFPTrainerConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/LocalTrainerConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/LocalTrainerConfig.scala index ba2cc9389..86e238074 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/LocalTrainerConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/LocalTrainerConfig.scala @@ -17,7 +17,7 @@ final case class LocalTrainerConfig( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = numWorkers if (__value != 0) { @@ -34,7 +34,7 @@ final case class LocalTrainerConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala index 393ebe301..bdeda8bdb 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala @@ -116,8 +116,8 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(14) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(14) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(15) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(15) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SparkResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SparkResourceConfig.scala index d32c915cb..96f354f47 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SparkResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SparkResourceConfig.scala @@ -25,21 +25,21 @@ final case class SparkResourceConfig( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = machineType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = numLocalSsds if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(2, __value) } }; - + { val __value = numReplicas if (__value != 0) { @@ -56,7 +56,7 @@ final case class SparkResourceConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala index 4249c27fe..2323c5f45 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala @@ -32,6 +32,10 @@ final case class TrainerResourceConfig( val __value = trainerConfig.vertexAiGraphStoreTrainerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; + if (trainerConfig.customTrainerConfig.isDefined) { + val __value = trainerConfig.customTrainerConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -69,6 +73,12 @@ final case class TrainerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; + trainerConfig.customTrainerConfig.foreach { __v => + val __m = __v + _output__.writeTag(5, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; unknownFields.writeTo(_output__) } def getVertexAiTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = trainerConfig.vertexAiTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) @@ -79,6 +89,8 @@ final case class TrainerResourceConfig( def withLocalTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__v)) def getVertexAiGraphStoreTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = trainerConfig.vertexAiGraphStoreTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig.defaultInstance) def withVertexAiGraphStoreTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(__v)) + def getCustomTrainerConfig: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig = trainerConfig.customTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.CustomResourceConfig.defaultInstance) + def withCustomTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.CustomTrainerConfig(__v)) def clearTrainerConfig: TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) def withTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig): TrainerResourceConfig = copy(trainerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -89,6 +101,7 @@ final case class TrainerResourceConfig( case 2 => trainerConfig.kfpTrainerConfig.orNull case 3 => trainerConfig.localTrainerConfig.orNull case 4 => trainerConfig.vertexAiGraphStoreTrainerConfig.orNull + case 5 => trainerConfig.customTrainerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -98,6 +111,7 @@ final case class TrainerResourceConfig( case 2 => trainerConfig.kfpTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => trainerConfig.localTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 4 => trainerConfig.vertexAiGraphStoreTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 5 => trainerConfig.customTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -123,6 +137,8 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__trainerConfig.localTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 34 => __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(__trainerConfig.vertexAiGraphStoreTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 42 => + __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.CustomTrainerConfig(__trainerConfig.customTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -143,12 +159,13 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(5).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.CustomTrainerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(12) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(12) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(13) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(13) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -156,6 +173,7 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. case 2 => __out = snapchat.research.gbml.gigl_resource_config.KFPResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig + case 5 => __out = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig } __out } @@ -171,10 +189,12 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. def isKfpTrainerConfig: _root_.scala.Boolean = false def isLocalTrainerConfig: _root_.scala.Boolean = false def isVertexAiGraphStoreTrainerConfig: _root_.scala.Boolean = false + def isCustomTrainerConfig: _root_.scala.Boolean = false def vertexAiTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def kfpTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = _root_.scala.None def localTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None def vertexAiGraphStoreTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = _root_.scala.None + def customTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = _root_.scala.None } object TrainerConfig { @SerialVersionUID(0L) @@ -214,18 +234,27 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. override def vertexAiGraphStoreTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = Some(value) override def number: _root_.scala.Int = 4 } + @SerialVersionUID(0L) + final case class CustomTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig) extends snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig + override def isCustomTrainerConfig: _root_.scala.Boolean = true + override def customTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = Some(value) + override def number: _root_.scala.Int = 5 + } } implicit class TrainerResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig](_l) { def vertexAiTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiTrainerConfig(f_))) def kfpTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = field(_.getKfpTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(f_))) def localTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(f_))) def vertexAiGraphStoreTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = field(_.getVertexAiGraphStoreTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(f_))) + def customTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = field(_.getCustomTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.CustomTrainerConfig(f_))) def trainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig] = field(_.trainerConfig)((c_, f_) => c_.copy(trainerConfig = f_)) } final val VERTEX_AI_TRAINER_CONFIG_FIELD_NUMBER = 1 final val KFP_TRAINER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_TRAINER_CONFIG_FIELD_NUMBER = 3 final val VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG_FIELD_NUMBER = 4 + final val CUSTOM_TRAINER_CONFIG_FIELD_NUMBER = 5 def of( trainerConfig: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig( diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiTrainerConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiTrainerConfig.scala index c088bafc2..37d730799 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiTrainerConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiTrainerConfig.scala @@ -29,28 +29,28 @@ final case class VertexAiTrainerConfig( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = machineType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = gpuType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = gpuLimit if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(3, __value) } }; - + { val __value = numReplicas if (__value != 0) { @@ -67,7 +67,7 @@ final case class VertexAiTrainerConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Edge.scala b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Edge.scala index cd1501b1f..dc359d1e3 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Edge.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Edge.scala @@ -33,14 +33,14 @@ final case class Edge( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = srcNodeId if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(1, __value) } }; - + { val __value = dstNodeId if (__value != 0) { @@ -65,7 +65,7 @@ final case class Edge( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/EdgeType.scala b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/EdgeType.scala index d2a68d8b9..439ae2229 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/EdgeType.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/EdgeType.scala @@ -21,21 +21,21 @@ final case class EdgeType( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = relation if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = srcNodeType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = dstNodeType if (!__value.isEmpty) { @@ -52,7 +52,7 @@ final case class EdgeType( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Graph.scala b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Graph.scala index 0e4e3105d..7c317bb6d 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Graph.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Graph.scala @@ -35,7 +35,7 @@ final case class Graph( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { nodes.foreach { __v => diff --git a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphMetadata.scala b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphMetadata.scala index 8c6307580..856a159a0 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphMetadata.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphMetadata.scala @@ -58,7 +58,7 @@ final case class GraphMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { nodeTypes.foreach { __v => @@ -205,7 +205,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (__value != 0) { @@ -226,7 +226,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -269,7 +269,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research def companion: snapchat.research.gbml.graph_schema.GraphMetadata.CondensedEdgeTypeMapEntry.type = snapchat.research.gbml.graph_schema.GraphMetadata.CondensedEdgeTypeMapEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.GraphMetadata.CondensedEdgeTypeMapEntry]) } - + object CondensedEdgeTypeMapEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.graph_schema.GraphMetadata.CondensedEdgeTypeMapEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.graph_schema.GraphMetadata.CondensedEdgeTypeMapEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.graph_schema.GraphMetadata.CondensedEdgeTypeMapEntry = { @@ -341,7 +341,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GraphMetadata.CondensedEdgeTypeMapEntry]) } - + @SerialVersionUID(0L) final case class CondensedNodeTypeMapEntry( key: _root_.scala.Int = 0, @@ -352,14 +352,14 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -376,7 +376,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -420,7 +420,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research def companion: snapchat.research.gbml.graph_schema.GraphMetadata.CondensedNodeTypeMapEntry.type = snapchat.research.gbml.graph_schema.GraphMetadata.CondensedNodeTypeMapEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.GraphMetadata.CondensedNodeTypeMapEntry]) } - + object CondensedNodeTypeMapEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.graph_schema.GraphMetadata.CondensedNodeTypeMapEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.graph_schema.GraphMetadata.CondensedNodeTypeMapEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.graph_schema.GraphMetadata.CondensedNodeTypeMapEntry = { @@ -485,7 +485,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GraphMetadata.CondensedNodeTypeMapEntry]) } - + implicit class GraphMetadataLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.graph_schema.GraphMetadata]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.graph_schema.GraphMetadata](_l) { def nodeTypes: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[_root_.scala.Predef.String]] = field(_.nodeTypes)((c_, f_) => c_.copy(nodeTypes = f_)) def edgeTypes: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[snapchat.research.gbml.graph_schema.EdgeType]] = field(_.edgeTypes)((c_, f_) => c_.copy(edgeTypes = f_)) diff --git a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphSchemaProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphSchemaProto.scala index 28c2eff67..0cd9e60ca 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphSchemaProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphSchemaProto.scala @@ -50,4 +50,4 @@ object GraphSchemaProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Node.scala b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Node.scala index e8a4c6f98..7e4c5c8b9 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Node.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/graph_schema/Node.scala @@ -27,7 +27,7 @@ final case class Node( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = nodeId if (__value != 0) { @@ -52,7 +52,7 @@ final case class Node( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadata.scala b/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadata.scala index 6668d67ce..a58d891c0 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadata.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadata.scala @@ -31,7 +31,7 @@ final case class InferenceMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { nodeTypeToInferencerOutputInfoMap.foreach { __v => @@ -123,7 +123,7 @@ object InferenceMetadata extends scalapb.GeneratedMessageCompanion[snapchat.rese private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { @@ -144,7 +144,7 @@ object InferenceMetadata extends scalapb.GeneratedMessageCompanion[snapchat.rese __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -187,7 +187,7 @@ object InferenceMetadata extends scalapb.GeneratedMessageCompanion[snapchat.rese def companion: snapchat.research.gbml.inference_metadata.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry.type = snapchat.research.gbml.inference_metadata.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry]) } - + object NodeTypeToInferencerOutputInfoMapEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.inference_metadata.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.inference_metadata.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.inference_metadata.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry = { @@ -259,7 +259,7 @@ object InferenceMetadata extends scalapb.GeneratedMessageCompanion[snapchat.rese ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry]) } - + implicit class InferenceMetadataLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.inference_metadata.InferenceMetadata]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.inference_metadata.InferenceMetadata](_l) { def nodeTypeToInferencerOutputInfoMap: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, snapchat.research.gbml.inference_metadata.InferenceOutput]] = field(_.nodeTypeToInferencerOutputInfoMap)((c_, f_) => c_.copy(nodeTypeToInferencerOutputInfoMap = f_)) } diff --git a/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadataProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadataProto.scala index 7e1424d6e..9c335401b 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadataProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadataProto.scala @@ -35,4 +35,4 @@ object InferenceMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceOutput.scala b/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceOutput.scala index 845fbb29c..66c68b47b 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceOutput.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceOutput.scala @@ -38,7 +38,7 @@ final case class InferenceOutput( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { embeddingsPath.foreach { __v => diff --git a/scala/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostProcessedMetadata.scala b/scala/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostProcessedMetadata.scala index a0399d909..8ebed0a48 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostProcessedMetadata.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostProcessedMetadata.scala @@ -17,7 +17,7 @@ final case class PostProcessedMetadata( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = postProcessorLogMetricsUri if (!__value.isEmpty) { @@ -34,7 +34,7 @@ final case class PostProcessedMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostprocessedMetadataProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostprocessedMetadataProto.scala index 4b0e94597..d36401685 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostprocessedMetadataProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostprocessedMetadataProto.scala @@ -28,4 +28,4 @@ object PostprocessedMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadata.scala b/scala/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadata.scala index 6a160cdd4..80160636b 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadata.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadata.scala @@ -38,7 +38,7 @@ final case class PreprocessedMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { condensedNodeTypeToPreprocessedMetadata.foreach { __v => @@ -181,7 +181,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = nodeIdKey if (!__value.isEmpty) { @@ -196,28 +196,28 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r val __value = __item __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(3, __value) } - + { val __value = tfrecordUriPrefix if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(4, __value) } }; - + { val __value = schemaUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(5, __value) } }; - + { val __value = enumeratedNodeIdsBqTable if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(6, __value) } }; - + { val __value = enumeratedNodeDataBqTable if (!__value.isEmpty) { @@ -228,7 +228,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r val __value = featureDim.get __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(8, __value) }; - + { val __value = transformFnAssetsUri if (!__value.isEmpty) { @@ -245,7 +245,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -366,7 +366,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput.type = snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.PreprocessedMetadata.NodeMetadataOutput]) } - + object NodeMetadataOutput extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput = { @@ -499,7 +499,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.PreprocessedMetadata.NodeMetadataOutput]) } - + /** Houses metadata of edge features output from DataPreprocessor * * @param featureKeys @@ -540,21 +540,21 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r val __value = __item __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } - + { val __value = tfrecordUriPrefix if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(3, __value) } }; - + { val __value = schemaUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(4, __value) } }; - + { val __value = enumeratedEdgeDataBqTable if (!__value.isEmpty) { @@ -565,7 +565,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r val __value = featureDim.get __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(6, __value) }; - + { val __value = transformFnAssetsUri if (!__value.isEmpty) { @@ -582,7 +582,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { featureKeys.foreach { __v => @@ -679,7 +679,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataInfo.type = snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataInfo // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.PreprocessedMetadata.EdgeMetadataInfo]) } - + object EdgeMetadataInfo extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataInfo] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataInfo] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataInfo = { @@ -792,7 +792,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.PreprocessedMetadata.EdgeMetadataInfo]) } - + /** Houses metadata about edge TFTransform output from DataPreprocessor. * * @param srcNodeIdKey @@ -819,14 +819,14 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = srcNodeIdKey if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = dstNodeIdKey if (!__value.isEmpty) { @@ -855,7 +855,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -932,7 +932,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput.type = snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.PreprocessedMetadata.EdgeMetadataOutput]) } - + object EdgeMetadataOutput extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput = { @@ -1035,7 +1035,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.PreprocessedMetadata.EdgeMetadataOutput]) } - + @SerialVersionUID(0L) final case class CondensedNodeTypeToPreprocessedMetadataEntry( key: _root_.scala.Int = 0, @@ -1046,7 +1046,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (__value != 0) { @@ -1067,7 +1067,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -1110,7 +1110,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry.type = snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry]) } - + object CondensedNodeTypeToPreprocessedMetadataEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry = { @@ -1182,7 +1182,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry]) } - + @SerialVersionUID(0L) final case class CondensedEdgeTypeToPreprocessedMetadataEntry( key: _root_.scala.Int = 0, @@ -1193,7 +1193,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (__value != 0) { @@ -1214,7 +1214,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -1257,7 +1257,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry.type = snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry]) } - + object CondensedEdgeTypeToPreprocessedMetadataEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry = { @@ -1329,7 +1329,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry]) } - + implicit class PreprocessedMetadataLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata](_l) { def condensedNodeTypeToPreprocessedMetadata: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Int, snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput]] = field(_.condensedNodeTypeToPreprocessedMetadata)((c_, f_) => c_.copy(condensedNodeTypeToPreprocessedMetadata = f_)) def condensedEdgeTypeToPreprocessedMetadata: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Int, snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput]] = field(_.condensedEdgeTypeToPreprocessedMetadata)((c_, f_) => c_.copy(condensedEdgeTypeToPreprocessedMetadata = f_)) diff --git a/scala/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadataProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadataProto.scala index b6e8d0d6d..becc2d068 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadataProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadataProto.scala @@ -61,4 +61,4 @@ object PreprocessedMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/GlobalRandomUniformStrategy.scala b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/GlobalRandomUniformStrategy.scala index c56f47fa4..2a6227235 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/GlobalRandomUniformStrategy.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/GlobalRandomUniformStrategy.scala @@ -15,7 +15,7 @@ final case class GlobalRandomUniformStrategy( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = numHops if (__value != 0) { @@ -36,7 +36,7 @@ final case class GlobalRandomUniformStrategy( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPath.scala b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPath.scala index dc3069c8b..4aad9e19c 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPath.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPath.scala @@ -15,7 +15,7 @@ final case class MessagePassingPath( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = rootNodeType if (!__value.isEmpty) { @@ -36,7 +36,7 @@ final case class MessagePassingPath( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPathStrategy.scala b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPathStrategy.scala index fd8a906af..ee3aa77a0 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPathStrategy.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPathStrategy.scala @@ -34,7 +34,7 @@ final case class MessagePassingPathStrategy( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { paths.foreach { __v => diff --git a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomUniform.scala b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomUniform.scala index 5579eca5a..ef0fde958 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomUniform.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomUniform.scala @@ -5,7 +5,7 @@ package snapchat.research.gbml.subgraph_sampling_strategy -/** Randomly sample nodes from the neighborhood without replacement. +/** Randomly sample nodes from the neighborhood without replacement. */ @SerialVersionUID(0L) final case class RandomUniform( @@ -16,7 +16,7 @@ final case class RandomUniform( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = numNodesToSample if (__value != 0) { @@ -33,7 +33,7 @@ final case class RandomUniform( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomWeighted.scala b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomWeighted.scala index 389735b80..395f3876b 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomWeighted.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomWeighted.scala @@ -17,14 +17,14 @@ final case class RandomWeighted( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = numNodesToSample if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeInt32Size(1, __value) } }; - + { val __value = edgeFeatName if (!__value.isEmpty) { @@ -41,7 +41,7 @@ final case class RandomWeighted( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingDirection.scala b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingDirection.scala index 0563e9b71..8f521cf2e 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingDirection.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingDirection.scala @@ -16,7 +16,7 @@ sealed abstract class SamplingDirection(val value: _root_.scala.Int) extends _ro object SamplingDirection extends _root_.scalapb.GeneratedEnumCompanion[SamplingDirection] { sealed trait Recognized extends SamplingDirection implicit def enumCompanion: _root_.scalapb.GeneratedEnumCompanion[SamplingDirection] = this - + /** Sample incoming edges to the dst nodes (default) */ @SerialVersionUID(0L) @@ -25,7 +25,7 @@ object SamplingDirection extends _root_.scalapb.GeneratedEnumCompanion[SamplingD val name = "INCOMING" override def isIncoming: _root_.scala.Boolean = true } - + /** Sample outgoing edges from the src nodes */ @SerialVersionUID(0L) @@ -34,7 +34,7 @@ object SamplingDirection extends _root_.scalapb.GeneratedEnumCompanion[SamplingD val name = "OUTGOING" override def isOutgoing: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) final case class Unrecognized(unrecognizedValue: _root_.scala.Int) extends SamplingDirection(unrecognizedValue) with _root_.scalapb.UnrecognizedEnum lazy val values = scala.collection.immutable.Seq(INCOMING, OUTGOING) @@ -45,4 +45,4 @@ object SamplingDirection extends _root_.scalapb.GeneratedEnumCompanion[SamplingD } def javaDescriptor: _root_.com.google.protobuf.Descriptors.EnumDescriptor = SubgraphSamplingStrategyProto.javaDescriptor.getEnumTypes().get(0) def scalaDescriptor: _root_.scalapb.descriptors.EnumDescriptor = SubgraphSamplingStrategyProto.scalaDescriptor.enums(0) -} \ No newline at end of file +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingOp.scala b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingOp.scala index 513ceaf83..6afa2e8b6 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingOp.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingOp.scala @@ -28,7 +28,7 @@ final case class SamplingOp( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = opName if (!__value.isEmpty) { @@ -59,7 +59,7 @@ final case class SamplingOp( val __value = samplingMethod.userDefined.get __size += 2 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; - + { val __value = samplingDirection.value if (__value != 0) { @@ -76,7 +76,7 @@ final case class SamplingOp( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -293,7 +293,7 @@ object SamplingOp extends scalapb.GeneratedMessageCompanion[snapchat.research.gb override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class RandomUniform(value: snapchat.research.gbml.subgraph_sampling_strategy.RandomUniform) extends snapchat.research.gbml.subgraph_sampling_strategy.SamplingOp.SamplingMethod { type ValueType = snapchat.research.gbml.subgraph_sampling_strategy.RandomUniform diff --git a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategy.scala b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategy.scala index 5f8e3dee6..4abc7182b 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategy.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategy.scala @@ -32,7 +32,7 @@ final case class SubgraphSamplingStrategy( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { strategy.messagePassingPaths.foreach { __v => @@ -143,7 +143,7 @@ object SubgraphSamplingStrategy extends scalapb.GeneratedMessageCompanion[snapch override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class MessagePassingPaths(value: snapchat.research.gbml.subgraph_sampling_strategy.MessagePassingPathStrategy) extends snapchat.research.gbml.subgraph_sampling_strategy.SubgraphSamplingStrategy.Strategy { type ValueType = snapchat.research.gbml.subgraph_sampling_strategy.MessagePassingPathStrategy diff --git a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategyProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategyProto.scala index 709ae1159..e6ae31c4c 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategyProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategyProto.scala @@ -67,4 +67,4 @@ object SubgraphSamplingStrategyProto extends _root_.scalapb.GeneratedFileObject } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/TopK.scala b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/TopK.scala index 86bda274e..39c5388e4 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/TopK.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/TopK.scala @@ -17,14 +17,14 @@ final case class TopK( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = numNodesToSample if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeInt32Size(1, __value) } }; - + { val __value = edgeFeatName if (!__value.isEmpty) { @@ -41,7 +41,7 @@ final case class TopK( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/UserDefined.scala b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/UserDefined.scala index a055bfbaa..cc00d0f9d 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/UserDefined.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/UserDefined.scala @@ -20,7 +20,7 @@ final case class UserDefined( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = pathToUdf if (!__value.isEmpty) { @@ -41,7 +41,7 @@ final case class UserDefined( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -151,14 +151,14 @@ object UserDefined extends scalapb.GeneratedMessageCompanion[snapchat.research.g private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -175,7 +175,7 @@ object UserDefined extends scalapb.GeneratedMessageCompanion[snapchat.research.g __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -219,7 +219,7 @@ object UserDefined extends scalapb.GeneratedMessageCompanion[snapchat.research.g def companion: snapchat.research.gbml.subgraph_sampling_strategy.UserDefined.ParamsEntry.type = snapchat.research.gbml.subgraph_sampling_strategy.UserDefined.ParamsEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.UserDefined.ParamsEntry]) } - + object ParamsEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.subgraph_sampling_strategy.UserDefined.ParamsEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.subgraph_sampling_strategy.UserDefined.ParamsEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.subgraph_sampling_strategy.UserDefined.ParamsEntry = { @@ -284,7 +284,7 @@ object UserDefined extends scalapb.GeneratedMessageCompanion[snapchat.research.g ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.UserDefined.ParamsEntry]) } - + implicit class UserDefinedLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.subgraph_sampling_strategy.UserDefined]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.subgraph_sampling_strategy.UserDefined](_l) { def pathToUdf: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.pathToUdf)((c_, f_) => c_.copy(pathToUdf = f_)) def params: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = field(_.params)((c_, f_) => c_.copy(params = f_)) diff --git a/scala/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadata.scala b/scala/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadata.scala index bcf95c046..2c5a042f9 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadata.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadata.scala @@ -26,28 +26,28 @@ final case class TrainedModelMetadata( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = trainedModelUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = scriptedModelUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = evalMetricsUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(3, __value) } }; - + { val __value = tensorboardLogsUri if (!__value.isEmpty) { @@ -64,7 +64,7 @@ final case class TrainedModelMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadataProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadataProto.scala index b06e0d55a..1262d9517 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadataProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadataProto.scala @@ -31,4 +31,4 @@ object TrainedModelMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/Label.scala b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/Label.scala index 21884234b..3289f2d76 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/Label.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/Label.scala @@ -15,14 +15,14 @@ final case class Label( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = labelType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = label if (__value != 0) { @@ -39,7 +39,7 @@ final case class Label( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/NodeAnchorBasedLinkPredictionSample.scala b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/NodeAnchorBasedLinkPredictionSample.scala index a05a6ef18..a1f7def5a 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/NodeAnchorBasedLinkPredictionSample.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/NodeAnchorBasedLinkPredictionSample.scala @@ -64,7 +64,7 @@ final case class NodeAnchorBasedLinkPredictionSample( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { rootNode.foreach { __v => diff --git a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/RootedNodeNeighborhood.scala b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/RootedNodeNeighborhood.scala index f326f0375..cdd1d4a4c 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/RootedNodeNeighborhood.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/RootedNodeNeighborhood.scala @@ -41,7 +41,7 @@ final case class RootedNodeNeighborhood( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { rootNode.foreach { __v => diff --git a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedLinkBasedTaskSample.scala b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedLinkBasedTaskSample.scala index 4ffbeb328..9707ea68a 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedLinkBasedTaskSample.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedLinkBasedTaskSample.scala @@ -48,7 +48,7 @@ final case class SupervisedLinkBasedTaskSample( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { rootEdge.foreach { __v => diff --git a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedNodeClassificationSample.scala b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedNodeClassificationSample.scala index 35c594ca3..00304a823 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedNodeClassificationSample.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedNodeClassificationSample.scala @@ -43,7 +43,7 @@ final case class SupervisedNodeClassificationSample( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { rootNode.foreach { __v => diff --git a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/TrainingSamplesSchemaProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/TrainingSamplesSchemaProto.scala index f4187ea60..f02891e26 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/TrainingSamplesSchemaProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/training_samples_schema/TrainingSamplesSchemaProto.scala @@ -51,4 +51,4 @@ object TrainingSamplesSchemaProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala/common/src/test/assets/resource_config.yaml b/scala/common/src/test/assets/resource_config.yaml index 5d0979b72..702b9f917 100644 --- a/scala/common/src/test/assets/resource_config.yaml +++ b/scala/common/src/test/assets/resource_config.yaml @@ -42,4 +42,4 @@ inferencer_config: num_workers: 1 max_num_workers: 256 machine_type: "c3-standard-22" - disk_size_gb: 100 \ No newline at end of file + disk_size_gb: 100 diff --git a/scala/common/src/test/assets/split_generator/node_anchor_based_link_prediction/frozen_gbml_config.yaml b/scala/common/src/test/assets/split_generator/node_anchor_based_link_prediction/frozen_gbml_config.yaml index b1a1421cd..a0a2ba944 100644 --- a/scala/common/src/test/assets/split_generator/node_anchor_based_link_prediction/frozen_gbml_config.yaml +++ b/scala/common/src/test/assets/split_generator/node_anchor_based_link_prediction/frozen_gbml_config.yaml @@ -44,4 +44,4 @@ graphMetadata: relation: engage srcNodeType: user nodeTypes: - - user \ No newline at end of file + - user diff --git a/scala/common/src/test/assets/split_generator/node_anchor_based_link_prediction/preprocessed_metadata.yaml b/scala/common/src/test/assets/split_generator/node_anchor_based_link_prediction/preprocessed_metadata.yaml index 429035cbb..f44cdbb44 100644 --- a/scala/common/src/test/assets/split_generator/node_anchor_based_link_prediction/preprocessed_metadata.yaml +++ b/scala/common/src/test/assets/split_generator/node_anchor_based_link_prediction/preprocessed_metadata.yaml @@ -15,4 +15,4 @@ condensedNodeTypeToPreprocessedMetadata: - f1 nodeIdKey: node_id schemaUri: not.used.for.test - tfrecordUriPrefix: not.used.for.test \ No newline at end of file + tfrecordUriPrefix: not.used.for.test diff --git a/scala/common/src/test/assets/split_generator/supervised_node_classification/frozen_gbml_config.yaml b/scala/common/src/test/assets/split_generator/supervised_node_classification/frozen_gbml_config.yaml index 90762002c..68711a211 100644 --- a/scala/common/src/test/assets/split_generator/supervised_node_classification/frozen_gbml_config.yaml +++ b/scala/common/src/test/assets/split_generator/supervised_node_classification/frozen_gbml_config.yaml @@ -28,4 +28,4 @@ sharedConfig: supervisedNodeClassificationOutput: labeledTfrecordUriPrefix: common/src/test/assets/split_generator/supervised_node_classification/sgs_output/labeled/samples/ unlabeledTfrecordUriPrefix: common/src/test/assets/split_generator/supervised_node_classification/sgs_output/unlabeled/samples/ - preprocessedMetadataUri: common/src/test/assets/split_generator/supervised_node_classification/preprocessed_metadata.yaml \ No newline at end of file + preprocessedMetadataUri: common/src/test/assets/split_generator/supervised_node_classification/preprocessed_metadata.yaml diff --git a/scala/common/src/test/assets/subgraph_sampler/heterogeneous/node_anchor_based_link_prediction/frozen_gbml_config_graphdb_dblp_local.yaml b/scala/common/src/test/assets/subgraph_sampler/heterogeneous/node_anchor_based_link_prediction/frozen_gbml_config_graphdb_dblp_local.yaml index 02a166cf1..6fea68ba6 100755 --- a/scala/common/src/test/assets/subgraph_sampler/heterogeneous/node_anchor_based_link_prediction/frozen_gbml_config_graphdb_dblp_local.yaml +++ b/scala/common/src/test/assets/subgraph_sampler/heterogeneous/node_anchor_based_link_prediction/frozen_gbml_config_graphdb_dblp_local.yaml @@ -50,4 +50,4 @@ graphMetadata: srcNodeType: paper nodeTypes: - author - - paper \ No newline at end of file + - paper diff --git a/scala/common/src/test/assets/subgraph_sampler/node_anchor_based_link_prediction/frozen_gbml_config.yaml b/scala/common/src/test/assets/subgraph_sampler/node_anchor_based_link_prediction/frozen_gbml_config.yaml index 404c5b0f9..c2cde89f6 100755 --- a/scala/common/src/test/assets/subgraph_sampler/node_anchor_based_link_prediction/frozen_gbml_config.yaml +++ b/scala/common/src/test/assets/subgraph_sampler/node_anchor_based_link_prediction/frozen_gbml_config.yaml @@ -37,4 +37,4 @@ graphMetadata: relation: friend srcNodeType: user nodeTypes: - - user \ No newline at end of file + - user diff --git a/scala/common/src/test/assets/subgraph_sampler/node_anchor_based_link_prediction/preprocessed_metadata.yaml b/scala/common/src/test/assets/subgraph_sampler/node_anchor_based_link_prediction/preprocessed_metadata.yaml index a9ed15158..0102fe86c 100755 --- a/scala/common/src/test/assets/subgraph_sampler/node_anchor_based_link_prediction/preprocessed_metadata.yaml +++ b/scala/common/src/test/assets/subgraph_sampler/node_anchor_based_link_prediction/preprocessed_metadata.yaml @@ -31,4 +31,4 @@ condensedNodeTypeToPreprocessedMetadata: - f1 nodeIdKey: node_id schemaUri: not.used.for.test - tfrecordUriPrefix: common/src/test/assets/subgraph_sampler/node_anchor_based_link_prediction/node_data \ No newline at end of file + tfrecordUriPrefix: common/src/test/assets/subgraph_sampler/node_anchor_based_link_prediction/node_data diff --git a/scala/common/src/test/assets/subgraph_sampler/supervised_node_classification/frozen_gbml_config.yaml b/scala/common/src/test/assets/subgraph_sampler/supervised_node_classification/frozen_gbml_config.yaml index cb1587a83..798350dd0 100755 --- a/scala/common/src/test/assets/subgraph_sampler/supervised_node_classification/frozen_gbml_config.yaml +++ b/scala/common/src/test/assets/subgraph_sampler/supervised_node_classification/frozen_gbml_config.yaml @@ -24,4 +24,4 @@ sharedConfig: supervisedNodeClassificationOutput: labeledTfrecordUriPrefix: common/src/test/assets/subgraph_sampler/supervised_node_classification/output/labeled/samples/ unlabeledTfrecordUriPrefix: common/src/test/assets/subgraph_sampler/supervised_node_classification/output/unlabeled/samples/ - preprocessedMetadataUri: common/src/test/assets/subgraph_sampler/supervised_node_classification/preprocessed_metadata.yaml \ No newline at end of file + preprocessedMetadataUri: common/src/test/assets/subgraph_sampler/supervised_node_classification/preprocessed_metadata.yaml diff --git a/scala/split_generator/src/main/scala/lib/assigners/AbstractAssigners.scala b/scala/split_generator/src/main/scala/lib/assigners/AbstractAssigners.scala index 90524f6a0..891198953 100644 --- a/scala/split_generator/src/main/scala/lib/assigners/AbstractAssigners.scala +++ b/scala/split_generator/src/main/scala/lib/assigners/AbstractAssigners.scala @@ -22,7 +22,7 @@ object AbstractAssigners { * e.g. could be assigning a NodePb (T) to some Enum (S). * * @param obj the object to hash - * @return + * @return */ def assign(obj: T): S } @@ -59,7 +59,7 @@ object AbstractAssigners { /** * Relative width of each bucket in the hash space. e.g. [0.2, 0.4, 0.4] would indicate 3 buckets, where - * the second and third bucket are twice as prominent as the first bucket. + * the second and third bucket are twice as prominent as the first bucket. */ lazy val weights: Seq[Float] = bucketWeights.values.toList diff --git a/scala/split_generator/src/main/scala/lib/split_strategies/SplitStrategy.scala b/scala/split_generator/src/main/scala/lib/split_strategies/SplitStrategy.scala index ba15966a3..b0d7e1146 100644 --- a/scala/split_generator/src/main/scala/lib/split_strategies/SplitStrategy.scala +++ b/scala/split_generator/src/main/scala/lib/split_strategies/SplitStrategy.scala @@ -32,7 +32,7 @@ abstract class SplitStrategy[A](splitStrategyArgs: Map[String, String]) extends val graphMetadataPbWrapper: GraphMetadataPbWrapper /** - * Takes in a single "un-split" training sample instance output by SubgraphSampler, + * Takes in a single "un-split" training sample instance output by SubgraphSampler, * and a DatasetSplit(TRAIN, TEST, VAL) and outputs the the "split" samples for that dataset split * * @param sample : Input Sample from SGS diff --git a/scala/split_generator/src/main/scala/lib/split_strategies/UDLAnchorBasedSupervisionEdgeSplitStrategy.scala b/scala/split_generator/src/main/scala/lib/split_strategies/UDLAnchorBasedSupervisionEdgeSplitStrategy.scala index 61fefef6f..12e378ebe 100644 --- a/scala/split_generator/src/main/scala/lib/split_strategies/UDLAnchorBasedSupervisionEdgeSplitStrategy.scala +++ b/scala/split_generator/src/main/scala/lib/split_strategies/UDLAnchorBasedSupervisionEdgeSplitStrategy.scala @@ -37,7 +37,7 @@ class UDLAnchorBasedSupervisionEdgeSplitStrategy( * (a) All pos_edges and hard_neg_edges belonging to the split. * (b) message passing structure which should be pb.neighborhood and therefore the same across all splits * (i.e. no masking). - * (c) The message passing structure may be filtered down to only include edges that are not in the pos_edges + * (c) The message passing structure may be filtered down to only include edges that are not in the pos_edges * and hard_neg_edges. * An output train-split sample needs to have >0 pos_edges in this setting for loss computation. * Output val/test-split samples may have 0 pos_edges (and even 0 hard_neg_edges), since these diff --git a/scala/subgraph_sampler/src/main/scala/libs/task/TaskOutputValidator.scala b/scala/subgraph_sampler/src/main/scala/libs/task/TaskOutputValidator.scala index f50bb5d0f..8327695b3 100644 --- a/scala/subgraph_sampler/src/main/scala/libs/task/TaskOutputValidator.scala +++ b/scala/subgraph_sampler/src/main/scala/libs/task/TaskOutputValidator.scala @@ -17,7 +17,7 @@ object TaskOutputValidator { * is present in the neighborhood nodes. * This method does a dataset.map() on the final output produced by SGS and returns the same dataset * if there is no validation failure. Raises and excpetion if there is some error - * @spark: dataset.map() is not an action (unlike foreach) and does not lead to any + * @spark: dataset.map() is not an action (unlike foreach) and does not lead to any * duplication of computation due to this validation code. * * @param mainSampleDS @@ -50,7 +50,7 @@ object TaskOutputValidator { * is present in the neighborhood nodes. * This method does a dataset.map() on the final output produced by SGS and returns the same dataset * if there is no validation failure. Raises and excpetion if there is some error - * @spark: dataset.map() is not an action (unlike foreach) and does not lead to any + * @spark: dataset.map() is not an action (unlike foreach) and does not lead to any * duplication of computation due to this validation code. * * @param mainSampleDS diff --git a/scala_spark35/.gitignore b/scala_spark35/.gitignore index f7d538fde..3c3cd13b5 100644 --- a/scala_spark35/.gitignore +++ b/scala_spark35/.gitignore @@ -267,4 +267,4 @@ spark-warehouse/ .metals/ .bloop/ .ammonite/ -metals.sbt \ No newline at end of file +metals.sbt diff --git a/scala_spark35/.scalafix.conf b/scala_spark35/.scalafix.conf index 195ccd431..9e45edd17 100644 --- a/scala_spark35/.scalafix.conf +++ b/scala_spark35/.scalafix.conf @@ -1,7 +1,7 @@ rules = [ ExplicitResultTypes NoValInForComprehension, - OrganizeImports, + OrganizeImports, ProcedureSyntax, RedundantSyntax, ] diff --git a/scala_spark35/common/src/main/scala/graphdb/nebula/NebulaGraphDBClient.scala b/scala_spark35/common/src/main/scala/graphdb/nebula/NebulaGraphDBClient.scala index 79cdf9889..449ddd653 100644 --- a/scala_spark35/common/src/main/scala/graphdb/nebula/NebulaGraphDBClient.scala +++ b/scala_spark35/common/src/main/scala/graphdb/nebula/NebulaGraphDBClient.scala @@ -27,8 +27,8 @@ import scala.collection.JavaConversions._ To use SessionPool, you must config the graph space to connect for SessionPool. The SessionPool is thread-safe, and support retry(release old session and get available session from SessionPool) for both connection error, session error and execution error(caused by bad storaged server), and the retry mechanism needs users to config retryTimes and intervalTime between retrys. - - + + This class needs to be serializable if defined outside of mapPartitions, nebula client however is has underlying classes that are not serializable ConnectionPool + getSession - java.io.NotSerializableException: com.vesoft.nebula.client.graph.net.RoundRobinLoadBalancer SessionPool - Task not serializable: java.io.NotSerializableException: java.util.concurrent.ScheduledThreadPoolExecutor diff --git a/scala_spark35/common/src/main/scala/graphdb/nebula/NebulaQueryResponseTranslator.scala b/scala_spark35/common/src/main/scala/graphdb/nebula/NebulaQueryResponseTranslator.scala index 61731262b..769f432a5 100644 --- a/scala_spark35/common/src/main/scala/graphdb/nebula/NebulaQueryResponseTranslator.scala +++ b/scala_spark35/common/src/main/scala/graphdb/nebula/NebulaQueryResponseTranslator.scala @@ -60,7 +60,7 @@ class NebulaQueryResponseTranslator( case SamplingOp.SamplingMethod.RandomUniform(value) => { val numNodesToSample = value.numNodesToSample s"""GO 1 STEP - FROM ${nebulaVID} + FROM ${nebulaVID} OVER ${nebulaEdgeType} ${outgoingEdgesModifier} YIELD src(edge) as ${NebulaQueryResponseTranslator.RESULT_SRC_NODE_ID_COL_NAME}, @@ -74,33 +74,33 @@ class NebulaQueryResponseTranslator( case SamplingOp.SamplingMethod.RandomWeighted(value) => { val numNodesToSample = value.numNodesToSample val edgeFeatName = value.edgeFeatName - s"""GO 1 STEP - FROM ${nebulaVID} + s"""GO 1 STEP + FROM ${nebulaVID} OVER ${nebulaEdgeType} ${outgoingEdgesModifier} - YIELD + YIELD src(edge) as ${NebulaQueryResponseTranslator.RESULT_SRC_NODE_ID_COL_NAME}, dst(edge) as ${NebulaQueryResponseTranslator.RESULT_DST_NODE_ID_COL_NAME}, ${nebulaEdgeType}.${edgeFeatName} * rand() as ${edgeFeatName} | ORDER BY $$-.${edgeFeatName} DESC | LIMIT ${numNodesToSample} | - YIELD - $$-.${NebulaQueryResponseTranslator.RESULT_SRC_NODE_ID_COL_NAME} AS ${NebulaQueryResponseTranslator.RESULT_SRC_NODE_ID_COL_NAME}, + YIELD + $$-.${NebulaQueryResponseTranslator.RESULT_SRC_NODE_ID_COL_NAME} AS ${NebulaQueryResponseTranslator.RESULT_SRC_NODE_ID_COL_NAME}, $$-.${NebulaQueryResponseTranslator.RESULT_DST_NODE_ID_COL_NAME} AS ${NebulaQueryResponseTranslator.RESULT_DST_NODE_ID_COL_NAME}""" } case SamplingOp.SamplingMethod.TopK(value) => { val numNodesToSample = value.numNodesToSample val edgeFeatName = value.edgeFeatName - s"""GO 1 STEP - FROM ${nebulaVID} + s"""GO 1 STEP + FROM ${nebulaVID} OVER ${nebulaEdgeType} ${outgoingEdgesModifier} - YIELD + YIELD src(edge) as ${NebulaQueryResponseTranslator.RESULT_SRC_NODE_ID_COL_NAME}, dst(edge) as ${NebulaQueryResponseTranslator.RESULT_DST_NODE_ID_COL_NAME}, ${nebulaEdgeType}.${edgeFeatName} as ${edgeFeatName} | ORDER BY $$-.${edgeFeatName} DESC | LIMIT ${numNodesToSample} | - YIELD - $$-.${NebulaQueryResponseTranslator.RESULT_SRC_NODE_ID_COL_NAME} AS ${NebulaQueryResponseTranslator.RESULT_SRC_NODE_ID_COL_NAME}, + YIELD + $$-.${NebulaQueryResponseTranslator.RESULT_SRC_NODE_ID_COL_NAME} AS ${NebulaQueryResponseTranslator.RESULT_SRC_NODE_ID_COL_NAME}, $$-.${NebulaQueryResponseTranslator.RESULT_DST_NODE_ID_COL_NAME} AS ${NebulaQueryResponseTranslator.RESULT_DST_NODE_ID_COL_NAME}""" } case SamplingOp.SamplingMethod.UserDefined(value) => { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadata.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadata.scala index 0cafea1ca..afe385a0d 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadata.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadata.scala @@ -38,7 +38,7 @@ final case class DatasetMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { outputMetadata.supervisedNodeClassificationDataset.foreach { __v => @@ -165,7 +165,7 @@ object DatasetMetadata extends scalapb.GeneratedMessageCompanion[snapchat.resear override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class SupervisedNodeClassificationDataset(value: snapchat.research.gbml.dataset_metadata.SupervisedNodeClassificationDataset) extends snapchat.research.gbml.dataset_metadata.DatasetMetadata.OutputMetadata { type ValueType = snapchat.research.gbml.dataset_metadata.SupervisedNodeClassificationDataset diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadataProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadataProto.scala index 7f2bfe943..5393b2b91 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadataProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/DatasetMetadataProto.scala @@ -60,4 +60,4 @@ object DatasetMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/NodeAnchorBasedLinkPredictionDataset.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/NodeAnchorBasedLinkPredictionDataset.scala index 63cd8933f..44fd8509f 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/NodeAnchorBasedLinkPredictionDataset.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/NodeAnchorBasedLinkPredictionDataset.scala @@ -21,21 +21,21 @@ final case class NodeAnchorBasedLinkPredictionDataset( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = trainMainDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = testMainDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = valMainDataUri if (!__value.isEmpty) { @@ -64,7 +64,7 @@ final case class NodeAnchorBasedLinkPredictionDataset( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -250,14 +250,14 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -274,7 +274,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -318,7 +318,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp def companion: snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry.type = snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry]) } - + object TrainNodeTypeToRandomNegativeDataUriEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry = { @@ -383,7 +383,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.TrainNodeTypeToRandomNegativeDataUriEntry]) } - + @SerialVersionUID(0L) final case class ValNodeTypeToRandomNegativeDataUriEntry( key: _root_.scala.Predef.String = "", @@ -394,14 +394,14 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -418,7 +418,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -462,7 +462,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp def companion: snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry.type = snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry]) } - + object ValNodeTypeToRandomNegativeDataUriEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry = { @@ -527,7 +527,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.ValNodeTypeToRandomNegativeDataUriEntry]) } - + @SerialVersionUID(0L) final case class TestNodeTypeToRandomNegativeDataUriEntry( key: _root_.scala.Predef.String = "", @@ -538,14 +538,14 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -562,7 +562,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -606,7 +606,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp def companion: snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry.type = snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry]) } - + object TestNodeTypeToRandomNegativeDataUriEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry = { @@ -671,7 +671,7 @@ object NodeAnchorBasedLinkPredictionDataset extends scalapb.GeneratedMessageComp ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.NodeAnchorBasedLinkPredictionDataset.TestNodeTypeToRandomNegativeDataUriEntry]) } - + implicit class NodeAnchorBasedLinkPredictionDatasetLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.dataset_metadata.NodeAnchorBasedLinkPredictionDataset](_l) { def trainMainDataUri: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.trainMainDataUri)((c_, f_) => c_.copy(trainMainDataUri = f_)) def testMainDataUri: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.testMainDataUri)((c_, f_) => c_.copy(testMainDataUri = f_)) diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedLinkBasedTaskSplitDataset.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedLinkBasedTaskSplitDataset.scala index 2aca6bc33..453be5b92 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedLinkBasedTaskSplitDataset.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedLinkBasedTaskSplitDataset.scala @@ -18,21 +18,21 @@ final case class SupervisedLinkBasedTaskSplitDataset( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = trainDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = testDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = valDataUri if (!__value.isEmpty) { @@ -49,7 +49,7 @@ final case class SupervisedLinkBasedTaskSplitDataset( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedNodeClassificationDataset.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedNodeClassificationDataset.scala index ac5f71b11..2411d36d2 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedNodeClassificationDataset.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/dataset_metadata/SupervisedNodeClassificationDataset.scala @@ -18,21 +18,21 @@ final case class SupervisedNodeClassificationDataset( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = trainDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = testDataUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = valDataUri if (!__value.isEmpty) { @@ -49,7 +49,7 @@ final case class SupervisedNodeClassificationDataset( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadata.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadata.scala index 0395d2bd1..6f8ab4d22 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadata.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadata.scala @@ -38,7 +38,7 @@ final case class FlattenedGraphMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { outputMetadata.supervisedNodeClassificationOutput.foreach { __v => @@ -165,7 +165,7 @@ object FlattenedGraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class SupervisedNodeClassificationOutput(value: snapchat.research.gbml.flattened_graph_metadata.SupervisedNodeClassificationOutput) extends snapchat.research.gbml.flattened_graph_metadata.FlattenedGraphMetadata.OutputMetadata { type ValueType = snapchat.research.gbml.flattened_graph_metadata.SupervisedNodeClassificationOutput diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadataProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadataProto.scala index b0102ad23..1018c5787 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadataProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/FlattenedGraphMetadataProto.scala @@ -50,4 +50,4 @@ object FlattenedGraphMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/NodeAnchorBasedLinkPredictionOutput.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/NodeAnchorBasedLinkPredictionOutput.scala index c31af112a..375313e97 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/NodeAnchorBasedLinkPredictionOutput.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/NodeAnchorBasedLinkPredictionOutput.scala @@ -22,7 +22,7 @@ final case class NodeAnchorBasedLinkPredictionOutput( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = tfrecordUriPrefix if (!__value.isEmpty) { @@ -43,7 +43,7 @@ final case class NodeAnchorBasedLinkPredictionOutput( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -153,14 +153,14 @@ object NodeAnchorBasedLinkPredictionOutput extends scalapb.GeneratedMessageCompa private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -177,7 +177,7 @@ object NodeAnchorBasedLinkPredictionOutput extends scalapb.GeneratedMessageCompa __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -221,7 +221,7 @@ object NodeAnchorBasedLinkPredictionOutput extends scalapb.GeneratedMessageCompa def companion: snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry.type = snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry]) } - + object NodeTypeToRandomNegativeTfrecordUriPrefixEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry = { @@ -286,7 +286,7 @@ object NodeAnchorBasedLinkPredictionOutput extends scalapb.GeneratedMessageCompa ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.NodeAnchorBasedLinkPredictionOutput.NodeTypeToRandomNegativeTfrecordUriPrefixEntry]) } - + implicit class NodeAnchorBasedLinkPredictionOutputLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.flattened_graph_metadata.NodeAnchorBasedLinkPredictionOutput](_l) { def tfrecordUriPrefix: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.tfrecordUriPrefix)((c_, f_) => c_.copy(tfrecordUriPrefix = f_)) def nodeTypeToRandomNegativeTfrecordUriPrefix: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = field(_.nodeTypeToRandomNegativeTfrecordUriPrefix)((c_, f_) => c_.copy(nodeTypeToRandomNegativeTfrecordUriPrefix = f_)) diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedLinkBasedTaskOutput.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedLinkBasedTaskOutput.scala index 8cfd94948..7af22ab07 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedLinkBasedTaskOutput.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedLinkBasedTaskOutput.scala @@ -20,14 +20,14 @@ final case class SupervisedLinkBasedTaskOutput( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = labeledTfrecordUriPrefix if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = unlabeledTfrecordUriPrefix if (!__value.isEmpty) { @@ -44,7 +44,7 @@ final case class SupervisedLinkBasedTaskOutput( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedNodeClassificationOutput.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedNodeClassificationOutput.scala index cb3ed0bb3..538828f18 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedNodeClassificationOutput.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/flattened_graph_metadata/SupervisedNodeClassificationOutput.scala @@ -20,14 +20,14 @@ final case class SupervisedNodeClassificationOutput( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = labeledTfrecordUriPrefix if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = unlabeledTfrecordUriPrefix if (!__value.isEmpty) { @@ -44,7 +44,7 @@ final case class SupervisedNodeClassificationOutput( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/Component.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/Component.scala index d61aa9295..3169265e7 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/Component.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/Component.scala @@ -25,63 +25,63 @@ sealed abstract class Component(val value: _root_.scala.Int) extends _root_.scal object Component extends _root_.scalapb.GeneratedEnumCompanion[Component] { sealed trait Recognized extends Component implicit def enumCompanion: _root_.scalapb.GeneratedEnumCompanion[Component] = this - + @SerialVersionUID(0L) case object Component_Unknown extends Component(0) with Component.Recognized { val index = 0 val name = "Component_Unknown" override def isComponentUnknown: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Config_Validator extends Component(1) with Component.Recognized { val index = 1 val name = "Component_Config_Validator" override def isComponentConfigValidator: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Config_Populator extends Component(2) with Component.Recognized { val index = 2 val name = "Component_Config_Populator" override def isComponentConfigPopulator: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Data_Preprocessor extends Component(3) with Component.Recognized { val index = 3 val name = "Component_Data_Preprocessor" override def isComponentDataPreprocessor: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Subgraph_Sampler extends Component(4) with Component.Recognized { val index = 4 val name = "Component_Subgraph_Sampler" override def isComponentSubgraphSampler: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Split_Generator extends Component(5) with Component.Recognized { val index = 5 val name = "Component_Split_Generator" override def isComponentSplitGenerator: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Trainer extends Component(6) with Component.Recognized { val index = 6 val name = "Component_Trainer" override def isComponentTrainer: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) case object Component_Inferencer extends Component(7) with Component.Recognized { val index = 7 val name = "Component_Inferencer" override def isComponentInferencer: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) final case class Unrecognized(unrecognizedValue: _root_.scala.Int) extends Component(unrecognizedValue) with _root_.scalapb.UnrecognizedEnum lazy val values: scala.collection.immutable.Seq[ValueType] = scala.collection.immutable.Seq(Component_Unknown, Component_Config_Validator, Component_Config_Populator, Component_Data_Preprocessor, Component_Subgraph_Sampler, Component_Split_Generator, Component_Trainer, Component_Inferencer) @@ -98,4 +98,4 @@ object Component extends _root_.scalapb.GeneratedEnumCompanion[Component] { } def javaDescriptor: _root_.com.google.protobuf.Descriptors.EnumDescriptor = GiglResourceConfigProto.javaDescriptor.getEnumTypes().get(0) def scalaDescriptor: _root_.scalapb.descriptors.EnumDescriptor = GiglResourceConfigProto.scalaDescriptor.enums(0) -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/CustomResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/CustomResourceConfig.scala new file mode 100644 index 000000000..33ca7b857 --- /dev/null +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/CustomResourceConfig.scala @@ -0,0 +1,159 @@ +// Generated by the Scala Plugin for the Protocol Buffer Compiler. +// Do not edit! +// +// Protofile syntax: PROTO3 + +package snapchat.research.gbml.gigl_resource_config + +/** Lets user-defined launchers be piped in. + * The launcher dispatcher invokes `command` (interpreted by /bin/sh -c so + * leading "KEY=VALUE" assignments parse as inline env vars) with `args` + * appended as positional arguments. String fields support OmegaConf + * `${gigl:<key>}` substitutions, which the dispatcher resolves at exec + * time from the runtime context (task_config_uri, applied_task_identifier, + * component, etc.). + * + * @param command + * Shell snippet invoked via /bin/sh -c. Leading "KEY=VALUE" assignments + * are honored by the shell, so callers can inline env vars (e.g. + * "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python -m my.cli"). + * @param args + * Positional arguments appended after the command. Each element is + * shell-quoted by the dispatcher so values containing spaces/quotes + * survive the shell pass. + */ +@SerialVersionUID(0L) +final case class CustomResourceConfig( + command: _root_.scala.Predef.String = "", + args: _root_.scala.Seq[_root_.scala.Predef.String] = _root_.scala.Seq.empty, + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[CustomResourceConfig] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + + { + val __value = command + if (!__value.isEmpty) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) + } + }; + args.foreach { __item => + val __value = __item + __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) + } + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + { + val __v = command + if (!__v.isEmpty) { + _output__.writeString(1, __v) + } + }; + args.foreach { __v => + val __m = __v + _output__.writeString(2, __m) + }; + unknownFields.writeTo(_output__) + } + def withCommand(__v: _root_.scala.Predef.String): CustomResourceConfig = copy(command = __v) + def clearArgs = copy(args = _root_.scala.Seq.empty) + def addArgs(__vs: _root_.scala.Predef.String *): CustomResourceConfig = addAllArgs(__vs) + def addAllArgs(__vs: Iterable[_root_.scala.Predef.String]): CustomResourceConfig = copy(args = args ++ __vs) + def withArgs(__v: _root_.scala.Seq[_root_.scala.Predef.String]): CustomResourceConfig = copy(args = __v) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => { + val __t = command + if (__t != "") __t else null + } + case 2 => args + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => _root_.scalapb.descriptors.PString(command) + case 2 => _root_.scalapb.descriptors.PRepeated(args.iterator.map(_root_.scalapb.descriptors.PString(_)).toVector) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig.type = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.CustomResourceConfig]) +} + +object CustomResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gigl_resource_config.CustomResourceConfig = { + var __command: _root_.scala.Predef.String = "" + val __args: _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String] = new _root_.scala.collection.immutable.VectorBuilder[_root_.scala.Predef.String] + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __command = _input__.readStringRequireUtf8() + case 18 => + __args += _input__.readStringRequireUtf8() + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gigl_resource_config.CustomResourceConfig( + command = __command, + args = __args.result(), + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gigl_resource_config.CustomResourceConfig( + command = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Predef.String]).getOrElse(""), + args = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).map(_.as[_root_.scala.Seq[_root_.scala.Predef.String]]).getOrElse(_root_.scala.Seq.empty) + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(11) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(11) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = throw new MatchError(__number) + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig( + command = "", + args = _root_.scala.Seq.empty + ) + implicit class CustomResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.CustomResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.CustomResourceConfig](_l) { + def command: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.command)((c_, f_) => c_.copy(command = f_)) + def args: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[_root_.scala.Predef.String]] = field(_.args)((c_, f_) => c_.copy(args = f_)) + } + final val COMMAND_FIELD_NUMBER = 1 + final val ARGS_FIELD_NUMBER = 2 + def of( + command: _root_.scala.Predef.String, + args: _root_.scala.Seq[_root_.scala.Predef.String] + ): _root_.snapchat.research.gbml.gigl_resource_config.CustomResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.CustomResourceConfig( + command, + args + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.CustomResourceConfig]) +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DataPreprocessorConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DataPreprocessorConfig.scala index 75c2e54af..7d9951c80 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DataPreprocessorConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DataPreprocessorConfig.scala @@ -35,7 +35,7 @@ final case class DataPreprocessorConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { edgePreprocessorConfig.foreach { __v => diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedInferencerConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedInferencerConfig.scala index 8363bdb1f..2198a2eb5 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedInferencerConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedInferencerConfig.scala @@ -38,7 +38,7 @@ final case class DistributedInferencerConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { trainerConfig.vertexAiInferencerConfig.foreach { __v => @@ -165,7 +165,7 @@ object DistributedInferencerConfig extends scalapb.GeneratedMessageCompanion[sna override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class VertexAiInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig) extends snapchat.research.gbml.gigl_resource_config.DistributedInferencerConfig.TrainerConfig { type ValueType = snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala index 676b61794..60313b1cc 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/DistributedTrainerConfig.scala @@ -131,8 +131,8 @@ object DistributedTrainerConfig extends scalapb.GeneratedMessageCompanion[snapch ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(11) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(11) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(12) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(12) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala index d88d363e9..16ff1d6a6 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfig.scala @@ -275,8 +275,8 @@ object GiglResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.res ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(15) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(15) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(16) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(16) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala index a086f6113..603a940e4 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/GiglResourceConfigProto.scala @@ -20,6 +20,7 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { snapchat.research.gbml.gigl_resource_config.KFPResourceConfig, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig, + snapchat.research.gbml.gigl_resource_config.CustomResourceConfig, snapchat.research.gbml.gigl_resource_config.DistributedTrainerConfig, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig, @@ -65,65 +66,70 @@ object GiglResourceConfigProto extends _root_.scalapb.GeneratedFileObject { AEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0IT4j8QEg5ncmFwaFN0b3JlUG9vbFIOZ 3JhcGhTdG9yZVBvb2wSYwoMY29tcHV0ZV9wb29sGAIgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291c mNlQ29uZmlnQhDiPw0SC2NvbXB1dGVQb29sUgtjb21wdXRlUG9vbBJpCiBjb21wdXRlX2NsdXN0ZXJfbG9jYWxfd29ybGRfc2l6Z - RgDIAEoBUIh4j8eEhxjb21wdXRlQ2x1c3RlckxvY2FsV29ybGRTaXplUhxjb21wdXRlQ2x1c3RlckxvY2FsV29ybGRTaXplIp0DC - hhEaXN0cmlidXRlZFRyYWluZXJDb25maWcShAEKGHZlcnRleF9haV90cmFpbmVyX2NvbmZpZxgBIAEoCzItLnNuYXBjaGF0LnJlc - 2VhcmNoLmdibWwuVmVydGV4QWlUcmFpbmVyQ29uZmlnQhriPxcSFXZlcnRleEFpVHJhaW5lckNvbmZpZ0gAUhV2ZXJ0ZXhBaVRyY - WluZXJDb25maWcSbwoSa2ZwX3RyYWluZXJfY29uZmlnGAIgASgLMiguc25hcGNoYXQucmVzZWFyY2guZ2JtbC5LRlBUcmFpbmVyQ - 29uZmlnQhXiPxISEGtmcFRyYWluZXJDb25maWdIAFIQa2ZwVHJhaW5lckNvbmZpZxJ3ChRsb2NhbF90cmFpbmVyX2NvbmZpZxgDI - AEoCzIqLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxUcmFpbmVyQ29uZmlnQhfiPxQSEmxvY2FsVHJhaW5lckNvbmZpZ0gAU - hJsb2NhbFRyYWluZXJDb25maWdCEAoOdHJhaW5lcl9jb25maWcixwQKFVRyYWluZXJSZXNvdXJjZUNvbmZpZxKFAQoYdmVydGV4X - 2FpX3RyYWluZXJfY29uZmlnGAEgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQhriP - xcSFXZlcnRleEFpVHJhaW5lckNvbmZpZ0gAUhV2ZXJ0ZXhBaVRyYWluZXJDb25maWcScAoSa2ZwX3RyYWluZXJfY29uZmlnGAIgA - SgLMikuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5LRlBSZXNvdXJjZUNvbmZpZ0IV4j8SEhBrZnBUcmFpbmVyQ29uZmlnSABSEGtmc - FRyYWluZXJDb25maWcSeAoUbG9jYWxfdHJhaW5lcl9jb25maWcYAyABKAsyKy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkxvY2FsU - mVzb3VyY2VDb25maWdCF+I/FBISbG9jYWxUcmFpbmVyQ29uZmlnSABSEmxvY2FsVHJhaW5lckNvbmZpZxKnAQokdmVydGV4X2FpX - 2dyYXBoX3N0b3JlX3RyYWluZXJfY29uZmlnGAQgASgLMjAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaUdyYXBoU3Rvc - mVDb25maWdCJOI/IRIfdmVydGV4QWlHcmFwaFN0b3JlVHJhaW5lckNvbmZpZ0gAUh92ZXJ0ZXhBaUdyYXBoU3RvcmVUcmFpbmVyQ - 29uZmlnQhAKDnRyYWluZXJfY29uZmlnIocFChhJbmZlcmVuY2VyUmVzb3VyY2VDb25maWcSjgEKG3ZlcnRleF9haV9pbmZlcmVuY - 2VyX2NvbmZpZxgBIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlSZXNvdXJjZUNvbmZpZ0Id4j8aEhh2ZXJ0Z - XhBaUluZmVyZW5jZXJDb25maWdIAFIYdmVydGV4QWlJbmZlcmVuY2VyQ29uZmlnEo0BChpkYXRhZmxvd19pbmZlcmVuY2VyX2Nvb - mZpZxgCIAEoCzIuLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuRGF0YWZsb3dSZXNvdXJjZUNvbmZpZ0Id4j8aEhhkYXRhZmxvd0luZ - mVyZW5jZXJDb25maWdIAFIYZGF0YWZsb3dJbmZlcmVuY2VyQ29uZmlnEoEBChdsb2NhbF9pbmZlcmVuY2VyX2NvbmZpZxgDIAEoC - zIrLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuTG9jYWxSZXNvdXJjZUNvbmZpZ0Ia4j8XEhVsb2NhbEluZmVyZW5jZXJDb25maWdIA - FIVbG9jYWxJbmZlcmVuY2VyQ29uZmlnErABCid2ZXJ0ZXhfYWlfZ3JhcGhfc3RvcmVfaW5mZXJlbmNlcl9jb25maWcYBCABKAsyM - C5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRleEFpR3JhcGhTdG9yZUNvbmZpZ0In4j8kEiJ2ZXJ0ZXhBaUdyYXBoU3RvcmVJb - mZlcmVuY2VyQ29uZmlnSABSInZlcnRleEFpR3JhcGhTdG9yZUluZmVyZW5jZXJDb25maWdCEwoRaW5mZXJlbmNlcl9jb25maWcil - wgKFFNoYXJlZFJlc291cmNlQ29uZmlnEn4KD3Jlc291cmNlX2xhYmVscxgBIAMoCzJALnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU - 2hhcmVkUmVzb3VyY2VDb25maWcuUmVzb3VyY2VMYWJlbHNFbnRyeUIT4j8QEg5yZXNvdXJjZUxhYmVsc1IOcmVzb3VyY2VMYWJlb - HMSjgEKFWNvbW1vbl9jb21wdXRlX2NvbmZpZxgCIAEoCzJALnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb - 25maWcuQ29tbW9uQ29tcHV0ZUNvbmZpZ0IY4j8VEhNjb21tb25Db21wdXRlQ29uZmlnUhNjb21tb25Db21wdXRlQ29uZmlnGpQFC - hNDb21tb25Db21wdXRlQ29uZmlnEiYKB3Byb2plY3QYASABKAlCDOI/CRIHcHJvamVjdFIHcHJvamVjdBIjCgZyZWdpb24YAiABK - AlCC+I/CBIGcmVnaW9uUgZyZWdpb24SQwoSdGVtcF9hc3NldHNfYnVja2V0GAMgASgJQhXiPxISEHRlbXBBc3NldHNCdWNrZXRSE - HRlbXBBc3NldHNCdWNrZXQSXAobdGVtcF9yZWdpb25hbF9hc3NldHNfYnVja2V0GAQgASgJQh3iPxoSGHRlbXBSZWdpb25hbEFzc - 2V0c0J1Y2tldFIYdGVtcFJlZ2lvbmFsQXNzZXRzQnVja2V0EkMKEnBlcm1fYXNzZXRzX2J1Y2tldBgFIAEoCUIV4j8SEhBwZXJtQ - XNzZXRzQnVja2V0UhBwZXJtQXNzZXRzQnVja2V0EloKG3RlbXBfYXNzZXRzX2JxX2RhdGFzZXRfbmFtZRgGIAEoCUIc4j8ZEhd0Z - W1wQXNzZXRzQnFEYXRhc2V0TmFtZVIXdGVtcEFzc2V0c0JxRGF0YXNldE5hbWUSVgoZZW1iZWRkaW5nX2JxX2RhdGFzZXRfbmFtZ - RgHIAEoCUIb4j8YEhZlbWJlZGRpbmdCcURhdGFzZXROYW1lUhZlbWJlZGRpbmdCcURhdGFzZXROYW1lElYKGWdjcF9zZXJ2aWNlX - 2FjY291bnRfZW1haWwYCCABKAlCG+I/GBIWZ2NwU2VydmljZUFjY291bnRFbWFpbFIWZ2NwU2VydmljZUFjY291bnRFbWFpbBI8C - g9kYXRhZmxvd19ydW5uZXIYCyABKAlCE+I/EBIOZGF0YWZsb3dSdW5uZXJSDmRhdGFmbG93UnVubmVyGlcKE1Jlc291cmNlTGFiZ - WxzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEi9 - wgKEkdpZ2xSZXNvdXJjZUNvbmZpZxJbChpzaGFyZWRfcmVzb3VyY2VfY29uZmlnX3VyaRgBIAEoCUIc4j8ZEhdzaGFyZWRSZXNvd - XJjZUNvbmZpZ1VyaUgAUhdzaGFyZWRSZXNvdXJjZUNvbmZpZ1VyaRJ/ChZzaGFyZWRfcmVzb3VyY2VfY29uZmlnGAIgASgLMiwuc - 25hcGNoYXQucmVzZWFyY2guZ2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZ0IZ4j8WEhRzaGFyZWRSZXNvdXJjZUNvbmZpZ0gAUhRza - GFyZWRSZXNvdXJjZUNvbmZpZxJ4ChNwcmVwcm9jZXNzb3JfY29uZmlnGAwgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EY - XRhUHJlcHJvY2Vzc29yQ29uZmlnQhfiPxQSEnByZXByb2Nlc3NvckNvbmZpZ1IScHJlcHJvY2Vzc29yQ29uZmlnEn8KF3N1YmdyY - XBoX3NhbXBsZXJfY29uZmlnGA0gASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TcGFya1Jlc291cmNlQ29uZmlnQhriPxcSF - XN1YmdyYXBoU2FtcGxlckNvbmZpZ1IVc3ViZ3JhcGhTYW1wbGVyQ29uZmlnEnwKFnNwbGl0X2dlbmVyYXRvcl9jb25maWcYDiABK - AsyKy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlNwYXJrUmVzb3VyY2VDb25maWdCGeI/FhIUc3BsaXRHZW5lcmF0b3JDb25maWdSF - HNwbGl0R2VuZXJhdG9yQ29uZmlnEm0KDnRyYWluZXJfY29uZmlnGA8gASgLMjAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EaXN0c - mlidXRlZFRyYWluZXJDb25maWdCFBgB4j8PEg10cmFpbmVyQ29uZmlnUg10cmFpbmVyQ29uZmlnEnQKEWluZmVyZW5jZXJfY29uZ - mlnGBAgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EYXRhZmxvd1Jlc291cmNlQ29uZmlnQhcYAeI/EhIQaW5mZXJlbmNlc - kNvbmZpZ1IQaW5mZXJlbmNlckNvbmZpZxKBAQoXdHJhaW5lcl9yZXNvdXJjZV9jb25maWcYESABKAsyLS5zbmFwY2hhdC5yZXNlY - XJjaC5nYm1sLlRyYWluZXJSZXNvdXJjZUNvbmZpZ0Ia4j8XEhV0cmFpbmVyUmVzb3VyY2VDb25maWdSFXRyYWluZXJSZXNvdXJjZ - UNvbmZpZxKNAQoaaW5mZXJlbmNlcl9yZXNvdXJjZV9jb25maWcYEiABKAsyMC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkluZmVyZ - W5jZXJSZXNvdXJjZUNvbmZpZ0Id4j8aEhhpbmZlcmVuY2VyUmVzb3VyY2VDb25maWdSGGluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZ - 0IRCg9zaGFyZWRfcmVzb3VyY2Uq4wMKCUNvbXBvbmVudBItChFDb21wb25lbnRfVW5rbm93bhAAGhbiPxMSEUNvbXBvbmVudF9Vb - mtub3duEj8KGkNvbXBvbmVudF9Db25maWdfVmFsaWRhdG9yEAEaH+I/HBIaQ29tcG9uZW50X0NvbmZpZ19WYWxpZGF0b3ISPwoaQ - 29tcG9uZW50X0NvbmZpZ19Qb3B1bGF0b3IQAhof4j8cEhpDb21wb25lbnRfQ29uZmlnX1BvcHVsYXRvchJBChtDb21wb25lbnRfR - GF0YV9QcmVwcm9jZXNzb3IQAxog4j8dEhtDb21wb25lbnRfRGF0YV9QcmVwcm9jZXNzb3ISPwoaQ29tcG9uZW50X1N1YmdyYXBoX - 1NhbXBsZXIQBBof4j8cEhpDb21wb25lbnRfU3ViZ3JhcGhfU2FtcGxlchI9ChlDb21wb25lbnRfU3BsaXRfR2VuZXJhdG9yEAUaH - uI/GxIZQ29tcG9uZW50X1NwbGl0X0dlbmVyYXRvchItChFDb21wb25lbnRfVHJhaW5lchAGGhbiPxMSEUNvbXBvbmVudF9UcmFpb - mVyEjMKFENvbXBvbmVudF9JbmZlcmVuY2VyEAcaGeI/FhIUQ29tcG9uZW50X0luZmVyZW5jZXJiBnByb3RvMw==""" + RgDIAEoBUIh4j8eEhxjb21wdXRlQ2x1c3RlckxvY2FsV29ybGRTaXplUhxjb21wdXRlQ2x1c3RlckxvY2FsV29ybGRTaXplIl0KF + EN1c3RvbVJlc291cmNlQ29uZmlnEiYKB2NvbW1hbmQYASABKAlCDOI/CRIHY29tbWFuZFIHY29tbWFuZBIdCgRhcmdzGAIgAygJQ + gniPwYSBGFyZ3NSBGFyZ3MinQMKGERpc3RyaWJ1dGVkVHJhaW5lckNvbmZpZxKEAQoYdmVydGV4X2FpX3RyYWluZXJfY29uZmlnG + AEgASgLMi0uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVRyYWluZXJDb25maWdCGuI/FxIVdmVydGV4QWlUcmFpbmVyQ + 29uZmlnSABSFXZlcnRleEFpVHJhaW5lckNvbmZpZxJvChJrZnBfdHJhaW5lcl9jb25maWcYAiABKAsyKC5zbmFwY2hhdC5yZXNlY + XJjaC5nYm1sLktGUFRyYWluZXJDb25maWdCFeI/EhIQa2ZwVHJhaW5lckNvbmZpZ0gAUhBrZnBUcmFpbmVyQ29uZmlnEncKFGxvY + 2FsX3RyYWluZXJfY29uZmlnGAMgASgLMiouc25hcGNoYXQucmVzZWFyY2guZ2JtbC5Mb2NhbFRyYWluZXJDb25maWdCF+I/FBISb + G9jYWxUcmFpbmVyQ29uZmlnSABSEmxvY2FsVHJhaW5lckNvbmZpZ0IQCg50cmFpbmVyX2NvbmZpZyLFBQoVVHJhaW5lclJlc291c + mNlQ29uZmlnEoUBChh2ZXJ0ZXhfYWlfdHJhaW5lcl9jb25maWcYASABKAsyLi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlZlcnRle + EFpUmVzb3VyY2VDb25maWdCGuI/FxIVdmVydGV4QWlUcmFpbmVyQ29uZmlnSABSFXZlcnRleEFpVHJhaW5lckNvbmZpZxJwChJrZ + nBfdHJhaW5lcl9jb25maWcYAiABKAsyKS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLktGUFJlc291cmNlQ29uZmlnQhXiPxISEGtmc + FRyYWluZXJDb25maWdIAFIQa2ZwVHJhaW5lckNvbmZpZxJ4ChRsb2NhbF90cmFpbmVyX2NvbmZpZxgDIAEoCzIrLnNuYXBjaGF0L + nJlc2VhcmNoLmdibWwuTG9jYWxSZXNvdXJjZUNvbmZpZ0IX4j8UEhJsb2NhbFRyYWluZXJDb25maWdIAFISbG9jYWxUcmFpbmVyQ + 29uZmlnEqcBCiR2ZXJ0ZXhfYWlfZ3JhcGhfc3RvcmVfdHJhaW5lcl9jb25maWcYBCABKAsyMC5zbmFwY2hhdC5yZXNlYXJjaC5nY + m1sLlZlcnRleEFpR3JhcGhTdG9yZUNvbmZpZ0Ik4j8hEh92ZXJ0ZXhBaUdyYXBoU3RvcmVUcmFpbmVyQ29uZmlnSABSH3ZlcnRle + EFpR3JhcGhTdG9yZVRyYWluZXJDb25maWcSfAoVY3VzdG9tX3RyYWluZXJfY29uZmlnGAUgASgLMiwuc25hcGNoYXQucmVzZWFyY + 2guZ2JtbC5DdXN0b21SZXNvdXJjZUNvbmZpZ0IY4j8VEhNjdXN0b21UcmFpbmVyQ29uZmlnSABSE2N1c3RvbVRyYWluZXJDb25ma + WdCEAoOdHJhaW5lcl9jb25maWcijwYKGEluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZxKOAQobdmVydGV4X2FpX2luZmVyZW5jZXJfY + 29uZmlnGAEgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5WZXJ0ZXhBaVJlc291cmNlQ29uZmlnQh3iPxoSGHZlcnRleEFpS + W5mZXJlbmNlckNvbmZpZ0gAUhh2ZXJ0ZXhBaUluZmVyZW5jZXJDb25maWcSjQEKGmRhdGFmbG93X2luZmVyZW5jZXJfY29uZmlnG + AIgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EYXRhZmxvd1Jlc291cmNlQ29uZmlnQh3iPxoSGGRhdGFmbG93SW5mZXJlb + mNlckNvbmZpZ0gAUhhkYXRhZmxvd0luZmVyZW5jZXJDb25maWcSgQEKF2xvY2FsX2luZmVyZW5jZXJfY29uZmlnGAMgASgLMisuc + 25hcGNoYXQucmVzZWFyY2guZ2JtbC5Mb2NhbFJlc291cmNlQ29uZmlnQhriPxcSFWxvY2FsSW5mZXJlbmNlckNvbmZpZ0gAUhVsb + 2NhbEluZmVyZW5jZXJDb25maWcSsAEKJ3ZlcnRleF9haV9ncmFwaF9zdG9yZV9pbmZlcmVuY2VyX2NvbmZpZxgEIAEoCzIwLnNuY + XBjaGF0LnJlc2VhcmNoLmdibWwuVmVydGV4QWlHcmFwaFN0b3JlQ29uZmlnQifiPyQSInZlcnRleEFpR3JhcGhTdG9yZUluZmVyZ + W5jZXJDb25maWdIAFIidmVydGV4QWlHcmFwaFN0b3JlSW5mZXJlbmNlckNvbmZpZxKFAQoYY3VzdG9tX2luZmVyZW5jZXJfY29uZ + mlnGAUgASgLMiwuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5DdXN0b21SZXNvdXJjZUNvbmZpZ0Ib4j8YEhZjdXN0b21JbmZlcmVuY + 2VyQ29uZmlnSABSFmN1c3RvbUluZmVyZW5jZXJDb25maWdCEwoRaW5mZXJlbmNlcl9jb25maWcilwgKFFNoYXJlZFJlc291cmNlQ + 29uZmlnEn4KD3Jlc291cmNlX2xhYmVscxgBIAMoCzJALnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25ma + WcuUmVzb3VyY2VMYWJlbHNFbnRyeUIT4j8QEg5yZXNvdXJjZUxhYmVsc1IOcmVzb3VyY2VMYWJlbHMSjgEKFWNvbW1vbl9jb21wd + XRlX2NvbmZpZxgCIAEoCzJALnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuU2hhcmVkUmVzb3VyY2VDb25maWcuQ29tbW9uQ29tcHV0Z + UNvbmZpZ0IY4j8VEhNjb21tb25Db21wdXRlQ29uZmlnUhNjb21tb25Db21wdXRlQ29uZmlnGpQFChNDb21tb25Db21wdXRlQ29uZ + mlnEiYKB3Byb2plY3QYASABKAlCDOI/CRIHcHJvamVjdFIHcHJvamVjdBIjCgZyZWdpb24YAiABKAlCC+I/CBIGcmVnaW9uUgZyZ + Wdpb24SQwoSdGVtcF9hc3NldHNfYnVja2V0GAMgASgJQhXiPxISEHRlbXBBc3NldHNCdWNrZXRSEHRlbXBBc3NldHNCdWNrZXQSX + AobdGVtcF9yZWdpb25hbF9hc3NldHNfYnVja2V0GAQgASgJQh3iPxoSGHRlbXBSZWdpb25hbEFzc2V0c0J1Y2tldFIYdGVtcFJlZ + 2lvbmFsQXNzZXRzQnVja2V0EkMKEnBlcm1fYXNzZXRzX2J1Y2tldBgFIAEoCUIV4j8SEhBwZXJtQXNzZXRzQnVja2V0UhBwZXJtQ + XNzZXRzQnVja2V0EloKG3RlbXBfYXNzZXRzX2JxX2RhdGFzZXRfbmFtZRgGIAEoCUIc4j8ZEhd0ZW1wQXNzZXRzQnFEYXRhc2V0T + mFtZVIXdGVtcEFzc2V0c0JxRGF0YXNldE5hbWUSVgoZZW1iZWRkaW5nX2JxX2RhdGFzZXRfbmFtZRgHIAEoCUIb4j8YEhZlbWJlZ + GRpbmdCcURhdGFzZXROYW1lUhZlbWJlZGRpbmdCcURhdGFzZXROYW1lElYKGWdjcF9zZXJ2aWNlX2FjY291bnRfZW1haWwYCCABK + AlCG+I/GBIWZ2NwU2VydmljZUFjY291bnRFbWFpbFIWZ2NwU2VydmljZUFjY291bnRFbWFpbBI8Cg9kYXRhZmxvd19ydW5uZXIYC + yABKAlCE+I/EBIOZGF0YWZsb3dSdW5uZXJSDmRhdGFmbG93UnVubmVyGlcKE1Jlc291cmNlTGFiZWxzRW50cnkSGgoDa2V5GAEgA + SgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEi9wgKEkdpZ2xSZXNvdXJjZUNvb + mZpZxJbChpzaGFyZWRfcmVzb3VyY2VfY29uZmlnX3VyaRgBIAEoCUIc4j8ZEhdzaGFyZWRSZXNvdXJjZUNvbmZpZ1VyaUgAUhdza + GFyZWRSZXNvdXJjZUNvbmZpZ1VyaRJ/ChZzaGFyZWRfcmVzb3VyY2VfY29uZmlnGAIgASgLMiwuc25hcGNoYXQucmVzZWFyY2guZ + 2JtbC5TaGFyZWRSZXNvdXJjZUNvbmZpZ0IZ4j8WEhRzaGFyZWRSZXNvdXJjZUNvbmZpZ0gAUhRzaGFyZWRSZXNvdXJjZUNvbmZpZ + xJ4ChNwcmVwcm9jZXNzb3JfY29uZmlnGAwgASgLMi4uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EYXRhUHJlcHJvY2Vzc29yQ29uZ + mlnQhfiPxQSEnByZXByb2Nlc3NvckNvbmZpZ1IScHJlcHJvY2Vzc29yQ29uZmlnEn8KF3N1YmdyYXBoX3NhbXBsZXJfY29uZmlnG + A0gASgLMisuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5TcGFya1Jlc291cmNlQ29uZmlnQhriPxcSFXN1YmdyYXBoU2FtcGxlckNvb + mZpZ1IVc3ViZ3JhcGhTYW1wbGVyQ29uZmlnEnwKFnNwbGl0X2dlbmVyYXRvcl9jb25maWcYDiABKAsyKy5zbmFwY2hhdC5yZXNlY + XJjaC5nYm1sLlNwYXJrUmVzb3VyY2VDb25maWdCGeI/FhIUc3BsaXRHZW5lcmF0b3JDb25maWdSFHNwbGl0R2VuZXJhdG9yQ29uZ + mlnEm0KDnRyYWluZXJfY29uZmlnGA8gASgLMjAuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5EaXN0cmlidXRlZFRyYWluZXJDb25ma + WdCFBgB4j8PEg10cmFpbmVyQ29uZmlnUg10cmFpbmVyQ29uZmlnEnQKEWluZmVyZW5jZXJfY29uZmlnGBAgASgLMi4uc25hcGNoY + XQucmVzZWFyY2guZ2JtbC5EYXRhZmxvd1Jlc291cmNlQ29uZmlnQhcYAeI/EhIQaW5mZXJlbmNlckNvbmZpZ1IQaW5mZXJlbmNlc + kNvbmZpZxKBAQoXdHJhaW5lcl9yZXNvdXJjZV9jb25maWcYESABKAsyLS5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLlRyYWluZXJSZ + XNvdXJjZUNvbmZpZ0Ia4j8XEhV0cmFpbmVyUmVzb3VyY2VDb25maWdSFXRyYWluZXJSZXNvdXJjZUNvbmZpZxKNAQoaaW5mZXJlb + mNlcl9yZXNvdXJjZV9jb25maWcYEiABKAsyMC5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZ + 0Id4j8aEhhpbmZlcmVuY2VyUmVzb3VyY2VDb25maWdSGGluZmVyZW5jZXJSZXNvdXJjZUNvbmZpZ0IRCg9zaGFyZWRfcmVzb3VyY + 2Uq4wMKCUNvbXBvbmVudBItChFDb21wb25lbnRfVW5rbm93bhAAGhbiPxMSEUNvbXBvbmVudF9Vbmtub3duEj8KGkNvbXBvbmVud + F9Db25maWdfVmFsaWRhdG9yEAEaH+I/HBIaQ29tcG9uZW50X0NvbmZpZ19WYWxpZGF0b3ISPwoaQ29tcG9uZW50X0NvbmZpZ19Qb + 3B1bGF0b3IQAhof4j8cEhpDb21wb25lbnRfQ29uZmlnX1BvcHVsYXRvchJBChtDb21wb25lbnRfRGF0YV9QcmVwcm9jZXNzb3IQA + xog4j8dEhtDb21wb25lbnRfRGF0YV9QcmVwcm9jZXNzb3ISPwoaQ29tcG9uZW50X1N1YmdyYXBoX1NhbXBsZXIQBBof4j8cEhpDb + 21wb25lbnRfU3ViZ3JhcGhfU2FtcGxlchI9ChlDb21wb25lbnRfU3BsaXRfR2VuZXJhdG9yEAUaHuI/GxIZQ29tcG9uZW50X1Nwb + Gl0X0dlbmVyYXRvchItChFDb21wb25lbnRfVHJhaW5lchAGGhbiPxMSEUNvbXBvbmVudF9UcmFpbmVyEjMKFENvbXBvbmVudF9Jb + mZlcmVuY2VyEAcaGeI/FhIUQ29tcG9uZW50X0luZmVyZW5jZXJiBnByb3RvMw==""" ).mkString) lazy val scalaDescriptor: _root_.scalapb.descriptors.FileDescriptor = { val scalaProto = com.google.protobuf.descriptor.FileDescriptorProto.parseFrom(ProtoBytes) diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala index 77a949c19..dd637b565 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/InferencerResourceConfig.scala @@ -32,6 +32,10 @@ final case class InferencerResourceConfig( val __value = inferencerConfig.vertexAiGraphStoreInferencerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; + if (inferencerConfig.customInferencerConfig.isDefined) { + val __value = inferencerConfig.customInferencerConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -69,6 +73,12 @@ final case class InferencerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; + inferencerConfig.customInferencerConfig.foreach { __v => + val __m = __v + _output__.writeTag(5, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; unknownFields.writeTo(_output__) } def getVertexAiInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = inferencerConfig.vertexAiInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) @@ -79,6 +89,8 @@ final case class InferencerResourceConfig( def withLocalInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__v)) def getVertexAiGraphStoreInferencerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = inferencerConfig.vertexAiGraphStoreInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig.defaultInstance) def withVertexAiGraphStoreInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(__v)) + def getCustomInferencerConfig: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig = inferencerConfig.customInferencerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.CustomResourceConfig.defaultInstance) + def withCustomInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig): InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.CustomInferencerConfig(__v)) def clearInferencerConfig: InferencerResourceConfig = copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) def withInferencerConfig(__v: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig): InferencerResourceConfig = copy(inferencerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -89,6 +101,7 @@ final case class InferencerResourceConfig( case 2 => inferencerConfig.dataflowInferencerConfig.orNull case 3 => inferencerConfig.localInferencerConfig.orNull case 4 => inferencerConfig.vertexAiGraphStoreInferencerConfig.orNull + case 5 => inferencerConfig.customInferencerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -98,6 +111,7 @@ final case class InferencerResourceConfig( case 2 => inferencerConfig.dataflowInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => inferencerConfig.localInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 4 => inferencerConfig.vertexAiGraphStoreInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 5 => inferencerConfig.customInferencerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -123,6 +137,8 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(__inferencerConfig.localInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 34 => __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(__inferencerConfig.vertexAiGraphStoreInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 42 => + __inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.CustomInferencerConfig(__inferencerConfig.customInferencerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -143,12 +159,13 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(5).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.CustomInferencerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(13) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(13) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(14) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(14) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -156,6 +173,7 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch case 2 => __out = snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig + case 5 => __out = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig } __out } @@ -171,10 +189,12 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch def isDataflowInferencerConfig: _root_.scala.Boolean = false def isLocalInferencerConfig: _root_.scala.Boolean = false def isVertexAiGraphStoreInferencerConfig: _root_.scala.Boolean = false + def isCustomInferencerConfig: _root_.scala.Boolean = false def vertexAiInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def dataflowInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = _root_.scala.None def localInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None def vertexAiGraphStoreInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = _root_.scala.None + def customInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = _root_.scala.None } object InferencerConfig { @SerialVersionUID(0L) @@ -214,18 +234,27 @@ object InferencerResourceConfig extends scalapb.GeneratedMessageCompanion[snapch override def vertexAiGraphStoreInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = Some(value) override def number: _root_.scala.Int = 4 } + @SerialVersionUID(0L) + final case class CustomInferencerConfig(value: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig) extends snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig + override def isCustomInferencerConfig: _root_.scala.Boolean = true + override def customInferencerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = Some(value) + override def number: _root_.scala.Int = 5 + } } implicit class InferencerResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig](_l) { def vertexAiInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiInferencerConfig(f_))) def dataflowInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.DataflowResourceConfig] = field(_.getDataflowInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.DataflowInferencerConfig(f_))) def localInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.LocalInferencerConfig(f_))) def vertexAiGraphStoreInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = field(_.getVertexAiGraphStoreInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.VertexAiGraphStoreInferencerConfig(f_))) + def customInferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = field(_.getCustomInferencerConfig)((c_, f_) => c_.copy(inferencerConfig = snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig.CustomInferencerConfig(f_))) def inferencerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig] = field(_.inferencerConfig)((c_, f_) => c_.copy(inferencerConfig = f_)) } final val VERTEX_AI_INFERENCER_CONFIG_FIELD_NUMBER = 1 final val DATAFLOW_INFERENCER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_INFERENCER_CONFIG_FIELD_NUMBER = 3 final val VERTEX_AI_GRAPH_STORE_INFERENCER_CONFIG_FIELD_NUMBER = 4 + final val CUSTOM_INFERENCER_CONFIG_FIELD_NUMBER = 5 def of( inferencerConfig: snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig.InferencerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.InferencerResourceConfig( diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/KFPTrainerConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/KFPTrainerConfig.scala index 909ec979b..1225ba210 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/KFPTrainerConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/KFPTrainerConfig.scala @@ -32,35 +32,35 @@ final case class KFPTrainerConfig( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = cpuRequest if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = memoryRequest if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = gpuType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(3, __value) } }; - + { val __value = gpuLimit if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(4, __value) } }; - + { val __value = numReplicas if (__value != 0) { @@ -77,7 +77,7 @@ final case class KFPTrainerConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/LocalTrainerConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/LocalTrainerConfig.scala index ba2cc9389..86e238074 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/LocalTrainerConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/LocalTrainerConfig.scala @@ -17,7 +17,7 @@ final case class LocalTrainerConfig( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = numWorkers if (__value != 0) { @@ -34,7 +34,7 @@ final case class LocalTrainerConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala index 393ebe301..bdeda8bdb 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SharedResourceConfig.scala @@ -116,8 +116,8 @@ object SharedResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat.r ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(14) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(14) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(15) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(15) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SparkResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SparkResourceConfig.scala index d32c915cb..96f354f47 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SparkResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/SparkResourceConfig.scala @@ -25,21 +25,21 @@ final case class SparkResourceConfig( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = machineType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = numLocalSsds if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(2, __value) } }; - + { val __value = numReplicas if (__value != 0) { @@ -56,7 +56,7 @@ final case class SparkResourceConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala index 4249c27fe..2323c5f45 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/TrainerResourceConfig.scala @@ -32,6 +32,10 @@ final case class TrainerResourceConfig( val __value = trainerConfig.vertexAiGraphStoreTrainerConfig.get __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; + if (trainerConfig.customTrainerConfig.isDefined) { + val __value = trainerConfig.customTrainerConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -69,6 +73,12 @@ final case class TrainerResourceConfig( _output__.writeUInt32NoTag(__m.serializedSize) __m.writeTo(_output__) }; + trainerConfig.customTrainerConfig.foreach { __v => + val __m = __v + _output__.writeTag(5, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; unknownFields.writeTo(_output__) } def getVertexAiTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig = trainerConfig.vertexAiTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig.defaultInstance) @@ -79,6 +89,8 @@ final case class TrainerResourceConfig( def withLocalTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.LocalResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__v)) def getVertexAiGraphStoreTrainerConfig: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig = trainerConfig.vertexAiGraphStoreTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig.defaultInstance) def withVertexAiGraphStoreTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(__v)) + def getCustomTrainerConfig: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig = trainerConfig.customTrainerConfig.getOrElse(snapchat.research.gbml.gigl_resource_config.CustomResourceConfig.defaultInstance) + def withCustomTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig): TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.CustomTrainerConfig(__v)) def clearTrainerConfig: TrainerResourceConfig = copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) def withTrainerConfig(__v: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig): TrainerResourceConfig = copy(trainerConfig = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -89,6 +101,7 @@ final case class TrainerResourceConfig( case 2 => trainerConfig.kfpTrainerConfig.orNull case 3 => trainerConfig.localTrainerConfig.orNull case 4 => trainerConfig.vertexAiGraphStoreTrainerConfig.orNull + case 5 => trainerConfig.customTrainerConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -98,6 +111,7 @@ final case class TrainerResourceConfig( case 2 => trainerConfig.kfpTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 3 => trainerConfig.localTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) case 4 => trainerConfig.vertexAiGraphStoreTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) + case 5 => trainerConfig.customTrainerConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -123,6 +137,8 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(__trainerConfig.localTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case 34 => __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(__trainerConfig.vertexAiGraphStoreTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) + case 42 => + __trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.CustomTrainerConfig(__trainerConfig.customTrainerConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -143,12 +159,13 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(3).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(_))) .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(4).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(_))) + .orElse[snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig](__fieldsMap.get(scalaDescriptor.findFieldByNumber(5).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig]]).map(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.CustomTrainerConfig(_))) .getOrElse(snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(12) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(12) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = GiglResourceConfigProto.javaDescriptor.getMessageTypes().get(13) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = GiglResourceConfigProto.scalaDescriptor.messages(13) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -156,6 +173,7 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. case 2 => __out = snapchat.research.gbml.gigl_resource_config.KFPResourceConfig case 3 => __out = snapchat.research.gbml.gigl_resource_config.LocalResourceConfig case 4 => __out = snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig + case 5 => __out = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig } __out } @@ -171,10 +189,12 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. def isKfpTrainerConfig: _root_.scala.Boolean = false def isLocalTrainerConfig: _root_.scala.Boolean = false def isVertexAiGraphStoreTrainerConfig: _root_.scala.Boolean = false + def isCustomTrainerConfig: _root_.scala.Boolean = false def vertexAiTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = _root_.scala.None def kfpTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = _root_.scala.None def localTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = _root_.scala.None def vertexAiGraphStoreTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = _root_.scala.None + def customTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = _root_.scala.None } object TrainerConfig { @SerialVersionUID(0L) @@ -214,18 +234,27 @@ object TrainerResourceConfig extends scalapb.GeneratedMessageCompanion[snapchat. override def vertexAiGraphStoreTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = Some(value) override def number: _root_.scala.Int = 4 } + @SerialVersionUID(0L) + final case class CustomTrainerConfig(value: snapchat.research.gbml.gigl_resource_config.CustomResourceConfig) extends snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig { + type ValueType = snapchat.research.gbml.gigl_resource_config.CustomResourceConfig + override def isCustomTrainerConfig: _root_.scala.Boolean = true + override def customTrainerConfig: _root_.scala.Option[snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = Some(value) + override def number: _root_.scala.Int = 5 + } } implicit class TrainerResourceConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig](_l) { def vertexAiTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiResourceConfig] = field(_.getVertexAiTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiTrainerConfig(f_))) def kfpTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.KFPResourceConfig] = field(_.getKfpTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.KfpTrainerConfig(f_))) def localTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.LocalResourceConfig] = field(_.getLocalTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.LocalTrainerConfig(f_))) def vertexAiGraphStoreTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.VertexAiGraphStoreConfig] = field(_.getVertexAiGraphStoreTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.VertexAiGraphStoreTrainerConfig(f_))) + def customTrainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.CustomResourceConfig] = field(_.getCustomTrainerConfig)((c_, f_) => c_.copy(trainerConfig = snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig.CustomTrainerConfig(f_))) def trainerConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig] = field(_.trainerConfig)((c_, f_) => c_.copy(trainerConfig = f_)) } final val VERTEX_AI_TRAINER_CONFIG_FIELD_NUMBER = 1 final val KFP_TRAINER_CONFIG_FIELD_NUMBER = 2 final val LOCAL_TRAINER_CONFIG_FIELD_NUMBER = 3 final val VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG_FIELD_NUMBER = 4 + final val CUSTOM_TRAINER_CONFIG_FIELD_NUMBER = 5 def of( trainerConfig: snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig.TrainerConfig ): _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig = _root_.snapchat.research.gbml.gigl_resource_config.TrainerResourceConfig( diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiTrainerConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiTrainerConfig.scala index c088bafc2..37d730799 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiTrainerConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gigl_resource_config/VertexAiTrainerConfig.scala @@ -29,28 +29,28 @@ final case class VertexAiTrainerConfig( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = machineType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = gpuType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = gpuLimit if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(3, __value) } }; - + { val __value = numReplicas if (__value != 0) { @@ -67,7 +67,7 @@ final case class VertexAiTrainerConfig( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Edge.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Edge.scala index cd1501b1f..dc359d1e3 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Edge.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Edge.scala @@ -33,14 +33,14 @@ final case class Edge( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = srcNodeId if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(1, __value) } }; - + { val __value = dstNodeId if (__value != 0) { @@ -65,7 +65,7 @@ final case class Edge( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/EdgeType.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/EdgeType.scala index d2a68d8b9..439ae2229 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/EdgeType.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/EdgeType.scala @@ -21,21 +21,21 @@ final case class EdgeType( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = relation if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = srcNodeType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = dstNodeType if (!__value.isEmpty) { @@ -52,7 +52,7 @@ final case class EdgeType( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Graph.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Graph.scala index 0e4e3105d..7c317bb6d 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Graph.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Graph.scala @@ -35,7 +35,7 @@ final case class Graph( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { nodes.foreach { __v => diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphMetadata.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphMetadata.scala index 8c6307580..856a159a0 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphMetadata.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphMetadata.scala @@ -58,7 +58,7 @@ final case class GraphMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { nodeTypes.foreach { __v => @@ -205,7 +205,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (__value != 0) { @@ -226,7 +226,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -269,7 +269,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research def companion: snapchat.research.gbml.graph_schema.GraphMetadata.CondensedEdgeTypeMapEntry.type = snapchat.research.gbml.graph_schema.GraphMetadata.CondensedEdgeTypeMapEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.GraphMetadata.CondensedEdgeTypeMapEntry]) } - + object CondensedEdgeTypeMapEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.graph_schema.GraphMetadata.CondensedEdgeTypeMapEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.graph_schema.GraphMetadata.CondensedEdgeTypeMapEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.graph_schema.GraphMetadata.CondensedEdgeTypeMapEntry = { @@ -341,7 +341,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GraphMetadata.CondensedEdgeTypeMapEntry]) } - + @SerialVersionUID(0L) final case class CondensedNodeTypeMapEntry( key: _root_.scala.Int = 0, @@ -352,14 +352,14 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -376,7 +376,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -420,7 +420,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research def companion: snapchat.research.gbml.graph_schema.GraphMetadata.CondensedNodeTypeMapEntry.type = snapchat.research.gbml.graph_schema.GraphMetadata.CondensedNodeTypeMapEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.GraphMetadata.CondensedNodeTypeMapEntry]) } - + object CondensedNodeTypeMapEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.graph_schema.GraphMetadata.CondensedNodeTypeMapEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.graph_schema.GraphMetadata.CondensedNodeTypeMapEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.graph_schema.GraphMetadata.CondensedNodeTypeMapEntry = { @@ -485,7 +485,7 @@ object GraphMetadata extends scalapb.GeneratedMessageCompanion[snapchat.research ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GraphMetadata.CondensedNodeTypeMapEntry]) } - + implicit class GraphMetadataLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.graph_schema.GraphMetadata]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.graph_schema.GraphMetadata](_l) { def nodeTypes: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[_root_.scala.Predef.String]] = field(_.nodeTypes)((c_, f_) => c_.copy(nodeTypes = f_)) def edgeTypes: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Seq[snapchat.research.gbml.graph_schema.EdgeType]] = field(_.edgeTypes)((c_, f_) => c_.copy(edgeTypes = f_)) diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphSchemaProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphSchemaProto.scala index 28c2eff67..0cd9e60ca 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphSchemaProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/GraphSchemaProto.scala @@ -50,4 +50,4 @@ object GraphSchemaProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Node.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Node.scala index e8a4c6f98..7e4c5c8b9 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Node.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/graph_schema/Node.scala @@ -27,7 +27,7 @@ final case class Node( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = nodeId if (__value != 0) { @@ -52,7 +52,7 @@ final case class Node( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadata.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadata.scala index 6668d67ce..a58d891c0 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadata.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadata.scala @@ -31,7 +31,7 @@ final case class InferenceMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { nodeTypeToInferencerOutputInfoMap.foreach { __v => @@ -123,7 +123,7 @@ object InferenceMetadata extends scalapb.GeneratedMessageCompanion[snapchat.rese private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { @@ -144,7 +144,7 @@ object InferenceMetadata extends scalapb.GeneratedMessageCompanion[snapchat.rese __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -187,7 +187,7 @@ object InferenceMetadata extends scalapb.GeneratedMessageCompanion[snapchat.rese def companion: snapchat.research.gbml.inference_metadata.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry.type = snapchat.research.gbml.inference_metadata.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry]) } - + object NodeTypeToInferencerOutputInfoMapEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.inference_metadata.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.inference_metadata.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.inference_metadata.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry = { @@ -259,7 +259,7 @@ object InferenceMetadata extends scalapb.GeneratedMessageCompanion[snapchat.rese ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.InferenceMetadata.NodeTypeToInferencerOutputInfoMapEntry]) } - + implicit class InferenceMetadataLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.inference_metadata.InferenceMetadata]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.inference_metadata.InferenceMetadata](_l) { def nodeTypeToInferencerOutputInfoMap: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, snapchat.research.gbml.inference_metadata.InferenceOutput]] = field(_.nodeTypeToInferencerOutputInfoMap)((c_, f_) => c_.copy(nodeTypeToInferencerOutputInfoMap = f_)) } diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadataProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadataProto.scala index 7e1424d6e..9c335401b 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadataProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceMetadataProto.scala @@ -35,4 +35,4 @@ object InferenceMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceOutput.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceOutput.scala index 845fbb29c..66c68b47b 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceOutput.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/inference_metadata/InferenceOutput.scala @@ -38,7 +38,7 @@ final case class InferenceOutput( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { embeddingsPath.foreach { __v => diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostProcessedMetadata.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostProcessedMetadata.scala index a0399d909..8ebed0a48 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostProcessedMetadata.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostProcessedMetadata.scala @@ -17,7 +17,7 @@ final case class PostProcessedMetadata( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = postProcessorLogMetricsUri if (!__value.isEmpty) { @@ -34,7 +34,7 @@ final case class PostProcessedMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostprocessedMetadataProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostprocessedMetadataProto.scala index 4b0e94597..d36401685 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostprocessedMetadataProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/postprocessed_metadata/PostprocessedMetadataProto.scala @@ -28,4 +28,4 @@ object PostprocessedMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadata.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadata.scala index 6a160cdd4..80160636b 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadata.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadata.scala @@ -38,7 +38,7 @@ final case class PreprocessedMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { condensedNodeTypeToPreprocessedMetadata.foreach { __v => @@ -181,7 +181,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = nodeIdKey if (!__value.isEmpty) { @@ -196,28 +196,28 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r val __value = __item __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(3, __value) } - + { val __value = tfrecordUriPrefix if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(4, __value) } }; - + { val __value = schemaUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(5, __value) } }; - + { val __value = enumeratedNodeIdsBqTable if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(6, __value) } }; - + { val __value = enumeratedNodeDataBqTable if (!__value.isEmpty) { @@ -228,7 +228,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r val __value = featureDim.get __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(8, __value) }; - + { val __value = transformFnAssetsUri if (!__value.isEmpty) { @@ -245,7 +245,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -366,7 +366,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput.type = snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.PreprocessedMetadata.NodeMetadataOutput]) } - + object NodeMetadataOutput extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput = { @@ -499,7 +499,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.PreprocessedMetadata.NodeMetadataOutput]) } - + /** Houses metadata of edge features output from DataPreprocessor * * @param featureKeys @@ -540,21 +540,21 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r val __value = __item __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } - + { val __value = tfrecordUriPrefix if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(3, __value) } }; - + { val __value = schemaUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(4, __value) } }; - + { val __value = enumeratedEdgeDataBqTable if (!__value.isEmpty) { @@ -565,7 +565,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r val __value = featureDim.get __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(6, __value) }; - + { val __value = transformFnAssetsUri if (!__value.isEmpty) { @@ -582,7 +582,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { featureKeys.foreach { __v => @@ -679,7 +679,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataInfo.type = snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataInfo // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.PreprocessedMetadata.EdgeMetadataInfo]) } - + object EdgeMetadataInfo extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataInfo] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataInfo] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataInfo = { @@ -792,7 +792,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.PreprocessedMetadata.EdgeMetadataInfo]) } - + /** Houses metadata about edge TFTransform output from DataPreprocessor. * * @param srcNodeIdKey @@ -819,14 +819,14 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = srcNodeIdKey if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = dstNodeIdKey if (!__value.isEmpty) { @@ -855,7 +855,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -932,7 +932,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput.type = snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.PreprocessedMetadata.EdgeMetadataOutput]) } - + object EdgeMetadataOutput extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput = { @@ -1035,7 +1035,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.PreprocessedMetadata.EdgeMetadataOutput]) } - + @SerialVersionUID(0L) final case class CondensedNodeTypeToPreprocessedMetadataEntry( key: _root_.scala.Int = 0, @@ -1046,7 +1046,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (__value != 0) { @@ -1067,7 +1067,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -1110,7 +1110,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry.type = snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry]) } - + object CondensedNodeTypeToPreprocessedMetadataEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry = { @@ -1182,7 +1182,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.PreprocessedMetadata.CondensedNodeTypeToPreprocessedMetadataEntry]) } - + @SerialVersionUID(0L) final case class CondensedEdgeTypeToPreprocessedMetadataEntry( key: _root_.scala.Int = 0, @@ -1193,7 +1193,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (__value != 0) { @@ -1214,7 +1214,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -1257,7 +1257,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r def companion: snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry.type = snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry]) } - + object CondensedEdgeTypeToPreprocessedMetadataEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry = { @@ -1329,7 +1329,7 @@ object PreprocessedMetadata extends scalapb.GeneratedMessageCompanion[snapchat.r ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.PreprocessedMetadata.CondensedEdgeTypeToPreprocessedMetadataEntry]) } - + implicit class PreprocessedMetadataLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata](_l) { def condensedNodeTypeToPreprocessedMetadata: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Int, snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.NodeMetadataOutput]] = field(_.condensedNodeTypeToPreprocessedMetadata)((c_, f_) => c_.copy(condensedNodeTypeToPreprocessedMetadata = f_)) def condensedEdgeTypeToPreprocessedMetadata: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Int, snapchat.research.gbml.preprocessed_metadata.PreprocessedMetadata.EdgeMetadataOutput]] = field(_.condensedEdgeTypeToPreprocessedMetadata)((c_, f_) => c_.copy(condensedEdgeTypeToPreprocessedMetadata = f_)) diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadataProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadataProto.scala index b6e8d0d6d..becc2d068 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadataProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/preprocessed_metadata/PreprocessedMetadataProto.scala @@ -61,4 +61,4 @@ object PreprocessedMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/GlobalRandomUniformStrategy.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/GlobalRandomUniformStrategy.scala index c56f47fa4..2a6227235 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/GlobalRandomUniformStrategy.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/GlobalRandomUniformStrategy.scala @@ -15,7 +15,7 @@ final case class GlobalRandomUniformStrategy( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = numHops if (__value != 0) { @@ -36,7 +36,7 @@ final case class GlobalRandomUniformStrategy( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPath.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPath.scala index dc3069c8b..4aad9e19c 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPath.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPath.scala @@ -15,7 +15,7 @@ final case class MessagePassingPath( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = rootNodeType if (!__value.isEmpty) { @@ -36,7 +36,7 @@ final case class MessagePassingPath( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPathStrategy.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPathStrategy.scala index fd8a906af..ee3aa77a0 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPathStrategy.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/MessagePassingPathStrategy.scala @@ -34,7 +34,7 @@ final case class MessagePassingPathStrategy( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { paths.foreach { __v => diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomUniform.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomUniform.scala index 5579eca5a..ef0fde958 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomUniform.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomUniform.scala @@ -5,7 +5,7 @@ package snapchat.research.gbml.subgraph_sampling_strategy -/** Randomly sample nodes from the neighborhood without replacement. +/** Randomly sample nodes from the neighborhood without replacement. */ @SerialVersionUID(0L) final case class RandomUniform( @@ -16,7 +16,7 @@ final case class RandomUniform( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = numNodesToSample if (__value != 0) { @@ -33,7 +33,7 @@ final case class RandomUniform( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomWeighted.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomWeighted.scala index 389735b80..395f3876b 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomWeighted.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/RandomWeighted.scala @@ -17,14 +17,14 @@ final case class RandomWeighted( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = numNodesToSample if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeInt32Size(1, __value) } }; - + { val __value = edgeFeatName if (!__value.isEmpty) { @@ -41,7 +41,7 @@ final case class RandomWeighted( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingDirection.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingDirection.scala index b35d80157..6b6b756d5 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingDirection.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingDirection.scala @@ -17,7 +17,7 @@ sealed abstract class SamplingDirection(val value: _root_.scala.Int) extends _ro object SamplingDirection extends _root_.scalapb.GeneratedEnumCompanion[SamplingDirection] { sealed trait Recognized extends SamplingDirection implicit def enumCompanion: _root_.scalapb.GeneratedEnumCompanion[SamplingDirection] = this - + /** Sample incoming edges to the dst nodes (default) */ @SerialVersionUID(0L) @@ -26,7 +26,7 @@ object SamplingDirection extends _root_.scalapb.GeneratedEnumCompanion[SamplingD val name = "INCOMING" override def isIncoming: _root_.scala.Boolean = true } - + /** Sample outgoing edges from the src nodes */ @SerialVersionUID(0L) @@ -35,7 +35,7 @@ object SamplingDirection extends _root_.scalapb.GeneratedEnumCompanion[SamplingD val name = "OUTGOING" override def isOutgoing: _root_.scala.Boolean = true } - + @SerialVersionUID(0L) final case class Unrecognized(unrecognizedValue: _root_.scala.Int) extends SamplingDirection(unrecognizedValue) with _root_.scalapb.UnrecognizedEnum lazy val values: scala.collection.immutable.Seq[ValueType] = scala.collection.immutable.Seq(INCOMING, OUTGOING) @@ -46,4 +46,4 @@ object SamplingDirection extends _root_.scalapb.GeneratedEnumCompanion[SamplingD } def javaDescriptor: _root_.com.google.protobuf.Descriptors.EnumDescriptor = SubgraphSamplingStrategyProto.javaDescriptor.getEnumTypes().get(0) def scalaDescriptor: _root_.scalapb.descriptors.EnumDescriptor = SubgraphSamplingStrategyProto.scalaDescriptor.enums(0) -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingOp.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingOp.scala index 513ceaf83..6afa2e8b6 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingOp.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SamplingOp.scala @@ -28,7 +28,7 @@ final case class SamplingOp( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = opName if (!__value.isEmpty) { @@ -59,7 +59,7 @@ final case class SamplingOp( val __value = samplingMethod.userDefined.get __size += 2 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize }; - + { val __value = samplingDirection.value if (__value != 0) { @@ -76,7 +76,7 @@ final case class SamplingOp( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -293,7 +293,7 @@ object SamplingOp extends scalapb.GeneratedMessageCompanion[snapchat.research.gb override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class RandomUniform(value: snapchat.research.gbml.subgraph_sampling_strategy.RandomUniform) extends snapchat.research.gbml.subgraph_sampling_strategy.SamplingOp.SamplingMethod { type ValueType = snapchat.research.gbml.subgraph_sampling_strategy.RandomUniform diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategy.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategy.scala index 5f8e3dee6..4abc7182b 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategy.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategy.scala @@ -32,7 +32,7 @@ final case class SubgraphSamplingStrategy( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { strategy.messagePassingPaths.foreach { __v => @@ -143,7 +143,7 @@ object SubgraphSamplingStrategy extends scalapb.GeneratedMessageCompanion[snapch override def number: _root_.scala.Int = 0 override def value: _root_.scala.Nothing = throw new java.util.NoSuchElementException("Empty.value") } - + @SerialVersionUID(0L) final case class MessagePassingPaths(value: snapchat.research.gbml.subgraph_sampling_strategy.MessagePassingPathStrategy) extends snapchat.research.gbml.subgraph_sampling_strategy.SubgraphSamplingStrategy.Strategy { type ValueType = snapchat.research.gbml.subgraph_sampling_strategy.MessagePassingPathStrategy diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategyProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategyProto.scala index 709ae1159..e6ae31c4c 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategyProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/SubgraphSamplingStrategyProto.scala @@ -67,4 +67,4 @@ object SubgraphSamplingStrategyProto extends _root_.scalapb.GeneratedFileObject } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/TopK.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/TopK.scala index 86bda274e..39c5388e4 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/TopK.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/TopK.scala @@ -17,14 +17,14 @@ final case class TopK( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = numNodesToSample if (__value != 0) { __size += _root_.com.google.protobuf.CodedOutputStream.computeInt32Size(1, __value) } }; - + { val __value = edgeFeatName if (!__value.isEmpty) { @@ -41,7 +41,7 @@ final case class TopK( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/UserDefined.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/UserDefined.scala index a055bfbaa..cc00d0f9d 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/UserDefined.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/subgraph_sampling_strategy/UserDefined.scala @@ -20,7 +20,7 @@ final case class UserDefined( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = pathToUdf if (!__value.isEmpty) { @@ -41,7 +41,7 @@ final case class UserDefined( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -151,14 +151,14 @@ object UserDefined extends scalapb.GeneratedMessageCompanion[snapchat.research.g private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = key if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = value if (!__value.isEmpty) { @@ -175,7 +175,7 @@ object UserDefined extends scalapb.GeneratedMessageCompanion[snapchat.research.g __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { @@ -219,7 +219,7 @@ object UserDefined extends scalapb.GeneratedMessageCompanion[snapchat.research.g def companion: snapchat.research.gbml.subgraph_sampling_strategy.UserDefined.ParamsEntry.type = snapchat.research.gbml.subgraph_sampling_strategy.UserDefined.ParamsEntry // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.UserDefined.ParamsEntry]) } - + object ParamsEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.subgraph_sampling_strategy.UserDefined.ParamsEntry] { implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.subgraph_sampling_strategy.UserDefined.ParamsEntry] = this def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.subgraph_sampling_strategy.UserDefined.ParamsEntry = { @@ -284,7 +284,7 @@ object UserDefined extends scalapb.GeneratedMessageCompanion[snapchat.research.g ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.UserDefined.ParamsEntry]) } - + implicit class UserDefinedLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.subgraph_sampling_strategy.UserDefined]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.subgraph_sampling_strategy.UserDefined](_l) { def pathToUdf: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.pathToUdf)((c_, f_) => c_.copy(pathToUdf = f_)) def params: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = field(_.params)((c_, f_) => c_.copy(params = f_)) diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadata.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadata.scala index bcf95c046..2c5a042f9 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadata.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadata.scala @@ -26,28 +26,28 @@ final case class TrainedModelMetadata( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = trainedModelUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = scriptedModelUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) } }; - + { val __value = evalMetricsUri if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(3, __value) } }; - + { val __value = tensorboardLogsUri if (!__value.isEmpty) { @@ -64,7 +64,7 @@ final case class TrainedModelMetadata( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadataProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadataProto.scala index b06e0d55a..1262d9517 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadataProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/trained_model_metadata/TrainedModelMetadataProto.scala @@ -31,4 +31,4 @@ object TrainedModelMetadataProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/Label.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/Label.scala index 21884234b..3289f2d76 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/Label.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/Label.scala @@ -15,14 +15,14 @@ final case class Label( private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 private[this] def __computeSerializedSize(): _root_.scala.Int = { var __size = 0 - + { val __value = labelType if (!__value.isEmpty) { __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) } }; - + { val __value = label if (__value != 0) { @@ -39,7 +39,7 @@ final case class Label( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { { diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/NodeAnchorBasedLinkPredictionSample.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/NodeAnchorBasedLinkPredictionSample.scala index a05a6ef18..a1f7def5a 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/NodeAnchorBasedLinkPredictionSample.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/NodeAnchorBasedLinkPredictionSample.scala @@ -64,7 +64,7 @@ final case class NodeAnchorBasedLinkPredictionSample( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { rootNode.foreach { __v => diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/RootedNodeNeighborhood.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/RootedNodeNeighborhood.scala index f326f0375..cdd1d4a4c 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/RootedNodeNeighborhood.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/RootedNodeNeighborhood.scala @@ -41,7 +41,7 @@ final case class RootedNodeNeighborhood( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { rootNode.foreach { __v => diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedLinkBasedTaskSample.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedLinkBasedTaskSample.scala index 4ffbeb328..9707ea68a 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedLinkBasedTaskSample.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedLinkBasedTaskSample.scala @@ -48,7 +48,7 @@ final case class SupervisedLinkBasedTaskSample( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { rootEdge.foreach { __v => diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedNodeClassificationSample.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedNodeClassificationSample.scala index 35c594ca3..00304a823 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedNodeClassificationSample.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/SupervisedNodeClassificationSample.scala @@ -43,7 +43,7 @@ final case class SupervisedNodeClassificationSample( __serializedSizeMemoized = __size } __size - 1 - + } def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { rootNode.foreach { __v => diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/TrainingSamplesSchemaProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/TrainingSamplesSchemaProto.scala index f4187ea60..f02891e26 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/TrainingSamplesSchemaProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/training_samples_schema/TrainingSamplesSchemaProto.scala @@ -51,4 +51,4 @@ object TrainingSamplesSchemaProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/userDefinedAggregators/RnnUDAF.scala b/scala_spark35/common/src/main/scala/userDefinedAggregators/RnnUDAF.scala index 1900f8e07..856e64e47 100644 --- a/scala_spark35/common/src/main/scala/userDefinedAggregators/RnnUDAF.scala +++ b/scala_spark35/common/src/main/scala/userDefinedAggregators/RnnUDAF.scala @@ -70,16 +70,16 @@ class RnnUDAF( extends Aggregator[RnnUDAF.InTwoHopData, RnnUDAF.BufferRNN, Array[Byte]] { /** - * Introduces a custom user defined aggregation function that + * Introduces a custom user defined aggregation function that * allows for more efficient "GROUP BY" on "root_node_id" when formulating a 2 hop subgraph, * as compared to using default Spark aggregate functions like array_append, array_union, array_agg, et al. * These functions are quite expensive and not suitable for aggregating all types of columns. - * + * * The UDAF is used to aggregate the 2-hop subgraph information into a single RootedNodeNeighborhood * protobuf message (byte array). - * + * * sampleN: Option[Int] - The number of edges to sample from the 1-hop and 2-hop neighbors of the root node. - * + * * Example usage: * spark.udf.register("rnnUDAF", F.udaf(new RnnUDAF(sampleN = Some(VAL)))) * ... @@ -99,9 +99,9 @@ class RnnUDAF( * _2_hop_node_features, * _2_hop_edge_features, * _2_hop_edge_type - * ) as result + * ) as result * FROM - * ... + * ... * GROUP BY * _root_node_id, _root_node_type */ diff --git a/scala_spark35/common/src/main/scala/utils/SlottedJoiner.scala b/scala_spark35/common/src/main/scala/utils/SlottedJoiner.scala index d9635d857..2006b079d 100644 --- a/scala_spark35/common/src/main/scala/utils/SlottedJoiner.scala +++ b/scala_spark35/common/src/main/scala/utils/SlottedJoiner.scala @@ -9,7 +9,7 @@ object SlottedJoiner { /** * This class helps handle OOM and disk space issues in Spark jobs during large table joins. - * Instead of one big join, it partitions the left table into smaller tables and joins + * Instead of one big join, it partitions the left table into smaller tables and joins * them iteratively with the right table, ensuring better scalability with commodity hardware. * * Usage: @@ -17,11 +17,11 @@ object SlottedJoiner { * vla rightDf = ... * val numSlots = 10 * val slottedLeftDf = SlottedJoiner.computeSlotsOnDataframe( - * df=leftDf, - * columnToComputeSlotOn="joinKey", + * df=leftDf, + * columnToComputeSlotOn="joinKey", * numSlots=numSlots * ) - * + * * // Caching helps us avoid recomputing the tables * cacher.createDiskPartitionedTable( * df = leftSlottedDF, @@ -119,7 +119,7 @@ object SlottedJoiner { val queryWithSlottedTables = f""" WITH ${leftTableName} as ( - SELECT * + SELECT * FROM ${leftSlottedTableName} WHERE ${SLOT_NUM_COLUMN_NAME} = ${slotNum} ) ${currSlotQuery} @@ -171,12 +171,12 @@ object SlottedJoiner { .replace(rightSlottedTableName, rightTableName) val queryWithSlottedTables = f""" WITH ${leftTableName} as ( - SELECT * + SELECT * FROM ${leftSlottedTableName} WHERE ${SLOT_NUM_COLUMN_NAME} = ${slotNum} ), ${rightTableName} as ( - SELECT * + SELECT * FROM ${rightSlottedTableName} WHERE ${SLOT_NUM_COLUMN_NAME} = ${slotNum} ) ${currSlotQuery} diff --git a/scala_spark35/common/src/test/assets/resource_config.yaml b/scala_spark35/common/src/test/assets/resource_config.yaml index 1a795ef14..8eff18103 100644 --- a/scala_spark35/common/src/test/assets/resource_config.yaml +++ b/scala_spark35/common/src/test/assets/resource_config.yaml @@ -41,4 +41,4 @@ inferencer_config: num_workers: 1 max_num_workers: 256 machine_type: "c3-standard-22" - disk_size_gb: 100 \ No newline at end of file + disk_size_gb: 100 diff --git a/scala_spark35/common/src/test/assets/split_generator/node_anchor_based_link_prediction/frozen_gbml_config.yaml b/scala_spark35/common/src/test/assets/split_generator/node_anchor_based_link_prediction/frozen_gbml_config.yaml index b1a1421cd..a0a2ba944 100644 --- a/scala_spark35/common/src/test/assets/split_generator/node_anchor_based_link_prediction/frozen_gbml_config.yaml +++ b/scala_spark35/common/src/test/assets/split_generator/node_anchor_based_link_prediction/frozen_gbml_config.yaml @@ -44,4 +44,4 @@ graphMetadata: relation: engage srcNodeType: user nodeTypes: - - user \ No newline at end of file + - user diff --git a/scala_spark35/common/src/test/assets/split_generator/node_anchor_based_link_prediction/preprocessed_metadata.yaml b/scala_spark35/common/src/test/assets/split_generator/node_anchor_based_link_prediction/preprocessed_metadata.yaml index 429035cbb..f44cdbb44 100644 --- a/scala_spark35/common/src/test/assets/split_generator/node_anchor_based_link_prediction/preprocessed_metadata.yaml +++ b/scala_spark35/common/src/test/assets/split_generator/node_anchor_based_link_prediction/preprocessed_metadata.yaml @@ -15,4 +15,4 @@ condensedNodeTypeToPreprocessedMetadata: - f1 nodeIdKey: node_id schemaUri: not.used.for.test - tfrecordUriPrefix: not.used.for.test \ No newline at end of file + tfrecordUriPrefix: not.used.for.test diff --git a/scala_spark35/common/src/test/assets/split_generator/supervised_node_classification/frozen_gbml_config.yaml b/scala_spark35/common/src/test/assets/split_generator/supervised_node_classification/frozen_gbml_config.yaml index 90762002c..68711a211 100644 --- a/scala_spark35/common/src/test/assets/split_generator/supervised_node_classification/frozen_gbml_config.yaml +++ b/scala_spark35/common/src/test/assets/split_generator/supervised_node_classification/frozen_gbml_config.yaml @@ -28,4 +28,4 @@ sharedConfig: supervisedNodeClassificationOutput: labeledTfrecordUriPrefix: common/src/test/assets/split_generator/supervised_node_classification/sgs_output/labeled/samples/ unlabeledTfrecordUriPrefix: common/src/test/assets/split_generator/supervised_node_classification/sgs_output/unlabeled/samples/ - preprocessedMetadataUri: common/src/test/assets/split_generator/supervised_node_classification/preprocessed_metadata.yaml \ No newline at end of file + preprocessedMetadataUri: common/src/test/assets/split_generator/supervised_node_classification/preprocessed_metadata.yaml diff --git a/scala_spark35/common/src/test/assets/subgraph_sampler/supervised_node_classification/frozen_gbml_config.yaml b/scala_spark35/common/src/test/assets/subgraph_sampler/supervised_node_classification/frozen_gbml_config.yaml index cb1587a83..798350dd0 100644 --- a/scala_spark35/common/src/test/assets/subgraph_sampler/supervised_node_classification/frozen_gbml_config.yaml +++ b/scala_spark35/common/src/test/assets/subgraph_sampler/supervised_node_classification/frozen_gbml_config.yaml @@ -24,4 +24,4 @@ sharedConfig: supervisedNodeClassificationOutput: labeledTfrecordUriPrefix: common/src/test/assets/subgraph_sampler/supervised_node_classification/output/labeled/samples/ unlabeledTfrecordUriPrefix: common/src/test/assets/subgraph_sampler/supervised_node_classification/output/unlabeled/samples/ - preprocessedMetadataUri: common/src/test/assets/subgraph_sampler/supervised_node_classification/preprocessed_metadata.yaml \ No newline at end of file + preprocessedMetadataUri: common/src/test/assets/subgraph_sampler/supervised_node_classification/preprocessed_metadata.yaml diff --git a/scala_spark35/common/src/test/scala/userDefinedAggregators/RnnUDAFTest.scala b/scala_spark35/common/src/test/scala/userDefinedAggregators/RnnUDAFTest.scala index 699ae40f5..ca3a1b984 100644 --- a/scala_spark35/common/src/test/scala/userDefinedAggregators/RnnUDAFTest.scala +++ b/scala_spark35/common/src/test/scala/userDefinedAggregators/RnnUDAFTest.scala @@ -263,10 +263,10 @@ class RnnUDAFTest extends AnyFunSuite with BeforeAndAfterAll with SharedSparkSes _2_hop_node_features, _2_hop_edge_features, _2_hop_edge_type - ) as result - FROM - test_view - GROUP BY + ) as result + FROM + test_view + GROUP BY _root_node_id, _root_node_type """) diff --git a/scala_spark35/common/src/test/scala/utils/SlottedJoinerTest.scala b/scala_spark35/common/src/test/scala/utils/SlottedJoinerTest.scala index 49a77ce58..f6fa5a1d8 100644 --- a/scala_spark35/common/src/test/scala/utils/SlottedJoinerTest.scala +++ b/scala_spark35/common/src/test/scala/utils/SlottedJoinerTest.scala @@ -89,11 +89,11 @@ class SlottedJoinerTest extends AnyFunSuite with BeforeAndAfterAll with SharedSp slottedOnSrc.dst_node as root_node, slottedOnDst.dst_node as 1_hop_node, slottedOnDst.src_node as 2_hop_node - FROM - slottedOnSrc - JOIN - slottedOnDst - ON + FROM + slottedOnSrc + JOIN + slottedOnDst + ON slottedOnSrc.src_node = slottedOnDst.dst_node """, numSlots = numSlots, diff --git a/scala_spark35/split_generator/src/main/scala/Main.scala b/scala_spark35/split_generator/src/main/scala/Main.scala index b54f019e4..8c8fa5b2e 100644 --- a/scala_spark35/split_generator/src/main/scala/Main.scala +++ b/scala_spark35/split_generator/src/main/scala/Main.scala @@ -17,7 +17,7 @@ object Main { val resourceConfigYamlGcsUri = args(2) println(f""" - Starting Split Generator with the following arguments: + Starting Split Generator with the following arguments: sparkAppName=${sparkAppName}, frozenGbmlConfigYamlGcsUri=${frozenGbmlConfigYamlGcsUri}, resourceConfigYamlGcsUri=${resourceConfigYamlGcsUri} diff --git a/scala_spark35/split_generator/src/main/scala/lib/assigners/AbstractAssigners.scala b/scala_spark35/split_generator/src/main/scala/lib/assigners/AbstractAssigners.scala index 90524f6a0..891198953 100644 --- a/scala_spark35/split_generator/src/main/scala/lib/assigners/AbstractAssigners.scala +++ b/scala_spark35/split_generator/src/main/scala/lib/assigners/AbstractAssigners.scala @@ -22,7 +22,7 @@ object AbstractAssigners { * e.g. could be assigning a NodePb (T) to some Enum (S). * * @param obj the object to hash - * @return + * @return */ def assign(obj: T): S } @@ -59,7 +59,7 @@ object AbstractAssigners { /** * Relative width of each bucket in the hash space. e.g. [0.2, 0.4, 0.4] would indicate 3 buckets, where - * the second and third bucket are twice as prominent as the first bucket. + * the second and third bucket are twice as prominent as the first bucket. */ lazy val weights: Seq[Float] = bucketWeights.values.toList diff --git a/scala_spark35/split_generator/src/main/scala/lib/split_strategies/SplitStrategy.scala b/scala_spark35/split_generator/src/main/scala/lib/split_strategies/SplitStrategy.scala index ba15966a3..b0d7e1146 100644 --- a/scala_spark35/split_generator/src/main/scala/lib/split_strategies/SplitStrategy.scala +++ b/scala_spark35/split_generator/src/main/scala/lib/split_strategies/SplitStrategy.scala @@ -32,7 +32,7 @@ abstract class SplitStrategy[A](splitStrategyArgs: Map[String, String]) extends val graphMetadataPbWrapper: GraphMetadataPbWrapper /** - * Takes in a single "un-split" training sample instance output by SubgraphSampler, + * Takes in a single "un-split" training sample instance output by SubgraphSampler, * and a DatasetSplit(TRAIN, TEST, VAL) and outputs the the "split" samples for that dataset split * * @param sample : Input Sample from SGS diff --git a/scala_spark35/split_generator/src/main/scala/lib/split_strategies/UDLAnchorBasedSupervisionEdgeSplitStrategy.scala b/scala_spark35/split_generator/src/main/scala/lib/split_strategies/UDLAnchorBasedSupervisionEdgeSplitStrategy.scala index 61fefef6f..12e378ebe 100644 --- a/scala_spark35/split_generator/src/main/scala/lib/split_strategies/UDLAnchorBasedSupervisionEdgeSplitStrategy.scala +++ b/scala_spark35/split_generator/src/main/scala/lib/split_strategies/UDLAnchorBasedSupervisionEdgeSplitStrategy.scala @@ -37,7 +37,7 @@ class UDLAnchorBasedSupervisionEdgeSplitStrategy( * (a) All pos_edges and hard_neg_edges belonging to the split. * (b) message passing structure which should be pb.neighborhood and therefore the same across all splits * (i.e. no masking). - * (c) The message passing structure may be filtered down to only include edges that are not in the pos_edges + * (c) The message passing structure may be filtered down to only include edges that are not in the pos_edges * and hard_neg_edges. * An output train-split sample needs to have >0 pos_edges in this setting for loss computation. * Output val/test-split samples may have 0 pos_edges (and even 0 hard_neg_edges), since these diff --git a/scala_spark35/subgraph_sampler/src/main/scala/libs/task/TaskOutputValidator.scala b/scala_spark35/subgraph_sampler/src/main/scala/libs/task/TaskOutputValidator.scala index a1c5b9081..dd63cc113 100644 --- a/scala_spark35/subgraph_sampler/src/main/scala/libs/task/TaskOutputValidator.scala +++ b/scala_spark35/subgraph_sampler/src/main/scala/libs/task/TaskOutputValidator.scala @@ -17,7 +17,7 @@ object TaskOutputValidator { * is present in the neighborhood nodes. * This method does a dataset.map() on the final output produced by SGS and returns the same dataset * if there is no validation failure. Raises and excpetion if there is some error - * @spark: dataset.map() is not an action (unlike foreach) and does not lead to any + * @spark: dataset.map() is not an action (unlike foreach) and does not lead to any * duplication of computation due to this validation code. * * @param mainSampleDS @@ -50,7 +50,7 @@ object TaskOutputValidator { * is present in the neighborhood nodes. * This method does a dataset.map() on the final output produced by SGS and returns the same dataset * if there is no validation failure. Raises and excpetion if there is some error - * @spark: dataset.map() is not an action (unlike foreach) and does not lead to any + * @spark: dataset.map() is not an action (unlike foreach) and does not lead to any * duplication of computation due to this validation code. * * @param mainSampleDS diff --git a/scala_spark35/subgraph_sampler/src/main/scala/libs/task/pureSparkV2/EgoNetGeneration.scala b/scala_spark35/subgraph_sampler/src/main/scala/libs/task/pureSparkV2/EgoNetGeneration.scala index b66300d9e..68ad26f3d 100644 --- a/scala_spark35/subgraph_sampler/src/main/scala/libs/task/pureSparkV2/EgoNetGeneration.scala +++ b/scala_spark35/subgraph_sampler/src/main/scala/libs/task/pureSparkV2/EgoNetGeneration.scala @@ -46,7 +46,7 @@ object EgoNetGeneration { spark.sql( s""" SELECT DISTINCT * FROM ( - SELECT + SELECT dst_node_id as src_node_id, src_node_id as dst_node_id, ${DEFAULT_EDGE_TYPE} as edge_type, @@ -454,10 +454,10 @@ class EgoNetGeneration( val toNodeIdColumn = flags.to_node_id_column println(s""" - Running EgoNetGeneration Job w/ - nodeTableName: ${nodeTableName}, - edgeTableName: ${edgeTableName}, - fromNodeIdColumn: ${fromNodeIdColumn}, + Running EgoNetGeneration Job w/ + nodeTableName: ${nodeTableName}, + edgeTableName: ${edgeTableName}, + fromNodeIdColumn: ${fromNodeIdColumn}, toNodeIdColumn: ${toNodeIdColumn} """) diff --git a/snapchat/research/gbml/gigl_resource_config_pb2.py b/snapchat/research/gbml/gigl_resource_config_pb2.py index bbda8cf57..3073ce92a 100644 --- a/snapchat/research/gbml/gigl_resource_config_pb2.py +++ b/snapchat/research/gbml/gigl_resource_config_pb2.py @@ -15,7 +15,7 @@ -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n1snapchat/research/gbml/gigl_resource_config.proto\x12\x16snapchat.research.gbml\"Y\n\x13SparkResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x16\n\x0enum_local_ssds\x18\x02 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x03 \x01(\r\"\x83\x01\n\x16\x44\x61taflowResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\x12\x17\n\x0fmax_num_workers\x18\x02 \x01(\r\x12\x14\n\x0cmachine_type\x18\x03 \x01(\t\x12\x14\n\x0c\x64isk_size_gb\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\"\xbc\x01\n\x16\x44\x61taPreprocessorConfig\x12P\n\x18\x65\x64ge_preprocessor_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\x12P\n\x18node_preprocessor_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\"h\n\x15VertexAiTrainerConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\"z\n\x10KFPTrainerConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\")\n\x12LocalTrainerConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"O\n\x1bVertexAiReservationAffinity\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\"\n\x1areservation_resource_names\x18\x02 \x03(\t\"\xa2\x02\n\x16VertexAiResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12\x1b\n\x13gcp_region_override\x18\x06 \x01(\t\x12\x1b\n\x13scheduling_strategy\x18\x07 \x01(\t\x12\x19\n\x11\x62oot_disk_size_gb\x18\x08 \x01(\r\x12Q\n\x14reservation_affinity\x18\t \x01(\x0b\x32\x33.snapchat.research.gbml.VertexAiReservationAffinity\"{\n\x11KFPResourceConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\"*\n\x13LocalResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\xd4\x01\n\x18VertexAiGraphStoreConfig\x12H\n\x10graph_store_pool\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\x12\x44\n\x0c\x63ompute_pool\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\x12(\n compute_cluster_local_world_size\x18\x03 \x01(\x05\"\x93\x02\n\x18\x44istributedTrainerConfig\x12Q\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32-.snapchat.research.gbml.VertexAiTrainerConfigH\x00\x12\x46\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32(.snapchat.research.gbml.KFPTrainerConfigH\x00\x12J\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32*.snapchat.research.gbml.LocalTrainerConfigH\x00\x42\x10\n\x0etrainer_config\"\xf5\x02\n\x15TrainerResourceConfig\x12R\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12G\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32).snapchat.research.gbml.KFPResourceConfigH\x00\x12K\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12`\n$vertex_ai_graph_store_trainer_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.VertexAiGraphStoreConfigH\x00\x42\x10\n\x0etrainer_config\"\x91\x03\n\x18InferencerResourceConfig\x12U\n\x1bvertex_ai_inferencer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12T\n\x1a\x64\x61taflow_inferencer_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigH\x00\x12N\n\x17local_inferencer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12\x63\n\'vertex_ai_graph_store_inferencer_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.VertexAiGraphStoreConfigH\x00\x42\x13\n\x11inferencer_config\"\xa3\x04\n\x14SharedResourceConfig\x12Y\n\x0fresource_labels\x18\x01 \x03(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry\x12_\n\x15\x63ommon_compute_config\x18\x02 \x01(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig\x1a\x97\x02\n\x13\x43ommonComputeConfig\x12\x0f\n\x07project\x18\x01 \x01(\t\x12\x0e\n\x06region\x18\x02 \x01(\t\x12\x1a\n\x12temp_assets_bucket\x18\x03 \x01(\t\x12#\n\x1btemp_regional_assets_bucket\x18\x04 \x01(\t\x12\x1a\n\x12perm_assets_bucket\x18\x05 \x01(\t\x12#\n\x1btemp_assets_bq_dataset_name\x18\x06 \x01(\t\x12!\n\x19\x65mbedding_bq_dataset_name\x18\x07 \x01(\t\x12!\n\x19gcp_service_account_email\x18\x08 \x01(\t\x12\x17\n\x0f\x64\x61taflow_runner\x18\x0b \x01(\t\x1a\x35\n\x13ResourceLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc8\x05\n\x12GiglResourceConfig\x12$\n\x1ashared_resource_config_uri\x18\x01 \x01(\tH\x00\x12N\n\x16shared_resource_config\x18\x02 \x01(\x0b\x32,.snapchat.research.gbml.SharedResourceConfigH\x00\x12K\n\x13preprocessor_config\x18\x0c \x01(\x0b\x32..snapchat.research.gbml.DataPreprocessorConfig\x12L\n\x17subgraph_sampler_config\x18\r \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12K\n\x16split_generator_config\x18\x0e \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12L\n\x0etrainer_config\x18\x0f \x01(\x0b\x32\x30.snapchat.research.gbml.DistributedTrainerConfigB\x02\x18\x01\x12M\n\x11inferencer_config\x18\x10 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigB\x02\x18\x01\x12N\n\x17trainer_resource_config\x18\x11 \x01(\x0b\x32-.snapchat.research.gbml.TrainerResourceConfig\x12T\n\x1ainferencer_resource_config\x18\x12 \x01(\x0b\x32\x30.snapchat.research.gbml.InferencerResourceConfigB\x11\n\x0fshared_resource*\xf3\x01\n\tComponent\x12\x15\n\x11\x43omponent_Unknown\x10\x00\x12\x1e\n\x1a\x43omponent_Config_Validator\x10\x01\x12\x1e\n\x1a\x43omponent_Config_Populator\x10\x02\x12\x1f\n\x1b\x43omponent_Data_Preprocessor\x10\x03\x12\x1e\n\x1a\x43omponent_Subgraph_Sampler\x10\x04\x12\x1d\n\x19\x43omponent_Split_Generator\x10\x05\x12\x15\n\x11\x43omponent_Trainer\x10\x06\x12\x18\n\x14\x43omponent_Inferencer\x10\x07\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n1snapchat/research/gbml/gigl_resource_config.proto\x12\x16snapchat.research.gbml\"Y\n\x13SparkResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x16\n\x0enum_local_ssds\x18\x02 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x03 \x01(\r\"\x83\x01\n\x16\x44\x61taflowResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\x12\x17\n\x0fmax_num_workers\x18\x02 \x01(\r\x12\x14\n\x0cmachine_type\x18\x03 \x01(\t\x12\x14\n\x0c\x64isk_size_gb\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\"\xbc\x01\n\x16\x44\x61taPreprocessorConfig\x12P\n\x18\x65\x64ge_preprocessor_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\x12P\n\x18node_preprocessor_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfig\"h\n\x15VertexAiTrainerConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\"z\n\x10KFPTrainerConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\")\n\x12LocalTrainerConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"O\n\x1bVertexAiReservationAffinity\x12\x0c\n\x04type\x18\x01 \x01(\t\x12\"\n\x1areservation_resource_names\x18\x02 \x03(\t\"\xa2\x02\n\x16VertexAiResourceConfig\x12\x14\n\x0cmachine_type\x18\x01 \x01(\t\x12\x10\n\x08gpu_type\x18\x02 \x01(\t\x12\x11\n\tgpu_limit\x18\x03 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x04 \x01(\r\x12\x0f\n\x07timeout\x18\x05 \x01(\r\x12\x1b\n\x13gcp_region_override\x18\x06 \x01(\t\x12\x1b\n\x13scheduling_strategy\x18\x07 \x01(\t\x12\x19\n\x11\x62oot_disk_size_gb\x18\x08 \x01(\r\x12Q\n\x14reservation_affinity\x18\t \x01(\x0b\x32\x33.snapchat.research.gbml.VertexAiReservationAffinity\"{\n\x11KFPResourceConfig\x12\x13\n\x0b\x63pu_request\x18\x01 \x01(\t\x12\x16\n\x0ememory_request\x18\x02 \x01(\t\x12\x10\n\x08gpu_type\x18\x03 \x01(\t\x12\x11\n\tgpu_limit\x18\x04 \x01(\r\x12\x14\n\x0cnum_replicas\x18\x05 \x01(\r\"*\n\x13LocalResourceConfig\x12\x13\n\x0bnum_workers\x18\x01 \x01(\r\"\xd4\x01\n\x18VertexAiGraphStoreConfig\x12H\n\x10graph_store_pool\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\x12\x44\n\x0c\x63ompute_pool\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfig\x12(\n compute_cluster_local_world_size\x18\x03 \x01(\x05\"5\n\x14\x43ustomResourceConfig\x12\x0f\n\x07\x63ommand\x18\x01 \x01(\t\x12\x0c\n\x04\x61rgs\x18\x02 \x03(\t\"\x93\x02\n\x18\x44istributedTrainerConfig\x12Q\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32-.snapchat.research.gbml.VertexAiTrainerConfigH\x00\x12\x46\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32(.snapchat.research.gbml.KFPTrainerConfigH\x00\x12J\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32*.snapchat.research.gbml.LocalTrainerConfigH\x00\x42\x10\n\x0etrainer_config\"\xc4\x03\n\x15TrainerResourceConfig\x12R\n\x18vertex_ai_trainer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12G\n\x12kfp_trainer_config\x18\x02 \x01(\x0b\x32).snapchat.research.gbml.KFPResourceConfigH\x00\x12K\n\x14local_trainer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12`\n$vertex_ai_graph_store_trainer_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.VertexAiGraphStoreConfigH\x00\x12M\n\x15\x63ustom_trainer_config\x18\x05 \x01(\x0b\x32,.snapchat.research.gbml.CustomResourceConfigH\x00\x42\x10\n\x0etrainer_config\"\xe3\x03\n\x18InferencerResourceConfig\x12U\n\x1bvertex_ai_inferencer_config\x18\x01 \x01(\x0b\x32..snapchat.research.gbml.VertexAiResourceConfigH\x00\x12T\n\x1a\x64\x61taflow_inferencer_config\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigH\x00\x12N\n\x17local_inferencer_config\x18\x03 \x01(\x0b\x32+.snapchat.research.gbml.LocalResourceConfigH\x00\x12\x63\n\'vertex_ai_graph_store_inferencer_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.VertexAiGraphStoreConfigH\x00\x12P\n\x18\x63ustom_inferencer_config\x18\x05 \x01(\x0b\x32,.snapchat.research.gbml.CustomResourceConfigH\x00\x42\x13\n\x11inferencer_config\"\xa3\x04\n\x14SharedResourceConfig\x12Y\n\x0fresource_labels\x18\x01 \x03(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.ResourceLabelsEntry\x12_\n\x15\x63ommon_compute_config\x18\x02 \x01(\x0b\x32@.snapchat.research.gbml.SharedResourceConfig.CommonComputeConfig\x1a\x97\x02\n\x13\x43ommonComputeConfig\x12\x0f\n\x07project\x18\x01 \x01(\t\x12\x0e\n\x06region\x18\x02 \x01(\t\x12\x1a\n\x12temp_assets_bucket\x18\x03 \x01(\t\x12#\n\x1btemp_regional_assets_bucket\x18\x04 \x01(\t\x12\x1a\n\x12perm_assets_bucket\x18\x05 \x01(\t\x12#\n\x1btemp_assets_bq_dataset_name\x18\x06 \x01(\t\x12!\n\x19\x65mbedding_bq_dataset_name\x18\x07 \x01(\t\x12!\n\x19gcp_service_account_email\x18\x08 \x01(\t\x12\x17\n\x0f\x64\x61taflow_runner\x18\x0b \x01(\t\x1a\x35\n\x13ResourceLabelsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\"\xc8\x05\n\x12GiglResourceConfig\x12$\n\x1ashared_resource_config_uri\x18\x01 \x01(\tH\x00\x12N\n\x16shared_resource_config\x18\x02 \x01(\x0b\x32,.snapchat.research.gbml.SharedResourceConfigH\x00\x12K\n\x13preprocessor_config\x18\x0c \x01(\x0b\x32..snapchat.research.gbml.DataPreprocessorConfig\x12L\n\x17subgraph_sampler_config\x18\r \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12K\n\x16split_generator_config\x18\x0e \x01(\x0b\x32+.snapchat.research.gbml.SparkResourceConfig\x12L\n\x0etrainer_config\x18\x0f \x01(\x0b\x32\x30.snapchat.research.gbml.DistributedTrainerConfigB\x02\x18\x01\x12M\n\x11inferencer_config\x18\x10 \x01(\x0b\x32..snapchat.research.gbml.DataflowResourceConfigB\x02\x18\x01\x12N\n\x17trainer_resource_config\x18\x11 \x01(\x0b\x32-.snapchat.research.gbml.TrainerResourceConfig\x12T\n\x1ainferencer_resource_config\x18\x12 \x01(\x0b\x32\x30.snapchat.research.gbml.InferencerResourceConfigB\x11\n\x0fshared_resource*\xf3\x01\n\tComponent\x12\x15\n\x11\x43omponent_Unknown\x10\x00\x12\x1e\n\x1a\x43omponent_Config_Validator\x10\x01\x12\x1e\n\x1a\x43omponent_Config_Populator\x10\x02\x12\x1f\n\x1b\x43omponent_Data_Preprocessor\x10\x03\x12\x1e\n\x1a\x43omponent_Subgraph_Sampler\x10\x04\x12\x1d\n\x19\x43omponent_Split_Generator\x10\x05\x12\x15\n\x11\x43omponent_Trainer\x10\x06\x12\x18\n\x14\x43omponent_Inferencer\x10\x07\x62\x06proto3') _COMPONENT = DESCRIPTOR.enum_types_by_name['Component'] Component = enum_type_wrapper.EnumTypeWrapper(_COMPONENT) @@ -40,6 +40,7 @@ _KFPRESOURCECONFIG = DESCRIPTOR.message_types_by_name['KFPResourceConfig'] _LOCALRESOURCECONFIG = DESCRIPTOR.message_types_by_name['LocalResourceConfig'] _VERTEXAIGRAPHSTORECONFIG = DESCRIPTOR.message_types_by_name['VertexAiGraphStoreConfig'] +_CUSTOMRESOURCECONFIG = DESCRIPTOR.message_types_by_name['CustomResourceConfig'] _DISTRIBUTEDTRAINERCONFIG = DESCRIPTOR.message_types_by_name['DistributedTrainerConfig'] _TRAINERRESOURCECONFIG = DESCRIPTOR.message_types_by_name['TrainerResourceConfig'] _INFERENCERRESOURCECONFIG = DESCRIPTOR.message_types_by_name['InferencerResourceConfig'] @@ -124,6 +125,13 @@ }) _sym_db.RegisterMessage(VertexAiGraphStoreConfig) +CustomResourceConfig = _reflection.GeneratedProtocolMessageType('CustomResourceConfig', (_message.Message,), { + 'DESCRIPTOR' : _CUSTOMRESOURCECONFIG, + '__module__' : 'snapchat.research.gbml.gigl_resource_config_pb2' + # @@protoc_insertion_point(class_scope:snapchat.research.gbml.CustomResourceConfig) + }) +_sym_db.RegisterMessage(CustomResourceConfig) + DistributedTrainerConfig = _reflection.GeneratedProtocolMessageType('DistributedTrainerConfig', (_message.Message,), { 'DESCRIPTOR' : _DISTRIBUTEDTRAINERCONFIG, '__module__' : 'snapchat.research.gbml.gigl_resource_config_pb2' @@ -184,8 +192,8 @@ _GIGLRESOURCECONFIG.fields_by_name['trainer_config']._serialized_options = b'\030\001' _GIGLRESOURCECONFIG.fields_by_name['inferencer_config']._options = None _GIGLRESOURCECONFIG.fields_by_name['inferencer_config']._serialized_options = b'\030\001' - _COMPONENT._serialized_start=3848 - _COMPONENT._serialized_end=4091 + _COMPONENT._serialized_start=4064 + _COMPONENT._serialized_end=4307 _SPARKRESOURCECONFIG._serialized_start=77 _SPARKRESOURCECONFIG._serialized_end=166 _DATAFLOWRESOURCECONFIG._serialized_start=169 @@ -208,18 +216,20 @@ _LOCALRESOURCECONFIG._serialized_end=1307 _VERTEXAIGRAPHSTORECONFIG._serialized_start=1310 _VERTEXAIGRAPHSTORECONFIG._serialized_end=1522 - _DISTRIBUTEDTRAINERCONFIG._serialized_start=1525 - _DISTRIBUTEDTRAINERCONFIG._serialized_end=1800 - _TRAINERRESOURCECONFIG._serialized_start=1803 - _TRAINERRESOURCECONFIG._serialized_end=2176 - _INFERENCERRESOURCECONFIG._serialized_start=2179 - _INFERENCERRESOURCECONFIG._serialized_end=2580 - _SHAREDRESOURCECONFIG._serialized_start=2583 - _SHAREDRESOURCECONFIG._serialized_end=3130 - _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_start=2796 - _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_end=3075 - _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_start=3077 - _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_end=3130 - _GIGLRESOURCECONFIG._serialized_start=3133 - _GIGLRESOURCECONFIG._serialized_end=3845 + _CUSTOMRESOURCECONFIG._serialized_start=1524 + _CUSTOMRESOURCECONFIG._serialized_end=1577 + _DISTRIBUTEDTRAINERCONFIG._serialized_start=1580 + _DISTRIBUTEDTRAINERCONFIG._serialized_end=1855 + _TRAINERRESOURCECONFIG._serialized_start=1858 + _TRAINERRESOURCECONFIG._serialized_end=2310 + _INFERENCERRESOURCECONFIG._serialized_start=2313 + _INFERENCERRESOURCECONFIG._serialized_end=2796 + _SHAREDRESOURCECONFIG._serialized_start=2799 + _SHAREDRESOURCECONFIG._serialized_end=3346 + _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_start=3012 + _SHAREDRESOURCECONFIG_COMMONCOMPUTECONFIG._serialized_end=3291 + _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_start=3293 + _SHAREDRESOURCECONFIG_RESOURCELABELSENTRY._serialized_end=3346 + _GIGLRESOURCECONFIG._serialized_start=3349 + _GIGLRESOURCECONFIG._serialized_end=4061 # @@protoc_insertion_point(module_scope) diff --git a/snapchat/research/gbml/gigl_resource_config_pb2.pyi b/snapchat/research/gbml/gigl_resource_config_pb2.pyi index 6198d1076..dbf842ea8 100644 --- a/snapchat/research/gbml/gigl_resource_config_pb2.pyi +++ b/snapchat/research/gbml/gigl_resource_config_pb2.pyi @@ -396,6 +396,41 @@ class VertexAiGraphStoreConfig(google.protobuf.message.Message): global___VertexAiGraphStoreConfig = VertexAiGraphStoreConfig +class CustomResourceConfig(google.protobuf.message.Message): + """Lets user-defined launchers be piped in. + The launcher dispatcher invokes `command` (interpreted by /bin/sh -c so + leading "KEY=VALUE" assignments parse as inline env vars) with `args` + appended as positional arguments. String fields support OmegaConf + `${gigl:}` substitutions, which the dispatcher resolves at exec + time from the runtime context (task_config_uri, applied_task_identifier, + component, etc.). + """ + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + COMMAND_FIELD_NUMBER: builtins.int + ARGS_FIELD_NUMBER: builtins.int + command: builtins.str + """Shell snippet invoked via /bin/sh -c. Leading "KEY=VALUE" assignments + are honored by the shell, so callers can inline env vars (e.g. + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python python -m my.cli"). + """ + @property + def args(self) -> google.protobuf.internal.containers.RepeatedScalarFieldContainer[builtins.str]: + """Positional arguments appended after the command. Each element is + shell-quoted by the dispatcher so values containing spaces/quotes + survive the shell pass. + """ + def __init__( + self, + *, + command: builtins.str = ..., + args: collections.abc.Iterable[builtins.str] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["args", b"args", "command", b"command"]) -> None: ... + +global___CustomResourceConfig = CustomResourceConfig + class DistributedTrainerConfig(google.protobuf.message.Message): """(deprecated) Configuration for distributed training resources @@ -434,6 +469,7 @@ class TrainerResourceConfig(google.protobuf.message.Message): KFP_TRAINER_CONFIG_FIELD_NUMBER: builtins.int LOCAL_TRAINER_CONFIG_FIELD_NUMBER: builtins.int VERTEX_AI_GRAPH_STORE_TRAINER_CONFIG_FIELD_NUMBER: builtins.int + CUSTOM_TRAINER_CONFIG_FIELD_NUMBER: builtins.int @property def vertex_ai_trainer_config(self) -> global___VertexAiResourceConfig: ... @property @@ -442,6 +478,8 @@ class TrainerResourceConfig(google.protobuf.message.Message): def local_trainer_config(self) -> global___LocalResourceConfig: ... @property def vertex_ai_graph_store_trainer_config(self) -> global___VertexAiGraphStoreConfig: ... + @property + def custom_trainer_config(self) -> global___CustomResourceConfig: ... def __init__( self, *, @@ -449,10 +487,11 @@ class TrainerResourceConfig(google.protobuf.message.Message): kfp_trainer_config: global___KFPResourceConfig | None = ..., local_trainer_config: global___LocalResourceConfig | None = ..., vertex_ai_graph_store_trainer_config: global___VertexAiGraphStoreConfig | None = ..., + custom_trainer_config: global___CustomResourceConfig | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_graph_store_trainer_config", b"vertex_ai_graph_store_trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_graph_store_trainer_config", b"vertex_ai_graph_store_trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["trainer_config", b"trainer_config"]) -> typing_extensions.Literal["vertex_ai_trainer_config", "kfp_trainer_config", "local_trainer_config", "vertex_ai_graph_store_trainer_config"] | None: ... + def HasField(self, field_name: typing_extensions.Literal["custom_trainer_config", b"custom_trainer_config", "kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_graph_store_trainer_config", b"vertex_ai_graph_store_trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["custom_trainer_config", b"custom_trainer_config", "kfp_trainer_config", b"kfp_trainer_config", "local_trainer_config", b"local_trainer_config", "trainer_config", b"trainer_config", "vertex_ai_graph_store_trainer_config", b"vertex_ai_graph_store_trainer_config", "vertex_ai_trainer_config", b"vertex_ai_trainer_config"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["trainer_config", b"trainer_config"]) -> typing_extensions.Literal["vertex_ai_trainer_config", "kfp_trainer_config", "local_trainer_config", "vertex_ai_graph_store_trainer_config", "custom_trainer_config"] | None: ... global___TrainerResourceConfig = TrainerResourceConfig @@ -465,6 +504,7 @@ class InferencerResourceConfig(google.protobuf.message.Message): DATAFLOW_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int LOCAL_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int VERTEX_AI_GRAPH_STORE_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int + CUSTOM_INFERENCER_CONFIG_FIELD_NUMBER: builtins.int @property def vertex_ai_inferencer_config(self) -> global___VertexAiResourceConfig: ... @property @@ -473,6 +513,8 @@ class InferencerResourceConfig(google.protobuf.message.Message): def local_inferencer_config(self) -> global___LocalResourceConfig: ... @property def vertex_ai_graph_store_inferencer_config(self) -> global___VertexAiGraphStoreConfig: ... + @property + def custom_inferencer_config(self) -> global___CustomResourceConfig: ... def __init__( self, *, @@ -480,10 +522,11 @@ class InferencerResourceConfig(google.protobuf.message.Message): dataflow_inferencer_config: global___DataflowResourceConfig | None = ..., local_inferencer_config: global___LocalResourceConfig | None = ..., vertex_ai_graph_store_inferencer_config: global___VertexAiGraphStoreConfig | None = ..., + custom_inferencer_config: global___CustomResourceConfig | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_graph_store_inferencer_config", b"vertex_ai_graph_store_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_graph_store_inferencer_config", b"vertex_ai_graph_store_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config"]) -> None: ... - def WhichOneof(self, oneof_group: typing_extensions.Literal["inferencer_config", b"inferencer_config"]) -> typing_extensions.Literal["vertex_ai_inferencer_config", "dataflow_inferencer_config", "local_inferencer_config", "vertex_ai_graph_store_inferencer_config"] | None: ... + def HasField(self, field_name: typing_extensions.Literal["custom_inferencer_config", b"custom_inferencer_config", "dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_graph_store_inferencer_config", b"vertex_ai_graph_store_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["custom_inferencer_config", b"custom_inferencer_config", "dataflow_inferencer_config", b"dataflow_inferencer_config", "inferencer_config", b"inferencer_config", "local_inferencer_config", b"local_inferencer_config", "vertex_ai_graph_store_inferencer_config", b"vertex_ai_graph_store_inferencer_config", "vertex_ai_inferencer_config", b"vertex_ai_inferencer_config"]) -> None: ... + def WhichOneof(self, oneof_group: typing_extensions.Literal["inferencer_config", b"inferencer_config"]) -> typing_extensions.Literal["vertex_ai_inferencer_config", "dataflow_inferencer_config", "local_inferencer_config", "vertex_ai_graph_store_inferencer_config", "custom_inferencer_config"] | None: ... global___InferencerResourceConfig = InferencerResourceConfig diff --git a/tests/unit/common/test_omegaconf_resolvers.py b/tests/unit/common/test_omegaconf_resolvers.py index 3fa9da934..9ba32f712 100644 --- a/tests/unit/common/test_omegaconf_resolvers.py +++ b/tests/unit/common/test_omegaconf_resolvers.py @@ -6,7 +6,10 @@ from absl.testing import absltest from omegaconf import OmegaConf -from gigl.common.omegaconf_resolvers import register_resolvers +from gigl.common.omegaconf_resolvers import ( + register_resolvers, + set_gigl_resolver_values, +) from tests.test_assets.test_case import TestCase @@ -126,5 +129,58 @@ def test_git_hash_resolver_not_git_repo(self, mock_subprocess_run): self.assertEqual(OmegaConf.create(yaml_config).experiment.commit, "") +class TestGiglResolver(TestCase): + """Tests for the ``gigl`` OmegaConf resolver. + + The resolver pulls runtime values (set by ``launch_custom`` before + re-resolving the proto's ``command`` / ``args`` strings) from a + module-level dict, falling back to the literal placeholder + ``"${gigl:}"`` so the first-pass YAML load is lossless. + """ + + def setUp(self): + register_resolvers() + # Reset the resolver dict between tests so values from prior cases + # do not leak β€” set_gigl_resolver_values clears before populating. + set_gigl_resolver_values({}) + + def test_gigl_resolver_returns_value_when_set(self): + set_gigl_resolver_values({"foo": "bar"}) + cfg = OmegaConf.create({"v": "${gigl:foo}"}) + self.assertEqual(cfg.v, "bar") + + def test_gigl_resolver_returns_placeholder_when_unset(self): + cfg = OmegaConf.create({"v": "${gigl:foo}"}) + self.assertEqual(cfg.v, "${gigl:foo}") + + def test_gigl_resolver_round_trips_unset_keys_through_yaml(self): + # Models the first-pass YAML load that ProtoUtils does β€” the + # placeholder must survive a YAMLβ†’OmegaConf round trip so + # launch_custom can re-resolve it later with values set. + yaml_config = """ + custom: + command: "${gigl:task_config_uri}" + args: + - "--component=${gigl:component}" + - "--applied_task_identifier=${gigl:applied_task_identifier}" + """ + cfg = OmegaConf.create(yaml.safe_load(yaml_config)) + self.assertEqual(cfg.custom.command, "${gigl:task_config_uri}") + self.assertEqual(cfg.custom.args[0], "--component=${gigl:component}") + self.assertEqual( + cfg.custom.args[1], + "--applied_task_identifier=${gigl:applied_task_identifier}", + ) + + def test_set_gigl_resolver_values_overwrites_prior_values(self): + set_gigl_resolver_values({"a": "1", "b": "2"}) + set_gigl_resolver_values({"a": "10"}) + cfg = OmegaConf.create({"a": "${gigl:a}", "b": "${gigl:b}"}) + self.assertEqual(cfg.a, "10") + # ``b`` was not in the second call, so it must fall back to the + # placeholder rather than retaining the prior call's "2" value. + self.assertEqual(cfg.b, "${gigl:b}") + + if __name__ == "__main__": absltest.main() diff --git a/tests/unit/env/runtime_test.py b/tests/unit/env/runtime_test.py new file mode 100644 index 000000000..52b7a6be0 --- /dev/null +++ b/tests/unit/env/runtime_test.py @@ -0,0 +1,64 @@ +"""Tests for ``gigl.env.runtime`` execution-environment detection.""" + +import os +from unittest.mock import patch + +from absl.testing import absltest + +from gigl.env.runtime import RuntimeEnv, get_runtime_env, is_ray_runtime +from tests.test_assets.test_case import TestCase + + +class TestIsRayRuntime(TestCase): + """Exercises each branch of the ``is_ray_runtime`` priority chain.""" + + def test_gigl_ray_runtime_authoritative_signal_wins(self) -> None: + with patch.dict(os.environ, {"GIGL_RAY_RUNTIME": "1"}, clear=True): + self.assertTrue(is_ray_runtime()) + + def test_ray_dashboard_address_triggers_ray(self) -> None: + with patch.dict( + os.environ, + {"RAY_DASHBOARD_ADDRESS": "http://10.0.0.1:8265"}, + clear=True, + ): + self.assertTrue(is_ray_runtime()) + + def test_ray_address_triggers_ray(self) -> None: + with patch.dict( + os.environ, {"RAY_ADDRESS": "ray://10.0.0.1:10001"}, clear=True + ): + self.assertTrue(is_ray_runtime()) + + def test_gigl_ray_runtime_must_be_one(self) -> None: + # Only the literal string "1" is treated as the authoritative signal. + with patch.dict(os.environ, {"GIGL_RAY_RUNTIME": "0"}, clear=True): + # Cannot `ray.init` in a unit test; rely on the ImportError fallback + # path returning False when ray itself is unavailable. If ray *is* + # installed, fall through to is_initialized() which must be False. + self.assertFalse(is_ray_runtime()) + + +class TestGetRuntimeEnv(TestCase): + """Exercises each branch of ``get_runtime_env``.""" + + def test_gigl_ray_runtime_returns_ray(self) -> None: + with patch.dict(os.environ, {"GIGL_RAY_RUNTIME": "1"}, clear=True): + self.assertEqual(get_runtime_env(), RuntimeEnv.RAY) + + def test_cloud_ml_job_id_returns_vertex_ai(self) -> None: + with patch.dict(os.environ, {"CLOUD_ML_JOB_ID": "12345"}, clear=True): + self.assertEqual(get_runtime_env(), RuntimeEnv.VERTEX_AI) + + def test_aip_model_dir_returns_vertex_ai(self) -> None: + with patch.dict(os.environ, {"AIP_MODEL_DIR": "gs://bucket/model"}, clear=True): + self.assertEqual(get_runtime_env(), RuntimeEnv.VERTEX_AI) + + def test_no_env_returns_unknown(self) -> None: + with patch.dict(os.environ, {}, clear=True): + self.assertEqual(get_runtime_env(), RuntimeEnv.UNKNOWN) + self.assertFalse(is_ray_runtime()) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/unit/orchestration/kubeflow/kfp_orchestrator_test.py b/tests/unit/orchestration/kubeflow/kfp_orchestrator_test.py index 34cef81f9..ea4a2dd96 100644 --- a/tests/unit/orchestration/kubeflow/kfp_orchestrator_test.py +++ b/tests/unit/orchestration/kubeflow/kfp_orchestrator_test.py @@ -1,8 +1,11 @@ +import tempfile +from pathlib import Path from unittest.mock import ANY, patch +import yaml from absl.testing import absltest -from gigl.common import GcsUri +from gigl.common import GcsUri, LocalUri from gigl.common.logger import Logger from gigl.orchestration.kubeflow.kfp_orchestrator import KfpOrchestrator from tests.test_assets.test_case import TestCase @@ -29,6 +32,86 @@ def test_compile_uploads_compiled_yaml(self, MockFileLoader): file_uri_src=ANY, file_uri_dst=dst_compiled_pipeline_path ) + def test_compile_bakes_env_vars_into_every_gigl_owned_executor(self): + """env_vars passed to compile() should appear on every GiGL-owned executor's container env. + + The managed VertexNotificationEmailOp exit handler is the documented + carve-out and must not receive the env vars. + """ + env_vars = { + "FOO": "bar", + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": "python", + } + with tempfile.TemporaryDirectory() as tmp_dir: + dst = LocalUri(str(Path(tmp_dir) / "pipeline.yaml")) + KfpOrchestrator.compile( + cuda_container_image="SOME NONEXISTENT IMAGE 1", + cpu_container_image="SOME NONEXISTENT IMAGE 2", + dataflow_container_image="SOME NONEXISTENT IMAGE 3", + dst_compiled_pipeline_path=dst, + env_vars=env_vars, + ) + + with open(dst.uri, "r") as f: + compiled = yaml.safe_load(f) + + executors = compiled["deploymentSpec"]["executors"] + self.assertGreater(len(executors), 0, "Expected at least one executor in IR.") + + gigl_owned_with_env: list[str] = [] + notification_executors_without_env: list[str] = [] + for executor_id, executor_spec in executors.items(): + container = executor_spec.get("container", {}) + env_list = container.get("env", []) + env_dict = {entry["name"]: entry["value"] for entry in env_list} + is_notification = "notification-email" in executor_id.lower() + if is_notification: + # The managed notification op must not receive our env vars. + for name in env_vars: + self.assertNotIn( + name, + env_dict, + f"Env var {name} unexpectedly applied to managed " + f"notification executor {executor_id}.", + ) + notification_executors_without_env.append(executor_id) + else: + for name, value in env_vars.items(): + self.assertEqual( + env_dict.get(name), + value, + f"Executor {executor_id} missing env var {name}={value}; " + f"actual env: {env_dict}.", + ) + gigl_owned_with_env.append(executor_id) + + self.assertGreater( + len(gigl_owned_with_env), + 0, + "Expected at least one GiGL-owned executor to receive env vars.", + ) + + def test_compile_without_env_vars_does_not_inject_env(self): + """When env_vars is omitted, no GiGL-owned executor should pick up phantom env entries from this code path.""" + with tempfile.TemporaryDirectory() as tmp_dir: + dst = LocalUri(str(Path(tmp_dir) / "pipeline.yaml")) + KfpOrchestrator.compile( + cuda_container_image="SOME NONEXISTENT IMAGE 1", + cpu_container_image="SOME NONEXISTENT IMAGE 2", + dataflow_container_image="SOME NONEXISTENT IMAGE 3", + dst_compiled_pipeline_path=dst, + ) + + with open(dst.uri, "r") as f: + compiled = yaml.safe_load(f) + + # Only assert the default (unset) case adds no FOO key β€” we don't make + # claims about other env entries that KFP itself may inject. + for executor_spec in compiled["deploymentSpec"]["executors"].values(): + env_list = executor_spec.get("container", {}).get("env", []) + env_names = {entry["name"] for entry in env_list} + self.assertNotIn("FOO", env_names) + if __name__ == "__main__": absltest.main() diff --git a/tests/unit/orchestration/kubeflow/kfp_runner_test.py b/tests/unit/orchestration/kubeflow/kfp_runner_test.py index eade9b3ed..368ed61cb 100644 --- a/tests/unit/orchestration/kubeflow/kfp_runner_test.py +++ b/tests/unit/orchestration/kubeflow/kfp_runner_test.py @@ -5,6 +5,7 @@ _assert_required_flags, _get_parser, _parse_additional_job_args, + _parse_env_vars, _parse_labels, ) from gigl.src.common.constants.components import GiGLComponents @@ -110,6 +111,76 @@ def test_assert_required_flags_success(self): # Should not raise any exception _assert_required_flags(args) + def test_parse_env_vars_single(self): + parsed = _parse_env_vars(["FOO=bar"]) + self.assertEqual(parsed, {"FOO": "bar"}) + + def test_parse_env_vars_multiple(self): + parsed = _parse_env_vars(["FOO=bar", "BAZ=qux"]) + self.assertEqual(parsed, {"FOO": "bar", "BAZ": "qux"}) + + def test_parse_env_vars_value_contains_equals(self): + # split("=", 1) means only the first '=' delimits key/value; the rest is value. + parsed = _parse_env_vars(["URL=https://example.com/?q=1&r=2"]) + self.assertEqual(parsed, {"URL": "https://example.com/?q=1&r=2"}) + + def test_parse_env_vars_empty_value(self): + parsed = _parse_env_vars(["FOO="]) + self.assertEqual(parsed, {"FOO": ""}) + + def test_parse_env_vars_empty_list(self): + self.assertEqual(_parse_env_vars([]), {}) + + def test_parse_env_vars_malformed_raises(self): + # No '=' in the entry β€” split("=", 1) returns a single-element list and the + # tuple unpack raises ValueError, mirroring _parse_labels semantics. + with self.assertRaises(ValueError): + _parse_env_vars(["NOT_A_VALID_ENTRY"]) + + def test_parse_env_vars_duplicate_keys_last_wins(self): + parsed = _parse_env_vars(["FOO=first", "FOO=second"]) + self.assertEqual(parsed, {"FOO": "second"}) + + def test_assert_required_flags_rejects_env_vars_with_run_no_compile(self): + """--env_vars must not be combined with --action=run_no_compile.""" + parser = _get_parser() + args = parser.parse_args( + [ + "--action=run_no_compile", + "--task_config_uri=gs://bucket/task_config.yaml", + "--resource_config_uri=gs://bucket/resource_config.yaml", + "--compiled_pipeline_path=gs://bucket/pipeline.yaml", + "--env_vars=FOO=bar", + ] + ) + with self.assertRaises(ValueError): + _assert_required_flags(args) + + def test_assert_required_flags_allows_env_vars_with_run(self): + """--env_vars is valid for --action=run.""" + parser = _get_parser() + args = parser.parse_args( + [ + "--action=run", + "--task_config_uri=gs://bucket/task_config.yaml", + "--resource_config_uri=gs://bucket/resource_config.yaml", + "--env_vars=FOO=bar", + ] + ) + _assert_required_flags(args) + + def test_assert_required_flags_allows_env_vars_with_compile(self): + """--env_vars is valid for --action=compile.""" + parser = _get_parser() + args = parser.parse_args( + [ + "--action=compile", + "--compiled_pipeline_path=gs://bucket/pipeline.yaml", + "--env_vars=FOO=bar", + ] + ) + _assert_required_flags(args) + if __name__ == "__main__": absltest.main() diff --git a/tests/unit/src/common/custom_launcher_test.py b/tests/unit/src/common/custom_launcher_test.py new file mode 100644 index 000000000..3c18cbac1 --- /dev/null +++ b/tests/unit/src/common/custom_launcher_test.py @@ -0,0 +1,211 @@ +"""Unit tests for ``gigl.src.common.custom_launcher``.""" + +from unittest.mock import MagicMock, patch + +from absl.testing import absltest + +from gigl.common import Uri +from gigl.src.common.constants.components import GiGLComponents +from gigl.src.common.custom_launcher import launch_custom +from snapchat.research.gbml import gigl_resource_config_pb2 +from tests.test_assets.test_case import TestCase + + +class TestLaunchCustom(TestCase): + """Exercises ``launch_custom`` subprocess dispatch and guards. + + The launcher resolves ``${gigl:*}`` placeholders in the proto's + ``command`` / ``args`` fields against the runtime kwargs and shells + out via ``subprocess.run``. Tests patch ``subprocess.run`` to + capture the resolved shell line without actually spawning processes. + """ + + def _build_config( + self, + command: str, + args: list[str] | None = None, + ) -> gigl_resource_config_pb2.CustomResourceConfig: + return gigl_resource_config_pb2.CustomResourceConfig( + command=command, + args=args or [], + ) + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_dispatches_subprocess_with_resolved_command_and_args( + self, mock_run: MagicMock + ) -> None: + config = self._build_config( + command="python -m my.cli", + args=[ + "--task_config_uri=${gigl:task_config_uri}", + "--component=${gigl:component}", + "--cuda=${gigl:cuda_docker_image}", + "--applied_task_identifier=${gigl:applied_task_identifier}", + ], + ) + launch_custom( + custom_resource_config=config, + applied_task_identifier="job-42", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="ignored", + process_runtime_args={"ignored": "v"}, + cpu_docker_uri="gcr.io/p/cpu:tag", + cuda_docker_uri="gcr.io/p/cuda:tag", + component=GiGLComponents.Trainer, + is_dry_run=False, + ) + + mock_run.assert_called_once() + shell_line = mock_run.call_args.args[0] + self.assertIn("python -m my.cli", shell_line) + self.assertIn("--task_config_uri=gs://bucket/task.yaml", shell_line) + # component should resolve to the Title-case name (matching CLI + # argparse choices), NOT the lowercase enum value. + self.assertIn("--component=Trainer", shell_line) + self.assertIn("--cuda=gcr.io/p/cuda:tag", shell_line) + self.assertIn("--applied_task_identifier=job-42", shell_line) + # No leftover ${gigl:*} placeholders in the shell line. + self.assertNotIn("${gigl:", shell_line) + # subprocess invoked with shell=True and check=True. + self.assertTrue(mock_run.call_args.kwargs.get("shell", False)) + self.assertTrue(mock_run.call_args.kwargs.get("check", False)) + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_is_dry_run_skips_subprocess(self, mock_run: MagicMock) -> None: + config = self._build_config(command="echo", args=["hi"]) + launch_custom( + custom_resource_config=config, + applied_task_identifier="job", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="", + process_runtime_args={}, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Trainer, + is_dry_run=True, + ) + mock_run.assert_not_called() + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_is_dry_run_defaults_to_false(self, mock_run: MagicMock) -> None: + config = self._build_config(command="echo", args=[]) + launch_custom( + custom_resource_config=config, + applied_task_identifier="job-43", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="", + process_runtime_args={}, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Inferencer, + ) + mock_run.assert_called_once() + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_empty_command_raises_value_error(self, mock_run: MagicMock) -> None: + config = self._build_config(command="", args=["ignored"]) + with self.assertRaises(ValueError): + launch_custom( + custom_resource_config=config, + applied_task_identifier="job", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="", + process_runtime_args={}, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Trainer, + ) + mock_run.assert_not_called() + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_invalid_component_raises_value_error(self, mock_run: MagicMock) -> None: + config = self._build_config(command="echo") + with self.assertRaises(ValueError): + launch_custom( + custom_resource_config=config, + applied_task_identifier="job", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="", + process_runtime_args={}, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.DataPreprocessor, + ) + mock_run.assert_not_called() + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_args_with_spaces_are_shell_quoted(self, mock_run: MagicMock) -> None: + config = self._build_config( + command="echo", args=["a b c", "--name=with space"] + ) + launch_custom( + custom_resource_config=config, + applied_task_identifier="job", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="", + process_runtime_args={}, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Trainer, + ) + shell_line = mock_run.call_args.args[0] + # shlex.quote wraps tokens with spaces in single quotes so the + # shell sees one argv element per proto args[] entry. + self.assertIn("'a b c'", shell_line) + self.assertIn("'--name=with space'", shell_line) + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_process_command_and_runtime_args_are_not_plumbed( + self, mock_run: MagicMock + ) -> None: + # Confirm the resolver dict does not carry process_command or + # process_runtime_args β€” consumers re-derive them from + # ${gigl:task_config_uri} on the receiving side. + config = self._build_config(command="python", args=["-m", "foo"]) + launch_custom( + custom_resource_config=config, + applied_task_identifier="job", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="should-not-appear", + process_runtime_args={"unused_lr": "0.42"}, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Trainer, + ) + shell_line = mock_run.call_args.args[0] + self.assertNotIn("should-not-appear", shell_line) + self.assertNotIn("unused_lr", shell_line) + self.assertNotIn("0.42", shell_line) + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_logs_resolved_shell_line(self, mock_run: MagicMock) -> None: + config = self._build_config(command="echo", args=["${gigl:component}"]) + mock_logger = MagicMock() + with patch("gigl.src.common.custom_launcher.logger", new=mock_logger): + launch_custom( + custom_resource_config=config, + applied_task_identifier="job", + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + process_command="", + process_runtime_args={}, + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Inferencer, + is_dry_run=False, + ) + mock_logger.info.assert_called_once() + (log_line,), _ = mock_logger.info.call_args + self.assertIn("Inferencer", log_line) + self.assertIn("dry_run=False", log_line) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/unit/src/common/utils/types/pb_wrappers/gigl_resource_config_test.py b/tests/unit/src/common/utils/types/pb_wrappers/gigl_resource_config_test.py index c9394857d..bba658738 100644 --- a/tests/unit/src/common/utils/types/pb_wrappers/gigl_resource_config_test.py +++ b/tests/unit/src/common/utils/types/pb_wrappers/gigl_resource_config_test.py @@ -241,6 +241,23 @@ def test_trainer_config_vertex_ai_graph_store(self): wrapper = GiglResourceConfigWrapper(resource_config=config) self.assertEqual(wrapper.trainer_config, trainer_config) + def test_trainer_config_custom(self): + """Test trainer_config with Custom (user-supplied launcher) configuration.""" + config = self._create_gigl_resource_config_with_direct_shared_config() + trainer_config = gigl_resource_config_pb2.CustomResourceConfig( + command="python -m my_project.launchers.ray.launch", + args=["--cluster=dev", "--num_workers=4"], + ) + config.trainer_resource_config.custom_trainer_config.CopyFrom( + copy.deepcopy(trainer_config) + ) + + wrapper = GiglResourceConfigWrapper(resource_config=config) + self.assertIsInstance( + wrapper.trainer_config, gigl_resource_config_pb2.CustomResourceConfig + ) + self.assertEqual(wrapper.trainer_config, trainer_config) + def test_trainer_config_missing(self): """Test that ValueError is raised when trainer config is missing.""" config = self._create_gigl_resource_config_with_direct_shared_config() @@ -355,6 +372,23 @@ def test_inferencer_config_vertex_ai_graph_store(self): wrapper = GiglResourceConfigWrapper(resource_config=config) self.assertEqual(wrapper.inferencer_config, inferencer_config) + def test_inferencer_config_custom(self): + """Test inferencer_config with Custom (user-supplied launcher) configuration.""" + config = self._create_gigl_resource_config_with_direct_shared_config() + inferencer_config = gigl_resource_config_pb2.CustomResourceConfig( + command="python -m my_project.launchers.ray.launch", + args=["--cluster=prod", "--shards=8"], + ) + config.inferencer_resource_config.custom_inferencer_config.CopyFrom( + copy.deepcopy(inferencer_config) + ) + + wrapper = GiglResourceConfigWrapper(resource_config=config) + self.assertIsInstance( + wrapper.inferencer_config, gigl_resource_config_pb2.CustomResourceConfig + ) + self.assertEqual(wrapper.inferencer_config, inferencer_config) + def test_inferencer_config_missing(self): """Test that ValueError is raised when inferencer config is missing.""" config = self._create_gigl_resource_config_with_direct_shared_config() diff --git a/tests/unit/src/inference/v2/__init__.py b/tests/unit/src/inference/v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/src/inference/v2/glt_inferencer_test.py b/tests/unit/src/inference/v2/glt_inferencer_test.py new file mode 100644 index 000000000..d4c8c2cdc --- /dev/null +++ b/tests/unit/src/inference/v2/glt_inferencer_test.py @@ -0,0 +1,131 @@ +"""Tests for ``gigl.src.inference.v2.glt_inferencer`` dispatch wiring. + +Covers the CustomResourceConfig branch added alongside the existing +VertexAiResourceConfig / VertexAiGraphStoreConfig branches. VAI dispatch is +exercised by existing integration flows; these tests only assert that +CustomResourceConfig reaches ``launch_custom`` with the expected kwargs. +""" + +from typing import Final +from unittest.mock import patch + +from absl.testing import absltest + +from gigl.common import Uri +from gigl.src.common.constants.components import GiGLComponents +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( + GiglResourceConfigWrapper, +) +from gigl.src.inference.v2.glt_inferencer import GLTInferencer +from snapchat.research.gbml import gbml_config_pb2, gigl_resource_config_pb2 +from tests.test_assets.test_case import TestCase + +_PROCESS_COMMAND: Final[str] = "python -m gigl.src.inference.v2.glt_inferencer" + + +def _build_resource_config_with_custom_inferencer() -> ( + gigl_resource_config_pb2.GiglResourceConfig +): + shared = gigl_resource_config_pb2.SharedResourceConfig( + resource_labels={ + "env": "test", + "cost_resource_group_tag": "unittest_COMPONENT", + "cost_resource_group": "gigl_test", + }, + common_compute_config=( + gigl_resource_config_pb2.SharedResourceConfig.CommonComputeConfig( + project="test-project", + region="us-central1", + temp_assets_bucket="gs://test-temp-bucket", + temp_regional_assets_bucket="gs://test-temp-regional-bucket", + perm_assets_bucket="gs://test-perm-bucket", + temp_assets_bq_dataset_name="test_temp_dataset", + embedding_bq_dataset_name="test_embeddings_dataset", + gcp_service_account_email=( + "test-sa@test-project.iam.gserviceaccount.com" + ), + dataflow_runner="DataflowRunner", + ) + ), + ) + resource_config = gigl_resource_config_pb2.GiglResourceConfig( + shared_resource_config=shared, + ) + resource_config.inferencer_resource_config.custom_inferencer_config.CopyFrom( + gigl_resource_config_pb2.CustomResourceConfig( + command="python -m my_project.launchers.ray.launch", + args=["--cluster=dev", "--num_workers=8"], + ) + ) + return resource_config + + +def _build_gbml_config_with_inferencer_command() -> gbml_config_pb2.GbmlConfig: + return gbml_config_pb2.GbmlConfig( + inferencer_config=gbml_config_pb2.GbmlConfig.InferencerConfig( + command=_PROCESS_COMMAND, + inferencer_args={"batch_size": "64"}, + ), + ) + + +class TestGLTInferencerCustomDispatch(TestCase): + """Asserts CustomResourceConfig routes to ``launch_custom``.""" + + @patch("gigl.src.inference.v2.glt_inferencer.launch_custom") + @patch( + "gigl.src.inference.v2.glt_inferencer.GbmlConfigPbWrapper" + ".get_gbml_config_pb_wrapper_from_uri" + ) + @patch("gigl.src.inference.v2.glt_inferencer.get_resource_config") + def test_custom_resource_config_dispatches_to_launch_custom( + self, + mock_get_resource_config, + mock_get_gbml, + mock_launch_custom, + ): + resource_config = _build_resource_config_with_custom_inferencer() + mock_get_resource_config.return_value = GiglResourceConfigWrapper( + resource_config=resource_config + ) + mock_get_gbml.return_value = GbmlConfigPbWrapper( + gbml_config_pb=_build_gbml_config_with_inferencer_command() + ) + + task_uri = Uri("gs://bucket/task.yaml") + resource_uri = Uri("gs://bucket/resource.yaml") + GLTInferencer().run( + applied_task_identifier=AppliedTaskIdentifier("job-99"), + task_config_uri=task_uri, + resource_config_uri=resource_uri, + cpu_docker_uri="gcr.io/p/cpu:tag", + cuda_docker_uri="gcr.io/p/cuda:tag", + ) + + mock_launch_custom.assert_called_once() + call_kwargs = mock_launch_custom.call_args.kwargs + self.assertEqual(call_kwargs["component"], GiGLComponents.Inferencer) + self.assertEqual(call_kwargs["applied_task_identifier"], "job-99") + self.assertEqual(call_kwargs["task_config_uri"], task_uri) + self.assertEqual(call_kwargs["resource_config_uri"], resource_uri) + self.assertEqual(call_kwargs["process_command"], _PROCESS_COMMAND) + self.assertEqual( + dict(call_kwargs["process_runtime_args"]), + {"batch_size": "64"}, + ) + self.assertEqual(call_kwargs["cpu_docker_uri"], "gcr.io/p/cpu:tag") + self.assertEqual(call_kwargs["cuda_docker_uri"], "gcr.io/p/cuda:tag") + self.assertFalse(call_kwargs["is_dry_run"]) + + forwarded = call_kwargs["custom_resource_config"] + self.assertEqual(forwarded.command, "python -m my_project.launchers.ray.launch") + self.assertEqual( + list(forwarded.args), + ["--cluster=dev", "--num_workers=8"], + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/unit/src/training/v2/__init__.py b/tests/unit/src/training/v2/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/src/training/v2/glt_trainer_test.py b/tests/unit/src/training/v2/glt_trainer_test.py new file mode 100644 index 000000000..f4b6dc2e3 --- /dev/null +++ b/tests/unit/src/training/v2/glt_trainer_test.py @@ -0,0 +1,134 @@ +"""Tests for ``gigl.src.training.v2.glt_trainer`` dispatch wiring. + +Covers the CustomResourceConfig branch added alongside the existing +VertexAiResourceConfig / VertexAiGraphStoreConfig branches. VAI dispatch +is exercised by existing integration flows; these tests only assert that +CustomResourceConfig reaches ``launch_custom`` with the expected kwargs. +""" + +from typing import Final +from unittest.mock import patch + +from absl.testing import absltest + +from gigl.common import Uri +from gigl.src.common.constants.components import GiGLComponents +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( + GiglResourceConfigWrapper, +) +from gigl.src.training.v2.glt_trainer import GLTTrainer +from snapchat.research.gbml import gbml_config_pb2, gigl_resource_config_pb2 +from tests.test_assets.test_case import TestCase + +_PROCESS_COMMAND: Final[str] = "python -m gigl.src.training.v2.glt_trainer" + + +def _build_resource_config_with_custom_trainer() -> ( + gigl_resource_config_pb2.GiglResourceConfig +): + shared = gigl_resource_config_pb2.SharedResourceConfig( + resource_labels={ + "env": "test", + "cost_resource_group_tag": "unittest_COMPONENT", + "cost_resource_group": "gigl_test", + }, + common_compute_config=( + gigl_resource_config_pb2.SharedResourceConfig.CommonComputeConfig( + project="test-project", + region="us-central1", + temp_assets_bucket="gs://test-temp-bucket", + temp_regional_assets_bucket="gs://test-temp-regional-bucket", + perm_assets_bucket="gs://test-perm-bucket", + temp_assets_bq_dataset_name="test_temp_dataset", + embedding_bq_dataset_name="test_embeddings_dataset", + gcp_service_account_email=( + "test-sa@test-project.iam.gserviceaccount.com" + ), + dataflow_runner="DataflowRunner", + ) + ), + ) + resource_config = gigl_resource_config_pb2.GiglResourceConfig( + shared_resource_config=shared, + ) + resource_config.trainer_resource_config.custom_trainer_config.CopyFrom( + gigl_resource_config_pb2.CustomResourceConfig( + command="python -m my_project.launchers.ray.launch", + args=["--cluster=dev", "--num_workers=4"], + ) + ) + return resource_config + + +def _build_gbml_config_with_trainer_command() -> gbml_config_pb2.GbmlConfig: + return gbml_config_pb2.GbmlConfig( + trainer_config=gbml_config_pb2.GbmlConfig.TrainerConfig( + command=_PROCESS_COMMAND, + trainer_args={"lr": "0.01", "epochs": "5"}, + ), + ) + + +class TestGLTTrainerCustomDispatch(TestCase): + """Asserts CustomResourceConfig routes to ``launch_custom``.""" + + @patch("gigl.src.training.v2.glt_trainer.launch_custom") + @patch( + "gigl.src.training.v2.glt_trainer.GbmlConfigPbWrapper" + ".get_gbml_config_pb_wrapper_from_uri" + ) + @patch("gigl.src.training.v2.glt_trainer.get_resource_config") + def test_custom_resource_config_dispatches_to_launch_custom( + self, + mock_get_resource_config, + mock_get_gbml, + mock_launch_custom, + ): + resource_config = _build_resource_config_with_custom_trainer() + mock_get_resource_config.return_value = GiglResourceConfigWrapper( + resource_config=resource_config + ) + mock_get_gbml.return_value = GbmlConfigPbWrapper( + gbml_config_pb=_build_gbml_config_with_trainer_command() + ) + + task_uri = Uri("gs://bucket/task.yaml") + resource_uri = Uri("gs://bucket/resource.yaml") + GLTTrainer().run( + applied_task_identifier=AppliedTaskIdentifier("job-77"), + task_config_uri=task_uri, + resource_config_uri=resource_uri, + cpu_docker_uri="gcr.io/p/cpu:tag", + cuda_docker_uri="gcr.io/p/cuda:tag", + ) + + mock_launch_custom.assert_called_once() + call_kwargs = mock_launch_custom.call_args.kwargs + self.assertEqual(call_kwargs["component"], GiGLComponents.Trainer) + self.assertEqual(call_kwargs["applied_task_identifier"], "job-77") + self.assertEqual(call_kwargs["task_config_uri"], task_uri) + self.assertEqual(call_kwargs["resource_config_uri"], resource_uri) + self.assertEqual(call_kwargs["process_command"], _PROCESS_COMMAND) + # trainer_args is a proto ScalarMap; compare by equality to plain dict. + self.assertEqual( + dict(call_kwargs["process_runtime_args"]), + {"lr": "0.01", "epochs": "5"}, + ) + self.assertEqual(call_kwargs["cpu_docker_uri"], "gcr.io/p/cpu:tag") + self.assertEqual(call_kwargs["cuda_docker_uri"], "gcr.io/p/cuda:tag") + self.assertFalse(call_kwargs["is_dry_run"]) + + # The forwarded CustomResourceConfig matches what we put in the + # resource config. + forwarded = call_kwargs["custom_resource_config"] + self.assertEqual(forwarded.command, "python -m my_project.launchers.ray.launch") + self.assertEqual( + list(forwarded.args), + ["--cluster=dev", "--num_workers=4"], + ) + + +if __name__ == "__main__": + absltest.main() diff --git a/tests/unit/src/validation/lib/gbml_and_resource_config_compatibility_checks_test.py b/tests/unit/src/validation/lib/gbml_and_resource_config_compatibility_checks_test.py index c70450501..3ae8485e9 100644 --- a/tests/unit/src/validation/lib/gbml_and_resource_config_compatibility_checks_test.py +++ b/tests/unit/src/validation/lib/gbml_and_resource_config_compatibility_checks_test.py @@ -5,12 +5,18 @@ GiglResourceConfigWrapper, ) from gigl.src.validation_check.libs.gbml_and_resource_config_compatibility_checks import ( + check_custom_resource_config_requires_glt_backend, check_inferencer_graph_store_compatibility, check_trainer_graph_store_compatibility, ) from snapchat.research.gbml import gbml_config_pb2, gigl_resource_config_pb2 from tests.test_assets.test_case import TestCase +# Placeholder shell snippet used by CustomResourceConfig fixtures in this +# module β€” these tests only exercise type-of-config dispatch, not actual +# subprocess execution. +_FAKE_COMMAND = "echo fake" + # Helper functions for creating VertexAiGraphStoreConfig @@ -214,5 +220,97 @@ def test_resource_has_inferencer_graph_store_template_does_not(self): ) +# Helper functions for custom + glt-backend compatibility tests + + +def _create_gbml_config_with_glt_flag(value: str) -> GbmlConfigPbWrapper: + """Create a GbmlConfig whose feature_flags.should_run_glt_backend is set. + + Note the raw YAML key is ``should_run_glt_backend`` (not + ``should_use_glt_backend``). The wrapper's ``should_use_glt_backend`` + property reads this key from the ``feature_flags`` map and converts it to + a bool. + """ + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.feature_flags["should_run_glt_backend"] = value + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + +def _create_resource_config_with_custom_trainer() -> GiglResourceConfigWrapper: + """Create a GiglResourceConfig whose trainer is a CustomResourceConfig.""" + config = gigl_resource_config_pb2.GiglResourceConfig() + _create_shared_resource_config(config) + config.trainer_resource_config.custom_trainer_config.command = _FAKE_COMMAND + # Inferencer uses a built-in config so only the trainer path is custom. + config.inferencer_resource_config.vertex_ai_inferencer_config.CopyFrom( + _create_vertex_ai_resource_config() + ) + return GiglResourceConfigWrapper(resource_config=config) + + +def _create_resource_config_with_custom_inferencer() -> GiglResourceConfigWrapper: + """Create a GiglResourceConfig whose inferencer is a CustomResourceConfig.""" + config = gigl_resource_config_pb2.GiglResourceConfig() + _create_shared_resource_config(config) + config.trainer_resource_config.vertex_ai_trainer_config.CopyFrom( + _create_vertex_ai_resource_config() + ) + config.inferencer_resource_config.custom_inferencer_config.command = _FAKE_COMMAND + return GiglResourceConfigWrapper(resource_config=config) + + +class TestCustomResourceConfigRequiresGltBackend(TestCase): + """Test suite for the CustomResourceConfig + GLT-backend compatibility guard. + + Because v1 trainer/inferencer dispatchers don't consult the custom oneof, + pairing a ``CustomResourceConfig`` with + ``feature_flags.should_run_glt_backend = "False"`` must be caught at + validation time rather than at runtime. + """ + + def test_custom_trainer_without_glt_raises(self): + """CustomResourceConfig trainer + glt=False raises a clear ValueError.""" + gbml_config = _create_gbml_config_with_glt_flag("False") + resource_config = _create_resource_config_with_custom_trainer() + with self.assertRaises(ValueError) as ctx: + check_custom_resource_config_requires_glt_backend( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + self.assertIn("should_run_glt_backend", str(ctx.exception)) + self.assertIn("custom_trainer_config", str(ctx.exception)) + + def test_custom_inferencer_without_glt_raises(self): + """CustomResourceConfig inferencer + glt=False raises a clear ValueError.""" + gbml_config = _create_gbml_config_with_glt_flag("False") + resource_config = _create_resource_config_with_custom_inferencer() + with self.assertRaises(ValueError) as ctx: + check_custom_resource_config_requires_glt_backend( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + self.assertIn("custom_inferencer_config", str(ctx.exception)) + + def test_custom_trainer_with_glt_passes(self): + """CustomResourceConfig trainer + glt=True passes validation.""" + gbml_config = _create_gbml_config_with_glt_flag("True") + resource_config = _create_resource_config_with_custom_trainer() + # Should not raise any exception + check_custom_resource_config_requires_glt_backend( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_no_custom_config_passes_without_glt(self): + """No CustomResourceConfig at all passes regardless of glt flag.""" + gbml_config = _create_gbml_config_with_glt_flag("False") + resource_config = _create_resource_config_without_graph_stores() + # Should not raise any exception: no custom oneof means nothing to enforce. + check_custom_resource_config_requires_glt_backend( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + if __name__ == "__main__": absltest.main() diff --git a/tests/unit/src/validation/lib/resource_config_checks_test.py b/tests/unit/src/validation/lib/resource_config_checks_test.py index 7f00b7102..1f3ba7371 100644 --- a/tests/unit/src/validation/lib/resource_config_checks_test.py +++ b/tests/unit/src/validation/lib/resource_config_checks_test.py @@ -1,11 +1,16 @@ +from unittest.mock import patch + from absl.testing import absltest +from gigl.common import Uri +from gigl.src.common.constants.components import GiGLComponents from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.validation_check.libs.resource_config_checks import ( _check_if_dataflow_resource_config_valid, _check_if_spark_resource_config_valid, _validate_accelerator_type, _validate_machine_config, + check_if_custom_resource_config_dry_run_valid, check_if_inferencer_graph_store_storage_command_valid, check_if_inferencer_resource_config_valid, check_if_preprocessor_resource_config_valid, @@ -18,6 +23,11 @@ from snapchat.research.gbml import gbml_config_pb2, gigl_resource_config_pb2 from tests.test_assets.test_case import TestCase +# Placeholder shell snippet used by CustomResourceConfig fixtures β€” +# subprocess invocation is patched in the dry-run tests below, so the +# command never actually executes. +_FAKE_COMMAND = "echo fake" + # Helper functions for creating valid configurations @@ -204,6 +214,28 @@ def _create_valid_local_inferencer_config() -> ( return config +def _create_valid_custom_trainer_config( + command: str = _FAKE_COMMAND, + args: list[str] | None = None, +) -> gigl_resource_config_pb2.GiglResourceConfig: + """Create a GiglResourceConfig with a CustomResourceConfig trainer.""" + config = gigl_resource_config_pb2.GiglResourceConfig() + config.trainer_resource_config.custom_trainer_config.command = command + config.trainer_resource_config.custom_trainer_config.args.extend(args or []) + return config + + +def _create_valid_custom_inferencer_config( + command: str = _FAKE_COMMAND, + args: list[str] | None = None, +) -> gigl_resource_config_pb2.GiglResourceConfig: + """Create a GiglResourceConfig with a CustomResourceConfig inferencer.""" + config = gigl_resource_config_pb2.GiglResourceConfig() + config.inferencer_resource_config.custom_inferencer_config.command = command + config.inferencer_resource_config.custom_inferencer_config.args.extend(args or []) + return config + + def _create_valid_vertex_ai_config() -> gigl_resource_config_pb2.VertexAiResourceConfig: """Create a valid Vertex AI resource configuration.""" config = gigl_resource_config_pb2.VertexAiResourceConfig() @@ -791,6 +823,155 @@ def test_no_graph_store_config(self): check_if_inferencer_graph_store_storage_command_valid(gbml_config) +class TestCustomResourceConfigBypass(TestCase): + """Test suite for CustomResourceConfig caller-level bypass. + + ``CustomResourceConfig`` is launcher-pluggable: it has no concrete machine + shape to validate. The callers (``check_if_trainer_resource_config_valid`` + and ``check_if_inferencer_resource_config_valid``) short-circuit before + reaching ``_validate_machine_config``, which keeps that helper's contract + ("validate a concrete machine spec") intact. + """ + + def test_trainer_custom_config_bypasses_machine_validation(self): + """CustomResourceConfig trainer bypasses _validate_machine_config entirely.""" + config = _create_valid_custom_trainer_config( + args=["--cluster_size=4"] + ) + with patch( + "gigl.src.validation_check.libs.resource_config_checks._validate_machine_config" + ) as mock_validate: + check_if_trainer_resource_config_valid(resource_config_pb=config) + mock_validate.assert_not_called() + + def test_inferencer_custom_config_bypasses_machine_validation(self): + """CustomResourceConfig inferencer bypasses _validate_machine_config entirely.""" + config = _create_valid_custom_inferencer_config( + args=["--cluster_size=4"] + ) + with patch( + "gigl.src.validation_check.libs.resource_config_checks._validate_machine_config" + ) as mock_validate: + check_if_inferencer_resource_config_valid(resource_config_pb=config) + mock_validate.assert_not_called() + + def test_vertex_ai_trainer_still_calls_machine_validation(self): + """Sanity: non-custom trainer still dispatches to _validate_machine_config.""" + config = _create_valid_vertex_ai_trainer_config() + with patch( + "gigl.src.validation_check.libs.resource_config_checks._validate_machine_config" + ) as mock_validate: + check_if_trainer_resource_config_valid(resource_config_pb=config) + mock_validate.assert_called_once() + + def test_vertex_ai_inferencer_still_calls_machine_validation(self): + """Sanity: non-custom inferencer still dispatches to _validate_machine_config.""" + config = _create_valid_vertex_ai_inferencer_config() + with patch( + "gigl.src.validation_check.libs.resource_config_checks._validate_machine_config" + ) as mock_validate: + check_if_inferencer_resource_config_valid(resource_config_pb=config) + mock_validate.assert_called_once() + + def test_empty_command_raises_via_dry_run(self): + """Empty command is caught by launch_custom's guard at dry-run time.""" + config = _create_valid_custom_trainer_config(command="") + with self.assertRaises(ValueError): + check_if_custom_resource_config_dry_run_valid( + resource_config_pb=config, + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + applied_task_identifier="job-empty-command", + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Trainer, + ) + + +class TestCustomResourceConfigDryRun(TestCase): + """Test suite for ``check_if_custom_resource_config_dry_run_valid``. + + Dry-run dispatch flows through ``launch_custom``, which logs the + resolved shell line and returns *before* spawning a subprocess. Tests + patch ``subprocess.run`` to assert it is never invoked on the dry-run + path. + """ + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_trainer_dry_run_does_not_spawn_subprocess(self, mock_run): + """Dry-run logs the resolved shell line; subprocess.run is not called.""" + config = _create_valid_custom_trainer_config(args=["--cluster_size=4"]) + check_if_custom_resource_config_dry_run_valid( + resource_config_pb=config, + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + applied_task_identifier="job-trainer-dry-run", + cpu_docker_uri="gcr.io/p/cpu:tag", + cuda_docker_uri="gcr.io/p/cuda:tag", + component=GiGLComponents.Trainer, + ) + mock_run.assert_not_called() + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_inferencer_dry_run_does_not_spawn_subprocess(self, mock_run): + """Symmetric to the trainer case.""" + config = _create_valid_custom_inferencer_config(args=["--cluster_size=8"]) + check_if_custom_resource_config_dry_run_valid( + resource_config_pb=config, + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + applied_task_identifier="job-inferencer-dry-run", + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Inferencer, + ) + mock_run.assert_not_called() + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_non_custom_trainer_is_no_op(self, mock_run): + """Non-custom trainer config is a no-op (subprocess never invoked).""" + config = _create_valid_vertex_ai_trainer_config() + check_if_custom_resource_config_dry_run_valid( + resource_config_pb=config, + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + applied_task_identifier="job-noop", + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Trainer, + ) + mock_run.assert_not_called() + + @patch("gigl.src.common.custom_launcher.subprocess.run") + def test_non_custom_inferencer_is_no_op(self, mock_run): + """Non-custom inferencer config is a no-op (subprocess never invoked).""" + config = _create_valid_vertex_ai_inferencer_config() + check_if_custom_resource_config_dry_run_valid( + resource_config_pb=config, + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + applied_task_identifier="job-noop", + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.Inferencer, + ) + mock_run.assert_not_called() + + def test_unsupported_component_raises(self): + """Only Trainer and Inferencer are supported; other components raise ValueError.""" + config = _create_valid_custom_trainer_config() + with self.assertRaises(ValueError): + check_if_custom_resource_config_dry_run_valid( + resource_config_pb=config, + task_config_uri=Uri("gs://bucket/task.yaml"), + resource_config_uri=Uri("gs://bucket/resource.yaml"), + applied_task_identifier="job-bad-component", + cpu_docker_uri=None, + cuda_docker_uri=None, + component=GiGLComponents.DataPreprocessor, + ) + + class TestReservationAffinityValidation(TestCase): """Validate VertexAiResourceConfig.reservation_affinity handling."""