|
import os |
|
import requests |
|
from typing import Optional |
|
|
|
|
|
HF_TOKEN = os.environ.get("HF_TOKEN") |
|
MISTRAL_URL = os.environ.get("MISTRAL_URL") |
|
|
|
|
|
HEADERS = { |
|
"Authorization": f"Bearer {HF_TOKEN}", |
|
"Content-Type": "application/json" |
|
} |
|
|
|
|
|
def call_mistral(base_prompt: str, tail_prompt: str) -> Optional[str]: |
|
full_prompt = f"<s>[INST]{base_prompt}\n\n{tail_prompt}[/INST]</s>" |
|
payload = { |
|
"inputs": full_prompt |
|
} |
|
|
|
try: |
|
timeout = (10, 120) |
|
response = requests.post(MISTRAL_URL, headers=HEADERS, json=payload, timeout=timeout) |
|
response.raise_for_status() |
|
data = response.json() |
|
|
|
raw_output = "" |
|
if isinstance(data, list) and data: |
|
raw_output = data[0].get("generated_text", "") |
|
elif isinstance(data, dict): |
|
raw_output = data.get("generated_text", "") |
|
|
|
if "[/INST]</s>" in raw_output: |
|
return raw_output.split("[/INST]</s>")[-1].strip() |
|
return raw_output.strip() |
|
|
|
except Exception as e: |
|
print(f"β οΈ Mistral error: {e}") |
|
return None |
|
|