babar-azam / app.py
abdullahalioo's picture
Update app.py
00c5d27 verified
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import asyncio
import os
import logging
# Set up logging
logging.basicConfig(level=logging.DEBUG)
# Set cache directory (Change this to a writable directory if necessary)
os.environ["HF_HOME"] = "/tmp/cache" # You can modify this to any directory with write access
# FastAPI app
app = FastAPI()
# CORS Middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables for model and tokenizer
model = None
tokenizer = None
model_loaded = False
# Load model and tokenizer in the background
async def load_model():
global model, tokenizer, model_loaded
model_name = "microsoft/phi-2" # Use a different model if necessary (e.g., "gpt2" for testing)
try:
logging.info("Starting model and tokenizer loading...")
# Load tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/cache", use_fast=True)
# Load model with quantization
quantization_config = BitsAndBytesConfig(load_in_4bit=True)
model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
device_map="auto",
quantization_config=quantization_config,
cache_dir="/tmp/cache"
)
model_loaded = True
logging.info("Model and tokenizer loaded successfully")
except Exception as e:
logging.error(f"Failed to load model or tokenizer: {e}")
raise
# Startup event to trigger model loading
@app.on_event("startup")
async def startup_event():
logging.info("Application starting up...")
background_tasks = BackgroundTasks()
background_tasks.add_task(load_model)
@app.on_event("shutdown")
async def shutdown_event():
logging.info("Application shutting down...")
# Health check endpoint
@app.get("/health")
async def health():
logging.info("Health check requested")
status = {"status": "Server is running", "model_loaded": model_loaded}
return status
# Request body model
class Question(BaseModel):
question: str
# Async generator for streaming response
async def generate_response_chunks(prompt: str):
input_ids = tokenizer(prompt, return_tensors="pt").input_ids
input_ids = input_ids.to(model.device)
output_ids = model.generate(
input_ids,
max_new_tokens=300,
do_sample=True,
top_p=0.95,
temperature=0.7,
)
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
answer = output_text[len(prompt):]
chunk_size = 10
for i in range(0, len(answer), chunk_size):
yield answer[i:i + chunk_size]
await asyncio.sleep(0.01)
# POST endpoint for asking questions
@app.post("/ask")
async def ask(question: Question):
return StreamingResponse(
generate_response_chunks(question.question),
media_type="text/plain"
)