Spaces:
Sleeping
Sleeping
"""Azure-OpenAI wrapper that exposes the *Responses* API. | |
Keeps the rest of the codebase insulated from SDK / vendor details. | |
""" | |
from __future__ import annotations | |
from typing import Any, List, Dict, Optional | |
import time | |
import random | |
import json | |
import openai | |
from openai import AzureOpenAI | |
import logging | |
class LLMClient: | |
"""Thin wrapper around Azure OpenAI using both Responses and Chat Completions APIs.""" | |
def __init__(self, settings): | |
# Configure the global client for Azure (for Responses API) | |
openai.api_type = "azure" | |
openai.api_key = settings.OPENAI_API_KEY or settings.AZURE_OPENAI_API_KEY | |
openai.api_base = settings.AZURE_OPENAI_ENDPOINT | |
openai.api_version = settings.AZURE_OPENAI_API_VERSION | |
# Create Azure OpenAI client for structured output | |
self.azure_client = AzureOpenAI( | |
azure_endpoint=settings.AZURE_OPENAI_ENDPOINT, | |
api_key=settings.OPENAI_API_KEY or settings.AZURE_OPENAI_API_KEY, | |
api_version=settings.AZURE_OPENAI_API_VERSION | |
) | |
self._deployment = settings.AZURE_OPENAI_DEPLOYMENT | |
self._max_retries = settings.LLM_MAX_RETRIES | |
self._base_delay = settings.LLM_BASE_DELAY | |
self._max_delay = settings.LLM_MAX_DELAY | |
# Log configuration (without exposing the API key) | |
logger = logging.getLogger(__name__) | |
logger.info("Azure OpenAI Configuration:") | |
logger.info(f"API Type: {openai.api_type}") | |
logger.info(f"API Base: {openai.api_base}") | |
logger.info(f"API Version from settings: {settings.AZURE_OPENAI_API_VERSION}") | |
logger.info(f"API Version in openai client: {openai.api_version}") | |
logger.info(f"Deployment: {self._deployment}") | |
logger.info(f"API Key present: {'Yes' if openai.api_key else 'No'}") | |
logger.info(f"API Key length: {len(openai.api_key) if openai.api_key else 0}") | |
logger.info(f"Retry config: max_retries={self._max_retries}, base_delay={self._base_delay}s, max_delay={self._max_delay}s") | |
def _should_retry(self, exception) -> bool: | |
"""Determine if an exception should trigger a retry.""" | |
# Retry on 503 Service Unavailable, 500 Internal Server Error, and other server errors | |
if hasattr(exception, 'status_code'): | |
return exception.status_code >= 500 | |
# Also retry on connection errors and timeouts | |
if hasattr(exception, '__class__'): | |
error_type = exception.__class__.__name__ | |
return any(error in error_type for error in ['Timeout', 'Connection', 'Network']) | |
return False | |
def _exponential_backoff(self, attempt: int, base_delay: float = 1.0, max_delay: float = 60.0) -> float: | |
"""Calculate delay for exponential backoff with jitter.""" | |
delay = min(base_delay * (2 ** attempt), max_delay) | |
# Add jitter to prevent thundering herd | |
jitter = random.uniform(0, 0.1 * delay) | |
return delay + jitter | |
def _create_structured_output_schema(self, fields: List[str]) -> Dict[str, Any]: | |
"""Create a dynamic JSON schema for structured output based on the fields.""" | |
properties = {} | |
required = [] | |
for field in fields: | |
# Create a property for each field that can contain an array of values | |
properties[field] = { | |
"type": "array", | |
"items": { | |
"type": ["string", "null"] | |
}, | |
"description": f"Array of values for the field '{field}'" | |
} | |
required.append(field) | |
return { | |
"type": "object", | |
"properties": properties, | |
"required": required, | |
"additionalProperties": False | |
} | |
def _create_combinations_schema(self, unique_indices: List[str]) -> Dict[str, Any]: | |
"""Create a dynamic JSON schema for unique combinations output.""" | |
# Define properties for each unique index | |
properties = {} | |
required = [] | |
for index in unique_indices: | |
properties[index] = { | |
"type": "string", | |
"description": f"Value for the unique index '{index}'" | |
} | |
required.append(index) | |
# Return schema with root object containing a combinations array | |
# Azure OpenAI structured output requires root schema to be an object | |
return { | |
"type": "object", | |
"properties": { | |
"combinations": { | |
"type": "array", | |
"items": { | |
"type": "object", | |
"properties": properties, | |
"required": required, | |
"additionalProperties": False | |
}, | |
"description": "Array of unique combinations of indices" | |
} | |
}, | |
"required": ["combinations"], | |
"additionalProperties": False | |
} | |
# -------------------------------------------------- | |
def responses(self, prompt: str, tools: List[dict] | None = None, description: str = "LLM Call", | |
max_retries: int = None, base_delay: float = None, **kwargs: Any) -> str: | |
"""Call the Responses API and return the assistant content as string.""" | |
logger = logging.getLogger(__name__) | |
logger.info(f"Making request with API version: {openai.api_version}") | |
logger.info(f"Request URL will be: {openai.api_base}/openai/responses?api-version={openai.api_version}") | |
# Use instance defaults if not provided | |
max_retries = max_retries if max_retries is not None else self._max_retries | |
base_delay = base_delay if base_delay is not None else self._base_delay | |
# Remove ctx from kwargs before passing to openai | |
ctx = kwargs.pop("ctx", None) | |
last_exception = None | |
for attempt in range(max_retries + 1): | |
try: | |
resp = openai.responses.create( | |
input=prompt, | |
model=self._deployment, | |
tools=tools or [], | |
**kwargs, | |
) | |
# Log the raw response for debugging | |
logging.debug(f"LLM raw response: {resp}") | |
# --- Cost tracking: must be BEFORE any return! --- | |
logger.info(f"LLMClient.responses: ctx is {ctx}") | |
if ctx and "cost_tracker" in ctx: | |
logger.info(f"LLMClient.responses: cost_tracker is {ctx['cost_tracker']}") | |
usage = getattr(resp, "usage", None) | |
if usage: | |
logger.info(f"LLMClient.responses: usage is {usage}") | |
ctx["cost_tracker"].add_llm_tokens( | |
input_tokens=getattr(usage, "input_tokens", 0), | |
output_tokens=getattr(usage, "output_tokens", 0), | |
description=description | |
) | |
logger.info(f"LLMClient.responses: prompt: {prompt[:200]}...") # Log first 200 chars | |
logger.info(f"LLMClient.responses: resp: {str(resp)[:200]}...") # Log first 200 chars | |
if usage: | |
logger.info(f"LLMClient.responses: usage.input_tokens={getattr(usage, 'input_tokens', None)}, usage.output_tokens={getattr(usage, 'output_tokens', None)}, usage.total_tokens={getattr(usage, 'total_tokens', None)}") | |
else: | |
# Fallback: estimate tokens (very rough) | |
ctx["cost_tracker"].add_llm_tokens( | |
input_tokens=len(prompt.split()), | |
output_tokens=len(str(resp).split()), | |
description=description | |
) | |
# Extract the text content from the response | |
if hasattr(resp, "output") and isinstance(resp.output, list): | |
# Handle list of ResponseOutputMessage objects | |
for message in resp.output: | |
if hasattr(message, "content") and isinstance(message.content, list): | |
for content in message.content: | |
if hasattr(content, "text"): | |
return content.text | |
# Fallback methods if the above doesn't work | |
if hasattr(resp, "output"): | |
return resp.output | |
elif hasattr(resp, "response"): | |
return resp.response | |
elif hasattr(resp, "content"): | |
return resp.content | |
elif hasattr(resp, "data"): | |
return resp.data | |
else: | |
logging.error(f"Could not extract text from response: {resp}") | |
return str(resp) | |
except Exception as e: | |
last_exception = e | |
logger.warning(f"Attempt {attempt + 1}/{max_retries + 1} failed: {type(e).__name__}: {str(e)}") | |
# Check if we should retry | |
if attempt < max_retries and self._should_retry(e): | |
delay = self._exponential_backoff(attempt, base_delay, self._max_delay) | |
logger.info(f"Retrying in {delay:.2f} seconds...") | |
time.sleep(delay) | |
continue | |
else: | |
# Either we've exhausted retries or this is not a retryable error | |
if attempt >= max_retries: | |
logger.error(f"Max retries ({max_retries}) exceeded. Last error: {type(e).__name__}: {str(e)}") | |
else: | |
logger.error(f"Non-retryable error: {type(e).__name__}: {str(e)}") | |
raise last_exception | |
# -------------------------------------------------- | |
def structured_responses(self, prompt: str, fields: List[str], description: str = "Structured LLM Call", | |
max_retries: int = None, base_delay: float = None, **kwargs: Any) -> Dict[str, Any]: | |
"""Call the Azure OpenAI Chat Completions API with structured output and return parsed JSON.""" | |
logger = logging.getLogger(__name__) | |
logger.info(f"Making structured request for fields: {fields}") | |
# Use instance defaults if not provided | |
max_retries = max_retries if max_retries is not None else self._max_retries | |
base_delay = base_delay if base_delay is not None else self._base_delay | |
# Remove ctx from kwargs before passing to openai | |
ctx = kwargs.pop("ctx", None) | |
# Create the structured output schema | |
schema = self._create_structured_output_schema(fields) | |
logger.debug(f"Using schema: {json.dumps(schema, indent=2)}") | |
# Create the response format for structured output | |
response_format = { | |
"type": "json_schema", | |
"json_schema": { | |
"name": "field_extraction_schema", | |
"description": "Schema for extracting structured field data", | |
"schema": schema | |
} | |
} | |
last_exception = None | |
for attempt in range(max_retries + 1): | |
try: | |
# Use Azure OpenAI Chat Completions API with structured output | |
completion = self.azure_client.beta.chat.completions.parse( | |
model=self._deployment, | |
messages=[ | |
{"role": "user", "content": prompt} | |
], | |
response_format=response_format, | |
**kwargs, | |
) | |
# Log the raw response for debugging | |
logging.debug(f"Structured LLM raw response: {completion}") | |
# --- Cost tracking: must be BEFORE any return! --- | |
if ctx and "cost_tracker" in ctx: | |
usage = getattr(completion, "usage", None) | |
if usage: | |
ctx["cost_tracker"].add_llm_tokens( | |
input_tokens=getattr(usage, "prompt_tokens", 0), | |
output_tokens=getattr(usage, "completion_tokens", 0), | |
description=description | |
) | |
else: | |
# Fallback: estimate tokens (very rough) | |
ctx["cost_tracker"].add_llm_tokens( | |
input_tokens=len(prompt.split()), | |
output_tokens=len(str(completion).split()), | |
description=description | |
) | |
# Extract the structured output from the response | |
if hasattr(completion, "choices") and len(completion.choices) > 0: | |
choice = completion.choices[0] | |
if hasattr(choice, "message"): | |
# First try to get the parsed structured output | |
if hasattr(choice.message, "parsed") and choice.message.parsed is not None: | |
result = choice.message.parsed | |
logger.info(f"Successfully parsed structured output: {json.dumps(result, indent=2)}") | |
return result | |
# If parsed is None but content exists, try to parse the content as JSON | |
elif hasattr(choice.message, "content") and choice.message.content: | |
logger.info("Parsed field is None, attempting to parse content as JSON") | |
try: | |
result = json.loads(choice.message.content) | |
logger.info(f"Successfully parsed JSON from content: {json.dumps(result, indent=2)}") | |
return result | |
except json.JSONDecodeError as json_error: | |
logger.warning(f"Failed to parse content as JSON: {json_error}") | |
logger.debug(f"Content was: {choice.message.content}") | |
else: | |
logger.warning("No parsed output or content found in message") | |
# Fallback: try to extract from text if structured output failed | |
logger.warning("Structured output not found, falling back to text extraction") | |
text_response = self.responses(prompt, description=description, ctx=ctx, **kwargs) | |
try: | |
# Try to parse the text response as JSON | |
result = json.loads(text_response) | |
logger.info(f"Successfully parsed fallback JSON: {json.dumps(result, indent=2)}") | |
return result | |
except json.JSONDecodeError: | |
logger.error(f"Failed to parse fallback response as JSON: {text_response}") | |
# Return empty result with the expected structure | |
empty_result = {field: [] for field in fields} | |
logger.warning(f"Returning empty result: {empty_result}") | |
return empty_result | |
except Exception as e: | |
last_exception = e | |
logger.warning(f"Attempt {attempt + 1}/{max_retries + 1} failed: {type(e).__name__}: {str(e)}") | |
# Check if we should retry | |
if attempt < max_retries and self._should_retry(e): | |
delay = self._exponential_backoff(attempt, base_delay, self._max_delay) | |
logger.info(f"Retrying in {delay:.2f} seconds...") | |
time.sleep(delay) | |
continue | |
else: | |
# Either we've exhausted retries or this is not a retryable error | |
if attempt >= max_retries: | |
logger.error(f"Max retries ({max_retries}) exceeded. Last error: {type(e).__name__}: {str(e)}") | |
else: | |
logger.error(f"Non-retryable error: {type(e).__name__}: {str(e)}") | |
# Return empty result on final failure | |
empty_result = {field: [] for field in fields} | |
logger.warning(f"Returning empty result due to error: {empty_result}") | |
return empty_result | |
# -------------------------------------------------- | |
def structured_combinations(self, prompt: str, unique_indices: List[str], description: str = "Structured Combinations Call", | |
max_retries: int = None, base_delay: float = None, **kwargs: Any) -> str: | |
"""Call the Azure OpenAI Chat Completions API with structured output for unique combinations and return JSON string.""" | |
logger = logging.getLogger(__name__) | |
logger.info(f"Making structured combinations request for indices: {unique_indices}") | |
# Use instance defaults if not provided | |
max_retries = max_retries if max_retries is not None else self._max_retries | |
base_delay = base_delay if base_delay is not None else self._base_delay | |
# Remove ctx from kwargs before passing to openai | |
ctx = kwargs.pop("ctx", None) | |
# Create the structured output schema for combinations | |
schema = self._create_combinations_schema(unique_indices) | |
logger.debug(f"Using combinations schema: {json.dumps(schema, indent=2)}") | |
# Create the response format for structured output | |
response_format = { | |
"type": "json_schema", | |
"json_schema": { | |
"name": "unique_combinations_schema", | |
"description": "Schema for extracting unique combinations of indices", | |
"schema": schema, | |
"strict": True | |
} | |
} | |
logger.info(f"Using response format: {json.dumps(response_format, indent=2)}") | |
last_exception = None | |
for attempt in range(max_retries + 1): | |
try: | |
# Use Azure OpenAI Chat Completions API with structured output | |
completion = self.azure_client.chat.completions.create( | |
model=self._deployment, | |
messages=[ | |
{"role": "user", "content": prompt} | |
], | |
response_format=response_format, | |
temperature=kwargs.get("temperature", 0.0), | |
) | |
# Log the raw response for debugging | |
logging.debug(f"Structured combinations LLM raw response: {completion}") | |
# --- Cost tracking: must be BEFORE any return! --- | |
if ctx and "cost_tracker" in ctx: | |
usage = getattr(completion, "usage", None) | |
if usage: | |
ctx["cost_tracker"].add_llm_tokens( | |
input_tokens=getattr(usage, "prompt_tokens", 0), | |
output_tokens=getattr(usage, "completion_tokens", 0), | |
description=description | |
) | |
logger.info(f"Structured combinations costs - Input tokens: {ctx['cost_tracker'].llm_input_tokens}, Output tokens: {ctx['cost_tracker'].llm_output_tokens}") | |
logger.info(f"Structured combinations cost: ${ctx['cost_tracker'].calculate_current_file_costs()['openai']['total_cost']:.4f}") | |
else: | |
# Fallback: estimate tokens (very rough) | |
ctx["cost_tracker"].add_llm_tokens( | |
input_tokens=len(prompt.split()), | |
output_tokens=len(str(completion).split()), | |
description=description | |
) | |
# Extract the structured output from the response | |
if hasattr(completion, "choices") and len(completion.choices) > 0: | |
choice = completion.choices[0] | |
if hasattr(choice, "message"): | |
# First try to get the parsed structured output | |
if hasattr(choice.message, "parsed") and choice.message.parsed is not None: | |
parsed_result = choice.message.parsed | |
logger.info(f"Successfully got parsed structured combinations output") | |
logger.debug(f"Parsed result: {parsed_result}") | |
# Extract the combinations array from the structured response | |
if isinstance(parsed_result, dict) and "combinations" in parsed_result: | |
combinations = parsed_result["combinations"] | |
logger.info(f"Successfully extracted {len(combinations)} unique combinations from structured output") | |
# Log the first combination as an example | |
if combinations and len(combinations) > 0: | |
logger.info(f"Example combination: {json.dumps(combinations[0], indent=2)}") | |
# Return the combinations array as JSON string (formatted) | |
return json.dumps(combinations, indent=2) | |
else: | |
logger.warning(f"Unexpected parsed result structure: {parsed_result}") | |
# Fallback to content parsing if parsed is None | |
elif hasattr(choice.message, "content") and choice.message.content: | |
content = choice.message.content | |
logger.info(f"Parsed field is None, attempting to parse content as JSON") | |
logger.debug(f"Raw content: {content}") | |
# Validate that it's valid JSON | |
try: | |
parsed_json = json.loads(content) | |
logger.info(f"Successfully parsed JSON from content") | |
# Extract combinations array if it exists in the expected structure | |
if isinstance(parsed_json, dict) and "combinations" in parsed_json: | |
combinations = parsed_json["combinations"] | |
logger.info(f"Successfully extracted {len(combinations)} unique combinations from content") | |
# Log the first combination as an example | |
if combinations and len(combinations) > 0: | |
logger.info(f"Example combination: {json.dumps(combinations[0], indent=2)}") | |
# Return the combinations array as JSON string (formatted) | |
return json.dumps(combinations, indent=2) | |
# Fallback: if it's already an array (old format), return as-is | |
elif isinstance(parsed_json, list): | |
logger.info(f"Content is already an array format with {len(parsed_json)} combinations") | |
return json.dumps(parsed_json, indent=2) | |
else: | |
logger.warning(f"Unexpected JSON structure in content: {parsed_json}") | |
except json.JSONDecodeError as json_error: | |
logger.warning(f"Failed to parse content as JSON: {json_error}") | |
logger.debug(f"Content was: {content}") | |
else: | |
logger.warning("No parsed output or content found in message") | |
# Fallback: try to extract from text if structured output failed | |
logger.warning("Structured output not found, falling back to regular responses method") | |
fallback_response = self.responses(prompt, description=description, ctx=ctx, **kwargs) | |
logger.info("Fallback to regular responses method successful") | |
return fallback_response | |
except Exception as e: | |
last_exception = e | |
logger.warning(f"Attempt {attempt + 1}/{max_retries + 1} failed: {type(e).__name__}: {str(e)}") | |
# Check if we should retry | |
if attempt < max_retries and self._should_retry(e): | |
delay = self._exponential_backoff(attempt, base_delay, self._max_delay) | |
logger.info(f"Retrying in {delay:.2f} seconds...") | |
time.sleep(delay) | |
continue | |
else: | |
# Either we've exhausted retries or this is not a retryable error | |
if attempt >= max_retries: | |
logger.error(f"Max retries ({max_retries}) exceeded. Last error: {type(e).__name__}: {str(e)}") | |
else: | |
logger.error(f"Non-retryable error: {type(e).__name__}: {str(e)}") | |
# Fallback to regular responses method on final failure | |
logger.warning("Final fallback to regular responses method") | |
try: | |
fallback_response = self.responses(prompt, description=description, ctx=ctx, **kwargs) | |
logger.info("Final fallback successful") | |
return fallback_response | |
except Exception as fallback_error: | |
logger.error(f"Final fallback also failed: {fallback_error}") | |
raise last_exception |