#!/usr/bin/env python3 """ Vision Transformer Essence Generator for Tag Collector Game Based on "What do Vision Transformers Learn? A Visual Exploration" """ import torch import torch.nn as nn import torch.nn.functional as F from torchvision.transforms.functional import to_pil_image from PIL import Image import numpy as np import os import re import math import json import timm import streamlit as st from tqdm import tqdm from scipy.ndimage import gaussian_filter from functools import wraps, lru_cache from safetensors.torch import load_file import time import tag_storage # Import for saving game state from game_constants import RARITY_LEVELS, ENKEPHALIN_CURRENCY_NAME, ENKEPHALIN_ICON from tag_categories import TAG_CATEGORIES os.environ["CUDA_LAUNCH_BLOCKING"] = "1" torch.backends.cudnn.benchmark = False torch.backends.cudnn.deterministic = True # Define essence quality levels with thresholds and styles ESSENCE_QUALITY_LEVELS = { "ZAYIN": {"threshold": 0.0, "color": "#1CFC00", "description": "Basic representation with minimal details."}, "TETH": {"threshold": 3.0, "color": "#389DDF", "description": "Clear representation with recognizable features."}, "HE": {"threshold": 5.0, "color": "#FEF900", "description": "Refined representation with distinctive elements."}, "WAW": {"threshold": 10.0, "color": "#7930F1", "description": "Advanced representation with precise details."}, "ALEPH": {"threshold": 12.0, "color": "#FF0000", "description": "Perfect representation with extraordinary precision."} } # Essence generation costs in enkephalin based on tag rarity ESSENCE_COSTS = { "Special": 0, "Canard": 100, "Urban Myth": 125, "Urban Legend": 150, "Urban Plague": 200, "Urban Nightmare": 250, "Star of the City": 300, "Impuritas Civitas": 400 } # Default essence generation settings DEFAULT_ESSENCE_SETTINGS = { "iterations": 256, "lr": 0.05, "ensemble_k": 8, "neighbor_count": 8, "image_size": 512, "layer_emphasis": "balanced", "tv_weight": 1e-3 } def initialize_essence_settings(): """Initialize essence generator settings if not already present""" if 'essence_custom_settings' not in st.session_state: # Try to load from storage first loaded_state = tag_storage.load_essence_state() if loaded_state and 'essence_custom_settings' in loaded_state: old_settings = loaded_state['essence_custom_settings'] # Validate and merge with current defaults new_settings = DEFAULT_ESSENCE_SETTINGS.copy() # Only keep valid settings that exist in current defaults for key in DEFAULT_ESSENCE_SETTINGS.keys(): if key in old_settings: # Validate layer_emphasis values if key == 'layer_emphasis' and old_settings[key] not in ['balanced', 'early', 'mid', 'late']: continue # Use default new_settings[key] = old_settings[key] st.session_state.essence_custom_settings = new_settings else: st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy() def initialize_manual_tags(): """Initialize manual tags if not already present""" if 'manual_tags' not in st.session_state: # Try to load from storage first loaded_state = tag_storage.load_essence_state() if loaded_state and 'manual_tags' in loaded_state: st.session_state.manual_tags = loaded_state['manual_tags'] else: st.session_state.manual_tags = { "hatsune_miku": {"rarity": "Special", "description": "Popular virtual singer with long teal twin-tails"}, } def timeout(seconds, fallback_value=None): """Simple timeout utility for functions.""" def decorator(func): @wraps(func) def wrapper(*args, **kwargs): start_time = time.time() result = func(*args, **kwargs) elapsed = time.time() - start_time if elapsed > seconds: print(f"WARNING: Function {func.__name__} took {elapsed:.2f} seconds (expected max {seconds}s)") return result return wrapper return decorator class TaggerTorch(nn.Module): def __init__(self, backbone_name="vit_base_patch16_384", img_size=512, num_tags=70527, normalize=True): super().__init__() # num_classes=0 -> return features; we add our own head self.backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0, img_size=img_size) in_features = self.backbone.num_features # 768 for vit_base_patch16_384 self.head = nn.Linear(in_features, num_tags) # Most ViT taggers expect ImageNet normalization; keep it configurable self.normalize = normalize if self.normalize: self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) def forward(self, x): if self.normalize: x = (x - self.mean) / self.std feats = self.backbone.forward_features(x) # [B, C] or [B, tokens, C] if feats.ndim == 3: # if tokens, take CLS feats = feats[:, 0, :] return self.head(feats) def _remap_backbone_keys(sd): out = {} for k, v in sd.items(): if k.startswith("module."): k = k[7:] # collapse ImageTagger → TaggerTorch vit paths if k.startswith("backbone.vit."): k = "backbone." + k[len("backbone.vit."):] elif k.startswith("vit."): k = "backbone." + k[len("vit."):] elif k.startswith(("pos_embed","patch_embed.","blocks.","norm.","cls_token")): k = "backbone." + k out[k] = v return out def _get_logits_from_output(out): if isinstance(out, dict): return out.get("refined_predictions") or out.get("initial_predictions") return out def build_torch_model_from_safetensors(ckpt_path, num_tags, backbone="vit_base_patch16_384", img_size=512): model = TaggerTorch(backbone_name=backbone, img_size=img_size, num_tags=num_tags, normalize=True) sd = load_file(ckpt_path) sd = _remap_backbone_keys(sd) # pull out tag embedding/bias if present and later copy into the linear head te_w = sd.pop("tag_embedding.weight", sd.pop("module.tag_embedding.weight", None)) te_b = sd.pop("tag_bias", sd.pop("module.tag_bias", None)) # load backbone etc. missing, unexpected = model.load_state_dict(sd, strict=False) print("[load] missing:", missing[:20], "…") print("[load] unexpected:", unexpected[:20], "…") # copy tag embedding → head with torch.no_grad(): if te_w is not None and te_w.shape == model.head.weight.shape: model.head.weight.copy_(te_w) print("[load] copied tag_embedding.weight → head.weight") if te_b is not None and model.head.bias is not None and te_b.shape == model.head.bias.shape: model.head.bias.copy_(te_b) print("[load] copied tag_bias → head.bias") return model @torch.no_grad() def _get_classifier_matrix(model): # [T, C] — works for both ImageTagger and TaggerTorch if hasattr(model, "tag_embedding"): return model.tag_embedding.weight.detach() if hasattr(model, "head") and hasattr(model.head, "weight"): return model.head.weight.detach() raise AttributeError("Model has neither tag_embedding nor head.weight") @torch.no_grad() def neighbor_sets_from_embedding(model, class_idx, k_pos=8, k_neg=8): """ Returns (pos_idx, pos_sims, neg_idx, neg_sims) pos: highest cosine neighbors (exclude self) neg: lowest cosine neighbors (most dissimilar) """ W = _get_classifier_matrix(model) # [T, C] Wn = F.normalize(W, dim=1) q = Wn[class_idx:class_idx+1] # [1, C] sims = (q @ Wn.T).squeeze(0) # [T] sims[class_idx] = -9e9 # mask self # positives: largest similarities pos_vals, pos_idx = torch.topk(sims, k=min(k_pos, sims.numel()-1)) # negatives: smallest similarities (most negative / least similar) neg_vals, neg_idx = torch.topk(-sims, k=min(k_neg, sims.numel()-1)) neg_vals = -neg_vals # weights (clip for stability) pos_w = torch.clamp(pos_vals, 0.0, 1.0).tolist() neg_w = torch.clamp(neg_vals.abs(), 0.0, 1.0).tolist() return pos_idx.tolist(), pos_w, neg_idx.tolist(), neg_w def weighted_class_objective(logits, main_idx, plus_idxs=(), plus_w=None, alpha=0.25, minus_idxs=(), minus_w=None, beta=0.15): score = logits[:, main_idx].mean() if plus_idxs: w = torch.tensor(plus_w or [1.0]*len(plus_idxs), device=logits.device, dtype=logits.dtype) w = w / (w.sum() + 1e-8) score = score + alpha * (logits[:, plus_idxs] * w).sum(dim=1).mean() if minus_idxs: w = torch.tensor(minus_w or [1.0]*len(minus_idxs), device=logits.device, dtype=logits.dtype) w = w / (w.sum() + 1e-8) score = score - beta * (logits[:, minus_idxs] * w).sum(dim=1).mean() return score def idx_to_name(idx, dataset=None): if dataset is not None and hasattr(dataset, "idx_to_tag"): return dataset.idx_to_tag.get(int(idx), f"Tag {idx}") # fallback to your cached JSON meta = _load_tagger_metadata_cached() return meta.get("dataset_info",{}).get("tag_mapping",{}).get("idx_to_tag",{}).get(str(int(idx)), f"Tag {idx}") # Core Classes for ViT Essence Generation class ViTLayerHook: """Hook for capturing ViT feed-forward layer activations.""" def __init__(self, layer, layer_name): self.layer = layer self.layer_name = layer_name self.features = None self.hook = layer.register_forward_hook(self.hook_fn) def hook_fn(self, module, input, output): """Store the output activations.""" self.features = output def close(self): self.hook.remove() class ViTFeatureAnalyzer: """Analyzes ViT architecture to find optimal layers for visualization.""" def __init__(self, model): self.model = model self.layer_info = self._analyze_architecture() def _analyze_architecture(self): """Analyze the ViT architecture and identify feed-forward layers.""" layer_info = {} def traverse_modules(module, prefix=''): for name, child in module.named_children(): full_name = f"{prefix}.{name}" if prefix else name # Look for transformer blocks and their MLP components if 'mlp' in full_name.lower() and (hasattr(child, 'act') or 'act' in dict(child.named_children())): # prefer the actual activation submodule act = getattr(child, 'act', None) if act is not None: layer_info[full_name + ".act"] = { 'type': 'mlp_activation', 'module': act, 'block_idx': self._extract_block_number(full_name) } else: # fallback: search by name for n2, c2 in child.named_children(): if 'act' in n2.lower(): layer_info[full_name + f".{n2}"] = { 'type': 'mlp_activation', 'module': c2, 'block_idx': self._extract_block_number(full_name) } elif 'gelu' in str(type(child)).lower() or 'activation' in name.lower(): # Direct activation layers (GELU, etc.) parent_name = prefix.split('.')[-1] if '.' in prefix else prefix if 'mlp' in prefix.lower() or 'ffn' in prefix.lower(): layer_info[full_name] = { 'type': 'activation', 'module': child, 'block_idx': self._extract_block_number(full_name) } # Recurse into children traverse_modules(child, full_name) traverse_modules(self.model) return layer_info def _extract_block_number(self, layer_name): """Extract block/layer number from layer name.""" import re numbers = re.findall(r'\.(\d+)\.', layer_name) if numbers: return int(numbers[0]) return 0 def get_visualization_layers(self, layer_emphasis="balanced"): """Get the best layers for visualization based on emphasis.""" if not self.layer_info: print("Warning: No suitable ViT layers found for visualization") return [] # Sort layers by block index sorted_layers = sorted( [(n, info) for n, info in self.layer_info.items() if 'mlp' in n.lower() and 'act' in n.lower()], key=lambda x: x[1]['block_idx'] ) total_blocks = max([info['block_idx'] for _, info in sorted_layers]) + 1 if layer_emphasis == "early": # Focus on first 1/3 of blocks target_blocks = list(range(0, max(1, total_blocks // 3))) elif layer_emphasis == "mid": # Focus on middle 1/3 of blocks start = total_blocks // 3 end = 2 * total_blocks // 3 target_blocks = list(range(start, max(start + 1, end))) elif layer_emphasis == "late": # Focus on last 1/3 of blocks start = 2 * total_blocks // 3 target_blocks = list(range(start, total_blocks)) else: # balanced # Sample across all blocks if total_blocks <= 3: target_blocks = list(range(total_blocks)) else: target_blocks = [0, total_blocks // 2, total_blocks - 1] # Select layers from target blocks selected_layers = [] for layer_name, info in sorted_layers: if info['block_idx'] in target_blocks: selected_layers.append(layer_name) return selected_layers def _jitter_reflect_crop(x, pad=16): b, c, h, w = x.shape padded = F.pad(x, (pad, pad, pad, pad), mode='reflect').contiguous() off_h = torch.randint(0, 2 * pad + 1, (b,), device=x.device) off_w = torch.randint(0, 2 * pad + 1, (b,), device=x.device) crops = [] for i in range(b): hs, ws = int(off_h[i]), int(off_w[i]) crop = padded[i:i+1, :, hs:hs+h, ws:ws+w].contiguous() crops.append(crop) return torch.cat(crops, 0).contiguous() def _channel_affine(x): # per-channel affine: σ ~ exp(U[-1,1]), μ ~ U[-1,1] b, c, _, _ = x.shape mu = torch.empty(b, c, 1, 1, device=x.device, dtype=x.dtype).uniform_(-1.0, 1.0) log_sigma = torch.empty(b, c, 1, 1, device=x.device, dtype=x.dtype).uniform_(-1.0, 1.0) sigma = torch.exp(log_sigma) return (x * sigma + mu) def _add_gaussian_noise(x, std=0.15): return (x + torch.randn_like(x) * std) def _augment_once(x, noise_std=0.15): z = _jitter_reflect_crop(x) z = _channel_affine(z) z = _add_gaussian_noise(z, std=noise_std) return z def _augment_batch(x, K=8, noise_std=0.15): augs = [] for _ in range(K): z = _augment_once(x, noise_std=noise_std) augs.append(z) return torch.cat(augs, dim=0).contiguous() class ViTEssenceGenerator: """ ViT Essence Generator based on the methodology from 'What do Vision Transformers Learn? A Visual Exploration' """ def __init__( self, model, tag_to_name=None, iterations=500, learning_rate=0.05, layer_emphasis="balanced", ensemble_K=8, tv_weight=1e-3 ): """Initialize the ViT Essence Generator""" self.model = model self.tag_to_name = tag_to_name self.iterations = iterations self.lr = learning_rate self.layer_emphasis = layer_emphasis self.ensemble_K = ensemble_K self.tv_weight = tv_weight # Set device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.eval().to(self.device) self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1,3,1,1) self.imagenet_std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1,3,1,1) self.expect_imagenet = not (hasattr(self.model, "normalize") and getattr(self.model, "normalize") is True) # Analyze ViT architecture self.analyzer = ViTFeatureAnalyzer(self.model) # Initialize hooks self.hooks = {} self.selected_layers = [] print(f"ViT Essence Generator initialized on {self.device}") def _preprocess(self, x): return (x - self.imagenet_mean) / self.imagenet_std if self.expect_imagenet else x def setup_hooks(self, tag_idx): """Setup hooks for multi-layer visualization.""" self.close_hooks() names = self.analyzer.get_visualization_layers(self.layer_emphasis) if not names: print("Warning: No suitable layers found for visualization") return {} print(f"Setting up hooks on {len(names)} ViT layer(s)") layer_weights = {} for i, layer_name in enumerate(names): try: layer_info = self.analyzer.layer_info[layer_name] layer_module = layer_info['module'] self.hooks[layer_name] = ViTLayerHook(layer_module, layer_name) weight = 0.3 + 0.7 * (i / max(1, len(names) - 1)) layer_weights[layer_name] = weight print(f" - {layer_name} (block {layer_info['block_idx']}, weight: {weight:.2f})") except Exception as e: print(f"Failed to setup hook for {layer_name}: {e}") self.selected_layers = names return layer_weights def close_hooks(self): """Clean up hooks to avoid memory leaks.""" for hook in self.hooks.values(): hook.close() self.hooks.clear() def _fourier_init(self, size=224, decay=1.5): H = W = size # complex spectrum (rFFT domain) spec = torch.randn(1, 3, H, W//2 + 1, dtype=torch.complex64, device=self.device) fy = torch.fft.fftfreq(H, device=self.device).abs().view(H, 1) fx = torch.fft.rfftfreq(W, device=self.device).abs().view(1, W//2 + 1) radius = (fy**2 + fx**2).sqrt().clamp_(min=1e-6) spec = spec * (1.0 / (radius ** decay)) # 1/f^decay img = torch.fft.irfft2(spec, s=(H, W)) # [1,3,H,W], roughly zero-mean # scale to [0,1] img = (img - img.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]) img = img / (img.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] + 1e-8) return img def create_optimizable_image(self, size=224, use_fourier=True): if use_fourier: with torch.no_grad(): image = self._fourier_init(size) image = image.to(self.device) else: image = torch.rand(1, 3, size, size, device=self.device) image = image.detach().contiguous().requires_grad_(True) return image def total_variation_loss(self, image): # image: [B,3,H,W] diff_y = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :]) diff_x = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1]) tv_per_sample = diff_y.mean(dim=(1,2,3)) + diff_x.mean(dim=(1,2,3)) # [B] return tv_per_sample.mean() def get_feature_activations(self, layer_weights, topk_channels=None): total = 0.0 for name, hook in self.hooks.items(): feats = hook.features # [B, tokens, C] from GELU if feats is None: continue w = layer_weights.get(name, 0.5) # aggregate: sum over tokens; then (optionally) top-k over channels agg = feats.sum(dim=1) # [B, C] if topk_channels is not None and topk_channels > 0 and agg.shape[1] > topk_channels: # take the mean of top-k channels for stability vals, _ = torch.topk(agg, k=topk_channels, dim=1) act = vals.mean() else: act = agg.mean() total = total + w * act return total def generate_essence(self, tag_idx, neighbor_count=8, image_size=224, return_score=True, progress_callback=None): """Generate an essence visualization for a ViT model.""" # Get tag name for logging tag_name = self.tag_to_name.get(tag_idx, f"Tag {tag_idx}") if self.tag_to_name else f"Tag {tag_idx}" print(f"Generating ViT essence for '{tag_name}' (index: {tag_idx})...") # Setup hooks for this tag layer_weights = self.setup_hooks(tag_idx) if not self.hooks and not hasattr(self.model, 'head'): print("Warning: No hooks set up and no classifier head found") return self._create_fallback_image(image_size), 0.0 # Initialize optimizable image image = self.create_optimizable_image(image_size) # Create optimizer optimizer = torch.optim.Adam([image], lr=self.lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( optimizer, T_max=self.iterations, eta_min=self.lr * 0.01 ) best_score = -float('inf') best_image = None print(f"Starting optimization for {self.iterations} iterations...") # Choose auxiliaries once per run pos_idx, pos_w, neg_idx, neg_w = neighbor_sets_from_embedding( self.model, tag_idx, k_pos=neighbor_count, k_neg=neighbor_count ) for i in range(self.iterations): optimizer.zero_grad() # Clear previous activations for hook in self.hooks.values(): hook.features = None # Original generation logic - tag mode only aug_batch = _augment_batch(image, K=self.ensemble_K, noise_std=0.15) out = self.model(self._preprocess(aug_batch)) logits = out["refined_predictions"] if isinstance(out, dict) else out # [K, T] cls_term = weighted_class_objective( logits, main_idx=tag_idx, plus_idxs=pos_idx, plus_w=pos_w, alpha=0.25, minus_idxs=neg_idx, minus_w=neg_w, beta=0.15 ) # keep your feature regularizer (hooks) feat_term = 0.0 if self.hooks: feats = self.get_feature_activations(layer_weights, topk_channels=64) Ltv = self.total_variation_loss(aug_batch) total_loss = -(cls_term + 0.5 * feat_term) + self.tv_weight * Ltv # Backward pass total_loss.backward() if image.grad is None or not torch.isfinite(image.grad).all(): print("WARN: no/invalid grad reaching the image; check hook & loss wiring.") # Gradient clip to avoid exploding updates torch.nn.utils.clip_grad_norm_([image], max_norm=3.0) optimizer.step() scheduler.step() # Keep pixels in valid range with torch.no_grad(): image.clamp_(0.0, 1.0) # Handle non-finite losses if not torch.isfinite(total_loss.detach()): print("WARN: non-finite loss; resetting image step") optimizer.zero_grad(set_to_none=True) with torch.no_grad(): # Small reset toward noise image.add_(0.05 * torch.randn_like(image)).clamp_(0.0, 1.0) continue # Track best result - using original score calculation with torch.no_grad(): score_tensor = -(total_loss - self.tv_weight * Ltv) current_score = float(score_tensor.item()) if current_score > best_score: best_score = current_score best_image = image.detach().clone() # Progress reporting if progress_callback and i % max(1, self.iterations // 20) == 0: progress_callback( scale_idx=0, scale_count=1, iter_idx=i, iter_count=self.iterations, score=current_score ) # Logging if i % max(1, self.iterations // 10) == 0: print(f"Iteration {i}/{self.iterations}: Score = {current_score:.4f}") # Use best image if we found one if best_image is not None: final_image = best_image else: final_image = image.detach() # Convert to PIL image final_image = torch.clamp(final_image, 0, 1) pil_img = to_pil_image(final_image[0].cpu()) # Clean up hooks self.close_hooks() print(f"ViT essence generation complete for '{tag_name}'. Final score: {best_score:.4f}") if return_score: return pil_img, best_score else: return pil_img def _create_fallback_image(self, size): """Create a fallback image when generation fails.""" # Create a simple noise pattern image = torch.randn(1, 3, size, size) * 0.5 + 0.5 image = torch.clamp(image, 0, 1) return to_pil_image(image[0]) # Utility Functions def get_quality_level(score): """Determine the quality level of an essence based on its score""" for level in reversed(list(ESSENCE_QUALITY_LEVELS.keys())): if score >= ESSENCE_QUALITY_LEVELS[level]["threshold"]: return level return "ZAYIN" # Default to lowest level def get_essence_cost(rarity): """Calculate the cost to generate an essence image based on tag rarity""" return ESSENCE_COSTS.get(rarity, 100) # Default to 100 if rarity unknown def save_essence_to_game_folder(image, tag, score, quality_level): """Save the generated essence image to a persistent game folder""" # Create game folder paths with better structure base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) game_data_dir = os.path.join(base_dir, "game_data") essence_folder = os.path.join(game_data_dir, "essences") # Make sure all parent directories exist os.makedirs(game_data_dir, exist_ok=True) os.makedirs(essence_folder, exist_ok=True) # Organize essences by quality level for easier browsing quality_folder = os.path.join(essence_folder, quality_level) os.makedirs(quality_folder, exist_ok=True) # Create filename with more details and better organization safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') timestamp = time.strftime("%Y%m%d_%H%M%S") filename = f"{safe_tag}_{score:.2f}_{timestamp}.png" filepath = os.path.join(quality_folder, filename) # Save the image image.save(filepath) print(f"Saved ViT essence to: {filepath}") return filepath def load_tagger_metadata(): """Load the camie-tagger-v2-metadata.json file from parent directory.""" try: # Look for metadata file in parent directory current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) metadata_path = os.path.join(parent_dir, "camie-tagger-v2-metadata.json") if os.path.exists(metadata_path): with open(metadata_path, 'r', encoding='utf-8') as f: metadata = json.load(f) print(f"Loaded tagger metadata from: {metadata_path}") return metadata else: print(f"Metadata file not found at: {metadata_path}") return None except Exception as e: print(f"Error loading tagger metadata: {e}") return None @lru_cache(maxsize=1) def _load_tagger_metadata_cached(): meta = load_tagger_metadata() return meta or {} def resolve_tag_index(tag, dataset=None): """Robustly resolve tag -> index using dataset, session metadata, then camie-tagger-v2-metadata.json.""" if not isinstance(tag, str): return int(tag) # normalize variants cands = {tag.strip(), tag.strip().replace(" ", "_")} cands |= {c.lower() for c in list(cands)} # 1) dataset.tag_to_idx if dataset is not None and hasattr(dataset, "tag_to_idx"): for c in cands: if c in dataset.tag_to_idx: return int(dataset.tag_to_idx[c]) # 2) session metadata sm = getattr(st.session_state, "metadata", {}) or {} m = sm.get("tag_to_idx", {}) if isinstance(sm, dict) else {} for c in cands: if c in m: return int(m[c]) # 3) JSON metadata (cached) meta = _load_tagger_metadata_cached() mjson = (meta.get("dataset_info", {}) .get("tag_mapping", {}) .get("tag_to_idx", {})) if isinstance(meta, dict) else {} for c in cands: if c in mjson: return int(mjson[c]) return None def generate_essence_for_tag(tag, model, dataset, custom_settings=None): """ Generate an essence image for a specific tag using the ViT generator Args: tag: The tag name or index model: The ViT model to use dataset: The dataset containing tag information custom_settings: Optional dictionary with custom generation settings Returns: PIL Image of the generated essence, score, quality level """ print(f"\n=== Starting ViT essence generation for tag '{tag}' ===") # Check if tag is discovered or a manual tag is_manual_tag = hasattr(st.session_state, 'manual_tags') and tag in st.session_state.manual_tags is_discovered = hasattr(st.session_state, 'discovered_tags') and tag in st.session_state.discovered_tags if not is_discovered and not is_manual_tag: st.error(f"Tag '{tag}' has not been discovered yet.") return None, 0, None # Get tag rarity and calculate cost if is_discovered: rarity = st.session_state.discovered_tags[tag].get("rarity", "Canard") elif is_manual_tag: rarity = st.session_state.manual_tags[tag].get("rarity", "Canard") else: rarity = "Canard" # Calculate cost based on rarity cost = get_essence_cost(rarity) # Check if player has enough Enkephalin if st.session_state.enkephalin < cost: st.error(f"Not enough {ENKEPHALIN_CURRENCY_NAME} to generate this essence. You need {cost} {ENKEPHALIN_ICON} but have {st.session_state.enkephalin} {ENKEPHALIN_ICON}.") return None, 0, None # Use provided settings or defaults settings = custom_settings or DEFAULT_ESSENCE_SETTINGS.copy() print(f"Using settings: {settings}") # UI containers for progress preview_container = st.empty() progress_container = st.empty() message_container = st.empty() try: message_container.info(f"Generating ViT essence for '{tag}' with {settings.get('layer_emphasis', 'balanced')} layer emphasis...") # Progress callback function def progress_callback(scale_idx, scale_count, iter_idx, iter_count, score): progress = iter_idx / iter_count progress_container.progress(progress, f"Iteration {iter_idx}/{iter_count}") message_container.info(f"Current score: {score:.4f}") if iter_idx % 50 == 0: print(f"Progress: Iteration {iter_idx}/{iter_count}, Score: {score:.4f}") # Convert tag name to index tag_idx = None if isinstance(tag, str): tag_idx = resolve_tag_index(tag, dataset) if tag_idx is None: st.error( f"Tag '{tag}' index not found in dataset or metadata. " f"Make sure it exists in camie-tagger-v2-metadata.json." ) return None, 0, None else: tag_idx = int(tag) print(f"Resolved tag '{tag}' -> index {tag_idx}") # Create tag-to-name mapping tag_to_name = {tag_idx: tag} # Get or create Torch model specifically for essence generation torch_model = getattr(st.session_state, "model_torch", None) if not isinstance(torch_model, nn.Module): # safetensors lives ONE directory above this file current_dir = os.path.dirname(os.path.abspath(__file__)) parent_dir = os.path.dirname(current_dir) ckpt = os.path.join(parent_dir, "camie-tagger-v2.safetensors") if not os.path.exists(ckpt): st.error(f"Missing safetensors checkpoint at: {ckpt}") return None, 0, None # metadata-driven sizes meta = _load_tagger_metadata_cached() num_tags = int(meta.get("dataset_info", {}).get("total_tags", 70527)) img_size = int(meta.get("model_info", {}).get("img_size", 512)) torch_model = build_torch_model_from_safetensors( ckpt_path=ckpt, num_tags=num_tags, backbone="vit_base_patch16_384", img_size=img_size ) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch_model = torch_model.to(device).eval() st.session_state.model_torch = torch_model # cache for later # Create ViT essence generator with settings from UI generator = ViTEssenceGenerator( model=torch_model, tag_to_name=tag_to_name, iterations=settings.get('iterations', 500), learning_rate=settings.get('lr', 0.05), layer_emphasis=settings.get('layer_emphasis', 'balanced'), ensemble_K=settings.get('ensemble_k', 8), tv_weight=settings.get('tv_weight', 1e-3) ) image, score = generator.generate_essence( tag_idx=tag_idx, neighbor_count=settings.get('neighbor_count', 8), image_size=settings.get('image_size', 512), return_score=True, progress_callback=progress_callback ) # Determine quality level quality_level = get_quality_level(score) # Deduct enkephalin cost st.session_state.enkephalin -= cost st.session_state.game_stats["enkephalin_spent"] = st.session_state.game_stats.get("enkephalin_spent", 0) + cost # Increment essence counter st.session_state.game_stats["essences_generated"] = st.session_state.game_stats.get("essences_generated", 0) + 1 # Save to persistent location filepath = save_essence_to_game_folder(image, tag, score, quality_level) # Update UI with result preview_container.image(image, caption=f"ViT Essence of '{tag}' - Quality: {quality_level}", width=400) # Clear progress elements progress_container.empty() message_container.empty() # Store in session state if 'generated_essences' not in st.session_state: st.session_state.generated_essences = {} st.session_state.generated_essences[tag] = { "path": filepath, "score": score, "quality": quality_level, "rarity": rarity, "settings": settings, "generated_time": time.strftime("%Y-%m-%d %H:%M:%S") } # Show success message st.success(f"Successfully generated {quality_level} ViT essence for '{tag}' with score {score:.4f}! Spent {cost} {ENKEPHALIN_ICON}") print(f"=== ViT essence generation complete for '{tag}' ===\n") # Save state tag_storage.save_essence_state(session_state=st.session_state) return image, score, quality_level except Exception as e: st.error(f"Error generating ViT essence: {str(e)}") print(f"EXCEPTION in generate_essence_for_tag: {str(e)}") import traceback err_traceback = traceback.format_exc() print(err_traceback) st.code(err_traceback) return None, 0, None # Utility Functions for Model Analysis and Layer Selection def get_model_layers(model): """Utility function to get all available layers in a model.""" layers = [] for name, _ in model.named_modules(): if name: # Skip empty name (the model itself) layers.append(name) return layers def get_key_layers(model, max_layers=15): """ Get a curated list of the most relevant layers for visualization. """ all_layers = get_model_layers(model) # For models with hundreds of layers, we need to be selective if len(all_layers) > 30: # Extract patterns to identify layer types block_patterns = {} # Find common patterns in layer names for layer in all_layers: # Extract the main component (e.g., "backbone.features") parts = layer.split(".") if len(parts) >= 2: prefix = ".".join(parts[:2]) if prefix not in block_patterns: block_patterns[prefix] = [] block_patterns[prefix].append(layer) # Now select representative layers from each major block key_layers = { "early": [], "middle": [], "late": [] } # For each major block, select layers at strategic positions for prefix, layers in block_patterns.items(): if len(layers) > 3: # Only process significant blocks # Sort by natural depth (assuming numerical components indicate depth) layers.sort(key=lambda x: [int(s) if s.isdigit() else s for s in re.findall(r'\d+|\D+', x)]) # Get layers at strategic positions early = layers[0] middle = layers[len(layers) // 2] late = layers[-1] key_layers["early"].append(early) key_layers["middle"].append(middle) key_layers["late"].append(late) # Ensure we don't have too many layers # If we need to reduce further, prioritize middle and late layers flattened = [] for _, group_layers in key_layers.items(): flattened.extend(group_layers) if len(flattened) > max_layers: # Calculate how many to keep from each group total = len(flattened) # Prioritize keeping late layers (for character recognition) late_count = min(len(key_layers["late"]), max_layers // 3) # Allocate remaining slots between early and middle remaining = max_layers - late_count middle_count = min(len(key_layers["middle"]), remaining // 2) early_count = min(len(key_layers["early"]), remaining - middle_count) # Take only the needed number from each category key_layers["early"] = key_layers["early"][:early_count] key_layers["middle"] = key_layers["middle"][:middle_count] key_layers["late"] = key_layers["late"][:late_count] else: # For simpler models, use standard distribution n = len(all_layers) key_layers = { "early": all_layers[:n//3][:3], # First few layers "middle": all_layers[n//3:2*n//3][:4], # Middle layers "late": all_layers[2*n//3:][:3] # Last few layers } # Try to identify the classifier/final layer classifier_layers = [layer for layer in all_layers if any(x in layer.lower() for x in ["classifier", "fc", "linear", "output", "logits", "head"])] if classifier_layers: key_layers["classifier"] = [classifier_layers[-1]] return key_layers def get_suggested_layers(model, layer_type="balanced"): """ Get suggested layers based on the desired feature type. """ key_layers = get_key_layers(model) # Flatten all layers for reference all_key_layers = [] for layers in key_layers.values(): all_key_layers.extend(layers) # Choose layers based on the requested emphasis if layer_type == "low": # Focus on early visual features (textures, patterns, colors) selected = key_layers.get("early", []) # Add one middle layer for stability if "middle" in key_layers and key_layers["middle"]: selected.append(key_layers["middle"][0]) elif layer_type == "mid": # Focus on mid-level features (parts, components) selected = key_layers.get("middle", []) # Add one early layer for context if "early" in key_layers and key_layers["early"]: selected.append(key_layers["early"][-1]) elif layer_type == "high": # Focus on high-level semantic features (objects, characters) selected = key_layers.get("late", []) selected.extend(key_layers.get("classifier", [])) # Add one middle layer for context if "middle" in key_layers and key_layers["middle"]: selected.append(key_layers["middle"][-1]) else: # balanced # Use a mix of early, middle and late layers selected = [] for category in ["early", "middle", "late", "classifier"]: if category in key_layers and key_layers[category]: # Take one from each category selected.append(key_layers[category][0]) # For middle and late, also take the last one if different if category in ["middle", "late"] and len(key_layers[category]) > 1: selected.append(key_layers[category][-1]) # Ensure we have at least one layer if not selected and all_key_layers: selected = [all_key_layers[-1]] # Use the last layer as fallback return selected # Game UI and Integration Functions def display_essence_generator(): """ Display the essence generator interface """ # Initialize settings initialize_essence_settings() st.title("🎨 Tag Essence Generator") st.write("Generate visual representations of what the AI model recognizes for specific tags.") # Add detailed explanation of what essences are for with st.expander("What are Tag Essences & How to Use Them", expanded=True): st.markdown(""" ### 💡 Understanding Tag Essences Tag Essences are visual representations of what the AI model recognizes for specific tags. They can be extremely valuable for your tag collection strategy! **How to use Tag Essences:** 1. **Generate a high-quality essence** for a tag you want to collect more of (only available on tags discovered in the library) 2. **Save the essence image** to your computer 3. **Upload the essence image** back into the tagger 4. The tagger will **almost always detect the original tag** 5. It will often also **detect related rare tags** from the same category **Strategic Value:** - Character essences can help unlock other tags associated with that character - Category essences can help discover rare tags within that category - High-quality essences (WAW, ALEPH) have the strongest effect **This is why Enkephalin costs are high** - essences are powerful tools that can help you discover rare tags much more efficiently than random image scanning! """) # Check for model availability model_available = hasattr(st.session_state, 'model') if not model_available: st.warning("Model not available. You can browse your tags but cannot generate essences.") # Create tabs for the different sections tabs = st.tabs(["Generate Essence", "My Essences"]) with tabs[0]: # Check for pending generation from previous interaction if hasattr(st.session_state, 'selected_tag') and st.session_state.selected_tag: tag = st.session_state.selected_tag st.subheader(f"Generating Essence for '{tag}'") # Generate the essence image, score, quality = generate_essence_for_tag( tag, st.session_state.model, st.session_state.model.dataset, st.session_state.essence_custom_settings ) # Show usage tips if successful if image is not None: with st.expander("Essence Usage", expanded=True): st.markdown(""" 💡 **Tag Essence Usage Tips:** 1. Look for similar patterns, colors, and elements in real images 2. The essence reveals what features the AI model recognizes for this tag 3. Use this as inspiration when creating or finding images to get this tag """) else: st.error("Essence generation failed. Please check the error messages above and try again with different settings.") # Clear selected tag st.session_state.selected_tag = None else: # Show the interface to select a tag selected_tag = display_essence_generation_interface(model_available) # If a tag was selected, store it for the next run and rerun if selected_tag: st.session_state.selected_tag = selected_tag st.rerun() with tabs[1]: display_saved_essences() def essence_folder_path(): """Get the path to the essence folder, creating it if necessary""" base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) game_data_dir = os.path.join(base_dir, "game_data") essence_folder = os.path.join(game_data_dir, "essences") # Make sure all directories exist os.makedirs(game_data_dir, exist_ok=True) os.makedirs(essence_folder, exist_ok=True) return essence_folder def display_saved_essences(): """Display the user's saved essence images""" st.subheader("My Generated Essences") if not hasattr(st.session_state, 'generated_essences') or not st.session_state.generated_essences: st.info("You haven't generated any essences yet. Go to the Generate tab to create some!") return # Add usage instructions at the top st.markdown(""" ### How to Use Your Essences 1. **Click on any essence image** to open it in full size 2. **Save the image** to your computer (right-click → Save image) 3. **Go to the Scan Images tab** and upload the saved essence 4. The tagger will likely detect the original tag and potentially related rare tags! Higher quality essences (WAW, ALEPH) generally produce the best results. """) # Get the essence folder path essence_dir = essence_folder_path() # Try to locate any missing files for tag, info in st.session_state.generated_essences.items(): if "path" in info and not os.path.exists(info["path"]): # Try to find the file in the essence directory quality = info.get("quality", "ZAYIN") quality_dir = os.path.join(essence_dir, quality) if os.path.exists(quality_dir): # Check for files with this tag name safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') matching_files = [f for f in os.listdir(quality_dir) if f.startswith(safe_tag)] if matching_files: # Use the most recent file if there are multiple matching_files.sort(reverse=True) info["path"] = os.path.join(quality_dir, matching_files[0]) print(f"Reconnected essence for {tag} to {info['path']}") # List essences by quality level essences_by_quality = {} for tag, info in st.session_state.generated_essences.items(): quality = info.get("quality", "ZAYIN") # Default to lowest if not set if quality not in essences_by_quality: essences_by_quality[quality] = [] essences_by_quality[quality].append((tag, info)) # Check if any essences exist on disk but are not tracked in session state try: untracked_essences = {} for quality in ESSENCE_QUALITY_LEVELS.keys(): quality_dir = os.path.join(essence_dir, quality) if os.path.exists(quality_dir): essence_files = os.listdir(quality_dir) # Filter to only show PNG files essence_files = [f for f in essence_files if f.lower().endswith('.png')] if essence_files: # Check if any of these files aren't in our tracked essences for filename in essence_files: # Extract tag name from filename parts = filename.split('_') if len(parts) >= 2: tag = parts[0].replace('_', ' ') # Check if file is already tracked is_tracked = False for tracked_tag, tracked_info in st.session_state.generated_essences.items(): if "path" in tracked_info and os.path.basename(tracked_info["path"]) == filename: is_tracked = True break if not is_tracked: if quality not in untracked_essences: untracked_essences[quality] = [] untracked_essences[quality].append((tag, { "path": os.path.join(quality_dir, filename), "quality": quality, "discovered_on_disk": True })) except Exception as e: print(f"Error checking for untracked essences: {e}") # Combine tracked and untracked essences for quality, essences in untracked_essences.items(): if quality not in essences_by_quality: essences_by_quality[quality] = [] for tag, info in essences: # Only add if we don't already have this tag in this quality level if not any(tracked_tag == tag for tracked_tag, _ in essences_by_quality[quality]): essences_by_quality[quality].append((tag, info)) # Show essences from highest to lowest quality for quality in list(ESSENCE_QUALITY_LEVELS.keys())[::-1]: if quality in essences_by_quality: essences = essences_by_quality[quality] color = ESSENCE_QUALITY_LEVELS[quality]["color"] with st.expander(f"{quality} Essences ({len(essences)})", expanded=quality in ["ALEPH", "WAW"]): # Create grid layout cols = st.columns(3) for i, (tag, info) in enumerate(sorted(essences, key=lambda x: x[1].get("score", 0), reverse=True)): col_idx = i % 3 with cols[col_idx]: try: # Try to load the image from path if "path" in info and os.path.exists(info["path"]): image = Image.open(info["path"]) rarity = info.get("rarity", "Canard") score = info.get("score", 0) # Get color for rarity rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") # Display the image with metadata st.image(image, caption=tag, use_container_width=True) # Use special styling for rare tags if rarity == "Impuritas Civitas": st.markdown(f""" {quality} | {rarity} | Score: {score:.2f} """, unsafe_allow_html=True) elif rarity == "Star of the City": st.markdown(f""" {quality} | {rarity} | Score: {score:.2f} """, unsafe_allow_html=True) elif rarity == "Urban Nightmare": st.markdown(f""" {quality} | {rarity} | Score: {score:.2f} """, unsafe_allow_html=True) elif rarity == "Urban Plague": st.markdown(f""" {quality} | {rarity} | Score: {score:.2f} """, unsafe_allow_html=True) else: st.markdown(f""" {quality} | {rarity} | Score: {score:.2f} """, unsafe_allow_html=True) # Add file info if "discovered_on_disk" in info and info["discovered_on_disk"]: st.info("Found on disk (not in session state)") # Add button to open folder if st.button(f"Open Folder", key=f"open_folder_{tag}_{quality}"): folder_path = os.path.dirname(info["path"]) try: # Try different methods to open folder based on platform if os.name == 'nt': # Windows os.startfile(folder_path) elif os.name == 'posix': # macOS or Linux import subprocess if 'darwin' in os.sys.platform: # macOS subprocess.call(['open', folder_path]) else: # Linux subprocess.call(['xdg-open', folder_path]) st.success(f"Opened folder: {folder_path}") except Exception as e: st.error(f"Could not open folder: {str(e)}") # Provide the path for manual navigation st.code(folder_path) else: # Could not find image st.warning(f"Image file not found: {info.get('path', 'No path available')}") # Show quality and tag name st.markdown(f""" {quality} | {tag} """, unsafe_allow_html=True) # Only add reconnect button if we have some metadata if "rarity" in info and "score" in info: if st.button(f"Reconnect File", key=f"reconnect_{tag}_{quality}"): # Update path in session state safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') score = info.get("score", 0) quality_dir = os.path.join(essence_dir, quality) # Create directory if it doesn't exist os.makedirs(quality_dir, exist_ok=True) # Set a path - user will need to manually add the image timestamp = time.strftime("%Y%m%d_%H%M%S") filename = f"{safe_tag}_{score:.2f}_{timestamp}.png" info["path"] = os.path.join(quality_dir, filename) st.info(f"Please save your image to this location: {info['path']}") st.session_state.generated_essences[tag] = info tag_storage.save_essence_state(session_state=st.session_state) st.rerun() except Exception as e: st.write(f"Error loading {tag}: {str(e)}") # Add option to clean up missing files st.divider() if st.button("Clean Up Missing Files", help="Remove entries for essences where the file no longer exists"): # Find all entries with missing files to_remove = [] for tag, info in st.session_state.generated_essences.items(): if "path" in info and not os.path.exists(info["path"]): to_remove.append(tag) # Remove them for tag in to_remove: del st.session_state.generated_essences[tag] # Save state tag_storage.save_essence_state(session_state=st.session_state) if to_remove: st.success(f"Removed {len(to_remove)} entries with missing files") else: st.success("No missing files found") st.rerun() def display_essence_generation_interface(model_available): """Display the interface for generating new essences""" # Initialize manual tags initialize_manual_tags() st.subheader("Generate Tag Essence") st.write("Select a tag to generate its essence. Higher quality essences can help unlock rare related tags when uploaded back into the tagger.") # Settings column col1, col2 = st.columns(2) with col1: st.write("Generation Settings:") # Add reset button if st.button("Reset to Defaults", help="Clear saved settings and use default values"): st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy() tag_storage.save_essence_state(session_state=st.session_state) st.success("Settings reset to defaults!") st.rerun() # Advanced settings with better organization with st.expander("Advanced Settings", expanded=True): col_a, col_b = st.columns(2) with col_a: # Core generation parameters st.write("**Core Parameters**") iterations = st.slider( "Iterations", min_value=64, max_value=2048, value=st.session_state.essence_custom_settings.get("iterations", 500), step=64, help="More iterations improve quality but take longer" ) lr = st.slider( "Learning Rate", min_value=0.01, max_value=0.2, value=st.session_state.essence_custom_settings.get("lr", 0.05), step=0.01, help="Higher learning rates converge faster but may be less stable" ) ensemble_k = st.slider( "Ensemble Size", min_value=1, max_value=16, value=st.session_state.essence_custom_settings.get("ensemble_k", 8), step=1, help="Number of augmented versions per iteration. Higher = more stable but slower" ) with col_b: # Multi-tag parameters st.write("**Multi-Tag Enhancement**") neighbor_count = st.slider( "Neighbor Tags", min_value=0, max_value=16, value=st.session_state.essence_custom_settings.get("neighbor_count", 8), step=1, help="Number of similar/dissimilar tags to consider. 0 = only target tag" ) tv_weight = st.select_slider( "Smoothness", options=[1e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2], value=st.session_state.essence_custom_settings.get("tv_weight", 1e-3), format_func=lambda x: f"{x:.0e}", help="Higher values create smoother, less noisy images" ) # Layer emphasis selection layer_emphasis = st.selectbox( "Feature Targeting", options=["balanced", "early", "mid", "late"], index=0, format_func=lambda x: { "balanced": "Balanced (mix of features)", "early": "Early (textures, patterns)", "mid": "Mid (parts, components)", "late": "Late (characters, objects)" }.get(x, x), help="Controls which model features to emphasize" ) # Save settings st.session_state.essence_custom_settings = { "iterations": iterations, "lr": lr, "ensemble_k": ensemble_k, "neighbor_count": neighbor_count, "image_size": 512, # Fixed for now "layer_emphasis": layer_emphasis, "tv_weight": tv_weight } # Show current settings summary st.info(f""" **Current Settings:** - Iterations: {iterations} - Learning Rate: {lr} - Ensemble Size: {ensemble_k} - Neighbor Tags: {neighbor_count} - Feature Focus: {layer_emphasis.capitalize()} """) with col2: # Show quality level descriptions st.write("Quality Levels:") for level, info in ESSENCE_QUALITY_LEVELS.items(): st.markdown(f"""