import os import time import logging import asyncio from typing import List, Optional, Dict, Any from fastapi import FastAPI, HTTPException, Request, status from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from transformers import pipeline from concurrent.futures import ThreadPoolExecutor # ------------------------- # Configuration (via env) # ------------------------- REPO_ID = os.getenv("REPO_ID", "unsloth/gemma-3-270m-it-GGUF") MAX_WORKERS = int(os.getenv("MAX_WORKERS", "2")) # ThreadPool workers (reduced for speed) MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", "1")) # Reduced for speed RATE_LIMIT_PER_MIN = int(os.getenv("RATE_LIMIT_PER_MIN", "60")) ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "*") REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "120")) # llama-cpp-python specific settings N_CTX = int(os.getenv("N_CTX", "2048")) # Context window N_THREADS = int(os.getenv("N_THREADS", "4")) # CPU threads N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0")) # GPU layers (0 for CPU only) # ------------------------- # Logging # ------------------------- logging.basicConfig(level=logging.INFO) logger = logging.getLogger("gemma_api") # ------------------------- # FastAPI app # ------------------------- app = FastAPI(title="Gemma 3 270M ThreadPool API") origins = ["*"] if ALLOWED_ORIGINS=="*" else ALLOWED_ORIGINS.split(",") app.add_middleware( CORSMiddleware, allow_origins=origins, allow_methods=["*"], allow_headers=["*"], ) # ------------------------- # Request / Response Models # ------------------------- class Message(BaseModel): role: str content: str class GenerationRequest(BaseModel): messages: Optional[List[Message]] = None prompt: Optional[str] = None max_new_tokens: int = Field(50, ge=1, le=500) # Reduced for faster response temperature: float = Field(0.7, ge=0.0, le=2.0) top_p: float = Field(0.9, ge=0.0, le=1.0) do_sample: bool = Field(True) # Speed optimization parameters num_beams: int = Field(1, ge=1, le=4) # Greedy decoding by default early_stopping: bool = Field(True) use_cache: bool = Field(True) class GenerationResponse(BaseModel): generated_text: str model: str runtime_seconds: float # ------------------------- # Global objects # ------------------------- LLM_MODEL: Optional[Any] = None executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) model_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) # ------------------------- # Rate limiting (simple token-bucket per IP) # ------------------------- class RateLimiter: def __init__(self, per_minute: int): self.per_minute = per_minute self.storage: Dict[str, Dict[str, Any]] = {} self.lock = asyncio.Lock() async def allow(self, key: str) -> bool: now = time.time() async with self.lock: rec = self.storage.get(key) if not rec: self.storage[key] = {"tokens": self.per_minute - 1, "ts": now} return True elapsed = now - rec["ts"] refill = (elapsed / 60.0) * self.per_minute rec["tokens"] = min(self.per_minute, rec["tokens"] + refill) rec["ts"] = now if rec["tokens"] >= 1: rec["tokens"] -= 1 return True return False rate_limiter = RateLimiter(RATE_LIMIT_PER_MIN) # ------------------------- # Utility functions # ------------------------- # build_prompt_from_messages function removed - using chat completion format directly def generate_sync(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, num_beams: int = 1, early_stopping: bool = True, use_cache: bool = True) -> str: # transformers pipeline generation parameters generation_kwargs = { "max_new_tokens": max_new_tokens, "temperature": temperature, "top_p": top_p, "do_sample": do_sample, "num_beams": num_beams, "early_stopping": early_stopping, "use_cache": use_cache, } # Generate using transformers pipeline response = LLM_MODEL(messages, **generation_kwargs) return response[0]["generated_text"][-1]["content"] if isinstance(response[0]["generated_text"], list) else response[0]["generated_text"] async def generate_async(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, num_beams: int = 1, early_stopping: bool = True, use_cache: bool = True) -> str: loop = asyncio.get_event_loop() return await loop.run_in_executor( executor, lambda: generate_sync(messages, max_new_tokens, temperature, top_p, do_sample, num_beams, early_stopping, use_cache) ) # ------------------------- # Startup # ------------------------- @app.on_event("startup") async def on_startup(): global LLM_MODEL try: logger.info(f"Loading model from {REPO_ID}...") LLM_MODEL = pipeline( "text-generation", model=REPO_ID, device_map="auto" if N_GPU_LAYERS > 0 else "cpu" ) logger.info("Model loaded successfully.") # Warm up the model with a dummy request for faster first inference logger.info("Warming up model...") dummy_messages = [{"role": "user", "content": "Hello"}] _ = LLM_MODEL( dummy_messages, max_new_tokens=5, temperature=0.1 ) logger.info("Model warmed up successfully.") except Exception as e: logger.error(f"Failed to load model {REPO_ID}: {e}") raise RuntimeError(f"Model loading failed: {e}") from e # ------------------------- # Endpoints # ------------------------- @app.get("/") async def root(): return {"status": "Gemma 3 API is running 🎉", "model": REPO_ID} @app.get("/health") async def health(): return {"status": "ok", "model_loaded": LLM_MODEL is not None} @app.get("/metrics") async def metrics(): return { "model": REPO_ID, "max_concurrent_requests": MAX_CONCURRENT_REQUESTS, "current_semaphore_locked": model_semaphore._value if hasattr(model_semaphore, "_value") else None, "threadpool_workers": MAX_WORKERS } @app.post("/generate", response_model=GenerationResponse) async def generate(req: GenerationRequest, request: Request): client_ip = request.client.host if request.client else "unknown" allowed = await rate_limiter.allow(client_ip) if not allowed: raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Rate limit exceeded") # Convert to chat messages format for llama-cpp-python if req.messages: chat_messages = [{"role": msg.role, "content": msg.content} for msg in req.messages] elif req.prompt: chat_messages = [ {"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": req.prompt} ] else: raise HTTPException(status_code=400, detail="Provide either 'messages' or 'prompt'.") start = time.time() try: async with model_semaphore: generated_text = await generate_async( chat_messages, max_new_tokens=req.max_new_tokens, temperature=req.temperature, top_p=req.top_p, do_sample=req.do_sample, num_beams=req.num_beams, early_stopping=req.early_stopping, use_cache=req.use_cache ) except asyncio.TimeoutError: raise HTTPException(status_code=504, detail="Generation timed out or concurrency queue full") runtime = time.time() - start return GenerationResponse( generated_text=generated_text, model=REPO_ID, runtime_seconds=round(runtime, 3) )