Module Wrapper
Introduction
Module Wrapper is the submodule that is responsible for managing the module wrappers. The module wrappers are essential to add custom hook where in the original transfomer codebase the hook is not available. For example, the transformer
module does not have a hook to get the attention matrix of a head. The module wrapper is used to add this hook. The module_wrapper
submodel is composed of the following files:
- manager.py
: The manager file is responsible for managing the module wrappers. It is the standard interface to add the wrap around models.
- base.py
: The base file is the base class for the module wrapper. Implement a base form of a Wrapper class.
- model_name_attention.py
: The model name attention file is the module wrapper for the attention matrix of a single model. When add a new model, add a new file with the name model_name_attention.py
and implement the ModelNameAttention
class. It is basically a copy of the forward pass of the attention module with the addition of the hook to get the attention matrix.
Manager Wrappers and Abstract Base Class
AttentionWrapperFactory
Maps a given model name to the correct attention wrapper class.
Source code in easyroutine/interpretability/module_wrappers/manager.py
23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
|
get_wrapper_classes(model)
staticmethod
Returns a dictionary mapping module names to their corresponding wrapper classes for all supported modules found in the model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model to analyze |
required |
Returns:
Type | Description |
---|---|
Dict[str, Union[Type[ChameleonAttentionWrapper], Type[LlamaAttentionWrapper], Type[T5AttentionWrapper], Type[MistralAttentionWrapper], Type[Gemma3AttentionWrapper]]]
|
Dict[str, Type]: Dictionary mapping original module names to wrapper classes |
Source code in easyroutine/interpretability/module_wrappers/manager.py
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 |
|
ModuleWrapperManager
Handles the logic of replacing original modules within a given model with custom wrappers. Supports multiple module types per model. Also allows restoring the original modules if needed.
Source code in easyroutine/interpretability/module_wrappers/manager.py
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 |
|
__contains__(module_name)
Check if a module name is supported by this manager.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
module_name
|
str
|
The name of the module to check |
required |
Returns:
Name | Type | Description |
---|---|---|
bool |
bool
|
True if the module is supported, False otherwise |
Source code in easyroutine/interpretability/module_wrappers/manager.py
109 110 111 112 113 114 115 116 117 118 119 |
|
__init__(model, log_level='INFO')
Initializes the manager with a given model.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model whose modules will be wrapped |
required |
log_level
|
str
|
Logging level |
'INFO'
|
Source code in easyroutine/interpretability/module_wrappers/manager.py
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
|
restore_original_attention_module(model)
Public method that restores the original modules in the model. Logs each restoration.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model whose modules will be restored |
required |
Source code in easyroutine/interpretability/module_wrappers/manager.py
131 132 133 134 135 136 137 138 139 |
|
substitute_attention_module(model)
Public method that performs the substitution of all supported modules in the model. Logs each replacement.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
model
|
Module
|
The model whose modules will be wrapped |
required |
Source code in easyroutine/interpretability/module_wrappers/manager.py
121 122 123 124 125 126 127 128 129 |
|
AttentionMatrixHookModule
Bases: Module
Computation of the attention matrix. Note: it has been added just for adding custom hooks.
Source code in easyroutine/interpretability/module_wrappers/base.py
4 5 6 7 8 9 10 11 |
|
BaseAttentionWrapper
Bases: Module
A base class for wrapping an original attention module.
Provides
_orig_module
to store the real (unwrapped) attention.
A robust __getattr__
that checks:
1) self.dict
2) self._modules
3) the base class
4) fallback to _orig_module
Source code in easyroutine/interpretability/module_wrappers/base.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
|
__getattr__(name)
If name is not in this wrapper, fall back to the original module.
Also checks self._modules
for submodules, because PyTorch
automatically places them there.
Source code in easyroutine/interpretability/module_wrappers/base.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
original_name()
staticmethod
By default, you might override this in each derived class if you want your manager code to know which original class name this wrapper replaces.
Source code in easyroutine/interpretability/module_wrappers/base.py
58 59 60 61 62 63 64 |
|
Specific Module Wrappers
LlamaAttentionWrapper
Bases: BaseAttentionWrapper
A wrapper around the original LlamaAttention. It has:
- The same named attributes (q_proj, k_proj, etc.), which are references
to the original module's submodules/parameters.
- A private reference (_orig_attn
) to the entire original attention,
for falling back if something isn't found on the wrapper itself.
- An additional attention_matrix_hook
for intercepting attention.
Source code in easyroutine/interpretability/module_wrappers/llama_attention.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
|
__init__(original_attention)
Store references to all relevant submodules so the wrapper "feels" the same. Also store a reference to the original module in a private attribute for fallback.
Source code in easyroutine/interpretability/module_wrappers/llama_attention.py
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
|
repeat_kv(hidden_states, n_rep)
(batch, num_key_value_heads, seq_len, head_dim) -> (batch, num_attention_heads, seq_len, head_dim)
Source code in easyroutine/interpretability/module_wrappers/llama_attention.py
37 38 39 40 41 42 43 44 45 46 47 48 |
|
ChameleonAttentionWrapper
Bases: BaseAttentionWrapper
Attention wrapper for the Chameleon model.
Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
|
apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1)
Applies Rotary Position Embedding to the query and key tensors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q
|
`torch.Tensor`
|
The query tensor. |
required |
k
|
`torch.Tensor`
|
The key tensor. |
required |
cos
|
`torch.Tensor`
|
The cosine part of the rotary embedding. |
required |
sin
|
`torch.Tensor`
|
The sine part of the rotary embedding. |
required |
position_ids
|
`torch.Tensor`, *optional*
|
Deprecated and unused. |
None
|
unsqueeze_dim
|
`int`, *optional*, defaults to 1
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
1
|
Returns:
tuple(torch.Tensor)
comprising of the query and key tensors rotated using the Rotary Position Embedding.
Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
|
repeat_kv(hidden_states, n_rep)
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
43 44 45 46 47 48 49 50 51 52 |
|
rotate_half(x)
Rotates half the hidden dims of the input.
Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
9 10 11 12 13 |
|
T5AttentionWrapper
Bases: BaseAttentionWrapper
Source code in easyroutine/interpretability/module_wrappers/T5_attention.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
|
forward(hidden_states, mask=None, key_value_states=None, position_bias=None, past_key_value=None, layer_head_mask=None, query_length=None, use_cache=False, output_attentions=False, cache_position=None)
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
Source code in easyroutine/interpretability/module_wrappers/T5_attention.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
|