From 1a0b3113bdd347a9fb69ee355a433b60a55227c4 Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Tue, 9 Sep 2025 15:24:45 +0000 Subject: [PATCH 1/2] Add templates --- .gitignore | 3 ++ agentfly/templates/templates.py | 62 ++++++++++++++++++++++++++++----- agentfly/templates/utils.py | 8 +++-- verl | 2 +- 4 files changed, 63 insertions(+), 12 deletions(-) diff --git a/.gitignore b/.gitignore index c0918cf..a347287 100644 --- a/.gitignore +++ b/.gitignore @@ -130,6 +130,9 @@ log outputs *.out +/*.png +/*.jpg + # Notebooks agentfly/tests/*.ipynb agentfly/tests/*.jpg diff --git a/agentfly/templates/templates.py b/agentfly/templates/templates.py index 7e54eff..1662e56 100644 --- a/agentfly/templates/templates.py +++ b/agentfly/templates/templates.py @@ -64,8 +64,6 @@ class Template: system_template_with_tools: str = None # The system message system_message: str = "" - # Stop criteria (the default one is EOS token) - stop_words: Union[str, List[str]] = None # Behaviors # The tool template tool_template: str = None @@ -74,6 +72,12 @@ class Template: user_template_with_tools: str = None # The assistant template assistant_template: str = None + + + # Stop criteria (the default one is EOS token) + stop_words: Union[str, List[str]] = None + # Generation prompt + generation_prompt: str = None # Global policy global_policy: "GlobalPolicy" = None # System message policy @@ -283,7 +287,7 @@ def _maybe_add_generation_prompt(self, elements: List[str], roles: List[Role]): """Append the generation prefix so the model knows to continue generating an assistant response.""" - generation_prefix = self._encode_generation_prompt() + generation_prefix, prefix = self._encode_generation_prompt() elements.append(generation_prefix) roles.append(Role.ASSISTANT_PREFIX) @@ -414,24 +418,33 @@ def _encode_tool_message(self, content) -> str: return tool_message def _encode_generation_prompt(self) -> str: + # Use generation prompt if it is set if "{content}" in self.assistant_template: prefix = self.assistant_template.split("{content}")[0] - return prefix + if self.generation_prompt: + generation_prompt = self.generation_prompt + else: + generation_prompt = prefix else: raise ValueError(f"Assistant template {self.assistant_template} does not contain {{content}}") + return generation_prompt, prefix + + def _split_assistant_message(self, assistant_message: str) -> List[str]: # Split the assistant message into generation prefix, content, and generation suffix - generation_prefix = self._encode_generation_prompt() - assert assistant_message.startswith(generation_prefix), f"Assistant message {assistant_message} does not start with {generation_prefix}" - content_suffix = assistant_message[len(generation_prefix):] + _, prefix = self._encode_generation_prompt() + assert assistant_message.startswith(prefix), f"Assistant message {assistant_message} does not start with {generation_prefix}" + content_suffix = assistant_message[len(prefix):] + content = content_suffix + suffix = "" for stop_word in self.stop_words: if stop_word in content_suffix: stop_word_index = content_suffix.index(stop_word) content = content_suffix[:stop_word_index+len(stop_word)] suffix = content_suffix[stop_word_index+len(stop_word):] break - return generation_prefix, content, suffix + return prefix, content, suffix def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str: @@ -937,6 +950,7 @@ def copy(self): assistant_template=self.assistant_template, tool_template=self.tool_template, stop_words=self.stop_words, + generation_prompt=self.generation_prompt, vision_start=self.vision_start, vision_end=self.vision_end, image_token=self.image_token, @@ -1145,6 +1159,21 @@ def get_template(name: str) -> Template: ) ) +register_template( + Template( + name="qwen3", + system_template="<|im_start|>system\n{system_message}<|im_end|>\n", + user_template="<|im_start|>user\n{content}<|im_end|>\n", + assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", + stop_words=["<|im_end|>"], + generation_prompt="<|im_start|>assistant\n", + system_policy=SystemPolicy( + use_system=True, + use_system_without_system_message=False, + ), + ) +) + register_template( Template( name="deepseek-prover", @@ -1249,5 +1278,22 @@ def get_template(name: str) -> Template: ) ) +register_template( + Template( + name="deepseek-r1-distill-qwen", + user_template="<|User|>{content}", + assistant_template="<|Assistant|>{content}<|end▁of▁sentence|>", + stop_words=["<|end▁of▁sentence|>"], + generation_prompt="<|Assistant|>\n", + global_policy=GlobalPolicy( + prefix="<|begin▁of▁sentence|>" + ), + system_policy=SystemPolicy( + use_system=False, + use_system_without_system_message=False, + ), + ) +) + if __name__ == "__main__": pass \ No newline at end of file diff --git a/agentfly/templates/utils.py b/agentfly/templates/utils.py index 441f705..369b2c2 100644 --- a/agentfly/templates/utils.py +++ b/agentfly/templates/utils.py @@ -326,9 +326,9 @@ def visualize_jinja_template(tokenizer, messages=None, tools=None, **kwargs): prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, **kwargs) print(prompt) -def compare_hf_template(tokenizer, template_name, messages=None, tools=None, add_generation_prompt=False): - official_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt) - chat = Chat(template_name, messages, tokenizer) +def compare_hf_template(tokenizer, template_name, messages=None, tools=None, add_generation_prompt=False, **kwargs): + official_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs) + chat = Chat(template_name, messages=messages, tokenizer=tokenizer) implemented_prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools) is_equal = official_prompt == implemented_prompt highlighted_prompt = chat.prompt_with_mask(add_generation_prompt=add_generation_prompt, tools=tools) @@ -336,9 +336,11 @@ def compare_hf_template(tokenizer, template_name, messages=None, tools=None, add is_equal_between_implemented_prompts = implemented_prompt == plain_highlighted_prompt jinja_template = chat.template.jinja_template() + official_jinja_prompt = tokenizer.chat_template tokenizer.chat_template = jinja_template implemented_jinja_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt) is_equal_between_jinja_prompts = implemented_jinja_prompt == implemented_prompt + tokenizer.chat_template = official_jinja_prompt return is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt diff --git a/verl b/verl index 8bc5532..7a99093 160000 --- a/verl +++ b/verl @@ -1 +1 @@ -Subproject commit 8bc553292589eb2b8ace918410e78c1d27ec0c84 +Subproject commit 7a99093cd3bc939d1eeae6d2e6aafea1a70ca417 From e0dea5d630da55d508d3cd1e2b9e17218597857e Mon Sep 17 00:00:00 2001 From: Reason-Wang Date: Thu, 11 Sep 2025 10:15:25 +0000 Subject: [PATCH 2/2] Add qwen3 template --- agentfly/templates/system_policy.py | 7 +- agentfly/templates/templates.py | 258 ++++++++++++++++-- agentfly/templates/utils.py | 9 +- .../agents/templates/test_qwen3_prompt.py | 54 ++++ .../agents/templates/test_qwen3_tokenize.py | 69 +++++ test_qwen3_template.py | 83 ++++++ 6 files changed, 446 insertions(+), 34 deletions(-) create mode 100644 agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py create mode 100644 test_qwen3_template.py diff --git a/agentfly/templates/system_policy.py b/agentfly/templates/system_policy.py index 8ede734..027b900 100644 --- a/agentfly/templates/system_policy.py +++ b/agentfly/templates/system_policy.py @@ -5,8 +5,9 @@ @dataclasses.dataclass class SystemPolicy: - use_system: bool = True - use_system_without_system_message: bool = True + use_system: bool = True # Global control + use_system_without_system_message: bool = True # When no system message is provided, use the system message even it is empty (will use the default one if provided) + use_system_with_tools_provided: bool = True # When tools are provided, use the system message with tools even no system message is provided content_processor: Callable[[str], str] = None @@ -32,7 +33,7 @@ class Llama32DateProcessor(SystemContentProcessor): - Format should be 'dd MMM yyyy' (e.g., '15 Dec 2024') - No external context variables required """ - def __call__(self, system_message: str) -> str: + def __call__(self, system_message: str, tools: str) -> str: return f"Cutting Knowledge Date: December 2023\nToday Date: {datetime.datetime.now().strftime('%d %b %Y')}\n\n{system_message}" def jinja(self) -> str: diff --git a/agentfly/templates/templates.py b/agentfly/templates/templates.py index 1662e56..3c74346 100644 --- a/agentfly/templates/templates.py +++ b/agentfly/templates/templates.py @@ -324,10 +324,14 @@ def _encode_system_tools(self, tools: List[Dict]) -> str: def _encode_system_message_default(self, tools=None) -> str: if not self.system_policy.use_system_without_system_message: - return "" + if tools is None: + return "" + else: + # If tools are provided, use the system message with tools + pass if self.system_policy.content_processor is not None: - system_message = self.system_policy.content_processor(self.system_message) + system_message = self.system_policy.content_processor(self.system_message, tools=tools) else: system_message = self.system_message @@ -347,7 +351,7 @@ def _encode_system_message(self, content, tools=None) -> str: system_message = content[0]['text'] if self.system_policy.content_processor is not None: - system_message = self.system_policy.content_processor(system_message) + system_message = self.system_policy.content_processor(system_message, tools=tools) if tools is None: return self.system_template.format(system_message=system_message) @@ -433,8 +437,8 @@ def _encode_generation_prompt(self) -> str: def _split_assistant_message(self, assistant_message: str) -> List[str]: # Split the assistant message into generation prefix, content, and generation suffix - _, prefix = self._encode_generation_prompt() - assert assistant_message.startswith(prefix), f"Assistant message {assistant_message} does not start with {generation_prefix}" + generation_prefix, prefix = self._encode_generation_prompt() + assert assistant_message.startswith(prefix), f"Assistant message {assistant_message} does not start with {prefix}" content_suffix = assistant_message[len(prefix):] content = content_suffix suffix = "" @@ -447,7 +451,7 @@ def _split_assistant_message(self, assistant_message: str) -> List[str]: return prefix, content, suffix - def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str: + def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None, **kwargs) -> str: """Encode the messages to token ids. Args: @@ -466,15 +470,15 @@ def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_te if self.supports_vision(): # Use vision-aware encoding with proper alignment - return self._encode_with_vision_processor(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt, processor=processor) + return self._encode_with_vision_processor(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt, processor=processor, **kwargs) else: # Use standard encoding - return self._encode_standard(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt) + return self._encode_standard(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt, **kwargs) - def _encode_standard(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False) -> str: + def _encode_standard(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, **kwargs) -> str: Logger.debug(f"[Template] Encoding standard for template: {self.name}") """Standard encoding without vision support""" - prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt) + prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs) elements, mask_flags = self._postprocess_elements(elements, roles) input_ids = [] attention_mask = [] @@ -511,7 +515,7 @@ def _encode_standard(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, inputs = {k: torch.tensor([v]) for k, v in inputs.items()} return inputs - def _encode_with_vision_processor(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str: + def _encode_with_vision_processor(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None, **kwargs) -> str: Logger.debug(f"[Template] Encoding with vision processor for template: {self.name}") """Encode with vision processor handling proper alignment""" from .vision_processor import get_processor @@ -523,7 +527,7 @@ def _encode_with_vision_processor(self, messages: List[Dict], tokenizer: PreTrai raise ValueError(f"No vision processor registered for template: {self.name}") # Get base prompt and mask information - prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt) + prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs) elements, mask_flags = self._postprocess_elements(elements, roles) # Extract vision inputs @@ -689,7 +693,7 @@ def _jinja_header_constants(self) -> List[str]: # Compute default system message considering content processor if self.system_policy.content_processor is not None: # Apply content processor to system message - processed_system_message = self.system_policy.content_processor(self.system_message) + processed_system_message = self.system_policy.content_processor(self.system_message, tools=None) # TODO: tools is not used here, but we need to pass it for consistency default_system = self.system_template.format(system_message=processed_system_message) else: default_system = self.system_template.format(system_message=self.system_message) @@ -920,9 +924,9 @@ def _jinja_generation_block(self) -> List[str]: ] - def render_with_mask(self, messages: List[Dict], add_generation_prompt: bool = False, tools=None): + def render_with_mask(self, messages: List[Dict], add_generation_prompt: bool = False, tools=None, **kwargs): from termcolor import colored - prompt, elements, roles = self.render(messages, add_generation_prompt=add_generation_prompt, tools=tools) + prompt, elements, roles = self.render(messages, add_generation_prompt=add_generation_prompt, tools=tools, **kwargs) elements, mask_flags = self._postprocess_elements(elements, roles) @@ -931,7 +935,7 @@ def render_with_mask(self, messages: List[Dict], add_generation_prompt: bool = F if mask_flag: prompt += colored(element, "red") else: - prompt += element + prompt += colored(element, "green") return prompt, elements, mask_flags def set_system_message(self, system_message: str): @@ -940,7 +944,7 @@ def set_system_message(self, system_message: str): def copy(self): - return Template( + return self.__class__( name=self.name, system_template=self.system_template, system_template_with_tools=self.system_template_with_tools, @@ -973,6 +977,203 @@ def dict(self): "video_token": self.video_token, } +class Qwen3Template(Template): + def render(self, messages: List[Dict], tools=None, add_generation_prompt: bool = False, enable_thinking: bool = False) -> str: + """Render the Qwen3 template with special thinking logic. + + Args: + messages: The list of messages + tools: The list of tools + add_generation_prompt: Whether to add the generation prefix + enable_thinking: Whether to enable thinking mode + + Returns: + prompt: The final prompt string + elements: The list of string *elements* that compose the prompt + roles: The corresponding list of *roles* (used by downstream post-processing) + """ + + # Step 1 – decide tool placement & clone messages + work_messages, tools_str, insert_tools_idx = self._insert_tools(messages, tools) + + # Step 2 – clean think content from all assistant messages except the last one + work_messages = self._clean_think_content(work_messages) + + # Step 2.5 – reformat think content in the last assistant message if it exists + if work_messages and work_messages[-1].get("role") == "assistant": + work_messages = self._reformat_last_assistant_think_content(work_messages) + + # Step 3 – encode each conversation turn to text tokens + elements, roles = self._encode_turns(work_messages, tools_str, insert_tools_idx) + + # Step 4 – handle special generation prompt logic for Qwen3 + if add_generation_prompt: + self._maybe_add_generation_prompt_qwen3(elements, roles, enable_thinking, work_messages) + elif work_messages and work_messages[-1].get("role") == "assistant": + # Add empty think tokens to the last assistant message if it doesn't already have think tags + self._add_empty_think_to_last_assistant(elements, roles, work_messages) + + # Concatenate the prompt + prompt = "".join(elements) + return prompt, elements, roles + + def _clean_think_content(self, messages: List[Dict]) -> List[Dict]: + """Remove all think content (...) from assistant messages and reformat existing think content.""" + cleaned_messages = [] + for i, message in enumerate(messages): + if message.get("role") == "assistant" and i != len(messages) - 1: + cleaned_message = message.copy() + content = message["content"] + + if isinstance(content, str): + # Remove think content from string + cleaned_content = self._remove_think_tags(content) + else: + # Handle list content format + cleaned_content = [] + for item in content: + if item["type"] == "text": + cleaned_text = self._remove_think_tags(item["text"]) + cleaned_content.append({"type": "text", "text": cleaned_text}) + else: + cleaned_content.append(item) + + cleaned_message["content"] = cleaned_content + cleaned_messages.append(cleaned_message) + else: + cleaned_messages.append(message) + + return cleaned_messages + + def _remove_think_tags(self, text: str) -> str: + """Remove ... tags from text.""" + import re + # Remove ... tags and their content + pattern = r'.*?' + return re.sub(pattern, '', text, flags=re.DOTALL) + + def _has_think_tags(self, text: str) -> bool: + """Check if text contains and tags.""" + return '' in text and '' in text + + def _reformat_think_content(self, text: str) -> str: + """Reformat think content to ensure each think token ends with two newlines.""" + import re + + def replace_think_content(match): + think_content = match.group(1) + # Ensure the think content ends with exactly two newlines + think_content = think_content.rstrip('\n') + return f'\n{think_content}\n\n\n' + + # Find and replace think tags, ensuring proper formatting + pattern = r'(.*?)' + return re.sub(pattern, replace_think_content, text, flags=re.DOTALL) + + def _reformat_last_assistant_think_content(self, messages: List[Dict]) -> List[Dict]: + """Reformat think content in the last assistant message.""" + if not messages or messages[-1].get("role") != "assistant": + return messages + + messages = messages.copy() + last_message = messages[-1].copy() + content = last_message["content"] + + if isinstance(content, str): + # Reformat think content in string + last_message["content"] = self._reformat_think_content(content) + else: + # Handle list content format + reformed_content = [] + for item in content: + if item["type"] == "text": + reformed_text = self._reformat_think_content(item["text"]) + reformed_content.append({"type": "text", "text": reformed_text}) + else: + reformed_content.append(item) + last_message["content"] = reformed_content + + messages[-1] = last_message + return messages + + def _maybe_add_generation_prompt_qwen3(self, elements: List[str], roles: List[Role], enable_thinking: bool, work_messages: List[Dict]): + """Append the generation prefix with special Qwen3 thinking logic.""" + if enable_thinking: + # Use standard generation prompt + generation_prefix, prefix = self._encode_generation_prompt() + elements.append(generation_prefix) + roles.append(Role.ASSISTANT_PREFIX) + else: + # Check if the last message has think tags + has_existing_think = False + if work_messages and work_messages[-1].get("role") == "assistant": + content = work_messages[-1]["content"] + if isinstance(content, str): + has_existing_think = self._has_think_tags(content) + elif isinstance(content, list): + for item in content: + if item.get("type") == "text" and self._has_think_tags(item["text"]): + has_existing_think = True + break + + generation_prefix, prefix = self._encode_generation_prompt() + if has_existing_think: + # Don't add empty think tokens if think tags already exist + elements.append(generation_prefix) + else: + # Add empty think tokens after the generation prefix + elements.append(generation_prefix + "\n\n\n\n") + roles.append(Role.ASSISTANT_PREFIX) + + def _add_empty_think_to_last_assistant(self, elements: List[str], roles: List[Role], work_messages: List[Dict]): + """Add empty think tokens to the last assistant message if it doesn't already have think tags.""" + if not elements or not roles or not work_messages: + return + + # Check if the last message has think tags + has_existing_think = False + if work_messages[-1].get("role") == "assistant": + content = work_messages[-1]["content"] + if isinstance(content, str): + has_existing_think = self._has_think_tags(content) + elif isinstance(content, list): + for item in content: + if item.get("type") == "text" and self._has_think_tags(item["text"]): + has_existing_think = True + break + + # Only add empty think tokens if no existing think tags + if not has_existing_think: + generation_prefix, prefix = self._encode_generation_prompt() + + # Find the last assistant element + for i in range(len(elements) - 1, -1, -1): + if roles[i] == Role.ASSISTANT: + # Add empty think tokens at the start of the assistant message + elements[i] = prefix + "\n\n\n\n" + elements[i][len(prefix):] + break + + def _split_assistant_message(self, assistant_message: str) -> List[str]: + # Split the assistant message into generation prefix, content, and generation suffix + generation_prefix, prefix = self._encode_generation_prompt() + assert assistant_message.startswith(prefix), f"Assistant message {assistant_message} does not start with {prefix}" + + # We need to detect whether the assistant message starts with empty think tokens + # If so, we need to set empty think tokens as non-assistant message + if assistant_message.startswith(prefix + "\n\n\n\n"): + prefix = prefix + "\n\n\n\n" + + content_suffix = assistant_message[len(prefix):] + content = content_suffix + suffix = "" + for stop_word in self.stop_words: + if stop_word in content_suffix: + stop_word_index = content_suffix.index(stop_word) + content = content_suffix[:stop_word_index+len(stop_word)] + suffix = content_suffix[stop_word_index+len(stop_word):] + break + return prefix, content, suffix + class Chat: def __init__(self, template: str, messages: List[List[str]]=None, tools=None, tokenizer: PreTrainedTokenizer = None): """ @@ -1026,29 +1227,30 @@ def set_messages(self, messages: List[Dict]): """Set the messages for the chat.""" self.messages = self.convert_to_hf_format_messages(messages) - def prompt(self, add_generation_prompt=False, tools=None) -> str: + def prompt(self, add_generation_prompt=False, tools=None, **kwargs) -> str: """Get the prompt for the chat. Args: add_generation_prompt: Whether to add the generation prompt. tools: The tools to use for the chat. + **kwargs: Additional keyword arguments to pass to the template render method. Returns: The prompt for the chat. """ self.flags['add_generation_prompt'] = add_generation_prompt tools = tools or self.tools - prompt, _, _ = self.template.render(messages=self.messages, tools=tools, add_generation_prompt=add_generation_prompt) + prompt, _, _ = self.template.render(messages=self.messages, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs) return prompt - def prompt_with_mask(self, add_generation_prompt=False, tools=None) -> str: - prompt_with_mask, _, _ = self.template.render_with_mask(messages=self.messages, add_generation_prompt=add_generation_prompt, tools=tools) + def prompt_with_mask(self, add_generation_prompt=False, tools=None, **kwargs) -> str: + prompt_with_mask, _, _ = self.template.render_with_mask(messages=self.messages, add_generation_prompt=add_generation_prompt, tools=tools, **kwargs) return prompt_with_mask def vision_inputs(self) -> List[Any]: return self.template.get_vision_inputs(self.messages) - def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None, processor=None) -> List[int]: + def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None, processor=None, **kwargs) -> List[int]: """Tokenize the messages. Args: @@ -1072,7 +1274,7 @@ def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt= if tools is None: tools = self.tools - return self.template.encode(messages=self.messages, tokenizer=tokenizer, return_tensors="pt", tools=tools, add_generation_prompt=add_generation_prompt, processor=processor) + return self.template.encode(messages=self.messages, tokenizer=tokenizer, return_tensors="pt", tools=tools, add_generation_prompt=add_generation_prompt, processor=processor, **kwargs) def append(self, message: Union[Dict]): self._convert_single_message_to_hf_format(message) @@ -1160,17 +1362,19 @@ def get_template(name: str) -> Template: ) register_template( - Template( + Qwen3Template( name="qwen3", system_template="<|im_start|>system\n{system_message}<|im_end|>\n", + system_template_with_tools="""<|im_start|>system\n{system_message}# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{{"name": , "arguments": }}\n<|im_end|>\n""", user_template="<|im_start|>user\n{content}<|im_end|>\n", assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", + tool_template="<|im_start|>user\n\n{observation}\n<|im_end|>\n", stop_words=["<|im_end|>"], - generation_prompt="<|im_start|>assistant\n", system_policy=SystemPolicy( - use_system=True, use_system_without_system_message=False, + content_processor=lambda system, tools: f"{system}\n\n" if (system != "" and tools) else system, ), + chat_template="{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in content %}\n {%- set reasoning_content = content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- set content = content.split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}", ) ) @@ -1268,7 +1472,7 @@ def get_template(name: str) -> Template: stop_words=["<|eot_id|>"], system_policy=SystemPolicy( use_system=True, - content_processor=lambda x: f"\n{x}", + content_processor=lambda system_message, tools: f"\n{system_message}", ), tool_policy=ToolPolicy( placement=ToolPlacement.SYSTEM, diff --git a/agentfly/templates/utils.py b/agentfly/templates/utils.py index 369b2c2..2761822 100644 --- a/agentfly/templates/utils.py +++ b/agentfly/templates/utils.py @@ -106,6 +106,7 @@ def tokenize_conversation( processor=None, return_tensors="pt", add_generation_prompt=False, + **kwargs, # Additional kwargs for the chat template, e.g. enable_thinking ): """ We want to tokenize the whole conversation. But we can't just simply @@ -122,7 +123,7 @@ def tokenize_conversation( :return: input_ids, attention_mask, labels, action_mask """ chat = Chat(template=template, messages=messages, tokenizer=tokenizer) - inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools, processor=processor) + inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools, processor=processor, **kwargs) if max_length is not None: inputs['input_ids'] = inputs['input_ids'][:, :max_length] @@ -329,16 +330,16 @@ def visualize_jinja_template(tokenizer, messages=None, tools=None, **kwargs): def compare_hf_template(tokenizer, template_name, messages=None, tools=None, add_generation_prompt=False, **kwargs): official_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs) chat = Chat(template_name, messages=messages, tokenizer=tokenizer) - implemented_prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools) + implemented_prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools, **kwargs) is_equal = official_prompt == implemented_prompt - highlighted_prompt = chat.prompt_with_mask(add_generation_prompt=add_generation_prompt, tools=tools) + highlighted_prompt = chat.prompt_with_mask(add_generation_prompt=add_generation_prompt, tools=tools, **kwargs) plain_highlighted_prompt = strip_ansi(highlighted_prompt) is_equal_between_implemented_prompts = implemented_prompt == plain_highlighted_prompt jinja_template = chat.template.jinja_template() official_jinja_prompt = tokenizer.chat_template tokenizer.chat_template = jinja_template - implemented_jinja_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt) + implemented_jinja_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs) is_equal_between_jinja_prompts = implemented_jinja_prompt == implemented_prompt tokenizer.chat_template = official_jinja_prompt return is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt diff --git a/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py b/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py index e69de29..f319630 100644 --- a/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py +++ b/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py @@ -0,0 +1,54 @@ +from .....templates.utils import compare_hf_template +import pytest +from transformers import AutoTokenizer + +@pytest.mark.parametrize("template", ["qwen3"]) +@pytest.mark.parametrize("messages", [ + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": " This is test thinking content. I am fine, thank you."}, + {"role": "user", "content": "Want to play a game?"}, + {"role": "assistant", "content": " This is test thinking content. Sure, what game?"}, + ], + [ + {"role": "user", "content": "Help me to calculate 3 times 5."}, + {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + {"role": "tool", "content": "15"}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": " This is test thinking content. I am fine, thank you."}, + {"role": "user", "content": "What is 3 times 5?"}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, + [ + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + ] +]) +@pytest.mark.parametrize("add_generation_prompt", [True, False]) +@pytest.mark.parametrize("enable_thinking", [True, False]) +def test_chat_template_equal(template, messages, tools, add_generation_prompt, enable_thinking): + # Filter invalid combinations + if add_generation_prompt and messages[-1]['role'] == 'assistant': + return + + template_tokenizer_mapping = { + "qwen3": "Qwen/Qwen3-32B", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True) + + is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt, enable_thinking=enable_thinking) + + print(f"Official prompt:\n\n{official_prompt}") + print(f"Implemented prompt:\n\n{implemented_prompt}") + print(f"Highlighted prompt:\n\n{highlighted_prompt}") + assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nenable_thinking: {enable_thinking}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}" + assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}" + # print(f"Official prompt:\n\n{official_prompt}") + # print(f"Highlighted prompt:\n\n{highlighted_prompt}") + + diff --git a/agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py b/agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py new file mode 100644 index 0000000..6407642 --- /dev/null +++ b/agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py @@ -0,0 +1,69 @@ +""" This file is for testing the tokenization of the templates. The templates should align on following aspects: + - The tokenized prompt should be the same as the one obtained from HF template with all the following options: + - add_generation_prompt + - tools + - We need to observe the labels and action_mask to make sure the the they are correct. + +Since the align for textual prompt is already tested in other files, we only need to test the tokenization of the templates. +""" + +from .....templates.utils import tokenize_conversation +import pytest +from transformers import AutoTokenizer +import torch +from .....templates.templates import Chat + +@pytest.mark.parametrize("template", ["qwen3"]) +@pytest.mark.parametrize("messages", [ + [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "Want to play a game?"}, + {"role": "assistant", "content": "Sure, what game?"}, + ], + [ + {"role": "user", "content": "Help me to calculate 3 times 5."}, + {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''}, + {"role": "tool", "content": "15"}, + ], + [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I am fine, thank you."}, + {"role": "user", "content": "What is 3 times 5?"}, + {"role": "assistant", "content": "15"}, + {"role": "user", "content": "OK, what is 3 times 6?"}, + {"role": "assistant", "content": " This is test thinking content. 18"}, + ], +]) +@pytest.mark.parametrize("tools", [ + None, + [ + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}}, + ] +]) +@pytest.mark.parametrize("add_generation_prompt", [False, True]) +@pytest.mark.parametrize("enable_thinking", [True, False]) +def test_template_tokenize(template, messages, tools, add_generation_prompt, enable_thinking): + template_tokenizer_mapping = { + "qwen3": "Qwen/Qwen3-32B", + } + tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True) + + chat = Chat(template, messages, tools=tools) + prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools, enable_thinking=enable_thinking) + + hf_inputs = tokenizer(prompt, return_tensors="pt") + + implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=4096, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt", enable_thinking=enable_thinking) + + assert torch.equal(hf_inputs["input_ids"], implemented_inputs["input_ids"]), f"""template: {template} +messages: {messages} +tools: {tools} +add_generation_prompt: {add_generation_prompt} +enable_thinking: {enable_thinking} +prompt: {prompt} +implemented_prompt: shape: {implemented_inputs['input_ids'].shape} {tokenizer.decode(implemented_inputs['input_ids'][0], skip_special_tokens=False)} +hf_inputs: shape: {hf_inputs['input_ids'].shape} {tokenizer.decode(hf_inputs['input_ids'][0], skip_special_tokens=False)} +implemented_inputs: {implemented_inputs}""" diff --git a/test_qwen3_template.py b/test_qwen3_template.py new file mode 100644 index 0000000..172228c --- /dev/null +++ b/test_qwen3_template.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +""" +Test script for Qwen3Template implementation +""" + +import sys +import os +sys.path.append(os.path.dirname(os.path.abspath(__file__))) + +from agentfly.templates.templates import Qwen3Template + +def test_qwen3_template(): + """Test the Qwen3Template with various scenarios""" + + # Create a Qwen3Template instance + template = Qwen3Template( + name="qwen3-test", + system_template="<|im_start|>system\n{system_message}<|im_end|>\n", + user_template="<|im_start|>user\n{content}<|im_end|>\n", + assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n", + stop_words=["<|im_end|>"], + generation_prompt="<|im_start|>assistant\n", + ) + + # Test case 1: Basic conversation without thinking + print("=== Test Case 1: Basic conversation without thinking ===") + messages1 = [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you!"} + ] + + prompt1, elements1, roles1 = template.render(messages1, add_generation_prompt=False, enable_thinking=False) + print("Prompt:") + print(prompt1) + print() + + # Test case 2: Conversation with thinking content that should be cleaned + print("=== Test Case 2: Conversation with thinking content (should be cleaned) ===") + messages2 = [ + {"role": "user", "content": "What is 2+2?"}, + {"role": "assistant", "content": "I need to add 2 and 2 together. This is basic arithmetic.The answer is 4."} + ] + + prompt2, elements2, roles2 = template.render(messages2, add_generation_prompt=False, enable_thinking=False) + print("Prompt (thinking content should be removed):") + print(prompt2) + print() + + # Test case 3: With add_generation_prompt=True and enable_thinking=False + print("=== Test Case 3: With generation prompt and enable_thinking=False ===") + messages3 = [ + {"role": "user", "content": "Tell me a joke"} + ] + + prompt3, elements3, roles3 = template.render(messages3, add_generation_prompt=True, enable_thinking=False) + print("Prompt (should include empty think tokens):") + print(prompt3) + print() + + # Test case 4: With add_generation_prompt=True and enable_thinking=True + print("=== Test Case 4: With generation prompt and enable_thinking=True ===") + prompt4, elements4, roles4 = template.render(messages3, add_generation_prompt=True, enable_thinking=True) + print("Prompt (should NOT include empty think tokens):") + print(prompt4) + print() + + # Test case 5: Last message is assistant with enable_thinking=False + print("=== Test Case 5: Last message is assistant with enable_thinking=False ===") + messages5 = [ + {"role": "user", "content": "What's the weather like?"}, + {"role": "assistant", "content": "I don't have access to current weather data."} + ] + + prompt5, elements5, roles5 = template.render(messages5, add_generation_prompt=False, enable_thinking=False) + print("Prompt (last assistant message should have empty think tokens):") + print(prompt5) + print() + + print("All tests completed!") + +if __name__ == "__main__": + test_qwen3_template() +