Skip to content

vllm.utils.tensor_schema

logger module-attribute

logger = init_logger(__name__)

TensorSchema

Source code in vllm/utils/tensor_schema.py
class TensorSchema:
    def __init__(
        self,
        *,
        validate: bool = True,
        resolve_bindings: Optional[dict[str, int]] = None,
        **kwargs: Any,
    ) -> None:
        super().__init__()

        self._resolve_bindings = resolve_bindings if resolve_bindings else {}

        for key, value in kwargs.items():
            setattr(self, key, value)

        if validate:
            self.validate()

    def __getitem__(self, key: str) -> Any:
        return getattr(self, key)

    def get(self, key: str, default: Any = None) -> Any:
        return getattr(self, key, default)

    def _match_shape_with_dynamic(
        self,
        actual: tuple[int, ...],
        reference: tuple[int, ...],
        expected_shape: tuple[Union[int, str], ...],
        dynamic_dims: set[str],
    ) -> bool:
        if len(actual) != len(reference) or len(actual) > len(expected_shape):
            return False

        for i, (a, r) in enumerate(zip(actual, reference)):
            # When validating list inputs, we match shape suffixes only
            # (e.g. "p", 3, "h", "w"), assuming the list length corresponds
            # to the leading symbolic dim (e.g. "bn"). This allows comparing
            # only the trailing dimensions of each element in the list.
            dim = expected_shape[-len(actual) + i]
            # Skip this dimension if it's marked dynamic
            if dim in dynamic_dims:
                continue
            if a != r:
                return False
        return True

    def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
        if not idxs:
            return ""

        return str(list(idxs))

    def _validate_field(
        self,
        value: object,
        field_name: str,
        expected_shape: tuple[Union[int, str], ...],
        dynamic_dims: set[str],
        leading_idxs: tuple[int, ...] = (),
    ) -> tuple[int, ...]:
        """Validate a field and return the actual shape."""
        if isinstance(value, (int, float)):
            return ()  # Scalar
        if isinstance(value, torch.Tensor):
            return value.shape

        if not isinstance(value, (list, tuple)):
            raise TypeError(
                f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
                f"one of the expected types: int, float, Tensor, list, tuple. "
                f"Got: {type(value)}"
            )

        if len(value) == 0:
            raise ValueError(
                f"{field_name}{self._fmt_indexer(leading_idxs)} is an empty sequence"
            )

        # Ensure all tensors in the list have the same
        # shape, besides dynamic dimensions
        for i, v in enumerate(value):
            shape = self._validate_field(
                v,
                field_name,
                expected_shape[1:],
                dynamic_dims,
                leading_idxs=leading_idxs + (i,),
            )

            if i == 0:
                first_shape = shape
            elif not self._match_shape_with_dynamic(
                shape,
                first_shape,
                expected_shape,
                dynamic_dims,
            ):
                raise ValueError(
                    f"{field_name}{self._fmt_indexer(leading_idxs)} "
                    f"contains inconsistent shapes: {first_shape} "
                    f"(index 0) vs {shape} (index {i})"
                )

        # Treat the list as a stacked tensor:
        # shape = (len(list), *tensor.shape)
        return (len(value),) + first_shape

    def _validate_tensor_shape_expected(
        self,
        actual_shape: tuple[int, ...],
        expected_shape: tuple[Union[int, str], ...],
        field_name: str,
        shape_env: dict[str, int],
        dynamic_dims: set[str],
    ) -> None:
        """Validate that the actual tensor shape matches the expected shape."""

        if len(actual_shape) != len(expected_shape):
            raise ValueError(
                f"{field_name} has rank {len(actual_shape)} "
                f"but expected {len(expected_shape)}. "
                f"Expected shape: {expected_shape}, "
                f"but got {actual_shape}"
            )

        for i, dim in enumerate(expected_shape):
            if dim in dynamic_dims:
                continue
            elif isinstance(dim, int):
                if actual_shape[i] != dim:
                    raise ValueError(
                        f"{field_name} dim[{i}] expected "
                        f"{dim}, got {actual_shape[i]}. "
                        f"Expected shape: {expected_shape}, "
                        f"but got {actual_shape}"
                    )
            elif isinstance(dim, str):
                if dim in shape_env:
                    if actual_shape[i] != shape_env[dim]:
                        raise ValueError(
                            f"{field_name} dim[{i}] expected "
                            f"'{dim}'={shape_env[dim]}, got "
                            f"{actual_shape[i]}"
                        )
                else:
                    shape_env[dim] = actual_shape[i]
            else:
                raise TypeError(
                    f"{field_name} dim[{i}] has unsupported type: {type(dim)}"
                )

    def validate(self) -> None:
        type_hints = get_type_hints(self.__class__, include_extras=True)
        shape_env = dict[str, int]()

        for field_name, field_type in type_hints.items():
            # Check if field is missing
            if not hasattr(self, field_name) or getattr(self, field_name) is None:
                # Check if field is marked as optional
                actual_type = field_type
                if get_origin(field_type) is Annotated:
                    args = get_args(field_type)
                    actual_type = args[0]

                # Check arg was provided as Union
                if get_origin(actual_type) is Union:
                    args = get_args(actual_type)
                    # Skip validation when Union contains None
                    if type(None) in args:
                        continue
                # Otherwise field is required, raise error
                raise ValueError(f"Required field '{field_name}' is missing")

            # Field exists, proceed with validation
            value = getattr(self, field_name)
            if get_origin(field_type) is not None:
                args = get_args(field_type)

                for arg in args:
                    if isinstance(arg, TensorShape):
                        expected_shape = arg.resolve(**self._resolve_bindings)
                        actual_shape = self._validate_field(
                            value,
                            field_name,
                            expected_shape,
                            arg.dynamic_dims,
                        )

                        self._validate_tensor_shape_expected(
                            actual_shape,
                            expected_shape,
                            field_name,
                            shape_env,
                            arg.dynamic_dims,
                        )

    def print_shapes(self) -> None:
        """Print TensorShape annotations for debugging."""
        logger.debug("Shapes in %s:", self.__class__.__name__)
        type_hints = get_type_hints(self.__class__, include_extras=True)

        for field_name, field_type in type_hints.items():
            if get_origin(field_type) is not None:
                args = get_args(field_type)
                for arg in args:
                    if isinstance(arg, TensorShape):
                        logger.debug("  %s: %s", field_name, str(arg))

_resolve_bindings instance-attribute

_resolve_bindings = (
    resolve_bindings if resolve_bindings else {}
)

__getitem__

__getitem__(key: str) -> Any
Source code in vllm/utils/tensor_schema.py
def __getitem__(self, key: str) -> Any:
    return getattr(self, key)

__init__

__init__(
    *,
    validate: bool = True,
    resolve_bindings: Optional[dict[str, int]] = None,
    **kwargs: Any,
) -> None
Source code in vllm/utils/tensor_schema.py
def __init__(
    self,
    *,
    validate: bool = True,
    resolve_bindings: Optional[dict[str, int]] = None,
    **kwargs: Any,
) -> None:
    super().__init__()

    self._resolve_bindings = resolve_bindings if resolve_bindings else {}

    for key, value in kwargs.items():
        setattr(self, key, value)

    if validate:
        self.validate()

_fmt_indexer

_fmt_indexer(idxs: tuple[int, ...]) -> str
Source code in vllm/utils/tensor_schema.py
def _fmt_indexer(self, idxs: tuple[int, ...]) -> str:
    if not idxs:
        return ""

    return str(list(idxs))

_match_shape_with_dynamic

_match_shape_with_dynamic(
    actual: tuple[int, ...],
    reference: tuple[int, ...],
    expected_shape: tuple[Union[int, str], ...],
    dynamic_dims: set[str],
) -> bool
Source code in vllm/utils/tensor_schema.py
def _match_shape_with_dynamic(
    self,
    actual: tuple[int, ...],
    reference: tuple[int, ...],
    expected_shape: tuple[Union[int, str], ...],
    dynamic_dims: set[str],
) -> bool:
    if len(actual) != len(reference) or len(actual) > len(expected_shape):
        return False

    for i, (a, r) in enumerate(zip(actual, reference)):
        # When validating list inputs, we match shape suffixes only
        # (e.g. "p", 3, "h", "w"), assuming the list length corresponds
        # to the leading symbolic dim (e.g. "bn"). This allows comparing
        # only the trailing dimensions of each element in the list.
        dim = expected_shape[-len(actual) + i]
        # Skip this dimension if it's marked dynamic
        if dim in dynamic_dims:
            continue
        if a != r:
            return False
    return True

_validate_field

_validate_field(
    value: object,
    field_name: str,
    expected_shape: tuple[Union[int, str], ...],
    dynamic_dims: set[str],
    leading_idxs: tuple[int, ...] = (),
) -> tuple[int, ...]

Validate a field and return the actual shape.

Source code in vllm/utils/tensor_schema.py
def _validate_field(
    self,
    value: object,
    field_name: str,
    expected_shape: tuple[Union[int, str], ...],
    dynamic_dims: set[str],
    leading_idxs: tuple[int, ...] = (),
) -> tuple[int, ...]:
    """Validate a field and return the actual shape."""
    if isinstance(value, (int, float)):
        return ()  # Scalar
    if isinstance(value, torch.Tensor):
        return value.shape

    if not isinstance(value, (list, tuple)):
        raise TypeError(
            f"{field_name}{self._fmt_indexer(leading_idxs)} is not "
            f"one of the expected types: int, float, Tensor, list, tuple. "
            f"Got: {type(value)}"
        )

    if len(value) == 0:
        raise ValueError(
            f"{field_name}{self._fmt_indexer(leading_idxs)} is an empty sequence"
        )

    # Ensure all tensors in the list have the same
    # shape, besides dynamic dimensions
    for i, v in enumerate(value):
        shape = self._validate_field(
            v,
            field_name,
            expected_shape[1:],
            dynamic_dims,
            leading_idxs=leading_idxs + (i,),
        )

        if i == 0:
            first_shape = shape
        elif not self._match_shape_with_dynamic(
            shape,
            first_shape,
            expected_shape,
            dynamic_dims,
        ):
            raise ValueError(
                f"{field_name}{self._fmt_indexer(leading_idxs)} "
                f"contains inconsistent shapes: {first_shape} "
                f"(index 0) vs {shape} (index {i})"
            )

    # Treat the list as a stacked tensor:
    # shape = (len(list), *tensor.shape)
    return (len(value),) + first_shape

_validate_tensor_shape_expected

_validate_tensor_shape_expected(
    actual_shape: tuple[int, ...],
    expected_shape: tuple[Union[int, str], ...],
    field_name: str,
    shape_env: dict[str, int],
    dynamic_dims: set[str],
) -> None

Validate that the actual tensor shape matches the expected shape.

Source code in vllm/utils/tensor_schema.py
def _validate_tensor_shape_expected(
    self,
    actual_shape: tuple[int, ...],
    expected_shape: tuple[Union[int, str], ...],
    field_name: str,
    shape_env: dict[str, int],
    dynamic_dims: set[str],
) -> None:
    """Validate that the actual tensor shape matches the expected shape."""

    if len(actual_shape) != len(expected_shape):
        raise ValueError(
            f"{field_name} has rank {len(actual_shape)} "
            f"but expected {len(expected_shape)}. "
            f"Expected shape: {expected_shape}, "
            f"but got {actual_shape}"
        )

    for i, dim in enumerate(expected_shape):
        if dim in dynamic_dims:
            continue
        elif isinstance(dim, int):
            if actual_shape[i] != dim:
                raise ValueError(
                    f"{field_name} dim[{i}] expected "
                    f"{dim}, got {actual_shape[i]}. "
                    f"Expected shape: {expected_shape}, "
                    f"but got {actual_shape}"
                )
        elif isinstance(dim, str):
            if dim in shape_env:
                if actual_shape[i] != shape_env[dim]:
                    raise ValueError(
                        f"{field_name} dim[{i}] expected "
                        f"'{dim}'={shape_env[dim]}, got "
                        f"{actual_shape[i]}"
                    )
            else:
                shape_env[dim] = actual_shape[i]
        else:
            raise TypeError(
                f"{field_name} dim[{i}] has unsupported type: {type(dim)}"
            )

get

get(key: str, default: Any = None) -> Any
Source code in vllm/utils/tensor_schema.py
def get(self, key: str, default: Any = None) -> Any:
    return getattr(self, key, default)

print_shapes

print_shapes() -> None

Print TensorShape annotations for debugging.

Source code in vllm/utils/tensor_schema.py
def print_shapes(self) -> None:
    """Print TensorShape annotations for debugging."""
    logger.debug("Shapes in %s:", self.__class__.__name__)
    type_hints = get_type_hints(self.__class__, include_extras=True)

    for field_name, field_type in type_hints.items():
        if get_origin(field_type) is not None:
            args = get_args(field_type)
            for arg in args:
                if isinstance(arg, TensorShape):
                    logger.debug("  %s: %s", field_name, str(arg))

validate

validate() -> None
Source code in vllm/utils/tensor_schema.py
def validate(self) -> None:
    type_hints = get_type_hints(self.__class__, include_extras=True)
    shape_env = dict[str, int]()

    for field_name, field_type in type_hints.items():
        # Check if field is missing
        if not hasattr(self, field_name) or getattr(self, field_name) is None:
            # Check if field is marked as optional
            actual_type = field_type
            if get_origin(field_type) is Annotated:
                args = get_args(field_type)
                actual_type = args[0]

            # Check arg was provided as Union
            if get_origin(actual_type) is Union:
                args = get_args(actual_type)
                # Skip validation when Union contains None
                if type(None) in args:
                    continue
            # Otherwise field is required, raise error
            raise ValueError(f"Required field '{field_name}' is missing")

        # Field exists, proceed with validation
        value = getattr(self, field_name)
        if get_origin(field_type) is not None:
            args = get_args(field_type)

            for arg in args:
                if isinstance(arg, TensorShape):
                    expected_shape = arg.resolve(**self._resolve_bindings)
                    actual_shape = self._validate_field(
                        value,
                        field_name,
                        expected_shape,
                        arg.dynamic_dims,
                    )

                    self._validate_tensor_shape_expected(
                        actual_shape,
                        expected_shape,
                        field_name,
                        shape_env,
                        arg.dynamic_dims,
                    )

TensorShape

Source code in vllm/utils/tensor_schema.py
class TensorShape:
    def __init__(
        self,
        *dims: Union[int, str],
        dynamic_dims: Optional[set[str]] = None,
    ) -> None:
        super().__init__()

        self.dims = dims
        self.dynamic_dims = dynamic_dims if dynamic_dims else set()

    def resolve(self, **bindings: int) -> tuple[Union[int, str], ...]:
        resolved = list[Union[int, str]]()
        for dim in self.dims:
            if isinstance(dim, str) and dim in bindings:
                resolved.append(bindings[dim])
            else:
                resolved.append(dim)
        return tuple(resolved)

    def __str__(self) -> str:
        """Return a string representation of the tensor shape."""
        dim_strs = []
        for dim in self.dims:
            if isinstance(dim, str):
                if dim in self.dynamic_dims:
                    dim_strs.append(f"{dim}*")  # Mark dynamic dimensions with *
                else:
                    dim_strs.append(dim)
            else:
                dim_strs.append(str(dim))
        return f"({', '.join(dim_strs)})"

dims instance-attribute

dims = dims

dynamic_dims instance-attribute

dynamic_dims = dynamic_dims if dynamic_dims else set()

__init__

__init__(
    *dims: Union[int, str],
    dynamic_dims: Optional[set[str]] = None,
) -> None
Source code in vllm/utils/tensor_schema.py
def __init__(
    self,
    *dims: Union[int, str],
    dynamic_dims: Optional[set[str]] = None,
) -> None:
    super().__init__()

    self.dims = dims
    self.dynamic_dims = dynamic_dims if dynamic_dims else set()

__str__

__str__() -> str

Return a string representation of the tensor shape.

Source code in vllm/utils/tensor_schema.py
def __str__(self) -> str:
    """Return a string representation of the tensor shape."""
    dim_strs = []
    for dim in self.dims:
        if isinstance(dim, str):
            if dim in self.dynamic_dims:
                dim_strs.append(f"{dim}*")  # Mark dynamic dimensions with *
            else:
                dim_strs.append(dim)
        else:
            dim_strs.append(str(dim))
    return f"({', '.join(dim_strs)})"

resolve

resolve(**bindings: int) -> tuple[Union[int, str], ...]
Source code in vllm/utils/tensor_schema.py
def resolve(self, **bindings: int) -> tuple[Union[int, str], ...]:
    resolved = list[Union[int, str]]()
    for dim in self.dims:
        if isinstance(dim, str) and dim in bindings:
            resolved.append(bindings[dim])
        else:
            resolved.append(dim)
    return tuple(resolved)