File size: 3,288 Bytes
70006b8
a21ad8a
 
 
c9783ae
a21ad8a
 
25d4a37
 
 
d6458c6
25d4a37
 
d6458c6
 
a21ad8a
d6458c6
a21ad8a
 
d6458c6
a21ad8a
 
 
 
 
 
 
 
70006b8
 
 
 
 
d6458c6
70006b8
 
d6458c6
70006b8
 
 
d6458c6
 
 
70006b8
d6458c6
70006b8
 
 
 
 
 
d6458c6
70006b8
 
 
 
d6458c6
70006b8
 
 
 
d6458c6
c9783ae
 
 
d6458c6
 
c9783ae
 
 
 
 
 
 
 
70006b8
 
 
c9783ae
d6458c6
a21ad8a
 
 
25d4a37
a21ad8a
00c5d27
a21ad8a
 
 
 
 
25d4a37
a21ad8a
 
 
 
 
 
25d4a37
a21ad8a
25d4a37
a21ad8a
25d4a37
c9783ae
a21ad8a
d6458c6
a21ad8a
 
 
 
 
70006b8
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
from fastapi import FastAPI, BackgroundTasks
from pydantic import BaseModel
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModelForCausalLM, BitsAndBytesConfig
import torch
import asyncio
import os
import logging

# Set up logging
logging.basicConfig(level=logging.DEBUG)

# Set cache directory (Change this to a writable directory if necessary)
os.environ["HF_HOME"] = "/tmp/cache"  # You can modify this to any directory with write access

# FastAPI app
app = FastAPI()

# CORS Middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global variables for model and tokenizer
model = None
tokenizer = None
model_loaded = False

# Load model and tokenizer in the background
async def load_model():
    global model, tokenizer, model_loaded
    model_name = "microsoft/phi-2"  # Use a different model if necessary (e.g., "gpt2" for testing)
    
    try:
        logging.info("Starting model and tokenizer loading...")

        # Load tokenizer
        tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir="/tmp/cache", use_fast=True)
        
        # Load model with quantization
        quantization_config = BitsAndBytesConfig(load_in_4bit=True)
        model = AutoModelForCausalLM.from_pretrained(
            model_name,
            torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            device_map="auto",
            quantization_config=quantization_config,
            cache_dir="/tmp/cache"
        )
        
        model_loaded = True
        logging.info("Model and tokenizer loaded successfully")

    except Exception as e:
        logging.error(f"Failed to load model or tokenizer: {e}")
        raise

# Startup event to trigger model loading
@app.on_event("startup")
async def startup_event():
    logging.info("Application starting up...")
    background_tasks = BackgroundTasks()
    background_tasks.add_task(load_model)

@app.on_event("shutdown")
async def shutdown_event():
    logging.info("Application shutting down...")

# Health check endpoint
@app.get("/health")
async def health():
    logging.info("Health check requested")
    status = {"status": "Server is running", "model_loaded": model_loaded}
    return status

# Request body model
class Question(BaseModel):
    question: str

# Async generator for streaming response
async def generate_response_chunks(prompt: str):
    
    input_ids = tokenizer(prompt, return_tensors="pt").input_ids
    input_ids = input_ids.to(model.device)

    output_ids = model.generate(
        input_ids,
        max_new_tokens=300,
        do_sample=True,
        top_p=0.95,
        temperature=0.7,
    )

    output_text = tokenizer.decode(output_ids[0], skip_special_tokens=True)
    answer = output_text[len(prompt):]

    chunk_size = 10
    for i in range(0, len(answer), chunk_size):
        yield answer[i:i + chunk_size]
        await asyncio.sleep(0.01)

# POST endpoint for asking questions
@app.post("/ask")
async def ask(question: Question):
    return StreamingResponse(
        generate_response_chunks(question.question),
        media_type="text/plain"
    )