Skip to content

vllm.model_executor.layers.rotary_embedding.deepseek_scaling_rope

DeepseekScalingRotaryEmbedding

Bases: RotaryEmbedding

RotaryEmbedding extended with YaRN method.

Credits to Peng et al. github.com/jquesnelle/yarn

Source code in vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
    """RotaryEmbedding extended with YaRN method.

    Credits to Peng et al. github.com/jquesnelle/yarn
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        scaling_factor: float,
        dtype: torch.dtype,
        *,
        extrapolation_factor: float = 1,
        attn_factor: float = 1,
        beta_fast: int = 32,
        beta_slow: int = 1,
        mscale: float = 1,
        mscale_all_dim: float = 0,
    ) -> None:
        self.scaling_factor = scaling_factor
        self.extrapolation_factor = extrapolation_factor
        self.attn_factor = attn_factor
        self.beta_fast = beta_fast
        self.beta_slow = beta_slow
        # Get n-d magnitude scaling corrected for interpolation.
        self.mscale = float(
            yarn_get_mscale(self.scaling_factor, float(mscale))
            / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
            * attn_factor
        )
        super().__init__(
            head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
        )

    def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
        pos_freqs = self.base ** (
            torch.arange(
                0,
                self.rotary_dim,
                2,
                dtype=torch.float,
                device=current_platform.device_type,
            )
            / self.rotary_dim
        )
        inv_freq_extrapolation = 1.0 / pos_freqs
        inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

        low, high = yarn_find_correction_range(
            self.beta_fast,
            self.beta_slow,
            self.rotary_dim,
            self.base,
            self.max_position_embeddings,
        )
        # Get n-d rotational scaling corrected for extrapolation
        inv_freq_mask = (
            1
            - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
        ) * self.extrapolation_factor
        inv_freq = (
            inv_freq_interpolation * (1 - inv_freq_mask)
            + inv_freq_extrapolation * inv_freq_mask
        )
        return inv_freq

    def _compute_cos_sin_cache(self) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(self.scaling_factor)
        t = torch.arange(
            self.max_position_embeddings * self.scaling_factor,
            device=current_platform.device_type,
            dtype=torch.float32,
        )
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos() * self.mscale
        sin = freqs.sin() * self.mscale
        cache = torch.cat((cos, sin), dim=-1)
        return cache

    def forward_native(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        """PyTorch-native implementation equivalent to forward()."""
        assert key is not None
        self._match_cos_sin_cache_dtype(query)
        query_rot = query[..., : self.rotary_dim]
        key_rot = key[..., : self.rotary_dim]
        if self.rotary_dim < self.head_size:
            query_pass = query[..., self.rotary_dim :]
            key_pass = key[..., self.rotary_dim :]

        cos_sin = self.cos_sin_cache[
            torch.add(positions, offsets) if offsets is not None else positions
        ]
        cos, sin = cos_sin.chunk(2, dim=-1)
        if self.is_neox_style:
            # NOTE(woosuk): Here we assume that the positions tensor has the
            # shape [batch_size, seq_len].
            cos = cos.repeat(1, 1, 2).unsqueeze(-2)
            sin = sin.repeat(1, 1, 2).unsqueeze(-2)
        else:
            cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
            sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

        rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
        query_rot = query_rot * cos + rotate_fn(query_rot) * sin
        key_rot = key_rot * cos + rotate_fn(key_rot) * sin

        if self.rotary_dim < self.head_size:
            query = torch.cat((query_rot, query_pass), dim=-1)
            key = torch.cat((key_rot, key_pass), dim=-1)
        else:
            query = query_rot
            key = key_rot
        return query, key

    def forward_cuda(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        return self.forward_native(positions, query, key, offsets)

attn_factor instance-attribute

attn_factor = attn_factor

beta_fast instance-attribute

beta_fast = beta_fast

beta_slow instance-attribute

beta_slow = beta_slow

extrapolation_factor instance-attribute

extrapolation_factor = extrapolation_factor

mscale instance-attribute

mscale = float(
    yarn_get_mscale(scaling_factor, float(mscale))
    / yarn_get_mscale(scaling_factor, float(mscale_all_dim))
    * attn_factor
)

scaling_factor instance-attribute

scaling_factor = scaling_factor

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factor: float,
    dtype: dtype,
    *,
    extrapolation_factor: float = 1,
    attn_factor: float = 1,
    beta_fast: int = 32,
    beta_slow: int = 1,
    mscale: float = 1,
    mscale_all_dim: float = 0,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    scaling_factor: float,
    dtype: torch.dtype,
    *,
    extrapolation_factor: float = 1,
    attn_factor: float = 1,
    beta_fast: int = 32,
    beta_slow: int = 1,
    mscale: float = 1,
    mscale_all_dim: float = 0,
) -> None:
    self.scaling_factor = scaling_factor
    self.extrapolation_factor = extrapolation_factor
    self.attn_factor = attn_factor
    self.beta_fast = beta_fast
    self.beta_slow = beta_slow
    # Get n-d magnitude scaling corrected for interpolation.
    self.mscale = float(
        yarn_get_mscale(self.scaling_factor, float(mscale))
        / yarn_get_mscale(self.scaling_factor, float(mscale_all_dim))
        * attn_factor
    )
    super().__init__(
        head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
    )

_compute_cos_sin_cache

_compute_cos_sin_cache() -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
def _compute_cos_sin_cache(self) -> torch.Tensor:
    inv_freq = self._compute_inv_freq(self.scaling_factor)
    t = torch.arange(
        self.max_position_embeddings * self.scaling_factor,
        device=current_platform.device_type,
        dtype=torch.float32,
    )
    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    cos = freqs.cos() * self.mscale
    sin = freqs.sin() * self.mscale
    cache = torch.cat((cos, sin), dim=-1)
    return cache

_compute_inv_freq

_compute_inv_freq(scaling_factor: float) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
    pos_freqs = self.base ** (
        torch.arange(
            0,
            self.rotary_dim,
            2,
            dtype=torch.float,
            device=current_platform.device_type,
        )
        / self.rotary_dim
    )
    inv_freq_extrapolation = 1.0 / pos_freqs
    inv_freq_interpolation = 1.0 / (scaling_factor * pos_freqs)

    low, high = yarn_find_correction_range(
        self.beta_fast,
        self.beta_slow,
        self.rotary_dim,
        self.base,
        self.max_position_embeddings,
    )
    # Get n-d rotational scaling corrected for extrapolation
    inv_freq_mask = (
        1
        - yarn_linear_ramp_mask(low, high, self.rotary_dim // 2, dtype=torch.float)
    ) * self.extrapolation_factor
    inv_freq = (
        inv_freq_interpolation * (1 - inv_freq_mask)
        + inv_freq_extrapolation * inv_freq_mask
    )
    return inv_freq

forward_cuda

forward_cuda(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
def forward_cuda(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    return self.forward_native(positions, query, key, offsets)

forward_native

forward_native(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]

PyTorch-native implementation equivalent to forward().

Source code in vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
def forward_native(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    """PyTorch-native implementation equivalent to forward()."""
    assert key is not None
    self._match_cos_sin_cache_dtype(query)
    query_rot = query[..., : self.rotary_dim]
    key_rot = key[..., : self.rotary_dim]
    if self.rotary_dim < self.head_size:
        query_pass = query[..., self.rotary_dim :]
        key_pass = key[..., self.rotary_dim :]

    cos_sin = self.cos_sin_cache[
        torch.add(positions, offsets) if offsets is not None else positions
    ]
    cos, sin = cos_sin.chunk(2, dim=-1)
    if self.is_neox_style:
        # NOTE(woosuk): Here we assume that the positions tensor has the
        # shape [batch_size, seq_len].
        cos = cos.repeat(1, 1, 2).unsqueeze(-2)
        sin = sin.repeat(1, 1, 2).unsqueeze(-2)
    else:
        cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
        sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)

    rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
    query_rot = query_rot * cos + rotate_fn(query_rot) * sin
    key_rot = key_rot * cos + rotate_fn(key_rot) * sin

    if self.rotary_dim < self.head_size:
        query = torch.cat((query_rot, query_pass), dim=-1)
        key = torch.cat((key_rot, key_pass), dim=-1)
    else:
        query = query_rot
        key = key_rot
    return query, key

yarn_get_mscale

yarn_get_mscale(
    scale: float = 1, mscale: float = 1
) -> float
Source code in vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * mscale * math.log(scale) + 1.0