Module Wrapper
Introduction
Module Wrapper is the submodule that is responsible for managing the module wrappers. The module wrappers are essential to add custom hook where in the original transfomer codebase the hook is not available. For example, the transformer
module does not have a hook to get the attention matrix of a head. The module wrapper is used to add this hook. The module_wrapper
submodel is composed of the following files:
- manager.py
: The manager file is responsible for managing the module wrappers. It is the standard interface to add the wrap around models.
- base.py
: The base file is the base class for the module wrapper. Implement a base form of a Wrapper class.
- model_name_attention.py
: The model name attention file is the module wrapper for the attention matrix of a single model. When add a new model, add a new file with the name model_name_attention.py
and implement the ModelNameAttention
class. It is basically a copy of the forward pass of the attention module with the addition of the hook to get the attention matrix.
Manager Wrappers and Abstract Base Class
AttentionWrapperFactory
Maps a given model name to the correct attention wrapper class.
Source code in easyroutine/interpretability/module_wrappers/manager.py
19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
|
get_wrapper_class(model)
staticmethod
Returns the attention wrapper class for the specified model name. Raises a ValueError if the model is not supported.
Source code in easyroutine/interpretability/module_wrappers/manager.py
40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
|
ModuleWrapperManager
Handles the logic of replacing an original attention class within a given model with a custom attention wrapper, based on user-specified model_name. Also allows restoring the original modules if needed, using a single recursive function.
Source code in easyroutine/interpretability/module_wrappers/manager.py
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
|
__init__(model, log_level='INFO')
Initializes the manager with a given model name.
Source code in easyroutine/interpretability/module_wrappers/manager.py
72 73 74 75 76 77 78 79 80 81 82 83 84 |
|
restore_original_attention_module(model)
Public method that restores the original attention modules in the model. Logs each restoration.
Source code in easyroutine/interpretability/module_wrappers/manager.py
97 98 99 100 101 102 |
|
substitute_attention_module(model)
Public method that performs the substitution of attention modules in the model.
Logs each replacement. This will replace all modules whose class name
matches self.target_module_name
.
Source code in easyroutine/interpretability/module_wrappers/manager.py
89 90 91 92 93 94 95 |
|
AttentionMatrixHookModule
Bases: Module
Computation of the attention matrix. Note: it has been added just for adding custom hooks.
Source code in easyroutine/interpretability/module_wrappers/base.py
4 5 6 7 8 9 10 11 |
|
BaseAttentionWrapper
Bases: Module
A base class for wrapping an original attention module.
Provides
_orig_module
to store the real (unwrapped) attention.
A robust __getattr__
that checks:
1) self.dict
2) self._modules
3) the base class
4) fallback to _orig_module
Source code in easyroutine/interpretability/module_wrappers/base.py
13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
|
__getattr__(name)
If name is not in this wrapper, fall back to the original module.
Also checks self._modules
for submodules, because PyTorch
automatically places them there.
Source code in easyroutine/interpretability/module_wrappers/base.py
31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
|
original_name()
staticmethod
By default, you might override this in each derived class if you want your manager code to know which original class name this wrapper replaces.
Source code in easyroutine/interpretability/module_wrappers/base.py
58 59 60 61 62 63 64 |
|
Specific Module Wrappers
LlamaAttentionWrapper
Bases: BaseAttentionWrapper
A wrapper around the original LlamaAttention. It has:
- The same named attributes (q_proj, k_proj, etc.), which are references
to the original module's submodules/parameters.
- A private reference (_orig_attn
) to the entire original attention,
for falling back if something isn't found on the wrapper itself.
- An additional attention_matrix_hook
for intercepting attention.
Source code in easyroutine/interpretability/module_wrappers/llama_attention.py
78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
|
__init__(original_attention)
Store references to all relevant submodules so the wrapper "feels" the same. Also store a reference to the original module in a private attribute for fallback.
Source code in easyroutine/interpretability/module_wrappers/llama_attention.py
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
|
repeat_kv(hidden_states, n_rep)
(batch, num_key_value_heads, seq_len, head_dim) -> (batch, num_attention_heads, seq_len, head_dim)
Source code in easyroutine/interpretability/module_wrappers/llama_attention.py
37 38 39 40 41 42 43 44 45 46 47 48 |
|
ChameleonAttentionWrapper
Bases: BaseAttentionWrapper
Attention wrapper for the Chameleon model.
Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
|
apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1)
Applies Rotary Position Embedding to the query and key tensors.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
q
|
`torch.Tensor`
|
The query tensor. |
required |
k
|
`torch.Tensor`
|
The key tensor. |
required |
cos
|
`torch.Tensor`
|
The cosine part of the rotary embedding. |
required |
sin
|
`torch.Tensor`
|
The sine part of the rotary embedding. |
required |
position_ids
|
`torch.Tensor`, *optional*
|
Deprecated and unused. |
None
|
unsqueeze_dim
|
`int`, *optional*, defaults to 1
|
The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2. |
1
|
Returns:
tuple(torch.Tensor)
comprising of the query and key tensors rotated using the Rotary Position Embedding.
Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
|
repeat_kv(hidden_states, n_rep)
This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
43 44 45 46 47 48 49 50 51 52 |
|
rotate_half(x)
Rotates half the hidden dims of the input.
Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
9 10 11 12 13 |
|
T5AttentionWrapper
Bases: BaseAttentionWrapper
Source code in easyroutine/interpretability/module_wrappers/T5_attention.py
10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
|
forward(hidden_states, mask=None, key_value_states=None, position_bias=None, past_key_value=None, layer_head_mask=None, query_length=None, use_cache=False, output_attentions=False, cache_position=None)
Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
Source code in easyroutine/interpretability/module_wrappers/T5_attention.py
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
|