Skip to content

Add first-class K-FAC support to build#275

Open
jammastergirish wants to merge 13 commits into
mainfrom
feat/first-class-kfac-build-support
Open

Add first-class K-FAC support to build#275
jammastergirish wants to merge 13 commits into
mainfrom
feat/first-class-kfac-build-support

Conversation

@jammastergirish

@jammastergirish jammastergirish commented May 5, 2026

Copy link
Copy Markdown
Collaborator

Linear task

Summary

Compress the K-FAC IVHP output to match bergson build's compressed gradient store. K-FAC factors are computed on uncompressed gradients (unchanged from main + #273); the compression happens at apply time as a post-projection of the IVHP output:

G̃_q = P_S · (H⁻¹ G_q) · P_Aᵀ      ∈ [N, p, p]

P_S and P_A are the same per-layer Rademacher / normal matrices (f"{name}/left", f"{name}/right") that bergson build uses on training gradients (via the existing create_projection_matrix(...) helper). Train-side stored gradients are P_S · G_t · P_Aᵀ by construction, so their Frobenius inner product with the compressed query approximates <G_q, H⁻¹ G_t> per layer up to a known (p/d)² constant.

IndexConfig.projection_dim (the existing knob that controls compression in build/score) now also drives apply-time K-FAC compression; no new user-facing flag.

Diff

Net +16 lines vs kfac-full-support (#273):

  • apply_hessian.py — two EkfacConfig fields, a grad_sizes ternary, and a 7-line post-projection block at the end of compute_ivhp_sharded. The legacy rotate-divide-rotate above it is byte-identical to main.
  • pipeline.py — thread index_cfg.projection_dim / projection_type into EkfacConfig (step 3), and remove the score_index_cfg.projection_dim = 0 override (step 4) so the training-side gradient store at score time is built at [p, p] to match the apply output.

HessianConfig is unchanged — no new fields. bergson hessian is unchanged.

Composability

  • Works with ev_correction=True — the post-projection step is independent of which lambda the rotate-divide-rotate uses.
  • Works for any Kronecker-factored method (kfac, tkfac, shampoo) — the apply path consumes Q and λ that compute_eigendecomposition already writes for all of them.
  • One-sided application (per the brief): H⁻¹ only ever touches the query; training gradients flow straight from build through to score with no inverse-Hessian application.

Known caveat (pre-existing)

K-FAC + include_bias=True has a pre-existing shape mismatch — filed as #277. Out of scope here; this PR doesn't make it worse.

Test plan

Validated end-to-end on Pythia-14m / pile-10k (200 train docs, 10 query docs, 24 modules):

  • bergson hessian --method kfac baseline — unchanged from main + kfac support by setting ev_correction True/False in the HessianCfg #273.
  • bergson ekfac --projection_dim 16 --projection_target per_modulekfac_query/info.json shows per-module grad_sizes = 256 = 16², scores/processor_config.yaml carries the propagated projection settings, scores produced cleanly.
  • bergson ekfac --projection_dim 0 — legacy regression: full d_S · d_A per-module sizes, scores produced.
  • bergson ekfac --projection_dim 16 --ev_correction true — EV correction + compression compose (no guard needed).

Bergson supports random-projection compression of gradients in
`bergson build` and `bergson score`. This wires the same per-layer
projection through the K-FAC / EK-FAC Hessian path so compressed
gradient stores can be scored against a compressed Hessian.

- New `HessianConfig.projection_dim` / `projection_type` fields.
  Default `projection_dim=0` keeps existing behavior.
- `compute_eigendecomposition` now takes optional
  `projection_dim` / `projection_type` / `side` and compresses each
  gathered covariance to `P @ M @ P.T` before `eigh`, using the same
  per-layer projection identifier convention as `collector.py`.
- `LambdaCollector` projects activations (right) and output gradients
  (left) before applying the now-`[p, p]` eigenvectors so the
  eigenvalue corrections are in the compressed space.
- Hessian worker validates that compression is only enabled with
  `method='kfac'` and `projection_target='per_module'`.
- One-sided application only: query gets projected and `apply_hessian`
  runs unchanged.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment thread bergson/config.py Outdated
Comment thread bergson/hessians/hessian_approximations.py
@luciaquirke luciaquirke requested a review from LouisYRYJ May 6, 2026 06:21
@luciaquirke

Copy link
Copy Markdown
Collaborator

@LouisYRYJ you don't need to read the PR description but if you could double check my implementation review that would be fantastic

@LouisYRYJ

Copy link
Copy Markdown
Contributor

Thanks, will check now.
Can we either merge #273 first or make this PR on top that one?

Comment thread bergson/hessians/hessian_approximations.py Outdated
@LouisYRYJ

Copy link
Copy Markdown
Contributor

Ok I think this is almost correct, but importantly the compression should be the last step in our kfac pipeline. Here is the sketch:

  1. Compute KFAC as usual including its eigenvalues kfac support by setting ev_correction True/False in the HessianCfg #273 Here we should add an option where save the eigenvalues of A and S (the covariance matrices), so that on disk we have Q_A, Q_S, the eigenvectors, and E_A, E_S, the respective eigenvalues. We have the equality A= Q_A E_A Q_A.T and S= Q_S E_S Q_S.T (or A= Q_A.T E_A Q_A resp. S not sure right now, but in this case it doesnt matter because A and S are symmetric).
  2. Compute the (damped) inverses (E_A+lambda)^{-1} and compute the matrix A'=(A+lambda *id)^-1= Q_A (E_A+lambda)^{-1} Q_A.T. Likewise for S.
  3. Now thisthis is the matrix we random project: A'_P= A'*P where P is the Rademacher or whatever random projection. shape[A'_P]=shape[P] and so this will be our new in-place curvature + random projection operation all in one go. Same story for S to get the other side of the projection

This works for any method, once we are provided with A and S, regardless of how they were computed. Eigenvalue correction will not work because it does not only operate on either rows or columns (whereas A, S and their decompositions do).

@jammastergirish jammastergirish marked this pull request as draft May 6, 2026 21:36
…st-class-kfac-build-support

# Conflicts:
#	bergson/hessians/eigenvectors.py
@jammastergirish jammastergirish changed the base branch from main to kfac-full-support May 6, 2026 21:46
jammastergirish and others added 4 commits May 6, 2026 14:58
Per Louis's review (#275): the random projection should be the *last*
step in the K-FAC pipeline, folded into the damped inverse rather than
applied to covariances before eigh. This commit tears out the
project-before-eigh path so the next stages can introduce the correct
post-inverse compression.

- compute_eigendecomposition: drop projection_dim/projection_type/side
  kwargs and the per-key P @ M @ P.T compression. Returns the same
  eigenvalue dict (from #273) but always in the full-rank space.
- LambdaCollector: drop projection_dim/projection_type fields and the
  pre-multiplication branches in forward_hook/backward_hook.
- hessian_approximations.py: drop method != 'kfac' validation (Louis:
  any Kronecker-factored method works). Replace it with a check that
  ev_correction=True is incompatible with projection_dim > 0, since the
  EK-FAC eigenvalue correction does not decompose into per-side row/col
  ops (Lucia + Louis confirmed).
- HessianConfig.projection_dim doc: drop the EK-FAC mention, broaden to
  any Kronecker-factored method, note ev_correction incompatibility.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Stage 3 (precompute the damped-inverse-projection M matrices) needs
E_A and E_S separately. The #273 merge gave us their outer product on
disk via `save_uncorrected_eigenvalues` (-> `eigenvalue_sharded/`),
but not the per-side spectra.

`compute_eigendecomposition` now writes the per-key eigenvalue shards
alongside the eigenvectors, mirroring the existing naming:

  eigen_activation_sharded/    (Q_A)        eigenvalue_activation_sharded/ (E_A)
  eigen_gradient_sharded/      (Q_S)        eigenvalue_gradient_sharded/   (E_S)

The function still returns the eigenvalue dict so the in-process
`save_uncorrected_eigenvalues` call can use them without reloading.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
For each layer, fold the per-side damped curvature inverse and the
per-layer random projection into a single [d, projection_dim] matrix:

    M = Q · diag((E + α)^{-1}) · Q^T · P^T

where Q, E come from the saved eigendecomposition, P is the same
Rademacher / normal projection used by `bergson build` (identifier
`f"{name}/{side}"`), and α = lambda_damp_factor * mean(E) is the
adaptive damping (mirrors `ShardedMul._sharded_hadamard`).

`compute_whitening_projection_matrices` reads Q and E shards from disk,
distributes layers across ranks via `fair_distribute_by_cost`, computes
M in float64 on the assigned rank, then re-shards along dim 0 (rows of
d) so each rank ends up with `M[rank*d/W : (rank+1)*d/W, :]` — matching
the existing Q sharding so `ShardedMul._matmul(activations, M_shard)`
is the natural way to use these in stage 4.

Outputs land in:

    whitening_projection_activation_sharded/  (M_A, side='right')
    whitening_projection_gradient_sharded/    (M_S, side='left')

`HessianConfig.lambda_damp_factor` (new field, default 0.1) controls
the damping baked into M.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
When hessian build saves whitening-projection matrices (stage 3), the
IVHP computation collapses to a single sharded matmul per side per
layer:

    G̃_q = M_Sᵀ · G_q · M_A          shape: [N, p_S, p_A]

No eigenbasis rotation, no eigenvalue divide, no rotation back — the
damped curvature inverse and the random projection are already folded
into M.

`compute_ivhp_sharded` now dispatches: if
``whitening_projection_activation_sharded/`` exists under the hessian
path, take the new path; otherwise fall through to the legacy
rotate-divide-rotate flow. The new path also rejects
``ev_correction=True`` at runtime, mirroring the build-time guard
added in stage 1.

Output shape is ``[p_S, p_A]`` flattened per layer (down from
``[d_S, d_A]`` in the legacy path), which is what makes this a
compression of the IVHP store, not just a reformulation.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jammastergirish

jammastergirish commented May 6, 2026

Copy link
Copy Markdown
Collaborator Author

While digging through the apply path during the redesign, Claude noticed an issue. I'm working through everything now and am not fully certain of it, so will discuss/let you know shortly:

K-FAC + include_bias=True is broken on main (and therefore in this PR too). The K-FAC activation covariance is built without the bias column, so the per-layer reshape in compute_ivhp_sharded fails for any layer where bias is being collected.

Filed separately as #277 so the fix doesn't get tangled with the compression work here. This PR's compression path inherits whatever fix lands there for free, since it consumes covariances by their declared dimension.

@LouisYRYJ

Copy link
Copy Markdown
Contributor

Yes, I never implemented support for bias. IMO will not be load bearing, but could be nice for sake of completeness (and technically shampoo or others also dont support it)

Step 4 of `hessian_pipeline` always reset `score_index_cfg.projection_dim = 0`,
which was correct for the legacy IVHP path (apply_hessian writes the
transformed query at full `[d_S, d_A]`, and training gradients must match
that for a meaningful inner product at score time).

With Hessian-factor compression on (`hessian_cfg.projection_dim > 0`),
apply_hessian now writes the query at `[p, p]` per layer, so training
gradients also need to be projected to `[p, p]` per layer. Otherwise the
score step compares `[d_S, d_A]` against `[p, p]` and either errors out
or computes a meaningless inner product.

Inherit the matching `projection_dim`, `projection_type`, and
`projection_target='per_module'` from `hessian_cfg` when compression is on;
keep the `projection_dim=0` override for the legacy path.

This is the last piece needed to run `bergson ekfac` end-to-end with
compression enabled.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Comment thread bergson/hessians/hessian_approximations.py Outdated
Comment thread bergson/hessians/eigenvectors.py Outdated
jammastergirish and others added 3 commits May 7, 2026 05:35
Per Lucia's review (#275), the build-time `M` precompute was overkill.
The brief always asked for: K-FAC computed on uncompressed gradients
(unchanged), and compression added at apply time. This commit restores
that minimal shape.

Replaces the `M = (cov + αI)⁻¹ Pᵀ` precompute (stages 2-4) with a
trivial post-projection of the legacy IVHP output:

    G̃_q = P_S · (H⁻¹ G_q) · P_Aᵀ      ∈ [N, p, p]

The legacy rotate-divide-rotate is unchanged; the new compression block
runs after it, using the same `f"{name}/{side}"` per-layer projection
identifiers as `bergson build`. Train-side stored gradients are already
`P_S · G_t · P_Aᵀ` (collector.py's `g_proj.T @ a_proj` collapses to
this), so the Frobenius inner product approximates `<G_q, H⁻¹ G_t>`
up to a known `(p/d)²` constant.

Deletions vs current branch:
- `eigenvectors.py`: revert stage 2's per-side eigenvalue saving;
  delete `compute_whitening_projection_matrices` entirely.
- `hessian_approximations.py`: delete the M precompute calls and the
  `projection_target == 'per_module'` + `ev_correction` validation
  block (the post-projection design composes with EV correction).
- `config.py`: delete `HessianConfig.projection_dim`,
  `projection_type`, `lambda_damp_factor` — compression is now
  controlled entirely by `IndexConfig.projection_dim` (already used by
  build/score) and the new `EkfacConfig.projection_dim`.
- `apply_hessian.py`: delete `_compute_ivhp_with_whitening_projection`
  and the dispatch.

Additions:
- `EkfacConfig`: new `projection_dim`, `projection_type` fields.
- `apply_hessian.py:compute_ivhp_sharded`: ~15 lines projecting the
  rotated-and-divided IVHP output per layer when `projection_dim > 0`,
  plus a `grad_sizes` adjustment.
- `pipeline.py` step 3: pass `index_cfg.projection_dim` and
  `projection_type` into `EkfacConfig`. Step 4 stops overriding
  `score_index_cfg.projection_dim = 0` — the deepcopy of `index_cfg`
  now correctly carries the projection settings to the score step, so
  training-side gradient shapes match the apply output automatically.

Net: -258 lines vs current branch, +~25 lines vs main.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Drop one-arg-per-line `create_projection_matrix(...)` calls in favor of
positional args, fold the `grad_sizes` if/else into a dict-comprehension
ternary, and remove the trailing debug log. Same behavior, ~17 fewer
lines. Brings this PR to net +14 vs `kfac-full-support` (#273).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@jammastergirish jammastergirish marked this pull request as ready for review May 7, 2026 21:51
Comment thread bergson/hessians/apply_hessian.py
Comment thread bergson/hessians/apply_hessian.py Outdated
@jammastergirish

jammastergirish commented May 8, 2026 via email

Copy link
Copy Markdown
Collaborator Author

@luciaquirke luciaquirke left a comment

Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

🔥 🔥

@luciaquirke luciaquirke changed the base branch from kfac-full-support to main May 8, 2026 08:13
@jammastergirish jammastergirish force-pushed the feat/first-class-kfac-build-support branch from b7921ac to 8c9c30f Compare May 21, 2026 18:49
projection_type: Literal["normal", "rademacher"] = "rademacher"


class EkfacApplicator:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: This should allow us to also do H-1/2 multiplications, this is needed for KFAC + random projection.

print(f" [{label}] took {elapsed:.1f}s")


def hessian_pipeline(

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: throw error when we do compression + ev correction. We right now only support kfac without eigenvalue correction.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also this will work not only for kfac, but other methods too (e.g. shampoo)

[ekfac_cfg],
index_cfg.distributed,
)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Step 3.5: Compute R @ H^-1/2.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

THis needs to be computed for the right and the left projection seperately and you do this by multiplying left_proj with (act or grad cov inverse) and then right_proj with (grad or act cov inverse)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Save to disk

score_cfg.higher_is_better = True
_validate(score_index_cfg)

score_dataset(score_index_cfg, score_cfg, preprocess_cfg)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

go in here, and go deep into the logic to find out where random projections are multiplied with the grads. And then add a config or whatever that realizes that we should use the modified random projections we just computed and saved to disk

jammastergirish added a commit that referenced this pull request Jun 2, 2026
Per Louis Jaburi's review on #275: compose the random projection with
H^{-1/2} into a single saved matrix M = R · cov^{-1/2}, then use M
wherever the gradient collector used to sample R. Query and training
gradients both go through M, so <ĝ_z, ĝ_q> ≈ <g_z, H⁻¹ g_q> without
ever applying H⁻¹ in low-dim space.

- eigenvectors.py: persist per-side eigenvalues to eigval_{activation,
  gradient}_sharded/ alongside the existing eigenvectors.
- apply_hessian.py: build_kfac_projections builds
  M = R · Q · diag((E + λ·mean(E))^{-1/2}) · Qᵀ per (layer, side) and
  writes projection_{left,right}_sharded/. _apply_compressed applies
  ĝ_q = M_l · G_q · M_rᵀ. Legacy rotate-divide-rotate path is unchanged.
  Early guard rejects ev_correction + compression (the joint S⊗A
  eigenvalue correction breaks the Kronecker structure).
- config.py + gradient_collectors.py: new IndexConfig.kfac_projection_path
  pre-populates processor._projection_matrices from disk in
  GradientCollector.setup(), so HookCollectorBase.projection() returns
  the saved M instead of sampling a fresh R at hook time.
- pipeline.py: visible Step 3.5 builds M and projects the query;
  Step 4 sets kfac_projection_path so the training-side collector uses
  the same M. Top-level guard fails fast on projection_dim > 0 +
  ev_correction=True before any work runs.

Untested on GPU (Mac, no CUDA); smoke yaml at runs/p_sift_smoke.yaml.

Co-Authored-By: Claude Opus 4.7 <noreply@anthropic.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants