Spaces:
Running
Running
#!/usr/bin/env python3 | |
""" | |
Vision Transformer Essence Generator for Tag Collector Game | |
Based on "What do Vision Transformers Learn? A Visual Exploration" | |
""" | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision.transforms.functional import to_pil_image | |
from PIL import Image | |
import numpy as np | |
import os | |
import re | |
import math | |
import json | |
import timm | |
import streamlit as st | |
from tqdm import tqdm | |
from scipy.ndimage import gaussian_filter | |
from functools import wraps, lru_cache | |
from safetensors.torch import load_file | |
import time | |
import tag_storage # Import for saving game state | |
from game_constants import RARITY_LEVELS, ENKEPHALIN_CURRENCY_NAME, ENKEPHALIN_ICON | |
from tag_categories import TAG_CATEGORIES | |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cudnn.deterministic = True | |
# Define essence quality levels with thresholds and styles | |
ESSENCE_QUALITY_LEVELS = { | |
"ZAYIN": {"threshold": 0.0, "color": "#1CFC00", "description": "Basic representation with minimal details."}, | |
"TETH": {"threshold": 3.0, "color": "#389DDF", "description": "Clear representation with recognizable features."}, | |
"HE": {"threshold": 5.0, "color": "#FEF900", "description": "Refined representation with distinctive elements."}, | |
"WAW": {"threshold": 10.0, "color": "#7930F1", "description": "Advanced representation with precise details."}, | |
"ALEPH": {"threshold": 12.0, "color": "#FF0000", "description": "Perfect representation with extraordinary precision."} | |
} | |
# Essence generation costs in enkephalin based on tag rarity | |
ESSENCE_COSTS = { | |
"Special": 0, | |
"Canard": 100, | |
"Urban Myth": 125, | |
"Urban Legend": 150, | |
"Urban Plague": 200, | |
"Urban Nightmare": 250, | |
"Star of the City": 300, | |
"Impuritas Civitas": 400 | |
} | |
# Default essence generation settings | |
DEFAULT_ESSENCE_SETTINGS = { | |
"iterations": 256, | |
"lr": 0.05, | |
"ensemble_k": 8, | |
"neighbor_count": 8, | |
"image_size": 512, | |
"layer_emphasis": "balanced", | |
"tv_weight": 1e-3 | |
} | |
def initialize_essence_settings(): | |
"""Initialize essence generator settings if not already present""" | |
if 'essence_custom_settings' not in st.session_state: | |
# Try to load from storage first | |
loaded_state = tag_storage.load_essence_state() | |
if loaded_state and 'essence_custom_settings' in loaded_state: | |
old_settings = loaded_state['essence_custom_settings'] | |
# Validate and merge with current defaults | |
new_settings = DEFAULT_ESSENCE_SETTINGS.copy() | |
# Only keep valid settings that exist in current defaults | |
for key in DEFAULT_ESSENCE_SETTINGS.keys(): | |
if key in old_settings: | |
# Validate layer_emphasis values | |
if key == 'layer_emphasis' and old_settings[key] not in ['balanced', 'early', 'mid', 'late']: | |
continue # Use default | |
new_settings[key] = old_settings[key] | |
st.session_state.essence_custom_settings = new_settings | |
else: | |
st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy() | |
def initialize_manual_tags(): | |
"""Initialize manual tags if not already present""" | |
if 'manual_tags' not in st.session_state: | |
# Try to load from storage first | |
loaded_state = tag_storage.load_essence_state() | |
if loaded_state and 'manual_tags' in loaded_state: | |
st.session_state.manual_tags = loaded_state['manual_tags'] | |
else: | |
st.session_state.manual_tags = { | |
"hatsune_miku": {"rarity": "Special", "description": "Popular virtual singer with long teal twin-tails"}, | |
} | |
def timeout(seconds, fallback_value=None): | |
"""Simple timeout utility for functions.""" | |
def decorator(func): | |
def wrapper(*args, **kwargs): | |
start_time = time.time() | |
result = func(*args, **kwargs) | |
elapsed = time.time() - start_time | |
if elapsed > seconds: | |
print(f"WARNING: Function {func.__name__} took {elapsed:.2f} seconds (expected max {seconds}s)") | |
return result | |
return wrapper | |
return decorator | |
class TaggerTorch(nn.Module): | |
def __init__(self, backbone_name="vit_base_patch16_384", img_size=512, num_tags=70527, normalize=True): | |
super().__init__() | |
# num_classes=0 -> return features; we add our own head | |
self.backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0, img_size=img_size) | |
in_features = self.backbone.num_features # 768 for vit_base_patch16_384 | |
self.head = nn.Linear(in_features, num_tags) | |
# Most ViT taggers expect ImageNet normalization; keep it configurable | |
self.normalize = normalize | |
if self.normalize: | |
self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)) | |
self.register_buffer("std", torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)) | |
def forward(self, x): | |
if self.normalize: | |
x = (x - self.mean) / self.std | |
feats = self.backbone.forward_features(x) # [B, C] or [B, tokens, C] | |
if feats.ndim == 3: # if tokens, take CLS | |
feats = feats[:, 0, :] | |
return self.head(feats) | |
def _remap_backbone_keys(sd): | |
out = {} | |
for k, v in sd.items(): | |
if k.startswith("module."): k = k[7:] | |
# collapse ImageTagger β TaggerTorch vit paths | |
if k.startswith("backbone.vit."): | |
k = "backbone." + k[len("backbone.vit."):] | |
elif k.startswith("vit."): | |
k = "backbone." + k[len("vit."):] | |
elif k.startswith(("pos_embed","patch_embed.","blocks.","norm.","cls_token")): | |
k = "backbone." + k | |
out[k] = v | |
return out | |
def _get_logits_from_output(out): | |
if isinstance(out, dict): | |
return out.get("refined_predictions") or out.get("initial_predictions") | |
return out | |
def build_torch_model_from_safetensors(ckpt_path, num_tags, backbone="vit_base_patch16_384", img_size=512): | |
model = TaggerTorch(backbone_name=backbone, img_size=img_size, num_tags=num_tags, normalize=True) | |
sd = load_file(ckpt_path) | |
sd = _remap_backbone_keys(sd) | |
# pull out tag embedding/bias if present and later copy into the linear head | |
te_w = sd.pop("tag_embedding.weight", sd.pop("module.tag_embedding.weight", None)) | |
te_b = sd.pop("tag_bias", sd.pop("module.tag_bias", None)) | |
# load backbone etc. | |
missing, unexpected = model.load_state_dict(sd, strict=False) | |
print("[load] missing:", missing[:20], "β¦") | |
print("[load] unexpected:", unexpected[:20], "β¦") | |
# copy tag embedding β head | |
with torch.no_grad(): | |
if te_w is not None and te_w.shape == model.head.weight.shape: | |
model.head.weight.copy_(te_w) | |
print("[load] copied tag_embedding.weight β head.weight") | |
if te_b is not None and model.head.bias is not None and te_b.shape == model.head.bias.shape: | |
model.head.bias.copy_(te_b) | |
print("[load] copied tag_bias β head.bias") | |
return model | |
def _get_classifier_matrix(model): | |
# [T, C] β works for both ImageTagger and TaggerTorch | |
if hasattr(model, "tag_embedding"): | |
return model.tag_embedding.weight.detach() | |
if hasattr(model, "head") and hasattr(model.head, "weight"): | |
return model.head.weight.detach() | |
raise AttributeError("Model has neither tag_embedding nor head.weight") | |
def neighbor_sets_from_embedding(model, class_idx, k_pos=8, k_neg=8): | |
""" | |
Returns (pos_idx, pos_sims, neg_idx, neg_sims) | |
pos: highest cosine neighbors (exclude self) | |
neg: lowest cosine neighbors (most dissimilar) | |
""" | |
W = _get_classifier_matrix(model) # [T, C] | |
Wn = F.normalize(W, dim=1) | |
q = Wn[class_idx:class_idx+1] # [1, C] | |
sims = (q @ Wn.T).squeeze(0) # [T] | |
sims[class_idx] = -9e9 # mask self | |
# positives: largest similarities | |
pos_vals, pos_idx = torch.topk(sims, k=min(k_pos, sims.numel()-1)) | |
# negatives: smallest similarities (most negative / least similar) | |
neg_vals, neg_idx = torch.topk(-sims, k=min(k_neg, sims.numel()-1)) | |
neg_vals = -neg_vals | |
# weights (clip for stability) | |
pos_w = torch.clamp(pos_vals, 0.0, 1.0).tolist() | |
neg_w = torch.clamp(neg_vals.abs(), 0.0, 1.0).tolist() | |
return pos_idx.tolist(), pos_w, neg_idx.tolist(), neg_w | |
def weighted_class_objective(logits, main_idx, | |
plus_idxs=(), plus_w=None, alpha=0.25, | |
minus_idxs=(), minus_w=None, beta=0.15): | |
score = logits[:, main_idx].mean() | |
if plus_idxs: | |
w = torch.tensor(plus_w or [1.0]*len(plus_idxs), device=logits.device, dtype=logits.dtype) | |
w = w / (w.sum() + 1e-8) | |
score = score + alpha * (logits[:, plus_idxs] * w).sum(dim=1).mean() | |
if minus_idxs: | |
w = torch.tensor(minus_w or [1.0]*len(minus_idxs), device=logits.device, dtype=logits.dtype) | |
w = w / (w.sum() + 1e-8) | |
score = score - beta * (logits[:, minus_idxs] * w).sum(dim=1).mean() | |
return score | |
def idx_to_name(idx, dataset=None): | |
if dataset is not None and hasattr(dataset, "idx_to_tag"): | |
return dataset.idx_to_tag.get(int(idx), f"Tag {idx}") | |
# fallback to your cached JSON | |
meta = _load_tagger_metadata_cached() | |
return meta.get("dataset_info",{}).get("tag_mapping",{}).get("idx_to_tag",{}).get(str(int(idx)), f"Tag {idx}") | |
# Core Classes for ViT Essence Generation | |
class ViTLayerHook: | |
"""Hook for capturing ViT feed-forward layer activations.""" | |
def __init__(self, layer, layer_name): | |
self.layer = layer | |
self.layer_name = layer_name | |
self.features = None | |
self.hook = layer.register_forward_hook(self.hook_fn) | |
def hook_fn(self, module, input, output): | |
"""Store the output activations.""" | |
self.features = output | |
def close(self): | |
self.hook.remove() | |
class ViTFeatureAnalyzer: | |
"""Analyzes ViT architecture to find optimal layers for visualization.""" | |
def __init__(self, model): | |
self.model = model | |
self.layer_info = self._analyze_architecture() | |
def _analyze_architecture(self): | |
"""Analyze the ViT architecture and identify feed-forward layers.""" | |
layer_info = {} | |
def traverse_modules(module, prefix=''): | |
for name, child in module.named_children(): | |
full_name = f"{prefix}.{name}" if prefix else name | |
# Look for transformer blocks and their MLP components | |
if 'mlp' in full_name.lower() and (hasattr(child, 'act') or 'act' in dict(child.named_children())): | |
# prefer the actual activation submodule | |
act = getattr(child, 'act', None) | |
if act is not None: | |
layer_info[full_name + ".act"] = { | |
'type': 'mlp_activation', | |
'module': act, | |
'block_idx': self._extract_block_number(full_name) | |
} | |
else: | |
# fallback: search by name | |
for n2, c2 in child.named_children(): | |
if 'act' in n2.lower(): | |
layer_info[full_name + f".{n2}"] = { | |
'type': 'mlp_activation', | |
'module': c2, | |
'block_idx': self._extract_block_number(full_name) | |
} | |
elif 'gelu' in str(type(child)).lower() or 'activation' in name.lower(): | |
# Direct activation layers (GELU, etc.) | |
parent_name = prefix.split('.')[-1] if '.' in prefix else prefix | |
if 'mlp' in prefix.lower() or 'ffn' in prefix.lower(): | |
layer_info[full_name] = { | |
'type': 'activation', | |
'module': child, | |
'block_idx': self._extract_block_number(full_name) | |
} | |
# Recurse into children | |
traverse_modules(child, full_name) | |
traverse_modules(self.model) | |
return layer_info | |
def _extract_block_number(self, layer_name): | |
"""Extract block/layer number from layer name.""" | |
import re | |
numbers = re.findall(r'\.(\d+)\.', layer_name) | |
if numbers: | |
return int(numbers[0]) | |
return 0 | |
def get_visualization_layers(self, layer_emphasis="balanced"): | |
"""Get the best layers for visualization based on emphasis.""" | |
if not self.layer_info: | |
print("Warning: No suitable ViT layers found for visualization") | |
return [] | |
# Sort layers by block index | |
sorted_layers = sorted( | |
[(n, info) for n, info in self.layer_info.items() if 'mlp' in n.lower() and 'act' in n.lower()], | |
key=lambda x: x[1]['block_idx'] | |
) | |
total_blocks = max([info['block_idx'] for _, info in sorted_layers]) + 1 | |
if layer_emphasis == "early": | |
# Focus on first 1/3 of blocks | |
target_blocks = list(range(0, max(1, total_blocks // 3))) | |
elif layer_emphasis == "mid": | |
# Focus on middle 1/3 of blocks | |
start = total_blocks // 3 | |
end = 2 * total_blocks // 3 | |
target_blocks = list(range(start, max(start + 1, end))) | |
elif layer_emphasis == "late": | |
# Focus on last 1/3 of blocks | |
start = 2 * total_blocks // 3 | |
target_blocks = list(range(start, total_blocks)) | |
else: # balanced | |
# Sample across all blocks | |
if total_blocks <= 3: | |
target_blocks = list(range(total_blocks)) | |
else: | |
target_blocks = [0, total_blocks // 2, total_blocks - 1] | |
# Select layers from target blocks | |
selected_layers = [] | |
for layer_name, info in sorted_layers: | |
if info['block_idx'] in target_blocks: | |
selected_layers.append(layer_name) | |
return selected_layers | |
def _jitter_reflect_crop(x, pad=16): | |
b, c, h, w = x.shape | |
padded = F.pad(x, (pad, pad, pad, pad), mode='reflect').contiguous() | |
off_h = torch.randint(0, 2 * pad + 1, (b,), device=x.device) | |
off_w = torch.randint(0, 2 * pad + 1, (b,), device=x.device) | |
crops = [] | |
for i in range(b): | |
hs, ws = int(off_h[i]), int(off_w[i]) | |
crop = padded[i:i+1, :, hs:hs+h, ws:ws+w].contiguous() | |
crops.append(crop) | |
return torch.cat(crops, 0).contiguous() | |
def _channel_affine(x): | |
# per-channel affine: Ο ~ exp(U[-1,1]), ΞΌ ~ U[-1,1] | |
b, c, _, _ = x.shape | |
mu = torch.empty(b, c, 1, 1, device=x.device, dtype=x.dtype).uniform_(-1.0, 1.0) | |
log_sigma = torch.empty(b, c, 1, 1, device=x.device, dtype=x.dtype).uniform_(-1.0, 1.0) | |
sigma = torch.exp(log_sigma) | |
return (x * sigma + mu) | |
def _add_gaussian_noise(x, std=0.15): | |
return (x + torch.randn_like(x) * std) | |
def _augment_once(x, noise_std=0.15): | |
z = _jitter_reflect_crop(x) | |
z = _channel_affine(z) | |
z = _add_gaussian_noise(z, std=noise_std) | |
return z | |
def _augment_batch(x, K=8, noise_std=0.15): | |
augs = [] | |
for _ in range(K): | |
z = _augment_once(x, noise_std=noise_std) | |
augs.append(z) | |
return torch.cat(augs, dim=0).contiguous() | |
class ViTEssenceGenerator: | |
""" | |
ViT Essence Generator based on the methodology from | |
'What do Vision Transformers Learn? A Visual Exploration' | |
""" | |
def __init__( | |
self, | |
model, | |
tag_to_name=None, | |
iterations=500, | |
learning_rate=0.05, | |
layer_emphasis="balanced", | |
ensemble_K=8, | |
tv_weight=1e-3 | |
): | |
"""Initialize the ViT Essence Generator""" | |
self.model = model | |
self.tag_to_name = tag_to_name | |
self.iterations = iterations | |
self.lr = learning_rate | |
self.layer_emphasis = layer_emphasis | |
self.ensemble_K = ensemble_K | |
self.tv_weight = tv_weight | |
# Set device | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.eval().to(self.device) | |
self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1,3,1,1) | |
self.imagenet_std = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1,3,1,1) | |
self.expect_imagenet = not (hasattr(self.model, "normalize") and getattr(self.model, "normalize") is True) | |
# Analyze ViT architecture | |
self.analyzer = ViTFeatureAnalyzer(self.model) | |
# Initialize hooks | |
self.hooks = {} | |
self.selected_layers = [] | |
print(f"ViT Essence Generator initialized on {self.device}") | |
def _preprocess(self, x): | |
return (x - self.imagenet_mean) / self.imagenet_std if self.expect_imagenet else x | |
def setup_hooks(self, tag_idx): | |
"""Setup hooks for multi-layer visualization.""" | |
self.close_hooks() | |
names = self.analyzer.get_visualization_layers(self.layer_emphasis) | |
if not names: | |
print("Warning: No suitable layers found for visualization") | |
return {} | |
print(f"Setting up hooks on {len(names)} ViT layer(s)") | |
layer_weights = {} | |
for i, layer_name in enumerate(names): | |
try: | |
layer_info = self.analyzer.layer_info[layer_name] | |
layer_module = layer_info['module'] | |
self.hooks[layer_name] = ViTLayerHook(layer_module, layer_name) | |
weight = 0.3 + 0.7 * (i / max(1, len(names) - 1)) | |
layer_weights[layer_name] = weight | |
print(f" - {layer_name} (block {layer_info['block_idx']}, weight: {weight:.2f})") | |
except Exception as e: | |
print(f"Failed to setup hook for {layer_name}: {e}") | |
self.selected_layers = names | |
return layer_weights | |
def close_hooks(self): | |
"""Clean up hooks to avoid memory leaks.""" | |
for hook in self.hooks.values(): | |
hook.close() | |
self.hooks.clear() | |
def _fourier_init(self, size=224, decay=1.5): | |
H = W = size | |
# complex spectrum (rFFT domain) | |
spec = torch.randn(1, 3, H, W//2 + 1, dtype=torch.complex64, device=self.device) | |
fy = torch.fft.fftfreq(H, device=self.device).abs().view(H, 1) | |
fx = torch.fft.rfftfreq(W, device=self.device).abs().view(1, W//2 + 1) | |
radius = (fy**2 + fx**2).sqrt().clamp_(min=1e-6) | |
spec = spec * (1.0 / (radius ** decay)) # 1/f^decay | |
img = torch.fft.irfft2(spec, s=(H, W)) # [1,3,H,W], roughly zero-mean | |
# scale to [0,1] | |
img = (img - img.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0]) | |
img = img / (img.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] + 1e-8) | |
return img | |
def create_optimizable_image(self, size=224, use_fourier=True): | |
if use_fourier: | |
with torch.no_grad(): | |
image = self._fourier_init(size) | |
image = image.to(self.device) | |
else: | |
image = torch.rand(1, 3, size, size, device=self.device) | |
image = image.detach().contiguous().requires_grad_(True) | |
return image | |
def total_variation_loss(self, image): | |
# image: [B,3,H,W] | |
diff_y = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :]) | |
diff_x = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1]) | |
tv_per_sample = diff_y.mean(dim=(1,2,3)) + diff_x.mean(dim=(1,2,3)) # [B] | |
return tv_per_sample.mean() | |
def get_feature_activations(self, layer_weights, topk_channels=None): | |
total = 0.0 | |
for name, hook in self.hooks.items(): | |
feats = hook.features # [B, tokens, C] from GELU | |
if feats is None: | |
continue | |
w = layer_weights.get(name, 0.5) | |
# aggregate: sum over tokens; then (optionally) top-k over channels | |
agg = feats.sum(dim=1) # [B, C] | |
if topk_channels is not None and topk_channels > 0 and agg.shape[1] > topk_channels: | |
# take the mean of top-k channels for stability | |
vals, _ = torch.topk(agg, k=topk_channels, dim=1) | |
act = vals.mean() | |
else: | |
act = agg.mean() | |
total = total + w * act | |
return total | |
def generate_essence(self, tag_idx, neighbor_count=8, image_size=224, return_score=True, progress_callback=None): | |
"""Generate an essence visualization for a ViT model.""" | |
# Get tag name for logging | |
tag_name = self.tag_to_name.get(tag_idx, f"Tag {tag_idx}") if self.tag_to_name else f"Tag {tag_idx}" | |
print(f"Generating ViT essence for '{tag_name}' (index: {tag_idx})...") | |
# Setup hooks for this tag | |
layer_weights = self.setup_hooks(tag_idx) | |
if not self.hooks and not hasattr(self.model, 'head'): | |
print("Warning: No hooks set up and no classifier head found") | |
return self._create_fallback_image(image_size), 0.0 | |
# Initialize optimizable image | |
image = self.create_optimizable_image(image_size) | |
# Create optimizer | |
optimizer = torch.optim.Adam([image], lr=self.lr) | |
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR( | |
optimizer, T_max=self.iterations, eta_min=self.lr * 0.01 | |
) | |
best_score = -float('inf') | |
best_image = None | |
print(f"Starting optimization for {self.iterations} iterations...") | |
# Choose auxiliaries once per run | |
pos_idx, pos_w, neg_idx, neg_w = neighbor_sets_from_embedding( | |
self.model, tag_idx, k_pos=neighbor_count, k_neg=neighbor_count | |
) | |
for i in range(self.iterations): | |
optimizer.zero_grad() | |
# Clear previous activations | |
for hook in self.hooks.values(): | |
hook.features = None | |
# Original generation logic - tag mode only | |
aug_batch = _augment_batch(image, K=self.ensemble_K, noise_std=0.15) | |
out = self.model(self._preprocess(aug_batch)) | |
logits = out["refined_predictions"] if isinstance(out, dict) else out # [K, T] | |
cls_term = weighted_class_objective( | |
logits, main_idx=tag_idx, | |
plus_idxs=pos_idx, plus_w=pos_w, alpha=0.25, | |
minus_idxs=neg_idx, minus_w=neg_w, beta=0.15 | |
) | |
# keep your feature regularizer (hooks) | |
feat_term = 0.0 | |
if self.hooks: | |
feats = self.get_feature_activations(layer_weights, topk_channels=64) | |
Ltv = self.total_variation_loss(aug_batch) | |
total_loss = -(cls_term + 0.5 * feat_term) + self.tv_weight * Ltv | |
# Backward pass | |
total_loss.backward() | |
if image.grad is None or not torch.isfinite(image.grad).all(): | |
print("WARN: no/invalid grad reaching the image; check hook & loss wiring.") | |
# Gradient clip to avoid exploding updates | |
torch.nn.utils.clip_grad_norm_([image], max_norm=3.0) | |
optimizer.step() | |
scheduler.step() | |
# Keep pixels in valid range | |
with torch.no_grad(): | |
image.clamp_(0.0, 1.0) | |
# Handle non-finite losses | |
if not torch.isfinite(total_loss.detach()): | |
print("WARN: non-finite loss; resetting image step") | |
optimizer.zero_grad(set_to_none=True) | |
with torch.no_grad(): | |
# Small reset toward noise | |
image.add_(0.05 * torch.randn_like(image)).clamp_(0.0, 1.0) | |
continue | |
# Track best result - using original score calculation | |
with torch.no_grad(): | |
score_tensor = -(total_loss - self.tv_weight * Ltv) | |
current_score = float(score_tensor.item()) | |
if current_score > best_score: | |
best_score = current_score | |
best_image = image.detach().clone() | |
# Progress reporting | |
if progress_callback and i % max(1, self.iterations // 20) == 0: | |
progress_callback( | |
scale_idx=0, | |
scale_count=1, | |
iter_idx=i, | |
iter_count=self.iterations, | |
score=current_score | |
) | |
# Logging | |
if i % max(1, self.iterations // 10) == 0: | |
print(f"Iteration {i}/{self.iterations}: Score = {current_score:.4f}") | |
# Use best image if we found one | |
if best_image is not None: | |
final_image = best_image | |
else: | |
final_image = image.detach() | |
# Convert to PIL image | |
final_image = torch.clamp(final_image, 0, 1) | |
pil_img = to_pil_image(final_image[0].cpu()) | |
# Clean up hooks | |
self.close_hooks() | |
print(f"ViT essence generation complete for '{tag_name}'. Final score: {best_score:.4f}") | |
if return_score: | |
return pil_img, best_score | |
else: | |
return pil_img | |
def _create_fallback_image(self, size): | |
"""Create a fallback image when generation fails.""" | |
# Create a simple noise pattern | |
image = torch.randn(1, 3, size, size) * 0.5 + 0.5 | |
image = torch.clamp(image, 0, 1) | |
return to_pil_image(image[0]) | |
# Utility Functions | |
def get_quality_level(score): | |
"""Determine the quality level of an essence based on its score""" | |
for level in reversed(list(ESSENCE_QUALITY_LEVELS.keys())): | |
if score >= ESSENCE_QUALITY_LEVELS[level]["threshold"]: | |
return level | |
return "ZAYIN" # Default to lowest level | |
def get_essence_cost(rarity): | |
"""Calculate the cost to generate an essence image based on tag rarity""" | |
return ESSENCE_COSTS.get(rarity, 100) # Default to 100 if rarity unknown | |
def save_essence_to_game_folder(image, tag, score, quality_level): | |
"""Save the generated essence image to a persistent game folder""" | |
# Create game folder paths with better structure | |
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
game_data_dir = os.path.join(base_dir, "game_data") | |
essence_folder = os.path.join(game_data_dir, "essences") | |
# Make sure all parent directories exist | |
os.makedirs(game_data_dir, exist_ok=True) | |
os.makedirs(essence_folder, exist_ok=True) | |
# Organize essences by quality level for easier browsing | |
quality_folder = os.path.join(essence_folder, quality_level) | |
os.makedirs(quality_folder, exist_ok=True) | |
# Create filename with more details and better organization | |
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
timestamp = time.strftime("%Y%m%d_%H%M%S") | |
filename = f"{safe_tag}_{score:.2f}_{timestamp}.png" | |
filepath = os.path.join(quality_folder, filename) | |
# Save the image | |
image.save(filepath) | |
print(f"Saved ViT essence to: {filepath}") | |
return filepath | |
def load_tagger_metadata(): | |
"""Load the camie-tagger-v2-metadata.json file from parent directory.""" | |
try: | |
# Look for metadata file in parent directory | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
parent_dir = os.path.dirname(current_dir) | |
metadata_path = os.path.join(parent_dir, "camie-tagger-v2-metadata.json") | |
if os.path.exists(metadata_path): | |
with open(metadata_path, 'r', encoding='utf-8') as f: | |
metadata = json.load(f) | |
print(f"Loaded tagger metadata from: {metadata_path}") | |
return metadata | |
else: | |
print(f"Metadata file not found at: {metadata_path}") | |
return None | |
except Exception as e: | |
print(f"Error loading tagger metadata: {e}") | |
return None | |
def _load_tagger_metadata_cached(): | |
meta = load_tagger_metadata() | |
return meta or {} | |
def resolve_tag_index(tag, dataset=None): | |
"""Robustly resolve tag -> index using dataset, session metadata, then camie-tagger-v2-metadata.json.""" | |
if not isinstance(tag, str): | |
return int(tag) | |
# normalize variants | |
cands = {tag.strip(), tag.strip().replace(" ", "_")} | |
cands |= {c.lower() for c in list(cands)} | |
# 1) dataset.tag_to_idx | |
if dataset is not None and hasattr(dataset, "tag_to_idx"): | |
for c in cands: | |
if c in dataset.tag_to_idx: | |
return int(dataset.tag_to_idx[c]) | |
# 2) session metadata | |
sm = getattr(st.session_state, "metadata", {}) or {} | |
m = sm.get("tag_to_idx", {}) if isinstance(sm, dict) else {} | |
for c in cands: | |
if c in m: | |
return int(m[c]) | |
# 3) JSON metadata (cached) | |
meta = _load_tagger_metadata_cached() | |
mjson = (meta.get("dataset_info", {}) | |
.get("tag_mapping", {}) | |
.get("tag_to_idx", {})) if isinstance(meta, dict) else {} | |
for c in cands: | |
if c in mjson: | |
return int(mjson[c]) | |
return None | |
def generate_essence_for_tag(tag, model, dataset, custom_settings=None): | |
""" | |
Generate an essence image for a specific tag using the ViT generator | |
Args: | |
tag: The tag name or index | |
model: The ViT model to use | |
dataset: The dataset containing tag information | |
custom_settings: Optional dictionary with custom generation settings | |
Returns: | |
PIL Image of the generated essence, score, quality level | |
""" | |
print(f"\n=== Starting ViT essence generation for tag '{tag}' ===") | |
# Check if tag is discovered or a manual tag | |
is_manual_tag = hasattr(st.session_state, 'manual_tags') and tag in st.session_state.manual_tags | |
is_discovered = hasattr(st.session_state, 'discovered_tags') and tag in st.session_state.discovered_tags | |
if not is_discovered and not is_manual_tag: | |
st.error(f"Tag '{tag}' has not been discovered yet.") | |
return None, 0, None | |
# Get tag rarity and calculate cost | |
if is_discovered: | |
rarity = st.session_state.discovered_tags[tag].get("rarity", "Canard") | |
elif is_manual_tag: | |
rarity = st.session_state.manual_tags[tag].get("rarity", "Canard") | |
else: | |
rarity = "Canard" | |
# Calculate cost based on rarity | |
cost = get_essence_cost(rarity) | |
# Check if player has enough Enkephalin | |
if st.session_state.enkephalin < cost: | |
st.error(f"Not enough {ENKEPHALIN_CURRENCY_NAME} to generate this essence. You need {cost} {ENKEPHALIN_ICON} but have {st.session_state.enkephalin} {ENKEPHALIN_ICON}.") | |
return None, 0, None | |
# Use provided settings or defaults | |
settings = custom_settings or DEFAULT_ESSENCE_SETTINGS.copy() | |
print(f"Using settings: {settings}") | |
# UI containers for progress | |
preview_container = st.empty() | |
progress_container = st.empty() | |
message_container = st.empty() | |
try: | |
message_container.info(f"Generating ViT essence for '{tag}' with {settings.get('layer_emphasis', 'balanced')} layer emphasis...") | |
# Progress callback function | |
def progress_callback(scale_idx, scale_count, iter_idx, iter_count, score): | |
progress = iter_idx / iter_count | |
progress_container.progress(progress, f"Iteration {iter_idx}/{iter_count}") | |
message_container.info(f"Current score: {score:.4f}") | |
if iter_idx % 50 == 0: | |
print(f"Progress: Iteration {iter_idx}/{iter_count}, Score: {score:.4f}") | |
# Convert tag name to index | |
tag_idx = None | |
if isinstance(tag, str): | |
tag_idx = resolve_tag_index(tag, dataset) | |
if tag_idx is None: | |
st.error( | |
f"Tag '{tag}' index not found in dataset or metadata. " | |
f"Make sure it exists in camie-tagger-v2-metadata.json." | |
) | |
return None, 0, None | |
else: | |
tag_idx = int(tag) | |
print(f"Resolved tag '{tag}' -> index {tag_idx}") | |
# Create tag-to-name mapping | |
tag_to_name = {tag_idx: tag} | |
# Get or create Torch model specifically for essence generation | |
torch_model = getattr(st.session_state, "model_torch", None) | |
if not isinstance(torch_model, nn.Module): | |
# safetensors lives ONE directory above this file | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
parent_dir = os.path.dirname(current_dir) | |
ckpt = os.path.join(parent_dir, "camie-tagger-v2.safetensors") | |
if not os.path.exists(ckpt): | |
st.error(f"Missing safetensors checkpoint at: {ckpt}") | |
return None, 0, None | |
# metadata-driven sizes | |
meta = _load_tagger_metadata_cached() | |
num_tags = int(meta.get("dataset_info", {}).get("total_tags", 70527)) | |
img_size = int(meta.get("model_info", {}).get("img_size", 512)) | |
torch_model = build_torch_model_from_safetensors( | |
ckpt_path=ckpt, | |
num_tags=num_tags, | |
backbone="vit_base_patch16_384", | |
img_size=img_size | |
) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
torch_model = torch_model.to(device).eval() | |
st.session_state.model_torch = torch_model # cache for later | |
# Create ViT essence generator with settings from UI | |
generator = ViTEssenceGenerator( | |
model=torch_model, | |
tag_to_name=tag_to_name, | |
iterations=settings.get('iterations', 500), | |
learning_rate=settings.get('lr', 0.05), | |
layer_emphasis=settings.get('layer_emphasis', 'balanced'), | |
ensemble_K=settings.get('ensemble_k', 8), | |
tv_weight=settings.get('tv_weight', 1e-3) | |
) | |
image, score = generator.generate_essence( | |
tag_idx=tag_idx, | |
neighbor_count=settings.get('neighbor_count', 8), | |
image_size=settings.get('image_size', 512), | |
return_score=True, | |
progress_callback=progress_callback | |
) | |
# Determine quality level | |
quality_level = get_quality_level(score) | |
# Deduct enkephalin cost | |
st.session_state.enkephalin -= cost | |
st.session_state.game_stats["enkephalin_spent"] = st.session_state.game_stats.get("enkephalin_spent", 0) + cost | |
# Increment essence counter | |
st.session_state.game_stats["essences_generated"] = st.session_state.game_stats.get("essences_generated", 0) + 1 | |
# Save to persistent location | |
filepath = save_essence_to_game_folder(image, tag, score, quality_level) | |
# Update UI with result | |
preview_container.image(image, caption=f"ViT Essence of '{tag}' - Quality: {quality_level}", width=400) | |
# Clear progress elements | |
progress_container.empty() | |
message_container.empty() | |
# Store in session state | |
if 'generated_essences' not in st.session_state: | |
st.session_state.generated_essences = {} | |
st.session_state.generated_essences[tag] = { | |
"path": filepath, | |
"score": score, | |
"quality": quality_level, | |
"rarity": rarity, | |
"settings": settings, | |
"generated_time": time.strftime("%Y-%m-%d %H:%M:%S") | |
} | |
# Show success message | |
st.success(f"Successfully generated {quality_level} ViT essence for '{tag}' with score {score:.4f}! Spent {cost} {ENKEPHALIN_ICON}") | |
print(f"=== ViT essence generation complete for '{tag}' ===\n") | |
# Save state | |
tag_storage.save_essence_state(session_state=st.session_state) | |
return image, score, quality_level | |
except Exception as e: | |
st.error(f"Error generating ViT essence: {str(e)}") | |
print(f"EXCEPTION in generate_essence_for_tag: {str(e)}") | |
import traceback | |
err_traceback = traceback.format_exc() | |
print(err_traceback) | |
st.code(err_traceback) | |
return None, 0, None | |
# Utility Functions for Model Analysis and Layer Selection | |
def get_model_layers(model): | |
"""Utility function to get all available layers in a model.""" | |
layers = [] | |
for name, _ in model.named_modules(): | |
if name: # Skip empty name (the model itself) | |
layers.append(name) | |
return layers | |
def get_key_layers(model, max_layers=15): | |
""" | |
Get a curated list of the most relevant layers for visualization. | |
""" | |
all_layers = get_model_layers(model) | |
# For models with hundreds of layers, we need to be selective | |
if len(all_layers) > 30: | |
# Extract patterns to identify layer types | |
block_patterns = {} | |
# Find common patterns in layer names | |
for layer in all_layers: | |
# Extract the main component (e.g., "backbone.features") | |
parts = layer.split(".") | |
if len(parts) >= 2: | |
prefix = ".".join(parts[:2]) | |
if prefix not in block_patterns: | |
block_patterns[prefix] = [] | |
block_patterns[prefix].append(layer) | |
# Now select representative layers from each major block | |
key_layers = { | |
"early": [], | |
"middle": [], | |
"late": [] | |
} | |
# For each major block, select layers at strategic positions | |
for prefix, layers in block_patterns.items(): | |
if len(layers) > 3: # Only process significant blocks | |
# Sort by natural depth (assuming numerical components indicate depth) | |
layers.sort(key=lambda x: [int(s) if s.isdigit() else s for s in re.findall(r'\d+|\D+', x)]) | |
# Get layers at strategic positions | |
early = layers[0] | |
middle = layers[len(layers) // 2] | |
late = layers[-1] | |
key_layers["early"].append(early) | |
key_layers["middle"].append(middle) | |
key_layers["late"].append(late) | |
# Ensure we don't have too many layers | |
# If we need to reduce further, prioritize middle and late layers | |
flattened = [] | |
for _, group_layers in key_layers.items(): | |
flattened.extend(group_layers) | |
if len(flattened) > max_layers: | |
# Calculate how many to keep from each group | |
total = len(flattened) | |
# Prioritize keeping late layers (for character recognition) | |
late_count = min(len(key_layers["late"]), max_layers // 3) | |
# Allocate remaining slots between early and middle | |
remaining = max_layers - late_count | |
middle_count = min(len(key_layers["middle"]), remaining // 2) | |
early_count = min(len(key_layers["early"]), remaining - middle_count) | |
# Take only the needed number from each category | |
key_layers["early"] = key_layers["early"][:early_count] | |
key_layers["middle"] = key_layers["middle"][:middle_count] | |
key_layers["late"] = key_layers["late"][:late_count] | |
else: | |
# For simpler models, use standard distribution | |
n = len(all_layers) | |
key_layers = { | |
"early": all_layers[:n//3][:3], # First few layers | |
"middle": all_layers[n//3:2*n//3][:4], # Middle layers | |
"late": all_layers[2*n//3:][:3] # Last few layers | |
} | |
# Try to identify the classifier/final layer | |
classifier_layers = [layer for layer in all_layers if any(x in layer.lower() | |
for x in ["classifier", "fc", "linear", "output", "logits", "head"])] | |
if classifier_layers: | |
key_layers["classifier"] = [classifier_layers[-1]] | |
return key_layers | |
def get_suggested_layers(model, layer_type="balanced"): | |
""" | |
Get suggested layers based on the desired feature type. | |
""" | |
key_layers = get_key_layers(model) | |
# Flatten all layers for reference | |
all_key_layers = [] | |
for layers in key_layers.values(): | |
all_key_layers.extend(layers) | |
# Choose layers based on the requested emphasis | |
if layer_type == "low": | |
# Focus on early visual features (textures, patterns, colors) | |
selected = key_layers.get("early", []) | |
# Add one middle layer for stability | |
if "middle" in key_layers and key_layers["middle"]: | |
selected.append(key_layers["middle"][0]) | |
elif layer_type == "mid": | |
# Focus on mid-level features (parts, components) | |
selected = key_layers.get("middle", []) | |
# Add one early layer for context | |
if "early" in key_layers and key_layers["early"]: | |
selected.append(key_layers["early"][-1]) | |
elif layer_type == "high": | |
# Focus on high-level semantic features (objects, characters) | |
selected = key_layers.get("late", []) | |
selected.extend(key_layers.get("classifier", [])) | |
# Add one middle layer for context | |
if "middle" in key_layers and key_layers["middle"]: | |
selected.append(key_layers["middle"][-1]) | |
else: # balanced | |
# Use a mix of early, middle and late layers | |
selected = [] | |
for category in ["early", "middle", "late", "classifier"]: | |
if category in key_layers and key_layers[category]: | |
# Take one from each category | |
selected.append(key_layers[category][0]) | |
# For middle and late, also take the last one if different | |
if category in ["middle", "late"] and len(key_layers[category]) > 1: | |
selected.append(key_layers[category][-1]) | |
# Ensure we have at least one layer | |
if not selected and all_key_layers: | |
selected = [all_key_layers[-1]] # Use the last layer as fallback | |
return selected | |
# Game UI and Integration Functions | |
def display_essence_generator(): | |
""" | |
Display the essence generator interface | |
""" | |
# Initialize settings | |
initialize_essence_settings() | |
st.title("π¨ Tag Essence Generator") | |
st.write("Generate visual representations of what the AI model recognizes for specific tags.") | |
# Add detailed explanation of what essences are for | |
with st.expander("What are Tag Essences & How to Use Them", expanded=True): | |
st.markdown(""" | |
### π‘ Understanding Tag Essences | |
Tag Essences are visual representations of what the AI model recognizes for specific tags. They can be extremely valuable for your tag collection strategy! | |
**How to use Tag Essences:** | |
1. **Generate a high-quality essence** for a tag you want to collect more of (only available on tags discovered in the library) | |
2. **Save the essence image** to your computer | |
3. **Upload the essence image** back into the tagger | |
4. The tagger will **almost always detect the original tag** | |
5. It will often also **detect related rare tags** from the same category | |
**Strategic Value:** | |
- Character essences can help unlock other tags associated with that character | |
- Category essences can help discover rare tags within that category | |
- High-quality essences (WAW, ALEPH) have the strongest effect | |
**This is why Enkephalin costs are high** - essences are powerful tools that can help you discover rare tags much more efficiently than random image scanning! | |
""") | |
# Check for model availability | |
model_available = hasattr(st.session_state, 'model') | |
if not model_available: | |
st.warning("Model not available. You can browse your tags but cannot generate essences.") | |
# Create tabs for the different sections | |
tabs = st.tabs(["Generate Essence", "My Essences"]) | |
with tabs[0]: | |
# Check for pending generation from previous interaction | |
if hasattr(st.session_state, 'selected_tag') and st.session_state.selected_tag: | |
tag = st.session_state.selected_tag | |
st.subheader(f"Generating Essence for '{tag}'") | |
# Generate the essence | |
image, score, quality = generate_essence_for_tag( | |
tag, | |
st.session_state.model, | |
st.session_state.model.dataset, | |
st.session_state.essence_custom_settings | |
) | |
# Show usage tips if successful | |
if image is not None: | |
with st.expander("Essence Usage", expanded=True): | |
st.markdown(""" | |
π‘ **Tag Essence Usage Tips:** | |
1. Look for similar patterns, colors, and elements in real images | |
2. The essence reveals what features the AI model recognizes for this tag | |
3. Use this as inspiration when creating or finding images to get this tag | |
""") | |
else: | |
st.error("Essence generation failed. Please check the error messages above and try again with different settings.") | |
# Clear selected tag | |
st.session_state.selected_tag = None | |
else: | |
# Show the interface to select a tag | |
selected_tag = display_essence_generation_interface(model_available) | |
# If a tag was selected, store it for the next run and rerun | |
if selected_tag: | |
st.session_state.selected_tag = selected_tag | |
st.rerun() | |
with tabs[1]: | |
display_saved_essences() | |
def essence_folder_path(): | |
"""Get the path to the essence folder, creating it if necessary""" | |
base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) | |
game_data_dir = os.path.join(base_dir, "game_data") | |
essence_folder = os.path.join(game_data_dir, "essences") | |
# Make sure all directories exist | |
os.makedirs(game_data_dir, exist_ok=True) | |
os.makedirs(essence_folder, exist_ok=True) | |
return essence_folder | |
def display_saved_essences(): | |
"""Display the user's saved essence images""" | |
st.subheader("My Generated Essences") | |
if not hasattr(st.session_state, 'generated_essences') or not st.session_state.generated_essences: | |
st.info("You haven't generated any essences yet. Go to the Generate tab to create some!") | |
return | |
# Add usage instructions at the top | |
st.markdown(""" | |
### How to Use Your Essences | |
1. **Click on any essence image** to open it in full size | |
2. **Save the image** to your computer (right-click β Save image) | |
3. **Go to the Scan Images tab** and upload the saved essence | |
4. The tagger will likely detect the original tag and potentially related rare tags! | |
Higher quality essences (WAW, ALEPH) generally produce the best results. | |
""") | |
# Get the essence folder path | |
essence_dir = essence_folder_path() | |
# Try to locate any missing files | |
for tag, info in st.session_state.generated_essences.items(): | |
if "path" in info and not os.path.exists(info["path"]): | |
# Try to find the file in the essence directory | |
quality = info.get("quality", "ZAYIN") | |
quality_dir = os.path.join(essence_dir, quality) | |
if os.path.exists(quality_dir): | |
# Check for files with this tag name | |
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
matching_files = [f for f in os.listdir(quality_dir) if f.startswith(safe_tag)] | |
if matching_files: | |
# Use the most recent file if there are multiple | |
matching_files.sort(reverse=True) | |
info["path"] = os.path.join(quality_dir, matching_files[0]) | |
print(f"Reconnected essence for {tag} to {info['path']}") | |
# List essences by quality level | |
essences_by_quality = {} | |
for tag, info in st.session_state.generated_essences.items(): | |
quality = info.get("quality", "ZAYIN") # Default to lowest if not set | |
if quality not in essences_by_quality: | |
essences_by_quality[quality] = [] | |
essences_by_quality[quality].append((tag, info)) | |
# Check if any essences exist on disk but are not tracked in session state | |
try: | |
untracked_essences = {} | |
for quality in ESSENCE_QUALITY_LEVELS.keys(): | |
quality_dir = os.path.join(essence_dir, quality) | |
if os.path.exists(quality_dir): | |
essence_files = os.listdir(quality_dir) | |
# Filter to only show PNG files | |
essence_files = [f for f in essence_files if f.lower().endswith('.png')] | |
if essence_files: | |
# Check if any of these files aren't in our tracked essences | |
for filename in essence_files: | |
# Extract tag name from filename | |
parts = filename.split('_') | |
if len(parts) >= 2: | |
tag = parts[0].replace('_', ' ') | |
# Check if file is already tracked | |
is_tracked = False | |
for tracked_tag, tracked_info in st.session_state.generated_essences.items(): | |
if "path" in tracked_info and os.path.basename(tracked_info["path"]) == filename: | |
is_tracked = True | |
break | |
if not is_tracked: | |
if quality not in untracked_essences: | |
untracked_essences[quality] = [] | |
untracked_essences[quality].append((tag, { | |
"path": os.path.join(quality_dir, filename), | |
"quality": quality, | |
"discovered_on_disk": True | |
})) | |
except Exception as e: | |
print(f"Error checking for untracked essences: {e}") | |
# Combine tracked and untracked essences | |
for quality, essences in untracked_essences.items(): | |
if quality not in essences_by_quality: | |
essences_by_quality[quality] = [] | |
for tag, info in essences: | |
# Only add if we don't already have this tag in this quality level | |
if not any(tracked_tag == tag for tracked_tag, _ in essences_by_quality[quality]): | |
essences_by_quality[quality].append((tag, info)) | |
# Show essences from highest to lowest quality | |
for quality in list(ESSENCE_QUALITY_LEVELS.keys())[::-1]: | |
if quality in essences_by_quality: | |
essences = essences_by_quality[quality] | |
color = ESSENCE_QUALITY_LEVELS[quality]["color"] | |
with st.expander(f"{quality} Essences ({len(essences)})", expanded=quality in ["ALEPH", "WAW"]): | |
# Create grid layout | |
cols = st.columns(3) | |
for i, (tag, info) in enumerate(sorted(essences, key=lambda x: x[1].get("score", 0), reverse=True)): | |
col_idx = i % 3 | |
with cols[col_idx]: | |
try: | |
# Try to load the image from path | |
if "path" in info and os.path.exists(info["path"]): | |
image = Image.open(info["path"]) | |
rarity = info.get("rarity", "Canard") | |
score = info.get("score", 0) | |
# Get color for rarity | |
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") | |
# Display the image with metadata | |
st.image(image, caption=tag, use_container_width=True) | |
# Use special styling for rare tags | |
if rarity == "Impuritas Civitas": | |
st.markdown(f""" | |
<span style='color:{color};font-weight:bold;'>{quality}</span> | | |
<span style='animation: rainbow-text 4s linear infinite;font-weight:bold;'>{rarity}</span> | | |
Score: {score:.2f} | |
""", unsafe_allow_html=True) | |
elif rarity == "Star of the City": | |
st.markdown(f""" | |
<span style='color:{color};font-weight:bold;'>{quality}</span> | | |
<span style='color:{rarity_color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity}</span> | | |
Score: {score:.2f} | |
""", unsafe_allow_html=True) | |
elif rarity == "Urban Nightmare": | |
st.markdown(f""" | |
<span style='color:{color};font-weight:bold;'>{quality}</span> | | |
<span style='color:{rarity_color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity}</span> | | |
Score: {score:.2f} | |
""", unsafe_allow_html=True) | |
elif rarity == "Urban Plague": | |
st.markdown(f""" | |
<span style='color:{color};font-weight:bold;'>{quality}</span> | | |
<span style='color:{rarity_color};text-shadow:0 0 1px #9C27B0;font-weight:bold;'>{rarity}</span> | | |
Score: {score:.2f} | |
""", unsafe_allow_html=True) | |
else: | |
st.markdown(f""" | |
<span style='color:{color};font-weight:bold;'>{quality}</span> | | |
<span style='color:{rarity_color};font-weight:bold;'>{rarity}</span> | | |
Score: {score:.2f} | |
""", unsafe_allow_html=True) | |
# Add file info | |
if "discovered_on_disk" in info and info["discovered_on_disk"]: | |
st.info("Found on disk (not in session state)") | |
# Add button to open folder | |
if st.button(f"Open Folder", key=f"open_folder_{tag}_{quality}"): | |
folder_path = os.path.dirname(info["path"]) | |
try: | |
# Try different methods to open folder based on platform | |
if os.name == 'nt': # Windows | |
os.startfile(folder_path) | |
elif os.name == 'posix': # macOS or Linux | |
import subprocess | |
if 'darwin' in os.sys.platform: # macOS | |
subprocess.call(['open', folder_path]) | |
else: # Linux | |
subprocess.call(['xdg-open', folder_path]) | |
st.success(f"Opened folder: {folder_path}") | |
except Exception as e: | |
st.error(f"Could not open folder: {str(e)}") | |
# Provide the path for manual navigation | |
st.code(folder_path) | |
else: | |
# Could not find image | |
st.warning(f"Image file not found: {info.get('path', 'No path available')}") | |
# Show quality and tag name | |
st.markdown(f""" | |
<span style='color:{color};font-weight:bold;'>{quality}</span> | {tag} | |
""", unsafe_allow_html=True) | |
# Only add reconnect button if we have some metadata | |
if "rarity" in info and "score" in info: | |
if st.button(f"Reconnect File", key=f"reconnect_{tag}_{quality}"): | |
# Update path in session state | |
safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_') | |
score = info.get("score", 0) | |
quality_dir = os.path.join(essence_dir, quality) | |
# Create directory if it doesn't exist | |
os.makedirs(quality_dir, exist_ok=True) | |
# Set a path - user will need to manually add the image | |
timestamp = time.strftime("%Y%m%d_%H%M%S") | |
filename = f"{safe_tag}_{score:.2f}_{timestamp}.png" | |
info["path"] = os.path.join(quality_dir, filename) | |
st.info(f"Please save your image to this location: {info['path']}") | |
st.session_state.generated_essences[tag] = info | |
tag_storage.save_essence_state(session_state=st.session_state) | |
st.rerun() | |
except Exception as e: | |
st.write(f"Error loading {tag}: {str(e)}") | |
# Add option to clean up missing files | |
st.divider() | |
if st.button("Clean Up Missing Files", help="Remove entries for essences where the file no longer exists"): | |
# Find all entries with missing files | |
to_remove = [] | |
for tag, info in st.session_state.generated_essences.items(): | |
if "path" in info and not os.path.exists(info["path"]): | |
to_remove.append(tag) | |
# Remove them | |
for tag in to_remove: | |
del st.session_state.generated_essences[tag] | |
# Save state | |
tag_storage.save_essence_state(session_state=st.session_state) | |
if to_remove: | |
st.success(f"Removed {len(to_remove)} entries with missing files") | |
else: | |
st.success("No missing files found") | |
st.rerun() | |
def display_essence_generation_interface(model_available): | |
"""Display the interface for generating new essences""" | |
# Initialize manual tags | |
initialize_manual_tags() | |
st.subheader("Generate Tag Essence") | |
st.write("Select a tag to generate its essence. Higher quality essences can help unlock rare related tags when uploaded back into the tagger.") | |
# Settings column | |
col1, col2 = st.columns(2) | |
with col1: | |
st.write("Generation Settings:") | |
# Add reset button | |
if st.button("Reset to Defaults", help="Clear saved settings and use default values"): | |
st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy() | |
tag_storage.save_essence_state(session_state=st.session_state) | |
st.success("Settings reset to defaults!") | |
st.rerun() | |
# Advanced settings with better organization | |
with st.expander("Advanced Settings", expanded=True): | |
col_a, col_b = st.columns(2) | |
with col_a: | |
# Core generation parameters | |
st.write("**Core Parameters**") | |
iterations = st.slider( | |
"Iterations", | |
min_value=64, | |
max_value=2048, | |
value=st.session_state.essence_custom_settings.get("iterations", 500), | |
step=64, | |
help="More iterations improve quality but take longer" | |
) | |
lr = st.slider( | |
"Learning Rate", | |
min_value=0.01, | |
max_value=0.2, | |
value=st.session_state.essence_custom_settings.get("lr", 0.05), | |
step=0.01, | |
help="Higher learning rates converge faster but may be less stable" | |
) | |
ensemble_k = st.slider( | |
"Ensemble Size", | |
min_value=1, | |
max_value=16, | |
value=st.session_state.essence_custom_settings.get("ensemble_k", 8), | |
step=1, | |
help="Number of augmented versions per iteration. Higher = more stable but slower" | |
) | |
with col_b: | |
# Multi-tag parameters | |
st.write("**Multi-Tag Enhancement**") | |
neighbor_count = st.slider( | |
"Neighbor Tags", | |
min_value=0, | |
max_value=16, | |
value=st.session_state.essence_custom_settings.get("neighbor_count", 8), | |
step=1, | |
help="Number of similar/dissimilar tags to consider. 0 = only target tag" | |
) | |
tv_weight = st.select_slider( | |
"Smoothness", | |
options=[1e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2], | |
value=st.session_state.essence_custom_settings.get("tv_weight", 1e-3), | |
format_func=lambda x: f"{x:.0e}", | |
help="Higher values create smoother, less noisy images" | |
) | |
# Layer emphasis selection | |
layer_emphasis = st.selectbox( | |
"Feature Targeting", | |
options=["balanced", "early", "mid", "late"], | |
index=0, | |
format_func=lambda x: { | |
"balanced": "Balanced (mix of features)", | |
"early": "Early (textures, patterns)", | |
"mid": "Mid (parts, components)", | |
"late": "Late (characters, objects)" | |
}.get(x, x), | |
help="Controls which model features to emphasize" | |
) | |
# Save settings | |
st.session_state.essence_custom_settings = { | |
"iterations": iterations, | |
"lr": lr, | |
"ensemble_k": ensemble_k, | |
"neighbor_count": neighbor_count, | |
"image_size": 512, # Fixed for now | |
"layer_emphasis": layer_emphasis, | |
"tv_weight": tv_weight | |
} | |
# Show current settings summary | |
st.info(f""" | |
**Current Settings:** | |
- Iterations: {iterations} | |
- Learning Rate: {lr} | |
- Ensemble Size: {ensemble_k} | |
- Neighbor Tags: {neighbor_count} | |
- Feature Focus: {layer_emphasis.capitalize()} | |
""") | |
with col2: | |
# Show quality level descriptions | |
st.write("Quality Levels:") | |
for level, info in ESSENCE_QUALITY_LEVELS.items(): | |
st.markdown(f""" | |
<div style="padding:5px;margin-bottom:5px;border-radius:4px;background-color:rgba({int(info['color'][1:3], 16)},{int(info['color'][3:5], 16)},{int(info['color'][5:7], 16)},0.1);border-left:3px solid {info['color']}"> | |
<span style="color:{info['color']};font-weight:bold;">{level}</span> ({info['threshold']:.0f} Score+): {info['description']} | |
</div> | |
""", unsafe_allow_html=True) | |
# Feature targeting explanation | |
st.write("Feature Targeting Explanation:") | |
st.markdown(""" | |
- **Early**: Textures, colors, simple patterns | |
- **Mid**: Parts, components, intermediate features | |
- **Late**: Characters, objects, high-level concepts | |
- **Balanced**: Mix of all feature levels | |
""") | |
# Show current Enkephalin | |
st.markdown(f"### Your {ENKEPHALIN_CURRENCY_NAME}: **{st.session_state.enkephalin}** {ENKEPHALIN_ICON}") | |
st.divider() | |
# Add CSS for animations matching tag collection display | |
st.markdown(""" | |
<style> | |
@keyframes rainbow-text { | |
0% { color: red; } | |
14% { color: orange; } | |
28% { color: yellow; } | |
42% { color: green; } | |
57% { color: blue; } | |
71% { color: indigo; } | |
85% { color: violet; } | |
100% { color: red; } | |
} | |
.impuritas-text { | |
font-weight: bold; | |
animation: rainbow-text 4s linear infinite; | |
} | |
@keyframes glow-text { | |
0% { text-shadow: 0 0 2px gold; } | |
50% { text-shadow: 0 0 6px gold; } | |
100% { text-shadow: 0 0 2px gold; } | |
} | |
.star-text { | |
color: #FFEB3B; | |
text-shadow: 0 0 3px gold; | |
animation: glow-text 2s infinite; | |
font-weight: bold; | |
} | |
@keyframes pulse-text { | |
0% { opacity: 0.8; } | |
50% { opacity: 1; } | |
100% { opacity: 0.8; } | |
} | |
.nightmare-text { | |
color: #FF9800; | |
text-shadow: 0 0 1px #FF5722; | |
animation: pulse-text 3s infinite; | |
font-weight: bold; | |
} | |
.plague-text { | |
color: #9C27B0; | |
text-shadow: 0 0 1px #9C27B0; | |
font-weight: bold; | |
} | |
.category-section { | |
margin-top: 20px; | |
margin-bottom: 30px; | |
padding: 10px; | |
border-radius: 5px; | |
border-left: 5px solid; | |
} | |
</style> | |
""", unsafe_allow_html=True) | |
# Tag collection display (unchanged from original) | |
# Gather all tags for essence generation | |
all_tags = [] | |
# Process discovered tags | |
if hasattr(st.session_state, 'discovered_tags'): | |
for tag, info in st.session_state.discovered_tags.items(): | |
tag_info = { | |
"tag": tag, | |
"rarity": info.get("rarity", "Unknown"), | |
"category": info.get("category", "unknown"), | |
"source": "discovered", | |
"library_floor": info.get("library_floor", ""), | |
"discovery_time": info.get("discovery_time", "") | |
} | |
all_tags.append(tag_info) | |
# Process manual tags | |
if hasattr(st.session_state, 'manual_tags'): | |
for tag, info in st.session_state.manual_tags.items(): | |
tag_info = { | |
"tag": tag, | |
"rarity": info.get("rarity", "Special"), | |
"category": info.get("category", "special"), | |
"source": "manual", | |
"description": info.get("description", "") | |
} | |
all_tags.append(tag_info) | |
# Count tags by rarity | |
rarity_counts = {} | |
for info in all_tags: | |
rarity = info["rarity"] | |
if rarity not in rarity_counts: | |
rarity_counts[rarity] = 0 | |
rarity_counts[rarity] += 1 | |
# Display rarity counts at the top | |
st.subheader("Available Tags for Essence Generation") | |
st.write(f"You have {len(all_tags)} tags available for essence generation. Collect more from the library!") | |
# Display rarity distribution | |
rarity_cols = st.columns(len(rarity_counts)) | |
for i, (rarity, count) in enumerate(sorted(rarity_counts.items(), | |
key=lambda x: list(RARITY_LEVELS.keys()).index(x[0]) if x[0] in RARITY_LEVELS else 999)): | |
with rarity_cols[i]: | |
# Get color with fallback | |
color = RARITY_LEVELS.get(rarity, {}).get("color", "#888888") | |
# Apply special styling based on rarity | |
style = f"color:{color};font-weight:bold;" | |
class_name = "" | |
if rarity == "Impuritas Civitas": | |
class_name = "grid-impuritas" | |
elif rarity == "Star of the City": | |
class_name = "grid-star" | |
elif rarity == "Urban Nightmare": | |
class_name = "grid-nightmare" | |
elif rarity == "Urban Plague": | |
class_name = "grid-plague" | |
if class_name: | |
st.markdown( | |
f"<div style='text-align:center;'><span class='{class_name}' style='font-weight:bold;'>{rarity.capitalize()}</span><br>{count}</div>", | |
unsafe_allow_html=True | |
) | |
else: | |
st.markdown( | |
f"<div style='text-align:center;'><span style='{style}'>{rarity.capitalize()}</span><br>{count}</div>", | |
unsafe_allow_html=True | |
) | |
# Search box for all tags | |
search_term = st.text_input("Search tags", "", key="essence_search_tags") | |
# Sort options | |
sort_options = ["Category (rarest first)", "Rarity", "Discovery Time"] | |
selected_sort = st.selectbox("Sort tags by:", sort_options, key="essence_tags_sort") | |
# Filter tags by search term if provided | |
if search_term: | |
all_tags = [info for info in all_tags if search_term.lower() in info["tag"].lower()] | |
selected_tag = None | |
# Sort and group tags based on selection (rest of the display logic unchanged) | |
if selected_sort == "Category (rarest first)": | |
# Group tags by category | |
categories = {} | |
for info in all_tags: | |
category = info["category"] | |
if category not in categories: | |
categories[category] = [] | |
categories[category].append(info) | |
# Display tags by category in expanders | |
for category, tags in sorted(categories.items()): | |
# Get rarity order for sorting | |
rarity_order = list(reversed(RARITY_LEVELS.keys())) | |
# Sort tags by rarity (rarest first) | |
def get_rarity_index(info): | |
rarity = info["rarity"] | |
if rarity in rarity_order: | |
return len(rarity_order) - rarity_order.index(rarity) | |
return 0 | |
sorted_tags = sorted(tags, key=get_rarity_index, reverse=True) | |
# Check if category has any rare tags | |
has_rare_tags = any(info["rarity"] in ["Impuritas Civitas", "Star of the City"] | |
for info in sorted_tags) | |
# Get category info if available | |
category_display = category.capitalize() | |
if category in TAG_CATEGORIES: | |
category_info = TAG_CATEGORIES[category] | |
icon = category_info.get("icon", "") | |
color = category_info.get("color", "#888888") | |
category_display = f"<span style='color:{color};'>{icon} {category.capitalize()}</span>" | |
# Create header with information about rare tags if present | |
header = f"{category_display} ({len(tags)} tags)" | |
if has_rare_tags: | |
header += " β¨ Contains rare tags!" | |
# Display category header and expander | |
st.markdown(header, unsafe_allow_html=True) | |
with st.expander("Show/Hide", expanded=has_rare_tags): | |
# Create grid layout for tags | |
cols = st.columns(3) | |
for i, info in enumerate(sorted_tags): | |
with cols[i % 3]: | |
tag = info["tag"] | |
rarity = info["rarity"] | |
source = info["source"] | |
# Get rarity color | |
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") | |
# Check if this tag has an essence already | |
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences | |
# Get cost for this tag | |
cost = get_essence_cost(rarity) | |
can_afford = st.session_state.enkephalin >= cost | |
# Format tag display with special styling | |
if rarity == "Impuritas Civitas": | |
tag_display = f'<span class="impuritas-text">{tag}</span>' | |
elif rarity == "Star of the City": | |
tag_display = f'<span class="star-text">{tag}</span>' | |
elif rarity == "Urban Nightmare": | |
tag_display = f'<span class="nightmare-text">{tag}</span>' | |
elif rarity == "Urban Plague": | |
tag_display = f'<span class="plague-text">{tag}</span>' | |
else: | |
tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>' | |
# Show tag with rarity badge and cost | |
st.markdown( | |
f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})', | |
unsafe_allow_html=True | |
) | |
# Show discovery details if available | |
if source == "discovered" and "library_floor" in info and info["library_floor"]: | |
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', | |
unsafe_allow_html=True) | |
elif source == "manual" and "description" in info and info["description"]: | |
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', | |
unsafe_allow_html=True) | |
# Add generation button | |
button_label = "Generate" if not has_essence else "Regenerate β" | |
if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford): | |
selected_tag = tag | |
elif selected_sort == "Rarity": | |
# Group tags by rarity | |
rarity_groups = {} | |
for info in all_tags: | |
rarity = info["rarity"] | |
if rarity not in rarity_groups: | |
rarity_groups[rarity] = [] | |
rarity_groups[rarity].append(info) | |
# Get ordered rarities (rarest first) | |
ordered_rarities = list(RARITY_LEVELS.keys()) | |
ordered_rarities.reverse() # Reverse to show rarest first | |
# Add any rarities not in RARITY_LEVELS | |
for rarity in rarity_groups.keys(): | |
if rarity not in ordered_rarities: | |
ordered_rarities.append(rarity) | |
# Display tags by rarity | |
for rarity in ordered_rarities: | |
if rarity in rarity_groups: | |
tags = rarity_groups[rarity] | |
color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") | |
# Add special styling for rare rarities | |
rarity_html = f"<span style='color:{color};font-weight:bold;'>{rarity.capitalize()}</span>" | |
if rarity == "Impuritas Civitas": | |
rarity_html = f"<span style='animation:rainbow-text 4s linear infinite;font-weight:bold;'>{rarity.capitalize()}</span>" | |
elif rarity == "Star of the City": | |
rarity_html = f"<span style='color:{color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity.capitalize()}</span>" | |
elif rarity == "Urban Nightmare": | |
rarity_html = f"<span style='color:{color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity.capitalize()}</span>" | |
# First create the title with HTML, then use it in the expander | |
st.markdown(f"### {rarity_html} ({len(tags)} tags)", unsafe_allow_html=True) | |
with st.expander("Show/Hide", expanded=rarity in ["Impuritas Civitas", "Star of the City"]): | |
# Create grid layout for tags | |
cols = st.columns(3) | |
for i, info in enumerate(sorted(tags, key=lambda x: x["tag"])): | |
with cols[i % 3]: | |
tag = info["tag"] | |
source = info["source"] | |
# Check if this tag has an essence already | |
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences | |
# Get cost for this tag | |
cost = get_essence_cost(rarity) | |
can_afford = st.session_state.enkephalin >= cost | |
# Show tag with cost | |
st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})") | |
# Show discovery details if available | |
if source == "discovered" and "library_floor" in info and info["library_floor"]: | |
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', | |
unsafe_allow_html=True) | |
elif source == "manual" and "description" in info and info["description"]: | |
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', | |
unsafe_allow_html=True) | |
# Add generation button | |
button_label = "Generate" if not has_essence else "Regenerate β" | |
if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford): | |
selected_tag = tag | |
elif selected_sort == "Discovery Time": | |
# Filter to just discovered tags (manual tags don't have discovery time) | |
discovered_tags = [info for info in all_tags if info["source"] == "discovered" and "discovery_time" in info] | |
# Sort all tags by discovery time (newest first) | |
sorted_tags = sorted(discovered_tags, key=lambda x: x["discovery_time"], reverse=True) | |
# Group by date | |
date_groups = {} | |
for info in sorted_tags: | |
time_str = info["discovery_time"] | |
# Extract just the date part if timestamp has date and time | |
date = time_str.split()[0] if " " in time_str else time_str | |
if date not in date_groups: | |
date_groups[date] = [] | |
date_groups[date].append(info) | |
# Display tags grouped by discovery date | |
for date, tags in date_groups.items(): | |
date_display = date if date else "Unknown date" | |
st.markdown(f"### Discovered on {date_display} ({len(tags)} tags)") | |
with st.expander("Show/Hide", expanded=date == list(date_groups.keys())[0]): # Expand most recent by default | |
# Create grid layout for tags | |
cols = st.columns(3) | |
for i, info in enumerate(tags): | |
with cols[i % 3]: | |
tag = info["tag"] | |
rarity = info["rarity"] | |
# Get rarity color | |
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") | |
# Check if this tag has an essence already | |
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences | |
# Get cost for this tag | |
cost = get_essence_cost(rarity) | |
can_afford = st.session_state.enkephalin >= cost | |
# Format tag display with special styling | |
if rarity == "Impuritas Civitas": | |
tag_display = f'<span class="impuritas-text">{tag}</span>' | |
elif rarity == "Star of the City": | |
tag_display = f'<span class="star-text">{tag}</span>' | |
elif rarity == "Urban Nightmare": | |
tag_display = f'<span class="nightmare-text">{tag}</span>' | |
elif rarity == "Urban Plague": | |
tag_display = f'<span class="plague-text">{tag}</span>' | |
else: | |
tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>' | |
# Show tag with rarity badge and cost | |
st.markdown( | |
f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})', | |
unsafe_allow_html=True | |
) | |
# Show discovery details | |
if "library_floor" in info and info["library_floor"]: | |
st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', | |
unsafe_allow_html=True) | |
# Add generation button | |
button_label = "Generate" if not has_essence else "Regenerate β" | |
if st.button(button_label, key=f"gen_{tag}_disc", disabled=not model_available or not can_afford): | |
selected_tag = tag | |
# Show manual tags separately if we have any | |
manual_tags = [info for info in all_tags if info["source"] == "manual"] | |
if manual_tags: | |
st.markdown("### Manual Tags") | |
with st.expander("Show/Hide"): | |
# Create grid layout for tags | |
cols = st.columns(3) | |
for i, info in enumerate(manual_tags): | |
with cols[i % 3]: | |
tag = info["tag"] | |
rarity = info["rarity"] | |
# Get rarity color | |
rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA") | |
# Check if this tag has an essence already | |
has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences | |
# Get cost for this tag | |
cost = get_essence_cost(rarity) | |
can_afford = st.session_state.enkephalin >= cost | |
# Show tag with rarity badge and cost | |
st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})") | |
# Show description if available | |
if "description" in info and info["description"]: | |
st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', | |
unsafe_allow_html=True) | |
# Add generation button | |
button_label = "Generate" if not has_essence else "Regenerate β" | |
if st.button(button_label, key=f"gen_{tag}_manual", disabled=not model_available or not can_afford): | |
selected_tag = tag | |
return selected_tag |