Skip to content

Loss functions

Standard supervised

supervised_loss(criterion)

Make a loss_fn(model, batch) for HessianOperator from a (input, target) criterion.

Source code in hessian_eigenthings/loss_fns/standard.py
def supervised_loss(
    criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> Callable[[nn.Module, Any], torch.Tensor]:
    """Make a `loss_fn(model, batch)` for `HessianOperator` from a (input, target) criterion."""

    def _fn(model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        x, y = batch
        return criterion(model(x), y)

    return _fn

supervised_forward(model, batch)

The forward_fn for GGNOperator on a (input, target) batch: returns model(input).

Source code in hessian_eigenthings/loss_fns/standard.py
def supervised_forward(model: nn.Module, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
    """The `forward_fn` for `GGNOperator` on a (input, target) batch: returns model(input)."""
    x, _ = batch
    out: torch.Tensor = model(x)
    return out

supervised_loss_of_output(criterion)

Make a loss_of_output_fn for GGNOperator from a (output, target) criterion.

Source code in hessian_eigenthings/loss_fns/standard.py
def supervised_loss_of_output(
    criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> Callable[[torch.Tensor, tuple[torch.Tensor, torch.Tensor]], torch.Tensor]:
    """Make a `loss_of_output_fn` for `GGNOperator` from a (output, target) criterion."""

    def _fn(output: torch.Tensor, batch: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        _, y = batch
        return criterion(output, y)

    return _fn

supervised_per_sample_loss(criterion)

Make a per_sample_loss_fn for EmpiricalFisherOperator. The criterion is called on a single un-batched sample after vmap strips the batch dimension.

Source code in hessian_eigenthings/loss_fns/standard.py
def supervised_per_sample_loss(
    criterion: Callable[[torch.Tensor, torch.Tensor], torch.Tensor],
) -> Callable[[nn.Module, tuple[torch.Tensor, torch.Tensor]], torch.Tensor]:
    """Make a `per_sample_loss_fn` for `EmpiricalFisherOperator`. The criterion is called on
    a single un-batched sample after `vmap` strips the batch dimension."""

    def _fn(model: nn.Module, sample: tuple[torch.Tensor, torch.Tensor]) -> torch.Tensor:
        x, y = sample
        return criterion(model(x.unsqueeze(0)).squeeze(0), y)

    return _fn

HuggingFace Transformers

hf_lm_loss()

For autoregressive LMs: loss_fn(model, batch) calls model(**batch).loss.

The batch must include labels so HF computes the loss internally; for causal LMs that's typically labels=input_ids (with the standard internal shift).

Source code in hessian_eigenthings/loss_fns/huggingface.py
def hf_lm_loss() -> Callable[[nn.Module, dict[str, Any]], torch.Tensor]:
    """For autoregressive LMs: `loss_fn(model, batch)` calls `model(**batch).loss`.

    The batch must include `labels` so HF computes the loss internally; for causal LMs
    that's typically `labels=input_ids` (with the standard internal shift).
    """

    def _fn(model: nn.Module, batch: dict[str, Any]) -> torch.Tensor:
        out = model(**batch)
        return out.loss  # type: ignore[no-any-return]

    return _fn

hf_seq2seq_loss()

For seq2seq models (e.g. T5/BART) that compute the decoder cross-entropy internally.

Source code in hessian_eigenthings/loss_fns/huggingface.py
def hf_seq2seq_loss() -> Callable[[nn.Module, dict[str, Any]], torch.Tensor]:
    """For seq2seq models (e.g. T5/BART) that compute the decoder cross-entropy internally."""
    return hf_lm_loss()

hf_lm_forward()

forward_fn for GGNOperator on an HF causal LM: returns logits (no loss).

Source code in hessian_eigenthings/loss_fns/huggingface.py
def hf_lm_forward() -> Callable[[nn.Module, dict[str, Any]], torch.Tensor]:
    """`forward_fn` for `GGNOperator` on an HF causal LM: returns `logits` (no loss)."""

    def _fn(model: nn.Module, batch: dict[str, Any]) -> torch.Tensor:
        batch_no_labels = {k: v for k, v in batch.items() if k != "labels"}
        out = model(**batch_no_labels)
        return out.logits  # type: ignore[no-any-return]

    return _fn

hf_lm_loss_of_output()

loss_of_output_fn for GGNOperator on an HF causal LM: standard shifted CE on logits and labels.

The returned callable carries a .hvp(output, batch, u) method holding the closed-form loss-Hessian-vector product for mean-reduced cross-entropy with softmax: H @ u = (p * u - p * (p · u)) / n per non-ignored position, where p = softmax(logits) and n is the count of non-ignored positions. GGNOperator picks this up automatically and skips the autograd double-backward.

Source code in hessian_eigenthings/loss_fns/huggingface.py
def hf_lm_loss_of_output() -> Callable[[torch.Tensor, dict[str, Any]], torch.Tensor]:
    """`loss_of_output_fn` for `GGNOperator` on an HF causal LM: standard shifted CE on `logits` and `labels`.

    The returned callable carries a `.hvp(output, batch, u)` method holding the
    closed-form loss-Hessian-vector product for mean-reduced cross-entropy with
    softmax: `H @ u = (p * u - p * (p · u)) / n` per non-ignored position,
    where `p = softmax(logits)` and `n` is the count of non-ignored positions.
    `GGNOperator` picks this up automatically and skips the autograd
    double-backward.
    """
    return _LossOfOutputWithHvp(_hf_lm_shifted_ce, _hf_lm_ce_hvp)

TransformerLens

tlens_loss()

For TLens HookedTransformer: loss_fn(model, tokens) = model(tokens, return_type='loss').

Source code in hessian_eigenthings/loss_fns/transformer_lens.py
def tlens_loss() -> Callable[[nn.Module, Any], torch.Tensor]:
    """For TLens HookedTransformer: `loss_fn(model, tokens) = model(tokens, return_type='loss')`."""

    def _fn(model: nn.Module, batch: Any) -> torch.Tensor:
        tokens = batch["tokens"] if isinstance(batch, dict) else batch
        return model(tokens, return_type="loss")  # type: ignore[no-any-return]

    return _fn

tlens_forward()

forward_fn for GGNOperator: returns the model's logits.

Source code in hessian_eigenthings/loss_fns/transformer_lens.py
def tlens_forward() -> Callable[[nn.Module, Any], torch.Tensor]:
    """`forward_fn` for `GGNOperator`: returns the model's logits."""

    def _fn(model: nn.Module, batch: Any) -> torch.Tensor:
        tokens = batch["tokens"] if isinstance(batch, dict) else batch
        return model(tokens, return_type="logits")  # type: ignore[no-any-return]

    return _fn

tlens_loss_of_output()

loss_of_output_fn for GGNOperator: shifted CE on the TLens logits/tokens.

Source code in hessian_eigenthings/loss_fns/transformer_lens.py
def tlens_loss_of_output() -> Callable[[torch.Tensor, Any], torch.Tensor]:
    """`loss_of_output_fn` for `GGNOperator`: shifted CE on the TLens logits/tokens."""

    def _fn(logits: torch.Tensor, batch: Any) -> torch.Tensor:
        tokens = batch["tokens"] if isinstance(batch, dict) else batch
        shift_logits = logits[..., :-1, :].contiguous()
        shift_tokens = tokens[..., 1:].contiguous()
        return torch.nn.functional.cross_entropy(
            shift_logits.view(-1, shift_logits.size(-1)),
            shift_tokens.view(-1),
        )

    return _fn