Linkedin-Assistant / llm_utils.py
tsrivallabh's picture
Upload 10 files
5318b09 verified
import time
from typing import Type, Union, Dict, Any
from pydantic import BaseModel
import dirtyjson
import re
# Make sure you install dirtyjson: pip install dirtyjson
# === Optionally, import your Groq client from where you configure it ===
# === Helper function ===
def call_llm_and_parse(
groq_client,
prompt: str,
model: Type[BaseModel],
max_retries: int = 3,
delay: float = 1.0
) -> Union[BaseModel, Dict[str, Any]]:
"""
Call LLM with a prompt, parse the JSON response, and validate it using a Pydantic model.
Args:
prompt (str): The prompt to send to the LLM.
model (Type[BaseModel]): The Pydantic model to validate against.
max_retries (int, optional): Number of retries on failure. Default is 3.
delay (float, optional): Delay (in seconds) between retries, multiplied by attempt count.
Returns:
BaseModel: Validated Pydantic model instance if successful.
dict: Contains 'error' and 'raw' fields if validation fails after retries.
"""
for attempt in range(1, max_retries + 1):
try:
print(f"[call_llm_and_parse] Attempt {attempt}: sending prompt to LLM...")
completion = groq_client.chat.completions.create(
model="llama3-8b-8192",
messages=[{"role": "user", "content": prompt}],
temperature=0.3,
max_tokens=800
)
response_text = completion.choices[0].message.content
print(f"[call_llm_and_parse] Raw LLM response: {response_text[:200]}...") # first 200 chars
# Extract JSON (handle dirty or partial JSON)
json_str = extract_and_repair_json(response_text)
# Parse JSON using dirtyjson
parsed = dirtyjson.loads(json_str)
# Validate with Pydantic
validated = model.model_validate(parsed)
print("[call_llm_and_parse] Successfully parsed and validated.")
return validated
except Exception as e:
print(f"[Retry {attempt}] Error: {e}")
if attempt < max_retries:
time.sleep(delay * attempt)
else:
print("[call_llm_and_parse] Failed after retries.")
return {
"error": f"Validation failed after {max_retries} retries: {e}",
"raw": json_str if 'json_str' in locals() else response_text
}
def extract_and_repair_json(text: str) -> str:
"""
Extracts JSON starting from first '{' and balances braces.
"""
match = re.search(r'\{[\s\S]*', text)
if not match:
raise ValueError("No JSON object found.")
json_str = match.group()
# Fix unmatched braces
open_braces = json_str.count('{')
close_braces = json_str.count('}')
if open_braces > close_braces:
json_str += '}' * (open_braces - close_braces)
return json_str