Spaces:
Sleeping
Sleeping
import time | |
from typing import Type, Union, Dict, Any | |
from pydantic import BaseModel | |
import dirtyjson | |
import re | |
# Make sure you install dirtyjson: pip install dirtyjson | |
# === Optionally, import your Groq client from where you configure it === | |
# === Helper function === | |
def call_llm_and_parse( | |
groq_client, | |
prompt: str, | |
model: Type[BaseModel], | |
max_retries: int = 3, | |
delay: float = 1.0 | |
) -> Union[BaseModel, Dict[str, Any]]: | |
""" | |
Call LLM with a prompt, parse the JSON response, and validate it using a Pydantic model. | |
Args: | |
prompt (str): The prompt to send to the LLM. | |
model (Type[BaseModel]): The Pydantic model to validate against. | |
max_retries (int, optional): Number of retries on failure. Default is 3. | |
delay (float, optional): Delay (in seconds) between retries, multiplied by attempt count. | |
Returns: | |
BaseModel: Validated Pydantic model instance if successful. | |
dict: Contains 'error' and 'raw' fields if validation fails after retries. | |
""" | |
for attempt in range(1, max_retries + 1): | |
try: | |
print(f"[call_llm_and_parse] Attempt {attempt}: sending prompt to LLM...") | |
completion = groq_client.chat.completions.create( | |
model="llama3-8b-8192", | |
messages=[{"role": "user", "content": prompt}], | |
temperature=0.3, | |
max_tokens=800 | |
) | |
response_text = completion.choices[0].message.content | |
print(f"[call_llm_and_parse] Raw LLM response: {response_text[:200]}...") # first 200 chars | |
# Extract JSON (handle dirty or partial JSON) | |
json_str = extract_and_repair_json(response_text) | |
# Parse JSON using dirtyjson | |
parsed = dirtyjson.loads(json_str) | |
# Validate with Pydantic | |
validated = model.model_validate(parsed) | |
print("[call_llm_and_parse] Successfully parsed and validated.") | |
return validated | |
except Exception as e: | |
print(f"[Retry {attempt}] Error: {e}") | |
if attempt < max_retries: | |
time.sleep(delay * attempt) | |
else: | |
print("[call_llm_and_parse] Failed after retries.") | |
return { | |
"error": f"Validation failed after {max_retries} retries: {e}", | |
"raw": json_str if 'json_str' in locals() else response_text | |
} | |
def extract_and_repair_json(text: str) -> str: | |
""" | |
Extracts JSON starting from first '{' and balances braces. | |
""" | |
match = re.search(r'\{[\s\S]*', text) | |
if not match: | |
raise ValueError("No JSON object found.") | |
json_str = match.group() | |
# Fix unmatched braces | |
open_braces = json_str.count('{') | |
close_braces = json_str.count('}') | |
if open_braces > close_braces: | |
json_str += '}' * (open_braces - close_braces) | |
return json_str |