PhilippSpohn commited on
Commit
b9b96cc
·
0 Parent(s):

Initial commit: Token Probability Analyzer web application

Browse files
Files changed (7) hide show
  1. .gitignore +38 -0
  2. README.md +35 -0
  3. app.py +98 -0
  4. requirements.txt +4 -0
  5. static/script.js +145 -0
  6. static/style.css +195 -0
  7. templates/index.html +60 -0
.gitignore ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Python
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+ *.so
6
+ .Python
7
+ build/
8
+ develop-eggs/
9
+ dist/
10
+ downloads/
11
+ eggs/
12
+ .eggs/
13
+ lib/
14
+ lib64/
15
+ parts/
16
+ sdist/
17
+ var/
18
+ wheels/
19
+ *.egg-info/
20
+ .installed.cfg
21
+ *.egg
22
+
23
+ # Virtual Environment
24
+ venv/
25
+ env/
26
+ ENV/
27
+
28
+ # IDE
29
+ .idea/
30
+ .vscode/
31
+ *.swp
32
+ *.swo
33
+
34
+ # Misc
35
+ .DS_Store
36
+ .env
37
+ .env.local
38
+ .env.*.local
README.md ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Token Probability Analyzer
2
+
3
+ A web application that analyzes token probabilities using various language models. This tool helps visualize and understand how language models predict tokens in a given text sequence.
4
+
5
+ ## Features
6
+
7
+ - Support for multiple language models (GPT-2, TinyLlama, etc.)
8
+ - Token-by-token probability analysis
9
+ - Percentile scoring for token probabilities
10
+ - Top-k predictions for each position
11
+ - Joint and average log likelihood calculations
12
+
13
+ ## Setup
14
+
15
+ 1. Install the required dependencies:
16
+ ```bash
17
+ pip install -r requirements.txt
18
+ ```
19
+
20
+ 2. Run the application:
21
+ ```bash
22
+ python app.py
23
+ ```
24
+
25
+ 3. Open your browser and navigate to `http://localhost:5000`
26
+
27
+ ## Usage
28
+
29
+ 1. Select a language model from the dropdown menu
30
+ 2. Enter your text in the input field
31
+ 3. Click "Analyze" to see the token probabilities and predictions
32
+
33
+ ## Technical Details
34
+
35
+ The application uses Flask for the backend and provides a simple web interface. It leverages the Hugging Face Transformers library to load and run various language models for token probability analysis.
app.py ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request, jsonify
2
+ from transformers import AutoModelForCausalLM, AutoTokenizer
3
+ import torch
4
+ import torch.nn.functional as F
5
+ from scipy.stats import percentileofscore
6
+
7
+ app = Flask(__name__)
8
+
9
+ DEFAULT_MODEL = "gpt2"
10
+
11
+ model_cache = {}
12
+ tokenizer_cache = {}
13
+
14
+
15
+ def get_model_and_tokenizer(model_name):
16
+ if model_name not in model_cache:
17
+ trust_code = model_name == "microsoft/phi-1_5"
18
+ model_cache[model_name] = AutoModelForCausalLM.from_pretrained(
19
+ model_name, trust_remote_code=trust_code
20
+ )
21
+ tokenizer_cache[model_name] = AutoTokenizer.from_pretrained(
22
+ model_name, trust_remote_code=trust_code
23
+ )
24
+ return model_cache[model_name], tokenizer_cache[model_name]
25
+
26
+
27
+ @app.route("/")
28
+ def index():
29
+ return render_template(
30
+ "index.html",
31
+ models=[
32
+ DEFAULT_MODEL,
33
+ # "gpt2-medium",
34
+ # "gpt2-large",
35
+ # "gpt2-xl",
36
+ # "EleutherAI/pythia-1.4b",
37
+ # "facebook/opt-1.3b",
38
+ # "bigscience/bloom-1b7",
39
+ # "microsoft/phi-1_5",
40
+ "TinyLlama/TinyLlama-1.1B-Chat-v1.0",
41
+ ],
42
+ )
43
+
44
+
45
+ @app.route("/analyze", methods=["POST"])
46
+ def analyze():
47
+ data = request.get_json()
48
+ text = data["text"]
49
+ model_name = data["model"]
50
+
51
+ model, tokenizer = get_model_and_tokenizer(model_name)
52
+ model.eval()
53
+
54
+ with torch.no_grad():
55
+ inputs = tokenizer(text, return_tensors="pt")
56
+ outputs = model(**inputs)
57
+ logits = outputs.logits
58
+
59
+ input_ids = inputs["input_ids"][0]
60
+ tokens = tokenizer.convert_ids_to_tokens(input_ids)
61
+
62
+ log_probs = []
63
+ all_log_probs_list = []
64
+ top_k_predictions = []
65
+
66
+ for i in range(len(input_ids) - 1):
67
+ probs_at_position = F.log_softmax(logits[0, i, :], dim=-1)
68
+ all_log_probs_list.extend(probs_at_position.tolist())
69
+
70
+ top_k_values, top_k_indices = torch.topk(probs_at_position, 5)
71
+ top_k_tokens = tokenizer.convert_ids_to_tokens(top_k_indices)
72
+ top_k_predictions.append(
73
+ [
74
+ {"token": t, "log_prob": v.item()}
75
+ for t, v in zip(top_k_tokens, top_k_values)
76
+ ]
77
+ )
78
+
79
+ log_prob = probs_at_position[input_ids[i + 1]].item()
80
+ log_probs.append(log_prob)
81
+
82
+ percentiles = [percentileofscore(all_log_probs_list, lp) for lp in log_probs]
83
+ joint_log_likelihood = sum(log_probs)
84
+ average_log_likelihood = (
85
+ joint_log_likelihood / len(log_probs) if log_probs else 0
86
+ )
87
+
88
+ return jsonify({
89
+ "tokens": tokens,
90
+ "percentiles": percentiles,
91
+ "log_probs": log_probs,
92
+ "top_k_predictions": top_k_predictions,
93
+ "joint_log_likelihood": joint_log_likelihood,
94
+ "average_log_likelihood": average_log_likelihood,
95
+ })
96
+
97
+ if __name__ == "__main__":
98
+ app.run(debug=True)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ flask
2
+ transformers
3
+ torch
4
+ scipy
static/script.js ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ document.getElementById("analyze-button").addEventListener("click", async () => {
2
+ const text = document.getElementById("input-text").value;
3
+ const model = document.getElementById("model-select").value;
4
+
5
+ const response = await fetch("/analyze", {
6
+ method: "POST",
7
+ headers: {
8
+ "Content-Type": "application/json"
9
+ },
10
+ body: JSON.stringify({ text, model })
11
+ });
12
+
13
+ const data = await response.json();
14
+
15
+ const coloredTextDiv = document.getElementById("colored-text");
16
+ coloredTextDiv.innerHTML = "";
17
+
18
+ // Always add the first token
19
+ const firstToken = data.tokens[0];
20
+ const firstTokenSpan = document.createElement("span");
21
+ firstTokenSpan.classList.add("token");
22
+
23
+ // Handle special tokens and regular tokens differently
24
+ if (firstToken === "<s>" || firstToken === "<|endoftext|>") {
25
+ firstTokenSpan.style.backgroundColor = "#808080"; // Gray for special tokens
26
+ firstTokenSpan.textContent = "■";
27
+ tippy(firstTokenSpan, {
28
+ content: "<div><strong>Beginning of Sequence</strong></div>",
29
+ allowHTML: true,
30
+ theme: 'custom',
31
+ placement: 'top',
32
+ interactive: true
33
+ });
34
+ } else {
35
+ // Handle regular first token
36
+ firstTokenSpan.style.backgroundColor = "#808080"; // or any other color you prefer
37
+ firstTokenSpan.textContent = firstToken;
38
+ tippy(firstTokenSpan, {
39
+ content: `<div><strong>First Token</strong></div>`,
40
+ allowHTML: true,
41
+ theme: 'custom',
42
+ placement: 'top',
43
+ interactive: true
44
+ });
45
+ }
46
+
47
+ coloredTextDiv.appendChild(firstTokenSpan);
48
+
49
+ for (let index = 0; index < data.log_probs.length; index++) {
50
+ const token = data.tokens[index + 1];
51
+ const percentile = data.percentiles[index];
52
+ const logProb = data.log_probs[index];
53
+ const topKPredictions = data.top_k_predictions[index];
54
+ const color = getColor(data.log_probs, logProb);
55
+
56
+ const tokenSpan = document.createElement("span");
57
+ tokenSpan.classList.add("token");
58
+ tokenSpan.style.backgroundColor = color;
59
+
60
+ let displayToken = token;
61
+ let specialTokenDescription = "";
62
+
63
+ // Enhanced special token handling
64
+ if (token === "<s>" || token === "<|endoftext|>") {
65
+ displayToken = "■";
66
+ specialTokenDescription = "Beginning of Sequence";
67
+ } else if (token === "</s>" || token === "<|endoftext|>") {
68
+ displayToken = "■";
69
+ specialTokenDescription = "End of Sequence";
70
+ } else if (token === "<0x0A>") {
71
+ displayToken = "■";
72
+ specialTokenDescription = "Newline";
73
+ } else if (token.startsWith("<") && token.endsWith(">")) {
74
+ displayToken = "■";
75
+ specialTokenDescription = "Special Token: " + token;
76
+ } else {
77
+ // Clean up GPT-2 style tokens (Ġ and Ċ)
78
+ displayToken = displayToken
79
+ .replace(/\u2581/g, " ") // Replace underscore token
80
+ .replace(/Ġ/g, " ") // Replace GPT-2 space token
81
+ .replace(/Ċ/g, "\n"); // Replace GPT-2 newline token
82
+ }
83
+
84
+ tokenSpan.textContent = displayToken;
85
+
86
+ let tooltipContent = "";
87
+ if (specialTokenDescription) {
88
+ tooltipContent += `<div style="font-weight: bold; margin-bottom: 8px;">${specialTokenDescription}</div>`;
89
+ }
90
+
91
+ tooltipContent += `<div style="font-weight: bold; margin-bottom: 4px;">Top 5 Predictions:</div>`;
92
+ topKPredictions.forEach(pred => {
93
+ let predToken = pred.token;
94
+ if (predToken === "<0x0A>") {
95
+ predToken = "\\n";
96
+ } else if (predToken.startsWith("<") && predToken.endsWith(">")) {
97
+ predToken = "[SPECIAL]";
98
+ } else {
99
+ predToken = predToken
100
+ .replace(/\u2581/g, " ")
101
+ .replace(/Ġ/g, " ")
102
+ .replace(/Ċ/g, "\n");
103
+ }
104
+ tooltipContent += `<div style="padding-left: 8px;">${predToken}: ${pred.log_prob.toFixed(4)}</div>`;
105
+ });
106
+
107
+ tooltipContent += `<div style="margin-top: 8px; border-top: 1px solid #555; padding-top: 8px;">
108
+ <div><strong>Stats:</strong></div>
109
+ <div style="padding-left: 8px;">Percentile: ${percentile.toFixed(2)}</div>
110
+ <div style="padding-left: 8px;">Log-Likelihood: ${logProb.toFixed(4)}</div>
111
+ </div>`;
112
+
113
+ tippy(tokenSpan, {
114
+ content: tooltipContent,
115
+ allowHTML: true,
116
+ theme: 'custom',
117
+ placement: 'top',
118
+ interactive: true
119
+ });
120
+
121
+ coloredTextDiv.appendChild(tokenSpan);
122
+ if (token === "<0x0A>") {
123
+ coloredTextDiv.appendChild(document.createElement("br"));
124
+ }
125
+ }
126
+
127
+ document.getElementById("joint-log-likelihood").textContent = data.joint_log_likelihood.toFixed(4);
128
+ document.getElementById("average-log-likelihood").textContent = data.average_log_likelihood.toFixed(4);
129
+ });
130
+
131
+ function getColor(allLogProbs, currentLogProb) {
132
+ const minLogProb = Math.min(...allLogProbs);
133
+ const maxLogProb = Math.max(...allLogProbs);
134
+
135
+ // Normalize to 0-1 range
136
+ let normalizedLogProb = (currentLogProb - minLogProb) / (maxLogProb - minLogProb);
137
+ normalizedLogProb = Math.max(0, Math.min(1, normalizedLogProb)); // Clamp
138
+
139
+ // Optional: Apply a power transformation (adjust the exponent as needed)
140
+ const power = 0.7; // Example: Less than 1 emphasizes differences at lower end
141
+ normalizedLogProb = Math.pow(normalizedLogProb, power);
142
+
143
+ const hue = normalizedLogProb * 120; // 0 (red) to 120 (green)
144
+ return `hsl(${hue}, 100%, 50%)`;
145
+ }
static/style.css ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ :root {
2
+ --primary-color: #2563eb;
3
+ --primary-hover: #1d4ed8;
4
+ --background-color: #f8fafc;
5
+ --card-background: #ffffff;
6
+ --text-primary: #1e293b;
7
+ --text-secondary: #64748b;
8
+ --border-color: #e2e8f0;
9
+ --token-hover: #f1f5f9;
10
+ }
11
+
12
+ * {
13
+ margin: 0;
14
+ padding: 0;
15
+ box-sizing: border-box;
16
+ }
17
+
18
+ body {
19
+ font-family: 'Inter', sans-serif;
20
+ background-color: var(--background-color);
21
+ color: var(--text-primary);
22
+ line-height: 1.5;
23
+ }
24
+
25
+ .container {
26
+ max-width: 1200px;
27
+ margin: 0 auto;
28
+ padding: 2rem;
29
+ }
30
+
31
+ header {
32
+ text-align: center;
33
+ margin-bottom: 2rem;
34
+ }
35
+
36
+ h1 {
37
+ font-size: 2.5rem;
38
+ font-weight: 600;
39
+ color: var(--text-primary);
40
+ margin-bottom: 0.5rem;
41
+ }
42
+
43
+ .subtitle {
44
+ color: var(--text-secondary);
45
+ font-size: 1.1rem;
46
+ }
47
+
48
+ .control-panel {
49
+ background-color: var(--card-background);
50
+ border-radius: 12px;
51
+ padding: 1.5rem;
52
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
53
+ margin-bottom: 2rem;
54
+ }
55
+
56
+ .input-group {
57
+ margin-bottom: 1.5rem;
58
+ }
59
+
60
+ label {
61
+ display: block;
62
+ margin-bottom: 0.5rem;
63
+ font-weight: 500;
64
+ color: var(--text-primary);
65
+ }
66
+
67
+ .styled-select {
68
+ width: 100%;
69
+ padding: 0.75rem;
70
+ border: 1px solid var(--border-color);
71
+ border-radius: 6px;
72
+ font-size: 1rem;
73
+ background-color: white;
74
+ cursor: pointer;
75
+ }
76
+
77
+ textarea {
78
+ width: 100%;
79
+ min-height: 120px;
80
+ padding: 0.75rem;
81
+ border: 1px solid var(--border-color);
82
+ border-radius: 6px;
83
+ font-size: 1rem;
84
+ font-family: inherit;
85
+ resize: vertical;
86
+ }
87
+
88
+ .primary-button {
89
+ background-color: var(--primary-color);
90
+ color: white;
91
+ border: none;
92
+ padding: 0.75rem 1.5rem;
93
+ border-radius: 6px;
94
+ font-size: 1rem;
95
+ font-weight: 500;
96
+ cursor: pointer;
97
+ transition: background-color 0.2s;
98
+ }
99
+
100
+ .primary-button:hover {
101
+ background-color: var(--primary-hover);
102
+ }
103
+
104
+ .output-panel {
105
+ background-color: var(--card-background);
106
+ border-radius: 12px;
107
+ padding: 1.5rem;
108
+ box-shadow: 0 1px 3px rgba(0, 0, 0, 0.1);
109
+ }
110
+
111
+ .output-section {
112
+ margin-bottom: 2rem;
113
+ }
114
+
115
+ h2 {
116
+ font-size: 1.5rem;
117
+ margin-bottom: 1rem;
118
+ color: var(--text-primary);
119
+ }
120
+
121
+ .token-display {
122
+ background-color: white;
123
+ border-radius: 8px;
124
+ padding: 1rem;
125
+ line-height: 1.3;
126
+ min-height: 100px;
127
+ font-size: 1rem;
128
+ white-space: pre-wrap;
129
+ }
130
+
131
+ .token {
132
+ padding: 0;
133
+ border-radius: 0;
134
+ margin: 0;
135
+ cursor: pointer;
136
+ transition: background-color 0.15s;
137
+ display: inline;
138
+ }
139
+
140
+ .token:hover {
141
+ background-color: rgba(0, 0, 0, 0.05) !important;
142
+ }
143
+
144
+ .stats-grid {
145
+ display: grid;
146
+ grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
147
+ gap: 1rem;
148
+ }
149
+
150
+ .stat-card {
151
+ background-color: white;
152
+ padding: 1rem;
153
+ border-radius: 8px;
154
+ border: 1px solid var(--border-color);
155
+ }
156
+
157
+ .stat-label {
158
+ font-size: 0.875rem;
159
+ color: var(--text-secondary);
160
+ margin-bottom: 0.5rem;
161
+ }
162
+
163
+ .stat-value {
164
+ font-size: 1.25rem;
165
+ font-weight: 600;
166
+ color: var(--text-primary);
167
+ }
168
+
169
+ /* Tippy custom theme */
170
+ .tippy-box[data-theme~='custom'] {
171
+ background-color: white;
172
+ color: var(--text-primary);
173
+ border: 1px solid var(--border-color);
174
+ box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1);
175
+ border-radius: 8px;
176
+ font-size: 0.875rem;
177
+ }
178
+
179
+ .tippy-box[data-theme~='custom'] .tippy-content {
180
+ padding: 1rem;
181
+ }
182
+
183
+ @media (max-width: 768px) {
184
+ .container {
185
+ padding: 1rem;
186
+ }
187
+
188
+ h1 {
189
+ font-size: 2rem;
190
+ }
191
+
192
+ .stats-grid {
193
+ grid-template-columns: 1fr;
194
+ }
195
+ }
templates/index.html ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!DOCTYPE html>
2
+ <html>
3
+ <head>
4
+ <title>LLM Token Visualization</title>
5
+ <link rel="stylesheet" href="{{ url_for('static', filename='style.css') }}">
6
+ <script src="https://unpkg.com/@popperjs/core@2"></script>
7
+ <script src="https://unpkg.com/tippy.js@6"></script>
8
+ <link rel="stylesheet" href="https://unpkg.com/tippy.js@6/themes/light.css">
9
+ <link href="https://fonts.googleapis.com/css2?family=Inter:wght@400;500;600&display=swap" rel="stylesheet">
10
+ </head>
11
+ <body>
12
+ <div class="container">
13
+ <header>
14
+ <h1>LLM Token Visualization</h1>
15
+ <p class="subtitle">Analyze how language models process and predict text</p>
16
+ </header>
17
+
18
+ <div class="control-panel">
19
+ <div class="input-group">
20
+ <label for="model-select">Model:</label>
21
+ <select id="model-select" class="styled-select">
22
+ {% for model in models %}
23
+ <option value="{{ model }}">{{ model }}</option>
24
+ {% endfor %}
25
+ </select>
26
+ </div>
27
+
28
+ <div class="input-group">
29
+ <label for="input-text">Text to Analyze:</label>
30
+ <textarea id="input-text" placeholder="Enter your text here..."></textarea>
31
+ </div>
32
+
33
+ <button id="analyze-button" class="primary-button">Analyze</button>
34
+ </div>
35
+
36
+ <div id="output" class="output-panel">
37
+ <div class="output-section">
38
+ <h2>Token Analysis</h2>
39
+ <div id="colored-text" class="token-display"></div>
40
+ </div>
41
+
42
+ <div class="stats-section">
43
+ <h2>Statistics</h2>
44
+ <div class="stats-grid">
45
+ <div class="stat-card">
46
+ <div class="stat-label">Joint Log-Likelihood</div>
47
+ <div class="stat-value" id="joint-log-likelihood">-</div>
48
+ </div>
49
+ <div class="stat-card">
50
+ <div class="stat-label">Average Log-Likelihood</div>
51
+ <div class="stat-value" id="average-log-likelihood">-</div>
52
+ </div>
53
+ </div>
54
+ </div>
55
+ </div>
56
+ </div>
57
+
58
+ <script src="{{ url_for('static', filename='script.js') }}"></script>
59
+ </body>
60
+ </html>