kikikara's picture
Update app.py
234dafc verified
import gradio as gr
import os
import joblib
import torch
import numpy as np
import html # μ—¬μ „νžˆ highlighted_text_data 생성 μ‹œ html.escapeλ₯Ό μ‚¬μš©ν•  수 μžˆμœΌλ―€λ‘œ μœ μ§€
from transformers import AutoTokenizer, AutoModel, logging as hf_logging
import pandas as pd
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from sklearn.decomposition import PCA
import plotly.graph_objects as go
# --- Global Settings and Model Loading ---
hf_logging.set_verbosity_error()
MODEL_NAME = "bert-base-uncased"
DEVICE = "cpu"
SAVE_DIR = "μ €μž₯μ €μž₯1"
LAYER_ID = 4
SEED = 0
CLF_NAME = "linear"
CLASS_LABEL_MAP = {
0: "World",
1: "Sports",
2: "Business",
3: "Sci/Tech"
}
TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None
W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None
MODELS_LOADED_SUCCESSFULLY = False
MODEL_LOADING_ERROR_MESSAGE = ""
try:
print("Gradio App: Initializing model loading...")
lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl")
clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl")
if not os.path.isdir(SAVE_DIR):
raise FileNotFoundError(f"Error: Model storage directory '{SAVE_DIR}' not found.")
if not os.path.exists(lda_file_path):
raise FileNotFoundError(f"Error: LDA model file '{lda_file_path}' not found.")
if not os.path.exists(clf_file_path):
raise FileNotFoundError(f"Error: Classifier model file '{clf_file_path}' not found.")
lda = joblib.load(lda_file_path)
clf = joblib.load(clf_file_path)
if hasattr(clf, "base_estimator"): clf = clf.base_estimator
W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE)
MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE)
W_P_GLOBAL = torch.tensor(clf.coef_, dtype=torch.float32, device=DEVICE)
B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE)
TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True)
MODEL_GLOBAL = AutoModel.from_pretrained(
MODEL_NAME, output_hidden_states=True, output_attentions=False
).to(DEVICE).eval()
MODELS_LOADED_SUCCESSFULLY = True
print("Gradio App: All models and data loaded successfully!")
except Exception as e:
MODELS_LOADED_SUCCESSFULLY = False
MODEL_LOADING_ERROR_MESSAGE = f"Critical error during model loading: {str(e)}\nPlease ensure the '{SAVE_DIR}' folder and its contents are correct."
print(MODEL_LOADING_ERROR_MESSAGE)
# Helper function: 3D PCA Visualization using Plotly
def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"):
num_annotations = min(len(tokens), 20)
scores_array = np.array(scores).flatten()
text_annotations = [''] * len(tokens)
if len(scores_array) > 0 and len(tokens) > 0:
indices_to_annotate = np.argsort(scores_array)[-num_annotations:]
for i in indices_to_annotate:
if i < len(tokens):
text_annotations[i] = tokens[i]
fig = go.Figure(data=[go.Scatter3d(
x=token_embeddings_3d[:, 0],
y=token_embeddings_3d[:, 1],
z=token_embeddings_3d[:, 2],
mode='markers+text',
text=text_annotations,
textfont=dict(size=9, color='#333333'),
textposition='top center',
marker=dict(
size=6,
color=scores_array,
colorscale='RdBu',
reversescale=True,
opacity=0.8,
colorbar=dict(title='Importance', tickfont=dict(size=9), len=0.75, yanchor='middle')
),
hoverinfo='text',
hovertext=[f"Token: {t}<br>Score: {s:.3f}" for t, s in zip(tokens, scores_array)]
)])
fig.update_layout(
title=dict(text=title, x=0.5, font=dict(size=16)),
scene=dict(
xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"),
bgcolor="rgba(255, 255, 255, 0.95)",
camera_eye=dict(x=1.5, y=1.5, z=0.5)
),
margin=dict(l=5, r=5, b=5, t=45),
paper_bgcolor='rgba(0,0,0,0)'
)
return fig
# Helper function: Create an empty Plotly figure for placeholders
def create_empty_plotly_figure(message="N/A"):
fig = go.Figure()
fig.add_annotation(text=message, xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=12, color="grey"))
fig.update_layout(
xaxis={'visible': False},
yaxis={'visible': False},
height=300,
paper_bgcolor='rgba(0,0,0,0)',
plot_bgcolor='rgba(0,0,0,0)'
)
return fig
# --- Core Analysis Function (returns 6 items for Gradio UI) ---
def analyze_sentence_for_gradio(sentence_text, top_k_value):
if not MODELS_LOADED_SUCCESSFULLY:
# HTML output removed, adjust error return
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig = create_empty_plotly_figure("Model Loading Failed")
error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."}
return [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig # 6 items
try:
tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL
W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL
enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True)
input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE)
if input_ids.shape[1] == 0:
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig = create_empty_plotly_figure("Invalid Input")
error_label_output = {"Status": "Error", "Message": "Invalid input, no valid tokens."}
return [], "Input Error", error_label_output, [], empty_df, empty_fig # 6 items
input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach()
input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True)
outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask,
output_hidden_states=True, output_attentions=False)
cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :]
z_projected = (cls_vec - mu) @ W
logit_output = z_projected @ w_p.T + b_p
probs = torch.softmax(logit_output, dim=1)
pred_idx, pred_prob_val = torch.argmax(probs, dim=1).item(), probs[0, torch.argmax(probs, dim=1).item()].item()
if input_embeds_for_grad.grad is not None: input_embeds_for_grad.grad.zero_()
logit_output[0, pred_idx].backward()
if input_embeds_for_grad.grad is None:
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig = create_empty_plotly_figure("Gradient Error")
error_label_output = {"Status": "Error", "Message": "Gradient calculation failed."}
return [],"Analysis Error", error_label_output, [], empty_df, empty_fig # 6 items
grads = input_embeds_for_grad.grad.clone().detach()
scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0)
scores_np = scores.cpu().numpy()
valid_scores_for_norm = scores_np[np.isfinite(scores_np)]
scores_np = scores_np / (valid_scores_for_norm.max() + 1e-9) if len(valid_scores_for_norm) > 0 and valid_scores_for_norm.max() > 0 else np.zeros_like(scores_np)
tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False)
actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id]
actual_scores_np = scores_np[:len(actual_tokens)]
actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy()
# HTML generation logic removed
highlighted_text_data = []
cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id
for i, tok_str in enumerate(actual_tokens):
clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:]
current_score = actual_scores_np[i]
current_score_clipped = max(0, min(1, current_score))
current_token_id = input_ids[0, i].item()
if current_token_id == cls_token_id or current_token_id == sep_token_id:
highlighted_text_data.append((clean_tok_str + " ", None))
else:
highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3)))
top_tokens_for_df, top_tokens_for_barplot_list = [], []
valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
if token_id not in [cls_token_id, sep_token_id]]
sorted_valid_indices = sorted(valid_indices, key=lambda idx: -actual_scores_np[idx])
for token_idx in sorted_valid_indices[:top_k_value]:
token_str = actual_tokens[token_idx]
score_val_str = f"{actual_scores_np[token_idx]:.3f}"
top_tokens_for_df.append([token_str, score_val_str])
top_tokens_for_barplot_list.append({"token": token_str, "score": actual_scores_np[token_idx]})
barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score'])
predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index ({pred_idx})")
prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}"
prediction_details_for_label = {predicted_class_label_str: float(f"{pred_prob_val:.3f}")}
pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)")
non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist())
if token_id not in [cls_token_id, sep_token_id]]
if len(non_special_token_indices) >= 3 :
pca_tokens = [actual_tokens[i] for i in non_special_token_indices]
if len(pca_tokens) > 0:
pca_embeddings = actual_input_embeds[non_special_token_indices, :]
pca_scores_for_plot = actual_scores_np[non_special_token_indices]
pca = PCA(n_components=3, random_state=SEED)
token_embeddings_3d = pca.fit_transform(pca_embeddings)
pca_fig = plot_token_pca_3d_plotly(token_embeddings_3d, pca_tokens, pca_scores_for_plot)
return (highlighted_text_data, # HTML output removed
prediction_summary_text, prediction_details_for_label,
top_tokens_for_df, barplot_df,
pca_fig) # 6 items
except Exception as e:
import traceback
tb_str = traceback.format_exc()
# HTML output removed
print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}")
empty_df = pd.DataFrame(columns=['token', 'score'])
empty_fig = create_empty_plotly_figure("Analysis Error")
error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"}
return [], "Analysis Failed", error_label_output, [], empty_df, empty_fig # 6 items
# --- Gradio UI Definition (HTML Highlight Tab removed) ---
theme = gr.themes.Monochrome(
primary_hue=gr.themes.colors.blue,
secondary_hue=gr.themes.colors.sky,
neutral_hue=gr.themes.colors.slate
).set(
body_background_fill="#f0f2f6",
block_shadow="*shadow_drop_lg",
button_primary_background_fill="*primary_500",
button_primary_text_color="white",
)
with gr.Blocks(title="AI Sentence Analyzer XAI πŸš€", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo:
gr.Markdown("# πŸš€ AI Sentence Analyzer XAI: Exploring Model Explanations")
gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. "
"Explore token importance and their distribution in the embedding space.")
with gr.Row(equal_height=False):
with gr.Column(scale=1, min_width=350):
with gr.Group():
gr.Markdown("### ✏️ Input Sentence & Settings")
input_sentence = gr.Textbox(lines=5, label="English Sentence to Analyze", placeholder="Enter the English sentence you want to analyze here...")
input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Top-K Tokens")
submit_button = gr.Button("Analyze Sentence πŸ’«", variant="primary")
with gr.Column(scale=2):
with gr.Accordion("🎯 Prediction Outcome", open=True):
output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False)
output_prediction_details = gr.Label(label="Prediction Details & Confidence")
with gr.Accordion("⭐ Top-K Important Tokens (Table)", open=True):
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens",
row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True)
gr.Markdown("---")
gr.Markdown("## πŸ“Š Detailed Visualizations")
# HTML Highlight (Custom) section removed
with gr.Group(): # HighlightedText
gr.Markdown("### πŸ–οΈ Highlighted Text (Gradio)")
output_highlighted_text = gr.HighlightedText(
label="Token Importance (Score: 0-1)",
show_legend=True,
combine_adjacent=False
)
with gr.Row(): # BarPlot and PCA Plot Side-by-Side
with gr.Column(scale=1, min_width=400):
with gr.Group():
gr.Markdown("### πŸ“Š Top-K Bar Plot")
output_top_tokens_barplot = gr.BarPlot(
label="Top-K Token Importance Scores",
x="token",
y="score",
tooltip=['token', 'score'],
min_width=300
)
with gr.Column(scale=1, min_width=400):
with gr.Group():
gr.Markdown("### 🌐 Token Embeddings 3D PCA (Interactive)")
output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)")
gr.Markdown("---")
gr.Examples(
examples=[
["This movie is an absolute masterpiece, captivating from start to finish.", 5],
["Despite some flaws, the film offers a compelling narrative.", 3],
["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4]
],
inputs=[input_sentence, input_top_k],
outputs=[ # output_html_visualization removed
output_highlighted_text,
output_prediction_summary, output_prediction_details,
output_top_tokens_df, output_top_tokens_barplot,
output_pca_plot
],
fn=analyze_sentence_for_gradio,
cache_examples=False
)
gr.HTML("<p style='text-align: center; color: #4a5568;'>Explainable AI Demo powered by Gradio & Hugging Face Transformers</p>")
submit_button.click(
fn=analyze_sentence_for_gradio,
inputs=[input_sentence, input_top_k],
outputs=[ # output_html_visualization removed
output_highlighted_text,
output_prediction_summary, output_prediction_details,
output_top_tokens_df, output_top_tokens_barplot,
output_pca_plot
],
api_name="explain_sentence_xai"
)
if __name__ == "__main__":
if not MODELS_LOADED_SUCCESSFULLY:
print("*"*80)
print(f"WARNING: Models failed to load! {MODEL_LOADING_ERROR_MESSAGE}")
print("The Gradio UI will be displayed, but analysis will fail.")
print("*"*80)
demo.launch()