Skip to content

Parameter selection

select_parameters(model, param_filter=None)

Return parameters matching param_filter, in named_parameters order.

Source code in hessian_eigenthings/param_utils.py
def select_parameters(
    model: nn.Module,
    param_filter: ParamFilter | None = None,
) -> dict[str, nn.Parameter]:
    """Return parameters matching `param_filter`, in `named_parameters` order."""
    out: dict[str, nn.Parameter] = {}
    for name, p in model.named_parameters():
        if not p.requires_grad:
            continue
        if param_filter is not None and not param_filter(name, p):
            continue
        out[name] = p
    if not out:
        raise ValueError("param_filter selected zero parameters")
    return out

match_names(*patterns)

Glob-style name match. Matches if any pattern matches the parameter name.

Source code in hessian_eigenthings/param_utils.py
def match_names(*patterns: str) -> ParamFilter:
    """Glob-style name match. Matches if *any* pattern matches the parameter name."""

    def _filter(name: str, _: nn.Parameter) -> bool:
        return any(fnmatch.fnmatchcase(name, pat) for pat in patterns)

    return _filter

match_regex(*patterns)

Regex name match. Matches if any compiled pattern matches the parameter name.

Source code in hessian_eigenthings/param_utils.py
def match_regex(*patterns: str) -> ParamFilter:
    """Regex name match. Matches if *any* compiled pattern matches the parameter name."""
    compiled = [re.compile(pat) for pat in patterns]

    def _filter(name: str, _: nn.Parameter) -> bool:
        return any(p.search(name) is not None for p in compiled)

    return _filter

params_to_vector(params)

Concatenate per-param tensors into a single flat vector. Order follows iteration order.

Source code in hessian_eigenthings/param_utils.py
def params_to_vector(params: Mapping[str, torch.Tensor]) -> torch.Tensor:
    """Concatenate per-param tensors into a single flat vector. Order follows iteration order."""
    return torch.cat([p.reshape(-1) for p in params.values()])

vector_to_params(vec, reference)

Split a flat vector into a dict of param-shaped tensors matching reference.

Source code in hessian_eigenthings/param_utils.py
def vector_to_params(
    vec: torch.Tensor, reference: Mapping[str, torch.Tensor]
) -> dict[str, torch.Tensor]:
    """Split a flat vector into a dict of param-shaped tensors matching `reference`."""
    if vec.dim() != 1:
        raise ValueError(f"expected 1-D vector, got shape {tuple(vec.shape)}")
    expected = total_size(reference)
    if vec.numel() != expected:
        raise ValueError(f"vector has {vec.numel()} elements, reference expects {expected}")
    out: dict[str, torch.Tensor] = {}
    offset = 0
    for name, ref in reference.items():
        n = ref.numel()
        out[name] = vec[offset : offset + n].reshape_as(ref)
        offset += n
    return out

ParamFilter = Callable[[str, nn.Parameter], bool] module-attribute