|
|
|
|
|
|
|
|
|
import os |
|
import logging |
|
from fastapi import FastAPI, Request, Depends, HTTPException, status, Query |
|
from fastapi.responses import HTMLResponse, RedirectResponse, PlainTextResponse, FileResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from fastapi.templating import Jinja2Templates |
|
from starlette.middleware.base import BaseHTTPMiddleware |
|
from starlette.middleware.sessions import SessionMiddleware |
|
from fastapi.openapi.docs import get_swagger_ui_html |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from api.endpoints import router as api_router |
|
from api.auth import fastapi_users, auth_backend, current_active_user, get_auth_router |
|
from api.database import get_db, User |
|
from api.models import UserRead, UserCreate, UserUpdate |
|
from motor.motor_asyncio import AsyncIOMotorClient |
|
from pydantic import BaseModel |
|
from typing import List |
|
from contextlib import asynccontextmanager |
|
import uvicorn |
|
import markdown2 |
|
from sqlalchemy.orm import Session |
|
from pathlib import Path |
|
from hashlib import md5 |
|
from datetime import datetime |
|
from httpx_oauth.exceptions import GetIdEmailError |
|
import re |
|
from init_db import init_db |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
logger.info("Files in current dir: %s", os.listdir(os.getcwd())) |
|
|
|
|
|
HF_TOKEN = os.getenv("HF_TOKEN") |
|
if not HF_TOKEN: |
|
logger.error("HF_TOKEN is not set in environment variables.") |
|
raise ValueError("HF_TOKEN is required for Inference API.") |
|
|
|
MONGO_URI = os.getenv("MONGODB_URI") |
|
if not MONGO_URI: |
|
logger.error("MONGODB_URI is not set in environment variables.") |
|
raise ValueError("MONGODB_URI is required for MongoDB.") |
|
|
|
JWT_SECRET = os.getenv("JWT_SECRET") |
|
if not JWT_SECRET or len(JWT_SECRET) < 32: |
|
logger.error("JWT_SECRET is not set or too short.") |
|
raise ValueError("JWT_SECRET is required (at least 32 characters).") |
|
|
|
|
|
client = AsyncIOMotorClient(MONGO_URI) |
|
mongo_db = client["hager"] |
|
session_message_counts = mongo_db["session_message_counts"] |
|
|
|
|
|
async def setup_mongo_index(): |
|
await session_message_counts.create_index("session_id", unique=True) |
|
|
|
|
|
templates = Jinja2Templates(directory="templates") |
|
templates.env.filters['markdown'] = lambda text: markdown2.markdown(text) |
|
|
|
|
|
class BlogPost(BaseModel): |
|
id: str |
|
title: str |
|
content: str |
|
author: str |
|
date: str |
|
created_at: str |
|
|
|
|
|
QUEUE_SIZE = int(os.getenv("QUEUE_SIZE", 80)) |
|
CONCURRENCY_LIMIT = int(os.getenv("CONCURRENCY_LIMIT", 20)) |
|
|
|
|
|
@asynccontextmanager |
|
async def lifespan(app: FastAPI): |
|
init_db() |
|
await setup_mongo_index() |
|
yield |
|
|
|
app = FastAPI(title="MGZon Chatbot API", lifespan=lifespan) |
|
|
|
|
|
app.add_middleware(SessionMiddleware, secret_key=JWT_SECRET) |
|
|
|
|
|
os.makedirs("static", exist_ok=True) |
|
app.mount("/static", StaticFiles(directory="static"), name="static") |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=[ |
|
"https://mgzon-mgzon-app.hf.space", |
|
"http://localhost:7860", |
|
], |
|
allow_credentials=True, |
|
allow_methods=["GET", "POST", "OPTIONS"], |
|
allow_headers=["Accept", "Content-Type", "Authorization"], |
|
) |
|
|
|
|
|
app.include_router(api_router) |
|
get_auth_router(app) |
|
|
|
|
|
@app.get("/debug/routes", response_class=PlainTextResponse) |
|
async def debug_routes(): |
|
routes = [] |
|
for route in app.routes: |
|
methods = getattr(route, "methods", []) |
|
path = getattr(route, "path", "Unknown") |
|
routes.append(f"{methods} {path}") |
|
return "\n".join(sorted(routes)) |
|
|
|
|
|
class NotFoundMiddleware(BaseHTTPMiddleware): |
|
async def dispatch(self, request: Request, call_next): |
|
try: |
|
response = await call_next(request) |
|
if response.status_code == 404: |
|
logger.warning(f"404 Not Found: {request.url}") |
|
return templates.TemplateResponse("404.html", {"request": request}, status_code=404) |
|
return response |
|
except Exception as e: |
|
logger.exception(f"Error processing request {request.url}: {e}") |
|
return templates.TemplateResponse("500.html", {"request": request, "error": str(e)}, status_code=500) |
|
|
|
app.add_middleware(NotFoundMiddleware) |
|
|
|
|
|
@app.exception_handler(GetIdEmailError) |
|
async def handle_oauth_error(request: Request, exc: GetIdEmailError): |
|
logger.error(f"OAuth error: {exc}") |
|
error_message = "Failed to authenticate with OAuth. Please try again or contact support." |
|
return RedirectResponse( |
|
url=f"/login?error={error_message}", |
|
status_code=302 |
|
) |
|
|
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def root(request: Request, user: User = Depends(current_active_user)): |
|
return templates.TemplateResponse("index.html", {"request": request, "user": user}) |
|
|
|
|
|
@app.get("/google97468ef1f6b6e804.html", response_class=PlainTextResponse) |
|
async def google_verification(): |
|
return "google-site-verification: google97468ef1f6b6e804.html" |
|
|
|
|
|
@app.get("/login", response_class=HTMLResponse) |
|
async def login_page(request: Request, user: User = Depends(current_active_user)): |
|
if user: |
|
return RedirectResponse(url="/chat", status_code=302) |
|
return templates.TemplateResponse("login.html", {"request": request}) |
|
|
|
|
|
@app.get("/register", response_class=HTMLResponse) |
|
async def register_page(request: Request, user: User = Depends(current_active_user)): |
|
if user: |
|
return RedirectResponse(url="/chat", status_code=302) |
|
return templates.TemplateResponse("register.html", {"request": request}) |
|
|
|
|
|
@app.get("/chat", response_class=HTMLResponse) |
|
async def chat(request: Request, user: User = Depends(current_active_user)): |
|
return templates.TemplateResponse("chat.html", {"request": request, "user": user}) |
|
|
|
|
|
@app.get("/chat/{conversation_id}", response_class=HTMLResponse) |
|
async def chat_conversation( |
|
request: Request, |
|
conversation_id: str, |
|
user: User = Depends(current_active_user), |
|
db: Session = Depends(get_db) |
|
): |
|
if not user: |
|
return RedirectResponse(url="/login", status_code=302) |
|
conversation = db.query(Conversation).filter( |
|
Conversation.conversation_id == conversation_id, |
|
Conversation.user_id == user.id |
|
).first() |
|
if not conversation: |
|
raise HTTPException(status_code=404, detail="Conversation not found") |
|
return templates.TemplateResponse( |
|
"chat.html", |
|
{ |
|
"request": request, |
|
"user": user, |
|
"conversation_id": conversation.conversation_id, |
|
"conversation_title": conversation.title or "Untitled Conversation" |
|
} |
|
) |
|
|
|
|
|
@app.get("/about", response_class=HTMLResponse) |
|
async def about(request: Request, user: User = Depends(current_active_user)): |
|
return templates.TemplateResponse("about.html", {"request": request, "user": user}) |
|
|
|
|
|
@app.get("/static/{path:path}") |
|
async def serve_static(path: str): |
|
clean_path = re.sub(r'\?.*', '', path) |
|
file_path = Path("static") / clean_path |
|
if not file_path.exists(): |
|
raise HTTPException(status_code=404, detail="File not found") |
|
cache_duration = 31536000 if not clean_path.endswith(('.js', '.css')) else 3600 |
|
with open(file_path, "rb") as f: |
|
file_hash = md5(f.read()).hexdigest() |
|
headers = { |
|
"Cache-Control": f"public, max-age={cache_duration}", |
|
"ETag": file_hash, |
|
"Last-Modified": datetime.utcfromtimestamp(file_path.stat().st_mtime).strftime('%a, %d %b %Y %H:%M:%S GMT') |
|
} |
|
return FileResponse(file_path, headers=headers) |
|
|
|
|
|
@app.get("/blog", response_class=HTMLResponse) |
|
async def blog(request: Request, skip: int = Query(0, ge=0), limit: int = Query(10, ge=1, le=100)): |
|
posts = await mongo_db.blog_posts.find().skip(skip).limit(limit).to_list(limit) |
|
return templates.TemplateResponse("blog.html", {"request": request, "posts": posts}) |
|
|
|
|
|
@app.get("/blog/{post_id}", response_class=HTMLResponse) |
|
async def blog_post(request: Request, post_id: str): |
|
post = await mongo_db.blog_posts.find_one({"id": post_id}) |
|
if not post: |
|
raise HTTPException(status_code=404, detail="Post not found") |
|
return templates.TemplateResponse("blog_post.html", {"request": request, "post": post}) |
|
|
|
|
|
@app.get("/docs", response_class=HTMLResponse) |
|
async def docs(request: Request): |
|
return templates.TemplateResponse("docs.html", {"request": request}) |
|
|
|
|
|
@app.get("/swagger", response_class=HTMLResponse) |
|
async def swagger_ui(): |
|
return get_swagger_ui_html(openapi_url="/openapi.json", title="MGZon API Documentation") |
|
|
|
|
|
@app.get("/sitemap.xml", response_class=PlainTextResponse) |
|
async def sitemap(): |
|
posts = await mongo_db.blog_posts.find().to_list(100) |
|
current_date = datetime.utcnow().strftime('%Y-%m-%d') |
|
xml = '<?xml version="1.0" encoding="UTF-8"?>\n' |
|
xml += '<urlset xmlns="http://www.sitemaps.org/schemas/sitemap/0.9">\n' |
|
xml += ' <url>\n' |
|
xml += ' <loc>https://mgzon-mgzon-app.hf.space/</loc>\n' |
|
xml += f' <lastmod>{current_date}</lastmod>\n' |
|
xml += ' <changefreq>daily</changefreq>\n' |
|
xml += ' <priority>1.0</priority>\n' |
|
xml += ' </url>\n' |
|
xml += ' <url>\n' |
|
xml += ' <loc>https://mgzon-mgzon-app.hf.space/chat</loc>\n' |
|
xml += f' <lastmod>{current_date}</lastmod>\n' |
|
xml += ' <changefreq>daily</changefreq>\n' |
|
xml += ' <priority>0.8</priority>\n' |
|
xml += ' </url>\n' |
|
xml += ' <url>\n' |
|
xml += ' <loc>https://mgzon-mgzon-app.hf.space/about</loc>\n' |
|
xml += f' <lastmod>{current_date}</lastmod>\n' |
|
xml += ' <changefreq>weekly</changefreq>\n' |
|
xml += ' <priority>0.7</priority>\n' |
|
xml += ' </url>\n' |
|
xml += ' <url>\n' |
|
xml += ' <loc>https://mgzon-mgzon-app.hf.space/login</loc>\n' |
|
xml += f' <lastmod>{current_date}</lastmod>\n' |
|
xml += ' <changefreq>weekly</changefreq>\n' |
|
xml += ' <priority>0.8</priority>\n' |
|
xml += ' </url>\n' |
|
xml += ' <url>\n' |
|
xml += ' <loc>https://mgzon-mgzon-app.hf.space/register</loc>\n' |
|
xml += f' <lastmod>{current_date}</lastmod>\n' |
|
xml += ' <changefreq>weekly</changefreq>\n' |
|
xml += ' <priority>0.8</priority>\n' |
|
xml += ' </url>\n' |
|
xml += ' <url>\n' |
|
xml += ' <loc>https://mgzon-mgzon-app.hf.space/docs</loc>\n' |
|
xml += f' <lastmod>{current_date}</lastmod>\n' |
|
xml += ' <changefreq>weekly</changefreq>\n' |
|
xml += ' <priority>0.9</priority>\n' |
|
xml += ' </url>\n' |
|
xml += ' <url>\n' |
|
xml += ' <loc>https://mgzon-mgzon-app.hf.space/blog</loc>\n' |
|
xml += f' <lastmod>{current_date}</lastmod>\n' |
|
xml += ' <changefreq>daily</changefreq>\n' |
|
xml += ' <priority>0.9</priority>\n' |
|
xml += ' </url>\n' |
|
for post in posts: |
|
xml += ' <url>\n' |
|
xml += f' <loc>https://mgzon-mgzon-app.hf.space/blog/{post["id"]}</loc>\n' |
|
xml += f' <lastmod>{post["date"]}</lastmod>\n' |
|
xml += ' <changefreq>weekly</changefreq>\n' |
|
xml += ' <priority>0.9</priority>\n' |
|
xml += ' </url>\n' |
|
xml += '</urlset>' |
|
return xml |
|
|
|
|
|
@app.get("/gradio", response_class=RedirectResponse) |
|
async def launch_chatbot(): |
|
return RedirectResponse(url="/chat", status_code=302) |
|
|
|
if __name__ == "__main__": |
|
uvicorn.run(app, host="0.0.0.0", port=int(os.getenv("PORT", 7860))) |
|
|