|
from fastapi import FastAPI, UploadFile, File, Form, HTTPException, Request |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.responses import RedirectResponse, JSONResponse, HTMLResponse |
|
from transformers import pipeline, ViltProcessor, ViltForQuestionAnswering, M2M100ForConditionalGeneration, M2M100Tokenizer |
|
from typing import Optional, Dict, Any, List |
|
import logging |
|
import time |
|
import os |
|
import io |
|
import json |
|
import re |
|
from PIL import Image |
|
from docx import Document |
|
import fitz |
|
import pandas as pd |
|
from functools import lru_cache |
|
import torch |
|
import numpy as np |
|
from pydantic import BaseModel |
|
import asyncio |
|
import google.generativeai as genai |
|
from spellchecker import SpellChecker |
|
import nltk |
|
from nltk.tokenize import sent_tokenize |
|
|
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' |
|
) |
|
logger = logging.getLogger("cosmic_ai") |
|
|
|
|
|
nltk_data_dir = os.getenv('NLTK_DATA', '/tmp/nltk_data') |
|
os.makedirs(nltk_data_dir, exist_ok=True) |
|
nltk.data.path.append(nltk_data_dir) |
|
|
|
|
|
try: |
|
nltk.download('punkt_tab', download_dir=nltk_data_dir, quiet=True, raise_on_error=True) |
|
logger.info(f"NLTK punkt_tab verified in {nltk_data_dir}") |
|
except Exception as e: |
|
logger.error(f"Error verifying NLTK punkt_tab: {str(e)}") |
|
raise Exception(f"Failed to verify NLTK punkt_tab: {str(e)}") |
|
|
|
|
|
upload_dir = os.getenv('UPLOAD_DIR', '/tmp/uploads') |
|
os.makedirs(upload_dir, exist_ok=True) |
|
|
|
app = FastAPI( |
|
title="Cosmic AI Assistant", |
|
description="An advanced AI assistant with space-themed interface, translation, and file question-answering features", |
|
version="2.0.0" |
|
) |
|
|
|
|
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
app.mount("/images", StaticFiles(directory="images"), name="images") |
|
|
|
|
|
API_KEY = "AIzaSyDtLhhmXpy8ubSGb84ImaxM_ywlL0l_8bo" |
|
genai.configure(api_key=API_KEY) |
|
|
|
|
|
MODELS = { |
|
"summarization": "sshleifer/distilbart-cnn-12-6", |
|
"image-to-text": "Salesforce/blip-image-captioning-large", |
|
"visual-qa": "dandelin/vilt-b32-finetuned-vqa", |
|
"chatbot": "gemini-1.5-pro", |
|
"translation": "facebook/m2m100_418M", |
|
"file-qa": "distilbert-base-cased-distilled-squad" |
|
} |
|
|
|
|
|
SUPPORTED_LANGUAGES = { |
|
"english": "en", |
|
"french": "fr", |
|
"german": "de", |
|
"spanish": "es", |
|
"italian": "it", |
|
"russian": "ru", |
|
"chinese": "zh", |
|
"japanese": "ja", |
|
"arabic": "ar", |
|
"hindi": "hi", |
|
"portuguese": "pt", |
|
"korean": "ko" |
|
} |
|
|
|
|
|
translation_model = None |
|
translation_tokenizer = None |
|
|
|
|
|
spell = SpellChecker() |
|
|
|
|
|
@lru_cache(maxsize=8) |
|
def load_model(task: str, model_name: str = None): |
|
"""Cached model loader with proper task names and error handling""" |
|
try: |
|
logger.info(f"Loading model for task: {task}, model: {model_name or MODELS.get(task)}") |
|
start_time = time.time() |
|
|
|
model_to_load = model_name or MODELS.get(task) |
|
|
|
if task == "chatbot": |
|
return genai.GenerativeModel(model_to_load) |
|
|
|
if task == "visual-qa": |
|
processor = ViltProcessor.from_pretrained(model_to_load) |
|
model = ViltForQuestionAnswering.from_pretrained(model_to_load) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
model.to(device) |
|
|
|
def vqa_function(image, question, **generate_kwargs): |
|
if image.mode != "RGB": |
|
image = image.convert("RGB") |
|
inputs = processor(image, question, return_tensors="pt").to(device) |
|
logger.info(f"VQA inputs - question: {question}, image size: {image.size}") |
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
logits = outputs.logits |
|
idx = logits.argmax(-1).item() |
|
answer = model.config.id2label[idx] |
|
logger.info(f"VQA raw output: {answer}") |
|
return answer |
|
|
|
return vqa_function |
|
|
|
|
|
return pipeline( |
|
task if task != "file-qa" else "question-answering", |
|
model=model_to_load, |
|
tokenizer_kwargs={"clean_up_tokenization_spaces": True} |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Model load failed for {task}: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Model loading failed: {task} - {str(e)}") |
|
|
|
def get_gemini_response(user_input: str, is_generation: bool = False): |
|
"""Function to generate response with Gemini for both chat and text generation""" |
|
if not user_input: |
|
return "Please provide some input." |
|
try: |
|
chatbot = load_model("chatbot") |
|
if is_generation: |
|
prompt = f"Generate creative text based on this prompt: {user_input}" |
|
else: |
|
prompt = user_input |
|
response = chatbot.generate_content(prompt) |
|
return response.text.strip() |
|
except Exception as e: |
|
return f"Error: {str(e)}" |
|
|
|
def translate_text(text: str, target_language: str): |
|
"""Translate text to any target language using pre-loaded M2M100 model""" |
|
if not text: |
|
return "Please provide text to translate." |
|
|
|
try: |
|
global translation_model, translation_tokenizer |
|
|
|
target_lang = target_language.lower() |
|
if target_lang not in SUPPORTED_LANGUAGES: |
|
similar = [lang for lang in SUPPORTED_LANGUAGES if target_lang in lang or lang in target_lang] |
|
if similar: |
|
target_lang = similar[0] |
|
else: |
|
return f"Language '{target_language}' not supported. Available languages: {', '.join(SUPPORTED_LANGUAGES.keys())}" |
|
|
|
lang_code = SUPPORTED_LANGUAGES[target_lang] |
|
|
|
|
|
if translation_model is None or translation_tokenizer is None: |
|
logger.info("Translation model not pre-loaded, loading on demand...") |
|
model_name = MODELS["translation"] |
|
translation_model = M2M100ForConditionalGeneration.from_pretrained( |
|
model_name, |
|
cache_dir=os.getenv("HF_HOME", "/app/cache") |
|
) |
|
translation_tokenizer = M2M100Tokenizer.from_pretrained( |
|
model_name, |
|
cache_dir=os.getenv("HF_HOME", "/app/cache") |
|
) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
translation_model.to(device) |
|
logger.info("Translation model loaded on demand successfully") |
|
|
|
match = re.search(r'how to say\s+(.+?)\s+in\s+(\w+)', text.lower()) |
|
if match: |
|
text_to_translate = match.group(1) |
|
else: |
|
content_match = re.search(r'(?:translate|convert).*to\s+[a-zA-Z]+\s*[:\s]*(.+)', text, re.IGNORECASE) |
|
text_to_translate = content_match.group(1) if content_match else text |
|
|
|
translation_tokenizer.src_lang = "en" |
|
encoded = translation_tokenizer(text_to_translate, return_tensors="pt", padding=True, truncation=True).to(translation_model.device) |
|
|
|
start_time = time.time() |
|
generated_tokens = translation_model.generate( |
|
**encoded, |
|
forced_bos_token_id=translation_tokenizer.get_lang_id(lang_code), |
|
max_length=512, |
|
num_beams=1, |
|
early_stopping=True |
|
) |
|
translated_text = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
logger.info(f"Translation took {time.time() - start_time:.2f} seconds") |
|
|
|
return translated_text |
|
|
|
except Exception as e: |
|
logger.error(f"Translation error: {str(e)}", exc_info=True) |
|
return f"Translation error: {str(e)}" |
|
|
|
def detect_intent(text: str = None, file: UploadFile = None, intent: str = None) -> tuple[str, str]: |
|
"""Enhanced intent detection with explicit intent parameter support""" |
|
target_language = "English" |
|
valid_intents = [ |
|
"chatbot", "translate", "file-translate", "summarize", "image-to-text", |
|
"visual-qa", "visualize", "text-generation", "file-qa" |
|
] |
|
|
|
|
|
if intent and intent in valid_intents: |
|
logger.info(f"Using explicit intent: {intent}") |
|
|
|
if intent in ["translate", "file-translate"] and text: |
|
translate_patterns = [ |
|
r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', |
|
r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', |
|
r'how to say.*in\s+\[?([a-zA-Z]+)\]?:?\s*(.*)' |
|
] |
|
for pattern in translate_patterns: |
|
translate_match = re.search(pattern, text.lower()) |
|
if translate_match: |
|
potential_lang = translate_match.group(1).lower() |
|
if potential_lang in SUPPORTED_LANGUAGES: |
|
target_language = potential_lang.capitalize() |
|
break |
|
return intent, target_language |
|
|
|
|
|
if file and text: |
|
text_lower = text.lower() |
|
filename = file.filename.lower() if file.filename else "" |
|
|
|
|
|
translate_patterns = [ |
|
r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', |
|
r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', |
|
r'how to say.*in\s+\[?([a-zA-Z]+)\]?:?\s*(.*)' |
|
] |
|
for pattern in translate_patterns: |
|
translate_match = re.search(pattern, text_lower) |
|
if translate_match and filename.endswith(('.pdf', '.docx', '.txt', '.rtf')): |
|
potential_lang = translate_match.group(1).lower() |
|
if potential_lang in SUPPORTED_LANGUAGES: |
|
target_language = potential_lang.capitalize() |
|
return "file-translate", target_language |
|
|
|
|
|
content_type = file.content_type.lower() if file.content_type else "" |
|
if content_type.startswith('image/') and text: |
|
if "what’s this" in text_lower or "does this fly" in text_lower or ("fly" in text_lower and any(q in text_lower for q in ['does', 'can', 'will'])): |
|
return "visual-qa", target_language |
|
if any(q in text_lower for q in ['what is', 'what\'s', 'describe', 'tell me about', 'explain', 'how many', 'what color', 'is there', 'are they', 'does the']): |
|
return "visual-qa", target_language |
|
if "generate a caption" in text_lower or "caption" in text_lower: |
|
return "image-to-text", target_language |
|
|
|
|
|
if filename.endswith(('.xlsx', '.xls', '.csv')): |
|
return "visualize", target_language |
|
elif filename.endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')): |
|
if any(q in text_lower for q in ['what is', 'who is', 'where', 'when', 'why', 'how', 'what are', 'who are']): |
|
return "file-qa", target_language |
|
return "summarize", target_language |
|
|
|
if not text: |
|
|
|
if file: |
|
filename = file.filename.lower() if file.filename else "" |
|
content_type = file.content_type.lower() if file.content_type else "" |
|
if content_type.startswith('image/'): |
|
return "image-to-text", target_language |
|
elif filename.endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')): |
|
return "summarize", target_language |
|
elif filename.endswith(('.xlsx', '.xls', '.csv')): |
|
return "visualize", target_language |
|
return "chatbot", target_language |
|
|
|
text_lower = text.lower() |
|
|
|
if any(keyword in text_lower for keyword in ['chat', 'talk', 'converse', 'ask gemini']): |
|
return "chatbot", target_language |
|
|
|
|
|
translate_patterns = [ |
|
r'translate.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', |
|
r'convert.*to\s+\[?([a-zA-Z]+)\]?:?\s*(.*)', |
|
r'how to say.*in\s+\[?([a-zA-Z]+)\]?:?\s*(.*)' |
|
] |
|
|
|
for pattern in translate_patterns: |
|
translate_match = re.search(pattern, text_lower) |
|
if translate_match: |
|
potential_lang = translate_match.group(1).lower() |
|
if potential_lang in SUPPORTED_LANGUAGES: |
|
target_language = potential_lang.capitalize() |
|
return "translate", target_language |
|
else: |
|
logger.warning(f"Invalid language detected: {potential_lang}") |
|
return "chatbot", target_language |
|
|
|
vqa_patterns = [ |
|
r'how (many|much)', |
|
r'what (color|size|position|shape)', |
|
r'is (there|that|this) (a|an)', |
|
r'are (they|there) (any|some)', |
|
r'does (the|this) (image|picture) (show|contain)' |
|
] |
|
|
|
if any(re.search(pattern, text_lower) for pattern in vqa_patterns): |
|
return "visual-qa", target_language |
|
|
|
summarization_patterns = [ |
|
r'\b(summar(y|ize|ise)|brief( overview)?)\b', |
|
r'\b(long article|text|document)\b', |
|
r'\bcan you (summar|brief|condense)\b', |
|
r'\b(short summary|brief explanation)\b', |
|
r'\b(overview|main points|key ideas)\b', |
|
r'\b(tl;?dr|too long didn\'?t read)\b' |
|
] |
|
|
|
if any(re.search(pattern, text_lower) for pattern in summarization_patterns): |
|
return "summarize", target_language |
|
|
|
generation_patterns = [ |
|
r'\b(write|generate|create|compose)\b', |
|
r'\b(story|poem|essay|text|content)\b' |
|
] |
|
|
|
if any(re.search(pattern, text_lower) for pattern in generation_patterns): |
|
return "text-generation", target_language |
|
|
|
if len(text) > 100: |
|
return "summarize", target_language |
|
|
|
return "chatbot", target_language |
|
|
|
def preprocess_text(text: str) -> str: |
|
"""Correct spelling errors and improve text readability.""" |
|
words = text.split() |
|
corrected_words = [spell.correction(word) if spell.correction(word) else word for word in words] |
|
corrected_text = " ".join(corrected_words) |
|
sentences = sent_tokenize(corrected_text) |
|
return ". ".join(sentence.capitalize() for sentence in sentences) + (". " if sentences else "") |
|
|
|
class ProcessResponse(BaseModel): |
|
response: str |
|
type: str |
|
additional_data: Optional[Dict[str, Any]] = None |
|
|
|
@app.get("/chatbot") |
|
async def chatbot_interface(): |
|
"""Redirect to the static index.html file for the chatbot interface""" |
|
return RedirectResponse(url="/static/index.html") |
|
|
|
@app.post("/chat") |
|
async def chat_endpoint(data: dict): |
|
"""Endpoint for chatbot interactions""" |
|
message = data.get("message", "") |
|
if not message: |
|
raise HTTPException(status_code=400, detail="No message provided") |
|
try: |
|
response = get_gemini_response(message) |
|
return {"response": response} |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=f"Chat error: {str(e)}") |
|
|
|
@app.post("/process", response_model=ProcessResponse) |
|
async def process_input( |
|
request: Request, |
|
text: str = Form(None), |
|
file: UploadFile = File(None), |
|
intent: str = Form(None) |
|
): |
|
"""Enhanced unified endpoint with dynamic translation and file translation""" |
|
start_time = time.time() |
|
client_ip = request.client.host |
|
logger.info(f"Request from {client_ip}: text={text[:50] + '...' if text and len(text) > 50 else text}, file={file.filename if file else None}, intent={intent}") |
|
|
|
detected_intent, target_language = detect_intent(text, file, intent) |
|
logger.info(f"Detected intent: {detected_intent}, target_language: {target_language}") |
|
|
|
try: |
|
if detected_intent == "chatbot": |
|
response = get_gemini_response(text) |
|
return {"response": response, "type": "chat"} |
|
elif detected_intent == "translate": |
|
content = await extract_text_from_file(file) if file else text |
|
if "all languages" in text.lower(): |
|
translations = {} |
|
phrase_to_translate = "I want to explore the stars" if "I want to explore the stars" in text else content |
|
for lang, code in SUPPORTED_LANGUAGES.items(): |
|
translation_tokenizer.src_lang = "en" |
|
encoded = translation_tokenizer(phrase_to_translate, return_tensors="pt").to(translation_model.device) |
|
generated_tokens = translation_model.generate( |
|
**encoded, |
|
forced_bos_token_id=translation_tokenizer.get_lang_id(code), |
|
max_length=512, |
|
num_beams=1 |
|
) |
|
translations[lang] = translation_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |
|
response = "\n".join(f"{lang.capitalize()}: {translations[lang]}" for lang in translations) |
|
logger.info(f"Translated to all supported languages: {', '.join(translations.keys())}") |
|
return {"response": response, "type": "translation"} |
|
else: |
|
translated_text = translate_text(content, target_language) |
|
return {"response": translated_text, "type": "translation"} |
|
|
|
elif detected_intent == "file-translate": |
|
if not file or not file.filename.lower().endswith(('.pdf', '.docx', '.txt', '.rtf')): |
|
raise HTTPException(status_code=400, detail="A text-based file (PDF, DOCX, TXT, RTF) is required") |
|
if not text: |
|
raise HTTPException(status_code=400, detail="Please specify a target language for translation") |
|
|
|
content = await extract_text_from_file(file) |
|
if not content.strip(): |
|
raise HTTPException(status_code=400, detail="No text could be extracted from the file") |
|
|
|
|
|
max_chunk_size = 512 |
|
chunks = [content[i:i+max_chunk_size] for i in range(0, len(content), max_chunk_size)] |
|
translated_chunks = [] |
|
|
|
for chunk in chunks: |
|
translated_chunk = translate_text(chunk, target_language) |
|
translated_chunks.append(translated_chunk) |
|
|
|
translated_text = " ".join(translated_chunks) |
|
translated_text = translated_text.strip().capitalize() |
|
if not translated_text.endswith(('.', '!', '?')): |
|
translated_text += '.' |
|
|
|
logger.info(f"File translated to {target_language}: {translated_text[:100]}...") |
|
|
|
return { |
|
"response": translated_text, |
|
"type": "file_translation", |
|
"additional_data": { |
|
"file_name": file.filename, |
|
"target_language": target_language |
|
} |
|
} |
|
|
|
elif detected_intent == "summarize": |
|
content = await extract_text_from_file(file) if file else text |
|
if not content.strip(): |
|
raise HTTPException(status_code=400, detail="No content to summarize") |
|
|
|
content = preprocess_text(content) |
|
logger.info(f"Preprocessed content: {content[:100]}...") |
|
|
|
summarizer = load_model("summarization") |
|
|
|
content_length = len(content.split()) |
|
max_len = max(50, min(200, content_length)) |
|
min_len = max(20, min(50, content_length // 3)) |
|
|
|
try: |
|
if len(content) > 1024: |
|
chunks = [content[i:i+1024] for i in range(0, len(content), 1024)] |
|
summaries = [] |
|
|
|
for chunk in chunks[:3]: |
|
summary = summarizer( |
|
chunk, |
|
max_length=max_len, |
|
min_length=min_len, |
|
do_sample=False, |
|
truncation=True |
|
) |
|
summaries.append(summary[0]['summary_text']) |
|
|
|
final_summary = " ".join(summaries) |
|
else: |
|
summary = summarizer( |
|
content, |
|
max_length=max_len, |
|
min_length=min_len, |
|
do_sample=False, |
|
truncation=True |
|
) |
|
final_summary = summary[0]['summary_text'] |
|
|
|
final_summary = re.sub(r'\s+', ' ', final_summary).strip() |
|
if not final_summary or final_summary.lower().startswith(content.lower()[:30]): |
|
logger.warning("Summarizer produced inadequate output, falling back to Gemini") |
|
final_summary = get_gemini_response( |
|
f"Summarize this text in a concise and meaningful way: {content}" |
|
) |
|
|
|
if not final_summary.endswith(('.', '!', '?')): |
|
final_summary += '.' |
|
|
|
logger.info(f"Generated summary: {final_summary}") |
|
return {"response": final_summary, "type": "summary", "message": "Text was preprocessed to correct spelling errors"} |
|
|
|
except Exception as e: |
|
logger.error(f"Summarization error: {str(e)}") |
|
final_summary = get_gemini_response( |
|
f"Summarize this text in a concise and meaningful way: {content}" |
|
) |
|
return {"response": final_summary, "type": "summary", "message": "Text was preprocessed to correct spelling errors"} |
|
|
|
elif detected_intent == "image-to-text": |
|
if not file or not file.content_type.startswith('image/'): |
|
raise HTTPException(status_code=400, detail="An image file is required") |
|
|
|
image = Image.open(io.BytesIO(await file.read())) |
|
captioner = load_model("image-to-text") |
|
|
|
caption = captioner(image, max_new_tokens=50) |
|
|
|
return { |
|
"response": caption[0]['generated_text'], |
|
"type": "caption", |
|
"additional_data": { |
|
"image_size": f"{image.width}x{image.height}" |
|
} |
|
} |
|
|
|
elif detected_intent == "visual-qa": |
|
if not file or not file.content_type.startswith('image/'): |
|
raise HTTPException(status_code=400, detail="An image file is required") |
|
if not text: |
|
raise HTTPException(status_code=400, detail="A question is required for VQA") |
|
|
|
image = Image.open(io.BytesIO(await file.read())).convert("RGB") |
|
vqa_pipeline = load_model("visual-qa") |
|
|
|
question = text.strip() |
|
if not question.endswith('?'): |
|
question += '?' |
|
|
|
answer = vqa_pipeline( |
|
image=image, |
|
question=question |
|
) |
|
|
|
answer = answer.strip() |
|
if not answer or answer.lower() == question.lower(): |
|
logger.warning(f"VQA failed to generate a meaningful answer: {answer}") |
|
answer = "I couldn't determine the answer from the image." |
|
else: |
|
answer = answer.capitalize() |
|
if not answer.endswith(('.', '!', '?')): |
|
answer += '.' |
|
|
|
|
|
factual_questions = ['color', 'size', 'number', 'how many', 'what is the'] |
|
is_factual = any(keyword in question.lower() for keyword in factual_questions) |
|
|
|
if is_factual: |
|
|
|
final_answer = answer |
|
else: |
|
|
|
chatbot = load_model("chatbot") |
|
if "fly" in question.lower(): |
|
final_answer = chatbot.generate_content(f"Make this fun and spacey: {answer}").text.strip() |
|
else: |
|
final_answer = chatbot.generate_content(f"Make this cosmic and poetic: {answer}").text.strip() |
|
|
|
logger.info(f"Final VQA answer: {final_answer}") |
|
|
|
return { |
|
"response": final_answer, |
|
"type": "visual_qa", |
|
"additional_data": { |
|
"question": text, |
|
"image_size": f"{image.width}x{image.height}" |
|
} |
|
} |
|
|
|
elif detected_intent == "visualize": |
|
if not file: |
|
raise HTTPException(status_code=400, detail="An Excel file is required") |
|
|
|
file_content = await file.read() |
|
|
|
if file.filename.endswith('.csv'): |
|
df = pd.read_csv(io.BytesIO(file_content)) |
|
else: |
|
df = pd.read_excel(io.BytesIO(file_content)) |
|
|
|
code = generate_visualization_code(df, text) |
|
stats = df.describe().to_string() |
|
response = f"Stats:\n{stats}\n\nChart Code:\n{code}" |
|
|
|
return {"response": response, "type": "visualization_code"} |
|
|
|
elif detected_intent == "text-generation": |
|
response = get_gemini_response(text, is_generation=True) |
|
lines = response.split(". ") |
|
formatted_poem = "\n".join(line.strip() + ("." if not line.endswith(".") else "") for line in lines if line) |
|
return {"response": formatted_poem, "type": "generated_text"} |
|
|
|
elif detected_intent == "file-qa": |
|
if not file or not file.filename.lower().endswith(('.pdf', '.docx', '.doc', '.txt', '.rtf')): |
|
raise HTTPException(status_code=400, detail="A text-based file (PDF, DOCX, TXT, RTF) is required") |
|
if not text: |
|
raise HTTPException(status_code=400, detail="A question about the file is required") |
|
|
|
content = await extract_text_from_file(file) |
|
if not content.strip(): |
|
raise HTTPException(status_code=400, detail="No text could be extracted from the file") |
|
|
|
qa_pipeline = load_model("file-qa") |
|
|
|
question = text.strip() |
|
if not question.endswith('?'): |
|
question += '?' |
|
|
|
if len(content) > 512: |
|
chunks = [content[i:i+512] for i in range(0, len(content), 512)] |
|
answers = [] |
|
for chunk in chunks[:3]: |
|
result = qa_pipeline(question=question, context=chunk) |
|
if result['score'] > 0.1: |
|
answers.append((result['answer'], result['score'])) |
|
if answers: |
|
best_answer = max(answers, key=lambda x: x[1])[0] |
|
else: |
|
best_answer = "I couldn't find a clear answer in the document." |
|
else: |
|
result = qa_pipeline(question=question, context=content) |
|
best_answer = result['answer'] if result['score'] > 0.1 else "I couldn't find a clear answer in the document." |
|
|
|
best_answer = best_answer.strip().capitalize() |
|
if not best_answer.endswith(('.', '!', '?')): |
|
best_answer += '.' |
|
|
|
try: |
|
chatbot = load_model("chatbot") |
|
final_answer = chatbot.generate_content(f"Make this cosmic and poetic: {best_answer}").text.strip() |
|
except Exception as e: |
|
logger.warning(f"Failed to add cosmic tone: {str(e)}. Using raw answer.") |
|
final_answer = best_answer |
|
|
|
logger.info(f"File QA answer: {final_answer}") |
|
|
|
return { |
|
"response": final_answer, |
|
"type": "file_qa", |
|
"additional_data": { |
|
"question": text, |
|
"file_name": file.filename |
|
} |
|
} |
|
|
|
else: |
|
response = get_gemini_response(text or "Hello! How can I assist you?") |
|
return {"response": response, "type": "chat"} |
|
|
|
except Exception as e: |
|
logger.error(f"Processing error: {str(e)}", exc_info=True) |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
finally: |
|
process_time = time.time() - start_time |
|
logger.info(f"Request processed in {process_time:.2f} seconds") |
|
|
|
async def extract_text_from_file(file: UploadFile) -> str: |
|
"""Enhanced text extraction with multiple fallbacks""" |
|
if not file: |
|
return "" |
|
|
|
content = await file.read() |
|
filename = file.filename.lower() |
|
|
|
try: |
|
if filename.endswith('.pdf'): |
|
try: |
|
doc = fitz.open(stream=content, filetype="pdf") |
|
if doc.is_encrypted: |
|
return "PDF is encrypted and cannot be read" |
|
text = "" |
|
for page in doc: |
|
text += page.get_text() |
|
return text |
|
except Exception as pdf_error: |
|
logger.warning(f"PyMuPDF failed: {str(pdf_error)}. Trying pdfminer.six...") |
|
from pdfminer.high_level import extract_text |
|
from io import BytesIO |
|
return extract_text(BytesIO(content)) |
|
|
|
elif filename.endswith(('.docx', '.doc')): |
|
doc = Document(io.BytesIO(content)) |
|
return "\n".join(para.text for para in doc.paragraphs) |
|
|
|
elif filename.endswith('.txt'): |
|
return content.decode('utf-8', errors='replace') |
|
|
|
elif filename.endswith('.rtf'): |
|
text = content.decode('utf-8', errors='replace') |
|
text = re.sub(r'\\[a-z]+', ' ', text) |
|
text = re.sub(r'\{|\}|\\', '', text) |
|
return text |
|
|
|
else: |
|
raise HTTPException(status_code=400, detail=f"Unsupported file format: {filename}") |
|
|
|
except Exception as e: |
|
logger.error(f"File extraction error: {str(e)}", exc_info=True) |
|
raise HTTPException( |
|
status_code=500, |
|
detail=f"Error extracting text: {str(e)}. Supported formats: PDF, DOCX, TXT, RTF" |
|
) |
|
|
|
def generate_visualization_code(df: pd.DataFrame, request: str = None) -> str: |
|
"""Generate visualization code based on data analysis""" |
|
num_rows, num_cols = df.shape |
|
numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist() |
|
categorical_cols = df.select_dtypes(include=['object']).columns.tolist() |
|
date_cols = [col for col in df.columns if df[col].dtype == 'datetime64[ns]' or |
|
(isinstance(df[col].dtype, np.dtype) and pd.to_datetime(df[col], errors='coerce').notna().all())] |
|
|
|
if request: |
|
request_lower = request.lower() |
|
else: |
|
request_lower = "" |
|
|
|
if len(numeric_cols) >= 2 and ("scatter" in request_lower or "correlation" in request_lower): |
|
x_col = numeric_cols[0] |
|
y_col = numeric_cols[1] |
|
return f"""import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
df = pd.read_excel('data.xlsx') |
|
plt.figure(figsize=(10, 6)) |
|
sns.regplot(x='{x_col}', y='{y_col}', data=df, scatter_kws={{'alpha': 0.6}}) |
|
plt.title('Correlation between {x_col} and {y_col}') |
|
plt.grid(True, alpha=0.3) |
|
plt.tight_layout() |
|
plt.savefig('correlation_plot.png') |
|
plt.show() |
|
correlation = df['{x_col}'].corr(df['{y_col}']) |
|
print(f"Correlation coefficient: {{correlation:.4f}}")""" |
|
|
|
elif len(numeric_cols) >= 1 and len(categorical_cols) >= 1 and ("bar" in request_lower or "comparison" in request_lower): |
|
cat_col = categorical_cols[0] |
|
num_col = numeric_cols[0] |
|
return f"""import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
df = pd.read_excel('data.xlsx') |
|
plt.figure(figsize=(12, 7)) |
|
ax = sns.barplot(x='{cat_col}', y='{num_col}', data=df, palette='viridis') |
|
for p in ax.patches: |
|
ax.annotate(f'{{p.get_height():.1f}}', |
|
(p.get_x() + p.get_width() / 2., p.get_height()), |
|
ha='center', va='bottom', fontsize=10, color='black', xytext=(0, 5), |
|
textcoords='offset points') |
|
plt.title('Comparison of {num_col} by {cat_col}', fontsize=15) |
|
plt.xlabel('{cat_col}', fontsize=12) |
|
plt.ylabel('{num_col}', fontsize=12) |
|
plt.xticks(rotation=45, ha='right') |
|
plt.grid(axis='y', alpha=0.3) |
|
plt.tight_layout() |
|
plt.savefig('comparison_chart.png') |
|
plt.show()""" |
|
|
|
elif len(numeric_cols) >= 1 and ("distribution" in request_lower or "histogram" in request_lower): |
|
num_col = numeric_cols[0] |
|
return f"""import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
df = pd.read_excel('data.xlsx') |
|
plt.figure(figsize=(10, 6)) |
|
sns.histplot(df['{num_col}'], kde=True, bins=20, color='purple') |
|
plt.title('Distribution of {num_col}', fontsize=15) |
|
plt.xlabel('{num_col}', fontsize=12) |
|
plt.ylabel('Frequency', fontsize=12) |
|
plt.grid(True, alpha=0.3) |
|
plt.tight_layout() |
|
plt.savefig('distribution_plot.png') |
|
plt.show() |
|
print(df['{num_col}'].describe())""" |
|
|
|
else: |
|
return f"""import pandas as pd |
|
import matplotlib.pyplot as plt |
|
import seaborn as sns |
|
import numpy as np |
|
df = pd.read_excel('data.xlsx') |
|
print("Descriptive statistics:") |
|
print(df.describe()) |
|
fig, axes = plt.subplots(2, 2, figsize=(15, 12)) |
|
numeric_df = df.select_dtypes(include=[np.number]) |
|
if not numeric_df.empty and numeric_df.shape[1] > 1: |
|
sns.heatmap(numeric_df.corr(), annot=True, cmap='coolwarm', fmt='.2f', ax=axes[0, 0]) |
|
axes[0, 0].set_title('Correlation Matrix') |
|
if not numeric_df.empty: |
|
for i, col in enumerate(numeric_df.columns[:1]): |
|
sns.histplot(df[col], kde=True, ax=axes[0, 1], color='purple') |
|
axes[0, 1].set_title(f'Distribution of {col}') |
|
axes[0, 1].set_xlabel(col) |
|
axes[0, 1].set_ylabel('Frequency') |
|
categorical_cols = df.select_dtypes(include=['object']).columns |
|
if len(categorical_cols) > 0 and not numeric_df.empty: |
|
cat_col = categorical_cols[0] |
|
num_col = numeric_df.columns[0] |
|
sns.barplot(x=cat_col, y=num_col, data=df, ax=axes[1, 0], palette='viridis') |
|
axes[1, 0].set_title(f'{num_col} by {cat_col}') |
|
axes[1, 0].set_xticklabels(axes[1, 0].get_xticklabels(), rotation=45, ha='right') |
|
if not numeric_df.empty and len(categorical_cols) > 0: |
|
cat_col = categorical_cols[0] |
|
num_col = numeric_df.columns[0] |
|
sns.boxplot(x=cat_col, y=num_col, data=df, ax=axes[1, 1], palette='Set3') |
|
axes[1, 1].set_title(f'Distribution of {num_col} by {cat_col}') |
|
axes[1, 1].set_xticklabels(axes[1, 1].get_xticklabels(), rotation=45, ha='right') |
|
plt.tight_layout() |
|
plt.savefig('dashboard.png') |
|
plt.show()""" |
|
|
|
@app.get("/", include_in_schema=False) |
|
async def home(): |
|
"""Redirect to the static index.html file""" |
|
return RedirectResponse(url="/static/index.html") |
|
|
|
@app.get("/health", include_in_schema=True) |
|
async def health_check(): |
|
"""Health check endpoint""" |
|
return {"status": "healthy", "version": "2.0.0"} |
|
|
|
@app.get("/models", include_in_schema=True) |
|
async def list_models(): |
|
"""List available models""" |
|
return {"models": MODELS} |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Pre-load models at startup with timeout and fallback""" |
|
global translation_model, translation_tokenizer |
|
logger.info("Starting model pre-loading...") |
|
|
|
async def load_model_with_timeout(task, model_name=None): |
|
try: |
|
await asyncio.wait_for( |
|
asyncio.to_thread(load_model, task, model_name), |
|
timeout=60.0 |
|
) |
|
logger.info(f"Successfully pre-loaded {task} model") |
|
except asyncio.TimeoutError: |
|
logger.warning(f"Timeout loading {task} model - will load on demand") |
|
except Exception as e: |
|
logger.error(f"Error pre-loading {task}: {str(e)}") |
|
|
|
|
|
try: |
|
model_name = MODELS["translation"] |
|
logger.info(f"Attempting to load translation model: {model_name}") |
|
translation_model = M2M100ForConditionalGeneration.from_pretrained( |
|
model_name, |
|
cache_dir=os.getenv("HF_HOME", "/app/cache") |
|
) |
|
translation_tokenizer = M2M100Tokenizer.from_pretrained( |
|
model_name, |
|
cache_dir=os.getenv("HF_HOME", "/app/cache") |
|
) |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
translation_model.to(device) |
|
logger.info("Translation model pre-loaded successfully") |
|
except Exception as e: |
|
logger.error(f"Error pre-loading translation model: {str(e)}") |
|
|
|
translation_model = None |
|
translation_tokenizer = None |
|
|
|
|
|
await asyncio.gather( |
|
load_model_with_timeout("summarization"), |
|
load_model_with_timeout("image-to-text"), |
|
load_model_with_timeout("visual-qa"), |
|
load_model_with_timeout("chatbot"), |
|
load_model_with_timeout("file-qa") |
|
) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run("app:app", host="0.0.0.0", port=7860, reload=True) |