import gradio as gr import tempfile, os, shutil, zipfile, glob, time import os from PIL import Image, ImageDraw, ImageFont os.environ["GRADIO_TEMP_DIR"] = "/tmp/gradio" os.environ["TMPDIR"] = "/tmp" os.environ["HF_HOME"] = "/tmp/hfhome" os.environ["HF_HUB_CACHE"] = "/tmp/hfhome/hub" os.environ["XDG_CACHE_HOME"] = "/tmp/xdgcache" # au cas où os.environ["TORCH_HOME"] = "/tmp/torchhome" import numpy as np import pandas as pd import torch from torchvision import transforms from transformers import SegformerForSemanticSegmentation import tifffile from scipy.ndimage import label from huggingface_hub import hf_hub_download PREVSIZE=600 PLACEHOLDER_W, PLACEHOLDER_H = PREVSIZE, 340 PLACEHOLDER = Image.new("RGB", (PLACEHOLDER_W, PLACEHOLDER_H), (245, 245, 245)) def _safe_rmtree(p): try: shutil.rmtree(p, ignore_errors=True) except Exception: pass # -- ménage de démarrage (si quelque chose traîne) -- for _p in ["/home/user/.cache", "/home/user/.local/share/Trash", "/home/user/app/tmp"]: _safe_rmtree(_p) CACHE_DIR = "/tmp/hfhome/transformers" # même valeur que ci-dessus for _p in ["/home/user/.cache", "/home/user/.local/share/Trash", "/home/user/app/tmp","/home/user/.gradio", "/tmp/gradio", "/tmp/hfhome", "/tmp/xdgcache", "/tmp/torchhome"]: _safe_rmtree(_p) # ======== CONFIG ======== DEVICE = "cpu" # Space CPU TARGET_SIZE = (341, 512) AR_MODEL_ID = "RomainFernandez/MicroDeepAerenchyma-Lacuna" CE_MODEL_ID = "RomainFernandez/MicroDeepAerenchyma-Cortex" # (Optionnel) dataset de démo -> stocke un petit ZIP (qq images 341x512 ou proches) DEMO_DATASET_ID = "RomainFernandez/MicroDeepAerenchyma-Demo" # à créer par toi DEMO_ZIP_FILENAME = "demo_images.zip" # mets ce nom dans le dataset # ======== MODELS (load once) ======== ar_model = SegformerForSemanticSegmentation.from_pretrained( AR_MODEL_ID, cache_dir=CACHE_DIR).to(DEVICE).eval() ce_model = SegformerForSemanticSegmentation.from_pretrained( CE_MODEL_ID, cache_dir=CACHE_DIR).to(DEVICE).eval() # ---- PURGE OPTIONNELLE DU CACHE HF (réduit l'usage disque) ---- try: from huggingface_hub import scan_cache_dir # CACHE_DIR = "/tmp/hfhome/transformers" -> on monte d'un cran: "/tmp/hfhome" cache_root = os.path.abspath(os.path.join(CACHE_DIR, "..")) info = scan_cache_dir(cache_root) # ne garder que les 2 dépôts utiles à l'app keep = {AR_MODEL_ID, CE_MODEL_ID} for repo in info.repos: if getattr(repo, "repo_id", None) not in keep: # supprime entièrement le repo non utilisé repo_path = getattr(repo, "repo_path", None) if repo_path and os.path.exists(repo_path): _safe_rmtree(repo_path) except Exception: # pas bloquant si la structure du cache change pass transform = transforms.Compose([ transforms.Resize(TARGET_SIZE), transforms.ToTensor(), transforms.Lambda(lambda x: x.repeat(3, 1, 1)) # grayscale -> 3 canaux (SegFormer) ]) def predict_mask(model, image_L): x = transform(image_L).unsqueeze(0).to(DEVICE) with torch.no_grad(): logits = model(x).logits up = torch.nn.functional.interpolate(logits, size=x.shape[-2:], mode="bilinear", align_corners=False) pred = torch.argmax(up, dim=1)[0].cpu().numpy() return pred def largest_cc(binary): lab, n = label(binary) if n == 0: return binary*0 largest = np.argmax(np.bincount(lab.flat)[1:]) + 1 return (lab == largest).astype(np.uint8) def safe_list_images(root): exts = (".png",".jpg",".jpeg",".tif",".tiff") out = [] for p in glob.glob(os.path.join(root, "**/*"), recursive=True): if os.path.isfile(p) and p.lower().endswith(exts): out.append(p) return sorted(out) # On garde trace des répertoires temporaires de la session ACTIVE_WORKDIRS = set() BUSY = False def overlay_counter(img_np: np.ndarray, idx: int, total: int) -> Image.Image: """Ajoute 'Image idx/total' en haut-gauche, fond noir, police plus grande.""" im = Image.fromarray(img_np) if isinstance(img_np, np.ndarray) else img_np.copy() draw = ImageDraw.Draw(im) text = f"Image {idx}/{total}" # Taille de police proportionnelle à la largeur de l'image (≈ 3%) # min/max pour éviter trop petit/grand try: W = im.width except Exception: W = 800 fontsize = max(18, min(int(W * 0.03), 48)) # Essaie une vraie police (HF Spaces a en général DejaVuSans) try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", fontsize) except Exception: try: font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", fontsize) except Exception: font = ImageFont.load_default() # fallback pad_x, pad_y = max(6, fontsize // 3), max(4, fontsize // 4) # bbox du texte left, top, right, bottom = draw.textbbox((0, 0), text, font=font) tw, th = right - left, bottom - top box = (0, 0, tw + 2 * pad_x, th + 2 * pad_y) # fond noir (opaque). Si tu veux semi-opaque, convertis en RGBA + composite. draw.rectangle(box, fill=(0, 0, 0)) draw.text((pad_x, pad_y), text, fill=(255, 255, 255), font=font) return im def run_job(files_or_zip, use_demo=False, progress=gr.Progress()): # garde: espace disque minimal requis (ex: 5 Go) MIN_FREE = 5 * (1024**3) try: usage = shutil.disk_usage("/") if usage.free < MIN_FREE: yield None, ( "⚠️ Low free disk on Space. Please retry later " "(administrator: increase persistent storage or clean caches)." ), current_preview, None return except Exception: pass global BUSY first_img_time = None # durée mesurée sur la 1ère image current_preview = PLACEHOLDER.copy() # Refuser immédiatement si un autre job tourne if BUSY: yield None, ( "⚠️ Server already in use — please come back in a few minutes " "(another batch is running on this CPU Space)." ), current_preview, None return # BUSY = True try: previews = [] # liste d’images PIL pour la galerie # --- Prépare un espace de travail éphémère --- workdir = tempfile.mkdtemp(prefix="job_") ACTIVE_WORKDIRS.add(workdir) in_root = os.path.join(workdir, "in") out_root = os.path.join(workdir, "out") os.makedirs(in_root, exist_ok=True) os.makedirs(out_root, exist_ok=True) # --- Récupère les entrées --- file_list = [] if use_demo: demo_zip = hf_hub_download(repo_id=DEMO_DATASET_ID, filename=DEMO_ZIP_FILENAME, repo_type="dataset") with zipfile.ZipFile(demo_zip) as z: z.extractall(in_root) file_list = safe_list_images(in_root) total = len(file_list) previews = [] # images pour la galerie else: # files_or_zip est une LISTE de chemins (strings) quand type="filepath" paths = files_or_zip or [] for p in paths: if not p: continue src = p if isinstance(p, str) else p.name if src.lower().endswith(".zip"): with zipfile.ZipFile(src) as z: z.extractall(in_root) else: # copie d'un fichier image individuel shutil.copy(src, os.path.join(in_root, os.path.basename(src))) file_list = safe_list_images(in_root) total = len(file_list) previews = [] # images pour la galerie if not file_list: yield None, "No images found. Upload multiple images or a ZIP (recommended for folders).", current_preview, None return start = time.time() # --- constants for RAW contrast --- p_lo, p_hi, gamma = 5.0, 99.5, 1.30 yield None, f"Found {total} image(s). Starting…", current_preview, workdir progress(0.0) # initialise # Regroupe par sous-dossier relatif (pour mirrorer la sortie) rel_map = {} for p in file_list: rel = os.path.relpath(p, in_root) rel_dir = os.path.dirname(rel) # "" si racine rel_map.setdefault(rel_dir, []).append(p) processed = 0 for rel_dir, paths in rel_map.items(): out_dir = os.path.join(out_root, rel_dir) pred_dir = os.path.join(out_dir, "predicted_images") ar_dir = os.path.join(pred_dir, "pred_AR") cortex_dir= os.path.join(pred_dir, "pred_Cortex") endo_dir = os.path.join(pred_dir, "pred_Endoderm") global_dir= os.path.join(pred_dir, "pred_global") for d in [ar_dir, cortex_dir, endo_dir, global_dir]: os.makedirs(d, exist_ok=True) rows = [] for p in paths: image_name = os.path.basename(p) #yield None, f"Processing {image_name}…", previews,workdir progress(processed / total if total > 0 else 0.0) im = Image.open(p).convert("L") # ---- PREVIEW RAW à T/4 (à partir de la 2e image) ---- # On fabrique un RAW "hint" à la taille du réseau (TARGET_SIZE: H=341, W=512) w_hint, h_hint = TARGET_SIZE[1], TARGET_SIZE[0] imr_hint = im.resize((w_hint, h_hint), Image.BILINEAR) raw_hint = np.array(imr_hint, dtype=np.uint8) # contraste doux lo_hint, hi_hint = np.percentile(raw_hint, (p_lo, p_hi)) if hi_hint > lo_hint: x = (raw_hint.astype(np.float32) - lo_hint) / (hi_hint - lo_hint) x = np.clip(x, 0.0, 1.0) x = np.power(x, 1.0 / gamma) raw_boost_hint = (x * 255.0).astype(np.uint8) else: raw_boost_hint = raw_hint raw_rgb_hint = np.repeat(raw_boost_hint[..., None], 3, axis=-1).astype(np.uint8) black = np.zeros_like(raw_rgb_hint, dtype=np.uint8) interim = np.concatenate([raw_rgb_hint, black], axis=1).astype(np.uint8) # à partir de la 2e image, attendre T/4 (max 0.30s), puis publier l'intermédiaire if (processed >= 1) and (first_img_time is not None) and (first_img_time > 0): delay = min(0.7 * first_img_time, 2) if delay > 0.02: time.sleep(delay) interim_prev = overlay_counter(interim, processed + 1, total) if interim_prev.width > PREVSIZE: hh = int(interim_prev.height * (PREVSIZE / interim_prev.width)) interim_prev = interim_prev.resize((PREVSIZE, hh), Image.LANCZOS) current_preview = interim_prev yield None, f"Preview (RAW) {processed+1}/{total}: {image_name}", current_preview, workdir # Prédictions t0 = time.time() if processed == 0 else None pred_ar = predict_mask(ar_model, im) pred_ce = predict_mask(ce_model, im) mask_ar = (pred_ar == 1).astype(np.uint8)*255 mask_cortex = (pred_ce == 1).astype(np.uint8)*255 mask_endo = (pred_ce == 2).astype(np.uint8)*255 # Plus grande composante CE combined_ce = ((mask_cortex > 0) | (mask_endo > 0)).astype(np.uint8) largest = largest_cc(combined_ce) ar_in = ((mask_ar > 0) & (largest > 0)).astype(np.uint8) cortex_in = ((mask_cortex > 0) & (largest > 0)).astype(np.uint8) total_ce = int(largest.sum()) total_cortex = int(cortex_in.sum()) ar_px = int(ar_in.sum()) AR_in_CE = (ar_px / total_ce * 100) if total_ce > 0 else 0.0 AR_in_Cortex = (ar_px / total_cortex * 100) if total_cortex > 0 else 0.0 rows.append({ "image_name": image_name, "AR_percent_in_Cortex+Endoderm": AR_in_CE, "AR_percent_in_Cortex_only": AR_in_Cortex, "Cortex_surface_in_pixels": total_cortex, "Stele_surface_in_pixels":total_ce-total_cortex, "AR_surface_in_pixels": ar_px }) for d in (ar_dir, cortex_dir, endo_dir, global_dir): os.makedirs(d, exist_ok=True) # Écrit avec garde-fous try: tifffile.imwrite(os.path.join(ar_dir, image_name), mask_ar, compression="deflate") except FileNotFoundError: os.makedirs(ar_dir, exist_ok=True) tifffile.imwrite(os.path.join(ar_dir, image_name), mask_ar, compression="deflate") try: tifffile.imwrite(os.path.join(cortex_dir, image_name), mask_cortex, compression="deflate") except FileNotFoundError: os.makedirs(cortex_dir, exist_ok=True) tifffile.imwrite(os.path.join(cortex_dir, image_name), mask_cortex, compression="deflate") try: tifffile.imwrite(os.path.join(endo_dir, image_name), mask_endo, compression="deflate") except FileNotFoundError: os.makedirs(endo_dir, exist_ok=True) tifffile.imwrite(os.path.join(endo_dir, image_name), mask_endo, compression="deflate") # Panel TIFF [raw_boost | overlay RGB] h, w = mask_ar.shape imr = im.resize((w, h), Image.BILINEAR) raw = np.array(imr, dtype=np.uint8) # --- RAW contrast (softer) --- # percentiles un peu plus serrés pour réduire l'écrêtage p_lo, p_hi = 5.0, 99.5 lo, hi = np.percentile(raw, (p_lo, p_hi)) if hi > lo: # mise à l'échelle linéaire 0..1 x = (raw.astype(np.float32) - lo) / (hi - lo) x = np.clip(x, 0.0, 1.0) # courbe gamma douce (>1 comprime les hautes lumières) gamma = 1.30 x = np.power(x, 1.0 / gamma) # 1/gamma ~ 0.833 raw_boost = (x * 255.0).astype(np.uint8) else: raw_boost = raw raw_rgb = np.repeat(raw_boost[..., None], 3, axis=-1).astype(np.uint8) RED_SCALE, GREEN_SCALE = 0.5, 0.3 # ajustables r = np.clip(mask_ar.astype(np.float32) * RED_SCALE, 0, 255).astype(np.uint8) g = np.clip(mask_cortex.astype(np.float32) * GREEN_SCALE, 0, 255).astype(np.uint8) b = raw_boost # on garde le raw inchangé en B (tu peux mettre raw_boost si tu préfères) rgb = np.stack([r, g, b], axis=-1) panel = np.concatenate([raw_rgb, rgb], axis=1).astype(np.uint8) base = os.path.splitext(image_name)[0] os.makedirs(global_dir, exist_ok=True) tifffile.imwrite(os.path.join(global_dir, base + "_panel.tif"), panel, photometric="rgb", compression="deflate") # Aperçu avec compteur (idx=processed+1) # fixe la durée de référence après la 1re image if t0 is not None and first_img_time is None: first_img_time = max(0.02, time.time() - t0) # borne basse 20 ms preview_img = overlay_counter(panel, processed + 1, total) # miniature raisonnable (éviter des images immenses côté client) maxw = PREVSIZE if preview_img.width > maxw: h = int(preview_img.height * (maxw / preview_img.width)) preview_img = preview_img.resize((maxw, h), Image.LANCZOS) current_preview = preview_img # <- on remplace l’aperçu courant yield None, f"Processed {processed+1}/{total}: {image_name}", current_preview, workdir processed += 1 progress(processed / total if total > 0 else 1.0) #yield None, f"Processed {processed}/{len(file_list)}: {image_name}", current_preview, workdir if rows: os.makedirs(out_dir, exist_ok=True) pd.DataFrame(rows).to_csv(os.path.join(out_dir, "pred_metrics.csv"), index=False) # ZIP final zip_path = os.path.join(workdir, "results.zip") shutil.make_archive(zip_path[:-4], "zip", out_root) # libère l’espace immédiatement : on supprime le dossier des sorties _safe_rmtree(out_root) elapsed = time.time() - start yield zip_path, f"Done. {processed} image(s) in {elapsed:.1f}s. Download the ZIP below.", current_preview, workdir # On ne supprime pas ici pour laisser le téléchargement se faire finally: BUSY = False def cleanup_all(): # Appelé à la fin de la session utilisateur for wd in list(ACTIVE_WORKDIRS): try: _safe_rmtree(wd) shutil.rmtree(wd, ignore_errors=True) ACTIVE_WORKDIRS.discard(wd) except Exception: pass # purge des caches au cas où for _p in ["/tmp/hfhome", "/tmp/xdgcache", "/tmp/torchhome"]: _safe_rmtree(_p) for wd in list(ACTIVE_WORKDIRS): try: _safe_rmtree(wd) ACTIVE_WORKDIRS.discard(wd) except Exception: pass # Purge de tous les caches/temp usuels for _p in [ "/tmp/gradio", "/tmp/hfhome", "/tmp/xdgcache", "/tmp/torchhome", "/home/user/.cache", "/home/user/.gradio", "/home/user/.local/share/Trash" ]: _safe_rmtree(_p) with gr.Blocks(theme=gr.themes.Soft(), title="MicroDeepAerenchyma — Segmentation (CPU)") as demo: BANNER_URL = "https://huggingface.co/spaces/RomainFernandez/DeepAerenchyma/resolve/main/banner.png" gr.HTML(f"""
""") gr.HTML(""" """) gr.HTML(""" """) gr.HTML(""" """) gr.HTML(""" """) title_md = gr.Markdown( """ # Deep Aerenchyma **High-throughput anatomical phenotyping based on Segformer** **Authors:** Hani Atef, Romain Fernandez (CIRAD, UMR AGAP, France) **Funder:** Global Methane Hub, CIRAD Arize project Upload **multiple images** *or* **one ZIP** (supports nested folders). Output: masks, result panels in `pred_global`, and `pred_metrics.csv`. """, elem_id="titleblock" ) # Upload (ligne 1) with gr.Row(): files = gr.Files( label="Upload images or one ZIP (recommended)", file_count="multiple", type="filepath", elem_id="uploader" # <-- on va le styler ) # Boutons (ligne 2) with gr.Row(): start_btn = gr.Button("Start with my own images", variant="primary") with gr.Row(): demo_btn = gr.Button("Test with demo images", variant="primary") # Logs (ligne 5) with gr.Row(): log = gr.Textbox(label="Logs", lines=4) # Préview (ligne 3) with gr.Row(): live_preview = gr.Image( label="Live preview", value=PLACEHOLDER, interactive=False, elem_id="live_preview", height=None ) # Download (ligne 4) with gr.Row(): out_zip = gr.File(label="Download results (ZIP)") workdir_state = gr.State(value=None) def driver(files): from builtins import globals as _g if _g().get("BUSY", False): # toast + log immédiats (sans lancer le job) yield {log: "⚠️ Server already in use — please retry in a few minutes.", live_preview: PLACEHOLDER, workdir_state: None} return for zip_path, msg, preview, wd in run_job(files, use_demo=False): update = {log: msg, live_preview: preview, workdir_state: wd} if zip_path is not None: update[out_zip] = zip_path yield update def driver_demo(): from builtins import globals as _g if _g().get("BUSY", False): yield {log: "⚠️ Server already in use — please retry in a few minutes.", live_preview: PLACEHOLDER, workdir_state: None} return for zip_path, msg, preview, wd in run_job(None, use_demo=True): update = {log: msg, live_preview: preview, workdir_state: wd} if zip_path is not None: update[out_zip] = zip_path yield update start_btn.click( driver, inputs=[files], outputs=[out_zip, log, live_preview, workdir_state], show_progress="hidden" # <-- AJOUT ) demo_btn.click( driver_demo, inputs=None, outputs=[out_zip, log, live_preview, workdir_state], show_progress="hidden" # <-- AJOUT ) demo.unload(lambda: cleanup_all()) demo.queue().launch(ssr_mode=False)