File size: 86,186 Bytes
a5ab717
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
1370
1371
1372
1373
1374
1375
1376
1377
1378
1379
1380
1381
1382
1383
1384
1385
1386
1387
1388
1389
1390
1391
1392
1393
1394
1395
1396
1397
1398
1399
1400
1401
1402
1403
1404
1405
1406
1407
1408
1409
1410
1411
1412
1413
1414
1415
1416
1417
1418
1419
1420
1421
1422
1423
1424
1425
1426
1427
1428
1429
1430
1431
1432
1433
1434
1435
1436
1437
1438
1439
1440
1441
1442
1443
1444
1445
1446
1447
1448
1449
1450
1451
1452
1453
1454
1455
1456
1457
1458
1459
1460
1461
1462
1463
1464
1465
1466
1467
1468
1469
1470
1471
1472
1473
1474
1475
1476
1477
1478
1479
1480
1481
1482
1483
1484
1485
1486
1487
1488
1489
1490
1491
1492
1493
1494
1495
1496
1497
1498
1499
1500
1501
1502
1503
1504
1505
1506
1507
1508
1509
1510
1511
1512
1513
1514
1515
1516
1517
1518
1519
1520
1521
1522
1523
1524
1525
1526
1527
1528
1529
1530
1531
1532
1533
1534
1535
1536
1537
1538
1539
1540
1541
1542
1543
1544
1545
1546
1547
1548
1549
1550
1551
1552
1553
1554
1555
1556
1557
1558
1559
1560
1561
1562
1563
1564
1565
1566
1567
1568
1569
1570
1571
1572
1573
1574
1575
1576
1577
1578
1579
1580
1581
1582
1583
1584
1585
1586
1587
1588
1589
1590
1591
1592
1593
1594
1595
1596
1597
1598
1599
1600
1601
1602
1603
1604
1605
1606
1607
1608
1609
1610
1611
1612
1613
1614
1615
1616
1617
1618
1619
1620
1621
1622
1623
1624
1625
1626
1627
1628
1629
1630
1631
1632
1633
1634
1635
1636
1637
1638
1639
1640
1641
1642
1643
1644
1645
1646
1647
1648
1649
1650
1651
1652
1653
1654
1655
1656
1657
1658
1659
1660
1661
1662
1663
1664
1665
1666
1667
1668
1669
1670
1671
1672
1673
1674
1675
1676
1677
1678
1679
1680
1681
1682
1683
1684
1685
1686
1687
1688
1689
1690
1691
1692
1693
1694
1695
1696
1697
1698
1699
1700
1701
1702
1703
1704
1705
1706
1707
1708
1709
1710
1711
1712
1713
1714
1715
1716
1717
1718
1719
1720
1721
1722
1723
1724
1725
1726
1727
1728
1729
1730
1731
1732
1733
1734
1735
1736
1737
1738
1739
1740
1741
1742
1743
1744
1745
1746
1747
1748
1749
1750
1751
1752
1753
1754
1755
1756
1757
1758
1759
1760
1761
1762
1763
1764
1765
1766
1767
1768
1769
1770
1771
1772
1773
1774
1775
1776
1777
1778
1779
1780
1781
1782
1783
1784
1785
1786
1787
1788
1789
1790
1791
1792
1793
1794
1795
1796
1797
1798
1799
1800
1801
1802
1803
1804
1805
1806
1807
1808
1809
1810
1811
1812
1813
1814
1815
1816
1817
1818
1819
1820
1821
1822
1823
1824
1825
1826
1827
1828
1829
1830
1831
1832
1833
1834
1835
1836
1837
1838
1839
1840
1841
1842
1843
1844
1845
1846
1847
1848
1849
1850
1851
1852
1853
1854
1855
1856
1857
1858
1859
1860
1861
1862
1863
1864
1865
1866
1867
1868
1869
1870
1871
1872
1873
1874
1875
1876
1877
1878
1879
1880
1881
1882
1883
1884
1885
1886
1887
1888
1889
1890
1891
1892
1893
1894
1895
1896
1897
1898
1899
1900
1901
1902
1903
1904
1905
1906
1907
1908
1909
1910
1911
1912
1913
1914
1915
1916
1917
1918
1919
1920
1921
1922
1923
1924
1925
1926
1927
1928
1929
1930
#!/usr/bin/env python3
"""
Vision Transformer Essence Generator for Tag Collector Game
Based on "What do Vision Transformers Learn? A Visual Exploration"
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms.functional import to_pil_image
from PIL import Image
import numpy as np
import os
import re
import math
import json
import timm
import streamlit as st
from tqdm import tqdm
from scipy.ndimage import gaussian_filter
from functools import wraps, lru_cache
from safetensors.torch import load_file
import time
import tag_storage  # Import for saving game state

from game_constants import RARITY_LEVELS, ENKEPHALIN_CURRENCY_NAME, ENKEPHALIN_ICON
from tag_categories import TAG_CATEGORIES

os.environ["CUDA_LAUNCH_BLOCKING"] = "1"
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

# Define essence quality levels with thresholds and styles
ESSENCE_QUALITY_LEVELS = {
    "ZAYIN": {"threshold": 0.0, "color": "#1CFC00", "description": "Basic representation with minimal details."},
    "TETH": {"threshold": 3.0, "color": "#389DDF", "description": "Clear representation with recognizable features."},
    "HE": {"threshold": 5.0, "color": "#FEF900", "description": "Refined representation with distinctive elements."},
    "WAW": {"threshold": 10.0, "color": "#7930F1", "description": "Advanced representation with precise details."},
    "ALEPH": {"threshold": 12.0, "color": "#FF0000", "description": "Perfect representation with extraordinary precision."}
}

# Essence generation costs in enkephalin based on tag rarity
ESSENCE_COSTS = {
    "Special": 0,
    "Canard": 100,
    "Urban Myth": 125,
    "Urban Legend": 150,
    "Urban Plague": 200,
    "Urban Nightmare": 250,
    "Star of the City": 300,
    "Impuritas Civitas": 400
}

# Default essence generation settings
DEFAULT_ESSENCE_SETTINGS = {
    "iterations": 256,
    "lr": 0.05,
    "ensemble_k": 8,
    "neighbor_count": 8,
    "image_size": 512,
    "layer_emphasis": "balanced",
    "tv_weight": 1e-3
}

def initialize_essence_settings():
    """Initialize essence generator settings if not already present"""
    if 'essence_custom_settings' not in st.session_state:
        # Try to load from storage first
        loaded_state = tag_storage.load_essence_state()
        
        if loaded_state and 'essence_custom_settings' in loaded_state:
            old_settings = loaded_state['essence_custom_settings']
            # Validate and merge with current defaults
            new_settings = DEFAULT_ESSENCE_SETTINGS.copy()
            
            # Only keep valid settings that exist in current defaults
            for key in DEFAULT_ESSENCE_SETTINGS.keys():
                if key in old_settings:
                    # Validate layer_emphasis values
                    if key == 'layer_emphasis' and old_settings[key] not in ['balanced', 'early', 'mid', 'late']:
                        continue  # Use default
                    new_settings[key] = old_settings[key]
            
            st.session_state.essence_custom_settings = new_settings
        else:
            st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy()

def initialize_manual_tags():
    """Initialize manual tags if not already present"""
    if 'manual_tags' not in st.session_state:
        # Try to load from storage first
        loaded_state = tag_storage.load_essence_state()
        
        if loaded_state and 'manual_tags' in loaded_state:
            st.session_state.manual_tags = loaded_state['manual_tags']
        else:
            st.session_state.manual_tags = {
                "hatsune_miku": {"rarity": "Special", "description": "Popular virtual singer with long teal twin-tails"},
            }

def timeout(seconds, fallback_value=None):
    """Simple timeout utility for functions."""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            start_time = time.time()
            result = func(*args, **kwargs)
            elapsed = time.time() - start_time
            
            if elapsed > seconds:
                print(f"WARNING: Function {func.__name__} took {elapsed:.2f} seconds (expected max {seconds}s)")
                
            return result
        return wrapper
    return decorator

class TaggerTorch(nn.Module):
    def __init__(self, backbone_name="vit_base_patch16_384", img_size=512, num_tags=70527, normalize=True):
        super().__init__()
        # num_classes=0 -> return features; we add our own head
        self.backbone = timm.create_model(backbone_name, pretrained=False, num_classes=0, img_size=img_size)
        in_features = self.backbone.num_features  # 768 for vit_base_patch16_384
        self.head = nn.Linear(in_features, num_tags)

        # Most ViT taggers expect ImageNet normalization; keep it configurable
        self.normalize = normalize
        if self.normalize:
            self.register_buffer("mean", torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1))
            self.register_buffer("std",  torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1))

    def forward(self, x):
        if self.normalize:
            x = (x - self.mean) / self.std
        feats = self.backbone.forward_features(x)   # [B, C] or [B, tokens, C]
        if feats.ndim == 3:                          # if tokens, take CLS
            feats = feats[:, 0, :]
        return self.head(feats)

def _remap_backbone_keys(sd):
    out = {}
    for k, v in sd.items():
        if k.startswith("module."): k = k[7:]

        # collapse ImageTagger β†’ TaggerTorch vit paths
        if k.startswith("backbone.vit."):
            k = "backbone." + k[len("backbone.vit."):]
        elif k.startswith("vit."):
            k = "backbone." + k[len("vit."):]
        elif k.startswith(("pos_embed","patch_embed.","blocks.","norm.","cls_token")):
            k = "backbone." + k

        out[k] = v
    return out

def _get_logits_from_output(out):
    if isinstance(out, dict):
        return out.get("refined_predictions") or out.get("initial_predictions")
    return out

def build_torch_model_from_safetensors(ckpt_path, num_tags, backbone="vit_base_patch16_384", img_size=512):
    model = TaggerTorch(backbone_name=backbone, img_size=img_size, num_tags=num_tags, normalize=True)
    sd = load_file(ckpt_path)
    sd = _remap_backbone_keys(sd)

    # pull out tag embedding/bias if present and later copy into the linear head
    te_w = sd.pop("tag_embedding.weight", sd.pop("module.tag_embedding.weight", None))
    te_b = sd.pop("tag_bias", sd.pop("module.tag_bias", None))

    # load backbone etc.
    missing, unexpected = model.load_state_dict(sd, strict=False)
    print("[load] missing:", missing[:20], "…")
    print("[load] unexpected:", unexpected[:20], "…")

    # copy tag embedding β†’ head
    with torch.no_grad():
        if te_w is not None and te_w.shape == model.head.weight.shape:
            model.head.weight.copy_(te_w)
            print("[load] copied tag_embedding.weight β†’ head.weight")
        if te_b is not None and model.head.bias is not None and te_b.shape == model.head.bias.shape:
            model.head.bias.copy_(te_b)
            print("[load] copied tag_bias β†’ head.bias")

    return model

@torch.no_grad()
def _get_classifier_matrix(model):
    # [T, C] β€” works for both ImageTagger and TaggerTorch
    if hasattr(model, "tag_embedding"):
        return model.tag_embedding.weight.detach()
    if hasattr(model, "head") and hasattr(model.head, "weight"):
        return model.head.weight.detach()
    raise AttributeError("Model has neither tag_embedding nor head.weight")

@torch.no_grad()
def neighbor_sets_from_embedding(model, class_idx, k_pos=8, k_neg=8):
    """
    Returns (pos_idx, pos_sims, neg_idx, neg_sims)
    pos: highest cosine neighbors (exclude self)
    neg: lowest cosine neighbors (most dissimilar)
    """
    W = _get_classifier_matrix(model)               # [T, C]
    Wn = F.normalize(W, dim=1)
    q  = Wn[class_idx:class_idx+1]                  # [1, C]
    sims = (q @ Wn.T).squeeze(0)                    # [T]
    sims[class_idx] = -9e9                          # mask self

    # positives: largest similarities
    pos_vals, pos_idx = torch.topk(sims, k=min(k_pos, sims.numel()-1))
    # negatives: smallest similarities (most negative / least similar)
    neg_vals, neg_idx = torch.topk(-sims, k=min(k_neg, sims.numel()-1))
    neg_vals = -neg_vals

    # weights (clip for stability)
    pos_w = torch.clamp(pos_vals, 0.0, 1.0).tolist()
    neg_w = torch.clamp(neg_vals.abs(), 0.0, 1.0).tolist()
    return pos_idx.tolist(), pos_w, neg_idx.tolist(), neg_w

def weighted_class_objective(logits, main_idx,
                             plus_idxs=(), plus_w=None, alpha=0.25,
                             minus_idxs=(), minus_w=None, beta=0.15):
    score = logits[:, main_idx].mean()
    if plus_idxs:
        w = torch.tensor(plus_w or [1.0]*len(plus_idxs), device=logits.device, dtype=logits.dtype)
        w = w / (w.sum() + 1e-8)
        score = score + alpha * (logits[:, plus_idxs] * w).sum(dim=1).mean()
    if minus_idxs:
        w = torch.tensor(minus_w or [1.0]*len(minus_idxs), device=logits.device, dtype=logits.dtype)
        w = w / (w.sum() + 1e-8)
        score = score - beta * (logits[:, minus_idxs] * w).sum(dim=1).mean()
    return score

def idx_to_name(idx, dataset=None):
    if dataset is not None and hasattr(dataset, "idx_to_tag"):
        return dataset.idx_to_tag.get(int(idx), f"Tag {idx}")
    # fallback to your cached JSON
    meta = _load_tagger_metadata_cached()
    return meta.get("dataset_info",{}).get("tag_mapping",{}).get("idx_to_tag",{}).get(str(int(idx)), f"Tag {idx}")

# Core Classes for ViT Essence Generation
class ViTLayerHook:
    """Hook for capturing ViT feed-forward layer activations."""
    def __init__(self, layer, layer_name):
        self.layer = layer
        self.layer_name = layer_name
        self.features = None
        self.hook = layer.register_forward_hook(self.hook_fn)
    
    def hook_fn(self, module, input, output):
        """Store the output activations."""
        self.features = output
    
    def close(self):
        self.hook.remove()

class ViTFeatureAnalyzer:
    """Analyzes ViT architecture to find optimal layers for visualization."""
    
    def __init__(self, model):
        self.model = model
        self.layer_info = self._analyze_architecture()
    
    def _analyze_architecture(self):
        """Analyze the ViT architecture and identify feed-forward layers."""
        layer_info = {}
        
        def traverse_modules(module, prefix=''):
            for name, child in module.named_children():
                full_name = f"{prefix}.{name}" if prefix else name
                
                # Look for transformer blocks and their MLP components
                if 'mlp' in full_name.lower() and (hasattr(child, 'act') or 'act' in dict(child.named_children())):
                    # prefer the actual activation submodule
                    act = getattr(child, 'act', None)
                    if act is not None:
                        layer_info[full_name + ".act"] = {
                            'type': 'mlp_activation',
                            'module': act,
                            'block_idx': self._extract_block_number(full_name)
                        }
                    else:
                        # fallback: search by name
                        for n2, c2 in child.named_children():
                            if 'act' in n2.lower():
                                layer_info[full_name + f".{n2}"] = {
                                    'type': 'mlp_activation',
                                    'module': c2,
                                    'block_idx': self._extract_block_number(full_name)
                                }
                elif 'gelu' in str(type(child)).lower() or 'activation' in name.lower():
                    # Direct activation layers (GELU, etc.)
                    parent_name = prefix.split('.')[-1] if '.' in prefix else prefix
                    if 'mlp' in prefix.lower() or 'ffn' in prefix.lower():
                        layer_info[full_name] = {
                            'type': 'activation',
                            'module': child,
                            'block_idx': self._extract_block_number(full_name)
                        }
                
                # Recurse into children
                traverse_modules(child, full_name)
        
        traverse_modules(self.model)
        return layer_info
    
    def _extract_block_number(self, layer_name):
        """Extract block/layer number from layer name."""
        import re
        numbers = re.findall(r'\.(\d+)\.', layer_name)
        if numbers:
            return int(numbers[0])
        return 0
    
    def get_visualization_layers(self, layer_emphasis="balanced"):
        """Get the best layers for visualization based on emphasis."""
        if not self.layer_info:
            print("Warning: No suitable ViT layers found for visualization")
            return []
        
        # Sort layers by block index
        sorted_layers = sorted(
            [(n, info) for n, info in self.layer_info.items() if 'mlp' in n.lower() and 'act' in n.lower()],
            key=lambda x: x[1]['block_idx']
        )
        
        total_blocks = max([info['block_idx'] for _, info in sorted_layers]) + 1
        
        if layer_emphasis == "early":
            # Focus on first 1/3 of blocks
            target_blocks = list(range(0, max(1, total_blocks // 3)))
        elif layer_emphasis == "mid":
            # Focus on middle 1/3 of blocks
            start = total_blocks // 3
            end = 2 * total_blocks // 3
            target_blocks = list(range(start, max(start + 1, end)))
        elif layer_emphasis == "late":
            # Focus on last 1/3 of blocks
            start = 2 * total_blocks // 3
            target_blocks = list(range(start, total_blocks))
        else:  # balanced
            # Sample across all blocks
            if total_blocks <= 3:
                target_blocks = list(range(total_blocks))
            else:
                target_blocks = [0, total_blocks // 2, total_blocks - 1]
        
        # Select layers from target blocks
        selected_layers = []
        for layer_name, info in sorted_layers:
            if info['block_idx'] in target_blocks:
                selected_layers.append(layer_name)
        
        return selected_layers

def _jitter_reflect_crop(x, pad=16):
    b, c, h, w = x.shape
    padded = F.pad(x, (pad, pad, pad, pad), mode='reflect').contiguous()
    off_h = torch.randint(0, 2 * pad + 1, (b,), device=x.device)
    off_w = torch.randint(0, 2 * pad + 1, (b,), device=x.device)
    crops = []
    for i in range(b):
        hs, ws = int(off_h[i]), int(off_w[i])
        crop = padded[i:i+1, :, hs:hs+h, ws:ws+w].contiguous()
        crops.append(crop)
    return torch.cat(crops, 0).contiguous()

def _channel_affine(x):
    # per-channel affine: Οƒ ~ exp(U[-1,1]), ΞΌ ~ U[-1,1]
    b, c, _, _ = x.shape
    mu = torch.empty(b, c, 1, 1, device=x.device, dtype=x.dtype).uniform_(-1.0, 1.0)
    log_sigma = torch.empty(b, c, 1, 1, device=x.device, dtype=x.dtype).uniform_(-1.0, 1.0)
    sigma = torch.exp(log_sigma)
    return (x * sigma + mu)

def _add_gaussian_noise(x, std=0.15):
    return (x + torch.randn_like(x) * std)

def _augment_once(x, noise_std=0.15):
    z = _jitter_reflect_crop(x)
    z = _channel_affine(z)
    z = _add_gaussian_noise(z, std=noise_std)
    return z

def _augment_batch(x, K=8, noise_std=0.15):
    augs = []
    for _ in range(K):
        z = _augment_once(x, noise_std=noise_std)
        augs.append(z)
    return torch.cat(augs, dim=0).contiguous()

class ViTEssenceGenerator:
    """
    ViT Essence Generator based on the methodology from 
    'What do Vision Transformers Learn? A Visual Exploration'
    """
    
    def __init__(
        self,
        model,
        tag_to_name=None,
        iterations=500,
        learning_rate=0.05,
        layer_emphasis="balanced",
        ensemble_K=8,
        tv_weight=1e-3
    ):
        """Initialize the ViT Essence Generator"""
        self.model = model
        self.tag_to_name = tag_to_name
        self.iterations = iterations
        self.lr = learning_rate
        self.layer_emphasis = layer_emphasis
        self.ensemble_K = ensemble_K
        self.tv_weight = tv_weight
                    
        # Set device
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.eval().to(self.device)

        self.imagenet_mean = torch.tensor([0.485, 0.456, 0.406], device=self.device).view(1,3,1,1)
        self.imagenet_std  = torch.tensor([0.229, 0.224, 0.225], device=self.device).view(1,3,1,1)
        self.expect_imagenet = not (hasattr(self.model, "normalize") and getattr(self.model, "normalize") is True)
        
        # Analyze ViT architecture
        self.analyzer = ViTFeatureAnalyzer(self.model)
        
        # Initialize hooks
        self.hooks = {}
        self.selected_layers = []
        
        print(f"ViT Essence Generator initialized on {self.device}")
    
    def _preprocess(self, x):
        return (x - self.imagenet_mean) / self.imagenet_std if self.expect_imagenet else x
    
    def setup_hooks(self, tag_idx):
        """Setup hooks for multi-layer visualization."""
        self.close_hooks()
        
        names = self.analyzer.get_visualization_layers(self.layer_emphasis)

        if not names:
            print("Warning: No suitable layers found for visualization")
            return {}

        print(f"Setting up hooks on {len(names)} ViT layer(s)")
        layer_weights = {}
        for i, layer_name in enumerate(names):
            try:
                layer_info = self.analyzer.layer_info[layer_name]
                layer_module = layer_info['module']
                self.hooks[layer_name] = ViTLayerHook(layer_module, layer_name)

                weight = 0.3 + 0.7 * (i / max(1, len(names) - 1))
                layer_weights[layer_name] = weight
                print(f"  - {layer_name} (block {layer_info['block_idx']}, weight: {weight:.2f})")
            except Exception as e:
                print(f"Failed to setup hook for {layer_name}: {e}")

        self.selected_layers = names
        return layer_weights

    def close_hooks(self):
        """Clean up hooks to avoid memory leaks."""
        for hook in self.hooks.values():
            hook.close()
        self.hooks.clear()

    def _fourier_init(self, size=224, decay=1.5):
        H = W = size
        # complex spectrum (rFFT domain)
        spec = torch.randn(1, 3, H, W//2 + 1, dtype=torch.complex64, device=self.device)
        fy = torch.fft.fftfreq(H, device=self.device).abs().view(H, 1)
        fx = torch.fft.rfftfreq(W, device=self.device).abs().view(1, W//2 + 1)
        radius = (fy**2 + fx**2).sqrt().clamp_(min=1e-6)
        spec = spec * (1.0 / (radius ** decay))  # 1/f^decay
        img = torch.fft.irfft2(spec, s=(H, W))   # [1,3,H,W], roughly zero-mean
        # scale to [0,1]
        img = (img - img.min(dim=-1, keepdim=True)[0].min(dim=-2, keepdim=True)[0])
        img = img / (img.max(dim=-1, keepdim=True)[0].max(dim=-2, keepdim=True)[0] + 1e-8)
        return img

    def create_optimizable_image(self, size=224, use_fourier=True):
        if use_fourier:
            with torch.no_grad():
                image = self._fourier_init(size)
            image = image.to(self.device)
        else:
            image = torch.rand(1, 3, size, size, device=self.device)
        image = image.detach().contiguous().requires_grad_(True)
        return image

    def total_variation_loss(self, image):
        # image: [B,3,H,W]
        diff_y = torch.abs(image[:, :, 1:, :] - image[:, :, :-1, :])
        diff_x = torch.abs(image[:, :, :, 1:] - image[:, :, :, :-1])
        tv_per_sample = diff_y.mean(dim=(1,2,3)) + diff_x.mean(dim=(1,2,3))  # [B]
        return tv_per_sample.mean()
    
    def get_feature_activations(self, layer_weights, topk_channels=None):
        total = 0.0
        for name, hook in self.hooks.items():
            feats = hook.features  # [B, tokens, C] from GELU
            if feats is None:
                continue
            w = layer_weights.get(name, 0.5)
            # aggregate: sum over tokens; then (optionally) top-k over channels
            agg = feats.sum(dim=1)  # [B, C]
            if topk_channels is not None and topk_channels > 0 and agg.shape[1] > topk_channels:
                # take the mean of top-k channels for stability
                vals, _ = torch.topk(agg, k=topk_channels, dim=1)
                act = vals.mean()
            else:
                act = agg.mean()
            total = total + w * act
        return total
    
    def generate_essence(self, tag_idx, neighbor_count=8, image_size=224, return_score=True, progress_callback=None):
        """Generate an essence visualization for a ViT model."""
        # Get tag name for logging
        tag_name = self.tag_to_name.get(tag_idx, f"Tag {tag_idx}") if self.tag_to_name else f"Tag {tag_idx}"
        print(f"Generating ViT essence for '{tag_name}' (index: {tag_idx})...")
        
        # Setup hooks for this tag
        layer_weights = self.setup_hooks(tag_idx)
        
        if not self.hooks and not hasattr(self.model, 'head'):
            print("Warning: No hooks set up and no classifier head found")
            return self._create_fallback_image(image_size), 0.0
        
        # Initialize optimizable image
        image = self.create_optimizable_image(image_size)
        
        # Create optimizer
        optimizer = torch.optim.Adam([image], lr=self.lr)
        scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            optimizer, T_max=self.iterations, eta_min=self.lr * 0.01
        )
        
        best_score = -float('inf')
        best_image = None
        
        print(f"Starting optimization for {self.iterations} iterations...")

        # Choose auxiliaries once per run
        pos_idx, pos_w, neg_idx, neg_w = neighbor_sets_from_embedding(
            self.model, tag_idx, k_pos=neighbor_count, k_neg=neighbor_count
        )

        for i in range(self.iterations):
            optimizer.zero_grad()
            
            # Clear previous activations
            for hook in self.hooks.values():
                hook.features = None
            
            # Original generation logic - tag mode only
            aug_batch = _augment_batch(image, K=self.ensemble_K, noise_std=0.15)
            out = self.model(self._preprocess(aug_batch))
            logits = out["refined_predictions"] if isinstance(out, dict) else out  # [K, T]

            cls_term = weighted_class_objective(
                logits, main_idx=tag_idx,
                plus_idxs=pos_idx, plus_w=pos_w, alpha=0.25,
                minus_idxs=neg_idx, minus_w=neg_w, beta=0.15
            )

            # keep your feature regularizer (hooks)
            feat_term = 0.0
            if self.hooks:
                feats = self.get_feature_activations(layer_weights, topk_channels=64)

            Ltv = self.total_variation_loss(aug_batch)
            total_loss = -(cls_term + 0.5 * feat_term) + self.tv_weight * Ltv
            
            # Backward pass
            total_loss.backward()
            if image.grad is None or not torch.isfinite(image.grad).all():
                print("WARN: no/invalid grad reaching the image; check hook & loss wiring.")
            
            # Gradient clip to avoid exploding updates
            torch.nn.utils.clip_grad_norm_([image], max_norm=3.0)
            optimizer.step()
            scheduler.step()

            # Keep pixels in valid range
            with torch.no_grad():
                image.clamp_(0.0, 1.0)

            # Handle non-finite losses
            if not torch.isfinite(total_loss.detach()):
                print("WARN: non-finite loss; resetting image step")
                optimizer.zero_grad(set_to_none=True)
                with torch.no_grad():
                    # Small reset toward noise
                    image.add_(0.05 * torch.randn_like(image)).clamp_(0.0, 1.0)
                continue
            
            # Track best result - using original score calculation
            with torch.no_grad():
                score_tensor = -(total_loss - self.tv_weight * Ltv)
                current_score = float(score_tensor.item())
                
                if current_score > best_score:
                    best_score = current_score
                    best_image = image.detach().clone()
            
            # Progress reporting
            if progress_callback and i % max(1, self.iterations // 20) == 0:
                progress_callback(
                    scale_idx=0,
                    scale_count=1,
                    iter_idx=i,
                    iter_count=self.iterations,
                    score=current_score
                )
            
            # Logging
            if i % max(1, self.iterations // 10) == 0:
                print(f"Iteration {i}/{self.iterations}: Score = {current_score:.4f}")
        
        # Use best image if we found one
        if best_image is not None:
            final_image = best_image
        else:
            final_image = image.detach()
        
        # Convert to PIL image
        final_image = torch.clamp(final_image, 0, 1)
        pil_img = to_pil_image(final_image[0].cpu())
        
        # Clean up hooks
        self.close_hooks()
        
        print(f"ViT essence generation complete for '{tag_name}'. Final score: {best_score:.4f}")
        
        if return_score:
            return pil_img, best_score
        else:
            return pil_img
    
    def _create_fallback_image(self, size):
        """Create a fallback image when generation fails."""
        # Create a simple noise pattern
        image = torch.randn(1, 3, size, size) * 0.5 + 0.5
        image = torch.clamp(image, 0, 1)
        return to_pil_image(image[0])

# Utility Functions
def get_quality_level(score):
    """Determine the quality level of an essence based on its score"""
    for level in reversed(list(ESSENCE_QUALITY_LEVELS.keys())):
        if score >= ESSENCE_QUALITY_LEVELS[level]["threshold"]:
            return level
    return "ZAYIN"  # Default to lowest level

def get_essence_cost(rarity):
    """Calculate the cost to generate an essence image based on tag rarity"""
    return ESSENCE_COSTS.get(rarity, 100)  # Default to 100 if rarity unknown

def save_essence_to_game_folder(image, tag, score, quality_level):
    """Save the generated essence image to a persistent game folder"""
    # Create game folder paths with better structure
    base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    game_data_dir = os.path.join(base_dir, "game_data")
    essence_folder = os.path.join(game_data_dir, "essences")
    
    # Make sure all parent directories exist
    os.makedirs(game_data_dir, exist_ok=True)
    os.makedirs(essence_folder, exist_ok=True)
    
    # Organize essences by quality level for easier browsing
    quality_folder = os.path.join(essence_folder, quality_level)
    os.makedirs(quality_folder, exist_ok=True)
    
    # Create filename with more details and better organization
    safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
    timestamp = time.strftime("%Y%m%d_%H%M%S")
    filename = f"{safe_tag}_{score:.2f}_{timestamp}.png"
    filepath = os.path.join(quality_folder, filename)
    
    # Save the image
    image.save(filepath)
    
    print(f"Saved ViT essence to: {filepath}")
    return filepath

def load_tagger_metadata():
    """Load the camie-tagger-v2-metadata.json file from parent directory."""
    try:
        # Look for metadata file in parent directory
        current_dir = os.path.dirname(os.path.abspath(__file__))
        parent_dir = os.path.dirname(current_dir)
        metadata_path = os.path.join(parent_dir, "camie-tagger-v2-metadata.json")
        
        if os.path.exists(metadata_path):
            with open(metadata_path, 'r', encoding='utf-8') as f:
                metadata = json.load(f)
            print(f"Loaded tagger metadata from: {metadata_path}")
            return metadata
        else:
            print(f"Metadata file not found at: {metadata_path}")
            return None
    except Exception as e:
        print(f"Error loading tagger metadata: {e}")
        return None
    
@lru_cache(maxsize=1)
def _load_tagger_metadata_cached():
    meta = load_tagger_metadata()
    return meta or {}

def resolve_tag_index(tag, dataset=None):
    """Robustly resolve tag -> index using dataset, session metadata, then camie-tagger-v2-metadata.json."""
    if not isinstance(tag, str):
        return int(tag)

    # normalize variants
    cands = {tag.strip(), tag.strip().replace(" ", "_")}
    cands |= {c.lower() for c in list(cands)}

    # 1) dataset.tag_to_idx
    if dataset is not None and hasattr(dataset, "tag_to_idx"):
        for c in cands:
            if c in dataset.tag_to_idx:
                return int(dataset.tag_to_idx[c])

    # 2) session metadata
    sm = getattr(st.session_state, "metadata", {}) or {}
    m = sm.get("tag_to_idx", {}) if isinstance(sm, dict) else {}
    for c in cands:
        if c in m:
            return int(m[c])

    # 3) JSON metadata (cached)
    meta = _load_tagger_metadata_cached()
    mjson = (meta.get("dataset_info", {})
                .get("tag_mapping", {})
                .get("tag_to_idx", {})) if isinstance(meta, dict) else {}
    for c in cands:
        if c in mjson:
            return int(mjson[c])

    return None

def generate_essence_for_tag(tag, model, dataset, custom_settings=None):
    """
    Generate an essence image for a specific tag using the ViT generator
    
    Args:
        tag: The tag name or index
        model: The ViT model to use
        dataset: The dataset containing tag information
        custom_settings: Optional dictionary with custom generation settings
        
    Returns:
        PIL Image of the generated essence, score, quality level
    """
    
    print(f"\n=== Starting ViT essence generation for tag '{tag}' ===")
    
    # Check if tag is discovered or a manual tag
    is_manual_tag = hasattr(st.session_state, 'manual_tags') and tag in st.session_state.manual_tags
    is_discovered = hasattr(st.session_state, 'discovered_tags') and tag in st.session_state.discovered_tags
    
    if not is_discovered and not is_manual_tag:
        st.error(f"Tag '{tag}' has not been discovered yet.")
        return None, 0, None
    
    # Get tag rarity and calculate cost
    if is_discovered:
        rarity = st.session_state.discovered_tags[tag].get("rarity", "Canard")
    elif is_manual_tag:
        rarity = st.session_state.manual_tags[tag].get("rarity", "Canard")
    else:
        rarity = "Canard"
    
    # Calculate cost based on rarity
    cost = get_essence_cost(rarity)
    
    # Check if player has enough Enkephalin
    if st.session_state.enkephalin < cost:
        st.error(f"Not enough {ENKEPHALIN_CURRENCY_NAME} to generate this essence. You need {cost} {ENKEPHALIN_ICON} but have {st.session_state.enkephalin} {ENKEPHALIN_ICON}.")
        return None, 0, None
    
    # Use provided settings or defaults
    settings = custom_settings or DEFAULT_ESSENCE_SETTINGS.copy()
    print(f"Using settings: {settings}")
    
    # UI containers for progress
    preview_container = st.empty()
    progress_container = st.empty()
    message_container = st.empty()
    
    try:
        message_container.info(f"Generating ViT essence for '{tag}' with {settings.get('layer_emphasis', 'balanced')} layer emphasis...")
        
        # Progress callback function
        def progress_callback(scale_idx, scale_count, iter_idx, iter_count, score):
            progress = iter_idx / iter_count
            progress_container.progress(progress, f"Iteration {iter_idx}/{iter_count}")
            message_container.info(f"Current score: {score:.4f}")
            
            if iter_idx % 50 == 0:
                print(f"Progress: Iteration {iter_idx}/{iter_count}, Score: {score:.4f}")
        
        # Convert tag name to index
        tag_idx = None
        
        if isinstance(tag, str):
            tag_idx = resolve_tag_index(tag, dataset)
            if tag_idx is None:
                st.error(
                    f"Tag '{tag}' index not found in dataset or metadata. "
                    f"Make sure it exists in camie-tagger-v2-metadata.json."
                )
                return None, 0, None
        else:
            tag_idx = int(tag)

        print(f"Resolved tag '{tag}' -> index {tag_idx}")
                
        # Create tag-to-name mapping
        tag_to_name = {tag_idx: tag}

        # Get or create Torch model specifically for essence generation
        torch_model = getattr(st.session_state, "model_torch", None)
        if not isinstance(torch_model, nn.Module):
            # safetensors lives ONE directory above this file
            current_dir = os.path.dirname(os.path.abspath(__file__))
            parent_dir = os.path.dirname(current_dir)
            ckpt = os.path.join(parent_dir, "camie-tagger-v2.safetensors")
            if not os.path.exists(ckpt):
                st.error(f"Missing safetensors checkpoint at: {ckpt}")
                return None, 0, None

            # metadata-driven sizes
            meta = _load_tagger_metadata_cached()
            num_tags = int(meta.get("dataset_info", {}).get("total_tags", 70527))
            img_size = int(meta.get("model_info", {}).get("img_size", 512))

            torch_model = build_torch_model_from_safetensors(
                ckpt_path=ckpt,
                num_tags=num_tags,
                backbone="vit_base_patch16_384",
                img_size=img_size
            )
            device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
            torch_model = torch_model.to(device).eval()
            st.session_state.model_torch = torch_model  # cache for later
        
        # Create ViT essence generator with settings from UI
        generator = ViTEssenceGenerator(
            model=torch_model,
            tag_to_name=tag_to_name,
            iterations=settings.get('iterations', 500),
            learning_rate=settings.get('lr', 0.05),
            layer_emphasis=settings.get('layer_emphasis', 'balanced'),
            ensemble_K=settings.get('ensemble_k', 8),
            tv_weight=settings.get('tv_weight', 1e-3)
        )
        
        image, score = generator.generate_essence(
            tag_idx=tag_idx, 
            neighbor_count=settings.get('neighbor_count', 8),
            image_size=settings.get('image_size', 512), 
            return_score=True,
            progress_callback=progress_callback
        )
        
        # Determine quality level
        quality_level = get_quality_level(score)
        
        # Deduct enkephalin cost
        st.session_state.enkephalin -= cost
        st.session_state.game_stats["enkephalin_spent"] = st.session_state.game_stats.get("enkephalin_spent", 0) + cost
        
        # Increment essence counter
        st.session_state.game_stats["essences_generated"] = st.session_state.game_stats.get("essences_generated", 0) + 1
        
        # Save to persistent location
        filepath = save_essence_to_game_folder(image, tag, score, quality_level)
        
        # Update UI with result
        preview_container.image(image, caption=f"ViT Essence of '{tag}' - Quality: {quality_level}", width=400)
        
        # Clear progress elements
        progress_container.empty()
        message_container.empty()
        
        # Store in session state
        if 'generated_essences' not in st.session_state:
            st.session_state.generated_essences = {}
        
        st.session_state.generated_essences[tag] = {
            "path": filepath,
            "score": score,
            "quality": quality_level,
            "rarity": rarity,
            "settings": settings,
            "generated_time": time.strftime("%Y-%m-%d %H:%M:%S")
        }
        
        # Show success message
        st.success(f"Successfully generated {quality_level} ViT essence for '{tag}' with score {score:.4f}! Spent {cost} {ENKEPHALIN_ICON}")
        print(f"=== ViT essence generation complete for '{tag}' ===\n")

        # Save state
        tag_storage.save_essence_state(session_state=st.session_state)
        
        return image, score, quality_level
    
    except Exception as e:
        st.error(f"Error generating ViT essence: {str(e)}")
        print(f"EXCEPTION in generate_essence_for_tag: {str(e)}")
        import traceback
        err_traceback = traceback.format_exc()
        print(err_traceback)
        st.code(err_traceback)
        return None, 0, None

# Utility Functions for Model Analysis and Layer Selection
def get_model_layers(model):
    """Utility function to get all available layers in a model."""
    layers = []
    for name, _ in model.named_modules():
        if name:  # Skip empty name (the model itself)
            layers.append(name)
    return layers

def get_key_layers(model, max_layers=15):
    """
    Get a curated list of the most relevant layers for visualization.
    """
    all_layers = get_model_layers(model)
    
    # For models with hundreds of layers, we need to be selective
    if len(all_layers) > 30:
        # Extract patterns to identify layer types
        block_patterns = {}
        
        # Find common patterns in layer names
        for layer in all_layers:
            # Extract the main component (e.g., "backbone.features")
            parts = layer.split(".")
            if len(parts) >= 2:
                prefix = ".".join(parts[:2])
                if prefix not in block_patterns:
                    block_patterns[prefix] = []
                block_patterns[prefix].append(layer)
        
        # Now select representative layers from each major block
        key_layers = {
            "early": [],
            "middle": [],
            "late": []
        }
        
        # For each major block, select layers at strategic positions
        for prefix, layers in block_patterns.items():
            if len(layers) > 3:  # Only process significant blocks
                # Sort by natural depth (assuming numerical components indicate depth)
                layers.sort(key=lambda x: [int(s) if s.isdigit() else s for s in re.findall(r'\d+|\D+', x)])
                
                # Get layers at strategic positions
                early = layers[0]
                middle = layers[len(layers) // 2]
                late = layers[-1]
                
                key_layers["early"].append(early)
                key_layers["middle"].append(middle)
                key_layers["late"].append(late)
        
        # Ensure we don't have too many layers
        # If we need to reduce further, prioritize middle and late layers
        flattened = []
        for _, group_layers in key_layers.items():
            flattened.extend(group_layers)
        
        if len(flattened) > max_layers:
            # Calculate how many to keep from each group
            total = len(flattened)
            # Prioritize keeping late layers (for character recognition)
            late_count = min(len(key_layers["late"]), max_layers // 3)
            # Allocate remaining slots between early and middle
            remaining = max_layers - late_count
            middle_count = min(len(key_layers["middle"]), remaining // 2)
            early_count = min(len(key_layers["early"]), remaining - middle_count)
            
            # Take only the needed number from each category
            key_layers["early"] = key_layers["early"][:early_count]
            key_layers["middle"] = key_layers["middle"][:middle_count]
            key_layers["late"] = key_layers["late"][:late_count]
    else:
        # For simpler models, use standard distribution
        n = len(all_layers)
        key_layers = {
            "early": all_layers[:n//3][:3],  # First few layers
            "middle": all_layers[n//3:2*n//3][:4],  # Middle layers
            "late": all_layers[2*n//3:][:3]  # Last few layers
        }
    
    # Try to identify the classifier/final layer
    classifier_layers = [layer for layer in all_layers if any(x in layer.lower() 
                      for x in ["classifier", "fc", "linear", "output", "logits", "head"])]
    if classifier_layers:
        key_layers["classifier"] = [classifier_layers[-1]]
    
    return key_layers

def get_suggested_layers(model, layer_type="balanced"):
    """
    Get suggested layers based on the desired feature type.
    """
    key_layers = get_key_layers(model)
    
    # Flatten all layers for reference
    all_key_layers = []
    for layers in key_layers.values():
        all_key_layers.extend(layers)
    
    # Choose layers based on the requested emphasis
    if layer_type == "low":
        # Focus on early visual features (textures, patterns, colors)
        selected = key_layers.get("early", [])
        # Add one middle layer for stability
        if "middle" in key_layers and key_layers["middle"]:
            selected.append(key_layers["middle"][0])
    
    elif layer_type == "mid":
        # Focus on mid-level features (parts, components)
        selected = key_layers.get("middle", [])
        # Add one early layer for context
        if "early" in key_layers and key_layers["early"]:
            selected.append(key_layers["early"][-1])
    
    elif layer_type == "high":
        # Focus on high-level semantic features (objects, characters)
        selected = key_layers.get("late", [])
        selected.extend(key_layers.get("classifier", []))
        # Add one middle layer for context
        if "middle" in key_layers and key_layers["middle"]:
            selected.append(key_layers["middle"][-1])
    
    else:  # balanced
        # Use a mix of early, middle and late layers
        selected = []
        for category in ["early", "middle", "late", "classifier"]:
            if category in key_layers and key_layers[category]:
                # Take one from each category
                selected.append(key_layers[category][0])
                # For middle and late, also take the last one if different
                if category in ["middle", "late"] and len(key_layers[category]) > 1:
                    selected.append(key_layers[category][-1])
    
    # Ensure we have at least one layer
    if not selected and all_key_layers:
        selected = [all_key_layers[-1]]  # Use the last layer as fallback
    
    return selected

# Game UI and Integration Functions
def display_essence_generator():
    """
    Display the essence generator interface
    """
    # Initialize settings
    initialize_essence_settings()
    
    st.title("🎨 Tag Essence Generator")
    st.write("Generate visual representations of what the AI model recognizes for specific tags.")
    
    # Add detailed explanation of what essences are for
    with st.expander("What are Tag Essences & How to Use Them", expanded=True):
        st.markdown("""
        ### πŸ’‘ Understanding Tag Essences
        
        Tag Essences are visual representations of what the AI model recognizes for specific tags. They can be extremely valuable for your tag collection strategy!
        
        **How to use Tag Essences:**
        1. **Generate a high-quality essence** for a tag you want to collect more of (only available on tags discovered in the library)
        2. **Save the essence image** to your computer
        3. **Upload the essence image** back into the tagger
        4. The tagger will **almost always detect the original tag**
        5. It will often also **detect related rare tags** from the same category
        
        **Strategic Value:**
        - Character essences can help unlock other tags associated with that character
        - Category essences can help discover rare tags within that category
        - High-quality essences (WAW, ALEPH) have the strongest effect
        
        **This is why Enkephalin costs are high** - essences are powerful tools that can help you discover rare tags much more efficiently than random image scanning!
        """)
    
    
    # Check for model availability
    model_available = hasattr(st.session_state, 'model')
    if not model_available:
        st.warning("Model not available. You can browse your tags but cannot generate essences.")
    
    # Create tabs for the different sections
    tabs = st.tabs(["Generate Essence", "My Essences"])
    
    with tabs[0]:
        # Check for pending generation from previous interaction
        if hasattr(st.session_state, 'selected_tag') and st.session_state.selected_tag:
            tag = st.session_state.selected_tag
            
            st.subheader(f"Generating Essence for '{tag}'")
            
            # Generate the essence
            image, score, quality = generate_essence_for_tag(
                tag, 
                st.session_state.model, 
                st.session_state.model.dataset,
                st.session_state.essence_custom_settings
            )
            
            # Show usage tips if successful
            if image is not None:
                with st.expander("Essence Usage", expanded=True):
                    st.markdown("""
                    πŸ’‘ **Tag Essence Usage Tips:**
                    1. Look for similar patterns, colors, and elements in real images
                    2. The essence reveals what features the AI model recognizes for this tag
                    3. Use this as inspiration when creating or finding images to get this tag
                    """)
            else:
                st.error("Essence generation failed. Please check the error messages above and try again with different settings.")
            
            # Clear selected tag
            st.session_state.selected_tag = None
        else:
            # Show the interface to select a tag
            selected_tag = display_essence_generation_interface(model_available)
            
            # If a tag was selected, store it for the next run and rerun
            if selected_tag:
                st.session_state.selected_tag = selected_tag
                st.rerun()
    
    with tabs[1]:
        display_saved_essences()

def essence_folder_path():
    """Get the path to the essence folder, creating it if necessary"""
    base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
    game_data_dir = os.path.join(base_dir, "game_data")
    essence_folder = os.path.join(game_data_dir, "essences")
    
    # Make sure all directories exist
    os.makedirs(game_data_dir, exist_ok=True)
    os.makedirs(essence_folder, exist_ok=True)
    
    return essence_folder

def display_saved_essences():
    """Display the user's saved essence images"""
    st.subheader("My Generated Essences")
    
    if not hasattr(st.session_state, 'generated_essences') or not st.session_state.generated_essences:
        st.info("You haven't generated any essences yet. Go to the Generate tab to create some!")
        return
        
    # Add usage instructions at the top
    st.markdown("""
    ### How to Use Your Essences
    
    1. **Click on any essence image** to open it in full size
    2. **Save the image** to your computer (right-click β†’ Save image)
    3. **Go to the Scan Images tab** and upload the saved essence
    4. The tagger will likely detect the original tag and potentially related rare tags!
    
    Higher quality essences (WAW, ALEPH) generally produce the best results.
    """)
    
    # Get the essence folder path
    essence_dir = essence_folder_path()
    
    # Try to locate any missing files
    for tag, info in st.session_state.generated_essences.items():
        if "path" in info and not os.path.exists(info["path"]):
            # Try to find the file in the essence directory
            quality = info.get("quality", "ZAYIN")
            quality_dir = os.path.join(essence_dir, quality)
            
            if os.path.exists(quality_dir):
                # Check for files with this tag name
                safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
                matching_files = [f for f in os.listdir(quality_dir) if f.startswith(safe_tag)]
                
                if matching_files:
                    # Use the most recent file if there are multiple
                    matching_files.sort(reverse=True)
                    info["path"] = os.path.join(quality_dir, matching_files[0])
                    print(f"Reconnected essence for {tag} to {info['path']}")
    
    # List essences by quality level
    essences_by_quality = {}
    for tag, info in st.session_state.generated_essences.items():
        quality = info.get("quality", "ZAYIN")  # Default to lowest if not set
        if quality not in essences_by_quality:
            essences_by_quality[quality] = []
        essences_by_quality[quality].append((tag, info))
    
    # Check if any essences exist on disk but are not tracked in session state
    try:
        untracked_essences = {}
        
        for quality in ESSENCE_QUALITY_LEVELS.keys():
            quality_dir = os.path.join(essence_dir, quality)
            if os.path.exists(quality_dir):
                essence_files = os.listdir(quality_dir)
                
                # Filter to only show PNG files
                essence_files = [f for f in essence_files if f.lower().endswith('.png')]
                
                if essence_files:
                    # Check if any of these files aren't in our tracked essences
                    for filename in essence_files:
                        # Extract tag name from filename
                        parts = filename.split('_')
                        if len(parts) >= 2:
                            tag = parts[0].replace('_', ' ')
                            
                            # Check if file is already tracked
                            is_tracked = False
                            for tracked_tag, tracked_info in st.session_state.generated_essences.items():
                                if "path" in tracked_info and os.path.basename(tracked_info["path"]) == filename:
                                    is_tracked = True
                                    break
                            
                            if not is_tracked:
                                if quality not in untracked_essences:
                                    untracked_essences[quality] = []
                                untracked_essences[quality].append((tag, {
                                    "path": os.path.join(quality_dir, filename),
                                    "quality": quality,
                                    "discovered_on_disk": True
                                }))
    except Exception as e:
        print(f"Error checking for untracked essences: {e}")
    
    # Combine tracked and untracked essences
    for quality, essences in untracked_essences.items():
        if quality not in essences_by_quality:
            essences_by_quality[quality] = []
        for tag, info in essences:
            # Only add if we don't already have this tag in this quality level
            if not any(tracked_tag == tag for tracked_tag, _ in essences_by_quality[quality]):
                essences_by_quality[quality].append((tag, info))
    
    # Show essences from highest to lowest quality
    for quality in list(ESSENCE_QUALITY_LEVELS.keys())[::-1]:
        if quality in essences_by_quality:
            essences = essences_by_quality[quality]
            color = ESSENCE_QUALITY_LEVELS[quality]["color"]
            
            with st.expander(f"{quality} Essences ({len(essences)})", expanded=quality in ["ALEPH", "WAW"]):
                # Create grid layout
                cols = st.columns(3)
                for i, (tag, info) in enumerate(sorted(essences, key=lambda x: x[1].get("score", 0), reverse=True)):
                    col_idx = i % 3
                    with cols[col_idx]:
                        try:
                            # Try to load the image from path
                            if "path" in info and os.path.exists(info["path"]):
                                image = Image.open(info["path"])
                                rarity = info.get("rarity", "Canard")
                                score = info.get("score", 0)
                                
                                # Get color for rarity
                                rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
                                
                                # Display the image with metadata
                                st.image(image, caption=tag, use_container_width=True)
                                
                                # Use special styling for rare tags
                                if rarity == "Impuritas Civitas":
                                    st.markdown(f"""
                                    <span style='color:{color};font-weight:bold;'>{quality}</span> | 
                                    <span style='animation: rainbow-text 4s linear infinite;font-weight:bold;'>{rarity}</span> | 
                                    Score: {score:.2f}
                                    """, unsafe_allow_html=True)
                                elif rarity == "Star of the City":
                                    st.markdown(f"""
                                    <span style='color:{color};font-weight:bold;'>{quality}</span> | 
                                    <span style='color:{rarity_color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity}</span> | 
                                    Score: {score:.2f}
                                    """, unsafe_allow_html=True)
                                elif rarity == "Urban Nightmare":
                                    st.markdown(f"""
                                    <span style='color:{color};font-weight:bold;'>{quality}</span> | 
                                    <span style='color:{rarity_color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity}</span> | 
                                    Score: {score:.2f}
                                    """, unsafe_allow_html=True)
                                elif rarity == "Urban Plague":
                                    st.markdown(f"""
                                    <span style='color:{color};font-weight:bold;'>{quality}</span> | 
                                    <span style='color:{rarity_color};text-shadow:0 0 1px #9C27B0;font-weight:bold;'>{rarity}</span> | 
                                    Score: {score:.2f}
                                    """, unsafe_allow_html=True)
                                else:
                                    st.markdown(f"""
                                    <span style='color:{color};font-weight:bold;'>{quality}</span> | 
                                    <span style='color:{rarity_color};font-weight:bold;'>{rarity}</span> | 
                                    Score: {score:.2f}
                                    """, unsafe_allow_html=True)
                                
                                # Add file info
                                if "discovered_on_disk" in info and info["discovered_on_disk"]:
                                    st.info("Found on disk (not in session state)")
                                
                                # Add button to open folder
                                if st.button(f"Open Folder", key=f"open_folder_{tag}_{quality}"):
                                    folder_path = os.path.dirname(info["path"])
                                    try:
                                        # Try different methods to open folder based on platform
                                        if os.name == 'nt':  # Windows
                                            os.startfile(folder_path)
                                        elif os.name == 'posix':  # macOS or Linux
                                            import subprocess
                                            if 'darwin' in os.sys.platform:  # macOS
                                                subprocess.call(['open', folder_path])
                                            else:  # Linux
                                                subprocess.call(['xdg-open', folder_path])
                                        st.success(f"Opened folder: {folder_path}")
                                    except Exception as e:
                                        st.error(f"Could not open folder: {str(e)}")
                                        # Provide the path for manual navigation
                                        st.code(folder_path)
                            else:
                                # Could not find image
                                st.warning(f"Image file not found: {info.get('path', 'No path available')}")
                                
                                # Show quality and tag name
                                st.markdown(f"""
                                <span style='color:{color};font-weight:bold;'>{quality}</span> | {tag}
                                """, unsafe_allow_html=True)
                                
                                # Only add reconnect button if we have some metadata
                                if "rarity" in info and "score" in info:
                                    if st.button(f"Reconnect File", key=f"reconnect_{tag}_{quality}"):
                                        # Update path in session state
                                        safe_tag = tag.replace('/', '_').replace('\\', '_').replace(' ', '_')
                                        score = info.get("score", 0)
                                        quality_dir = os.path.join(essence_dir, quality)
                                        
                                        # Create directory if it doesn't exist
                                        os.makedirs(quality_dir, exist_ok=True)
                                        
                                        # Set a path - user will need to manually add the image
                                        timestamp = time.strftime("%Y%m%d_%H%M%S")
                                        filename = f"{safe_tag}_{score:.2f}_{timestamp}.png"
                                        info["path"] = os.path.join(quality_dir, filename)
                                        
                                        st.info(f"Please save your image to this location: {info['path']}")
                                        st.session_state.generated_essences[tag] = info
                                        tag_storage.save_essence_state(session_state=st.session_state)
                                        st.rerun()
                        
                        except Exception as e:
                            st.write(f"Error loading {tag}: {str(e)}")
    
    # Add option to clean up missing files
    st.divider()
    if st.button("Clean Up Missing Files", help="Remove entries for essences where the file no longer exists"):
        # Find all entries with missing files
        to_remove = []
        for tag, info in st.session_state.generated_essences.items():
            if "path" in info and not os.path.exists(info["path"]):
                to_remove.append(tag)
        
        # Remove them
        for tag in to_remove:
            del st.session_state.generated_essences[tag]
        
        # Save state
        tag_storage.save_essence_state(session_state=st.session_state)
        
        if to_remove:
            st.success(f"Removed {len(to_remove)} entries with missing files")
        else:
            st.success("No missing files found")
        
        st.rerun()

def display_essence_generation_interface(model_available):
    """Display the interface for generating new essences"""
    # Initialize manual tags
    initialize_manual_tags()
    
    st.subheader("Generate Tag Essence")
    st.write("Select a tag to generate its essence. Higher quality essences can help unlock rare related tags when uploaded back into the tagger.")
    
    # Settings column
    col1, col2 = st.columns(2)
    
    with col1:
        st.write("Generation Settings:")
        
        # Add reset button
        if st.button("Reset to Defaults", help="Clear saved settings and use default values"):
            st.session_state.essence_custom_settings = DEFAULT_ESSENCE_SETTINGS.copy()
            tag_storage.save_essence_state(session_state=st.session_state)
            st.success("Settings reset to defaults!")
            st.rerun()
        
        # Advanced settings with better organization
        with st.expander("Advanced Settings", expanded=True):
            col_a, col_b = st.columns(2)
            
            with col_a:
                # Core generation parameters
                st.write("**Core Parameters**")
                iterations = st.slider(
                    "Iterations", 
                    min_value=64, 
                    max_value=2048, 
                    value=st.session_state.essence_custom_settings.get("iterations", 500), 
                    step=64,
                    help="More iterations improve quality but take longer"
                )
                
                lr = st.slider(
                    "Learning Rate", 
                    min_value=0.01, 
                    max_value=0.2, 
                    value=st.session_state.essence_custom_settings.get("lr", 0.05), 
                    step=0.01,
                    help="Higher learning rates converge faster but may be less stable"
                )
                
                ensemble_k = st.slider(
                    "Ensemble Size", 
                    min_value=1, 
                    max_value=16, 
                    value=st.session_state.essence_custom_settings.get("ensemble_k", 8), 
                    step=1,
                    help="Number of augmented versions per iteration. Higher = more stable but slower"
                )
            
            with col_b:
                # Multi-tag parameters
                st.write("**Multi-Tag Enhancement**")
                neighbor_count = st.slider(
                    "Neighbor Tags", 
                    min_value=0, 
                    max_value=16, 
                    value=st.session_state.essence_custom_settings.get("neighbor_count", 8), 
                    step=1,
                    help="Number of similar/dissimilar tags to consider. 0 = only target tag"
                )
                
                tv_weight = st.select_slider(
                    "Smoothness", 
                    options=[1e-4, 5e-4, 1e-3, 2e-3, 5e-3, 1e-2],
                    value=st.session_state.essence_custom_settings.get("tv_weight", 1e-3),
                    format_func=lambda x: f"{x:.0e}",
                    help="Higher values create smoother, less noisy images"
                )
                
                # Layer emphasis selection
                layer_emphasis = st.selectbox(
                    "Feature Targeting", 
                    options=["balanced", "early", "mid", "late"],
                    index=0,
                    format_func=lambda x: {
                        "balanced": "Balanced (mix of features)",
                        "early": "Early (textures, patterns)", 
                        "mid": "Mid (parts, components)",
                        "late": "Late (characters, objects)"
                    }.get(x, x),
                    help="Controls which model features to emphasize"
                )
        
        # Save settings
        st.session_state.essence_custom_settings = {
            "iterations": iterations,
            "lr": lr,
            "ensemble_k": ensemble_k,
            "neighbor_count": neighbor_count,
            "image_size": 512,  # Fixed for now
            "layer_emphasis": layer_emphasis,
            "tv_weight": tv_weight
        }
        
        # Show current settings summary
        st.info(f"""
        **Current Settings:**
        - Iterations: {iterations}
        - Learning Rate: {lr}
        - Ensemble Size: {ensemble_k}
        - Neighbor Tags: {neighbor_count}
        - Feature Focus: {layer_emphasis.capitalize()}
        """)
    
    with col2:
        # Show quality level descriptions
        st.write("Quality Levels:")
        for level, info in ESSENCE_QUALITY_LEVELS.items():
            st.markdown(f"""
            <div style="padding:5px;margin-bottom:5px;border-radius:4px;background-color:rgba({int(info['color'][1:3], 16)},{int(info['color'][3:5], 16)},{int(info['color'][5:7], 16)},0.1);border-left:3px solid {info['color']}">
                <span style="color:{info['color']};font-weight:bold;">{level}</span> ({info['threshold']:.0f} Score+): {info['description']}
            </div>
            """, unsafe_allow_html=True)
        
        # Feature targeting explanation
        st.write("Feature Targeting Explanation:")
        st.markdown("""
        - **Early**: Textures, colors, simple patterns
        - **Mid**: Parts, components, intermediate features  
        - **Late**: Characters, objects, high-level concepts
        - **Balanced**: Mix of all feature levels
        """)
    
    # Show current Enkephalin
    st.markdown(f"### Your {ENKEPHALIN_CURRENCY_NAME}: **{st.session_state.enkephalin}** {ENKEPHALIN_ICON}")
    st.divider()
    
    # Add CSS for animations matching tag collection display
    st.markdown("""
    <style>
    @keyframes rainbow-text {
        0% { color: red; }
        14% { color: orange; }
        28% { color: yellow; }
        42% { color: green; }
        57% { color: blue; }
        71% { color: indigo; }
        85% { color: violet; }
        100% { color: red; }
    }
    
    .impuritas-text {
        font-weight: bold;
        animation: rainbow-text 4s linear infinite;
    }
    
    @keyframes glow-text {
        0% { text-shadow: 0 0 2px gold; }
        50% { text-shadow: 0 0 6px gold; }
        100% { text-shadow: 0 0 2px gold; }
    }
    
    .star-text {
        color: #FFEB3B;
        text-shadow: 0 0 3px gold;
        animation: glow-text 2s infinite;
        font-weight: bold;
    }
    
    @keyframes pulse-text {
        0% { opacity: 0.8; }
        50% { opacity: 1; }
        100% { opacity: 0.8; }
    }
    
    .nightmare-text {
        color: #FF9800;
        text-shadow: 0 0 1px #FF5722;
        animation: pulse-text 3s infinite;
        font-weight: bold;
    }
    
    .plague-text {
        color: #9C27B0;
        text-shadow: 0 0 1px #9C27B0;
        font-weight: bold;
    }
    
    .category-section {
        margin-top: 20px;
        margin-bottom: 30px;
        padding: 10px;
        border-radius: 5px;
        border-left: 5px solid;
    }
    </style>
    """, unsafe_allow_html=True)
    
    # Tag collection display (unchanged from original)
    # Gather all tags for essence generation
    all_tags = []
    
    # Process discovered tags
    if hasattr(st.session_state, 'discovered_tags'):
        for tag, info in st.session_state.discovered_tags.items():
            tag_info = {
                "tag": tag,
                "rarity": info.get("rarity", "Unknown"),
                "category": info.get("category", "unknown"),
                "source": "discovered",
                "library_floor": info.get("library_floor", ""),
                "discovery_time": info.get("discovery_time", "")
            }
            all_tags.append(tag_info)
    
    # Process manual tags
    if hasattr(st.session_state, 'manual_tags'):
        for tag, info in st.session_state.manual_tags.items():
            tag_info = {
                "tag": tag,
                "rarity": info.get("rarity", "Special"),
                "category": info.get("category", "special"),
                "source": "manual",
                "description": info.get("description", "")
            }
            all_tags.append(tag_info)
    
    # Count tags by rarity
    rarity_counts = {}
    for info in all_tags:
        rarity = info["rarity"]
        if rarity not in rarity_counts:
            rarity_counts[rarity] = 0
        rarity_counts[rarity] += 1
    
    # Display rarity counts at the top
    st.subheader("Available Tags for Essence Generation")
    st.write(f"You have {len(all_tags)} tags available for essence generation. Collect more from the library!")
    
    # Display rarity distribution
    rarity_cols = st.columns(len(rarity_counts))
    for i, (rarity, count) in enumerate(sorted(rarity_counts.items(), 
                                      key=lambda x: list(RARITY_LEVELS.keys()).index(x[0]) if x[0] in RARITY_LEVELS else 999)):
        with rarity_cols[i]:
            # Get color with fallback
            color = RARITY_LEVELS.get(rarity, {}).get("color", "#888888")
            
            # Apply special styling based on rarity
            style = f"color:{color};font-weight:bold;"
            class_name = ""
            
            if rarity == "Impuritas Civitas":
                class_name = "grid-impuritas"
            elif rarity == "Star of the City":
                class_name = "grid-star"
            elif rarity == "Urban Nightmare":
                class_name = "grid-nightmare"
            elif rarity == "Urban Plague":
                class_name = "grid-plague"
            
            if class_name:
                st.markdown(
                    f"<div style='text-align:center;'><span class='{class_name}' style='font-weight:bold;'>{rarity.capitalize()}</span><br>{count}</div>",
                    unsafe_allow_html=True
                )
            else:
                st.markdown(
                    f"<div style='text-align:center;'><span style='{style}'>{rarity.capitalize()}</span><br>{count}</div>",
                    unsafe_allow_html=True
                )
    
    # Search box for all tags
    search_term = st.text_input("Search tags", "", key="essence_search_tags")
    
    # Sort options
    sort_options = ["Category (rarest first)", "Rarity", "Discovery Time"]
    selected_sort = st.selectbox("Sort tags by:", sort_options, key="essence_tags_sort")
    
    # Filter tags by search term if provided
    if search_term:
        all_tags = [info for info in all_tags if search_term.lower() in info["tag"].lower()]
    
    selected_tag = None
    
    # Sort and group tags based on selection (rest of the display logic unchanged)
    if selected_sort == "Category (rarest first)":
        # Group tags by category
        categories = {}
        for info in all_tags:
            category = info["category"]
            if category not in categories:
                categories[category] = []
            categories[category].append(info)
        
        # Display tags by category in expanders
        for category, tags in sorted(categories.items()):
            # Get rarity order for sorting
            rarity_order = list(reversed(RARITY_LEVELS.keys()))
            
            # Sort tags by rarity (rarest first)
            def get_rarity_index(info):
                rarity = info["rarity"]
                if rarity in rarity_order:
                    return len(rarity_order) - rarity_order.index(rarity)
                return 0
            
            sorted_tags = sorted(tags, key=get_rarity_index, reverse=True)
            
            # Check if category has any rare tags
            has_rare_tags = any(info["rarity"] in ["Impuritas Civitas", "Star of the City"] 
                               for info in sorted_tags)
            
            # Get category info if available
            category_display = category.capitalize()
            if category in TAG_CATEGORIES:
                category_info = TAG_CATEGORIES[category]
                icon = category_info.get("icon", "")
                color = category_info.get("color", "#888888")
                category_display = f"<span style='color:{color};'>{icon} {category.capitalize()}</span>"
            
            # Create header with information about rare tags if present
            header = f"{category_display} ({len(tags)} tags)"
            if has_rare_tags:
                header += " ✨ Contains rare tags!"
                
            # Display category header and expander
            st.markdown(header, unsafe_allow_html=True)
            with st.expander("Show/Hide", expanded=has_rare_tags):
                # Create grid layout for tags
                cols = st.columns(3)
                for i, info in enumerate(sorted_tags):
                    with cols[i % 3]:
                        tag = info["tag"]
                        rarity = info["rarity"]
                        source = info["source"]
                        
                        # Get rarity color
                        rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
                        
                        # Check if this tag has an essence already
                        has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
                        
                        # Get cost for this tag
                        cost = get_essence_cost(rarity)
                        can_afford = st.session_state.enkephalin >= cost
                        
                        # Format tag display with special styling
                        if rarity == "Impuritas Civitas":
                            tag_display = f'<span class="impuritas-text">{tag}</span>'
                        elif rarity == "Star of the City":
                            tag_display = f'<span class="star-text">{tag}</span>'
                        elif rarity == "Urban Nightmare":
                            tag_display = f'<span class="nightmare-text">{tag}</span>'
                        elif rarity == "Urban Plague":
                            tag_display = f'<span class="plague-text">{tag}</span>'
                        else:
                            tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>'
                        
                        # Show tag with rarity badge and cost
                        st.markdown(
                            f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})',
                            unsafe_allow_html=True
                        )
                        
                        # Show discovery details if available
                        if source == "discovered" and "library_floor" in info and info["library_floor"]:
                            st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', 
                                      unsafe_allow_html=True)
                        elif source == "manual" and "description" in info and info["description"]:
                            st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', 
                                      unsafe_allow_html=True)
                        
                        # Add generation button
                        button_label = "Generate" if not has_essence else "Regenerate βœ“"
                        if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford):
                            selected_tag = tag
                            
    elif selected_sort == "Rarity":
        # Group tags by rarity
        rarity_groups = {}
        for info in all_tags:
            rarity = info["rarity"]
            if rarity not in rarity_groups:
                rarity_groups[rarity] = []
            rarity_groups[rarity].append(info)
        
        # Get ordered rarities (rarest first)
        ordered_rarities = list(RARITY_LEVELS.keys())
        ordered_rarities.reverse()  # Reverse to show rarest first
        
        # Add any rarities not in RARITY_LEVELS
        for rarity in rarity_groups.keys():
            if rarity not in ordered_rarities:
                ordered_rarities.append(rarity)
        
        # Display tags by rarity
        for rarity in ordered_rarities:
            if rarity in rarity_groups:
                tags = rarity_groups[rarity]
                color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
                
                # Add special styling for rare rarities
                rarity_html = f"<span style='color:{color};font-weight:bold;'>{rarity.capitalize()}</span>"
                if rarity == "Impuritas Civitas":
                    rarity_html = f"<span style='animation:rainbow-text 4s linear infinite;font-weight:bold;'>{rarity.capitalize()}</span>"
                elif rarity == "Star of the City":
                    rarity_html = f"<span style='color:{color};text-shadow:0 0 3px gold;font-weight:bold;'>{rarity.capitalize()}</span>"
                elif rarity == "Urban Nightmare":
                    rarity_html = f"<span style='color:{color};text-shadow:0 0 1px #FF5722;font-weight:bold;'>{rarity.capitalize()}</span>"
                
                # First create the title with HTML, then use it in the expander
                st.markdown(f"### {rarity_html} ({len(tags)} tags)", unsafe_allow_html=True)
                with st.expander("Show/Hide", expanded=rarity in ["Impuritas Civitas", "Star of the City"]):
                    # Create grid layout for tags
                    cols = st.columns(3)
                    for i, info in enumerate(sorted(tags, key=lambda x: x["tag"])):
                        with cols[i % 3]:
                            tag = info["tag"]
                            source = info["source"]
                            
                            # Check if this tag has an essence already
                            has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
                            
                            # Get cost for this tag
                            cost = get_essence_cost(rarity)
                            can_afford = st.session_state.enkephalin >= cost
                            
                            # Show tag with cost
                            st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})")
                            
                            # Show discovery details if available
                            if source == "discovered" and "library_floor" in info and info["library_floor"]:
                                st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', 
                                          unsafe_allow_html=True)
                            elif source == "manual" and "description" in info and info["description"]:
                                st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', 
                                          unsafe_allow_html=True)
                            
                            # Add generation button
                            button_label = "Generate" if not has_essence else "Regenerate βœ“"
                            if st.button(button_label, key=f"gen_{tag}_{source}", disabled=not model_available or not can_afford):
                                selected_tag = tag
                            
    elif selected_sort == "Discovery Time":
        # Filter to just discovered tags (manual tags don't have discovery time)
        discovered_tags = [info for info in all_tags if info["source"] == "discovered" and "discovery_time" in info]
        
        # Sort all tags by discovery time (newest first)
        sorted_tags = sorted(discovered_tags, key=lambda x: x["discovery_time"], reverse=True)
        
        # Group by date
        date_groups = {}
        for info in sorted_tags:
            time_str = info["discovery_time"]
            # Extract just the date part if timestamp has date and time
            date = time_str.split()[0] if " " in time_str else time_str
            
            if date not in date_groups:
                date_groups[date] = []
            date_groups[date].append(info)
        
        # Display tags grouped by discovery date
        for date, tags in date_groups.items():
            date_display = date if date else "Unknown date"
            st.markdown(f"### Discovered on {date_display} ({len(tags)} tags)")
            
            with st.expander("Show/Hide", expanded=date == list(date_groups.keys())[0]):  # Expand most recent by default
                # Create grid layout for tags
                cols = st.columns(3)
                for i, info in enumerate(tags):
                    with cols[i % 3]:
                        tag = info["tag"]
                        rarity = info["rarity"]
                        
                        # Get rarity color
                        rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
                        
                        # Check if this tag has an essence already
                        has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
                        
                        # Get cost for this tag
                        cost = get_essence_cost(rarity)
                        can_afford = st.session_state.enkephalin >= cost
                        
                        # Format tag display with special styling
                        if rarity == "Impuritas Civitas":
                            tag_display = f'<span class="impuritas-text">{tag}</span>'
                        elif rarity == "Star of the City":
                            tag_display = f'<span class="star-text">{tag}</span>'
                        elif rarity == "Urban Nightmare":
                            tag_display = f'<span class="nightmare-text">{tag}</span>'
                        elif rarity == "Urban Plague":
                            tag_display = f'<span class="plague-text">{tag}</span>'
                        else:
                            tag_display = f'<span style="color:{rarity_color};font-weight:bold;">{tag}</span>'
                        
                        # Show tag with rarity badge and cost
                        st.markdown(
                            f'{tag_display} <span style="background-color:{rarity_color};color:white;padding:2px 6px;border-radius:10px;font-size:0.8em;">{rarity.capitalize()}</span> ({cost} {ENKEPHALIN_ICON})',
                            unsafe_allow_html=True
                        )
                        
                        # Show discovery details
                        if "library_floor" in info and info["library_floor"]:
                            st.markdown(f'<span style="font-size:0.85em;">Found in: {info["library_floor"]}</span>', 
                                      unsafe_allow_html=True)
                        
                        # Add generation button
                        button_label = "Generate" if not has_essence else "Regenerate βœ“"
                        if st.button(button_label, key=f"gen_{tag}_disc", disabled=not model_available or not can_afford):
                            selected_tag = tag
        
        # Show manual tags separately if we have any
        manual_tags = [info for info in all_tags if info["source"] == "manual"]
        if manual_tags:
            st.markdown("### Manual Tags")
            with st.expander("Show/Hide"):
                # Create grid layout for tags
                cols = st.columns(3)
                for i, info in enumerate(manual_tags):
                    with cols[i % 3]:
                        tag = info["tag"]
                        rarity = info["rarity"]
                        
                        # Get rarity color
                        rarity_color = RARITY_LEVELS.get(rarity, {}).get("color", "#AAAAAA")
                        
                        # Check if this tag has an essence already
                        has_essence = hasattr(st.session_state, 'generated_essences') and tag in st.session_state.generated_essences
                        
                        # Get cost for this tag
                        cost = get_essence_cost(rarity)
                        can_afford = st.session_state.enkephalin >= cost
                        
                        # Show tag with rarity badge and cost
                        st.markdown(f"**{tag}** ({cost} {ENKEPHALIN_ICON})")
                        
                        # Show description if available
                        if "description" in info and info["description"]:
                            st.markdown(f'<span style="font-size:0.85em;font-style:italic;">{info["description"]}</span>', 
                                      unsafe_allow_html=True)
                        
                        # Add generation button
                        button_label = "Generate" if not has_essence else "Regenerate βœ“"
                        if st.button(button_label, key=f"gen_{tag}_manual", disabled=not model_available or not can_afford):
                            selected_tag = tag
    
    return selected_tag