Godfrey / app.py
AmeyaKawthalkar's picture
Update app.py
3bec787 verified
# app.py
# FastAPI backend for a Hugging Face Space (CPU tier)
# • Only MedGemma-4B-IT, no Parakeet, no tool-calling
# • Reads HF_TOKEN from Space secrets, uses /tmp for writable cache
# • /chat endpoint expects {"messages":[{"role":"user","content": "..."}]}
import os, pathlib, uuid
from typing import List, Optional
# Set all cache-related environment variables BEFORE importing transformers
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# Set ALL possible cache environment variables
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HUB_CACHE"] = CACHE_DIR
os.environ["HUGGINGFACE_HUB_CACHE"] = CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import pipeline
# Get token
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
# ------------------------------------------------------------
# 2. Simple Pydantic request model
# ------------------------------------------------------------
class Message(BaseModel):
role: str
content: Optional[str]
class ChatCompletionRequest(BaseModel):
messages: List[Message]
# ------------------------------------------------------------
# 3. Lazy MedGemma loader with memory optimization
# ------------------------------------------------------------
DEVICE = "cpu"
DTYPE = torch.float32
medgemma_pipe = None
def get_medgemma():
global medgemma_pipe
if medgemma_pipe is None:
try:
print("🚀 Loading MedGemma-4B-IT with memory optimization...")
print(f"Using cache directory: {CACHE_DIR}")
medgemma_pipe = pipeline(
"text-generation",
model="google/medgemma-4b-it",
torch_dtype=torch.float16, # Use float16 to reduce memory
device_map="auto",
token=HF_TOKEN,
cache_dir=CACHE_DIR,
trust_remote_code=True,
# Memory optimization
low_cpu_mem_usage=True,
max_memory={0: "6GB", "cpu": "8GB"}, # Limit memory usage
)
print("✅ MedGemma loaded successfully!")
except Exception as e:
print(f"❌ Error loading MedGemma: {e}")
print(f"Cache directory exists: {os.path.exists(CACHE_DIR)}")
print(f"Cache directory writable: {os.access(CACHE_DIR, os.W_OK)}")
print(f"HF_TOKEN present: {bool(HF_TOKEN)}")
medgemma_pipe = None
return medgemma_pipe
# ------------------------------------------------------------
# 4. FastAPI app with permissive CORS (for Replit frontend)
# ------------------------------------------------------------
app = FastAPI(title="MedGemma Radiology Chat")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # adjust in prod
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ------------------------------------------------------------
# 5. System prompt
# ------------------------------------------------------------
SYSTEM_PROMPT = (
"You are MedGemma, a medical vision-language assistant specialised in radiology. "
"When given a patient case or study description, respond with a concise, professional "
"radiology report. Use headings such as FINDINGS and IMPRESSION."
)
# ------------------------------------------------------------
# 6. /chat endpoint
# ------------------------------------------------------------
@app.post("/chat")
async def chat(request: Request):
try:
body = await request.json()
payload = ChatCompletionRequest(**body)
user_msg = payload.messages[-1].content or ""
prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}\n\nRadiology Report:\n"
pipe = get_medgemma()
if pipe is None:
return JSONResponse(
{
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"choices": [{
"message": {
"role": "assistant",
"content": "MedGemma model is unavailable. "
"Check your gated-model access and HF_TOKEN.",
}
}],
},
status_code=503,
)
try:
result = pipe(
prompt,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
pad_token_id=pipe.tokenizer.eos_token_id,
return_full_text=False,
)
assistant_text = result[0]["generated_text"].strip() if result else "No response."
except Exception as e:
print("Generation error:", e)
assistant_text = "Error generating response. Please retry later."
return JSONResponse(
{
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"choices": [{
"message": {
"role": "assistant",
"content": assistant_text,
}
}]
}
)
except Exception as e:
print(f"Chat endpoint error: {e}")
return JSONResponse(
{
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"choices": [{
"message": {
"role": "assistant",
"content": "Server error. Please try again later.",
}
}]
},
status_code=500
)
# ------------------------------------------------------------
# 7. Health endpoint
# ------------------------------------------------------------
@app.get("/")
async def root():
return {"status": "healthy", "message": "MedGemma API is running"}
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": medgemma_pipe is not None,
"hf_token_present": bool(HF_TOKEN),
"cache_dir": CACHE_DIR,
"cache_exists": os.path.exists(CACHE_DIR),
"cache_writable": os.access(CACHE_DIR, os.W_OK),
"env_vars": {
"HF_HOME": os.environ.get("HF_HOME"),
"TRANSFORMERS_CACHE": os.environ.get("TRANSFORMERS_CACHE"),
"HF_HUB_CACHE": os.environ.get("HF_HUB_CACHE"),
}
}
# ------------------------------------------------------------
# 8. For local dev (won't run inside Space runtime)
# ------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)