File size: 59,574 Bytes
f214f36
 
 
 
 
 
 
 
 
 
 
3a7a5c6
f214f36
 
 
 
 
 
 
 
011baea
 
c9b5c5a
f214f36
011baea
f214f36
9006816
165ea49
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
011baea
9783844
 
011baea
9783844
011baea
9783844
011baea
9783844
011baea
9783844
 
011baea
 
 
9783844
 
 
011baea
 
9783844
011baea
 
9783844
 
011baea
 
 
 
 
 
 
 
 
9783844
011baea
 
 
 
 
 
f214f36
 
 
9783844
f214f36
 
 
 
 
 
 
 
 
011baea
f214f36
 
 
 
 
 
 
 
 
 
011baea
 
 
9783844
 
011baea
 
9783844
f214f36
 
011baea
 
9783844
 
f214f36
011baea
f214f36
 
 
 
 
 
3a7a5c6
f214f36
 
04a10c7
5612036
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
011baea
 
 
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63f814c
 
 
 
ec7a6b4
63f814c
 
 
 
 
 
 
f214f36
 
 
5612036
63f814c
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
011baea
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
011baea
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a7a5c6
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
011baea
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a7a5c6
f214f36
 
 
 
 
 
 
 
 
 
dd6e1b6
 
 
 
 
 
 
9006816
dd6e1b6
 
 
f214f36
 
 
 
 
 
 
41db03f
f214f36
 
 
 
 
 
cb1ecf3
 
 
 
 
 
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
011baea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f214f36
011baea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f214f36
011baea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f214f36
011baea
 
 
 
 
 
 
 
 
 
 
 
f214f36
011baea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b9f4b3
ae6df13
011baea
 
 
9783844
 
 
ce7e944
9783844
 
 
 
 
 
 
 
 
ae6df13
9783844
 
ae6df13
 
9783844
 
011baea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6b9f4b3
011baea
ae6df13
011baea
 
 
9783844
 
 
 
 
ae6df13
9783844
 
011baea
 
 
 
 
 
 
 
 
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
011baea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f214f36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
011baea
 
 
9783844
011baea
 
f214f36
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
1255
1256
1257
1258
1259
1260
1261
1262
1263
1264
1265
1266
1267
1268
1269
1270
1271
1272
1273
1274
1275
1276
1277
1278
1279
1280
1281
1282
1283
1284
1285
1286
1287
1288
1289
1290
1291
1292
1293
1294
1295
1296
1297
1298
1299
1300
1301
1302
1303
1304
1305
1306
1307
1308
1309
1310
1311
1312
1313
1314
1315
1316
1317
1318
1319
1320
1321
1322
1323
1324
1325
1326
1327
1328
1329
1330
1331
1332
1333
1334
1335
1336
1337
1338
1339
1340
1341
1342
1343
1344
1345
1346
1347
1348
1349
1350
1351
1352
1353
1354
1355
1356
1357
1358
1359
1360
1361
1362
1363
1364
1365
1366
1367
1368
1369
# Acknowledgement: This demo code is adapted from the original Hugging Face Space "ContextCite"
# (https://huggingface.co/spaces/contextcite/context-cite).
import os
from enum import Enum
from dataclasses import dataclass
from typing import Dict, List, Any, Optional
import gradio as gr
import numpy as np
import spaces
import nltk
import base64
import traceback
from src.utils import split_into_sentences as split_into_sentences_utils
# --- AttnTrace imports (from app_full.py) ---
from src.models import create_model
from src.attribution import AttnTraceAttribution
from src.prompts import wrap_prompt
from gradio_highlightedtextbox import HighlightedTextbox
from examples import run_example_1, run_example_2, run_example_3, run_example_4, run_example_5, run_example_6
from functools import partial


from nltk.tokenize import sent_tokenize


# Load original app constants
APP_TITLE = '<div class="app-title"><span class="brand">AttnTrace: </span><span class="subtitle">Attention-based Context Traceback for Long-Context LLMs</span></div>'
APP_DESCRIPTION = """AttnTrace traces a model's generated statements back to specific parts of the context using attention-based traceback. Try it out with Meta-Llama-3.1-8B-Instruct here! See the [[paper](https://arxiv.org/abs/2508.03793)] and [[code](https://github.com/Wang-Yanting/AttnTrace)] for more!
Maintained by the AttnTrace team."""
# NEW_TEXT = """Long-context large language models (LLMs), such as Gemini-2.5-Pro and Claude-Sonnet-4, are increasingly used to empower advanced AI systems, including retrieval-augmented generation (RAG) pipelines and autonomous agents. In these systems, an LLM receives an instruction along with a contextβ€”often consisting of texts retrieved from a knowledge database or memoryβ€”and generates a response that is contextually grounded by following the instruction. Recent studies have designed solutions to trace back to a subset of texts in the context that contributes most to the response generated by the LLM. These solutions have numerous real-world applications, including performing post-attack forensic analysis and improving the interpretability and trustworthiness of LLM outputs. While significant efforts have been made, state-of-the-art solutions such as TracLLM often lead to a high computation cost, e.g., it takes TracLLM hundreds of seconds to perform traceback for a single response-context pair. In this work, we propose {\name}, a new context traceback method based on the attention weights produced by an LLM for a prompt. To effectively utilize attention weights, we introduce two techniques designed to enhance the effectiveness of {\name}, and we provide theoretical insights for our design choice. %Moreover, we perform both theoretical analysis and empirical evaluation to demonstrate their effectiveness. 
# We also perform a systematic evaluation for {\name}. The results demonstrate that {\name} is more accurate and efficient than existing state-of-the-art context traceback methods. We also show {\name} can improve state-of-the-art methods in detecting prompt injection under long contexts through the attribution-before-detection paradigm. As a real-world application, we demonstrate that {\name} can effectively pinpoint injected instructions in a paper designed to manipulate LLM-generated reviews.
# The code and data will be open-sourced. """
# EDIT_TEXT = "Feel free to edit!"
GENERATE_CONTEXT_TOO_LONG_TEXT = (
    '<em style="color: red;">Context is too long for the current model.</em>'
)
ATTRIBUTE_CONTEXT_TOO_LONG_TEXT = '<em style="color: red;">Context is too long for the current traceback method.</em>'
CONTEXT_LINES = 20
CONTEXT_MAX_LINES = 40
SELECTION_DEFAULT_TEXT = "Click on a sentence in the response to traceback!"
SELECTION_DEFAULT_VALUE = [(SELECTION_DEFAULT_TEXT, None)]
SOURCES_INFO = 'These are the texts that contribute most to the response.'
# SOURCES_IN_CONTEXT_INFO = (
#     "This shows the important sentences highlighted within their surrounding context from the text above. Colors indicate ranking: Red (1st), Orange (2nd), Golden (3rd), Yellow (4th-5th), Light (6th+)."
# )

MODEL_PATHS = [
    "meta-llama/Meta-Llama-3.1-8B-Instruct",
]
MAX_TOKENS = {
    "meta-llama/Meta-Llama-3.1-8B-Instruct": 131072,
}
DEFAULT_MODEL_PATH = MODEL_PATHS[0]
EXPLANATION_LEVELS = ["sentence", "paragraph", "text segment"]
DEFAULT_EXPLANATION_LEVEL = "sentence"

class WorkflowState(Enum):
    WAITING_TO_GENERATE = 0
    WAITING_TO_SELECT = 1
    READY_TO_ATTRIBUTE = 2

@dataclass
class State:
    workflow_state: WorkflowState
    context: str
    query: str
    response: str
    start_index: int
    end_index: int
    scores: np.ndarray
    answer: str
    highlighted_context: str
    full_response: str
    explained_response_part: str
    last_query_used: str = ""

# --- Dynamic Model and Attribution Management ---
current_llm = None
current_attr = None
current_model_path = None
current_explanation_level = None
current_api_key = None
current_top_k = 3  # Add top-k tracking
current_B = 30  # Add B parameter tracking
current_q = 0.4  # Add q parameter tracking

def update_configuration(explanation_level, top_k, B, q):
    """Update the global configuration and reinitialize attribution if needed"""
    global current_explanation_level, current_top_k, current_B, current_q, current_attr, current_llm
    
    # Convert parameters to appropriate types
    top_k = int(top_k)
    B = int(B)
    q = float(q)
    
    # Check if configuration has changed
    config_changed = (current_explanation_level != explanation_level or 
                     current_top_k != top_k or
                     current_B != B or
                     current_q != q)
    
    if config_changed:
        print(f"πŸ”„ Updating configuration: explanation_level={explanation_level}, top_k={top_k}, B={B}, q={q}")
        current_explanation_level = explanation_level
        current_top_k = top_k
        current_B = B
        current_q = q
        
        # Reset both model and attribution to force complete reinitialization
        current_llm = None
        current_attr = None
        
        # Reinitialize with new configuration
        try:
            llm, attr, error_msg = initialize_model_and_attr()
            if llm is not None and attr is not None:
                return gr.update(value=f"βœ… Configuration updated: {explanation_level} level, top-{top_k}, B={B}, q={q}")
            else:
                return gr.update(value=f"❌ Error reinitializing: {error_msg}")
        except Exception as e:
            return gr.update(value=f"❌ Error updating configuration: {str(e)}")
    else:
        return gr.update(value="ℹ️ Configuration unchanged")

def initialize_model_and_attr():
    """Initialize model and attribution with default configuration"""
    global current_llm, current_attr, current_model_path, current_explanation_level, current_api_key, current_top_k, current_B, current_q
    
    try:
        # Check if we need to reinitialize the model
        need_model_update = (current_llm is None or 
                           current_model_path != DEFAULT_MODEL_PATH or 
                           current_api_key != os.getenv("HF_TOKEN"))
        
        # Check if we need to update attribution
        need_attr_update = (current_attr is None or 
                          current_explanation_level != (current_explanation_level or DEFAULT_EXPLANATION_LEVEL) or 
                          need_model_update)
        
        if need_model_update:
            print(f"Initializing model: {DEFAULT_MODEL_PATH}")
            effective_api_key = os.getenv("HF_TOKEN")
            current_llm = create_model(model_path=DEFAULT_MODEL_PATH, api_key=effective_api_key, device="cuda")
            current_model_path = DEFAULT_MODEL_PATH
            current_api_key = effective_api_key
            
        if need_attr_update:
            # Use current configuration or defaults
            explanation_level = current_explanation_level or DEFAULT_EXPLANATION_LEVEL
            top_k = current_top_k or 3
            B = current_B or 30
            q = current_q or 0.4
            if "segment" in explanation_level:
                explanation_level = "segment"
            print(f"Initializing context traceback with explanation level: {explanation_level}, top_k: {top_k}, B: {B}, q: {q}")
            current_attr = AttnTraceAttribution(
                current_llm, 
                explanation_level= explanation_level, 
                K=top_k, 
                q=q, 
                B=B
            )
            current_explanation_level = explanation_level
            
        return current_llm, current_attr, None
        
    except Exception as e:
        error_msg = f"Error initializing model/traceback: {str(e)}"
        print(error_msg)
        traceback.print_exc()
        return None, None, error_msg

# Remove immediate initialization - let lazy initialization work
llm, attr, error_msg = initialize_model_and_attr()  # Commented out to avoid main-thread CUDA initialization

# Images replaced with CSS textures and gradients - no longer needed

def clear_state():
    return State(
        workflow_state=WorkflowState.WAITING_TO_GENERATE,
        context="",
        query="",
        response="",
        start_index=0,
        end_index=0,
        scores=np.array([]),
        answer="",
        highlighted_context="",
        full_response="",
        explained_response_part="",
        last_query_used=""
    )

def load_an_example(example_loader_func, state: State):
    context, query = example_loader_func()
    # Update both UI and state
    state.context = context
    state.query = query
    state.workflow_state = WorkflowState.WAITING_TO_GENERATE
    # Clear previous results
    state.response = ""
    state.answer = ""
    state.full_response = ""
    state.explained_response_part = ""
    print(f"Loaded example - Context: {len(context)} chars, Query: {query[:50]}...")
    return (
        context,  # basic_context_box
        query,    # basic_query_box
        state, 
        "",       # response_input_box - clear it
        gr.update(value=[("Click the 'Generate/Use Response' button above to see response text here for traceback analysis.", None)]),  # basic_response_box - keep visible
        gr.update(selected=0)  # basic_context_tabs - switch to first tab
    )


def get_max_tokens(model_path: str):
    return MAX_TOKENS.get(model_path, 2048)  # Default fallback


def get_scroll_js_code(elem_id):
    return f"""
    function scrollToElement() {{
        const element = document.getElementById("{elem_id}");
        element.scrollIntoView({{ behavior: "smooth", block: "nearest" }});
    }}
    """

def basic_update(context: str, query: str, state: State):
    state.context = context
    state.query = query
    state.workflow_state = WorkflowState.WAITING_TO_GENERATE
    return (
        gr.update(value=[("Click the 'Generate/Use Response' button above to see response text here for traceback analysis.", None)]),  # basic_response_box - keep visible
        gr.update(selected=0),  # basic_context_tabs - switch to first tab
        state,
    )


    


@spaces.GPU
def generate_model_response(state: State):
    # Validate inputs first with debug info
    print(f"Validation - Context length: {len(state.context) if state.context else 0}")
    print(f"Validation - Query: {state.query[:50] if state.query else 'empty'}...")
    
    if not state.context or not state.context.strip():
        print("❌ Validation failed: No context")
        return state, gr.update(value=[("❌ Please enter context before generating response! If you just changed configuration, try reloading an example.", None)], visible=True)
    
    if not state.query or not state.query.strip():
        print("❌ Validation failed: No query")
        return state, gr.update(value=[("❌ Please enter a query before generating response! If you just changed configuration, try reloading an example.", None)], visible=True)
    
    # Initialize model and attribution with current configuration
    print(f"πŸ”§ Generating response with explanation_level: {current_explanation_level or DEFAULT_EXPLANATION_LEVEL}, top_k: {current_top_k or 3}")
    llm, attr, error_msg = initialize_model_and_attr()
    
    if llm is None or attr is None:
        error_text = error_msg if error_msg else "Model initialization failed!"
        return state, gr.update(value=[(f"❌ {error_text}", None)], visible=True)
    
    prompt = wrap_prompt(state.query, [state.context])
    print(f"Generated prompt for {DEFAULT_MODEL_PATH}: {prompt[:200]}...")  # Debug log
    
    # Check context length
    if len(prompt.split()) > get_max_tokens(DEFAULT_MODEL_PATH) - 512:
        return state, gr.update(value=[(GENERATE_CONTEXT_TOO_LONG_TEXT, None)], visible=True)
    
    answer = llm.query(prompt)
    print(f"Model response: {answer}")  # Debug log
    
    state.response = answer
    state.answer = answer
    state.full_response = answer
    state.workflow_state = WorkflowState.WAITING_TO_SELECT
    return state, gr.update(visible=False)

def split_into_sentences(text: str):
    def rule_based_split(text):
        sentences = []
        start = 0
        for i, char in enumerate(text):
            if char in ".?。":
                if i + 1 == len(text) or text[i + 1] == " ":
                    sentences.append(text[start:i + 1].strip())
                    start = i + 1
        if start < len(text):
            sentences.append(text[start:].strip())
        return sentences
        
    lines = text.splitlines()
    sentences = []
    for line in lines:
        #sentences.extend(sent_tokenize(line))
        sentences.extend(rule_based_split(line))
    separators = []
    cur_start = 0
    for sentence in sentences:
        cur_end = text.find(sentence, cur_start)
        separators.append(text[cur_start:cur_end])
        cur_start = cur_end + len(sentence)
    return sentences, separators


def basic_highlight_response(
    response: str, selected_index: int, num_sources: int = -1
):
    sentences, separators = split_into_sentences(response)
    ht = []
    if num_sources == -1:
        citations_text = "Traceback!"
    elif num_sources == 0:
        citations_text = "No important text!"
    else:
        citations_text = f"[{','.join(str(i) for i in range(1, num_sources + 1))}]"
    for i, (sentence, separator) in enumerate(zip(sentences, separators)):
        label = citations_text if i == selected_index else "Traceback"
        # Hack to ignore punctuation
        if len(sentence) >= 4:
            ht.append((separator + sentence, label))
        else:
            ht.append((separator + sentence, None))
    color_map = {"Click to cite!": "blue", citations_text: "yellow"}
    return gr.HighlightedText(value=ht, color_map=color_map)

def basic_highlight_response_with_visibility(
    response: str, selected_index: int, num_sources: int = -1, visible: bool = True
):
    """Version of basic_highlight_response that also sets visibility"""
    sentences, separators = split_into_sentences(response)
    ht = []
    if num_sources == -1:
        citations_text = "Traceback!"
    elif num_sources == 0:
        citations_text = "No important text!"
    else:
        citations_text = f"[{','.join(str(i) for i in range(1, num_sources + 1))}]"
    for i, (sentence, separator) in enumerate(zip(sentences, separators)):
        label = citations_text if i == selected_index else "Traceback"
        # Hack to ignore punctuation
        if len(sentence) >= 4:
            ht.append((separator + sentence, label))
        else:
            ht.append((separator + sentence, None))
    color_map = {"Click to cite!": "blue", citations_text: "yellow"}
    return gr.update(value=ht, color_map=color_map, visible=visible)



def basic_update_highlighted_response(evt: gr.SelectData, state: State):
    response_update = basic_highlight_response(state.response, evt.index)
    return response_update, state

def unified_response_handler(response_text: str, state: State):
    """Handle both LLM generation and manual input based on whether text is provided"""
    
    # Check if instruction has changed from what was used to generate current response
    instruction_changed = hasattr(state, 'last_query_used') and state.last_query_used != state.query
    
    # If response_text is empty, whitespace, or instruction changed, generate from LLM
    if not response_text or not response_text.strip() or instruction_changed:
        if instruction_changed:
            print("πŸ“ Instruction changed, generating new response from LLM...")
        else:
            print("πŸ€– Generating response from LLM...")
        
        # Validate inputs first
        if not state.context or not state.context.strip():
            return (
                state, 
                response_text,  # Keep current text box content
                gr.update(visible=False),  # Keep response box hidden
                gr.update(value=[("❌ Please enter context before generating response!", None)], visible=True)
            )
        
        if not state.query or not state.query.strip():
            return (
                state, 
                response_text,  # Keep current text box content
                gr.update(visible=False),  # Keep response box hidden
                gr.update(value=[("❌ Please enter a query before generating response!", None)], visible=True)
            )
        
        # Initialize model and generate response
        llm, attr, error_msg = initialize_model_and_attr()
        
        if llm is None:
            error_text = error_msg if error_msg else "Model initialization failed!"
            return (
                state, 
                response_text,  # Keep current text box content
                gr.update(visible=False),  # Keep response box hidden
                gr.update(value=[(f"❌ {error_text}", None)], visible=True)
            )
        
        prompt = wrap_prompt(state.query, [state.context])
        
        # Check context length
        if len(prompt.split()) > get_max_tokens(DEFAULT_MODEL_PATH) - 512:
            return (
                state, 
                response_text,  # Keep current text box content
                gr.update(visible=False),  # Keep response box hidden
                gr.update(value=[(GENERATE_CONTEXT_TOO_LONG_TEXT, None)], visible=True)
            )
        
        # Generate response
        answer = llm.query(prompt)
        print(f"Generated response: {answer[:100]}...")
        
        # Update state and UI
        state.response = answer
        state.answer = answer
        state.full_response = answer
        state.last_query_used = state.query  # Track which query was used for this response
        state.workflow_state = WorkflowState.WAITING_TO_SELECT
        
        # Create highlighted response and show it
        response_update = basic_highlight_response_with_visibility(state.response, -1, visible=True)
        
        return (
            state,
            answer,  # Put generated response in text box
            response_update,  # Update clickable response content
            gr.update(visible=False)  # Hide error box
        )
    
    else:
        # Use provided text as manual response
        print("✏️ Using manual response...")
        manual_text = response_text.strip()
        
        # Update state with manual response
        state.response = manual_text
        state.answer = manual_text
        state.full_response = manual_text
        state.last_query_used = state.query  # Track current query for this response
        state.workflow_state = WorkflowState.WAITING_TO_SELECT
        
        # Create highlighted response for selection
        response_update = basic_highlight_response_with_visibility(state.response, -1, visible=True)
        
        return (
            state,
            manual_text,  # Keep text in text box
            response_update,  # Update clickable response content
            gr.update(visible=False)  # Hide error box
        )

def get_color_by_rank(rank, total_items):
    """Get color based purely on rank position for better visual distinction"""
    if total_items == 0:
        return "#F0F0F0", "rgba(240, 240, 240, 0.8)"
    
    # Pure ranking-based color assignment for clear visual hierarchy
    if rank == 1:  # Highest importance - Strong Red
        bg_color = "#FF4444"  # Bright red
        rgba_color = "rgba(255, 68, 68, 0.9)"
    elif rank == 2:  # Second highest - Orange  
        bg_color = "#FF8C42"  # Bright orange
        rgba_color = "rgba(255, 140, 66, 0.8)"
    elif rank == 3:  # Third highest - Golden Yellow
        bg_color = "#FFD93D"  # Golden yellow
        rgba_color = "rgba(255, 217, 61, 0.8)"
    elif rank <= 5:  # 4th-5th - Light Yellow
        bg_color = "#FFF280"  # Standard yellow
        rgba_color = "rgba(255, 242, 128, 0.7)"
    else:  # Lower importance - Very Light Yellow
        bg_color = "#FFF9C4"  # Very light yellow
        rgba_color = "rgba(255, 249, 196, 0.6)"
    
    return bg_color, rgba_color

@spaces.GPU
def basic_get_scores_and_sources_full_response(state: State):
    """Traceback the entire response instead of a selected segment"""
    
    
    # Use the entire response as the explained part
    state.explained_response_part = state.full_response

    # Attribution using default configuration
    llm, attr, error_msg = initialize_model_and_attr()
    
    if attr is None:
        error_text = error_msg if error_msg else "Traceback initialization failed!"
        return (
            gr.update(value=[("", None)], visible=False),
            gr.update(selected=0),
            gr.update(visible=False),
            gr.update(value=""),
            gr.update(value=[(f"❌ {error_text}", None)], visible=True),
            state,
        )
    try:
        # Validate attribution inputs
        if not state.context or not state.context.strip():
            return (
                gr.update(value=[("", None)], visible=False),
                gr.update(selected=0),
                gr.update(visible=False),
                gr.update(value=""),
                gr.update(value=[("❌ No context available for traceback!", None)], visible=True),
                state,
            )
            
        if not state.query or not state.query.strip():
            return (
                gr.update(value=[("", None)], visible=False),
                gr.update(selected=0),
                gr.update(visible=False),
                gr.update(value=""),
                gr.update(value=[("❌ No query available for traceback!", None)], visible=True),
                state,
            )
            
        if not state.full_response or not state.full_response.strip():
            return (
                gr.update(value=[("", None)], visible=False),
                gr.update(selected=0),
                gr.update(visible=False),
                gr.update(value=""),
                gr.update(value=[("❌ No response available for traceback!", None)], visible=True),
                state,
            )
        
        print(f"start full response traceback with explanation_level: {DEFAULT_EXPLANATION_LEVEL}")
        print(f"context length: {len(state.context)}, query: {state.query[:100]}...")
        print(f"full response: {state.full_response[:100]}...")
        print(f"tracing entire response (length: {len(state.full_response)} chars)")
        
        texts, important_ids, importance_scores, _, _ = attr.attribute(
            state.query, [state.context], state.full_response, state.full_response
        )
        print("end full response traceback")
        print(f"explanation_level: {DEFAULT_EXPLANATION_LEVEL}")
        print(f"texts count: {len(texts)} (how context was segmented)")
        if len(texts) > 0:
            print(f"sample text segments: {[text[:50] + '...' if len(text) > 50 else text for text in texts[:3]]}")
        print(f"important_ids: {important_ids}")
        print("importance_scores: ", importance_scores)
        
        if not importance_scores:
            return (
                gr.update(value=[("", None)], visible=False),
                gr.update(selected=0),
                gr.update(visible=False),
                gr.update(value=""),
                gr.update(value=[("❌ No traceback scores generated for full response!", None)], visible=True),
                state,
            )
            
        state.scores = np.array(importance_scores)
        
        # Highlighted sources with ranking-based colors
        highlighted_text = []
        sorted_indices = np.argsort(state.scores)[::-1]
        total_sources = len(important_ids)
        
        for rank, i in enumerate(sorted_indices):
            source_text = texts[important_ids[i]]
            _ = get_color_by_rank(rank + 1, total_sources)
            
            highlighted_text.append(
                (
                    source_text,
                    f"rank_{rank+1}",
                )
            )
        
        # In-context highlights with ranking-based colors - show ALL text
        in_context_highlighted_text = []
        ranks = {important_ids[i]: rank for rank, i in enumerate(sorted_indices)}
        
        for i in range(len(texts)):
            source_text = texts[i]
            
            # Skip or don't highlight segments that are only newlines or whitespace
            if source_text.strip() == "":
                # For whitespace-only segments, add them without highlighting
                in_context_highlighted_text.append((source_text, None))
            elif i in important_ids:
                # Only highlight if the segment has actual content (not just newlines)
                if source_text.strip():  # Has non-whitespace content
                    rank = ranks[i] + 1
                    
                    # Split the segment to separate leading/trailing newlines from content
                    # This prevents newlines from being highlighted
                    leading_whitespace = ""
                    trailing_whitespace = ""
                    content = source_text
                    
                    # Extract leading newlines/whitespace
                    while content and content[0] in ['\n', '\r', '\t', ' ']:
                        leading_whitespace += content[0]
                        content = content[1:]
                    
                    # Extract trailing newlines/whitespace
                    while content and content[-1] in ['\n', '\r', '\t', ' ']:
                        trailing_whitespace = content[-1] + trailing_whitespace
                        content = content[:-1]
                    
                    # Add the parts separately: whitespace unhighlighted, content highlighted
                    if leading_whitespace:
                        in_context_highlighted_text.append((leading_whitespace, None))
                    if content:
                        in_context_highlighted_text.append((content, f"rank_{rank}"))
                    if trailing_whitespace:
                        in_context_highlighted_text.append((trailing_whitespace, None))
                else:
                    # Even if marked as important, don't highlight whitespace-only segments
                    in_context_highlighted_text.append((source_text, None))
            else:
                # Add unhighlighted text for non-important segments
                in_context_highlighted_text.append((source_text, None))
        
        # Enhanced color map with ranking-based colors
        color_map = {}
        for rank in range(len(important_ids)):
            _, rgba_color = get_color_by_rank(rank + 1, total_sources)
            color_map[f"rank_{rank+1}"] = rgba_color
        dummy_update = gr.update(
            value=f"AttnTrace_{state.response}_{state.start_index}_{state.end_index}"
        )
        attribute_error_update = gr.update(visible=False)
        
        # Combine sources and highlighted context into a single display
        # Sources at the top
        combined_display = []
        
        # Add sources header (no highlighting for UI elements)
        combined_display.append(("═══ FULL RESPONSE TRACEBACK RESULTS ═══\n", None))
        combined_display.append(("These are the text segments that contribute most to the entire response:\n\n", None))
        
        # Add sources using available data
        for rank, i in enumerate(sorted_indices):
            if i < len(important_ids):
                source_text = texts[important_ids[i]]
                
                # Strip leading/trailing whitespace from source text to avoid highlighting newlines
                clean_source_text = source_text.strip()
                
                if clean_source_text:  # Only add if there's actual content
                    # Add the source text with highlighting, then add spacing without highlighting
                    combined_display.append((clean_source_text, f"rank_{rank+1}"))
                    combined_display.append(("\n\n", None))
        
        # Add separator (no highlighting for UI elements)
        combined_display.append(("\n" + "═"*50 + "\n", None))
        combined_display.append(("FULL CONTEXT WITH HIGHLIGHTS\n", None))
        combined_display.append(("Scroll down to see the complete context with important segments highlighted:\n\n", None))
        
        # Add highlighted context using in_context_highlighted_text
        combined_display.extend(in_context_highlighted_text)
        
        # Use only the ranking colors (no highlighting for UI elements)
        enhanced_color_map = color_map.copy()
        
        combined_sources_update = HighlightedTextbox(
            value=combined_display, color_map=enhanced_color_map, visible=True
        )
        
        # Switch to the highlighted context tab and show results
        basic_context_tabs_update = gr.update(selected=1)
        basic_sources_in_context_tab_update = gr.update(visible=True)
        
        return (
            combined_sources_update,
            basic_context_tabs_update,
            basic_sources_in_context_tab_update,
            dummy_update,
            attribute_error_update,
            state,
        )
    except Exception as e:
        traceback.print_exc()
        return (
            gr.update(value=[("", None)], visible=False),
            gr.update(selected=0),
            gr.update(visible=False),
            gr.update(value=""),
            gr.update(value=[(f"❌ Error: {str(e)}", None)], visible=True),
            state,
        )

def basic_get_scores_and_sources(
    evt: gr.SelectData,
    highlighted_response: List[Dict[str, str]],
    state: State,
):
    
    # Get the selected sentence
    print("highlighted_response: ", highlighted_response[evt.index])
    selected_text = highlighted_response[evt.index]['token']
    state.explained_response_part = selected_text

    # Attribution using default configuration
    llm, attr, error_msg = initialize_model_and_attr()
    
    if attr is None:
        error_text = error_msg if error_msg else "Traceback initialization failed!"
        return (
            gr.update(value=[("", None)], visible=False),
            gr.update(selected=0),
            gr.update(visible=False),
            gr.update(value=""),
            gr.update(value=[(f"❌ {error_text}", None)], visible=True),
            state,
        )
    try:
        # Validate attribution inputs
        if not state.context or not state.context.strip():
            return (
                gr.update(value=[("", None)], visible=False),
                gr.update(selected=0),
                gr.update(visible=False),
                gr.update(value=""),
                gr.update(value=[("❌ No context available for traceback!", None)], visible=True),
                state,
            )
            
        if not state.query or not state.query.strip():
            return (
                gr.update(value=[("", None)], visible=False),
                gr.update(selected=0),
                gr.update(visible=False),
                gr.update(value=""),
                gr.update(value=[("❌ No query available for traceback!", None)], visible=True),
                state,
            )
            
        if not state.full_response or not state.full_response.strip():
            return (
                gr.update(value=[("", None)], visible=False),
                gr.update(selected=0),
                gr.update(visible=False),
                gr.update(value=""),
                gr.update(value=[("❌ No response available for traceback!", None)], visible=True),
                state,
            )
        
        print(f"start traceback with explanation_level: {DEFAULT_EXPLANATION_LEVEL}")
        print(f"context length: {len(state.context)}, query: {state.query[:100]}...")
        print(f"response: {state.full_response[:100]}...")
        print(f"selected part: {state.explained_response_part[:100]}...")
        
        texts, important_ids, importance_scores, _, _ = attr.attribute(
            state.query, [state.context], state.full_response, state.explained_response_part
        )
        print("end traceback")
        print(f"explanation_level: {DEFAULT_EXPLANATION_LEVEL}")
        print(f"texts count: {len(texts)} (how context was segmented)")
        if len(texts) > 0:
            print(f"sample text segments: {[text[:50] + '...' if len(text) > 50 else text for text in texts[:3]]}")
        print(f"important_ids: {important_ids}")
        print("importance_scores: ", importance_scores)
        
        if not importance_scores:
            return (
                gr.update(value=[("", None)], visible=False),
                gr.update(selected=0),
                gr.update(visible=False),
                gr.update(value=""),
                gr.update(value=[("❌ No traceback scores generated! Try a different text segment.", None)], visible=True),
                state,
            )
            
        state.scores = np.array(importance_scores)
        
        # Highlighted sources with ranking-based colors
        highlighted_text = []
        sorted_indices = np.argsort(state.scores)[::-1]
        total_sources = len(important_ids)
        
        for rank, i in enumerate(sorted_indices):
            source_text = texts[important_ids[i]]
            _ = get_color_by_rank(rank + 1, total_sources)
            
            highlighted_text.append(
                (
                    source_text,
                    f"rank_{rank+1}",
                )
            )
        
        # In-context highlights with ranking-based colors - show ALL text
        in_context_highlighted_text = []
        ranks = {important_ids[i]: rank for rank, i in enumerate(sorted_indices)}
        
        for i in range(len(texts)):
            source_text = texts[i]
            
            # Skip or don't highlight segments that are only newlines or whitespace
            if source_text.strip() == "":
                # For whitespace-only segments, add them without highlighting
                in_context_highlighted_text.append((source_text, None))
            elif i in important_ids:
                # Only highlight if the segment has actual content (not just newlines)
                if source_text.strip():  # Has non-whitespace content
                    rank = ranks[i] + 1
                    
                    # Split the segment to separate leading/trailing newlines from content
                    # This prevents newlines from being highlighted
                    leading_whitespace = ""
                    trailing_whitespace = ""
                    content = source_text
                    
                    # Extract leading newlines/whitespace
                    while content and content[0] in ['\n', '\r', '\t', ' ']:
                        leading_whitespace += content[0]
                        content = content[1:]
                    
                    # Extract trailing newlines/whitespace
                    while content and content[-1] in ['\n', '\r', '\t', ' ']:
                        trailing_whitespace = content[-1] + trailing_whitespace
                        content = content[:-1]
                    
                    # Add the parts separately: whitespace unhighlighted, content highlighted
                    if leading_whitespace:
                        in_context_highlighted_text.append((leading_whitespace, None))
                    if content:
                        in_context_highlighted_text.append((content, f"rank_{rank}"))
                    if trailing_whitespace:
                        in_context_highlighted_text.append((trailing_whitespace, None))
                else:
                    # Even if marked as important, don't highlight whitespace-only segments
                    in_context_highlighted_text.append((source_text, None))
            else:
                # Add unhighlighted text for non-important segments
                in_context_highlighted_text.append((source_text, None))
        
        # Enhanced color map with ranking-based colors
        color_map = {}
        for rank in range(len(important_ids)):
            _, rgba_color = get_color_by_rank(rank + 1, total_sources)
            color_map[f"rank_{rank+1}"] = rgba_color
        dummy_update = gr.update(
            value=f"AttnTrace_{state.response}_{state.start_index}_{state.end_index}"
        )
        attribute_error_update = gr.update(visible=False)
        
        # Combine sources and highlighted context into a single display
        # Sources at the top
        combined_display = []
        
        # Add sources header (no highlighting for UI elements)
        combined_display.append(("═══ TRACEBACK RESULTS ═══\n", None))
        combined_display.append(("These are the text segments that contribute most to the response:\n\n", None))
        
        # Add sources using available data
        for rank, i in enumerate(sorted_indices):
            if i < len(important_ids):
                source_text = texts[important_ids[i]]
                
                # Strip leading/trailing whitespace from source text to avoid highlighting newlines
                clean_source_text = source_text.strip()
                
                if clean_source_text:  # Only add if there's actual content
                    # Add the source text with highlighting, then add spacing without highlighting
                    combined_display.append((clean_source_text, f"rank_{rank+1}"))
                    combined_display.append(("\n\n", None))
        
        # Add separator (no highlighting for UI elements)
        combined_display.append(("\n" + "═"*50 + "\n", None))
        combined_display.append(("FULL CONTEXT WITH HIGHLIGHTS\n", None))
        combined_display.append(("Scroll down to see the complete context with important segments highlighted:\n\n", None))
        
        # Add highlighted context using in_context_highlighted_text
        combined_display.extend(in_context_highlighted_text)
        
        # Use only the ranking colors (no highlighting for UI elements)
        enhanced_color_map = color_map.copy()
        
        combined_sources_update = HighlightedTextbox(
            value=combined_display, color_map=enhanced_color_map, visible=True
        )
        
        # Switch to the highlighted context tab and show results
        basic_context_tabs_update = gr.update(selected=1)
        basic_sources_in_context_tab_update = gr.update(visible=True)
        
        return (
            combined_sources_update,
            basic_context_tabs_update,
            basic_sources_in_context_tab_update,
            dummy_update,
            attribute_error_update,
            state,
        )
    except Exception as e:
        traceback.print_exc()
        return (
            gr.update(value=[("", None)], visible=False),
            gr.update(selected=0),
            gr.update(visible=False),
            gr.update(value=""),
            gr.update(value=[(f"❌ Error: {str(e)}", None)], visible=True),
            state,
        )

def load_custom_css():
    """Load CSS from external file"""
    try:
        with open("assets/app_styles.css", "r") as f:
            css_content = f.read()
        return css_content
    except FileNotFoundError:
        print("Warning: CSS file not found, using minimal CSS")
        return ""
    except Exception as e:
        print(f"Error loading CSS: {e}")
        return ""

# Load CSS from external file
custom_css = load_custom_css()
theme = gr.themes.Citrus(
    text_size="lg",
    spacing_size="md",
)
with gr.Blocks(theme=theme, css=custom_css) as demo:
    gr.Markdown(f"# {APP_TITLE}")
    gr.Markdown(APP_DESCRIPTION, elem_classes="app-description")
    # gr.Markdown(NEW_TEXT, elem_classes="app-description-2")

    gr.Markdown("""
    <div style="font-size: 18px;">          
    AttnTrace is an efficient context traceback method for long contexts (e.g., full papers). It is over 15Γ— faster than the state-of-the-art context traceback method TracLLM. Compared to previous attention-based approaches, AttnTrace is more accurate, reliable, and memory-efficient.
    """, elem_classes="feature-highlights")
    # Feature highlights
    gr.Markdown("""
    <div style="font-size: 18px;">    
    AttnTrace can be used in many real-world applications, such as tracing back to: 
    
    - πŸ“„ prompt injection instructions that manipulate LLM-generated paper reviews.
    - πŸ’» malicious comment & code hiding in the codebase that misleads the AI coding assistant.
    - πŸ€– malicious instructions that mislead the action of the LLM agent.
    - πŸ–‹ source texts in the context from an AI summary.
    - πŸ” evidence that supports the LLM-generated answer for a question.
    - ❌ misinformation (corrupted knowledge) that manipulates LLM output for a question.
    - And a lot more...
    
    </div>
    """, elem_classes="feature-highlights")
    
    # Example buttons with topic-relevant images - moved here for better positioning
    gr.Markdown("### πŸš€ Try These Examples!", elem_classes="example-title")
    with gr.Row(elem_classes=["example-button-container"]):
        with gr.Column(scale=1):
            example_1_btn = gr.Button(
                "πŸ“„ Prompt Injection Attacks in AI Paper Review",
                elem_classes=["example-button", "example-paper"],
                elem_id="example_1_button",
                scale=None,
                size="sm"
            )
        with gr.Column(scale=1):
            example_2_btn = gr.Button(
                "πŸ’» Malicious Comments & Code in Codebase",
                elem_classes=["example-button", "example-movie"],
                elem_id="example_2_button"
            )
        with gr.Column(scale=1):
            example_3_btn = gr.Button(
                "πŸ€– Malicious Instructions Misleading the LLM Agent",
                elem_classes=["example-button", "example-code"],
                elem_id="example_3_button"
            )
    
    with gr.Row(elem_classes=["example-button-container"]):
        with gr.Column(scale=1):
            example_4_btn = gr.Button(
                "πŸ–‹ Source Texts for an AI Summary",
                elem_classes=["example-button", "example-paper-alt"],
                elem_id="example_4_button"
            )
        with gr.Column(scale=1):
            example_5_btn = gr.Button(
                "πŸ” Evidence that Support Question Answering",
                elem_classes=["example-button", "example-movie-alt"],
                elem_id="example_5_button"
            )
        with gr.Column(scale=1):
            example_6_btn = gr.Button(
                "❌ Misinformation (Corrupted Knowledge) in Question Answering",
                elem_classes=["example-button", "example-code-alt"],
                elem_id="example_6_button"
            )

    state = gr.State(
        value=clear_state()
    )

    # Create tabs for Demo and Configuration
    with gr.Tabs() as main_tabs:
        # Demo Tab
        with gr.Tab("Demo", id="demo_tab"):
            gr.Markdown(
                "Enter your context and instruction below to try out AttnTrace! You can also click on the example buttons above to load pre-configured examples."
            )
            
            gr.Markdown(
                '**Color Legend for Context Traceback (by ranking):** <span style="background-color: #FF4444; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Red</span> = 1st (most important) | <span style="background-color: #FF8C42; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Orange</span> = 2nd | <span style="background-color: #FFD93D; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Golden</span> = 3rd | <span style="background-color: #FFF280; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Yellow</span> = 4th-5th | <span style="background-color: #FFF9C4; color: black; padding: 2px 6px; border-radius: 4px; font-weight: 600;">Light</span> = 6th+'
            )
            
            # Top section: Wide Context box with tabs
            with gr.Row():
                with gr.Column(scale=1):
                    with gr.Tabs() as basic_context_tabs:
                        with gr.TabItem("Context", id=0):
                            basic_context_box = gr.Textbox(
                                placeholder="Enter context...",
                                show_label=False,
                                value="",
                                lines=6,
                                max_lines=6,
                                elem_id="basic_context_box",
                                autoscroll=False,
                            )
                        with gr.TabItem("Context with highlighted traceback results", id=1, visible=True) as basic_sources_in_context_tab:
                            basic_sources_in_context_box = HighlightedTextbox(
                                value=[("Click on a sentence in the response below to see highlighted traceback results here.", None)],
                                show_legend_label=False,
                                show_label=False,
                                show_legend=False,
                                interactive=False,
                                elem_id="basic_sources_in_context_box",
                            )
            
            # Error messages
            basic_generate_error_box = HighlightedTextbox(
                show_legend_label=False,
                show_label=False,
                show_legend=False,
                visible=False,
                interactive=False,
                container=False,
            )

            # Bottom section: Left (instruction + button + response), Right (response selection)
            with gr.Row(equal_height=True):
                # Left: Instruction + Button + Response
                with gr.Column(scale=1):
                    basic_query_box = gr.Textbox(
                        label="Instruction",
                        placeholder="Enter an instruction...",
                        value="",
                        lines=3,
                        max_lines=3,
                    )
                    
                    unified_response_button = gr.Button(
                        "Generate/Use Response", 
                        variant="primary", 
                        size="lg"
                    )
                    
                    response_input_box = gr.Textbox(
                        label="Response (Editable)",
                        placeholder="Response will appear here after generation, or type your own response for traceback...",
                        lines=8,
                        max_lines=8,
                        info="Leave empty and click button to generate from LLM, or type your own response to use for traceback"
                    )
                
                # Right: Response for attribution selection
                with gr.Column(scale=1):
                    basic_response_box = gr.HighlightedText(
                        label="Click to select text for traceback!",
                        value=[("Click the 'Generate/Use Response' button on the left to see response text here for traceback analysis.", None)],
                        interactive=False,
                        combine_adjacent=False,
                        show_label=True,
                        show_legend=False,
                        elem_id="basic_response_box",
                        visible=True,
                    )
                    
                    # Button for full response traceback
                    full_response_traceback_button = gr.Button(
                        "πŸ” Traceback Entire Response",
                        variant="secondary",
                        size="sm"
                    )
            
            # Hidden error box and dummy elements
            basic_attribute_error_box = HighlightedTextbox(
                show_legend_label=False,
                show_label=False,
                show_legend=False,
                visible=False,
                interactive=False,
                container=False,
            )
            dummy_basic_sources_box = gr.Textbox(
                visible=False, interactive=False, container=False
            )

        # Configuration Tab
        with gr.Tab("Config", id="config_tab"):
            gr.Markdown("## βš™οΈ AttnTrace Configuration")
            gr.Markdown("Configure the traceback analysis parameters to customize how AttnTrace processes your context and generates results.")
            
            with gr.Row():
                with gr.Column(scale=1):
                    explanation_level_dropdown = gr.Dropdown(
                        choices=["sentence", "paragraph", "text segment"],
                        value="sentence",
                        label="Explanation Level",
                        info="How to segment the context for traceback analysis"
                    )
                with gr.Column(scale=1):
                    top_k_dropdown = gr.Dropdown(
                        choices=["3", "5", "10"],
                        value="3",
                        label="Top-N Value",
                        info="Number of most important text segments to highlight"
                    )
            
            with gr.Row():
                with gr.Column(scale=1):
                    B_slider = gr.Slider(
                        minimum=1,
                        maximum=100,
                        value=30,
                        step=5,
                        label="B Parameter",
                        info="Number of subsamples (higher = more accurate but slower)"
                    )
                with gr.Column(scale=1):
                    q_slider = gr.Slider(
                        minimum=0.1,
                        maximum=1.0,
                        value=0.4,
                        step=0.1,
                        label="ρ Parameter",
                        info="Sub-sampling ratio (0.1-1.0)"
                    )
            
            with gr.Row():
                with gr.Column(scale=1):
                    apply_config_button = gr.Button(
                        "Apply Configuration",
                        variant="primary",
                        size="lg"
                    )
                with gr.Column(scale=2):
                    config_status_text = gr.Textbox(
                        label="Configuration Status",
                        value="Ready to apply configuration",
                        interactive=False,
                        lines=2
                    )
            
            gr.Markdown("### πŸ“‹ Current Configuration")
            gr.Markdown("""
            - **Explanation Level**: Determines how the context is segmented for analysis
              - `sentence`: Analyze at sentence level (recommended for most cases)
              - `paragraph`: Analyze at paragraph level (good for longer documents)
              - `text segment`: Analyze at the level of 100-word text segments (ideal for non-standard document formats)
            
            - **Top-N Value**: Number of most important text segments to highlight in results
              - Higher values show more context but may be less focused
              - Lower values provide more focused results but may miss some context
            
            - **B Parameter**: Number of subsamples
              - Higher values (50-100): More thorough analysis but slower
              - Lower values (10-30): Faster analysis but may miss some important segments
              - Default: 30 (good balance of speed and accuracy)
            
            - **ρ Parameter**: Sub-sampling ratio (0.1-1.0)

            
            **Note**: Configuration changes will take effect immediately for new traceback operations.
            """)
            
            gr.Markdown("### πŸ”„ Model Information")
            gr.Markdown(f"""
            - **Current Model**: {DEFAULT_MODEL_PATH}
            - **Max Tokens**: {get_max_tokens(DEFAULT_MODEL_PATH):,}
            - **Device**: CUDA (GPU accelerated)
            """)

    # Only a single (AttnTrace) method and model in this simplified version

    def basic_clear_state():
        state = clear_state()
        return (
            "",  # basic_context_box
            "",  # basic_query_box
            "",  # response_input_box
            gr.update(value=[("Click the 'Generate/Use Response' button above to see response text here for traceback analysis.", None)]),  # basic_response_box - keep visible
            gr.update(selected=0),  # basic_context_tabs - switch to first tab
            state,
        )

    # Defining behavior of various interactions for the demo tab only
    def handle_demo_tab_selection(evt: gr.SelectData):
        """Handle tab selection - only clear state when switching to demo tab"""
        if evt.index == 0:  # Demo tab
            return basic_clear_state()
        else:  # Configuration tab - no state change needed
            return (
                gr.update(),  # basic_context_box
                gr.update(),  # basic_query_box
                gr.update(),  # response_input_box
                gr.update(),  # basic_response_box
                gr.update(),  # basic_context_tabs
                gr.update(),  # state
            )
    
    main_tabs.select(
        fn=handle_demo_tab_selection,
        inputs=[],
        outputs=[
            basic_context_box,
            basic_query_box,
            response_input_box,
            basic_response_box,
            basic_context_tabs,
            state,
        ],
    )
    for component in [basic_context_box, basic_query_box]:
        component.change(
            basic_update,
            [basic_context_box, basic_query_box, state],
            [
                basic_response_box,
                basic_context_tabs,
                state,
            ],
        )
    # Example button event handlers - now update both UI and state
    outputs_for_examples = [
        basic_context_box,
        basic_query_box,
        state,
        response_input_box,
        basic_response_box,
        basic_context_tabs,
    ]
    example_1_btn.click(
        fn=partial(load_an_example, run_example_1),
        inputs=[state],
        outputs=outputs_for_examples
    )
    example_2_btn.click(
        fn=partial(load_an_example, run_example_2),
        inputs=[state],
        outputs=outputs_for_examples
    )
    example_3_btn.click(
        fn=partial(load_an_example, run_example_3),
        inputs=[state],
        outputs=outputs_for_examples
    )
    example_4_btn.click(
        fn=partial(load_an_example, run_example_4),
        inputs=[state],
        outputs=outputs_for_examples
    )
    example_5_btn.click(
        fn=partial(load_an_example, run_example_5),
        inputs=[state],
        outputs=outputs_for_examples
    )
    example_6_btn.click(
        fn=partial(load_an_example, run_example_6),
        inputs=[state],
        outputs=outputs_for_examples
    )
    
    unified_response_button.click(
        fn=lambda: None,
        inputs=[],
        outputs=[],
        js=get_scroll_js_code("basic_response_box"),
    )
    basic_response_box.change(
        fn=lambda: None,
        inputs=[],
        outputs=[],
        js=get_scroll_js_code("basic_sources_in_context_box"),
    )
    # Add immediate tab switch on response selection
    def immediate_tab_switch():
        return (
            gr.update(value=[("πŸ”„ Processing traceback... Please wait...", None)]),  # Show progress message
            gr.update(selected=1),  # Switch to annotation tab immediately
        )
    
    basic_response_box.select(
        fn=immediate_tab_switch,
        inputs=[],
        outputs=[basic_sources_in_context_box, basic_context_tabs],
        queue=False,  # Execute immediately without queue
    )
    
    basic_response_box.select(
        fn=basic_get_scores_and_sources,
        inputs=[basic_response_box, state],
        outputs=[
            basic_sources_in_context_box,
            basic_context_tabs,
            basic_sources_in_context_tab,
            dummy_basic_sources_box,
            basic_attribute_error_box,
            state,
        ],
        show_progress="full",
    )
    basic_response_box.select(
        fn=basic_update_highlighted_response,
        inputs=[state],
        outputs=[basic_response_box, state],
    )
    
    # Full response traceback button
    full_response_traceback_button.click(
        fn=immediate_tab_switch,
        inputs=[],
        outputs=[basic_sources_in_context_box, basic_context_tabs],
        queue=False,  # Execute immediately without queue
    )
    
    full_response_traceback_button.click(
        fn=basic_get_scores_and_sources_full_response,
        inputs=[state],
        outputs=[
            basic_sources_in_context_box,
            basic_context_tabs,
            basic_sources_in_context_tab,
            dummy_basic_sources_box,
            basic_attribute_error_box,
            state,
        ],
        show_progress="full",
    )
    
    dummy_basic_sources_box.change(
        fn=lambda: None,
        inputs=[],
        outputs=[],
        js=get_scroll_js_code("basic_sources_in_context_box"),
    )

    # Unified response handler
    unified_response_button.click(
        fn=unified_response_handler,
        inputs=[response_input_box, state],
        outputs=[state, response_input_box, basic_response_box, basic_generate_error_box]
    )
    
    # Configuration update handler
    apply_config_button.click(
        fn=update_configuration,
        inputs=[explanation_level_dropdown, top_k_dropdown, B_slider, q_slider],
        outputs=[config_status_text]
    )

    # gr.Markdown(
    #     "Please do not interact with elements while generation/attribution is in progress. This may cause errors. You can refresh the page if you run into issues because of this."
    # )

demo.launch(show_api=False, share=True)