File size: 3,843 Bytes
707db97
 
 
 
 
c6c2112
707db97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
c6c2112
707db97
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
964e0e1
707db97
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
import torch
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import pandas as pd
import math

# Load model and tokenizer
model_ids = {
    "ERNIE-4.5-PT": "baidu/ERNIE-4.5-0.3B-PT",
    "ERNIE-4.5-Base-PT": "baidu/ERNIE-4.5-0.3B-Base-PT"
}

tokenizers = {
    name: AutoTokenizer.from_pretrained(path)
    for name, path in model_ids.items()
}

models = {
    name: AutoModelForCausalLM.from_pretrained(path).eval()
    for name, path in model_ids.items()
}

# Main function: compute token-wise log probabilities and top-k predictions
@torch.no_grad()
def compare_models(text, top_k=5):
    results = {}

    for model_name in model_ids:
        tokenizer = tokenizers[model_name]
        model = models[model_name]

        # Tokenize input
        inputs = tokenizer(text, return_tensors="pt")
        input_ids = inputs["input_ids"]

        # Get model output logits
        outputs = model(**inputs)
        shift_logits = outputs.logits[:, :-1, :]          # Align prediction with target
        shift_labels = input_ids[:, 1:]                   # Shift labels to match predictions

        # Compute log probabilities
        log_probs = F.log_softmax(shift_logits, dim=-1)
        token_log_probs = log_probs.gather(2, shift_labels.unsqueeze(-1)).squeeze(-1)

        total_log_prob = token_log_probs.sum().item()
        tokens = tokenizer.convert_ids_to_tokens(input_ids[0])[1:]  # Skip BOS token

        # Generate top-k predictions for each position (up to first 20 tokens)
        topk_list = []
        for i in range(min(20, shift_logits.shape[1])):
            topk = torch.topk(log_probs[0, i], k=top_k)
            topk_ids = topk.indices.tolist()
            topk_scores = topk.values.tolist()
            topk_tokens = tokenizer.convert_ids_to_tokens(topk_ids)
            topk_probs = [round(math.exp(s), 4) for s in topk_scores]
            pair_list = [f"{tok} ({prob})" for tok, prob in zip(topk_tokens, topk_probs)]
            topk_list.append(", ".join(pair_list))

        # Prepare dataframe for display
        df = pd.DataFrame({
            "Token": tokens[:20],
            "LogProb": [round(float(x), 4) for x in token_log_probs[0][:20]],
            f"Top-{top_k} Predictions": topk_list
        })

        results[model_name] = {
            "df": df,
            "total_log_prob": total_log_prob
        }

    # Merge two model results into one table
    merged = pd.DataFrame({
        "Token": results["ERNIE-4.5-PT"]["df"]["Token"],
        "ERNIE-4.5-PT LogProb": results["ERNIE-4.5-PT"]["df"]["LogProb"],
        "ERNIE-4.5-PT Top-k": results["ERNIE-4.5-PT"]["df"][f"Top-{top_k} Predictions"],
        "ERNIE-4.5-Base-PT LogProb": results["ERNIE-4.5-Base-PT"]["df"]["LogProb"],
        "ERNIE-4.5-Base-PT Top-k": results["ERNIE-4.5-Base-PT"]["df"][f"Top-{top_k} Predictions"],
    })

    # Summarize total log probability for each model
    summary = (
        f"🧠 Total Log Prob:\n"
        f"- ERNIE-4.5-PT: {results['ERNIE-4.5-PT']['total_log_prob']:.2f}\n"
        f"- ERNIE-4.5-Base-PT: {results['ERNIE-4.5-Base-PT']['total_log_prob']:.2f}"
    )

    return merged, summary

# Gradio interface
demo = gr.Interface(
    fn=compare_models,
    inputs=[
        gr.Textbox(lines=2, placeholder="Type a sentence here...", label="Input Sentence"),
        gr.Slider(minimum=1, maximum=10, value=3, step=1, label="Top-k Predictions")
    ],
    outputs=[
        gr.Dataframe(label="Token LogProbs and Top-k Predictions"),
        gr.Textbox(label="Sentence Total Log Probability", lines=3)
    ],
    title="🧪 ERNIE 4.5 Model Comparison with Top-k Predictions",
    description="Compare ERNIE-4.5-0.3B Instruct and Base model by computing token logprobs and Top-k predictions"
)

if __name__ == "__main__":
    demo.launch()