Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions bergson/collector/gradient_collectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,13 +126,6 @@ def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]):
def global_project(self) -> None:
"""Concatenate per-module per-example gradients and project.
Sets ``self.mod_grads`` to ``{"gradients": projected}``.

Projects in row-chunks. A naive ``flat = torch.cat(..., dim=1)`` of
all module gradients can need tens of GiB contiguous on rank 0 when
the bin-packer assigns many short examples to a single batch (e.g.
flan_v2 with token_batch_size=2048 packs ~80 rows of ~525 MB each).
The projector is per-row, so chunking is exact; chunk size is sized
to a fixed GPU-byte budget.
"""
# backward_hook fires in reverse forward order, so insertion order in
# mod_grads is deterministic for a given model.
Expand All @@ -150,9 +143,6 @@ def global_project(self) -> None:
projection_type=self.processor.projection_type,
)

# Cap chunk_flat at ~4 GiB per chunk to leave headroom for fast_jl's
# internal scratch buffers alongside the per-module tensors that stay
# alive until the loop exits.
bytes_per_row = total_grad_dim * parts[0].element_size()
chunk_rows = max(1, (4 * 1024**3) // max(bytes_per_row, 1))
chunk_rows = min(chunk_rows, n_rows)
Expand All @@ -171,7 +161,6 @@ def global_project(self) -> None:
)
del chunk_flat, chunk_projected

# Free per-module GPU tensors now that all chunks are projected.
self.mod_grads = {}

self.mod_grads = {"gradients": torch.cat(chunks_cpu, dim=0)}
Expand Down
20 changes: 19 additions & 1 deletion bergson/hessians/apply_hessian.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
from dataclasses import dataclass
from pathlib import Path
from typing import Literal

import numpy as np
import torch
Expand All @@ -11,6 +12,7 @@
from simple_parsing import ArgumentParser
from torch import Tensor

from bergson.collector.collector import create_projection_matrix
from bergson.data import create_index, load_gradients
from bergson.distributed import init_dist
from bergson.hessians.sharded_computation import ShardedMul
Expand All @@ -29,6 +31,8 @@ class EkfacConfig:
`HessianConfig.ev_correction=True`."""
debug: bool = False
lambda_damp_factor: float = 0.1
projection_dim: int = 0
projection_type: Literal["normal", "rademacher"] = "rademacher"


class EkfacApplicator:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: This should allow us to also do H-1/2 multiplications, this is needed for KFAC + random projection.

Expand Down Expand Up @@ -72,8 +76,10 @@ def compute_ivhp_sharded(self):
eigen_g[k] = eigen_g[k].to(dtype=torch.float32)
lambda_factor[k] = v.to(dtype=torch.float32)

p = self.cfg.projection_dim
grad_sizes = {
name: eigen_g[name].shape[1] * eigen_a[name].shape[1] for name in eigen_a
name: p * p if p > 0 else eigen_g[name].shape[1] * eigen_a[name].shape[1]
for name in eigen_a
Comment thread
jammastergirish marked this conversation as resolved.
}

mmap = load_gradients(self.gradient_path)
Expand Down Expand Up @@ -151,6 +157,18 @@ def compute_ivhp_sharded(self):
del eigen_a
gc.collect()

if p > 0:
projection_type = self.cfg.projection_type
for k, v in transformed_gradients.items():
d_S, d_A = v.shape[-2:]
P_l = create_projection_matrix(
f"{k}/left", p, d_S, v.dtype, v.device, projection_type
)
P_r = create_projection_matrix(
f"{k}/right", p, d_A, v.dtype, v.device, projection_type
)
transformed_gradients[k] = torch.einsum("ps,nsa,ra->npr", P_l, v, P_r)

torch.cuda.synchronize()
for k, v in transformed_gradients.items():
grad_buffer[k][:] = v.to(device="cpu", non_blocking=True).flatten(1).numpy()
Expand Down
3 changes: 2 additions & 1 deletion bergson/hessians/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,8 @@ def _validate(cfg: IndexConfig):
run_path=transformed_query_path,
ev_correction=hessian_cfg.ev_correction,
lambda_damp_factor=hessian_pipeline_cfg.lambda_damp_factor,
projection_dim=index_cfg.projection_dim,
projection_type=index_cfg.projection_type,
)
launch_distributed_run(
"apply_hessian",
Expand All @@ -116,7 +118,6 @@ def _validate(cfg: IndexConfig):
if not _step_complete(scores_path, resume):
score_index_cfg = deepcopy(index_cfg)
score_index_cfg.run_path = scores_path
score_index_cfg.projection_dim = 0
score_index_cfg.skip_hessians = True
score_cfg.query_path = transformed_query_path
score_cfg.higher_is_better = True
Expand Down
Loading