dentalAI-nxt / app.py
olemeyer's picture
Update app.py
f3086c3 verified
# app.py
# --- ENV FIX: MUSS GANZ OBEN STEHEN, vor JEDEM anderen Import! ---
import os
os.environ.pop("TORCH_LOGS", None) # falls ungültig gesetzt -> torch-Import fixen
os.environ.setdefault("TORCHDYNAMO_DISABLE", "1") # Dynamo/JIT aus (vermeidet Gemma3-Issues)
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import pipeline
# Optional: Dynamo-Fehler unterdrücken, falls doch aktiv
try:
import torch._dynamo as dynamo
dynamo.config.suppress_errors = True
except Exception:
pass
# -----------------------------
# Config
# -----------------------------
MODEL_ID = os.getenv("MODEL_ID", "olemeyer/dentalai-nxt-medgemma4b") # ggf. "google/medgemma-4b-it"
HF_TOKEN = os.getenv("HF_TOKEN") # im Space als Secret setzen
def _device_index() -> int:
return 0 if torch.cuda.is_available() else -1
torch.set_grad_enabled(False)
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
pipe = pipeline(
task="image-text-to-text",
model=MODEL_ID,
token=HF_TOKEN,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
device=_device_index(), # 0=CUDA:0, -1=CPU
)
# -----------------------------
# Utils
# -----------------------------
def _resize_long_side_max(img: Image.Image, max_long_side: int = 800) -> Image.Image:
w, h = img.size
long_side = max(w, h)
if long_side <= max_long_side:
return img
scale = max_long_side / float(long_side)
new_w = int(round(w * scale))
new_h = int(round(h * scale))
return img.resize((new_w, new_h), Image.LANCZOS)
# -----------------------------
# Inferenz (Trainingsformat: NUR user, Bild zuerst)
# -----------------------------
@spaces.GPU
def infer(image: Image.Image):
if image is None:
return "Bitte ein Bild hochladen."
# Bild auf lange Seite 800px skalieren
image = _resize_long_side_max(image, 800)
messages = [
{
"role": "user",
"content": [
{"index": 0, "type": "image", "text": None, "image": image},
{
"index": None,
"type": "text",
"text": "Analysiere dieses Zahnbild und extrahiere alle verfügbaren Informationen als JSON.",
},
],
}
]
try:
out = pipe(
text=messages,
max_new_tokens=2048, # angefordert
temperature=0.1, # angefordert
)
except TypeError:
# Fallback für alternative Parameternamen
out = pipe(
messages=messages,
max_new_tokens=2048,
temperature=0.1,
)
except Exception as e:
return f"Fehler bei der Inferenz: {e}"
# Ausgabe normalisieren
try:
first = out[0]
gen = first.get("generated_text", first)
if isinstance(gen, list) and gen:
last_msg = gen[-1]
content = last_msg.get("content", "")
if isinstance(content, list):
texts = [c.get("text", "") for c in content if isinstance(c, dict) and c.get("type") == "text"]
joined = "\n".join(t for t in texts if t)
return joined or str(last_msg)
if isinstance(content, str):
return content
return str(last_msg)
if isinstance(gen, str):
return gen
return str(gen)
except Exception:
return str(out)
# -----------------------------
# UI (simpel)
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🦷 DentalAI – Trainingsformat\nNur Bild hochladen → **Analysieren**.")
with gr.Row():
with gr.Column(scale=1):
image_in = gr.Image(type="pil", label="Bild", height=360)
run_btn = gr.Button("Analysieren", variant="primary")
with gr.Column(scale=1):
output = gr.Textbox(label="Ausgabe", lines=22, show_copy_button=True)
run_btn.click(fn=infer, inputs=[image_in], outputs=[output], api_name="predict")
demo.queue().launch()