File size: 13,481 Bytes
eae7bce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
import json
import re
import warnings
from abc import ABC, abstractmethod
from typing import Dict, List, Tuple

import torch
from transformers import PreTrainedTokenizer

from .template import ChatTemplate

__all__ = ["GeneralParser", "HarmonyParser"]


class Parser(ABC):

    def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate):
        self.tokenizer = tokenizer
        self.chat_template = chat_template

    @abstractmethod
    def parse(
        self, conversation: "Conversation", max_length: int
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Parse the conversation into a list of tensors.

        Args:
            conversation: The conversation to parse.

        Returns:
            A list of tensors: [input_ids, loss_mask]
        """


_harmony_encoding = None


class GeneralParser(Parser):

    def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate):
        super().__init__(tokenizer, chat_template)
        self.system_prompt = chat_template.system_prompt
        self.user_message_separator = f"{chat_template.end_of_turn_token}"
        self.assistant_message_separator = f"{chat_template.assistant_header}"
        self.set_assistant_pattern(chat_template)

    def apply_chat_template(self, messages, **kwargs) -> str:
        conversation = self.tokenizer.apply_chat_template(
            messages, tokenize=False, add_generation_prompt=False, **kwargs
        )
        return conversation

    def set_assistant_pattern(self, chat_template: ChatTemplate):
        if chat_template.assistant_pattern_type == "longcat":
            self.assistant_pattern = (
                re.escape(self.assistant_message_separator)
                + r"([\s\S]*?(?:"
                + re.escape("[Round ")
                + r"\d+"
                + re.escape("] USER:")
                + "|$))"
            )
        else:
            self.assistant_pattern = (
                re.escape(self.assistant_message_separator)
                + r"([\s\S]*?(?:"
                + re.escape(self.chat_template.end_of_turn_token)
                + "|$))"
            )

    def parse(
        self,
        conversation: "Conversation",
        max_length: int,
        preformatted: bool = False,
        train_only_last_turn: bool = False,
        **kwargs,
    ) -> Dict[str, List[torch.Tensor]]:
        if not preformatted:
            messages = []

            if conversation[0]["role"] == "system":
                warnings.warn(
                    f"The first message is from system, we will use the system prompt from the data and ignore the system prompt from the template"
                )
                messages.append(
                    {"role": "system", "content": conversation[0]["content"]}
                )
                conversation = conversation[1:]
            else:
                if self.system_prompt:
                    messages.append({"role": "system", "content": self.system_prompt})

            for j, sentence in enumerate(conversation):
                role = sentence["role"]
                if j == 0:
                    if role != "user":
                        warnings.warn(
                            f"Conversation must start with a 'user' role, but found '{role}'. Conversation truncated."
                        )
                        break
                else:
                    prev_role = conversation[j - 1]["role"]
                    if role == "tool" and prev_role not in ["assistant", "tool"]:
                        warnings.warn(
                            f"A 'tool' message must follow an 'assistant' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated."
                        )
                        break
                    if role == "assistant" and prev_role not in ["user", "tool"]:
                        warnings.warn(
                            f"An 'assistant' message must follow a 'user' or 'tool' message, but was preceded by '{prev_role}'. Conversation truncated."
                        )
                        break
                tool_calls = sentence.get("tool_calls")
                if isinstance(tool_calls, str):
                    try:
                        sentence["tool_calls"] = json.loads(tool_calls)
                    except json.JSONDecodeError:
                        warnings.warn(f"Failed to parse tool_calls JSON: {tool_calls}")
                        break
                messages.append(sentence)

            try:
                conversation = self.apply_chat_template(messages, **kwargs)
            except (ValueError, TypeError):
                # Fallback rendering for tokenizers without built-in chat_template
                warnings.warn(
                    "Tokenizer does not have a chat_template, using fallback rendering."
                )
                parts = []
                bos_token = getattr(self.tokenizer, "bos_token", None)
                user_header = self.chat_template.user_header or ""
                assistant_header = self.chat_template.assistant_header or ""
                end_of_turn = self.chat_template.end_of_turn_token or ""

                # Add BOS token at the start
                if bos_token:
                    parts.append(bos_token)

                for msg in messages:
                    if msg["role"] == "system":
                        parts.append(msg["content"])
                    elif msg["role"] == "user":
                        parts.append(f"{user_header}{msg['content']}")
                    elif msg["role"] == "assistant":
                        parts.append(f"{assistant_header}{msg['content']}{end_of_turn}")
                conversation = "".join(parts)

        if not self.tokenizer.pad_token_id:
            self.tokenizer.pad_token_id = self.tokenizer.unk_token_id

        # get input_ids
        encoding = self.tokenizer(
            conversation,
            max_length=max_length,
            truncation=True,
            return_tensors="pt",
            add_special_tokens=False,
        )
        input_ids = encoding.input_ids[0]
        loss_mask = torch.zeros(len(input_ids), dtype=torch.long)

        matches = list(re.finditer(self.assistant_pattern, conversation, re.DOTALL))
        if train_only_last_turn and matches:
            matches = [matches[-1]]  # Only keep the last match

        for match in matches:
            content_start_char = match.start(1)
            content_end_char = match.end(1)

            # --- Core Alternative Operation: Calculate Token Index Based on Prefix String Length ---
            # Encode the text "assistant start", the length of which is the position of the starting token.
            prefix_ids = self.tokenizer.encode(
                conversation[:content_start_char],
                add_special_tokens=False,
                truncation=True,
                max_length=max_length,
            )
            # Encodes the text "assistant end", the length of which is the position of the end token.
            full_ids = self.tokenizer.encode(
                conversation[:content_end_char],
                add_special_tokens=False,
                truncation=True,
                max_length=max_length,
            )

            start_token_idx = len(prefix_ids)
            end_token_idx = len(full_ids)

            # Handling out-of-bounds errors caused by truncation
            actual_start = min(start_token_idx, len(input_ids))
            actual_end = min(end_token_idx, len(input_ids))

            if actual_start < actual_end:
                loss_mask[actual_start:actual_end] = 1
        return input_ids, loss_mask


class HarmonyParser(Parser):
    def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate):
        super().__init__(tokenizer, chat_template)
        self.reasoning_levels = ["low", "medium", "high"]
        self.default_reasoning_level = "low"

    def build_single_turn_prompt(
        self,
        prompt_text: str,
        role: str,
        content: str,
    ) -> str:
        """Embed user message into the required prompt template."""
        if role == "system":
            prompt_text = f"<|start|>system<|message|>{content}<|end|>"
        elif role == "assistant_reasoning_effort":
            prompt_text = f"<|start|>system<|message|>You are ChatGPT, a large language model trained by OpenAI.\nKnowledge cutoff: 2024-06\nCurrent date: 2025-06-28\n\nReasoning: {content.lower()}\n\n# Valid channels: analysis, commentary, final. Channel must be included for every message.<|end|>"
        elif role == "user":
            prompt_text += f"<|start|>user<|message|>{content}<|end|>"
        elif role == "assistant_analysis":
            prompt_text += (
                f"<|start|>assistant<|channel|>analysis<|message|>{content}<|end|>"
            )
        elif role == "assistant_commentary":
            prompt_text += (
                f"<|start|>assistant<|channel|>commentary<|message|>{content}<|end|>"
            )
        elif role == "assistant_final":
            prompt_text += (
                f"<|start|>assistant<|channel|>final<|message|>{content}<|end|>"
            )
        else:
            raise ValueError(f"Unknown role: {role}")
        return prompt_text

    def parse(
        self,
        conversation: "Conversation",
        max_length: int,
        preformatted: bool = False,
        train_only_last_turn: bool = False,
    ) -> List[torch.Tensor]:
        # conversation = process_harmony_conversations(conversation)
        if not preformatted:
            prompt_text = ""
            for j, message in enumerate(conversation):
                if j == 0 and (
                    message["role"] != "system"
                    or message["role"] != "assistant_reasoning_effort"
                ):
                    prompt_text = self.build_single_turn_prompt(
                        prompt_text,
                        "assistant_reasoning_effort",
                        self.default_reasoning_level,
                    )
                prompt_text = self.build_single_turn_prompt(
                    prompt_text, message["role"], message["content"]
                )
            conversation = prompt_text

        if not self.tokenizer.pad_token_id:
            self.tokenizer.pad_token_id = self.tokenizer.unk_token_id

        encoding = self.tokenizer(
            conversation,
            return_offsets_mapping=True,
            max_length=max_length,
            truncation=True,
            return_tensors="pt",
            add_special_tokens=False,
        )
        input_ids = encoding.input_ids[0]
        offsets = encoding.offset_mapping[0]
        loss_mask = torch.zeros(len(input_ids), dtype=torch.long)

        # Find spans of assistant responses using regex
        # We match `<|start|>assistant` and only extract the content following it.
        # This continues until `<|start|>user<|message|>` appears, or until the end of the string.
        pattern = re.compile(
            r"<\|start\|>assistant([\s\S]*?)(?=<\|start\|>user<\|message\|>|$)"
        )

        # Find all matching segments
        matches = list(pattern.finditer(conversation))
        if train_only_last_turn and matches:
            matches = [matches[-1]]  # Only keep the last match

        for match in matches:
            # match.start(0) is the start index of the full match (including `<|start|>assistant`)
            # match.start(1) is the start index of the first capture group (excluding `<|start|>assistant`)
            # match.end(1) is the end index of the content
            start_char = match.start(1)
            end_char = match.end(1)

            # Map character indices to token indices
            for idx, (ts, te) in enumerate(offsets):
                # Set mask to 1 only if the token's character range falls entirely within the "content area"
                if ts >= start_char and te <= end_char:
                    loss_mask[idx] = 1

        return input_ids, loss_mask


class ThinkingParser(GeneralParser):
    def __init__(self, tokenizer: PreTrainedTokenizer, chat_template: ChatTemplate):
        super().__init__(tokenizer, chat_template)

    def apply_chat_template(self, messages, **kwargs) -> str:
        if messages[-1]["role"] == "assistant":
            conversation_history = self.tokenizer.apply_chat_template(
                messages[:-1],
                tokenize=False,
                add_generation_prompt=True,
                add_special_tokens=False,
                **kwargs,
            )
            conversation = (
                conversation_history
                + messages[-1]["content"]
                + self.chat_template.end_of_turn_token
            )
            return conversation
        else:
            raise Exception(
                f"The last message is not assistant but {messages[-1]['role']}"
            )

    def parse(
        self,
        conversation: "Conversation",
        max_length: int,
        preformatted: bool = False,
        train_only_last_turn: bool = False,
        **kwargs,
    ) -> Dict[str, List[torch.Tensor]]:
        if self.chat_template.enable_thinking:
            kwargs["enable_thinking"] = True
        else:
            pass
        return super().parse(
            conversation, max_length, preformatted, train_only_last_turn, **kwargs
        )