Skip to content

Module Wrapper

Introduction

Module Wrapper is the submodule that is responsible for managing the module wrappers. The module wrappers are essential to add custom hook where in the original transfomer codebase the hook is not available. For example, the transformer module does not have a hook to get the attention matrix of a head. The module wrapper is used to add this hook. The module_wrapper submodel is composed of the following files: - manager.py: The manager file is responsible for managing the module wrappers. It is the standard interface to add the wrap around models. - base.py: The base file is the base class for the module wrapper. Implement a base form of a Wrapper class. - model_name_attention.py: The model name attention file is the module wrapper for the attention matrix of a single model. When add a new model, add a new file with the name model_name_attention.py and implement the ModelNameAttention class. It is basically a copy of the forward pass of the attention module with the addition of the hook to get the attention matrix.

Manager Wrappers and Abstract Base Class

AttentionWrapperFactory

Maps a given model name to the correct attention wrapper class.

Source code in easyroutine/interpretability/module_wrappers/manager.py
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
class AttentionWrapperFactory:
    """
    Maps a given model name to the correct attention wrapper class.
    """

    AVAILABLE_MODULE_WRAPPERS:dict = {
                                    ChameleonAttentionWrapper.original_name(): ChameleonAttentionWrapper, 
                                    LlamaAttentionWrapper.original_name(): LlamaAttentionWrapper, 
                                    T5AttentionWrapper.original_name(): T5AttentionWrapper, 
                                    MistralAttentionWrapper.original_name(): MistralAttentionWrapper 
    }

    # MODEL_NAME_TO_WRAPPER = {
    #     "facebook/chameleon-7b": ChameleonAttentionWrapper,
    #     "facebook/chameleon-30b": ChameleonAttentionWrapper,
    #     "mistral-community/pixtral-12b": LlamaAttentionWrapper,
    #     "llava-hf/llava-v1.6-mistral-7b-hf": LlamaAttentionWrapper,
    #     "hf-internal-testing/tiny-random-LlamaForCausalLM": LlamaAttentionWrapper,
    #     "ChoereForAI/aya-101": T5AttentionWrapper,
    # }

    @staticmethod
    def get_wrapper_class(
        model: nn.Module,
    ) -> Union[
        Type[ChameleonAttentionWrapper],
        Type[LlamaAttentionWrapper],
        Type[T5AttentionWrapper],
        Type[MistralAttentionWrapper],
    ]:
        """
        Returns the attention wrapper class for the specified model name.
        Raises a ValueError if the model is not supported.
        """
        all_modules = find_all_modules(model, return_only_names=True)

        for candidate_name, candidate_wrappers in AttentionWrapperFactory.AVAILABLE_MODULE_WRAPPERS.items():
            if candidate_name in all_modules:
                LambdaLogger().info(f"Found a wrapper for {candidate_name}")
                return candidate_wrappers

        LambdaLogger().warning(f"Do not have any wrapper for {model}")

get_wrapper_class(model) staticmethod

Returns the attention wrapper class for the specified model name. Raises a ValueError if the model is not supported.

Source code in easyroutine/interpretability/module_wrappers/manager.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
@staticmethod
def get_wrapper_class(
    model: nn.Module,
) -> Union[
    Type[ChameleonAttentionWrapper],
    Type[LlamaAttentionWrapper],
    Type[T5AttentionWrapper],
    Type[MistralAttentionWrapper],
]:
    """
    Returns the attention wrapper class for the specified model name.
    Raises a ValueError if the model is not supported.
    """
    all_modules = find_all_modules(model, return_only_names=True)

    for candidate_name, candidate_wrappers in AttentionWrapperFactory.AVAILABLE_MODULE_WRAPPERS.items():
        if candidate_name in all_modules:
            LambdaLogger().info(f"Found a wrapper for {candidate_name}")
            return candidate_wrappers

    LambdaLogger().warning(f"Do not have any wrapper for {model}")

ModuleWrapperManager

Handles the logic of replacing an original attention class within a given model with a custom attention wrapper, based on user-specified model_name. Also allows restoring the original modules if needed, using a single recursive function.

Source code in easyroutine/interpretability/module_wrappers/manager.py
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
class ModuleWrapperManager:
    """
    Handles the logic of replacing an original attention class within a given model
    with a custom attention wrapper, based on user-specified model_name.
    Also allows restoring the original modules if needed, using a single
    recursive function.
    """

    def __init__(self, model: nn.Module, log_level: str = "INFO"):
        """
        Initializes the manager with a given model name.
        """
        self.logger = Logger(logname="ModuleWrapperManager", level=log_level)

        # Fetch the appropriate wrapper class for the given model name
        self.attention_wrapper_class = AttentionWrapperFactory.get_wrapper_class(model) # TODO: extend to support multiple module type for model
        # The original attention class name is fetched via a class method or attribute in the wrapper
        self.target_module_name = self.attention_wrapper_class.original_name() # TODO: extend to support multiple module type for model

        # Dictionary to store submodule_path -> original attention module
        self.original_modules = {}

    def __contains__(self, module_name:str):
        return module_name == self.target_module_name # TODO: extend to support multiple module type for model

    def substitute_attention_module(self, model: nn.Module) -> None:
        """
        Public method that performs the substitution of attention modules in the model.
        Logs each replacement. This will replace *all* modules whose class name
        matches `self.target_module_name`.
        """
        self._traverse_and_modify(model, parent_path="", mode="substitute")

    def restore_original_attention_module(self, model: nn.Module) -> None:
        """
        Public method that restores the original attention modules in the model.
        Logs each restoration.
        """
        self._traverse_and_modify(model, parent_path="", mode="restore")

    def _traverse_and_modify(self, module: nn.Module, parent_path: str, mode: str) -> None:
        """
        Recursively traverses `module` and either substitutes or restores each matching
        submodule, depending on `mode`.

        - mode="substitute": Replaces the original module (with class name == self.target_module_name)
                            with the wrapper, storing the original in self.original_modules.
        - mode="restore": Replaces the wrapper submodule (class name == self.attention_wrapper_class.__name__)
                        with the original module from self.original_modules.

        Args:
            module (nn.Module): The current module to inspect.
            parent_path (str): A string that tracks the 'path' of this submodule in the overall model hierarchy.
            mode (str): Either "substitute" or "restore".
        """
        for name, child in list(module.named_children()):
            # Identify the submodule path (e.g. "encoder.layer.0.attention")
            submodule_path = f"{parent_path}.{name}" if parent_path else name

            if mode == "substitute":
                # Look for the original module class name
                if child.__class__.__name__ == self.target_module_name:
                    # Store the original
                    self.original_modules[submodule_path] = child
                    # Wrap it
                    wrapped_module = self.attention_wrapper_class(child)
                    setattr(module, name, wrapped_module)

                    self.logger.info(
                        f"Substituted '{submodule_path}' with wrapper for {self.target_module_name}."
                    )
                else:
                    # Recurse
                    self._traverse_and_modify(child, submodule_path, mode="substitute")

            elif mode == "restore":
                # Look for the wrapper class name
                if child.__class__.__name__ == self.attention_wrapper_class.__name__:
                    if submodule_path in self.original_modules:
                        original_module = self.original_modules[submodule_path]
                        setattr(module, name, original_module)
                        self.logger.info(
                            f"Restored '{submodule_path}' to original {self.target_module_name}."
                        )
                    else:
                        self.logger.warning(
                            f"Found a wrapped submodule '{submodule_path}' but no original stored. Skipping."
                        )
                else:
                    # Recurse
                    self._traverse_and_modify(child, submodule_path, mode="restore")

__init__(model, log_level='INFO')

Initializes the manager with a given model name.

Source code in easyroutine/interpretability/module_wrappers/manager.py
72
73
74
75
76
77
78
79
80
81
82
83
84
def __init__(self, model: nn.Module, log_level: str = "INFO"):
    """
    Initializes the manager with a given model name.
    """
    self.logger = Logger(logname="ModuleWrapperManager", level=log_level)

    # Fetch the appropriate wrapper class for the given model name
    self.attention_wrapper_class = AttentionWrapperFactory.get_wrapper_class(model) # TODO: extend to support multiple module type for model
    # The original attention class name is fetched via a class method or attribute in the wrapper
    self.target_module_name = self.attention_wrapper_class.original_name() # TODO: extend to support multiple module type for model

    # Dictionary to store submodule_path -> original attention module
    self.original_modules = {}

restore_original_attention_module(model)

Public method that restores the original attention modules in the model. Logs each restoration.

Source code in easyroutine/interpretability/module_wrappers/manager.py
 97
 98
 99
100
101
102
def restore_original_attention_module(self, model: nn.Module) -> None:
    """
    Public method that restores the original attention modules in the model.
    Logs each restoration.
    """
    self._traverse_and_modify(model, parent_path="", mode="restore")

substitute_attention_module(model)

Public method that performs the substitution of attention modules in the model. Logs each replacement. This will replace all modules whose class name matches self.target_module_name.

Source code in easyroutine/interpretability/module_wrappers/manager.py
89
90
91
92
93
94
95
def substitute_attention_module(self, model: nn.Module) -> None:
    """
    Public method that performs the substitution of attention modules in the model.
    Logs each replacement. This will replace *all* modules whose class name
    matches `self.target_module_name`.
    """
    self._traverse_and_modify(model, parent_path="", mode="substitute")

AttentionMatrixHookModule

Bases: Module

Computation of the attention matrix. Note: it has been added just for adding custom hooks.

Source code in easyroutine/interpretability/module_wrappers/base.py
 4
 5
 6
 7
 8
 9
10
11
class AttentionMatrixHookModule(nn.Module):
    """Computation of the attention matrix. *Note*: it has been added just for adding custom hooks."""

    def forward(
            self,
            attention_matrix: torch.Tensor,
    ):
        return attention_matrix

BaseAttentionWrapper

Bases: Module

A base class for wrapping an original attention module.

Provides

_orig_module to store the real (unwrapped) attention. A robust __getattr__ that checks: 1) self.dict 2) self._modules 3) the base class 4) fallback to _orig_module

Source code in easyroutine/interpretability/module_wrappers/base.py
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
class BaseAttentionWrapper(nn.Module):
    """
    A base class for wrapping an original attention module.

    Provides:
        `_orig_module` to store the real (unwrapped) attention.
        A robust `__getattr__` that checks:
            1) self.__dict__
            2) self._modules
            3) the base class
            4) fallback to `_orig_module`
    """

    def __init__(self, original_module: nn.Module):
        super().__init__()
        # store the original module in a private attribute
        object.__setattr__(self, "_orig_module", original_module)

    def __getattr__(self, name: str):
        """
        If name is not in this wrapper, fall back to the original module.
        Also checks `self._modules` for submodules, because PyTorch
        automatically places them there.
        """
        # 1) get this wrapper's __dict__
        wrapper_dict = object.__getattribute__(self, "__dict__")

        # 2) if name is in our own instance dictionary, return it
        if name in wrapper_dict:
            return wrapper_dict[name]

        # 3) if name is in our submodules, return it
        modules_dict = wrapper_dict["_modules"]
        if name in modules_dict:
            return modules_dict[name]

        # 4) check if name is in our class (methods, etc.)
        cls = object.__getattribute__(self, "__class__")
        if hasattr(cls, name):
            return getattr(cls, name)

        # 5) fallback to _orig_module
        orig = wrapper_dict["_orig_module"]
        return getattr(orig, name)

    @staticmethod
    def original_name() -> str:
        """
        By default, you might override this in each derived class if you want
        your manager code to know which original class name this wrapper replaces.
        """
        return "BaseAttention"

__getattr__(name)

If name is not in this wrapper, fall back to the original module. Also checks self._modules for submodules, because PyTorch automatically places them there.

Source code in easyroutine/interpretability/module_wrappers/base.py
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
def __getattr__(self, name: str):
    """
    If name is not in this wrapper, fall back to the original module.
    Also checks `self._modules` for submodules, because PyTorch
    automatically places them there.
    """
    # 1) get this wrapper's __dict__
    wrapper_dict = object.__getattribute__(self, "__dict__")

    # 2) if name is in our own instance dictionary, return it
    if name in wrapper_dict:
        return wrapper_dict[name]

    # 3) if name is in our submodules, return it
    modules_dict = wrapper_dict["_modules"]
    if name in modules_dict:
        return modules_dict[name]

    # 4) check if name is in our class (methods, etc.)
    cls = object.__getattribute__(self, "__class__")
    if hasattr(cls, name):
        return getattr(cls, name)

    # 5) fallback to _orig_module
    orig = wrapper_dict["_orig_module"]
    return getattr(orig, name)

original_name() staticmethod

By default, you might override this in each derived class if you want your manager code to know which original class name this wrapper replaces.

Source code in easyroutine/interpretability/module_wrappers/base.py
58
59
60
61
62
63
64
@staticmethod
def original_name() -> str:
    """
    By default, you might override this in each derived class if you want
    your manager code to know which original class name this wrapper replaces.
    """
    return "BaseAttention"

Specific Module Wrappers

LlamaAttentionWrapper

Bases: BaseAttentionWrapper

A wrapper around the original LlamaAttention. It has: - The same named attributes (q_proj, k_proj, etc.), which are references to the original module's submodules/parameters. - A private reference (_orig_attn) to the entire original attention, for falling back if something isn't found on the wrapper itself. - An additional attention_matrix_hook for intercepting attention.

Source code in easyroutine/interpretability/module_wrappers/llama_attention.py
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
class LlamaAttentionWrapper(BaseAttentionWrapper):
    """
    A wrapper around the original LlamaAttention. It has:
    - The same named attributes (q_proj, k_proj, etc.), which are references
        to the original module's submodules/parameters.
    - A private reference (`_orig_attn`) to the entire original attention,
        for falling back if something isn't found on the wrapper itself.
    - An additional `attention_matrix_hook` for intercepting attention.
    """

    @staticmethod
    def original_name():
        return "LlamaAttention"

    def __init__(self, original_attention: nn.Module):
        """
        Store references to all relevant submodules so the wrapper
        "feels" the same. Also store a reference to the original module
        in a private attribute for fallback.
        """
        super().__init__(original_attention)

        # This is the private reference to the entire original attention.
        # We'll fallback to it for any attribute we haven't explicitly set.
        object.__setattr__(self, "_orig_attn", original_attention)

        # Now replicate the original attention's submodules as attributes of *this* wrapper.
        # These are direct references, not new modules:
        self.q_proj = original_attention.q_proj
        self.k_proj = original_attention.k_proj
        self.v_proj = original_attention.v_proj
        self.o_proj = original_attention.o_proj

        # Copy over any scalar attributes you need
        # self.num_heads = original_attention.num_heads
        # self.num_key_value_heads = original_attention.num_key_value_heads
        # self.num_key_value_groups = original_attention.num_key_value_groups
        self.head_dim = original_attention.head_dim
        # self.hidden_size = original_attention.hidden_size
        self.attention_dropout = original_attention.attention_dropout
        self.layer_idx = original_attention.layer_idx
        self.config = original_attention.config

        # Add your custom hook module
        self.attention_matrix_hook = AttentionMatrixHookModule()

    def forward(
        self,
        hidden_states: torch.Tensor,
        position_embeddings: Tuple[torch.Tensor, torch.Tensor],
        attention_mask: Optional[torch.Tensor],
        past_key_value: Optional[Cache] = None,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs: Unpack[FlashAttentionKwargs],
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        input_shape = hidden_states.shape[:-1]
        hidden_shape = (*input_shape, -1, self.head_dim)

        query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
        value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)

        cos, sin = position_embeddings
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(
                key_states, value_states, self.layer_idx, cache_kwargs
            )

        # Inline eager_attention_forward logic
        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)
        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) * self.scaling

        if attention_mask is not None:
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = self.attention_matrix_hook(attn_weights)
        attn_weights = nn.functional.dropout(
            attn_weights,
            p=0.0 if not self.training else self.attention_dropout,
            training=self.training,
        )

        attn_output = torch.matmul(attn_weights, value_states).transpose(1, 2).contiguous()
        # End inline

        attn_output = attn_output.reshape(*input_shape, -1).contiguous()
        attn_output = self.o_proj(attn_output)
        return attn_output, attn_weights # type: ignore

__init__(original_attention)

Store references to all relevant submodules so the wrapper "feels" the same. Also store a reference to the original module in a private attribute for fallback.

Source code in easyroutine/interpretability/module_wrappers/llama_attention.py
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
def __init__(self, original_attention: nn.Module):
    """
    Store references to all relevant submodules so the wrapper
    "feels" the same. Also store a reference to the original module
    in a private attribute for fallback.
    """
    super().__init__(original_attention)

    # This is the private reference to the entire original attention.
    # We'll fallback to it for any attribute we haven't explicitly set.
    object.__setattr__(self, "_orig_attn", original_attention)

    # Now replicate the original attention's submodules as attributes of *this* wrapper.
    # These are direct references, not new modules:
    self.q_proj = original_attention.q_proj
    self.k_proj = original_attention.k_proj
    self.v_proj = original_attention.v_proj
    self.o_proj = original_attention.o_proj

    # Copy over any scalar attributes you need
    # self.num_heads = original_attention.num_heads
    # self.num_key_value_heads = original_attention.num_key_value_heads
    # self.num_key_value_groups = original_attention.num_key_value_groups
    self.head_dim = original_attention.head_dim
    # self.hidden_size = original_attention.hidden_size
    self.attention_dropout = original_attention.attention_dropout
    self.layer_idx = original_attention.layer_idx
    self.config = original_attention.config

    # Add your custom hook module
    self.attention_matrix_hook = AttentionMatrixHookModule()

repeat_kv(hidden_states, n_rep)

(batch, num_key_value_heads, seq_len, head_dim) -> (batch, num_attention_heads, seq_len, head_dim)

Source code in easyroutine/interpretability/module_wrappers/llama_attention.py
37
38
39
40
41
42
43
44
45
46
47
48
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    (batch, num_key_value_heads, seq_len, head_dim)
        -> (batch, num_attention_heads, seq_len, head_dim)
    """
    bsz, num_kv_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(
        bsz, num_kv_heads, n_rep, slen, head_dim
    )
    return hidden_states.reshape(bsz, num_kv_heads * n_rep, slen, head_dim)

ChameleonAttentionWrapper

Bases: BaseAttentionWrapper

Attention wrapper for the Chameleon model.

Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
class ChameleonAttentionWrapper(BaseAttentionWrapper):
    """
    Attention wrapper for the Chameleon model.
    """

    @staticmethod
    def original_name():
        return "ChameleonAttention"

    def __init__(self, original_attention: nn.Module):
        super().__init__(original_attention)

        self.q_proj = original_attention.q_proj
        self.k_proj = original_attention.k_proj
        self.v_proj = original_attention.v_proj
        self.q_norm = original_attention.q_norm
        self.k_norm = original_attention.k_norm
        self.o_proj = original_attention.o_proj
        # self.softmax = original_attention.softmax
        self.attention_dropout = original_attention.attention_dropout
        self.training = original_attention.training
        self.layer_idx = original_attention.layer_idx
        self.num_heads = original_attention.num_heads
        self.num_key_value_heads = original_attention.num_key_value_heads
        self.num_key_value_groups = original_attention.num_key_value_groups
        self.head_dim = original_attention.head_dim
        self.hidden_size = original_attention.hidden_size
        self.rotary_emb = original_attention.rotary_emb


        self.attention_matrix_hook = AttentionMatrixHookModule()

        self.original_attention = original_attention


    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        position_ids: Optional[torch.LongTensor] = None,
        past_key_value: Optional[Cache] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
        cache_position: Optional[torch.LongTensor] = None,
        **kwargs,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        bsz, q_len, _ = hidden_states.size()

        query_states = self.q_proj(hidden_states)
        key_states = self.k_proj(hidden_states)
        value_states = self.v_proj(hidden_states)

        query_states = query_states.reshape(-1, self.num_heads, self.head_dim)
        query_states = self.q_norm(query_states)

        key_states = key_states.reshape(-1, self.num_key_value_heads, self.head_dim)
        key_states = self.k_norm(key_states)

        query_states = query_states.reshape(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
        key_states = key_states.reshape(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
        value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)

        cos, sin = self.rotary_emb(value_states, position_ids)
        query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)

        if past_key_value is not None:
            # sin and cos are specific to RoPE models; position_ids needed for the static cache
            cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
            key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)

        key_states = repeat_kv(key_states, self.num_key_value_groups)
        value_states = repeat_kv(value_states, self.num_key_value_groups)

        attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)

        if attention_mask is not None:  # no matter the length, we just slice it
            causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
            attn_weights = attn_weights + causal_mask

        # upcast attention to fp32
        # attn_weights = nn.functional.softmax(attn_weights, dim=-1).to(query_states.dtype)
        attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
        attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
        attn_weights = self.attention_matrix_hook(attn_weights)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value # type: ignore

apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1)

Applies Rotary Position Embedding to the query and key tensors.

Parameters:

Name Type Description Default
q `torch.Tensor`

The query tensor.

required
k `torch.Tensor`

The key tensor.

required
cos `torch.Tensor`

The cosine part of the rotary embedding.

required
sin `torch.Tensor`

The sine part of the rotary embedding.

required
position_ids `torch.Tensor`, *optional*

Deprecated and unused.

None
unsqueeze_dim `int`, *optional*, defaults to 1

The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.

1

Returns: tuple(torch.Tensor) comprising of the query and key tensors rotated using the Rotary Position Embedding.

Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
def apply_rotary_pos_emb( q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
    """Applies Rotary Position Embedding to the query and key tensors.

    Args:
        q (`torch.Tensor`): The query tensor.
        k (`torch.Tensor`): The key tensor.
        cos (`torch.Tensor`): The cosine part of the rotary embedding.
        sin (`torch.Tensor`): The sine part of the rotary embedding.
        position_ids (`torch.Tensor`, *optional*):
            Deprecated and unused.
        unsqueeze_dim (`int`, *optional*, defaults to 1):
            The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
            sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
            that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
            k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
            cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
            the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
    Returns:
        `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
    """

    cos = cos.unsqueeze(unsqueeze_dim)
    sin = sin.unsqueeze(unsqueeze_dim)
    q_embed = (q * cos) + (rotate_half(q) * sin)
    k_embed = (k * cos) + (rotate_half(k) * sin)
    return q_embed, k_embed

repeat_kv(hidden_states, n_rep)

This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)

Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
43
44
45
46
47
48
49
50
51
52
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
    """
    This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
    num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
    """
    batch, num_key_value_heads, slen, head_dim = hidden_states.shape
    if n_rep == 1:
        return hidden_states
    hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
    return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)

rotate_half(x)

Rotates half the hidden dims of the input.

Source code in easyroutine/interpretability/module_wrappers/chameleon_attention.py
 9
10
11
12
13
def rotate_half(x):
    """Rotates half the hidden dims of the input."""
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

T5AttentionWrapper

Bases: BaseAttentionWrapper

Source code in easyroutine/interpretability/module_wrappers/T5_attention.py
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
class T5AttentionWrapper(BaseAttentionWrapper):
    @staticmethod
    def original_name() -> str:
        return "T5Attention"

    def __init__(self, original_attention: nn.Module):
        super().__init__(original_attention)
        self.q = original_attention.q
        self.k = original_attention.k
        self.v = original_attention.v
        self.o = original_attention.o
        self.dropout = original_attention.dropout
        self.layer_idx = original_attention.layer_idx
        self.n_heads = original_attention.n_heads
        self.key_value_proj_dim = original_attention.key_value_proj_dim
        self.inner_dim = original_attention.inner_dim
        self.has_relative_attention_bias = (
            original_attention.has_relative_attention_bias
        )
        self.pruned_heads = original_attention.pruned_heads
        self.attention_matrix_hook = AttentionMatrixHookModule()
        self.original_attention = original_attention

    def forward(
        self,
        hidden_states,
        mask=None,
        key_value_states=None,
        position_bias=None,
        past_key_value=None,
        layer_head_mask=None,
        query_length=None,
        use_cache=False,
        output_attentions=False,
        cache_position=None,
    ):
        """
        Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
        """
        # Input is (batch_size, seq_length, dim)
        # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
        batch_size, seq_length = hidden_states.shape[:2]

        # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
        is_cross_attention = key_value_states is not None

        query_states = self.q(hidden_states)
        query_states = query_states.view(
            batch_size, -1, self.n_heads, self.key_value_proj_dim
        ).transpose(1, 2)

        if past_key_value is not None:
            is_updated = past_key_value.is_updated.get(self.layer_idx)
            if is_cross_attention:
                # after the first generated id, we can subsequently re-use all key/value_states from cache
                curr_past_key_value = past_key_value.cross_attention_cache
            else:
                curr_past_key_value = past_key_value.self_attention_cache

        current_states = key_value_states if is_cross_attention else hidden_states
        if is_cross_attention and past_key_value is not None and is_updated:  # type: ignore
            # reuse k,v, cross_attentions
            key_states = curr_past_key_value.key_cache[self.layer_idx]  # type: ignore
            value_states = curr_past_key_value.value_cache[self.layer_idx]  # type: ignore
        else:
            key_states = self.k(current_states)
            value_states = self.v(current_states)
            key_states = key_states.view(
                batch_size, -1, self.n_heads, self.key_value_proj_dim
            ).transpose(1, 2)
            value_states = value_states.view(
                batch_size, -1, self.n_heads, self.key_value_proj_dim
            ).transpose(1, 2)

            if past_key_value is not None:
                # save all key/value_states to cache to be re-used for fast auto-regressive generation
                cache_position = cache_position if not is_cross_attention else None
                key_states, value_states = curr_past_key_value.update(  # type: ignore
                    key_states,
                    value_states,
                    self.layer_idx,
                    {"cache_position": cache_position},
                )
                # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
                if is_cross_attention:
                    past_key_value.is_updated[self.layer_idx] = True

        # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
        scores = torch.matmul(query_states, key_states.transpose(3, 2))

        if position_bias is None:
            key_length = key_states.shape[-2]
            # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
            real_seq_length = (
                query_length if query_length is not None else cache_position[-1] + 1 # type: ignore
            )  # type: ignore
            if not self.has_relative_attention_bias:
                position_bias = torch.zeros(
                    (1, self.n_heads, seq_length, key_length),
                    device=scores.device,
                    dtype=scores.dtype,
                )
                if self.gradient_checkpointing and self.training:
                    position_bias.requires_grad = True
            else:
                position_bias = self.compute_bias(
                    real_seq_length,
                    key_length,
                    device=scores.device,
                    cache_position=cache_position,
                )
                position_bias = position_bias[:, :, -seq_length:, :]

            if mask is not None:
                causal_mask = mask[:, :, :, : key_states.shape[-2]]
                position_bias = position_bias + causal_mask

        if self.pruned_heads:
            mask = torch.ones(position_bias.shape[1])
            mask[list(self.pruned_heads)] = 0
            position_bias_masked = position_bias[:, mask.bool()]
        else:
            position_bias_masked = position_bias

        scores += position_bias_masked

        # (batch_size, n_heads, seq_length, key_length)
        attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
        attn_weights = self.attention_matrix_hook(attn_weights)
        attn_weights = nn.functional.dropout(
            attn_weights, p=self.dropout, training=self.training
        )

        # Mask heads if we want to
        if layer_head_mask is not None:
            attn_weights = attn_weights * layer_head_mask

        attn_output = torch.matmul(attn_weights, value_states)

        attn_output = attn_output.transpose(1, 2).contiguous()
        attn_output = attn_output.view(batch_size, -1, self.inner_dim)
        attn_output = self.o(attn_output)

        outputs = (attn_output, past_key_value, position_bias)

        if output_attentions:
            outputs = outputs + (attn_weights,)
        return outputs

forward(hidden_states, mask=None, key_value_states=None, position_bias=None, past_key_value=None, layer_head_mask=None, query_length=None, use_cache=False, output_attentions=False, cache_position=None)

Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).

Source code in easyroutine/interpretability/module_wrappers/T5_attention.py
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
def forward(
    self,
    hidden_states,
    mask=None,
    key_value_states=None,
    position_bias=None,
    past_key_value=None,
    layer_head_mask=None,
    query_length=None,
    use_cache=False,
    output_attentions=False,
    cache_position=None,
):
    """
    Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
    """
    # Input is (batch_size, seq_length, dim)
    # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
    batch_size, seq_length = hidden_states.shape[:2]

    # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
    is_cross_attention = key_value_states is not None

    query_states = self.q(hidden_states)
    query_states = query_states.view(
        batch_size, -1, self.n_heads, self.key_value_proj_dim
    ).transpose(1, 2)

    if past_key_value is not None:
        is_updated = past_key_value.is_updated.get(self.layer_idx)
        if is_cross_attention:
            # after the first generated id, we can subsequently re-use all key/value_states from cache
            curr_past_key_value = past_key_value.cross_attention_cache
        else:
            curr_past_key_value = past_key_value.self_attention_cache

    current_states = key_value_states if is_cross_attention else hidden_states
    if is_cross_attention and past_key_value is not None and is_updated:  # type: ignore
        # reuse k,v, cross_attentions
        key_states = curr_past_key_value.key_cache[self.layer_idx]  # type: ignore
        value_states = curr_past_key_value.value_cache[self.layer_idx]  # type: ignore
    else:
        key_states = self.k(current_states)
        value_states = self.v(current_states)
        key_states = key_states.view(
            batch_size, -1, self.n_heads, self.key_value_proj_dim
        ).transpose(1, 2)
        value_states = value_states.view(
            batch_size, -1, self.n_heads, self.key_value_proj_dim
        ).transpose(1, 2)

        if past_key_value is not None:
            # save all key/value_states to cache to be re-used for fast auto-regressive generation
            cache_position = cache_position if not is_cross_attention else None
            key_states, value_states = curr_past_key_value.update(  # type: ignore
                key_states,
                value_states,
                self.layer_idx,
                {"cache_position": cache_position},
            )
            # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
            if is_cross_attention:
                past_key_value.is_updated[self.layer_idx] = True

    # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
    scores = torch.matmul(query_states, key_states.transpose(3, 2))

    if position_bias is None:
        key_length = key_states.shape[-2]
        # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
        real_seq_length = (
            query_length if query_length is not None else cache_position[-1] + 1 # type: ignore
        )  # type: ignore
        if not self.has_relative_attention_bias:
            position_bias = torch.zeros(
                (1, self.n_heads, seq_length, key_length),
                device=scores.device,
                dtype=scores.dtype,
            )
            if self.gradient_checkpointing and self.training:
                position_bias.requires_grad = True
        else:
            position_bias = self.compute_bias(
                real_seq_length,
                key_length,
                device=scores.device,
                cache_position=cache_position,
            )
            position_bias = position_bias[:, :, -seq_length:, :]

        if mask is not None:
            causal_mask = mask[:, :, :, : key_states.shape[-2]]
            position_bias = position_bias + causal_mask

    if self.pruned_heads:
        mask = torch.ones(position_bias.shape[1])
        mask[list(self.pruned_heads)] = 0
        position_bias_masked = position_bias[:, mask.bool()]
    else:
        position_bias_masked = position_bias

    scores += position_bias_masked

    # (batch_size, n_heads, seq_length, key_length)
    attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
    attn_weights = self.attention_matrix_hook(attn_weights)
    attn_weights = nn.functional.dropout(
        attn_weights, p=self.dropout, training=self.training
    )

    # Mask heads if we want to
    if layer_head_mask is not None:
        attn_weights = attn_weights * layer_head_mask

    attn_output = torch.matmul(attn_weights, value_states)

    attn_output = attn_output.transpose(1, 2).contiguous()
    attn_output = attn_output.view(batch_size, -1, self.inner_dim)
    attn_output = self.o(attn_output)

    outputs = (attn_output, past_key_value, position_bias)

    if output_attentions:
        outputs = outputs + (attn_weights,)
    return outputs