Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,203 Bytes
c05b19d 9978565 e7f3452 f3086c3 9978565 a618a58 fde035c f3086c3 fde035c 6d6070f e7f3452 f3086c3 a618a58 e7f3452 f3086c3 e7f3452 9978565 c05b19d e7f3452 a618a58 6d6070f e7f3452 9978565 fde035c f3086c3 d5481aa af65cab f3086c3 e7f3452 fde035c 9978565 e7f3452 c05b19d e7f3452 f3086c3 d5481aa f3086c3 e7f3452 f3086c3 e7f3452 f3086c3 e7f3452 c05b19d e7f3452 a618a58 d5481aa e7f3452 9978565 e7f3452 d917144 f3086c3 e7f3452 af65cab e7f3452 af65cab f3086c3 d917144 9978565 d917144 af65cab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
# app.py
# --- ENV FIX: MUSS GANZ OBEN STEHEN, vor JEDEM anderen Import! ---
import os
os.environ.pop("TORCH_LOGS", None) # falls ungültig gesetzt -> torch-Import fixen
os.environ.setdefault("TORCHDYNAMO_DISABLE", "1") # Dynamo/JIT aus (vermeidet Gemma3-Issues)
os.environ.setdefault("CUDA_VISIBLE_DEVICES", "0")
import gradio as gr
import spaces
import torch
from PIL import Image
from transformers import pipeline
# Optional: Dynamo-Fehler unterdrücken, falls doch aktiv
try:
import torch._dynamo as dynamo
dynamo.config.suppress_errors = True
except Exception:
pass
# -----------------------------
# Config
# -----------------------------
MODEL_ID = os.getenv("MODEL_ID", "olemeyer/dentalai-nxt-medgemma4b") # ggf. "google/medgemma-4b-it"
HF_TOKEN = os.getenv("HF_TOKEN") # im Space als Secret setzen
def _device_index() -> int:
return 0 if torch.cuda.is_available() else -1
torch.set_grad_enabled(False)
if hasattr(torch, "set_float32_matmul_precision"):
torch.set_float32_matmul_precision("high")
pipe = pipeline(
task="image-text-to-text",
model=MODEL_ID,
token=HF_TOKEN,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else None,
device=_device_index(), # 0=CUDA:0, -1=CPU
)
# -----------------------------
# Utils
# -----------------------------
def _resize_long_side_max(img: Image.Image, max_long_side: int = 800) -> Image.Image:
w, h = img.size
long_side = max(w, h)
if long_side <= max_long_side:
return img
scale = max_long_side / float(long_side)
new_w = int(round(w * scale))
new_h = int(round(h * scale))
return img.resize((new_w, new_h), Image.LANCZOS)
# -----------------------------
# Inferenz (Trainingsformat: NUR user, Bild zuerst)
# -----------------------------
@spaces.GPU
def infer(image: Image.Image):
if image is None:
return "Bitte ein Bild hochladen."
# Bild auf lange Seite 800px skalieren
image = _resize_long_side_max(image, 800)
messages = [
{
"role": "user",
"content": [
{"index": 0, "type": "image", "text": None, "image": image},
{
"index": None,
"type": "text",
"text": "Analysiere dieses Zahnbild und extrahiere alle verfügbaren Informationen als JSON.",
},
],
}
]
try:
out = pipe(
text=messages,
max_new_tokens=2048, # angefordert
temperature=0.1, # angefordert
)
except TypeError:
# Fallback für alternative Parameternamen
out = pipe(
messages=messages,
max_new_tokens=2048,
temperature=0.1,
)
except Exception as e:
return f"Fehler bei der Inferenz: {e}"
# Ausgabe normalisieren
try:
first = out[0]
gen = first.get("generated_text", first)
if isinstance(gen, list) and gen:
last_msg = gen[-1]
content = last_msg.get("content", "")
if isinstance(content, list):
texts = [c.get("text", "") for c in content if isinstance(c, dict) and c.get("type") == "text"]
joined = "\n".join(t for t in texts if t)
return joined or str(last_msg)
if isinstance(content, str):
return content
return str(last_msg)
if isinstance(gen, str):
return gen
return str(gen)
except Exception:
return str(out)
# -----------------------------
# UI (simpel)
# -----------------------------
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown("# 🦷 DentalAI – Trainingsformat\nNur Bild hochladen → **Analysieren**.")
with gr.Row():
with gr.Column(scale=1):
image_in = gr.Image(type="pil", label="Bild", height=360)
run_btn = gr.Button("Analysieren", variant="primary")
with gr.Column(scale=1):
output = gr.Textbox(label="Ausgabe", lines=22, show_copy_button=True)
run_btn.click(fn=infer, inputs=[image_in], outputs=[output], api_name="predict")
demo.queue().launch()
|