diff --git a/bergson/hessians/eigenvectors.py b/bergson/hessians/eigenvectors.py index 333fcc65..b5b7a0dc 100644 --- a/bergson/hessians/eigenvectors.py +++ b/bergson/hessians/eigenvectors.py @@ -391,19 +391,26 @@ def save_uncorrected_eigenvalues( get_logger().info(f"Saved uncorrected eigenvalues to {out_dir}") -def save_identity_eigen( +def save_identity_shards( partial_run_path: str | os.PathLike, dim_per_key: dict[str, int], sub_dir: str, rank: int, world_size: int, dtype: torch.dtype = torch.float32, + scale: float | int | Tensor = 1.0, ) -> None: - """Write per-rank shards of identity Q-side matrices to `sub_dir`. + """Write per-rank shards of `scale * I_d` + to `//shard_.safetensors`. - `dim_per_key` maps each module name to the size `d` of its - `[d, d]` identity Q. + Each saved tensor for key `name` has shape `[dim/world_size, dim]` (the + rank-local row-slice). `dim_per_key` maps each module name to the size + `d` of its `[d, d]` identity matrix. """ + if isinstance(scale, Tensor): + scale = scale.item() + scale = float(scale) + payload: dict[str, Tensor] = {} for name, d in dim_per_key.items(): if d % world_size != 0: @@ -412,7 +419,7 @@ def save_identity_eigen( ) shard_size = d // world_size shard = torch.zeros(shard_size, d, dtype=dtype) - shard.diagonal(offset=rank * shard_size).fill_(1.0) + shard.diagonal(offset=rank * shard_size).fill_(scale) payload[name] = shard out_dir = Path(str(partial_run_path)) / sub_dir @@ -432,7 +439,7 @@ def save_identity_factors( `layer_dims` maps each target module name to its weight shape `(O, I)`. """ partial_run_path = Path(str(partial_run_path)) - save_identity_eigen( + save_identity_shards( partial_run_path, {n: i for n, (_, i) in layer_dims.items()}, "eigen_activation_sharded", @@ -440,7 +447,7 @@ def save_identity_factors( world_size, dtype, ) - save_identity_eigen( + save_identity_shards( partial_run_path, {n: o for n, (o, _) in layer_dims.items()}, "eigen_gradient_sharded", diff --git a/bergson/hessians/hessian_approximations.py b/bergson/hessians/hessian_approximations.py index a6c2c1f8..b217fc5c 100644 --- a/bergson/hessians/hessian_approximations.py +++ b/bergson/hessians/hessian_approximations.py @@ -18,8 +18,8 @@ from bergson.hessians.eigenvectors import ( LambdaCollector, compute_eigendecomposition, - save_identity_eigen, save_identity_factors, + save_identity_shards, save_uncorrected_eigenvalues, ) from bergson.hessians.foof import ActivationCovarianceCollector @@ -156,6 +156,22 @@ def hessian_worker( target_modules = peft_target_modules assert target_modules is not None + def _get_target_module(name: str): + """Resolve a `target_modules` entry against the loaded model. + + PEFT-extracted names are full paths from the wrapper top (e.g. + ``model.layers.X.mlp.Y.lora_A``), so ``model.get_submodule`` works + thanks to PEFT's ``__getattr__`` shim. Non-PEFT runs (incl. FSDP- + wrapped HF causal LMs) typically pass paths relative to the inner + ``base_model`` (e.g. ``layers.X.mlp.Y``), where the wrapper class + has no matching top-level attribute. Fall back to ``base_model`` + in that case. + """ + try: + return model.get_submodule(name) + except AttributeError: + return model.base_model.get_submodule(name) + kwargs = { "model": model, "data": ds, @@ -170,7 +186,7 @@ def hessian_worker( if hessian_cfg.method == "identity": layer_dims = { - name: tuple(model.get_submodule(name).weight.shape) + name: tuple(_get_target_module(name).weight.shape) for name in target_modules } dtype = convert_precision_to_torch(hessian_cfg.hessian_dtype) @@ -194,6 +210,24 @@ def hessian_worker( map_location="cpu", weights_only=False, ) + save_identity_shards( + partial_run_path=index_cfg.partial_run_path, + dim_per_key={n: i for n, (_, i) in layer_dims.items()}, + sub_dir="activation_sharded", + rank=rank, + world_size=world_size, + dtype=dtype, + scale=total_processed, + ) + save_identity_shards( + partial_run_path=index_cfg.partial_run_path, + dim_per_key={n: o for n, (o, _) in layer_dims.items()}, + sub_dir="gradient_sharded", + rank=rank, + world_size=world_size, + dtype=dtype, + scale=total_processed, + ) save_uncorrected_eigenvalues( partial_run_path=index_cfg.partial_run_path, eva_a_local={ @@ -232,15 +266,17 @@ def hessian_worker( world_size = dist.get_world_size() if dist.is_initialized() else 1 if hessian_cfg.method == "foof": - # F_FOOF = E[aaᵀ] ⊗ I. Synthesise identity Q_G and eva_g = 1 to reuse - # the standard apply path. + # F_FOOF = E[aaᵀ] ⊗ I. Synthesise identity Q_G, eva_g = 1, and + # gradient_sharded = total_processed * I so the on-disk + # layout matches `CovarianceCollector`'s. + # named_modules() returns un-stripped PEFT paths; get_submodule resolves # the stripped target_modules names via PEFT's __getattr__ shim. out_dims = { - name: model.get_submodule(name).weight.shape[0] for name in target_modules + name: _get_target_module(name).weight.shape[0] for name in target_modules } dtype = convert_precision_to_torch(hessian_cfg.hessian_dtype) - save_identity_eigen( + save_identity_shards( index_cfg.partial_run_path, out_dims, "eigen_gradient_sharded", @@ -248,6 +284,15 @@ def hessian_worker( world_size, dtype, ) + save_identity_shards( + partial_run_path=index_cfg.partial_run_path, + dim_per_key=out_dims, + sub_dir="gradient_sharded", + rank=rank, + world_size=world_size, + dtype=dtype, + scale=total_processed, + ) eva_g = { name: torch.ones(d // world_size, dtype=dtype) for name, d in out_dims.items()