Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from transformers import AutoConfig
|
3 |
+
from typing import Dict, Any, Tuple, Optional
|
4 |
+
import math
|
5 |
+
|
6 |
+
def get_model_config(model_id: str) -> AutoConfig:
|
7 |
+
"""获取模型配置信息"""
|
8 |
+
try:
|
9 |
+
# 使用transformers的AutoConfig,更加可靠
|
10 |
+
config = AutoConfig.from_pretrained(
|
11 |
+
model_id,
|
12 |
+
trust_remote_code=True, # 支持自定义模型
|
13 |
+
revision="main"
|
14 |
+
)
|
15 |
+
return config
|
16 |
+
except Exception as e:
|
17 |
+
raise Exception(f"无法获取模型配置: {str(e)}")
|
18 |
+
|
19 |
+
def analyze_attention_mechanism(config: AutoConfig) -> Dict[str, Any]:
|
20 |
+
"""分析注意力机制类型"""
|
21 |
+
model_type = getattr(config, "model_type", "").lower()
|
22 |
+
architecture = getattr(config, "architectures", [])
|
23 |
+
|
24 |
+
# 检测各种优化技术
|
25 |
+
attention_info = {
|
26 |
+
"uses_gqa": False,
|
27 |
+
"uses_mla": False,
|
28 |
+
"uses_sliding_window": False,
|
29 |
+
"attention_type": "Multi-Head Attention (MHA)"
|
30 |
+
}
|
31 |
+
|
32 |
+
# 检测GQA (Grouped Query Attention)
|
33 |
+
num_attention_heads = getattr(config, "num_attention_heads", getattr(config, "n_head", 0))
|
34 |
+
num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
|
35 |
+
|
36 |
+
if num_key_value_heads < num_attention_heads and num_key_value_heads > 0:
|
37 |
+
attention_info["uses_gqa"] = True
|
38 |
+
attention_info["attention_type"] = "Grouped Query Attention (GQA)"
|
39 |
+
|
40 |
+
# 检测MLA (Multi-head Latent Attention) - 主要在DeepSeek-V2等模型中
|
41 |
+
if "deepseek" in model_type or any("deepseek" in str(arch).lower() for arch in architecture):
|
42 |
+
if hasattr(config, "kv_lora_rank") or hasattr(config, "q_lora_rank"):
|
43 |
+
attention_info["uses_mla"] = True
|
44 |
+
attention_info["attention_type"] = "Multi-head Latent Attention (MLA)"
|
45 |
+
|
46 |
+
# 检测滑动窗口注意力
|
47 |
+
if hasattr(config, "sliding_window") or hasattr(config, "attention_window_size"):
|
48 |
+
attention_info["uses_sliding_window"] = True
|
49 |
+
|
50 |
+
# 特殊模型类型检测
|
51 |
+
if "llama" in model_type:
|
52 |
+
attention_info["attention_type"] = "RoPE + GQA" if attention_info["uses_gqa"] else "RoPE + MHA"
|
53 |
+
elif "mistral" in model_type:
|
54 |
+
attention_info["attention_type"] = "Sliding Window + GQA" if attention_info["uses_gqa"] else "Sliding Window + MHA"
|
55 |
+
elif "qwen" in model_type:
|
56 |
+
attention_info["attention_type"] = "QWen Attention (GQA)" if attention_info["uses_gqa"] else "QWen Attention"
|
57 |
+
|
58 |
+
return attention_info
|
59 |
+
|
60 |
+
def calculate_kv_cache_size(config: AutoConfig, sequence_length: int = 2048, batch_size: int = 1) -> Dict[str, Any]:
|
61 |
+
"""计算KV cache大小"""
|
62 |
+
|
63 |
+
# 获取基本参数,兼容不同的参数名
|
64 |
+
num_layers = getattr(config, "num_hidden_layers", getattr(config, "n_layer", getattr(config, "num_layers", 0)))
|
65 |
+
num_attention_heads = getattr(config, "num_attention_heads", getattr(config, "n_head", 0))
|
66 |
+
num_key_value_heads = getattr(config, "num_key_value_heads", num_attention_heads)
|
67 |
+
hidden_size = getattr(config, "hidden_size", getattr(config, "n_embd", getattr(config, "d_model", 0)))
|
68 |
+
|
69 |
+
# 计算head dimension
|
70 |
+
head_dim = hidden_size // num_attention_heads if num_attention_heads > 0 else 0
|
71 |
+
|
72 |
+
# 如果是MLA,需要特殊处理
|
73 |
+
kv_lora_rank = getattr(config, "kv_lora_rank", 0)
|
74 |
+
if kv_lora_rank > 0: # MLA架构
|
75 |
+
# MLA中KV的维度被压缩
|
76 |
+
effective_kv_dim = kv_lora_rank
|
77 |
+
else:
|
78 |
+
effective_kv_dim = head_dim * num_key_value_heads
|
79 |
+
|
80 |
+
# 计算每个token的KV cache大小 (Key + Value)
|
81 |
+
# 使用FP16 (2 bytes per element)
|
82 |
+
bytes_per_element = 2
|
83 |
+
kv_size_per_token_per_layer = 2 * effective_kv_dim * bytes_per_element # K + V
|
84 |
+
|
85 |
+
# 总的KV cache大小
|
86 |
+
total_kv_cache_bytes = kv_size_per_token_per_layer * num_layers * sequence_length * batch_size
|
87 |
+
|
88 |
+
# 转换为更友好的单位
|
89 |
+
def format_bytes(bytes_val):
|
90 |
+
if bytes_val < 1024:
|
91 |
+
return f"{bytes_val} B"
|
92 |
+
elif bytes_val < 1024**2:
|
93 |
+
return f"{bytes_val/1024:.2f} KB"
|
94 |
+
elif bytes_val < 1024**3:
|
95 |
+
return f"{bytes_val/(1024**2):.2f} MB"
|
96 |
+
else:
|
97 |
+
return f"{bytes_val/(1024**3):.2f} GB"
|
98 |
+
|
99 |
+
return {
|
100 |
+
"num_layers": num_layers,
|
101 |
+
"num_attention_heads": num_attention_heads,
|
102 |
+
"num_key_value_heads": num_key_value_heads,
|
103 |
+
"head_dim": head_dim,
|
104 |
+
"hidden_size": hidden_size,
|
105 |
+
"effective_kv_dim": effective_kv_dim,
|
106 |
+
"kv_size_per_token": format_bytes(kv_size_per_token_per_layer * num_layers),
|
107 |
+
"total_kv_cache": format_bytes(total_kv_cache_bytes),
|
108 |
+
"total_kv_cache_bytes": total_kv_cache_bytes,
|
109 |
+
"kv_lora_rank": kv_lora_rank
|
110 |
+
}
|
111 |
+
|
112 |
+
def analyze_model(model_id: str, sequence_length: int = 2048, batch_size: int = 1) -> str:
|
113 |
+
"""分析模型并返回结果"""
|
114 |
+
try:
|
115 |
+
# 获取模型配置
|
116 |
+
config = get_model_config(model_id)
|
117 |
+
|
118 |
+
# 分析注意力机制
|
119 |
+
attention_info = analyze_attention_mechanism(config)
|
120 |
+
|
121 |
+
# 计算KV cache大小
|
122 |
+
kv_info = calculate_kv_cache_size(config, sequence_length, batch_size)
|
123 |
+
|
124 |
+
# 格式化输出
|
125 |
+
result = f"""
|
126 |
+
## 模型信息分析 - {model_id}
|
127 |
+
|
128 |
+
### 基本参数
|
129 |
+
- **模型类型**: {getattr(config, 'model_type', 'Unknown')}
|
130 |
+
- **层数**: {kv_info['num_layers']}
|
131 |
+
- **隐藏层大小**: {kv_info['hidden_size']}
|
132 |
+
- **注意力头数**: {kv_info['num_attention_heads']}
|
133 |
+
- **KV头数**: {kv_info['num_key_value_heads']}
|
134 |
+
- **每个头的维度**: {kv_info['head_dim']}
|
135 |
+
|
136 |
+
### 注意力机制优化
|
137 |
+
- **注意力类型**: {attention_info['attention_type']}
|
138 |
+
- **使用GQA**: {'✅ 是' if attention_info['uses_gqa'] else '❌ 否'}
|
139 |
+
- **使用MLA**: {'✅ 是' if attention_info['uses_mla'] else '❌ 否'}
|
140 |
+
- **滑动窗口**: {'✅ 是' if attention_info['uses_sliding_window'] else '❌ 否'}
|
141 |
+
|
142 |
+
### KV Cache 存储分析
|
143 |
+
- **序列长度**: {sequence_length}
|
144 |
+
- **批量大小**: {batch_size}
|
145 |
+
- **有效KV维度**: {kv_info['effective_kv_dim']}
|
146 |
+
- **每个token的KV存储**: {kv_info['kv_size_per_token']}
|
147 |
+
- **总KV Cache大小**: {kv_info['total_kv_cache']}
|
148 |
+
|
149 |
+
### 优化效果分析
|
150 |
+
"""
|
151 |
+
|
152 |
+
# 计算GQA的内存节省
|
153 |
+
if attention_info['uses_gqa']:
|
154 |
+
original_kv_heads = kv_info['num_attention_heads']
|
155 |
+
actual_kv_heads = kv_info['num_key_value_heads']
|
156 |
+
memory_reduction = (1 - actual_kv_heads / original_kv_heads) * 100
|
157 |
+
result += f"- **GQA内存节省**: {memory_reduction:.1f}% (KV头数从{original_kv_heads}减少到{actual_kv_heads})\n"
|
158 |
+
|
159 |
+
# MLA的特殊说明
|
160 |
+
if attention_info['uses_mla']:
|
161 |
+
result += f"- **MLA压缩**: KV维度被压缩到{kv_info['kv_lora_rank']}维\n"
|
162 |
+
|
163 |
+
# 内存使用建议
|
164 |
+
total_gb = kv_info['total_kv_cache_bytes'] / (1024**3)
|
165 |
+
if total_gb > 8:
|
166 |
+
result += f"\n⚠️ **内存警告**: KV Cache需要{total_gb:.2f}GB内存,建议使用高端GPU"
|
167 |
+
elif total_gb > 4:
|
168 |
+
result += f"\n💡 **内存提示**: KV Cache需要{total_gb:.2f}GB内存,中等配置可运行"
|
169 |
+
else:
|
170 |
+
result += f"\n✅ **内存友好**: KV Cache仅需{total_gb:.2f}GB内存"
|
171 |
+
|
172 |
+
return result
|
173 |
+
|
174 |
+
except Exception as e:
|
175 |
+
return f"❌ 分析失败: {str(e)}"
|
176 |
+
|
177 |
+
# 创建Gradio界面
|
178 |
+
def create_interface():
|
179 |
+
with gr.Blocks(title="Hugging Face模型KV Cache分析器", theme=gr.themes.Soft()) as iface:
|
180 |
+
gr.Markdown("# 🤗 Hugging Face模型KV Cache分析器")
|
181 |
+
gr.Markdown("输入模型ID来分析其KV cache大小和注意力机制优化情况")
|
182 |
+
|
183 |
+
with gr.Row():
|
184 |
+
with gr.Column(scale=3):
|
185 |
+
model_input = gr.Textbox(
|
186 |
+
label="模型ID",
|
187 |
+
placeholder="例如: microsoft/DialoGPT-medium, meta-llama/Llama-2-7b-hf",
|
188 |
+
value="microsoft/DialoGPT-medium"
|
189 |
+
)
|
190 |
+
with gr.Column(scale=1):
|
191 |
+
seq_len_input = gr.Number(
|
192 |
+
label="序列长度",
|
193 |
+
value=2048,
|
194 |
+
minimum=1,
|
195 |
+
maximum=32768
|
196 |
+
)
|
197 |
+
with gr.Column(scale=1):
|
198 |
+
batch_size_input = gr.Number(
|
199 |
+
label="批量大小",
|
200 |
+
value=1,
|
201 |
+
minimum=1,
|
202 |
+
maximum=128
|
203 |
+
)
|
204 |
+
|
205 |
+
analyze_btn = gr.Button("🔍 分析模型", variant="primary", size="lg")
|
206 |
+
|
207 |
+
output = gr.Markdown(label="分析结果")
|
208 |
+
|
209 |
+
# 添加一些示例模型
|
210 |
+
gr.Markdown("### 💡 热门模型示例")
|
211 |
+
example_models = [
|
212 |
+
["meta-llama/Llama-2-7b-hf", 2048, 1],
|
213 |
+
["microsoft/DialoGPT-medium", 1024, 1],
|
214 |
+
["Qwen/Qwen-7B", 2048, 1],
|
215 |
+
["mistralai/Mistral-7B-v0.1", 2048, 1],
|
216 |
+
["deepseek-ai/deepseek-coder-6.7b-base", 2048, 1]
|
217 |
+
]
|
218 |
+
|
219 |
+
gr.Examples(
|
220 |
+
examples=example_models,
|
221 |
+
inputs=[model_input, seq_len_input, batch_size_input],
|
222 |
+
outputs=output,
|
223 |
+
fn=analyze_model,
|
224 |
+
cache_examples=False
|
225 |
+
)
|
226 |
+
|
227 |
+
analyze_btn.click(
|
228 |
+
fn=analyze_model,
|
229 |
+
inputs=[model_input, seq_len_input, batch_size_input],
|
230 |
+
outputs=output
|
231 |
+
)
|
232 |
+
|
233 |
+
gr.Markdown("""
|
234 |
+
### 📖 说明
|
235 |
+
- **GQA**: Grouped Query Attention,通过减少KV头数来节省内存
|
236 |
+
- **MLA**: Multi-head Latent Attention,通过低秩分解压缩KV cache
|
237 |
+
- **滑动窗口**: 限制注意力范围来减少计算和内存使用
|
238 |
+
- KV Cache大小计算基于FP16精度 (每个元素2字节)
|
239 |
+
- 使用 `transformers.AutoConfig` 获取配置,支持自定义模型
|
240 |
+
|
241 |
+
### 🛠️ 安装依赖
|
242 |
+
```bash
|
243 |
+
pip install gradio transformers torch
|
244 |
+
```
|
245 |
+
""")
|
246 |
+
|
247 |
+
return iface
|
248 |
+
|
249 |
+
if __name__ == "__main__":
|
250 |
+
app = create_interface()
|
251 |
+
app.launch(share=True)
|