File size: 1,240 Bytes
c8b3b66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea32884
2e4b9b5
c8b3b66
 
 
 
 
f090069
c8b3b66
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
import os
import requests
from typing import Optional

# πŸ” Load HF credentials and endpoint URL from environment variables
HF_TOKEN = os.environ.get("HF_TOKEN")
MISTRAL_URL = os.environ.get("MISTRAL_URL")

# πŸ“œ Headers for HF Inference Endpoint
HEADERS = {
    "Authorization": f"Bearer {HF_TOKEN}",
    "Content-Type": "application/json"
}

# πŸ” Call Mistral using HF Inference Endpoint
def call_mistral(base_prompt: str, tail_prompt: str) -> Optional[str]:
    full_prompt = f"<s>[INST]{base_prompt}\n\n{tail_prompt}[/INST]</s>"
    payload = {
        "inputs": full_prompt
    }

    try:
        timeout = (10, 120)
        response = requests.post(MISTRAL_URL, headers=HEADERS, json=payload, timeout=timeout)
        response.raise_for_status()
        data = response.json()

        raw_output = ""
        if isinstance(data, list) and data:
            raw_output = data[0].get("generated_text", "")
        elif isinstance(data, dict):
            raw_output = data.get("generated_text", "")

        if "[/INST]</s>" in raw_output:
            return raw_output.split("[/INST]</s>")[-1].strip()
        return raw_output.strip()

    except Exception as e:
        print(f"⚠️ Mistral error: {e}")
        return None