diff --git a/agentfly/templates/system_policy.py b/agentfly/templates/system_policy.py
index 8ede734..027b900 100644
--- a/agentfly/templates/system_policy.py
+++ b/agentfly/templates/system_policy.py
@@ -5,8 +5,9 @@
@dataclasses.dataclass
class SystemPolicy:
- use_system: bool = True
- use_system_without_system_message: bool = True
+ use_system: bool = True # Global control
+ use_system_without_system_message: bool = True # When no system message is provided, use the system message even it is empty (will use the default one if provided)
+ use_system_with_tools_provided: bool = True # When tools are provided, use the system message with tools even no system message is provided
content_processor: Callable[[str], str] = None
@@ -32,7 +33,7 @@ class Llama32DateProcessor(SystemContentProcessor):
- Format should be 'dd MMM yyyy' (e.g., '15 Dec 2024')
- No external context variables required
"""
- def __call__(self, system_message: str) -> str:
+ def __call__(self, system_message: str, tools: str) -> str:
return f"Cutting Knowledge Date: December 2023\nToday Date: {datetime.datetime.now().strftime('%d %b %Y')}\n\n{system_message}"
def jinja(self) -> str:
diff --git a/agentfly/templates/templates.py b/agentfly/templates/templates.py
index 7e54eff..3c74346 100644
--- a/agentfly/templates/templates.py
+++ b/agentfly/templates/templates.py
@@ -64,8 +64,6 @@ class Template:
system_template_with_tools: str = None
# The system message
system_message: str = ""
- # Stop criteria (the default one is EOS token)
- stop_words: Union[str, List[str]] = None
# Behaviors
# The tool template
tool_template: str = None
@@ -74,6 +72,12 @@ class Template:
user_template_with_tools: str = None
# The assistant template
assistant_template: str = None
+
+
+ # Stop criteria (the default one is EOS token)
+ stop_words: Union[str, List[str]] = None
+ # Generation prompt
+ generation_prompt: str = None
# Global policy
global_policy: "GlobalPolicy" = None
# System message policy
@@ -283,7 +287,7 @@ def _maybe_add_generation_prompt(self, elements: List[str], roles: List[Role]):
"""Append the generation prefix so the model knows to continue
generating an assistant response."""
- generation_prefix = self._encode_generation_prompt()
+ generation_prefix, prefix = self._encode_generation_prompt()
elements.append(generation_prefix)
roles.append(Role.ASSISTANT_PREFIX)
@@ -320,10 +324,14 @@ def _encode_system_tools(self, tools: List[Dict]) -> str:
def _encode_system_message_default(self, tools=None) -> str:
if not self.system_policy.use_system_without_system_message:
- return ""
+ if tools is None:
+ return ""
+ else:
+ # If tools are provided, use the system message with tools
+ pass
if self.system_policy.content_processor is not None:
- system_message = self.system_policy.content_processor(self.system_message)
+ system_message = self.system_policy.content_processor(self.system_message, tools=tools)
else:
system_message = self.system_message
@@ -343,7 +351,7 @@ def _encode_system_message(self, content, tools=None) -> str:
system_message = content[0]['text']
if self.system_policy.content_processor is not None:
- system_message = self.system_policy.content_processor(system_message)
+ system_message = self.system_policy.content_processor(system_message, tools=tools)
if tools is None:
return self.system_template.format(system_message=system_message)
@@ -414,27 +422,36 @@ def _encode_tool_message(self, content) -> str:
return tool_message
def _encode_generation_prompt(self) -> str:
+ # Use generation prompt if it is set
if "{content}" in self.assistant_template:
prefix = self.assistant_template.split("{content}")[0]
- return prefix
+ if self.generation_prompt:
+ generation_prompt = self.generation_prompt
+ else:
+ generation_prompt = prefix
else:
raise ValueError(f"Assistant template {self.assistant_template} does not contain {{content}}")
+ return generation_prompt, prefix
+
+
def _split_assistant_message(self, assistant_message: str) -> List[str]:
# Split the assistant message into generation prefix, content, and generation suffix
- generation_prefix = self._encode_generation_prompt()
- assert assistant_message.startswith(generation_prefix), f"Assistant message {assistant_message} does not start with {generation_prefix}"
- content_suffix = assistant_message[len(generation_prefix):]
+ generation_prefix, prefix = self._encode_generation_prompt()
+ assert assistant_message.startswith(prefix), f"Assistant message {assistant_message} does not start with {prefix}"
+ content_suffix = assistant_message[len(prefix):]
+ content = content_suffix
+ suffix = ""
for stop_word in self.stop_words:
if stop_word in content_suffix:
stop_word_index = content_suffix.index(stop_word)
content = content_suffix[:stop_word_index+len(stop_word)]
suffix = content_suffix[stop_word_index+len(stop_word):]
break
- return generation_prefix, content, suffix
+ return prefix, content, suffix
- def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str:
+ def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None, **kwargs) -> str:
"""Encode the messages to token ids.
Args:
@@ -453,15 +470,15 @@ def encode(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_te
if self.supports_vision():
# Use vision-aware encoding with proper alignment
- return self._encode_with_vision_processor(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt, processor=processor)
+ return self._encode_with_vision_processor(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt, processor=processor, **kwargs)
else:
# Use standard encoding
- return self._encode_standard(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt)
+ return self._encode_standard(messages, tokenizer, return_tensors, tools, add_generation_prompt=add_generation_prompt, **kwargs)
- def _encode_standard(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False) -> str:
+ def _encode_standard(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, **kwargs) -> str:
Logger.debug(f"[Template] Encoding standard for template: {self.name}")
"""Standard encoding without vision support"""
- prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt)
+ prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs)
elements, mask_flags = self._postprocess_elements(elements, roles)
input_ids = []
attention_mask = []
@@ -498,7 +515,7 @@ def _encode_standard(self, messages: List[Dict], tokenizer: PreTrainedTokenizer,
inputs = {k: torch.tensor([v]) for k, v in inputs.items()}
return inputs
- def _encode_with_vision_processor(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None) -> str:
+ def _encode_with_vision_processor(self, messages: List[Dict], tokenizer: PreTrainedTokenizer, return_tensors: str = None, tools=None, add_generation_prompt=False, processor=None, **kwargs) -> str:
Logger.debug(f"[Template] Encoding with vision processor for template: {self.name}")
"""Encode with vision processor handling proper alignment"""
from .vision_processor import get_processor
@@ -510,7 +527,7 @@ def _encode_with_vision_processor(self, messages: List[Dict], tokenizer: PreTrai
raise ValueError(f"No vision processor registered for template: {self.name}")
# Get base prompt and mask information
- prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt)
+ prompt, elements, roles = self.render(messages, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs)
elements, mask_flags = self._postprocess_elements(elements, roles)
# Extract vision inputs
@@ -676,7 +693,7 @@ def _jinja_header_constants(self) -> List[str]:
# Compute default system message considering content processor
if self.system_policy.content_processor is not None:
# Apply content processor to system message
- processed_system_message = self.system_policy.content_processor(self.system_message)
+ processed_system_message = self.system_policy.content_processor(self.system_message, tools=None) # TODO: tools is not used here, but we need to pass it for consistency
default_system = self.system_template.format(system_message=processed_system_message)
else:
default_system = self.system_template.format(system_message=self.system_message)
@@ -907,9 +924,9 @@ def _jinja_generation_block(self) -> List[str]:
]
- def render_with_mask(self, messages: List[Dict], add_generation_prompt: bool = False, tools=None):
+ def render_with_mask(self, messages: List[Dict], add_generation_prompt: bool = False, tools=None, **kwargs):
from termcolor import colored
- prompt, elements, roles = self.render(messages, add_generation_prompt=add_generation_prompt, tools=tools)
+ prompt, elements, roles = self.render(messages, add_generation_prompt=add_generation_prompt, tools=tools, **kwargs)
elements, mask_flags = self._postprocess_elements(elements, roles)
@@ -918,7 +935,7 @@ def render_with_mask(self, messages: List[Dict], add_generation_prompt: bool = F
if mask_flag:
prompt += colored(element, "red")
else:
- prompt += element
+ prompt += colored(element, "green")
return prompt, elements, mask_flags
def set_system_message(self, system_message: str):
@@ -927,7 +944,7 @@ def set_system_message(self, system_message: str):
def copy(self):
- return Template(
+ return self.__class__(
name=self.name,
system_template=self.system_template,
system_template_with_tools=self.system_template_with_tools,
@@ -937,6 +954,7 @@ def copy(self):
assistant_template=self.assistant_template,
tool_template=self.tool_template,
stop_words=self.stop_words,
+ generation_prompt=self.generation_prompt,
vision_start=self.vision_start,
vision_end=self.vision_end,
image_token=self.image_token,
@@ -959,6 +977,203 @@ def dict(self):
"video_token": self.video_token,
}
+class Qwen3Template(Template):
+ def render(self, messages: List[Dict], tools=None, add_generation_prompt: bool = False, enable_thinking: bool = False) -> str:
+ """Render the Qwen3 template with special thinking logic.
+
+ Args:
+ messages: The list of messages
+ tools: The list of tools
+ add_generation_prompt: Whether to add the generation prefix
+ enable_thinking: Whether to enable thinking mode
+
+ Returns:
+ prompt: The final prompt string
+ elements: The list of string *elements* that compose the prompt
+ roles: The corresponding list of *roles* (used by downstream post-processing)
+ """
+
+ # Step 1 – decide tool placement & clone messages
+ work_messages, tools_str, insert_tools_idx = self._insert_tools(messages, tools)
+
+ # Step 2 – clean think content from all assistant messages except the last one
+ work_messages = self._clean_think_content(work_messages)
+
+ # Step 2.5 – reformat think content in the last assistant message if it exists
+ if work_messages and work_messages[-1].get("role") == "assistant":
+ work_messages = self._reformat_last_assistant_think_content(work_messages)
+
+ # Step 3 – encode each conversation turn to text tokens
+ elements, roles = self._encode_turns(work_messages, tools_str, insert_tools_idx)
+
+ # Step 4 – handle special generation prompt logic for Qwen3
+ if add_generation_prompt:
+ self._maybe_add_generation_prompt_qwen3(elements, roles, enable_thinking, work_messages)
+ elif work_messages and work_messages[-1].get("role") == "assistant":
+ # Add empty think tokens to the last assistant message if it doesn't already have think tags
+ self._add_empty_think_to_last_assistant(elements, roles, work_messages)
+
+ # Concatenate the prompt
+ prompt = "".join(elements)
+ return prompt, elements, roles
+
+ def _clean_think_content(self, messages: List[Dict]) -> List[Dict]:
+ """Remove all think content (...) from assistant messages and reformat existing think content."""
+ cleaned_messages = []
+ for i, message in enumerate(messages):
+ if message.get("role") == "assistant" and i != len(messages) - 1:
+ cleaned_message = message.copy()
+ content = message["content"]
+
+ if isinstance(content, str):
+ # Remove think content from string
+ cleaned_content = self._remove_think_tags(content)
+ else:
+ # Handle list content format
+ cleaned_content = []
+ for item in content:
+ if item["type"] == "text":
+ cleaned_text = self._remove_think_tags(item["text"])
+ cleaned_content.append({"type": "text", "text": cleaned_text})
+ else:
+ cleaned_content.append(item)
+
+ cleaned_message["content"] = cleaned_content
+ cleaned_messages.append(cleaned_message)
+ else:
+ cleaned_messages.append(message)
+
+ return cleaned_messages
+
+ def _remove_think_tags(self, text: str) -> str:
+ """Remove ... tags from text."""
+ import re
+ # Remove ... tags and their content
+ pattern = r'.*?'
+ return re.sub(pattern, '', text, flags=re.DOTALL)
+
+ def _has_think_tags(self, text: str) -> bool:
+ """Check if text contains and tags."""
+ return '' in text and '' in text
+
+ def _reformat_think_content(self, text: str) -> str:
+ """Reformat think content to ensure each think token ends with two newlines."""
+ import re
+
+ def replace_think_content(match):
+ think_content = match.group(1)
+ # Ensure the think content ends with exactly two newlines
+ think_content = think_content.rstrip('\n')
+ return f'\n{think_content}\n\n\n'
+
+ # Find and replace think tags, ensuring proper formatting
+ pattern = r'(.*?)'
+ return re.sub(pattern, replace_think_content, text, flags=re.DOTALL)
+
+ def _reformat_last_assistant_think_content(self, messages: List[Dict]) -> List[Dict]:
+ """Reformat think content in the last assistant message."""
+ if not messages or messages[-1].get("role") != "assistant":
+ return messages
+
+ messages = messages.copy()
+ last_message = messages[-1].copy()
+ content = last_message["content"]
+
+ if isinstance(content, str):
+ # Reformat think content in string
+ last_message["content"] = self._reformat_think_content(content)
+ else:
+ # Handle list content format
+ reformed_content = []
+ for item in content:
+ if item["type"] == "text":
+ reformed_text = self._reformat_think_content(item["text"])
+ reformed_content.append({"type": "text", "text": reformed_text})
+ else:
+ reformed_content.append(item)
+ last_message["content"] = reformed_content
+
+ messages[-1] = last_message
+ return messages
+
+ def _maybe_add_generation_prompt_qwen3(self, elements: List[str], roles: List[Role], enable_thinking: bool, work_messages: List[Dict]):
+ """Append the generation prefix with special Qwen3 thinking logic."""
+ if enable_thinking:
+ # Use standard generation prompt
+ generation_prefix, prefix = self._encode_generation_prompt()
+ elements.append(generation_prefix)
+ roles.append(Role.ASSISTANT_PREFIX)
+ else:
+ # Check if the last message has think tags
+ has_existing_think = False
+ if work_messages and work_messages[-1].get("role") == "assistant":
+ content = work_messages[-1]["content"]
+ if isinstance(content, str):
+ has_existing_think = self._has_think_tags(content)
+ elif isinstance(content, list):
+ for item in content:
+ if item.get("type") == "text" and self._has_think_tags(item["text"]):
+ has_existing_think = True
+ break
+
+ generation_prefix, prefix = self._encode_generation_prompt()
+ if has_existing_think:
+ # Don't add empty think tokens if think tags already exist
+ elements.append(generation_prefix)
+ else:
+ # Add empty think tokens after the generation prefix
+ elements.append(generation_prefix + "\n\n\n\n")
+ roles.append(Role.ASSISTANT_PREFIX)
+
+ def _add_empty_think_to_last_assistant(self, elements: List[str], roles: List[Role], work_messages: List[Dict]):
+ """Add empty think tokens to the last assistant message if it doesn't already have think tags."""
+ if not elements or not roles or not work_messages:
+ return
+
+ # Check if the last message has think tags
+ has_existing_think = False
+ if work_messages[-1].get("role") == "assistant":
+ content = work_messages[-1]["content"]
+ if isinstance(content, str):
+ has_existing_think = self._has_think_tags(content)
+ elif isinstance(content, list):
+ for item in content:
+ if item.get("type") == "text" and self._has_think_tags(item["text"]):
+ has_existing_think = True
+ break
+
+ # Only add empty think tokens if no existing think tags
+ if not has_existing_think:
+ generation_prefix, prefix = self._encode_generation_prompt()
+
+ # Find the last assistant element
+ for i in range(len(elements) - 1, -1, -1):
+ if roles[i] == Role.ASSISTANT:
+ # Add empty think tokens at the start of the assistant message
+ elements[i] = prefix + "\n\n\n\n" + elements[i][len(prefix):]
+ break
+
+ def _split_assistant_message(self, assistant_message: str) -> List[str]:
+ # Split the assistant message into generation prefix, content, and generation suffix
+ generation_prefix, prefix = self._encode_generation_prompt()
+ assert assistant_message.startswith(prefix), f"Assistant message {assistant_message} does not start with {prefix}"
+
+ # We need to detect whether the assistant message starts with empty think tokens
+ # If so, we need to set empty think tokens as non-assistant message
+ if assistant_message.startswith(prefix + "\n\n\n\n"):
+ prefix = prefix + "\n\n\n\n"
+
+ content_suffix = assistant_message[len(prefix):]
+ content = content_suffix
+ suffix = ""
+ for stop_word in self.stop_words:
+ if stop_word in content_suffix:
+ stop_word_index = content_suffix.index(stop_word)
+ content = content_suffix[:stop_word_index+len(stop_word)]
+ suffix = content_suffix[stop_word_index+len(stop_word):]
+ break
+ return prefix, content, suffix
+
class Chat:
def __init__(self, template: str, messages: List[List[str]]=None, tools=None, tokenizer: PreTrainedTokenizer = None):
"""
@@ -1012,29 +1227,30 @@ def set_messages(self, messages: List[Dict]):
"""Set the messages for the chat."""
self.messages = self.convert_to_hf_format_messages(messages)
- def prompt(self, add_generation_prompt=False, tools=None) -> str:
+ def prompt(self, add_generation_prompt=False, tools=None, **kwargs) -> str:
"""Get the prompt for the chat.
Args:
add_generation_prompt: Whether to add the generation prompt.
tools: The tools to use for the chat.
+ **kwargs: Additional keyword arguments to pass to the template render method.
Returns:
The prompt for the chat.
"""
self.flags['add_generation_prompt'] = add_generation_prompt
tools = tools or self.tools
- prompt, _, _ = self.template.render(messages=self.messages, tools=tools, add_generation_prompt=add_generation_prompt)
+ prompt, _, _ = self.template.render(messages=self.messages, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs)
return prompt
- def prompt_with_mask(self, add_generation_prompt=False, tools=None) -> str:
- prompt_with_mask, _, _ = self.template.render_with_mask(messages=self.messages, add_generation_prompt=add_generation_prompt, tools=tools)
+ def prompt_with_mask(self, add_generation_prompt=False, tools=None, **kwargs) -> str:
+ prompt_with_mask, _, _ = self.template.render_with_mask(messages=self.messages, add_generation_prompt=add_generation_prompt, tools=tools, **kwargs)
return prompt_with_mask
def vision_inputs(self) -> List[Any]:
return self.template.get_vision_inputs(self.messages)
- def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None, processor=None) -> List[int]:
+ def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=False, tools=None, processor=None, **kwargs) -> List[int]:
"""Tokenize the messages.
Args:
@@ -1058,7 +1274,7 @@ def tokenize(self, tokenizer: PreTrainedTokenizer = None, add_generation_prompt=
if tools is None:
tools = self.tools
- return self.template.encode(messages=self.messages, tokenizer=tokenizer, return_tensors="pt", tools=tools, add_generation_prompt=add_generation_prompt, processor=processor)
+ return self.template.encode(messages=self.messages, tokenizer=tokenizer, return_tensors="pt", tools=tools, add_generation_prompt=add_generation_prompt, processor=processor, **kwargs)
def append(self, message: Union[Dict]):
self._convert_single_message_to_hf_format(message)
@@ -1145,6 +1361,23 @@ def get_template(name: str) -> Template:
)
)
+register_template(
+ Qwen3Template(
+ name="qwen3",
+ system_template="<|im_start|>system\n{system_message}<|im_end|>\n",
+ system_template_with_tools="""<|im_start|>system\n{system_message}# Tools\n\nYou may call one or more functions to assist with the user query.\n\nYou are provided with function signatures within XML tags:\n\n{tools}\n\n\nFor each function call, return a json object with function name and arguments within XML tags:\n\n{{"name": , "arguments": }}\n<|im_end|>\n""",
+ user_template="<|im_start|>user\n{content}<|im_end|>\n",
+ assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n",
+ tool_template="<|im_start|>user\n\n{observation}\n<|im_end|>\n",
+ stop_words=["<|im_end|>"],
+ system_policy=SystemPolicy(
+ use_system_without_system_message=False,
+ content_processor=lambda system, tools: f"{system}\n\n" if (system != "" and tools) else system,
+ ),
+ chat_template="{%- if tools %}\n {{- '<|im_start|>system\\n' }}\n {%- if messages[0].role == 'system' %}\n {{- messages[0].content + '\\n\\n' }}\n {%- endif %}\n {{- \"# Tools\\n\\nYou may call one or more functions to assist with the user query.\\n\\nYou are provided with function signatures within XML tags:\\n\" }}\n {%- for tool in tools %}\n {{- \"\\n\" }}\n {{- tool | tojson }}\n {%- endfor %}\n {{- \"\\n\\n\\nFor each function call, return a json object with function name and arguments within XML tags:\\n\\n{\\\"name\\\": , \\\"arguments\\\": }\\n<|im_end|>\\n\" }}\n{%- else %}\n {%- if messages[0].role == 'system' %}\n {{- '<|im_start|>system\\n' + messages[0].content + '<|im_end|>\\n' }}\n {%- endif %}\n{%- endif %}\n{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}\n{%- for message in messages[::-1] %}\n {%- set index = (messages|length - 1) - loop.index0 %}\n {%- if ns.multi_step_tool and message.role == \"user\" and message.content is string and not(message.content.startswith('') and message.content.endswith('')) %}\n {%- set ns.multi_step_tool = false %}\n {%- set ns.last_query_index = index %}\n {%- endif %}\n{%- endfor %}\n{%- for message in messages %}\n {%- if message.content is string %}\n {%- set content = message.content %}\n {%- else %}\n {%- set content = '' %}\n {%- endif %}\n {%- if (message.role == \"user\") or (message.role == \"system\" and not loop.first) %}\n {{- '<|im_start|>' + message.role + '\\n' + content + '<|im_end|>' + '\\n' }}\n {%- elif message.role == \"assistant\" %}\n {%- set reasoning_content = '' %}\n {%- if message.reasoning_content is string %}\n {%- set reasoning_content = message.reasoning_content %}\n {%- else %}\n {%- if '' in content %}\n {%- set reasoning_content = content.split('')[0].rstrip('\\n').split('')[-1].lstrip('\\n') %}\n {%- set content = content.split('')[-1].lstrip('\\n') %}\n {%- endif %}\n {%- endif %}\n {%- if loop.index0 > ns.last_query_index %}\n {%- if loop.last or (not loop.last and reasoning_content) %}\n {{- '<|im_start|>' + message.role + '\\n\\n' + reasoning_content.strip('\\n') + '\\n\\n\\n' + content.lstrip('\\n') }}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- else %}\n {{- '<|im_start|>' + message.role + '\\n' + content }}\n {%- endif %}\n {%- if message.tool_calls %}\n {%- for tool_call in message.tool_calls %}\n {%- if (loop.first and content) or (not loop.first) %}\n {{- '\\n' }}\n {%- endif %}\n {%- if tool_call.function %}\n {%- set tool_call = tool_call.function %}\n {%- endif %}\n {{- '\\n{\"name\": \"' }}\n {{- tool_call.name }}\n {{- '\", \"arguments\": ' }}\n {%- if tool_call.arguments is string %}\n {{- tool_call.arguments }}\n {%- else %}\n {{- tool_call.arguments | tojson }}\n {%- endif %}\n {{- '}\\n' }}\n {%- endfor %}\n {%- endif %}\n {{- '<|im_end|>\\n' }}\n {%- elif message.role == \"tool\" %}\n {%- if loop.first or (messages[loop.index0 - 1].role != \"tool\") %}\n {{- '<|im_start|>user' }}\n {%- endif %}\n {{- '\\n\\n' }}\n {{- content }}\n {{- '\\n' }}\n {%- if loop.last or (messages[loop.index0 + 1].role != \"tool\") %}\n {{- '<|im_end|>\\n' }}\n {%- endif %}\n {%- endif %}\n{%- endfor %}\n{%- if add_generation_prompt %}\n {{- '<|im_start|>assistant\\n' }}\n {%- if enable_thinking is defined and enable_thinking is false %}\n {{- '\\n\\n\\n\\n' }}\n {%- endif %}\n{%- endif %}",
+ )
+)
+
register_template(
Template(
name="deepseek-prover",
@@ -1239,7 +1472,7 @@ def get_template(name: str) -> Template:
stop_words=["<|eot_id|>"],
system_policy=SystemPolicy(
use_system=True,
- content_processor=lambda x: f"\n{x}",
+ content_processor=lambda system_message, tools: f"\n{system_message}",
),
tool_policy=ToolPolicy(
placement=ToolPlacement.SYSTEM,
@@ -1249,5 +1482,22 @@ def get_template(name: str) -> Template:
)
)
+register_template(
+ Template(
+ name="deepseek-r1-distill-qwen",
+ user_template="<|User|>{content}",
+ assistant_template="<|Assistant|>{content}<|end▁of▁sentence|>",
+ stop_words=["<|end▁of▁sentence|>"],
+ generation_prompt="<|Assistant|>\n",
+ global_policy=GlobalPolicy(
+ prefix="<|begin▁of▁sentence|>"
+ ),
+ system_policy=SystemPolicy(
+ use_system=False,
+ use_system_without_system_message=False,
+ ),
+ )
+)
+
if __name__ == "__main__":
pass
\ No newline at end of file
diff --git a/agentfly/templates/utils.py b/agentfly/templates/utils.py
index 441f705..2761822 100644
--- a/agentfly/templates/utils.py
+++ b/agentfly/templates/utils.py
@@ -106,6 +106,7 @@ def tokenize_conversation(
processor=None,
return_tensors="pt",
add_generation_prompt=False,
+ **kwargs, # Additional kwargs for the chat template, e.g. enable_thinking
):
"""
We want to tokenize the whole conversation. But we can't just simply
@@ -122,7 +123,7 @@ def tokenize_conversation(
:return: input_ids, attention_mask, labels, action_mask
"""
chat = Chat(template=template, messages=messages, tokenizer=tokenizer)
- inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools, processor=processor)
+ inputs = chat.tokenize(tokenizer, add_generation_prompt=add_generation_prompt, tools=tools, processor=processor, **kwargs)
if max_length is not None:
inputs['input_ids'] = inputs['input_ids'][:, :max_length]
@@ -326,19 +327,21 @@ def visualize_jinja_template(tokenizer, messages=None, tools=None, **kwargs):
prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, **kwargs)
print(prompt)
-def compare_hf_template(tokenizer, template_name, messages=None, tools=None, add_generation_prompt=False):
- official_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt)
- chat = Chat(template_name, messages, tokenizer)
- implemented_prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools)
+def compare_hf_template(tokenizer, template_name, messages=None, tools=None, add_generation_prompt=False, **kwargs):
+ official_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs)
+ chat = Chat(template_name, messages=messages, tokenizer=tokenizer)
+ implemented_prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools, **kwargs)
is_equal = official_prompt == implemented_prompt
- highlighted_prompt = chat.prompt_with_mask(add_generation_prompt=add_generation_prompt, tools=tools)
+ highlighted_prompt = chat.prompt_with_mask(add_generation_prompt=add_generation_prompt, tools=tools, **kwargs)
plain_highlighted_prompt = strip_ansi(highlighted_prompt)
is_equal_between_implemented_prompts = implemented_prompt == plain_highlighted_prompt
jinja_template = chat.template.jinja_template()
+ official_jinja_prompt = tokenizer.chat_template
tokenizer.chat_template = jinja_template
- implemented_jinja_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt)
+ implemented_jinja_prompt = tokenizer.apply_chat_template(messages, tokenize=False, tools=tools, add_generation_prompt=add_generation_prompt, **kwargs)
is_equal_between_jinja_prompts = implemented_jinja_prompt == implemented_prompt
+ tokenizer.chat_template = official_jinja_prompt
return is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt
diff --git a/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py b/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py
index e69de29..f319630 100644
--- a/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py
+++ b/agentfly/tests/unit/agents/templates/test_qwen3_prompt.py
@@ -0,0 +1,54 @@
+from .....templates.utils import compare_hf_template
+import pytest
+from transformers import AutoTokenizer
+
+@pytest.mark.parametrize("template", ["qwen3"])
+@pytest.mark.parametrize("messages", [
+ [
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": " This is test thinking content. I am fine, thank you."},
+ {"role": "user", "content": "Want to play a game?"},
+ {"role": "assistant", "content": " This is test thinking content. Sure, what game?"},
+ ],
+ [
+ {"role": "user", "content": "Help me to calculate 3 times 5."},
+ {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
+ {"role": "tool", "content": "15"},
+ ],
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": " This is test thinking content. I am fine, thank you."},
+ {"role": "user", "content": "What is 3 times 5?"},
+ ],
+])
+@pytest.mark.parametrize("tools", [
+ None,
+ [
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ ]
+])
+@pytest.mark.parametrize("add_generation_prompt", [True, False])
+@pytest.mark.parametrize("enable_thinking", [True, False])
+def test_chat_template_equal(template, messages, tools, add_generation_prompt, enable_thinking):
+ # Filter invalid combinations
+ if add_generation_prompt and messages[-1]['role'] == 'assistant':
+ return
+
+ template_tokenizer_mapping = {
+ "qwen3": "Qwen/Qwen3-32B",
+ }
+ tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True)
+
+ is_equal, is_equal_between_implemented_prompts, is_equal_between_jinja_prompts, official_prompt, implemented_prompt, implemented_jinja_prompt, highlighted_prompt = compare_hf_template(tokenizer, template, messages=messages, tools=tools,add_generation_prompt=add_generation_prompt, enable_thinking=enable_thinking)
+
+ print(f"Official prompt:\n\n{official_prompt}")
+ print(f"Implemented prompt:\n\n{implemented_prompt}")
+ print(f"Highlighted prompt:\n\n{highlighted_prompt}")
+ assert is_equal, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nenable_thinking: {enable_thinking}\n\nOfficial prompt:\n\n{official_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}"
+ assert is_equal_between_jinja_prompts, f"Template: {template}\n\nMessages: {messages}\n\ntools: {tools}\n\nadd_generation_prompt: {add_generation_prompt}\n\nImplemented prompt:\n\n{implemented_prompt}\n\nJinja prompt:\n\n{implemented_jinja_prompt}"
+ # print(f"Official prompt:\n\n{official_prompt}")
+ # print(f"Highlighted prompt:\n\n{highlighted_prompt}")
+
+
diff --git a/agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py b/agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py
new file mode 100644
index 0000000..6407642
--- /dev/null
+++ b/agentfly/tests/unit/agents/templates/test_qwen3_tokenize.py
@@ -0,0 +1,69 @@
+""" This file is for testing the tokenization of the templates. The templates should align on following aspects:
+ - The tokenized prompt should be the same as the one obtained from HF template with all the following options:
+ - add_generation_prompt
+ - tools
+ - We need to observe the labels and action_mask to make sure the the they are correct.
+
+Since the align for textual prompt is already tested in other files, we only need to test the tokenization of the templates.
+"""
+
+from .....templates.utils import tokenize_conversation
+import pytest
+from transformers import AutoTokenizer
+import torch
+from .....templates.templates import Chat
+
+@pytest.mark.parametrize("template", ["qwen3"])
+@pytest.mark.parametrize("messages", [
+ [
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I am fine, thank you."},
+ {"role": "user", "content": "Want to play a game?"},
+ {"role": "assistant", "content": "Sure, what game?"},
+ ],
+ [
+ {"role": "user", "content": "Help me to calculate 3 times 5."},
+ {"role": "assistant", "content": '''{"name": "multiply", "arguments": {"x": 3, "y": 5}}'''},
+ {"role": "tool", "content": "15"},
+ ],
+ [
+ {"role": "system", "content": "You are a helpful assistant."},
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I am fine, thank you."},
+ {"role": "user", "content": "What is 3 times 5?"},
+ {"role": "assistant", "content": "15"},
+ {"role": "user", "content": "OK, what is 3 times 6?"},
+ {"role": "assistant", "content": " This is test thinking content. 18"},
+ ],
+])
+@pytest.mark.parametrize("tools", [
+ None,
+ [
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ {"type": "function", "function": {"name": "multiply", "description": "A function that multiplies two numbers", "parameters": {"type": "object", "properties": {"x": {"type": "number", "description": "The first number to multiply"}, "y": {"type": "number", "description": "The second number to multiply"}}, "required": ["x", "y"]}}},
+ ]
+])
+@pytest.mark.parametrize("add_generation_prompt", [False, True])
+@pytest.mark.parametrize("enable_thinking", [True, False])
+def test_template_tokenize(template, messages, tools, add_generation_prompt, enable_thinking):
+ template_tokenizer_mapping = {
+ "qwen3": "Qwen/Qwen3-32B",
+ }
+ tokenizer = AutoTokenizer.from_pretrained(template_tokenizer_mapping[template], trust_remote_code=True)
+
+ chat = Chat(template, messages, tools=tools)
+ prompt = chat.prompt(add_generation_prompt=add_generation_prompt, tools=tools, enable_thinking=enable_thinking)
+
+ hf_inputs = tokenizer(prompt, return_tensors="pt")
+
+ implemented_inputs = tokenize_conversation(messages, tokenizer, template, max_length=4096, tools=tools, add_generation_prompt=add_generation_prompt, return_tensors="pt", enable_thinking=enable_thinking)
+
+ assert torch.equal(hf_inputs["input_ids"], implemented_inputs["input_ids"]), f"""template: {template}
+messages: {messages}
+tools: {tools}
+add_generation_prompt: {add_generation_prompt}
+enable_thinking: {enable_thinking}
+prompt: {prompt}
+implemented_prompt: shape: {implemented_inputs['input_ids'].shape} {tokenizer.decode(implemented_inputs['input_ids'][0], skip_special_tokens=False)}
+hf_inputs: shape: {hf_inputs['input_ids'].shape} {tokenizer.decode(hf_inputs['input_ids'][0], skip_special_tokens=False)}
+implemented_inputs: {implemented_inputs}"""
diff --git a/test_qwen3_template.py b/test_qwen3_template.py
new file mode 100644
index 0000000..172228c
--- /dev/null
+++ b/test_qwen3_template.py
@@ -0,0 +1,83 @@
+#!/usr/bin/env python3
+"""
+Test script for Qwen3Template implementation
+"""
+
+import sys
+import os
+sys.path.append(os.path.dirname(os.path.abspath(__file__)))
+
+from agentfly.templates.templates import Qwen3Template
+
+def test_qwen3_template():
+ """Test the Qwen3Template with various scenarios"""
+
+ # Create a Qwen3Template instance
+ template = Qwen3Template(
+ name="qwen3-test",
+ system_template="<|im_start|>system\n{system_message}<|im_end|>\n",
+ user_template="<|im_start|>user\n{content}<|im_end|>\n",
+ assistant_template="<|im_start|>assistant\n{content}<|im_end|>\n",
+ stop_words=["<|im_end|>"],
+ generation_prompt="<|im_start|>assistant\n",
+ )
+
+ # Test case 1: Basic conversation without thinking
+ print("=== Test Case 1: Basic conversation without thinking ===")
+ messages1 = [
+ {"role": "user", "content": "Hello, how are you?"},
+ {"role": "assistant", "content": "I'm doing well, thank you!"}
+ ]
+
+ prompt1, elements1, roles1 = template.render(messages1, add_generation_prompt=False, enable_thinking=False)
+ print("Prompt:")
+ print(prompt1)
+ print()
+
+ # Test case 2: Conversation with thinking content that should be cleaned
+ print("=== Test Case 2: Conversation with thinking content (should be cleaned) ===")
+ messages2 = [
+ {"role": "user", "content": "What is 2+2?"},
+ {"role": "assistant", "content": "I need to add 2 and 2 together. This is basic arithmetic.The answer is 4."}
+ ]
+
+ prompt2, elements2, roles2 = template.render(messages2, add_generation_prompt=False, enable_thinking=False)
+ print("Prompt (thinking content should be removed):")
+ print(prompt2)
+ print()
+
+ # Test case 3: With add_generation_prompt=True and enable_thinking=False
+ print("=== Test Case 3: With generation prompt and enable_thinking=False ===")
+ messages3 = [
+ {"role": "user", "content": "Tell me a joke"}
+ ]
+
+ prompt3, elements3, roles3 = template.render(messages3, add_generation_prompt=True, enable_thinking=False)
+ print("Prompt (should include empty think tokens):")
+ print(prompt3)
+ print()
+
+ # Test case 4: With add_generation_prompt=True and enable_thinking=True
+ print("=== Test Case 4: With generation prompt and enable_thinking=True ===")
+ prompt4, elements4, roles4 = template.render(messages3, add_generation_prompt=True, enable_thinking=True)
+ print("Prompt (should NOT include empty think tokens):")
+ print(prompt4)
+ print()
+
+ # Test case 5: Last message is assistant with enable_thinking=False
+ print("=== Test Case 5: Last message is assistant with enable_thinking=False ===")
+ messages5 = [
+ {"role": "user", "content": "What's the weather like?"},
+ {"role": "assistant", "content": "I don't have access to current weather data."}
+ ]
+
+ prompt5, elements5, roles5 = template.render(messages5, add_generation_prompt=False, enable_thinking=False)
+ print("Prompt (last assistant message should have empty think tokens):")
+ print(prompt5)
+ print()
+
+ print("All tests completed!")
+
+if __name__ == "__main__":
+ test_qwen3_template()
+