Skip to content

vllm.model_executor.layers.rotary_embedding.phi3_long_rope_scaled_rope

Phi3LongRoPEScaledRotaryEmbedding

Bases: Module

Phi3 family of models scaled rotary embedding.

Based on the original RotaryEmbedding implementation.

Source code in vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py
class Phi3LongRoPEScaledRotaryEmbedding(nn.Module):
    """Phi3 family of models scaled rotary embedding.

    Based on the original RotaryEmbedding implementation.
    """

    def __init__(
        self,
        head_size: int,
        rotary_dim: int,
        max_position_embeddings: int,
        original_max_position_embeddings: int,
        base: float,
        is_neox_style: bool,
        dtype: torch.dtype,
        short_factor: list[float],
        long_factor: list[float],
        short_mscale: Optional[float] = None,
        long_mscale: Optional[float] = None,
    ):
        super().__init__()

        if is_neox_style is False:
            raise ValueError(
                "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
            )

        self.rotary_dim = rotary_dim
        self.head_size = head_size
        self.max_position_embeddings = max_position_embeddings
        self.original_max_position_embeddings = original_max_position_embeddings
        self.base = base
        self.short_factor = short_factor
        self.long_factor = long_factor

        scale = self.max_position_embeddings / self.original_max_position_embeddings
        if scale <= 1.0:
            scaling_factor = 1.0
        else:
            scaling_factor = math.sqrt(
                1 + math.log(scale) / math.log(self.original_max_position_embeddings)
            )
        if short_mscale is None:
            short_mscale = scaling_factor
        if long_mscale is None:
            long_mscale = scaling_factor

        self.short_mscale = short_mscale
        self.long_mscale = long_mscale

        short_cache = self._compute_cos_sin_cache(
            original_max_position_embeddings, short_factor, short_mscale
        )
        short_cache = short_cache.to(dtype)

        long_cache = self._compute_cos_sin_cache(
            max_position_embeddings, long_factor, long_mscale
        )
        long_cache = long_cache.to(dtype)

        long_short_cache = torch.cat([short_cache, long_cache], dim=0)
        self.register_buffer(
            "long_short_cos_sin_cache", long_short_cache, persistent=False
        )

    def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor:
        rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
        inv_freq = 1.0 / (
            rescale_factors
            * (
                self.base
                ** (
                    torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
                    / self.rotary_dim
                )
            )
        )
        return inv_freq

    def _compute_cos_sin_cache(
        self,
        max_position_embeddings: int,
        rescale_factors: list[float],
        mscale: float,
    ) -> torch.Tensor:
        inv_freq = self._compute_inv_freq(rescale_factors)
        t = torch.arange(max_position_embeddings, dtype=torch.float)
        freqs = torch.einsum("i,j -> ij", t, inv_freq)
        cos = freqs.cos() * mscale
        sin = freqs.sin() * mscale
        cache = torch.cat((cos, sin), dim=-1)
        return cache

    def forward(
        self,
        positions: torch.Tensor,
        query: torch.Tensor,
        key: Optional[torch.Tensor] = None,
        offsets: Optional[torch.Tensor] = None,
    ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
        assert key is not None
        query = query.view(*query.shape[:-1], -1, self.head_size)
        key = key.view(*key.shape[:-1], -1, self.head_size)

        k = self.original_max_position_embeddings
        long_prompt_offset = (
            torch.any(positions > k).float() * torch.full_like(positions, k)
        ).long()
        idx = (
            torch.add(positions, long_prompt_offset)
            if long_prompt_offset is not None
            else positions
        )
        idx = torch.add(idx, offsets) if offsets is not None else idx
        cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)

        cos, sin = cos_sin.chunk(2, dim=-1)
        cos = cos.repeat(1, 2).unsqueeze(-2)
        sin = sin.repeat(1, 2).unsqueeze(-2)

        query_rot = query[..., : self.rotary_dim]
        query_pass = query[..., self.rotary_dim :]
        query_rot = query_rot * cos + rotate_neox(query_rot) * sin
        query = torch.cat((query_rot, query_pass), dim=-1)

        key_rot = key[..., : self.rotary_dim]
        key_pass = key[..., self.rotary_dim :]
        key_rot = key_rot * cos + rotate_neox(key_rot) * sin
        key = torch.cat((key_rot, key_pass), dim=-1)

        return query.flatten(-2), key.flatten(-2)

base instance-attribute

base = base

head_size instance-attribute

head_size = head_size

long_factor instance-attribute

long_factor = long_factor

long_mscale instance-attribute

long_mscale = long_mscale

max_position_embeddings instance-attribute

max_position_embeddings = max_position_embeddings

original_max_position_embeddings instance-attribute

original_max_position_embeddings = (
    original_max_position_embeddings
)

rotary_dim instance-attribute

rotary_dim = rotary_dim

short_factor instance-attribute

short_factor = short_factor

short_mscale instance-attribute

short_mscale = short_mscale

__init__

__init__(
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    original_max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: dtype,
    short_factor: list[float],
    long_factor: list[float],
    short_mscale: Optional[float] = None,
    long_mscale: Optional[float] = None,
)
Source code in vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py
def __init__(
    self,
    head_size: int,
    rotary_dim: int,
    max_position_embeddings: int,
    original_max_position_embeddings: int,
    base: float,
    is_neox_style: bool,
    dtype: torch.dtype,
    short_factor: list[float],
    long_factor: list[float],
    short_mscale: Optional[float] = None,
    long_mscale: Optional[float] = None,
):
    super().__init__()

    if is_neox_style is False:
        raise ValueError(
            "`Phi3LongRoPEScaledRotaryEmbedding` only supports neox_style."
        )

    self.rotary_dim = rotary_dim
    self.head_size = head_size
    self.max_position_embeddings = max_position_embeddings
    self.original_max_position_embeddings = original_max_position_embeddings
    self.base = base
    self.short_factor = short_factor
    self.long_factor = long_factor

    scale = self.max_position_embeddings / self.original_max_position_embeddings
    if scale <= 1.0:
        scaling_factor = 1.0
    else:
        scaling_factor = math.sqrt(
            1 + math.log(scale) / math.log(self.original_max_position_embeddings)
        )
    if short_mscale is None:
        short_mscale = scaling_factor
    if long_mscale is None:
        long_mscale = scaling_factor

    self.short_mscale = short_mscale
    self.long_mscale = long_mscale

    short_cache = self._compute_cos_sin_cache(
        original_max_position_embeddings, short_factor, short_mscale
    )
    short_cache = short_cache.to(dtype)

    long_cache = self._compute_cos_sin_cache(
        max_position_embeddings, long_factor, long_mscale
    )
    long_cache = long_cache.to(dtype)

    long_short_cache = torch.cat([short_cache, long_cache], dim=0)
    self.register_buffer(
        "long_short_cos_sin_cache", long_short_cache, persistent=False
    )

_compute_cos_sin_cache

_compute_cos_sin_cache(
    max_position_embeddings: int,
    rescale_factors: list[float],
    mscale: float,
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py
def _compute_cos_sin_cache(
    self,
    max_position_embeddings: int,
    rescale_factors: list[float],
    mscale: float,
) -> torch.Tensor:
    inv_freq = self._compute_inv_freq(rescale_factors)
    t = torch.arange(max_position_embeddings, dtype=torch.float)
    freqs = torch.einsum("i,j -> ij", t, inv_freq)
    cos = freqs.cos() * mscale
    sin = freqs.sin() * mscale
    cache = torch.cat((cos, sin), dim=-1)
    return cache

_compute_inv_freq

_compute_inv_freq(rescale_factors: list[float]) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py
def _compute_inv_freq(self, rescale_factors: list[float]) -> torch.Tensor:
    rescale_factors = torch.tensor(rescale_factors, dtype=torch.float32)
    inv_freq = 1.0 / (
        rescale_factors
        * (
            self.base
            ** (
                torch.arange(0, self.rotary_dim, 2, dtype=torch.float)
                / self.rotary_dim
            )
        )
    )
    return inv_freq

forward

forward(
    positions: Tensor,
    query: Tensor,
    key: Optional[Tensor] = None,
    offsets: Optional[Tensor] = None,
) -> tuple[Tensor, Optional[Tensor]]
Source code in vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py
def forward(
    self,
    positions: torch.Tensor,
    query: torch.Tensor,
    key: Optional[torch.Tensor] = None,
    offsets: Optional[torch.Tensor] = None,
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
    assert key is not None
    query = query.view(*query.shape[:-1], -1, self.head_size)
    key = key.view(*key.shape[:-1], -1, self.head_size)

    k = self.original_max_position_embeddings
    long_prompt_offset = (
        torch.any(positions > k).float() * torch.full_like(positions, k)
    ).long()
    idx = (
        torch.add(positions, long_prompt_offset)
        if long_prompt_offset is not None
        else positions
    )
    idx = torch.add(idx, offsets) if offsets is not None else idx
    cos_sin = torch.index_select(self.long_short_cos_sin_cache, 0, idx)

    cos, sin = cos_sin.chunk(2, dim=-1)
    cos = cos.repeat(1, 2).unsqueeze(-2)
    sin = sin.repeat(1, 2).unsqueeze(-2)

    query_rot = query[..., : self.rotary_dim]
    query_pass = query[..., self.rotary_dim :]
    query_rot = query_rot * cos + rotate_neox(query_rot) * sin
    query = torch.cat((query_rot, query_pass), dim=-1)

    key_rot = key[..., : self.rotary_dim]
    key_pass = key[..., self.rotary_dim :]
    key_rot = key_rot * cos + rotate_neox(key_rot) * sin
    key = torch.cat((key_rot, key_pass), dim=-1)

    return query.flatten(-2), key.flatten(-2)