import gradio as gr import os import joblib import torch import numpy as np import html # 여전히 highlighted_text_data 생성 시 html.escape를 사용할 수 있으므로 유지 from transformers import AutoTokenizer, AutoModel, logging as hf_logging import pandas as pd import matplotlib matplotlib.use('Agg') import matplotlib.pyplot as plt from sklearn.decomposition import PCA import plotly.graph_objects as go # --- Global Settings and Model Loading --- hf_logging.set_verbosity_error() MODEL_NAME = "bert-base-uncased" DEVICE = "cpu" SAVE_DIR = "저장저장1" LAYER_ID = 4 SEED = 0 CLF_NAME = "linear" CLASS_LABEL_MAP = { 0: "World", 1: "Sports", 2: "Business", 3: "Sci/Tech" } TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None MODELS_LOADED_SUCCESSFULLY = False MODEL_LOADING_ERROR_MESSAGE = "" try: print("Gradio App: Initializing model loading...") lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl") clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl") if not os.path.isdir(SAVE_DIR): raise FileNotFoundError(f"Error: Model storage directory '{SAVE_DIR}' not found.") if not os.path.exists(lda_file_path): raise FileNotFoundError(f"Error: LDA model file '{lda_file_path}' not found.") if not os.path.exists(clf_file_path): raise FileNotFoundError(f"Error: Classifier model file '{clf_file_path}' not found.") lda = joblib.load(lda_file_path) clf = joblib.load(clf_file_path) if hasattr(clf, "base_estimator"): clf = clf.base_estimator W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE) MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE) W_P_GLOBAL = torch.tensor(clf.coef_, dtype=torch.float32, device=DEVICE) B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE) TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) MODEL_GLOBAL = AutoModel.from_pretrained( MODEL_NAME, output_hidden_states=True, output_attentions=False ).to(DEVICE).eval() MODELS_LOADED_SUCCESSFULLY = True print("Gradio App: All models and data loaded successfully!") except Exception as e: MODELS_LOADED_SUCCESSFULLY = False MODEL_LOADING_ERROR_MESSAGE = f"Critical error during model loading: {str(e)}\nPlease ensure the '{SAVE_DIR}' folder and its contents are correct." print(MODEL_LOADING_ERROR_MESSAGE) # Helper function: 3D PCA Visualization using Plotly def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"): num_annotations = min(len(tokens), 20) scores_array = np.array(scores).flatten() text_annotations = [''] * len(tokens) if len(scores_array) > 0 and len(tokens) > 0: indices_to_annotate = np.argsort(scores_array)[-num_annotations:] for i in indices_to_annotate: if i < len(tokens): text_annotations[i] = tokens[i] fig = go.Figure(data=[go.Scatter3d( x=token_embeddings_3d[:, 0], y=token_embeddings_3d[:, 1], z=token_embeddings_3d[:, 2], mode='markers+text', text=text_annotations, textfont=dict(size=9, color='#333333'), textposition='top center', marker=dict( size=6, color=scores_array, colorscale='RdBu', reversescale=True, opacity=0.8, colorbar=dict(title='Importance', tickfont=dict(size=9), len=0.75, yanchor='middle') ), hoverinfo='text', hovertext=[f"Token: {t}
Score: {s:.3f}" for t, s in zip(tokens, scores_array)] )]) fig.update_layout( title=dict(text=title, x=0.5, font=dict(size=16)), scene=dict( xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"), yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"), zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"), bgcolor="rgba(255, 255, 255, 0.95)", camera_eye=dict(x=1.5, y=1.5, z=0.5) ), margin=dict(l=5, r=5, b=5, t=45), paper_bgcolor='rgba(0,0,0,0)' ) return fig # Helper function: Create an empty Plotly figure for placeholders def create_empty_plotly_figure(message="N/A"): fig = go.Figure() fig.add_annotation(text=message, xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=12, color="grey")) fig.update_layout( xaxis={'visible': False}, yaxis={'visible': False}, height=300, paper_bgcolor='rgba(0,0,0,0)', plot_bgcolor='rgba(0,0,0,0)' ) return fig # --- Core Analysis Function (returns 6 items for Gradio UI) --- def analyze_sentence_for_gradio(sentence_text, top_k_value): if not MODELS_LOADED_SUCCESSFULLY: # HTML output removed, adjust error return empty_df = pd.DataFrame(columns=['token', 'score']) empty_fig = create_empty_plotly_figure("Model Loading Failed") error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."} return [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig # 6 items try: tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True) input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE) if input_ids.shape[1] == 0: empty_df = pd.DataFrame(columns=['token', 'score']) empty_fig = create_empty_plotly_figure("Invalid Input") error_label_output = {"Status": "Error", "Message": "Invalid input, no valid tokens."} return [], "Input Error", error_label_output, [], empty_df, empty_fig # 6 items input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach() input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True) outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask, output_hidden_states=True, output_attentions=False) cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :] z_projected = (cls_vec - mu) @ W logit_output = z_projected @ w_p.T + b_p probs = torch.softmax(logit_output, dim=1) pred_idx, pred_prob_val = torch.argmax(probs, dim=1).item(), probs[0, torch.argmax(probs, dim=1).item()].item() if input_embeds_for_grad.grad is not None: input_embeds_for_grad.grad.zero_() logit_output[0, pred_idx].backward() if input_embeds_for_grad.grad is None: empty_df = pd.DataFrame(columns=['token', 'score']) empty_fig = create_empty_plotly_figure("Gradient Error") error_label_output = {"Status": "Error", "Message": "Gradient calculation failed."} return [],"Analysis Error", error_label_output, [], empty_df, empty_fig # 6 items grads = input_embeds_for_grad.grad.clone().detach() scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0) scores_np = scores.cpu().numpy() valid_scores_for_norm = scores_np[np.isfinite(scores_np)] scores_np = scores_np / (valid_scores_for_norm.max() + 1e-9) if len(valid_scores_for_norm) > 0 and valid_scores_for_norm.max() > 0 else np.zeros_like(scores_np) tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False) actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id] actual_scores_np = scores_np[:len(actual_tokens)] actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy() # HTML generation logic removed highlighted_text_data = [] cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id for i, tok_str in enumerate(actual_tokens): clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:] current_score = actual_scores_np[i] current_score_clipped = max(0, min(1, current_score)) current_token_id = input_ids[0, i].item() if current_token_id == cls_token_id or current_token_id == sep_token_id: highlighted_text_data.append((clean_tok_str + " ", None)) else: highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3))) top_tokens_for_df, top_tokens_for_barplot_list = [], [] valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist()) if token_id not in [cls_token_id, sep_token_id]] sorted_valid_indices = sorted(valid_indices, key=lambda idx: -actual_scores_np[idx]) for token_idx in sorted_valid_indices[:top_k_value]: token_str = actual_tokens[token_idx] score_val_str = f"{actual_scores_np[token_idx]:.3f}" top_tokens_for_df.append([token_str, score_val_str]) top_tokens_for_barplot_list.append({"token": token_str, "score": actual_scores_np[token_idx]}) barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score']) predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index ({pred_idx})") prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}" prediction_details_for_label = {predicted_class_label_str: float(f"{pred_prob_val:.3f}")} pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)") non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist()) if token_id not in [cls_token_id, sep_token_id]] if len(non_special_token_indices) >= 3 : pca_tokens = [actual_tokens[i] for i in non_special_token_indices] if len(pca_tokens) > 0: pca_embeddings = actual_input_embeds[non_special_token_indices, :] pca_scores_for_plot = actual_scores_np[non_special_token_indices] pca = PCA(n_components=3, random_state=SEED) token_embeddings_3d = pca.fit_transform(pca_embeddings) pca_fig = plot_token_pca_3d_plotly(token_embeddings_3d, pca_tokens, pca_scores_for_plot) return (highlighted_text_data, # HTML output removed prediction_summary_text, prediction_details_for_label, top_tokens_for_df, barplot_df, pca_fig) # 6 items except Exception as e: import traceback tb_str = traceback.format_exc() # HTML output removed print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}") empty_df = pd.DataFrame(columns=['token', 'score']) empty_fig = create_empty_plotly_figure("Analysis Error") error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"} return [], "Analysis Failed", error_label_output, [], empty_df, empty_fig # 6 items # --- Gradio UI Definition (HTML Highlight Tab removed) --- theme = gr.themes.Monochrome( primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky, neutral_hue=gr.themes.colors.slate ).set( body_background_fill="#f0f2f6", block_shadow="*shadow_drop_lg", button_primary_background_fill="*primary_500", button_primary_text_color="white", ) with gr.Blocks(title="AI Sentence Analyzer XAI 🚀", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo: gr.Markdown("# 🚀 AI Sentence Analyzer XAI: Exploring Model Explanations") gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. " "Explore token importance and their distribution in the embedding space.") with gr.Row(equal_height=False): with gr.Column(scale=1, min_width=350): with gr.Group(): gr.Markdown("### ✏️ Input Sentence & Settings") input_sentence = gr.Textbox(lines=5, label="English Sentence to Analyze", placeholder="Enter the English sentence you want to analyze here...") input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Top-K Tokens") submit_button = gr.Button("Analyze Sentence 💫", variant="primary") with gr.Column(scale=2): with gr.Accordion("🎯 Prediction Outcome", open=True): output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False) output_prediction_details = gr.Label(label="Prediction Details & Confidence") with gr.Accordion("⭐ Top-K Important Tokens (Table)", open=True): output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens", row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True) gr.Markdown("---") gr.Markdown("## 📊 Detailed Visualizations") # HTML Highlight (Custom) section removed with gr.Group(): # HighlightedText gr.Markdown("### 🖍️ Highlighted Text (Gradio)") output_highlighted_text = gr.HighlightedText( label="Token Importance (Score: 0-1)", show_legend=True, combine_adjacent=False ) with gr.Row(): # BarPlot and PCA Plot Side-by-Side with gr.Column(scale=1, min_width=400): with gr.Group(): gr.Markdown("### 📊 Top-K Bar Plot") output_top_tokens_barplot = gr.BarPlot( label="Top-K Token Importance Scores", x="token", y="score", tooltip=['token', 'score'], min_width=300 ) with gr.Column(scale=1, min_width=400): with gr.Group(): gr.Markdown("### 🌐 Token Embeddings 3D PCA (Interactive)") output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)") gr.Markdown("---") gr.Examples( examples=[ ["This movie is an absolute masterpiece, captivating from start to finish.", 5], ["Despite some flaws, the film offers a compelling narrative.", 3], ["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4] ], inputs=[input_sentence, input_top_k], outputs=[ # output_html_visualization removed output_highlighted_text, output_prediction_summary, output_prediction_details, output_top_tokens_df, output_top_tokens_barplot, output_pca_plot ], fn=analyze_sentence_for_gradio, cache_examples=False ) gr.HTML("

Explainable AI Demo powered by Gradio & Hugging Face Transformers

") submit_button.click( fn=analyze_sentence_for_gradio, inputs=[input_sentence, input_top_k], outputs=[ # output_html_visualization removed output_highlighted_text, output_prediction_summary, output_prediction_details, output_top_tokens_df, output_top_tokens_barplot, output_pca_plot ], api_name="explain_sentence_xai" ) if __name__ == "__main__": if not MODELS_LOADED_SUCCESSFULLY: print("*"*80) print(f"WARNING: Models failed to load! {MODEL_LOADING_ERROR_MESSAGE}") print("The Gradio UI will be displayed, but analysis will fail.") print("*"*80) demo.launch()