# app.py # --- ENV FIX: MUSS GANZ OBEN STEHEN, vor JEDEM anderen Import! --- import os os.environ.pop("TORCH_LOGS", None) # falls ungültig gesetzt -> torch-Import fixen os.environ.setdefault("TORCHDYNAMO_DISABLE", "1") # Dynamo/JIT aus (vermeidet Gemma3-Issues) os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") import gradio as gr import spaces import torch from PIL import Image from transformers import pipeline # Optional: Dynamo-Fehler unterdrücken, falls doch aktiv try: import torch._dynamo as dynamo dynamo.config.suppress_errors = True except Exception: pass # ----------------------------- # Config # ----------------------------- MODEL_ID = os.getenv("MODEL_ID", "olemeyer/dentalai-nxt-medgemma4b") # ggf. "google/medgemma-4b-it" HF_TOKEN = os.getenv("HF_TOKEN") # im Space als Secret setzen def _device_index() -> int: return 0 if torch.cuda.is_available() else -1 torch.set_grad_enabled(False) if hasattr(torch, "set_float32_matmul_precision"): torch.set_float32_matmul_precision("high") pipe = pipeline( task="image-text-to-text", model=MODEL_ID, token=HF_TOKEN, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, device=_device_index(), # 0=CUDA:0, -1=CPU ) # ----------------------------- # Utils # ----------------------------- def _resize_long_side_max(img: Image.Image, max_long_side: int = 800) -> Image.Image: w, h = img.size long_side = max(w, h) if long_side <= max_long_side: return img scale = max_long_side / float(long_side) new_w = int(round(w * scale)) new_h = int(round(h * scale)) return img.resize((new_w, new_h), Image.LANCZOS) # ----------------------------- # Inferenz (Trainingsformat: NUR user, Bild zuerst) # ----------------------------- @spaces.GPU def infer(image: Image.Image): if image is None: return "Bitte ein Bild hochladen." # Bild auf lange Seite 800px skalieren image = _resize_long_side_max(image, 800) messages = [ { "role": "user", "content": [ {"index": 0, "type": "image", "text": None, "image": image}, { "index": None, "type": "text", "text": "Analysiere dieses Zahnbild und extrahiere alle verfügbaren Informationen als JSON.", }, ], } ] try: out = pipe( text=messages, max_new_tokens=2048, # angefordert temperature=0.1, # angefordert ) except TypeError: # Fallback für alternative Parameternamen out = pipe( messages=messages, max_new_tokens=2048, temperature=0.1, ) except Exception as e: return f"Fehler bei der Inferenz: {e}" # Ausgabe normalisieren try: first = out[0] gen = first.get("generated_text", first) if isinstance(gen, list) and gen: last_msg = gen[-1] content = last_msg.get("content", "") if isinstance(content, list): texts = [c.get("text", "") for c in content if isinstance(c, dict) and c.get("type") == "text"] joined = "\n".join(t for t in texts if t) return joined or str(last_msg) if isinstance(content, str): return content return str(last_msg) if isinstance(gen, str): return gen return str(gen) except Exception: return str(out) # ----------------------------- # UI (simpel) # ----------------------------- with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown("# 🦷 DentalAI – Trainingsformat\nNur Bild hochladen → **Analysieren**.") with gr.Row(): with gr.Column(scale=1): image_in = gr.Image(type="pil", label="Bild", height=360) run_btn = gr.Button("Analysieren", variant="primary") with gr.Column(scale=1): output = gr.Textbox(label="Ausgabe", lines=22, show_copy_button=True) run_btn.click(fn=infer, inputs=[image_in], outputs=[output], api_name="predict") demo.queue().launch()