Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
HuggingFace Inference Providers API-based answer generation. | |
This module provides answer generation using HuggingFace's new Inference Providers API, | |
which offers OpenAI-compatible chat completion format for better reliability and consistency. | |
""" | |
import os | |
import sys | |
import logging | |
import time | |
from datetime import datetime | |
from typing import List, Dict, Any, Optional, Tuple | |
from pathlib import Path | |
import re | |
# Import shared components | |
from .hf_answer_generator import Citation, GeneratedAnswer | |
from .prompt_templates import TechnicalPromptTemplates | |
# Check if huggingface_hub is new enough for InferenceClient chat completion | |
try: | |
from huggingface_hub import InferenceClient | |
from huggingface_hub import __version__ as hf_hub_version | |
print(f"π Using huggingface_hub version: {hf_hub_version}", file=sys.stderr, flush=True) | |
except ImportError: | |
print("β huggingface_hub not found or outdated. Please install: pip install -U huggingface-hub", file=sys.stderr, flush=True) | |
raise | |
logger = logging.getLogger(__name__) | |
class InferenceProvidersGenerator: | |
""" | |
Generates answers using HuggingFace Inference Providers API. | |
This uses the new OpenAI-compatible chat completion format for better reliability | |
compared to the classic Inference API. It provides: | |
- Consistent response format across models | |
- Better error handling and retry logic | |
- Support for streaming responses | |
- Automatic provider selection and failover | |
""" | |
# Models that work well with chat completion format | |
CHAT_MODELS = [ | |
"microsoft/DialoGPT-medium", # Proven conversational model | |
"google/gemma-2-2b-it", # Instruction-tuned, good for Q&A | |
"meta-llama/Llama-3.2-3B-Instruct", # If available with token | |
"Qwen/Qwen2.5-1.5B-Instruct", # Small, fast, good quality | |
] | |
# Fallback to classic API models if chat completion fails | |
CLASSIC_FALLBACK_MODELS = [ | |
"google/flan-t5-small", # Good for instructions | |
"deepset/roberta-base-squad2", # Q&A specific | |
"facebook/bart-base", # Summarization | |
] | |
def __init__( | |
self, | |
model_name: Optional[str] = None, | |
api_token: Optional[str] = None, | |
temperature: float = 0.3, | |
max_tokens: int = 512, | |
timeout: int = 30 | |
): | |
""" | |
Initialize the Inference Providers answer generator. | |
Args: | |
model_name: Model to use (defaults to first available chat model) | |
api_token: HF API token (uses env vars if not provided) | |
temperature: Generation temperature (0.0-1.0) | |
max_tokens: Maximum tokens to generate | |
timeout: Request timeout in seconds | |
""" | |
# Get API token from various sources | |
self.api_token = ( | |
api_token or | |
os.getenv("HUGGINGFACE_API_TOKEN") or | |
os.getenv("HF_TOKEN") or | |
os.getenv("HF_API_TOKEN") | |
) | |
if not self.api_token: | |
print("β οΈ No HF API token found. Inference Providers requires authentication.", file=sys.stderr, flush=True) | |
print("Set HF_TOKEN, HUGGINGFACE_API_TOKEN, or HF_API_TOKEN environment variable.", file=sys.stderr, flush=True) | |
raise ValueError("HuggingFace API token required for Inference Providers") | |
print(f"β Found HF token (starts with: {self.api_token[:8]}...)", file=sys.stderr, flush=True) | |
# Initialize client with token | |
self.client = InferenceClient(token=self.api_token) | |
self.temperature = temperature | |
self.max_tokens = max_tokens | |
self.timeout = timeout | |
# Select model | |
self.model_name = model_name or self.CHAT_MODELS[0] | |
self.using_chat_completion = True | |
print(f"π Initialized Inference Providers with model: {self.model_name}", file=sys.stderr, flush=True) | |
# Test the connection | |
self._test_connection() | |
def _test_connection(self): | |
"""Test if the API is accessible and model is available.""" | |
print(f"π§ Testing Inference Providers API connection...", file=sys.stderr, flush=True) | |
try: | |
# Try a simple test query | |
test_messages = [ | |
{"role": "user", "content": "Hello"} | |
] | |
# First try chat completion (preferred) | |
try: | |
response = self.client.chat_completion( | |
messages=test_messages, | |
model=self.model_name, | |
max_tokens=10, | |
temperature=0.1 | |
) | |
print(f"β Chat completion API working with {self.model_name}", file=sys.stderr, flush=True) | |
self.using_chat_completion = True | |
return | |
except Exception as e: | |
print(f"β οΈ Chat completion failed for {self.model_name}: {e}", file=sys.stderr, flush=True) | |
# Try other chat models | |
for model in self.CHAT_MODELS: | |
if model != self.model_name: | |
try: | |
print(f"π Trying {model}...", file=sys.stderr, flush=True) | |
response = self.client.chat_completion( | |
messages=test_messages, | |
model=model, | |
max_tokens=10 | |
) | |
print(f"β Found working model: {model}", file=sys.stderr, flush=True) | |
self.model_name = model | |
self.using_chat_completion = True | |
return | |
except: | |
continue | |
# If chat completion fails, test classic text generation | |
print("π Falling back to classic text generation API...", file=sys.stderr, flush=True) | |
for model in self.CLASSIC_FALLBACK_MODELS: | |
try: | |
response = self.client.text_generation( | |
model=model, | |
prompt="Hello", | |
max_new_tokens=10 | |
) | |
print(f"β Classic API working with fallback model: {model}", file=sys.stderr, flush=True) | |
self.model_name = model | |
self.using_chat_completion = False | |
return | |
except: | |
continue | |
raise Exception("No working models found in Inference Providers API") | |
except Exception as e: | |
print(f"β Inference Providers API test failed: {e}", file=sys.stderr, flush=True) | |
raise | |
def _format_context(self, chunks: List[Dict[str, Any]]) -> str: | |
"""Format retrieved chunks into context string.""" | |
context_parts = [] | |
for i, chunk in enumerate(chunks): | |
chunk_text = chunk.get('content', chunk.get('text', '')) | |
page_num = chunk.get('metadata', {}).get('page_number', 'unknown') | |
source = chunk.get('metadata', {}).get('source', 'unknown') | |
context_parts.append( | |
f"[chunk_{i+1}] (Page {page_num} from {source}):\n{chunk_text}\n" | |
) | |
return "\n---\n".join(context_parts) | |
def _create_messages(self, query: str, context: str) -> List[Dict[str, str]]: | |
"""Create chat messages using TechnicalPromptTemplates.""" | |
# Get appropriate template based on query type | |
prompt_data = TechnicalPromptTemplates.format_prompt_with_template( | |
query=query, | |
context=context | |
) | |
# Create messages for chat completion | |
messages = [ | |
{ | |
"role": "system", | |
"content": prompt_data['system'] + "\n\nMANDATORY: Use [chunk_X] citations for all facts." | |
}, | |
{ | |
"role": "user", | |
"content": prompt_data['user'] | |
} | |
] | |
return messages | |
def _call_chat_completion(self, messages: List[Dict[str, str]]) -> str: | |
"""Call the chat completion API.""" | |
try: | |
print(f"π€ Calling Inference Providers chat completion with {self.model_name}...", file=sys.stderr, flush=True) | |
# Use chat completion with proper error handling | |
response = self.client.chat_completion( | |
messages=messages, | |
model=self.model_name, | |
temperature=self.temperature, | |
max_tokens=self.max_tokens, | |
stream=False | |
) | |
# Extract content from response | |
if hasattr(response, 'choices') and response.choices: | |
content = response.choices[0].message.content | |
print(f"β Got response: {len(content)} characters", file=sys.stderr, flush=True) | |
return content | |
else: | |
print(f"β οΈ Unexpected response format: {response}", file=sys.stderr, flush=True) | |
return str(response) | |
except Exception as e: | |
print(f"β Chat completion error: {e}", file=sys.stderr, flush=True) | |
# Try with a fallback model | |
if self.model_name != "microsoft/DialoGPT-medium": | |
print("π Trying fallback model: microsoft/DialoGPT-medium", file=sys.stderr, flush=True) | |
try: | |
response = self.client.chat_completion( | |
messages=messages, | |
model="microsoft/DialoGPT-medium", | |
temperature=self.temperature, | |
max_tokens=self.max_tokens | |
) | |
if hasattr(response, 'choices') and response.choices: | |
return response.choices[0].message.content | |
except: | |
pass | |
raise Exception(f"Chat completion failed: {e}") | |
def _call_classic_api(self, query: str, context: str) -> str: | |
"""Fallback to classic text generation API.""" | |
print(f"π Using classic text generation with {self.model_name}...", file=sys.stderr, flush=True) | |
# Format prompt for classic API | |
if "squad" in self.model_name.lower(): | |
# Q&A format for squad models | |
prompt = f"Context: {context}\n\nQuestion: {query}\n\nAnswer:" | |
elif "flan" in self.model_name.lower(): | |
# Instruction format for Flan models | |
prompt = f"Answer the question based on the context.\n\nContext: {context}\n\nQuestion: {query}\n\nAnswer:" | |
else: | |
# Generic format | |
prompt = f"Based on the following context, answer the question.\n\nContext:\n{context}\n\nQuestion: {query}\n\nAnswer:" | |
try: | |
response = self.client.text_generation( | |
model=self.model_name, | |
prompt=prompt, | |
max_new_tokens=self.max_tokens, | |
temperature=self.temperature | |
) | |
return response | |
except Exception as e: | |
print(f"β Classic API error: {e}", file=sys.stderr, flush=True) | |
return f"Error generating response: {str(e)}" | |
def _extract_citations(self, answer: str, chunks: List[Dict[str, Any]]) -> Tuple[str, List[Citation]]: | |
"""Extract citations from the answer.""" | |
citations = [] | |
citation_pattern = r'\[chunk_(\d+)\]' | |
cited_chunks = set() | |
# Find explicit citations | |
matches = re.finditer(citation_pattern, answer) | |
for match in matches: | |
chunk_idx = int(match.group(1)) - 1 | |
if 0 <= chunk_idx < len(chunks): | |
cited_chunks.add(chunk_idx) | |
# Fallback: Create citations for top chunks if none found | |
if not cited_chunks and chunks and len(answer.strip()) > 50: | |
num_fallback = min(3, len(chunks)) | |
cited_chunks = set(range(num_fallback)) | |
print(f"π§ Creating {num_fallback} fallback citations", file=sys.stderr, flush=True) | |
# Create Citation objects | |
chunk_to_source = {} | |
for idx in cited_chunks: | |
chunk = chunks[idx] | |
citation = Citation( | |
chunk_id=chunk.get('id', f'chunk_{idx}'), | |
page_number=chunk.get('metadata', {}).get('page_number', 0), | |
source_file=chunk.get('metadata', {}).get('source', 'unknown'), | |
relevance_score=chunk.get('score', 0.0), | |
text_snippet=chunk.get('content', chunk.get('text', ''))[:200] + '...' | |
) | |
citations.append(citation) | |
# Map for natural language replacement | |
source_name = chunk.get('metadata', {}).get('source', 'unknown') | |
if source_name != 'unknown': | |
natural_name = Path(source_name).stem.replace('-', ' ').replace('_', ' ') | |
chunk_to_source[f'[chunk_{idx+1}]'] = f"the {natural_name} documentation" | |
else: | |
chunk_to_source[f'[chunk_{idx+1}]'] = "the documentation" | |
# Replace citations with natural language | |
natural_answer = answer | |
for chunk_ref, natural_ref in chunk_to_source.items(): | |
natural_answer = natural_answer.replace(chunk_ref, natural_ref) | |
# Clean up any remaining citations | |
natural_answer = re.sub(r'\[chunk_\d+\]', 'the documentation', natural_answer) | |
natural_answer = re.sub(r'\s+', ' ', natural_answer).strip() | |
return natural_answer, citations | |
def _calculate_confidence(self, answer: str, citations: List[Citation], chunks: List[Dict[str, Any]]) -> float: | |
"""Calculate confidence score for the answer.""" | |
if not answer or len(answer.strip()) < 10: | |
return 0.1 | |
# Base confidence from chunk quality | |
if len(chunks) >= 3: | |
confidence = 0.8 | |
elif len(chunks) >= 2: | |
confidence = 0.7 | |
else: | |
confidence = 0.6 | |
# Citation bonus | |
if citations and chunks: | |
citation_ratio = len(citations) / min(len(chunks), 3) | |
confidence += 0.15 * citation_ratio | |
# Check for uncertainty phrases | |
uncertainty_phrases = [ | |
"insufficient information", | |
"cannot determine", | |
"not available in the provided documents", | |
"i don't know", | |
"unclear" | |
] | |
if any(phrase in answer.lower() for phrase in uncertainty_phrases): | |
confidence *= 0.3 | |
return min(confidence, 0.95) | |
def generate(self, query: str, chunks: List[Dict[str, Any]]) -> GeneratedAnswer: | |
""" | |
Generate an answer using Inference Providers API. | |
Args: | |
query: User's question | |
chunks: Retrieved document chunks | |
Returns: | |
GeneratedAnswer with answer, citations, and metadata | |
""" | |
start_time = datetime.now() | |
# Check for no-context situation | |
if not chunks or all(len(chunk.get('content', chunk.get('text', ''))) < 20 for chunk in chunks): | |
return GeneratedAnswer( | |
answer="This information isn't available in the provided documents.", | |
citations=[], | |
confidence_score=0.05, | |
generation_time=0.1, | |
model_used=self.model_name, | |
context_used=chunks | |
) | |
# Format context | |
context = self._format_context(chunks) | |
# Generate answer | |
try: | |
if self.using_chat_completion: | |
# Create chat messages | |
messages = self._create_messages(query, context) | |
# Call chat completion API | |
answer_text = self._call_chat_completion(messages) | |
else: | |
# Fallback to classic API | |
answer_text = self._call_classic_api(query, context) | |
# Extract citations and clean answer | |
natural_answer, citations = self._extract_citations(answer_text, chunks) | |
# Calculate confidence | |
confidence = self._calculate_confidence(natural_answer, citations, chunks) | |
generation_time = (datetime.now() - start_time).total_seconds() | |
return GeneratedAnswer( | |
answer=natural_answer, | |
citations=citations, | |
confidence_score=confidence, | |
generation_time=generation_time, | |
model_used=self.model_name, | |
context_used=chunks | |
) | |
except Exception as e: | |
logger.error(f"Error generating answer: {e}") | |
print(f"β Generation failed: {e}", file=sys.stderr, flush=True) | |
# Return error response | |
return GeneratedAnswer( | |
answer="I apologize, but I encountered an error while generating the answer. Please try again.", | |
citations=[], | |
confidence_score=0.0, | |
generation_time=(datetime.now() - start_time).total_seconds(), | |
model_used=self.model_name, | |
context_used=chunks | |
) | |
# Example usage | |
if __name__ == "__main__": | |
# Test the generator | |
print("Testing Inference Providers Generator...") | |
try: | |
generator = InferenceProvidersGenerator() | |
# Test chunks | |
test_chunks = [ | |
{ | |
"content": "RISC-V is an open-source instruction set architecture (ISA) based on established reduced instruction set computer (RISC) principles.", | |
"metadata": {"page_number": 1, "source": "riscv-spec.pdf"}, | |
"score": 0.95 | |
}, | |
{ | |
"content": "Unlike most other ISA designs, RISC-V is provided under open source licenses that do not require fees to use.", | |
"metadata": {"page_number": 2, "source": "riscv-spec.pdf"}, | |
"score": 0.85 | |
} | |
] | |
# Generate answer | |
result = generator.generate("What is RISC-V and why is it important?", test_chunks) | |
print(f"\nπ Answer: {result.answer}") | |
print(f"π Confidence: {result.confidence_score:.1%}") | |
print(f"β±οΈ Generation time: {result.generation_time:.2f}s") | |
print(f"π€ Model: {result.model_used}") | |
print(f"π Citations: {len(result.citations)}") | |
except Exception as e: | |
print(f"β Test failed: {e}") | |
import traceback | |
traceback.print_exc() |