Loss functions¶
Standard supervised¶
supervised_loss(criterion)
¶
Make a loss_fn(model, batch) for HessianOperator from a (input, target) criterion.
Source code in hessian_eigenthings/loss_fns/standard.py
supervised_forward(model, batch)
¶
The forward_fn for GGNOperator on a (input, target) batch: returns model(input).
Source code in hessian_eigenthings/loss_fns/standard.py
supervised_loss_of_output(criterion)
¶
Make a loss_of_output_fn for GGNOperator from a (output, target) criterion.
Source code in hessian_eigenthings/loss_fns/standard.py
supervised_per_sample_loss(criterion)
¶
Make a per_sample_loss_fn for EmpiricalFisherOperator. The criterion is called on
a single un-batched sample after vmap strips the batch dimension.
Source code in hessian_eigenthings/loss_fns/standard.py
HuggingFace Transformers¶
hf_lm_loss()
¶
For autoregressive LMs: loss_fn(model, batch) calls model(**batch).loss.
The batch must include labels so HF computes the loss internally; for causal LMs
that's typically labels=input_ids (with the standard internal shift).
Source code in hessian_eigenthings/loss_fns/huggingface.py
hf_seq2seq_loss()
¶
For seq2seq models (e.g. T5/BART) that compute the decoder cross-entropy internally.
hf_lm_forward()
¶
forward_fn for GGNOperator on an HF causal LM: returns logits (no loss).
Source code in hessian_eigenthings/loss_fns/huggingface.py
hf_lm_loss_of_output()
¶
loss_of_output_fn for GGNOperator on an HF causal LM: standard shifted CE on logits and labels.
The returned callable carries a .hvp(output, batch, u) method holding the
closed-form loss-Hessian-vector product for mean-reduced cross-entropy with
softmax: H @ u = (p * u - p * (p · u)) / n per non-ignored position,
where p = softmax(logits) and n is the count of non-ignored positions.
GGNOperator picks this up automatically and skips the autograd
double-backward.
Source code in hessian_eigenthings/loss_fns/huggingface.py
TransformerLens¶
tlens_loss()
¶
For TLens HookedTransformer: loss_fn(model, tokens) = model(tokens, return_type='loss').
Source code in hessian_eigenthings/loss_fns/transformer_lens.py
tlens_forward()
¶
forward_fn for GGNOperator: returns the model's logits.
Source code in hessian_eigenthings/loss_fns/transformer_lens.py
tlens_loss_of_output()
¶
loss_of_output_fn for GGNOperator: shifted CE on the TLens logits/tokens.