Skip to content

vllm.transformers_utils.configs.medusa

MedusaConfig

Bases: PretrainedConfig

Source code in vllm/transformers_utils/configs/medusa.py
class MedusaConfig(PretrainedConfig):
    model_type = "medusa"

    def __init__(
        self,
        hidden_size: int = 4096,
        vocab_size: int = 32001,
        num_heads: int = 5,
        num_hidden_layers: int = 1,
        max_paths: int = 64,
        topk: int = 10,
        truncated_vocab_size: Optional[int] = None,
        **kwargs,
    ):
        self.hidden_size = hidden_size
        self.vocab_size = vocab_size
        self.num_heads = num_heads
        self.num_hidden_layers = num_hidden_layers
        self.max_paths = max_paths
        self.topk = topk
        self.max_seq_len = int(2**20)
        self.truncated_vocab_size = (
            vocab_size if truncated_vocab_size is None else truncated_vocab_size
        )
        if "architectures" not in kwargs:
            kwargs["architectures"] = ["MedusaModel"]

        super().__init__(**kwargs)

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        **kwargs,
    ) -> "MedusaConfig":
        config_dict, kwargs = cls.get_config_dict(
            pretrained_model_name_or_path, **kwargs
        )
        for k in list(config_dict.keys()):
            if "num" in k:
                if "heads" in k:
                    config_dict["num_heads"] = config_dict.pop(k)
                elif "layers" in k:
                    config_dict["num_hidden_layers"] = config_dict.pop(k)
        return cls.from_dict(config_dict, **kwargs)

    @property
    def num_attention_heads(self):
        return 0

    @property
    def num_lookahead_tokens(self):
        return self.num_heads

    @num_lookahead_tokens.setter
    def num_lookahead_tokens(self, num_lookahead_tokens: int):
        self.num_heads = num_lookahead_tokens

hidden_size instance-attribute

hidden_size = hidden_size

max_paths instance-attribute

max_paths = max_paths

max_seq_len instance-attribute

max_seq_len = int(2 ** 20)

model_type class-attribute instance-attribute

model_type = 'medusa'

num_attention_heads property

num_attention_heads

num_heads instance-attribute

num_heads = num_heads

num_hidden_layers instance-attribute

num_hidden_layers = num_hidden_layers

num_lookahead_tokens property writable

num_lookahead_tokens

topk instance-attribute

topk = topk

truncated_vocab_size instance-attribute

truncated_vocab_size = (
    vocab_size
    if truncated_vocab_size is None
    else truncated_vocab_size
)

vocab_size instance-attribute

vocab_size = vocab_size

__init__

__init__(
    hidden_size: int = 4096,
    vocab_size: int = 32001,
    num_heads: int = 5,
    num_hidden_layers: int = 1,
    max_paths: int = 64,
    topk: int = 10,
    truncated_vocab_size: Optional[int] = None,
    **kwargs,
)
Source code in vllm/transformers_utils/configs/medusa.py
def __init__(
    self,
    hidden_size: int = 4096,
    vocab_size: int = 32001,
    num_heads: int = 5,
    num_hidden_layers: int = 1,
    max_paths: int = 64,
    topk: int = 10,
    truncated_vocab_size: Optional[int] = None,
    **kwargs,
):
    self.hidden_size = hidden_size
    self.vocab_size = vocab_size
    self.num_heads = num_heads
    self.num_hidden_layers = num_hidden_layers
    self.max_paths = max_paths
    self.topk = topk
    self.max_seq_len = int(2**20)
    self.truncated_vocab_size = (
        vocab_size if truncated_vocab_size is None else truncated_vocab_size
    )
    if "architectures" not in kwargs:
        kwargs["architectures"] = ["MedusaModel"]

    super().__init__(**kwargs)

from_pretrained classmethod

from_pretrained(
    pretrained_model_name_or_path: Union[str, PathLike],
    **kwargs,
) -> MedusaConfig
Source code in vllm/transformers_utils/configs/medusa.py
@classmethod
def from_pretrained(
    cls,
    pretrained_model_name_or_path: Union[str, os.PathLike],
    **kwargs,
) -> "MedusaConfig":
    config_dict, kwargs = cls.get_config_dict(
        pretrained_model_name_or_path, **kwargs
    )
    for k in list(config_dict.keys()):
        if "num" in k:
            if "heads" in k:
                config_dict["num_heads"] = config_dict.pop(k)
            elif "layers" in k:
                config_dict["num_hidden_layers"] = config_dict.pop(k)
    return cls.from_dict(config_dict, **kwargs)