Skip to content

LiteLLMModelInterface

LiteLLMInferenceModel

Bases: BaseInferenceModel

Source code in easyroutine/inference/litellm_model_interface.py
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
class LiteLLMInferenceModel(BaseInferenceModel):

    def __init__(self, config: LiteLLMInferenceModelConfig):
        self.config = config
        self.set_os_env()

    def set_os_env(self):
        import os
        os.environ['OPENAI_API_KEY'] = self.config.openai_api_key
        os.environ['ANTHROPIC_API_KEY'] = self.config.anthropic_api_key
        os.environ['XAI_API_KEY'] = self.config.xai_api_key

    def convert_chat_messages_to_custom_format(self, chat_messages: List[dict[str, str]]) -> List[dict[str, str]]:
        """
        For now, VLLM is compatible with the chat template format we use.
        """
        return chat_messages

    def chat(self, chat_messages: List[dict[str, str]], use_tqdm=False, **kwargs) -> list:
        """
        Generate a response based on the provided chat messages.

        Arguments:
            chat_messages (List[dict[str, str]]): List of chat messages to process.
            **kwargs: Additional parameters for the model.

        Returns:
            str: The generated response from the model.
        """
        chat_messages = self.convert_chat_messages_to_custom_format(chat_messages)


        response = completion(
            model = self.config.model_name,
            messages = chat_messages,
            temperature = self.config.temperature,
            top_p = self.config.top_p,
            max_tokens = self.config.max_new_tokens,
        )
        return response['choices']

    def batch_chat(self, chat_messages: List[List[dict[str, str]]], use_tqdm=False, **kwargs) -> List[list]:
        """
        Generate responses for a batch of chat messages.

        Arguments:
            chat_messages (List[List[dict[str, str]]]): List of chat messages to process.
            **kwargs: Additional parameters for the model.

        Returns:
            List[list]: List of generated responses from the model.
        """
        chat_messages = [self.convert_chat_messages_to_custom_format(msg) for msg in chat_messages]

        responses = batch_completion(
            model = self.config.model_name,
            messages = chat_messages,
            temperature = self.config.temperature,
            top_p = self.config.top_p,
            max_tokens = self.config.max_new_tokens,
        )
        return responses

batch_chat(chat_messages, use_tqdm=False, **kwargs)

Generate responses for a batch of chat messages.

Parameters:

Name Type Description Default
chat_messages List[List[dict[str, str]]]

List of chat messages to process.

required
**kwargs

Additional parameters for the model.

{}

Returns:

Type Description
List[list]

List[list]: List of generated responses from the model.

Source code in easyroutine/inference/litellm_model_interface.py
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
def batch_chat(self, chat_messages: List[List[dict[str, str]]], use_tqdm=False, **kwargs) -> List[list]:
    """
    Generate responses for a batch of chat messages.

    Arguments:
        chat_messages (List[List[dict[str, str]]]): List of chat messages to process.
        **kwargs: Additional parameters for the model.

    Returns:
        List[list]: List of generated responses from the model.
    """
    chat_messages = [self.convert_chat_messages_to_custom_format(msg) for msg in chat_messages]

    responses = batch_completion(
        model = self.config.model_name,
        messages = chat_messages,
        temperature = self.config.temperature,
        top_p = self.config.top_p,
        max_tokens = self.config.max_new_tokens,
    )
    return responses

chat(chat_messages, use_tqdm=False, **kwargs)

Generate a response based on the provided chat messages.

Parameters:

Name Type Description Default
chat_messages List[dict[str, str]]

List of chat messages to process.

required
**kwargs

Additional parameters for the model.

{}

Returns:

Name Type Description
str list

The generated response from the model.

Source code in easyroutine/inference/litellm_model_interface.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
def chat(self, chat_messages: List[dict[str, str]], use_tqdm=False, **kwargs) -> list:
    """
    Generate a response based on the provided chat messages.

    Arguments:
        chat_messages (List[dict[str, str]]): List of chat messages to process.
        **kwargs: Additional parameters for the model.

    Returns:
        str: The generated response from the model.
    """
    chat_messages = self.convert_chat_messages_to_custom_format(chat_messages)


    response = completion(
        model = self.config.model_name,
        messages = chat_messages,
        temperature = self.config.temperature,
        top_p = self.config.top_p,
        max_tokens = self.config.max_new_tokens,
    )
    return response['choices']

convert_chat_messages_to_custom_format(chat_messages)

For now, VLLM is compatible with the chat template format we use.

Source code in easyroutine/inference/litellm_model_interface.py
34
35
36
37
38
def convert_chat_messages_to_custom_format(self, chat_messages: List[dict[str, str]]) -> List[dict[str, str]]:
    """
    For now, VLLM is compatible with the chat template format we use.
    """
    return chat_messages

LiteLLMInferenceModelConfig dataclass

Bases: BaseInferenceModelConfig

just a placeholder for now, as we don't have any specific config for VLLM.

Source code in easyroutine/inference/litellm_model_interface.py
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
@dataclass
class LiteLLMInferenceModelConfig(BaseInferenceModelConfig):
    """just a placeholder for now, as we don't have any specific config for VLLM."""
    model_name: str

    n_gpus: int = 0
    dtype: str = 'bfloat16'
    temperature: float = 0
    top_p: float = 0.95
    max_new_tokens: int = 5000

    openai_api_key: str = ''
    anthropic_api_key: str = ''
    xai_api_key: str = ''