|
""" |
|
This module handles the creation of API clients for the application. |
|
It includes the logic for instantiating the Hugging Face InferenceClient |
|
and the Tavily Search client. |
|
|
|
The get_inference_client function is critical for enabling the "user-pays" |
|
model in a Hugging Face Space. It prioritizes the API token of the logged-in |
|
user, ensuring their account is billed for inference costs. |
|
""" |
|
|
|
import os |
|
from typing import Optional |
|
|
|
from huggingface_hub import InferenceClient |
|
from tavily import TavilyClient |
|
|
|
|
|
|
|
|
|
|
|
HF_TOKEN = os.getenv('HF_TOKEN') |
|
|
|
def get_inference_client(model_id: str, provider: str = "auto", user_token: Optional[str] = None) -> InferenceClient: |
|
""" |
|
Creates and returns a Hugging Face InferenceClient. |
|
|
|
This function implements the "user-pays" logic. It prioritizes using the token |
|
provided by the logged-in user (`user_token`). If that is not available, |
|
it falls back to the Space owner's token (`HF_TOKEN`). |
|
|
|
Args: |
|
model_id (str): The ID of the model to be used (e.g., "mistralai/Mistral-7B-Instruct-v0.2"). |
|
provider (str): The specific inference provider (e.g., "groq"). Defaults to "auto". |
|
user_token (Optional[str]): The API token of the logged-in user, passed from the Gradio app. |
|
|
|
Returns: |
|
InferenceClient: An initialized client ready for making API calls. |
|
|
|
Raises: |
|
ValueError: If no API token can be found (neither from the user nor the environment). |
|
""" |
|
|
|
token_to_use = user_token or HF_TOKEN |
|
|
|
|
|
if not token_to_use: |
|
raise ValueError( |
|
"Cannot proceed without an API token. " |
|
"Please log into Hugging Face, or ensure the HF_TOKEN environment secret is set for this Space." |
|
) |
|
|
|
|
|
if model_id == "moonshotai/Kimi-K2-Instruct": |
|
provider = "groq" |
|
|
|
|
|
|
|
|
|
return InferenceClient( |
|
provider=provider, |
|
api_key=token_to_use |
|
) |
|
|
|
|
|
|
|
|
|
|
|
TAVILY_API_KEY = os.getenv('TAVILY_API_KEY') |
|
tavily_client = None |
|
|
|
if TAVILY_API_KEY: |
|
try: |
|
tavily_client = TavilyClient(api_key=TAVILY_API_KEY) |
|
except Exception as e: |
|
|
|
print(f"Warning: Failed to initialize Tavily client. Web search will be unavailable. Error: {e}") |
|
tavily_client = None |