Spaces:
Running
Running
File size: 6,892 Bytes
e72aab1 8c721cb 6a814f1 e72aab1 6a814f1 e72aab1 6a814f1 08484be 937d1d0 8c721cb 08484be e72aab1 08484be 0054c76 e72aab1 8c721cb 0054c76 cf6cb86 8c721cb e72aab1 3bec787 e72aab1 08484be 3bec787 6a814f1 08484be 3bec787 08484be e72aab1 08484be 3bec787 08484be 3bec787 08484be e72aab1 6a814f1 08484be e72aab1 8c721cb e72aab1 8c721cb e72aab1 0054c76 08484be 0054c76 8c721cb e72aab1 8c721cb 937d1d0 c4fd269 e72aab1 c4fd269 e72aab1 08484be e72aab1 682c96e e72aab1 08484be e72aab1 08484be e72aab1 0e2d025 08484be e72aab1 3817b9e e72aab1 c4fd269 e72aab1 c4fd269 3817b9e edbbbaa 0054c76 08484be 0054c76 6a814f1 edbbbaa 3817b9e e72aab1 3817b9e e72aab1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
# app.py
# FastAPI backend for a Hugging Face Space (CPU tier)
# • Only MedGemma-4B-IT, no Parakeet, no tool-calling
# • Reads HF_TOKEN from Space secrets, uses /tmp for writable cache
# • /chat endpoint expects {"messages":[{"role":"user","content": "..."}]}
import os, pathlib, uuid
from typing import List, Optional
# Set all cache-related environment variables BEFORE importing transformers
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)
# Set ALL possible cache environment variables
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HUB_CACHE"] = CACHE_DIR
os.environ["HUGGINGFACE_HUB_CACHE"] = CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = CACHE_DIR
import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import pipeline
# Get token
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")
# ------------------------------------------------------------
# 2. Simple Pydantic request model
# ------------------------------------------------------------
class Message(BaseModel):
role: str
content: Optional[str]
class ChatCompletionRequest(BaseModel):
messages: List[Message]
# ------------------------------------------------------------
# 3. Lazy MedGemma loader with memory optimization
# ------------------------------------------------------------
DEVICE = "cpu"
DTYPE = torch.float32
medgemma_pipe = None
def get_medgemma():
global medgemma_pipe
if medgemma_pipe is None:
try:
print("🚀 Loading MedGemma-4B-IT with memory optimization...")
print(f"Using cache directory: {CACHE_DIR}")
medgemma_pipe = pipeline(
"text-generation",
model="google/medgemma-4b-it",
torch_dtype=torch.float16, # Use float16 to reduce memory
device_map="auto",
token=HF_TOKEN,
cache_dir=CACHE_DIR,
trust_remote_code=True,
# Memory optimization
low_cpu_mem_usage=True,
max_memory={0: "6GB", "cpu": "8GB"}, # Limit memory usage
)
print("✅ MedGemma loaded successfully!")
except Exception as e:
print(f"❌ Error loading MedGemma: {e}")
print(f"Cache directory exists: {os.path.exists(CACHE_DIR)}")
print(f"Cache directory writable: {os.access(CACHE_DIR, os.W_OK)}")
print(f"HF_TOKEN present: {bool(HF_TOKEN)}")
medgemma_pipe = None
return medgemma_pipe
# ------------------------------------------------------------
# 4. FastAPI app with permissive CORS (for Replit frontend)
# ------------------------------------------------------------
app = FastAPI(title="MedGemma Radiology Chat")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # adjust in prod
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ------------------------------------------------------------
# 5. System prompt
# ------------------------------------------------------------
SYSTEM_PROMPT = (
"You are MedGemma, a medical vision-language assistant specialised in radiology. "
"When given a patient case or study description, respond with a concise, professional "
"radiology report. Use headings such as FINDINGS and IMPRESSION."
)
# ------------------------------------------------------------
# 6. /chat endpoint
# ------------------------------------------------------------
@app.post("/chat")
async def chat(request: Request):
try:
body = await request.json()
payload = ChatCompletionRequest(**body)
user_msg = payload.messages[-1].content or ""
prompt = f"{SYSTEM_PROMPT}\n\n{user_msg}\n\nRadiology Report:\n"
pipe = get_medgemma()
if pipe is None:
return JSONResponse(
{
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"choices": [{
"message": {
"role": "assistant",
"content": "MedGemma model is unavailable. "
"Check your gated-model access and HF_TOKEN.",
}
}],
},
status_code=503,
)
try:
result = pipe(
prompt,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
pad_token_id=pipe.tokenizer.eos_token_id,
return_full_text=False,
)
assistant_text = result[0]["generated_text"].strip() if result else "No response."
except Exception as e:
print("Generation error:", e)
assistant_text = "Error generating response. Please retry later."
return JSONResponse(
{
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"choices": [{
"message": {
"role": "assistant",
"content": assistant_text,
}
}]
}
)
except Exception as e:
print(f"Chat endpoint error: {e}")
return JSONResponse(
{
"id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
"choices": [{
"message": {
"role": "assistant",
"content": "Server error. Please try again later.",
}
}]
},
status_code=500
)
# ------------------------------------------------------------
# 7. Health endpoint
# ------------------------------------------------------------
@app.get("/")
async def root():
return {"status": "healthy", "message": "MedGemma API is running"}
@app.get("/health")
async def health():
return {
"status": "ok",
"model_loaded": medgemma_pipe is not None,
"hf_token_present": bool(HF_TOKEN),
"cache_dir": CACHE_DIR,
"cache_exists": os.path.exists(CACHE_DIR),
"cache_writable": os.access(CACHE_DIR, os.W_OK),
"env_vars": {
"HF_HOME": os.environ.get("HF_HOME"),
"TRANSFORMERS_CACHE": os.environ.get("TRANSFORMERS_CACHE"),
"HF_HUB_CACHE": os.environ.get("HF_HUB_CACHE"),
}
}
# ------------------------------------------------------------
# 8. For local dev (won't run inside Space runtime)
# ------------------------------------------------------------
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860)
|