| """Self-contained subset of :mod:`circuit_sparsity.hook_utils` for inference builds. |
| |
| The full module has no exotic dependencies, but mirroring the definitions here |
| keeps the trimmed :mod:`circuit_sparsity.inference.gpt` module hermetic and easy to vendor. The |
| implementations below are copied with minor tweaks for readability so that code |
| written against :func:`hook_recorder`, :func:`hook_namespace`, and |
| :func:`torch_recompute_preserving_hook_context` behaves identically in both the |
| training and inference configurations. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import re |
| from contextlib import contextmanager |
| from functools import partial |
|
|
| import torch |
| import torch.utils.checkpoint |
|
|
|
|
| class HookContext: |
| """State container used by the hook helpers.""" |
|
|
| def __init__(self) -> None: |
| self._reset() |
| self.curintervtransformer = lambda x: x |
|
|
| def _reset(self) -> None: |
| self.curcontext = None |
| self.curname = "" |
| self.curregex = None |
| self.curinterventions = None |
| self.save_grads = None |
|
|
| def _get_interventions(self): |
| return self.curintervtransformer( |
| self.curinterventions if self.curinterventions is not None else {} |
| ) |
|
|
| @contextmanager |
| def hook_recorder(self, regex: str = ".*", interventions=None, save_grads: bool = False): |
| """Record tensors that pass through hooks matching ``regex``.""" |
|
|
| assert self.curcontext is None, "reentrancy not allowed!" |
|
|
| try: |
| self.curcontext = {} |
| self.curregex = re.compile(regex) |
| self.curname = "" |
| self.curinterventions = interventions |
| self.save_grads = save_grads |
|
|
| yield self.curcontext |
| finally: |
| self._reset() |
| get_context()._reset() |
|
|
| @contextmanager |
| def hook_intervention_transform(self, intervention_transformer): |
| oldintervention_transformer = self.curintervtransformer |
|
|
| def compose(f, g): |
| return lambda x: f(g(x)) |
|
|
| self.curintervtransformer = compose( |
| intervention_transformer, |
| self.curintervtransformer, |
| ) |
|
|
| try: |
| yield |
| finally: |
| self.curintervtransformer = oldintervention_transformer |
|
|
| @contextmanager |
| def hook_namespace(self, name: str): |
| """Temporarily push ``name`` onto the hook namespace stack.""" |
|
|
| oldname = self.curname |
| self.curname = self.curname + name + "." |
|
|
| try: |
| yield |
| finally: |
| self.curname = oldname |
|
|
| def hook_save(self, name: str, tensor: torch.Tensor) -> torch.Tensor: |
| """Optionally record ``tensor`` using the current namespace.""" |
|
|
| curinterventions = self._get_interventions() |
| if curinterventions is not None: |
| key = self.curname + name |
| if key in curinterventions: |
| tensor = curinterventions[key](tensor) |
|
|
| if self.curcontext is not None and self.curregex.match(self.curname + name): |
| self.curcontext[self.curname + name] = tensor |
|
|
| if self.curcontext is not None and self.save_grads and tensor.requires_grad: |
|
|
| class _Grad(torch.autograd.Function): |
| @staticmethod |
| def forward(ctx, input_tensor): |
| return input_tensor |
|
|
| @staticmethod |
| def backward(ctx, grad_output): |
| self.curcontext[self.curname + name + ".grad"] = grad_output |
| return grad_output |
|
|
| if self.curregex.match(self.curname + name + ".grad"): |
| tensor = _Grad.apply(tensor) |
|
|
| return tensor |
|
|
|
|
| def set_context(new_context: HookContext) -> None: |
| global context |
| context = new_context |
|
|
|
|
| def get_context() -> HookContext: |
| global context |
| return context |
|
|
|
|
| def torch_recompute_preserving_hook_context(f, *xs, use_reentrant=None): |
| """Wrapper around :func:`torch.utils.checkpoint` that propagates hooks.""" |
|
|
| oldcontext = get_context() |
| curcontext = HookContext() |
| curcontext.curcontext = ( |
| dict(oldcontext.curcontext) if oldcontext.curcontext is not None else None |
| ) |
| curcontext.curregex = oldcontext.curregex |
| curcontext.curname = oldcontext.curname |
| curcontext.curinterventions = ( |
| dict(oldcontext.curinterventions) if oldcontext.curinterventions is not None else None |
| ) |
| curcontext.save_grads = oldcontext.save_grads |
|
|
| is_recompute = False |
|
|
| def _f(curcontext: HookContext, *xs): |
| initcontext = get_context() |
| nonlocal is_recompute |
|
|
| set_context(curcontext) |
| try: |
| res = f(*xs) |
|
|
| if not is_recompute and oldcontext.curcontext is not None: |
| oldcontext.curcontext |= curcontext.curcontext |
| finally: |
| set_context(initcontext) |
| is_recompute = True |
| return res |
|
|
| res = torch.utils.checkpoint.checkpoint( |
| partial(_f, curcontext), *xs, use_reentrant=use_reentrant |
| ) |
|
|
| return res |
|
|
|
|
| context = HookContext() |
|
|
|
|
| def hook_recorder(*a, **k): |
| return get_context().hook_recorder(*a, **k) |
|
|
|
|
| def hook_namespace(*a, **k): |
| return get_context().hook_namespace(*a, **k) |
|
|
|
|
| def hook_save(*a, **k): |
| return get_context().hook_save(*a, **k) |
|
|
|
|
| def hook_intervention_transform(*a, **k): |
| return get_context().hook_intervention_transform(*a, **k) |