import torch from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, AutoConfig from peft import PeftModel from fastapi import FastAPI, HTTPException from pydantic import BaseModel import uvicorn import os import time # For checking model load status # --- Global Variables for Model and Tokenizer --- model = None tokenizer = None model_loaded_successfully = False # Flag to indicate model status device = "cuda" if torch.cuda.is_available() else "cpu" print(f"--- Initializing on Device: {device} ---") # --- Pydantic Model for Request Body --- class PromptRequest(BaseModel): prompt: str max_new_tokens: int = 256 temperature: float = 0.7 top_p: float = 0.9 top_k: int = 50 # --- FastAPI App Initialization --- app = FastAPI() def load_model_and_tokenizer(): global model, tokenizer, model_loaded_successfully base_model_id = os.environ.get("BASE_MODEL_ID") adapter_path = os.environ.get("ADAPTER_PATH") hf_token = os.environ.get("HF_TOKEN") if not base_model_id: print("CRITICAL ERROR: BASE_MODEL_ID environment variable not set.") # In a real app, you might want to prevent startup or handle this more gracefully return if not adapter_path: print("CRITICAL ERROR: ADAPTER_PATH environment variable not set.") return print(f"Using device: {device}") print(f"Attempting to load base model: {base_model_id}") print(f"Attempting to load adapter from: {adapter_path}") try: # --- Load Tokenizer --- print(f"Loading tokenizer...") try: tokenizer = AutoTokenizer.from_pretrained(adapter_path, token=hf_token, trust_remote_code=True) print(f"Loaded tokenizer from adapter path: {adapter_path}") except Exception as e: print(f"Could not load tokenizer from adapter path: {e}. Loading from base model path: {base_model_id}") tokenizer = AutoTokenizer.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True) if tokenizer.pad_token is None: if tokenizer.eos_token is not None: print("Setting pad_token to eos_token.") tokenizer.pad_token = tokenizer.eos_token else: print("Adding new pad_token '[PAD]'.") tokenizer.add_special_tokens({'pad_token': '[PAD]'}) tokenizer.padding_side = "left" # --- Configure Quantization --- print("Configuring 4-bit quantization...") compute_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() and device == "cuda" else torch.float16 bnb_config = None if device == "cuda": bnb_config = BitsAndBytesConfig( load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_compute_dtype=compute_dtype, bnb_4bit_use_double_quant=True, ) print(f"Using BNB config with compute_dtype: {compute_dtype}") else: print("Running on CPU, BNB quantization will not be applied.") # --- Load Base Model with Quantization --- print(f"Loading base model: {base_model_id}...") config = AutoConfig.from_pretrained(base_model_id, token=hf_token, trust_remote_code=True) if getattr(config, "pretraining_tp", 1) != 1: print(f"Overriding pretraining_tp from {getattr(config, 'pretraining_tp', 'N/A')} to 1.") config.pretraining_tp = 1 base_model_instance = AutoModelForCausalLM.from_pretrained( base_model_id, config=config, quantization_config=bnb_config if device == "cuda" else None, device_map={"": device}, token=hf_token, trust_remote_code=True, low_cpu_mem_usage=True if device == "cuda" else False ) print("Base model loaded.") if tokenizer.pad_token_id is not None and tokenizer.pad_token_id >= base_model_instance.config.vocab_size: print("Resizing token embeddings for base model.") base_model_instance.resize_token_embeddings(len(tokenizer)) # --- Load LoRA Adapter --- print(f"Loading LoRA adapter from: {adapter_path}...") model = PeftModel.from_pretrained(base_model_instance, adapter_path) model.eval() print("LoRA adapter loaded and model is in eval mode.") print(f"Model is on device: {model.device}") model_loaded_successfully = True # Set flag on successful load print("Model and tokenizer loaded successfully.") except Exception as e: print(f"CRITICAL ERROR during model/tokenizer loading: {e}") model_loaded_successfully = False # Optionally, re-raise or handle to prevent app from starting if model load fails. # For now, it will print error and the /generate endpoint will show model not loaded. # And the health check will show model not ready. @app.on_event("startup") async def startup_event(): print("Server startup event: Initiating model and tokenizer loading...") # Model loading can take time, so it's done here. # Health checks might hit the server before this completes. load_model_and_tokenizer() if model_loaded_successfully: print("Model loading process completed successfully within startup event.") else: print("Model loading process encountered an error or did not complete within startup event.") # <<< --- ADDED HEALTH CHECK ENDPOINT --- >>> @app.get("/") async def health_check(): """Basic health check endpoint.""" if model_loaded_successfully and model is not None and tokenizer is not None: return {"status": "ok", "message": "Model is loaded and ready."} else: # Return a 503 if model isn't ready yet, so Spaces knows it's still starting up # or if loading failed. raise HTTPException(status_code=503, detail="Model is not loaded or still loading.") @app.get("/health") # Common alternative health check path async def health_check_alternative(): return await health_check() # <<< --- END OF HEALTH CHECK ENDPOINT --- >>> @app.post("/generate/") async def generate_text(request: PromptRequest): global model, tokenizer, model_loaded_successfully if not model_loaded_successfully or model is None or tokenizer is None: raise HTTPException(status_code=503, detail="Model is not loaded or still loading. Please try again shortly or check server logs.") try: inputs = tokenizer(request.prompt, return_tensors="pt", padding=True, truncation=True, max_length=512).to(device) print(f"Received prompt: {request.prompt}") print("Generating...") with torch.no_grad(): outputs = model.generate( **inputs, max_new_tokens=request.max_new_tokens, num_return_sequences=1, do_sample=True, temperature=request.temperature, top_p=request.top_p, top_k=request.top_k, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id ) prompt_tokens = inputs.input_ids.shape[-1] if outputs[0].size(0) > prompt_tokens: generated_sequence = outputs[0][prompt_tokens:] generated_text = tokenizer.decode(generated_sequence, skip_special_tokens=True) else: generated_text = "" print(f"Generated text: {generated_text}") return {"generated_text": generated_text} except Exception as e: print(f"Error during generation: {e}") raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": print("Starting Uvicorn server directly from app.py for local testing...") port = int(os.environ.get("PORT", 8000)) host = "0.0.0.0" print(f"Uvicorn will attempt to listen on host {host}, port {port}") print("Set BASE_MODEL_ID and ADAPTER_PATH environment variables for model loading.") # The @app.on_event("startup") will be called by Uvicorn. try: uvicorn.run(app, host=host, port=port) except Exception as e: print(f"Error attempting to run uvicorn: {e}")