Spaces:
Running
on
Zero
Running
on
Zero
# app.py | |
# --- ENV FIX: MUSS GANZ OBEN STEHEN, vor JEDEM anderen Import! --- | |
import os | |
os.environ.pop("TORCH_LOGS", None) # falls ungültig gesetzt -> torch-Import fixen | |
os.environ.setdefault("TORCHDYNAMO_DISABLE", "1") # Dynamo/JIT aus (vermeidet Gemma3-Issues) | |
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0") | |
import gradio as gr | |
import spaces | |
import torch | |
from PIL import Image | |
from transformers import pipeline | |
# Optional: Dynamo-Fehler unterdrücken, falls doch aktiv | |
try: | |
import torch._dynamo as dynamo | |
dynamo.config.suppress_errors = True | |
except Exception: | |
pass | |
# ----------------------------- | |
# Config | |
# ----------------------------- | |
MODEL_ID = os.getenv("MODEL_ID", "olemeyer/dentalai-nxt-medgemma4b") # ggf. "google/medgemma-4b-it" | |
HF_TOKEN = os.getenv("HF_TOKEN") # im Space als Secret setzen | |
def _device_index() -> int: | |
return 0 if torch.cuda.is_available() else -1 | |
torch.set_grad_enabled(False) | |
if hasattr(torch, "set_float32_matmul_precision"): | |
torch.set_float32_matmul_precision("high") | |
pipe = pipeline( | |
task="image-text-to-text", | |
model=MODEL_ID, | |
token=HF_TOKEN, | |
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None, | |
device=_device_index(), # 0=CUDA:0, -1=CPU | |
) | |
# ----------------------------- | |
# Utils | |
# ----------------------------- | |
def _resize_long_side_max(img: Image.Image, max_long_side: int = 800) -> Image.Image: | |
w, h = img.size | |
long_side = max(w, h) | |
if long_side <= max_long_side: | |
return img | |
scale = max_long_side / float(long_side) | |
new_w = int(round(w * scale)) | |
new_h = int(round(h * scale)) | |
return img.resize((new_w, new_h), Image.LANCZOS) | |
# ----------------------------- | |
# Inferenz (Trainingsformat: NUR user, Bild zuerst) | |
# ----------------------------- | |
def infer(image: Image.Image): | |
if image is None: | |
return "Bitte ein Bild hochladen." | |
# Bild auf lange Seite 800px skalieren | |
image = _resize_long_side_max(image, 800) | |
messages = [ | |
{ | |
"role": "user", | |
"content": [ | |
{"index": 0, "type": "image", "text": None, "image": image}, | |
{ | |
"index": None, | |
"type": "text", | |
"text": "Analysiere dieses Zahnbild und extrahiere alle verfügbaren Informationen als JSON.", | |
}, | |
], | |
} | |
] | |
try: | |
out = pipe( | |
text=messages, | |
max_new_tokens=2048, # angefordert | |
temperature=0.1, # angefordert | |
) | |
except TypeError: | |
# Fallback für alternative Parameternamen | |
out = pipe( | |
messages=messages, | |
max_new_tokens=2048, | |
temperature=0.1, | |
) | |
except Exception as e: | |
return f"Fehler bei der Inferenz: {e}" | |
# Ausgabe normalisieren | |
try: | |
first = out[0] | |
gen = first.get("generated_text", first) | |
if isinstance(gen, list) and gen: | |
last_msg = gen[-1] | |
content = last_msg.get("content", "") | |
if isinstance(content, list): | |
texts = [c.get("text", "") for c in content if isinstance(c, dict) and c.get("type") == "text"] | |
joined = "\n".join(t for t in texts if t) | |
return joined or str(last_msg) | |
if isinstance(content, str): | |
return content | |
return str(last_msg) | |
if isinstance(gen, str): | |
return gen | |
return str(gen) | |
except Exception: | |
return str(out) | |
# ----------------------------- | |
# UI (simpel) | |
# ----------------------------- | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown("# 🦷 DentalAI – Trainingsformat\nNur Bild hochladen → **Analysieren**.") | |
with gr.Row(): | |
with gr.Column(scale=1): | |
image_in = gr.Image(type="pil", label="Bild", height=360) | |
run_btn = gr.Button("Analysieren", variant="primary") | |
with gr.Column(scale=1): | |
output = gr.Textbox(label="Ausgabe", lines=22, show_copy_button=True) | |
run_btn.click(fn=infer, inputs=[image_in], outputs=[output], api_name="predict") | |
demo.queue().launch() | |