doctorecord / src /services /llm_client.py
levalencia's picture
feat: enhance LLMClient with structured output capabilities
c77b1c5
"""Azure-OpenAI wrapper that exposes the *Responses* API.
Keeps the rest of the codebase insulated from SDK / vendor details.
"""
from __future__ import annotations
from typing import Any, List, Dict, Optional
import time
import random
import json
import openai
from openai import AzureOpenAI
import logging
class LLMClient:
"""Thin wrapper around Azure OpenAI using both Responses and Chat Completions APIs."""
def __init__(self, settings):
# Configure the global client for Azure (for Responses API)
openai.api_type = "azure"
openai.api_key = settings.OPENAI_API_KEY or settings.AZURE_OPENAI_API_KEY
openai.api_base = settings.AZURE_OPENAI_ENDPOINT
openai.api_version = settings.AZURE_OPENAI_API_VERSION
# Create Azure OpenAI client for structured output
self.azure_client = AzureOpenAI(
azure_endpoint=settings.AZURE_OPENAI_ENDPOINT,
api_key=settings.OPENAI_API_KEY or settings.AZURE_OPENAI_API_KEY,
api_version=settings.AZURE_OPENAI_API_VERSION
)
self._deployment = settings.AZURE_OPENAI_DEPLOYMENT
self._max_retries = settings.LLM_MAX_RETRIES
self._base_delay = settings.LLM_BASE_DELAY
self._max_delay = settings.LLM_MAX_DELAY
# Log configuration (without exposing the API key)
logger = logging.getLogger(__name__)
logger.info("Azure OpenAI Configuration:")
logger.info(f"API Type: {openai.api_type}")
logger.info(f"API Base: {openai.api_base}")
logger.info(f"API Version from settings: {settings.AZURE_OPENAI_API_VERSION}")
logger.info(f"API Version in openai client: {openai.api_version}")
logger.info(f"Deployment: {self._deployment}")
logger.info(f"API Key present: {'Yes' if openai.api_key else 'No'}")
logger.info(f"API Key length: {len(openai.api_key) if openai.api_key else 0}")
logger.info(f"Retry config: max_retries={self._max_retries}, base_delay={self._base_delay}s, max_delay={self._max_delay}s")
def _should_retry(self, exception) -> bool:
"""Determine if an exception should trigger a retry."""
# Retry on 503 Service Unavailable, 500 Internal Server Error, and other server errors
if hasattr(exception, 'status_code'):
return exception.status_code >= 500
# Also retry on connection errors and timeouts
if hasattr(exception, '__class__'):
error_type = exception.__class__.__name__
return any(error in error_type for error in ['Timeout', 'Connection', 'Network'])
return False
def _exponential_backoff(self, attempt: int, base_delay: float = 1.0, max_delay: float = 60.0) -> float:
"""Calculate delay for exponential backoff with jitter."""
delay = min(base_delay * (2 ** attempt), max_delay)
# Add jitter to prevent thundering herd
jitter = random.uniform(0, 0.1 * delay)
return delay + jitter
def _create_structured_output_schema(self, fields: List[str]) -> Dict[str, Any]:
"""Create a dynamic JSON schema for structured output based on the fields."""
properties = {}
required = []
for field in fields:
# Create a property for each field that can contain an array of values
properties[field] = {
"type": "array",
"items": {
"type": ["string", "null"]
},
"description": f"Array of values for the field '{field}'"
}
required.append(field)
return {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False
}
def _create_combinations_schema(self, unique_indices: List[str]) -> Dict[str, Any]:
"""Create a dynamic JSON schema for unique combinations output."""
# Define properties for each unique index
properties = {}
required = []
for index in unique_indices:
properties[index] = {
"type": "string",
"description": f"Value for the unique index '{index}'"
}
required.append(index)
# Return schema with root object containing a combinations array
# Azure OpenAI structured output requires root schema to be an object
return {
"type": "object",
"properties": {
"combinations": {
"type": "array",
"items": {
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": False
},
"description": "Array of unique combinations of indices"
}
},
"required": ["combinations"],
"additionalProperties": False
}
# --------------------------------------------------
def responses(self, prompt: str, tools: List[dict] | None = None, description: str = "LLM Call",
max_retries: int = None, base_delay: float = None, **kwargs: Any) -> str:
"""Call the Responses API and return the assistant content as string."""
logger = logging.getLogger(__name__)
logger.info(f"Making request with API version: {openai.api_version}")
logger.info(f"Request URL will be: {openai.api_base}/openai/responses?api-version={openai.api_version}")
# Use instance defaults if not provided
max_retries = max_retries if max_retries is not None else self._max_retries
base_delay = base_delay if base_delay is not None else self._base_delay
# Remove ctx from kwargs before passing to openai
ctx = kwargs.pop("ctx", None)
last_exception = None
for attempt in range(max_retries + 1):
try:
resp = openai.responses.create(
input=prompt,
model=self._deployment,
tools=tools or [],
**kwargs,
)
# Log the raw response for debugging
logging.debug(f"LLM raw response: {resp}")
# --- Cost tracking: must be BEFORE any return! ---
logger.info(f"LLMClient.responses: ctx is {ctx}")
if ctx and "cost_tracker" in ctx:
logger.info(f"LLMClient.responses: cost_tracker is {ctx['cost_tracker']}")
usage = getattr(resp, "usage", None)
if usage:
logger.info(f"LLMClient.responses: usage is {usage}")
ctx["cost_tracker"].add_llm_tokens(
input_tokens=getattr(usage, "input_tokens", 0),
output_tokens=getattr(usage, "output_tokens", 0),
description=description
)
logger.info(f"LLMClient.responses: prompt: {prompt[:200]}...") # Log first 200 chars
logger.info(f"LLMClient.responses: resp: {str(resp)[:200]}...") # Log first 200 chars
if usage:
logger.info(f"LLMClient.responses: usage.input_tokens={getattr(usage, 'input_tokens', None)}, usage.output_tokens={getattr(usage, 'output_tokens', None)}, usage.total_tokens={getattr(usage, 'total_tokens', None)}")
else:
# Fallback: estimate tokens (very rough)
ctx["cost_tracker"].add_llm_tokens(
input_tokens=len(prompt.split()),
output_tokens=len(str(resp).split()),
description=description
)
# Extract the text content from the response
if hasattr(resp, "output") and isinstance(resp.output, list):
# Handle list of ResponseOutputMessage objects
for message in resp.output:
if hasattr(message, "content") and isinstance(message.content, list):
for content in message.content:
if hasattr(content, "text"):
return content.text
# Fallback methods if the above doesn't work
if hasattr(resp, "output"):
return resp.output
elif hasattr(resp, "response"):
return resp.response
elif hasattr(resp, "content"):
return resp.content
elif hasattr(resp, "data"):
return resp.data
else:
logging.error(f"Could not extract text from response: {resp}")
return str(resp)
except Exception as e:
last_exception = e
logger.warning(f"Attempt {attempt + 1}/{max_retries + 1} failed: {type(e).__name__}: {str(e)}")
# Check if we should retry
if attempt < max_retries and self._should_retry(e):
delay = self._exponential_backoff(attempt, base_delay, self._max_delay)
logger.info(f"Retrying in {delay:.2f} seconds...")
time.sleep(delay)
continue
else:
# Either we've exhausted retries or this is not a retryable error
if attempt >= max_retries:
logger.error(f"Max retries ({max_retries}) exceeded. Last error: {type(e).__name__}: {str(e)}")
else:
logger.error(f"Non-retryable error: {type(e).__name__}: {str(e)}")
raise last_exception
# --------------------------------------------------
def structured_responses(self, prompt: str, fields: List[str], description: str = "Structured LLM Call",
max_retries: int = None, base_delay: float = None, **kwargs: Any) -> Dict[str, Any]:
"""Call the Azure OpenAI Chat Completions API with structured output and return parsed JSON."""
logger = logging.getLogger(__name__)
logger.info(f"Making structured request for fields: {fields}")
# Use instance defaults if not provided
max_retries = max_retries if max_retries is not None else self._max_retries
base_delay = base_delay if base_delay is not None else self._base_delay
# Remove ctx from kwargs before passing to openai
ctx = kwargs.pop("ctx", None)
# Create the structured output schema
schema = self._create_structured_output_schema(fields)
logger.debug(f"Using schema: {json.dumps(schema, indent=2)}")
# Create the response format for structured output
response_format = {
"type": "json_schema",
"json_schema": {
"name": "field_extraction_schema",
"description": "Schema for extracting structured field data",
"schema": schema
}
}
last_exception = None
for attempt in range(max_retries + 1):
try:
# Use Azure OpenAI Chat Completions API with structured output
completion = self.azure_client.beta.chat.completions.parse(
model=self._deployment,
messages=[
{"role": "user", "content": prompt}
],
response_format=response_format,
**kwargs,
)
# Log the raw response for debugging
logging.debug(f"Structured LLM raw response: {completion}")
# --- Cost tracking: must be BEFORE any return! ---
if ctx and "cost_tracker" in ctx:
usage = getattr(completion, "usage", None)
if usage:
ctx["cost_tracker"].add_llm_tokens(
input_tokens=getattr(usage, "prompt_tokens", 0),
output_tokens=getattr(usage, "completion_tokens", 0),
description=description
)
else:
# Fallback: estimate tokens (very rough)
ctx["cost_tracker"].add_llm_tokens(
input_tokens=len(prompt.split()),
output_tokens=len(str(completion).split()),
description=description
)
# Extract the structured output from the response
if hasattr(completion, "choices") and len(completion.choices) > 0:
choice = completion.choices[0]
if hasattr(choice, "message"):
# First try to get the parsed structured output
if hasattr(choice.message, "parsed") and choice.message.parsed is not None:
result = choice.message.parsed
logger.info(f"Successfully parsed structured output: {json.dumps(result, indent=2)}")
return result
# If parsed is None but content exists, try to parse the content as JSON
elif hasattr(choice.message, "content") and choice.message.content:
logger.info("Parsed field is None, attempting to parse content as JSON")
try:
result = json.loads(choice.message.content)
logger.info(f"Successfully parsed JSON from content: {json.dumps(result, indent=2)}")
return result
except json.JSONDecodeError as json_error:
logger.warning(f"Failed to parse content as JSON: {json_error}")
logger.debug(f"Content was: {choice.message.content}")
else:
logger.warning("No parsed output or content found in message")
# Fallback: try to extract from text if structured output failed
logger.warning("Structured output not found, falling back to text extraction")
text_response = self.responses(prompt, description=description, ctx=ctx, **kwargs)
try:
# Try to parse the text response as JSON
result = json.loads(text_response)
logger.info(f"Successfully parsed fallback JSON: {json.dumps(result, indent=2)}")
return result
except json.JSONDecodeError:
logger.error(f"Failed to parse fallback response as JSON: {text_response}")
# Return empty result with the expected structure
empty_result = {field: [] for field in fields}
logger.warning(f"Returning empty result: {empty_result}")
return empty_result
except Exception as e:
last_exception = e
logger.warning(f"Attempt {attempt + 1}/{max_retries + 1} failed: {type(e).__name__}: {str(e)}")
# Check if we should retry
if attempt < max_retries and self._should_retry(e):
delay = self._exponential_backoff(attempt, base_delay, self._max_delay)
logger.info(f"Retrying in {delay:.2f} seconds...")
time.sleep(delay)
continue
else:
# Either we've exhausted retries or this is not a retryable error
if attempt >= max_retries:
logger.error(f"Max retries ({max_retries}) exceeded. Last error: {type(e).__name__}: {str(e)}")
else:
logger.error(f"Non-retryable error: {type(e).__name__}: {str(e)}")
# Return empty result on final failure
empty_result = {field: [] for field in fields}
logger.warning(f"Returning empty result due to error: {empty_result}")
return empty_result
# --------------------------------------------------
def structured_combinations(self, prompt: str, unique_indices: List[str], description: str = "Structured Combinations Call",
max_retries: int = None, base_delay: float = None, **kwargs: Any) -> str:
"""Call the Azure OpenAI Chat Completions API with structured output for unique combinations and return JSON string."""
logger = logging.getLogger(__name__)
logger.info(f"Making structured combinations request for indices: {unique_indices}")
# Use instance defaults if not provided
max_retries = max_retries if max_retries is not None else self._max_retries
base_delay = base_delay if base_delay is not None else self._base_delay
# Remove ctx from kwargs before passing to openai
ctx = kwargs.pop("ctx", None)
# Create the structured output schema for combinations
schema = self._create_combinations_schema(unique_indices)
logger.debug(f"Using combinations schema: {json.dumps(schema, indent=2)}")
# Create the response format for structured output
response_format = {
"type": "json_schema",
"json_schema": {
"name": "unique_combinations_schema",
"description": "Schema for extracting unique combinations of indices",
"schema": schema,
"strict": True
}
}
logger.info(f"Using response format: {json.dumps(response_format, indent=2)}")
last_exception = None
for attempt in range(max_retries + 1):
try:
# Use Azure OpenAI Chat Completions API with structured output
completion = self.azure_client.chat.completions.create(
model=self._deployment,
messages=[
{"role": "user", "content": prompt}
],
response_format=response_format,
temperature=kwargs.get("temperature", 0.0),
)
# Log the raw response for debugging
logging.debug(f"Structured combinations LLM raw response: {completion}")
# --- Cost tracking: must be BEFORE any return! ---
if ctx and "cost_tracker" in ctx:
usage = getattr(completion, "usage", None)
if usage:
ctx["cost_tracker"].add_llm_tokens(
input_tokens=getattr(usage, "prompt_tokens", 0),
output_tokens=getattr(usage, "completion_tokens", 0),
description=description
)
logger.info(f"Structured combinations costs - Input tokens: {ctx['cost_tracker'].llm_input_tokens}, Output tokens: {ctx['cost_tracker'].llm_output_tokens}")
logger.info(f"Structured combinations cost: ${ctx['cost_tracker'].calculate_current_file_costs()['openai']['total_cost']:.4f}")
else:
# Fallback: estimate tokens (very rough)
ctx["cost_tracker"].add_llm_tokens(
input_tokens=len(prompt.split()),
output_tokens=len(str(completion).split()),
description=description
)
# Extract the structured output from the response
if hasattr(completion, "choices") and len(completion.choices) > 0:
choice = completion.choices[0]
if hasattr(choice, "message"):
# First try to get the parsed structured output
if hasattr(choice.message, "parsed") and choice.message.parsed is not None:
parsed_result = choice.message.parsed
logger.info(f"Successfully got parsed structured combinations output")
logger.debug(f"Parsed result: {parsed_result}")
# Extract the combinations array from the structured response
if isinstance(parsed_result, dict) and "combinations" in parsed_result:
combinations = parsed_result["combinations"]
logger.info(f"Successfully extracted {len(combinations)} unique combinations from structured output")
# Log the first combination as an example
if combinations and len(combinations) > 0:
logger.info(f"Example combination: {json.dumps(combinations[0], indent=2)}")
# Return the combinations array as JSON string (formatted)
return json.dumps(combinations, indent=2)
else:
logger.warning(f"Unexpected parsed result structure: {parsed_result}")
# Fallback to content parsing if parsed is None
elif hasattr(choice.message, "content") and choice.message.content:
content = choice.message.content
logger.info(f"Parsed field is None, attempting to parse content as JSON")
logger.debug(f"Raw content: {content}")
# Validate that it's valid JSON
try:
parsed_json = json.loads(content)
logger.info(f"Successfully parsed JSON from content")
# Extract combinations array if it exists in the expected structure
if isinstance(parsed_json, dict) and "combinations" in parsed_json:
combinations = parsed_json["combinations"]
logger.info(f"Successfully extracted {len(combinations)} unique combinations from content")
# Log the first combination as an example
if combinations and len(combinations) > 0:
logger.info(f"Example combination: {json.dumps(combinations[0], indent=2)}")
# Return the combinations array as JSON string (formatted)
return json.dumps(combinations, indent=2)
# Fallback: if it's already an array (old format), return as-is
elif isinstance(parsed_json, list):
logger.info(f"Content is already an array format with {len(parsed_json)} combinations")
return json.dumps(parsed_json, indent=2)
else:
logger.warning(f"Unexpected JSON structure in content: {parsed_json}")
except json.JSONDecodeError as json_error:
logger.warning(f"Failed to parse content as JSON: {json_error}")
logger.debug(f"Content was: {content}")
else:
logger.warning("No parsed output or content found in message")
# Fallback: try to extract from text if structured output failed
logger.warning("Structured output not found, falling back to regular responses method")
fallback_response = self.responses(prompt, description=description, ctx=ctx, **kwargs)
logger.info("Fallback to regular responses method successful")
return fallback_response
except Exception as e:
last_exception = e
logger.warning(f"Attempt {attempt + 1}/{max_retries + 1} failed: {type(e).__name__}: {str(e)}")
# Check if we should retry
if attempt < max_retries and self._should_retry(e):
delay = self._exponential_backoff(attempt, base_delay, self._max_delay)
logger.info(f"Retrying in {delay:.2f} seconds...")
time.sleep(delay)
continue
else:
# Either we've exhausted retries or this is not a retryable error
if attempt >= max_retries:
logger.error(f"Max retries ({max_retries}) exceeded. Last error: {type(e).__name__}: {str(e)}")
else:
logger.error(f"Non-retryable error: {type(e).__name__}: {str(e)}")
# Fallback to regular responses method on final failure
logger.warning("Final fallback to regular responses method")
try:
fallback_response = self.responses(prompt, description=description, ctx=ctx, **kwargs)
logger.info("Final fallback successful")
return fallback_response
except Exception as fallback_error:
logger.error(f"Final fallback also failed: {fallback_error}")
raise last_exception