kebeliu commited on
Commit
d2cd03f
·
verified ·
1 Parent(s): 7da430a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +251 -0
app.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from transformers import AutoConfig
3
+ from typing import Dict, Any, Tuple, Optional
4
+ import math
5
+
6
+ def get_model_config(model_id: str) -> AutoConfig:
7
+ """获取模型配置信息"""
8
+ try:
9
+ # 使用transformers的AutoConfig,更加可靠
10
+ config = AutoConfig.from_pretrained(
11
+ model_id,
12
+ trust_remote_code=True, # 支持自定义模型
13
+ revision="main"
14
+ )
15
+ return config
16
+ except Exception as e:
17
+ raise Exception(f"无法获取模型配置: {str(e)}")
18
+
19
+ def analyze_attention_mechanism(config: AutoConfig) -> Dict[str, Any]:
20
+ """分析注意力机制类型"""
21
+ model_type = getattr(config, "model_type", "").lower()
22
+ architecture = getattr(config, "architectures", [])
23
+
24
+ # 检测各种优化技术
25
+ attention_info = {
26
+ "uses_gqa": False,
27
+ "uses_mla": False,
28
+ "uses_sliding_window": False,
29
+ "attention_type": "Multi-Head Attention (MHA)"
30
+ }
31
+
32
+ # 检测GQA (Grouped Query Attention)
33
+ num_attention_heads = getattr(config, "num_attention_heads", getattr(config, "n_head", 0))
34
+ num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
35
+
36
+ if num_key_value_heads < num_attention_heads and num_key_value_heads > 0:
37
+ attention_info["uses_gqa"] = True
38
+ attention_info["attention_type"] = "Grouped Query Attention (GQA)"
39
+
40
+ # 检测MLA (Multi-head Latent Attention) - 主要在DeepSeek-V2等模型中
41
+ if "deepseek" in model_type or any("deepseek" in str(arch).lower() for arch in architecture):
42
+ if hasattr(config, "kv_lora_rank") or hasattr(config, "q_lora_rank"):
43
+ attention_info["uses_mla"] = True
44
+ attention_info["attention_type"] = "Multi-head Latent Attention (MLA)"
45
+
46
+ # 检测滑动窗口注意力
47
+ if hasattr(config, "sliding_window") or hasattr(config, "attention_window_size"):
48
+ attention_info["uses_sliding_window"] = True
49
+
50
+ # 特殊模型类型检测
51
+ if "llama" in model_type:
52
+ attention_info["attention_type"] = "RoPE + GQA" if attention_info["uses_gqa"] else "RoPE + MHA"
53
+ elif "mistral" in model_type:
54
+ attention_info["attention_type"] = "Sliding Window + GQA" if attention_info["uses_gqa"] else "Sliding Window + MHA"
55
+ elif "qwen" in model_type:
56
+ attention_info["attention_type"] = "QWen Attention (GQA)" if attention_info["uses_gqa"] else "QWen Attention"
57
+
58
+ return attention_info
59
+
60
+ def calculate_kv_cache_size(config: AutoConfig, sequence_length: int = 2048, batch_size: int = 1) -> Dict[str, Any]:
61
+ """计算KV cache大小"""
62
+
63
+ # 获取基本参数,兼容不同的参数名
64
+ num_layers = getattr(config, "num_hidden_layers", getattr(config, "n_layer", getattr(config, "num_layers", 0)))
65
+ num_attention_heads = getattr(config, "num_attention_heads", getattr(config, "n_head", 0))
66
+ num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
67
+ hidden_size = getattr(config, "hidden_size", getattr(config, "n_embd", getattr(config, "d_model", 0)))
68
+
69
+ # 计算head dimension
70
+ head_dim = hidden_size // num_attention_heads if num_attention_heads > 0 else 0
71
+
72
+ # 如果是MLA,需要特殊处理
73
+ kv_lora_rank = getattr(config, "kv_lora_rank", 0)
74
+ if kv_lora_rank > 0: # MLA架构
75
+ # MLA中KV的维度被压缩
76
+ effective_kv_dim = kv_lora_rank
77
+ else:
78
+ effective_kv_dim = head_dim * num_key_value_heads
79
+
80
+ # 计算每个token的KV cache大小 (Key + Value)
81
+ # 使用FP16 (2 bytes per element)
82
+ bytes_per_element = 2
83
+ kv_size_per_token_per_layer = 2 * effective_kv_dim * bytes_per_element # K + V
84
+
85
+ # 总的KV cache大小
86
+ total_kv_cache_bytes = kv_size_per_token_per_layer * num_layers * sequence_length * batch_size
87
+
88
+ # 转换为更友好的单位
89
+ def format_bytes(bytes_val):
90
+ if bytes_val < 1024:
91
+ return f"{bytes_val} B"
92
+ elif bytes_val < 1024**2:
93
+ return f"{bytes_val/1024:.2f} KB"
94
+ elif bytes_val < 1024**3:
95
+ return f"{bytes_val/(1024**2):.2f} MB"
96
+ else:
97
+ return f"{bytes_val/(1024**3):.2f} GB"
98
+
99
+ return {
100
+ "num_layers": num_layers,
101
+ "num_attention_heads": num_attention_heads,
102
+ "num_key_value_heads": num_key_value_heads,
103
+ "head_dim": head_dim,
104
+ "hidden_size": hidden_size,
105
+ "effective_kv_dim": effective_kv_dim,
106
+ "kv_size_per_token": format_bytes(kv_size_per_token_per_layer * num_layers),
107
+ "total_kv_cache": format_bytes(total_kv_cache_bytes),
108
+ "total_kv_cache_bytes": total_kv_cache_bytes,
109
+ "kv_lora_rank": kv_lora_rank
110
+ }
111
+
112
+ def analyze_model(model_id: str, sequence_length: int = 2048, batch_size: int = 1) -> str:
113
+ """分析模型并返回结果"""
114
+ try:
115
+ # 获取模型配置
116
+ config = get_model_config(model_id)
117
+
118
+ # 分析注意力机制
119
+ attention_info = analyze_attention_mechanism(config)
120
+
121
+ # 计算KV cache大小
122
+ kv_info = calculate_kv_cache_size(config, sequence_length, batch_size)
123
+
124
+ # 格式化输出
125
+ result = f"""
126
+ ## 模型信息分析 - {model_id}
127
+
128
+ ### 基本参数
129
+ - **模型类型**: {getattr(config, 'model_type', 'Unknown')}
130
+ - **层数**: {kv_info['num_layers']}
131
+ - **隐藏层大小**: {kv_info['hidden_size']}
132
+ - **注意力头数**: {kv_info['num_attention_heads']}
133
+ - **KV头数**: {kv_info['num_key_value_heads']}
134
+ - **每个头的维度**: {kv_info['head_dim']}
135
+
136
+ ### 注意力机制优化
137
+ - **注意力类型**: {attention_info['attention_type']}
138
+ - **使用GQA**: {'✅ 是' if attention_info['uses_gqa'] else '❌ 否'}
139
+ - **使用MLA**: {'✅ 是' if attention_info['uses_mla'] else '❌ 否'}
140
+ - **滑动窗口**: {'✅ 是' if attention_info['uses_sliding_window'] else '❌ 否'}
141
+
142
+ ### KV Cache 存储分析
143
+ - **序列长度**: {sequence_length}
144
+ - **批量大小**: {batch_size}
145
+ - **有效KV维度**: {kv_info['effective_kv_dim']}
146
+ - **每个token的KV存储**: {kv_info['kv_size_per_token']}
147
+ - **总KV Cache大小**: {kv_info['total_kv_cache']}
148
+
149
+ ### 优化效果分析
150
+ """
151
+
152
+ # 计算GQA的内存节省
153
+ if attention_info['uses_gqa']:
154
+ original_kv_heads = kv_info['num_attention_heads']
155
+ actual_kv_heads = kv_info['num_key_value_heads']
156
+ memory_reduction = (1 - actual_kv_heads / original_kv_heads) * 100
157
+ result += f"- **GQA内存节省**: {memory_reduction:.1f}% (KV头数从{original_kv_heads}减少到{actual_kv_heads})\n"
158
+
159
+ # MLA的特殊说明
160
+ if attention_info['uses_mla']:
161
+ result += f"- **MLA压缩**: KV维度被压缩到{kv_info['kv_lora_rank']}维\n"
162
+
163
+ # 内存使用建议
164
+ total_gb = kv_info['total_kv_cache_bytes'] / (1024**3)
165
+ if total_gb > 8:
166
+ result += f"\n⚠️ **内存警告**: KV Cache需要{total_gb:.2f}GB内存,建议使用高端GPU"
167
+ elif total_gb > 4:
168
+ result += f"\n💡 **内存提示**: KV Cache需要{total_gb:.2f}GB内存,中等配置可运行"
169
+ else:
170
+ result += f"\n✅ **内存友好**: KV Cache仅需{total_gb:.2f}GB内存"
171
+
172
+ return result
173
+
174
+ except Exception as e:
175
+ return f"❌ 分析失败: {str(e)}"
176
+
177
+ # 创建Gradio界面
178
+ def create_interface():
179
+ with gr.Blocks(title="Hugging Face模型KV Cache分析器", theme=gr.themes.Soft()) as iface:
180
+ gr.Markdown("# 🤗 Hugging Face模型KV Cache分析器")
181
+ gr.Markdown("输入模型ID来分析其KV cache大小和注意力机制优化情况")
182
+
183
+ with gr.Row():
184
+ with gr.Column(scale=3):
185
+ model_input = gr.Textbox(
186
+ label="模型ID",
187
+ placeholder="例如: microsoft/DialoGPT-medium, meta-llama/Llama-2-7b-hf",
188
+ value="microsoft/DialoGPT-medium"
189
+ )
190
+ with gr.Column(scale=1):
191
+ seq_len_input = gr.Number(
192
+ label="序列长度",
193
+ value=2048,
194
+ minimum=1,
195
+ maximum=32768
196
+ )
197
+ with gr.Column(scale=1):
198
+ batch_size_input = gr.Number(
199
+ label="批量大小",
200
+ value=1,
201
+ minimum=1,
202
+ maximum=128
203
+ )
204
+
205
+ analyze_btn = gr.Button("🔍 分析模型", variant="primary", size="lg")
206
+
207
+ output = gr.Markdown(label="分析结果")
208
+
209
+ # 添加一些示例模型
210
+ gr.Markdown("### 💡 热门模型示例")
211
+ example_models = [
212
+ ["meta-llama/Llama-2-7b-hf", 2048, 1],
213
+ ["microsoft/DialoGPT-medium", 1024, 1],
214
+ ["Qwen/Qwen-7B", 2048, 1],
215
+ ["mistralai/Mistral-7B-v0.1", 2048, 1],
216
+ ["deepseek-ai/deepseek-coder-6.7b-base", 2048, 1]
217
+ ]
218
+
219
+ gr.Examples(
220
+ examples=example_models,
221
+ inputs=[model_input, seq_len_input, batch_size_input],
222
+ outputs=output,
223
+ fn=analyze_model,
224
+ cache_examples=False
225
+ )
226
+
227
+ analyze_btn.click(
228
+ fn=analyze_model,
229
+ inputs=[model_input, seq_len_input, batch_size_input],
230
+ outputs=output
231
+ )
232
+
233
+ gr.Markdown("""
234
+ ### 📖 说明
235
+ - **GQA**: Grouped Query Attention,通过减少KV头数来节省内存
236
+ - **MLA**: Multi-head Latent Attention,通过低秩分解压缩KV cache
237
+ - **滑动窗口**: 限制注意力范围来减少计算和内存使用
238
+ - KV Cache大小计算基于FP16精度 (每个元素2字节)
239
+ - 使用 `transformers.AutoConfig` 获取配置,支持自定义模型
240
+
241
+ ### 🛠️ 安装依赖
242
+ ```bash
243
+ pip install gradio transformers torch
244
+ ```
245
+ """)
246
+
247
+ return iface
248
+
249
+ if __name__ == "__main__":
250
+ app = create_interface()
251
+ app.launch(share=True)