File size: 16,066 Bytes
eea129f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bc61229
eea129f
a93c636
eea129f
 
a93c636
eea129f
a93c636
eea129f
 
 
a93c636
eea129f
a93c636
eea129f
 
a93c636
 
eea129f
 
 
 
 
 
 
a93c636
5d55daf
 
eea129f
5d55daf
eea129f
 
5d55daf
 
 
 
eea129f
5d55daf
 
 
 
9de4aae
 
eea129f
 
 
a93c636
 
eea129f
 
 
 
 
 
 
 
 
 
 
 
a93c636
 
eea129f
 
 
 
a93c636
eea129f
 
 
 
a93c636
eea129f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a93c636
eea129f
 
 
 
a93c636
eea129f
 
 
5d55daf
 
 
 
 
 
eea129f
 
a93c636
 
eea129f
 
 
 
 
 
 
 
 
 
 
 
9de4aae
eea129f
9de4aae
eea129f
 
 
 
 
 
a93c636
eea129f
 
 
 
 
 
 
 
 
 
 
 
a93c636
eea129f
 
 
5d55daf
 
 
 
 
 
eea129f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9de4aae
 
eea129f
 
 
 
a93c636
eea129f
 
 
a93c636
eea129f
 
a93c636
 
eea129f
a93c636
eea129f
 
5d55daf
 
 
a93c636
eea129f
 
 
 
 
a93c636
 
eea129f
 
 
 
 
 
 
 
 
 
 
 
5d55daf
 
 
eea129f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a93c636
eea129f
 
a93c636
eea129f
 
 
 
 
 
 
 
a93c636
eea129f
 
 
a93c636
 
 
 
 
 
 
 
 
 
eea129f
a93c636
eea129f
 
a93c636
eea129f
 
a93c636
eea129f
9de4aae
eea129f
 
a93c636
eea129f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a93c636
eea129f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a93c636
eea129f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a93c636
eea129f
 
 
a93c636
 
eea129f
 
 
 
 
 
 
 
 
 
 
 
5d55daf
 
 
eea129f
 
 
9de4aae
eea129f
9de4aae
 
eea129f
9de4aae
eea129f
 
9de4aae
eea129f
5d55daf
 
 
 
9de4aae
 
 
 
 
 
a93c636
9de4aae
5d55daf
 
 
9de4aae
 
 
 
 
 
 
 
a93c636
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
# Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.

# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at

#     http://www.apache.org/licenses/LICENSE-2.0

# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""BotClient class for interacting with bot models."""

import os
import argparse
import json
import logging
import traceback

import jieba
import requests
from openai import OpenAI


class BotClient:
    """Client for interacting with various AI models."""

    def __init__(self, args: argparse.Namespace):
        """
        Initializes the BotClient instance by configuring essential parameters from command line arguments
        including retry limits, character constraints, model endpoints and API credentials while setting up
        default values for missing arguments to ensure robust operation.

        Args:
            args (argparse.Namespace): Command line arguments containing configuration parameters.
                                      Uses getattr() to safely retrieve values with fallback defaults.
        """
        self.logger = logging.getLogger(__name__)

        self.max_retry_num = getattr(args, "max_retry_num", 3)
        self.max_char = getattr(args, "max_char", 8000)

        self.model_map = getattr(args, "model_map", {})
        self.api_key = os.environ.get("API_KEY")

        self.embedding_service_url = getattr(
            args, "embedding_service_url", "embedding_service_url"
        )
        self.embedding_model = getattr(args, "embedding_model", "embedding_model")

        self.web_search_service_url = getattr(
            args, "web_search_service_url", "web_search_service_url"
        )
        self.max_search_results_num = getattr(args, "max_search_results_num", 15)

        self.qianfan_api_key = os.environ.get("API_KEY")

    def call_back(self, host_url: str, req_data: dict) -> dict:
        """
        Executes an HTTP request to the specified endpoint using the OpenAI client, handles the response
        conversion to a compatible dictionary format, and manages any exceptions that may occur during
        the request process while logging errors appropriately.

        Args:
            host_url (str): The URL to send the request to.
            req_data (dict): The data to send in the request body.

        Returns:
            dict: Parsed JSON response from the server. Returns empty dict
                if request fails or response is invalid.
        """
        try:
            client = OpenAI(base_url=host_url, api_key=self.api_key)
            response = client.chat.completions.create(**req_data)

            # Convert OpenAI response to compatible format
            return response.model_dump()

        except Exception as e:
            self.logger.error(f"Stream request failed: {e}")
            raise

    def call_back_stream(self, host_url: str, req_data: dict) -> dict:
        """
        Makes a streaming HTTP request to the specified host URL using the OpenAI client and yields response chunks
        in real-time while handling any exceptions that may occur during the streaming process.

        Args:
            host_url (str): The URL to send the request to.
            req_data (dict): The data to send in the request body.

        Returns:
            generator: Generator that yields parsed JSON responses from the server.
        """
        try:
            client = OpenAI(base_url=host_url, api_key=self.api_key)
            response = client.chat.completions.create(
                **req_data,
                stream=True,
            )
            for chunk in response:
                if not chunk.choices:
                    continue

                # Convert OpenAI response to compatible format
                yield chunk.model_dump()

        except Exception as e:
            self.logger.error(f"Stream request failed: {e}")
            raise

    def process(
        self,
        model_name: str,
        req_data: dict,
        max_tokens: int = 2048,
        temperature: float = 1.0,
        top_p: float = 0.7,
    ) -> dict:
        """
        Handles chat completion requests by mapping the model name to its endpoint, preparing request parameters
        including token limits and sampling settings, truncating messages to fit character limits, making API calls
        with built-in retry mechanism, and logging the full request/response cycle for debugging purposes.

        Args:
            model_name (str): Name of the model, used to look up the model URL from model_map.
            req_data (dict): Dictionary containing request data, including information to be processed.
            max_tokens (int): Maximum number of tokens to generate.
            temperature (float): Sampling temperature to control the diversity of generated text.
            top_p (float): Cumulative probability threshold to control the diversity of generated text.

        Returns:
            dict: Dictionary containing the model's processing results.
        """
        model_url = self.model_map[model_name]

        req_data["model"] = model_name
        req_data["max_tokens"] = max_tokens
        req_data["temperature"] = temperature
        req_data["top_p"] = top_p
        req_data["messages"] = self.truncate_messages(req_data["messages"])
        for _ in range(self.max_retry_num):
            try:
                self.logger.info(f"[MODEL] {model_url}")
                self.logger.info("[req_data]====>")
                self.logger.info(json.dumps(req_data, ensure_ascii=False))
                res = self.call_back(model_url, req_data)
                self.logger.info("model response")
                self.logger.info(res)
                self.logger.info("-" * 30)
            except Exception as e:
                self.logger.info(e)
                self.logger.info(traceback.format_exc())
                res = {}
            if len(res) != 0 and "error" not in res:
                break

        return res

    def process_stream(
        self,
        model_name: str,
        req_data: dict,
        max_tokens: int = 2048,
        temperature: float = 1.0,
        top_p: float = 0.7,
    ) -> dict:
        """
        Processes streaming requests by mapping the model name to its endpoint, configuring request parameters,
        implementing a retry mechanism with logging, and streaming back response chunks in real-time while
        handling any errors that may occur during the streaming session.

        Args:
            model_name (str): Name of the model, used to look up the model URL from model_map.
            req_data (dict): Dictionary containing request data, including information to be processed.
            max_tokens (int): Maximum number of tokens to generate.
            temperature (float): Sampling temperature to control the diversity of generated text.
            top_p (float): Cumulative probability threshold to control the diversity of generated text.

        Yields:
            dict: Dictionary containing the model's processing results.
        """
        model_url = self.model_map[model_name]
        req_data["model"] = model_name
        req_data["max_tokens"] = max_tokens
        req_data["temperature"] = temperature
        req_data["top_p"] = top_p
        req_data["messages"] = self.truncate_messages(req_data["messages"])

        last_error = None
        for _ in range(self.max_retry_num):
            try:
                self.logger.info(f"[MODEL] {model_url}")
                self.logger.info("[req_data]====>")
                self.logger.info(json.dumps(req_data, ensure_ascii=False))

                yield from self.call_back_stream(model_url, req_data)
                return

            except Exception as e:
                last_error = e
                self.logger.error(
                    f"Stream request failed (attempt {_ + 1}/{self.max_retry_num}): {e}"
                )

        self.logger.error("All retry attempts failed for stream request")
        yield {"error": str(last_error)}

    def cut_chinese_english(self, text: str) -> list:
        """
        Segments mixed Chinese and English text into individual components using Jieba for Chinese words
        while preserving English words as whole units, with special handling for Unicode character ranges
        to distinguish between the two languages.

        Args:
            text (str): Input string to be segmented.

        Returns:
            list: A list of segments, where each segment is either a letter or a word.
        """
        words = jieba.lcut(text)
        en_ch_words = []

        for word in words:
            if word.isalpha() and not any(
                "\u4e00" <= char <= "\u9fff" for char in word
            ):
                en_ch_words.append(word)
            else:
                en_ch_words.extend(list(word))
        return en_ch_words

    def truncate_messages(self, messages: list[dict]) -> list:
        """
        Truncates conversation messages to fit within the maximum character limit (self.max_char)
        by intelligently removing content while preserving message structure. The truncation follows
        a prioritized order: historical messages first, then system message, and finally the last message.

        Args:
            messages (list[dict]): List of messages to be truncated.

        Returns:
            list[dict]: Modified list of messages after truncation.
        """
        if not messages:
            return messages

        processed = []
        total_units = 0

        for msg in messages:
            # Handle two different content formats
            if isinstance(msg["content"], str):
                text_content = msg["content"]
            elif isinstance(msg["content"], list):
                text_content = msg["content"][1]["text"]
            else:
                text_content = ""

            # Calculate unit count after tokenization
            units = self.cut_chinese_english(text_content)
            unit_count = len(units)

            processed.append(
                {
                    "role": msg["role"],
                    "original_content": msg["content"],  # Preserve original content
                    "text_content": text_content,  # Extracted plain text
                    "units": units,
                    "unit_count": unit_count,
                }
            )
            total_units += unit_count

        if total_units <= self.max_char:
            return messages

        # Number of units to remove
        to_remove = total_units - self.max_char

        # 1. Truncate historical messages
        for i in range(len(processed) - 1, 1):
            if to_remove <= 0:
                break

            # current = processed[i]
            if processed[i]["unit_count"] <= to_remove:
                processed[i]["text_content"] = ""
                to_remove -= processed[i]["unit_count"]
                if isinstance(processed[i]["original_content"], str):
                    processed[i]["original_content"] = ""
                elif isinstance(processed[i]["original_content"], list):
                    processed[i]["original_content"][1]["text"] = ""
            else:
                kept_units = processed[i]["units"][:-to_remove]
                new_text = "".join(kept_units)
                processed[i]["text_content"] = new_text
                if isinstance(processed[i]["original_content"], str):
                    processed[i]["original_content"] = new_text
                elif isinstance(processed[i]["original_content"], list):
                    processed[i]["original_content"][1]["text"] = new_text
                to_remove = 0

        # 2. Truncate system message
        if to_remove > 0:
            system_msg = processed[0]
            if system_msg["unit_count"] <= to_remove:
                processed[0]["text_content"] = ""
                to_remove -= system_msg["unit_count"]
                if isinstance(processed[0]["original_content"], str):
                    processed[0]["original_content"] = ""
                elif isinstance(processed[0]["original_content"], list):
                    processed[0]["original_content"][1]["text"] = ""
            else:
                kept_units = system_msg["units"][:-to_remove]
                new_text = "".join(kept_units)
                processed[0]["text_content"] = new_text
                if isinstance(processed[0]["original_content"], str):
                    processed[0]["original_content"] = new_text
                elif isinstance(processed[0]["original_content"], list):
                    processed[0]["original_content"][1]["text"] = new_text
                to_remove = 0

        # 3. Truncate last message
        if to_remove > 0 and len(processed) > 1:
            last_msg = processed[-1]
            if last_msg["unit_count"] > to_remove:
                kept_units = last_msg["units"][:-to_remove]
                new_text = "".join(kept_units)
                last_msg["text_content"] = new_text
                if isinstance(last_msg["original_content"], str):
                    last_msg["original_content"] = new_text
                elif isinstance(last_msg["original_content"], list):
                    last_msg["original_content"][1]["text"] = new_text
            else:
                last_msg["text_content"] = ""
                if isinstance(last_msg["original_content"], str):
                    last_msg["original_content"] = ""
                elif isinstance(last_msg["original_content"], list):
                    last_msg["original_content"][1]["text"] = ""

        result = []
        for msg in processed:
            if msg["text_content"]:
                result.append({"role": msg["role"], "content": msg["original_content"]})

        return result

    def embed_fn(self, text: str) -> list:
        """
        Generate an embedding for the given text using the QianFan API.

        Args:
            text (str): The input text to be embedded.

        Returns:
            list: A list of floats representing the embedding.
        """
        client = OpenAI(
            base_url=self.embedding_service_url, api_key=self.qianfan_api_key
        )
        response = client.embeddings.create(input=[text], model=self.embedding_model)
        return response.data[0].embedding

    def get_web_search_res(self, query_list: list) -> list:
        """
        Send a request to the AI Search service using the provided API key and service URL.

        Args:
            query_list (list): List of queries to send to the AI Search service.

        Returns:
            list: List of responses from the AI Search service.
        """
        headers = {
            "Authorization": "Bearer " + self.qianfan_api_key,
            "Content-Type": "application/json",
        }

        results = []
        top_k = self.max_search_results_num // len(query_list)
        for query in query_list:
            payload = {
                "messages": [{"role": "user", "content": query}],
                "resource_type_filter": [{"type": "web", "top_k": top_k}],
            }
            response = requests.post(
                self.web_search_service_url, headers=headers, json=payload
            )

            if response.status_code == 200:
                response = response.json()
                self.logger.info(response)
                results.append(response["references"])
            else:
                self.logger.info(f"请求失败,状态码: {response.status_code}")
                self.logger.info(response.text)
        return results