From b98e0428c744262a015cb766a7282bf872a429bc Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Tue, 14 Apr 2026 22:42:33 +0800 Subject: [PATCH 1/7] CRCR L2 implementation Co-authored-by: can-gaa-hou Co-authored-by: fffrog --- .../cross-repo-ci-relay-callback/action.yml | 120 ++++++++++ .../workflows/_lambda-do-release-runners.yml | 3 +- aws/lambda/cross_repo_ci_relay/Makefile | 33 +-- aws/lambda/cross_repo_ci_relay/README.md | 213 +++++++++++++++++- .../cross_repo_ci_relay/callback/Makefile | 18 ++ .../callback/lambda_function.py | 71 ++++++ .../callback/result_handler.py | 99 ++++++++ .../cross_repo_ci_relay/local_server.py | 38 +++- .../cross_repo_ci_relay/requirements.txt | 1 + .../tests/test_allowlist.py | 2 +- .../cross_repo_ci_relay/tests/test_config.py | 5 +- .../tests/test_event_handler.py | 15 +- ...nction.py => test_event_handler_lambda.py} | 23 +- .../cross_repo_ci_relay/tests/test_hud.py | 71 ++++++ .../tests/test_jwt_helper.py | 61 +++++ .../tests/test_redis_helper.py | 52 ++++- .../tests/test_result_handler.py | 170 ++++++++++++++ .../tests/test_result_handler_lambda.py | 116 ++++++++++ aws/lambda/cross_repo_ci_relay/utils.py | 13 -- .../{ => utils}/allowlist.py | 9 +- .../cross_repo_ci_relay/{ => utils}/config.py | 31 ++- .../{ => utils}/gh_helper.py | 3 +- aws/lambda/cross_repo_ci_relay/utils/hud.py | 86 +++++++ .../cross_repo_ci_relay/utils/jwt_helper.py | 43 ++++ aws/lambda/cross_repo_ci_relay/utils/misc.py | 53 +++++ .../{ => utils}/redis_helper.py | 101 +++++++-- .../cross_repo_ci_relay/webhook/Makefile | 18 ++ .../{ => webhook}/event_handler.py | 27 ++- .../{ => webhook}/lambda_function.py | 51 ++--- 29 files changed, 1403 insertions(+), 143 deletions(-) create mode 100644 .github/actions/cross-repo-ci-relay-callback/action.yml create mode 100644 aws/lambda/cross_repo_ci_relay/callback/Makefile create mode 100644 aws/lambda/cross_repo_ci_relay/callback/lambda_function.py create mode 100644 aws/lambda/cross_repo_ci_relay/callback/result_handler.py rename aws/lambda/cross_repo_ci_relay/tests/{test_lambda_function.py => test_event_handler_lambda.py} (83%) create mode 100644 aws/lambda/cross_repo_ci_relay/tests/test_hud.py create mode 100644 aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py create mode 100644 aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py create mode 100644 aws/lambda/cross_repo_ci_relay/tests/test_result_handler_lambda.py delete mode 100644 aws/lambda/cross_repo_ci_relay/utils.py rename aws/lambda/cross_repo_ci_relay/{ => utils}/allowlist.py (96%) rename aws/lambda/cross_repo_ci_relay/{ => utils}/config.py (81%) rename aws/lambda/cross_repo_ci_relay/{ => utils}/gh_helper.py (98%) create mode 100644 aws/lambda/cross_repo_ci_relay/utils/hud.py create mode 100644 aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py create mode 100644 aws/lambda/cross_repo_ci_relay/utils/misc.py rename aws/lambda/cross_repo_ci_relay/{ => utils}/redis_helper.py (54%) create mode 100644 aws/lambda/cross_repo_ci_relay/webhook/Makefile rename aws/lambda/cross_repo_ci_relay/{ => webhook}/event_handler.py (86%) rename aws/lambda/cross_repo_ci_relay/{ => webhook}/lambda_function.py (69%) diff --git a/.github/actions/cross-repo-ci-relay-callback/action.yml b/.github/actions/cross-repo-ci-relay-callback/action.yml new file mode 100644 index 0000000000..10f86070a3 --- /dev/null +++ b/.github/actions/cross-repo-ci-relay-callback/action.yml @@ -0,0 +1,120 @@ +name: Cross-Repo CI Relay Callback + +description: > + Report the status of a downstream CI workflow back to the Cross-Repo CI + Relay server. The job must have `id-token: write` permission so that a + GitHub OIDC token can be minted and used to authenticate the callback. + + This action is meant to run in a workflow triggered by a `repository_dispatch` + event from the relay. It reads the dispatch payload (`github.event.client_payload`) + and the ambient `github` context directly, so workflow authors only need to + supply the relay URL, the status/conclusion, and optional structured test + results. + +inputs: + status: + description: > + Workflow status to report. Must be either "in_progress" or "completed". + required: true + conclusion: + description: > + Conclusion of the workflow run. Required (and must be "success" or + "failure") when status is "completed". Ignored when status is + "in_progress". + required: false + default: '' + test-results: + description: > + Optional JSON string with structured test results produced by the + downstream job. Forwarded verbatim to the relay under `test_results`. + required: false + default: '' + callback-url: + description: > + Base URL of the result callback server. + required: false + default: https://zciswjsrynb6ksccc2nb3mckpa0ldzou.lambda-url.us-east-1.on.aws/github/result + +runs: + using: composite + steps: + - name: Mint OIDC token + id: oidc + uses: actions/github-script@v7 + with: + script: | + const audience = `https://github.com/${context.repo.owner}/${context.repo.repo}`; + const token = await core.getIDToken(audience); + core.setSecret(token); + core.setOutput('token', token); + + - name: Send callback to relay server + shell: bash + env: + STATUS: ${{ inputs.status }} + CONCLUSION: ${{ inputs.conclusion }} + WORKFLOW_NAME: ${{ github.workflow }} + WORKFLOW_URL: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }} + TEST_RESULTS: ${{ inputs.test-results }} + CLIENT_PAYLOAD: ${{ toJson(github.event.client_payload) }} + OIDC_TOKEN: ${{ steps.oidc.outputs.token }} + CALLBACK_URL: ${{ inputs.callback-url }} + run: | + set -euo pipefail + + PAYLOAD=$(python3 - <<'PYEOF' + import json, os, sys + + status = os.environ["STATUS"] + if status not in ("in_progress", "completed"): + sys.exit(f"::error::status must be 'in_progress' or 'completed', got {status!r}") + + conclusion = os.environ.get("CONCLUSION", "").strip() or None + if status == "completed" and conclusion not in ("success", "failure"): + sys.exit("::error::conclusion must be 'success' or 'failure' when status is 'completed'") + if status == "in_progress": + conclusion = None + + try: + client_payload = json.loads(os.environ["CLIENT_PAYLOAD"]) + except json.JSONDecodeError as exc: + sys.exit(f"::error::github.event.client_payload is not valid JSON: {exc}") + + # Relay's original dispatch payload (event_type, delivery_id, payload) is + # forwarded verbatim. Downstream-reported fields live in a sibling + # `workflow` dict so the two sources stay clearly separated on the wire. + workflow: dict = { + "status": status, + "conclusion": conclusion, + "name": os.environ["WORKFLOW_NAME"], + "url": os.environ["WORKFLOW_URL"], + } + + test_results = os.environ.get("TEST_RESULTS", "").strip() + if test_results: + try: + workflow["test_results"] = json.loads(test_results) + except json.JSONDecodeError as exc: + sys.exit(f"::error::test-results input is not valid JSON: {exc}") + + client_payload["workflow"] = workflow + print(json.dumps(client_payload)) + PYEOF + ) + + HTTP_CODE=$( + curl --silent --show-error --output /tmp/relay_response.json \ + --write-out "%{http_code}" \ + -X POST \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer ${OIDC_TOKEN}" \ + --data "${PAYLOAD}" \ + "${CALLBACK_URL%/}" + ) + + if [[ "${HTTP_CODE}" -lt 200 || "${HTTP_CODE}" -ge 300 ]]; then + echo "::error::Callback server returned HTTP ${HTTP_CODE}." + exit 1 + fi + + echo "Relay server response HTTP: ${HTTP_CODE}" diff --git a/.github/workflows/_lambda-do-release-runners.yml b/.github/workflows/_lambda-do-release-runners.yml index 07c33be885..0a253308d5 100644 --- a/.github/workflows/_lambda-do-release-runners.yml +++ b/.github/workflows/_lambda-do-release-runners.yml @@ -93,7 +93,8 @@ jobs: { dir-name: 'keep-going-call-log-classifier', zip-name: 'keep-going-call-log-classifier' }, { dir-name: 'buildkite-webhook-handler', zip-name: 'buildkite-webhook-handler' }, { dir-name: 'benchmark_regression_summary_report', zip-name: 'benchmark-regression-summary-report' }, - { dir-name: 'cross_repo_ci_relay', zip-name: 'cross-repo-ci-webhook' }, + { dir-name: 'cross_repo_ci_relay/result', zip-name: 'cross-repo-ci-result' }, + { dir-name: 'cross_repo_ci_relay/webhook', zip-name: 'cross-repo-ci-webhook' }, ] name: Upload Release for ${{ matrix.dir-name }} lambda runs-on: ubuntu-latest diff --git a/aws/lambda/cross_repo_ci_relay/Makefile b/aws/lambda/cross_repo_ci_relay/Makefile index c1f07bbc56..39fd320825 100644 --- a/aws/lambda/cross_repo_ci_relay/Makefile +++ b/aws/lambda/cross_repo_ci_relay/Makefile @@ -1,21 +1,24 @@ -SHARED := config.py utils.py redis_helper.py allowlist.py gh_helper.py event_handler.py -PIP_FLAGS := --platform manylinux2014_x86_64 --only-binary=:all: --implementation cp --python-version 3.13 -AWS_REGION := us-east-1 -FUNCTION_NAME := cross_repo_ci_webhook +AWS_REGION ?= us-east-1 +CALLBACK_FUNCTION_NAME ?= cross_repo_ci_callback +WEBHOOK_FUNCTION_NAME ?= cross_repo_ci_webhook -deployment.zip: clean - mkdir -p ./deployment - cp $(SHARED) lambda_function.py ./deployment/ - pip3 install --target ./deployment -r requirements.txt $(PIP_FLAGS) - cd deployment && zip -r ../deployment.zip . - -deploy: deployment.zip - aws lambda update-function-code --region $(AWS_REGION) --function-name $(FUNCTION_NAME) --zip-file fileb://deployment.zip +.PHONY: test deploy deploy-callback deploy-webhook clean test: python3 -m pytest tests -v -clean: - rm -rf deployment deployment.zip +# Deploy both; keep going on failure so a broken half doesn't block the other. +deploy: + @rc=0; \ + for t in deploy-callback deploy-webhook; do $(MAKE) $$t || rc=$$?; done; \ + exit $$rc + +deploy-callback: + $(MAKE) -C callback deploy AWS_REGION=$(AWS_REGION) FUNCTION_NAME=$(CALLBACK_FUNCTION_NAME) -.PHONY: prepare deploy test clean +deploy-webhook: + $(MAKE) -C webhook deploy AWS_REGION=$(AWS_REGION) FUNCTION_NAME=$(WEBHOOK_FUNCTION_NAME) + +clean: + $(MAKE) -C callback clean + $(MAKE) -C webhook clean diff --git a/aws/lambda/cross_repo_ci_relay/README.md b/aws/lambda/cross_repo_ci_relay/README.md index 2df4335816..331c911d67 100644 --- a/aws/lambda/cross_repo_ci_relay/README.md +++ b/aws/lambda/cross_repo_ci_relay/README.md @@ -1,6 +1,6 @@ # Cross Repo CI Relay -An AWS Lambda function that relays GitHub webhook events from the upstream repository to downstream repositories. +An AWS Lambda function that relays GitHub webhook events from the upstream repository to downstream repositories, and forwards downstream CI results to HUD. For more information, please refer to this [RFC](https://github.com/pytorch/pytorch/issues/175022). @@ -31,33 +31,218 @@ Each entry is either a plain `owner/repo` string or a `owner/repo: oncall1, onca The allowlist is cached in Redis under the key `crcr:allowlist_yaml` with a TTL controlled by `ALLOWLIST_TTL_SECONDS`. On a Redis error the function falls back to fetching directly from GitHub. +## Reporting Results from Downstream CI + +L2+ downstream repositories can report the status of their CI workflows back to the relay server using the [`cross-repo-ci-relay-callback`](../../../.github/actions/cross-repo-ci-relay-callback/action.yml) composite action. + +### Security and the Relay/HUD boundary + +The callback endpoint is a **near-transparent proxy to HUD** with a single +security responsibility: identifying the calling repo. Everything else — +schema validation, persistence, dedup — is HUD's job. + +- **Identity (Relay's job)**: the `Authorization: Bearer ` header + is verified against GitHub's JWKS. The OIDC `repository` claim is the only + trusted identity for the caller and is used for the L2+ allowlist check. + Relay forwards this trusted value to HUD as a top-level `authenticated_repo` + field; HUD should prefer it over anything self-reported in `body`. +- **Schema / business validation (HUD's job)**: the callback body is passed + through to HUD verbatim as a top-level `body` field. Relay does **not** + validate fields inside `body.workflow` or `body.payload` — HUD owns the + schema since it owns persistence. Relay only enforces that `delivery_id` + and `workflow.status` are present (contract violation → `400`). +- **Timing (Relay's job)**: Relay records dispatch/in-progress timestamps in + Redis keyed on `delivery_id`, then computes `queue_time` and + `execution_time` and surfaces them to HUD as `ci_metrics`. Each callback + phase reports exactly one metric: + - `in_progress` → `queue_time` (dispatch → in_progress) + - `completed` → `execution_time` (in_progress → completed) + +The HUD request looks like: + +```json +{ + "body": { + "event_type": "pull_request", + "delivery_id": "", + "payload": { ...original upstream webhook payload, verbatim... }, + "workflow": { + "status": "completed", + "conclusion": "success", + "name": "CI", + "url": "https://github.com/org/repo/actions/runs/123", + "test_results": { ... } + } + }, + "ci_metrics": { "queue_time": 1.23, "execution_time": null }, + "authenticated_repo": "org/repo" +} +``` + +Trust boundaries inside `body`: + +- `body.payload` is the upstream webhook payload, transparently forwarded — + trusted at dispatch time, but not re-verified on the callback. +- `body.workflow` is **self-reported by the downstream CI** and is not + authenticated. Only `authenticated_repo` carries a cryptographic identity. + +### Error propagation back to the downstream workflow + +| HUD response | Relay behaviour | Effect on downstream CI step | +|---|---|---| +| `2xx` | record delivered | green | +| `4xx` (schema reject) | propagate same status | **red** — author must fix payload | +| `5xx` / network error | log + swallow | green — HUD outage is not the caller's fault | + +The asymmetry is deliberate: `4xx` means the caller sent something wrong and +should see it; `5xx`/network means HUD or its infrastructure is broken and +should not be surfaced as a red CI step across every L2 repo. Operators are +expected to alert on the `HUD forward failed` CloudWatch log pattern. + +#### Known limitations of this model + +A compromised or malicious maintainer of an allowlisted repo can: + +1. Fabricate `workflow.status` / `workflow.conclusion` values for upstream PRs + their repo was never dispatched for — HUD will receive the row, but + `authenticated_repo` always identifies the true caller. +2. Replay an older dispatched payload against the callback endpoint — there + is no dispatch-side nonce. +3. Tamper with any field inside `body` — HUD must trust `authenticated_repo`, + not the body. + +All three attacks are **scoped to the attacker's own OIDC-authenticated repo +identity** — OIDC guarantees they cannot impersonate another allowlisted +repo. Mitigation is operational: every HUD row carries `authenticated_repo`, +so misbehaviour is observable, and the offending repo can be removed from +`allowlist.yaml`. + +If stronger guarantees are required later, the typical next step is a signed +callback token minted by the webhook side plus a one-shot state machine in +Redis keyed on `delivery_id`. This was intentionally deferred to keep the +relay simple — see the PR description for the discussion. + +### Prerequisites + +- The downstream repository must be listed at level **L2 or higher** in the allowlist. +- The **calling job** must declare `permissions: id-token: write` so that the action can mint a GitHub OIDC token for authentication. + +### Usage + +When triggered by a relay `repository_dispatch`, the action automatically +reads `github.event.client_payload` for `delivery_id` and the upstream +webhook payload, and reads `github.workflow` / the current run URL for the +workflow identity. Workflow authors only need to pass `status` (and +`conclusion` when `status=completed`). + +```yaml +on: + repository_dispatch: + types: [pull_request] + +jobs: + my-ci-job: + runs-on: ubuntu-latest + permissions: + id-token: write # required for OIDC token minting + contents: read + steps: + - name: Report in-progress to relay + uses: pytorch/test-infra/.github/actions/cross-repo-ci-relay-callback@main + with: + status: in_progress + + # ... your CI steps ... + + - name: Report final result to relay + if: always() + uses: pytorch/test-infra/.github/actions/cross-repo-ci-relay-callback@main + with: + status: completed + conclusion: ${{ job.status == 'success' && 'success' || 'failure' }} +``` + +### Inputs + +| Input | Required | Default | Description | +|---|---|---|---| +| `status` | **yes** | — | `in_progress` or `completed` | +| `conclusion` | no | `''` | `success` or `failure` (required when `status=completed`) | +| `test-results` | no | `''` | Optional JSON string forwarded under `body.workflow.test_results` | +| `callback-url` | no | see `action.yml` | Callback endpoint URL (overridable for local testing) | + ## Build, Deploy, and Test +### Deployment layout + +The build packages each Lambda as a zip that preserves the package hierarchy: + +``` +deployment/ +├── webhook/ +│ ├── lambda_function.py +│ └── event_handler.py +├── callback/ +│ ├── lambda_function.py +│ └── result_handler.py +└── utils/ + └── ... +``` + +This matches the layout used during local development and tests, so imports +behave identically in both environments. Configure the AWS Lambda handlers +as: + +- Webhook Lambda: `webhook.lambda_function.lambda_handler` +- Callback Lambda: `callback.lambda_function.lambda_handler` + ### Make Targets -Build the Lambda zip (output: deployment.zip) +Build the Webhook Lambda zip (output: `webhook/deployment.zip`): + +```bash +cd webhook +make deployment.zip +``` + +Build the Callback Lambda zip (output: `callback/deployment.zip`): + ```bash +cd callback make deployment.zip ``` -Deploy to AWS Lambda (requires AWS CLI v2 configured with permissions) +Deploy both zips to AWS Lambda (requires AWS CLI v2 with permissions): + ```bash -make deploy AWS_REGION=us-east-1 FUNCTION_NAME=crcr-prod-crcr-webhook +make deploy AWS_REGION=us-east-1 \ + WEBHOOK_FUNCTION_NAME=cross_repo_ci_webhook \ + CALLBACK_FUNCTION_NAME=cross_repo_ci_callback ``` -Run all unit tests under tests/ folder +Either side can be deployed independently: + +```bash +make deploy-webhook +make deploy-callback +``` + +Run all unit tests under `tests/`: + ```bash make test ``` -Clean build artifacts +Clean build artifacts: + ```bash make clean ``` ## Local Development -`local_server.py` wraps the Lambda handler in a FastAPI app so you can test the full cross-repo-ci-relay flow without deploying to AWS. +`local_server.py` wraps both Lambda handlers in a FastAPI app so you can test +the full cross-repo-ci-relay flow without deploying to AWS. ### Prerequisites @@ -72,12 +257,20 @@ make clean redis:7-alpine \ redis-server --requirepass ``` -- [smee.io](https://smee.io) CLI to forward GitHub webhook events to localhost (paste this link to GitHub App webhook URL): +- [smee.io](https://smee.io) + + CLI to forward GitHub webhook events to localhost (paste this link to GitHub App webhook URL): ```bash npm install -g smee-client smee --url https://smee.io/ --path /github/webhook --port 8000 ``` + CLI to forward GitHub result callbacks to localhost (set this URL as `callback-url` in the downstream workflow): + ```bash + npm install -g smee-client + smee --url https://smee.io/ --path /github/result --port 8000 + ``` + #### Remote - GitHub App settings (refer to this [RFC](https://github.com/pytorch/pytorch/issues/175022)) @@ -110,7 +303,7 @@ make clean REDIS_LOGIN=default: ALLOWLIST_TTL_SECONDS=1200 ``` - **Note**: `ALLOWLIST_URL` is required for local development which should point to a GitHub URL that can be different from the real one. + **Note**: `ALLOWLIST_URL` is required for local development and should point to a GitHub URL (it can differ from the production one). 3. Start the server: ```bash @@ -118,3 +311,5 @@ make clean ``` 4. Point your GitHub App's webhook URL to the smee.io channel, then open or update a pull request in the upstream repo to trigger a full relay cycle. + +5. Check whether the workflow run status is reported back through `callback-url`. diff --git a/aws/lambda/cross_repo_ci_relay/callback/Makefile b/aws/lambda/cross_repo_ci_relay/callback/Makefile new file mode 100644 index 0000000000..bd2ec01e87 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/callback/Makefile @@ -0,0 +1,18 @@ +UTILS_SRC := $(wildcard ../utils/*.py) +PIP_FLAGS := --platform manylinux2014_x86_64 --only-binary=:all: --implementation cp --python-version 3.13 +AWS_REGION := us-east-1 +FUNCTION_NAME := cross_repo_ci_callback + +deployment.zip: clean + mkdir -p ./deployment/{callback,utils} + cp *.py ./deployment/callback/ && cp $(UTILS_SRC) ./deployment/utils/ + pip3 install --target ./deployment -r ../requirements.txt $(PIP_FLAGS) + cd deployment && zip -r ../deployment.zip . + +deploy: deployment.zip + aws lambda update-function-code --region $(AWS_REGION) --function-name $(FUNCTION_NAME) --zip-file fileb://deployment.zip + +clean: + rm -rf deployment deployment.zip + +.PHONY: deploy clean diff --git a/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py b/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py new file mode 100644 index 0000000000..8148773ff6 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import json +import logging + +from utils import jwt_helper +from utils.config import get_config +from utils.misc import HTTPException, JSON_HEADERS, parse_lambda_event + +from . import result_handler + +logging.getLogger().setLevel(logging.INFO) +logger = logging.getLogger(__name__) + + +def lambda_handler(event, context): + method, path, body_bytes, headers = parse_lambda_event(event) + + logger.info("request method=%s path=%s", method, path) + + if method != "POST" or path != "/github/result": + if path == "/github/result": + return { + "statusCode": 405, + "headers": JSON_HEADERS, + "body": json.dumps({"detail": "Method not allowed"}), + } + return { + "statusCode": 404, + "headers": JSON_HEADERS, + "body": json.dumps({"detail": "Not found"}), + } + + try: + config = get_config() + body = json.loads(body_bytes) if body_bytes else {} + + # OIDC is the only identity check Relay performs. The callback body is + # passed through to HUD untouched — HUD owns schema/business validation. + # Relay reports the OIDC-verified repo to HUD separately as + # `authenticated_repo` so HUD has a trusted source of truth for the + # caller's identity. + oidc_claims = jwt_helper.verify_oidc_token( + config, headers.get("authorization", "") + ) + verified_repo = oidc_claims["repository"] + + result = result_handler.handle(config, body, verified_repo) + return {"statusCode": 200, "headers": JSON_HEADERS, "body": json.dumps(result)} + + except json.JSONDecodeError: + logger.exception("Invalid JSON body") + return { + "statusCode": 400, + "headers": JSON_HEADERS, + "body": json.dumps({"detail": "Invalid JSON body"}), + } + except HTTPException as exc: + logger.exception(exc.detail) + return { + "statusCode": exc.status_code, + "headers": JSON_HEADERS, + "body": json.dumps({"detail": exc.detail}), + } + except Exception: + logger.exception("Internal server error") + return { + "statusCode": 500, + "headers": JSON_HEADERS, + "body": json.dumps({"detail": "Internal server error"}), + } diff --git a/aws/lambda/cross_repo_ci_relay/callback/result_handler.py b/aws/lambda/cross_repo_ci_relay/callback/result_handler.py new file mode 100644 index 0000000000..d8d191ab07 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/callback/result_handler.py @@ -0,0 +1,99 @@ +from __future__ import annotations + +import logging +import time + +import utils.redis_helper as redis_helper +from utils.allowlist import AllowlistLevel, load_allowlist +from utils.config import RelayConfig +from utils.hud import forward_to_hud +from utils.misc import HTTPException, TimingPhase + + +logger = logging.getLogger(__name__) + + +def _safe_delta( + start_ts: float | None, end_ts: float | None, label: str +) -> float | None: + """Compute end-start, clamping tiny negatives to 0 and returning None + whenever either endpoint is missing (e.g. Redis cache miss).""" + if start_ts is None or end_ts is None: + return None + delta = round(end_ts - start_ts, 3) + if delta < 0: + logger.warning("negative %s computed, start=%s end=%s", label, start_ts, end_ts) + return 0 + return delta + + +def handle(config: RelayConfig, body: dict, verified_repo: str) -> dict: + """Forward a downstream callback to HUD. + + ``body`` is the downstream self-report, passed through to HUD verbatim. + It carries the original dispatch envelope (``delivery_id``, ``payload``) + and a sibling ``workflow`` dict with status/conclusion/name/url. + + ``verified_repo`` is the OIDC-authenticated downstream repository — used + for allowlist / timing lookups, and surfaced to HUD as ``authenticated_repo`` + so HUD can trust it over anything self-reported in the body. + """ + allowlist = load_allowlist(config) + l2_repos, _ = allowlist.get_repos_at_or_above_level(AllowlistLevel.L2) + + if verified_repo not in l2_repos: + logger.info( + "verified_repo %s is not configured for L2+ features, ignoring result", + verified_repo, + ) + return {"ok": True, "status": "ignored"} + + # delivery_id and workflow.status are required fields on the callback body — + # the relay-callback action always sets them. A missing value is a contract + # violation from the downstream, so fail loudly rather than silently + # producing a HUD row with no timing. + try: + delivery_id = body["delivery_id"] + status = body["workflow"]["status"] + except (KeyError, TypeError) as exc: + raise HTTPException( + 400, f"callback body missing required field: {exc}" + ) from exc + + # Each phase reports exactly one metric so HUD receives a clean, + # single-purpose row per callback: + # in_progress → queue_time (dispatch → in_progress) + # completed → execution_time (in_progress → completed) + # + # Timing keys are indexed by the body-reported delivery_id and the + # OIDC-verified repo. delivery_id is not independently authenticated — + # a tampered value just misses the timing cache, which only hurts the + # attacker's own HUD row. + ci_metrics: dict = {"queue_time": None, "execution_time": None} + if status == "in_progress": + in_progress_at = time.time() + redis_helper.set_timing( + config, delivery_id, verified_repo, TimingPhase.IN_PROGRESS, in_progress_at + ) + dispatch_at = redis_helper.get_timing( + config, delivery_id, verified_repo, TimingPhase.DISPATCH + ) + ci_metrics["queue_time"] = _safe_delta( + dispatch_at, in_progress_at, "queue_time" + ) + elif status == "completed": + completed_at = time.time() + in_progress_at = redis_helper.get_timing( + config, delivery_id, verified_repo, TimingPhase.IN_PROGRESS + ) + ci_metrics["execution_time"] = _safe_delta( + in_progress_at, completed_at, "execution_time" + ) + + # HUD owns schema validation: its 4xx surfaces back to the workflow author + # (forward_to_hud raises HTTPException). 5xx / network failures are + # swallowed inside forward_to_hud — they're HUD/infra problems and should + # not turn every downstream L2 CI red. + forward_to_hud(config, body, ci_metrics, verified_repo) + + return {"ok": True, "status": status} diff --git a/aws/lambda/cross_repo_ci_relay/local_server.py b/aws/lambda/cross_repo_ci_relay/local_server.py index 9f41230b4c..483e3daee6 100644 --- a/aws/lambda/cross_repo_ci_relay/local_server.py +++ b/aws/lambda/cross_repo_ci_relay/local_server.py @@ -7,13 +7,14 @@ load_dotenv(find_dotenv(usecwd=True)) -import lambda_function +from callback import lambda_function as callback_lambda +from webhook import lambda_function as webhook_lambda -webhook_router = APIRouter() +relay_router = APIRouter() -@webhook_router.post("/github/webhook") +@relay_router.post("/github/webhook") async def github_webhook(req: Request): body = await req.body() event = { @@ -28,20 +29,41 @@ async def github_webhook(req: Request): "isBase64Encoded": False, } - result = lambda_function.lambda_handler(event, None) + result = webhook_lambda.lambda_handler(event, None) + return JSONResponse( + status_code=result["statusCode"], content=json.loads(result["body"]) + ) + + +@relay_router.post("/github/result") +async def github_result(req: Request): + body = await req.body() + event = { + "requestContext": { + "http": { + "method": req.method, + "path": req.url.path, + } + }, + "headers": {k.decode(): v.decode() for k, v in req.scope["headers"]}, + "body": body.decode("utf-8"), + "isBase64Encoded": False, + } + + result = callback_lambda.lambda_handler(event, None) return JSONResponse( status_code=result["statusCode"], content=json.loads(result["body"]) ) # ================= FastAPI apps ================= -# - webhook_app: only /github/webhook (for smee forward) +# - relay_router: defines the same endpoints as the Lambda functions, but callable via HTTP for local testing -webhook_app = FastAPI() -webhook_app.include_router(webhook_router) +relay_server = FastAPI() +relay_server.include_router(relay_router) if __name__ == "__main__": import uvicorn - uvicorn.run("local_server:webhook_app", host="0.0.0.0", port=8000, reload=True) + uvicorn.run("local_server:relay_server", host="0.0.0.0", port=8000, reload=True) diff --git a/aws/lambda/cross_repo_ci_relay/requirements.txt b/aws/lambda/cross_repo_ci_relay/requirements.txt index f1e84d1e02..172030f82e 100644 --- a/aws/lambda/cross_repo_ci_relay/requirements.txt +++ b/aws/lambda/cross_repo_ci_relay/requirements.txt @@ -2,3 +2,4 @@ PyYAML==6.0.3 PyGithub==2.9.0 redis==7.4.0 boto3==1.42.78 +PyJWT==2.10.1 diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_allowlist.py b/aws/lambda/cross_repo_ci_relay/tests/test_allowlist.py index 261ef58154..63c77971b6 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_allowlist.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_allowlist.py @@ -1,6 +1,6 @@ import unittest -from allowlist import AllowlistLevel, AllowlistMap +from utils.allowlist import AllowlistLevel, AllowlistMap class TestAllowlistMap(unittest.TestCase): diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_config.py b/aws/lambda/cross_repo_ci_relay/tests/test_config.py index 05ed48f311..a34c820983 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_config.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_config.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import patch -from config import RelayConfig, RelaySecrets +from utils.config import RelayConfig, RelaySecrets _ENV = { @@ -27,7 +27,7 @@ def test_missing_vars_raises(self): with self.assertRaises(RuntimeError): RelayConfig.from_env() - @patch("config.RelaySecrets.from_aws") + @patch("utils.config.RelaySecrets.from_aws") @patch.dict( "os.environ", { @@ -43,6 +43,7 @@ def test_secrets_manager_fallback(self, mock_aws): github_app_secret="s", github_app_private_key="k", redis_login="secret-pass", + hud_bot_key="hud-key", ) cfg = RelayConfig.from_env() self.assertEqual(cfg.github_app_secret, "s") diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py b/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py index 49bfb6d836..48dffd08c4 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py @@ -1,7 +1,7 @@ import unittest from unittest.mock import call, MagicMock, patch -from event_handler import handle +from webhook.event_handler import handle def _cfg(): @@ -9,6 +9,7 @@ def _cfg(): cfg.github_app_id = "12345" cfg.github_app_private_key = "fake-key" cfg.max_dispatch_workers = 4 + cfg.github_app_secret = "test-secret" return cfg @@ -32,9 +33,9 @@ def test_ignored_action(self): {"ignored": True}, ) - @patch("event_handler.gh_helper.create_repository_dispatch") - @patch("event_handler.gh_helper.get_repo_access_token", return_value="tok") - @patch("event_handler.load_allowlist") + @patch("webhook.event_handler.gh_helper.create_repository_dispatch") + @patch("webhook.event_handler.gh_helper.get_repo_access_token", return_value="tok") + @patch("webhook.event_handler.load_allowlist") def test_dispatch_success(self, mock_load, _tok, mock_dispatch): mock_load.return_value = MagicMock( get_repos_at_or_above_level=MagicMock(return_value=(["org/a"], [])) @@ -43,12 +44,12 @@ def test_dispatch_success(self, mock_load, _tok, mock_dispatch): self.assertTrue(result["ok"]) mock_dispatch.assert_called_once() - @patch("event_handler.gh_helper.create_repository_dispatch") + @patch("webhook.event_handler.gh_helper.create_repository_dispatch") @patch( - "event_handler.gh_helper.get_repo_access_token", + "webhook.event_handler.gh_helper.get_repo_access_token", side_effect=["tok-a", "tok-b"], ) - @patch("event_handler.load_allowlist") + @patch("webhook.event_handler.load_allowlist") def test_dispatch_mints_token_per_downstream_repo( self, mock_load, mock_get_repo_access_token, mock_dispatch ): diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_lambda_function.py b/aws/lambda/cross_repo_ci_relay/tests/test_event_handler_lambda.py similarity index 83% rename from aws/lambda/cross_repo_ci_relay/tests/test_lambda_function.py rename to aws/lambda/cross_repo_ci_relay/tests/test_event_handler_lambda.py index 517697a49b..f2ecd5a142 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_lambda_function.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_event_handler_lambda.py @@ -4,9 +4,8 @@ import unittest from unittest.mock import MagicMock, patch -import lambda_function -from lambda_function import lambda_handler -from utils import HTTPException +from utils.misc import HTTPException +from webhook.lambda_function import lambda_handler SECRET = "test-key" @@ -54,41 +53,43 @@ def _event(*, method="POST", path="/github/webhook", body=None, headers=None): class TestLambdaHandler(unittest.TestCase): def setUp(self): - lambda_function._cached_config = None + import utils.config + + utils.config._cached_config = None def test_route_error_404_and_405(self): self.assertEqual(lambda_handler(_event(path="/other"), None)["statusCode"], 404) self.assertEqual(lambda_handler(_event(method="GET"), None)["statusCode"], 405) - @patch("lambda_function.RelayConfig.from_env") + @patch("utils.config.RelayConfig.from_env") def test_bad_signature_401(self, mock_env): mock_env.return_value = _cfg() ev = _event(headers={"x-hub-signature-256": "sha256=bad"}) self.assertEqual(lambda_handler(ev, None)["statusCode"], 401) - @patch("lambda_function.RelayConfig.from_env") + @patch("utils.config.RelayConfig.from_env") def test_success_delegates_to_handler(self, mock_env): mock_env.return_value = _cfg() mock_handle = MagicMock(return_value={"ok": True}) - with patch("lambda_function.event_handler.handle", mock_handle): + with patch("webhook.lambda_function.event_handler.handle", mock_handle): resp = lambda_handler(_event(), None) self.assertEqual(resp["statusCode"], 200) mock_handle.assert_called_once() - @patch("lambda_function.RelayConfig.from_env") + @patch("utils.config.RelayConfig.from_env") def test_http_exception_forwarded(self, mock_env): mock_env.return_value = _cfg() with patch( - "lambda_function.event_handler.handle", + "webhook.lambda_function.event_handler.handle", MagicMock(side_effect=HTTPException(502, "err")), ): self.assertEqual(lambda_handler(_event(), None)["statusCode"], 502) - @patch("lambda_function.RelayConfig.from_env") + @patch("utils.config.RelayConfig.from_env") def test_config_cached_across_warm_invocations(self, mock_env): mock_env.return_value = _cfg() with patch( - "lambda_function.event_handler.handle", + "webhook.lambda_function.event_handler.handle", MagicMock(return_value={"ok": True}), ): first = lambda_handler(_event(), None) diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_hud.py b/aws/lambda/cross_repo_ci_relay/tests/test_hud.py new file mode 100644 index 0000000000..68ae991d9a --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/tests/test_hud.py @@ -0,0 +1,71 @@ +import json +import unittest +import urllib.error +from unittest.mock import MagicMock, patch + +from utils.hud import forward_to_hud +from utils.misc import HTTPException + + +def _cfg(url="http://hud/api/oot-ci-events", key="bot-key"): + cfg = MagicMock() + cfg.hud_api_url = url + cfg.hud_bot_key = key + return cfg + + +class TestForwardToHud(unittest.TestCase): + @patch("utils.hud.urllib.request.urlopen") + def test_empty_url_skips_request(self, mock_urlopen): + forward_to_hud(_cfg(url=""), {"delivery_id": "d"}, {}, "org/repo") + mock_urlopen.assert_not_called() + + @patch("utils.hud.urllib.request.urlopen") + def test_hud_payload_has_three_top_level_fields(self, mock_urlopen): + resp = MagicMock() + resp.status = 200 + mock_urlopen.return_value.__enter__.return_value = resp + + report = {"delivery_id": "d", "workflow": {"status": "completed"}} + metrics = {"queue_time": 1.0, "execution_time": 2.0} + forward_to_hud(_cfg(), report, metrics, "org/repo") + + sent = json.loads(mock_urlopen.call_args[0][0].data) + self.assertEqual(sent["body"], report) + self.assertEqual(sent["ci_metrics"], metrics) + self.assertEqual(sent["authenticated_repo"], "org/repo") + + @patch("utils.hud.urllib.request.urlopen") + def test_4xx_propagates_with_huds_status(self, mock_urlopen): + # 4xx means the caller sent bad data — propagate so the downstream + # workflow author sees a red step. + mock_urlopen.side_effect = urllib.error.HTTPError( + "http://hud", 422, "bad schema", {}, None + ) + + with self.assertRaises(HTTPException) as ctx: + forward_to_hud(_cfg(), {}, {}, "org/repo") + self.assertEqual(ctx.exception.status_code, 422) + + @patch("utils.hud.urllib.request.urlopen") + def test_5xx_is_swallowed(self, mock_urlopen): + # 5xx is HUD's own problem — don't turn every downstream CI red. + mock_urlopen.side_effect = urllib.error.HTTPError( + "http://hud", 503, "unavailable", {}, None + ) + + # must not raise + forward_to_hud(_cfg(), {}, {}, "org/repo") + + @patch("utils.hud.urllib.request.urlopen") + def test_url_error_is_swallowed(self, mock_urlopen): + # Network-level failure (DNS, timeout, connection refused) is + # infrastructure, not a caller bug. + mock_urlopen.side_effect = urllib.error.URLError("unreachable") + + # must not raise + forward_to_hud(_cfg(), {}, {}, "org/repo") + + +if __name__ == "__main__": + unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py b/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py new file mode 100644 index 0000000000..e5fd3faf6c --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py @@ -0,0 +1,61 @@ +import unittest +from unittest.mock import MagicMock, patch + +from utils.jwt_helper import verify_oidc_token +from utils.misc import HTTPException + + +def _cfg(): + return MagicMock() + + +class TestVerifyDownstreamIdentity(unittest.TestCase): + def setUp(self): + self.patcher_jwks = patch( + "utils.jwt_helper._jwks_client.get_signing_key_from_jwt" + ) + self.mock_signing_key = self.patcher_jwks.start() + self.mock_signing_key.return_value = MagicMock(key="fake-key") + + self.patcher_decode = patch("utils.jwt_helper.jwt.decode") + self.mock_decode = self.patcher_decode.start() + + def tearDown(self): + self.patcher_jwks.stop() + self.patcher_decode.stop() + + def test_valid_token_returns_claims(self): + expected = { + "repository": "org/repo", + "sub": "repo:org/repo:ref:refs/heads/main", + } + self.mock_decode.return_value = expected + + claims = verify_oidc_token(_cfg(), "some.oidc.token") + + self.assertEqual(claims, expected) + + def test_bearer_prefix_stripped_before_jwks_lookup(self): + self.mock_decode.return_value = {"repository": "org/repo"} + + verify_oidc_token(_cfg(), "Bearer some.oidc.token") + + self.mock_signing_key.assert_called_once_with("some.oidc.token") + + def test_empty_token_raises_401_without_jwks_lookup(self): + with self.assertRaises(HTTPException) as ctx: + verify_oidc_token(_cfg(), "") + self.assertEqual(ctx.exception.status_code, 401) + self.assertIn("Missing", ctx.exception.detail) + self.mock_signing_key.assert_not_called() + + def test_jwks_lookup_failure_raises_401(self): + self.mock_signing_key.side_effect = Exception("JWKS fetch failed") + + with self.assertRaises(HTTPException) as ctx: + verify_oidc_token(_cfg(), "bad.token") + self.assertEqual(ctx.exception.status_code, 401) + + +if __name__ == "__main__": + unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py b/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py index f20a1d1f7e..52ce7e828c 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py @@ -2,8 +2,9 @@ from unittest.mock import MagicMock import redis as redis_lib -import redis_helper -from redis_helper import ( +from utils import redis_helper +from utils.misc import TimingPhase +from utils.redis_helper import ( _ALLOWLIST_CACHE_KEY, create_client, get_cached_yaml, @@ -56,5 +57,52 @@ def test_create_client_reuses_cached_client_for_same_url(self): mock_from_url.assert_called_once() +class TestTimingHelpers(unittest.TestCase): + def setUp(self): + redis_helper._cached_client = None + redis_helper._cached_client_url = None + + def test_set_timing_swallows_redis_error(self): + from utils.redis_helper import set_timing + + client = MagicMock() + client.setex.side_effect = redis_lib.exceptions.RedisError("boom") + cfg = MagicMock() + cfg.redis_endpoint = "host:6379" + cfg.redis_login = "" + cfg.oot_status_ttl = 3600 + + # must not raise — signature is (config, delivery_id, downstream_repo, phase, ts, client) + set_timing(cfg, "del-123", "org/repo", TimingPhase.DISPATCH, 1234.5, client) + + def test_get_timing_returns_none_on_cache_miss(self): + from utils.redis_helper import get_timing + + client = MagicMock() + client.get.return_value = None + cfg = MagicMock() + cfg.redis_endpoint = "host:6379" + cfg.redis_login = "" + + result = get_timing(cfg, "del-123", "org/repo", TimingPhase.DISPATCH, client) + + self.assertIsNone(result) + + def test_get_timing_swallows_redis_error(self): + # Timing is a best-effort reporting enrichment — a Redis outage must + # degrade gracefully to None rather than breaking the result handler. + from utils.redis_helper import get_timing + + client = MagicMock() + client.get.side_effect = redis_lib.exceptions.RedisError("timeout") + cfg = MagicMock() + cfg.redis_endpoint = "host:6379" + cfg.redis_login = "" + + self.assertIsNone( + get_timing(cfg, "del-123", "org/repo", TimingPhase.DISPATCH, client) + ) + + if __name__ == "__main__": unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py b/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py new file mode 100644 index 0000000000..c9a8f4d968 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py @@ -0,0 +1,170 @@ +import time +import unittest +from unittest.mock import MagicMock, patch + +from callback.result_handler import handle +from utils.misc import TimingPhase + + +def _cfg(): + cfg = MagicMock() + cfg.hud_api_url = "http://hud/api/oot-ci-events" + cfg.hud_bot_key = "bot-key-123" + cfg.redis_endpoint = "host:6379" + cfg.redis_login = "" + cfg.oot_status_ttl = 259200 + return cfg + + +def _body(status="completed"): + return { + "event_type": "pull_request", + "delivery_id": "del-123", + "payload": { + "pull_request": {"number": 42, "head": {"sha": "abc123"}}, + "repository": {"full_name": "pytorch/pytorch"}, + }, + "workflow": { + "status": status, + "conclusion": "success" if status == "completed" else None, + "name": "CI", + "url": "http://ci.example.com/run/1", + }, + } + + +class TestResultHandler(unittest.TestCase): + def setUp(self): + self.patcher_allowlist = patch("callback.result_handler.load_allowlist") + self.mock_load_allowlist = self.patcher_allowlist.start() + mock_map = MagicMock() + mock_map.get_repos_at_or_above_level.return_value = (["org/repo"], []) + self.mock_load_allowlist.return_value = mock_map + + self.patcher_redis = patch("callback.result_handler.redis_helper") + self.mock_redis = self.patcher_redis.start() + self.mock_redis.create_client.return_value = MagicMock() + self.mock_redis.get_timing.return_value = None + + self.patcher_hud = patch("callback.result_handler.forward_to_hud") + self.mock_hud = self.patcher_hud.start() + + def tearDown(self): + self.patcher_allowlist.stop() + self.patcher_redis.stop() + self.patcher_hud.stop() + + # --- allowlist uses the OIDC-verified repo, not the body --- + + def test_verified_repo_not_in_l2_returns_ignored(self): + mock_map = MagicMock() + mock_map.get_repos_at_or_above_level.return_value = (["other/repo"], []) + self.mock_load_allowlist.return_value = mock_map + + result = handle(_cfg(), _body(), verified_repo="org/repo") + + self.assertEqual(result, {"ok": True, "status": "ignored"}) + self.assertFalse(self.mock_redis.create_client.called) + self.assertFalse(self.mock_hud.called) + + # --- body is forwarded to HUD verbatim; authenticated_repo is a sibling --- + + def test_body_is_passed_to_hud_unchanged(self): + body = _body() + handle(_cfg(), body, verified_repo="org/repo") + + # forward_to_hud(config, downstream_report, ci_metrics, authenticated_repo) + _, report_arg, metrics_arg, auth_repo_arg = self.mock_hud.call_args[0] + self.assertIs(report_arg, body) + self.assertEqual(auth_repo_arg, "org/repo") + # authenticated_repo is a sibling of ci_metrics, not nested inside it. + self.assertNotIn("authenticated_repo", metrics_arg) + + # --- timing --- + + def test_in_progress_records_timing_and_computes_queue_time(self): + dispatch_at = time.time() - 30 + self.mock_redis.get_timing.return_value = dispatch_at + + result = handle(_cfg(), _body(status="in_progress"), verified_repo="org/repo") + + self.assertEqual(result, {"ok": True, "status": "in_progress"}) + # set_timing called with the verified repo, not any body-reported repo. + args, _ = self.mock_redis.set_timing.call_args + self.assertEqual(args[2], "org/repo") + self.assertEqual(args[3], TimingPhase.IN_PROGRESS) + _, _, metrics, _ = self.mock_hud.call_args[0] + self.assertAlmostEqual(metrics["queue_time"], 30, delta=1.0) + self.assertIsNone(metrics["execution_time"]) + + def test_completed_computes_execution_time_only(self): + # Each phase reports exactly one metric: completed → execution_time. + # queue_time was already reported during in_progress, so HUD merges + # the two rows on delivery_id. + in_progress_at = time.time() - 30 + self.mock_redis.get_timing.return_value = in_progress_at + + result = handle(_cfg(), _body(status="completed"), verified_repo="org/repo") + + self.assertEqual(result, {"ok": True, "status": "completed"}) + _, _, metrics, _ = self.mock_hud.call_args[0] + self.assertIsNone(metrics["queue_time"]) + self.assertAlmostEqual(metrics["execution_time"], 30, delta=1.0) + + # --- best-effort redis infra --- + + def test_get_timing_redis_error_does_not_break_handler(self): + self.mock_redis.get_timing.return_value = None + + result = handle(_cfg(), _body(status="completed"), verified_repo="org/repo") + + self.assertEqual(result, {"ok": True, "status": "completed"}) + self.assertTrue(self.mock_hud.called) + _, _, metrics, _ = self.mock_hud.call_args[0] + self.assertIsNone(metrics["queue_time"]) + self.assertIsNone(metrics["execution_time"]) + + def test_redis_client_unavailable_skips_timing(self): + self.mock_redis.create_client.side_effect = RuntimeError("redis down") + + result = handle(_cfg(), _body(status="completed"), verified_repo="org/repo") + + self.assertEqual(result, {"ok": True, "status": "completed"}) + self.assertTrue(self.mock_hud.called) + + # --- HUD 4xx propagates (5xx is swallowed inside forward_to_hud) --- + + def test_hud_4xx_propagates(self): + from utils.misc import HTTPException + + self.mock_hud.side_effect = HTTPException(422, "bad schema") + + with self.assertRaises(HTTPException) as ctx: + handle(_cfg(), _body(), verified_repo="org/repo") + self.assertEqual(ctx.exception.status_code, 422) + + # --- required field validation --- + + def test_missing_delivery_id_returns_400(self): + from utils.misc import HTTPException + + body = _body() + del body["delivery_id"] + + with self.assertRaises(HTTPException) as ctx: + handle(_cfg(), body, verified_repo="org/repo") + self.assertEqual(ctx.exception.status_code, 400) + + def test_missing_workflow_status_returns_400(self): + from utils.misc import HTTPException + + body = _body() + del body["workflow"]["status"] + + with self.assertRaises(HTTPException) as ctx: + handle(_cfg(), body, verified_repo="org/repo") + self.assertEqual(ctx.exception.status_code, 400) + + +if __name__ == "__main__": + unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler_lambda.py b/aws/lambda/cross_repo_ci_relay/tests/test_result_handler_lambda.py new file mode 100644 index 0000000000..3ee85e3548 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/tests/test_result_handler_lambda.py @@ -0,0 +1,116 @@ +import base64 +import json +import unittest +from unittest.mock import patch + +from callback.lambda_function import lambda_handler +from utils.misc import HTTPException + + +def _event( + *, + method="POST", + path="/github/result", + body=None, + headers=None, + base64_encoded=False, +): + if body is None: + body = json.dumps({"status": "completed", "head_sha": "abc123"}) + if base64_encoded: + body = base64.b64encode(body.encode()).decode() + if headers is None: + hdrs = {"authorization": "Bearer oidc.tok"} + else: + hdrs = dict(headers) + return { + "requestContext": {"http": {"method": method, "path": path}}, + "body": body, + "isBase64Encoded": base64_encoded, + "headers": hdrs, + } + + +class TestResultLambdaHandler(unittest.TestCase): + def setUp(self): + import utils.config + + utils.config._cached_config = None + + def test_route_validation(self): + response = lambda_handler(_event(path="/other"), {}) + self.assertEqual(response["statusCode"], 404) + response = lambda_handler(_event(method="GET"), {}) + self.assertEqual(response["statusCode"], 405) + + @patch("callback.lambda_function.get_config") + def test_invalid_json_body_returns_400(self, mock_get_config): + response = lambda_handler(_event(body="not-json"), {}) + self.assertEqual(response["statusCode"], 400) + + @patch("callback.lambda_function.get_config") + def test_missing_authorization_header_returns_401(self, mock_get_config): + response = lambda_handler(_event(headers={}), {}) + self.assertEqual(response["statusCode"], 401) + self.assertIn("Missing", json.loads(response["body"])["detail"]) + + @patch("callback.lambda_function.get_config") + @patch("callback.lambda_function.jwt_helper.verify_oidc_token") + def test_oidc_failure_returns_401(self, mock_oidc, mock_get_config): + mock_oidc.side_effect = HTTPException(401, "Invalid authorization token") + + response = lambda_handler(_event(), {}) + + self.assertEqual(response["statusCode"], 401) + + @patch("callback.lambda_function.get_config") + @patch("callback.lambda_function.jwt_helper.verify_oidc_token") + @patch("callback.lambda_function.result_handler.handle") + def test_happy_path_forwards_body_and_verified_repo( + self, mock_handle, mock_oidc, mock_get_config + ): + mock_oidc.return_value = {"repository": "org/repo"} + mock_handle.return_value = {"ok": True, "status": "completed"} + + response = lambda_handler(_event(), {}) + + self.assertEqual(response["statusCode"], 200) + self.assertEqual( + json.loads(response["body"]), {"ok": True, "status": "completed"} + ) + # Body passed through verbatim, verified_repo comes from OIDC claims. + args = mock_handle.call_args[0] + self.assertEqual(args[1], {"status": "completed", "head_sha": "abc123"}) + self.assertEqual(args[2], "org/repo") + + @patch("callback.lambda_function.get_config") + @patch("callback.lambda_function.jwt_helper.verify_oidc_token") + @patch("callback.lambda_function.result_handler.handle") + def test_hud_error_from_handler_is_forwarded( + self, mock_handle, mock_oidc, mock_get_config + ): + # HUD's HTTP status propagates out of Relay (transparent proxy). + mock_oidc.return_value = {"repository": "org/repo"} + mock_handle.side_effect = HTTPException(503, "HUD unreachable") + + response = lambda_handler(_event(), {}) + + self.assertEqual(response["statusCode"], 503) + self.assertEqual(json.loads(response["body"])["detail"], "HUD unreachable") + + @patch("callback.lambda_function.get_config") + @patch("callback.lambda_function.jwt_helper.verify_oidc_token") + @patch("callback.lambda_function.result_handler.handle") + def test_unhandled_exception_returns_500( + self, mock_handle, mock_oidc, mock_get_config + ): + mock_oidc.return_value = {"repository": "org/repo"} + mock_handle.side_effect = Exception("boom") + + response = lambda_handler(_event(), {}) + + self.assertEqual(response["statusCode"], 500) + + +if __name__ == "__main__": + unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/utils.py b/aws/lambda/cross_repo_ci_relay/utils.py deleted file mode 100644 index 78e613cfe2..0000000000 --- a/aws/lambda/cross_repo_ci_relay/utils.py +++ /dev/null @@ -1,13 +0,0 @@ -from typing import TypedDict - - -class HTTPException(Exception): - def __init__(self, status_code: int, detail): - self.status_code = status_code - self.detail = detail - - -class EventDispatchPayload(TypedDict): - event_type: str - delivery_id: str - payload: dict diff --git a/aws/lambda/cross_repo_ci_relay/allowlist.py b/aws/lambda/cross_repo_ci_relay/utils/allowlist.py similarity index 96% rename from aws/lambda/cross_repo_ci_relay/allowlist.py rename to aws/lambda/cross_repo_ci_relay/utils/allowlist.py index 8be55ad639..822ef28c77 100644 --- a/aws/lambda/cross_repo_ci_relay/allowlist.py +++ b/aws/lambda/cross_repo_ci_relay/utils/allowlist.py @@ -13,10 +13,10 @@ from enum import Enum from urllib.parse import urlparse -import gh_helper -import redis_helper import yaml -from config import RelayConfig + +from . import gh_helper, redis_helper +from .config import RelayConfig logger = logging.getLogger(__name__) @@ -96,9 +96,6 @@ def get_repos_at_or_above_level( oncalls.extend(lvl_oncalls) return repos, oncalls - def __bool__(self) -> bool: - return any(bool(entries) for entries in self._levels.values()) - @classmethod def _parse(cls, raw: dict) -> "AllowlistMap": if not isinstance(raw, dict): diff --git a/aws/lambda/cross_repo_ci_relay/config.py b/aws/lambda/cross_repo_ci_relay/utils/config.py similarity index 81% rename from aws/lambda/cross_repo_ci_relay/config.py rename to aws/lambda/cross_repo_ci_relay/utils/config.py index bfc5252a0f..2d9b9e056f 100644 --- a/aws/lambda/cross_repo_ci_relay/config.py +++ b/aws/lambda/cross_repo_ci_relay/utils/config.py @@ -11,13 +11,14 @@ class RelaySecrets: github_app_secret: str = "" github_app_private_key: str = "" redis_login: str = "" + hud_bot_key: str = "" @classmethod def from_aws(cls, secret_store_arn: str, client=None) -> "RelaySecrets": region = os.environ.get("AWS_REGION", "us-east-1") try: if client is None: - client = boto3.session.Session().client( + client = boto3.client( "secretsmanager", region_name=region, config=Config(retries={"max_attempts": 3, "mode": "standard"}), @@ -36,6 +37,7 @@ def from_aws(cls, secret_store_arn: str, client=None) -> "RelaySecrets": github_app_secret=secret.get("GITHUB_APP_SECRET", ""), github_app_private_key=secret.get("GITHUB_APP_PRIVATE_KEY", ""), redis_login=secret.get("REDIS_LOGIN", ""), + hud_bot_key=secret.get("HUD_BOT_KEY", ""), ) @@ -57,6 +59,9 @@ class RelayConfig: redis_login: str allowlist_ttl_seconds: int max_dispatch_workers: int + hud_api_url: str + hud_bot_key: str + oot_status_ttl: int @classmethod def from_env(cls) -> "RelayConfig": @@ -65,6 +70,7 @@ def from_env(cls) -> "RelayConfig": github_app_private_key = os.getenv("GITHUB_APP_PRIVATE_KEY", "") redis_login = os.getenv("REDIS_LOGIN", "") secret_store_arn = os.getenv("SECRET_STORE_ARN", "") + hud_bot_key = os.getenv("HUD_BOT_KEY", "") if not github_app_secret or not github_app_private_key or not redis_login: if not secret_store_arn: @@ -87,12 +93,14 @@ def from_env(cls) -> "RelayConfig": github_app_private_key or secrets.github_app_private_key ) redis_login = redis_login or secrets.redis_login + hud_bot_key = hud_bot_key or secrets.hud_bot_key missing_in_secret = [ v for v, val in [ ("GITHUB_APP_SECRET", github_app_secret), ("GITHUB_APP_PRIVATE_KEY", github_app_private_key), ("REDIS_LOGIN", redis_login), + ("HUD_BOT_KEY", hud_bot_key), ] if not val ] @@ -112,6 +120,14 @@ def from_env(cls) -> "RelayConfig": # avoidable rate-limit risk in production. allowlist_ttl_seconds = max(allowlist_ttl_seconds, 900) + # GitHub can keep a workflow job in `pending` state for up to 3 days before + # auto-cancelling it, so OOT-status records must live at least that long. + # Default to 3 days (259200 s). + try: + oot_status_ttl = int(os.getenv("OOT_STATUS_TTL", "259200")) + except ValueError: + raise RuntimeError("OOT_STATUS_TTL must be a valid integer") + return cls( github_app_id=_require("GITHUB_APP_ID"), github_app_secret=github_app_secret, @@ -122,4 +138,17 @@ def from_env(cls) -> "RelayConfig": redis_login=redis_login, allowlist_ttl_seconds=allowlist_ttl_seconds, max_dispatch_workers=int(os.getenv("MAX_DISPATCH_WORKERS", "32")), + hud_api_url=os.getenv("HUD_API_URL", ""), + hud_bot_key=hud_bot_key, + oot_status_ttl=oot_status_ttl, ) + + +_cached_config: RelayConfig | None = None + + +def get_config() -> RelayConfig: + global _cached_config + if _cached_config is None: + _cached_config = RelayConfig.from_env() + return _cached_config diff --git a/aws/lambda/cross_repo_ci_relay/gh_helper.py b/aws/lambda/cross_repo_ci_relay/utils/gh_helper.py similarity index 98% rename from aws/lambda/cross_repo_ci_relay/gh_helper.py rename to aws/lambda/cross_repo_ci_relay/utils/gh_helper.py index 6d0d9c7fa6..4e279d1994 100644 --- a/aws/lambda/cross_repo_ci_relay/gh_helper.py +++ b/aws/lambda/cross_repo_ci_relay/utils/gh_helper.py @@ -4,7 +4,8 @@ import github from github import GithubIntegration -from utils import EventDispatchPayload + +from .misc import EventDispatchPayload logger = logging.getLogger(__name__) diff --git a/aws/lambda/cross_repo_ci_relay/utils/hud.py b/aws/lambda/cross_repo_ci_relay/utils/hud.py new file mode 100644 index 0000000000..87017c20bf --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/utils/hud.py @@ -0,0 +1,86 @@ +import json +import logging +import urllib.error +import urllib.request + +from .config import RelayConfig +from .misc import HTTPException + + +logger = logging.getLogger(__name__) + + +def forward_to_hud( + config: RelayConfig, + body: dict, + ci_metrics: dict, + authenticated_repo: str, +) -> None: + """POST a callback record to HUD. + + The HUD request body has three top-level fields: + + - ``body``: the downstream workflow's callback body, forwarded verbatim. + Contains the original dispatch envelope (``delivery_id``, ``payload``) + plus a ``workflow`` dict the downstream self-reports. Treat every field + here as untrusted — downstream can set them to anything. + - ``ci_metrics``: relay-measured performance of the downstream CI + infrastructure (``queue_time``, ``execution_time``). These come from + relay's own timing records, not from the downstream, so HUD can trust + them as a signal of downstream CI capability. + - ``authenticated_repo``: the OIDC-authenticated downstream repository. + HUD should treat this as the sole trusted identity of the caller and + prefer it over any self-reported repo field inside ``body``. + + Error handling splits by responsibility: + + - HUD 4xx (schema/validation errors, i.e. the caller's fault) is propagated + back to the downstream workflow so the workflow author sees a red CI + step and can fix their payload. + - HUD 5xx and network-level failures (HUD's own problem or infra) are + logged loudly but swallowed. The callback channel is observational — + letting HUD outages turn every downstream L2 CI red would blame the + wrong team. CloudWatch logs and alarms on ``HUD forward failed`` are + the intended operator signal here. + """ + if not config.hud_api_url: + # No HUD configured (e.g. local dev before HUD endpoint exists) — + # log and no-op rather than 500. Remove this branch once HUD is + # mandatory in every environment. + logger.info("HUD_API_URL not configured, skipping HUD write") + return + + hud_payload = json.dumps( + { + "body": dict(body), + "ci_metrics": dict(ci_metrics), + "authenticated_repo": authenticated_repo, + } + ).encode("utf-8") + req = urllib.request.Request( + config.hud_api_url, + data=hud_payload, + headers={ + "Content-Type": "application/json", + "Authorization": config.hud_bot_key, + }, + method="POST", + ) + try: + with urllib.request.urlopen(req, timeout=10) as resp: + logger.info("HUD forward succeeded status=%d", resp.status) + except urllib.error.HTTPError as exc: + if 400 <= exc.code < 500: + detail = f"HUD rejected callback: HTTP {exc.code}: {exc.reason}" + logger.warning("HUD forward failed (client error): %s", detail) + raise HTTPException(exc.code, detail) from exc + # 5xx — HUD's own problem, don't propagate. + logger.exception( + "HUD forward failed (server error), swallowing: HTTP %d %s", + exc.code, + exc.reason, + ) + except urllib.error.URLError as exc: + # Network-level failure (DNS, timeout, connection refused). Treated + # as infrastructure rather than caller error — same as 5xx. + logger.exception("HUD forward failed (unreachable), swallowing: %s", exc.reason) diff --git a/aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py b/aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py new file mode 100644 index 0000000000..088674de41 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py @@ -0,0 +1,43 @@ +"""JWT utilities for the cross-repo CI relay.""" + +from __future__ import annotations + +import logging + +import jwt +from utils.config import RelayConfig +from utils.misc import HTTPException + + +logger = logging.getLogger(__name__) + +_jwks_client = jwt.PyJWKClient( + "https://token.actions.githubusercontent.com/.well-known/jwks" +) + + +def verify_oidc_token(config: RelayConfig, token: str) -> dict: + """Decode a GitHub Actions OIDC token and return the claims. + + Rejects an empty/missing token up front so every call site gets a uniform + 401 without repeating the check. Raises ``HTTPException(401)`` on any + verification failure. + """ + if not token: + raise HTTPException(401, "Missing authorization token") + + try: + if token.lower().startswith("bearer "): + token = token[7:].strip() + + signing_key = _jwks_client.get_signing_key_from_jwt(token) + return jwt.decode( + token, + signing_key.key, + algorithms=["RS256"], + issuer="https://token.actions.githubusercontent.com", + options={"verify_aud": False}, + ) + except Exception as exc: + logger.exception("OIDC token verification error") + raise HTTPException(401, "Invalid authorization token") from exc diff --git a/aws/lambda/cross_repo_ci_relay/utils/misc.py b/aws/lambda/cross_repo_ci_relay/utils/misc.py new file mode 100644 index 0000000000..d4a197ff48 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/utils/misc.py @@ -0,0 +1,53 @@ +"""Small shared types, exceptions, and Lambda helpers. + +Grouped together because each piece is too small to justify its own module +and there's no shared abstraction to organise them under. +""" + +from __future__ import annotations + +import base64 +from enum import Enum +from typing import TypedDict + + +JSON_HEADERS = {"content-type": "application/json"} + + +class HTTPException(Exception): + def __init__(self, status_code: int, detail): + self.status_code = status_code + self.detail = detail + + +class EventDispatchPayload(TypedDict): + event_type: str + delivery_id: str + payload: dict + + +class TimingPhase(str, Enum): + """Phases recorded in the crcr:timing:* Redis keys. + + - ``DISPATCH``: webhook side, when a repository_dispatch is fired. + - ``IN_PROGRESS``: result side, when the downstream workflow reports it + has started running. + """ + + DISPATCH = "dispatch" + IN_PROGRESS = "in_progress" + + +def parse_lambda_event(event: dict) -> tuple[str, str, bytes, dict]: + """Extract method, path, body bytes, and lower-cased headers from a Lambda event dict.""" + http = event.get("requestContext", {}).get("http", {}) + method = http.get("method", "").upper() + path = http.get("path", "") + raw_body = event.get("body") or "" + body_bytes = ( + base64.b64decode(raw_body) + if event.get("isBase64Encoded") + else raw_body.encode("utf-8") + ) + headers = {k.lower(): v for k, v in (event.get("headers") or {}).items()} + return method, path, body_bytes, headers diff --git a/aws/lambda/cross_repo_ci_relay/redis_helper.py b/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py similarity index 54% rename from aws/lambda/cross_repo_ci_relay/redis_helper.py rename to aws/lambda/cross_repo_ci_relay/utils/redis_helper.py index d4e64fcdd3..f54421a767 100644 --- a/aws/lambda/cross_repo_ci_relay/redis_helper.py +++ b/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py @@ -1,14 +1,19 @@ import logging import os +from typing import cast from urllib.parse import quote import redis as redis_lib -from config import RelayConfig +from redis.exceptions import RedisError + +from .config import RelayConfig +from .misc import TimingPhase logger = logging.getLogger(__name__) _ALLOWLIST_CACHE_KEY = "crcr:allowlist_yaml" +_TIMING_PREFIX = "crcr:timing:" _cached_client: redis_lib.Redis | None = None _cached_client_url: str | None = None @@ -69,17 +74,20 @@ def create_client(config: RelayConfig) -> redis_lib.Redis: """Create or reuse a Redis client for the given config.""" global _cached_client global _cached_client_url - - redis_url = _build_url(config) - if _cached_client is not None and _cached_client_url == redis_url: - return _cached_client - - client = redis_lib.from_url( - redis_url, - decode_responses=True, - socket_connect_timeout=2, - socket_timeout=2, - ) + try: + redis_url = _build_url(config) + if _cached_client is not None and _cached_client_url == redis_url: + return _cached_client + + client = redis_lib.from_url( + redis_url, + decode_responses=True, + socket_connect_timeout=2, + socket_timeout=2, + ) + except Exception: + logger.exception("Error creating Redis client") + raise RuntimeError("Failed to create Redis client") _cached_client = client _cached_client_url = redis_url return client @@ -95,12 +103,10 @@ def get_cached_yaml( value = client.get(_ALLOWLIST_CACHE_KEY) if value is not None: logger.info("allowlist cache hit key=%s", _ALLOWLIST_CACHE_KEY) - return value - except redis_lib.exceptions.RedisError as exc: - error_message = str(exc) - logger.warning( - "redis cache read failed, falling back to source: %s", - error_message, + return cast(str | None, value) + except RedisError: + logger.exception( + "redis cache read failed, falling back to source", ) return None @@ -116,9 +122,56 @@ def set_cached_yaml( logger.info( "allowlist cached %d bytes key=%s", len(yaml_str), _ALLOWLIST_CACHE_KEY ) - except redis_lib.exceptions.RedisError as exc: - error_message = str(exc) - logger.warning( - "redis cache write failed, continuing without cache: %s", - error_message, - ) + except RedisError: + logger.exception("redis cache write failed, continuing without cache") + + +def _timing_key(delivery_id: str, downstream_repo: str, phase: TimingPhase) -> str: + # delivery_id is GitHub's globally-unique X-GitHub-Delivery, so it disambiguates + # retries/reruns that share a head_sha. downstream_repo keeps the fan-out + # dimension since one delivery dispatches to many repos with independent timings. + return f"{_TIMING_PREFIX}{delivery_id}:{downstream_repo}:{phase.value}" + + +def set_timing( + config: RelayConfig, + delivery_id: str, + downstream_repo: str, + phase: TimingPhase, + ts: float, + client: redis_lib.Redis | None = None, +) -> None: + """Set timestamp for a given delivery+repo. Best-effort.""" + try: + if client is None: + client = create_client(config) + key = _timing_key(delivery_id, downstream_repo, phase) + client.setex(key, config.oot_status_ttl, ts) + logger.info("%s timing cached key=%s", phase.value, key) + except Exception: + logger.exception("redis set_timing failed phase=%s", phase.value) + + +def get_timing( + config: RelayConfig, + delivery_id: str, + downstream_repo: str, + phase: TimingPhase, + client: redis_lib.Redis | None = None, +) -> float | None: + """Return the stored timestamp as a float, or None on cache miss / Redis error. + + Best-effort: timing data is a reporting-only enrichment, so Redis failures + must not break the result handler. Errors are logged and swallowed. + """ + try: + if client is None: + client = create_client(config) + key = _timing_key(delivery_id, downstream_repo, phase) + value = client.get(key) + if value is None: + return None + return float(value) + except Exception: + logger.exception("redis get_timing failed phase=%s", phase.value) + return None diff --git a/aws/lambda/cross_repo_ci_relay/webhook/Makefile b/aws/lambda/cross_repo_ci_relay/webhook/Makefile new file mode 100644 index 0000000000..75c4a565b8 --- /dev/null +++ b/aws/lambda/cross_repo_ci_relay/webhook/Makefile @@ -0,0 +1,18 @@ +UTILS_SRC := $(wildcard ../utils/*.py) +PIP_FLAGS := --platform manylinux2014_x86_64 --only-binary=:all: --implementation cp --python-version 3.13 +AWS_REGION := us-east-1 +FUNCTION_NAME := cross_repo_ci_webhook + +deployment.zip: clean + mkdir -p ./deployment/{webhook,utils} + cp *.py ./deployment/webhook/ && cp $(UTILS_SRC) ./deployment/utils/ + pip3 install --target ./deployment -r ../requirements.txt $(PIP_FLAGS) + cd deployment && zip -r ../deployment.zip . + +deploy: deployment.zip + aws lambda update-function-code --region $(AWS_REGION) --function-name $(FUNCTION_NAME) --zip-file fileb://deployment.zip + +clean: + rm -rf deployment deployment.zip + +.PHONY: deploy clean diff --git a/aws/lambda/cross_repo_ci_relay/event_handler.py b/aws/lambda/cross_repo_ci_relay/webhook/event_handler.py similarity index 86% rename from aws/lambda/cross_repo_ci_relay/event_handler.py rename to aws/lambda/cross_repo_ci_relay/webhook/event_handler.py index 274dffd593..85d8e8f9bc 100644 --- a/aws/lambda/cross_repo_ci_relay/event_handler.py +++ b/aws/lambda/cross_repo_ci_relay/webhook/event_handler.py @@ -2,12 +2,13 @@ import json import logging +import time from concurrent.futures import as_completed, ThreadPoolExecutor -import gh_helper -from allowlist import AllowlistLevel, load_allowlist -from config import RelayConfig -from utils import EventDispatchPayload, HTTPException +from utils import gh_helper, redis_helper +from utils.allowlist import AllowlistLevel, load_allowlist +from utils.config import RelayConfig +from utils.misc import EventDispatchPayload, HTTPException, TimingPhase logger = logging.getLogger(__name__) @@ -33,6 +34,17 @@ def _dispatch_one( client_payload=client_payload, ) + # Record dispatch timestamp for timing calculations (best-effort). + # Keyed by X-GitHub-Delivery (globally unique per webhook delivery) so + # retries/reruns with the same head_sha don't collide. + redis_helper.set_timing( + config, + client_payload.get("delivery_id"), + downstream_repo, + TimingPhase.DISPATCH, + time.time(), + ) + def _dispatch_to_allowlist( *, @@ -77,13 +89,12 @@ def _dispatch_to_allowlist( ) dispatched.append({"repo": downstream_repo}) except Exception as e: - error_message = str(e) - logger.error( - "dispatch failed event_type=%s repo=%s error=%s", + logger.exception( + "dispatch failed event_type=%s repo=%s", event_type, downstream_repo, - error_message, ) + error_message = str(e) failed.append( { "repo": downstream_repo, diff --git a/aws/lambda/cross_repo_ci_relay/lambda_function.py b/aws/lambda/cross_repo_ci_relay/webhook/lambda_function.py similarity index 69% rename from aws/lambda/cross_repo_ci_relay/lambda_function.py rename to aws/lambda/cross_repo_ci_relay/webhook/lambda_function.py index e951a5926e..a2f1444806 100644 --- a/aws/lambda/cross_repo_ci_relay/lambda_function.py +++ b/aws/lambda/cross_repo_ci_relay/webhook/lambda_function.py @@ -1,21 +1,18 @@ from __future__ import annotations -import base64 import hashlib import hmac import json import logging -import event_handler -from config import RelayConfig -from utils import HTTPException +from utils.config import get_config +from utils.misc import HTTPException, JSON_HEADERS, parse_lambda_event +from . import event_handler logging.getLogger().setLevel(logging.INFO) logger = logging.getLogger(__name__) -_cached_config: RelayConfig | None = None - def _verify_signature(secret: str, body: bytes, signature: str) -> None: if not signature: @@ -27,29 +24,11 @@ def _verify_signature(secret: str, body: bytes, signature: str) -> None: raise HTTPException(status_code=401, detail="Bad signature") -_JSON_HEADERS = {"content-type": "application/json"} _SUPPORTED_EVENTS = frozenset({"pull_request", "push"}) -def _get_config() -> RelayConfig: - global _cached_config - if _cached_config is None: - _cached_config = RelayConfig.from_env() - return _cached_config - - def lambda_handler(event, context): - http = event.get("requestContext", {}).get("http", {}) - method = http.get("method", "").upper() - path = http.get("path", "") - - raw_body = event.get("body") or "" - body_bytes = ( - base64.b64decode(raw_body) - if event.get("isBase64Encoded") - else raw_body.encode("utf-8") - ) - headers = {k.lower(): v for k, v in (event.get("headers") or {}).items()} + method, path, body_bytes, headers = parse_lambda_event(event) delivery = headers.get("x-github-delivery", "") logger.info("request method=%s path=%s delivery=%s", method, path, delivery) @@ -58,12 +37,12 @@ def lambda_handler(event, context): if path == "/github/webhook": return { "statusCode": 405, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"detail": "Method not allowed"}), } return { "statusCode": 404, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"detail": "Not found"}), } @@ -72,12 +51,12 @@ def lambda_handler(event, context): logger.info("event=%s ignored before verification", event_type) return { "statusCode": 200, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"ignored": True}), } try: - config = _get_config() + config = get_config() _verify_signature( config.github_app_secret, body_bytes, headers.get("x-hub-signature-256", "") @@ -90,7 +69,7 @@ def lambda_handler(event, context): logger.info("repo=%s not upstream, ignored", repo) return { "statusCode": 200, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"ignored": True}), } @@ -100,24 +79,28 @@ def lambda_handler(event, context): event_type=event_type, delivery_id=delivery, ) - return {"statusCode": 200, "headers": _JSON_HEADERS, "body": json.dumps(result)} + return {"statusCode": 200, "headers": JSON_HEADERS, "body": json.dumps(result)} except json.JSONDecodeError: + logger.warning("invalid JSON body in webhook request") return { "statusCode": 400, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"detail": "Invalid JSON body"}), } except HTTPException as exc: + logger.warning( + "http exception status=%d detail=%s", exc.status_code, exc.detail + ) return { "statusCode": exc.status_code, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"detail": exc.detail}), } except Exception: logger.exception("unhandled error") return { "statusCode": 500, - "headers": _JSON_HEADERS, + "headers": JSON_HEADERS, "body": json.dumps({"detail": "Internal server error"}), } From a8101648c64a5620ec3febac9283f2029373b3e5 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Sat, 9 May 2026 07:03:48 +0000 Subject: [PATCH 2/7] Fix all comments in 0509 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - Unified state machine: Single JSON structure `{state, timestamp}` - Per-job timestamps: Each job has independent timing - Repo level tracking: Added `downstream_repo_level` from allowlist (L1-L4) - Error handling: Honest messages, 5xx → green CI, 4xx → red CI --- .../cross-repo-ci-relay-callback/action.yml | 32 ++- aws/lambda/cross_repo_ci_relay/README.md | 158 ++++++------ .../callback/lambda_function.py | 3 +- .../callback/result_handler.py | 232 +++++++++++++---- .../cross_repo_ci_relay/tests/test_hud.py | 84 ++++-- .../tests/test_redis_helper.py | 243 +++++++++++++++--- .../tests/test_result_handler.py | 153 ++++++----- .../cross_repo_ci_relay/utils/allowlist.py | 8 + .../cross_repo_ci_relay/utils/config.py | 20 ++ aws/lambda/cross_repo_ci_relay/utils/hud.py | 119 +++++---- aws/lambda/cross_repo_ci_relay/utils/misc.py | 33 ++- .../cross_repo_ci_relay/utils/redis_helper.py | 204 ++++++++++++--- .../webhook/event_handler.py | 20 +- .../webhook/lambda_function.py | 1 + 14 files changed, 973 insertions(+), 337 deletions(-) diff --git a/.github/actions/cross-repo-ci-relay-callback/action.yml b/.github/actions/cross-repo-ci-relay-callback/action.yml index 10f86070a3..3e610e7991 100644 --- a/.github/actions/cross-repo-ci-relay-callback/action.yml +++ b/.github/actions/cross-repo-ci-relay-callback/action.yml @@ -25,15 +25,21 @@ inputs: default: '' test-results: description: > - Optional JSON string with structured test results produced by the - downstream job. Forwarded verbatim to the relay under `test_results`. + Optional JSON string with test result summary (counts: passed/failed/skipped). + Note: This should be a summary only, not a full enumeration of all test cases. + Full test results should be uploaded as artifacts and referenced via `artifact-url`. required: false default: '' callback-url: description: > Base URL of the result callback server. + required: true + artifact-url: + description: > + URL to downstream-hosted artifacts (logs, reports, results), + any publicly accessible URL. required: false - default: https://zciswjsrynb6ksccc2nb3mckpa0ldzou.lambda-url.us-east-1.on.aws/github/result + default: '' runs: using: composite @@ -51,6 +57,7 @@ runs: - name: Send callback to relay server shell: bash env: + SCHEMA_VERSION: 1 STATUS: ${{ inputs.status }} CONCLUSION: ${{ inputs.conclusion }} WORKFLOW_NAME: ${{ github.workflow }} @@ -59,11 +66,17 @@ runs: CLIENT_PAYLOAD: ${{ toJson(github.event.client_payload) }} OIDC_TOKEN: ${{ steps.oidc.outputs.token }} CALLBACK_URL: ${{ inputs.callback-url }} + ARTIFACT_URL: ${{ inputs.artifact-url }} + JOB_NAME: ${{ github.job }} + CHECK_RUN_ID: ${{ job.check_run_id }} + RUN_ID: ${{ github.run_id }} + RUN_ATTEMPT: ${{ github.run_attempt }} run: | set -euo pipefail PAYLOAD=$(python3 - <<'PYEOF' import json, os, sys + from datetime import datetime, timezone status = os.environ["STATUS"] if status not in ("in_progress", "completed"): @@ -80,14 +93,23 @@ runs: except json.JSONDecodeError as exc: sys.exit(f"::error::github.event.client_payload is not valid JSON: {exc}") + current_time = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + # Relay's original dispatch payload (event_type, delivery_id, payload) is # forwarded verbatim. Downstream-reported fields live in a sibling # `workflow` dict so the two sources stay clearly separated on the wire. workflow: dict = { + "schema_version": str(os.environ["SCHEMA_VERSION"]), "status": status, "conclusion": conclusion, "name": os.environ["WORKFLOW_NAME"], "url": os.environ["WORKFLOW_URL"], + "run_attempt": os.environ["RUN_ATTEMPT"], + "job_name": os.environ["JOB_NAME"], + "check_run_id": str(os.environ["CHECK_RUN_ID"]), + "run_id": str(os.environ["RUN_ID"]), + "started_at": None if status == "completed" else current_time, + "completed_at": None if status == "in_progress" else current_time, } test_results = os.environ.get("TEST_RESULTS", "").strip() @@ -97,6 +119,10 @@ runs: except json.JSONDecodeError as exc: sys.exit(f"::error::test-results input is not valid JSON: {exc}") + artifact_url = os.environ.get("ARTIFACT_URL", "").strip() + if artifact_url: + workflow["artifact_url"] = artifact_url + client_payload["workflow"] = workflow print(json.dumps(client_payload)) PYEOF diff --git a/aws/lambda/cross_repo_ci_relay/README.md b/aws/lambda/cross_repo_ci_relay/README.md index 331c911d67..58af886a44 100644 --- a/aws/lambda/cross_repo_ci_relay/README.md +++ b/aws/lambda/cross_repo_ci_relay/README.md @@ -37,54 +37,77 @@ L2+ downstream repositories can report the status of their CI workflows back to ### Security and the Relay/HUD boundary -The callback endpoint is a **near-transparent proxy to HUD** with a single -security responsibility: identifying the calling repo. Everything else — -schema validation, persistence, dedup — is HUD's job. - -- **Identity (Relay's job)**: the `Authorization: Bearer ` header - is verified against GitHub's JWKS. The OIDC `repository` claim is the only - trusted identity for the caller and is used for the L2+ allowlist check. - Relay forwards this trusted value to HUD as a top-level `authenticated_repo` - field; HUD should prefer it over anything self-reported in `body`. -- **Schema / business validation (HUD's job)**: the callback body is passed - through to HUD verbatim as a top-level `body` field. Relay does **not** - validate fields inside `body.workflow` or `body.payload` — HUD owns the - schema since it owns persistence. Relay only enforces that `delivery_id` - and `workflow.status` are present (contract violation → `400`). -- **Timing (Relay's job)**: Relay records dispatch/in-progress timestamps in - Redis keyed on `delivery_id`, then computes `queue_time` and - `execution_time` and surfaces them to HUD as `ci_metrics`. Each callback - phase reports exactly one metric: - - `in_progress` → `queue_time` (dispatch → in_progress) - - `completed` → `execution_time` (in_progress → completed) - -The HUD request looks like: +The callback endpoint validates incoming callbacks and forwards them to HUD for persistence. The relay is the gatekeeper for OIDC authentication, allowlist checks, rate limiting, and schema validation — HUD just authenticates the relay and writes what it's told. + +#### Relay's responsibilities: + +- **Identity**: the `Authorization: Bearer ` header is verified against GitHub's JWKS. The OIDC `repository` claim is a trusted identity for the caller and is used for the L2+ allowlist check. Relay forwards this trusted value to HUD as a top-level `verified_repo` field; HUD should prefer it over anything self-reported in `callback_payload`. +- **Repo level**: Relay determines the downstream repository's allowlist level (L1–L4) and forwards it to HUD as `downstream_repo_level`. This authoritative level information is determined once by the relay, ensuring HUD doesn't need to recompute it and avoiding synchronization/timing issues if tiering information becomes dynamic. +- **Schema validation**: Relay validates that required fields (`delivery_id` and `workflow.status`) are present in the callback body. Missing fields result in a `400` error to signal contract violations to the caller. HUD receives validated data and does not need to perform schema checks. +- **State machine**: Relay maintains a **unified state machine** in Redis to validate callback lifecycles, compute timing metrics, and support per-job tracking: + - **Unified structure**: Single enum `CallbackState` with states `DISPATCHED` (webhook side, keyed by sentinel `check_run_id="dispatched"`), `IN_PROGRESS`, and `COMPLETED` (callback side, per-job). State records stored as JSON: `{"state": "...", "timestamp": 1234.56, "job_name": "...", "run_id": "..."}`. + - **Dispatch validation**: `DISPATCHED` state proves valid webhook origin. Callbacks without this state are rejected (no prior dispatch). + - **Job-level tracking**: Each job has independent state and timestamps keyed by `check_run_id` (`oot:state:{delivery_id}:{repo}:{check_run_id}`). Supports multiple jobs per webhook. + - **Timing metrics**: `queue_time = dispatch_timestamp → in_progress_timestamp`, `execution_time = in_progress_timestamp → completed_timestamp`. Timestamps extracted from state records. + - **State transitions**: Rejects invalid flows (`COMPLETED` without prior `IN_PROGRESS`, duplicate `IN_PROGRESS` for the same `check_run_id`, duplicate `COMPLETED`, callbacks without a prior `DISPATCHED` record). + Note that the direction graph below is for a single check run, reruns have different `check_run_id` and are treated as separate jobs, so they won't violate the state machine since they won't have a prior `IN_PROGRESS` or `COMPLETED` record. + ```mermaid + stateDiagram-v2 + direction LR + + [*] --> DISPATCHED: webhook sends + DISPATCHED --> IN_PROGRESS: first callback + IN_PROGRESS --> COMPLETED: completion + + IN_PROGRESS --> IN_PROGRESS: ❌ duplicate + DISPATCHED --> COMPLETED: ❌ skip IN_PROGRESS + COMPLETED --> COMPLETED: ❌ duplicate + COMPLETED --> IN_PROGRESS: ❌ wrong direction + [*] --> IN_PROGRESS: ❌ no dispatch + [*] --> COMPLETED: ❌ no dispatch + ``` + +The HUD request looks like (two top-level namespaces: `trusted` and `untrusted`): ```json { - "body": { - "event_type": "pull_request", - "delivery_id": "", - "payload": { ...original upstream webhook payload, verbatim... }, - "workflow": { - "status": "completed", - "conclusion": "success", - "name": "CI", - "url": "https://github.com/org/repo/actions/runs/123", - "test_results": { ... } - } + "trusted": { + "ci_metrics": { "queue_time": 1.23, "execution_time": null }, + "verified_repo": "org/repo", + "downstream_repo_level": "L2" }, - "ci_metrics": { "queue_time": 1.23, "execution_time": null }, - "authenticated_repo": "org/repo" + "untrusted": { + "callback_payload": { + "event_type": "pull_request", + "delivery_id": "", + "payload": { ...original upstream webhook payload, verbatim... }, + "workflow": { + "schema_version": 1, + "status": "completed", + "conclusion": "success", + "name": "CI", + "url": "https://github.com/org/repo/actions/runs/123", + "job_name": "my-ci-job", + "started_at": "2026-05-04T20:48:28Z", // when status == in_progress, else None + "completed_at": "2026-05-04T21:23:45Z", // when status == completed, else None + "test_results": { "passed": 42, "failed": 3, "skipped": 5 }, + "artifact_url": "https://github.com/org/repo/actions/runs/123/artifacts" + } + } + } } ``` -Trust boundaries inside `body`: +Notes: +- `trusted` contains relay-generated fields the HUD can rely on (`ci_metrics`, `verified_repo`, and `downstream_repo_level`). +- `untrusted.callback_payload` contains the downstream-reported callback body; HUD should treat it as untrusted and prefer `trusted.verified_repo` for identity. + +Trust boundaries inside `untrusted.callback_payload`: -- `body.payload` is the upstream webhook payload, transparently forwarded — +- `untrusted.callback_payload.payload` is the upstream webhook payload, transparently forwarded — trusted at dispatch time, but not re-verified on the callback. -- `body.workflow` is **self-reported by the downstream CI** and is not - authenticated. Only `authenticated_repo` carries a cryptographic identity. +- `untrusted.callback_payload.workflow` is **self-reported by the downstream CI** and is not + authenticated. Only `verified_repo` carries a cryptographic identity. ### Error propagation back to the downstream workflow @@ -92,35 +115,19 @@ Trust boundaries inside `body`: |---|---|---| | `2xx` | record delivered | green | | `4xx` (schema reject) | propagate same status | **red** — author must fix payload | -| `5xx` / network error | log + swallow | green — HUD outage is not the caller's fault | +| `5xx` / network error | log + return | green — HUD outage is not the caller's fault | -The asymmetry is deliberate: `4xx` means the caller sent something wrong and -should see it; `5xx`/network means HUD or its infrastructure is broken and -should not be surfaced as a red CI step across every L2 repo. Operators are -expected to alert on the `HUD forward failed` CloudWatch log pattern. +The asymmetry is deliberate: `4xx` means the caller sent something wrong and should see it; `5xx`/network means HUD or its infrastructure is broken and should not be surfaced as a red CI step across every L2+ repo. Operators are expected to alert on the `HUD forward failed` CloudWatch log pattern. #### Known limitations of this model A compromised or malicious maintainer of an allowlisted repo can: -1. Fabricate `workflow.status` / `workflow.conclusion` values for upstream PRs - their repo was never dispatched for — HUD will receive the row, but - `authenticated_repo` always identifies the true caller. -2. Replay an older dispatched payload against the callback endpoint — there - is no dispatch-side nonce. -3. Tamper with any field inside `body` — HUD must trust `authenticated_repo`, - not the body. - -All three attacks are **scoped to the attacker's own OIDC-authenticated repo -identity** — OIDC guarantees they cannot impersonate another allowlisted -repo. Mitigation is operational: every HUD row carries `authenticated_repo`, -so misbehaviour is observable, and the offending repo can be removed from -`allowlist.yaml`. - -If stronger guarantees are required later, the typical next step is a signed -callback token minted by the webhook side plus a one-shot state machine in -Redis keyed on `delivery_id`. This was intentionally deferred to keep the -relay simple — see the PR description for the discussion. +1. Fabricate `workflow.status` / `workflow.conclusion` values for upstream PRs their repo was never dispatched for — HUD will receive the row, but `verified_repo` always identifies the true caller. +2. Replay an older dispatched payload against the callback endpoint — there is no dispatch-side nonce. +3. Tamper with any field inside `callback_payload` — HUD must trust `verified_repo`, not the others. + +All three attacks are **scoped to the attacker's own OIDC-authenticated repo identity** — OIDC guarantees they cannot impersonate another allowlisted repo. Mitigation is operational: every HUD row carries `verified_repo`, so misbehaviour is observable, and the offending repo can be removed from `allowlist.yaml`. ### Prerequisites @@ -129,11 +136,7 @@ relay simple — see the PR description for the discussion. ### Usage -When triggered by a relay `repository_dispatch`, the action automatically -reads `github.event.client_payload` for `delivery_id` and the upstream -webhook payload, and reads `github.workflow` / the current run URL for the -workflow identity. Workflow authors only need to pass `status` (and -`conclusion` when `status=completed`). +When triggered by a relay `repository_dispatch`, the action automatically reads `github.event.client_payload` for `delivery_id` and the upstream webhook payload, and reads `github.workflow` / the current run URL for the workflow identity. Workflow authors only required to pass `status` (and `conclusion` when `status=completed`, others are optional). ```yaml on: @@ -159,7 +162,7 @@ jobs: uses: pytorch/test-infra/.github/actions/cross-repo-ci-relay-callback@main with: status: completed - conclusion: ${{ job.status == 'success' && 'success' || 'failure' }} + conclusion: ${{ job.status }} ``` ### Inputs @@ -168,8 +171,9 @@ jobs: |---|---|---|---| | `status` | **yes** | — | `in_progress` or `completed` | | `conclusion` | no | `''` | `success` or `failure` (required when `status=completed`) | -| `test-results` | no | `''` | Optional JSON string forwarded under `body.workflow.test_results` | -| `callback-url` | no | see `action.yml` | Callback endpoint URL (overridable for local testing) | +| `test-results` | no | `''` | Optional JSON string with test result summary (counts: passed/failed/skipped) | +| `callback-url` | **yes** | — | Callback endpoint URL (production Lambda URL; set once at the workflow level) | +| `artifact-url` | no | `''` | URL to downstream-hosted artifacts (logs, reports, results) | ## Build, Deploy, and Test @@ -189,9 +193,7 @@ deployment/ └── ... ``` -This matches the layout used during local development and tests, so imports -behave identically in both environments. Configure the AWS Lambda handlers -as: +This matches the layout used during local development and tests, so imports behave identically in both environments. Configure the AWS Lambda handlers as: - Webhook Lambda: `webhook.lambda_function.lambda_handler` - Callback Lambda: `callback.lambda_function.lambda_handler` @@ -241,8 +243,7 @@ make clean ## Local Development -`local_server.py` wraps both Lambda handlers in a FastAPI app so you can test -the full cross-repo-ci-relay flow without deploying to AWS. +`local_server.py` wraps both Lambda handlers in a FastAPI app so you can test the full cross-repo-ci-relay flow without deploying to AWS. ### Prerequisites @@ -257,9 +258,7 @@ the full cross-repo-ci-relay flow without deploying to AWS. redis:7-alpine \ redis-server --requirepass ``` -- [smee.io](https://smee.io) - - CLI to forward GitHub webhook events to localhost (paste this link to GitHub App webhook URL): +- [smee.io](https://smee.io) CLI to forward GitHub webhook events to localhost (paste this link to GitHub App webhook URL): ```bash npm install -g smee-client smee --url https://smee.io/ --path /github/webhook --port 8000 @@ -302,6 +301,9 @@ the full cross-repo-ci-relay flow without deploying to AWS. REDIS_ENDPOINT=localhost:6379 REDIS_LOGIN=default: ALLOWLIST_TTL_SECONDS=1200 + + # HUD (local testing) + HUD_ENDPOINT= ``` **Note**: `ALLOWLIST_URL` is required for local development and should point to a GitHub URL (it can differ from the production one). diff --git a/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py b/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py index 8148773ff6..5f50201a43 100644 --- a/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py +++ b/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py @@ -9,6 +9,7 @@ from . import result_handler + logging.getLogger().setLevel(logging.INFO) logger = logging.getLogger(__name__) @@ -38,7 +39,7 @@ def lambda_handler(event, context): # OIDC is the only identity check Relay performs. The callback body is # passed through to HUD untouched — HUD owns schema/business validation. # Relay reports the OIDC-verified repo to HUD separately as - # `authenticated_repo` so HUD has a trusted source of truth for the + # `verified_repo` so HUD has a trusted source of truth for the # caller's identity. oidc_claims = jwt_helper.verify_oidc_token( config, headers.get("authorization", "") diff --git a/aws/lambda/cross_repo_ci_relay/callback/result_handler.py b/aws/lambda/cross_repo_ci_relay/callback/result_handler.py index d8d191ab07..f991f63319 100644 --- a/aws/lambda/cross_repo_ci_relay/callback/result_handler.py +++ b/aws/lambda/cross_repo_ci_relay/callback/result_handler.py @@ -4,10 +4,16 @@ import time import utils.redis_helper as redis_helper -from utils.allowlist import AllowlistLevel, load_allowlist +from utils.allowlist import AllowlistLevel, AllowlistMap, load_allowlist from utils.config import RelayConfig from utils.hud import forward_to_hud -from utils.misc import HTTPException, TimingPhase +from utils.misc import ( + CallbackState, + CallbackStateRecord, + DISPATCH_CHECK_RUN_ID, + HTTPException, +) +from utils.redis_helper import check_rate_limit logger = logging.getLogger(__name__) @@ -27,73 +33,209 @@ def _safe_delta( return delta -def handle(config: RelayConfig, body: dict, verified_repo: str) -> dict: - """Forward a downstream callback to HUD. +def _verify_access(config: RelayConfig, verified_repo: str) -> AllowlistMap | None: + """Return the AllowlistMap when ``verified_repo`` is L2+, else None. - ``body`` is the downstream self-report, passed through to HUD verbatim. - It carries the original dispatch envelope (``delivery_id``, ``payload``) - and a sibling ``workflow`` dict with status/conclusion/name/url. - - ``verified_repo`` is the OIDC-authenticated downstream repository — used - for allowlist / timing lookups, and surfaced to HUD as ``authenticated_repo`` - so HUD can trust it over anything self-reported in the body. + Raises HTTPException(429) if the per-repo rate limit is exceeded. + A ``None`` return signals the caller to silently ignore the request. """ allowlist = load_allowlist(config) l2_repos, _ = allowlist.get_repos_at_or_above_level(AllowlistLevel.L2) - if verified_repo not in l2_repos: logger.info( "verified_repo %s is not configured for L2+ features, ignoring result", verified_repo, ) - return {"ok": True, "status": "ignored"} + return None + if not check_rate_limit(config, verified_repo): + logger.warning( + "rate limit exceeded for verified_repo=%s, rejecting request", + verified_repo, + ) + raise HTTPException(429, f"rate limit exceeded for {verified_repo}") + return allowlist - # delivery_id and workflow.status are required fields on the callback body — - # the relay-callback action always sets them. A missing value is a contract - # violation from the downstream, so fail loudly rather than silently - # producing a HUD row with no timing. + +def _parse_callback_body(body: dict) -> tuple[str, str, str, str, str]: + """Return (delivery_id, status, check_run_id, job_name, run_id) from ``body``. + + check_run_id is set by GitHub Actions (job.check_run_id context) and + cannot be tampered with, ensuring replay-attack detection integrity. + + Raises HTTPException(400) on any missing or mis-typed field. + """ try: delivery_id = body["delivery_id"] - status = body["workflow"]["status"] + workflow_dict = body["workflow"] + status = workflow_dict["status"] + check_run_id = workflow_dict["check_run_id"] # Required + job_name = workflow_dict["job_name"] # Required for HUD grouping + run_id = workflow_dict["run_id"] # Required for HUD grouping except (KeyError, TypeError) as exc: + logger.warning("missing required field in callback body: %s", exc) raise HTTPException( 400, f"callback body missing required field: {exc}" ) from exc + return delivery_id, status, check_run_id, job_name, run_id + + +def _update_state_and_compute_metrics( + config: RelayConfig, + delivery_id: str, + verified_repo: str, + check_run_id: str, + job_name: str, + run_id: str, + status: str, + dispatch_record: CallbackStateRecord, + job_record: CallbackStateRecord | None, +) -> dict: + """Persist the new job state to Redis and return CI timing metrics. + + Writes IN_PROGRESS or COMPLETED state (with the current timestamp), then + reads back the stored record to compute: + - ``queue_time``: dispatch → in_progress (set on "in_progress" callbacks) + - ``execution_time``: in_progress → completed (set on "completed" callbacks) - # Each phase reports exactly one metric so HUD receives a clean, - # single-purpose row per callback: - # in_progress → queue_time (dispatch → in_progress) - # completed → execution_time (in_progress → completed) - # - # Timing keys are indexed by the body-reported delivery_id and the - # OIDC-verified repo. delivery_id is not independently authenticated — - # a tampered value just misses the timing cache, which only hurts the - # attacker's own HUD row. + Both metrics default to None when the required prior state is unavailable + (e.g. Redis cache miss or rerun without matching prior record). + """ ci_metrics: dict = {"queue_time": None, "execution_time": None} + current_timestamp = time.time() + if status == "in_progress": - in_progress_at = time.time() - redis_helper.set_timing( - config, delivery_id, verified_repo, TimingPhase.IN_PROGRESS, in_progress_at + if not redis_helper.set_callback_state( + config, + delivery_id, + verified_repo, + check_run_id, + CallbackState.IN_PROGRESS, + current_timestamp, + job_name, + run_id, + ): + raise HTTPException( + 400, + f"callback rejected: invalid state transition delivery_id={delivery_id} status={status}", + ) + updated_job_record = redis_helper.get_callback_state( + config, delivery_id, verified_repo, check_run_id ) - dispatch_at = redis_helper.get_timing( - config, delivery_id, verified_repo, TimingPhase.DISPATCH + if updated_job_record is not None: + ci_metrics["queue_time"] = _safe_delta( + dispatch_record.timestamp, + updated_job_record.timestamp, + "queue_time", + ) + + elif status == "completed": + if not redis_helper.set_callback_state( + config, + delivery_id, + verified_repo, + check_run_id, + CallbackState.COMPLETED, + current_timestamp, + job_name, + run_id, + ): + raise HTTPException( + 400, + f"callback rejected: invalid state transition delivery_id={delivery_id} status={status}", + ) + updated_job_record = redis_helper.get_callback_state( + config, delivery_id, verified_repo, check_run_id ) - ci_metrics["queue_time"] = _safe_delta( - dispatch_at, in_progress_at, "queue_time" + if updated_job_record is not None: + ci_metrics["execution_time"] = _safe_delta( + job_record.timestamp, updated_job_record.timestamp, "execution_time" + ) + + else: + raise HTTPException(400, f"unknown callback status: {status!r}") + + return ci_metrics + + +def handle(config: RelayConfig, body: dict, verified_repo: str) -> dict: + """Forward a downstream callback to HUD. + + ``body`` is the downstream self-report, passed through to HUD verbatim. + It carries the original dispatch envelope (``delivery_id``, ``payload``) + and a sibling ``workflow`` dict with status/conclusion/name/url. + + ``verified_repo`` is the OIDC-authenticated downstream repository — used + for allowlist / timing lookups, and surfaced to HUD as ``verified_repo`` + so HUD can trust it over anything self-reported in the body. + + State machine ensures: + - Callbacks without prior dispatch are rejected + - Timestamps (started_at, completed_at) are recorded once only + - Duplicate callbacks are handled gracefully + - State transitions follow valid lifecycle paths + """ + allowlist = _verify_access(config, verified_repo) + if allowlist is None: + return {"ok": True, "status": "ignored"} + + delivery_id, status, check_run_id, job_name, run_id = _parse_callback_body(body) + + # Get dispatch state record (proves valid webhook, provides dispatch timestamp) + dispatch_record = redis_helper.get_callback_state( + config, delivery_id, verified_repo, DISPATCH_CHECK_RUN_ID + ) + if not dispatch_record: + logger.warning( + "no dispatch record found for delivery_id=%s, verified_repo=%s; rejecting callback", + delivery_id, + verified_repo, ) - elif status == "completed": - completed_at = time.time() - in_progress_at = redis_helper.get_timing( - config, delivery_id, verified_repo, TimingPhase.IN_PROGRESS + raise HTTPException(400, "callback rejected: no matching dispatch record") + # Get job-level state record + job_record = redis_helper.get_callback_state( + config, delivery_id, verified_repo, check_run_id + ) + + ci_metrics = _update_state_and_compute_metrics( + config, + delivery_id, + verified_repo, + check_run_id, + job_name, + run_id, + status, + dispatch_record, + job_record, + ) + + repo_level = allowlist.get_repo_level(verified_repo) + if repo_level is None: + logger.error( + "verified_repo %s not found in allowlist after passing L2+ check", + verified_repo, ) - ci_metrics["execution_time"] = _safe_delta( - in_progress_at, completed_at, "execution_time" + raise HTTPException( + 500, f"internal error: repo level lookup failed for {verified_repo}" ) - # HUD owns schema validation: its 4xx surfaces back to the workflow author - # (forward_to_hud raises HTTPException). 5xx / network failures are - # swallowed inside forward_to_hud — they're HUD/infra problems and should - # not turn every downstream L2 CI red. - forward_to_hud(config, body, ci_metrics, verified_repo) + trusted = { + "ci_metrics": ci_metrics, + "verified_repo": verified_repo, + "downstream_repo_level": repo_level.value, + } + # downstream's payload is untrusted — provide it under the "callback_payload" + # key so HUD receives it under the expected untrusted namespace. + untrusted = {"callback_payload": body} + try: + forward_to_hud(config, trusted, untrusted) + except HTTPException as exc: + if 400 <= exc.status_code < 500: + raise + logger.error("HUD internal error (HTTP %d): %s", exc.status_code, exc.detail) + return { + "ok": True, + "status": status, + "warning": "HUD update failed but CI run is valid", + } return {"ok": True, "status": status} diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_hud.py b/aws/lambda/cross_repo_ci_relay/tests/test_hud.py index 68ae991d9a..e3c19bb6f4 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_hud.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_hud.py @@ -7,17 +7,22 @@ from utils.misc import HTTPException -def _cfg(url="http://hud/api/oot-ci-events", key="bot-key"): +def _cfg(url="http://hud/api/oot-ci-events", key="bot-key", max_retries=3): cfg = MagicMock() cfg.hud_api_url = url cfg.hud_bot_key = key + cfg.hud_max_retries = max_retries return cfg class TestForwardToHud(unittest.TestCase): @patch("utils.hud.urllib.request.urlopen") def test_empty_url_skips_request(self, mock_urlopen): - forward_to_hud(_cfg(url=""), {"delivery_id": "d"}, {}, "org/repo") + forward_to_hud( + _cfg(url=""), + {"ci_metrics": {}, "verified_repo": "org/repo"}, + {"callback_payload": {"delivery_id": "d"}}, + ) mock_urlopen.assert_not_called() @patch("utils.hud.urllib.request.urlopen") @@ -28,12 +33,16 @@ def test_hud_payload_has_three_top_level_fields(self, mock_urlopen): report = {"delivery_id": "d", "workflow": {"status": "completed"}} metrics = {"queue_time": 1.0, "execution_time": 2.0} - forward_to_hud(_cfg(), report, metrics, "org/repo") + forward_to_hud( + _cfg(), + {"ci_metrics": metrics, "verified_repo": "org/repo"}, + {"callback_payload": report}, + ) sent = json.loads(mock_urlopen.call_args[0][0].data) - self.assertEqual(sent["body"], report) - self.assertEqual(sent["ci_metrics"], metrics) - self.assertEqual(sent["authenticated_repo"], "org/repo") + self.assertEqual(sent["trusted"]["ci_metrics"], metrics) + self.assertEqual(sent["trusted"]["verified_repo"], "org/repo") + self.assertEqual(sent["untrusted"]["callback_payload"], report) @patch("utils.hud.urllib.request.urlopen") def test_4xx_propagates_with_huds_status(self, mock_urlopen): @@ -44,27 +53,56 @@ def test_4xx_propagates_with_huds_status(self, mock_urlopen): ) with self.assertRaises(HTTPException) as ctx: - forward_to_hud(_cfg(), {}, {}, "org/repo") + forward_to_hud( + _cfg(), + {"ci_metrics": {}, "verified_repo": "org/repo"}, + {"callback_payload": {}}, + ) self.assertEqual(ctx.exception.status_code, 422) + @patch("utils.hud.time.sleep") @patch("utils.hud.urllib.request.urlopen") - def test_5xx_is_swallowed(self, mock_urlopen): - # 5xx is HUD's own problem — don't turn every downstream CI red. - mock_urlopen.side_effect = urllib.error.HTTPError( - "http://hud", 503, "unavailable", {}, None - ) - - # must not raise - forward_to_hud(_cfg(), {}, {}, "org/repo") - + def test_retries_exhausted(self, mock_urlopen, mock_sleep): + """5xx and URLError are retried; after exhaustion an exception is raised.""" + cases = [ + ( + urllib.error.HTTPError("http://hud", 500, "err", {}, None), + HTTPException, + 500, + ), + (urllib.error.URLError("unreachable"), urllib.error.URLError, None), + ] + for exc, expected_type, expected_code in cases: + with self.subTest(exc=exc): + mock_urlopen.reset_mock() + mock_sleep.reset_mock() + mock_urlopen.side_effect = exc + + with self.assertRaises(expected_type) as ctx: + forward_to_hud( + _cfg(max_retries=2), + {"ci_metrics": {}, "verified_repo": "org/repo"}, + {"callback_payload": {}}, + ) + if expected_code is not None: + self.assertEqual(ctx.exception.status_code, expected_code) + self.assertEqual(mock_urlopen.call_count, 3) # 1 + 2 retries + self.assertEqual(mock_sleep.call_count, 2) + + @patch("utils.hud.time.sleep") @patch("utils.hud.urllib.request.urlopen") - def test_url_error_is_swallowed(self, mock_urlopen): - # Network-level failure (DNS, timeout, connection refused) is - # infrastructure, not a caller bug. - mock_urlopen.side_effect = urllib.error.URLError("unreachable") - - # must not raise - forward_to_hud(_cfg(), {}, {}, "org/repo") + def test_5xx_succeeds_on_retry(self, mock_urlopen, mock_sleep): + mock_urlopen.side_effect = [ + urllib.error.HTTPError("http://hud", 500, "unavailable", {}, None), + MagicMock(status=200), + ] + forward_to_hud( + _cfg(), + {"ci_metrics": {}, "verified_repo": "org/repo"}, + {"callback_payload": {}}, + ) + self.assertEqual(mock_urlopen.call_count, 2) + self.assertEqual(mock_sleep.call_count, 1) if __name__ == "__main__": diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py b/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py index 52ce7e828c..5c1a72e5cd 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py @@ -1,14 +1,19 @@ +import json import unittest +import unittest.mock from unittest.mock import MagicMock import redis as redis_lib from utils import redis_helper -from utils.misc import TimingPhase +from utils.misc import CallbackState, DISPATCH_CHECK_RUN_ID from utils.redis_helper import ( _ALLOWLIST_CACHE_KEY, + CallbackStateRecord, create_client, get_cached_yaml, + get_callback_state, set_cached_yaml, + set_callback_state, ) @@ -17,6 +22,8 @@ def _cfg(): cfg.redis_endpoint = "host:6379" cfg.redis_login = "" cfg.allowlist_ttl_seconds = 600 + cfg.oot_status_ttl = 3600 + cfg.rate_limit_per_min = 20 return cfg @@ -57,51 +64,229 @@ def test_create_client_reuses_cached_client_for_same_url(self): mock_from_url.assert_called_once() -class TestTimingHelpers(unittest.TestCase): +class TestCallbackStateMachine(unittest.TestCase): def setUp(self): redis_helper._cached_client = None redis_helper._cached_client_url = None - def test_set_timing_swallows_redis_error(self): - from utils.redis_helper import set_timing + def test_set_dispatch_state_with_timestamp(self): + """Webhook sets DISPATCHED state.""" + client = MagicMock() + result = set_callback_state( + _cfg(), + "del-123", + "org/repo", + DISPATCH_CHECK_RUN_ID, + CallbackState.DISPATCHED, + 1000.0, + client=client, + ) + self.assertTrue(result) + client.setex.assert_called_once() + def test_get_callback_state_parses_json(self): + """get_callback_state returns CallbackStateRecord from JSON.""" client = MagicMock() - client.setex.side_effect = redis_lib.exceptions.RedisError("boom") - cfg = MagicMock() - cfg.redis_endpoint = "host:6379" - cfg.redis_login = "" - cfg.oot_status_ttl = 3600 + client.get.return_value = json.dumps( + { + "state": "IN_PROGRESS", + "timestamp": 1010.5, + "job_name": "test-job", + "run_id": "12345", + } + ) + cfg = _cfg() - # must not raise — signature is (config, delivery_id, downstream_repo, phase, ts, client) - set_timing(cfg, "del-123", "org/repo", TimingPhase.DISPATCH, 1234.5, client) + record = get_callback_state(cfg, "del-123", "org/repo", "test-job", client) - def test_get_timing_returns_none_on_cache_miss(self): - from utils.redis_helper import get_timing + self.assertIsNotNone(record) + self.assertEqual(record.state, CallbackState.IN_PROGRESS) + self.assertEqual(record.timestamp, 1010.5) + self.assertEqual(record.job_name, "test-job") + self.assertEqual(record.run_id, "12345") + def test_get_callback_state_returns_none_on_missing_or_error(self): + """get_callback_state returns None on missing key or Redis error.""" client = MagicMock() + cfg = _cfg() + client.get.return_value = None - cfg = MagicMock() - cfg.redis_endpoint = "host:6379" - cfg.redis_login = "" + self.assertIsNone( + get_callback_state(cfg, "del-123", "org/repo", "test-job", client) + ) - result = get_timing(cfg, "del-123", "org/repo", TimingPhase.DISPATCH, client) + client.get.side_effect = redis_lib.exceptions.RedisError("boom") + self.assertIsNone( + get_callback_state(cfg, "del-123", "org/repo", "test-job", client) + ) - self.assertIsNone(result) + def test_invalid_state_transitions_rejected(self): + """Duplicate or invalid state transitions are all rejected.""" + cases = [ + # (check_run_id, new_state, existing_state_value_or_None) + (DISPATCH_CHECK_RUN_ID, CallbackState.DISPATCHED, "DISPATCHED"), + ("check-run", CallbackState.IN_PROGRESS, "IN_PROGRESS"), + ("check-run", CallbackState.COMPLETED, None), # None → COMPLETED + ("check-run", CallbackState.COMPLETED, "COMPLETED"), + ] + for check_run_id, state, existing in cases: + with self.subTest(state=state, existing=existing): + client = MagicMock() + client.get.return_value = ( + json.dumps( + { + "state": existing, + "timestamp": 1000.0, + "job_name": "job", + "run_id": "111", + } + ) + if existing + else None + ) + result = set_callback_state( + _cfg(), + "del-123", + "org/repo", + check_run_id, + state, + 1100.0, + client=client, + ) + self.assertFalse(result) + client.setex.assert_not_called() - def test_get_timing_swallows_redis_error(self): - # Timing is a best-effort reporting enrichment — a Redis outage must - # degrade gracefully to None rather than breaking the result handler. - from utils.redis_helper import get_timing + def test_set_completed_from_in_progress_accepts(self): + """IN_PROGRESS → COMPLETED transition is accepted.""" + client = MagicMock() + client.get.return_value = json.dumps( + { + "state": "IN_PROGRESS", + "timestamp": 1010.0, + "job_name": "test-job", + "run_id": "12345", + } + ) + result = set_callback_state( + _cfg(), + "del-123", + "org/repo", + "test-job", + CallbackState.COMPLETED, + 1020.0, + job_name="test-job", + run_id="12345", + client=client, + ) + self.assertTrue(result) + + def test_set_in_progress_accepts_first_callback(self): + """None → IN_PROGRESS is accepted when dispatch record exists.""" + + def get_side_effect(cfg, delivery_id, repo, check_run_id_arg, client=None): + if check_run_id_arg == DISPATCH_CHECK_RUN_ID: + return CallbackStateRecord( + CallbackState.DISPATCHED, 1000.0, "dispatch-job", 11111 + ) + return None client = MagicMock() - client.get.side_effect = redis_lib.exceptions.RedisError("timeout") - cfg = MagicMock() - cfg.redis_endpoint = "host:6379" - cfg.redis_login = "" + with unittest.mock.patch( + "utils.redis_helper.get_callback_state", side_effect=get_side_effect + ): + result = set_callback_state( + _cfg(), + "del-123", + "org/repo", + "check-run-456", + CallbackState.IN_PROGRESS, + 1010.0, + job_name="test-job", + run_id="99999", + client=client, + ) + self.assertTrue(result) - self.assertIsNone( - get_timing(cfg, "del-123", "org/repo", TimingPhase.DISPATCH, client) - ) + def test_set_non_dispatched_state_with_reserved_check_run_id_rejected(self): + """Using the reserved DISPATCH_CHECK_RUN_ID for non-DISPATCHED state is rejected.""" + client = MagicMock() + cfg = _cfg() + + for state in (CallbackState.IN_PROGRESS, CallbackState.COMPLETED): + with self.subTest(state=state): + client.reset_mock() + result = set_callback_state( + cfg, + "del-123", + "org/repo", + DISPATCH_CHECK_RUN_ID, + state, + 1010.0, + client=client, + ) + self.assertFalse(result) + client.setex.assert_not_called() + + def test_set_callback_state_redis_exception_returns_false(self): + """Redis write failure is caught and returns False.""" + + def get_side_effect(cfg, delivery_id, repo, check_run_id_arg, client=None): + if check_run_id_arg == DISPATCH_CHECK_RUN_ID: + return CallbackStateRecord( + CallbackState.DISPATCHED, 1000.0, "dispatch-job", "11111" + ) + return None + + cfg = _cfg() + client = MagicMock() + client.setex.side_effect = redis_lib.exceptions.RedisError("write failed") + + with unittest.mock.patch( + "utils.redis_helper.get_callback_state", side_effect=get_side_effect + ): + result = set_callback_state( + cfg, + "del-123", + "org/repo", + "check-run-456", + CallbackState.IN_PROGRESS, + 1010.0, + job_name="test-job", + run_id="99999", + client=client, + ) + + self.assertFalse(result) + + +class TestRateLimit(unittest.TestCase): + def setUp(self): + redis_helper._cached_client = None + redis_helper._cached_client_url = None + + def test_check_rate_limit_allowed(self): + from utils.redis_helper import check_rate_limit + + client = MagicMock() + client.zcard.return_value = 10 + self.assertTrue(check_rate_limit(_cfg(), "org/repo", client=client)) + + def test_check_rate_limit_exceeded(self): + from utils.redis_helper import check_rate_limit + + client = MagicMock() + client.zcard.return_value = 25 + self.assertFalse(check_rate_limit(_cfg(), "org/repo", client=client)) + + def test_check_rate_limit_redis_error_raises_500(self): + from utils.misc import HTTPException + from utils.redis_helper import check_rate_limit + + client = MagicMock() + client.zadd.side_effect = redis_lib.exceptions.RedisError("boom") + with self.assertRaises(HTTPException) as ctx: + check_rate_limit(_cfg(), "org/repo", client=client) + self.assertEqual(ctx.exception.status_code, 500) if __name__ == "__main__": diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py b/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py index c9a8f4d968..17423db1ef 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py @@ -3,7 +3,8 @@ from unittest.mock import MagicMock, patch from callback.result_handler import handle -from utils.misc import TimingPhase +from utils.misc import CallbackState, DISPATCH_CHECK_RUN_ID, HTTPException +from utils.redis_helper import CallbackStateRecord def _cfg(): @@ -13,10 +14,11 @@ def _cfg(): cfg.redis_endpoint = "host:6379" cfg.redis_login = "" cfg.oot_status_ttl = 259200 + cfg.rate_limit_per_min = 20 return cfg -def _body(status="completed"): +def _body(status="completed", job_name="default", check_run_id="12345", run_id="99999"): return { "event_type": "pull_request", "delivery_id": "del-123", @@ -29,6 +31,9 @@ def _body(status="completed"): "conclusion": "success" if status == "completed" else None, "name": "CI", "url": "http://ci.example.com/run/1", + "job_name": job_name, + "check_run_id": check_run_id, + "run_id": run_id, }, } @@ -39,12 +44,30 @@ def setUp(self): self.mock_load_allowlist = self.patcher_allowlist.start() mock_map = MagicMock() mock_map.get_repos_at_or_above_level.return_value = (["org/repo"], []) + mock_map.get_repo_level.return_value = MagicMock(value="L2") self.mock_load_allowlist.return_value = mock_map self.patcher_redis = patch("callback.result_handler.redis_helper") self.mock_redis = self.patcher_redis.start() self.mock_redis.create_client.return_value = MagicMock() - self.mock_redis.get_timing.return_value = None + + # Setup default: dispatch exists, job state is None (in_progress not yet reported) + def default_get_state(cfg, delivery_id, repo, check_run_id_arg, client=None): + if check_run_id_arg == DISPATCH_CHECK_RUN_ID: + return CallbackStateRecord( + CallbackState.DISPATCHED, time.time() - 30, "dispatch-job", 11111 + ) + elif check_run_id_arg == "12345": # default check_run_id in _body() + return CallbackStateRecord( + CallbackState.IN_PROGRESS, time.time() - 20, "default", 99999 + ) + return None + + self.mock_redis.get_callback_state.side_effect = default_get_state + + self.patcher_rate_limit = patch("callback.result_handler.check_rate_limit") + self.mock_check_rate_limit = self.patcher_rate_limit.start() + self.mock_check_rate_limit.return_value = True self.patcher_hud = patch("callback.result_handler.forward_to_hud") self.mock_hud = self.patcher_hud.start() @@ -52,6 +75,7 @@ def setUp(self): def tearDown(self): self.patcher_allowlist.stop() self.patcher_redis.stop() + self.patcher_rate_limit.stop() self.patcher_hud.stop() # --- allowlist uses the OIDC-verified repo, not the body --- @@ -67,76 +91,67 @@ def test_verified_repo_not_in_l2_returns_ignored(self): self.assertFalse(self.mock_redis.create_client.called) self.assertFalse(self.mock_hud.called) - # --- body is forwarded to HUD verbatim; authenticated_repo is a sibling --- + # --- body is forwarded to HUD verbatim; verified_repo is a sibling --- def test_body_is_passed_to_hud_unchanged(self): body = _body() handle(_cfg(), body, verified_repo="org/repo") - # forward_to_hud(config, downstream_report, ci_metrics, authenticated_repo) - _, report_arg, metrics_arg, auth_repo_arg = self.mock_hud.call_args[0] - self.assertIs(report_arg, body) - self.assertEqual(auth_repo_arg, "org/repo") - # authenticated_repo is a sibling of ci_metrics, not nested inside it. - self.assertNotIn("authenticated_repo", metrics_arg) - - # --- timing --- - - def test_in_progress_records_timing_and_computes_queue_time(self): - dispatch_at = time.time() - 30 - self.mock_redis.get_timing.return_value = dispatch_at - - result = handle(_cfg(), _body(status="in_progress"), verified_repo="org/repo") - - self.assertEqual(result, {"ok": True, "status": "in_progress"}) - # set_timing called with the verified repo, not any body-reported repo. - args, _ = self.mock_redis.set_timing.call_args - self.assertEqual(args[2], "org/repo") - self.assertEqual(args[3], TimingPhase.IN_PROGRESS) - _, _, metrics, _ = self.mock_hud.call_args[0] - self.assertAlmostEqual(metrics["queue_time"], 30, delta=1.0) - self.assertIsNone(metrics["execution_time"]) - - def test_completed_computes_execution_time_only(self): - # Each phase reports exactly one metric: completed → execution_time. - # queue_time was already reported during in_progress, so HUD merges - # the two rows on delivery_id. - in_progress_at = time.time() - 30 - self.mock_redis.get_timing.return_value = in_progress_at - - result = handle(_cfg(), _body(status="completed"), verified_repo="org/repo") - - self.assertEqual(result, {"ok": True, "status": "completed"}) - _, _, metrics, _ = self.mock_hud.call_args[0] - self.assertIsNone(metrics["queue_time"]) - self.assertAlmostEqual(metrics["execution_time"], 30, delta=1.0) - - # --- best-effort redis infra --- - - def test_get_timing_redis_error_does_not_break_handler(self): - self.mock_redis.get_timing.return_value = None - - result = handle(_cfg(), _body(status="completed"), verified_repo="org/repo") - - self.assertEqual(result, {"ok": True, "status": "completed"}) - self.assertTrue(self.mock_hud.called) - _, _, metrics, _ = self.mock_hud.call_args[0] - self.assertIsNone(metrics["queue_time"]) + # forward_to_hud(config, trusted, untrusted) + _, trusted_arg, untrusted_arg = self.mock_hud.call_args[0] + self.assertIs(untrusted_arg["callback_payload"], body) + self.assertEqual(trusted_arg.get("verified_repo"), "org/repo") + # verified_repo is a sibling of ci_metrics, not nested inside it. + self.assertNotIn("verified_repo", trusted_arg.get("ci_metrics", {})) + + # --- timing metrics calculation --- + + def test_queue_time_calculated_from_state_records(self): + """queue_time is the dispatch-to-in_progress delta.""" + dispatch_record = CallbackStateRecord( + CallbackState.DISPATCHED, 1000.0, "dispatch-job", 11111 + ) + job_record = CallbackStateRecord( + CallbackState.IN_PROGRESS, 1030.0, "default", 99999 + ) + self.mock_redis.get_callback_state.side_effect = [ + dispatch_record, # dispatch lookup + None, # job state: not yet set + job_record, # re-read after set_callback_state + ] + + handle(_cfg(), _body(status="in_progress"), verified_repo="org/repo") + + _, trusted_arg, _ = self.mock_hud.call_args[0] + metrics = trusted_arg["ci_metrics"] + self.assertEqual(metrics["queue_time"], 30.0) self.assertIsNone(metrics["execution_time"]) - def test_redis_client_unavailable_skips_timing(self): - self.mock_redis.create_client.side_effect = RuntimeError("redis down") - - result = handle(_cfg(), _body(status="completed"), verified_repo="org/repo") - - self.assertEqual(result, {"ok": True, "status": "completed"}) - self.assertTrue(self.mock_hud.called) + def test_execution_time_calculated_from_state_records(self): + """execution_time is the in_progress-to-completed delta.""" + dispatch_record = CallbackStateRecord( + CallbackState.DISPATCHED, 1000.0, "dispatch-job", 11111 + ) + job_record = CallbackStateRecord( + CallbackState.IN_PROGRESS, 1030.0, "default", 99999 + ) + completed_record = CallbackStateRecord( + CallbackState.COMPLETED, 1060.0, "default", 99999 + ) + self.mock_redis.get_callback_state.side_effect = [ + dispatch_record, # dispatch lookup + job_record, # job state: in_progress + completed_record, # re-read after set_callback_state + ] + + handle(_cfg(), _body(status="completed"), verified_repo="org/repo") + + _, trusted_arg, _ = self.mock_hud.call_args[0] + self.assertEqual(trusted_arg["ci_metrics"]["execution_time"], 30.0) # --- HUD 4xx propagates (5xx is swallowed inside forward_to_hud) --- def test_hud_4xx_propagates(self): - from utils.misc import HTTPException - self.mock_hud.side_effect = HTTPException(422, "bad schema") with self.assertRaises(HTTPException) as ctx: @@ -146,8 +161,6 @@ def test_hud_4xx_propagates(self): # --- required field validation --- def test_missing_delivery_id_returns_400(self): - from utils.misc import HTTPException - body = _body() del body["delivery_id"] @@ -156,8 +169,6 @@ def test_missing_delivery_id_returns_400(self): self.assertEqual(ctx.exception.status_code, 400) def test_missing_workflow_status_returns_400(self): - from utils.misc import HTTPException - body = _body() del body["workflow"]["status"] @@ -165,6 +176,16 @@ def test_missing_workflow_status_returns_400(self): handle(_cfg(), body, verified_repo="org/repo") self.assertEqual(ctx.exception.status_code, 400) + # --- rate limiting --- + + def test_rate_limit_exceeded_returns_429(self): + self.mock_check_rate_limit.return_value = False + + with self.assertRaises(HTTPException) as ctx: + handle(_cfg(), _body(), verified_repo="org/repo") + self.assertEqual(ctx.exception.status_code, 429) + self.assertFalse(self.mock_hud.called) + if __name__ == "__main__": unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/utils/allowlist.py b/aws/lambda/cross_repo_ci_relay/utils/allowlist.py index 822ef28c77..cd972efff8 100644 --- a/aws/lambda/cross_repo_ci_relay/utils/allowlist.py +++ b/aws/lambda/cross_repo_ci_relay/utils/allowlist.py @@ -96,6 +96,14 @@ def get_repos_at_or_above_level( oncalls.extend(lvl_oncalls) return repos, oncalls + def get_repo_level(self, repo: str) -> AllowlistLevel | None: + """Return the level for a specific repo, or None if repo is not in allowlist.""" + for level, entries in self._levels.items(): + for entry in entries: + if entry.repo == repo: + return level + return None + @classmethod def _parse(cls, raw: dict) -> "AllowlistMap": if not isinstance(raw, dict): diff --git a/aws/lambda/cross_repo_ci_relay/utils/config.py b/aws/lambda/cross_repo_ci_relay/utils/config.py index 2d9b9e056f..f8920de605 100644 --- a/aws/lambda/cross_repo_ci_relay/utils/config.py +++ b/aws/lambda/cross_repo_ci_relay/utils/config.py @@ -62,6 +62,8 @@ class RelayConfig: hud_api_url: str hud_bot_key: str oot_status_ttl: int + hud_max_retries: int + rate_limit_per_min: int @classmethod def from_env(cls) -> "RelayConfig": @@ -128,6 +130,22 @@ def from_env(cls) -> "RelayConfig": except ValueError: raise RuntimeError("OOT_STATUS_TTL must be a valid integer") + # Maximum number of retry attempts for HUD API calls. + # Default to 3 retries with exponential backoff. + try: + hud_max_retries = int(os.getenv("HUD_MAX_RETRIES", "3")) + if hud_max_retries < 0: + raise ValueError("must be non-negative") + except ValueError: + raise RuntimeError("HUD_MAX_RETRIES must be a non-negative integer") + + try: + rate_limit_per_min = int(os.getenv("RATE_LIMIT_PER_MIN", "20")) + if rate_limit_per_min <= 0: + raise ValueError("must be positive") + except ValueError: + raise RuntimeError("RATE_LIMIT_PER_MIN must be a positive integer") + return cls( github_app_id=_require("GITHUB_APP_ID"), github_app_secret=github_app_secret, @@ -141,6 +159,8 @@ def from_env(cls) -> "RelayConfig": hud_api_url=os.getenv("HUD_API_URL", ""), hud_bot_key=hud_bot_key, oot_status_ttl=oot_status_ttl, + hud_max_retries=hud_max_retries, + rate_limit_per_min=rate_limit_per_min, ) diff --git a/aws/lambda/cross_repo_ci_relay/utils/hud.py b/aws/lambda/cross_repo_ci_relay/utils/hud.py index 87017c20bf..1dc1695651 100644 --- a/aws/lambda/cross_repo_ci_relay/utils/hud.py +++ b/aws/lambda/cross_repo_ci_relay/utils/hud.py @@ -1,5 +1,6 @@ import json import logging +import time import urllib.error import urllib.request @@ -10,38 +11,21 @@ logger = logging.getLogger(__name__) -def forward_to_hud( - config: RelayConfig, - body: dict, - ci_metrics: dict, - authenticated_repo: str, -) -> None: +def forward_to_hud(config: RelayConfig, trusted: dict, untrusted: dict) -> None: """POST a callback record to HUD. - The HUD request body has three top-level fields: + This function splits inputs into two explicit namespaces: - - ``body``: the downstream workflow's callback body, forwarded verbatim. - Contains the original dispatch envelope (``delivery_id``, ``payload``) - plus a ``workflow`` dict the downstream self-reports. Treat every field - here as untrusted — downstream can set them to anything. - - ``ci_metrics``: relay-measured performance of the downstream CI - infrastructure (``queue_time``, ``execution_time``). These come from - relay's own timing records, not from the downstream, so HUD can trust - them as a signal of downstream CI capability. - - ``authenticated_repo``: the OIDC-authenticated downstream repository. - HUD should treat this as the sole trusted identity of the caller and - prefer it over any self-reported repo field inside ``body``. + - ``trusted``: a dict supplied by this relay and therefore considered + authoritative. + - ``untrusted``: a dict forwarded from the downstream workflow and treated as + untrusted user-supplied data. - Error handling splits by responsibility: - - - HUD 4xx (schema/validation errors, i.e. the caller's fault) is propagated - back to the downstream workflow so the workflow author sees a red CI - step and can fix their payload. - - HUD 5xx and network-level failures (HUD's own problem or infra) are - logged loudly but swallowed. The callback channel is observational — - letting HUD outages turn every downstream L2 CI red would blame the - wrong team. CloudWatch logs and alarms on ``HUD forward failed`` are - the intended operator signal here. + Retry Behavior: + On server errors (HTTP 5xx) or network failures (URLError), the request + is retried up to ``config.hud_max_retries`` times with exponential backoff + (1s, 2s, 4s, ...). Client errors (HTTP 4xx) are not retried and raise + HTTPException immediately. """ if not config.hud_api_url: # No HUD configured (e.g. local dev before HUD endpoint exists) — @@ -52,35 +36,74 @@ def forward_to_hud( hud_payload = json.dumps( { - "body": dict(body), - "ci_metrics": dict(ci_metrics), - "authenticated_repo": authenticated_repo, + "trusted": trusted, + "untrusted": untrusted, } ).encode("utf-8") + req = urllib.request.Request( config.hud_api_url, data=hud_payload, headers={ "Content-Type": "application/json", - "Authorization": config.hud_bot_key, + "X-OOT-Relay-Token": config.hud_bot_key, }, method="POST", ) - try: - with urllib.request.urlopen(req, timeout=10) as resp: - logger.info("HUD forward succeeded status=%d", resp.status) - except urllib.error.HTTPError as exc: - if 400 <= exc.code < 500: - detail = f"HUD rejected callback: HTTP {exc.code}: {exc.reason}" - logger.warning("HUD forward failed (client error): %s", detail) - raise HTTPException(exc.code, detail) from exc - # 5xx — HUD's own problem, don't propagate. + + last_exception = None + for attempt in range(config.hud_max_retries + 1): + try: + with urllib.request.urlopen(req, timeout=10) as resp: + logger.info("HUD forward succeeded status=%d", resp.status) + return + except urllib.error.HTTPError as exc: + if 400 <= exc.code < 500: + detail = f"HUD rejected callback: HTTP {exc.code}: {exc.reason}" + logger.warning("HUD forward failed (client error): %s", detail) + raise HTTPException(exc.code, detail) from exc + last_exception = exc + logger.warning( + "HUD forward failed (server error, attempt %d/%d): HTTP %d %s", + attempt + 1, + config.hud_max_retries + 1, + exc.code, + exc.reason, + ) + except urllib.error.URLError as exc: + last_exception = exc + logger.warning( + "HUD forward failed (unreachable, attempt %d/%d): %s", + attempt + 1, + config.hud_max_retries + 1, + exc.reason, + ) + + # If we have more retries remaining, wait with exponential backoff + if attempt < config.hud_max_retries: + delay = 2**attempt + logger.info("Retrying HUD forward in %d seconds...", delay) + time.sleep(delay) + + # All retries exhausted, raise the last exception + if isinstance(last_exception, urllib.error.HTTPError): + logger.exception( + "HUD forward failed after %d attempts: HTTP %d %s", + config.hud_max_retries + 1, + last_exception.code, + last_exception.reason, + ) + raise HTTPException( + 500, + "An internal failure occurred. " + "Your update was not saved, but the CI run is still valid. " + "You can attempt progressive retries after " + f"{60 // config.rate_limit_per_min} seconds or ignore this failure.", + ) from last_exception + else: logger.exception( - "HUD forward failed (server error), swallowing: HTTP %d %s", - exc.code, - exc.reason, + "HUD forward failed after %d attempts: %s", + config.hud_max_retries + 1, + last_exception.reason, ) - except urllib.error.URLError as exc: - # Network-level failure (DNS, timeout, connection refused). Treated - # as infrastructure rather than caller error — same as 5xx. - logger.exception("HUD forward failed (unreachable), swallowing: %s", exc.reason) + raise last_exception diff --git a/aws/lambda/cross_repo_ci_relay/utils/misc.py b/aws/lambda/cross_repo_ci_relay/utils/misc.py index d4a197ff48..aa1f3f2431 100644 --- a/aws/lambda/cross_repo_ci_relay/utils/misc.py +++ b/aws/lambda/cross_repo_ci_relay/utils/misc.py @@ -7,6 +7,7 @@ from __future__ import annotations import base64 +from dataclasses import dataclass from enum import Enum from typing import TypedDict @@ -26,16 +27,34 @@ class EventDispatchPayload(TypedDict): payload: dict -class TimingPhase(str, Enum): - """Phases recorded in the crcr:timing:* Redis keys. +# Specific check_run_id used to identify callbacks +# from dispatches that didn't specify a real check_run_id. +# Here we set as a string for better readability, +# but it could be any unique identifier. +DISPATCH_CHECK_RUN_ID = "dispatched" - - ``DISPATCH``: webhook side, when a repository_dispatch is fired. - - ``IN_PROGRESS``: result side, when the downstream workflow reports it - has started running. + +class CallbackState(str, Enum): + """Unified state machine for callback lifecycle (both webhook and callback sides). + + - ``DISPATCHED``: webhook side, when repository_dispatch is sent (job_name=DISPATCH_JOB_NAME). + - ``IN_PROGRESS``: callback side, when downstream workflow reports started (per-job). + - ``COMPLETED``: callback side, when downstream workflow reports finished (per-job). """ - DISPATCH = "dispatch" - IN_PROGRESS = "in_progress" + DISPATCHED = "DISPATCHED" + IN_PROGRESS = "IN_PROGRESS" + COMPLETED = "COMPLETED" + + +@dataclass +class CallbackStateRecord: + """Record containing state, timestamp, and job metadata for HUD grouping.""" + + state: CallbackState + timestamp: float + job_name: str + run_id: str def parse_lambda_event(event: dict) -> tuple[str, str, bytes, dict]: diff --git a/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py b/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py index f54421a767..844e48710d 100644 --- a/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py +++ b/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py @@ -1,5 +1,7 @@ +import json import logging import os +import time from typing import cast from urllib.parse import quote @@ -7,13 +9,19 @@ from redis.exceptions import RedisError from .config import RelayConfig -from .misc import TimingPhase +from .misc import ( + CallbackState, + CallbackStateRecord, + DISPATCH_CHECK_RUN_ID, + HTTPException, +) logger = logging.getLogger(__name__) -_ALLOWLIST_CACHE_KEY = "crcr:allowlist_yaml" -_TIMING_PREFIX = "crcr:timing:" +_ALLOWLIST_CACHE_KEY = "oot:allowlist_yaml" +_STATE_PREFIX = "oot:state:" +_RATE_LIMIT_PREFIX = "oot:rate:" _cached_client: redis_lib.Redis | None = None _cached_client_url: str | None = None @@ -126,52 +134,188 @@ def set_cached_yaml( logger.exception("redis cache write failed, continuing without cache") -def _timing_key(delivery_id: str, downstream_repo: str, phase: TimingPhase) -> str: - # delivery_id is GitHub's globally-unique X-GitHub-Delivery, so it disambiguates - # retries/reruns that share a head_sha. downstream_repo keeps the fan-out - # dimension since one delivery dispatches to many repos with independent timings. - return f"{_TIMING_PREFIX}{delivery_id}:{downstream_repo}:{phase.value}" +def check_rate_limit( + config: RelayConfig, + repo: str, + client: redis_lib.Redis | None = None, +) -> bool: + """Check if repo is within rate limit using sliding window. + + Returns True if allowed, False if rate exceeded. + Raises HTTPException(500) on Redis failure (fail-closed). + """ + try: + if client is None: + client = create_client(config) + + key = f"{_RATE_LIMIT_PREFIX}{repo}" + now = time.time() + window_start = now - 60 + + member = f"{now}:{repo}" + client.zadd(key, {member: now}) + client.zremrangebyscore(key, "-inf", window_start) + count = client.zcard(key) + client.expire(key, 120) + + if count > config.rate_limit_per_min: + logger.warning( + "rate limit exceeded key=%s count=%d limit=%d", + key, + count, + config.rate_limit_per_min, + ) + return False + return True + except RedisError as e: + logger.exception("redis rate limit check failed") + raise HTTPException(500, f"rate limit check failed: {e}") from e + + +def _state_key(delivery_id: str, downstream_repo: str, check_run_id: str) -> str: + """Redis key for callback state machine. + + Keyed by delivery_id + repo + check_run_id to support per-execution state tracking. + check_run_id is unique per job execution, enabling replay attack detection. + """ + return f"{_STATE_PREFIX}{delivery_id}:{downstream_repo}:{check_run_id}" -def set_timing( +def get_callback_state( config: RelayConfig, delivery_id: str, downstream_repo: str, - phase: TimingPhase, - ts: float, + check_run_id: str, client: redis_lib.Redis | None = None, -) -> None: - """Set timestamp for a given delivery+repo. Best-effort.""" +) -> CallbackStateRecord | None: + """Get callback state record from Redis, or None if no record exists. + + Returns a record containing state, timestamp, and optional job metadata. + """ try: if client is None: client = create_client(config) - key = _timing_key(delivery_id, downstream_repo, phase) - client.setex(key, config.oot_status_ttl, ts) - logger.info("%s timing cached key=%s", phase.value, key) + key = _state_key(delivery_id, downstream_repo, check_run_id) + value = client.get(key) + if value is None: + return None + data = json.loads(value) + return CallbackStateRecord( + state=CallbackState(data["state"]), + timestamp=data["timestamp"], + job_name=data["job_name"], + run_id=data["run_id"], + ) except Exception: - logger.exception("redis set_timing failed phase=%s", phase.value) + logger.exception("redis get_callback_state failed") + return None -def get_timing( +def set_callback_state( config: RelayConfig, delivery_id: str, downstream_repo: str, - phase: TimingPhase, + check_run_id: str, + state: CallbackState, + timestamp: float, + job_name: str | None = None, + run_id: int | None = None, client: redis_lib.Redis | None = None, -) -> float | None: - """Return the stored timestamp as a float, or None on cache miss / Redis error. +) -> bool: + """Set callback state with timestamp in Redis. Returns True on success, False on error. + + State transition validation: + + DISPATCHED state (webhook-side): + - None -> DISPATCHED: accept (initial dispatch) + - DISPATCHED -> DISPATCHED: reject (duplicate webhook) + + IN_PROGRESS state (callback-side): + - None -> IN_PROGRESS: accept (first callback for this check_run_id) + - IN_PROGRESS -> IN_PROGRESS: reject (replay attack for same check_run_id) - Best-effort: timing data is a reporting-only enrichment, so Redis failures - must not break the result handler. Errors are logged and swallowed. + COMPLETED state (callback-side): + - None -> COMPLETED: reject (no prior in_progress) + - IN_PROGRESS -> COMPLETED: accept (normal completion) + - COMPLETED -> COMPLETED: reject (duplicate) """ try: if client is None: client = create_client(config) - key = _timing_key(delivery_id, downstream_repo, phase) - value = client.get(key) - if value is None: - return None - return float(value) + + if check_run_id == DISPATCH_CHECK_RUN_ID and state != CallbackState.DISPATCHED: + logger.warning( + "check_run_id '%s' is preserved for DISPATCHED state only, rejecting invalid state=%s", + DISPATCH_CHECK_RUN_ID, + state.value, + ) + return False + + key = _state_key(delivery_id, downstream_repo, check_run_id) + + current_record = get_callback_state( + config, delivery_id, downstream_repo, check_run_id, client + ) + + if state == CallbackState.DISPATCHED: + if current_record is not None: + logger.warning("rejecting duplicate DISPATCHED key=%s", key) + return False + elif state == CallbackState.IN_PROGRESS: + if current_record is not None: + logger.warning( + "rejecting replay attack IN_PROGRESS for same " + "check_run_id=%s, downstream_repo=%s, job_name=%s, run_id=%s", + check_run_id, + downstream_repo, + job_name, + run_id, + ) + return False + + elif state == CallbackState.COMPLETED: + if current_record is None: + logger.warning( + "rejecting COMPLETED without prior IN_PROGRESS " + "key=%s, downstream_repo=%s, job_name=%s, run_id=%s", + key, + downstream_repo, + job_name, + run_id, + ) + return False + if current_record.state == CallbackState.COMPLETED: + logger.warning("rejecting duplicate COMPLETED key=%s", key) + return False + if current_record.state != CallbackState.IN_PROGRESS: + logger.warning( + "rejecting abnormal state transition %s -> COMPLETED " + "key=%s, downstream_repo=%s, job_name=%s, run_id=%s", + current_record.state.value, + key, + downstream_repo, + job_name, + run_id, + ) + return False + + data: dict = { + "state": state.value, + "timestamp": timestamp, + "job_name": job_name, + "run_id": run_id, + } + client.setex(key, config.oot_status_ttl, json.dumps(data)) + logger.info( + "callback state set key=%s state=%s timestamp=%s job_name=%s run_id=%s", + key, + state.value, + timestamp, + job_name, + run_id, + ) + return True + except Exception: - logger.exception("redis get_timing failed phase=%s", phase.value) - return None + logger.exception("redis set_callback_state failed") + return False diff --git a/aws/lambda/cross_repo_ci_relay/webhook/event_handler.py b/aws/lambda/cross_repo_ci_relay/webhook/event_handler.py index 85d8e8f9bc..c763e95230 100644 --- a/aws/lambda/cross_repo_ci_relay/webhook/event_handler.py +++ b/aws/lambda/cross_repo_ci_relay/webhook/event_handler.py @@ -8,7 +8,12 @@ from utils import gh_helper, redis_helper from utils.allowlist import AllowlistLevel, load_allowlist from utils.config import RelayConfig -from utils.misc import EventDispatchPayload, HTTPException, TimingPhase +from utils.misc import ( + CallbackState, + DISPATCH_CHECK_RUN_ID, + EventDispatchPayload, + HTTPException, +) logger = logging.getLogger(__name__) @@ -34,14 +39,15 @@ def _dispatch_one( client_payload=client_payload, ) - # Record dispatch timestamp for timing calculations (best-effort). - # Keyed by X-GitHub-Delivery (globally unique per webhook delivery) so - # retries/reruns with the same head_sha don't collide. - redis_helper.set_timing( + # Set dispatch state with timestamp to prove valid webhook occurred. + # Keyed by delivery_id + repo + DISPATCH_JOB_NAME="*" (repo-level, not job-specific). + # Timestamp is used for queue_time calculation (dispatch → in_progress). + redis_helper.set_callback_state( config, - client_payload.get("delivery_id"), + client_payload["delivery_id"], downstream_repo, - TimingPhase.DISPATCH, + DISPATCH_CHECK_RUN_ID, + CallbackState.DISPATCHED, time.time(), ) diff --git a/aws/lambda/cross_repo_ci_relay/webhook/lambda_function.py b/aws/lambda/cross_repo_ci_relay/webhook/lambda_function.py index a2f1444806..af67e9c979 100644 --- a/aws/lambda/cross_repo_ci_relay/webhook/lambda_function.py +++ b/aws/lambda/cross_repo_ci_relay/webhook/lambda_function.py @@ -10,6 +10,7 @@ from . import event_handler + logging.getLogger().setLevel(logging.INFO) logger = logging.getLogger(__name__) From 820563d2ba834baf7b9c30677f5c6e251d2afb23 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Wed, 13 May 2026 01:14:22 +0000 Subject: [PATCH 3/7] Optimize code - Refactor Makefile: separate directory creation for clarity in deployment process - Enhance Cross-Repo CI Relay Callback: handle edge case for CHECK_RUN_ID and improve repo level verification in result handler - Improve error handling in Cross-Repo CI Relay Callback action --- .../cross-repo-ci-relay-callback/action.yml | 21 +++- .../cross_repo_ci_relay/callback/Makefile | 3 +- .../callback/result_handler.py | 110 ++++++++---------- .../tests/test_result_handler.py | 5 +- .../cross_repo_ci_relay/webhook/Makefile | 3 +- 5 files changed, 70 insertions(+), 72 deletions(-) diff --git a/.github/actions/cross-repo-ci-relay-callback/action.yml b/.github/actions/cross-repo-ci-relay-callback/action.yml index 3e610e7991..b362140878 100644 --- a/.github/actions/cross-repo-ci-relay-callback/action.yml +++ b/.github/actions/cross-repo-ci-relay-callback/action.yml @@ -95,6 +95,12 @@ runs: current_time = datetime.now(timezone.utc).isoformat().replace("+00:00", "Z") + # In case check_run_id is not exist (edge case), replace it + # with {run_id}-{run_attempt}, which is also unique for each job run. + check_run_id = os.environ.get("CHECK_RUN_ID", "").strip() + if not check_run_id: + check_run_id = f"{os.environ['RUN_ID']}-{os.environ['RUN_ATTEMPT']}" + # Relay's original dispatch payload (event_type, delivery_id, payload) is # forwarded verbatim. Downstream-reported fields live in a sibling # `workflow` dict so the two sources stay clearly separated on the wire. @@ -106,7 +112,7 @@ runs: "url": os.environ["WORKFLOW_URL"], "run_attempt": os.environ["RUN_ATTEMPT"], "job_name": os.environ["JOB_NAME"], - "check_run_id": str(os.environ["CHECK_RUN_ID"]), + "check_run_id": check_run_id, "run_id": str(os.environ["RUN_ID"]), "started_at": None if status == "completed" else current_time, "completed_at": None if status == "in_progress" else current_time, @@ -128,8 +134,9 @@ runs: PYEOF ) + set +e HTTP_CODE=$( - curl --silent --show-error --output /tmp/relay_response.json \ + curl --silent --show-error --fail-with-body --output /tmp/relay_response.json \ --write-out "%{http_code}" \ -X POST \ -H "Content-Type: application/json" \ @@ -137,10 +144,16 @@ runs: --data "${PAYLOAD}" \ "${CALLBACK_URL%/}" ) + CURL_EXIT_CODE=$? + set -e - if [[ "${HTTP_CODE}" -lt 200 || "${HTTP_CODE}" -ge 300 ]]; then + if [[ "${CURL_EXIT_CODE}" -ne 0 ]]; then echo "::error::Callback server returned HTTP ${HTTP_CODE}." - exit 1 + if [[ -s /tmp/relay_response.json ]]; then + echo "Relay server error response body:" + cat /tmp/relay_response.json + fi + exit "${CURL_EXIT_CODE}" fi echo "Relay server response HTTP: ${HTTP_CODE}" diff --git a/aws/lambda/cross_repo_ci_relay/callback/Makefile b/aws/lambda/cross_repo_ci_relay/callback/Makefile index bd2ec01e87..e8185278b7 100644 --- a/aws/lambda/cross_repo_ci_relay/callback/Makefile +++ b/aws/lambda/cross_repo_ci_relay/callback/Makefile @@ -4,7 +4,8 @@ AWS_REGION := us-east-1 FUNCTION_NAME := cross_repo_ci_callback deployment.zip: clean - mkdir -p ./deployment/{callback,utils} + mkdir -p ./deployment/callback + mkdir -p ./deployment/utils cp *.py ./deployment/callback/ && cp $(UTILS_SRC) ./deployment/utils/ pip3 install --target ./deployment -r ../requirements.txt $(PIP_FLAGS) cd deployment && zip -r ../deployment.zip . diff --git a/aws/lambda/cross_repo_ci_relay/callback/result_handler.py b/aws/lambda/cross_repo_ci_relay/callback/result_handler.py index f991f63319..f5a5f80d05 100644 --- a/aws/lambda/cross_repo_ci_relay/callback/result_handler.py +++ b/aws/lambda/cross_repo_ci_relay/callback/result_handler.py @@ -33,15 +33,17 @@ def _safe_delta( return delta -def _verify_access(config: RelayConfig, verified_repo: str) -> AllowlistMap | None: - """Return the AllowlistMap when ``verified_repo`` is L2+, else None. +def _verify_access( + config: RelayConfig, verified_repo: str +) -> tuple[AllowlistMap, AllowlistLevel] | None: + """Return (AllowlistMap, repo_level) when ``verified_repo`` is L2+, else None. Raises HTTPException(429) if the per-repo rate limit is exceeded. A ``None`` return signals the caller to silently ignore the request. """ allowlist = load_allowlist(config) - l2_repos, _ = allowlist.get_repos_at_or_above_level(AllowlistLevel.L2) - if verified_repo not in l2_repos: + repo_level = allowlist.get_repo_level(verified_repo) + if repo_level is None or repo_level.value < AllowlistLevel.L2.value: logger.info( "verified_repo %s is not configured for L2+ features, ignoring result", verified_repo, @@ -53,7 +55,7 @@ def _verify_access(config: RelayConfig, verified_repo: str) -> AllowlistMap | No verified_repo, ) raise HTTPException(429, f"rate limit exceeded for {verified_repo}") - return allowlist + return allowlist, repo_level def _parse_callback_body(body: dict) -> tuple[str, str, str, str, str]: @@ -72,7 +74,7 @@ def _parse_callback_body(body: dict) -> tuple[str, str, str, str, str]: job_name = workflow_dict["job_name"] # Required for HUD grouping run_id = workflow_dict["run_id"] # Required for HUD grouping except (KeyError, TypeError) as exc: - logger.warning("missing required field in callback body: %s", exc) + logger.warning(f"missing required field in callback body: {exc}") raise HTTPException( 400, f"callback body missing required field: {exc}" ) from exc @@ -100,60 +102,50 @@ def _update_state_and_compute_metrics( Both metrics default to None when the required prior state is unavailable (e.g. Redis cache miss or rerun without matching prior record). """ + if status not in ("in_progress", "completed"): + raise HTTPException(400, f"unknown callback status: {status!r}") + ci_metrics: dict = {"queue_time": None, "execution_time": None} current_timestamp = time.time() + state = ( + CallbackState.IN_PROGRESS + if status == "in_progress" + else CallbackState.COMPLETED + ) - if status == "in_progress": - if not redis_helper.set_callback_state( - config, - delivery_id, - verified_repo, - check_run_id, - CallbackState.IN_PROGRESS, - current_timestamp, - job_name, - run_id, - ): - raise HTTPException( - 400, - f"callback rejected: invalid state transition delivery_id={delivery_id} status={status}", - ) - updated_job_record = redis_helper.get_callback_state( - config, delivery_id, verified_repo, check_run_id + if not redis_helper.set_callback_state( + config, + delivery_id, + verified_repo, + check_run_id, + state, + current_timestamp, + job_name, + run_id, + ): + raise HTTPException( + 400, + f"callback rejected: invalid state transition delivery_id={delivery_id} status={status}", ) - if updated_job_record is not None: - ci_metrics["queue_time"] = _safe_delta( - dispatch_record.timestamp, - updated_job_record.timestamp, - "queue_time", - ) - elif status == "completed": - if not redis_helper.set_callback_state( - config, - delivery_id, - verified_repo, - check_run_id, - CallbackState.COMPLETED, - current_timestamp, - job_name, - run_id, - ): - raise HTTPException( - 400, - f"callback rejected: invalid state transition delivery_id={delivery_id} status={status}", - ) - updated_job_record = redis_helper.get_callback_state( - config, delivery_id, verified_repo, check_run_id + updated_job_record = redis_helper.get_callback_state( + config, delivery_id, verified_repo, check_run_id + ) + if updated_job_record is None: + return ci_metrics + + if state == CallbackState.IN_PROGRESS: + ci_metrics["queue_time"] = _safe_delta( + dispatch_record.timestamp, + updated_job_record.timestamp, + "queue_time", ) - if updated_job_record is not None: + else: + if job_record is not None: ci_metrics["execution_time"] = _safe_delta( job_record.timestamp, updated_job_record.timestamp, "execution_time" ) - else: - raise HTTPException(400, f"unknown callback status: {status!r}") - return ci_metrics @@ -174,13 +166,13 @@ def handle(config: RelayConfig, body: dict, verified_repo: str) -> dict: - Duplicate callbacks are handled gracefully - State transitions follow valid lifecycle paths """ - allowlist = _verify_access(config, verified_repo) - if allowlist is None: + result = _verify_access(config, verified_repo) + if result is None: return {"ok": True, "status": "ignored"} + _, repo_level = result delivery_id, status, check_run_id, job_name, run_id = _parse_callback_body(body) - # Get dispatch state record (proves valid webhook, provides dispatch timestamp) dispatch_record = redis_helper.get_callback_state( config, delivery_id, verified_repo, DISPATCH_CHECK_RUN_ID ) @@ -191,7 +183,7 @@ def handle(config: RelayConfig, body: dict, verified_repo: str) -> dict: verified_repo, ) raise HTTPException(400, "callback rejected: no matching dispatch record") - # Get job-level state record + job_record = redis_helper.get_callback_state( config, delivery_id, verified_repo, check_run_id ) @@ -208,16 +200,6 @@ def handle(config: RelayConfig, body: dict, verified_repo: str) -> dict: job_record, ) - repo_level = allowlist.get_repo_level(verified_repo) - if repo_level is None: - logger.error( - "verified_repo %s not found in allowlist after passing L2+ check", - verified_repo, - ) - raise HTTPException( - 500, f"internal error: repo level lookup failed for {verified_repo}" - ) - trusted = { "ci_metrics": ci_metrics, "verified_repo": verified_repo, diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py b/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py index 17423db1ef..e64e29a4d0 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py @@ -3,6 +3,7 @@ from unittest.mock import MagicMock, patch from callback.result_handler import handle +from utils.allowlist import AllowlistLevel from utils.misc import CallbackState, DISPATCH_CHECK_RUN_ID, HTTPException from utils.redis_helper import CallbackStateRecord @@ -44,7 +45,7 @@ def setUp(self): self.mock_load_allowlist = self.patcher_allowlist.start() mock_map = MagicMock() mock_map.get_repos_at_or_above_level.return_value = (["org/repo"], []) - mock_map.get_repo_level.return_value = MagicMock(value="L2") + mock_map.get_repo_level.return_value = AllowlistLevel.L2 self.mock_load_allowlist.return_value = mock_map self.patcher_redis = patch("callback.result_handler.redis_helper") @@ -82,7 +83,7 @@ def tearDown(self): def test_verified_repo_not_in_l2_returns_ignored(self): mock_map = MagicMock() - mock_map.get_repos_at_or_above_level.return_value = (["other/repo"], []) + mock_map.get_repo_level.return_value = None self.mock_load_allowlist.return_value = mock_map result = handle(_cfg(), _body(), verified_repo="org/repo") diff --git a/aws/lambda/cross_repo_ci_relay/webhook/Makefile b/aws/lambda/cross_repo_ci_relay/webhook/Makefile index 75c4a565b8..adf53ae9b8 100644 --- a/aws/lambda/cross_repo_ci_relay/webhook/Makefile +++ b/aws/lambda/cross_repo_ci_relay/webhook/Makefile @@ -4,7 +4,8 @@ AWS_REGION := us-east-1 FUNCTION_NAME := cross_repo_ci_webhook deployment.zip: clean - mkdir -p ./deployment/{webhook,utils} + mkdir -p ./deployment/webhook + mkdir -p ./deployment/utils cp *.py ./deployment/webhook/ && cp $(UTILS_SRC) ./deployment/utils/ pip3 install --target ./deployment -r ../requirements.txt $(PIP_FLAGS) cd deployment && zip -r ../deployment.zip . From 6519fe2592a6c0bc20ffa8cc45f0cd7daebdbaef Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Fri, 15 May 2026 02:25:18 +0000 Subject: [PATCH 4/7] Remove audience setting in OIDC token generation --- .github/actions/cross-repo-ci-relay-callback/action.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/actions/cross-repo-ci-relay-callback/action.yml b/.github/actions/cross-repo-ci-relay-callback/action.yml index b362140878..109b981132 100644 --- a/.github/actions/cross-repo-ci-relay-callback/action.yml +++ b/.github/actions/cross-repo-ci-relay-callback/action.yml @@ -49,8 +49,7 @@ runs: uses: actions/github-script@v7 with: script: | - const audience = `https://github.com/${context.repo.owner}/${context.repo.repo}`; - const token = await core.getIDToken(audience); + const token = await core.getIDToken(); core.setSecret(token); core.setOutput('token', token); From c50432261d97a6995554dfcc8aad3879d4e86d06 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Tue, 19 May 2026 23:01:25 +0800 Subject: [PATCH 5/7] Fix comments 0519 --- .../actions/cross-repo-ci-relay-callback/action.yml | 2 +- .../cross_repo_ci_relay/callback/lambda_function.py | 4 +--- .../cross_repo_ci_relay/tests/test_jwt_helper.py | 12 ++++-------- aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py | 5 ++--- 4 files changed, 8 insertions(+), 15 deletions(-) diff --git a/.github/actions/cross-repo-ci-relay-callback/action.yml b/.github/actions/cross-repo-ci-relay-callback/action.yml index 109b981132..1a7ae54549 100644 --- a/.github/actions/cross-repo-ci-relay-callback/action.yml +++ b/.github/actions/cross-repo-ci-relay-callback/action.yml @@ -49,7 +49,7 @@ runs: uses: actions/github-script@v7 with: script: | - const token = await core.getIDToken(); + const token = await core.getIDToken("pytorch-cross-repo-ci-relay"); core.setSecret(token); core.setOutput('token', token); diff --git a/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py b/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py index 5f50201a43..6daaf7abf4 100644 --- a/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py +++ b/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py @@ -41,9 +41,7 @@ def lambda_handler(event, context): # Relay reports the OIDC-verified repo to HUD separately as # `verified_repo` so HUD has a trusted source of truth for the # caller's identity. - oidc_claims = jwt_helper.verify_oidc_token( - config, headers.get("authorization", "") - ) + oidc_claims = jwt_helper.verify_oidc_token(headers.get("authorization", "")) verified_repo = oidc_claims["repository"] result = result_handler.handle(config, body, verified_repo) diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py b/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py index e5fd3faf6c..6f26a3a5a6 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py @@ -5,10 +5,6 @@ from utils.misc import HTTPException -def _cfg(): - return MagicMock() - - class TestVerifyDownstreamIdentity(unittest.TestCase): def setUp(self): self.patcher_jwks = patch( @@ -31,20 +27,20 @@ def test_valid_token_returns_claims(self): } self.mock_decode.return_value = expected - claims = verify_oidc_token(_cfg(), "some.oidc.token") + claims = verify_oidc_token("some.oidc.token") self.assertEqual(claims, expected) def test_bearer_prefix_stripped_before_jwks_lookup(self): self.mock_decode.return_value = {"repository": "org/repo"} - verify_oidc_token(_cfg(), "Bearer some.oidc.token") + verify_oidc_token("Bearer some.oidc.token") self.mock_signing_key.assert_called_once_with("some.oidc.token") def test_empty_token_raises_401_without_jwks_lookup(self): with self.assertRaises(HTTPException) as ctx: - verify_oidc_token(_cfg(), "") + verify_oidc_token("") self.assertEqual(ctx.exception.status_code, 401) self.assertIn("Missing", ctx.exception.detail) self.mock_signing_key.assert_not_called() @@ -53,7 +49,7 @@ def test_jwks_lookup_failure_raises_401(self): self.mock_signing_key.side_effect = Exception("JWKS fetch failed") with self.assertRaises(HTTPException) as ctx: - verify_oidc_token(_cfg(), "bad.token") + verify_oidc_token("bad.token") self.assertEqual(ctx.exception.status_code, 401) diff --git a/aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py b/aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py index 088674de41..8149248016 100644 --- a/aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py +++ b/aws/lambda/cross_repo_ci_relay/utils/jwt_helper.py @@ -5,7 +5,6 @@ import logging import jwt -from utils.config import RelayConfig from utils.misc import HTTPException @@ -16,7 +15,7 @@ ) -def verify_oidc_token(config: RelayConfig, token: str) -> dict: +def verify_oidc_token(token: str) -> dict: """Decode a GitHub Actions OIDC token and return the claims. Rejects an empty/missing token up front so every call site gets a uniform @@ -36,7 +35,7 @@ def verify_oidc_token(config: RelayConfig, token: str) -> dict: signing_key.key, algorithms=["RS256"], issuer="https://token.actions.githubusercontent.com", - options={"verify_aud": False}, + audience="pytorch-cross-repo-ci-relay", ) except Exception as exc: logger.exception("OIDC token verification error") From 7daecbfa69e312c4f50a81680b5d55ee67b9a0e9 Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Thu, 21 May 2026 02:07:02 +0000 Subject: [PATCH 6/7] Fix comments in 0521 --- .../cross-repo-ci-relay-callback/action.yml | 21 +++++ .../callback/result_handler.py | 48 +++++----- .../tests/test_event_handler.py | 6 +- .../tests/test_jwt_helper.py | 18 ++++ .../tests/test_redis_helper.py | 63 ++++++------- .../tests/test_result_handler.py | 29 ++++++ .../cross_repo_ci_relay/utils/config.py | 8 +- aws/lambda/cross_repo_ci_relay/utils/hud.py | 25 +++--- .../cross_repo_ci_relay/utils/redis_helper.py | 88 +++++++++++-------- 9 files changed, 194 insertions(+), 112 deletions(-) diff --git a/.github/actions/cross-repo-ci-relay-callback/action.yml b/.github/actions/cross-repo-ci-relay-callback/action.yml index 1a7ae54549..36bced27b8 100644 --- a/.github/actions/cross-repo-ci-relay-callback/action.yml +++ b/.github/actions/cross-repo-ci-relay-callback/action.yml @@ -40,6 +40,21 @@ inputs: any publicly accessible URL. required: false default: '' + max-time: + description: > + Maximum time in seconds to wait for the callback HTTP request to complete. + required: false + default: 10 + max-retries: + description: > + Maximum number of retries for the callback HTTP request in case of failure. + required: false + default: 3 + retry-delay: + description: > + Delay in seconds between retries for the callback HTTP request. + required: false + default: 2 runs: using: composite @@ -70,6 +85,9 @@ runs: CHECK_RUN_ID: ${{ job.check_run_id }} RUN_ID: ${{ github.run_id }} RUN_ATTEMPT: ${{ github.run_attempt }} + MAX_TIME: ${{ inputs.max-time }} + MAX_RETRIES: ${{ inputs.max-retries }} + RETRY_DELAY: ${{ inputs.retry-delay }} run: | set -euo pipefail @@ -138,6 +156,9 @@ runs: curl --silent --show-error --fail-with-body --output /tmp/relay_response.json \ --write-out "%{http_code}" \ -X POST \ + --max-time ${MAX_TIME} \ + --retry ${MAX_RETRIES} \ + --retry-delay ${RETRY_DELAY} \ -H "Content-Type: application/json" \ -H "Authorization: Bearer ${OIDC_TOKEN}" \ --data "${PAYLOAD}" \ diff --git a/aws/lambda/cross_repo_ci_relay/callback/result_handler.py b/aws/lambda/cross_repo_ci_relay/callback/result_handler.py index f5a5f80d05..edf9a46b75 100644 --- a/aws/lambda/cross_repo_ci_relay/callback/result_handler.py +++ b/aws/lambda/cross_repo_ci_relay/callback/result_handler.py @@ -4,6 +4,7 @@ import time import utils.redis_helper as redis_helper +from redis.exceptions import RedisError from utils.allowlist import AllowlistLevel, AllowlistMap, load_allowlist from utils.config import RelayConfig from utils.hud import forward_to_hud @@ -113,20 +114,31 @@ def _update_state_and_compute_metrics( else CallbackState.COMPLETED ) - if not redis_helper.set_callback_state( - config, - delivery_id, - verified_repo, - check_run_id, - state, - current_timestamp, - job_name, - run_id, - ): + try: + redis_helper.set_callback_state( + config, + delivery_id, + verified_repo, + check_run_id, + state, + current_timestamp, + job_name, + run_id, + ) + except RedisError: raise HTTPException( - 400, - f"callback rejected: invalid state transition delivery_id={delivery_id} status={status}", + 503, "redis temporary outage: failed to persist callback state" ) + except AssertionError as e: + msg = ( + "callback rejected: invalid state transition delivery_id=%s repo=%s status=%s" + % (delivery_id, + verified_repo, + status) + ) + raise HTTPException(400, msg) from e + except Exception: + raise updated_job_record = redis_helper.get_callback_state( config, delivery_id, verified_repo, check_run_id @@ -209,15 +221,5 @@ def handle(config: RelayConfig, body: dict, verified_repo: str) -> dict: # key so HUD receives it under the expected untrusted namespace. untrusted = {"callback_payload": body} - try: - forward_to_hud(config, trusted, untrusted) - except HTTPException as exc: - if 400 <= exc.status_code < 500: - raise - logger.error("HUD internal error (HTTP %d): %s", exc.status_code, exc.detail) - return { - "ok": True, - "status": status, - "warning": "HUD update failed but CI run is valid", - } + forward_to_hud(config, trusted, untrusted) return {"ok": True, "status": status} diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py b/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py index 48dffd08c4..9991f6b04e 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_event_handler.py @@ -33,10 +33,11 @@ def test_ignored_action(self): {"ignored": True}, ) + @patch("webhook.event_handler.redis_helper.set_callback_state") @patch("webhook.event_handler.gh_helper.create_repository_dispatch") @patch("webhook.event_handler.gh_helper.get_repo_access_token", return_value="tok") @patch("webhook.event_handler.load_allowlist") - def test_dispatch_success(self, mock_load, _tok, mock_dispatch): + def test_dispatch_success(self, mock_load, _tok, mock_dispatch, _mock_set_state): mock_load.return_value = MagicMock( get_repos_at_or_above_level=MagicMock(return_value=(["org/a"], [])) ) @@ -44,6 +45,7 @@ def test_dispatch_success(self, mock_load, _tok, mock_dispatch): self.assertTrue(result["ok"]) mock_dispatch.assert_called_once() + @patch("webhook.event_handler.redis_helper.set_callback_state") @patch("webhook.event_handler.gh_helper.create_repository_dispatch") @patch( "webhook.event_handler.gh_helper.get_repo_access_token", @@ -51,7 +53,7 @@ def test_dispatch_success(self, mock_load, _tok, mock_dispatch): ) @patch("webhook.event_handler.load_allowlist") def test_dispatch_mints_token_per_downstream_repo( - self, mock_load, mock_get_repo_access_token, mock_dispatch + self, mock_load, mock_get_repo_access_token, mock_dispatch, _mock_set_state ): mock_load.return_value = MagicMock( get_repos_at_or_above_level=MagicMock( diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py b/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py index 6f26a3a5a6..a9887f67e8 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_jwt_helper.py @@ -30,6 +30,24 @@ def test_valid_token_returns_claims(self): claims = verify_oidc_token("some.oidc.token") self.assertEqual(claims, expected) + self.assertEqual(claims, expected) + self.mock_decode.assert_called_once() + self.assertEqual( + self.mock_decode.call_args.kwargs["audience"], + "pytorch-cross-repo-ci-relay", + ) + self.assertEqual( + self.mock_decode.call_args.kwargs["issuer"], + "https://token.actions.githubusercontent.com", + ) + + def test_wrong_audience_raises_401(self): + import jwt as _jwt + + self.mock_decode.side_effect = _jwt.InvalidAudienceError("bad aud") + with self.assertRaises(HTTPException) as ctx: + verify_oidc_token("token.with.wrong.aud") + self.assertEqual(ctx.exception.status_code, 401) def test_bearer_prefix_stripped_before_jwks_lookup(self): self.mock_decode.return_value = {"repository": "org/repo"} diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py b/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py index 5c1a72e5cd..91c652290c 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_redis_helper.py @@ -72,7 +72,7 @@ def setUp(self): def test_set_dispatch_state_with_timestamp(self): """Webhook sets DISPATCHED state.""" client = MagicMock() - result = set_callback_state( + set_callback_state( _cfg(), "del-123", "org/repo", @@ -81,7 +81,6 @@ def test_set_dispatch_state_with_timestamp(self): 1000.0, client=client, ) - self.assertTrue(result) client.setex.assert_called_once() def test_get_callback_state_parses_json(self): @@ -105,8 +104,8 @@ def test_get_callback_state_parses_json(self): self.assertEqual(record.job_name, "test-job") self.assertEqual(record.run_id, "12345") - def test_get_callback_state_returns_none_on_missing_or_error(self): - """get_callback_state returns None on missing key or Redis error.""" + def test_get_callback_state_returns_none_on_missing_key_and_on_redis_error(self): + """get_callback_state returns None on missing key and on Redis error.""" client = MagicMock() cfg = _cfg() @@ -144,16 +143,16 @@ def test_invalid_state_transitions_rejected(self): if existing else None ) - result = set_callback_state( - _cfg(), - "del-123", - "org/repo", - check_run_id, - state, - 1100.0, - client=client, - ) - self.assertFalse(result) + with self.assertRaises(AssertionError): + set_callback_state( + _cfg(), + "del-123", + "org/repo", + check_run_id, + state, + 1100.0, + client=client, + ) client.setex.assert_not_called() def test_set_completed_from_in_progress_accepts(self): @@ -167,7 +166,7 @@ def test_set_completed_from_in_progress_accepts(self): "run_id": "12345", } ) - result = set_callback_state( + set_callback_state( _cfg(), "del-123", "org/repo", @@ -178,7 +177,6 @@ def test_set_completed_from_in_progress_accepts(self): run_id="12345", client=client, ) - self.assertTrue(result) def test_set_in_progress_accepts_first_callback(self): """None → IN_PROGRESS is accepted when dispatch record exists.""" @@ -194,7 +192,7 @@ def get_side_effect(cfg, delivery_id, repo, check_run_id_arg, client=None): with unittest.mock.patch( "utils.redis_helper.get_callback_state", side_effect=get_side_effect ): - result = set_callback_state( + set_callback_state( _cfg(), "del-123", "org/repo", @@ -205,7 +203,6 @@ def get_side_effect(cfg, delivery_id, repo, check_run_id_arg, client=None): run_id="99999", client=client, ) - self.assertTrue(result) def test_set_non_dispatched_state_with_reserved_check_run_id_rejected(self): """Using the reserved DISPATCH_CHECK_RUN_ID for non-DISPATCHED state is rejected.""" @@ -215,20 +212,20 @@ def test_set_non_dispatched_state_with_reserved_check_run_id_rejected(self): for state in (CallbackState.IN_PROGRESS, CallbackState.COMPLETED): with self.subTest(state=state): client.reset_mock() - result = set_callback_state( - cfg, - "del-123", - "org/repo", - DISPATCH_CHECK_RUN_ID, - state, - 1010.0, - client=client, - ) - self.assertFalse(result) + with self.assertRaises(AssertionError): + set_callback_state( + cfg, + "del-123", + "org/repo", + DISPATCH_CHECK_RUN_ID, + state, + 1010.0, + client=client, + ) client.setex.assert_not_called() - def test_set_callback_state_redis_exception_returns_false(self): - """Redis write failure is caught and returns False.""" + def test_set_callback_state_redis_exception_raises(self): + """Redis write failure is re-raised as RedisError.""" def get_side_effect(cfg, delivery_id, repo, check_run_id_arg, client=None): if check_run_id_arg == DISPATCH_CHECK_RUN_ID: @@ -243,8 +240,8 @@ def get_side_effect(cfg, delivery_id, repo, check_run_id_arg, client=None): with unittest.mock.patch( "utils.redis_helper.get_callback_state", side_effect=get_side_effect - ): - result = set_callback_state( + ), self.assertRaises(redis_lib.exceptions.RedisError): + set_callback_state( cfg, "del-123", "org/repo", @@ -256,8 +253,6 @@ def get_side_effect(cfg, delivery_id, repo, check_run_id_arg, client=None): client=client, ) - self.assertFalse(result) - class TestRateLimit(unittest.TestCase): def setUp(self): diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py b/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py index e64e29a4d0..506670796f 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py @@ -187,6 +187,35 @@ def test_rate_limit_exceeded_returns_429(self): self.assertEqual(ctx.exception.status_code, 429) self.assertFalse(self.mock_hud.called) + # --- Redis outage during callback is tolerated --- + + def test_redis_error_fetching_dispatch_record_rejected(self): + """Redis error on dispatch lookup returns None, causing 400 rejection.""" + self.mock_redis.get_callback_state.side_effect = [ + None, # dispatch lookup returns None (get_callback_state catches RedisError) + ] + + with self.assertRaises(HTTPException) as ctx: + handle(_cfg(), _body(status="completed"), verified_repo="org/repo") + self.assertEqual(ctx.exception.status_code, 400) + + def test_redis_error_fetching_job_record_proceeds(self): + """Redis error on job record lookup returns None; callback proceeds.""" + dispatch_record = CallbackStateRecord( + CallbackState.DISPATCHED, 1000.0, "dispatch-job", 11111 + ) + # Three calls: dispatch lookup, job record lookup, re-read after set. + self.mock_redis.get_callback_state.side_effect = [ + dispatch_record, + None, # job record lookup returns None (get_callback_state catches RedisError) + None, # updated_job_record re-read → early return with empty metrics + ] + + handle(_cfg(), _body(status="completed"), verified_repo="org/repo") + + _, trusted_arg, _ = self.mock_hud.call_args[0] + self.assertIsNone(trusted_arg["ci_metrics"]["execution_time"]) + if __name__ == "__main__": unittest.main() diff --git a/aws/lambda/cross_repo_ci_relay/utils/config.py b/aws/lambda/cross_repo_ci_relay/utils/config.py index f8920de605..707c6ed8a4 100644 --- a/aws/lambda/cross_repo_ci_relay/utils/config.py +++ b/aws/lambda/cross_repo_ci_relay/utils/config.py @@ -146,6 +146,12 @@ def from_env(cls) -> "RelayConfig": except ValueError: raise RuntimeError("RATE_LIMIT_PER_MIN must be a positive integer") + hud_api_url = os.getenv("HUD_API_URL", "") + if hud_api_url and not hud_api_url.startswith("https://"): + raise RuntimeError( + "HUD_API_URL must use https:// to protect the bot key in transit" + ) + return cls( github_app_id=_require("GITHUB_APP_ID"), github_app_secret=github_app_secret, @@ -156,7 +162,7 @@ def from_env(cls) -> "RelayConfig": redis_login=redis_login, allowlist_ttl_seconds=allowlist_ttl_seconds, max_dispatch_workers=int(os.getenv("MAX_DISPATCH_WORKERS", "32")), - hud_api_url=os.getenv("HUD_API_URL", ""), + hud_api_url=hud_api_url, hud_bot_key=hud_bot_key, oot_status_ttl=oot_status_ttl, hud_max_retries=hud_max_retries, diff --git a/aws/lambda/cross_repo_ci_relay/utils/hud.py b/aws/lambda/cross_repo_ci_relay/utils/hud.py index 1dc1695651..40ba67ec8c 100644 --- a/aws/lambda/cross_repo_ci_relay/utils/hud.py +++ b/aws/lambda/cross_repo_ci_relay/utils/hud.py @@ -52,7 +52,8 @@ def forward_to_hud(config: RelayConfig, trusted: dict, untrusted: dict) -> None: ) last_exception = None - for attempt in range(config.hud_max_retries + 1): + total_attempts = config.hud_max_retries + 1 + for attempt in range(total_attempts): try: with urllib.request.urlopen(req, timeout=10) as resp: logger.info("HUD forward succeeded status=%d", resp.status) @@ -60,36 +61,34 @@ def forward_to_hud(config: RelayConfig, trusted: dict, untrusted: dict) -> None: except urllib.error.HTTPError as exc: if 400 <= exc.code < 500: detail = f"HUD rejected callback: HTTP {exc.code}: {exc.reason}" - logger.warning("HUD forward failed (client error): %s", detail) + logger.error("HUD forward failed (client error): %s", detail) raise HTTPException(exc.code, detail) from exc last_exception = exc - logger.warning( + logger.debug( "HUD forward failed (server error, attempt %d/%d): HTTP %d %s", attempt + 1, - config.hud_max_retries + 1, + total_attempts, exc.code, exc.reason, ) except urllib.error.URLError as exc: last_exception = exc - logger.warning( + logger.debug( "HUD forward failed (unreachable, attempt %d/%d): %s", attempt + 1, - config.hud_max_retries + 1, + total_attempts, exc.reason, ) # If we have more retries remaining, wait with exponential backoff if attempt < config.hud_max_retries: - delay = 2**attempt - logger.info("Retrying HUD forward in %d seconds...", delay) - time.sleep(delay) + time.sleep(2**attempt) - # All retries exhausted, raise the last exception + # All retries exhausted if isinstance(last_exception, urllib.error.HTTPError): - logger.exception( + logger.error( "HUD forward failed after %d attempts: HTTP %d %s", - config.hud_max_retries + 1, + total_attempts, last_exception.code, last_exception.reason, ) @@ -101,7 +100,7 @@ def forward_to_hud(config: RelayConfig, trusted: dict, untrusted: dict) -> None: f"{60 // config.rate_limit_per_min} seconds or ignore this failure.", ) from last_exception else: - logger.exception( + logger.error( "HUD forward failed after %d attempts: %s", config.hud_max_retries + 1, last_exception.reason, diff --git a/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py b/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py index 844e48710d..74ac6677f8 100644 --- a/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py +++ b/aws/lambda/cross_repo_ci_relay/utils/redis_helper.py @@ -206,9 +206,11 @@ def get_callback_state( job_name=data["job_name"], run_id=data["run_id"], ) + except RedisError: + logger.exception("redis temporary outage or unreachable") except Exception: logger.exception("redis get_callback_state failed") - return None + return None def set_callback_state( @@ -221,8 +223,8 @@ def set_callback_state( job_name: str | None = None, run_id: int | None = None, client: redis_lib.Redis | None = None, -) -> bool: - """Set callback state with timestamp in Redis. Returns True on success, False on error. +) -> None: + """Set callback state with timestamp in Redis. State transition validation: @@ -239,17 +241,19 @@ def set_callback_state( - IN_PROGRESS -> COMPLETED: accept (normal completion) - COMPLETED -> COMPLETED: reject (duplicate) """ + error_msg = "" try: if client is None: client = create_client(config) if check_run_id == DISPATCH_CHECK_RUN_ID and state != CallbackState.DISPATCHED: - logger.warning( - "check_run_id '%s' is preserved for DISPATCHED state only, rejecting invalid state=%s", - DISPATCH_CHECK_RUN_ID, - state.value, + error_msg = ( + "check_run_id '%s' is preserved for DISPATCHED state only, rejecting invalid state=%s" + % ( + DISPATCH_CHECK_RUN_ID, + state.value, + ) ) - return False key = _state_key(delivery_id, downstream_repo, check_run_id) @@ -259,45 +263,50 @@ def set_callback_state( if state == CallbackState.DISPATCHED: if current_record is not None: - logger.warning("rejecting duplicate DISPATCHED key=%s", key) - return False + error_msg = "rejecting duplicate DISPATCHED key=%s" % key elif state == CallbackState.IN_PROGRESS: if current_record is not None: - logger.warning( + error_msg = ( "rejecting replay attack IN_PROGRESS for same " - "check_run_id=%s, downstream_repo=%s, job_name=%s, run_id=%s", - check_run_id, - downstream_repo, - job_name, - run_id, + "check_run_id=%s, downstream_repo=%s, job_name=%s, run_id=%s" + % ( + check_run_id, + downstream_repo, + job_name, + run_id, + ) ) - return False elif state == CallbackState.COMPLETED: if current_record is None: - logger.warning( + error_msg = ( "rejecting COMPLETED without prior IN_PROGRESS " - "key=%s, downstream_repo=%s, job_name=%s, run_id=%s", - key, - downstream_repo, - job_name, - run_id, + "key=%s, downstream_repo=%s, job_name=%s, run_id=%s" + % ( + key, + downstream_repo, + job_name, + run_id, + ) ) - return False - if current_record.state == CallbackState.COMPLETED: - logger.warning("rejecting duplicate COMPLETED key=%s", key) - return False - if current_record.state != CallbackState.IN_PROGRESS: - logger.warning( + elif current_record.state == CallbackState.COMPLETED: + error_msg = "rejecting duplicate COMPLETED key=%s" % key + elif current_record.state != CallbackState.IN_PROGRESS: + error_msg = ( "rejecting abnormal state transition %s -> COMPLETED " - "key=%s, downstream_repo=%s, job_name=%s, run_id=%s", - current_record.state.value, - key, - downstream_repo, - job_name, - run_id, + "key=%s, downstream_repo=%s, job_name=%s, run_id=%s" + % ( + current_record.state.value, + key, + downstream_repo, + job_name, + run_id, + ) ) - return False + + if error_msg: + logger.warning(error_msg) + raise AssertionError(error_msg) data: dict = { "state": state.value, @@ -314,8 +323,9 @@ def set_callback_state( job_name, run_id, ) - return True - + except RedisError: + logger.exception("set_callback_state: redis is temporary outage or unreachable") + raise except Exception: logger.exception("redis set_callback_state failed") - return False + raise From d67b442adfa686b504088971737a832ff38aeecd Mon Sep 17 00:00:00 2001 From: KarhouTam Date: Fri, 22 May 2026 02:07:16 +0000 Subject: [PATCH 7/7] Fix codes - Unify the name from "result*" to "callback*" - Fix lints --- .github/workflows/_lambda-do-release-runners.yml | 2 +- aws/lambda/cross_repo_ci_relay/README.md | 6 +++--- .../{result_handler.py => callback_handler.py} | 4 +--- .../cross_repo_ci_relay/callback/lambda_function.py | 8 ++++---- aws/lambda/cross_repo_ci_relay/local_server.py | 4 ++-- ...st_result_handler.py => test_callback_handler.py} | 12 ++++++------ ...ler_lambda.py => test_callback_handler_lambda.py} | 10 +++++----- 7 files changed, 22 insertions(+), 24 deletions(-) rename aws/lambda/cross_repo_ci_relay/callback/{result_handler.py => callback_handler.py} (99%) rename aws/lambda/cross_repo_ci_relay/tests/{test_result_handler.py => test_callback_handler.py} (95%) rename aws/lambda/cross_repo_ci_relay/tests/{test_result_handler_lambda.py => test_callback_handler_lambda.py} (93%) diff --git a/.github/workflows/_lambda-do-release-runners.yml b/.github/workflows/_lambda-do-release-runners.yml index 0a253308d5..02500dd143 100644 --- a/.github/workflows/_lambda-do-release-runners.yml +++ b/.github/workflows/_lambda-do-release-runners.yml @@ -93,7 +93,7 @@ jobs: { dir-name: 'keep-going-call-log-classifier', zip-name: 'keep-going-call-log-classifier' }, { dir-name: 'buildkite-webhook-handler', zip-name: 'buildkite-webhook-handler' }, { dir-name: 'benchmark_regression_summary_report', zip-name: 'benchmark-regression-summary-report' }, - { dir-name: 'cross_repo_ci_relay/result', zip-name: 'cross-repo-ci-result' }, + { dir-name: 'cross_repo_ci_relay/callback', zip-name: 'cross-repo-ci-callback' }, { dir-name: 'cross_repo_ci_relay/webhook', zip-name: 'cross-repo-ci-webhook' }, ] name: Upload Release for ${{ matrix.dir-name }} lambda diff --git a/aws/lambda/cross_repo_ci_relay/README.md b/aws/lambda/cross_repo_ci_relay/README.md index 58af886a44..d7d8187960 100644 --- a/aws/lambda/cross_repo_ci_relay/README.md +++ b/aws/lambda/cross_repo_ci_relay/README.md @@ -188,7 +188,7 @@ deployment/ │ └── event_handler.py ├── callback/ │ ├── lambda_function.py -│ └── result_handler.py +│ └── callback_handler.py └── utils/ └── ... ``` @@ -264,10 +264,10 @@ make clean smee --url https://smee.io/ --path /github/webhook --port 8000 ``` - CLI to forward GitHub result callbacks to localhost (set this URL as `callback-url` in the downstream workflow): + CLI to forward GitHub callback callbacks to localhost (set this URL as `callback-url` in the downstream workflow): ```bash npm install -g smee-client - smee --url https://smee.io/ --path /github/result --port 8000 + smee --url https://smee.io/ --path /github/callback --port 8000 ``` #### Remote diff --git a/aws/lambda/cross_repo_ci_relay/callback/result_handler.py b/aws/lambda/cross_repo_ci_relay/callback/callback_handler.py similarity index 99% rename from aws/lambda/cross_repo_ci_relay/callback/result_handler.py rename to aws/lambda/cross_repo_ci_relay/callback/callback_handler.py index edf9a46b75..34ef63ced1 100644 --- a/aws/lambda/cross_repo_ci_relay/callback/result_handler.py +++ b/aws/lambda/cross_repo_ci_relay/callback/callback_handler.py @@ -132,9 +132,7 @@ def _update_state_and_compute_metrics( except AssertionError as e: msg = ( "callback rejected: invalid state transition delivery_id=%s repo=%s status=%s" - % (delivery_id, - verified_repo, - status) + % (delivery_id, verified_repo, status) ) raise HTTPException(400, msg) from e except Exception: diff --git a/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py b/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py index 6daaf7abf4..e573e4eeb9 100644 --- a/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py +++ b/aws/lambda/cross_repo_ci_relay/callback/lambda_function.py @@ -7,7 +7,7 @@ from utils.config import get_config from utils.misc import HTTPException, JSON_HEADERS, parse_lambda_event -from . import result_handler +from . import callback_handler logging.getLogger().setLevel(logging.INFO) @@ -19,8 +19,8 @@ def lambda_handler(event, context): logger.info("request method=%s path=%s", method, path) - if method != "POST" or path != "/github/result": - if path == "/github/result": + if method != "POST" or path != "/github/callback": + if path == "/github/callback": return { "statusCode": 405, "headers": JSON_HEADERS, @@ -44,7 +44,7 @@ def lambda_handler(event, context): oidc_claims = jwt_helper.verify_oidc_token(headers.get("authorization", "")) verified_repo = oidc_claims["repository"] - result = result_handler.handle(config, body, verified_repo) + result = callback_handler.handle(config, body, verified_repo) return {"statusCode": 200, "headers": JSON_HEADERS, "body": json.dumps(result)} except json.JSONDecodeError: diff --git a/aws/lambda/cross_repo_ci_relay/local_server.py b/aws/lambda/cross_repo_ci_relay/local_server.py index 483e3daee6..2c0edeba60 100644 --- a/aws/lambda/cross_repo_ci_relay/local_server.py +++ b/aws/lambda/cross_repo_ci_relay/local_server.py @@ -35,8 +35,8 @@ async def github_webhook(req: Request): ) -@relay_router.post("/github/result") -async def github_result(req: Request): +@relay_router.post("/github/callback") +async def github_callback(req: Request): body = await req.body() event = { "requestContext": { diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py b/aws/lambda/cross_repo_ci_relay/tests/test_callback_handler.py similarity index 95% rename from aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py rename to aws/lambda/cross_repo_ci_relay/tests/test_callback_handler.py index 506670796f..e2ff10bb82 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_callback_handler.py @@ -2,7 +2,7 @@ import unittest from unittest.mock import MagicMock, patch -from callback.result_handler import handle +from callback.callback_handler import handle from utils.allowlist import AllowlistLevel from utils.misc import CallbackState, DISPATCH_CHECK_RUN_ID, HTTPException from utils.redis_helper import CallbackStateRecord @@ -39,16 +39,16 @@ def _body(status="completed", job_name="default", check_run_id="12345", run_id=" } -class TestResultHandler(unittest.TestCase): +class TestCallbackHandler(unittest.TestCase): def setUp(self): - self.patcher_allowlist = patch("callback.result_handler.load_allowlist") + self.patcher_allowlist = patch("callback.callback_handler.load_allowlist") self.mock_load_allowlist = self.patcher_allowlist.start() mock_map = MagicMock() mock_map.get_repos_at_or_above_level.return_value = (["org/repo"], []) mock_map.get_repo_level.return_value = AllowlistLevel.L2 self.mock_load_allowlist.return_value = mock_map - self.patcher_redis = patch("callback.result_handler.redis_helper") + self.patcher_redis = patch("callback.callback_handler.redis_helper") self.mock_redis = self.patcher_redis.start() self.mock_redis.create_client.return_value = MagicMock() @@ -66,11 +66,11 @@ def default_get_state(cfg, delivery_id, repo, check_run_id_arg, client=None): self.mock_redis.get_callback_state.side_effect = default_get_state - self.patcher_rate_limit = patch("callback.result_handler.check_rate_limit") + self.patcher_rate_limit = patch("callback.callback_handler.check_rate_limit") self.mock_check_rate_limit = self.patcher_rate_limit.start() self.mock_check_rate_limit.return_value = True - self.patcher_hud = patch("callback.result_handler.forward_to_hud") + self.patcher_hud = patch("callback.callback_handler.forward_to_hud") self.mock_hud = self.patcher_hud.start() def tearDown(self): diff --git a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler_lambda.py b/aws/lambda/cross_repo_ci_relay/tests/test_callback_handler_lambda.py similarity index 93% rename from aws/lambda/cross_repo_ci_relay/tests/test_result_handler_lambda.py rename to aws/lambda/cross_repo_ci_relay/tests/test_callback_handler_lambda.py index 3ee85e3548..9bb9f592c2 100644 --- a/aws/lambda/cross_repo_ci_relay/tests/test_result_handler_lambda.py +++ b/aws/lambda/cross_repo_ci_relay/tests/test_callback_handler_lambda.py @@ -10,7 +10,7 @@ def _event( *, method="POST", - path="/github/result", + path="/github/callback", body=None, headers=None, base64_encoded=False, @@ -31,7 +31,7 @@ def _event( } -class TestResultLambdaHandler(unittest.TestCase): +class TestCallbackLambdaHandler(unittest.TestCase): def setUp(self): import utils.config @@ -65,7 +65,7 @@ def test_oidc_failure_returns_401(self, mock_oidc, mock_get_config): @patch("callback.lambda_function.get_config") @patch("callback.lambda_function.jwt_helper.verify_oidc_token") - @patch("callback.lambda_function.result_handler.handle") + @patch("callback.lambda_function.callback_handler.handle") def test_happy_path_forwards_body_and_verified_repo( self, mock_handle, mock_oidc, mock_get_config ): @@ -85,7 +85,7 @@ def test_happy_path_forwards_body_and_verified_repo( @patch("callback.lambda_function.get_config") @patch("callback.lambda_function.jwt_helper.verify_oidc_token") - @patch("callback.lambda_function.result_handler.handle") + @patch("callback.lambda_function.callback_handler.handle") def test_hud_error_from_handler_is_forwarded( self, mock_handle, mock_oidc, mock_get_config ): @@ -100,7 +100,7 @@ def test_hud_error_from_handler_is_forwarded( @patch("callback.lambda_function.get_config") @patch("callback.lambda_function.jwt_helper.verify_oidc_token") - @patch("callback.lambda_function.result_handler.handle") + @patch("callback.lambda_function.callback_handler.handle") def test_unhandled_exception_returns_500( self, mock_handle, mock_oidc, mock_get_config ):