Skip to content

Add automatic Lightning SCAFFOLD support#4838

Open
holgerroth wants to merge 9 commits into
NVIDIA:mainfrom
holgerroth:codex/lightning-scaffold
Open

Add automatic Lightning SCAFFOLD support#4838
holgerroth wants to merge 9 commits into
NVIDIA:mainfrom
holgerroth:codex/lightning-scaffold

Conversation

@holgerroth

@holgerroth holgerroth commented Jun 29, 2026

Copy link
Copy Markdown
Collaborator

Summary

  • automatically detect SCAFFOLD server metadata in nvflare.client.lightning.patch()
  • apply PTScaffoldHelper around Lightning optimizer steps and return SCAFFOLD_CTRL_DIFF
  • preserve local controls across rounds while keeping the FedAvg path unchanged
  • add fail-fast diagnostics for unsupported optimization modes and malformed controls
  • update the Hello Lightning example, recipe documentation, and integration coverage

Motivation

ScaffoldRecipe requires every client to apply control-variate corrections during local training and return a control delta. Raw PyTorch clients already do this explicitly with PTScaffoldHelper, but the Lightning patch() callback previously only handled model receive/send. As a result, changing a patched Lightning job from FedAvgRecipe to ScaffoldRecipe left out the client-side algorithm and caused server aggregation to fail.

This change closes that contract gap without changing the public patch() API. Patched Lightning clients now activate SCAFFOLD automatically when the received FLModel contains global controls.

Availability

This feature targets NVFlare 2.9.0. The Hello Lightning requirement pins nvflare~=2.9.0rc, and the example documentation directs main-branch users to install NVFlare from this repository until the 2.9 package is published. The feature is intentionally not listed in the 2.8 release notes.

Implementation

  • Keep FLCallback algorithm-neutral through a private algorithm-handler manager.
  • Lazy-load and instantiate SCAFFOLD only after SCAFFOLD_CTRL_GLOBAL metadata is received; FedAvg never creates a SCAFFOLD handler.
  • Initialize PTScaffoldHelper once per SCAFFOLD client and preserve its local controls across rounds.
  • Store the global-weight snapshot as detached CPU tensors instead of deep-copying the complete Lightning module.
  • Track the active training round so mid-fit validation collects metrics without receiving or reloading the global model.
  • Retain the module associated with validation input and apply that input to a different fit module without another receive().
  • Capture the actual optimizer learning rate in on_before_optimizer_step and apply corrections after completed optimizer steps, including with gradient accumulation.
  • Permit finite zero-LR warmup steps while requiring positive total LR exposure for a completed round.
  • Apply SCAFFOLD correction in place to trainable parameters only. Buffers such as BatchNorm running statistics remain ordinary aggregated model state and are omitted from control deltas.
  • Reject collisions between user metadata and automatically generated algorithm metadata instead of overwriting SCAFFOLD_CTRL_DIFF.
  • Preserve explicit user step metadata and the existing FedAvg step-count fallback.
  • Reject manual optimization, scaler-backed mixed precision, multiple optimizers, unequal parameter-group learning rates, negative/non-finite learning rates, malformed controls, missing later-round controls, and rounds with no completed optimizer step or no positive LR exposure.

The automatic path supports Lightning automatic optimization with one optimizer and precision="32-true" or precision="bf16-mixed". Raw PyTorch loops, manual Lightning optimization, and scaler-backed mixed precision must use an explicit receive/train/send loop with PTScaffoldHelper; patch() is not used for that path. FedProx support is intentionally out of scope.

Validation

  • combined Lightning, PyTorch SCAFFOLD, controller, and recipe unit suites: 151 passed, 4 optional-dependency skips
  • two-client, two-round fit-only BatchNorm H100 smoke test: passed, valid checkpoint, zero errors
  • unchanged two-client, two-round Hello Lightning SCAFFOLD H100 job: passed, valid checkpoint, zero errors
  • ./runtest.sh -s --skip-install: passed
  • ./build_doc.sh --html --skip-api: passed (existing documentation warnings only)
  • GitHub pre-merge matrix: all checks passing, including unit tests on Python 3.10-3.14 and Ubuntu 22.04/24.04

H100 end-to-end evaluation

The evaluation harness and results are not included in this PR.

Commit d11f4198d was tested on NVIDIA H100 NVL GPUs with the advanced CIFAR-10 example setup:

  • same algorithm-agnostic patched Lightning client for FedAvg and SCAFFOLD
  • no client-provided NUM_STEPS_CURRENT_ROUND, exercising automatic SCAFFOLD step accounting
  • ModerateCNN, 8 clients, Dirichlet alpha 0.1, seed 0
  • 50 rounds, 4 local epochs, SGD with momentum, cosine LR schedule
  • batch size 64, initial LR 0.05
Algorithm Final test accuracy Runtime Successful aggregations Tracebacks/errors
FedAvg 80.59% 607.5 s 50/50 0
SCAFFOLD 82.74% 710.1 s 50/50 0

SCAFFOLD improved final accuracy by 2.15 percentage points with 16.9% additional runtime. Successful aggregation in every SCAFFOLD round verifies that all eight patched Lightning clients returned the required parameter-only control delta and per-round step metadata automatically. The FedAvg run verifies that the generic manager preserves behavior without loading a SCAFFOLD handler.

@holgerroth holgerroth changed the title [codex] Add automatic Lightning SCAFFOLD support Add automatic Lightning SCAFFOLD support Jun 29, 2026
@codecov-commenter

codecov-commenter commented Jun 29, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 96.31148% with 9 lines in your changes missing coverage. Please review.
✅ Project coverage is 56.69%. Comparing base (a6c83e2) to head (04ed1ab).

Files with missing lines Patch % Lines
nvflare/app_opt/lightning/scaffold.py 95.38% 6 Missing ⚠️
nvflare/app_opt/pt/scaffold.py 88.23% 2 Missing ⚠️
nvflare/app_opt/lightning/api.py 98.43% 1 Missing ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             main    #4838      +/-   ##
==========================================
+ Coverage   56.52%   56.69%   +0.16%     
==========================================
  Files         969      971       +2     
  Lines       92255    92457     +202     
==========================================
+ Hits        52151    52417     +266     
+ Misses      40104    40040      -64     
Flag Coverage Δ
unit-tests 56.69% <96.31%> (+0.16%) ⬆️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

Comment thread nvflare/app_opt/lightning/callbacks.py Outdated
Comment thread nvflare/app_opt/lightning/api.py Outdated
@holgerroth holgerroth force-pushed the codex/lightning-scaffold branch from ebf1743 to 5c31e8b Compare June 30, 2026 01:13
@holgerroth holgerroth marked this pull request as ready for review June 30, 2026 03:49
@greptile-apps

greptile-apps Bot commented Jun 30, 2026

Copy link
Copy Markdown
Contributor

Greptile Summary

This PR closes the gap between FedAvgRecipe and ScaffoldRecipe for patched Lightning clients by automatically detecting SCAFFOLD global controls and applying PTScaffoldHelper corrections around Lightning's optimizer steps, returning the required control delta without any public API change. It also restricts SCAFFOLD corrections to trainable named parameters only (fixing the prior state-dict round-trip in PTScaffoldHelper), adds a _StateDictSnapshot to store the global-weight snapshot as detached CPU tensors, and resolves the double-receive issue for train_with_evaluation via a _pending_train_model pattern.

  • New Lightning SCAFFOLD path (algorithm.py, scaffold.py): lazily instantiates _ScaffoldHandler on first SCAFFOLD round, preserves local controls across rounds, and validates/rejects unsupported modes (manual optimization, gradient scaler, multiple optimizers, unequal LRs) with actionable errors.
  • PTScaffoldHelper refactor (pt/scaffold.py): model_update and terms_update now iterate named_parameters() with in-place add_ instead of a full state_dict round-trip; c_delta_para is a fresh dict of parameter-only numpy arrays, leaving buffer controls untouched at the server.
  • Test coverage (api_test.py, scaffold_test.py, fedavg_test.py): ~650 new lines covering gradient accumulation, train-with-evaluation, mid-fit validation, CPU snapshots, buffer isolation, and a server-side parameter-only-delta acceptance test.

Confidence Score: 5/5

Safe to merge; the SCAFFOLD and FedAvg paths are well-separated, the PTScaffoldHelper changes are backward-compatible, and the new Lightning hooks are no-ops when SCAFFOLD controls are absent.

The implementation is algorithmically correct: control corrections apply only to trainable parameters, local controls accumulate across rounds, the pending-train-model pattern eliminates the double-receive race for train_with_evaluation, and all fail-fast checks are exercised by the new unit suite. The two observations flagged are a defensive-copy omission in _StateDictSnapshot.state_dict() and an undocumented server-protocol assumption (buffer keys always present in global controls) — neither affects current correctness.

nvflare/app_opt/lightning/scaffold.py — _StateDictSnapshot.state_dict() and _validate_controls are the only areas worth a second look before future changes to the server-side control protocol.

Important Files Changed

Filename Overview
nvflare/app_opt/lightning/scaffold.py New file: implements _ScaffoldHandler and _StateDictSnapshot; state_dict() returns internal dict by reference rather than a copy
nvflare/app_opt/lightning/algorithm.py New file: _AlgorithmHandlerManager lazily instantiates _ScaffoldHandler on first SCAFFOLD round; FedAvg path unchanged
nvflare/app_opt/lightning/api.py Integrates algorithm handler into FLCallback; adds on_before_optimizer_step and on_train_batch_end hooks; fixes double receive() with _pending_train_model pattern
nvflare/app_opt/pt/scaffold.py model_update and terms_update now operate only on named_parameters with requires_grad=True; buffers excluded from corrections; no load_state_dict round-trip on model_update
tests/unit_test/app_opt/lightning/api_test.py Adds ~580 lines of new tests covering SCAFFOLD handler, gradient accumulation, train_with_evaluation, mid-fit validation, metadata conflict detection, and real Lightning integration
tests/unit_test/app_opt/pt/scaffold_test.py Adds tests for CPU-snapshot acceptance, trainable-parameter-only corrections, and buffer aggregation handling in PTScaffoldHelper
tests/unit_test/app_common/workflow/fedavg_test.py Adds server-side test verifying parameter-only delta leaves integer buffer controls unchanged in _global_ctrl_weights
examples/hello-world/hello-lightning/job.py Adds --algorithm CLI flag selecting FedAvgRecipe or ScaffoldRecipe; job name includes algorithm for teardown safety

Sequence Diagram

%%{init: {'theme': 'neutral'}}%%
sequenceDiagram
    participant Server
    participant FLCallback
    participant AlgorithmHandlerManager
    participant ScaffoldHandler
    participant PTScaffoldHelper
    participant LightningModule

    Note over FLCallback: on_validation_start (sanity val)
    FLCallback->>FLCallback: _receive_and_update_model()
    FLCallback->>FLCallback: store _pending_train_model

    Note over FLCallback: on_train_start
    FLCallback->>FLCallback: consume _pending_train_model
    FLCallback->>AlgorithmHandlerManager: start_round(trainer, pl_module, input_model)
    AlgorithmHandlerManager->>ScaffoldHandler: start_round()
    ScaffoldHandler->>ScaffoldHandler: validate global controls (all keys)
    ScaffoldHandler->>PTScaffoldHelper: init(pl_module) [first round only]
    ScaffoldHandler->>PTScaffoldHelper: load_global_controls()
    ScaffoldHandler->>ScaffoldHandler: _StateDictSnapshot(local_state)

    loop For each optimizer step
        Note over FLCallback: on_before_optimizer_step
        FLCallback->>AlgorithmHandlerManager: before_optimizer_step(optimizer)
        AlgorithmHandlerManager->>ScaffoldHandler: capture _pending_lr
        Note over LightningModule: Optimizer step applied
        Note over FLCallback: on_train_batch_end
        FLCallback->>AlgorithmHandlerManager: after_train_batch(pl_module)
        AlgorithmHandlerManager->>ScaffoldHandler: after_train_batch()
        ScaffoldHandler->>PTScaffoldHelper: model_update(model, curr_lr, ...)
        PTScaffoldHelper->>LightningModule: apply SCAFFOLD correction (params only)
    end

    Note over FLCallback: on_train_end
    FLCallback->>AlgorithmHandlerManager: finish_round(pl_module)
    AlgorithmHandlerManager->>ScaffoldHandler: finish_round()
    ScaffoldHandler->>PTScaffoldHelper: terms_update(model, avg_lr, ...)
    PTScaffoldHelper->>PTScaffoldHelper: compute c_delta_para (params only)
    ScaffoldHandler-->>FLCallback: "{SCAFFOLD_CTRL_DIFF: delta}"
    FLCallback->>Server: "send FLModel(params, meta={SCAFFOLD_CTRL_DIFF, NUM_STEPS_CURRENT_ROUND})"
Loading
%%{init: {'theme': 'base', 'themeVariables': {"darkMode": true, "background": "#0d1117", "primaryColor": "#21262d", "primaryTextColor": "#e6edf3", "primaryBorderColor": "#8b949e", "lineColor": "#8b949e", "textColor": "#e6edf3", "edgeLabelBackground": "#161b22", "actorBkg": "#21262d", "actorBorder": "#8b949e", "actorTextColor": "#e6edf3", "actorLineColor": "#8b949e", "signalColor": "#8b949e", "signalTextColor": "#e6edf3", "noteBkgColor": "#373320", "noteBorderColor": "#d4a72c", "noteTextColor": "#f0e6c0", "labelBoxBkgColor": "#21262d", "labelBoxBorderColor": "#8b949e", "labelTextColor": "#e6edf3", "loopTextColor": "#e6edf3", "activationBkgColor": "#30363d", "activationBorderColor": "#8b949e"}}}%%
sequenceDiagram
    participant Server
    participant FLCallback
    participant AlgorithmHandlerManager
    participant ScaffoldHandler
    participant PTScaffoldHelper
    participant LightningModule

    Note over FLCallback: on_validation_start (sanity val)
    FLCallback->>FLCallback: _receive_and_update_model()
    FLCallback->>FLCallback: store _pending_train_model

    Note over FLCallback: on_train_start
    FLCallback->>FLCallback: consume _pending_train_model
    FLCallback->>AlgorithmHandlerManager: start_round(trainer, pl_module, input_model)
    AlgorithmHandlerManager->>ScaffoldHandler: start_round()
    ScaffoldHandler->>ScaffoldHandler: validate global controls (all keys)
    ScaffoldHandler->>PTScaffoldHelper: init(pl_module) [first round only]
    ScaffoldHandler->>PTScaffoldHelper: load_global_controls()
    ScaffoldHandler->>ScaffoldHandler: _StateDictSnapshot(local_state)

    loop For each optimizer step
        Note over FLCallback: on_before_optimizer_step
        FLCallback->>AlgorithmHandlerManager: before_optimizer_step(optimizer)
        AlgorithmHandlerManager->>ScaffoldHandler: capture _pending_lr
        Note over LightningModule: Optimizer step applied
        Note over FLCallback: on_train_batch_end
        FLCallback->>AlgorithmHandlerManager: after_train_batch(pl_module)
        AlgorithmHandlerManager->>ScaffoldHandler: after_train_batch()
        ScaffoldHandler->>PTScaffoldHelper: model_update(model, curr_lr, ...)
        PTScaffoldHelper->>LightningModule: apply SCAFFOLD correction (params only)
    end

    Note over FLCallback: on_train_end
    FLCallback->>AlgorithmHandlerManager: finish_round(pl_module)
    AlgorithmHandlerManager->>ScaffoldHandler: finish_round()
    ScaffoldHandler->>PTScaffoldHelper: terms_update(model, avg_lr, ...)
    PTScaffoldHelper->>PTScaffoldHelper: compute c_delta_para (params only)
    ScaffoldHandler-->>FLCallback: "{SCAFFOLD_CTRL_DIFF: delta}"
    FLCallback->>Server: "send FLModel(params, meta={SCAFFOLD_CTRL_DIFF, NUM_STEPS_CURRENT_ROUND})"
Loading

Reviews (6): Last reviewed commit: "Merge branch 'main' into codex/lightning..." | Re-trigger Greptile

Comment thread nvflare/app_opt/lightning/callbacks.py Outdated
Comment thread nvflare/app_opt/lightning/api.py Outdated
Comment thread nvflare/app_opt/lightning/callbacks.py Outdated
Comment thread nvflare/app_opt/lightning/api.py Outdated
Comment thread nvflare/app_opt/lightning/api.py Outdated
Comment thread nvflare/app_opt/lightning/api.py Outdated

@chesterxgchen chesterxgchen left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems the scaffold handling logics leak to the generic lightning training process.

Signed-off-by: Holger Roth <hroth@nvidia.com>
@holgerroth

holgerroth commented Jun 30, 2026

Copy link
Copy Markdown
Collaborator Author

@chesterxgchen Addressed the requested changes in e30cec82d.

The Lightning API is now algorithm-neutral: FLCallback depends only on a private generic algorithm-handler manager. On the FedAvg path, the manager remains empty and never imports or instantiates SCAFFOLD. SCAFFOLD is lazy-loaded only when SCAFFOLD_CTRL_GLOBAL is received, and the activated handler is retained so missing controls in a later round still fail explicitly.

The other review items were addressed as follows:

  • Replaced the full LightningModule deep copy with a detached, CPU-backed state-dict snapshot. PTScaffoldHelper.terms_update() stages each snapshot tensor onto the active parameter device as needed, preserving its public signature and numerical behavior.
  • Made validation-before-training explicit: a training model received during validation is retained and consumed by on_train_start, rather than relying on a second cached receive() call.
  • Improved the interrupted optimizer-hook error to explain that an optimizer step was observed but on_train_batch_end did not run, which points users toward a Lightning/plugin hook-sequence interruption.
  • Preserved public APIs, FedAvg step weighting, gradient accumulation behavior, user metadata, distributed handling, and persistent SCAFFOLD local controls.

Validation completed on this commit:

  • Combined Lightning and PyTorch SCAFFOLD unit suites: 49 passed, 2 optional skips
  • Repository style check and documentation build: passed
  • Two-client/two-round Lightning SCAFFOLD H100 integration: passed
  • Eight-client, alpha=0.1, 50-round, four-local-epoch H100 comparison: 50/50 aggregations for both FedAvg and SCAFFOLD, zero tracebacks, valid final checkpoints
  • GitHub CI: all checks passed

I also replied point-by-point on each inline review thread with the corresponding implementation and validation details, and requested re-review.

@holgerroth holgerroth enabled auto-merge (squash) June 30, 2026 20:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants