from fastapi import FastAPI, HTTPException from pydantic import BaseModel from llama_cpp import Llama import os import uvicorn from typing import Optional, List import logging from contextlib import asynccontextmanager # Set up logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # Global model variable model = None # Lifespan manager to load the model on startup @asynccontextmanager async def lifespan(app: FastAPI): global model model_gguf_path = os.path.join("./model", "gema-4b-indra10k-model1-q4_k_m.gguf") try: if not os.path.exists(model_gguf_path): raise RuntimeError(f"Model file not found at: {model_gguf_path}") logger.info(f"Loading model from: {model_gguf_path}") # Load the model using llama-cpp-python model = Llama( model_path=model_gguf_path, n_ctx=2048, # Context length n_gpu_layers=0, # Set to a positive number if GPU is available n_threads=os.cpu_count() or 1, verbose=True, ) logger.info("Model loaded successfully using llama-cpp-python!") except Exception as e: logger.error(f"Failed to load model: {e}") raise e yield # Cleanup code if needed on shutdown logger.info("Application is shutting down.") app = FastAPI(title="Gema 4B Model API", version="1.0.0", lifespan=lifespan) # Request model class TextRequest(BaseModel): inputs: str system_prompt: Optional[str] = None max_tokens: Optional[int] = 512 temperature: Optional[float] = 0.7 top_k: Optional[int] = 50 top_p: Optional[float] = 0.9 repeat_penalty: Optional[float] = 1.1 stop: Optional[List[str]] = None # Response model class TextResponse(BaseModel): generated_text: str @app.post("/generate", response_model=TextResponse) async def generate_text(request: TextRequest): if model is None: raise HTTPException(status_code=503, detail="Model is not ready or failed to load.") try: # Create prompt if request.system_prompt: full_prompt = f"{request.system_prompt}\n\nUser: {request.inputs}\nAssistant:" else: full_prompt = request.inputs # Generate text using llama-cpp-python syntax output = model( prompt=full_prompt, max_tokens=request.max_tokens, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, repeat_penalty=request.repeat_penalty, stop=request.stop or [] ) # Extract the generated text from the response structure generated_text = output['choices'][0]['text'].strip() return TextResponse(generated_text=generated_text) except Exception as e: logger.error(f"Generation error: {e}") raise HTTPException(status_code=500, detail=f"Generation failed: {str(e)}") @app.get("/health") async def health_check(): return {"status": "healthy", "model_loaded": model is not None} @app.get("/") async def root(): return {"message": "Gema 4B Model API", "docs": "/docs"} if __name__ == "__main__": uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")