Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, BackgroundTasks | |
from pydantic import BaseModel | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import StreamingResponse | |
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig | |
import torch | |
import asyncio | |
import os | |
import logging | |
# Set up logging | |
logging.basicConfig(level=logging.DEBUG) | |
# Set cache directory (Change this to a writable directory if necessary) | |
os.environ["HF_HOME"] = "/tmp/cache" # You can modify this to any directory with write access | |
# FastAPI app | |
app = FastAPI() | |
# CORS Middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Global variables for model and tokenizer | |
model = None | |
tokenizer = None | |
model_loaded = False | |
# Load model and tokenizer in the background | |
async def load_model(): | |
global model, tokenizer, model_loaded | |
model_name = "microsoft/phi-2" # Use a different model if necessary (e.g., "gpt2" for testing) | |
try: | |
logging.info("Starting model and tokenizer loading...") | |
# Load tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/cache", use_fast=True) | |
# Load model with quantization | |
quantization_config = BitsAndBytesConfig(load_in_4bit=True) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_name, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map="auto", | |
quantization_config=quantization_config, | |
cache_dir="/tmp/cache" | |
) | |
model_loaded = True | |
logging.info("Model and tokenizer loaded successfully") | |
except Exception as e: | |
logging.error(f"Failed to load model or tokenizer: {e}") | |
raise | |
# Startup event to trigger model loading | |
async def startup_event(): | |
logging.info("Application starting up...") | |
background_tasks = BackgroundTasks() | |
background_tasks.add_task(load_model) | |
async def shutdown_event(): | |
logging.info("Application shutting down...") | |
# Health check endpoint | |
async def health(): | |
logging.info("Health check requested") | |
status = {"status": "Server is running", "model_loaded": model_loaded} | |
return status | |
# Request body model | |
class Question(BaseModel): | |
question: str | |
# Async generator for streaming response | |
async def generate_response_chunks(prompt: str): | |
input_ids = tokenizer(prompt, return_tensors="pt").input_ids | |
input_ids = input_ids.to(model.device) | |
output_ids = model.generate( | |
input_ids, | |
max_new_tokens=300, | |
do_sample=True, | |
top_p=0.95, | |
temperature=0.7, | |
) | |
output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True) | |
answer = output_text[len(prompt):] | |
chunk_size = 10 | |
for i in range(0, len(answer), chunk_size): | |
yield answer[i:i + chunk_size] | |
await asyncio.sleep(0.01) | |
# POST endpoint for asking questions | |
async def ask(question: Question): | |
return StreamingResponse( | |
generate_response_chunks(question.question), | |
media_type="text/plain" | |
) | |