Skip to content

vllm.entrypoints.openai.tool_parsers.llama_tool_parser ΒΆ

logger module-attribute ΒΆ

logger = init_logger(__name__)

Llama3JsonToolParser ΒΆ

Bases: ToolParser

Tool call parser for Llama 3.x and 4 models intended for use with the examples/tool_chat_template_llama.jinja template.

Used when --enable-auto-tool-choice --tool-call-parser llama3_json or llama4_json are set.

Source code in vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
@ToolParserManager.register_module("llama3_json")
@ToolParserManager.register_module("llama4_json")
class Llama3JsonToolParser(ToolParser):
    """
    Tool call parser for Llama 3.x and 4 models intended for use with the
    examples/tool_chat_template_llama.jinja template.

    Used when --enable-auto-tool-choice --tool-call-parser llama3_json or
    llama4_json are set.
    """

    def __init__(self, tokenizer: PreTrainedTokenizerBase):
        super().__init__(tokenizer)

        # initialize properties used for state when parsing tool calls in
        # streaming mode
        self.prev_tool_call_arr: list[dict] = []
        self.current_tool_id: int = -1
        self.current_tool_name_sent: bool = False
        self.streamed_args_for_tool: list[
            str
        ] = []  # map what has been streamed for each tool so far to a list
        self.bot_token = "<|python_tag|>"
        self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
            0
        ]
        # Updated regex to match multiple JSONs separated by semicolons
        # This pattern is more robust and can handle nested JSON objects
        self.tool_call_regex = re.compile(
            r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*",
            re.DOTALL,
        )

    def extract_tool_calls(
        self, model_output: str, request: ChatCompletionRequest
    ) -> ExtractedToolCallInformation:
        """
        Extract the tool calls from a complete model response.
        Only extracts JSON content and ignores any surrounding plain text.
        Supports both single JSON and multiple JSONs separated by semicolons.
        """
        # Quick check before running regex
        if not (self.bot_token in model_output or "{" in model_output):
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )

        # Find JSON object(s) in the text using regex
        match = self.tool_call_regex.search(model_output)
        if not match:
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )

        try:
            json_str = match.group(0)
            # Split by semicolon and strip whitespace
            json_objects = [obj.strip() for obj in json_str.split(";")]

            tool_calls: list[ToolCall] = []
            for json_obj in json_objects:
                if not json_obj:  # Skip empty strings
                    continue
                obj = json.loads(json_obj)
                tool_calls.append(
                    ToolCall(
                        type="function",
                        function=FunctionCall(
                            name=obj["name"],
                            # function call args are JSON but as a string
                            arguments=json.dumps(
                                obj["arguments"]
                                if "arguments" in obj
                                else obj["parameters"],
                                ensure_ascii=False,
                            ),
                        ),
                    )
                )

            return ExtractedToolCallInformation(
                tools_called=True, tool_calls=tool_calls, content=None
            )

        except Exception:
            logger.exception("Error in extracting tool call from response.")
            # return information to just treat the tool call as regular JSON
            return ExtractedToolCallInformation(
                tools_called=False, tool_calls=[], content=model_output
            )

    def extract_tool_calls_streaming(
        self,
        previous_text: str,
        current_text: str,
        delta_text: str,
        previous_token_ids: Sequence[int],
        current_token_ids: Sequence[int],
        delta_token_ids: Sequence[int],
        request: ChatCompletionRequest,
    ) -> Union[DeltaMessage, None]:
        if not (
            current_text.startswith(self.bot_token) or current_text.startswith("{")
        ):
            return DeltaMessage(content=delta_text)

        # bit mask flags for partial JSON parsing. If the name hasn't been
        # sent yet, don't allow sending
        # an incomplete string since OpenAI only ever (as far as I have
        # seen) allows sending the entire tool/ function name at once.
        flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
        try:
            tool_call_arr = []
            is_complete = []
            try:
                # depending on the prompt format the Llama model may or may not
                # prefix the output with the <|python_tag|> token
                start_idx = (
                    len(self.bot_token)
                    if current_text.startswith(self.bot_token)
                    else 0
                )
                while start_idx < len(current_text):
                    (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
                    is_complete.append(
                        is_complete_json(current_text[start_idx : start_idx + end_idx])
                    )
                    start_idx += end_idx + len("; ")
                    # depending on the prompt Llama can use
                    # either arguments or parameters
                    if "parameters" in obj:
                        assert "arguments" not in obj, (
                            "model generated both parameters and arguments"
                        )
                        obj["arguments"] = obj["parameters"]
                    tool_call_arr.append(obj)
            except partial_json_parser.core.exceptions.MalformedJSON:
                logger.debug("not enough tokens to parse into JSON yet")
                return None

            # select as the current tool call the one we're on the state at
            current_tool_call: dict = (
                tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
            )

            # case -- if no tokens have been streamed for the tool, e.g.
            #   only the array brackets, stream nothing
            if len(tool_call_arr) == 0:
                return None

            # case: we are starting a new tool in the array
            #   -> array has > 0 length AND length has moved past cursor
            elif (
                len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
            ):
                # if we're moving on to a new call, first make sure we
                # haven't missed anything in the previous one that was
                # auto-generated due to JSON completions, but wasn't
                # streamed to the client yet.
                if self.current_tool_id >= 0:
                    cur_arguments = current_tool_call.get("arguments")
                    if cur_arguments:
                        cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                        sent = len(self.streamed_args_for_tool[self.current_tool_id])
                        argument_diff = cur_args_json[sent:]

                        logger.debug("got arguments diff: %s", argument_diff)
                        delta = DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    index=self.current_tool_id,
                                    function=DeltaFunctionCall(
                                        arguments=argument_diff
                                    ).model_dump(exclude_none=True),
                                )
                            ]
                        )
                        self.streamed_args_for_tool[self.current_tool_id] += (
                            argument_diff
                        )
                    else:
                        delta = None
                else:
                    delta = None
                # re-set stuff pertaining to progress in the current tool
                self.current_tool_id = len(tool_call_arr) - 1
                self.current_tool_name_sent = False
                self.streamed_args_for_tool.append("")
                logger.debug("starting on new tool %d", self.current_tool_id)
                return delta

            # if the current tool name hasn't been sent, send if available
            # - otherwise send nothing
            elif not self.current_tool_name_sent:
                function_name = current_tool_call.get("name")
                if function_name:
                    delta = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                type="function",
                                id=make_tool_call_id(),
                                function=DeltaFunctionCall(
                                    name=function_name
                                ).model_dump(exclude_none=True),
                            )
                        ]
                    )
                    self.current_tool_name_sent = True
                else:
                    delta = None

            # now we know we're on the same tool call and we're streaming
            # arguments
            else:
                cur_arguments = current_tool_call.get("arguments")
                delta = None

                if cur_arguments:
                    sent = len(self.streamed_args_for_tool[self.current_tool_id])
                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                    prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
                        "arguments"
                    )

                    argument_diff = None
                    if is_complete[self.current_tool_id]:
                        argument_diff = cur_args_json[sent:]
                    elif prev_arguments:
                        prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
                        if cur_args_json != prev_args_json:
                            prefix = find_common_prefix(prev_args_json, cur_args_json)
                            argument_diff = prefix[sent:]

                    if argument_diff is not None:
                        delta = DeltaMessage(
                            tool_calls=[
                                DeltaToolCall(
                                    index=self.current_tool_id,
                                    function=DeltaFunctionCall(
                                        arguments=argument_diff
                                    ).model_dump(exclude_none=True),
                                )
                            ]
                        )
                        self.streamed_args_for_tool[self.current_tool_id] += (
                            argument_diff
                        )

            self.prev_tool_call_arr = tool_call_arr
            return delta

        except Exception:
            logger.exception("Error trying to handle streaming tool call.")
            logger.debug(
                "Skipping chunk as a result of tool streaming extraction error"
            )
            return None

bot_token instance-attribute ΒΆ

bot_token = '<|python_tag|>'

bot_token_id instance-attribute ΒΆ

bot_token_id = encode(bot_token, add_special_tokens=False)[
    0
]

current_tool_id instance-attribute ΒΆ

current_tool_id: int = -1

current_tool_name_sent instance-attribute ΒΆ

current_tool_name_sent: bool = False

prev_tool_call_arr instance-attribute ΒΆ

prev_tool_call_arr: list[dict] = []

streamed_args_for_tool instance-attribute ΒΆ

streamed_args_for_tool: list[str] = []

tool_call_regex instance-attribute ΒΆ

tool_call_regex = compile(
    "{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\\s*;\\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*",
    DOTALL,
)

__init__ ΒΆ

__init__(tokenizer: PreTrainedTokenizerBase)
Source code in vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
def __init__(self, tokenizer: PreTrainedTokenizerBase):
    super().__init__(tokenizer)

    # initialize properties used for state when parsing tool calls in
    # streaming mode
    self.prev_tool_call_arr: list[dict] = []
    self.current_tool_id: int = -1
    self.current_tool_name_sent: bool = False
    self.streamed_args_for_tool: list[
        str
    ] = []  # map what has been streamed for each tool so far to a list
    self.bot_token = "<|python_tag|>"
    self.bot_token_id = tokenizer.encode(self.bot_token, add_special_tokens=False)[
        0
    ]
    # Updated regex to match multiple JSONs separated by semicolons
    # This pattern is more robust and can handle nested JSON objects
    self.tool_call_regex = re.compile(
        r"{[^{}]*(?:{[^{}]*}[^{}]*)*}(?:\s*;\s*{[^{}]*(?:{[^{}]*}[^{}]*)*})*",
        re.DOTALL,
    )

extract_tool_calls ΒΆ

extract_tool_calls(
    model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation

Extract the tool calls from a complete model response. Only extracts JSON content and ignores any surrounding plain text. Supports both single JSON and multiple JSONs separated by semicolons.

Source code in vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
def extract_tool_calls(
    self, model_output: str, request: ChatCompletionRequest
) -> ExtractedToolCallInformation:
    """
    Extract the tool calls from a complete model response.
    Only extracts JSON content and ignores any surrounding plain text.
    Supports both single JSON and multiple JSONs separated by semicolons.
    """
    # Quick check before running regex
    if not (self.bot_token in model_output or "{" in model_output):
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )

    # Find JSON object(s) in the text using regex
    match = self.tool_call_regex.search(model_output)
    if not match:
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )

    try:
        json_str = match.group(0)
        # Split by semicolon and strip whitespace
        json_objects = [obj.strip() for obj in json_str.split(";")]

        tool_calls: list[ToolCall] = []
        for json_obj in json_objects:
            if not json_obj:  # Skip empty strings
                continue
            obj = json.loads(json_obj)
            tool_calls.append(
                ToolCall(
                    type="function",
                    function=FunctionCall(
                        name=obj["name"],
                        # function call args are JSON but as a string
                        arguments=json.dumps(
                            obj["arguments"]
                            if "arguments" in obj
                            else obj["parameters"],
                            ensure_ascii=False,
                        ),
                    ),
                )
            )

        return ExtractedToolCallInformation(
            tools_called=True, tool_calls=tool_calls, content=None
        )

    except Exception:
        logger.exception("Error in extracting tool call from response.")
        # return information to just treat the tool call as regular JSON
        return ExtractedToolCallInformation(
            tools_called=False, tool_calls=[], content=model_output
        )

extract_tool_calls_streaming ΒΆ

extract_tool_calls_streaming(
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
    request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]
Source code in vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py
def extract_tool_calls_streaming(
    self,
    previous_text: str,
    current_text: str,
    delta_text: str,
    previous_token_ids: Sequence[int],
    current_token_ids: Sequence[int],
    delta_token_ids: Sequence[int],
    request: ChatCompletionRequest,
) -> Union[DeltaMessage, None]:
    if not (
        current_text.startswith(self.bot_token) or current_text.startswith("{")
    ):
        return DeltaMessage(content=delta_text)

    # bit mask flags for partial JSON parsing. If the name hasn't been
    # sent yet, don't allow sending
    # an incomplete string since OpenAI only ever (as far as I have
    # seen) allows sending the entire tool/ function name at once.
    flags = Allow.ALL if self.current_tool_name_sent else Allow.ALL & ~Allow.STR
    try:
        tool_call_arr = []
        is_complete = []
        try:
            # depending on the prompt format the Llama model may or may not
            # prefix the output with the <|python_tag|> token
            start_idx = (
                len(self.bot_token)
                if current_text.startswith(self.bot_token)
                else 0
            )
            while start_idx < len(current_text):
                (obj, end_idx) = partial_json_loads(current_text[start_idx:], flags)
                is_complete.append(
                    is_complete_json(current_text[start_idx : start_idx + end_idx])
                )
                start_idx += end_idx + len("; ")
                # depending on the prompt Llama can use
                # either arguments or parameters
                if "parameters" in obj:
                    assert "arguments" not in obj, (
                        "model generated both parameters and arguments"
                    )
                    obj["arguments"] = obj["parameters"]
                tool_call_arr.append(obj)
        except partial_json_parser.core.exceptions.MalformedJSON:
            logger.debug("not enough tokens to parse into JSON yet")
            return None

        # select as the current tool call the one we're on the state at
        current_tool_call: dict = (
            tool_call_arr[self.current_tool_id] if len(tool_call_arr) > 0 else {}
        )

        # case -- if no tokens have been streamed for the tool, e.g.
        #   only the array brackets, stream nothing
        if len(tool_call_arr) == 0:
            return None

        # case: we are starting a new tool in the array
        #   -> array has > 0 length AND length has moved past cursor
        elif (
            len(tool_call_arr) > 0 and len(tool_call_arr) > self.current_tool_id + 1
        ):
            # if we're moving on to a new call, first make sure we
            # haven't missed anything in the previous one that was
            # auto-generated due to JSON completions, but wasn't
            # streamed to the client yet.
            if self.current_tool_id >= 0:
                cur_arguments = current_tool_call.get("arguments")
                if cur_arguments:
                    cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                    sent = len(self.streamed_args_for_tool[self.current_tool_id])
                    argument_diff = cur_args_json[sent:]

                    logger.debug("got arguments diff: %s", argument_diff)
                    delta = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                function=DeltaFunctionCall(
                                    arguments=argument_diff
                                ).model_dump(exclude_none=True),
                            )
                        ]
                    )
                    self.streamed_args_for_tool[self.current_tool_id] += (
                        argument_diff
                    )
                else:
                    delta = None
            else:
                delta = None
            # re-set stuff pertaining to progress in the current tool
            self.current_tool_id = len(tool_call_arr) - 1
            self.current_tool_name_sent = False
            self.streamed_args_for_tool.append("")
            logger.debug("starting on new tool %d", self.current_tool_id)
            return delta

        # if the current tool name hasn't been sent, send if available
        # - otherwise send nothing
        elif not self.current_tool_name_sent:
            function_name = current_tool_call.get("name")
            if function_name:
                delta = DeltaMessage(
                    tool_calls=[
                        DeltaToolCall(
                            index=self.current_tool_id,
                            type="function",
                            id=make_tool_call_id(),
                            function=DeltaFunctionCall(
                                name=function_name
                            ).model_dump(exclude_none=True),
                        )
                    ]
                )
                self.current_tool_name_sent = True
            else:
                delta = None

        # now we know we're on the same tool call and we're streaming
        # arguments
        else:
            cur_arguments = current_tool_call.get("arguments")
            delta = None

            if cur_arguments:
                sent = len(self.streamed_args_for_tool[self.current_tool_id])
                cur_args_json = json.dumps(cur_arguments, ensure_ascii=False)
                prev_arguments = self.prev_tool_call_arr[self.current_tool_id].get(
                    "arguments"
                )

                argument_diff = None
                if is_complete[self.current_tool_id]:
                    argument_diff = cur_args_json[sent:]
                elif prev_arguments:
                    prev_args_json = json.dumps(prev_arguments, ensure_ascii=False)
                    if cur_args_json != prev_args_json:
                        prefix = find_common_prefix(prev_args_json, cur_args_json)
                        argument_diff = prefix[sent:]

                if argument_diff is not None:
                    delta = DeltaMessage(
                        tool_calls=[
                            DeltaToolCall(
                                index=self.current_tool_id,
                                function=DeltaFunctionCall(
                                    arguments=argument_diff
                                ).model_dump(exclude_none=True),
                            )
                        ]
                    )
                    self.streamed_args_for_tool[self.current_tool_id] += (
                        argument_diff
                    )

        self.prev_tool_call_arr = tool_call_arr
        return delta

    except Exception:
        logger.exception("Error trying to handle streaming tool call.")
        logger.debug(
            "Skipping chunk as a result of tool streaming extraction error"
        )
        return None