File size: 3,036 Bytes
5318b09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
import time
from typing import Type, Union, Dict, Any
from pydantic import BaseModel
import dirtyjson
import re
# Make sure you install dirtyjson: pip install dirtyjson

# === Optionally, import your Groq client from where you configure it ===

# === Helper function ===

def call_llm_and_parse(

    groq_client,

    prompt: str,

    model: Type[BaseModel],

    max_retries: int = 3,

    delay: float = 1.0

) -> Union[BaseModel, Dict[str, Any]]:
    """

    Call LLM with a prompt, parse the JSON response, and validate it using a Pydantic model.

    

    Args:

        prompt (str): The prompt to send to the LLM.

        model (Type[BaseModel]): The Pydantic model to validate against.

        max_retries (int, optional): Number of retries on failure. Default is 3.

        delay (float, optional): Delay (in seconds) between retries, multiplied by attempt count.

    

    Returns:

        BaseModel: Validated Pydantic model instance if successful.

        dict: Contains 'error' and 'raw' fields if validation fails after retries.

    """
    for attempt in range(1, max_retries + 1):
        try:
            print(f"[call_llm_and_parse] Attempt {attempt}: sending prompt to LLM...")

            completion = groq_client.chat.completions.create(
                model="llama3-8b-8192",
                messages=[{"role": "user", "content": prompt}],
                temperature=0.3,
                max_tokens=800
            )

            response_text = completion.choices[0].message.content
            print(f"[call_llm_and_parse] Raw LLM response: {response_text[:200]}...")  # first 200 chars

            # Extract JSON (handle dirty or partial JSON)
            json_str = extract_and_repair_json(response_text)

            # Parse JSON using dirtyjson
            parsed = dirtyjson.loads(json_str)

            # Validate with Pydantic
            validated = model.model_validate(parsed)

            print("[call_llm_and_parse] Successfully parsed and validated.")
            return validated

        except Exception as e:
            print(f"[Retry {attempt}] Error: {e}")
            if attempt < max_retries:
                time.sleep(delay * attempt)
            else:
                print("[call_llm_and_parse] Failed after retries.")
                return {
                    "error": f"Validation failed after {max_retries} retries: {e}",
                    "raw": json_str if 'json_str' in locals() else response_text
                }


def extract_and_repair_json(text: str) -> str:
    """

    Extracts JSON starting from first '{' and balances braces.

    """
    match = re.search(r'\{[\s\S]*', text)
    if not match:
        raise ValueError("No JSON object found.")
    json_str = match.group()
    # Fix unmatched braces
    open_braces = json_str.count('{')
    close_braces = json_str.count('}')
    if open_braces > close_braces:
        json_str += '}' * (open_braces - close_braces)
    return json_str