DeepAerenchyma / app.py
RomainFernandezCIRAD's picture
washing space
a99da1e
import gradio as gr
import tempfile, os, shutil, zipfile, glob, time
import os
from PIL import Image, ImageDraw, ImageFont
os.environ["GRADIO_TEMP_DIR"] = "/tmp/gradio"
os.environ["TMPDIR"] = "/tmp"
os.environ["HF_HOME"] = "/tmp/hfhome"
os.environ["HF_HUB_CACHE"] = "/tmp/hfhome/hub"
os.environ["XDG_CACHE_HOME"] = "/tmp/xdgcache" # au cas où
os.environ["TORCH_HOME"] = "/tmp/torchhome"
import numpy as np
import pandas as pd
import torch
from torchvision import transforms
from transformers import SegformerForSemanticSegmentation
import tifffile
from scipy.ndimage import label
from huggingface_hub import hf_hub_download
PREVSIZE=600
PLACEHOLDER_W, PLACEHOLDER_H = PREVSIZE, 340
PLACEHOLDER = Image.new("RGB", (PLACEHOLDER_W, PLACEHOLDER_H), (245, 245, 245))
def _safe_rmtree(p):
try: shutil.rmtree(p, ignore_errors=True)
except Exception: pass
# -- ménage de démarrage (si quelque chose traîne) --
for _p in ["/home/user/.cache", "/home/user/.local/share/Trash", "/home/user/app/tmp"]:
_safe_rmtree(_p)
CACHE_DIR = "/tmp/hfhome/transformers" # même valeur que ci-dessus
for _p in ["/home/user/.cache", "/home/user/.local/share/Trash", "/home/user/app/tmp","/home/user/.gradio", "/tmp/gradio", "/tmp/hfhome", "/tmp/xdgcache", "/tmp/torchhome"]:
_safe_rmtree(_p)
# ======== CONFIG ========
DEVICE = "cpu" # Space CPU
TARGET_SIZE = (341, 512)
AR_MODEL_ID = "RomainFernandez/MicroDeepAerenchyma-Lacuna"
CE_MODEL_ID = "RomainFernandez/MicroDeepAerenchyma-Cortex"
# (Optionnel) dataset de démo -> stocke un petit ZIP (qq images 341x512 ou proches)
DEMO_DATASET_ID = "RomainFernandez/MicroDeepAerenchyma-Demo" # à créer par toi
DEMO_ZIP_FILENAME = "demo_images.zip" # mets ce nom dans le dataset
# ======== MODELS (load once) ========
ar_model = SegformerForSemanticSegmentation.from_pretrained(
AR_MODEL_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()
ce_model = SegformerForSemanticSegmentation.from_pretrained(
CE_MODEL_ID, cache_dir=CACHE_DIR).to(DEVICE).eval()
# ---- PURGE OPTIONNELLE DU CACHE HF (réduit l'usage disque) ----
try:
from huggingface_hub import scan_cache_dir
# CACHE_DIR = "/tmp/hfhome/transformers" -> on monte d'un cran: "/tmp/hfhome"
cache_root = os.path.abspath(os.path.join(CACHE_DIR, ".."))
info = scan_cache_dir(cache_root)
# ne garder que les 2 dépôts utiles à l'app
keep = {AR_MODEL_ID, CE_MODEL_ID}
for repo in info.repos:
if getattr(repo, "repo_id", None) not in keep:
# supprime entièrement le repo non utilisé
repo_path = getattr(repo, "repo_path", None)
if repo_path and os.path.exists(repo_path):
_safe_rmtree(repo_path)
except Exception:
# pas bloquant si la structure du cache change
pass
transform = transforms.Compose([
transforms.Resize(TARGET_SIZE),
transforms.ToTensor(),
transforms.Lambda(lambda x: x.repeat(3, 1, 1)) # grayscale -> 3 canaux (SegFormer)
])
def predict_mask(model, image_L):
x = transform(image_L).unsqueeze(0).to(DEVICE)
with torch.no_grad():
logits = model(x).logits
up = torch.nn.functional.interpolate(logits, size=x.shape[-2:], mode="bilinear", align_corners=False)
pred = torch.argmax(up, dim=1)[0].cpu().numpy()
return pred
def largest_cc(binary):
lab, n = label(binary)
if n == 0: return binary*0
largest = np.argmax(np.bincount(lab.flat)[1:]) + 1
return (lab == largest).astype(np.uint8)
def safe_list_images(root):
exts = (".png",".jpg",".jpeg",".tif",".tiff")
out = []
for p in glob.glob(os.path.join(root, "**/*"), recursive=True):
if os.path.isfile(p) and p.lower().endswith(exts):
out.append(p)
return sorted(out)
# On garde trace des répertoires temporaires de la session
ACTIVE_WORKDIRS = set()
BUSY = False
def overlay_counter(img_np: np.ndarray, idx: int, total: int) -> Image.Image:
"""Ajoute 'Image idx/total' en haut-gauche, fond noir, police plus grande."""
im = Image.fromarray(img_np) if isinstance(img_np, np.ndarray) else img_np.copy()
draw = ImageDraw.Draw(im)
text = f"Image {idx}/{total}"
# Taille de police proportionnelle à la largeur de l'image (≈ 3%)
# min/max pour éviter trop petit/grand
try:
W = im.width
except Exception:
W = 800
fontsize = max(18, min(int(W * 0.03), 48))
# Essaie une vraie police (HF Spaces a en général DejaVuSans)
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans-Bold.ttf", fontsize)
except Exception:
try:
font = ImageFont.truetype("/usr/share/fonts/truetype/dejavu/DejaVuSans.ttf", fontsize)
except Exception:
font = ImageFont.load_default() # fallback
pad_x, pad_y = max(6, fontsize // 3), max(4, fontsize // 4)
# bbox du texte
left, top, right, bottom = draw.textbbox((0, 0), text, font=font)
tw, th = right - left, bottom - top
box = (0, 0, tw + 2 * pad_x, th + 2 * pad_y)
# fond noir (opaque). Si tu veux semi-opaque, convertis en RGBA + composite.
draw.rectangle(box, fill=(0, 0, 0))
draw.text((pad_x, pad_y), text, fill=(255, 255, 255), font=font)
return im
def run_job(files_or_zip, use_demo=False, progress=gr.Progress()):
# garde: espace disque minimal requis (ex: 5 Go)
MIN_FREE = 5 * (1024**3)
try:
usage = shutil.disk_usage("/")
if usage.free < MIN_FREE:
yield None, (
"⚠️ Low free disk on Space. Please retry later "
"(administrator: increase persistent storage or clean caches)."
), current_preview, None
return
except Exception:
pass
global BUSY
first_img_time = None # durée mesurée sur la 1ère image
current_preview = PLACEHOLDER.copy()
# Refuser immédiatement si un autre job tourne
if BUSY:
yield None, (
"⚠️ Server already in use — please come back in a few minutes "
"(another batch is running on this CPU Space)."
), current_preview, None
return
# BUSY = True
try:
previews = [] # liste d’images PIL pour la galerie
# --- Prépare un espace de travail éphémère ---
workdir = tempfile.mkdtemp(prefix="job_")
ACTIVE_WORKDIRS.add(workdir)
in_root = os.path.join(workdir, "in")
out_root = os.path.join(workdir, "out")
os.makedirs(in_root, exist_ok=True)
os.makedirs(out_root, exist_ok=True)
# --- Récupère les entrées ---
file_list = []
if use_demo:
demo_zip = hf_hub_download(repo_id=DEMO_DATASET_ID, filename=DEMO_ZIP_FILENAME, repo_type="dataset")
with zipfile.ZipFile(demo_zip) as z:
z.extractall(in_root)
file_list = safe_list_images(in_root)
total = len(file_list)
previews = [] # images pour la galerie
else:
# files_or_zip est une LISTE de chemins (strings) quand type="filepath"
paths = files_or_zip or []
for p in paths:
if not p:
continue
src = p if isinstance(p, str) else p.name
if src.lower().endswith(".zip"):
with zipfile.ZipFile(src) as z:
z.extractall(in_root)
else:
# copie d'un fichier image individuel
shutil.copy(src, os.path.join(in_root, os.path.basename(src)))
file_list = safe_list_images(in_root)
total = len(file_list)
previews = [] # images pour la galerie
if not file_list:
yield None, "No images found. Upload multiple images or a ZIP (recommended for folders).", current_preview, None
return
start = time.time()
# --- constants for RAW contrast ---
p_lo, p_hi, gamma = 5.0, 99.5, 1.30
yield None, f"Found {total} image(s). Starting…", current_preview, workdir
progress(0.0) # initialise
# Regroupe par sous-dossier relatif (pour mirrorer la sortie)
rel_map = {}
for p in file_list:
rel = os.path.relpath(p, in_root)
rel_dir = os.path.dirname(rel) # "" si racine
rel_map.setdefault(rel_dir, []).append(p)
processed = 0
for rel_dir, paths in rel_map.items():
out_dir = os.path.join(out_root, rel_dir)
pred_dir = os.path.join(out_dir, "predicted_images")
ar_dir = os.path.join(pred_dir, "pred_AR")
cortex_dir= os.path.join(pred_dir, "pred_Cortex")
endo_dir = os.path.join(pred_dir, "pred_Endoderm")
global_dir= os.path.join(pred_dir, "pred_global")
for d in [ar_dir, cortex_dir, endo_dir, global_dir]:
os.makedirs(d, exist_ok=True)
rows = []
for p in paths:
image_name = os.path.basename(p)
#yield None, f"Processing {image_name}…", previews,workdir
progress(processed / total if total > 0 else 0.0)
im = Image.open(p).convert("L")
# ---- PREVIEW RAW à T/4 (à partir de la 2e image) ----
# On fabrique un RAW "hint" à la taille du réseau (TARGET_SIZE: H=341, W=512)
w_hint, h_hint = TARGET_SIZE[1], TARGET_SIZE[0]
imr_hint = im.resize((w_hint, h_hint), Image.BILINEAR)
raw_hint = np.array(imr_hint, dtype=np.uint8)
# contraste doux
lo_hint, hi_hint = np.percentile(raw_hint, (p_lo, p_hi))
if hi_hint > lo_hint:
x = (raw_hint.astype(np.float32) - lo_hint) / (hi_hint - lo_hint)
x = np.clip(x, 0.0, 1.0)
x = np.power(x, 1.0 / gamma)
raw_boost_hint = (x * 255.0).astype(np.uint8)
else:
raw_boost_hint = raw_hint
raw_rgb_hint = np.repeat(raw_boost_hint[..., None], 3, axis=-1).astype(np.uint8)
black = np.zeros_like(raw_rgb_hint, dtype=np.uint8)
interim = np.concatenate([raw_rgb_hint, black], axis=1).astype(np.uint8)
# à partir de la 2e image, attendre T/4 (max 0.30s), puis publier l'intermédiaire
if (processed >= 1) and (first_img_time is not None) and (first_img_time > 0):
delay = min(0.7 * first_img_time, 2)
if delay > 0.02:
time.sleep(delay)
interim_prev = overlay_counter(interim, processed + 1, total)
if interim_prev.width > PREVSIZE:
hh = int(interim_prev.height * (PREVSIZE / interim_prev.width))
interim_prev = interim_prev.resize((PREVSIZE, hh), Image.LANCZOS)
current_preview = interim_prev
yield None, f"Preview (RAW) {processed+1}/{total}: {image_name}", current_preview, workdir
# Prédictions
t0 = time.time() if processed == 0 else None
pred_ar = predict_mask(ar_model, im)
pred_ce = predict_mask(ce_model, im)
mask_ar = (pred_ar == 1).astype(np.uint8)*255
mask_cortex = (pred_ce == 1).astype(np.uint8)*255
mask_endo = (pred_ce == 2).astype(np.uint8)*255
# Plus grande composante CE
combined_ce = ((mask_cortex > 0) | (mask_endo > 0)).astype(np.uint8)
largest = largest_cc(combined_ce)
ar_in = ((mask_ar > 0) & (largest > 0)).astype(np.uint8)
cortex_in = ((mask_cortex > 0) & (largest > 0)).astype(np.uint8)
total_ce = int(largest.sum())
total_cortex = int(cortex_in.sum())
ar_px = int(ar_in.sum())
AR_in_CE = (ar_px / total_ce * 100) if total_ce > 0 else 0.0
AR_in_Cortex = (ar_px / total_cortex * 100) if total_cortex > 0 else 0.0
rows.append({
"image_name": image_name,
"AR_percent_in_Cortex+Endoderm": AR_in_CE,
"AR_percent_in_Cortex_only": AR_in_Cortex,
"Cortex_surface_in_pixels": total_cortex,
"Stele_surface_in_pixels":total_ce-total_cortex,
"AR_surface_in_pixels": ar_px
})
for d in (ar_dir, cortex_dir, endo_dir, global_dir):
os.makedirs(d, exist_ok=True)
# Écrit avec garde-fous
try:
tifffile.imwrite(os.path.join(ar_dir, image_name), mask_ar, compression="deflate")
except FileNotFoundError:
os.makedirs(ar_dir, exist_ok=True)
tifffile.imwrite(os.path.join(ar_dir, image_name), mask_ar, compression="deflate")
try:
tifffile.imwrite(os.path.join(cortex_dir, image_name), mask_cortex, compression="deflate")
except FileNotFoundError:
os.makedirs(cortex_dir, exist_ok=True)
tifffile.imwrite(os.path.join(cortex_dir, image_name), mask_cortex, compression="deflate")
try:
tifffile.imwrite(os.path.join(endo_dir, image_name), mask_endo, compression="deflate")
except FileNotFoundError:
os.makedirs(endo_dir, exist_ok=True)
tifffile.imwrite(os.path.join(endo_dir, image_name), mask_endo, compression="deflate")
# Panel TIFF [raw_boost | overlay RGB]
h, w = mask_ar.shape
imr = im.resize((w, h), Image.BILINEAR)
raw = np.array(imr, dtype=np.uint8)
# --- RAW contrast (softer) ---
# percentiles un peu plus serrés pour réduire l'écrêtage
p_lo, p_hi = 5.0, 99.5
lo, hi = np.percentile(raw, (p_lo, p_hi))
if hi > lo:
# mise à l'échelle linéaire 0..1
x = (raw.astype(np.float32) - lo) / (hi - lo)
x = np.clip(x, 0.0, 1.0)
# courbe gamma douce (>1 comprime les hautes lumières)
gamma = 1.30
x = np.power(x, 1.0 / gamma) # 1/gamma ~ 0.833
raw_boost = (x * 255.0).astype(np.uint8)
else:
raw_boost = raw
raw_rgb = np.repeat(raw_boost[..., None], 3, axis=-1).astype(np.uint8)
RED_SCALE, GREEN_SCALE = 0.5, 0.3 # ajustables
r = np.clip(mask_ar.astype(np.float32) * RED_SCALE, 0, 255).astype(np.uint8)
g = np.clip(mask_cortex.astype(np.float32) * GREEN_SCALE, 0, 255).astype(np.uint8)
b = raw_boost # on garde le raw inchangé en B (tu peux mettre raw_boost si tu préfères)
rgb = np.stack([r, g, b], axis=-1)
panel = np.concatenate([raw_rgb, rgb], axis=1).astype(np.uint8)
base = os.path.splitext(image_name)[0]
os.makedirs(global_dir, exist_ok=True)
tifffile.imwrite(os.path.join(global_dir, base + "_panel.tif"), panel, photometric="rgb", compression="deflate")
# Aperçu avec compteur (idx=processed+1)
# fixe la durée de référence après la 1re image
if t0 is not None and first_img_time is None:
first_img_time = max(0.02, time.time() - t0) # borne basse 20 ms
preview_img = overlay_counter(panel, processed + 1, total)
# miniature raisonnable (éviter des images immenses côté client)
maxw = PREVSIZE
if preview_img.width > maxw:
h = int(preview_img.height * (maxw / preview_img.width))
preview_img = preview_img.resize((maxw, h), Image.LANCZOS)
current_preview = preview_img # <- on remplace l’aperçu courant
yield None, f"Processed {processed+1}/{total}: {image_name}", current_preview, workdir
processed += 1
progress(processed / total if total > 0 else 1.0)
#yield None, f"Processed {processed}/{len(file_list)}: {image_name}", current_preview, workdir
if rows:
os.makedirs(out_dir, exist_ok=True)
pd.DataFrame(rows).to_csv(os.path.join(out_dir, "pred_metrics.csv"), index=False)
# ZIP final
zip_path = os.path.join(workdir, "results.zip")
shutil.make_archive(zip_path[:-4], "zip", out_root)
# libère l’espace immédiatement : on supprime le dossier des sorties
_safe_rmtree(out_root)
elapsed = time.time() - start
yield zip_path, f"Done. {processed} image(s) in {elapsed:.1f}s. Download the ZIP below.", current_preview, workdir
# On ne supprime pas ici pour laisser le téléchargement se faire
finally:
BUSY = False
def cleanup_all():
# Appelé à la fin de la session utilisateur
for wd in list(ACTIVE_WORKDIRS):
try:
_safe_rmtree(wd)
shutil.rmtree(wd, ignore_errors=True)
ACTIVE_WORKDIRS.discard(wd)
except Exception:
pass
# purge des caches au cas où
for _p in ["/tmp/hfhome", "/tmp/xdgcache", "/tmp/torchhome"]:
_safe_rmtree(_p)
for wd in list(ACTIVE_WORKDIRS):
try:
_safe_rmtree(wd)
ACTIVE_WORKDIRS.discard(wd)
except Exception:
pass
# Purge de tous les caches/temp usuels
for _p in [
"/tmp/gradio", "/tmp/hfhome", "/tmp/xdgcache", "/tmp/torchhome",
"/home/user/.cache", "/home/user/.gradio", "/home/user/.local/share/Trash"
]:
_safe_rmtree(_p)
with gr.Blocks(theme=gr.themes.Soft(), title="MicroDeepAerenchyma — Segmentation (CPU)") as demo:
BANNER_URL = "https://huggingface.co/spaces/RomainFernandez/DeepAerenchyma/resolve/main/banner.png"
gr.HTML(f"""
<style>
.banner-wrap {{ margin: 4px 0 6px; text-align:center; }} /* marge réduite */
.banner-wrap img {{
display:inline-block;
max-width:100%;
height:auto;
max-height:clamp(140px, 24vh, 256px);
border-radius:12px;
box-shadow: 0 2px 8px rgba(0,0,0,.08);
}}
/* ces règles peuvent rester si tu veux affiner plus tard
.prose h1:first-child {{ margin-top: 0.1rem !important; margin-bottom: 0.1rem !important; }}
.prose p {{ margin-top: 0.1rem !important; margin-bottom: 0.1rem !important; }}
*/
</style>
<div class="banner-wrap">
<img src="{BANNER_URL}" alt="Deep Aerenchyma">
</div>
""")
gr.HTML("""
<style>
#live_preview { min-height: 220px; } /* réserve une hauteur, évite les sauts */
#live_preview img { width: 100% !important; height: auto !important; object-fit: contain; }
</style>
""")
gr.HTML("""
<style>
/* réduire la largeur visuelle et la typo du Files */
#uploader * { font-size: 0.92rem !important; }
/* zone de drop compacte (selector robuste via testid présent en v5) */
#uploader [data-testid="dropzone"] {
min-height: 84px !important;
padding: 8px !important;
}
</style>
""")
gr.HTML("""
<style>
html, body { height: 100%; overflow-y: auto !important; }
.gradio-container { min-height: 100vh; overflow-y: auto !important; }
@media (max-width: 600px) {
.gradio-container { padding: 8px !important; }
}
</style>
""")
gr.HTML("""
<style>
/* H1 un peu plus grand + compact */
#titleblock h1 {
font-size: 1.55rem !important; /* ≈ +10–15% */
line-height: 1.25 !important;
margin: 6px 0 8px !important; /* réduit l’espace sous le H1 */
}
#titleblock p {
margin: 2px 0 6px !important; /* compresse un peu les paragraphes */
}
</style>
""")
title_md = gr.Markdown(
"""
# Deep Aerenchyma
**High-throughput anatomical phenotyping based on Segformer**
**Authors:** Hani Atef, Romain Fernandez (CIRAD, UMR AGAP, France)
**Funder:** Global Methane Hub, CIRAD Arize project
Upload **multiple images** *or* **one ZIP** (supports nested folders).
Output: masks, result panels in `pred_global`, and `pred_metrics.csv`.
""",
elem_id="titleblock"
)
# Upload (ligne 1)
with gr.Row():
files = gr.Files(
label="Upload images or one ZIP (recommended)",
file_count="multiple",
type="filepath",
elem_id="uploader" # <-- on va le styler
)
# Boutons (ligne 2)
with gr.Row():
start_btn = gr.Button("Start with my own images", variant="primary")
with gr.Row():
demo_btn = gr.Button("Test with demo images", variant="primary")
# Logs (ligne 5)
with gr.Row():
log = gr.Textbox(label="Logs", lines=4)
# Préview (ligne 3)
with gr.Row():
live_preview = gr.Image(
label="Live preview",
value=PLACEHOLDER,
interactive=False,
elem_id="live_preview",
height=None
)
# Download (ligne 4)
with gr.Row():
out_zip = gr.File(label="Download results (ZIP)")
workdir_state = gr.State(value=None)
def driver(files):
from builtins import globals as _g
if _g().get("BUSY", False):
# toast + log immédiats (sans lancer le job)
yield {log: "⚠️ Server already in use — please retry in a few minutes.",
live_preview: PLACEHOLDER,
workdir_state: None}
return
for zip_path, msg, preview, wd in run_job(files, use_demo=False):
update = {log: msg, live_preview: preview, workdir_state: wd}
if zip_path is not None:
update[out_zip] = zip_path
yield update
def driver_demo():
from builtins import globals as _g
if _g().get("BUSY", False):
yield {log: "⚠️ Server already in use — please retry in a few minutes.",
live_preview: PLACEHOLDER,
workdir_state: None}
return
for zip_path, msg, preview, wd in run_job(None, use_demo=True):
update = {log: msg, live_preview: preview, workdir_state: wd}
if zip_path is not None:
update[out_zip] = zip_path
yield update
start_btn.click(
driver,
inputs=[files],
outputs=[out_zip, log, live_preview, workdir_state],
show_progress="hidden" # <-- AJOUT
)
demo_btn.click(
driver_demo,
inputs=None,
outputs=[out_zip, log, live_preview, workdir_state],
show_progress="hidden" # <-- AJOUT
)
demo.unload(lambda: cleanup_all())
demo.queue().launch(ssr_mode=False)