|
import torch |
|
import torch.nn.functional as F |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import gradio as gr |
|
import pandas as pd |
|
import math |
|
|
|
|
|
model_ids = { |
|
"ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT", |
|
"ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT" |
|
} |
|
|
|
tokenizers = { |
|
name: AutoTokenizer.from_pretrained(path) |
|
for name, path in model_ids.items() |
|
} |
|
|
|
models = { |
|
name: AutoModelForCausalLM.from_pretrained(path).eval() |
|
for name, path in model_ids.items() |
|
} |
|
|
|
|
|
@torch.no_grad() |
|
def compare_models(text, top_k=5): |
|
results = {} |
|
|
|
for model_name in model_ids: |
|
tokenizer = tokenizers[model_name] |
|
model = models[model_name] |
|
|
|
|
|
inputs = tokenizer(text, return_tensors="pt") |
|
input_ids = inputs["input_ids"] |
|
|
|
|
|
outputs = model(**inputs) |
|
shift_logits = outputs.logits[:, :-1, :] |
|
shift_labels = input_ids[:, 1:] |
|
|
|
|
|
log_probs = F.log_softmax(shift_logits, dim=-1) |
|
token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1) |
|
|
|
total_log_prob = token_log_probs.sum().item() |
|
tokens = tokenizer.convert_ids_to_tokens(input_ids[0])[1:] |
|
|
|
|
|
topk_list = [] |
|
for i in range(min(20, shift_logits.shape[1])): |
|
topk = torch.topk(log_probs[0, i], k=top_k) |
|
topk_ids = topk.indices.tolist() |
|
topk_scores = topk.values.tolist() |
|
topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids) |
|
topk_probs = [round(math.exp(s), 4) for s in topk_scores] |
|
pair_list = [f"{tok} ({prob})" for tok, prob in zip(topk_tokens, topk_probs)] |
|
topk_list.append(", ".join(pair_list)) |
|
|
|
|
|
df = pd.DataFrame({ |
|
"Token": tokens[:20], |
|
"LogProb": [round(float(x), 4) for x in token_log_probs[0][:20]], |
|
f"Top-{top_k} Predictions": topk_list |
|
}) |
|
|
|
results[model_name] = { |
|
"df": df, |
|
"total_log_prob": total_log_prob |
|
} |
|
|
|
|
|
merged = pd.DataFrame({ |
|
"Token": results["ERNIE-4.5-PT"]["df"]["Token"], |
|
"ERNIE-4.5-PT LogProb": results["ERNIE-4.5-PT"]["df"]["LogProb"], |
|
"ERNIE-4.5-PT Top-k": results["ERNIE-4.5-PT"]["df"][f"Top-{top_k} Predictions"], |
|
"ERNIE-4.5-Base-PT LogProb": results["ERNIE-4.5-Base-PT"]["df"]["LogProb"], |
|
"ERNIE-4.5-Base-PT Top-k": results["ERNIE-4.5-Base-PT"]["df"][f"Top-{top_k} Predictions"], |
|
}) |
|
|
|
|
|
summary = ( |
|
f"🧠 Total Log Prob:\n" |
|
f"- ERNIE-4.5-PT: {results['ERNIE-4.5-PT']['total_log_prob']:.2f}\n" |
|
f"- ERNIE-4.5-Base-PT: {results['ERNIE-4.5-Base-PT']['total_log_prob']:.2f}" |
|
) |
|
|
|
return merged, summary |
|
|
|
|
|
demo = gr.Interface( |
|
fn=compare_models, |
|
inputs=[ |
|
gr.Textbox(lines=2, placeholder="Type a sentence here...", label="Input Sentence"), |
|
gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-k Predictions") |
|
], |
|
outputs=[ |
|
gr.Dataframe(label="Token LogProbs and Top-k Predictions"), |
|
gr.Textbox(label="Sentence Total Log Probability", lines=3) |
|
], |
|
title="🧪 ERNIE 4.5 Model Comparison with Top-k Predictions", |
|
description="Compare ERNIE-4.5-0.3B Instruct and Base model by computing token logprobs and Top-k predictions" |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|