Spaces:
Sleeping
Sleeping
| import os | |
| import time | |
| import logging | |
| import asyncio | |
| from typing import List, Optional, Dict, Any | |
| from fastapi import FastAPI, HTTPException, Request, status | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel, Field | |
| from transformers import pipeline | |
| from concurrent.futures import ThreadPoolExecutor | |
| # ------------------------- | |
| # Configuration (via env) | |
| # ------------------------- | |
| REPO_ID = os.getenv("REPO_ID", "unsloth/gemma-3-270m-it-GGUF") | |
| MAX_WORKERS = int(os.getenv("MAX_WORKERS", "2")) # ThreadPool workers (reduced for speed) | |
| MAX_CONCURRENT_REQUESTS = int(os.getenv("MAX_CONCURRENT_REQUESTS", "1")) # Reduced for speed | |
| RATE_LIMIT_PER_MIN = int(os.getenv("RATE_LIMIT_PER_MIN", "60")) | |
| ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS", "*") | |
| REQUEST_TIMEOUT = int(os.getenv("REQUEST_TIMEOUT", "120")) | |
| # llama-cpp-python specific settings | |
| N_CTX = int(os.getenv("N_CTX", "2048")) # Context window | |
| N_THREADS = int(os.getenv("N_THREADS", "4")) # CPU threads | |
| N_GPU_LAYERS = int(os.getenv("N_GPU_LAYERS", "0")) # GPU layers (0 for CPU only) | |
| # ------------------------- | |
| # Logging | |
| # ------------------------- | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger("gemma_api") | |
| # ------------------------- | |
| # FastAPI app | |
| # ------------------------- | |
| app = FastAPI(title="Gemma 3 270M ThreadPool API") | |
| origins = ["*"] if ALLOWED_ORIGINS=="*" else ALLOWED_ORIGINS.split(",") | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=origins, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # ------------------------- | |
| # Request / Response Models | |
| # ------------------------- | |
| class Message(BaseModel): | |
| role: str | |
| content: str | |
| class GenerationRequest(BaseModel): | |
| messages: Optional[List[Message]] = None | |
| prompt: Optional[str] = None | |
| max_new_tokens: int = Field(50, ge=1, le=500) # Reduced for faster response | |
| temperature: float = Field(0.7, ge=0.0, le=2.0) | |
| top_p: float = Field(0.9, ge=0.0, le=1.0) | |
| do_sample: bool = Field(True) | |
| # Speed optimization parameters | |
| num_beams: int = Field(1, ge=1, le=4) # Greedy decoding by default | |
| early_stopping: bool = Field(True) | |
| use_cache: bool = Field(True) | |
| class GenerationResponse(BaseModel): | |
| generated_text: str | |
| model: str | |
| runtime_seconds: float | |
| # ------------------------- | |
| # Global objects | |
| # ------------------------- | |
| LLM_MODEL: Optional[Any] = None | |
| executor = ThreadPoolExecutor(max_workers=MAX_WORKERS) | |
| model_semaphore = asyncio.Semaphore(MAX_CONCURRENT_REQUESTS) | |
| # ------------------------- | |
| # Rate limiting (simple token-bucket per IP) | |
| # ------------------------- | |
| class RateLimiter: | |
| def __init__(self, per_minute: int): | |
| self.per_minute = per_minute | |
| self.storage: Dict[str, Dict[str, Any]] = {} | |
| self.lock = asyncio.Lock() | |
| async def allow(self, key: str) -> bool: | |
| now = time.time() | |
| async with self.lock: | |
| rec = self.storage.get(key) | |
| if not rec: | |
| self.storage[key] = {"tokens": self.per_minute - 1, "ts": now} | |
| return True | |
| elapsed = now - rec["ts"] | |
| refill = (elapsed / 60.0) * self.per_minute | |
| rec["tokens"] = min(self.per_minute, rec["tokens"] + refill) | |
| rec["ts"] = now | |
| if rec["tokens"] >= 1: | |
| rec["tokens"] -= 1 | |
| return True | |
| return False | |
| rate_limiter = RateLimiter(RATE_LIMIT_PER_MIN) | |
| # ------------------------- | |
| # Utility functions | |
| # ------------------------- | |
| # build_prompt_from_messages function removed - using chat completion format directly | |
| def generate_sync(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, num_beams: int = 1, early_stopping: bool = True, use_cache: bool = True) -> str: | |
| # transformers pipeline generation parameters | |
| generation_kwargs = { | |
| "max_new_tokens": max_new_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "do_sample": do_sample, | |
| "num_beams": num_beams, | |
| "early_stopping": early_stopping, | |
| "use_cache": use_cache, | |
| } | |
| # Generate using transformers pipeline | |
| response = LLM_MODEL(messages, **generation_kwargs) | |
| return response[0]["generated_text"][-1]["content"] if isinstance(response[0]["generated_text"], list) else response[0]["generated_text"] | |
| async def generate_async(messages: List[Dict[str, str]], max_new_tokens: int, temperature: float, top_p: float, do_sample: bool, num_beams: int = 1, early_stopping: bool = True, use_cache: bool = True) -> str: | |
| loop = asyncio.get_event_loop() | |
| return await loop.run_in_executor( | |
| executor, | |
| lambda: generate_sync(messages, max_new_tokens, temperature, top_p, do_sample, num_beams, early_stopping, use_cache) | |
| ) | |
| # ------------------------- | |
| # Startup | |
| # ------------------------- | |
| async def on_startup(): | |
| global LLM_MODEL | |
| try: | |
| logger.info(f"Loading model from {REPO_ID}...") | |
| LLM_MODEL = pipeline( | |
| "text-generation", | |
| model=REPO_ID, | |
| device_map="auto" if N_GPU_LAYERS > 0 else "cpu" | |
| ) | |
| logger.info("Model loaded successfully.") | |
| # Warm up the model with a dummy request for faster first inference | |
| logger.info("Warming up model...") | |
| dummy_messages = [{"role": "user", "content": "Hello"}] | |
| _ = LLM_MODEL( | |
| dummy_messages, | |
| max_new_tokens=5, | |
| temperature=0.1 | |
| ) | |
| logger.info("Model warmed up successfully.") | |
| except Exception as e: | |
| logger.error(f"Failed to load model {REPO_ID}: {e}") | |
| raise RuntimeError(f"Model loading failed: {e}") from e | |
| # ------------------------- | |
| # Endpoints | |
| # ------------------------- | |
| async def root(): | |
| return {"status": "Gemma 3 API is running 🎉", "model": REPO_ID} | |
| async def health(): | |
| return {"status": "ok", "model_loaded": LLM_MODEL is not None} | |
| async def metrics(): | |
| return { | |
| "model": REPO_ID, | |
| "max_concurrent_requests": MAX_CONCURRENT_REQUESTS, | |
| "current_semaphore_locked": model_semaphore._value if hasattr(model_semaphore, "_value") else None, | |
| "threadpool_workers": MAX_WORKERS | |
| } | |
| async def generate(req: GenerationRequest, request: Request): | |
| client_ip = request.client.host if request.client else "unknown" | |
| allowed = await rate_limiter.allow(client_ip) | |
| if not allowed: | |
| raise HTTPException(status_code=status.HTTP_429_TOO_MANY_REQUESTS, detail="Rate limit exceeded") | |
| # Convert to chat messages format for llama-cpp-python | |
| if req.messages: | |
| chat_messages = [{"role": msg.role, "content": msg.content} for msg in req.messages] | |
| elif req.prompt: | |
| chat_messages = [ | |
| {"role": "system", "content": "You are a helpful assistant."}, | |
| {"role": "user", "content": req.prompt} | |
| ] | |
| else: | |
| raise HTTPException(status_code=400, detail="Provide either 'messages' or 'prompt'.") | |
| start = time.time() | |
| try: | |
| async with model_semaphore: | |
| generated_text = await generate_async( | |
| chat_messages, | |
| max_new_tokens=req.max_new_tokens, | |
| temperature=req.temperature, | |
| top_p=req.top_p, | |
| do_sample=req.do_sample, | |
| num_beams=req.num_beams, | |
| early_stopping=req.early_stopping, | |
| use_cache=req.use_cache | |
| ) | |
| except asyncio.TimeoutError: | |
| raise HTTPException(status_code=504, detail="Generation timed out or concurrency queue full") | |
| runtime = time.time() - start | |
| return GenerationResponse( | |
| generated_text=generated_text, | |
| model=REPO_ID, | |
| runtime_seconds=round(runtime, 3) | |
| ) | |