import gradio as gr import pandas as pd import os import re # from dotenv import load_dotenv import sys import gc import traceback from transformers import AutoTokenizer from nltk.util import ngrams from difflib import SequenceMatcher import io sys.path.append(os.path.dirname(__file__)) from solar_api import retranslate_single # --- Configuration --- # Assuming the script is run from the root of the project # load_dotenv() DATA_DIR = './data' TRAIN1_CSV = os.path.join(DATA_DIR, "train_1.csv") TRAIN2_CSV = os.path.join(DATA_DIR, "train_2.csv") DEV_CSV = os.path.join(DATA_DIR, "dev.csv") TEST_CSV = os.path.join(DATA_DIR, "test.csv") TRAIN_SOLAR_CSV = os.path.join(DATA_DIR, "train_solar_results_filtered") DEV_SOLAR_CSV = os.path.join(DATA_DIR, "val_solar_results_filtered.csv") SPECIAL_TOKENS = [ '#Person1#','#Person2#','#Person3#','#Person4#', '#Person5#','#Person6#','#Person7#', '#PhoneNumber#','#Address#','#DateOfBirth#','#PassportNumber#','#SSN#','#CardNumber#','#CarNumber#','#Email#' ] # UPSTAGE_API_KEY = os.getenv("UPSTAGE_API_KEY") # --- Load Data --- # try: t1 = pd.read_csv(TRAIN1_CSV) t2 = pd.read_csv(TRAIN2_CSV) train_df = pd.concat([t1,t2]) del t1, t2 gc.collect() val_df = pd.read_csv(DEV_CSV) test_df = pd.read_csv(TEST_CSV) llm_train_df = None llm_val_df = None dfs = [] for i in range(8): df = pd.read_csv(f"{TRAIN_SOLAR_CSV}{i+1}.csv") dfs.append(df) llm_train_df = pd.concat(dfs) del dfs, df gc.collect() llm_val_df = pd.read_csv(DEV_SOLAR_CSV) # --- Block 1: Explore by fname --- def get_NER(dialogue): # not implemented return None def get_sample_by_index(dataset_name, index, use_NER=True): df = train_df if dataset_name == "train" else val_df llm_df = llm_train_df if dataset_name == "train" else llm_val_df index = int(index) if 0 <= index < len(df): sample = df.iloc[index] # get NER by NER-model if exists. ner = get_NER(sample['dialogue']) if use_NER else None if ner is None: ner = "NER model is not ready." # summary_solar_filter,topic_solar_filter,dialogue_ko2en_filter,dialogue_en2ko_filter,re_summary_solar_filter,ner_solar_filter try: llm_sample = llm_df.iloc[index] except: llm_sample = {'summary_solar_filter': 'N/A','ner_solar_filter': 'N/A','re_summary_solar_filter': 'N/A','topic_solar_filter': 'N/A', 'dialogue_en2ko_filter': 'N/A'} return (sample['fname'], sample['dialogue'], sample['summary'], ner, sample['topic'], llm_sample['summary_solar_filter'],llm_sample['ner_solar_filter'],llm_sample['re_summary_solar_filter'],llm_sample['topic_solar_filter'],llm_sample['dialogue_en2ko_filter'], index) return "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", index def update_outputs(dataset_name, index): fname, dialogue, summary, ner, topic, llm_summary, llm_ner, re_summary, re_topic, re_dialogue, _ = get_sample_by_index(dataset_name, index) return fname, f"```\n{dialogue}\n```", f"```\n{summary}\n```", f"```\n{ner}\n```", f"```\n{topic}\n```", f"```\n{llm_summary}\n```", f"```\n{llm_ner}\n```", f"```\n{re_summary}\n```", f"```\n{re_topic}\n```", f"```\n{re_dialogue}\n```" def next_sample(dataset_name, index): df = train_df if dataset_name == "train" else val_df new_index = min(int(index) + 1, len(df) - 1) fname, dialogue, summary, ner, topic, llm_summary, llm_ner, re_summary, re_topic, re_dialogue, _ = get_sample_by_index(dataset_name, new_index) return new_index, fname, f"```\n{dialogue}\n```", f"```\n{summary}\n```", f"```\n{ner}\n```", f"```\n{topic}\n```", f"```\n{llm_summary}\n```", f"```\n{llm_ner}\n```", f"```\n{re_summary}\n```", f"```\n{re_topic}\n```", f"```\n{re_dialogue}\n```" def prev_sample(dataset_name, index): new_index = max(int(index) - 1, 0) fname, dialogue, summary, ner, topic, llm_summary, llm_ner, re_summary, re_topic, re_dialogue, _ = get_sample_by_index(dataset_name, new_index) return new_index, fname, f"```\n{dialogue}\n```", f"```\n{summary}\n```", f"```\n{ner}\n```", f"```\n{topic}\n```", f"```\n{llm_summary}\n```", f"```\n{llm_ner}\n```", f"```\n{re_summary}\n```", f"```\n{re_topic}\n```", f"```\n{re_dialogue}\n```" def reset_index_on_split_change(dataset_name): # When the dataset changes, reset the index to 0 and update the display fname, dialogue, summary, ner, topic, llm_summary, llm_ner, re_summary, re_topic, re_dialogue, _ = get_sample_by_index(dataset_name, 0) return 0, fname, f"```\n{dialogue}\n```", f"```\n{summary}\n```", f"```\n{ner}\n```", f"```\n{topic}\n```", f"```\n{llm_summary}\n```", f"```\n{llm_ner}\n```", f"```\n{re_summary}\n```", f"```\n{re_topic}\n```", f"```\n{re_dialogue}\n```" def get_sample_by_fname(dataset_name, fname): df = train_df if dataset_name == "train" else val_df # Find the index for the given fname if fname in df['fname'].values: index = df[df['fname'] == fname].index[0] # Use the existing function with the found index fname, dialogue, summary, ner, topic, llm_summary, llm_ner, re_summary, re_topic, re_dialogue, _ = get_sample_by_index(dataset_name, index) return index, fname, f"```\n{dialogue}\n```", f"```\n{summary}\n```", f"```\n{ner}\n```", f"```\n{topic}\n```", f"```\n{llm_summary}\n```", f"```\n{llm_ner}\n```", f"```\n{re_summary}\n```", f"```\n{re_topic}\n```", f"```\n{re_dialogue}\n```" else: # fname not found return -1, f"fname '{fname}' not found in {dataset_name} dataset.", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A", "N/A" def request_solar_api(api_key, model_name, dataset_name, index): if not api_key or not api_key.strip(): error_msg = "```\n[오류] UPSTAGE_API_KEY가 입력되지 않았습니다. API 키를 입력하고 다시 시도해주세요.\n```" return error_msg, error_msg, error_msg, error_msg, error_msg df = train_df if dataset_name == "train" else val_df index = int(index) row_data = df.iloc[index] # process_row expects a tuple of (index, series) row_for_api = (index, row_data) try: # retranslate_single returns (fname, results_list) fname, results = retranslate_single(row_for_api, api_key, model=model_name) # results is a list: [summary, topic, en2ko, re_summary, ner] summary, topic, re_dialogue, re_summary, ner = results # Match the output order for the UI: # llm_summary_out, llm_ner_out, re_summary_out, re_topic_out, re_dialogue_out return f"```\n{summary}\n```", f"```\n{ner}\n```", f"```\n{re_summary}\n```", f"```\n{topic}\n```", f"```\n{re_dialogue}\n```" except Exception as e: # It's possible the API key is invalid, causing an error inside retranslate_single error_msg = f"```\n[오류] API 호출에 실패했습니다. API 키가 잘못되었거나 네트워크 문제일 수 있습니다.\n{e}\n```" return error_msg, error_msg, error_msg, error_msg, error_msg # --- Block 2: Explore by Topic --- # Merge dataframes train_df['source'] = 'train' val_df['source'] = 'val' merged_df = pd.concat([train_df, val_df], ignore_index=True) merged_df['topic'] = merged_df['topic'].astype(str) topics = sorted(merged_df['topic'].unique().tolist()) topic_df_grouped = merged_df.groupby('topic') def get_topic_data(topic_name): if topic_name not in topic_df_grouped.groups: return "Topic not found.", "Topic not found." topic_df = topic_df_grouped.get_group(topic_name) # Clean and split topic name into keywords keywords = re.sub(r"[,\\.]", " ", topic_name, flags=re.IGNORECASE) keywords = re.sub(r"\\s+", " ", topic_name, flags=re.IGNORECASE) keywords = keywords.strip().split(" ") def highlight_keywords(text, keywords): for keyword in keywords: text = re.sub(f"({re.escape(keyword)})", rf'{keyword}', text, flags=re.IGNORECASE) return text train_output = "" val_output = "" for _, row in topic_df.iterrows(): dialogue = highlight_keywords(row['dialogue'], keywords) summary = highlight_keywords(row['summary'], keywords) formatted_output = f"fname: {row['fname']}\n\n" formatted_output += f"Dialogue:\n{dialogue}\n\n" formatted_output += f"Summary:\n{summary}\n\n" formatted_output += ("-" * 20) + "\n" if row['source'] == 'train': train_output += formatted_output elif row['source'] == 'val': val_output += formatted_output return train_output, val_output def update_topic_display_by_name(topic_name): return get_topic_data(topic_name) def change_topic(change, current_topic_name): current_index = topics.index(current_topic_name) new_index = (current_index + change + len(topics)) % len(topics) new_topic = topics[new_index] train_output, val_output = get_topic_data(new_topic) return new_topic, train_output, val_output # --- Block 3: Validation Inference Exploration --- def get_rouge_highlighted_html(text_a, text_b, tokenizer_name): # 두 개의 텍스트(text_a, text_b)를 비교하여 ROUGE 점수를 기반으로 HTML 하이라이팅을 적용합니다. # ROUGE-1 (unigram)은 노란색, ROUGE-2 (bigram)는 연두색, ROUGE-L (LCS)은 하늘색으로 표시됩니다. # 토크나이저는 Hugging Face의 AutoTokenizer를 사용하여 로드하며, SPECIAL_TOKENS를 추가합니다. try: # 지정된 이름의 토크나이저를 로드합니다. tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) # 사전에 정의된 SPECIAL_TOKENS를 토크나이저에 추가합니다. tokenizer.add_special_tokens({'additional_special_tokens': SPECIAL_TOKENS}) except Exception as e: # 토크나이저 로딩에 실패하면 오류 메시지를 반환합니다. error_msg = f"

Error loading tokenizer '{tokenizer_name}': {e}

" return error_msg, error_msg, None # 입력 텍스트를 토큰화합니다. tokens_a = tokenizer.tokenize(text_a) tokens_b = tokenizer.tokenize(text_b) # --- ROUGE Score Calculation --- # ROUGE-1 unigrams_a = set(tokens_a) unigrams_b = set(tokens_b) common_unigrams_count = len(unigrams_a & unigrams_b) r1_precision = common_unigrams_count / len(unigrams_b) if len(unigrams_b) > 0 else 0 r1_recall = common_unigrams_count / len(unigrams_a) if len(unigrams_a) > 0 else 0 r1_f1 = 2 * (r1_precision * r1_recall) / (r1_precision + r1_recall) if (r1_precision + r1_recall) > 0 else 0 # ROUGE-2 bigrams_a = set(ngrams(tokens_a, 2)) bigrams_b = set(ngrams(tokens_b, 2)) common_bigrams_count = len(bigrams_a & bigrams_b) r2_precision = common_bigrams_count / len(bigrams_b) if len(bigrams_b) > 0 else 0 r2_recall = common_bigrams_count / len(bigrams_a) if len(bigrams_a) > 0 else 0 r2_f1 = 2 * (r2_precision * r2_recall) / (r2_precision + r2_recall) if (r2_precision + r2_recall) > 0 else 0 # ROUGE-L matcher_for_score = SequenceMatcher(None, tokens_a, tokens_b, autojunk=False) lcs_length = sum(block.size for block in matcher_for_score.get_matching_blocks()) rl_precision = lcs_length / len(tokens_b) if len(tokens_b) > 0 else 0 rl_recall = lcs_length / len(tokens_a) if len(tokens_a) > 0 else 0 rl_f1 = 2 * (rl_precision * rl_recall) / (rl_precision + rl_recall) if (rl_precision + rl_recall) > 0 else 0 rouge_scores = { "rouge-1": r1_f1, "rouge-2": r2_f1, "rouge-l": rl_f1, } # --- End ROUGE Score Calculation --- # ROUGE 메트릭별 하이라이트 색상을 정의합니다. colors = { 'rouge-1': 'yellow', # 1-gram (단일 토큰) 일치 'rouge-2': 'lightgreen', # 2-gram (연속된 두 토큰) 일치 'rouge-l': 'lightcoral' # Longest Common Subsequence (최장 공통 부분 서열) 일치 } # 각 텍스트의 토큰에 대한 하이라이트 정보를 저장할 맵을 생성합니다. # 초기에는 하이라이트가 없는 상태(None)로 초기화합니다. highlight_map_a = [None] * len(tokens_a) highlight_map_b = [None] * len(tokens_b) # ROUGE-1 (unigrams) 하이라이팅 # 두 텍스트에 공통으로 나타나는 모든 단일 토큰(unigram)을 찾습니다. common_unigrams = unigrams_a & unigrams_b # 공통 unigram에 해당하는 토큰 위치에 'rouge-1' 색상을 매핑합니다. for i, token in enumerate(tokens_a): if token in common_unigrams: highlight_map_a[i] = colors['rouge-1'] for i, token in enumerate(tokens_b): if token in common_unigrams: highlight_map_b[i] = colors['rouge-1'] # ROUGE-2 (bigrams) 하이라이팅 # 두 텍스트에 공통으로 나타나는 모든 연속된 두 토큰(bigram)을 찾습니다. common_bigrams = bigrams_a & bigrams_b # 공통 bigram에 해당하는 토큰 위치에 'rouge-2' 색상을 덮어씁니다. # ROUGE-1 및 ROUGE-L보다 우선순위가 높게 적용됩니다. for i in range(len(tokens_a) - 1): if (tokens_a[i], tokens_a[i+1]) in common_bigrams: highlight_map_a[i] = colors['rouge-2'] highlight_map_a[i+1] = colors['rouge-2'] for i in range(len(tokens_b) - 1): if (tokens_b[i], tokens_b[i+1]) in common_bigrams: highlight_map_b[i] = colors['rouge-2'] highlight_map_b[i+1] = colors['rouge-2'] # ROUGE-L (LCS) 하이라이팅 # Longest Common Subsequence (최장 공통 부분 서열)를 찾습니다. # SequenceMatcher를 사용하여 두 토큰 시퀀스 간의 일치하는 블록을 찾습니다. matcher = SequenceMatcher(None, tokens_a, tokens_b, autojunk=False) for block in matcher.get_matching_blocks(): if block.size > 0: # 일치하는 블록의 모든 토큰 위치에 'rouge-l' 색상을 덮어씁니다. # ROUGE-1보다 우선순위가 높게 적용됩니다. for i in range(block.size): highlight_map_a[block.a + i] = colors['rouge-l'] highlight_map_b[block.b + i] = colors['rouge-l'] def build_html_from_map(tokens, h_map, tokenizer): # 하이라이트 맵을 기반으로 HTML을 생성합니다. # 연속된 동일한 색상의 토큰들을 하나의 태그로 묶어 효율성을 높입니다. if not tokens: return "" result_html = [] if not tokens: return "" # 첫 번째 토큰부터 시작합니다. current_color = h_map[0] token_buffer = [tokens[0]] # 두 번째 토큰부터 순회하면서 색상이 바뀌는 지점을 찾습니다. for i in range(1, len(tokens)): if h_map[i] != current_color: # 색상이 바뀌면, 버퍼에 쌓인 토큰들을 문자열로 변환하고 태그로 감쌉니다. text = tokenizer.convert_tokens_to_string(token_buffer) if current_color: result_html.append(f'{text}') else: result_html.append(text) # 버퍼를 비우고 새로운 토큰과 색상으로 다시 시작합니다. token_buffer = [tokens[i]] current_color = h_map[i] else: # 색상이 같으면 버퍼에 토큰을 추가합니다. token_buffer.append(tokens[i]) # 마지막 버퍼에 남은 토큰들을 처리합니다. text = tokenizer.convert_tokens_to_string(token_buffer) if current_color: result_html.append(f'{text}') else: result_html.append(text) return "".join(result_html) # 각 텍스트에 대해 하이라이트가 적용된 HTML을 생성합니다. html_a = build_html_from_map(tokens_a, highlight_map_a, tokenizer) html_b = build_html_from_map(tokens_b, highlight_map_b, tokenizer) # 최종적으로 생성된 HTML을

태그로 감싸서 반환합니다. return f"

{html_a}

", f"

{html_b}

", rouge_scores def get_validation_data(fname, inference_file, tokenizer_name): # 1. Get original data if fname not in val_df['fname'].values: msg = f"fname '{fname}' not found in validation dataset." return msg, msg, msg, msg, "N/A" original_sample = val_df[val_df['fname'] == fname].iloc[0] dialogue_orig = f"```\n{original_sample['dialogue']}\n```" topic_orig = f"```\n{original_sample['topic']}\n```" summary_orig_text = original_sample['summary'] # 2. Defaults for outputs summary_orig_out = f"

{summary_orig_text}

" summary_infer_out = "Please upload an inference CSV file and select a sample." rouge_scores_out = "Upload an inference file to see ROUGE scores." # 3. Process inference file if available if inference_file is not None: try: inference_df = pd.read_csv(inference_file.name) if fname in inference_df['fname'].values: inference_sample = inference_df[inference_df['fname'] == fname].iloc[0] summary_infer_text = inference_sample['summary'] summary_orig_out, summary_infer_out, rouge_scores = get_rouge_highlighted_html( summary_orig_text, summary_infer_text, tokenizer_name ) if rouge_scores: rouge_scores_out = ( f"**ROUGE-1 F1:** {rouge_scores['rouge-1']:.4f} \n" f"**ROUGE-2 F1:** {rouge_scores['rouge-2']:.4f} \n" f"**ROUGE-L F1:** {rouge_scores['rouge-l']:.4f}" ) else: rouge_scores_out = "Could not calculate ROUGE scores." else: msg = f"fname '{fname}' not found in the uploaded file." summary_infer_out = msg except Exception as e: tb_str = traceback.format_exc() msg = f"
Error reading or processing the file: {e}\n\nTraceback:\n{tb_str}
" summary_infer_out = msg rouge_scores_out = "Error calculating scores." return dialogue_orig, summary_orig_out, topic_orig, summary_infer_out, rouge_scores_out def get_val_sample_by_fname(fname, inference_file, tokenizer_name): if fname not in val_df['fname'].values: msg = f"fname '{fname}' not found." return -1, fname, msg, msg, msg, msg, "N/A" index = val_df[val_df['fname'] == fname].index[0] d_orig, s_orig, t_orig, s_infer, scores = get_validation_data(fname, inference_file, tokenizer_name) return index, fname, d_orig, s_orig, t_orig, s_infer, scores def change_val_sample(index_change, current_index, inference_file, tokenizer_name): current_index = int(current_index) new_index = current_index + index_change if not (0 <= new_index < len(val_df)): new_index = current_index fname = val_df.iloc[new_index]['fname'] d_orig, s_orig, t_orig, s_infer, scores = get_validation_data(fname, inference_file, tokenizer_name) return new_index, fname, d_orig, s_orig, t_orig, s_infer, scores # --- Gradio Interface --- with gr.Blocks() as demo: gr.Markdown("# Data Explorer for Dialogue Summarization") with gr.Tab("Explore by Sample (fname)"): # State to hold the current index current_index = gr.State(0) with gr.Row(): split_select = gr.Radio(["train", "val"], label="Dataset", value="train") with gr.Column(): prev_btn = gr.Button("Previous") next_btn = gr.Button("Next") fname_out = gr.Textbox(label="fname", interactive=True) search_btn = gr.Button("Search") with gr.Row(): with gr.Column(): origin_out = gr.Markdown(label="original dataset") dialogue_out = gr.Markdown(label="Dialogue") summary_out = gr.Markdown(label="Summary") ner_out = gr.Markdown(label='NER') topic_out = gr.Markdown(label="Topic") with gr.Column(): with gr.Row(): UPSTAGE_API_KEY = gr.Textbox(label="Your Upstage API Key", interactive=True) solar_model_select = gr.Dropdown(["solar-pro2", "solar-pro", "solar-mini"], label="Solar Model", value="solar-pro2") solar_btn = gr.Button("Request") llm_summary_desc = gr.Markdown(label="desc1") llm_summary_out = gr.Markdown(label="LLM Summary") llm_ner_desc = gr.Markdown(label="desc2") llm_ner_out = gr.Markdown(label='LLM_NER') re_summary_desc = gr.Markdown(label="desc3") re_summary_out = gr.Markdown(label="LLM Retranslate Summary") re_topic_desc = gr.Markdown(label="desc4") re_topic_out = gr.Markdown(label="LLM Retranslate Topic") re_dialogue_desc = gr.Markdown(label="desc5") re_dialogue_out = gr.Markdown(label="LLM Retranslate dialogue") # Initial load initial_fname, initial_dialogue, initial_summary, initial_ner, initial_topic, initial_llm_summary_out,initial_llm_ner_out,initial_re_summary_out,initial_re_topic_out,initial_re_dialogue_out, _ = get_sample_by_index("train", 0) # Set initial values demo.load( lambda: ( "### Original Dataset", 'Solar Summary result','Solar NER result','Back-translated(ko>en>ko) Solar Summary','Back-translated(ko>en>ko) Solar topic','Back-translated(ko>en>ko) Dialogue', initial_fname, f"```\n{initial_dialogue}\n```", f"```\n{initial_summary}\n```", f"```\n{initial_ner}\n```", f"```\n{initial_topic}\n```", f"```\n{initial_llm_summary_out}\n```", f"```\n{initial_llm_ner_out}\n```", f"```\n{initial_re_summary_out}\n```", f"```\n{initial_re_topic_out}\n```",f"```\n{initial_re_dialogue_out}\n```" ), inputs=None, outputs=[ origin_out, llm_summary_desc,llm_ner_desc,re_summary_desc,re_topic_desc,re_dialogue_desc, fname_out, dialogue_out, summary_out, ner_out, topic_out, llm_summary_out,llm_ner_out,re_summary_out,re_topic_out,re_dialogue_out ] ) # Event handlers for Block 1 search_btn.click( get_sample_by_fname, inputs=[split_select, fname_out], outputs=[current_index, fname_out, dialogue_out, summary_out, ner_out, topic_out, llm_summary_out,llm_ner_out,re_summary_out,re_topic_out,re_dialogue_out] ) next_btn.click( next_sample, inputs=[split_select, current_index], outputs=[current_index, fname_out, dialogue_out, summary_out, ner_out, topic_out, llm_summary_out,llm_ner_out,re_summary_out,re_topic_out,re_dialogue_out] ) prev_btn.click( prev_sample, inputs=[split_select, current_index], outputs=[current_index, fname_out, dialogue_out, summary_out, ner_out, topic_out, llm_summary_out,llm_ner_out,re_summary_out,re_topic_out,re_dialogue_out] ) split_select.change( reset_index_on_split_change, inputs=[split_select], outputs=[current_index, fname_out, dialogue_out, summary_out, ner_out, topic_out, llm_summary_out,llm_ner_out,re_summary_out,re_topic_out,re_dialogue_out] ) solar_btn.click( request_solar_api, inputs=[UPSTAGE_API_KEY, solar_model_select, split_select, current_index], outputs=[llm_summary_out, llm_ner_out, re_summary_out, re_topic_out, re_dialogue_out] ) with gr.Tab("Explore by Topic"): with gr.Row(): topic_select = gr.Dropdown(topics, label="Select Topic", value=topics[0]) prev_topic_btn = gr.Button("Previous Topic") next_topic_btn = gr.Button("Next Topic") with gr.Row(): train_topic_display = gr.Markdown(label="Train Data") val_topic_display = gr.Markdown(label="Validation Data") # Initial load for Block 2 demo.load( lambda: get_topic_data(topics[0]), inputs=None, outputs=[train_topic_display, val_topic_display] ) # Event handlers for Block 2 topic_select.change( update_topic_display_by_name, inputs=[topic_select], outputs=[train_topic_display, val_topic_display] ) next_topic_btn.click( lambda current_topic: change_topic(1, current_topic), inputs=[topic_select], outputs=[topic_select, train_topic_display, val_topic_display] ) prev_topic_btn.click( lambda current_topic: change_topic(-1, current_topic), inputs=[topic_select], outputs=[topic_select, train_topic_display, val_topic_display] ) with gr.Tab("Explore Validation Inference"): # State current_index_2 = gr.State(0) with gr.Row(): inference_file_input = gr.File(label="Upload Inferred Validation CSV", value="./data/dev_sample.csv") tokenizer_input = gr.Textbox(label="Tokenizer", value="digit82/kobart-summarization") with gr.Row(): prev_btn_2 = gr.Button("Previous") next_btn_2 = gr.Button("Next") fname_out_2 = gr.Textbox(label="fname", interactive=True) search_btn_2 = gr.Button("Search") with gr.Row(): val_dialogue_out = gr.Markdown(label="Dialogue") val_topic_out = gr.Markdown(label="Topic") with gr.Row(): with gr.Column(): gr.Markdown("### Original Validation Data") val_summary_out = gr.HTML(label="Summary") with gr.Column(): gr.Markdown("### Inference Data") val_inference_summary_out = gr.HTML(label="Summary") rouge_scores_out = gr.Markdown(label="ROUGE F1 Scores") # Outputs for the new tab outputs_2 = [ current_index_2, fname_out_2, val_dialogue_out, val_summary_out, val_topic_out, val_inference_summary_out, rouge_scores_out ] # Event handlers search_btn_2.click( get_val_sample_by_fname, inputs=[fname_out_2, inference_file_input, tokenizer_input], outputs=outputs_2 ) next_btn_2.click( lambda idx, file, tok: change_val_sample(1, idx, file, tok), inputs=[current_index_2, inference_file_input, tokenizer_input], outputs=outputs_2 ) prev_btn_2.click( lambda idx, file, tok: change_val_sample(-1, idx, file, tok), inputs=[current_index_2, inference_file_input, tokenizer_input], outputs=outputs_2 ) # Trigger update when file or tokenizer changes, using the current fname inference_file_input.change( get_val_sample_by_fname, inputs=[fname_out_2, inference_file_input, tokenizer_input], outputs=outputs_2 ) tokenizer_input.change( get_val_sample_by_fname, inputs=[fname_out_2, inference_file_input, tokenizer_input], outputs=outputs_2 ) # Initial load for the tab def initial_load_tab3(): fname = val_df.iloc[0]['fname'] d_orig, s_orig, t_orig, s_infer, scores = get_validation_data(fname, None, "digit82/kobart-summarization") return 0, fname, d_orig, s_orig, t_orig, s_infer, scores demo.load(initial_load_tab3, None, outputs_2) if __name__ == "__main__": # To run this script, use the command: # python src/app/app_gradio.py demo.launch()