import gradio as gr import torch from torch import nn from einops import rearrange from PIL import Image import numpy as np import matplotlib.pyplot as plt import requests import os import sys import warnings # Silenciar aviso depreciação do timm visto no HF Spaces warnings.filterwarnings( "ignore", message="Importing from timm.models.layers is deprecated, please import via timm.layers", category=FutureWarning, ) # Garantir import local do pacote `surya` mesmo se CWD for diferente sys.path.append(os.path.dirname(__file__)) # ================================ # 1. Baixar pesos do Surya-1.0 # ================================ MODEL_URL = "https://huggingface.co/nasa-ibm-ai4science/Surya-1.0/resolve/main/surya.366m.v1.pt" # Preferir checkpoint local se existir MODEL_CANDIDATES = [ os.path.join(os.path.dirname(__file__), "surya_model.pt"), os.path.join(os.path.dirname(__file__), "surya.366m.v1.pt"), ] def _pick_model_file(): for p in MODEL_CANDIDATES: if os.path.exists(p): return p return MODEL_CANDIDATES[-1] MODEL_FILE = _pick_model_file() def download_model(): if not os.path.exists(MODEL_FILE): print("Baixando pesos do Surya-1.0...") r = requests.get(MODEL_URL) with open(MODEL_FILE, "wb") as f: f.write(r.content) print("Download concluído!") download_model() # ================================ # 2. Colar aqui a classe HelioSpectFormer # ================================ # Copie todo o conteúdo que você me enviou da HelioSpectFormer aqui # ⚠️ Substitua a seção abaixo pelo código real do repo from surya.models.helio_spectformer import HelioSpectFormer # se você tiver a pasta surya local # ================================ # 3. Instanciar o modelo com parâmetros padrão # ================================ model = HelioSpectFormer( img_size=224, patch_size=16, in_chans=1, embed_dim=368, time_embedding={"type": "linear", "time_dim": 1}, depth=8, n_spectral_blocks=4, num_heads=8, mlp_ratio=4.0, drop_rate=0.0, window_size=7, dp_rank=1, learned_flow=False, finetune=True ) # Carregar pesos de forma resiliente (strict=False) e logar diferenças def _try_load_weights(m: nn.Module, path: str) -> None: if os.environ.get("NO_WEIGHTS", "").lower() in {"1", "true", "yes"}: print("NO_WEIGHTS=1 -> pulando carregamento de pesos") return try: raw_sd = torch.load(path, map_location=torch.device('cpu')) model_sd = m.state_dict() filtered = {} dropped = [] for k, v in raw_sd.items(): if k in model_sd and model_sd[k].shape == v.shape: filtered[k] = v else: dropped.append((k, tuple(v.shape) if hasattr(v, 'shape') else None, tuple(model_sd.get(k, torch.tensor(())).shape) if k in model_sd else None)) missing, unexpected = m.load_state_dict(filtered, strict=False) print(f"Pesos carregados parcialmente. Ok={len(filtered)} Missing={len(missing)} Unexpected={len(unexpected)} Dropped={len(dropped)}") if dropped: print("Algumas chaves foram descartadas por mismatch (ex.:)", dropped[:5]) if missing: print("Exemplos de missing:", missing[:10]) if unexpected: print("Exemplos de unexpected:", unexpected[:10]) except Exception as e: print(f"Falha ao carregar pesos de {path}: {e}") _try_load_weights(model, MODEL_FILE) model.eval() # ================================ # 4. Função de inferência para heatmap # ================================ def infer_solar_image_heatmap(img): # Pré-processamento da imagem img_gray = img.convert("L").resize((224, 224)) img_np = np.array(img_gray) ts_tensor = ( torch.tensor(img_np, dtype=torch.float32) .unsqueeze(0) .unsqueeze(0) .unsqueeze(2) / 255.0 ) # [B=1,C=1,T=1,H=224,W=224] batch = {"ts": ts_tensor, "time_delta_input": torch.zeros((1, 1))} # Inferência (retorna tokens [1, L, D] com finetune=True) with torch.no_grad(): tokens = model(batch).squeeze(0).cpu() # [L, D] # Remover o componente estático de posição para evitar mapa "igual" entre imagens try: pos = model.embedding.pos_embed.squeeze(0).to(tokens.dtype).cpu() # [L, D] if pos.shape == tokens.shape: tokens = tokens - pos except Exception: pass # Agregar energia por patch (L2) e remontar 14x14 L, D = tokens.shape side = int(L ** 0.5) # 14 para 224/16 heat_vec = torch.sqrt((tokens**2).mean(dim=1)) # [L] heat = heat_vec.reshape(side, side).numpy() # Normalizar e upsample p/ 224x224 (nearest para simplicidade) heat = (heat - heat.min()) / (heat.max() - heat.min() + 1e-8) heat224 = np.kron(heat, np.ones((224 // side, 224 // side))) # Overlay sobre a imagem original plt.figure(figsize=(5, 5)) plt.imshow(img_np, cmap="gray") plt.imshow(heat224, cmap="inferno", alpha=0.5, vmin=0.0, vmax=1.0) plt.axis("off") plt.tight_layout() return plt.gcf() # ================================ # 5. Interface Gradio # ================================ interface = gr.Interface( fn=infer_solar_image_heatmap, inputs=gr.Image(type="pil"), outputs=gr.Plot(label="Heatmap do embedding Surya"), title="Playground Surya-1.0 com Heatmap", description="Upload de imagem solar → visualize heatmap gerado pelo Surya-1.0" ) interface.launch()