File size: 4,203 Bytes
c05b19d
9978565
e7f3452
f3086c3
 
9978565
a618a58
fde035c
f3086c3
fde035c
6d6070f
e7f3452
 
f3086c3
a618a58
 
 
 
 
 
e7f3452
 
 
f3086c3
 
e7f3452
9978565
c05b19d
e7f3452
a618a58
 
 
 
6d6070f
e7f3452
 
 
 
9978565
fde035c
 
f3086c3
 
 
 
 
 
 
 
 
 
 
 
d5481aa
af65cab
f3086c3
e7f3452
fde035c
9978565
e7f3452
c05b19d
e7f3452
f3086c3
 
d5481aa
f3086c3
 
 
 
 
 
 
 
 
 
 
 
 
e7f3452
 
f3086c3
 
 
 
 
e7f3452
f3086c3
 
 
 
 
 
e7f3452
c05b19d
e7f3452
a618a58
d5481aa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e7f3452
 
9978565
e7f3452
d917144
f3086c3
e7f3452
af65cab
e7f3452
 
af65cab
f3086c3
d917144
9978565
d917144
af65cab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# app.py
# --- ENV FIX: MUSS GANZ OBEN STEHEN, vor JEDEM anderen Import! ---
import os
os.environ.pop("TORCH_LOGS", None)                 # falls ungültig gesetzt -> torch-Import fixen
os.environ.setdefault("TORCHDYNAMO_DISABLE", "1")  # Dynamo/JIT aus (vermeidet Gemma3-Issues)
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")

import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import pipeline

# Optional: Dynamo-Fehler unterdrücken, falls doch aktiv
try:
    import torch._dynamo as dynamo
    dynamo.config.suppress_errors = True
except Exception:
    pass

# -----------------------------
# Config
# -----------------------------
MODEL_ID = os.getenv("MODEL_ID", "olemeyer/dentalai-nxt-medgemma4b")  # ggf. "google/medgemma-4b-it"
HF_TOKEN = os.getenv("HF_TOKEN")  # im Space als Secret setzen

def _device_index() -> int:
    return 0 if torch.cuda.is_available() else -1

torch.set_grad_enabled(False)
if hasattr(torch, "set_float32_matmul_precision"):
    torch.set_float32_matmul_precision("high")

pipe = pipeline(
    task="image-text-to-text",
    model=MODEL_ID,
    token=HF_TOKEN,
    torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
    device=_device_index(),  # 0=CUDA:0, -1=CPU
)

# -----------------------------
# Utils
# -----------------------------
def _resize_long_side_max(img: Image.Image, max_long_side: int = 800) -> Image.Image:
    w, h = img.size
    long_side = max(w, h)
    if long_side <= max_long_side:
        return img
    scale = max_long_side / float(long_side)
    new_w = int(round(w * scale))
    new_h = int(round(h * scale))
    return img.resize((new_w, new_h), Image.LANCZOS)

# -----------------------------
# Inferenz (Trainingsformat: NUR user, Bild zuerst)
# -----------------------------
@spaces.GPU
def infer(image: Image.Image):
    if image is None:
        return "Bitte ein Bild hochladen."

    # Bild auf lange Seite 800px skalieren
    image = _resize_long_side_max(image, 800)

    messages = [
        {
            "role": "user",
            "content": [
                {"index": 0, "type": "image", "text": None, "image": image},
                {
                    "index": None,
                    "type": "text",
                    "text": "Analysiere dieses Zahnbild und extrahiere alle verfügbaren Informationen als JSON.",
                },
            ],
        }
    ]

    try:
        out = pipe(
            text=messages,
            max_new_tokens=2048,   # angefordert
            temperature=0.1,       # angefordert
        )
    except TypeError:
        # Fallback für alternative Parameternamen
        out = pipe(
            messages=messages,
            max_new_tokens=2048,
            temperature=0.1,
        )
    except Exception as e:
        return f"Fehler bei der Inferenz: {e}"

    # Ausgabe normalisieren
    try:
        first = out[0]
        gen = first.get("generated_text", first)
        if isinstance(gen, list) and gen:
            last_msg = gen[-1]
            content = last_msg.get("content", "")
            if isinstance(content, list):
                texts = [c.get("text", "") for c in content if isinstance(c, dict) and c.get("type") == "text"]
                joined = "\n".join(t for t in texts if t)
                return joined or str(last_msg)
            if isinstance(content, str):
                return content
            return str(last_msg)
        if isinstance(gen, str):
            return gen
        return str(gen)
    except Exception:
        return str(out)

# -----------------------------
# UI (simpel)
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🦷 DentalAI – Trainingsformat\nNur Bild hochladen → **Analysieren**.")
    with gr.Row():
        with gr.Column(scale=1):
            image_in = gr.Image(type="pil", label="Bild", height=360)
            run_btn = gr.Button("Analysieren", variant="primary")
        with gr.Column(scale=1):
            output = gr.Textbox(label="Ausgabe", lines=22, show_copy_button=True)

    run_btn.click(fn=infer, inputs=[image_in], outputs=[output], api_name="predict")

demo.queue().launch()