# app.py # FastAPI backend for a Hugging Face Space (CPU tier) # • Only MedGemma-4B-IT, no Parakeet, no tool-calling # • Reads HF_TOKEN from Space secrets, uses /tmp for writable cache # • /chat endpoint expects {"messages":[{"role":"user","content": "..."}]} import os, pathlib, uuid from typing import List, Optional # Set all cache-related environment variables BEFORE importing transformers CACHE_DIR = "/tmp/hf_cache" os.makedirs(CACHE_DIR, exist_ok=True) # Set ALL possible cache environment variables os.environ["HF_HOME"] = CACHE_DIR os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR os.environ["HF_HUB_CACHE"] = CACHE_DIR os.environ["HUGGINGFACE_HUB_CACHE"] = CACHE_DIR os.environ["HF_DATASETS_CACHE"] = CACHE_DIR import torch from fastapi import FastAPI, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from pydantic import BaseModel from transformers import pipeline # Get token HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") # ------------------------------------------------------------ # 2. Simple Pydantic request model # ------------------------------------------------------------ class Message(BaseModel): role: str content: Optional[str] class ChatCompletionRequest(BaseModel): messages: List[Message] # ------------------------------------------------------------ # 3. Lazy MedGemma loader with memory optimization # ------------------------------------------------------------ DEVICE = "cpu" DTYPE = torch.float32 medgemma_pipe = None def get_medgemma(): global medgemma_pipe if medgemma_pipe is None: try: print("🚀 Loading MedGemma-4B-IT with memory optimization...") print(f"Using cache directory: {CACHE_DIR}") medgemma_pipe = pipeline( "text-generation", model="google/medgemma-4b-it", torch_dtype=torch.float16, # Use float16 to reduce memory device_map="auto", token=HF_TOKEN, cache_dir=CACHE_DIR, trust_remote_code=True, # Memory optimization low_cpu_mem_usage=True, max_memory={0: "6GB", "cpu": "8GB"}, # Limit memory usage ) print("✅ MedGemma loaded successfully!") except Exception as e: print(f"❌ Error loading MedGemma: {e}") print(f"Cache directory exists: {os.path.exists(CACHE_DIR)}") print(f"Cache directory writable: {os.access(CACHE_DIR, os.W_OK)}") print(f"HF_TOKEN present: {bool(HF_TOKEN)}") medgemma_pipe = None return medgemma_pipe # ------------------------------------------------------------ # 4. FastAPI app with permissive CORS (for Replit frontend) # ------------------------------------------------------------ app = FastAPI(title="MedGemma Radiology Chat") app.add_middleware( CORSMiddleware, allow_origins=["*"], # adjust in prod allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # ------------------------------------------------------------ # 5. System prompt # ------------------------------------------------------------ SYSTEM_PROMPT = ( "You are MedGemma, a medical vision-language assistant specialised in radiology. " "When given a patient case or study description, respond with a concise, professional " "radiology report. Use headings such as FINDINGS and IMPRESSION." ) # ------------------------------------------------------------ # 6. /chat endpoint # ------------------------------------------------------------ @app.post("/chat") async def chat(request: Request): try: body = await request.json() payload = ChatCompletionRequest(**body) user_msg = payload.messages[-1].content or "" prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}\n\nRadiology Report:\n" pipe = get_medgemma() if pipe is None: return JSONResponse( { "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "choices": [{ "message": { "role": "assistant", "content": "MedGemma model is unavailable. " "Check your gated-model access and HF_TOKEN.", } }], }, status_code=503, ) try: result = pipe( prompt, max_new_tokens=256, do_sample=True, temperature=0.7, pad_token_id=pipe.tokenizer.eos_token_id, return_full_text=False, ) assistant_text = result[0]["generated_text"].strip() if result else "No response." except Exception as e: print("Generation error:", e) assistant_text = "Error generating response. Please retry later." return JSONResponse( { "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "choices": [{ "message": { "role": "assistant", "content": assistant_text, } }] } ) except Exception as e: print(f"Chat endpoint error: {e}") return JSONResponse( { "id": f"chatcmpl-{uuid.uuid4().hex[:8]}", "choices": [{ "message": { "role": "assistant", "content": "Server error. Please try again later.", } }] }, status_code=500 ) # ------------------------------------------------------------ # 7. Health endpoint # ------------------------------------------------------------ @app.get("/") async def root(): return {"status": "healthy", "message": "MedGemma API is running"} @app.get("/health") async def health(): return { "status": "ok", "model_loaded": medgemma_pipe is not None, "hf_token_present": bool(HF_TOKEN), "cache_dir": CACHE_DIR, "cache_exists": os.path.exists(CACHE_DIR), "cache_writable": os.access(CACHE_DIR, os.W_OK), "env_vars": { "HF_HOME": os.environ.get("HF_HOME"), "TRANSFORMERS_CACHE": os.environ.get("TRANSFORMERS_CACHE"), "HF_HUB_CACHE": os.environ.get("HF_HUB_CACHE"), } } # ------------------------------------------------------------ # 8. For local dev (won't run inside Space runtime) # ------------------------------------------------------------ if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)