File size: 6,892 Bytes
e72aab1
 
 
 
 
 
 
8c721cb
6a814f1
e72aab1
6a814f1
 
 
e72aab1
6a814f1
 
 
 
 
 
08484be
937d1d0
8c721cb
 
 
08484be
 
e72aab1
08484be
0054c76
e72aab1
 
 
8c721cb
0054c76
cf6cb86
8c721cb
 
 
 
e72aab1
3bec787
e72aab1
 
 
08484be
 
 
 
 
 
3bec787
6a814f1
08484be
 
 
 
3bec787
 
08484be
e72aab1
08484be
3bec787
 
 
08484be
3bec787
 
08484be
 
e72aab1
 
6a814f1
08484be
 
 
e72aab1
 
 
 
8c721cb
 
 
e72aab1
8c721cb
 
 
 
 
e72aab1
 
 
0054c76
08484be
 
 
0054c76
8c721cb
e72aab1
 
 
8c721cb
937d1d0
c4fd269
e72aab1
 
c4fd269
e72aab1
 
08484be
 
e72aab1
 
 
 
 
 
 
 
 
 
 
 
 
682c96e
e72aab1
08484be
 
 
 
 
 
 
e72aab1
 
08484be
e72aab1
 
 
 
 
 
 
 
 
 
 
 
 
 
0e2d025
08484be
e72aab1
 
 
 
 
 
 
 
 
 
 
 
3817b9e
e72aab1
 
 
c4fd269
 
e72aab1
c4fd269
3817b9e
 
edbbbaa
0054c76
08484be
0054c76
6a814f1
 
 
 
 
 
 
 
edbbbaa
3817b9e
e72aab1
 
 
3817b9e
 
e72aab1
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
# app.py
# FastAPI backend for a Hugging Face Space (CPU tier)
# • Only MedGemma-4B-IT, no Parakeet, no tool-calling
# • Reads HF_TOKEN from Space secrets, uses /tmp for writable cache
# • /chat endpoint expects {"messages":[{"role":"user","content": "..."}]}

import os, pathlib, uuid
from typing import List, Optional

# Set all cache-related environment variables BEFORE importing transformers
CACHE_DIR = "/tmp/hf_cache"
os.makedirs(CACHE_DIR, exist_ok=True)

# Set ALL possible cache environment variables
os.environ["HF_HOME"] = CACHE_DIR
os.environ["TRANSFORMERS_CACHE"] = CACHE_DIR
os.environ["HF_HUB_CACHE"] = CACHE_DIR
os.environ["HUGGINGFACE_HUB_CACHE"] = CACHE_DIR
os.environ["HF_DATASETS_CACHE"] = CACHE_DIR

import torch
from fastapi import FastAPI, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from transformers import pipeline

# Get token
HF_TOKEN = os.getenv("HF_TOKEN") or os.getenv("HUGGINGFACE_TOKEN")

# ------------------------------------------------------------
# 2.  Simple Pydantic request model
# ------------------------------------------------------------
class Message(BaseModel):
    role: str
    content: Optional[str]

class ChatCompletionRequest(BaseModel):
    messages: List[Message]

# ------------------------------------------------------------
# 3.  Lazy MedGemma loader with memory optimization
# ------------------------------------------------------------
DEVICE = "cpu"
DTYPE  = torch.float32
medgemma_pipe = None

def get_medgemma():
    global medgemma_pipe
    if medgemma_pipe is None:
        try:
            print("🚀 Loading MedGemma-4B-IT with memory optimization...")
            print(f"Using cache directory: {CACHE_DIR}")
            
            medgemma_pipe = pipeline(
                "text-generation",
                model="google/medgemma-4b-it",
                torch_dtype=torch.float16,  # Use float16 to reduce memory
                device_map="auto",
                token=HF_TOKEN,
                cache_dir=CACHE_DIR,
                trust_remote_code=True,
                # Memory optimization
                low_cpu_mem_usage=True,
                max_memory={0: "6GB", "cpu": "8GB"},  # Limit memory usage
            )
            print("✅ MedGemma loaded successfully!")
            
        except Exception as e:
            print(f"❌ Error loading MedGemma: {e}")
            print(f"Cache directory exists: {os.path.exists(CACHE_DIR)}")
            print(f"Cache directory writable: {os.access(CACHE_DIR, os.W_OK)}")
            print(f"HF_TOKEN present: {bool(HF_TOKEN)}")
            medgemma_pipe = None
    return medgemma_pipe

# ------------------------------------------------------------
# 4.  FastAPI app with permissive CORS (for Replit frontend)
# ------------------------------------------------------------
app = FastAPI(title="MedGemma Radiology Chat")

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],          # adjust in prod
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# ------------------------------------------------------------
# 5.  System prompt
# ------------------------------------------------------------
SYSTEM_PROMPT = (
    "You are MedGemma, a medical vision-language assistant specialised in radiology. "
    "When given a patient case or study description, respond with a concise, professional "
    "radiology report. Use headings such as FINDINGS and IMPRESSION."
)

# ------------------------------------------------------------
# 6.  /chat endpoint
# ------------------------------------------------------------
@app.post("/chat")
async def chat(request: Request):
    try:
        body     = await request.json()
        payload  = ChatCompletionRequest(**body)
        user_msg = payload.messages[-1].content or ""
        prompt   = f"{SYSTEM_PROMPT}\n\n{user_msg}\n\nRadiology Report:\n"

        pipe = get_medgemma()
        if pipe is None:
            return JSONResponse(
                {
                    "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
                    "choices": [{
                        "message": {
                            "role": "assistant",
                            "content": "MedGemma model is unavailable. "
                                       "Check your gated-model access and HF_TOKEN.",
                        }
                    }],
                },
                status_code=503,
            )

        try:
            result = pipe(
                prompt,
                max_new_tokens=256,
                do_sample=True,
                temperature=0.7,
                pad_token_id=pipe.tokenizer.eos_token_id,
                return_full_text=False,
            )
            assistant_text = result[0]["generated_text"].strip() if result else "No response."
        except Exception as e:
            print("Generation error:", e)
            assistant_text = "Error generating response. Please retry later."

        return JSONResponse(
            {
                "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
                "choices": [{
                    "message": {
                        "role": "assistant",
                        "content": assistant_text,
                    }
                }]
            }
        )
    except Exception as e:
        print(f"Chat endpoint error: {e}")
        return JSONResponse(
            {
                "id": f"chatcmpl-{uuid.uuid4().hex[:8]}",
                "choices": [{
                    "message": {
                        "role": "assistant",
                        "content": "Server error. Please try again later.",
                    }
                }]
            },
            status_code=500
        )

# ------------------------------------------------------------
# 7.  Health endpoint
# ------------------------------------------------------------
@app.get("/")
async def root():
    return {"status": "healthy", "message": "MedGemma API is running"}

@app.get("/health")
async def health():
    return {
        "status": "ok",
        "model_loaded": medgemma_pipe is not None,
        "hf_token_present": bool(HF_TOKEN),
        "cache_dir": CACHE_DIR,
        "cache_exists": os.path.exists(CACHE_DIR),
        "cache_writable": os.access(CACHE_DIR, os.W_OK),
        "env_vars": {
            "HF_HOME": os.environ.get("HF_HOME"),
            "TRANSFORMERS_CACHE": os.environ.get("TRANSFORMERS_CACHE"),
            "HF_HUB_CACHE": os.environ.get("HF_HUB_CACHE"),
        }
    }

# ------------------------------------------------------------
# 8.  For local dev (won't run inside Space runtime)
# ------------------------------------------------------------
if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860)