diff --git a/.github/actions/cross-repo-ci-relay-callback/action.yml b/.github/actions/cross-repo-ci-relay-callback/action.yml new file mode 100644 index 0000000000..36bced27b8 --- /dev/null +++ b/.github/actions/cross-repo-ci-relay-callback/action.yml @@ -0,0 +1,179 @@ +name: Cross-Repo CI Relay Callback + +description: > + Report the status of a downstream CI workflow back to the Cross-Repo CI + Relay server. The job must have `id-token: write` permission so that a + GitHub OIDC token can be minted and used to authenticate the callback. + + This action is meant to run in a workflow triggered by a `repository_dispatch` + event from the relay. It reads the dispatch payload (`github.event.client_payload`) + and the ambient `github` context directly, so workflow authors only need to + supply the relay URL, the status/conclusion, and optional structured test + results. + +inputs: + status: + description: > + Workflow status to report. Must be either "in_progress" or "completed". + required: true + conclusion: + description: > + Conclusion of the workflow run. Required (and must be "success" or + "failure") when status is "completed". Ignored when status is + "in_progress". + required: false + default: '' + test-results: + description: > + Optional JSON string with test result summary (counts: passed/failed/skipped). + Note: This should be a summary only, not a full enumeration of all test cases. + Full test results should be uploaded as artifacts and referenced via `artifact-url`. + required: false + default: '' + callback-url: + description: > + Base URL of the result callback server. + required: true + artifact-url: + description: > + URL to downstream-hosted artifacts (logs, reports, results), + any publicly accessible URL. + required: false + default: '' + max-time: + description: > + Maximum time in seconds to wait for the callback HTTP request to complete. + required: false + default: 10 + max-retries: + description: > + Maximum number of retries for the callback HTTP request in case of failure. + required: false + default: 3 + retry-delay: + description: > + Delay in seconds between retries for the callback HTTP request. + required: false + default: 2 + +runs: + using: composite + steps: + - name: Mint OIDC token + id: oidc + uses: actions/github-script@v7 + with: + script: | + const token = await core.getIDToken("pytorch-cross-repo-ci-relay"); + core.setSecret(token); + core.setOutput('token', token); + + - name: Send callback to relay server + shell: bash + env: + SCHEMA_VERSION: 1 + STATUS: ${{ inputs.status }} + CONCLUSION: ${{ inputs.conclusion }} + WORKFLOW_NAME: ${{ github.workflow }} + WORKFLOW_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + TEST_RESULTS: ${{ inputs.test-results }} + CLIENT_PAYLOAD: ${{ toJson(github.event.client_payload) }} + OIDC_TOKEN: ${{ steps.oidc.outputs.token }} + CALLBACK_URL: ${{ inputs.callback-url }} + ARTIFACT_URL: ${{ inputs.artifact-url }} + JOB_NAME: ${{ github.job }} + CHECK_RUN_ID: ${{ job.check_run_id }} + RUN_ID: ${{ github.run_id }} + RUN_ATTEMPT: ${{ github.run_attempt }} + MAX_TIME: ${{ inputs.max-time }} + MAX_RETRIES: ${{ inputs.max-retries }} + RETRY_DELAY: ${{ inputs.retry-delay }} + run: | + set -euo pipefail + + PAYLOAD=$(python3 - <<'PYEOF' + import json, os, sys + from datetime import datetime, timezone + + status = os.environ["STATUS"] + if status not in ("in_progress", "completed"): + sys.exit(f"::error::status must be 'in_progress' or 'completed', got {status!r}") + + conclusion = os.environ.get("CONCLUSION", "").strip() or None + if status == "completed" and conclusion not in ("success", "failure"): + sys.exit("::error::conclusion must be 'success' or 'failure' when status is 'completed'") + if status == "in_progress": + conclusion = None + + try: + client_payload = json.loads(os.environ["CLIENT_PAYLOAD"]) + except json.JSONDecodeError as exc: + sys.exit(f"::error::github.event.client_payload is not valid JSON: {exc}") + + current_time = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + + # In case check_run_id is not exist (edge case), replace it + # with {run_id}-{run_attempt}, which is also unique for each job run. + check_run_id = os.environ.get("CHECK_RUN_ID", "").strip() + if not check_run_id: + check_run_id = f"{os.environ['RUN_ID']}-{os.environ['RUN_ATTEMPT']}" + + # Relay's original dispatch payload (event_type, delivery_id, payload) is + # forwarded verbatim. Downstream-reported fields live in a sibling + # `workflow` dict so the two sources stay clearly separated on the wire. + workflow: dict = { + "schema_version": str(os.environ["SCHEMA_VERSION"]), + "status": status, + "conclusion": conclusion, + "name": os.environ["WORKFLOW_NAME"], + "url": os.environ["WORKFLOW_URL"], + "run_attempt": os.environ["RUN_ATTEMPT"], + "job_name": os.environ["JOB_NAME"], + "check_run_id": check_run_id, + "run_id": str(os.environ["RUN_ID"]), + "started_at": None if status == "completed" else current_time, + "completed_at": None if status == "in_progress" else current_time, + } + + test_results = os.environ.get("TEST_RESULTS", "").strip() + if test_results: + try: + workflow["test_results"] = json.loads(test_results) + except json.JSONDecodeError as exc: + sys.exit(f"::error::test-results input is not valid JSON: {exc}") + + artifact_url = os.environ.get("ARTIFACT_URL", "").strip() + if artifact_url: + workflow["artifact_url"] = artifact_url + + client_payload["workflow"] = workflow + print(json.dumps(client_payload)) + PYEOF + ) + + set +e + HTTP_CODE=$( + curl --silent --show-error --fail-with-body --output /tmp/relay_response.json \ + --write-out "%{http_code}" \ + -X POST \ + --max-time ${MAX_TIME} \ + --retry ${MAX_RETRIES} \ + --retry-delay ${RETRY_DELAY} \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${OIDC_TOKEN}" \ + --data "${PAYLOAD}" \ + "${CALLBACK_URL%/}" + ) + CURL_EXIT_CODE=$? + set -e + + if [[ "${CURL_EXIT_CODE}" -ne 0 ]]; then + echo "::error::Callback server returned HTTP ${HTTP_CODE}." + if [[ -s /tmp/relay_response.json ]]; then + echo "Relay server error response body:" + cat /tmp/relay_response.json + fi + exit "${CURL_EXIT_CODE}" + fi + + echo "Relay server response HTTP: ${HTTP_CODE}" diff --git a/.github/workflows/_lambda-do-release-runners.yml b/.github/workflows/_lambda-do-release-runners.yml index 07c33be885..02500dd143 100644 --- a/.github/workflows/_lambda-do-release-runners.yml +++ b/.github/workflows/_lambda-do-release-runners.yml @@ -93,7 +93,8 @@ jobs: { dir-name: 'keep-going-call-log-classifier', zip-name: 'keep-going-call-log-classifier' }, { dir-name: 'buildkite-webhook-handler', zip-name: 'buildkite-webhook-handler' }, { dir-name: 'benchmark_regression_summary_report', zip-name: 'benchmark-regression-summary-report' }, - { dir-name: 'cross_repo_ci_relay', zip-name: 'cross-repo-ci-webhook' }, + { dir-name: 'cross_repo_ci_relay/callback', zip-name: 'cross-repo-ci-callback' }, + { dir-name: 'cross_repo_ci_relay/webhook', zip-name: 'cross-repo-ci-webhook' }, ] name: Upload Release for ${{ matrix.dir-name }} lambda runs-on: ubuntu-latest diff --git a/aws/lambda/cross_repo_ci_relay/Makefile b/aws/lambda/cross_repo_ci_relay/Makefile index c1f07bbc56..39fd320825 100644 --- a/aws/lambda/cross_repo_ci_relay/Makefile +++ b/aws/lambda/cross_repo_ci_relay/Makefile @@ -1,21 +1,24 @@ -SHARED := config.py utils.py redis_helper.py allowlist.py gh_helper.py event_handler.py -PIP_FLAGS := --platform manylinux2014_x86_64 --only-binary=:all: --implementation cp --python-version 3.13 -AWS_REGION := us-east-1 -FUNCTION_NAME := cross_repo_ci_webhook +AWS_REGION ?= us-east-1 +CALLBACK_FUNCTION_NAME ?= cross_repo_ci_callback +WEBHOOK_FUNCTION_NAME ?= cross_repo_ci_webhook -deployment.zip: clean - mkdir -p ./deployment - cp $(SHARED) lambda_function.py ./deployment/ - pip3 install --target ./deployment -r requirements.txt $(PIP_FLAGS) - cd deployment && zip -r ../deployment.zip . - -deploy: deployment.zip - aws lambda update-function-code --region $(AWS_REGION) --function-name $(FUNCTION_NAME) --zip-file fileb://deployment.zip +.PHONY: test deploy deploy-callback deploy-webhook clean test: python3 -m pytest tests -v -clean: - rm -rf deployment deployment.zip +# Deploy both; keep going on failure so a broken half doesn't block the other. +deploy: + @rc=0; \ + for t in deploy-callback deploy-webhook; do $(MAKE) $$t || rc=$$?; done; \ + exit $$rc + +deploy-callback: + $(MAKE) -C callback deploy AWS_REGION=$(AWS_REGION) FUNCTION_NAME=$(CALLBACK_FUNCTION_NAME) -.PHONY: prepare deploy test clean +deploy-webhook: + $(MAKE) -C webhook deploy AWS_REGION=$(AWS_REGION) FUNCTION_NAME=$(WEBHOOK_FUNCTION_NAME) + +clean: + $(MAKE) -C callback clean + $(MAKE) -C webhook clean diff --git a/aws/lambda/cross_repo_ci_relay/README.md b/aws/lambda/cross_repo_ci_relay/README.md index 2df4335816..d7d8187960 100644 --- a/aws/lambda/cross_repo_ci_relay/README.md +++ b/aws/lambda/cross_repo_ci_relay/README.md @@ -1,6 +1,6 @@ # Cross Repo CI Relay -An AWS Lambda function that relays GitHub webhook events from the upstream repository to downstream repositories. +An AWS Lambda function that relays GitHub webhook events from the upstream repository to downstream repositories, and forwards downstream CI results to HUD. For more information, please refer to this [RFC](https://github.com/pytorch/pytorch/issues/175022). @@ -31,33 +31,219 @@ Each entry is either a plain `owner/repo` string or a `owner/repo: oncall1, onca The allowlist is cached in Redis under the key `crcr:allowlist_yaml` with a TTL controlled by `ALLOWLIST_TTL_SECONDS`. On a Redis error the function falls back to fetching directly from GitHub. +## Reporting Results from Downstream CI + +L2+ downstream repositories can report the status of their CI workflows back to the relay server using the [`cross-repo-ci-relay-callback`](../../../.github/actions/cross-repo-ci-relay-callback/action.yml) composite action. + +### Security and the Relay/HUD boundary + +The callback endpoint validates incoming callbacks and forwards them to HUD for persistence. The relay is the gatekeeper for OIDC authentication, allowlist checks, rate limiting, and schema validation — HUD just authenticates the relay and writes what it's told. + +#### Relay's responsibilities: + +- **Identity**: the `Authorization: Bearer ` header is verified against GitHub's JWKS. The OIDC `repository` claim is a trusted identity for the caller and is used for the L2+ allowlist check. Relay forwards this trusted value to HUD as a top-level `verified_repo` field; HUD should prefer it over anything self-reported in `callback_payload`. +- **Repo level**: Relay determines the downstream repository's allowlist level (L1–L4) and forwards it to HUD as `downstream_repo_level`. This authoritative level information is determined once by the relay, ensuring HUD doesn't need to recompute it and avoiding synchronization/timing issues if tiering information becomes dynamic. +- **Schema validation**: Relay validates that required fields (`delivery_id` and `workflow.status`) are present in the callback body. Missing fields result in a `400` error to signal contract violations to the caller. HUD receives validated data and does not need to perform schema checks. +- **State machine**: Relay maintains a **unified state machine** in Redis to validate callback lifecycles, compute timing metrics, and support per-job tracking: + - **Unified structure**: Single enum `CallbackState` with states `DISPATCHED` (webhook side, keyed by sentinel `check_run_id="dispatched"`), `IN_PROGRESS`, and `COMPLETED` (callback side, per-job). State records stored as JSON: `{"state": "...", "timestamp": 1234.56, "job_name": "...", "run_id": "..."}`. + - **Dispatch validation**: `DISPATCHED` state proves valid webhook origin. Callbacks without this state are rejected (no prior dispatch). + - **Job-level tracking**: Each job has independent state and timestamps keyed by `check_run_id` (`oot:state:{delivery_id}:{repo}:{check_run_id}`). Supports multiple jobs per webhook. + - **Timing metrics**: `queue_time = dispatch_timestamp → in_progress_timestamp`, `execution_time = in_progress_timestamp → completed_timestamp`. Timestamps extracted from state records. + - **State transitions**: Rejects invalid flows (`COMPLETED` without prior `IN_PROGRESS`, duplicate `IN_PROGRESS` for the same `check_run_id`, duplicate `COMPLETED`, callbacks without a prior `DISPATCHED` record). + Note that the direction graph below is for a single check run, reruns have different `check_run_id` and are treated as separate jobs, so they won't violate the state machine since they won't have a prior `IN_PROGRESS` or `COMPLETED` record. + ```mermaid + stateDiagram-v2 + direction LR + + [*] --> DISPATCHED: webhook sends + DISPATCHED --> IN_PROGRESS: first callback + IN_PROGRESS --> COMPLETED: completion + + IN_PROGRESS --> IN_PROGRESS: ❌ duplicate + DISPATCHED --> COMPLETED: ❌ skip IN_PROGRESS + COMPLETED --> COMPLETED: ❌ duplicate + COMPLETED --> IN_PROGRESS: ❌ wrong direction + [*] --> IN_PROGRESS: ❌ no dispatch + [*] --> COMPLETED: ❌ no dispatch + ``` + +The HUD request looks like (two top-level namespaces: `trusted` and `untrusted`): + +```json +{ + "trusted": { + "ci_metrics": { "queue_time": 1.23, "execution_time": null }, + "verified_repo": "org/repo", + "downstream_repo_level": "L2" + }, + "untrusted": { + "callback_payload": { + "event_type": "pull_request", + "delivery_id": "", + "payload": { ...original upstream webhook payload, verbatim... }, + "workflow": { + "schema_version": 1, + "status": "completed", + "conclusion": "success", + "name": "CI", + "url": "https://github.com/org/repo/actions/runs/123", + "job_name": "my-ci-job", + "started_at": "2026-05-04T20:48:28Z", // when status == in_progress, else None + "completed_at": "2026-05-04T21:23:45Z", // when status == completed, else None + "test_results": { "passed": 42, "failed": 3, "skipped": 5 }, + "artifact_url": "https://github.com/org/repo/actions/runs/123/artifacts" + } + } + } +} +``` + +Notes: +- `trusted` contains relay-generated fields the HUD can rely on (`ci_metrics`, `verified_repo`, and `downstream_repo_level`). +- `untrusted.callback_payload` contains the downstream-reported callback body; HUD should treat it as untrusted and prefer `trusted.verified_repo` for identity. + +Trust boundaries inside `untrusted.callback_payload`: + +- `untrusted.callback_payload.payload` is the upstream webhook payload, transparently forwarded — + trusted at dispatch time, but not re-verified on the callback. +- `untrusted.callback_payload.workflow` is **self-reported by the downstream CI** and is not + authenticated. Only `verified_repo` carries a cryptographic identity. + +### Error propagation back to the downstream workflow + +| HUD response | Relay behaviour | Effect on downstream CI step | +|---|---|---| +| `2xx` | record delivered | green | +| `4xx` (schema reject) | propagate same status | **red** — author must fix payload | +| `5xx` / network error | log + return | green — HUD outage is not the caller's fault | + +The asymmetry is deliberate: `4xx` means the caller sent something wrong and should see it; `5xx`/network means HUD or its infrastructure is broken and should not be surfaced as a red CI step across every L2+ repo. Operators are expected to alert on the `HUD forward failed` CloudWatch log pattern. + +#### Known limitations of this model + +A compromised or malicious maintainer of an allowlisted repo can: + +1. Fabricate `workflow.status` / `workflow.conclusion` values for upstream PRs their repo was never dispatched for — HUD will receive the row, but `verified_repo` always identifies the true caller. +2. Replay an older dispatched payload against the callback endpoint — there is no dispatch-side nonce. +3. Tamper with any field inside `callback_payload` — HUD must trust `verified_repo`, not the others. + +All three attacks are **scoped to the attacker's own OIDC-authenticated repo identity** — OIDC guarantees they cannot impersonate another allowlisted repo. Mitigation is operational: every HUD row carries `verified_repo`, so misbehaviour is observable, and the offending repo can be removed from `allowlist.yaml`. + +### Prerequisites + +- The downstream repository must be listed at level **L2 or higher** in the allowlist. +- The **calling job** must declare `permissions: id-token: write` so that the action can mint a GitHub OIDC token for authentication. + +### Usage + +When triggered by a relay `repository_dispatch`, the action automatically reads `github.event.client_payload` for `delivery_id` and the upstream webhook payload, and reads `github.workflow` / the current run URL for the workflow identity. Workflow authors only required to pass `status` (and `conclusion` when `status=completed`, others are optional). + +```yaml +on: + repository_dispatch: + types: [pull_request] + +jobs: + my-ci-job: + runs-on: ubuntu-latest + permissions: + id-token: write # required for OIDC token minting + contents: read + steps: + - name: Report in-progress to relay + uses: pytorch/test-infra/.github/actions/cross-repo-ci-relay-callback@main + with: + status: in_progress + + # ... your CI steps ... + + - name: Report final result to relay + if: always() + uses: pytorch/test-infra/.github/actions/cross-repo-ci-relay-callback@main + with: + status: completed + conclusion: ${{ job.status }} +``` + +### Inputs + +| Input | Required | Default | Description | +|---|---|---|---| +| `status` | **yes** | — | `in_progress` or `completed` | +| `conclusion` | no | `''` | `success` or `failure` (required when `status=completed`) | +| `test-results` | no | `''` | Optional JSON string with test result summary (counts: passed/failed/skipped) | +| `callback-url` | **yes** | — | Callback endpoint URL (production Lambda URL; set once at the workflow level) | +| `artifact-url` | no | `''` | URL to downstream-hosted artifacts (logs, reports, results) | + ## Build, Deploy, and Test +### Deployment layout + +The build packages each Lambda as a zip that preserves the package hierarchy: + +``` +deployment/ +├── webhook/ +│ ├── lambda_function.py +│ └── event_handler.py +├── callback/ +│ ├── lambda_function.py +│ └── callback_handler.py +└── utils/ + └── ... +``` + +This matches the layout used during local development and tests, so imports behave identically in both environments. Configure the AWS Lambda handlers as: + +- Webhook Lambda: `webhook.lambda_function.lambda_handler` +- Callback Lambda: `callback.lambda_function.lambda_handler` + ### Make Targets -Build the Lambda zip (output: deployment.zip) +Build the Webhook Lambda zip (output: `webhook/deployment.zip`): + +```bash +cd webhook +make deployment.zip +``` + +Build the Callback Lambda zip (output: `callback/deployment.zip`): + ```bash +cd callback make deployment.zip ``` -Deploy to AWS Lambda (requires AWS CLI v2 configured with permissions) +Deploy both zips to AWS Lambda (requires AWS CLI v2 with permissions): + +```bash +make deploy AWS_REGION=us-east-1 \ + WEBHOOK_FUNCTION_NAME=cross_repo_ci_webhook \ + CALLBACK_FUNCTION_NAME=cross_repo_ci_callback +``` + +Either side can be deployed independently: + ```bash -make deploy AWS_REGION=us-east-1 FUNCTION_NAME=crcr-prod-crcr-webhook +make deploy-webhook +make deploy-callback ``` -Run all unit tests under tests/ folder +Run all unit tests under `tests/`: + ```bash make test ``` -Clean build artifacts +Clean build artifacts: + ```bash make clean ``` ## Local Development -`local_server.py` wraps the Lambda handler in a FastAPI app so you can test the full cross-repo-ci-relay flow without deploying to AWS. +`local_server.py` wraps both Lambda handlers in a FastAPI app so you can test the full cross-repo-ci-relay flow without deploying to AWS. ### Prerequisites @@ -78,6 +264,12 @@ make clean smee --url https://smee.io/ --path /github/webhook --port 8000 ``` + CLI to forward GitHub callback callbacks to localhost (set this URL as `callback-url` in the downstream workflow): + ```bash + npm install -g smee-client + smee --url https://smee.io/ --path /github/callback --port 8000 + ``` + #### Remote - GitHub App settings (refer to this [RFC](https://github.com/pytorch/pytorch/issues/175022)) @@ -109,8 +301,11 @@ make clean REDIS_ENDPOINT=localhost:6379 REDIS_LOGIN=default: ALLOWLIST_TTL_SECONDS=1200 + + # HUD (local testing) + HUD_ENDPOINT= ``` - **Note**: `ALLOWLIST_URL` is required for local development which should point to a GitHub URL that can be different from the real one. + **Note**: `ALLOWLIST_URL` is required for local development and should point to a GitHub URL (it can differ from the production one). 3. Start the server: ```bash @@ -118,3 +313,5 @@ make clean ``` 4. Point your GitHub App's webhook URL to the smee.io channel, then open or update a pull request in the upstream repo to trigger a full relay cycle. + +5. Check whether the workflow run status is reported back through `callback-url`. diff --git a/aws/lambda/cross_repo_ci_relay/callback/Makefile b/aws/lambda/cross_repo_ci_relay/callback/Makefile new file mode 100644 index 0000000000..e8185278b7 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/callback/Makefile @@ -0,0 +1,19 @@ +UTILS_SRC := $(wildcard ../utils/*.py) +PIP_FLAGS := --platform manylinux2014_x86_64 --only-binary=:all: --implementation cp --python-version 3.13 +AWS_REGION := us-east-1 +FUNCTION_NAME := cross_repo_ci_callback + +deployment.zip: clean + mkdir -p ./deployment/callback + mkdir -p ./deployment/utils + cp *.py ./deployment/callback/ && cp $(UTILS_SRC) ./deployment/utils/ + pip3 install --target ./deployment -r ../requirements.txt $(PIP_FLAGS) + cd deployment && zip -r ../deployment.zip . + +deploy: deployment.zip + aws lambda update-function-code --region $(AWS_REGION) --function-name $(FUNCTION_NAME) --zip-file fileb://deployment.zip + +clean: + rm -rf deployment deployment.zip + +.PHONY: deploy clean diff --git a/aws/lambda/cross_repo_ci_relay/callback/callback_handler.py b/aws/lambda/cross_repo_ci_relay/callback/callback_handler.py new file mode 100644 index 0000000000..34ef63ced1 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/callback/callback_handler.py @@ -0,0 +1,223 @@ +from __future__ import annotations + +import logging +import time + +import utils.redis_helper as redis_helper +from redis.exceptions import RedisError +from utils.allowlist import AllowlistLevel, AllowlistMap, load_allowlist +from utils.config import RelayConfig +from utils.hud import forward_to_hud +from utils.misc import ( + CallbackState, + CallbackStateRecord, + DISPATCH_CHECK_RUN_ID, + HTTPException, +) +from utils.redis_helper import check_rate_limit + + +logger = logging.getLogger(__name__) + + +def _safe_delta( + start_ts: float | None, end_ts: float | None, label: str +) -> float | None: + """Compute end-start, clamping tiny negatives to 0 and returning None + whenever either endpoint is missing (e.g. Redis cache miss).""" + if start_ts is None or end_ts is None: + return None + delta = round(end_ts - start_ts, 3) + if delta < 0: + logger.warning("negative %s computed, start=%s end=%s", label, start_ts, end_ts) + return 0 + return delta + + +def _verify_access( + config: RelayConfig, verified_repo: str +) -> tuple[AllowlistMap, AllowlistLevel] | None: + """Return (AllowlistMap, repo_level) when ``verified_repo`` is L2+, else None. + + Raises HTTPException(429) if the per-repo rate limit is exceeded. + A ``None`` return signals the caller to silently ignore the request. + """ + allowlist = load_allowlist(config) + repo_level = allowlist.get_repo_level(verified_repo) + if repo_level is None or repo_level.value < AllowlistLevel.L2.value: + logger.info( + "verified_repo %s is not configured for L2+ features, ignoring result", + verified_repo, + ) + return None + if not check_rate_limit(config, verified_repo): + logger.warning( + "rate limit exceeded for verified_repo=%s, rejecting request", + verified_repo, + ) + raise HTTPException(429, f"rate limit exceeded for {verified_repo}") + return allowlist, repo_level + + +def _parse_callback_body(body: dict) -> tuple[str, str, str, str, str]: + """Return (delivery_id, status, check_run_id, job_name, run_id) from ``body``. + + check_run_id is set by GitHub Actions (job.check_run_id context) and + cannot be tampered with, ensuring replay-attack detection integrity. + + Raises HTTPException(400) on any missing or mis-typed field. + """ + try: + delivery_id = body["delivery_id"] + workflow_dict = body["workflow"] + status = workflow_dict["status"] + check_run_id = workflow_dict["check_run_id"] # Required + job_name = workflow_dict["job_name"] # Required for HUD grouping + run_id = workflow_dict["run_id"] # Required for HUD grouping + except (KeyError, TypeError) as exc: + logger.warning(f"missing required field in callback body: {exc}") + raise HTTPException( + 400, f"callback body missing required field: {exc}" + ) from exc + return delivery_id, status, check_run_id, job_name, run_id + + +def _update_state_and_compute_metrics( + config: RelayConfig, + delivery_id: str, + verified_repo: str, + check_run_id: str, + job_name: str, + run_id: str, + status: str, + dispatch_record: CallbackStateRecord, + job_record: CallbackStateRecord | None, +) -> dict: + """Persist the new job state to Redis and return CI timing metrics. + + Writes IN_PROGRESS or COMPLETED state (with the current timestamp), then + reads back the stored record to compute: + - ``queue_time``: dispatch → in_progress (set on "in_progress" callbacks) + - ``execution_time``: in_progress → completed (set on "completed" callbacks) + + Both metrics default to None when the required prior state is unavailable + (e.g. Redis cache miss or rerun without matching prior record). + """ + if status not in ("in_progress", "completed"): + raise HTTPException(400, f"unknown callback status: {status!r}") + + ci_metrics: dict = {"queue_time": None, "execution_time": None} + current_timestamp = time.time() + state = ( + CallbackState.IN_PROGRESS + if status == "in_progress" + else CallbackState.COMPLETED + ) + + try: + redis_helper.set_callback_state( + config, + delivery_id, + verified_repo, + check_run_id, + state, + current_timestamp, + job_name, + run_id, + ) + except RedisError: + raise HTTPException( + 503, "redis temporary outage: failed to persist callback state" + ) + except AssertionError as e: + msg = ( + "callback rejected: invalid state transition delivery_id=%s repo=%s status=%s" + % (delivery_id, verified_repo, status) + ) + raise HTTPException(400, msg) from e + except Exception: + raise + + updated_job_record = redis_helper.get_callback_state( + config, delivery_id, verified_repo, check_run_id + ) + if updated_job_record is None: + return ci_metrics + + if state == CallbackState.IN_PROGRESS: + ci_metrics["queue_time"] = _safe_delta( + dispatch_record.timestamp, + updated_job_record.timestamp, + "queue_time", + ) + else: + if job_record is not None: + ci_metrics["execution_time"] = _safe_delta( + job_record.timestamp, updated_job_record.timestamp, "execution_time" + ) + + return ci_metrics + + +def handle(config: RelayConfig, body: dict, verified_repo: str) -> dict: + """Forward a downstream callback to HUD. + + ``body`` is the downstream self-report, passed through to HUD verbatim. + It carries the original dispatch envelope (``delivery_id``, ``payload``) + and a sibling ``workflow`` dict with status/conclusion/name/url. + + ``verified_repo`` is the OIDC-authenticated downstream repository — used + for allowlist / timing lookups, and surfaced to HUD as ``verified_repo`` + so HUD can trust it over anything self-reported in the body. + + State machine ensures: + - Callbacks without prior dispatch are rejected + - Timestamps (started_at, completed_at) are recorded once only + - Duplicate callbacks are handled gracefully + - State transitions follow valid lifecycle paths + """ + result = _verify_access(config, verified_repo) + if result is None: + return {"ok": True, "status": "ignored"} + _, repo_level = result + + delivery_id, status, check_run_id, job_name, run_id = _parse_callback_body(body) + + dispatch_record = redis_helper.get_callback_state( + config, delivery_id, verified_repo, DISPATCH_CHECK_RUN_ID + ) + if not dispatch_record: + logger.warning( + "no dispatch record found for delivery_id=%s, verified_repo=%s; rejecting callback", + delivery_id, + verified_repo, + ) + raise HTTPException(400, "callback rejected: no matching dispatch record") + + job_record = redis_helper.get_callback_state( + config, delivery_id, verified_repo, check_run_id + ) + + ci_metrics = _update_state_and_compute_metrics( + config, + delivery_id, + verified_repo, + check_run_id, + job_name, + run_id, + status, + dispatch_record, + job_record, + ) + + trusted = { + "ci_metrics": ci_metrics, + "verified_repo": verified_repo, + "downstream_repo_level": repo_level.value, + } + # downstream's payload is untrusted — provide it under the "callback_payload" + # key so HUD receives it under the expected untrusted namespace. + untrusted = {"callback_payload": body} + + forward_to_hud(config, trusted, untrusted) + return {"ok": True, "status": status} diff --git a/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py b/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py new file mode 100644 index 0000000000..e573e4eeb9 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py @@ -0,0 +1,70 @@ +from __future__ import annotations + +import json +import logging + +from utils import jwt_helper +from utils.config import get_config +from utils.misc import HTTPException, JSON_HEADERS, parse_lambda_event + +from . import callback_handler + + +logging.getLogger().setLevel(logging.INFO) +logger = logging.getLogger(__name__) + + +def lambda_handler(event, context): + method, path, body_bytes, headers = parse_lambda_event(event) + + logger.info("request method=%s path=%s", method, path) + + if method != "POST" or path != "/github/callback": + if path == "/github/callback": + return { + "statusCode": 405, + "headers": JSON_HEADERS, + "body": json.dumps({"detail": "Method not allowed"}), + } + return { + "statusCode": 404, + "headers": JSON_HEADERS, + "body": json.dumps({"detail": "Not found"}), + } + + try: + config = get_config() + body = json.loads(body_bytes) if body_bytes else {} + + # OIDC is the only identity check Relay performs. The callback body is + # passed through to HUD untouched — HUD owns schema/business validation. + # Relay reports the OIDC-verified repo to HUD separately as + # `verified_repo` so HUD has a trusted source of truth for the + # caller's identity. + oidc_claims = jwt_helper.verify_oidc_token(headers.get("authorization", "")) + verified_repo = oidc_claims["repository"] + + result = callback_handler.handle(config, body, verified_repo) + return {"statusCode": 200, "headers": JSON_HEADERS, "body": json.dumps(result)} + + except json.JSONDecodeError: + logger.exception("Invalid JSON body") + return { + "statusCode": 400, + "headers": JSON_HEADERS, + "body": json.dumps({"detail": "Invalid JSON body"}), + } + except HTTPException as exc: + logger.exception(exc.detail) + return { + "statusCode": exc.status_code, + "headers": JSON_HEADERS, + "body": json.dumps({"detail": exc.detail}), + } + except Exception: + logger.exception("Internal server error") + return { + "statusCode": 500, + "headers": JSON_HEADERS, + "body": json.dumps({"detail": "Internal server error"}), + } diff --git a/aws/lambda/cross_repo_ci_relay/local_server.py b/aws/lambda/cross_repo_ci_relay/local_server.py index 9f41230b4c..2c0edeba60 100644 --- a/aws/lambda/cross_repo_ci_relay/local_server.py +++ b/aws/lambda/cross_repo_ci_relay/local_server.py @@ -7,13 +7,14 @@ load_dotenv(find_dotenv(usecwd=True)) -import lambda_function +from callback import lambda_function as callback_lambda +from webhook import lambda_function as webhook_lambda -webhook_router = APIRouter() +relay_router = APIRouter() -@webhook_router.post("/github/webhook") +@relay_router.post("/github/webhook") async def github_webhook(req: Request): body = await req.body() event = { @@ -28,20 +29,41 @@ async def github_webhook(req: Request): "isBase64Encoded": False, } - result = lambda_function.lambda_handler(event, None) + result = webhook_lambda.lambda_handler(event, None) + return JSONResponse( + status_code=result["statusCode"], content=json.loads(result["body"]) + ) + + +@relay_router.post("/github/callback") +async def github_callback(req: Request): + body = await req.body() + event = { + "requestContext": { + "http": { + "method": req.method, + "path": req.url.path, + } + }, + "headers": {k.decode(): v.decode() for k, v in req.scope["headers"]}, + "body": body.decode("utf-8"), + "isBase64Encoded": False, + } + + result = callback_lambda.lambda_handler(event, None) return JSONResponse( status_code=result["statusCode"], content=json.loads(result["body"]) ) # ================= FastAPI apps ================= -# - webhook_app: only /github/webhook (for smee forward) +# - relay_router: defines the same endpoints as the Lambda functions, but callable via HTTP for local testing -webhook_app = FastAPI() -webhook_app.include_router(webhook_router) +relay_server = FastAPI() +relay_server.include_router(relay_router) if __name__ == "__main__": import uvicorn - uvicorn.run("local_server:webhook_app", host="0.0.0.0", port=8000, reload=True) + uvicorn.run("local_server:relay_server", host="0.0.0.0", port=8000, reload=True) diff --git a/aws/lambda/cross_repo_ci_relay/redis_helper.py b/aws/lambda/cross_repo_ci_relay/redis_helper.py deleted file mode 100644 index d4e64fcdd3..0000000000 --- a/aws/lambda/cross_repo_ci_relay/redis_helper.py +++ /dev/null @@ -1,124 +0,0 @@ -import logging -import os -from urllib.parse import quote - -import redis as redis_lib -from config import RelayConfig - - -logger = logging.getLogger(__name__) - -_ALLOWLIST_CACHE_KEY = "crcr:allowlist_yaml" -_cached_client: redis_lib.Redis | None = None -_cached_client_url: str | None = None - - -def _parse_endpoint(endpoint: str) -> tuple[str, int]: - host = endpoint.strip() - - if not host: - raise RuntimeError("REDIS_ENDPOINT must not be empty") - - if host.startswith(("redis://", "rediss://")): - raise RuntimeError( - "REDIS_ENDPOINT must be a hostname or host:port, not a redis URL" - ) - - if "/" in host: - raise RuntimeError("REDIS_ENDPOINT must be a hostname or host:port") - - port = 6379 - if ":" in host: - maybe_host, maybe_port = host.rsplit(":", 1) - if not maybe_port.isdigit(): - raise RuntimeError(f"REDIS_ENDPOINT has invalid port: {maybe_port!r}") - host, port = maybe_host, int(maybe_port) - - return host, port - - -def _parse_login(login: str) -> tuple[str, str]: - login = login.strip() - if not login: - return "", "" - - if ":" in login: - username, password = login.split(":", 1) - return username, password - - # ElastiCache auth_token config provides only a password, not a username. - return "", login - - -def _build_url(config: RelayConfig) -> str: - host, port = _parse_endpoint(config.redis_endpoint or "") - auth = "" - username, password = _parse_login(config.redis_login or "") - if password and username: - auth = f"{quote(username, safe='')}:{quote(password, safe='')}@" - elif password: - auth = f":{quote(password, safe='')}@" - # Use TLS (rediss://) on AWS Lambda where ElastiCache requires it; - # fall back to plain redis:// for local development. - # AWS_LAMBDA_FUNCTION_NAME is automatically set by the Lambda runtime. - scheme = "rediss" if os.environ.get("AWS_LAMBDA_FUNCTION_NAME") else "redis" - return f"{scheme}://{auth}{host}:{port}/0" - - -def create_client(config: RelayConfig) -> redis_lib.Redis: - """Create or reuse a Redis client for the given config.""" - global _cached_client - global _cached_client_url - - redis_url = _build_url(config) - if _cached_client is not None and _cached_client_url == redis_url: - return _cached_client - - client = redis_lib.from_url( - redis_url, - decode_responses=True, - socket_connect_timeout=2, - socket_timeout=2, - ) - _cached_client = client - _cached_client_url = redis_url - return client - - -def get_cached_yaml( - config: RelayConfig, client: redis_lib.Redis | None = None -) -> str | None: - """Return cached allowlist YAML string, or None on cache miss or Redis error.""" - try: - if client is None: - client = create_client(config) - value = client.get(_ALLOWLIST_CACHE_KEY) - if value is not None: - logger.info("allowlist cache hit key=%s", _ALLOWLIST_CACHE_KEY) - return value - except redis_lib.exceptions.RedisError as exc: - error_message = str(exc) - logger.warning( - "redis cache read failed, falling back to source: %s", - error_message, - ) - return None - - -def set_cached_yaml( - config: RelayConfig, yaml_str: str, client: redis_lib.Redis | None = None -) -> None: - """Cache allowlist YAML string with TTL. Logs and ignores Redis errors.""" - try: - if client is None: - client = create_client(config) - client.setex(_ALLOWLIST_CACHE_KEY, config.allowlist_ttl_seconds, yaml_str) - logger.info( - "allowlist cached %d bytes key=%s", len(yaml_str), _ALLOWLIST_CACHE_KEY - ) - except redis_lib.exceptions.RedisError as exc: - error_message = str(exc) - logger.warning( - "redis cache write failed, continuing without cache: %s", - error_message, - ) diff --git a/aws/lambda/cross_repo_ci_relay/requirements.txt b/aws/lambda/cross_repo_ci_relay/requirements.txt index f1e84d1e02..172030f82e 100644 --- a/aws/lambda/cross_repo_ci_relay/requirements.txt +++ b/aws/lambda/cross_repo_ci_relay/requirements.txt @@ -2,3 +2,4 @@ PyYAML==6.0.3 PyGithub==2.9.0 redis==7.4.0 boto3==1.42.78 +PyJWT==2.10.1 diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_allowlist.py b/aws/lambda/cross_repo_ci_relay/tests/test_allowlist.py index 261ef58154..63c77971b6 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_allowlist.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_allowlist.py @@ -1,6 +1,6 @@ import unittest -from allowlist import AllowlistLevel, AllowlistMap +from utils.allowlist import AllowlistLevel, AllowlistMap class TestAllowlistMap(unittest.TestCase): diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_callback_handler.py b/aws/lambda/cross_repo_ci_relay/tests/test_callback_handler.py new file mode 100644 index 0000000000..e2ff10bb82 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/tests/test_callback_handler.py @@ -0,0 +1,221 @@ +import time +import unittest +from unittest.mock import MagicMock, patch + +from callback.callback_handler import handle +from utils.allowlist import AllowlistLevel +from utils.misc import CallbackState, DISPATCH_CHECK_RUN_ID, HTTPException +from utils.redis_helper import CallbackStateRecord + + +def _cfg(): + cfg = MagicMock() + cfg.hud_api_url = "http://hud/api/oot-ci-events" + cfg.hud_bot_key = "bot-key-123" + cfg.redis_endpoint = "host:6379" + cfg.redis_login = "" + cfg.oot_status_ttl = 259200 + cfg.rate_limit_per_min = 20 + return cfg + + +def _body(status="completed", job_name="default", check_run_id="12345", run_id="99999"): + return { + "event_type": "pull_request", + "delivery_id": "del-123", + "payload": { + "pull_request": {"number": 42, "head": {"sha": "abc123"}}, + "repository": {"full_name": "pytorch/pytorch"}, + }, + "workflow": { + "status": status, + "conclusion": "success" if status == "completed" else None, + "name": "CI", + "url": "http://ci.example.com/run/1", + "job_name": job_name, + "check_run_id": check_run_id, + "run_id": run_id, + }, + } + + +class TestCallbackHandler(unittest.TestCase): + def setUp(self): + self.patcher_allowlist = patch("callback.callback_handler.load_allowlist") + self.mock_load_allowlist = self.patcher_allowlist.start() + mock_map = MagicMock() + mock_map.get_repos_at_or_above_level.return_value = (["org/repo"], []) + mock_map.get_repo_level.return_value = AllowlistLevel.L2 + self.mock_load_allowlist.return_value = mock_map + + self.patcher_redis = patch("callback.callback_handler.redis_helper") + self.mock_redis = self.patcher_redis.start() + self.mock_redis.create_client.return_value = MagicMock() + + # Setup default: dispatch exists, job state is None (in_progress not yet reported) + def default_get_state(cfg, delivery_id, repo, check_run_id_arg, client=None): + if check_run_id_arg == DISPATCH_CHECK_RUN_ID: + return CallbackStateRecord( + CallbackState.DISPATCHED, time.time() - 30, "dispatch-job", 11111 + ) + elif check_run_id_arg == "12345": # default check_run_id in _body() + return CallbackStateRecord( + CallbackState.IN_PROGRESS, time.time() - 20, "default", 99999 + ) + return None + + self.mock_redis.get_callback_state.side_effect = default_get_state + + self.patcher_rate_limit = patch("callback.callback_handler.check_rate_limit") + self.mock_check_rate_limit = self.patcher_rate_limit.start() + self.mock_check_rate_limit.return_value = True + + self.patcher_hud = patch("callback.callback_handler.forward_to_hud") + self.mock_hud = self.patcher_hud.start() + + def tearDown(self): + self.patcher_allowlist.stop() + self.patcher_redis.stop() + self.patcher_rate_limit.stop() + self.patcher_hud.stop() + + # --- allowlist uses the OIDC-verified repo, not the body --- + + def test_verified_repo_not_in_l2_returns_ignored(self): + mock_map = MagicMock() + mock_map.get_repo_level.return_value = None + self.mock_load_allowlist.return_value = mock_map + + result = handle(_cfg(), _body(), verified_repo="org/repo") + + self.assertEqual(result, {"ok": True, "status": "ignored"}) + self.assertFalse(self.mock_redis.create_client.called) + self.assertFalse(self.mock_hud.called) + + # --- body is forwarded to HUD verbatim; verified_repo is a sibling --- + + def test_body_is_passed_to_hud_unchanged(self): + body = _body() + handle(_cfg(), body, verified_repo="org/repo") + + # forward_to_hud(config, trusted, untrusted) + _, trusted_arg, untrusted_arg = self.mock_hud.call_args[0] + self.assertIs(untrusted_arg["callback_payload"], body) + self.assertEqual(trusted_arg.get("verified_repo"), "org/repo") + # verified_repo is a sibling of ci_metrics, not nested inside it. + self.assertNotIn("verified_repo", trusted_arg.get("ci_metrics", {})) + + # --- timing metrics calculation --- + + def test_queue_time_calculated_from_state_records(self): + """queue_time is the dispatch-to-in_progress delta.""" + dispatch_record = CallbackStateRecord( + CallbackState.DISPATCHED, 1000.0, "dispatch-job", 11111 + ) + job_record = CallbackStateRecord( + CallbackState.IN_PROGRESS, 1030.0, "default", 99999 + ) + self.mock_redis.get_callback_state.side_effect = [ + dispatch_record, # dispatch lookup + None, # job state: not yet set + job_record, # re-read after set_callback_state + ] + + handle(_cfg(), _body(status="in_progress"), verified_repo="org/repo") + + _, trusted_arg, _ = self.mock_hud.call_args[0] + metrics = trusted_arg["ci_metrics"] + self.assertEqual(metrics["queue_time"], 30.0) + self.assertIsNone(metrics["execution_time"]) + + def test_execution_time_calculated_from_state_records(self): + """execution_time is the in_progress-to-completed delta.""" + dispatch_record = CallbackStateRecord( + CallbackState.DISPATCHED, 1000.0, "dispatch-job", 11111 + ) + job_record = CallbackStateRecord( + CallbackState.IN_PROGRESS, 1030.0, "default", 99999 + ) + completed_record = CallbackStateRecord( + CallbackState.COMPLETED, 1060.0, "default", 99999 + ) + self.mock_redis.get_callback_state.side_effect = [ + dispatch_record, # dispatch lookup + job_record, # job state: in_progress + completed_record, # re-read after set_callback_state + ] + + handle(_cfg(), _body(status="completed"), verified_repo="org/repo") + + _, trusted_arg, _ = self.mock_hud.call_args[0] + self.assertEqual(trusted_arg["ci_metrics"]["execution_time"], 30.0) + + # --- HUD 4xx propagates (5xx is swallowed inside forward_to_hud) --- + + def test_hud_4xx_propagates(self): + self.mock_hud.side_effect = HTTPException(422, "bad schema") + + with self.assertRaises(HTTPException) as ctx: + handle(_cfg(), _body(), verified_repo="org/repo") + self.assertEqual(ctx.exception.status_code, 422) + + # --- required field validation --- + + def test_missing_delivery_id_returns_400(self): + body = _body() + del body["delivery_id"] + + with self.assertRaises(HTTPException) as ctx: + handle(_cfg(), body, verified_repo="org/repo") + self.assertEqual(ctx.exception.status_code, 400) + + def test_missing_workflow_status_returns_400(self): + body = _body() + del body["workflow"]["status"] + + with self.assertRaises(HTTPException) as ctx: + handle(_cfg(), body, verified_repo="org/repo") + self.assertEqual(ctx.exception.status_code, 400) + + # --- rate limiting --- + + def test_rate_limit_exceeded_returns_429(self): + self.mock_check_rate_limit.return_value = False + + with self.assertRaises(HTTPException) as ctx: + handle(_cfg(), _body(), verified_repo="org/repo") + self.assertEqual(ctx.exception.status_code, 429) + self.assertFalse(self.mock_hud.called) + + # --- Redis outage during callback is tolerated --- + + def test_redis_error_fetching_dispatch_record_rejected(self): + """Redis error on dispatch lookup returns None, causing 400 rejection.""" + self.mock_redis.get_callback_state.side_effect = [ + None, # dispatch lookup returns None (get_callback_state catches RedisError) + ] + + with self.assertRaises(HTTPException) as ctx: + handle(_cfg(), _body(status="completed"), verified_repo="org/repo") + self.assertEqual(ctx.exception.status_code, 400) + + def test_redis_error_fetching_job_record_proceeds(self): + """Redis error on job record lookup returns None; callback proceeds.""" + dispatch_record = CallbackStateRecord( + CallbackState.DISPATCHED, 1000.0, "dispatch-job", 11111 + ) + # Three calls: dispatch lookup, job record lookup, re-read after set. + self.mock_redis.get_callback_state.side_effect = [ + dispatch_record, + None, # job record lookup returns None (get_callback_state catches RedisError) + None, # updated_job_record re-read → early return with empty metrics + ] + + handle(_cfg(), _body(status="completed"), verified_repo="org/repo") + + _, trusted_arg, _ = self.mock_hud.call_args[0] + self.assertIsNone(trusted_arg["ci_metrics"]["execution_time"]) + + +if __name__ == "__main__": + unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_callback_handler_lambda.py b/aws/lambda/cross_repo_ci_relay/tests/test_callback_handler_lambda.py new file mode 100644 index 0000000000..9bb9f592c2 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/tests/test_callback_handler_lambda.py @@ -0,0 +1,116 @@ +import base64 +import json +import unittest +from unittest.mock import patch + +from callback.lambda_function import lambda_handler +from utils.misc import HTTPException + + +def _event( + *, + method="POST", + path="/github/callback", + body=None, + headers=None, + base64_encoded=False, +): + if body is None: + body = json.dumps({"status": "completed", "head_sha": "abc123"}) + if base64_encoded: + body = base64.b64encode(body.encode()).decode() + if headers is None: + hdrs = {"authorization": "Bearer oidc.tok"} + else: + hdrs = dict(headers) + return { + "requestContext": {"http": {"method": method, "path": path}}, + "body": body, + "isBase64Encoded": base64_encoded, + "headers": hdrs, + } + + +class TestCallbackLambdaHandler(unittest.TestCase): + def setUp(self): + import utils.config + + utils.config._cached_config = None + + def test_route_validation(self): + response = lambda_handler(_event(path="/other"), {}) + self.assertEqual(response["statusCode"], 404) + response = lambda_handler(_event(method="GET"), {}) + self.assertEqual(response["statusCode"], 405) + + @patch("callback.lambda_function.get_config") + def test_invalid_json_body_returns_400(self, mock_get_config): + response = lambda_handler(_event(body="not-json"), {}) + self.assertEqual(response["statusCode"], 400) + + @patch("callback.lambda_function.get_config") + def test_missing_authorization_header_returns_401(self, mock_get_config): + response = lambda_handler(_event(headers={}), {}) + self.assertEqual(response["statusCode"], 401) + self.assertIn("Missing", json.loads(response["body"])["detail"]) + + @patch("callback.lambda_function.get_config") + @patch("callback.lambda_function.jwt_helper.verify_oidc_token") + def test_oidc_failure_returns_401(self, mock_oidc, mock_get_config): + mock_oidc.side_effect = HTTPException(401, "Invalid authorization token") + + response = lambda_handler(_event(), {}) + + self.assertEqual(response["statusCode"], 401) + + @patch("callback.lambda_function.get_config") + @patch("callback.lambda_function.jwt_helper.verify_oidc_token") + @patch("callback.lambda_function.callback_handler.handle") + def test_happy_path_forwards_body_and_verified_repo( + self, mock_handle, mock_oidc, mock_get_config + ): + mock_oidc.return_value = {"repository": "org/repo"} + mock_handle.return_value = {"ok": True, "status": "completed"} + + response = lambda_handler(_event(), {}) + + self.assertEqual(response["statusCode"], 200) + self.assertEqual( + json.loads(response["body"]), {"ok": True, "status": "completed"} + ) + # Body passed through verbatim, verified_repo comes from OIDC claims. + args = mock_handle.call_args[0] + self.assertEqual(args[1], {"status": "completed", "head_sha": "abc123"}) + self.assertEqual(args[2], "org/repo") + + @patch("callback.lambda_function.get_config") + @patch("callback.lambda_function.jwt_helper.verify_oidc_token") + @patch("callback.lambda_function.callback_handler.handle") + def test_hud_error_from_handler_is_forwarded( + self, mock_handle, mock_oidc, mock_get_config + ): + # HUD's HTTP status propagates out of Relay (transparent proxy). + mock_oidc.return_value = {"repository": "org/repo"} + mock_handle.side_effect = HTTPException(503, "HUD unreachable") + + response = lambda_handler(_event(), {}) + + self.assertEqual(response["statusCode"], 503) + self.assertEqual(json.loads(response["body"])["detail"], "HUD unreachable") + + @patch("callback.lambda_function.get_config") + @patch("callback.lambda_function.jwt_helper.verify_oidc_token") + @patch("callback.lambda_function.callback_handler.handle") + def test_unhandled_exception_returns_500( + self, mock_handle, mock_oidc, mock_get_config + ): + mock_oidc.return_value = {"repository": "org/repo"} + mock_handle.side_effect = Exception("boom") + + response = lambda_handler(_event(), {}) + + self.assertEqual(response["statusCode"], 500) + + +if __name__ == "__main__": + unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_config.py b/aws/lambda/cross_repo_ci_relay/tests/test_config.py index 05ed48f311..a34c820983 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_config.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_config.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import patch -from config import RelayConfig, RelaySecrets +from utils.config import RelayConfig, RelaySecrets _ENV = { @@ -27,7 +27,7 @@ def test_missing_vars_raises(self): with self.assertRaises(RuntimeError): RelayConfig.from_env() - @patch("config.RelaySecrets.from_aws") + @patch("utils.config.RelaySecrets.from_aws") @patch.dict( "os.environ", { @@ -43,6 +43,7 @@ def test_secrets_manager_fallback(self, mock_aws): github_app_secret="s", github_app_private_key="k", redis_login="secret-pass", + hud_bot_key="hud-key", ) cfg = RelayConfig.from_env() self.assertEqual(cfg.github_app_secret, "s") diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py b/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py index 49bfb6d836..9991f6b04e 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import call, MagicMock, patch -from event_handler import handle +from webhook.event_handler import handle def _cfg(): @@ -9,6 +9,7 @@ def _cfg(): cfg.github_app_id = "12345" cfg.github_app_private_key = "fake-key" cfg.max_dispatch_workers = 4 + cfg.github_app_secret = "test-secret" return cfg @@ -32,10 +33,11 @@ def test_ignored_action(self): {"ignored": True}, ) - @patch("event_handler.gh_helper.create_repository_dispatch") - @patch("event_handler.gh_helper.get_repo_access_token", return_value="tok") - @patch("event_handler.load_allowlist") - def test_dispatch_success(self, mock_load, _tok, mock_dispatch): + @patch("webhook.event_handler.redis_helper.set_callback_state") + @patch("webhook.event_handler.gh_helper.create_repository_dispatch") + @patch("webhook.event_handler.gh_helper.get_repo_access_token", return_value="tok") + @patch("webhook.event_handler.load_allowlist") + def test_dispatch_success(self, mock_load, _tok, mock_dispatch, _mock_set_state): mock_load.return_value = MagicMock( get_repos_at_or_above_level=MagicMock(return_value=(["org/a"], [])) ) @@ -43,14 +45,15 @@ def test_dispatch_success(self, mock_load, _tok, mock_dispatch): self.assertTrue(result["ok"]) mock_dispatch.assert_called_once() - @patch("event_handler.gh_helper.create_repository_dispatch") + @patch("webhook.event_handler.redis_helper.set_callback_state") + @patch("webhook.event_handler.gh_helper.create_repository_dispatch") @patch( - "event_handler.gh_helper.get_repo_access_token", + "webhook.event_handler.gh_helper.get_repo_access_token", side_effect=["tok-a", "tok-b"], ) - @patch("event_handler.load_allowlist") + @patch("webhook.event_handler.load_allowlist") def test_dispatch_mints_token_per_downstream_repo( - self, mock_load, mock_get_repo_access_token, mock_dispatch + self, mock_load, mock_get_repo_access_token, mock_dispatch, _mock_set_state ): mock_load.return_value = MagicMock( get_repos_at_or_above_level=MagicMock( diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_lambda_function.py b/aws/lambda/cross_repo_ci_relay/tests/test_event_handler_lambda.py similarity index 83% rename from aws/lambda/cross_repo_ci_relay/tests/test_lambda_function.py rename to aws/lambda/cross_repo_ci_relay/tests/test_event_handler_lambda.py index 517697a49b..f2ecd5a142 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_lambda_function.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_event_handler_lambda.py @@ -4,9 +4,8 @@ import unittest from unittest.mock import MagicMock, patch -import lambda_function -from lambda_function import lambda_handler -from utils import HTTPException +from utils.misc import HTTPException +from webhook.lambda_function import lambda_handler SECRET = "test-key" @@ -54,41 +53,43 @@ def _event(*, method="POST", path="/github/webhook", body=None, headers=None): class TestLambdaHandler(unittest.TestCase): def setUp(self): - lambda_function._cached_config = None + import utils.config + + utils.config._cached_config = None def test_route_error_404_and_405(self): self.assertEqual(lambda_handler(_event(path="/other"), None)["statusCode"], 404) self.assertEqual(lambda_handler(_event(method="GET"), None)["statusCode"], 405) - @patch("lambda_function.RelayConfig.from_env") + @patch("utils.config.RelayConfig.from_env") def test_bad_signature_401(self, mock_env): mock_env.return_value = _cfg() ev = _event(headers={"x-hub-signature-256": "sha256=bad"}) self.assertEqual(lambda_handler(ev, None)["statusCode"], 401) - @patch("lambda_function.RelayConfig.from_env") + @patch("utils.config.RelayConfig.from_env") def test_success_delegates_to_handler(self, mock_env): mock_env.return_value = _cfg() mock_handle = MagicMock(return_value={"ok": True}) - with patch("lambda_function.event_handler.handle", mock_handle): + with patch("webhook.lambda_function.event_handler.handle", mock_handle): resp = lambda_handler(_event(), None) self.assertEqual(resp["statusCode"], 200) mock_handle.assert_called_once() - @patch("lambda_function.RelayConfig.from_env") + @patch("utils.config.RelayConfig.from_env") def test_http_exception_forwarded(self, mock_env): mock_env.return_value = _cfg() with patch( - "lambda_function.event_handler.handle", + "webhook.lambda_function.event_handler.handle", MagicMock(side_effect=HTTPException(502, "err")), ): self.assertEqual(lambda_handler(_event(), None)["statusCode"], 502) - @patch("lambda_function.RelayConfig.from_env") + @patch("utils.config.RelayConfig.from_env") def test_config_cached_across_warm_invocations(self, mock_env): mock_env.return_value = _cfg() with patch( - "lambda_function.event_handler.handle", + "webhook.lambda_function.event_handler.handle", MagicMock(return_value={"ok": True}), ): first = lambda_handler(_event(), None) diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_hud.py b/aws/lambda/cross_repo_ci_relay/tests/test_hud.py new file mode 100644 index 0000000000..e3c19bb6f4 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/tests/test_hud.py @@ -0,0 +1,109 @@ +import json +import unittest +import urllib.error +from unittest.mock import MagicMock, patch + +from utils.hud import forward_to_hud +from utils.misc import HTTPException + + +def _cfg(url="http://hud/api/oot-ci-events", key="bot-key", max_retries=3): + cfg = MagicMock() + cfg.hud_api_url = url + cfg.hud_bot_key = key + cfg.hud_max_retries = max_retries + return cfg + + +class TestForwardToHud(unittest.TestCase): + @patch("utils.hud.urllib.request.urlopen") + def test_empty_url_skips_request(self, mock_urlopen): + forward_to_hud( + _cfg(url=""), + {"ci_metrics": {}, "verified_repo": "org/repo"}, + {"callback_payload": {"delivery_id": "d"}}, + ) + mock_urlopen.assert_not_called() + + @patch("utils.hud.urllib.request.urlopen") + def test_hud_payload_has_three_top_level_fields(self, mock_urlopen): + resp = MagicMock() + resp.status = 200 + mock_urlopen.return_value.__enter__.return_value = resp + + report = {"delivery_id": "d", "workflow": {"status": "completed"}} + metrics = {"queue_time": 1.0, "execution_time": 2.0} + forward_to_hud( + _cfg(), + {"ci_metrics": metrics, "verified_repo": "org/repo"}, + {"callback_payload": report}, + ) + + sent = json.loads(mock_urlopen.call_args[0][0].data) + self.assertEqual(sent["trusted"]["ci_metrics"], metrics) + self.assertEqual(sent["trusted"]["verified_repo"], "org/repo") + self.assertEqual(sent["untrusted"]["callback_payload"], report) + + @patch("utils.hud.urllib.request.urlopen") + def test_4xx_propagates_with_huds_status(self, mock_urlopen): + # 4xx means the caller sent bad data — propagate so the downstream + # workflow author sees a red step. + mock_urlopen.side_effect = urllib.error.HTTPError( + "http://hud", 422, "bad schema", {}, None + ) + + with self.assertRaises(HTTPException) as ctx: + forward_to_hud( + _cfg(), + {"ci_metrics": {}, "verified_repo": "org/repo"}, + {"callback_payload": {}}, + ) + self.assertEqual(ctx.exception.status_code, 422) + + @patch("utils.hud.time.sleep") + @patch("utils.hud.urllib.request.urlopen") + def test_retries_exhausted(self, mock_urlopen, mock_sleep): + """5xx and URLError are retried; after exhaustion an exception is raised.""" + cases = [ + ( + urllib.error.HTTPError("http://hud", 500, "err", {}, None), + HTTPException, + 500, + ), + (urllib.error.URLError("unreachable"), urllib.error.URLError, None), + ] + for exc, expected_type, expected_code in cases: + with self.subTest(exc=exc): + mock_urlopen.reset_mock() + mock_sleep.reset_mock() + mock_urlopen.side_effect = exc + + with self.assertRaises(expected_type) as ctx: + forward_to_hud( + _cfg(max_retries=2), + {"ci_metrics": {}, "verified_repo": "org/repo"}, + {"callback_payload": {}}, + ) + if expected_code is not None: + self.assertEqual(ctx.exception.status_code, expected_code) + self.assertEqual(mock_urlopen.call_count, 3) # 1 + 2 retries + self.assertEqual(mock_sleep.call_count, 2) + + @patch("utils.hud.time.sleep") + @patch("utils.hud.urllib.request.urlopen") + def test_5xx_succeeds_on_retry(self, mock_urlopen, mock_sleep): + mock_urlopen.side_effect = [ + urllib.error.HTTPError("http://hud", 500, "unavailable", {}, None), + MagicMock(status=200), + ] + forward_to_hud( + _cfg(), + {"ci_metrics": {}, "verified_repo": "org/repo"}, + {"callback_payload": {}}, + ) + self.assertEqual(mock_urlopen.call_count, 2) + self.assertEqual(mock_sleep.call_count, 1) + + +if __name__ == "__main__": + unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py b/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py new file mode 100644 index 0000000000..a9887f67e8 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py @@ -0,0 +1,75 @@ +import unittest +from unittest.mock import MagicMock, patch + +from utils.jwt_helper import verify_oidc_token +from utils.misc import HTTPException + + +class TestVerifyDownstreamIdentity(unittest.TestCase): + def setUp(self): + self.patcher_jwks = patch( + "utils.jwt_helper._jwks_client.get_signing_key_from_jwt" + ) + self.mock_signing_key = self.patcher_jwks.start() + self.mock_signing_key.return_value = MagicMock(key="fake-key") + + self.patcher_decode = patch("utils.jwt_helper.jwt.decode") + self.mock_decode = self.patcher_decode.start() + + def tearDown(self): + self.patcher_jwks.stop() + self.patcher_decode.stop() + + def test_valid_token_returns_claims(self): + expected = { + "repository": "org/repo", + "sub": "repo:org/repo:ref:refs/heads/main", + } + self.mock_decode.return_value = expected + + claims = verify_oidc_token("some.oidc.token") + + self.assertEqual(claims, expected) + self.assertEqual(claims, expected) + self.mock_decode.assert_called_once() + self.assertEqual( + self.mock_decode.call_args.kwargs["audience"], + "pytorch-cross-repo-ci-relay", + ) + self.assertEqual( + self.mock_decode.call_args.kwargs["issuer"], + "https://token.actions.githubusercontent.com", + ) + + def test_wrong_audience_raises_401(self): + import jwt as _jwt + + self.mock_decode.side_effect = _jwt.InvalidAudienceError("bad aud") + with self.assertRaises(HTTPException) as ctx: + verify_oidc_token("token.with.wrong.aud") + self.assertEqual(ctx.exception.status_code, 401) + + def test_bearer_prefix_stripped_before_jwks_lookup(self): + self.mock_decode.return_value = {"repository": "org/repo"} + + verify_oidc_token("Bearer some.oidc.token") + + self.mock_signing_key.assert_called_once_with("some.oidc.token") + + def test_empty_token_raises_401_without_jwks_lookup(self): + with self.assertRaises(HTTPException) as ctx: + verify_oidc_token("") + self.assertEqual(ctx.exception.status_code, 401) + self.assertIn("Missing", ctx.exception.detail) + self.mock_signing_key.assert_not_called() + + def test_jwks_lookup_failure_raises_401(self): + self.mock_signing_key.side_effect = Exception("JWKS fetch failed") + + with self.assertRaises(HTTPException) as ctx: + verify_oidc_token("bad.token") + self.assertEqual(ctx.exception.status_code, 401) + + +if __name__ == "__main__": + unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py b/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py index f20a1d1f7e..91c652290c 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py @@ -1,13 +1,19 @@ +import json import unittest +import unittest.mock from unittest.mock import MagicMock import redis as redis_lib -import redis_helper -from redis_helper import ( +from utils import redis_helper +from utils.misc import CallbackState, DISPATCH_CHECK_RUN_ID +from utils.redis_helper import ( _ALLOWLIST_CACHE_KEY, + CallbackStateRecord, create_client, get_cached_yaml, + get_callback_state, set_cached_yaml, + set_callback_state, ) @@ -16,6 +22,8 @@ def _cfg(): cfg.redis_endpoint = "host:6379" cfg.redis_login = "" cfg.allowlist_ttl_seconds = 600 + cfg.oot_status_ttl = 3600 + cfg.rate_limit_per_min = 20 return cfg @@ -56,5 +64,225 @@ def test_create_client_reuses_cached_client_for_same_url(self): mock_from_url.assert_called_once() +class TestCallbackStateMachine(unittest.TestCase): + def setUp(self): + redis_helper._cached_client = None + redis_helper._cached_client_url = None + + def test_set_dispatch_state_with_timestamp(self): + """Webhook sets DISPATCHED state.""" + client = MagicMock() + set_callback_state( + _cfg(), + "del-123", + "org/repo", + DISPATCH_CHECK_RUN_ID, + CallbackState.DISPATCHED, + 1000.0, + client=client, + ) + client.setex.assert_called_once() + + def test_get_callback_state_parses_json(self): + """get_callback_state returns CallbackStateRecord from JSON.""" + client = MagicMock() + client.get.return_value = json.dumps( + { + "state": "IN_PROGRESS", + "timestamp": 1010.5, + "job_name": "test-job", + "run_id": "12345", + } + ) + cfg = _cfg() + + record = get_callback_state(cfg, "del-123", "org/repo", "test-job", client) + + self.assertIsNotNone(record) + self.assertEqual(record.state, CallbackState.IN_PROGRESS) + self.assertEqual(record.timestamp, 1010.5) + self.assertEqual(record.job_name, "test-job") + self.assertEqual(record.run_id, "12345") + + def test_get_callback_state_returns_none_on_missing_key_and_on_redis_error(self): + """get_callback_state returns None on missing key and on Redis error.""" + client = MagicMock() + cfg = _cfg() + + client.get.return_value = None + self.assertIsNone( + get_callback_state(cfg, "del-123", "org/repo", "test-job", client) + ) + + client.get.side_effect = redis_lib.exceptions.RedisError("boom") + self.assertIsNone( + get_callback_state(cfg, "del-123", "org/repo", "test-job", client) + ) + + def test_invalid_state_transitions_rejected(self): + """Duplicate or invalid state transitions are all rejected.""" + cases = [ + # (check_run_id, new_state, existing_state_value_or_None) + (DISPATCH_CHECK_RUN_ID, CallbackState.DISPATCHED, "DISPATCHED"), + ("check-run", CallbackState.IN_PROGRESS, "IN_PROGRESS"), + ("check-run", CallbackState.COMPLETED, None), # None → COMPLETED + ("check-run", CallbackState.COMPLETED, "COMPLETED"), + ] + for check_run_id, state, existing in cases: + with self.subTest(state=state, existing=existing): + client = MagicMock() + client.get.return_value = ( + json.dumps( + { + "state": existing, + "timestamp": 1000.0, + "job_name": "job", + "run_id": "111", + } + ) + if existing + else None + ) + with self.assertRaises(AssertionError): + set_callback_state( + _cfg(), + "del-123", + "org/repo", + check_run_id, + state, + 1100.0, + client=client, + ) + client.setex.assert_not_called() + + def test_set_completed_from_in_progress_accepts(self): + """IN_PROGRESS → COMPLETED transition is accepted.""" + client = MagicMock() + client.get.return_value = json.dumps( + { + "state": "IN_PROGRESS", + "timestamp": 1010.0, + "job_name": "test-job", + "run_id": "12345", + } + ) + set_callback_state( + _cfg(), + "del-123", + "org/repo", + "test-job", + CallbackState.COMPLETED, + 1020.0, + job_name="test-job", + run_id="12345", + client=client, + ) + + def test_set_in_progress_accepts_first_callback(self): + """None → IN_PROGRESS is accepted when dispatch record exists.""" + + def get_side_effect(cfg, delivery_id, repo, check_run_id_arg, client=None): + if check_run_id_arg == DISPATCH_CHECK_RUN_ID: + return CallbackStateRecord( + CallbackState.DISPATCHED, 1000.0, "dispatch-job", 11111 + ) + return None + + client = MagicMock() + with unittest.mock.patch( + "utils.redis_helper.get_callback_state", side_effect=get_side_effect + ): + set_callback_state( + _cfg(), + "del-123", + "org/repo", + "check-run-456", + CallbackState.IN_PROGRESS, + 1010.0, + job_name="test-job", + run_id="99999", + client=client, + ) + + def test_set_non_dispatched_state_with_reserved_check_run_id_rejected(self): + """Using the reserved DISPATCH_CHECK_RUN_ID for non-DISPATCHED state is rejected.""" + client = MagicMock() + cfg = _cfg() + + for state in (CallbackState.IN_PROGRESS, CallbackState.COMPLETED): + with self.subTest(state=state): + client.reset_mock() + with self.assertRaises(AssertionError): + set_callback_state( + cfg, + "del-123", + "org/repo", + DISPATCH_CHECK_RUN_ID, + state, + 1010.0, + client=client, + ) + client.setex.assert_not_called() + + def test_set_callback_state_redis_exception_raises(self): + """Redis write failure is re-raised as RedisError.""" + + def get_side_effect(cfg, delivery_id, repo, check_run_id_arg, client=None): + if check_run_id_arg == DISPATCH_CHECK_RUN_ID: + return CallbackStateRecord( + CallbackState.DISPATCHED, 1000.0, "dispatch-job", "11111" + ) + return None + + cfg = _cfg() + client = MagicMock() + client.setex.side_effect = redis_lib.exceptions.RedisError("write failed") + + with unittest.mock.patch( + "utils.redis_helper.get_callback_state", side_effect=get_side_effect + ), self.assertRaises(redis_lib.exceptions.RedisError): + set_callback_state( + cfg, + "del-123", + "org/repo", + "check-run-456", + CallbackState.IN_PROGRESS, + 1010.0, + job_name="test-job", + run_id="99999", + client=client, + ) + + +class TestRateLimit(unittest.TestCase): + def setUp(self): + redis_helper._cached_client = None + redis_helper._cached_client_url = None + + def test_check_rate_limit_allowed(self): + from utils.redis_helper import check_rate_limit + + client = MagicMock() + client.zcard.return_value = 10 + self.assertTrue(check_rate_limit(_cfg(), "org/repo", client=client)) + + def test_check_rate_limit_exceeded(self): + from utils.redis_helper import check_rate_limit + + client = MagicMock() + client.zcard.return_value = 25 + self.assertFalse(check_rate_limit(_cfg(), "org/repo", client=client)) + + def test_check_rate_limit_redis_error_raises_500(self): + from utils.misc import HTTPException + from utils.redis_helper import check_rate_limit + + client = MagicMock() + client.zadd.side_effect = redis_lib.exceptions.RedisError("boom") + with self.assertRaises(HTTPException) as ctx: + check_rate_limit(_cfg(), "org/repo", client=client) + self.assertEqual(ctx.exception.status_code, 500) + + if __name__ == "__main__": unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/utils.py b/aws/lambda/cross_repo_ci_relay/utils.py deleted file mode 100644 index 78e613cfe2..0000000000 --- a/aws/lambda/cross_repo_ci_relay/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import TypedDict - - -class HTTPException(Exception): - def __init__(self, status_code: int, detail): - self.status_code = status_code - self.detail = detail - - -class EventDispatchPayload(TypedDict): - event_type: str - delivery_id: str - payload: dict diff --git a/aws/lambda/cross_repo_ci_relay/allowlist.py b/aws/lambda/cross_repo_ci_relay/utils/allowlist.py similarity index 93% rename from aws/lambda/cross_repo_ci_relay/allowlist.py rename to aws/lambda/cross_repo_ci_relay/utils/allowlist.py index 8be55ad639..cd972efff8 100644 --- a/aws/lambda/cross_repo_ci_relay/allowlist.py +++ b/aws/lambda/cross_repo_ci_relay/utils/allowlist.py @@ -13,10 +13,10 @@ from enum import Enum from urllib.parse import urlparse -import gh_helper -import redis_helper import yaml -from config import RelayConfig + +from . import gh_helper, redis_helper +from .config import RelayConfig logger = logging.getLogger(__name__) @@ -96,8 +96,13 @@ def get_repos_at_or_above_level( oncalls.extend(lvl_oncalls) return repos, oncalls - def __bool__(self) -> bool: - return any(bool(entries) for entries in self._levels.values()) + def get_repo_level(self, repo: str) -> AllowlistLevel | None: + """Return the level for a specific repo, or None if repo is not in allowlist.""" + for level, entries in self._levels.items(): + for entry in entries: + if entry.repo == repo: + return level + return None @classmethod def _parse(cls, raw: dict) -> "AllowlistMap": diff --git a/aws/lambda/cross_repo_ci_relay/config.py b/aws/lambda/cross_repo_ci_relay/utils/config.py similarity index 69% rename from aws/lambda/cross_repo_ci_relay/config.py rename to aws/lambda/cross_repo_ci_relay/utils/config.py index bfc5252a0f..707c6ed8a4 100644 --- a/aws/lambda/cross_repo_ci_relay/config.py +++ b/aws/lambda/cross_repo_ci_relay/utils/config.py @@ -11,13 +11,14 @@ class RelaySecrets: github_app_secret: str = "" github_app_private_key: str = "" redis_login: str = "" + hud_bot_key: str = "" @classmethod def from_aws(cls, secret_store_arn: str, client=None) -> "RelaySecrets": region = os.environ.get("AWS_REGION", "us-east-1") try: if client is None: - client = boto3.session.Session().client( + client = boto3.client( "secretsmanager", region_name=region, config=Config(retries={"max_attempts": 3, "mode": "standard"}), @@ -36,6 +37,7 @@ def from_aws(cls, secret_store_arn: str, client=None) -> "RelaySecrets": github_app_secret=secret.get("GITHUB_APP_SECRET", ""), github_app_private_key=secret.get("GITHUB_APP_PRIVATE_KEY", ""), redis_login=secret.get("REDIS_LOGIN", ""), + hud_bot_key=secret.get("HUD_BOT_KEY", ""), ) @@ -57,6 +59,11 @@ class RelayConfig: redis_login: str allowlist_ttl_seconds: int max_dispatch_workers: int + hud_api_url: str + hud_bot_key: str + oot_status_ttl: int + hud_max_retries: int + rate_limit_per_min: int @classmethod def from_env(cls) -> "RelayConfig": @@ -65,6 +72,7 @@ def from_env(cls) -> "RelayConfig": github_app_private_key = os.getenv("GITHUB_APP_PRIVATE_KEY", "") redis_login = os.getenv("REDIS_LOGIN", "") secret_store_arn = os.getenv("SECRET_STORE_ARN", "") + hud_bot_key = os.getenv("HUD_BOT_KEY", "") if not github_app_secret or not github_app_private_key or not redis_login: if not secret_store_arn: @@ -87,12 +95,14 @@ def from_env(cls) -> "RelayConfig": github_app_private_key or secrets.github_app_private_key ) redis_login = redis_login or secrets.redis_login + hud_bot_key = hud_bot_key or secrets.hud_bot_key missing_in_secret = [ v for v, val in [ ("GITHUB_APP_SECRET", github_app_secret), ("GITHUB_APP_PRIVATE_KEY", github_app_private_key), ("REDIS_LOGIN", redis_login), + ("HUD_BOT_KEY", hud_bot_key), ] if not val ] @@ -112,6 +122,36 @@ def from_env(cls) -> "RelayConfig": # avoidable rate-limit risk in production. allowlist_ttl_seconds = max(allowlist_ttl_seconds, 900) + # GitHub can keep a workflow job in `pending` state for up to 3 days before + # auto-cancelling it, so OOT-status records must live at least that long. + # Default to 3 days (259200 s). + try: + oot_status_ttl = int(os.getenv("OOT_STATUS_TTL", "259200")) + except ValueError: + raise RuntimeError("OOT_STATUS_TTL must be a valid integer") + + # Maximum number of retry attempts for HUD API calls. + # Default to 3 retries with exponential backoff. + try: + hud_max_retries = int(os.getenv("HUD_MAX_RETRIES", "3")) + if hud_max_retries < 0: + raise ValueError("must be non-negative") + except ValueError: + raise RuntimeError("HUD_MAX_RETRIES must be a non-negative integer") + + try: + rate_limit_per_min = int(os.getenv("RATE_LIMIT_PER_MIN", "20")) + if rate_limit_per_min <= 0: + raise ValueError("must be positive") + except ValueError: + raise RuntimeError("RATE_LIMIT_PER_MIN must be a positive integer") + + hud_api_url = os.getenv("HUD_API_URL", "") + if hud_api_url and not hud_api_url.startswith("https://"): + raise RuntimeError( + "HUD_API_URL must use https:// to protect the bot key in transit" + ) + return cls( github_app_id=_require("GITHUB_APP_ID"), github_app_secret=github_app_secret, @@ -122,4 +162,19 @@ def from_env(cls) -> "RelayConfig": redis_login=redis_login, allowlist_ttl_seconds=allowlist_ttl_seconds, max_dispatch_workers=int(os.getenv("MAX_DISPATCH_WORKERS", "32")), + hud_api_url=hud_api_url, + hud_bot_key=hud_bot_key, + oot_status_ttl=oot_status_ttl, + hud_max_retries=hud_max_retries, + rate_limit_per_min=rate_limit_per_min, ) + + +_cached_config: RelayConfig | None = None + + +def get_config() -> RelayConfig: + global _cached_config + if _cached_config is None: + _cached_config = RelayConfig.from_env() + return _cached_config diff --git a/aws/lambda/cross_repo_ci_relay/gh_helper.py b/aws/lambda/cross_repo_ci_relay/utils/gh_helper.py similarity index 98% rename from aws/lambda/cross_repo_ci_relay/gh_helper.py rename to aws/lambda/cross_repo_ci_relay/utils/gh_helper.py index 6d0d9c7fa6..4e279d1994 100644 --- a/aws/lambda/cross_repo_ci_relay/gh_helper.py +++ b/aws/lambda/cross_repo_ci_relay/utils/gh_helper.py @@ -4,7 +4,8 @@ import github from github import GithubIntegration -from utils import EventDispatchPayload + +from .misc import EventDispatchPayload logger = logging.getLogger(__name__) diff --git a/aws/lambda/cross_repo_ci_relay/utils/hud.py b/aws/lambda/cross_repo_ci_relay/utils/hud.py new file mode 100644 index 0000000000..40ba67ec8c --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/utils/hud.py @@ -0,0 +1,108 @@ +import json +import logging +import time +import urllib.error +import urllib.request + +from .config import RelayConfig +from .misc import HTTPException + + +logger = logging.getLogger(__name__) + + +def forward_to_hud(config: RelayConfig, trusted: dict, untrusted: dict) -> None: + """POST a callback record to HUD. + + This function splits inputs into two explicit namespaces: + + - ``trusted``: a dict supplied by this relay and therefore considered + authoritative. + - ``untrusted``: a dict forwarded from the downstream workflow and treated as + untrusted user-supplied data. + + Retry Behavior: + On server errors (HTTP 5xx) or network failures (URLError), the request + is retried up to ``config.hud_max_retries`` times with exponential backoff + (1s, 2s, 4s, ...). Client errors (HTTP 4xx) are not retried and raise + HTTPException immediately. + """ + if not config.hud_api_url: + # No HUD configured (e.g. local dev before HUD endpoint exists) — + # log and no-op rather than 500. Remove this branch once HUD is + # mandatory in every environment. + logger.info("HUD_API_URL not configured, skipping HUD write") + return + + hud_payload = json.dumps( + { + "trusted": trusted, + "untrusted": untrusted, + } + ).encode("utf-8") + + req = urllib.request.Request( + config.hud_api_url, + data=hud_payload, + headers={ + "Content-Type": "application/json", + "X-OOT-Relay-Token": config.hud_bot_key, + }, + method="POST", + ) + + last_exception = None + total_attempts = config.hud_max_retries + 1 + for attempt in range(total_attempts): + try: + with urllib.request.urlopen(req, timeout=10) as resp: + logger.info("HUD forward succeeded status=%d", resp.status) + return + except urllib.error.HTTPError as exc: + if 400 <= exc.code < 500: + detail = f"HUD rejected callback: HTTP {exc.code}: {exc.reason}" + logger.error("HUD forward failed (client error): %s", detail) + raise HTTPException(exc.code, detail) from exc + last_exception = exc + logger.debug( + "HUD forward failed (server error, attempt %d/%d): HTTP %d %s", + attempt + 1, + total_attempts, + exc.code, + exc.reason, + ) + except urllib.error.URLError as exc: + last_exception = exc + logger.debug( + "HUD forward failed (unreachable, attempt %d/%d): %s", + attempt + 1, + total_attempts, + exc.reason, + ) + + # If we have more retries remaining, wait with exponential backoff + if attempt < config.hud_max_retries: + time.sleep(2**attempt) + + # All retries exhausted + if isinstance(last_exception, urllib.error.HTTPError): + logger.error( + "HUD forward failed after %d attempts: HTTP %d %s", + total_attempts, + last_exception.code, + last_exception.reason, + ) + raise HTTPException( + 500, + "An internal failure occurred. " + "Your update was not saved, but the CI run is still valid. " + "You can attempt progressive retries after " + f"{60 // config.rate_limit_per_min} seconds or ignore this failure.", + ) from last_exception + else: + logger.error( + "HUD forward failed after %d attempts: %s", + config.hud_max_retries + 1, + last_exception.reason, + ) + raise last_exception diff --git a/aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py b/aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py new file mode 100644 index 0000000000..8149248016 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py @@ -0,0 +1,42 @@ +"""JWT utilities for the cross-repo CI relay.""" + +from __future__ import annotations + +import logging + +import jwt +from utils.misc import HTTPException + + +logger = logging.getLogger(__name__) + +_jwks_client = jwt.PyJWKClient( + "https://token.actions.githubusercontent.com/.well-known/jwks" +) + + +def verify_oidc_token(token: str) -> dict: + """Decode a GitHub Actions OIDC token and return the claims. + + Rejects an empty/missing token up front so every call site gets a uniform + 401 without repeating the check. Raises ``HTTPException(401)`` on any + verification failure. + """ + if not token: + raise HTTPException(401, "Missing authorization token") + + try: + if token.lower().startswith("bearer "): + token = token[7:].strip() + + signing_key = _jwks_client.get_signing_key_from_jwt(token) + return jwt.decode( + token, + signing_key.key, + algorithms=["RS256"], + issuer="https://token.actions.githubusercontent.com", + audience="pytorch-cross-repo-ci-relay", + ) + except Exception as exc: + logger.exception("OIDC token verification error") + raise HTTPException(401, "Invalid authorization token") from exc diff --git a/aws/lambda/cross_repo_ci_relay/utils/misc.py b/aws/lambda/cross_repo_ci_relay/utils/misc.py new file mode 100644 index 0000000000..aa1f3f2431 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/utils/misc.py @@ -0,0 +1,72 @@ +"""Small shared types, exceptions, and Lambda helpers. + +Grouped together because each piece is too small to justify its own module +and there's no shared abstraction to organise them under. +""" + +from __future__ import annotations + +import base64 +from dataclasses import dataclass +from enum import Enum +from typing import TypedDict + + +JSON_HEADERS = {"content-type": "application/json"} + + +class HTTPException(Exception): + def __init__(self, status_code: int, detail): + self.status_code = status_code + self.detail = detail + + +class EventDispatchPayload(TypedDict): + event_type: str + delivery_id: str + payload: dict + + +# Specific check_run_id used to identify callbacks +# from dispatches that didn't specify a real check_run_id. +# Here we set as a string for better readability, +# but it could be any unique identifier. +DISPATCH_CHECK_RUN_ID = "dispatched" + + +class CallbackState(str, Enum): + """Unified state machine for callback lifecycle (both webhook and callback sides). + + - ``DISPATCHED``: webhook side, when repository_dispatch is sent (job_name=DISPATCH_JOB_NAME). + - ``IN_PROGRESS``: callback side, when downstream workflow reports started (per-job). + - ``COMPLETED``: callback side, when downstream workflow reports finished (per-job). + """ + + DISPATCHED = "DISPATCHED" + IN_PROGRESS = "IN_PROGRESS" + COMPLETED = "COMPLETED" + + +@dataclass +class CallbackStateRecord: + """Record containing state, timestamp, and job metadata for HUD grouping.""" + + state: CallbackState + timestamp: float + job_name: str + run_id: str + + +def parse_lambda_event(event: dict) -> tuple[str, str, bytes, dict]: + """Extract method, path, body bytes, and lower-cased headers from a Lambda event dict.""" + http = event.get("requestContext", {}).get("http", {}) + method = http.get("method", "").upper() + path = http.get("path", "") + raw_body = event.get("body") or "" + body_bytes = ( + base64.b64decode(raw_body) + if event.get("isBase64Encoded") + else raw_body.encode("utf-8") + ) + headers = {k.lower(): v for k, v in (event.get("headers") or {}).items()} + return method, path, body_bytes, headers diff --git a/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py b/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py new file mode 100644 index 0000000000..74ac6677f8 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py @@ -0,0 +1,331 @@ +import json +import logging +import os +import time +from typing import cast +from urllib.parse import quote + +import redis as redis_lib +from redis.exceptions import RedisError + +from .config import RelayConfig +from .misc import ( + CallbackState, + CallbackStateRecord, + DISPATCH_CHECK_RUN_ID, + HTTPException, +) + + +logger = logging.getLogger(__name__) + +_ALLOWLIST_CACHE_KEY = "oot:allowlist_yaml" +_STATE_PREFIX = "oot:state:" +_RATE_LIMIT_PREFIX = "oot:rate:" +_cached_client: redis_lib.Redis | None = None +_cached_client_url: str | None = None + + +def _parse_endpoint(endpoint: str) -> tuple[str, int]: + host = endpoint.strip() + + if not host: + raise RuntimeError("REDIS_ENDPOINT must not be empty") + + if host.startswith(("redis://", "rediss://")): + raise RuntimeError( + "REDIS_ENDPOINT must be a hostname or host:port, not a redis URL" + ) + + if "/" in host: + raise RuntimeError("REDIS_ENDPOINT must be a hostname or host:port") + + port = 6379 + if ":" in host: + maybe_host, maybe_port = host.rsplit(":", 1) + if not maybe_port.isdigit(): + raise RuntimeError(f"REDIS_ENDPOINT has invalid port: {maybe_port!r}") + host, port = maybe_host, int(maybe_port) + + return host, port + + +def _parse_login(login: str) -> tuple[str, str]: + login = login.strip() + if not login: + return "", "" + + if ":" in login: + username, password = login.split(":", 1) + return username, password + + # ElastiCache auth_token config provides only a password, not a username. + return "", login + + +def _build_url(config: RelayConfig) -> str: + host, port = _parse_endpoint(config.redis_endpoint or "") + auth = "" + username, password = _parse_login(config.redis_login or "") + if password and username: + auth = f"{quote(username, safe='')}:{quote(password, safe='')}@" + elif password: + auth = f":{quote(password, safe='')}@" + # Use TLS (rediss://) on AWS Lambda where ElastiCache requires it; + # fall back to plain redis:// for local development. + # AWS_LAMBDA_FUNCTION_NAME is automatically set by the Lambda runtime. + scheme = "rediss" if os.environ.get("AWS_LAMBDA_FUNCTION_NAME") else "redis" + return f"{scheme}://{auth}{host}:{port}/0" + + +def create_client(config: RelayConfig) -> redis_lib.Redis: + """Create or reuse a Redis client for the given config.""" + global _cached_client + global _cached_client_url + try: + redis_url = _build_url(config) + if _cached_client is not None and _cached_client_url == redis_url: + return _cached_client + + client = redis_lib.from_url( + redis_url, + decode_responses=True, + socket_connect_timeout=2, + socket_timeout=2, + ) + except Exception: + logger.exception("Error creating Redis client") + raise RuntimeError("Failed to create Redis client") + _cached_client = client + _cached_client_url = redis_url + return client + + +def get_cached_yaml( + config: RelayConfig, client: redis_lib.Redis | None = None +) -> str | None: + """Return cached allowlist YAML string, or None on cache miss or Redis error.""" + try: + if client is None: + client = create_client(config) + value = client.get(_ALLOWLIST_CACHE_KEY) + if value is not None: + logger.info("allowlist cache hit key=%s", _ALLOWLIST_CACHE_KEY) + return cast(str | None, value) + except RedisError: + logger.exception( + "redis cache read failed, falling back to source", + ) + return None + + +def set_cached_yaml( + config: RelayConfig, yaml_str: str, client: redis_lib.Redis | None = None +) -> None: + """Cache allowlist YAML string with TTL. Logs and ignores Redis errors.""" + try: + if client is None: + client = create_client(config) + client.setex(_ALLOWLIST_CACHE_KEY, config.allowlist_ttl_seconds, yaml_str) + logger.info( + "allowlist cached %d bytes key=%s", len(yaml_str), _ALLOWLIST_CACHE_KEY + ) + except RedisError: + logger.exception("redis cache write failed, continuing without cache") + + +def check_rate_limit( + config: RelayConfig, + repo: str, + client: redis_lib.Redis | None = None, +) -> bool: + """Check if repo is within rate limit using sliding window. + + Returns True if allowed, False if rate exceeded. + Raises HTTPException(500) on Redis failure (fail-closed). + """ + try: + if client is None: + client = create_client(config) + + key = f"{_RATE_LIMIT_PREFIX}{repo}" + now = time.time() + window_start = now - 60 + + member = f"{now}:{repo}" + client.zadd(key, {member: now}) + client.zremrangebyscore(key, "-inf", window_start) + count = client.zcard(key) + client.expire(key, 120) + + if count > config.rate_limit_per_min: + logger.warning( + "rate limit exceeded key=%s count=%d limit=%d", + key, + count, + config.rate_limit_per_min, + ) + return False + return True + except RedisError as e: + logger.exception("redis rate limit check failed") + raise HTTPException(500, f"rate limit check failed: {e}") from e + + +def _state_key(delivery_id: str, downstream_repo: str, check_run_id: str) -> str: + """Redis key for callback state machine. + + Keyed by delivery_id + repo + check_run_id to support per-execution state tracking. + check_run_id is unique per job execution, enabling replay attack detection. + """ + return f"{_STATE_PREFIX}{delivery_id}:{downstream_repo}:{check_run_id}" + + +def get_callback_state( + config: RelayConfig, + delivery_id: str, + downstream_repo: str, + check_run_id: str, + client: redis_lib.Redis | None = None, +) -> CallbackStateRecord | None: + """Get callback state record from Redis, or None if no record exists. + + Returns a record containing state, timestamp, and optional job metadata. + """ + try: + if client is None: + client = create_client(config) + key = _state_key(delivery_id, downstream_repo, check_run_id) + value = client.get(key) + if value is None: + return None + data = json.loads(value) + return CallbackStateRecord( + state=CallbackState(data["state"]), + timestamp=data["timestamp"], + job_name=data["job_name"], + run_id=data["run_id"], + ) + except RedisError: + logger.exception("redis temporary outage or unreachable") + except Exception: + logger.exception("redis get_callback_state failed") + return None + + +def set_callback_state( + config: RelayConfig, + delivery_id: str, + downstream_repo: str, + check_run_id: str, + state: CallbackState, + timestamp: float, + job_name: str | None = None, + run_id: int | None = None, + client: redis_lib.Redis | None = None, +) -> None: + """Set callback state with timestamp in Redis. + + State transition validation: + + DISPATCHED state (webhook-side): + - None -> DISPATCHED: accept (initial dispatch) + - DISPATCHED -> DISPATCHED: reject (duplicate webhook) + + IN_PROGRESS state (callback-side): + - None -> IN_PROGRESS: accept (first callback for this check_run_id) + - IN_PROGRESS -> IN_PROGRESS: reject (replay attack for same check_run_id) + + COMPLETED state (callback-side): + - None -> COMPLETED: reject (no prior in_progress) + - IN_PROGRESS -> COMPLETED: accept (normal completion) + - COMPLETED -> COMPLETED: reject (duplicate) + """ + error_msg = "" + try: + if client is None: + client = create_client(config) + + if check_run_id == DISPATCH_CHECK_RUN_ID and state != CallbackState.DISPATCHED: + error_msg = ( + "check_run_id '%s' is preserved for DISPATCHED state only, rejecting invalid state=%s" + % ( + DISPATCH_CHECK_RUN_ID, + state.value, + ) + ) + + key = _state_key(delivery_id, downstream_repo, check_run_id) + + current_record = get_callback_state( + config, delivery_id, downstream_repo, check_run_id, client + ) + + if state == CallbackState.DISPATCHED: + if current_record is not None: + error_msg = "rejecting duplicate DISPATCHED key=%s" % key + elif state == CallbackState.IN_PROGRESS: + if current_record is not None: + error_msg = ( + "rejecting replay attack IN_PROGRESS for same " + "check_run_id=%s, downstream_repo=%s, job_name=%s, run_id=%s" + % ( + check_run_id, + downstream_repo, + job_name, + run_id, + ) + ) + + elif state == CallbackState.COMPLETED: + if current_record is None: + error_msg = ( + "rejecting COMPLETED without prior IN_PROGRESS " + "key=%s, downstream_repo=%s, job_name=%s, run_id=%s" + % ( + key, + downstream_repo, + job_name, + run_id, + ) + ) + elif current_record.state == CallbackState.COMPLETED: + error_msg = "rejecting duplicate COMPLETED key=%s" % key + elif current_record.state != CallbackState.IN_PROGRESS: + error_msg = ( + "rejecting abnormal state transition %s -> COMPLETED " + "key=%s, downstream_repo=%s, job_name=%s, run_id=%s" + % ( + current_record.state.value, + key, + downstream_repo, + job_name, + run_id, + ) + ) + + if error_msg: + logger.warning(error_msg) + raise AssertionError(error_msg) + + data: dict = { + "state": state.value, + "timestamp": timestamp, + "job_name": job_name, + "run_id": run_id, + } + client.setex(key, config.oot_status_ttl, json.dumps(data)) + logger.info( + "callback state set key=%s state=%s timestamp=%s job_name=%s run_id=%s", + key, + state.value, + timestamp, + job_name, + run_id, + ) + except RedisError: + logger.exception("set_callback_state: redis is temporary outage or unreachable") + raise + except Exception: + logger.exception("redis set_callback_state failed") + raise diff --git a/aws/lambda/cross_repo_ci_relay/webhook/Makefile b/aws/lambda/cross_repo_ci_relay/webhook/Makefile new file mode 100644 index 0000000000..adf53ae9b8 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/webhook/Makefile @@ -0,0 +1,19 @@ +UTILS_SRC := $(wildcard ../utils/*.py) +PIP_FLAGS := --platform manylinux2014_x86_64 --only-binary=:all: --implementation cp --python-version 3.13 +AWS_REGION := us-east-1 +FUNCTION_NAME := cross_repo_ci_webhook + +deployment.zip: clean + mkdir -p ./deployment/webhook + mkdir -p ./deployment/utils + cp *.py ./deployment/webhook/ && cp $(UTILS_SRC) ./deployment/utils/ + pip3 install --target ./deployment -r ../requirements.txt $(PIP_FLAGS) + cd deployment && zip -r ../deployment.zip . + +deploy: deployment.zip + aws lambda update-function-code --region $(AWS_REGION) --function-name $(FUNCTION_NAME) --zip-file fileb://deployment.zip + +clean: + rm -rf deployment deployment.zip + +.PHONY: deploy clean diff --git a/aws/lambda/cross_repo_ci_relay/event_handler.py b/aws/lambda/cross_repo_ci_relay/webhook/event_handler.py similarity index 84% rename from aws/lambda/cross_repo_ci_relay/event_handler.py rename to aws/lambda/cross_repo_ci_relay/webhook/event_handler.py index 274dffd593..c763e95230 100644 --- a/aws/lambda/cross_repo_ci_relay/event_handler.py +++ b/aws/lambda/cross_repo_ci_relay/webhook/event_handler.py @@ -2,12 +2,18 @@ import json import logging +import time from concurrent.futures import as_completed, ThreadPoolExecutor -import gh_helper -from allowlist import AllowlistLevel, load_allowlist -from config import RelayConfig -from utils import EventDispatchPayload, HTTPException +from utils import gh_helper, redis_helper +from utils.allowlist import AllowlistLevel, load_allowlist +from utils.config import RelayConfig +from utils.misc import ( + CallbackState, + DISPATCH_CHECK_RUN_ID, + EventDispatchPayload, + HTTPException, +) logger = logging.getLogger(__name__) @@ -33,6 +39,18 @@ def _dispatch_one( client_payload=client_payload, ) + # Set dispatch state with timestamp to prove valid webhook occurred. + # Keyed by delivery_id + repo + DISPATCH_JOB_NAME="*" (repo-level, not job-specific). + # Timestamp is used for queue_time calculation (dispatch → in_progress). + redis_helper.set_callback_state( + config, + client_payload["delivery_id"], + downstream_repo, + DISPATCH_CHECK_RUN_ID, + CallbackState.DISPATCHED, + time.time(), + ) + def _dispatch_to_allowlist( *, @@ -77,13 +95,12 @@ def _dispatch_to_allowlist( ) dispatched.append({"repo": downstream_repo}) except Exception as e: - error_message = str(e) - logger.error( - "dispatch failed event_type=%s repo=%s error=%s", + logger.exception( + "dispatch failed event_type=%s repo=%s", event_type, downstream_repo, - error_message, ) + error_message = str(e) failed.append( { "repo": downstream_repo, diff --git a/aws/lambda/cross_repo_ci_relay/lambda_function.py b/aws/lambda/cross_repo_ci_relay/webhook/lambda_function.py similarity index 69% rename from aws/lambda/cross_repo_ci_relay/lambda_function.py rename to aws/lambda/cross_repo_ci_relay/webhook/lambda_function.py index e951a5926e..af67e9c979 100644 --- a/aws/lambda/cross_repo_ci_relay/lambda_function.py +++ b/aws/lambda/cross_repo_ci_relay/webhook/lambda_function.py @@ -1,21 +1,19 @@ from __future__ import annotations -import base64 import hashlib import hmac import json import logging -import event_handler -from config import RelayConfig -from utils import HTTPException +from utils.config import get_config +from utils.misc import HTTPException, JSON_HEADERS, parse_lambda_event + +from . import event_handler logging.getLogger().setLevel(logging.INFO) logger = logging.getLogger(__name__) -_cached_config: RelayConfig | None = None - def _verify_signature(secret: str, body: bytes, signature: str) -> None: if not signature: @@ -27,29 +25,11 @@ def _verify_signature(secret: str, body: bytes, signature: str) -> None: raise HTTPException(status_code=401, detail="Bad signature") -_JSON_HEADERS = {"content-type": "application/json"} _SUPPORTED_EVENTS = frozenset({"pull_request", "push"}) -def _get_config() -> RelayConfig: - global _cached_config - if _cached_config is None: - _cached_config = RelayConfig.from_env() - return _cached_config - - def lambda_handler(event, context): - http = event.get("requestContext", {}).get("http", {}) - method = http.get("method", "").upper() - path = http.get("path", "") - - raw_body = event.get("body") or "" - body_bytes = ( - base64.b64decode(raw_body) - if event.get("isBase64Encoded") - else raw_body.encode("utf-8") - ) - headers = {k.lower(): v for k, v in (event.get("headers") or {}).items()} + method, path, body_bytes, headers = parse_lambda_event(event) delivery = headers.get("x-github-delivery", "") logger.info("request method=%s path=%s delivery=%s", method, path, delivery) @@ -58,12 +38,12 @@ def lambda_handler(event, context): if path == "/github/webhook": return { "statusCode": 405, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"detail": "Method not allowed"}), } return { "statusCode": 404, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"detail": "Not found"}), } @@ -72,12 +52,12 @@ def lambda_handler(event, context): logger.info("event=%s ignored before verification", event_type) return { "statusCode": 200, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"ignored": True}), } try: - config = _get_config() + config = get_config() _verify_signature( config.github_app_secret, body_bytes, headers.get("x-hub-signature-256", "") @@ -90,7 +70,7 @@ def lambda_handler(event, context): logger.info("repo=%s not upstream, ignored", repo) return { "statusCode": 200, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"ignored": True}), } @@ -100,24 +80,28 @@ def lambda_handler(event, context): event_type=event_type, delivery_id=delivery, ) - return {"statusCode": 200, "headers": _JSON_HEADERS, "body": json.dumps(result)} + return {"statusCode": 200, "headers": JSON_HEADERS, "body": json.dumps(result)} except json.JSONDecodeError: + logger.warning("invalid JSON body in webhook request") return { "statusCode": 400, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"detail": "Invalid JSON body"}), } except HTTPException as exc: + logger.warning( + "http exception status=%d detail=%s", exc.status_code, exc.detail + ) return { "statusCode": exc.status_code, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"detail": exc.detail}), } except Exception: logger.exception("unhandled error") return { "statusCode": 500, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"detail": "Internal server error"}), }