Skip to content

Models

ModelConfig dataclass

Configuration class for storing model specific parameters.

Attributes:

Name Type Description
residual_stream_input_hook_name str

Name of the residual stream torch module where attach the hook

residual_stream_hook_name str

Name of the residual stram torch module where attach the hook

intermediate_stream_hook_name str

Name of the intermediate stream torch module where attach the hook

residual_stream_input_post_layernorm_hook_name str

Name of the residual stream input post layer norm

attn_value_hook_name str

Name of the attention value torch module where attach the hook

attn_in_hook_name str

Name of the attention input torch module where attach the hook

attn_out_hook_name str

Name of the attention output torch module where attach the hook

attn_matrix_hook_name str

Name of the attention matrix torch module where attach the hook

attn_out_proj_weight str

Name of the attention output projection weight

attn_out_proj_bias str

Name of the attention output projection bias

embed_tokens str

Name of the embedding tokens torch module where attach the hook

num_hidden_layers int

Number of hidden layers

num_attention_heads int

Number of attention heads

hidden_size int

Hidden size of the transformer model

num_key_value_heads int

Number of key value heads

num_key_value_groups int

Number of key value groups

head_dim int

Dimension of the attention head

Source code in easyroutine/interpretability/models.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
@dataclass
class ModelConfig:
    r"""
    Configuration class for storing model specific parameters.

    Attributes:
        residual_stream_input_hook_name (str): Name of the residual stream torch module where attach the hook
        residual_stream_hook_name (str): Name of the residual stram torch module where attach the hook
        intermediate_stream_hook_name (str): Name of the intermediate stream torch module where attach the hook
        residual_stream_input_post_layernorm_hook_name (str): Name of the residual stream input post layer norm 
        attn_value_hook_name (str): Name of the attention value torch module where attach the hook
        attn_in_hook_name (str): Name of the attention input torch module where attach the hook
        attn_out_hook_name (str): Name of the attention output torch module where attach the hook
        attn_matrix_hook_name (str): Name of the attention matrix torch module where attach the hook
        attn_out_proj_weight (str): Name of the attention output projection weight
        attn_out_proj_bias (str): Name of the attention output projection bias
        embed_tokens (str): Name of the embedding tokens torch module where attach the hook
        num_hidden_layers (int): Number of hidden layers
        num_attention_heads (int): Number of attention heads
        hidden_size (int): Hidden size of the transformer model
        num_key_value_heads (int): Number of key value heads
        num_key_value_groups (int): Number of key value groups
        head_dim (int): Dimension of the attention head

    """

    residual_stream_input_hook_name: str # Name of the residual stream torch module where attach the hook
    residual_stream_hook_name: str # Name of the residual stram torch module where attach the hook
    intermediate_stream_hook_name: str # Name of the intermediate stream torch module where attach the hook
    residual_stream_input_post_layernorm_hook_name: str # Name of the residual stream input post layernorm torch module where attach the hook
    attn_value_hook_name: str # Name of the attention value torch module where attach the hook
    attn_in_hook_name: str # Name of the attention input torch module where attach the hook
    attn_out_hook_name: str # Name of the attention output torch module where attach the hook
    attn_matrix_hook_name: str # Name of the attention matrix torch module where attach the hook

    attn_out_proj_weight: str # Name of the attention output projection weight
    attn_out_proj_bias: str # Name of the attention output projection bias
    embed_tokens: str # Name of the embedding tokens torch module where attach the hook

    num_hidden_layers: int # Number of hidden layers
    num_attention_heads: int # Number of attention heads
    hidden_size: int # Hidden size of the transformer model
    num_key_value_heads: int # Number of key value heads
    num_key_value_groups: int # Number of key value groups
    head_dim: int # Dimension of the attention head

ModelFactory

This class is a factory to load the model and the processor. It supports the following models:

Supported Models

The following models are supported by this factory:

  • Chameleon-7b: A 7-billion parameter model for general-purpose tasks.
  • Chameleon-30b: A larger version of the Chameleon series with 30 billion parameters.
  • llava-hf/llava-v1.6-mistral-7b-hf: A 7-billion parameter model for multimodal tasks.
  • Pixtral-12b: Optimized for image-to-text tasks.
  • Emu3-Chat: Fine-tuned for conversational AI.
  • Emu3-Gen: Specialized in text generation tasks.
  • Emu3-Stage1: Pretrained for multi-stage training pipelines.
  • hf-internal-testing: A tiny model for internal testing purposes.
Adding a New Model

To add a new model: 1. Implement its logic in the load_model method. 2. Ensure it is correctly initialized and validated.

Source code in easyroutine/interpretability/models.py
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
class ModelFactory:
    r"""
    This class is a factory to load the model and the processor. It supports the following models:

    Supported Models:
        The following models are supported by this factory:

        - **Chameleon-7b**: A 7-billion parameter model for general-purpose tasks.
        - **Chameleon-30b**: A larger version of the Chameleon series with 30 billion parameters.
        - **llava-hf/llava-v1.6-mistral-7b-hf**: A 7-billion parameter model for multimodal tasks.
        - **Pixtral-12b**: Optimized for image-to-text tasks.
        - **Emu3-Chat**: Fine-tuned for conversational AI.
        - **Emu3-Gen**: Specialized in text generation tasks.
        - **Emu3-Stage1**: Pretrained for multi-stage training pipelines.
        - **hf-internal-testing**: A tiny model for internal testing purposes.

    Adding a New Model:
        To add a new model:
        1. Implement its logic in the `load_model` method.
        2. Ensure it is correctly initialized and validated.
    """

    @staticmethod
    def load_model(
        model_name: str,
        attn_implementation: str,
        torch_dtype: torch.dtype,
        device_map: str,
    ):
        r"""
        Load the model and its configuration based on the model name.

        Args:
            model_name (str): Name of the model to load.
            attn_implementation (str): Attention implementation type. (eager, flash-attn, sdp)
            torch_dtype (torch.dtype): Data type of the model.
            device_map (str): Device map for the model.

        Returns:
            model (HuggingFaceModel): Model instance.
            model_config (ModelConfig): Model configuration.
        """
        if attn_implementation != "eager":
            LambdaLogger.log(
                "Using an attention type different from eager or custom eager could have unexpected behavior in some experiments!",
                "WARNING",
            )

        language_model = None
        if model_name in ["facebook/chameleon-7b", "facebook/chameleon-30b"]:
            model = ChameleonForConditionalGeneration.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                device_map=device_map,
                attn_implementation=attn_implementation,
            )
            model_config = ModelFactory._create_model_config(model.config, prefix="model.")

        elif model_name in [
            "mistral-community/pixtral-12b",
            "llava-hf/llava-v1.6-mistral-7b-hf",
        ]:
            if model_name == "mistral-community/pixtral-12b":
                model = LlavaForConditionalGeneration.from_pretrained(
                    model_name,
                    torch_dtype=torch_dtype,
                    device_map=device_map,
                    attn_implementation=attn_implementation,
                )
            elif model_name == "llava-hf/llava-v1.6-mistral-7b-hf":
                model = LlavaNextForConditionalGeneration.from_pretrained(
                    model_name,
                    torch_dtype=torch_dtype,
                    device_map=device_map,
                    attn_implementation=attn_implementation,
                )
            else:
                raise ValueError("Unsupported model_name")
            language_model = model.language_model
            model_config = ModelFactory._create_model_config(model.language_model.config, prefix="language_model.model.")

        elif model_name in ["Emu3-Chat", "Emu3-Gen", "Emu3-Stage1"]:
            raise NotImplementedError("Emu3 model not implemented yet")

        elif model_name in ["hf-internal-testing/tiny-random-LlamaForCausalLM"]:
            model = LlamaForCausalLM.from_pretrained(
                model_name, torch_dtype=torch_dtype, device_map=device_map, attn_implementation=attn_implementation
            )
            model_config = ModelFactory._create_model_config(model.config)

        elif model_name in ["CohereForAI/aya-101"]:
            model = T5ForConditionalGeneration.from_pretrained(
                model_name, torch_dtype=torch_dtype, device_map=device_map, attn_implementation=attn_implementation
            )
            language_model = None
            model_config = ModelFactory._create_model_config(model.config, prefix="encoder.")

        else:
            raise ValueError("Unsupported model_name")
        return model, language_model, model_config

    @staticmethod
    def _create_model_config(model_config, prefix="model.",):
        return ModelConfig(
            residual_stream_input_hook_name=f"{prefix}layers[{{}}].input",
            residual_stream_hook_name=f"{prefix}layers[{{}}].output",
            intermediate_stream_hook_name=f"{prefix}layers[{{}}].post_attention_layernorm.output",
            residual_stream_input_post_layernorm_hook_name=f"{prefix}layers[{{}}].self_attn.input",
            attn_value_hook_name=f"{prefix}layers[{{}}].self_attn.v_proj.output",
            attn_out_hook_name=f"{prefix}layers[{{}}].self_attn.o_proj.output",
            attn_in_hook_name=f"{prefix}layers[{{}}].self_attn.input",
            attn_matrix_hook_name=f"{prefix}layers[{{}}].self_attn.attention_matrix_hook.output",
            attn_out_proj_weight=f"{prefix}layers[{{}}].self_attn.o_proj.weight",
            attn_out_proj_bias=f"{prefix}layers[{{}}].self_attn.o_proj.bias",
            embed_tokens=f"{prefix}embed_tokens.input",
            num_hidden_layers=model_config.num_hidden_layers,
            num_attention_heads=model_config.num_attention_heads,
            hidden_size=model_config.hidden_size,
            num_key_value_heads=model_config.num_key_value_heads,
            num_key_value_groups=model_config.num_attention_heads // model_config.num_key_value_heads,
            head_dim=model_config.hidden_size // model_config.num_attention_heads,
        )

load_model(model_name, attn_implementation, torch_dtype, device_map) staticmethod

Load the model and its configuration based on the model name.

Parameters:

Name Type Description Default
model_name str

Name of the model to load.

required
attn_implementation str

Attention implementation type. (eager, flash-attn, sdp)

required
torch_dtype dtype

Data type of the model.

required
device_map str

Device map for the model.

required

Returns:

Name Type Description
model HuggingFaceModel

Model instance.

model_config ModelConfig

Model configuration.

Source code in easyroutine/interpretability/models.py
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
@staticmethod
def load_model(
    model_name: str,
    attn_implementation: str,
    torch_dtype: torch.dtype,
    device_map: str,
):
    r"""
    Load the model and its configuration based on the model name.

    Args:
        model_name (str): Name of the model to load.
        attn_implementation (str): Attention implementation type. (eager, flash-attn, sdp)
        torch_dtype (torch.dtype): Data type of the model.
        device_map (str): Device map for the model.

    Returns:
        model (HuggingFaceModel): Model instance.
        model_config (ModelConfig): Model configuration.
    """
    if attn_implementation != "eager":
        LambdaLogger.log(
            "Using an attention type different from eager or custom eager could have unexpected behavior in some experiments!",
            "WARNING",
        )

    language_model = None
    if model_name in ["facebook/chameleon-7b", "facebook/chameleon-30b"]:
        model = ChameleonForConditionalGeneration.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            device_map=device_map,
            attn_implementation=attn_implementation,
        )
        model_config = ModelFactory._create_model_config(model.config, prefix="model.")

    elif model_name in [
        "mistral-community/pixtral-12b",
        "llava-hf/llava-v1.6-mistral-7b-hf",
    ]:
        if model_name == "mistral-community/pixtral-12b":
            model = LlavaForConditionalGeneration.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                device_map=device_map,
                attn_implementation=attn_implementation,
            )
        elif model_name == "llava-hf/llava-v1.6-mistral-7b-hf":
            model = LlavaNextForConditionalGeneration.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                device_map=device_map,
                attn_implementation=attn_implementation,
            )
        else:
            raise ValueError("Unsupported model_name")
        language_model = model.language_model
        model_config = ModelFactory._create_model_config(model.language_model.config, prefix="language_model.model.")

    elif model_name in ["Emu3-Chat", "Emu3-Gen", "Emu3-Stage1"]:
        raise NotImplementedError("Emu3 model not implemented yet")

    elif model_name in ["hf-internal-testing/tiny-random-LlamaForCausalLM"]:
        model = LlamaForCausalLM.from_pretrained(
            model_name, torch_dtype=torch_dtype, device_map=device_map, attn_implementation=attn_implementation
        )
        model_config = ModelFactory._create_model_config(model.config)

    elif model_name in ["CohereForAI/aya-101"]:
        model = T5ForConditionalGeneration.from_pretrained(
            model_name, torch_dtype=torch_dtype, device_map=device_map, attn_implementation=attn_implementation
        )
        language_model = None
        model_config = ModelFactory._create_model_config(model.config, prefix="encoder.")

    else:
        raise ValueError("Unsupported model_name")
    return model, language_model, model_config

TokenizerFactory

This class return the right tokenizer for the model. If the model is multimodal return is_a_process == True

Source code in easyroutine/interpretability/models.py
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
class TokenizerFactory:
    r"""
    This class return the right tokenizer for the model. If the model is multimodal return is_a_process == True
    """
    @staticmethod
    def load_tokenizer(model_name: str, torch_dtype: torch.dtype, device_map: str):
        r"""
        Load the tokenizer based on the model name.

        Args:
            model_name (str): Name of the model to load.
            torch_dtype (torch.dtype): Data type of the model.
            device_map (str): Device map for the model.

        Returns:
            processor (Tokenizer): Processor instance.
            is_a_processor (bool): True if the model is multimodal, False otherwise.
        """
        if model_name in ["facebook/chameleon-7b", "facebook/chameleon-30b"]:
            processor = ChameleonProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                device_map=device_map,
            )
            is_a_processor = True
        elif model_name in ["meta-llama/Llama-3.2-1B", "meta-llama/Llama-3.2-3B"]:
            processor = LlamaTokenizerFast.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                device_map=device_map,
            )
            is_a_processor = False
        elif model_name in ["mistral-community/pixtral-12b"]:
            processor = PixtralProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                device_map=device_map,
            )
            is_a_processor = True
        elif model_name in ["llava-hf/llava-v1.6-mistral-7b-hf"]:
            processor = LlavaNextProcessor.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                device_map=device_map,
            )
            is_a_processor = True
        elif model_name in ["Emu3-Chat", "Emu3-Gen", "Emu3-Stage1"]:
            raise NotImplementedError("Emu3 model not implemented yet")
        elif model_name in ["hf-internal-testing/tiny-random-LlamaForCausalLM"]:
            processor = LlamaTokenizer.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                device_map=device_map,
            )
            is_a_processor = False
        elif model_name in ["CohereForAI/aya-101"]:
            processor = T5TokenizerFast.from_pretrained(
                model_name,
                torch_dtype=torch_dtype,
                device_map=device_map,
            )
            is_a_processor = False

        else:
            raise ValueError("Unsupported model_name")

        return processor, is_a_processor

load_tokenizer(model_name, torch_dtype, device_map) staticmethod

Load the tokenizer based on the model name.

Parameters:

Name Type Description Default
model_name str

Name of the model to load.

required
torch_dtype dtype

Data type of the model.

required
device_map str

Device map for the model.

required

Returns:

Name Type Description
processor Tokenizer

Processor instance.

is_a_processor bool

True if the model is multimodal, False otherwise.

Source code in easyroutine/interpretability/models.py
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
@staticmethod
def load_tokenizer(model_name: str, torch_dtype: torch.dtype, device_map: str):
    r"""
    Load the tokenizer based on the model name.

    Args:
        model_name (str): Name of the model to load.
        torch_dtype (torch.dtype): Data type of the model.
        device_map (str): Device map for the model.

    Returns:
        processor (Tokenizer): Processor instance.
        is_a_processor (bool): True if the model is multimodal, False otherwise.
    """
    if model_name in ["facebook/chameleon-7b", "facebook/chameleon-30b"]:
        processor = ChameleonProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            device_map=device_map,
        )
        is_a_processor = True
    elif model_name in ["meta-llama/Llama-3.2-1B", "meta-llama/Llama-3.2-3B"]:
        processor = LlamaTokenizerFast.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            device_map=device_map,
        )
        is_a_processor = False
    elif model_name in ["mistral-community/pixtral-12b"]:
        processor = PixtralProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            device_map=device_map,
        )
        is_a_processor = True
    elif model_name in ["llava-hf/llava-v1.6-mistral-7b-hf"]:
        processor = LlavaNextProcessor.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            device_map=device_map,
        )
        is_a_processor = True
    elif model_name in ["Emu3-Chat", "Emu3-Gen", "Emu3-Stage1"]:
        raise NotImplementedError("Emu3 model not implemented yet")
    elif model_name in ["hf-internal-testing/tiny-random-LlamaForCausalLM"]:
        processor = LlamaTokenizer.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            device_map=device_map,
        )
        is_a_processor = False
    elif model_name in ["CohereForAI/aya-101"]:
        processor = T5TokenizerFast.from_pretrained(
            model_name,
            torch_dtype=torch_dtype,
            device_map=device_map,
        )
        is_a_processor = False

    else:
        raise ValueError("Unsupported model_name")

    return processor, is_a_processor