Spaces:
Sleeping
Sleeping
import gradio as gr | |
import os | |
import joblib | |
import torch | |
import numpy as np | |
import html # μ¬μ ν highlighted_text_data μμ± μ html.escapeλ₯Ό μ¬μ©ν μ μμΌλ―λ‘ μ μ§ | |
from transformers import AutoTokenizer, AutoModel, logging as hf_logging | |
import pandas as pd | |
import matplotlib | |
matplotlib.use('Agg') | |
import matplotlib.pyplot as plt | |
from sklearn.decomposition import PCA | |
import plotly.graph_objects as go | |
# --- Global Settings and Model Loading --- | |
hf_logging.set_verbosity_error() | |
MODEL_NAME = "bert-base-uncased" | |
DEVICE = "cpu" | |
SAVE_DIR = "μ μ₯μ μ₯1" | |
LAYER_ID = 4 | |
SEED = 0 | |
CLF_NAME = "linear" | |
CLASS_LABEL_MAP = { | |
0: "World", | |
1: "Sports", | |
2: "Business", | |
3: "Sci/Tech" | |
} | |
TOKENIZER_GLOBAL, MODEL_GLOBAL = None, None | |
W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL = None, None, None, None | |
MODELS_LOADED_SUCCESSFULLY = False | |
MODEL_LOADING_ERROR_MESSAGE = "" | |
try: | |
print("Gradio App: Initializing model loading...") | |
lda_file_path = os.path.join(SAVE_DIR, f"lda_layer{LAYER_ID}_seed{SEED}.pkl") | |
clf_file_path = os.path.join(SAVE_DIR, f"{CLF_NAME}_layer{LAYER_ID}_projlda_seed{SEED}.pkl") | |
if not os.path.isdir(SAVE_DIR): | |
raise FileNotFoundError(f"Error: Model storage directory '{SAVE_DIR}' not found.") | |
if not os.path.exists(lda_file_path): | |
raise FileNotFoundError(f"Error: LDA model file '{lda_file_path}' not found.") | |
if not os.path.exists(clf_file_path): | |
raise FileNotFoundError(f"Error: Classifier model file '{clf_file_path}' not found.") | |
lda = joblib.load(lda_file_path) | |
clf = joblib.load(clf_file_path) | |
if hasattr(clf, "base_estimator"): clf = clf.base_estimator | |
W_GLOBAL = torch.tensor(lda.scalings_, dtype=torch.float32, device=DEVICE) | |
MU_GLOBAL = torch.tensor(lda.xbar_, dtype=torch.float32, device=DEVICE) | |
W_P_GLOBAL = torch.tensor(clf.coef_, dtype=torch.float32, device=DEVICE) | |
B_P_GLOBAL = torch.tensor(clf.intercept_, dtype=torch.float32, device=DEVICE) | |
TOKENIZER_GLOBAL = AutoTokenizer.from_pretrained(MODEL_NAME, use_fast=True) | |
MODEL_GLOBAL = AutoModel.from_pretrained( | |
MODEL_NAME, output_hidden_states=True, output_attentions=False | |
).to(DEVICE).eval() | |
MODELS_LOADED_SUCCESSFULLY = True | |
print("Gradio App: All models and data loaded successfully!") | |
except Exception as e: | |
MODELS_LOADED_SUCCESSFULLY = False | |
MODEL_LOADING_ERROR_MESSAGE = f"Critical error during model loading: {str(e)}\nPlease ensure the '{SAVE_DIR}' folder and its contents are correct." | |
print(MODEL_LOADING_ERROR_MESSAGE) | |
# Helper function: 3D PCA Visualization using Plotly | |
def plot_token_pca_3d_plotly(token_embeddings_3d, tokens, scores, title="Token Embeddings 3D PCA (Colored by Importance)"): | |
num_annotations = min(len(tokens), 20) | |
scores_array = np.array(scores).flatten() | |
text_annotations = [''] * len(tokens) | |
if len(scores_array) > 0 and len(tokens) > 0: | |
indices_to_annotate = np.argsort(scores_array)[-num_annotations:] | |
for i in indices_to_annotate: | |
if i < len(tokens): | |
text_annotations[i] = tokens[i] | |
fig = go.Figure(data=[go.Scatter3d( | |
x=token_embeddings_3d[:, 0], | |
y=token_embeddings_3d[:, 1], | |
z=token_embeddings_3d[:, 2], | |
mode='markers+text', | |
text=text_annotations, | |
textfont=dict(size=9, color='#333333'), | |
textposition='top center', | |
marker=dict( | |
size=6, | |
color=scores_array, | |
colorscale='RdBu', | |
reversescale=True, | |
opacity=0.8, | |
colorbar=dict(title='Importance', tickfont=dict(size=9), len=0.75, yanchor='middle') | |
), | |
hoverinfo='text', | |
hovertext=[f"Token: {t}<br>Score: {s:.3f}" for t, s in zip(tokens, scores_array)] | |
)]) | |
fig.update_layout( | |
title=dict(text=title, x=0.5, font=dict(size=16)), | |
scene=dict( | |
xaxis=dict(title=dict(text='PCA Comp 1', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"), | |
yaxis=dict(title=dict(text='PCA Comp 2', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"), | |
zaxis=dict(title=dict(text='PCA Comp 3', font=dict(size=10)), tickfont=dict(size=9), backgroundcolor="rgba(230, 230, 230, 0.8)"), | |
bgcolor="rgba(255, 255, 255, 0.95)", | |
camera_eye=dict(x=1.5, y=1.5, z=0.5) | |
), | |
margin=dict(l=5, r=5, b=5, t=45), | |
paper_bgcolor='rgba(0,0,0,0)' | |
) | |
return fig | |
# Helper function: Create an empty Plotly figure for placeholders | |
def create_empty_plotly_figure(message="N/A"): | |
fig = go.Figure() | |
fig.add_annotation(text=message, xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False, font=dict(size=12, color="grey")) | |
fig.update_layout( | |
xaxis={'visible': False}, | |
yaxis={'visible': False}, | |
height=300, | |
paper_bgcolor='rgba(0,0,0,0)', | |
plot_bgcolor='rgba(0,0,0,0)' | |
) | |
return fig | |
# --- Core Analysis Function (returns 6 items for Gradio UI) --- | |
def analyze_sentence_for_gradio(sentence_text, top_k_value): | |
if not MODELS_LOADED_SUCCESSFULLY: | |
# HTML output removed, adjust error return | |
empty_df = pd.DataFrame(columns=['token', 'score']) | |
empty_fig = create_empty_plotly_figure("Model Loading Failed") | |
error_label_output = {"Status": "Error", "Message": "Model Loading Failed. Check logs."} | |
return [], "Model Loading Failed", error_label_output, [], empty_df, empty_fig # 6 items | |
try: | |
tokenizer, model = TOKENIZER_GLOBAL, MODEL_GLOBAL | |
W, mu, w_p, b_p = W_GLOBAL, MU_GLOBAL, W_P_GLOBAL, B_P_GLOBAL | |
enc = tokenizer(sentence_text, return_tensors="pt", truncation=True, max_length=510, padding=True) | |
input_ids, attn_mask = enc["input_ids"].to(DEVICE), enc["attention_mask"].to(DEVICE) | |
if input_ids.shape[1] == 0: | |
empty_df = pd.DataFrame(columns=['token', 'score']) | |
empty_fig = create_empty_plotly_figure("Invalid Input") | |
error_label_output = {"Status": "Error", "Message": "Invalid input, no valid tokens."} | |
return [], "Input Error", error_label_output, [], empty_df, empty_fig # 6 items | |
input_embeds_detached = model.embeddings.word_embeddings(input_ids).clone().detach() | |
input_embeds_for_grad = input_embeds_detached.clone().requires_grad_(True) | |
outputs = model(inputs_embeds=input_embeds_for_grad, attention_mask=attn_mask, | |
output_hidden_states=True, output_attentions=False) | |
cls_vec = outputs.hidden_states[LAYER_ID][:, 0, :] | |
z_projected = (cls_vec - mu) @ W | |
logit_output = z_projected @ w_p.T + b_p | |
probs = torch.softmax(logit_output, dim=1) | |
pred_idx, pred_prob_val = torch.argmax(probs, dim=1).item(), probs[0, torch.argmax(probs, dim=1).item()].item() | |
if input_embeds_for_grad.grad is not None: input_embeds_for_grad.grad.zero_() | |
logit_output[0, pred_idx].backward() | |
if input_embeds_for_grad.grad is None: | |
empty_df = pd.DataFrame(columns=['token', 'score']) | |
empty_fig = create_empty_plotly_figure("Gradient Error") | |
error_label_output = {"Status": "Error", "Message": "Gradient calculation failed."} | |
return [],"Analysis Error", error_label_output, [], empty_df, empty_fig # 6 items | |
grads = input_embeds_for_grad.grad.clone().detach() | |
scores = (grads * input_embeds_detached).norm(dim=2).squeeze(0) | |
scores_np = scores.cpu().numpy() | |
valid_scores_for_norm = scores_np[np.isfinite(scores_np)] | |
scores_np = scores_np / (valid_scores_for_norm.max() + 1e-9) if len(valid_scores_for_norm) > 0 and valid_scores_for_norm.max() > 0 else np.zeros_like(scores_np) | |
tokens_raw = tokenizer.convert_ids_to_tokens(input_ids[0], skip_special_tokens=False) | |
actual_tokens = [tok for i, tok in enumerate(tokens_raw) if input_ids[0,i] != tokenizer.pad_token_id] | |
actual_scores_np = scores_np[:len(actual_tokens)] | |
actual_input_embeds = input_embeds_detached[0, :len(actual_tokens), :].cpu().numpy() | |
# HTML generation logic removed | |
highlighted_text_data = [] | |
cls_token_id, sep_token_id = tokenizer.cls_token_id, tokenizer.sep_token_id | |
for i, tok_str in enumerate(actual_tokens): | |
clean_tok_str = tok_str.replace("##", "") if "##" not in tok_str else tok_str[2:] | |
current_score = actual_scores_np[i] | |
current_score_clipped = max(0, min(1, current_score)) | |
current_token_id = input_ids[0, i].item() | |
if current_token_id == cls_token_id or current_token_id == sep_token_id: | |
highlighted_text_data.append((clean_tok_str + " ", None)) | |
else: | |
highlighted_text_data.append((clean_tok_str + " ", round(current_score_clipped, 3))) | |
top_tokens_for_df, top_tokens_for_barplot_list = [], [] | |
valid_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist()) | |
if token_id not in [cls_token_id, sep_token_id]] | |
sorted_valid_indices = sorted(valid_indices, key=lambda idx: -actual_scores_np[idx]) | |
for token_idx in sorted_valid_indices[:top_k_value]: | |
token_str = actual_tokens[token_idx] | |
score_val_str = f"{actual_scores_np[token_idx]:.3f}" | |
top_tokens_for_df.append([token_str, score_val_str]) | |
top_tokens_for_barplot_list.append({"token": token_str, "score": actual_scores_np[token_idx]}) | |
barplot_df = pd.DataFrame(top_tokens_for_barplot_list) if top_tokens_for_barplot_list else pd.DataFrame(columns=['token', 'score']) | |
predicted_class_label_str = CLASS_LABEL_MAP.get(pred_idx, f"Unknown Index ({pred_idx})") | |
prediction_summary_text = f"Predicted Class: {predicted_class_label_str}\nProbability: {pred_prob_val:.3f}" | |
prediction_details_for_label = {predicted_class_label_str: float(f"{pred_prob_val:.3f}")} | |
pca_fig = create_empty_plotly_figure("PCA Plot N/A\n(Not enough non-special tokens for 3D)") | |
non_special_token_indices = [idx for idx, token_id in enumerate(input_ids[0,:len(actual_tokens)].tolist()) | |
if token_id not in [cls_token_id, sep_token_id]] | |
if len(non_special_token_indices) >= 3 : | |
pca_tokens = [actual_tokens[i] for i in non_special_token_indices] | |
if len(pca_tokens) > 0: | |
pca_embeddings = actual_input_embeds[non_special_token_indices, :] | |
pca_scores_for_plot = actual_scores_np[non_special_token_indices] | |
pca = PCA(n_components=3, random_state=SEED) | |
token_embeddings_3d = pca.fit_transform(pca_embeddings) | |
pca_fig = plot_token_pca_3d_plotly(token_embeddings_3d, pca_tokens, pca_scores_for_plot) | |
return (highlighted_text_data, # HTML output removed | |
prediction_summary_text, prediction_details_for_label, | |
top_tokens_for_df, barplot_df, | |
pca_fig) # 6 items | |
except Exception as e: | |
import traceback | |
tb_str = traceback.format_exc() | |
# HTML output removed | |
print(f"analyze_sentence_for_gradio error: {e}\n{tb_str}") | |
empty_df = pd.DataFrame(columns=['token', 'score']) | |
empty_fig = create_empty_plotly_figure("Analysis Error") | |
error_label_output = {"Status": "Error", "Message": f"Analysis failed: {str(e)}"} | |
return [], "Analysis Failed", error_label_output, [], empty_df, empty_fig # 6 items | |
# --- Gradio UI Definition (HTML Highlight Tab removed) --- | |
theme = gr.themes.Monochrome( | |
primary_hue=gr.themes.colors.blue, | |
secondary_hue=gr.themes.colors.sky, | |
neutral_hue=gr.themes.colors.slate | |
).set( | |
body_background_fill="#f0f2f6", | |
block_shadow="*shadow_drop_lg", | |
button_primary_background_fill="*primary_500", | |
button_primary_text_color="white", | |
) | |
with gr.Blocks(title="AI Sentence Analyzer XAI π", theme=theme, css=".gradio-container {max-width: 98% !important;}") as demo: | |
gr.Markdown("# π AI Sentence Analyzer XAI: Exploring Model Explanations") | |
gr.Markdown("Analyze English sentences to understand BERT model predictions through various XAI visualization techniques. " | |
"Explore token importance and their distribution in the embedding space.") | |
with gr.Row(equal_height=False): | |
with gr.Column(scale=1, min_width=350): | |
with gr.Group(): | |
gr.Markdown("### βοΈ Input Sentence & Settings") | |
input_sentence = gr.Textbox(lines=5, label="English Sentence to Analyze", placeholder="Enter the English sentence you want to analyze here...") | |
input_top_k = gr.Slider(minimum=1, maximum=10, value=5, step=1, label="Number of Top-K Tokens") | |
submit_button = gr.Button("Analyze Sentence π«", variant="primary") | |
with gr.Column(scale=2): | |
with gr.Accordion("π― Prediction Outcome", open=True): | |
output_prediction_summary = gr.Textbox(label="Prediction Summary", lines=2, interactive=False) | |
output_prediction_details = gr.Label(label="Prediction Details & Confidence") | |
with gr.Accordion("β Top-K Important Tokens (Table)", open=True): | |
output_top_tokens_df = gr.DataFrame(headers=["Token", "Score"], label="Most Important Tokens", | |
row_count=(1,"dynamic"), col_count=(2,"fixed"), interactive=False, wrap=True) | |
gr.Markdown("---") | |
gr.Markdown("## π Detailed Visualizations") | |
# HTML Highlight (Custom) section removed | |
with gr.Group(): # HighlightedText | |
gr.Markdown("### ποΈ Highlighted Text (Gradio)") | |
output_highlighted_text = gr.HighlightedText( | |
label="Token Importance (Score: 0-1)", | |
show_legend=True, | |
combine_adjacent=False | |
) | |
with gr.Row(): # BarPlot and PCA Plot Side-by-Side | |
with gr.Column(scale=1, min_width=400): | |
with gr.Group(): | |
gr.Markdown("### π Top-K Bar Plot") | |
output_top_tokens_barplot = gr.BarPlot( | |
label="Top-K Token Importance Scores", | |
x="token", | |
y="score", | |
tooltip=['token', 'score'], | |
min_width=300 | |
) | |
with gr.Column(scale=1, min_width=400): | |
with gr.Group(): | |
gr.Markdown("### π Token Embeddings 3D PCA (Interactive)") | |
output_pca_plot = gr.Plot(label="3D PCA of Token Embeddings (Colored by Importance Score)") | |
gr.Markdown("---") | |
gr.Examples( | |
examples=[ | |
["This movie is an absolute masterpiece, captivating from start to finish.", 5], | |
["Despite some flaws, the film offers a compelling narrative.", 3], | |
["I was thoroughly disappointed with the lackluster performance and predictable plot.", 4] | |
], | |
inputs=[input_sentence, input_top_k], | |
outputs=[ # output_html_visualization removed | |
output_highlighted_text, | |
output_prediction_summary, output_prediction_details, | |
output_top_tokens_df, output_top_tokens_barplot, | |
output_pca_plot | |
], | |
fn=analyze_sentence_for_gradio, | |
cache_examples=False | |
) | |
gr.HTML("<p style='text-align: center; color: #4a5568;'>Explainable AI Demo powered by Gradio & Hugging Face Transformers</p>") | |
submit_button.click( | |
fn=analyze_sentence_for_gradio, | |
inputs=[input_sentence, input_top_k], | |
outputs=[ # output_html_visualization removed | |
output_highlighted_text, | |
output_prediction_summary, output_prediction_details, | |
output_top_tokens_df, output_top_tokens_barplot, | |
output_pca_plot | |
], | |
api_name="explain_sentence_xai" | |
) | |
if __name__ == "__main__": | |
if not MODELS_LOADED_SUCCESSFULLY: | |
print("*"*80) | |
print(f"WARNING: Models failed to load! {MODEL_LOADING_ERROR_MESSAGE}") | |
print("The Gradio UI will be displayed, but analysis will fail.") | |
print("*"*80) | |
demo.launch() |