File size: 20,694 Bytes
ae64487
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
from typing import (
    List,
    Dict,
    Union,
    Generator,
    AsyncGenerator,
)
from aworld.config import ConfigDict
from aworld.config.conf import AgentConfig, ClientType
from aworld.logs.util import logger

from aworld.core.llm_provider_base import LLMProviderBase
from aworld.models.openai_provider import OpenAIProvider, AzureOpenAIProvider
from aworld.models.anthropic_provider import AnthropicProvider
from aworld.models.ant_provider import AntProvider
from aworld.models.model_response import ModelResponse

# Predefined model names for common providers
MODEL_NAMES = {
    "anthropic": ["claude-3-5-sonnet-20241022", "claude-3-5-sonnet-20240620", "claude-3-opus-20240229"],
    "openai": ["gpt-4o", "gpt-4", "gpt-3.5-turbo", "o3-mini", "gpt-4o-mini"],
    "azure_openai": ["gpt-4", "gpt-4-turbo", "gpt-4o", "gpt-35-turbo"],
}

# Endpoint patterns for identifying providers
ENDPOINT_PATTERNS = {
    "openai": ["api.openai.com"],
    "anthropic": ["api.anthropic.com", "claude-api"],
    "azure_openai": ["openai.azure.com"],
    "ant": ["zdfmng.alipay.com"],
}

# Provider class mapping
PROVIDER_CLASSES = {
    "openai": OpenAIProvider,
    "anthropic": AnthropicProvider,
    "azure_openai": AzureOpenAIProvider,
    "ant": AntProvider,
}


class LLMModel:
    """Unified large model interface, encapsulates different model implementations, provides a unified completion method.
    """

    def __init__(self, conf: Union[ConfigDict, AgentConfig] = None, custom_provider: LLMProviderBase = None, **kwargs):
        """Initialize unified model interface.

        Args:
            conf: Agent configuration, if provided, create model based on configuration.
            custom_provider: Custom LLMProviderBase instance, if provided, use it directly.
            **kwargs: Other parameters, may include:
                - base_url: Specify model endpoint.
                - api_key: API key.
                - model_name: Model name.
                - temperature: Temperature parameter.
        """

        # If custom_provider instance is provided, use it directly
        if custom_provider is not None:
            if not isinstance(custom_provider, LLMProviderBase):
                raise TypeError(
                    "custom_provider must be an instance of LLMProviderBase")
            self.provider_name = "custom"
            self.provider = custom_provider
            return

        # Get basic parameters
        base_url = kwargs.get("base_url") or (
            conf.llm_base_url if conf else None)
        model_name = kwargs.get("model_name") or (
            conf.llm_model_name if conf else None)
        llm_provider = conf.llm_provider if conf_contains_key(
            conf, "llm_provider") else None

        # Get API key from configuration (if any)
        if conf and conf.llm_api_key:
            kwargs["api_key"] = conf.llm_api_key

        # Identify provider
        self.provider_name = self._identify_provider(
            llm_provider, base_url, model_name)

        # Fill basic parameters
        kwargs['base_url'] = base_url
        kwargs['model_name'] = model_name

        # Fill parameters for llm provider
        kwargs['sync_enabled'] = conf.llm_sync_enabled if conf_contains_key(
            conf, "llm_sync_enabled") else True
        kwargs['async_enabled'] = conf.llm_async_enabled if conf_contains_key(
            conf, "llm_async_enabled") else True
        kwargs['client_type'] = conf.llm_client_type if conf_contains_key(
            conf, "llm_client_type") else ClientType.SDK

        kwargs.update(self._transfer_conf_to_args(conf))

        # Create model provider based on provider_name
        self._create_provider(**kwargs)

    def _transfer_conf_to_args(self, conf: Union[ConfigDict, AgentConfig] = None) -> dict:
        """
        Transfer parameters from conf to args

        Args:
            conf: config object
        """
        if not conf:
            return {}

        # Get all parameters from conf
        if type(conf).__name__ == 'AgentConfig':
            conf_dict = conf.model_dump()
        else:  # ConfigDict
            conf_dict = conf

        ignored_keys = ["llm_provider", "llm_base_url", "llm_model_name", "llm_api_key", "llm_sync_enabled",
                        "llm_async_enabled", "llm_client_type"]
        args = {}
        # Filter out used parameters and add remaining parameters to args
        for key, value in conf_dict.items():
            if key not in ignored_keys and value is not None:
                args[key] = value

        return args

    def _identify_provider(self, provider: str = None, base_url: str = None, model_name: str = None) -> str:
        """Identify LLM provider.

        Identification logic:
        1. If provider is specified and doesn't need to be overridden, use the specified provider.
        2. If base_url is provided, try to identify provider based on base_url.
        3. If model_name is provided, try to identify provider based on model_name.
        4. If none can be identified, default to "openai".

        Args:
            provider: Specified provider.
            base_url: Service URL.
            model_name: Model name.

        Returns:
            str: Identified provider.
        """
        # Default provider
        identified_provider = "openai"

        # Identify provider based on base_url
        if base_url:
            for p, patterns in ENDPOINT_PATTERNS.items():
                if any(pattern in base_url for pattern in patterns):
                    identified_provider = p
                    logger.info(
                        f"Identified provider: {identified_provider} based on base_url: {base_url}")
                    return identified_provider

        # Identify provider based on model_name
        if model_name and not base_url:
            for p, models in MODEL_NAMES.items():
                if model_name in models or any(model_name.startswith(model) for model in models):
                    identified_provider = p
                    logger.info(
                        f"Identified provider: {identified_provider} based on model_name: {model_name}")
                    break

        if provider and provider in PROVIDER_CLASSES and identified_provider and identified_provider != provider:
            logger.warning(
                f"Provider mismatch: {provider} != {identified_provider}, using {provider} as provider")
            identified_provider = provider

        return identified_provider

    def _create_provider(self, **kwargs):
        """Return the corresponding provider instance based on provider.

        Args:
            **kwargs: Parameters, may include:
                - base_url: Model endpoint.
                - api_key: API key.
                - model_name: Model name.
                - temperature: Temperature parameter.
                - timeout: Timeout.
                - max_retries: Maximum number of retries.
        """
        self.provider = PROVIDER_CLASSES[self.provider_name](**kwargs)

    @classmethod
    def supported_providers(cls) -> list[str]:
        return list(PROVIDER_CLASSES.keys())

    def supported_models(self) -> list[str]:
        """Get supported models for the current provider.
        Returns:
            list: Supported models.
        """
        return self.provider.supported_models() if self.provider else []

    async def acompletion(self,
                          messages: List[Dict[str, str]],
                          temperature: float = 0.0,
                          max_tokens: int = None,
                          stop: List[str] = None,
                          **kwargs) -> ModelResponse:
        """Asynchronously call model to generate response.

        Args:
            messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
            temperature: Temperature parameter.
            max_tokens: Maximum number of tokens to generate.
            stop: List of stop sequences.
            **kwargs: Other parameters.

        Returns:
            ModelResponse: Unified model response object.
        """
        # Call provider's acompletion method directly
        return await self.provider.acompletion(
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            stop=stop,
            **kwargs
        )

    def completion(self,
                   messages: List[Dict[str, str]],
                   temperature: float = 0.0,
                   max_tokens: int = None,
                   stop: List[str] = None,
                   **kwargs) -> ModelResponse:
        """Synchronously call model to generate response.

        Args:
            messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
            temperature: Temperature parameter.
            max_tokens: Maximum number of tokens to generate.
            stop: List of stop sequences.
            **kwargs: Other parameters.

        Returns:
            ModelResponse: Unified model response object.
        """
        # Call provider's completion method directly
        return self.provider.completion(
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            stop=stop,
            **kwargs
        )

    def stream_completion(self,
                          messages: List[Dict[str, str]],
                          temperature: float = 0.0,
                          max_tokens: int = None,
                          stop: List[str] = None,
                          **kwargs) -> Generator[ModelResponse, None, None]:
        """Synchronously call model to generate streaming response.

        Args:
            messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
            temperature: Temperature parameter.
            max_tokens: Maximum number of tokens to generate.
            stop: List of stop sequences.
            **kwargs: Other parameters.

        Returns:
            Generator yielding ModelResponse chunks.
        """
        # Call provider's stream_completion method directly
        return self.provider.stream_completion(
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            stop=stop,
            **kwargs
        )

    async def astream_completion(self,
                                 messages: List[Dict[str, str]],
                                 temperature: float = 0.0,
                                 max_tokens: int = None,
                                 stop: List[str] = None,
                                 **kwargs) -> AsyncGenerator[ModelResponse, None]:
        """Asynchronously call model to generate streaming response.

        Args:
            messages: Message list, format is [{"role": "system", "content": "..."}, {"role": "user", "content": "..."}].
            temperature: Temperature parameter.
            max_tokens: Maximum number of tokens to generate.
            stop: List of stop sequences.
            **kwargs: Other parameters, may include:
                - base_url: Specify model endpoint.
                - api_key: API key.
                - model_name: Model name.

        Returns:
            AsyncGenerator yielding ModelResponse chunks.
        """
        # Call provider's astream_completion method directly
        async for chunk in self.provider.astream_completion(
                messages=messages,
                temperature=temperature,
                max_tokens=max_tokens,
                stop=stop,
                **kwargs
        ):
            yield chunk

    def speech_to_text(self,
                       audio_file: str,
                       language: str = None,
                       prompt: str = None,
                       **kwargs) -> ModelResponse:
        """Convert speech to text.

        Args:
            audio_file: Path to audio file or file object.
            language: Audio language, optional.
            prompt: Transcription prompt, optional.
            **kwargs: Other parameters.

        Returns:
            ModelResponse: Unified model response object, with content field containing the transcription result.

        Raises:
            LLMResponseError: When LLM response error occurs.
            NotImplementedError: When provider does not support speech to text conversion.
        """
        return self.provider.speech_to_text(
            audio_file=audio_file,
            language=language,
            prompt=prompt,
            **kwargs
        )

    async def aspeech_to_text(self,
                              audio_file: str,
                              language: str = None,
                              prompt: str = None,
                              **kwargs) -> ModelResponse:
        """Asynchronously convert speech to text.

        Args:
            audio_file: Path to audio file or file object.
            language: Audio language, optional.
            prompt: Transcription prompt, optional.
            **kwargs: Other parameters.

        Returns:
            ModelResponse: Unified model response object, with content field containing the transcription result.

        Raises:
            LLMResponseError: When LLM response error occurs.
            NotImplementedError: When provider does not support speech to text conversion.
        """
        return await self.provider.aspeech_to_text(
            audio_file=audio_file,
            language=language,
            prompt=prompt,
            **kwargs
        )


def register_llm_provider(provider: str, provider_class: type):
    """Register a custom LLM provider.

    Args:
        provider: Provider name.
        provider_class: Provider class, must inherit from LLMProviderBase.
    """
    if not issubclass(provider_class, LLMProviderBase):
        raise TypeError("provider_class must be a subclass of LLMProviderBase")
    PROVIDER_CLASSES[provider] = provider_class


def conf_contains_key(conf: Union[ConfigDict, AgentConfig], key: str) -> bool:
    """Check if conf contains key.
    Args:
        conf: Config object.
        key: Key to check.
    Returns:
        bool: Whether conf contains key.
    """
    if not conf:
        return False
    if type(conf).__name__ == 'AgentConfig':
        return hasattr(conf, key)
    else:
        return key in conf


def get_llm_model(conf: Union[ConfigDict, AgentConfig] = None,
                  custom_provider: LLMProviderBase = None,
                  **kwargs) -> Union[LLMModel, 'ChatOpenAI']:
    """Get a unified LLM model instance.

    Args:
        conf: Agent configuration, if provided, create model based on configuration.
        custom_provider: Custom LLMProviderBase instance, if provided, use it directly.
        **kwargs: Other parameters, may include:
            - base_url: Specify model endpoint.
            - api_key: API key.
            - model_name: Model name.
            - temperature: Temperature parameter.

    Returns:
        Unified model interface.
    """
    # Create and return LLMModel instance directly
    llm_provider = conf.llm_provider if conf_contains_key(
        conf, "llm_provider") else None

    if (llm_provider == "chatopenai"):
        from langchain_openai import ChatOpenAI

        base_url = kwargs.get("base_url") or (
            conf.llm_base_url if conf_contains_key(conf, "llm_base_url") else None)
        model_name = kwargs.get("model_name") or (
            conf.llm_model_name if conf_contains_key(conf, "llm_model_name") else None)
        api_key = kwargs.get("api_key") or (
            conf.llm_api_key if conf_contains_key(conf, "llm_api_key") else None)

        return ChatOpenAI(
            model=model_name,
            temperature=kwargs.get("temperature", conf.llm_temperature if conf_contains_key(
                conf, "llm_temperature") else 0.0),
            base_url=base_url,
            api_key=api_key,
        )

    return LLMModel(conf=conf, custom_provider=custom_provider, **kwargs)


def call_llm_model(
        llm_model: LLMModel,
        messages: List[Dict[str, str]],
        temperature: float = 0.0,
        max_tokens: int = None,
        stop: List[str] = None,
        stream: bool = False,
        **kwargs
) -> Union[ModelResponse, Generator[ModelResponse, None, None]]:
    """Convenience function to call LLM model.

    Args:
        llm_model: LLM model instance.
        messages: Message list.
        temperature: Temperature parameter.
        max_tokens: Maximum number of tokens to generate.
        stop: List of stop sequences.
        stream: Whether to return a streaming response.
        **kwargs: Other parameters.

    Returns:
        Model response or response generator.
    """
    if stream:
        return llm_model.stream_completion(
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            stop=stop,
            **kwargs
        )
    else:
        return llm_model.completion(
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            stop=stop,
            **kwargs
        )


async def acall_llm_model(
        llm_model: LLMModel,
        messages: List[Dict[str, str]],
        temperature: float = 0.0,
        max_tokens: int = None,
        stop: List[str] = None,
        stream: bool = False,
        **kwargs
) -> ModelResponse:
    """Convenience function to asynchronously call LLM model.

    Args:
        llm_model: LLM model instance.
        messages: Message list.
        temperature: Temperature parameter.
        max_tokens: Maximum number of tokens to generate.
        stop: List of stop sequences.
        stream: Whether to return a streaming response.
        **kwargs: Other parameters.

    Returns:
        Model response or response generator.
    """
    return await llm_model.acompletion(
        messages=messages,
        temperature=temperature,
        max_tokens=max_tokens,
        stop=stop,
        **kwargs
    )


async def acall_llm_model_stream(
        llm_model: LLMModel,
        messages: List[Dict[str, str]],
        temperature: float = 0.0,
        max_tokens: int = None,
        stop: List[str] = None,
        **kwargs
) -> AsyncGenerator[ModelResponse, None]:
    async for chunk in llm_model.astream_completion(
            messages=messages,
            temperature=temperature,
            max_tokens=max_tokens,
            stop=stop,
            **kwargs
    ):
        yield chunk


def speech_to_text(
        llm_model: LLMModel,
        audio_file: str,
        language: str = None,
        prompt: str = None,
        **kwargs
) -> ModelResponse:
    """Convenience function to convert speech to text.

    Args:
        llm_model: LLM model instance.
        audio_file: Path to audio file or file object.
        language: Audio language, optional.
        prompt: Transcription prompt, optional.
        **kwargs: Other parameters.

    Returns:
        ModelResponse: Unified model response object, with content field containing the transcription result.
    """
    if llm_model.provider_name != "openai":
        raise NotImplementedError(
            f"Speech-to-text functionality is currently only supported for OpenAI compatible provider, current provider: {llm_model.provider_name}")

    return llm_model.speech_to_text(
        audio_file=audio_file,
        language=language,
        prompt=prompt,
        **kwargs
    )


async def aspeech_to_text(
        llm_model: LLMModel,
        audio_file: str,
        language: str = None,
        prompt: str = None,
        **kwargs
) -> ModelResponse:
    """Convenience function to asynchronously convert speech to text.

    Args:
        llm_model: LLM model instance.
        audio_file: Path to audio file or file object.
        language: Audio language, optional.
        prompt: Transcription prompt, optional.
        **kwargs: Other parameters.

    Returns:
        ModelResponse: Unified model response object, with content field containing the transcription result.
    """
    if llm_model.provider_name != "openai":
        raise NotImplementedError(
            f"Speech-to-text functionality is currently only supported for OpenAI compatible provider, current provider: {llm_model.provider_name}")

    return await llm_model.aspeech_to_text(
        audio_file=audio_file,
        language=language,
        prompt=prompt,
        **kwargs
    )