File size: 22,180 Bytes
900f476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
05e73cc
77e475c
900f476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b2fc09f
 
 
900f476
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
# model_handler.py
import os
import requests
import json
import logging
from dotenv import load_dotenv

# Load environment variables from .env file
load_dotenv()

logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(name)s - %(levelname)s - %(message)s"
)
logger = logging.getLogger(__name__)

# Maps provider name (uppercase) to environment variable name for API key
API_KEYS_ENV_VARS = {
  "HUGGINGFACE": 'HF_TOKEN', # Note: HF_TOKEN is often used for general HF auth
  "GROQ": 'GROQ_API_KEY',
  "OPENROUTER": 'OPENROUTER_API_KEY',
  "TOGETHERAI": 'TOGETHERAI_API_KEY',
  "COHERE": 'COHERE_API_KEY',
  "XAI": 'XAI_API_KEY',
  "OPENAI": 'OPENAI_API_KEY',
  "GOOGLE": 'GOOGLE_API_KEY', # Or GOOGLE_GEMINI_API_KEY etc.
}

API_URLS = {
  "HUGGINGFACE": 'https://api-inference.huggingface.co/models/',
  "GROQ": 'https://api.groq.com/openai/v1/chat/completions',
  "OPENROUTER": 'https://openrouter.ai/api/v1/chat/completions',
  "TOGETHERAI": 'https://api.together.ai/v1/chat/completions',
  "COHERE": 'https://api.cohere.ai/v1/chat', # v1 is common for chat, was v2 in ai-learn
  "XAI": 'https://api.x.ai/v1/chat/completions',
  "OPENAI": 'https://api.openai.com/v1/chat/completions',
  "GOOGLE": 'https://generativelanguage.googleapis.com/v1beta/models/',
}

MODELS_BY_PROVIDER = json.load(open("./models.json")) 


def _get_api_key(provider: str, ui_api_key_override: str = None) -> str | None:
    """
    Retrieves API key for a given provider.
    Priority: UI Override > Environment Variable from API_KEYS_ENV_VARS > Specific (e.g. HF_TOKEN for HuggingFace).
    """
    provider_upper = provider.upper()
    if ui_api_key_override and ui_api_key_override.strip():
        logger.debug(f"Using UI-provided API key for {provider_upper}.")
        return ui_api_key_override.strip()

    env_var_name = API_KEYS_ENV_VARS.get(provider_upper)
    if env_var_name:
        env_key = os.getenv(env_var_name)
        if env_key and env_key.strip():
            logger.debug(f"Using API key from env var '{env_var_name}' for {provider_upper}.")
            return env_key.strip()

    # Specific fallback for HuggingFace if HF_TOKEN is set and API_KEYS_ENV_VARS['HUGGINGFACE'] wasn't specific enough
    if provider_upper == 'HUGGINGFACE':
         hf_token_fallback = os.getenv("HF_TOKEN")
         if hf_token_fallback and hf_token_fallback.strip():
             logger.debug("Using HF_TOKEN as fallback for HuggingFace provider.")
             return hf_token_fallback.strip()

    logger.warning(f"API Key not found for provider '{provider_upper}'. Checked UI override and environment variable '{env_var_name or 'N/A'}'.")
    return None

def get_available_providers() -> list[str]:
    """Returns a sorted list of available provider names (e.g., 'groq', 'openai')."""
    return sorted(list(MODELS_BY_PROVIDER.keys()))

def get_model_display_names_for_provider(provider: str) -> list[str]:
    """Returns a sorted list of model display names for a given provider."""
    return sorted(list(MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {}).keys()))

def get_default_model_display_name_for_provider(provider: str) -> str | None:
    """Gets the default model's display name for a provider."""
    provider_data = MODELS_BY_PROVIDER.get(provider.lower(), {})
    models_dict = provider_data.get("models", {})
    default_model_id = provider_data.get("default")

    if default_model_id and models_dict:
        for display_name, model_id_val in models_dict.items():
            if model_id_val == default_model_id:
                return display_name
    
    # Fallback to the first model in the sorted list if default not found or not set
    if models_dict:
        #sorted_display_names = sorted(list(models_dict.keys()))
        sorted_display_names = list(models_dict.keys())
        
        if sorted_display_names:
            return sorted_display_names[0]
    return None

def get_model_id_from_display_name(provider: str, display_name: str) -> str | None:
    """Gets the actual model ID from its display name for a given provider."""
    models = MODELS_BY_PROVIDER.get(provider.lower(), {}).get("models", {})
    return models.get(display_name)


def call_model_stream(provider: str, model_display_name: str, messages: list[dict], api_key_override: str = None, temperature: float = 0.7, max_tokens: int = None) -> iter:
    """
    Calls the specified model via its provider and streams the response.
    Handles provider-specific request formatting and error handling.
    Yields chunks of the response text or an error string.
    """
    provider_lower = provider.lower()
    api_key = _get_api_key(provider_lower, api_key_override)
    base_url = API_URLS.get(provider.upper())
    model_id = get_model_id_from_display_name(provider_lower, model_display_name)

    if not api_key:
        env_var_name = API_KEYS_ENV_VARS.get(provider.upper(), 'N/A')
        yield f"Error: API Key not found for {provider}. Please set it in the UI or env var '{env_var_name}'."
        return
    if not base_url:
        yield f"Error: Unknown provider '{provider}' or missing API URL configuration."
        return
    if not model_id:
         yield f"Error: Model ID not found for '{model_display_name}' under provider '{provider}'. Check configuration."
         return

    headers = {}
    payload = {}
    request_url = base_url

    logger.info(f"Streaming from {provider}/{model_display_name} (ID: {model_id})...")
    
    # --- Standard OpenAI-compatible providers ---
    if provider_lower in ["groq", "openrouter", "togetherai", "openai", "xai"]:
        headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
        payload = {"model": model_id, "messages": messages, "stream": True, "temperature": temperature}
        if max_tokens: payload["max_tokens"] = max_tokens

        if provider_lower == "openrouter":
             headers["HTTP-Referer"] = os.getenv("OPENROUTER_REFERRER") or "http://localhost/gradio" # Example Referer
             headers["X-Title"] = os.getenv("OPENROUTER_X_TITLE") or "Gradio AI Researcher"      # Example Title

        try:
            response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
            response.raise_for_status()

            # More robust SSE parsing
            buffer = ""
            for chunk in response.iter_content(chunk_size=None): # Process raw bytes
                buffer += chunk.decode('utf-8', errors='replace')
                while '\n\n' in buffer:
                    event_str, buffer = buffer.split('\n\n', 1)
                    if not event_str.strip(): continue

                    content_chunk = ""
                    for line in event_str.splitlines():
                        if line.startswith('data: '):
                            data_json = line[len('data: '):].strip()
                            if data_json == '[DONE]':
                                return # Stream finished
                            try:
                                data = json.loads(data_json)
                                if data.get("choices") and len(data["choices"]) > 0:
                                    delta = data["choices"][0].get("delta", {})
                                    if delta and delta.get("content"):
                                        content_chunk += delta["content"]
                            except json.JSONDecodeError:
                                logger.warning(f"Failed to decode JSON from stream line: {data_json}")
                    if content_chunk:
                        yield content_chunk
            # Process any remaining buffer content (less common with '\n\n' delimiter)
            if buffer.strip():
                logger.debug(f"Remaining buffer after OpenAI-like stream: {buffer}")


        except requests.exceptions.HTTPError as e:
            err_msg = f"API HTTP Error ({e.response.status_code}): {e.response.text[:500]}"
            logger.error(f"{err_msg} for {provider}/{model_id}", exc_info=False)
            yield f"Error: {err_msg}"
        except requests.exceptions.RequestException as e:
            logger.error(f"API Request Error for {provider}/{model_id}: {e}", exc_info=False)
            yield f"Error: Could not connect to {provider} ({e})"
        except Exception as e:
            logger.exception(f"Unexpected error during {provider} stream:")
            yield f"Error: An unexpected error occurred: {e}"
        return

    # --- Google Gemini ---
    elif provider_lower == "google":
        system_instruction = None
        filtered_messages = []
        for msg in messages:
            if msg["role"] == "system": system_instruction = {"parts": [{"text": msg["content"]}]}
            else:
                role = "model" if msg["role"] == "assistant" else msg["role"]
                filtered_messages.append({"role": role, "parts": [{"text": msg["content"]}]})

        payload = {
             "contents": filtered_messages,
             "safetySettings": [ # Example: more permissive settings
                 {"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT", "threshold": "BLOCK_NONE"},
                 {"category": "HARM_CATEGORY_HATE_SPEECH", "threshold": "BLOCK_NONE"},
                 {"category": "HARM_CATEGORY_HARASSMENT", "threshold": "BLOCK_NONE"},
                 {"category": "HARM_CATEGORY_DANGEROUS_CONTENT", "threshold": "BLOCK_NONE"},
             ],
             "generationConfig": {"temperature": temperature}
        }
        if max_tokens: payload["generationConfig"]["maxOutputTokens"] = max_tokens
        if system_instruction: payload["system_instruction"] = system_instruction
        
        request_url = f"{base_url}{model_id}:streamGenerateContent?key={api_key}" # API key in query param
        headers = {"Content-Type": "application/json"}

        try:
            response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
            response.raise_for_status()
            
            # Google's stream is a bit different, often newline-delimited JSON arrays/objects
            buffer = ""
            for chunk in response.iter_content(chunk_size=None):
                buffer += chunk.decode('utf-8', errors='replace')
                # Google might send chunks that are not complete JSON objects, or multiple objects
                # A common pattern is [ {obj1} , {obj2} ] where chunks split mid-array or mid-object.
                # This parsing needs to be robust. A simple split by '\n' might not always work if JSON is pretty-printed.
                # The previous code's `json.loads(f"[{decoded_line}]")` was an attempt to handle this.
                # For now, let's assume newline delimited for simplicity, but this is a known tricky part.
                
                while '\n' in buffer:
                    line, buffer = buffer.split('\n', 1)
                    line = line.strip()
                    if not line: continue
                    if line.startswith(','): line = line[1:] # Handle leading commas if splitting an array

                    try:
                        # Remove "data: " prefix if present (less common for Gemini direct API but good practice)
                        if line.startswith('data: '): line = line[len('data: '):]
                        
                        # Gemini often streams an array of objects, or just one object.
                        # Try to parse as a single object first. If fails, try as array.
                        parsed_data = None
                        try:
                            parsed_data = json.loads(line)
                        except json.JSONDecodeError:
                            # If it's part of an array, it might be missing brackets.
                            # This heuristic is fragile. A proper SSE parser or stateful JSON parser is better.
                            if line.startswith('{') and line.endswith('}'): # Looks like a complete object
                                pass # already tried json.loads
                            # Try to wrap with [] if it seems like a list content without brackets
                            elif line.startswith('{') or line.endswith('}'):
                                try:
                                    temp_parsed_list = json.loads(f"[{line}]")
                                    if temp_parsed_list and isinstance(temp_parsed_list, list):
                                        parsed_data = temp_parsed_list[0] # take first if it becomes a list
                                except json.JSONDecodeError:
                                    logger.warning(f"Google: Still can't parse line even with array wrap: {line}")

                        if parsed_data:
                            data_to_process = [parsed_data] if isinstance(parsed_data, dict) else parsed_data # Ensure list
                            for event_data in data_to_process:
                                if not isinstance(event_data, dict): continue
                                if event_data.get("candidates"):
                                    for candidate in event_data["candidates"]:
                                        if candidate.get("content", {}).get("parts"):
                                            for part in candidate["content"]["parts"]:
                                                if part.get("text"):
                                                    yield part["text"]
                    except json.JSONDecodeError:
                        logger.warning(f"Google: JSONDecodeError for line: {line}")
                    except Exception as e_google_proc:
                        logger.error(f"Google: Error processing stream data: {e_google_proc}, Line: {line}")

        except requests.exceptions.HTTPError as e:
            err_msg = f"Google API HTTP Error ({e.response.status_code}): {e.response.text[:500]}"
            logger.error(err_msg, exc_info=False)
            yield f"Error: {err_msg}"
        except Exception as e:
            logger.exception(f"Unexpected error during Google stream:")
            yield f"Error: An unexpected error occurred with Google API: {e}"
        return

    # --- Cohere ---
    elif provider_lower == "cohere":
        headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json", "Accept": "application/json"}
        
        # Cohere message format
        chat_history_cohere = []
        preamble_cohere = None
        user_message_cohere = ""

        temp_messages = list(messages) # Work with a copy
        if temp_messages and temp_messages[0]["role"] == "system":
            preamble_cohere = temp_messages.pop(0)["content"]
        
        if temp_messages:
            user_message_cohere = temp_messages.pop()["content"] # Last message is the current user query
            for msg in temp_messages: # Remaining are history
                role = "USER" if msg["role"] == "user" else "CHATBOT"
                chat_history_cohere.append({"role": role, "message": msg["content"]})
        
        if not user_message_cohere:
            yield "Error: User message is empty for Cohere."
            return

        payload = {
            "model": model_id, 
            "message": user_message_cohere, 
            "stream": True, 
            "temperature": temperature
        }
        if max_tokens: payload["max_tokens"] = max_tokens # Cohere uses max_tokens
        if chat_history_cohere: payload["chat_history"] = chat_history_cohere
        if preamble_cohere: payload["preamble"] = preamble_cohere
        
        try:
            response = requests.post(base_url, headers=headers, json=payload, stream=True, timeout=180)
            response.raise_for_status()
            
            # Cohere SSE format is event: type\ndata: {json}\n\n
            buffer = ""
            for chunk_bytes in response.iter_content(chunk_size=None):
                buffer += chunk_bytes.decode('utf-8', errors='replace')
                while '\n\n' in buffer:
                    event_str, buffer = buffer.split('\n\n', 1)
                    if not event_str.strip(): continue
                    
                    event_type = None
                    data_json_str = None
                    for line in event_str.splitlines():
                        if line.startswith("event:"): event_type = line[len("event:"):].strip()
                        elif line.startswith("data:"): data_json_str = line[len("data:"):].strip()
                    
                    if data_json_str:
                        try:
                            data = json.loads(data_json_str)
                            if event_type == "text-generation" and "text" in data:
                                yield data["text"]
                            elif event_type == "stream-end":
                                logger.debug(f"Cohere stream ended. Finish reason: {data.get('finish_reason')}")
                                return 
                        except json.JSONDecodeError:
                            logger.warning(f"Cohere: Failed to decode JSON: {data_json_str}")
            if buffer.strip():
                 logger.debug(f"Cohere: Remaining buffer: {buffer.strip()}")


        except requests.exceptions.HTTPError as e:
            err_msg = f"Cohere API HTTP Error ({e.response.status_code}): {e.response.text[:500]}"
            logger.error(err_msg, exc_info=False)
            yield f"Error: {err_msg}"
        except Exception as e:
            logger.exception(f"Unexpected error during Cohere stream:")
            yield f"Error: An unexpected error occurred with Cohere API: {e}"
        return

    # --- HuggingFace Inference API (Basic TGI support) ---
    # This is very basic and might not work for all models or complex scenarios.
    # Assumes model is deployed with Text Generation Inference (TGI) and supports streaming.
    elif provider_lower == "huggingface":
        headers = {"Authorization": f"Bearer {api_key}", "Content-Type": "application/json"}
        # Construct prompt string for TGI (often needs specific formatting)
        # This is a generic attempt, specific models might need <|user|>, <|assistant|> etc.
        prompt_parts = []
        for msg in messages:
            role_prefix = ""
            if msg['role'] == 'system': role_prefix = "System: " # Or might be ignored/handled differently
            elif msg['role'] == 'user': role_prefix = "User: "
            elif msg['role'] == 'assistant': role_prefix = "Assistant: "
            prompt_parts.append(f"{role_prefix}{msg['content']}")
        
        # TGI typically expects a final "Assistant: " to start generating from
        tgi_prompt = "\n".join(prompt_parts) + "\nAssistant: "

        payload = {
            "inputs": tgi_prompt,
            "parameters": {
                "temperature": temperature if temperature > 0 else 0.01, # TGI needs temp > 0 for sampling
                "max_new_tokens": max_tokens or 1024, # Default TGI max_new_tokens
                "return_full_text": False, # We only want generated part
                "do_sample": True if temperature > 0 else False,
            },
            "stream": True
        }
        request_url = f"{base_url}{model_id}" # Model ID is part of URL path for HF

        try:
            response = requests.post(request_url, headers=headers, json=payload, stream=True, timeout=180)
            response.raise_for_status()

            # TGI SSE stream: data: {"token": {"id": ..., "text": "...", "logprob": ..., "special": ...}}
            # Or sometimes just data: "text_chunk" for simpler models/configs
            buffer = ""
            for chunk_bytes in response.iter_content(chunk_size=None):
                buffer += chunk_bytes.decode('utf-8', errors='replace')
                while '\n' in buffer: # TGI often uses single newline
                    line, buffer = buffer.split('\n', 1)
                    line = line.strip()
                    if not line: continue

                    if line.startswith('data:'):
                        data_json_str = line[len('data:'):].strip()
                        try:
                            data = json.loads(data_json_str)
                            if "token" in data and "text" in data["token"]:
                                yield data["token"]["text"]
                            elif "generated_text" in data and data.get("details") is None: # Sometimes a final non-streaming like object might appear
                                 # This case is tricky, if it's the *only* thing then it's not really streaming
                                 pass # For now, ignore if it's not a token object
                            # Some TGI might send raw text if not fully SSE compliant for stream
                            # elif isinstance(data, str): yield data 

                        except json.JSONDecodeError:
                            # If it's not JSON, it might be a raw string (less common for TGI stream=True)
                            # For safety, only yield if it's a clear text string
                            if not data_json_str.startswith('{') and not data_json_str.startswith('['):
                                yield data_json_str
                            else:
                                logger.warning(f"HF: Failed to decode JSON and not raw string: {data_json_str}")
            if buffer.strip():
                 logger.debug(f"HF: Remaining buffer: {buffer.strip()}")


        except requests.exceptions.HTTPError as e:
            err_msg = f"HF API HTTP Error ({e.response.status_code}): {e.response.text[:500]}"
            logger.error(err_msg, exc_info=False)
            yield f"Error: {err_msg}"
        except Exception as e:
            logger.exception(f"Unexpected error during HF stream:")
            yield f"Error: An unexpected error occurred with HF API: {e}"
        return

    else:
        yield f"Error: Provider '{provider}' is not configured for streaming in this handler."
        return