Skip to content

vllm.model_executor.layers.rotary_embedding.common

logger module-attribute

logger = init_logger(__name__)

_flashinfer_rotary_embedding

_flashinfer_rotary_embedding(
    positions: Tensor,
    query: Tensor,
    key: Tensor,
    head_size: int,
    cos_sin_cache: Tensor,
    is_neox: bool,
) -> None

Custom op wrapper for flashinfer's rotary embedding.

This is an in-place operation that modifies query and key tensors directly.

Source code in vllm/model_executor/layers/rotary_embedding/common.py
def _flashinfer_rotary_embedding(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
    """Custom op wrapper for flashinfer's rotary embedding.

    This is an in-place operation that modifies query and key tensors directly.
    """
    from flashinfer.rope import apply_rope_with_cos_sin_cache_inplace

    apply_rope_with_cos_sin_cache_inplace(
        positions=positions,
        query=query,
        key=key,
        head_size=head_size,
        cos_sin_cache=cos_sin_cache,
        is_neox=is_neox,
    )

_flashinfer_rotary_embedding_fake

_flashinfer_rotary_embedding_fake(
    positions: Tensor,
    query: Tensor,
    key: Tensor,
    head_size: int,
    cos_sin_cache: Tensor,
    is_neox: bool,
) -> None
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def _flashinfer_rotary_embedding_fake(
    positions: torch.Tensor,
    query: torch.Tensor,
    key: torch.Tensor,
    head_size: int,
    cos_sin_cache: torch.Tensor,
    is_neox: bool,
) -> None:
    return

apply_rotary_emb_dispatch

apply_rotary_emb_dispatch(
    x: Tensor, cos: Tensor, sin: Tensor, is_neox_style: bool
) -> Tensor

Parameters:

Name Type Description Default
x Tensor

[num_tokens, num_heads, head_size]

required
cos Tensor

[num_tokens, head_size // 2]

required
sin Tensor

[num_tokens, head_size // 2]

required
is_neox_style bool

Whether to use the Neox-style or GPT-J-style rotary positional embeddings.

required
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def apply_rotary_emb_dispatch(
    x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, is_neox_style: bool
) -> torch.Tensor:
    """
    Args:
        x: [num_tokens, num_heads, head_size]
        cos: [num_tokens, head_size // 2]
        sin: [num_tokens, head_size // 2]
        is_neox_style: Whether to use the Neox-style or GPT-J-style rotary
            positional embeddings.
    """
    if current_platform.is_cuda():
        return apply_rotary_emb(x.unsqueeze(0), cos, sin, not is_neox_style).squeeze(0)
    else:
        return apply_rotary_emb_torch(x, cos, sin, is_neox_style)

apply_rotary_emb_torch

apply_rotary_emb_torch(
    x: Tensor, cos: Tensor, sin: Tensor, is_neox_style: bool
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def apply_rotary_emb_torch(
    x: torch.Tensor,
    cos: torch.Tensor,
    sin: torch.Tensor,
    is_neox_style: bool,
) -> torch.Tensor:
    cos = cos.unsqueeze(-2).to(x.dtype)
    sin = sin.unsqueeze(-2).to(x.dtype)
    if is_neox_style:
        x1, x2 = torch.chunk(x, 2, dim=-1)
    else:
        x1 = x[..., ::2]
        x2 = x[..., 1::2]
    o1 = x1 * cos - x2 * sin
    o2 = x2 * cos + x1 * sin
    if is_neox_style:
        return torch.cat((o1, o2), dim=-1)
    else:
        return torch.stack((o1, o2), dim=-1).flatten(-2)

dispatch_rotary_emb_function cached

dispatch_rotary_emb_function(
    default: Optional[Callable[..., Tensor]] = None,
) -> Callable[..., Tensor]
Source code in vllm/model_executor/layers/rotary_embedding/common.py
@cache
def dispatch_rotary_emb_function(
    default: Optional[Callable[..., torch.Tensor]] = None,
) -> Callable[..., torch.Tensor]:
    if current_platform.is_cuda():
        return apply_rotary_emb

    if current_platform.is_rocm():
        if find_spec("flash_attn") is not None:
            from flash_attn.ops.triton.rotary import apply_rotary

            return apply_rotary
        else:
            logger.warning(
                "flash_attn is not installed. Falling back to PyTorch "
                "implementation for rotary embeddings."
            )

    if default is not None:
        return default
    else:
        return apply_rotary_emb_torch

rotate_gptj

rotate_gptj(x: Tensor) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def rotate_gptj(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., ::2]
    x2 = x[..., 1::2]
    x = torch.stack((-x2, x1), dim=-1)
    return x.flatten(-2)

rotate_neox

rotate_neox(x: Tensor) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def rotate_neox(x: torch.Tensor) -> torch.Tensor:
    x1 = x[..., : x.shape[-1] // 2]
    x2 = x[..., x.shape[-1] // 2 :]
    return torch.cat((-x2, x1), dim=-1)

yarn_find_correction_dim

yarn_find_correction_dim(
    num_rotations: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> float
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_find_correction_dim(
    num_rotations: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> float:
    return (dim * math.log(max_position_embeddings / (num_rotations * 2 * math.pi))) / (
        2 * math.log(base)
    )

yarn_find_correction_range

yarn_find_correction_range(
    low_rot: int,
    high_rot: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> tuple[int, int]
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_find_correction_range(
    low_rot: int,
    high_rot: int,
    dim: int,
    base: float = 10000,
    max_position_embeddings: int = 2048,
) -> tuple[int, int]:
    low = math.floor(
        yarn_find_correction_dim(low_rot, dim, base, max_position_embeddings)
    )
    high = math.ceil(
        yarn_find_correction_dim(high_rot, dim, base, max_position_embeddings)
    )
    return max(low, 0), min(high, dim - 1)  # Clamp values just in case

yarn_get_mscale

yarn_get_mscale(scale: float = 1) -> float
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_get_mscale(scale: float = 1) -> float:
    if scale <= 1:
        return 1.0
    return 0.1 * math.log(scale) + 1.0

yarn_linear_ramp_mask

yarn_linear_ramp_mask(
    low: float, high: float, dim: int, dtype: dtype
) -> Tensor
Source code in vllm/model_executor/layers/rotary_embedding/common.py
def yarn_linear_ramp_mask(
    low: float, high: float, dim: int, dtype: torch.dtype
) -> torch.Tensor:
    if low == high:
        high += 0.001  # Prevent singularity

    linear_func = (torch.arange(dim, dtype=dtype) - low) / (high - low)
    ramp_func = torch.clamp(linear_func, 0, 1)
    return ramp_func