Skip to content

Activation cache

ActivationCache

Class to store and aggregate activation values from a model. It is a dictionary-like object with additional functionality to aggregate values.

Source code in easyroutine/interpretability/activation_cache.py
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
class ActivationCache():
    r"""
    Class to store and aggregate activation values from a model.
    It is a dictionary-like object with additional functionality to aggregate values.
    """

    def __init__(self):
        self.cache = {}
        self.logger = Logger(
            logname="ActivationCache",
            level="INFO",
        )

        self.valid_keys = (
            re.compile(r"resid_out_\d+"),
            re.compile(r"resid_in_\d+"),
            re.compile(r"resid_mid_\d+"),
            re.compile(r"attn_in_\d+"),
            re.compile(r"attn_out_\d+"),
            re.compile(r"avg_attn_pattern_L\dH\d+"),
            re.compile(r"pattern_L\dH\d+"),
            re.compile(r"values_\d+"),
            re.compile(r"input_ids"),
            re.compile(r"mapping_index")
        )

        self.aggregation_strategies = {}
        self.register_aggregation("mapping_index", lambda values: values[0])  # First value
        self.register_aggregation("pattern_", lambda values: values)  # Keep as list
        self.register_aggregation("input_ids", lambda values: values)  # Keep as list
        self.register_aggregation("offset", lambda values: [item for sublist in values for item in sublist])  # Flatten lists

        self.defferred_cache = False

    def __repr__(self) -> str:
        """
        Returns:
            str: A string representation of the ActivationCache object.

        Examples:
            >>> cache
            ActivationCache(resid_out_0, resid_in_0, resid_mid_0, attn_in_0, attn_out_0, avg_attn_pattern_L1H1, pattern_L1H1, values_L1H1)
        """
        return f"ActivationCache(`{', '.join(self.cache.keys())}`)"

    def __str__(self) -> str:
        """
        Returns:
            str: A string representation of the ActivationCache object.

        Examples:
            >>> print(cache)
            ActivationCache(resid_out_0: torch.Tensor([1, 2, 3, 4]), resid_in_0: torch.Tensor([1, 2, 3, 4]))
        """
        return f"ActivationCache({', '.join([f'{key}: {value}' for key, value in self.cache.items()])})"

    def __setitem__(self, key:str, value):
        """
        Set a key-value pair in the cache.

        Arguments:
            key (str): The key to store the value.
            value (Any): The value to store.

        Examples:
            >>> cache["resid_out_0"] = torch.randn(1, 3, 16)
        """
        if not any([pattern.match(key) for pattern in self.valid_keys]):
            self.logger.warning(f"Invalid key: {key}. Valid keys are: {self.valid_keys}. Could be a user-defined key.")
        self.cache[key] = value

    def __getitem__(self, key:str):
        """
        Get a value from the cache.

        Arguments:
            key (str): The key to retrieve the value.

        Examples:
            >>> cache["resid_out_0"]
            torch.Tensor([1, 2, 3, 4])
        """
        return self.cache[key]


    def __delitem__(self, key:str):
        """
        Remove a key-value pair from the cache.

        Arguments:
            key (str): The key to remove from the cache.
        """
        del self.cache[key]

    def __add__(self, other) -> "ActivationCache":
        """
        Overload the `+` operator to merge caches efficiently.
        Arguments:
            other (dict or ActivationCache): Another cache or dictionary to merge with.
        Returns:
            ActivationCache: A new ActivationCache object with merged data.
        """
        if not isinstance(other, (dict, ActivationCache)):
            raise TypeError("Can only add ActivationCache or dict objects.")

        new_cache = ActivationCache()
        new_cache.cache = {**self.cache, **(other.cache if isinstance(other, ActivationCache) else other)}
        return new_cache

    def __contains__(self, key):
        """
        Check if a key is present in the cache.
        Arguments:
            key (str): The key to check.
        Returns:
            bool: True if the key is present, False otherwise
        """
        return key in self.cache

    def items(self):
        """
        Just like the dictionary items method, returns a list of key-value pairs.
        """
        return self.cache.items()

    def keys(self):
        """
        Just like the dictionary keys method, returns a list of keys.
        """
        return self.cache.keys()

    def values(self):
        """
        Just like the dictionary values method, returns a list of values.
        """
        return self.cache.values()

    def update(self, other):
        """
        Updates the cache with values from an additional dictionary or ActivationCache object.
        Arguments:
            other (Union[dict, ActivationCache]): Dictionary or ActivationCache object to update with.
        """
        if isinstance(other, dict):
            self.cache.update(other)
        elif isinstance(other, ActivationCache):
            self.cache.update(other.cache)
        else:
            raise TypeError("Can only update with dict or ActivationCache objects.")



    def to(self, device: Union[str, torch.device]):
        """
        Moves the tensors in the cache to a specified device.

        Args:
            device (Union[str, torch.device]): The device to move the tensors to.
        """

        for key, value in self.cache.items():
            if hasattr(value, "to"):
                self.cache[key] = value.to(device)

    def cpu(self):
        """
        Moves the tensors in the cache to the CPU.
        """
        self.to("cpu")

    def cuda(self):
        """
        Moves the tensors in the cache to the GPU.
        """
        self.to("cuda")


    def register_aggregation(self, key_pattern, function):
        """
        Register a custom aggregation strategy for keys matching a pattern. In this way, you can define how to aggregate values for specific keys when merging caches.

        Arguments:
            key_pattern (str): The key or prefix to match.
            function (callable): The function to apply for aggregation.

        Examples:
            >>> cache.register_aggregation("values_", lambda values: torch.stack(values, dim=0))
        """
        self.aggregation_strategies[key_pattern] = function


    def default_aggregation(self, values):
        """
        Default aggregation strategy for keys without a custom strategy.
        Handles tensors, lists, and scalars.

        Arguments:
            values (List): List of values to aggregate.

        Returns:
            Union[torch.Tensor, List, Any]: The aggregated value.
        """
        if isinstance(values[0], torch.Tensor):
            try:
                return torch.cat(values, dim=0)
            except RuntimeError:
                return torch.stack(values, dim=0)
        elif isinstance(values[0], list):
            return [item for sublist in values for item in sublist]
        else:
            return values[0]  # Fallback to the first value

    @contextlib.contextmanager
    def deferred_mode(self):
        """
        Context manager to enable deferred aggregation.
        Collects all external caches in a list and aggregates them at the end of the context.
        This is most similar to the old way of using the `cat` method. It could (or could not) be more efficient.
        The main difference to direct calls to `cat` is that the cache is not updated until the end of the context, in this way the torch.cat, torch.stack and the other strategies are called only once.
        It will require more memory, but it could be more efficient.

        Examples:
            >>> with cache.deferred_mode():
            >>>     cache.cat(external_cache1)
            >>>     cache.cat(external_cache2)
        """
        self.deferred_cache = []
        try:
            yield self
            # Perform aggregation at the end of the context
            for external_cache in self.deferred_cache:
                self.cat(external_cache)
        finally:
            # Clear the deferred cache
            self.deferred_cache = None

    def cat(self, external_cache):
        """
        Merge the current cache with an external cache using aggregation strategies.

        Arguments:
            external_cache (ActivationCache): The external cache to merge with.


        Examples:
            >>> a, b = ActivationCache(), ActivationCache()
            >>> a["values_0"] = torch.tensor([1, 2])
            >>> b["values_0"] = torch.tensor([1, 4])
            >>> a.cat(b)
            >>> print(a["values_0"].shape)
            torch.Size([2,1])
            >>> print(a["values_0"])
            tensor([[2], [4]]
        """
        if not isinstance(external_cache, ActivationCache):
            raise TypeError("external_cache must be an instance of ActivationCache")

        # Case 1: Initialize self if it's empty
        if not self.cache and external_cache.cache:
            self.update(external_cache.cache)
            return

        # Case 2: Ensure both caches have the same keys
        self_keys = set(self.cache.keys())
        external_keys = set(external_cache.cache.keys())

        if self_keys != external_keys:
            raise ValueError(
                f"Key mismatch: self has {self_keys - external_keys}, "
                f"external has {external_keys - self_keys}"
            )

        # Case 3: Aggregate matching keys using registered strategies or default
        for key in self.cache:
            # Check for a custom aggregation strategy
            for pattern, strategy in self.aggregation_strategies.items():
                if key.startswith(pattern):
                    self.cache[key] = strategy([self.cache[key], external_cache[key]])
                    break
            else:
                # Use the default aggregation if no custom strategy matches
                self.cache[key] = self.default_aggregation(
                    [self.cache[key], external_cache[key]]
                )

__add__(other)

Overload the + operator to merge caches efficiently. Arguments: other (dict or ActivationCache): Another cache or dictionary to merge with. Returns: ActivationCache: A new ActivationCache object with merged data.

Source code in easyroutine/interpretability/activation_cache.py
101
102
103
104
105
106
107
108
109
110
111
112
113
114
def __add__(self, other) -> "ActivationCache":
    """
    Overload the `+` operator to merge caches efficiently.
    Arguments:
        other (dict or ActivationCache): Another cache or dictionary to merge with.
    Returns:
        ActivationCache: A new ActivationCache object with merged data.
    """
    if not isinstance(other, (dict, ActivationCache)):
        raise TypeError("Can only add ActivationCache or dict objects.")

    new_cache = ActivationCache()
    new_cache.cache = {**self.cache, **(other.cache if isinstance(other, ActivationCache) else other)}
    return new_cache

__contains__(key)

Check if a key is present in the cache. Arguments: key (str): The key to check. Returns: bool: True if the key is present, False otherwise

Source code in easyroutine/interpretability/activation_cache.py
116
117
118
119
120
121
122
123
124
def __contains__(self, key):
    """
    Check if a key is present in the cache.
    Arguments:
        key (str): The key to check.
    Returns:
        bool: True if the key is present, False otherwise
    """
    return key in self.cache

__delitem__(key)

Remove a key-value pair from the cache.

Parameters:

Name Type Description Default
key str

The key to remove from the cache.

required
Source code in easyroutine/interpretability/activation_cache.py
92
93
94
95
96
97
98
99
def __delitem__(self, key:str):
    """
    Remove a key-value pair from the cache.

    Arguments:
        key (str): The key to remove from the cache.
    """
    del self.cache[key]

__getitem__(key)

Get a value from the cache.

Parameters:

Name Type Description Default
key str

The key to retrieve the value.

required

Examples:

>>> cache["resid_out_0"]
torch.Tensor([1, 2, 3, 4])
Source code in easyroutine/interpretability/activation_cache.py
78
79
80
81
82
83
84
85
86
87
88
89
def __getitem__(self, key:str):
    """
    Get a value from the cache.

    Arguments:
        key (str): The key to retrieve the value.

    Examples:
        >>> cache["resid_out_0"]
        torch.Tensor([1, 2, 3, 4])
    """
    return self.cache[key]

__repr__()

Returns:

Name Type Description
str str

A string representation of the ActivationCache object.

Examples:

>>> cache
ActivationCache(resid_out_0, resid_in_0, resid_mid_0, attn_in_0, attn_out_0, avg_attn_pattern_L1H1, pattern_L1H1, values_L1H1)
Source code in easyroutine/interpretability/activation_cache.py
41
42
43
44
45
46
47
48
49
50
def __repr__(self) -> str:
    """
    Returns:
        str: A string representation of the ActivationCache object.

    Examples:
        >>> cache
        ActivationCache(resid_out_0, resid_in_0, resid_mid_0, attn_in_0, attn_out_0, avg_attn_pattern_L1H1, pattern_L1H1, values_L1H1)
    """
    return f"ActivationCache(`{', '.join(self.cache.keys())}`)"

__setitem__(key, value)

Set a key-value pair in the cache.

Parameters:

Name Type Description Default
key str

The key to store the value.

required
value Any

The value to store.

required

Examples:

>>> cache["resid_out_0"] = torch.randn(1, 3, 16)
Source code in easyroutine/interpretability/activation_cache.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
def __setitem__(self, key:str, value):
    """
    Set a key-value pair in the cache.

    Arguments:
        key (str): The key to store the value.
        value (Any): The value to store.

    Examples:
        >>> cache["resid_out_0"] = torch.randn(1, 3, 16)
    """
    if not any([pattern.match(key) for pattern in self.valid_keys]):
        self.logger.warning(f"Invalid key: {key}. Valid keys are: {self.valid_keys}. Could be a user-defined key.")
    self.cache[key] = value

__str__()

Returns:

Name Type Description
str str

A string representation of the ActivationCache object.

Examples:

>>> print(cache)
ActivationCache(resid_out_0: torch.Tensor([1, 2, 3, 4]), resid_in_0: torch.Tensor([1, 2, 3, 4]))
Source code in easyroutine/interpretability/activation_cache.py
52
53
54
55
56
57
58
59
60
61
def __str__(self) -> str:
    """
    Returns:
        str: A string representation of the ActivationCache object.

    Examples:
        >>> print(cache)
        ActivationCache(resid_out_0: torch.Tensor([1, 2, 3, 4]), resid_in_0: torch.Tensor([1, 2, 3, 4]))
    """
    return f"ActivationCache({', '.join([f'{key}: {value}' for key, value in self.cache.items()])})"

cat(external_cache)

Merge the current cache with an external cache using aggregation strategies.

Parameters:

Name Type Description Default
external_cache ActivationCache

The external cache to merge with.

required

Examples:

>>> a, b = ActivationCache(), ActivationCache()
>>> a["values_0"] = torch.tensor([1, 2])
>>> b["values_0"] = torch.tensor([1, 4])
>>> a.cat(b)
>>> print(a["values_0"].shape)
torch.Size([2,1])
>>> print(a["values_0"])
tensor([[2], [4]]
Source code in easyroutine/interpretability/activation_cache.py
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
def cat(self, external_cache):
    """
    Merge the current cache with an external cache using aggregation strategies.

    Arguments:
        external_cache (ActivationCache): The external cache to merge with.


    Examples:
        >>> a, b = ActivationCache(), ActivationCache()
        >>> a["values_0"] = torch.tensor([1, 2])
        >>> b["values_0"] = torch.tensor([1, 4])
        >>> a.cat(b)
        >>> print(a["values_0"].shape)
        torch.Size([2,1])
        >>> print(a["values_0"])
        tensor([[2], [4]]
    """
    if not isinstance(external_cache, ActivationCache):
        raise TypeError("external_cache must be an instance of ActivationCache")

    # Case 1: Initialize self if it's empty
    if not self.cache and external_cache.cache:
        self.update(external_cache.cache)
        return

    # Case 2: Ensure both caches have the same keys
    self_keys = set(self.cache.keys())
    external_keys = set(external_cache.cache.keys())

    if self_keys != external_keys:
        raise ValueError(
            f"Key mismatch: self has {self_keys - external_keys}, "
            f"external has {external_keys - self_keys}"
        )

    # Case 3: Aggregate matching keys using registered strategies or default
    for key in self.cache:
        # Check for a custom aggregation strategy
        for pattern, strategy in self.aggregation_strategies.items():
            if key.startswith(pattern):
                self.cache[key] = strategy([self.cache[key], external_cache[key]])
                break
        else:
            # Use the default aggregation if no custom strategy matches
            self.cache[key] = self.default_aggregation(
                [self.cache[key], external_cache[key]]
            )

cpu()

Moves the tensors in the cache to the CPU.

Source code in easyroutine/interpretability/activation_cache.py
171
172
173
174
175
def cpu(self):
    """
    Moves the tensors in the cache to the CPU.
    """
    self.to("cpu")

cuda()

Moves the tensors in the cache to the GPU.

Source code in easyroutine/interpretability/activation_cache.py
177
178
179
180
181
def cuda(self):
    """
    Moves the tensors in the cache to the GPU.
    """
    self.to("cuda")

default_aggregation(values)

Default aggregation strategy for keys without a custom strategy. Handles tensors, lists, and scalars.

Parameters:

Name Type Description Default
values List

List of values to aggregate.

required

Returns:

Type Description

Union[torch.Tensor, List, Any]: The aggregated value.

Source code in easyroutine/interpretability/activation_cache.py
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
def default_aggregation(self, values):
    """
    Default aggregation strategy for keys without a custom strategy.
    Handles tensors, lists, and scalars.

    Arguments:
        values (List): List of values to aggregate.

    Returns:
        Union[torch.Tensor, List, Any]: The aggregated value.
    """
    if isinstance(values[0], torch.Tensor):
        try:
            return torch.cat(values, dim=0)
        except RuntimeError:
            return torch.stack(values, dim=0)
    elif isinstance(values[0], list):
        return [item for sublist in values for item in sublist]
    else:
        return values[0]  # Fallback to the first value

deferred_mode()

Context manager to enable deferred aggregation. Collects all external caches in a list and aggregates them at the end of the context. This is most similar to the old way of using the cat method. It could (or could not) be more efficient. The main difference to direct calls to cat is that the cache is not updated until the end of the context, in this way the torch.cat, torch.stack and the other strategies are called only once. It will require more memory, but it could be more efficient.

Examples:

>>> with cache.deferred_mode():
>>>     cache.cat(external_cache1)
>>>     cache.cat(external_cache2)
Source code in easyroutine/interpretability/activation_cache.py
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
@contextlib.contextmanager
def deferred_mode(self):
    """
    Context manager to enable deferred aggregation.
    Collects all external caches in a list and aggregates them at the end of the context.
    This is most similar to the old way of using the `cat` method. It could (or could not) be more efficient.
    The main difference to direct calls to `cat` is that the cache is not updated until the end of the context, in this way the torch.cat, torch.stack and the other strategies are called only once.
    It will require more memory, but it could be more efficient.

    Examples:
        >>> with cache.deferred_mode():
        >>>     cache.cat(external_cache1)
        >>>     cache.cat(external_cache2)
    """
    self.deferred_cache = []
    try:
        yield self
        # Perform aggregation at the end of the context
        for external_cache in self.deferred_cache:
            self.cat(external_cache)
    finally:
        # Clear the deferred cache
        self.deferred_cache = None

items()

Just like the dictionary items method, returns a list of key-value pairs.

Source code in easyroutine/interpretability/activation_cache.py
126
127
128
129
130
def items(self):
    """
    Just like the dictionary items method, returns a list of key-value pairs.
    """
    return self.cache.items()

keys()

Just like the dictionary keys method, returns a list of keys.

Source code in easyroutine/interpretability/activation_cache.py
132
133
134
135
136
def keys(self):
    """
    Just like the dictionary keys method, returns a list of keys.
    """
    return self.cache.keys()

register_aggregation(key_pattern, function)

Register a custom aggregation strategy for keys matching a pattern. In this way, you can define how to aggregate values for specific keys when merging caches.

Parameters:

Name Type Description Default
key_pattern str

The key or prefix to match.

required
function callable

The function to apply for aggregation.

required

Examples:

>>> cache.register_aggregation("values_", lambda values: torch.stack(values, dim=0))
Source code in easyroutine/interpretability/activation_cache.py
184
185
186
187
188
189
190
191
192
193
194
195
def register_aggregation(self, key_pattern, function):
    """
    Register a custom aggregation strategy for keys matching a pattern. In this way, you can define how to aggregate values for specific keys when merging caches.

    Arguments:
        key_pattern (str): The key or prefix to match.
        function (callable): The function to apply for aggregation.

    Examples:
        >>> cache.register_aggregation("values_", lambda values: torch.stack(values, dim=0))
    """
    self.aggregation_strategies[key_pattern] = function

to(device)

Moves the tensors in the cache to a specified device.

Parameters:

Name Type Description Default
device Union[str, device]

The device to move the tensors to.

required
Source code in easyroutine/interpretability/activation_cache.py
159
160
161
162
163
164
165
166
167
168
169
def to(self, device: Union[str, torch.device]):
    """
    Moves the tensors in the cache to a specified device.

    Args:
        device (Union[str, torch.device]): The device to move the tensors to.
    """

    for key, value in self.cache.items():
        if hasattr(value, "to"):
            self.cache[key] = value.to(device)

update(other)

Updates the cache with values from an additional dictionary or ActivationCache object. Arguments: other (Union[dict, ActivationCache]): Dictionary or ActivationCache object to update with.

Source code in easyroutine/interpretability/activation_cache.py
144
145
146
147
148
149
150
151
152
153
154
155
def update(self, other):
    """
    Updates the cache with values from an additional dictionary or ActivationCache object.
    Arguments:
        other (Union[dict, ActivationCache]): Dictionary or ActivationCache object to update with.
    """
    if isinstance(other, dict):
        self.cache.update(other)
    elif isinstance(other, ActivationCache):
        self.cache.update(other.cache)
    else:
        raise TypeError("Can only update with dict or ActivationCache objects.")

values()

Just like the dictionary values method, returns a list of values.

Source code in easyroutine/interpretability/activation_cache.py
138
139
140
141
142
def values(self):
    """
    Just like the dictionary values method, returns a list of values.
    """
    return self.cache.values()