diff --git a/bergson/collector/gradient_collectors.py b/bergson/collector/gradient_collectors.py index 19982264..a1c817e9 100644 --- a/bergson/collector/gradient_collectors.py +++ b/bergson/collector/gradient_collectors.py @@ -126,13 +126,6 @@ def backward_hook(self, module: nn.Module, g: Float[Tensor, "N S O"]): def global_project(self) -> None: """Concatenate per-module per-example gradients and project. Sets ``self.mod_grads`` to ``{"gradients": projected}``. - - Projects in row-chunks. A naive ``flat = torch.cat(..., dim=1)`` of - all module gradients can need tens of GiB contiguous on rank 0 when - the bin-packer assigns many short examples to a single batch (e.g. - flan_v2 with token_batch_size=2048 packs ~80 rows of ~525 MB each). - The projector is per-row, so chunking is exact; chunk size is sized - to a fixed GPU-byte budget. """ # backward_hook fires in reverse forward order, so insertion order in # mod_grads is deterministic for a given model. @@ -150,9 +143,6 @@ def global_project(self) -> None: projection_type=self.processor.projection_type, ) - # Cap chunk_flat at ~4 GiB per chunk to leave headroom for fast_jl's - # internal scratch buffers alongside the per-module tensors that stay - # alive until the loop exits. bytes_per_row = total_grad_dim * parts[0].element_size() chunk_rows = max(1, (4 * 1024**3) // max(bytes_per_row, 1)) chunk_rows = min(chunk_rows, n_rows) @@ -171,7 +161,6 @@ def global_project(self) -> None: ) del chunk_flat, chunk_projected - # Free per-module GPU tensors now that all chunks are projected. self.mod_grads = {} self.mod_grads = {"gradients": torch.cat(chunks_cpu, dim=0)} diff --git a/bergson/hessians/apply_hessian.py b/bergson/hessians/apply_hessian.py index 700c61eb..ea226851 100644 --- a/bergson/hessians/apply_hessian.py +++ b/bergson/hessians/apply_hessian.py @@ -3,6 +3,7 @@ import os from dataclasses import dataclass from pathlib import Path +from typing import Literal import numpy as np import torch @@ -11,6 +12,7 @@ from simple_parsing import ArgumentParser from torch import Tensor +from bergson.collector.collector import create_projection_matrix from bergson.data import create_index, load_gradients from bergson.distributed import init_dist from bergson.hessians.sharded_computation import ShardedMul @@ -29,6 +31,8 @@ class EkfacConfig: `HessianConfig.ev_correction=True`.""" debug: bool = False lambda_damp_factor: float = 0.1 + projection_dim: int = 0 + projection_type: Literal["normal", "rademacher"] = "rademacher" class EkfacApplicator: @@ -72,8 +76,10 @@ def compute_ivhp_sharded(self): eigen_g[k] = eigen_g[k].to(dtype=torch.float32) lambda_factor[k] = v.to(dtype=torch.float32) + p = self.cfg.projection_dim grad_sizes = { - name: eigen_g[name].shape[1] * eigen_a[name].shape[1] for name in eigen_a + name: p * p if p > 0 else eigen_g[name].shape[1] * eigen_a[name].shape[1] + for name in eigen_a } mmap = load_gradients(self.gradient_path) @@ -151,6 +157,18 @@ def compute_ivhp_sharded(self): del eigen_a gc.collect() + if p > 0: + projection_type = self.cfg.projection_type + for k, v in transformed_gradients.items(): + d_S, d_A = v.shape[-2:] + P_l = create_projection_matrix( + f"{k}/left", p, d_S, v.dtype, v.device, projection_type + ) + P_r = create_projection_matrix( + f"{k}/right", p, d_A, v.dtype, v.device, projection_type + ) + transformed_gradients[k] = torch.einsum("ps,nsa,ra->npr", P_l, v, P_r) + torch.cuda.synchronize() for k, v in transformed_gradients.items(): grad_buffer[k][:] = v.to(device="cpu", non_blocking=True).flatten(1).numpy() diff --git a/bergson/hessians/pipeline.py b/bergson/hessians/pipeline.py index fa2078ff..f98c64ed 100644 --- a/bergson/hessians/pipeline.py +++ b/bergson/hessians/pipeline.py @@ -103,6 +103,8 @@ def _validate(cfg: IndexConfig): run_path=transformed_query_path, ev_correction=hessian_cfg.ev_correction, lambda_damp_factor=hessian_pipeline_cfg.lambda_damp_factor, + projection_dim=index_cfg.projection_dim, + projection_type=index_cfg.projection_type, ) launch_distributed_run( "apply_hessian", @@ -116,7 +118,6 @@ def _validate(cfg: IndexConfig): if not _step_complete(scores_path, resume): score_index_cfg = deepcopy(index_cfg) score_index_cfg.run_path = scores_path - score_index_cfg.projection_dim = 0 score_index_cfg.skip_hessians = True score_cfg.query_path = transformed_query_path score_cfg.higher_is_better = True