Spaces:
Running
Running
# app.py | |
# FastAPI backend for a Hugging Face Space (CPU tier) | |
# • Only MedGemma-4B-IT, no Parakeet, no tool-calling | |
# • Reads HF_TOKEN from Space secrets, uses /tmp for writable cache | |
# • /chat endpoint expects {"messages":[{"role":"user","content": "..."}]} | |
import os, pathlib, uuid | |
from typing import List, Optional | |
# Set all cache-related environment variables BEFORE importing transformers | |
CACHE_DIR = "/tmp/hf_cache" | |
os.makedirs(CACHE_DIR, exist_ok=True) | |
# Set ALL possible cache environment variables | |
os.environ["HF_HOME"] = CACHE_DIR | |
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR | |
os.environ["HF_HUB_CACHE"] = CACHE_DIR | |
os.environ["HUGGINGFACE_HUB_CACHE"] = CACHE_DIR | |
os.environ["HF_DATASETS_CACHE"] = CACHE_DIR | |
import torch | |
from fastapi import FastAPI, Request | |
from fastapi.middleware.cors import CORSMiddleware | |
from fastapi.responses import JSONResponse | |
from pydantic import BaseModel | |
from transformers import pipeline | |
# Get token | |
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN") | |
# ------------------------------------------------------------ | |
# 2. Simple Pydantic request model | |
# ------------------------------------------------------------ | |
class Message(BaseModel): | |
role: str | |
content: Optional[str] | |
class ChatCompletionRequest(BaseModel): | |
messages: List[Message] | |
# ------------------------------------------------------------ | |
# 3. Lazy MedGemma loader with memory optimization | |
# ------------------------------------------------------------ | |
DEVICE = "cpu" | |
DTYPE = torch.float32 | |
medgemma_pipe = None | |
def get_medgemma(): | |
global medgemma_pipe | |
if medgemma_pipe is None: | |
try: | |
print("🚀 Loading MedGemma-4B-IT with memory optimization...") | |
print(f"Using cache directory: {CACHE_DIR}") | |
medgemma_pipe = pipeline( | |
"text-generation", | |
model="google/medgemma-4b-it", | |
torch_dtype=torch.float16, # Use float16 to reduce memory | |
device_map="auto", | |
token=HF_TOKEN, | |
cache_dir=CACHE_DIR, | |
trust_remote_code=True, | |
# Memory optimization | |
low_cpu_mem_usage=True, | |
max_memory={0: "6GB", "cpu": "8GB"}, # Limit memory usage | |
) | |
print("✅ MedGemma loaded successfully!") | |
except Exception as e: | |
print(f"❌ Error loading MedGemma: {e}") | |
print(f"Cache directory exists: {os.path.exists(CACHE_DIR)}") | |
print(f"Cache directory writable: {os.access(CACHE_DIR, os.W_OK)}") | |
print(f"HF_TOKEN present: {bool(HF_TOKEN)}") | |
medgemma_pipe = None | |
return medgemma_pipe | |
# ------------------------------------------------------------ | |
# 4. FastAPI app with permissive CORS (for Replit frontend) | |
# ------------------------------------------------------------ | |
app = FastAPI(title="MedGemma Radiology Chat") | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # adjust in prod | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# ------------------------------------------------------------ | |
# 5. System prompt | |
# ------------------------------------------------------------ | |
SYSTEM_PROMPT = ( | |
"You are MedGemma, a medical vision-language assistant specialised in radiology. " | |
"When given a patient case or study description, respond with a concise, professional " | |
"radiology report. Use headings such as FINDINGS and IMPRESSION." | |
) | |
# ------------------------------------------------------------ | |
# 6. /chat endpoint | |
# ------------------------------------------------------------ | |
async def chat(request: Request): | |
try: | |
body = await request.json() | |
payload = ChatCompletionRequest(**body) | |
user_msg = payload.messages[-1].content or "" | |
prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}\n\nRadiology Report:\n" | |
pipe = get_medgemma() | |
if pipe is None: | |
return JSONResponse( | |
{ | |
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}", | |
"choices": [{ | |
"message": { | |
"role": "assistant", | |
"content": "MedGemma model is unavailable. " | |
"Check your gated-model access and HF_TOKEN.", | |
} | |
}], | |
}, | |
status_code=503, | |
) | |
try: | |
result = pipe( | |
prompt, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.7, | |
pad_token_id=pipe.tokenizer.eos_token_id, | |
return_full_text=False, | |
) | |
assistant_text = result[0]["generated_text"].strip() if result else "No response." | |
except Exception as e: | |
print("Generation error:", e) | |
assistant_text = "Error generating response. Please retry later." | |
return JSONResponse( | |
{ | |
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}", | |
"choices": [{ | |
"message": { | |
"role": "assistant", | |
"content": assistant_text, | |
} | |
}] | |
} | |
) | |
except Exception as e: | |
print(f"Chat endpoint error: {e}") | |
return JSONResponse( | |
{ | |
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}", | |
"choices": [{ | |
"message": { | |
"role": "assistant", | |
"content": "Server error. Please try again later.", | |
} | |
}] | |
}, | |
status_code=500 | |
) | |
# ------------------------------------------------------------ | |
# 7. Health endpoint | |
# ------------------------------------------------------------ | |
async def root(): | |
return {"status": "healthy", "message": "MedGemma API is running"} | |
async def health(): | |
return { | |
"status": "ok", | |
"model_loaded": medgemma_pipe is not None, | |
"hf_token_present": bool(HF_TOKEN), | |
"cache_dir": CACHE_DIR, | |
"cache_exists": os.path.exists(CACHE_DIR), | |
"cache_writable": os.access(CACHE_DIR, os.W_OK), | |
"env_vars": { | |
"HF_HOME": os.environ.get("HF_HOME"), | |
"TRANSFORMERS_CACHE": os.environ.get("TRANSFORMERS_CACHE"), | |
"HF_HUB_CACHE": os.environ.get("HF_HUB_CACHE"), | |
} | |
} | |
# ------------------------------------------------------------ | |
# 8. For local dev (won't run inside Space runtime) | |
# ------------------------------------------------------------ | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860) | |