Skip to content

vllm.model_executor.layers.fla.ops.index

prepare_chunk_indices

prepare_chunk_indices(
    cu_seqlens: LongTensor, chunk_size: int
) -> LongTensor
Source code in vllm/model_executor/layers/fla/ops/index.py
@tensor_cache
def prepare_chunk_indices(
    cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
    indices = torch.cat(
        [
            torch.arange(n)
            for n in triton.cdiv(prepare_lens(cu_seqlens), chunk_size).tolist()
        ]
    )
    return torch.stack([indices.eq(0).cumsum(0) - 1, indices], 1).to(cu_seqlens)

prepare_chunk_offsets

prepare_chunk_offsets(
    cu_seqlens: LongTensor, chunk_size: int
) -> LongTensor
Source code in vllm/model_executor/layers/fla/ops/index.py
@tensor_cache
def prepare_chunk_offsets(
    cu_seqlens: torch.LongTensor, chunk_size: int
) -> torch.LongTensor:
    return torch.cat(
        [cu_seqlens.new_tensor([0]), triton.cdiv(prepare_lens(cu_seqlens), chunk_size)]
    ).cumsum(-1)

prepare_lens

prepare_lens(cu_seqlens: LongTensor) -> LongTensor
Source code in vllm/model_executor/layers/fla/ops/index.py
@tensor_cache
def prepare_lens(cu_seqlens: torch.LongTensor) -> torch.LongTensor:
    return cu_seqlens[1:] - cu_seqlens[:-1]