from fastapi import FastAPI, BackgroundTasks from pydantic import BaseModel from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import StreamingResponse from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig import torch import asyncio import os import logging # Set up logging logging.basicConfig(level=logging.DEBUG) # Set cache directory (Change this to a writable directory if necessary) os.environ["HF_HOME"] = "/tmp/cache" # You can modify this to any directory with write access # FastAPI app app = FastAPI() # CORS Middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Global variables for model and tokenizer model = None tokenizer = None model_loaded = False # Load model and tokenizer in the background async def load_model(): global model, tokenizer, model_loaded model_name = "microsoft/phi-2" # Use a different model if necessary (e.g., "gpt2" for testing) try: logging.info("Starting model and tokenizer loading...") # Load tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/cache", use_fast=True) # Load model with quantization quantization_config = BitsAndBytesConfig(load_in_4bit=True) model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map="auto", quantization_config=quantization_config, cache_dir="/tmp/cache" ) model_loaded = True logging.info("Model and tokenizer loaded successfully") except Exception as e: logging.error(f"Failed to load model or tokenizer: {e}") raise # Startup event to trigger model loading @app.on_event("startup") async def startup_event(): logging.info("Application starting up...") background_tasks = BackgroundTasks() background_tasks.add_task(load_model) @app.on_event("shutdown") async def shutdown_event(): logging.info("Application shutting down...") # Health check endpoint @app.get("/health") async def health(): logging.info("Health check requested") status = {"status": "Server is running", "model_loaded": model_loaded} return status # Request body model class Question(BaseModel): question: str # Async generator for streaming response async def generate_response_chunks(prompt: str): input_ids = tokenizer(prompt, return_tensors="pt").input_ids input_ids = input_ids.to(model.device) output_ids = model.generate( input_ids, max_new_tokens=300, do_sample=True, top_p=0.95, temperature=0.7, ) output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) answer = output_text[len(prompt):] chunk_size = 10 for i in range(0, len(answer), chunk_size): yield answer[i:i + chunk_size] await asyncio.sleep(0.01) # POST endpoint for asking questions @app.post("/ask") async def ask(question: Question): return StreamingResponse( generate_response_chunks(question.question), media_type="text/plain" )