| import json |
| import re |
| import warnings |
| from abc import ABC, abstractmethod |
| from typing import Dict, List, Tuple |
|
|
| import torch |
| from transformers import PreTrainedTokenizer |
|
|
| from .template import ChatTemplate |
|
|
| __all__ = ["GeneralParser", "HarmonyParser"] |
|
|
|
|
| class Parser(ABC): |
|
|
| def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): |
| self.tokenizer = tokenizer |
| self.chat_template = chat_template |
|
|
| @abstractmethod |
| def parse( |
| self, conversation: "Conversation", max_length: int |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Parse the conversation into a list of tensors. |
| |
| Args: |
| conversation: The conversation to parse. |
| |
| Returns: |
| A list of tensors: [input_ids, loss_mask] |
| """ |
|
|
|
|
| _harmony_encoding = None |
|
|
|
|
| class GeneralParser(Parser): |
|
|
| def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): |
| super().__init__(tokenizer, chat_template) |
| self.system_prompt = chat_template.system_prompt |
| self.user_message_separator = f"{chat_template.end_of_turn_token}" |
| self.assistant_message_separator = f"{chat_template.assistant_header}" |
| self.set_assistant_pattern(chat_template) |
|
|
| def apply_chat_template(self, messages, **kwargs) -> str: |
| conversation = self.tokenizer.apply_chat_template( |
| messages, tokenize=False, add_generation_prompt=False, **kwargs |
| ) |
| return conversation |
|
|
| def set_assistant_pattern(self, chat_template: ChatTemplate): |
| if chat_template.assistant_pattern_type == "longcat": |
| self.assistant_pattern = ( |
| re.escape(self.assistant_message_separator) |
| + r"([\s\S]*?(?:" |
| + re.escape("[Round ") |
| + r"\d+" |
| + re.escape("] USER:") |
| + "|$))" |
| ) |
| else: |
| self.assistant_pattern = ( |
| re.escape(self.assistant_message_separator) |
| + r"([\s\S]*?(?:" |
| + re.escape(self.chat_template.end_of_turn_token) |
| + "|$))" |
| ) |
|
|
| def parse( |
| self, |
| conversation: "Conversation", |
| max_length: int, |
| preformatted: bool = False, |
| train_only_last_turn: bool = False, |
| **kwargs, |
| ) -> Dict[str, List[torch.Tensor]]: |
| if not preformatted: |
| messages = [] |
|
|
| if conversation[0]["role"] == "system": |
| warnings.warn( |
| f"The first message is from system, we will use the system prompt from the data and ignore the system prompt from the template" |
| ) |
| messages.append( |
| {"role": "system", "content": conversation[0]["content"]} |
| ) |
| conversation = conversation[1:] |
| else: |
| if self.system_prompt: |
| messages.append({"role": "system", "content": self.system_prompt}) |
|
|
| for j, sentence in enumerate(conversation): |
| role = sentence["role"] |
| if j == 0: |
| if role != "user": |
| warnings.warn( |
| f"Conversation must start with a 'user' role, but found '{role}'. Conversation truncated." |
| ) |
| break |
| else: |
| prev_role = conversation[j - 1]["role"] |
| if role == "tool" and prev_role not in ["assistant", "tool"]: |
| warnings.warn( |
| f"A 'tool' message must follow an 'assistant' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated." |
| ) |
| break |
| if role == "assistant" and prev_role not in ["user", "tool"]: |
| warnings.warn( |
| f"An 'assistant' message must follow a 'user' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated." |
| ) |
| break |
| tool_calls = sentence.get("tool_calls") |
| if isinstance(tool_calls, str): |
| try: |
| sentence["tool_calls"] = json.loads(tool_calls) |
| except json.JSONDecodeError: |
| warnings.warn(f"Failed to parse tool_calls JSON: {tool_calls}") |
| break |
| messages.append(sentence) |
|
|
| try: |
| conversation = self.apply_chat_template(messages, **kwargs) |
| except (ValueError, TypeError): |
| |
| warnings.warn( |
| "Tokenizer does not have a chat_template, using fallback rendering." |
| ) |
| parts = [] |
| bos_token = getattr(self.tokenizer, "bos_token", None) |
| user_header = self.chat_template.user_header or "" |
| assistant_header = self.chat_template.assistant_header or "" |
| end_of_turn = self.chat_template.end_of_turn_token or "" |
|
|
| |
| if bos_token: |
| parts.append(bos_token) |
|
|
| for msg in messages: |
| if msg["role"] == "system": |
| parts.append(msg["content"]) |
| elif msg["role"] == "user": |
| parts.append(f"{user_header}{msg['content']}") |
| elif msg["role"] == "assistant": |
| parts.append(f"{assistant_header}{msg['content']}{end_of_turn}") |
| conversation = "".join(parts) |
|
|
| if not self.tokenizer.pad_token_id: |
| self.tokenizer.pad_token_id = self.tokenizer.unk_token_id |
|
|
| |
| encoding = self.tokenizer( |
| conversation, |
| max_length=max_length, |
| truncation=True, |
| return_tensors="pt", |
| add_special_tokens=False, |
| ) |
| input_ids = encoding.input_ids[0] |
| loss_mask = torch.zeros(len(input_ids), dtype=torch.long) |
|
|
| matches = list(re.finditer(self.assistant_pattern, conversation, re.DOTALL)) |
| if train_only_last_turn and matches: |
| matches = [matches[-1]] |
|
|
| for match in matches: |
| content_start_char = match.start(1) |
| content_end_char = match.end(1) |
|
|
| |
| |
| prefix_ids = self.tokenizer.encode( |
| conversation[:content_start_char], |
| add_special_tokens=False, |
| truncation=True, |
| max_length=max_length, |
| ) |
| |
| full_ids = self.tokenizer.encode( |
| conversation[:content_end_char], |
| add_special_tokens=False, |
| truncation=True, |
| max_length=max_length, |
| ) |
|
|
| start_token_idx = len(prefix_ids) |
| end_token_idx = len(full_ids) |
|
|
| |
| actual_start = min(start_token_idx, len(input_ids)) |
| actual_end = min(end_token_idx, len(input_ids)) |
|
|
| if actual_start < actual_end: |
| loss_mask[actual_start:actual_end] = 1 |
| return input_ids, loss_mask |
|
|
|
|
| class HarmonyParser(Parser): |
| def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): |
| super().__init__(tokenizer, chat_template) |
| self.reasoning_levels = ["low", "medium", "high"] |
| self.default_reasoning_level = "low" |
|
|
| def build_single_turn_prompt( |
| self, |
| prompt_text: str, |
| role: str, |
| content: str, |
| ) -> str: |
| """Embed user message into the required prompt template.""" |
| if role == "system": |
| prompt_text = f"<|start|>system<|message|>{content}<|end|>" |
| elif role == "assistant_reasoning_effort": |
| prompt_text = f"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2025-06-28\n\nReasoning: {content.lower()}\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|>" |
| elif role == "user": |
| prompt_text += f"<|start|>user<|message|>{content}<|end|>" |
| elif role == "assistant_analysis": |
| prompt_text += ( |
| f"<|start|>assistant<|channel|>analysis<|message|>{content}<|end|>" |
| ) |
| elif role == "assistant_commentary": |
| prompt_text += ( |
| f"<|start|>assistant<|channel|>commentary<|message|>{content}<|end|>" |
| ) |
| elif role == "assistant_final": |
| prompt_text += ( |
| f"<|start|>assistant<|channel|>final<|message|>{content}<|end|>" |
| ) |
| else: |
| raise ValueError(f"Unknown role: {role}") |
| return prompt_text |
|
|
| def parse( |
| self, |
| conversation: "Conversation", |
| max_length: int, |
| preformatted: bool = False, |
| train_only_last_turn: bool = False, |
| ) -> List[torch.Tensor]: |
| |
| if not preformatted: |
| prompt_text = "" |
| for j, message in enumerate(conversation): |
| if j == 0 and ( |
| message["role"] != "system" |
| or message["role"] != "assistant_reasoning_effort" |
| ): |
| prompt_text = self.build_single_turn_prompt( |
| prompt_text, |
| "assistant_reasoning_effort", |
| self.default_reasoning_level, |
| ) |
| prompt_text = self.build_single_turn_prompt( |
| prompt_text, message["role"], message["content"] |
| ) |
| conversation = prompt_text |
|
|
| if not self.tokenizer.pad_token_id: |
| self.tokenizer.pad_token_id = self.tokenizer.unk_token_id |
|
|
| encoding = self.tokenizer( |
| conversation, |
| return_offsets_mapping=True, |
| max_length=max_length, |
| truncation=True, |
| return_tensors="pt", |
| add_special_tokens=False, |
| ) |
| input_ids = encoding.input_ids[0] |
| offsets = encoding.offset_mapping[0] |
| loss_mask = torch.zeros(len(input_ids), dtype=torch.long) |
|
|
| |
| |
| |
| pattern = re.compile( |
| r"<\|start\|>assistant([\s\S]*?)(?=<\|start\|>user<\|message\|>|$)" |
| ) |
|
|
| |
| matches = list(pattern.finditer(conversation)) |
| if train_only_last_turn and matches: |
| matches = [matches[-1]] |
|
|
| for match in matches: |
| |
| |
| |
| start_char = match.start(1) |
| end_char = match.end(1) |
|
|
| |
| for idx, (ts, te) in enumerate(offsets): |
| |
| if ts >= start_char and te <= end_char: |
| loss_mask[idx] = 1 |
|
|
| return input_ids, loss_mask |
|
|
|
|
| class ThinkingParser(GeneralParser): |
| def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate): |
| super().__init__(tokenizer, chat_template) |
|
|
| def apply_chat_template(self, messages, **kwargs) -> str: |
| if messages[-1]["role"] == "assistant": |
| conversation_history = self.tokenizer.apply_chat_template( |
| messages[:-1], |
| tokenize=False, |
| add_generation_prompt=True, |
| add_special_tokens=False, |
| **kwargs, |
| ) |
| conversation = ( |
| conversation_history |
| + messages[-1]["content"] |
| + self.chat_template.end_of_turn_token |
| ) |
| return conversation |
| else: |
| raise Exception( |
| f"The last message is not assistant but {messages[-1]['role']}" |
| ) |
|
|
| def parse( |
| self, |
| conversation: "Conversation", |
| max_length: int, |
| preformatted: bool = False, |
| train_only_last_turn: bool = False, |
| **kwargs, |
| ) -> Dict[str, List[torch.Tensor]]: |
| if self.chat_template.enable_thinking: |
| kwargs["enable_thinking"] = True |
| else: |
| pass |
| return super().parse( |
| conversation, max_length, preformatted, train_only_last_turn, **kwargs |
| ) |
|
|