|
import os |
|
import uuid |
|
from fastapi import APIRouter, Depends, HTTPException, Request, status, UploadFile, File |
|
from fastapi.responses import StreamingResponse |
|
from api.database import User, Conversation, Message |
|
from api.models import QueryRequest, ConversationOut, ConversationCreate, UserUpdate |
|
from api.auth import current_active_user |
|
from api.database import get_db |
|
from sqlalchemy.ext.asyncio import AsyncSession |
|
from sqlalchemy import select, delete |
|
from utils.generation import request_generation, select_model, check_model_availability |
|
from utils.web_search import web_search |
|
import io |
|
from openai import OpenAI |
|
from motor.motor_asyncio import AsyncIOMotorClient |
|
from datetime import datetime |
|
import logging |
|
from typing import List, Optional |
|
|
|
router = APIRouter() |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
if not HF_TOKEN: |
|
logger.error("HF_TOKEN is not set in environment variables.") |
|
raise ValueError("HF_TOKEN is required for Inference API.") |
|
|
|
BACKUP_HF_TOKEN = os.getenv("BACKUP_HF_TOKEN") |
|
if not BACKUP_HF_TOKEN: |
|
logger.warning("BACKUP_HF_TOKEN is not set. Fallback to secondary model will not work if primary token fails.") |
|
|
|
ROUTER_API_URL = os.getenv("ROUTER_API_URL", "https://router.huggingface.co") |
|
API_ENDPOINT = os.getenv("API_ENDPOINT", "https://api-inference.huggingface.co/v1") |
|
FALLBACK_API_ENDPOINT = os.getenv("FALLBACK_API_ENDPOINT", "https://api-inference.huggingface.co/v1") |
|
MODEL_NAME = os.getenv("MODEL_NAME", "openai/gpt-oss-120b") |
|
SECONDARY_MODEL_NAME = os.getenv("SECONDARY_MODEL_NAME", "mistralai/Mixtral-8x7B-Instruct-v0.1") |
|
TERTIARY_MODEL_NAME = os.getenv("TERTIARY_MODEL_NAME", "meta-llama/Llama-3-8b-chat-hf") |
|
CLIP_BASE_MODEL = os.getenv("CLIP_BASE_MODEL", "Salesforce/blip-image-captioning-large") |
|
CLIP_LARGE_MODEL = os.getenv("CLIP_LARGE_MODEL", "openai/clip-vit-large-patch14") |
|
ASR_MODEL = os.getenv("ASR_MODEL", "openai/whisper-large-v3") |
|
TTS_MODEL = os.getenv("TTS_MODEL", "facebook/mms-tts-ara") |
|
|
|
|
|
MODEL_ALIASES = { |
|
"advanced": MODEL_NAME, |
|
"standard": SECONDARY_MODEL_NAME, |
|
"light": TERTIARY_MODEL_NAME, |
|
"image_base": CLIP_BASE_MODEL, |
|
"image_advanced": CLIP_LARGE_MODEL, |
|
"audio": ASR_MODEL, |
|
"tts": TTS_MODEL |
|
} |
|
|
|
|
|
MONGO_URI = os.getenv("MONGODB_URI") |
|
client = AsyncIOMotorClient(MONGO_URI) |
|
db = client["hager"] |
|
session_message_counts = db["session_message_counts"] |
|
|
|
|
|
async def handle_session(request: Request): |
|
if not hasattr(request, "session"): |
|
raise HTTPException(status_code=500, detail="Session middleware not configured") |
|
session_id = request.session.get("session_id") |
|
if not session_id: |
|
session_id = str(uuid.uuid4()) |
|
request.session["session_id"] = session_id |
|
await session_message_counts.insert_one({"session_id": session_id, "message_count": 0}) |
|
|
|
session_doc = await session_message_counts.find_one({"session_id": session_id}) |
|
if not session_doc: |
|
session_doc = {"session_id": session_id, "message_count": 0} |
|
await session_message_counts.insert_one(session_doc) |
|
|
|
message_count = session_doc["message_count"] + 1 |
|
await session_message_counts.update_one( |
|
{"session_id": session_id}, |
|
{"$set": {"message_count": message_count}} |
|
) |
|
if message_count > 4: |
|
raise HTTPException( |
|
status_code=status.HTTP_403_FORBIDDEN, |
|
detail="Message limit reached. Please log in to continue." |
|
) |
|
return session_id |
|
|
|
|
|
def enhance_system_prompt(system_prompt: str, message: str, user: Optional[User] = None) -> str: |
|
enhanced_prompt = system_prompt |
|
if any(0x0600 <= ord(char) <= 0x06FF for char in message): |
|
enhanced_prompt += "\nRespond in Arabic with clear, concise, and accurate information tailored to the user's query." |
|
if user and user.additional_info: |
|
enhanced_prompt += f"\nUser Profile: {user.additional_info}\nConversation Style: {user.conversation_style or 'default'}" |
|
return enhanced_prompt |
|
|
|
@router.get("/api/settings") |
|
async def get_settings(user: User = Depends(current_active_user)): |
|
if not user: |
|
raise HTTPException(status_code=401, detail="Login required") |
|
return { |
|
"available_models": [ |
|
{"alias": "advanced", "description": "High-performance model for complex queries"}, |
|
{"alias": "standard", "description": "Balanced model for general use"}, |
|
{"alias": "light", "description": "Lightweight model for quick responses"} |
|
], |
|
"conversation_styles": ["default", "concise", "analytical", "creative"], |
|
"user_settings": { |
|
"display_name": user.display_name, |
|
"preferred_model": user.preferred_model, |
|
"job_title": user.job_title, |
|
"education": user.education, |
|
"interests": user.interests, |
|
"additional_info": user.additional_info, |
|
"conversation_style": user.conversation_style |
|
} |
|
} |
|
|
|
@router.get("/api/model-info") |
|
async def model_info(): |
|
return { |
|
"available_models": [ |
|
{"alias": "advanced", "description": "High-performance model for complex queries"}, |
|
{"alias": "standard", "description": "Balanced model for general use"}, |
|
{"alias": "light", "description": "Lightweight model for quick responses"}, |
|
{"alias": "image_base", "description": "Basic image analysis model"}, |
|
{"alias": "image_advanced", "description": "Advanced image analysis model"}, |
|
{"alias": "audio", "description": "Audio transcription model (default)"}, |
|
{"alias": "tts", "description": "Text-to-speech model (default)"} |
|
], |
|
"api_base": API_ENDPOINT, |
|
"fallback_api_base": FALLBACK_API_ENDPOINT, |
|
"status": "online" |
|
} |
|
|
|
@router.get("/api/performance") |
|
async def performance_stats(): |
|
return { |
|
"queue_size": int(os.getenv("QUEUE_SIZE", 80)), |
|
"concurrency_limit": int(os.getenv("CONCURRENCY_LIMIT", 20)), |
|
"uptime": os.popen("uptime").read().strip() |
|
} |
|
|
|
|
|
@router.post("/api/chat") |
|
async def chat_endpoint( |
|
request: Request, |
|
req: QueryRequest, |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
logger.info(f"Received chat request: {req}") |
|
|
|
if not user: |
|
await handle_session(request) |
|
|
|
conversation = None |
|
if user: |
|
title = req.title or (req.message[:50] + "..." if len(req.message) > 50 else req.message or "Untitled Conversation") |
|
result = await db.execute( |
|
select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc()) |
|
) |
|
conversation = result.scalar_one_or_none() |
|
if not conversation: |
|
conversation_id = str(uuid.uuid4()) |
|
conversation = Conversation( |
|
conversation_id=conversation_id, |
|
user_id=user.id, |
|
title=title |
|
) |
|
db.add(conversation) |
|
await db.commit() |
|
await db.refresh(conversation) |
|
|
|
user_msg = Message(role="user", content=req.message, conversation_id=conversation.id) |
|
db.add(user_msg) |
|
await db.commit() |
|
|
|
preferred_model = user.preferred_model if user else None |
|
model_name, api_endpoint = select_model(req.message, input_type="text", preferred_model=preferred_model) |
|
|
|
|
|
is_available, api_key, selected_endpoint = check_model_availability(model_name, HF_TOKEN) |
|
if not is_available: |
|
logger.warning(f"Model {model_name} is not available at {api_endpoint}, trying fallback model.") |
|
model_name = SECONDARY_MODEL_NAME |
|
is_available, api_key, selected_endpoint = check_model_availability(model_name, HF_TOKEN) |
|
if not is_available: |
|
logger.error(f"Fallback model {model_name} is not available at {selected_endpoint}") |
|
raise HTTPException(status_code=503, detail=f"No available models. Tried {MODEL_NAME} and {SECONDARY_MODEL_NAME}.") |
|
|
|
system_prompt = enhance_system_prompt(req.system_prompt, req.message, user) |
|
|
|
stream = request_generation( |
|
api_key=api_key, |
|
api_base=selected_endpoint, |
|
message=req.message, |
|
system_prompt=system_prompt, |
|
model_name=model_name, |
|
chat_history=req.history, |
|
temperature=req.temperature, |
|
max_new_tokens=req.max_new_tokens or 2048, |
|
deep_search=req.enable_browsing, |
|
input_type="text", |
|
output_format=req.output_format |
|
) |
|
|
|
if req.output_format == "audio": |
|
audio_chunks = [] |
|
try: |
|
for chunk in stream: |
|
if isinstance(chunk, bytes): |
|
audio_chunks.append(chunk) |
|
else: |
|
logger.warning(f"Unexpected non-bytes chunk in audio stream: {chunk}") |
|
if not audio_chunks: |
|
logger.error("No audio data generated.") |
|
raise HTTPException(status_code=502, detail="No audio data generated. Model may be unavailable.") |
|
audio_data = b"".join(audio_chunks) |
|
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav") |
|
except Exception as e: |
|
logger.error(f"Audio generation failed: {e}") |
|
raise HTTPException(status_code=502, detail=f"Audio generation failed: {str(e)}") |
|
|
|
response_chunks = [] |
|
try: |
|
for chunk in stream: |
|
if isinstance(chunk, str): |
|
response_chunks.append(chunk) |
|
else: |
|
logger.warning(f"Unexpected non-string chunk in text stream: {chunk}") |
|
response = "".join(response_chunks) |
|
if not response.strip(): |
|
logger.warning(f"Empty response from {model_name}. Trying fallback model {SECONDARY_MODEL_NAME}.") |
|
|
|
model_name = SECONDARY_MODEL_NAME |
|
is_available, api_key, selected_endpoint = check_model_availability(model_name, HF_TOKEN) |
|
if not is_available: |
|
logger.error(f"Fallback model {model_name} is not available at {selected_endpoint}") |
|
raise HTTPException(status_code=503, detail=f"No available models. Tried {MODEL_NAME} and {SECONDARY_MODEL_NAME}.") |
|
|
|
stream = request_generation( |
|
api_key=api_key, |
|
api_base=selected_endpoint, |
|
message=req.message, |
|
system_prompt=system_prompt, |
|
model_name=model_name, |
|
chat_history=req.history, |
|
temperature=req.temperature, |
|
max_new_tokens=req.max_new_tokens or 2048, |
|
deep_search=req.enable_browsing, |
|
input_type="text", |
|
output_format=req.output_format |
|
) |
|
response_chunks = [] |
|
for chunk in stream: |
|
if isinstance(chunk, str): |
|
response_chunks.append(chunk) |
|
else: |
|
logger.warning(f"Unexpected non-string chunk in text stream: {chunk}") |
|
response = "".join(response_chunks) |
|
if not response.strip(): |
|
logger.error(f"Empty response from fallback model {model_name}.") |
|
raise HTTPException(status_code=502, detail=f"Empty response from both {MODEL_NAME} and {SECONDARY_MODEL_NAME}.") |
|
logger.info(f"Chat response: {response[:100]}...") |
|
except Exception as e: |
|
logger.error(f"Chat generation failed: {e}") |
|
raise HTTPException(status_code=500, detail=f"Chat generation failed: {str(e)}") |
|
|
|
if user and conversation: |
|
assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id) |
|
db.add(assistant_msg) |
|
await db.commit() |
|
conversation.updated_at = datetime.utcnow() |
|
await db.commit() |
|
return { |
|
"response": response, |
|
"conversation_id": conversation.conversation_id, |
|
"conversation_url": f"https://mgzon-mgzon-app.hf.space/chat/{conversation.conversation_id}", |
|
"conversation_title": conversation.title |
|
} |
|
|
|
return {"response": response} |
|
@router.post("/api/audio-transcription") |
|
async def audio_transcription_endpoint( |
|
request: Request, |
|
file: UploadFile = File(...), |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
logger.info(f"Received audio transcription request for file: {file.filename}") |
|
|
|
if not user: |
|
await handle_session(request) |
|
|
|
conversation = None |
|
if user: |
|
title = "Audio Transcription" |
|
result = await db.execute( |
|
select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc()) |
|
) |
|
conversation = result.scalar_one_or_none() |
|
if not conversation: |
|
conversation_id = str(uuid.uuid4()) |
|
conversation = Conversation( |
|
conversation_id=conversation_id, |
|
user_id=user.id, |
|
title=title |
|
) |
|
db.add(conversation) |
|
await db.commit() |
|
await db.refresh(conversation) |
|
|
|
user_msg = Message(role="user", content="Audio message", conversation_id=conversation.id) |
|
db.add(user_msg) |
|
await db.commit() |
|
|
|
model_name, api_endpoint = select_model("transcribe audio", input_type="audio") |
|
|
|
is_available, api_key, selected_endpoint = check_model_availability(model_name, HF_TOKEN) |
|
if not is_available: |
|
logger.error(f"Model {model_name} is not available at {api_endpoint}") |
|
raise HTTPException(status_code=503, detail=f"Model {model_name} is not available. Please try another model.") |
|
|
|
audio_data = await file.read() |
|
stream = request_generation( |
|
api_key=api_key, |
|
api_base=selected_endpoint, |
|
message="Transcribe audio", |
|
system_prompt="Transcribe the provided audio using Whisper. Ensure accurate transcription in the detected language.", |
|
model_name=model_name, |
|
temperature=0.7, |
|
max_new_tokens=2048, |
|
input_type="audio", |
|
audio_data=audio_data, |
|
output_format="text" |
|
) |
|
response_chunks = [] |
|
try: |
|
for chunk in stream: |
|
if isinstance(chunk, str): |
|
response_chunks.append(chunk) |
|
else: |
|
logger.warning(f"Unexpected non-string chunk in transcription stream: {chunk}") |
|
response = "".join(response_chunks) |
|
if not response.strip(): |
|
logger.error("Empty transcription generated.") |
|
raise HTTPException(status_code=500, detail="Empty transcription generated from model.") |
|
logger.info(f"Audio transcription response: {response[:100]}...") |
|
except Exception as e: |
|
logger.error(f"Audio transcription failed: {e}") |
|
raise HTTPException(status_code=500, detail=f"Audio transcription failed: {str(e)}") |
|
|
|
if user and conversation: |
|
assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id) |
|
db.add(assistant_msg) |
|
await db.commit() |
|
conversation.updated_at = datetime.utcnow() |
|
await db.commit() |
|
return { |
|
"transcription": response, |
|
"conversation_id": conversation.conversation_id, |
|
"conversation_url": f"https://mgzon-mgzon-app.hf.space/chat/{conversation.conversation_id}", |
|
"conversation_title": conversation.title |
|
} |
|
|
|
return {"transcription": response} |
|
|
|
@router.post("/api/text-to-speech") |
|
async def text_to_speech_endpoint( |
|
request: Request, |
|
req: dict, |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
if not user: |
|
await handle_session(request) |
|
|
|
text = req.get("text", "") |
|
if not text.strip(): |
|
raise HTTPException(status_code=400, detail="Text input is required for text-to-speech.") |
|
|
|
model_name, api_endpoint = select_model("text to speech", input_type="tts") |
|
|
|
is_available, api_key, selected_endpoint = check_model_availability(model_name, HF_TOKEN) |
|
if not is_available: |
|
logger.error(f"Model {model_name} is not available at {api_endpoint}") |
|
raise HTTPException(status_code=503, detail=f"Model {model_name} is not available. Please try another model.") |
|
|
|
stream = request_generation( |
|
api_key=api_key, |
|
api_base=selected_endpoint, |
|
message=text, |
|
system_prompt="Convert the provided text to speech using a text-to-speech model. Ensure clear and natural pronunciation, especially for Arabic text.", |
|
model_name=model_name, |
|
temperature=0.7, |
|
max_new_tokens=2048, |
|
input_type="tts", |
|
output_format="audio" |
|
) |
|
audio_chunks = [] |
|
try: |
|
for chunk in stream: |
|
if isinstance(chunk, bytes): |
|
audio_chunks.append(chunk) |
|
else: |
|
logger.warning(f"Unexpected non-bytes chunk in TTS stream: {chunk}") |
|
if not audio_chunks: |
|
logger.error("No audio data generated for TTS.") |
|
raise HTTPException(status_code=500, detail="No audio data generated for text-to-speech.") |
|
audio_data = b"".join(audio_chunks) |
|
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav") |
|
except Exception as e: |
|
logger.error(f"Text-to-speech generation failed: {e}") |
|
raise HTTPException(status_code=500, detail=f"Text-to-speech generation failed: {str(e)}") |
|
|
|
@router.post("/api/code") |
|
async def code_endpoint( |
|
request: Request, |
|
req: dict, |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
if not user: |
|
await handle_session(request) |
|
|
|
framework = req.get("framework") |
|
task = req.get("task") |
|
code = req.get("code", "") |
|
output_format = req.get("output_format", "text") |
|
if not task: |
|
raise HTTPException(status_code=400, detail="Task description is required.") |
|
|
|
prompt = f"Generate code for task: {task} using {framework}. Existing code: {code}" |
|
preferred_model = user.preferred_model if user else None |
|
model_name, api_endpoint = select_model(prompt, input_type="text", preferred_model=preferred_model) |
|
|
|
is_available, api_key, selected_endpoint = check_model_availability(model_name, HF_TOKEN) |
|
if not is_available: |
|
logger.error(f"Model {model_name} is not available at {api_endpoint}") |
|
raise HTTPException(status_code=503, detail=f"Model {model_name} is not available. Please try another model.") |
|
|
|
system_prompt = enhance_system_prompt( |
|
"You are a coding expert. Provide detailed, well-commented code with examples and explanations.", |
|
prompt, user |
|
) |
|
|
|
stream = request_generation( |
|
api_key=api_key, |
|
api_base=selected_endpoint, |
|
message=prompt, |
|
system_prompt=system_prompt, |
|
model_name=model_name, |
|
temperature=0.7, |
|
max_new_tokens=2048, |
|
input_type="text", |
|
output_format=output_format |
|
) |
|
if output_format == "audio": |
|
audio_chunks = [] |
|
try: |
|
for chunk in stream: |
|
if isinstance(chunk, bytes): |
|
audio_chunks.append(chunk) |
|
else: |
|
logger.warning(f"Unexpected non-bytes chunk in code audio stream: {chunk}") |
|
if not audio_chunks: |
|
logger.error("No audio data generated for code.") |
|
raise HTTPException(status_code=500, detail="No audio data generated for code.") |
|
audio_data = b"".join(audio_chunks) |
|
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav") |
|
except Exception as e: |
|
logger.error(f"Code audio generation failed: {e}") |
|
raise HTTPException(status_code=500, detail=f"Code audio generation failed: {str(e)}") |
|
|
|
response_chunks = [] |
|
try: |
|
for chunk in stream: |
|
if isinstance(chunk, str): |
|
response_chunks.append(chunk) |
|
else: |
|
logger.warning(f"Unexpected non-string chunk in code stream: {chunk}") |
|
response = "".join(response_chunks) |
|
if not response.strip(): |
|
logger.error("Empty code response generated.") |
|
raise HTTPException(status_code=500, detail="Empty code response generated from model.") |
|
return {"generated_code": response} |
|
except Exception as e: |
|
logger.error(f"Code generation failed: {e}") |
|
raise HTTPException(status_code=500, detail=f"Code generation failed: {str(e)}") |
|
|
|
@router.post("/api/analysis") |
|
async def analysis_endpoint( |
|
request: Request, |
|
req: dict, |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
if not user: |
|
await handle_session(request) |
|
|
|
message = req.get("text", "") |
|
output_format = req.get("output_format", "text") |
|
if not message.strip(): |
|
raise HTTPException(status_code=400, detail="Text input is required for analysis.") |
|
|
|
preferred_model = user.preferred_model if user else None |
|
model_name, api_endpoint = select_model(message, input_type="text", preferred_model=preferred_model) |
|
|
|
is_available, api_key, selected_endpoint = check_model_availability(model_name, HF_TOKEN) |
|
if not is_available: |
|
logger.error(f"Model {model_name} is not available at {api_endpoint}") |
|
raise HTTPException(status_code=503, detail=f"Model {model_name} is not available. Please try another model.") |
|
|
|
system_prompt = enhance_system_prompt( |
|
"You are an expert analyst. Provide detailed analysis with step-by-step reasoning and examples.", |
|
message, user |
|
) |
|
|
|
stream = request_generation( |
|
api_key=api_key, |
|
api_base=selected_endpoint, |
|
message=message, |
|
system_prompt=system_prompt, |
|
model_name=model_name, |
|
temperature=0.7, |
|
max_new_tokens=2048, |
|
input_type="text", |
|
output_format=output_format |
|
) |
|
if output_format == "audio": |
|
audio_chunks = [] |
|
try: |
|
for chunk in stream: |
|
if isinstance(chunk, bytes): |
|
audio_chunks.append(chunk) |
|
else: |
|
logger.warning(f"Unexpected non-bytes chunk in analysis audio stream: {chunk}") |
|
if not audio_chunks: |
|
logger.error("No audio data generated for analysis.") |
|
raise HTTPException(status_code=500, detail="No audio data generated for analysis.") |
|
audio_data = b"".join(audio_chunks) |
|
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav") |
|
except Exception as e: |
|
logger.error(f"Analysis audio generation failed: {e}") |
|
raise HTTPException(status_code=500, detail=f"Analysis audio generation failed: {str(e)}") |
|
|
|
response_chunks = [] |
|
try: |
|
for chunk in stream: |
|
if isinstance(chunk, str): |
|
response_chunks.append(chunk) |
|
else: |
|
logger.warning(f"Unexpected non-string chunk in analysis stream: {chunk}") |
|
response = "".join(response_chunks) |
|
if not response.strip(): |
|
logger.error("Empty analysis response generated.") |
|
raise HTTPException(status_code=500, detail="Empty analysis response generated from model.") |
|
return {"analysis": response} |
|
except Exception as e: |
|
logger.error(f"Analysis generation failed: {e}") |
|
raise HTTPException(status_code=500, detail=f"Analysis generation failed: {str(e)}") |
|
|
|
@router.post("/api/image-analysis") |
|
async def image_analysis_endpoint( |
|
request: Request, |
|
file: UploadFile = File(...), |
|
output_format: str = "text", |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
if not user: |
|
await handle_session(request) |
|
|
|
conversation = None |
|
if user: |
|
title = "Image Analysis" |
|
result = await db.execute( |
|
select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.updated_at.desc()) |
|
) |
|
conversation = result.scalar_one_or_none() |
|
if not conversation: |
|
conversation_id = str(uuid.uuid4()) |
|
conversation = Conversation( |
|
conversation_id=conversation_id, |
|
user_id=user.id, |
|
title=title |
|
) |
|
db.add(conversation) |
|
await db.commit() |
|
await db.refresh(conversation) |
|
|
|
user_msg = Message(role="user", content="Image analysis request", conversation_id=conversation.id) |
|
db.add(user_msg) |
|
await db.commit() |
|
|
|
preferred_model = user.preferred_model if user else None |
|
model_name, api_endpoint = select_model("analyze image", input_type="image", preferred_model=preferred_model) |
|
|
|
is_available, api_key, selected_endpoint = check_model_availability(model_name, HF_TOKEN) |
|
if not is_available: |
|
logger.error(f"Model {model_name} is not available at {api_endpoint}") |
|
raise HTTPException(status_code=503, detail=f"Model {model_name} is not available. Please try another model.") |
|
|
|
image_data = await file.read() |
|
system_prompt = enhance_system_prompt( |
|
"You are an expert in image analysis. Provide detailed descriptions or classifications based on the query.", |
|
"Analyze this image", user |
|
) |
|
|
|
stream = request_generation( |
|
api_key=api_key, |
|
api_base=selected_endpoint, |
|
message="Analyze this image", |
|
system_prompt=system_prompt, |
|
model_name=model_name, |
|
temperature=0.7, |
|
max_new_tokens=2048, |
|
input_type="image", |
|
image_data=image_data, |
|
output_format=output_format |
|
) |
|
if output_format == "audio": |
|
audio_chunks = [] |
|
try: |
|
for chunk in stream: |
|
if isinstance(chunk, bytes): |
|
audio_chunks.append(chunk) |
|
else: |
|
logger.warning(f"Unexpected non-bytes chunk in image analysis audio stream: {chunk}") |
|
if not audio_chunks: |
|
logger.error("No audio data generated for image analysis.") |
|
raise HTTPException(status_code=500, detail="No audio data generated for image analysis.") |
|
audio_data = b"".join(audio_chunks) |
|
return StreamingResponse(io.BytesIO(audio_data), media_type="audio/wav") |
|
except Exception as e: |
|
logger.error(f"Image analysis audio generation failed: {e}") |
|
raise HTTPException(status_code=500, detail=f"Image analysis audio generation failed: {str(e)}") |
|
|
|
response_chunks = [] |
|
try: |
|
for chunk in stream: |
|
if isinstance(chunk, str): |
|
response_chunks.append(chunk) |
|
else: |
|
logger.warning(f"Unexpected non-string chunk in image analysis stream: {chunk}") |
|
response = "".join(response_chunks) |
|
if not response.strip(): |
|
logger.error("Empty image analysis response generated.") |
|
raise HTTPException(status_code=500, detail="Empty image analysis response generated from model.") |
|
|
|
if user and conversation: |
|
assistant_msg = Message(role="assistant", content=response, conversation_id=conversation.id) |
|
db.add(assistant_msg) |
|
await db.commit() |
|
conversation.updated_at = datetime.utcnow() |
|
await db.commit() |
|
return { |
|
"image_analysis": response, |
|
"conversation_id": conversation.conversation_id, |
|
"conversation_url": f"https://mgzon-mgzon-app.hf.space/chat/{conversation.conversation_id}", |
|
"conversation_title": conversation.title |
|
} |
|
|
|
return {"image_analysis": response} |
|
except Exception as e: |
|
logger.error(f"Image analysis failed: {e}") |
|
raise HTTPException(status_code=500, detail=f"Image analysis failed: {str(e)}") |
|
|
|
@router.get("/api/test-model") |
|
async def test_model(model: str = MODEL_NAME, endpoint: str = API_ENDPOINT): |
|
try: |
|
is_available, api_key, selected_endpoint = check_model_availability(model, HF_TOKEN) |
|
if not is_available: |
|
logger.error(f"Model {model} is not available at {endpoint}") |
|
raise HTTPException(status_code=503, detail=f"Model {model} is not available.") |
|
|
|
client = OpenAI(api_key=api_key, base_url=selected_endpoint, timeout=60.0) |
|
response = client.chat.completions.create( |
|
model=model, |
|
messages=[{"role": "user", "content": "Test"}], |
|
max_tokens=50 |
|
) |
|
return {"status": "success", "response": response.choices[0].message.content} |
|
except Exception as e: |
|
logger.error(f"Test model failed: {e}") |
|
raise HTTPException(status_code=500, detail=f"Test model failed: {str(e)}") |
|
|
|
@router.post("/api/conversations", response_model=ConversationOut) |
|
async def create_conversation( |
|
req: ConversationCreate, |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
if not user: |
|
raise HTTPException(status_code=401, detail="Login required") |
|
conversation_id = str(uuid.uuid4()) |
|
conversation = Conversation( |
|
conversation_id=conversation_id, |
|
title=req.title or "Untitled Conversation", |
|
user_id=user.id |
|
) |
|
db.add(conversation) |
|
await db.commit() |
|
await db.refresh(conversation) |
|
return ConversationOut.from_orm(conversation) |
|
|
|
@router.get("/api/conversations/{conversation_id}", response_model=ConversationOut) |
|
async def get_conversation( |
|
conversation_id: str, |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
if not user: |
|
raise HTTPException(status_code=401, detail="Login required") |
|
result = await db.execute( |
|
select(Conversation).filter( |
|
Conversation.conversation_id == conversation_id, |
|
Conversation.user_id == user.id |
|
) |
|
) |
|
conversation = result.scalar_one_or_none() |
|
if not conversation: |
|
raise HTTPException(status_code=404, detail="Conversation not found") |
|
return ConversationOut.from_orm(conversation) |
|
|
|
@router.get("/api/conversations", response_model=List[ConversationOut]) |
|
async def list_conversations( |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
if not user: |
|
raise HTTPException(status_code=401, detail="Login required") |
|
result = await db.execute( |
|
select(Conversation).filter(Conversation.user_id == user.id).order_by(Conversation.created_at.desc()) |
|
) |
|
conversations = result.scalars().all() |
|
return [ConversationOut.from_orm(conv) for conv in conversations] |
|
|
|
@router.put("/api/conversations/{conversation_id}/title") |
|
async def update_conversation_title( |
|
conversation_id: str, |
|
title: str, |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
if not user: |
|
raise HTTPException(status_code=401, detail="Login required") |
|
result = await db.execute( |
|
select(Conversation).filter( |
|
Conversation.conversation_id == conversation_id, |
|
Conversation.user_id == user.id |
|
) |
|
) |
|
conversation = result.scalar_one_or_none() |
|
if not conversation: |
|
raise HTTPException(status_code=404, detail="Conversation not found") |
|
|
|
conversation.title = title |
|
conversation.updated_at = datetime.utcnow() |
|
await db.commit() |
|
return {"message": "Conversation title updated", "title": conversation.title} |
|
|
|
@router.delete("/api/conversations/{conversation_id}") |
|
async def delete_conversation( |
|
conversation_id: str, |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
if not user: |
|
raise HTTPException(status_code=401, detail="Login required") |
|
result = await db.execute( |
|
select(Conversation).filter( |
|
Conversation.conversation_id == conversation_id, |
|
Conversation.user_id == user.id |
|
) |
|
) |
|
conversation = result.scalar_one_or_none() |
|
if not conversation: |
|
raise HTTPException(status_code=404, detail="Conversation not found") |
|
|
|
await db.execute(delete(Message).filter(Message.conversation_id == conversation.id)) |
|
await db.delete(conversation) |
|
await db.commit() |
|
return {"message": "Conversation deleted successfully"} |
|
|
|
@router.get("/users/me") |
|
async def get_user_settings(user: User = Depends(current_active_user)): |
|
if not user: |
|
raise HTTPException(status_code=401, detail="Login required") |
|
return { |
|
"id": user.id, |
|
"email": user.email, |
|
"display_name": user.display_name, |
|
"preferred_model": user.preferred_model, |
|
"job_title": user.job_title, |
|
"education": user.education, |
|
"interests": user.interests, |
|
"additional_info": user.additional_info, |
|
"conversation_style": user.conversation_style, |
|
"is_active": user.is_active, |
|
"is_superuser": user.is_superuser |
|
} |
|
|
|
@router.put("/users/me") |
|
async def update_user_settings( |
|
settings: UserUpdate, |
|
user: User = Depends(current_active_user), |
|
db: AsyncSession = Depends(get_db) |
|
): |
|
if not user: |
|
raise HTTPException(status_code=401, detail="Login required") |
|
|
|
if settings.preferred_model and settings.preferred_model not in MODEL_ALIASES: |
|
raise HTTPException(status_code=400, detail="Invalid model alias") |
|
|
|
if settings.display_name is not None: |
|
user.display_name = settings.display_name |
|
if settings.preferred_model is not None: |
|
user.preferred_model = settings.preferred_model |
|
if settings.job_title is not None: |
|
user.job_title = settings.job_title |
|
if settings.education is not None: |
|
user.education = settings.education |
|
if settings.interests is not None: |
|
user.interests = settings.interests |
|
if settings.additional_info is not None: |
|
user.additional_info = settings.additional_info |
|
if settings.conversation_style is not None: |
|
user.conversation_style = settings.conversation_style |
|
|
|
await db.commit() |
|
await db.refresh(user) |
|
return {"message": "Settings updated successfully", "user": { |
|
"id": user.id, |
|
"email": user.email, |
|
"display_name": user.display_name, |
|
"preferred_model": user.preferred_model, |
|
"job_title": user.job_title, |
|
"education": user.education, |
|
"interests": user.interests, |
|
"additional_info": user.additional_info, |
|
"conversation_style": user.conversation_style, |
|
"is_active": user.is_active, |
|
"is_superuser": user.is_superuser |
|
}} |
|
|