Skip to content

Custom loss functions

Each operator takes a loss-function callable. The exact signature differs slightly by operator because GGN and empirical Fisher need to know more than just "the loss".

HessianOperator: loss_fn(model, batch) -> Tensor

The simplest case. Whatever closure you'd use to compute the loss in your training loop.

from hessian_eigenthings.operators import HessianOperator

def loss_fn(model, batch):
    inputs, targets = batch
    logits = model(inputs)
    return torch.nn.functional.cross_entropy(logits, targets)

op = HessianOperator(model=model, dataloader=loader, loss_fn=loss_fn)

For common patterns we ship helpers:

from hessian_eigenthings.loss_fns import supervised_loss, hf_lm_loss, tlens_loss

# Equivalent to the loss_fn above
op = HessianOperator(model=model, dataloader=loader,
                    loss_fn=supervised_loss(torch.nn.functional.cross_entropy))

GGNOperator: split into forward_fn + loss_of_output_fn

GGN computes \(G v = J^\top H_\ell J v\), which needs the model output and the loss-of-output separately. We can't extract them from a single closure efficiently.

from hessian_eigenthings.operators import GGNOperator

def forward_fn(model, batch):
    inputs, _ = batch
    return model(inputs)

def loss_of_output_fn(output, batch):
    _, targets = batch
    return torch.nn.functional.cross_entropy(output, targets)

op = GGNOperator(
    model=model,
    dataloader=loader,
    forward_fn=forward_fn,
    loss_of_output_fn=loss_of_output_fn,
)

Or use the helpers:

from hessian_eigenthings.loss_fns import supervised_forward, supervised_loss_of_output

op = GGNOperator(
    model=model,
    dataloader=loader,
    forward_fn=supervised_forward,
    loss_of_output_fn=supervised_loss_of_output(torch.nn.functional.cross_entropy),
)

EmpiricalFisherOperator: per_sample_loss_fn(model, sample) -> Tensor

Empirical Fisher needs per-sample gradients. Provide a function that takes a single un-batched sample:

from hessian_eigenthings.operators import EmpiricalFisherOperator

def per_sample_loss(model, sample):
    x, y = sample
    return torch.nn.functional.cross_entropy(
        model(x.unsqueeze(0)), y.unsqueeze(0)
    )

op = EmpiricalFisherOperator(
    model=model,
    dataloader=loader,
    per_sample_loss_fn=per_sample_loss,
    sample_dim=0,    # which axis of the batch tensors is the sample axis
)

The operator uses torch.func.vmap(grad(...)) to vectorize over the batch, so per-sample grads are computed in one efficient pass — not a Python loop.

HuggingFace and TransformerLens

For HuggingFace causal LMs:

from hessian_eigenthings.loss_fns import hf_lm_loss, hf_lm_forward, hf_lm_loss_of_output

# Hessian
HessianOperator(model=hf_model, dataloader=batches, loss_fn=hf_lm_loss())

# GGN
GGNOperator(
    model=hf_model, dataloader=batches,
    forward_fn=hf_lm_forward(),
    loss_of_output_fn=hf_lm_loss_of_output(),
)

For TransformerLens:

from hessian_eigenthings.loss_fns import tlens_loss, tlens_forward, tlens_loss_of_output

HessianOperator(model=tlens_model, dataloader=tokens, loss_fn=tlens_loss())
GGNOperator(model=tlens_model, dataloader=tokens,
            forward_fn=tlens_forward(),
            loss_of_output_fn=tlens_loss_of_output())

Notes

  • The loss should be a scalar (tensor.dim() == 0).
  • Reduction matters for normalization: reduction='mean' produces the per-sample-mean Hessian; reduction='sum' is batch_size times larger. Stay consistent across batches in the same dataloader.
  • For losses that depend on multiple samples interacting (contrastive, in-batch negatives), the per-sample-loss API for empirical Fisher is undefined — use Hessian or GGN instead.