Skip to content

Operators

CurvatureOperator

Bases: ABC

Symmetric matrix-free linear operator over a flat parameter vector.

Source code in hessian_eigenthings/operators/base.py
class CurvatureOperator(ABC):
    """Symmetric matrix-free linear operator over a flat parameter vector."""

    @property
    @abstractmethod
    def size(self) -> int:
        """Number of elements in the operator's input/output vector."""

    @property
    @abstractmethod
    def device(self) -> torch.device:
        """Device on which `matvec` produces its output."""

    @property
    @abstractmethod
    def dtype(self) -> torch.dtype:
        """Dtype of the output of `matvec`."""

    @abstractmethod
    def matvec(self, v: torch.Tensor) -> torch.Tensor:
        """Compute `M @ v` where `M` is this operator. `v` is a flat 1-D tensor of length `self.size`."""

    def rmatvec(self, v: torch.Tensor) -> torch.Tensor:
        """Compute `v^T @ M`. Equal to `matvec` because curvature operators are symmetric."""
        return self.matvec(v)

    def __call__(self, v: torch.Tensor) -> torch.Tensor:
        return self.matvec(v)

    def __matmul__(self, v: torch.Tensor) -> torch.Tensor:
        return self.matvec(v)

size abstractmethod property

Number of elements in the operator's input/output vector.

device abstractmethod property

Device on which matvec produces its output.

dtype abstractmethod property

Dtype of the output of matvec.

matvec(v) abstractmethod

Compute M @ v where M is this operator. v is a flat 1-D tensor of length self.size.

Source code in hessian_eigenthings/operators/base.py
@abstractmethod
def matvec(self, v: torch.Tensor) -> torch.Tensor:
    """Compute `M @ v` where `M` is this operator. `v` is a flat 1-D tensor of length `self.size`."""

rmatvec(v)

Compute v^T @ M. Equal to matvec because curvature operators are symmetric.

Source code in hessian_eigenthings/operators/base.py
def rmatvec(self, v: torch.Tensor) -> torch.Tensor:
    """Compute `v^T @ M`. Equal to `matvec` because curvature operators are symmetric."""
    return self.matvec(v)

LambdaOperator

Bases: CurvatureOperator

Wrap a callable as a CurvatureOperator. Useful for tests and ad-hoc operators.

Source code in hessian_eigenthings/operators/base.py
class LambdaOperator(CurvatureOperator):
    """Wrap a callable as a CurvatureOperator. Useful for tests and ad-hoc operators."""

    def __init__(
        self,
        matvec_fn: Callable[[torch.Tensor], torch.Tensor],
        size: int,
        device: torch.device | str = "cpu",
        dtype: torch.dtype = torch.float32,
    ) -> None:
        self._matvec_fn = matvec_fn
        self._size = size
        self._device = torch.device(device)
        self._dtype = dtype

    @property
    def size(self) -> int:
        return self._size

    @property
    def device(self) -> torch.device:
        return self._device

    @property
    def dtype(self) -> torch.dtype:
        return self._dtype

    def matvec(self, v: torch.Tensor) -> torch.Tensor:
        return self._matvec_fn(v)

HessianOperator

Bases: CurvatureOperator

Hessian of loss_fn(model, batch) averaged over batches in dataloader.

Two HVP methods are supported via method=:

  • "autograd" (default): exact double-backward via torch.autograd.grad with create_graph=True. Numerically exact (to rounding); ideal for single-device analysis up to ~7B parameters.

  • "finite_difference": central-difference (∇L(θ+εv) − ∇L(θ−εv)) / 2ε per Granziol & Juarev 2026. Two normal forward+backward passes per HVP, no second-backward graph anywhere — works with FSDP/HSDP/TP without any special handling. Trade-off: O(ε²) truncation bias plus precision-dependent roundoff (~1e-5 fp32, ~1e-2 bf16). Suitable for spectral analysis at scale.

Source code in hessian_eigenthings/operators/hessian.py
class HessianOperator(CurvatureOperator):
    """Hessian of `loss_fn(model, batch)` averaged over batches in `dataloader`.

    Two HVP methods are supported via `method=`:

    * ``"autograd"`` (default): exact double-backward via `torch.autograd.grad` with
      `create_graph=True`. Numerically exact (to rounding); ideal for single-device
      analysis up to ~7B parameters.

    * ``"finite_difference"``: central-difference `(∇L(θ+εv) − ∇L(θ−εv)) / 2ε` per
      Granziol & Juarev 2026. Two normal forward+backward passes per HVP, no
      second-backward graph anywhere — works with FSDP/HSDP/TP without any
      special handling. Trade-off: O(ε²) truncation bias plus precision-dependent
      roundoff (~1e-5 fp32, ~1e-2 bf16). Suitable for spectral analysis at scale.
    """

    def __init__(
        self,
        model: nn.Module,
        dataloader: Iterable[Any],
        loss_fn: LossFn,
        *,
        param_filter: ParamFilter | None = None,
        full_dataset: bool = True,
        num_batches: int | None = None,
        microbatch_size: int | None = None,
        microbatch_unsafe: bool = False,
        method: HvpMethod = "autograd",
        fd_eps: float | None = None,
        backend: LinAlgBackend[torch.Tensor] | None = None,
    ) -> None:
        self.model = model
        self.dataloader = dataloader
        self.loss_fn = loss_fn
        self.full_dataset = full_dataset
        self.num_batches = num_batches
        self.microbatch_size = microbatch_size
        self.method: HvpMethod = method
        self.backend: LinAlgBackend[torch.Tensor] = backend or SingleDeviceBackend()

        if microbatch_size is not None and not microbatch_unsafe:
            assert_microbatch_safe(model)

        self._params = select_parameters(model, param_filter)
        self._param_list = list(self._params.values())
        self._sizes = [int(p.numel()) for p in self._param_list]
        self._size = total_size(self._params)

        first = self._param_list[0]
        self._device = first.device
        self._dtype = first.dtype

        self.fd_eps = fd_eps if fd_eps is not None else _FD_EPS_BY_DTYPE.get(self._dtype, 1e-3)

        self._batch_iter: Iterator[Any] | None = None

    @property
    def size(self) -> int:
        return self._size

    @property
    def device(self) -> torch.device:
        return self._device

    @property
    def dtype(self) -> torch.dtype:
        return self._dtype

    def matvec(self, v: torch.Tensor) -> torch.Tensor:
        if v.shape != (self._size,):
            raise ValueError(f"expected vector of shape ({self._size},), got {tuple(v.shape)}")

        if self.full_dataset:
            return self._matvec_full(v)
        return self._matvec_one_batch(v, self._next_batch())

    def _matvec_full(self, v: torch.Tensor) -> torch.Tensor:
        total = self.backend.zeros_like(v)
        n = 0
        for batch in iterate_batches(self.dataloader, self.num_batches):
            chunk = self._matvec_one_batch(v, batch)
            total = self.backend.axpy(1.0, chunk, total)
            n += 1
        if n == 0:
            raise RuntimeError("dataloader yielded no batches")
        return self.backend.scale(1.0 / n, total)

    def _matvec_one_batch(self, v: torch.Tensor, batch: Any) -> torch.Tensor:
        batch = move_batch_to_device(batch, self._device)
        v_split = self._split(v)

        if self.microbatch_size is None:
            return self._hvp(v_split, batch)
        return self._hvp_microbatched(v_split, batch)

    def _hvp(self, v_split: list[torch.Tensor], batch: Any) -> torch.Tensor:
        if self.method == "autograd":
            return self._hvp_autograd(v_split, batch)
        if self.method == "finite_difference":
            return self._hvp_finite_difference(v_split, batch)
        raise ValueError(f"unknown method={self.method!r}")  # pragma: no cover

    def _hvp_autograd(self, v_split: list[torch.Tensor], batch: Any) -> torch.Tensor:
        loss = self.loss_fn(self.model, batch)
        grads = torch.autograd.grad(loss, self._param_list, create_graph=True)
        hvp = torch.autograd.grad(grads, self._param_list, grad_outputs=v_split)
        return torch.cat([h.reshape(-1) for h in hvp])

    def _hvp_finite_difference(self, v_split: list[torch.Tensor], batch: Any) -> torch.Tensor:
        eps = self.fd_eps
        snapshot = [p.detach().clone() for p in self._param_list]
        try:
            with torch.no_grad():
                for p, dv in zip(self._param_list, v_split, strict=True):
                    p.add_(dv, alpha=eps)
            g_plus = self._compute_grad_flat(batch)

            with torch.no_grad():
                for p, dv in zip(self._param_list, v_split, strict=True):
                    p.add_(dv, alpha=-2.0 * eps)
            g_minus = self._compute_grad_flat(batch)
        finally:
            with torch.no_grad():
                for p, snap in zip(self._param_list, snapshot, strict=True):
                    p.copy_(snap)

        return (g_plus - g_minus) / (2.0 * eps)

    def _compute_grad_flat(self, batch: Any) -> torch.Tensor:
        loss = self.loss_fn(self.model, batch)
        grads = torch.autograd.grad(loss, self._param_list)
        return torch.cat([g.reshape(-1).detach() for g in grads])

    def _hvp_microbatched(self, v_split: list[torch.Tensor], batch: Any) -> torch.Tensor:
        assert self.microbatch_size is not None
        chunks = _split_batch(batch, self.microbatch_size)
        if not chunks:
            raise RuntimeError("microbatching produced no chunks")
        total: torch.Tensor | None = None
        for chunk in chunks:
            hvp = self._hvp(v_split, chunk)
            total = hvp if total is None else total + hvp
        assert total is not None
        return total / len(chunks)

    def _split(self, v: torch.Tensor) -> list[torch.Tensor]:
        out: list[torch.Tensor] = []
        offset = 0
        for n, p in zip(self._sizes, self._param_list, strict=True):
            out.append(v[offset : offset + n].reshape_as(p))
            offset += n
        return out

    def _next_batch(self) -> Any:
        if self._batch_iter is None:
            self._batch_iter = iter(self.dataloader)
        try:
            return next(self._batch_iter)
        except StopIteration:
            self._batch_iter = iter(self.dataloader)
            return next(self._batch_iter)

GGNOperator

Bases: CurvatureOperator

Generalized Gauss-Newton matrix G = J^T H_loss J.

For convex per-sample losses (cross-entropy + softmax, MSE) H_loss is PSD so G is PSD by construction. For cross-entropy + softmax classification, G equals the Fisher information matrix.

The two-function API (forward_fn returns the model output, loss_of_output_fn converts that output + batch into a scalar loss) lets us compute J v, the loss-Hessian-vector product H_loss · (Jv), and J^T · (H_loss · Jv) without coupling to the loss internals.

Two implementations of the matvec are available via loss_hvp=:

  • "analytical" (default): finite-difference JVP + analytical loss-Hessian-vec product (read from loss_of_output_fn.hvp, which must be present) + a single normal backward to apply J^T. Memory footprint matches one normal training step. Required for LM-scale use; see the OOM diagnostic in scripts/repro_ggn_oom.py.

  • "autograd": the original torch.func.jvp + autograd double-backward + torch.func.vjp path. Numerically exact and supports any loss, but memory scales badly with output size — for cross-entropy heads with large vocab the create_graph=True step alone can dominate. Kept as a fallback for losses without an analytical .hvp.

Source code in hessian_eigenthings/operators/ggn.py
class GGNOperator(CurvatureOperator):
    """Generalized Gauss-Newton matrix `G = J^T H_loss J`.

    For convex per-sample losses (cross-entropy + softmax, MSE) `H_loss` is PSD so
    `G` is PSD by construction. For cross-entropy + softmax classification, `G`
    equals the Fisher information matrix.

    The two-function API (`forward_fn` returns the model output, `loss_of_output_fn`
    converts that output + batch into a scalar loss) lets us compute `J v`, the
    loss-Hessian-vector product `H_loss · (Jv)`, and `J^T · (H_loss · Jv)` without
    coupling to the loss internals.

    Two implementations of the matvec are available via `loss_hvp=`:

    * ``"analytical"`` (default): finite-difference JVP + analytical loss-Hessian-vec
      product (read from `loss_of_output_fn.hvp`, which must be present) + a single
      normal backward to apply `J^T`. Memory footprint matches one normal training
      step. Required for LM-scale use; see the OOM diagnostic in
      `scripts/repro_ggn_oom.py`.

    * ``"autograd"``: the original `torch.func.jvp` + autograd double-backward +
      `torch.func.vjp` path. Numerically exact and supports any loss, but memory
      scales badly with output size — for cross-entropy heads with large vocab the
      `create_graph=True` step alone can dominate. Kept as a fallback for losses
      without an analytical `.hvp`.
    """

    def __init__(
        self,
        model: nn.Module,
        dataloader: Iterable[Any],
        forward_fn: ForwardFn,
        loss_of_output_fn: LossOfOutputFn,
        *,
        param_filter: ParamFilter | None = None,
        full_dataset: bool = True,
        num_batches: int | None = None,
        loss_hvp: LossHvpMethod = "analytical",
        fd_eps: float | None = None,
        backend: LinAlgBackend[torch.Tensor] | None = None,
    ) -> None:
        self.model = model
        self.dataloader = dataloader
        self.forward_fn = forward_fn
        self.loss_of_output_fn = loss_of_output_fn
        self.full_dataset = full_dataset
        self.num_batches = num_batches
        self.loss_hvp: LossHvpMethod = loss_hvp
        self.backend: LinAlgBackend[torch.Tensor] = backend or SingleDeviceBackend()

        if loss_hvp not in ("analytical", "autograd"):
            raise ValueError(f"loss_hvp={loss_hvp!r} not in ('analytical', 'autograd')")
        if loss_hvp == "analytical" and not hasattr(loss_of_output_fn, "hvp"):
            raise ValueError(
                "loss_hvp='analytical' requires `loss_of_output_fn.hvp(output, "
                "batch, u)` to be defined (use `cross_entropy_loss_of_output()` "
                "or `mse_loss_of_output()` from `hessian_eigenthings.loss_fns`, "
                "or wrap your callable with `_LossOfOutputWithHvp`). Pass "
                "loss_hvp='autograd' to fall back to the double-backward path."
            )

        self._params = select_parameters(model, param_filter)
        self._param_names = list(self._params)
        self._param_list = list(self._params.values())
        self._sizes = [int(p.numel()) for p in self._param_list]
        self._size = total_size(self._params)
        self._fixed_params = {n: p for n, p in model.named_parameters() if n not in self._params}
        self._buffers = dict(model.named_buffers())

        first = self._param_list[0]
        self._device = first.device
        self._dtype = first.dtype

        self.fd_eps = fd_eps if fd_eps is not None else _FD_EPS_BY_DTYPE.get(self._dtype, 1e-3)

    @property
    def size(self) -> int:
        return self._size

    @property
    def device(self) -> torch.device:
        return self._device

    @property
    def dtype(self) -> torch.dtype:
        return self._dtype

    def matvec(self, v: torch.Tensor) -> torch.Tensor:
        if v.shape != (self._size,):
            raise ValueError(f"expected vector of shape ({self._size},), got {tuple(v.shape)}")
        if self.full_dataset:
            return self._matvec_full(v)
        return self._matvec_one_batch(v, next(iter(self.dataloader)))

    def _matvec_full(self, v: torch.Tensor) -> torch.Tensor:
        total = self.backend.zeros_like(v)
        n = 0
        for batch in iterate_batches(self.dataloader, self.num_batches):
            chunk = self._matvec_one_batch(v, batch)
            total = self.backend.axpy(1.0, chunk, total)
            n += 1
        if n == 0:
            raise RuntimeError("dataloader yielded no batches")
        return self.backend.scale(1.0 / n, total)

    def _matvec_one_batch(self, v: torch.Tensor, batch: Any) -> torch.Tensor:
        batch = move_batch_to_device(batch, self._device)
        if self.loss_hvp == "analytical":
            return self._matvec_fd_jvp(v, batch)
        return self._matvec_autograd(v, batch)

    # --- analytical (FD JVP + analytical H_loss + single backward) ----------

    def _matvec_fd_jvp(self, v: torch.Tensor, batch: Any) -> torch.Tensor:
        """`fd_jvp_single_vjp`: 2 no-grad forwards (FD JVP) + analytical loss-HVP +
        1 grad-enabled forward+backward to apply `J^T`. Memory peaks at one
        normal training step.
        """
        v_split = self._split(v)

        # Normalise v internally so eps * ||v|| can't underflow on tiny v.
        # We compute matvec(v / s) and then multiply by s — `G` is linear in v.
        v_norm = float(torch.linalg.vector_norm(v).item())
        scale = max(v_norm, _V_NORM_FLOOR)
        if scale != 1.0:
            v_split = [vs / scale for vs in v_split]

        eps = self.fd_eps
        snapshot = [p.detach().clone() for p in self._param_list]
        try:
            with torch.no_grad():
                self._add_inplace(v_split, +eps)
                out_plus = self.forward_fn(self.model, batch).detach().clone()
                self._add_inplace(v_split, -2.0 * eps)
                out_minus = self.forward_fn(self.model, batch).detach().clone()
        finally:
            with torch.no_grad():
                for p, snap in zip(self._param_list, snapshot, strict=True):
                    p.copy_(snap)
        del snapshot

        jvp_out = (out_plus - out_minus) / (2.0 * eps)
        del out_plus, out_minus

        # Single grad-enabled forward; we'll reuse `logits` both for the
        # analytical loss-HVP and as the source for `J^T h_loss_jvp`.
        logits = self.forward_fn(self.model, batch)
        # `.hvp` is guaranteed to exist for loss_hvp=="analytical" — checked in __init__.
        h_loss_jvp = self.loss_of_output_fn.hvp(logits.detach(), batch, jvp_out)  # type: ignore[attr-defined]
        del jvp_out

        grads = torch.autograd.grad(logits, self._param_list, grad_outputs=h_loss_jvp)
        result = torch.cat([g.reshape(-1) for g in grads])
        if scale != 1.0:
            result = result * scale
        return result

    def _add_inplace(self, v_split: list[torch.Tensor], alpha: float) -> None:
        for p, dv in zip(self._param_list, v_split, strict=True):
            p.add_(dv, alpha=alpha)

    def _split(self, v: torch.Tensor) -> list[torch.Tensor]:
        out: list[torch.Tensor] = []
        offset = 0
        for n, p in zip(self._sizes, self._param_list, strict=True):
            out.append(v[offset : offset + n].reshape_as(p))
            offset += n
        return out

    # --- autograd fallback (original implementation) ------------------------

    def _matvec_autograd(self, v: torch.Tensor, batch: Any) -> torch.Tensor:
        v_dict = self._unflatten(v)
        param_dict = dict(zip(self._param_names, self._param_list, strict=True))

        def model_call(p_subset: dict[str, torch.Tensor]) -> torch.Tensor:
            full = {**self._fixed_params, **p_subset, **self._buffers}
            adapter = _FunctionalModel(self.model, full)
            return self.forward_fn(adapter, batch)  # type: ignore[arg-type]

        jvp_result = cast(
            tuple[torch.Tensor, torch.Tensor],
            torch.func.jvp(model_call, (param_dict,), (v_dict,)),
        )
        output, jvp_out = jvp_result

        output_leaf = output.detach().requires_grad_(True)
        loss = self.loss_of_output_fn(output_leaf, batch)
        grad_loss = torch.autograd.grad(loss, output_leaf, create_graph=True)[0]
        h_loss_jvp = torch.autograd.grad(grad_loss, output_leaf, grad_outputs=jvp_out)[0]

        vjp_result = cast(
            tuple[torch.Tensor, Callable[[torch.Tensor], tuple[dict[str, torch.Tensor]]]],
            torch.func.vjp(model_call, param_dict),
        )
        _, vjp_fn = vjp_result
        gv_dict = vjp_fn(h_loss_jvp)[0]

        return torch.cat([gv_dict[n].reshape(-1) for n in self._param_names])

    def _unflatten(self, v: torch.Tensor) -> dict[str, torch.Tensor]:
        out: dict[str, torch.Tensor] = {}
        offset = 0
        for n, p, sz in zip(self._param_names, self._param_list, self._sizes, strict=True):
            out[n] = v[offset : offset + sz].reshape_as(p)
            offset += sz
        return out

EmpiricalFisherOperator

Bases: CurvatureOperator

Empirical Fisher F = (1/N) Σ_i g_i g_i^T where g_i = ∂loss_i/∂θ are per-sample grads.

Empirical Fisher uses the true labels in the loss (unlike the MC Fisher which samples labels from the model's predictive distribution), and is therefore a biased estimator of the actual Fisher information. Conflating the two is the classic GGN-vs-Fisher-vs-empirical-Fisher pitfall — see Martens 2014.

Per-sample gradients are computed in one pass via torch.func.vmap(grad(...)), so the cost is one forward+backward per batch, not per sample.

The per_sample_loss_fn(model, sample) -> Tensor takes a single (un-batched) sample. The sample_dim argument tells the operator which axis to vmap over when receiving a batch from the dataloader.

Source code in hessian_eigenthings/operators/fisher.py
class EmpiricalFisherOperator(CurvatureOperator):
    """Empirical Fisher `F = (1/N) Σ_i g_i g_i^T` where `g_i = ∂loss_i/∂θ` are per-sample grads.

    Empirical Fisher uses the *true* labels in the loss (unlike the MC Fisher which
    samples labels from the model's predictive distribution), and is therefore a
    biased estimator of the actual Fisher information. Conflating the two is the
    classic GGN-vs-Fisher-vs-empirical-Fisher pitfall — see Martens 2014.

    Per-sample gradients are computed in one pass via `torch.func.vmap(grad(...))`,
    so the cost is one forward+backward per batch, not per sample.

    The `per_sample_loss_fn(model, sample) -> Tensor` takes a single (un-batched)
    sample. The `sample_dim` argument tells the operator which axis to vmap over
    when receiving a batch from the dataloader.
    """

    def __init__(
        self,
        model: nn.Module,
        dataloader: Iterable[Any],
        per_sample_loss_fn: PerSampleLossFn,
        *,
        param_filter: ParamFilter | None = None,
        full_dataset: bool = True,
        num_batches: int | None = None,
        sample_dim: int = 0,
        backend: LinAlgBackend[torch.Tensor] | None = None,
    ) -> None:
        self.model = model
        self.dataloader = dataloader
        self.per_sample_loss_fn = per_sample_loss_fn
        self.full_dataset = full_dataset
        self.num_batches = num_batches
        self.sample_dim = sample_dim
        self.backend: LinAlgBackend[torch.Tensor] = backend or SingleDeviceBackend()

        self._params = select_parameters(model, param_filter)
        self._param_names = list(self._params)
        self._param_list = list(self._params.values())
        self._sizes = [int(p.numel()) for p in self._param_list]
        self._size = total_size(self._params)
        self._fixed_params = {n: p for n, p in model.named_parameters() if n not in self._params}
        self._buffers = dict(model.named_buffers())

        first = self._param_list[0]
        self._device = first.device
        self._dtype = first.dtype

    @property
    def size(self) -> int:
        return self._size

    @property
    def device(self) -> torch.device:
        return self._device

    @property
    def dtype(self) -> torch.dtype:
        return self._dtype

    def matvec(self, v: torch.Tensor) -> torch.Tensor:
        if v.shape != (self._size,):
            raise ValueError(f"expected vector of shape ({self._size},), got {tuple(v.shape)}")
        if self.full_dataset:
            return self._matvec_full(v)
        return self._matvec_one_batch(v, next(iter(self.dataloader)))

    def _matvec_full(self, v: torch.Tensor) -> torch.Tensor:
        total = self.backend.zeros_like(v)
        n = 0
        for batch in iterate_batches(self.dataloader, self.num_batches):
            chunk = self._matvec_one_batch(v, batch)
            total = self.backend.axpy(1.0, chunk, total)
            n += 1
        if n == 0:
            raise RuntimeError("dataloader yielded no batches")
        return self.backend.scale(1.0 / n, total)

    def _matvec_one_batch(self, v: torch.Tensor, batch: Any) -> torch.Tensor:
        batch = move_batch_to_device(batch, self._device)
        per_sample_grads = self._per_sample_grads(batch)
        # Stack into (n_samples, n_params).
        grads_mat = torch.cat(
            [
                per_sample_grads[n].reshape(per_sample_grads[n].shape[0], -1)
                for n in self._param_names
            ],
            dim=1,
        )
        n_samples = grads_mat.shape[0]
        # F v = (1/N) G^T (G v).
        return (grads_mat.t() @ (grads_mat @ v)) / n_samples

    def _per_sample_grads(self, batch: Any) -> Mapping[str, torch.Tensor]:
        param_dict: dict[str, torch.Tensor] = dict(
            zip(self._param_names, self._param_list, strict=True)
        )

        def loss_at(p_subset: dict[str, torch.Tensor], single: Any) -> torch.Tensor:
            full = {**self._fixed_params, **p_subset, **self._buffers}
            adapter = cast(nn.Module, _FunctionalModel(self.model, full))
            return self.per_sample_loss_fn(adapter, single)

        grad_fn = torch.func.grad(loss_at, argnums=0)
        in_dims = (None, _broadcast_dim(batch, self.sample_dim))
        result = torch.func.vmap(grad_fn, in_dims=in_dims)(param_dict, batch)
        return cast(Mapping[str, torch.Tensor], result)

DDPHessianOperator

Bases: HessianOperator

HessianOperator that all-reduces the HVP across torch.distributed ranks.

The model passed in may already be wrapped with torch.nn.parallel.DistributedDataParallel; we read params from it directly. Each rank should be receiving its own shard of the dataset (typical pattern: a torch.utils.data.distributed.DistributedSampler).

Source code in hessian_eigenthings/operators/distributed/ddp.py
class DDPHessianOperator(HessianOperator):
    """HessianOperator that all-reduces the HVP across `torch.distributed` ranks.

    The model passed in may already be wrapped with
    `torch.nn.parallel.DistributedDataParallel`; we read params from it directly.
    Each rank should be receiving its own shard of the dataset (typical pattern: a
    `torch.utils.data.distributed.DistributedSampler`).
    """

    def __init__(
        self,
        model: nn.Module,
        dataloader: Iterable[Any],
        loss_fn: LossFn,
        *,
        param_filter: ParamFilter | None = None,
        full_dataset: bool = True,
        num_batches: int | None = None,
        method: HvpMethod = "autograd",
        fd_eps: float | None = None,
        backend: LinAlgBackend[torch.Tensor] | None = None,
        process_group: dist.ProcessGroup | None = None,
    ) -> None:
        super().__init__(
            model=model,
            dataloader=dataloader,
            loss_fn=loss_fn,
            param_filter=param_filter,
            full_dataset=full_dataset,
            num_batches=num_batches,
            method=method,
            fd_eps=fd_eps,
            backend=backend,
        )
        self.process_group = process_group
        if dist.is_available() and dist.is_initialized():
            self._world_size = dist.get_world_size(group=process_group)
        else:
            self._world_size = 1

    def _hvp_autograd(self, v_split: list[torch.Tensor], batch: Any) -> torch.Tensor:
        loss = self.loss_fn(self.model, batch)
        grads = torch.autograd.grad(loss, self._param_list, create_graph=True)
        if self._world_size > 1:
            grads = tuple(self._all_reduce_mean(g) for g in grads)
        hvp = torch.autograd.grad(grads, self._param_list, grad_outputs=v_split)
        if self._world_size > 1:
            hvp = tuple(self._all_reduce_mean(h) for h in hvp)
        return torch.cat([h.reshape(-1) for h in hvp])

    def _hvp_finite_difference(self, v_split: list[torch.Tensor], batch: Any) -> torch.Tensor:
        # Each rank's _compute_grad_flat already returns its local gradient; we
        # all-reduce both g+ and g- before the difference.
        hvp_local = super()._hvp_finite_difference(v_split, batch)
        if self._world_size > 1:
            hvp_local = self._all_reduce_mean(hvp_local)
        return hvp_local

    def _all_reduce_mean(self, t: torch.Tensor) -> torch.Tensor:
        reduced: torch.Tensor = dist_nn.all_reduce(  # type: ignore[no-untyped-call]
            t, op=dist.ReduceOp.SUM, group=self.process_group
        )
        return reduced / self._world_size