Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions bergson/hessians/eigenvectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,19 +391,26 @@ def save_uncorrected_eigenvalues(
get_logger().info(f"Saved uncorrected eigenvalues to {out_dir}")


def save_identity_eigen(
def save_identity_shards(
partial_run_path: str | os.PathLike,
dim_per_key: dict[str, int],
sub_dir: str,
rank: int,
world_size: int,
dtype: torch.dtype = torch.float32,
scale: float | int | Tensor = 1.0,
) -> None:
"""Write per-rank shards of identity Q-side matrices to `sub_dir`.
"""Write per-rank shards of `scale * I_d`
to `<run>/<sub_dir>/shard_<rank>.safetensors`.

`dim_per_key` maps each module name to the size `d` of its
`[d, d]` identity Q.
Each saved tensor for key `name` has shape `[dim/world_size, dim]` (the
rank-local row-slice). `dim_per_key` maps each module name to the size
`d` of its `[d, d]` identity matrix.
"""
if isinstance(scale, Tensor):
scale = scale.item()
scale = float(scale)

payload: dict[str, Tensor] = {}
for name, d in dim_per_key.items():
if d % world_size != 0:
Expand All @@ -412,7 +419,7 @@ def save_identity_eigen(
)
shard_size = d // world_size
shard = torch.zeros(shard_size, d, dtype=dtype)
shard.diagonal(offset=rank * shard_size).fill_(1.0)
shard.diagonal(offset=rank * shard_size).fill_(scale)
payload[name] = shard

out_dir = Path(str(partial_run_path)) / sub_dir
Expand All @@ -432,15 +439,15 @@ def save_identity_factors(
`layer_dims` maps each target module name to its weight shape `(O, I)`.
"""
partial_run_path = Path(str(partial_run_path))
save_identity_eigen(
save_identity_shards(
partial_run_path,
{n: i for n, (_, i) in layer_dims.items()},
"eigen_activation_sharded",
rank,
world_size,
dtype,
)
save_identity_eigen(
save_identity_shards(
partial_run_path,
{n: o for n, (o, _) in layer_dims.items()},
"eigen_gradient_sharded",
Expand Down
57 changes: 51 additions & 6 deletions bergson/hessians/hessian_approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
from bergson.hessians.eigenvectors import (
LambdaCollector,
compute_eigendecomposition,
save_identity_eigen,
save_identity_factors,
save_identity_shards,
save_uncorrected_eigenvalues,
)
from bergson.hessians.foof import ActivationCovarianceCollector
Expand Down Expand Up @@ -156,6 +156,22 @@ def hessian_worker(
target_modules = peft_target_modules
assert target_modules is not None

def _get_target_module(name: str):
"""Resolve a `target_modules` entry against the loaded model.

PEFT-extracted names are full paths from the wrapper top (e.g.
``model.layers.X.mlp.Y.lora_A``), so ``model.get_submodule`` works
thanks to PEFT's ``__getattr__`` shim. Non-PEFT runs (incl. FSDP-
wrapped HF causal LMs) typically pass paths relative to the inner
``base_model`` (e.g. ``layers.X.mlp.Y``), where the wrapper class
has no matching top-level attribute. Fall back to ``base_model``
in that case.
"""
try:
return model.get_submodule(name)
except AttributeError:
return model.base_model.get_submodule(name)

kwargs = {
"model": model,
"data": ds,
Expand All @@ -170,7 +186,7 @@ def hessian_worker(

if hessian_cfg.method == "identity":
layer_dims = {
name: tuple(model.get_submodule(name).weight.shape)
name: tuple(_get_target_module(name).weight.shape)
for name in target_modules
}
dtype = convert_precision_to_torch(hessian_cfg.hessian_dtype)
Expand All @@ -194,6 +210,24 @@ def hessian_worker(
map_location="cpu",
weights_only=False,
)
save_identity_shards(
partial_run_path=index_cfg.partial_run_path,
dim_per_key={n: i for n, (_, i) in layer_dims.items()},
sub_dir="activation_sharded",
rank=rank,
world_size=world_size,
dtype=dtype,
scale=total_processed,
)
save_identity_shards(
partial_run_path=index_cfg.partial_run_path,
dim_per_key={n: o for n, (o, _) in layer_dims.items()},
sub_dir="gradient_sharded",
rank=rank,
world_size=world_size,
dtype=dtype,
scale=total_processed,
)
save_uncorrected_eigenvalues(
partial_run_path=index_cfg.partial_run_path,
eva_a_local={
Expand Down Expand Up @@ -232,22 +266,33 @@ def hessian_worker(
world_size = dist.get_world_size() if dist.is_initialized() else 1

if hessian_cfg.method == "foof":
# F_FOOF = E[aaᵀ] ⊗ I. Synthesise identity Q_G and eva_g = 1 to reuse
# the standard apply path.
# F_FOOF = E[aaᵀ] ⊗ I. Synthesise identity Q_G, eva_g = 1, and
# gradient_sharded = total_processed * I so the on-disk
# layout matches `CovarianceCollector`'s.

# named_modules() returns un-stripped PEFT paths; get_submodule resolves
# the stripped target_modules names via PEFT's __getattr__ shim.
out_dims = {
name: model.get_submodule(name).weight.shape[0] for name in target_modules
name: _get_target_module(name).weight.shape[0] for name in target_modules
}
dtype = convert_precision_to_torch(hessian_cfg.hessian_dtype)
save_identity_eigen(
save_identity_shards(
index_cfg.partial_run_path,
out_dims,
"eigen_gradient_sharded",
rank,
world_size,
dtype,
)
save_identity_shards(
partial_run_path=index_cfg.partial_run_path,
dim_per_key=out_dims,
sub_dir="gradient_sharded",
rank=rank,
world_size=world_size,
dtype=dtype,
scale=total_processed,
)
eva_g = {
name: torch.ones(d // world_size, dtype=dtype)
for name, d in out_dims.items()
Expand Down