Avijit Ghosh
commited on
Commit
·
961c6fe
1
Parent(s):
27c66d1
add cached data and preprocessing code
Browse files- .DS_Store +0 -0
- .gitattributes +1 -0
- app.py +488 -497
- models.csv → models_processed.parquet +2 -2
- preprocess.py +371 -0
.DS_Store
ADDED
Binary file (6.15 kB). View file
|
|
.gitattributes
CHANGED
@@ -35,3 +35,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
org_to_artifacts_2l_stats.json filter=lfs diff=lfs merge=lfs -text
|
37 |
models.csv filter=lfs diff=lfs merge=lfs -text
|
|
|
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
org_to_artifacts_2l_stats.json filter=lfs diff=lfs merge=lfs -text
|
37 |
models.csv filter=lfs diff=lfs merge=lfs -text
|
38 |
+
models_processed.parquet filter=lfs diff=lfs merge=lfs -text
|
app.py
CHANGED
@@ -1,550 +1,541 @@
|
|
|
|
|
|
1 |
import json
|
2 |
import gradio as gr
|
3 |
import pandas as pd
|
4 |
import plotly.express as px
|
5 |
import os
|
6 |
-
import numpy as np
|
7 |
-
import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
-
# Define pipeline tags
|
10 |
PIPELINE_TAGS = [
|
11 |
-
'text-generation',
|
12 |
-
'
|
13 |
-
'text-
|
14 |
-
'
|
15 |
-
'
|
16 |
-
'
|
17 |
-
'
|
18 |
-
'
|
19 |
-
'
|
20 |
-
'
|
21 |
-
'text-to-speech',
|
22 |
-
'automatic-speech-recognition',
|
23 |
-
'image-text-to-text',
|
24 |
-
'token-classification',
|
25 |
-
'sentence-similarity',
|
26 |
-
'question-answering',
|
27 |
-
'image-feature-extraction',
|
28 |
-
'summarization',
|
29 |
-
'zero-shot-image-classification',
|
30 |
-
'object-detection',
|
31 |
-
'image-segmentation',
|
32 |
-
'image-to-image',
|
33 |
-
'image-to-text',
|
34 |
-
'audio-classification',
|
35 |
-
'visual-question-answering',
|
36 |
-
'text-to-video',
|
37 |
-
'zero-shot-classification',
|
38 |
-
'depth-estimation',
|
39 |
-
'text-ranking',
|
40 |
-
'image-to-video',
|
41 |
-
'multiple-choice',
|
42 |
-
'unconditional-image-generation',
|
43 |
-
'video-classification',
|
44 |
-
'text-to-audio',
|
45 |
-
'time-series-forecasting',
|
46 |
-
'any-to-any',
|
47 |
-
'video-text-to-text',
|
48 |
'table-question-answering',
|
49 |
]
|
50 |
|
51 |
-
#
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
59 |
|
60 |
-
#
|
61 |
-
def
|
62 |
-
|
63 |
-
pipeline_tag = row.get("pipeline_tag", "")
|
64 |
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
106 |
|
107 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
if
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
125 |
-
|
126 |
-
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
131 |
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
|
|
|
|
|
|
137 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
|
139 |
-
|
140 |
-
# Create a copy to avoid modifying the original
|
141 |
filtered_df = df.copy()
|
|
|
|
|
|
|
142 |
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
148 |
if pipeline_filter:
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
|
158 |
-
# Add organization column
|
159 |
-
filtered_df["organization"] = filtered_df["id"].apply(extract_org_from_id)
|
160 |
-
|
161 |
-
# Skip organizations if specified
|
162 |
if skip_orgs and len(skip_orgs) > 0:
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
treemap_data
|
177 |
-
|
178 |
-
# Add a root node
|
179 |
-
treemap_data["root"] = "models"
|
180 |
-
|
181 |
-
# Ensure numeric values
|
182 |
-
treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0)
|
183 |
-
|
184 |
return treemap_data
|
185 |
|
186 |
def create_treemap(treemap_data, count_by, title=None):
|
187 |
-
"""Create a Plotly treemap from the prepared data"""
|
188 |
if treemap_data.empty:
|
189 |
-
|
190 |
-
fig =
|
191 |
-
names=["No data matches the selected filters"],
|
192 |
-
values=[1]
|
193 |
-
)
|
194 |
-
fig.update_layout(
|
195 |
-
title="No data matches the selected filters",
|
196 |
-
margin=dict(t=50, l=25, r=25, b=25)
|
197 |
-
)
|
198 |
return fig
|
199 |
-
|
200 |
-
# Create the treemap
|
201 |
fig = px.treemap(
|
202 |
-
treemap_data,
|
203 |
-
path=["root", "organization", "id"],
|
204 |
-
values=count_by,
|
205 |
title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
|
206 |
-
color_discrete_sequence=px.colors.qualitative.Plotly
|
207 |
-
)
|
208 |
-
|
209 |
-
# Update layout
|
210 |
-
fig.update_layout(
|
211 |
-
margin=dict(t=50, l=25, r=25, b=25)
|
212 |
-
)
|
213 |
-
|
214 |
-
# Update traces for better readability
|
215 |
-
fig.update_traces(
|
216 |
-
textinfo="label+value+percent root",
|
217 |
-
hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>"
|
218 |
)
|
219 |
-
|
|
|
220 |
return fig
|
221 |
|
222 |
-
|
223 |
-
|
224 |
-
|
225 |
-
|
226 |
-
# Process the tags column
|
227 |
-
def process_tags(tags_str):
|
228 |
-
if pd.isna(tags_str):
|
229 |
-
return []
|
230 |
-
|
231 |
-
# Clean the string and convert to a list
|
232 |
-
tags_str = tags_str.strip("[]").replace("'", "")
|
233 |
-
tags = [tag.strip() for tag in tags_str.split() if tag.strip()]
|
234 |
-
return tags
|
235 |
-
|
236 |
-
df['tags'] = df['tags'].apply(process_tags)
|
237 |
-
|
238 |
-
# Add more sample data for better visualization
|
239 |
-
add_sample_data(df)
|
240 |
-
|
241 |
-
return df
|
242 |
-
|
243 |
-
def add_sample_data(df):
|
244 |
-
"""Add more sample data to make the visualization more interesting"""
|
245 |
-
# Top organizations to include
|
246 |
-
orgs = ['openai', 'meta', 'google', 'microsoft', 'anthropic', 'nvidia', 'huggingface',
|
247 |
-
'deepseek-ai', 'stability-ai', 'mistralai', 'cerebras', 'databricks', 'together',
|
248 |
-
'facebook', 'amazon', 'deepmind', 'cohere', 'bigscience', 'eleutherai']
|
249 |
-
|
250 |
-
# Common model name formats
|
251 |
-
model_name_patterns = [
|
252 |
-
"model-{size}-{version}",
|
253 |
-
"{prefix}-{size}b",
|
254 |
-
"{prefix}-{size}b-{variant}",
|
255 |
-
"llama-{size}b-{variant}",
|
256 |
-
"gpt-{variant}-{size}b",
|
257 |
-
"{prefix}-instruct-{size}b",
|
258 |
-
"{prefix}-chat-{size}b",
|
259 |
-
"{prefix}-coder-{size}b",
|
260 |
-
"stable-diffusion-{version}",
|
261 |
-
"whisper-{size}",
|
262 |
-
"bert-{size}-{variant}",
|
263 |
-
"roberta-{size}",
|
264 |
-
"t5-{size}",
|
265 |
-
"{prefix}-vision-{size}b"
|
266 |
-
]
|
267 |
-
|
268 |
-
# Common name parts
|
269 |
-
prefixes = ["falcon", "llama", "mistral", "gpt", "phi", "gemma", "qwen", "yi", "mpt", "bloom"]
|
270 |
-
sizes = ["7", "13", "34", "70", "1", "3", "7b", "13b", "70b", "8b", "2b", "1b", "0.5b", "small", "base", "large", "huge"]
|
271 |
-
variants = ["chat", "instruct", "base", "v1.0", "v2", "beta", "turbo", "fast", "xl", "xxl"]
|
272 |
-
|
273 |
-
# Generate sample data
|
274 |
-
sample_data = []
|
275 |
-
for org_idx, org in enumerate(orgs):
|
276 |
-
# Create 5-10 models per organization
|
277 |
-
num_models = np.random.randint(5, 11)
|
278 |
-
|
279 |
-
for i in range(num_models):
|
280 |
-
# Create realistic model name
|
281 |
-
pattern = np.random.choice(model_name_patterns)
|
282 |
-
prefix = np.random.choice(prefixes)
|
283 |
-
size = np.random.choice(sizes)
|
284 |
-
version = f"v{np.random.randint(1, 4)}"
|
285 |
-
variant = np.random.choice(variants)
|
286 |
-
|
287 |
-
model_name = pattern.format(
|
288 |
-
prefix=prefix,
|
289 |
-
size=size,
|
290 |
-
version=version,
|
291 |
-
variant=variant
|
292 |
-
)
|
293 |
-
|
294 |
-
model_id = f"{org}/{model_name}"
|
295 |
-
|
296 |
-
# Select a realistic pipeline tag based on name
|
297 |
-
if "diffusion" in model_name or "image" in model_name:
|
298 |
-
pipeline_tag = np.random.choice(["text-to-image", "image-to-image", "image-segmentation"])
|
299 |
-
elif "whisper" in model_name or "speech" in model_name:
|
300 |
-
pipeline_tag = np.random.choice(["automatic-speech-recognition", "text-to-speech"])
|
301 |
-
elif "coder" in model_name or "code" in model_name:
|
302 |
-
pipeline_tag = "text-generation"
|
303 |
-
elif "bert" in model_name or "roberta" in model_name:
|
304 |
-
pipeline_tag = np.random.choice(["fill-mask", "text-classification", "token-classification"])
|
305 |
-
elif "vision" in model_name:
|
306 |
-
pipeline_tag = np.random.choice(["image-classification", "image-to-text", "visual-question-answering"])
|
307 |
-
else:
|
308 |
-
pipeline_tag = "text-generation" # Most common
|
309 |
-
|
310 |
-
# Generate realistic tags
|
311 |
-
tags = [pipeline_tag]
|
312 |
-
|
313 |
-
if "text-generation" in pipeline_tag:
|
314 |
-
tags.extend(["language-model", "text", "gpt", "llm"])
|
315 |
-
if "instruct" in model_name:
|
316 |
-
tags.append("instruction-following")
|
317 |
-
if "chat" in model_name:
|
318 |
-
tags.append("chat")
|
319 |
-
elif "speech" in pipeline_tag:
|
320 |
-
tags.extend(["audio", "speech", "voice"])
|
321 |
-
elif "image" in pipeline_tag:
|
322 |
-
tags.extend(["vision", "image", "diffusion"])
|
323 |
-
|
324 |
-
# Add language tags
|
325 |
-
if np.random.random() < 0.8: # 80% chance for English
|
326 |
-
tags.append("en")
|
327 |
-
if np.random.random() < 0.3: # 30% chance for multilingual
|
328 |
-
tags.append("multilingual")
|
329 |
-
|
330 |
-
# Generate downloads and likes (weighted by org position for variety)
|
331 |
-
# Earlier orgs get more downloads to make the visualization interesting
|
332 |
-
popularity_factor = (len(orgs) - org_idx) / len(orgs) # 1.0 to 0.0
|
333 |
-
base_downloads = 10000 * (10 ** (2 * popularity_factor))
|
334 |
-
downloads = int(base_downloads * np.random.uniform(0.3, 3.0))
|
335 |
-
likes = int(downloads * np.random.uniform(0.01, 0.1)) # 1-10% like ratio
|
336 |
-
|
337 |
-
# Generate model size (in bytes for params)
|
338 |
-
# Model size should correlate somewhat with the size in the name
|
339 |
-
size_indicator = 1
|
340 |
-
for s in ["70b", "13b", "7b", "3b", "2b", "1b", "large", "huge", "xl", "xxl"]:
|
341 |
-
if s in model_name.lower():
|
342 |
-
size_indicator = float(s.replace("b", "")) if s[0].isdigit() else 3
|
343 |
-
break
|
344 |
-
|
345 |
-
# Size in bytes
|
346 |
-
params = int(np.random.uniform(0.5, 2.0) * size_indicator * 1e9)
|
347 |
-
|
348 |
-
# Create model entry
|
349 |
-
model = {
|
350 |
-
"id": model_id,
|
351 |
-
"author": org,
|
352 |
-
"downloads": downloads,
|
353 |
-
"likes": likes,
|
354 |
-
"pipeline_tag": pipeline_tag,
|
355 |
-
"tags": tags,
|
356 |
-
"params": params
|
357 |
-
}
|
358 |
-
|
359 |
-
sample_data.append(model)
|
360 |
-
|
361 |
-
# Convert sample data to DataFrame and append to original
|
362 |
-
sample_df = pd.DataFrame(sample_data)
|
363 |
-
return pd.concat([df, sample_df], ignore_index=True)
|
364 |
|
365 |
-
# Create Gradio interface
|
366 |
-
with gr.Blocks() as demo:
|
367 |
-
models_data = gr.State() # To store loaded data
|
368 |
-
|
369 |
with gr.Row():
|
370 |
-
gr.Markdown(""
|
371 |
-
# HuggingFace Models TreeMap Visualization
|
372 |
-
|
373 |
-
This app shows how different organizations contribute to the HuggingFace ecosystem with their models.
|
374 |
-
Use the filters to explore models by different metrics, tags, pipelines, and model sizes.
|
375 |
-
|
376 |
-
The treemap visualizes models grouped by organization, with the size of each box representing the selected metric (downloads or likes).
|
377 |
-
""")
|
378 |
-
|
379 |
with gr.Row():
|
380 |
-
with gr.Column(scale=1):
|
381 |
-
count_by_dropdown = gr.Dropdown(
|
382 |
-
|
383 |
-
|
384 |
-
|
385 |
-
|
386 |
-
)
|
387 |
-
|
388 |
-
filter_choice_radio = gr.Radio(
|
389 |
-
label="Filter Type",
|
390 |
-
choices=["None", "Tag Filter", "Pipeline Filter"],
|
391 |
-
value="None",
|
392 |
-
info="Choose how to filter the models"
|
393 |
-
)
|
394 |
-
|
395 |
-
tag_filter_dropdown = gr.Dropdown(
|
396 |
-
label="Select Tag",
|
397 |
-
choices=list(TAG_FILTER_FUNCS.keys()),
|
398 |
-
value=None,
|
399 |
-
visible=False,
|
400 |
-
info="Filter models by domain/category"
|
401 |
-
)
|
402 |
-
|
403 |
-
pipeline_filter_dropdown = gr.Dropdown(
|
404 |
-
label="Select Pipeline Tag",
|
405 |
-
choices=PIPELINE_TAGS,
|
406 |
-
value=None,
|
407 |
-
visible=False,
|
408 |
-
info="Filter models by specific pipeline"
|
409 |
-
)
|
410 |
|
411 |
-
|
412 |
-
|
413 |
-
choices=["None"] + list(MODEL_SIZE_RANGES.keys()),
|
414 |
-
value="None",
|
415 |
-
info="Filter models by their size (using params column)"
|
416 |
-
)
|
417 |
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
|
424 |
-
|
425 |
-
|
426 |
-
|
427 |
-
|
428 |
-
|
429 |
-
|
430 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
431 |
)
|
432 |
|
433 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
434 |
|
435 |
-
|
436 |
-
|
437 |
-
|
|
|
|
|
|
|
|
|
|
|
438 |
|
439 |
-
|
440 |
-
|
441 |
-
|
442 |
-
|
443 |
-
|
444 |
-
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
|
450 |
-
|
451 |
-
|
452 |
-
|
453 |
-
|
454 |
-
|
455 |
-
|
456 |
-
|
457 |
-
|
458 |
-
skip_orgs = []
|
459 |
-
if skip_orgs_text and skip_orgs_text.strip():
|
460 |
-
skip_orgs = [org.strip() for org in skip_orgs_text.split(',') if org.strip()]
|
461 |
-
print(f"Skipping organizations: {skip_orgs}")
|
462 |
-
|
463 |
-
# Process data for treemap
|
464 |
-
treemap_data = make_treemap_data(
|
465 |
-
df=data_df,
|
466 |
-
count_by=count_by,
|
467 |
-
top_k=top_k,
|
468 |
-
tag_filter=selected_tag_filter,
|
469 |
-
pipeline_filter=selected_pipeline_filter,
|
470 |
-
size_filter=selected_size_filter,
|
471 |
-
skip_orgs=skip_orgs
|
472 |
-
)
|
473 |
-
|
474 |
-
# Create plot
|
475 |
-
fig = create_treemap(
|
476 |
-
treemap_data=treemap_data,
|
477 |
-
count_by=count_by,
|
478 |
-
title=f"HuggingFace Models - {count_by.capitalize()} by Organization"
|
479 |
-
)
|
480 |
|
481 |
-
|
482 |
-
|
483 |
-
|
484 |
-
|
485 |
-
|
486 |
-
total_value = treemap_data[count_by].sum()
|
487 |
-
top_5_orgs = treemap_data.groupby("organization")[count_by].sum().sort_values(ascending=False).head(5)
|
488 |
-
|
489 |
-
# Format the statistics using clean markdown
|
490 |
-
stats_md = f"""
|
491 |
-
## Statistics
|
492 |
-
- **Total models shown**: {total_models:,}
|
493 |
-
- **Total {count_by}**: {int(total_value):,}
|
494 |
|
495 |
-
|
496 |
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
508 |
|
509 |
-
|
510 |
-
|
511 |
-
|
512 |
-
|
513 |
-
|
514 |
-
|
515 |
-
|
516 |
-
|
517 |
-
|
518 |
-
|
519 |
-
|
520 |
-
|
521 |
-
|
522 |
-
|
523 |
-
|
524 |
-
|
525 |
-
# Load data once at startup
|
526 |
demo.load(
|
527 |
-
fn=
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
528 |
inputs=[],
|
529 |
-
outputs=[
|
530 |
)
|
531 |
|
532 |
-
#
|
533 |
generate_plot_button.click(
|
534 |
-
fn=
|
535 |
-
inputs=[
|
536 |
-
|
537 |
-
|
538 |
-
tag_filter_dropdown,
|
539 |
-
pipeline_filter_dropdown,
|
540 |
-
size_filter_dropdown,
|
541 |
-
top_k_slider,
|
542 |
-
skip_orgs_textbox,
|
543 |
-
models_data
|
544 |
-
],
|
545 |
-
outputs=[plot_output, stats_output]
|
546 |
)
|
547 |
|
548 |
-
|
549 |
if __name__ == "__main__":
|
550 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --- START OF FILE app.py ---
|
2 |
+
|
3 |
import json
|
4 |
import gradio as gr
|
5 |
import pandas as pd
|
6 |
import plotly.express as px
|
7 |
import os
|
8 |
+
import numpy as np # Make sure NumPy is imported
|
9 |
+
import duckdb
|
10 |
+
from tqdm.auto import tqdm # Standard tqdm for console, gr.Progress will track it
|
11 |
+
import time
|
12 |
+
import ast # For safely evaluating string representations of lists/dicts
|
13 |
+
|
14 |
+
# --- Constants ---
|
15 |
+
MODEL_SIZE_RANGES = {
|
16 |
+
"Small (<1GB)": (0, 1), "Medium (1-5GB)": (1, 5), "Large (5-20GB)": (5, 20),
|
17 |
+
"X-Large (20-50GB)": (20, 50), "XX-Large (>50GB)": (50, float('inf'))
|
18 |
+
}
|
19 |
+
PROCESSED_PARQUET_FILE_PATH = "models_processed.parquet"
|
20 |
+
HF_PARQUET_URL = 'https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet'
|
21 |
+
|
22 |
+
TAG_FILTER_CHOICES = [
|
23 |
+
"Audio & Speech", "Time series", "Robotics", "Music", "Video", "Images",
|
24 |
+
"Text", "Biomedical", "Sciences"
|
25 |
+
]
|
26 |
|
|
|
27 |
PIPELINE_TAGS = [
|
28 |
+
'text-generation', 'text-to-image', 'text-classification', 'text2text-generation',
|
29 |
+
'audio-to-audio', 'feature-extraction', 'image-classification', 'translation',
|
30 |
+
'reinforcement-learning', 'fill-mask', 'text-to-speech', 'automatic-speech-recognition',
|
31 |
+
'image-text-to-text', 'token-classification', 'sentence-similarity', 'question-answering',
|
32 |
+
'image-feature-extraction', 'summarization', 'zero-shot-image-classification',
|
33 |
+
'object-detection', 'image-segmentation', 'image-to-image', 'image-to-text',
|
34 |
+
'audio-classification', 'visual-question-answering', 'text-to-video',
|
35 |
+
'zero-shot-classification', 'depth-estimation', 'text-ranking', 'image-to-video',
|
36 |
+
'multiple-choice', 'unconditional-image-generation', 'video-classification',
|
37 |
+
'text-to-audio', 'time-series-forecasting', 'any-to-any', 'video-text-to-text',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
'table-question-answering',
|
39 |
]
|
40 |
|
41 |
+
# --- Utility Functions ---
|
42 |
+
def extract_model_size(safetensors_data): # Renamed for consistency if used, preprocessor uses extract_model_file_size_gb
|
43 |
+
try:
|
44 |
+
if pd.isna(safetensors_data): return 0.0
|
45 |
+
data_to_parse = safetensors_data
|
46 |
+
if isinstance(safetensors_data, str):
|
47 |
+
try:
|
48 |
+
if (safetensors_data.startswith('{') and safetensors_data.endswith('}')) or \
|
49 |
+
(safetensors_data.startswith('[') and safetensors_data.endswith(']')):
|
50 |
+
data_to_parse = ast.literal_eval(safetensors_data)
|
51 |
+
else: data_to_parse = json.loads(safetensors_data)
|
52 |
+
except: return 0.0
|
53 |
+
if isinstance(data_to_parse, dict) and 'total' in data_to_parse:
|
54 |
+
try:
|
55 |
+
total_bytes_val = data_to_parse['total']
|
56 |
+
size_bytes = float(total_bytes_val)
|
57 |
+
return size_bytes / (1024 * 1024 * 1024)
|
58 |
+
except (ValueError, TypeError): pass
|
59 |
+
return 0.0
|
60 |
+
except: return 0.0
|
61 |
+
|
62 |
+
def extract_org_from_id(model_id):
|
63 |
+
if pd.isna(model_id): return "unaffiliated"
|
64 |
+
model_id_str = str(model_id)
|
65 |
+
return model_id_str.split("/")[0] if "/" in model_id_str else "unaffiliated"
|
66 |
|
67 |
+
# --- THIS IS THE CORRECTED process_tags_for_series from preprocess.py ---
|
68 |
+
def process_tags_for_series(series_of_tags_values, tqdm_cls=None): # Added tqdm_cls for Gradio progress
|
69 |
+
processed_tags_accumulator = []
|
|
|
70 |
|
71 |
+
# Determine the iterable (use tqdm if tqdm_cls is provided, else direct iteration)
|
72 |
+
iterable = series_of_tags_values
|
73 |
+
if tqdm_cls and tqdm_cls != tqdm : # Check if it's Gradio's progress tracker
|
74 |
+
iterable = tqdm_cls(series_of_tags_values, desc="Standardizing Tags (App)", unit="row")
|
75 |
+
elif tqdm_cls == tqdm: # For direct console tqdm if passed
|
76 |
+
iterable = tqdm(series_of_tags_values, desc="Standardizing Tags (App)", unit="row", leave=False)
|
77 |
+
|
78 |
+
|
79 |
+
for i, tags_value_from_series in enumerate(iterable):
|
80 |
+
temp_processed_list_for_row = []
|
81 |
+
current_value_for_error_msg = str(tags_value_from_series)[:200]
|
82 |
+
|
83 |
+
try:
|
84 |
+
if isinstance(tags_value_from_series, list):
|
85 |
+
current_tags_in_list = []
|
86 |
+
for tag_item in tags_value_from_series:
|
87 |
+
try:
|
88 |
+
if pd.isna(tag_item): continue
|
89 |
+
str_tag = str(tag_item)
|
90 |
+
stripped_tag = str_tag.strip()
|
91 |
+
if stripped_tag:
|
92 |
+
current_tags_in_list.append(stripped_tag)
|
93 |
+
except Exception as e_inner_list_proc:
|
94 |
+
print(f"APP ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a list for row {i}. Error: {e_inner_list_proc}. Original: {current_value_for_error_msg}")
|
95 |
+
temp_processed_list_for_row = current_tags_in_list
|
96 |
+
|
97 |
+
elif isinstance(tags_value_from_series, np.ndarray):
|
98 |
+
current_tags_in_list = []
|
99 |
+
for tag_item in tags_value_from_series.tolist():
|
100 |
+
try:
|
101 |
+
if pd.isna(tag_item): continue
|
102 |
+
str_tag = str(tag_item)
|
103 |
+
stripped_tag = str_tag.strip()
|
104 |
+
if stripped_tag:
|
105 |
+
current_tags_in_list.append(stripped_tag)
|
106 |
+
except Exception as e_inner_array_proc:
|
107 |
+
print(f"APP ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a NumPy array for row {i}. Error: {e_inner_array_proc}. Original: {current_value_for_error_msg}")
|
108 |
+
temp_processed_list_for_row = current_tags_in_list
|
109 |
+
|
110 |
+
elif tags_value_from_series is None or pd.isna(tags_value_from_series):
|
111 |
+
temp_processed_list_for_row = []
|
112 |
+
|
113 |
+
elif isinstance(tags_value_from_series, str):
|
114 |
+
processed_str_tags = []
|
115 |
+
if (tags_value_from_series.startswith('[') and tags_value_from_series.endswith(']')) or \
|
116 |
+
(tags_value_from_series.startswith('(') and tags_value_from_series.endswith(')')):
|
117 |
+
try:
|
118 |
+
evaluated_tags = ast.literal_eval(tags_value_from_series)
|
119 |
+
if isinstance(evaluated_tags, (list, tuple)):
|
120 |
+
current_eval_list = []
|
121 |
+
for tag_item in evaluated_tags:
|
122 |
+
if pd.isna(tag_item): continue
|
123 |
+
str_tag = str(tag_item).strip()
|
124 |
+
if str_tag: current_eval_list.append(str_tag)
|
125 |
+
processed_str_tags = current_eval_list
|
126 |
+
except (ValueError, SyntaxError):
|
127 |
+
pass
|
128 |
+
|
129 |
+
if not processed_str_tags:
|
130 |
+
try:
|
131 |
+
json_tags = json.loads(tags_value_from_series)
|
132 |
+
if isinstance(json_tags, list):
|
133 |
+
current_json_list = []
|
134 |
+
for tag_item in json_tags:
|
135 |
+
if pd.isna(tag_item): continue
|
136 |
+
str_tag = str(tag_item).strip()
|
137 |
+
if str_tag: current_json_list.append(str_tag)
|
138 |
+
processed_str_tags = current_json_list
|
139 |
+
except json.JSONDecodeError:
|
140 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
|
141 |
+
except Exception as e_json_other:
|
142 |
+
print(f"APP ERROR during JSON processing for string '{current_value_for_error_msg}' for row {i}. Error: {e_json_other}")
|
143 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
|
144 |
+
|
145 |
+
temp_processed_list_for_row = processed_str_tags
|
146 |
+
|
147 |
+
else:
|
148 |
+
if pd.isna(tags_value_from_series):
|
149 |
+
temp_processed_list_for_row = []
|
150 |
+
else:
|
151 |
+
str_val = str(tags_value_from_series).strip()
|
152 |
+
temp_processed_list_for_row = [str_val] if str_val else []
|
153 |
+
|
154 |
+
processed_tags_accumulator.append(temp_processed_list_for_row)
|
155 |
+
|
156 |
+
except Exception as e_outer_tag_proc:
|
157 |
+
print(f"APP CRITICAL UNHANDLED ERROR processing row {i}: value '{current_value_for_error_msg}' (type: {type(tags_value_from_series)}). Error: {e_outer_tag_proc}. Appending [].")
|
158 |
+
processed_tags_accumulator.append([])
|
159 |
+
|
160 |
+
return processed_tags_accumulator
|
161 |
+
# --- END OF CORRECTED process_tags_for_series ---
|
162 |
+
|
163 |
+
|
164 |
+
def load_models_data(force_refresh=False, tqdm_cls=None): # tqdm_cls for Gradio progress
|
165 |
+
# ... (initial part of load_models_data for loading pre-processed parquet is the same) ...
|
166 |
+
if tqdm_cls is None: tqdm_cls = tqdm # Default to standard tqdm if None
|
167 |
+
overall_start_time = time.time()
|
168 |
+
print(f"Gradio load_models_data called with force_refresh={force_refresh}")
|
169 |
+
|
170 |
+
expected_cols_in_processed_parquet = [
|
171 |
+
'id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags', 'params',
|
172 |
+
'size_category', 'organization', 'has_audio', 'has_speech', 'has_music',
|
173 |
+
'has_robot', 'has_bio', 'has_med', 'has_series', 'has_video', 'has_image',
|
174 |
+
'has_text', 'has_science', 'is_audio_speech', 'is_biomed',
|
175 |
+
'data_download_timestamp'
|
176 |
+
]
|
177 |
+
|
178 |
+
if not force_refresh and os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
179 |
+
print(f"Attempting to load pre-processed data from: {PROCESSED_PARQUET_FILE_PATH}")
|
180 |
+
try:
|
181 |
+
df = pd.read_parquet(PROCESSED_PARQUET_FILE_PATH)
|
182 |
+
elapsed = time.time() - overall_start_time
|
183 |
+
missing_cols = [col for col in expected_cols_in_processed_parquet if col not in df.columns]
|
184 |
+
if missing_cols:
|
185 |
+
raise ValueError(f"Pre-processed Parquet is missing columns: {missing_cols}. Please run preprocessor or refresh data in app.")
|
186 |
+
|
187 |
+
if 'has_robot' in df.columns:
|
188 |
+
robot_count_parquet = df['has_robot'].sum()
|
189 |
+
print(f"DIAGNOSTIC (App - Parquet Load): 'has_robot' column found. Number of True values: {robot_count_parquet}")
|
190 |
+
else:
|
191 |
+
print("DIAGNOSTIC (App - Parquet Load): 'has_robot' column NOT FOUND.")
|
192 |
+
|
193 |
+
msg = f"Successfully loaded pre-processed data in {elapsed:.2f}s. Shape: {df.shape}"
|
194 |
+
print(msg)
|
195 |
+
return df, True, msg
|
196 |
+
except Exception as e:
|
197 |
+
print(f"Could not load pre-processed Parquet: {e}. ")
|
198 |
+
if force_refresh: print("Proceeding to fetch fresh data as force_refresh=True.")
|
199 |
+
else:
|
200 |
+
err_msg = (f"Pre-processed data could not be loaded: {e}. "
|
201 |
+
"Please use 'Refresh Data from Hugging Face' button.")
|
202 |
+
return pd.DataFrame(), False, err_msg
|
203 |
+
|
204 |
+
df_raw = None
|
205 |
+
raw_data_source_msg = ""
|
206 |
+
if force_refresh:
|
207 |
+
print("force_refresh=True (Gradio). Fetching fresh data...")
|
208 |
+
fetch_start = time.time()
|
209 |
+
try:
|
210 |
+
query = f"SELECT * FROM read_parquet('{HF_PARQUET_URL}')"
|
211 |
+
df_raw = duckdb.sql(query).df()
|
212 |
+
if df_raw is None or df_raw.empty: raise ValueError("Fetched data is empty or None.")
|
213 |
+
raw_data_source_msg = f"Fetched by Gradio in {time.time() - fetch_start:.2f}s. Rows: {len(df_raw)}"
|
214 |
+
print(raw_data_source_msg)
|
215 |
+
except Exception as e_hf:
|
216 |
+
return pd.DataFrame(), False, f"Fatal error fetching from Hugging Face (Gradio): {e_hf}"
|
217 |
+
else:
|
218 |
+
err_msg = (f"Pre-processed data '{PROCESSED_PARQUET_FILE_PATH}' not found/invalid. "
|
219 |
+
"Run preprocessor or use 'Refresh Data' button.")
|
220 |
+
return pd.DataFrame(), False, err_msg
|
221 |
+
|
222 |
+
print(f"Initiating processing for data newly fetched by Gradio. {raw_data_source_msg}")
|
223 |
+
df = pd.DataFrame() # This will be our processed DataFrame
|
224 |
+
proc_start = time.time()
|
225 |
|
226 |
+
core_cols = {'id': str, 'downloads': float, 'downloadsAllTime': float, 'likes': float,
|
227 |
+
'pipeline_tag': str, 'tags': object, 'safetensors': object}
|
228 |
+
for col, dtype in core_cols.items():
|
229 |
+
if col in df_raw.columns:
|
230 |
+
df[col] = df_raw[col] # Assign raw data first
|
231 |
+
if dtype == float: df[col] = pd.to_numeric(df[col], errors='coerce').fillna(0.0)
|
232 |
+
elif dtype == str: df[col] = df[col].astype(str).fillna('')
|
233 |
+
# For 'tags' and 'safetensors' (object type), no specific conversion here, done later
|
234 |
+
else: # If column is missing in raw data
|
235 |
+
if col in ['downloads', 'downloadsAllTime', 'likes']: df[col] = 0.0
|
236 |
+
elif col == 'pipeline_tag': df[col] = ''
|
237 |
+
elif col == 'tags': df[col] = pd.Series([[] for _ in range(len(df_raw))]) # Default to empty lists
|
238 |
+
elif col == 'safetensors': df[col] = None # Default to None
|
239 |
+
elif col == 'id': return pd.DataFrame(), False, "Critical: 'id' column missing."
|
240 |
|
241 |
+
output_filesize_col_name = 'params'
|
242 |
+
if output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name]):
|
243 |
+
df[output_filesize_col_name] = pd.to_numeric(df_raw[output_filesize_col_name], errors='coerce').fillna(0.0)
|
244 |
+
elif 'safetensors' in df.columns:
|
245 |
+
# Use tqdm_cls for progress tracking if available (Gradio's gr.Progress.tqdm)
|
246 |
+
safetensors_iter = df['safetensors']
|
247 |
+
if tqdm_cls and tqdm_cls != tqdm: # Check if it's Gradio's progress tracker
|
248 |
+
safetensors_iter = tqdm_cls(df['safetensors'], desc="Extracting model sizes (GB)", unit="row")
|
249 |
+
elif tqdm_cls == tqdm: # For direct console tqdm if passed
|
250 |
+
safetensors_iter = tqdm(df['safetensors'], desc="Extracting model sizes (GB)", unit="row", leave=False)
|
251 |
+
|
252 |
+
df[output_filesize_col_name] = [extract_model_size(s) for s in safetensors_iter]
|
253 |
+
df[output_filesize_col_name] = pd.to_numeric(df[output_filesize_col_name], errors='coerce').fillna(0.0)
|
254 |
+
else:
|
255 |
+
df[output_filesize_col_name] = 0.0
|
256 |
+
|
257 |
+
def get_size_category_gradio(size_gb_val):
|
258 |
+
try: numeric_size_gb = float(size_gb_val)
|
259 |
+
except (ValueError, TypeError): numeric_size_gb = 0.0
|
260 |
+
if pd.isna(numeric_size_gb): numeric_size_gb = 0.0
|
261 |
+
if 0 <= numeric_size_gb < 1: return "Small (<1GB)"
|
262 |
+
elif 1 <= numeric_size_gb < 5: return "Medium (1-5GB)"
|
263 |
+
elif 5 <= numeric_size_gb < 20: return "Large (5-20GB)"
|
264 |
+
elif 20 <= numeric_size_gb < 50: return "X-Large (20-50GB)"
|
265 |
+
elif numeric_size_gb >= 50: return "XX-Large (>50GB)"
|
266 |
+
else: return "Small (<1GB)" # Default
|
267 |
+
df['size_category'] = df[output_filesize_col_name].apply(get_size_category_gradio)
|
268 |
+
|
269 |
+
# >>> USE THE CORRECTED process_tags_for_series HERE <<<
|
270 |
+
df['tags'] = process_tags_for_series(df['tags'], tqdm_cls=tqdm_cls)
|
271 |
|
272 |
+
df['temp_tags_joined'] = df['tags'].apply(
|
273 |
+
lambda tl: '~~~'.join(str(t).lower().strip() for t in tl if pd.notna(t) and str(t).strip()) if isinstance(tl, list) else ''
|
274 |
+
)
|
275 |
+
tag_map = {
|
276 |
+
'has_audio': ['audio'], 'has_speech': ['speech'], 'has_music': ['music'],
|
277 |
+
'has_robot': ['robot', 'robotics'],
|
278 |
+
'has_bio': ['bio'], 'has_med': ['medic', 'medical'],
|
279 |
+
'has_series': ['series', 'time-series', 'timeseries'],
|
280 |
+
'has_video': ['video'], 'has_image': ['image', 'vision'],
|
281 |
+
'has_text': ['text', 'nlp', 'llm']
|
282 |
+
}
|
283 |
+
for col, kws in tag_map.items():
|
284 |
+
pattern = '|'.join(kws)
|
285 |
+
df[col] = df['temp_tags_joined'].str.contains(pattern, na=False, case=False, regex=True)
|
286 |
+
df['has_science'] = (
|
287 |
+
df['temp_tags_joined'].str.contains('science', na=False, case=False, regex=True) &
|
288 |
+
~df['temp_tags_joined'].str.contains('bigscience', na=False, case=False, regex=True)
|
289 |
+
)
|
290 |
+
del df['temp_tags_joined']
|
291 |
+
df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
|
292 |
+
df['pipeline_tag'].str.contains('audio|speech', case=False, na=False, regex=True))
|
293 |
+
df['is_biomed'] = df['has_bio'] | df['has_med']
|
294 |
+
df['organization'] = df['id'].apply(extract_org_from_id)
|
295 |
|
296 |
+
# Drop safetensors if params was calculated from it, and params didn't pre-exist as numeric
|
297 |
+
if 'safetensors' in df.columns and \
|
298 |
+
not (output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name])):
|
299 |
+
df = df.drop(columns=['safetensors'], errors='ignore')
|
300 |
+
|
301 |
+
if force_refresh and 'has_robot' in df.columns:
|
302 |
+
robot_count_app_proc = df['has_robot'].sum()
|
303 |
+
print(f"DIAGNOSTIC (App - Force Refresh Processing): 'has_robot' column processed. Number of True values: {robot_count_app_proc}")
|
304 |
|
305 |
+
print(f"Data processing by Gradio completed in {time.time() - proc_start:.2f}s.")
|
306 |
+
|
307 |
+
total_elapsed = time.time() - overall_start_time
|
308 |
+
final_msg = f"{raw_data_source_msg}. Processing by Gradio took {time.time() - proc_start:.2f}s. Total: {total_elapsed:.2f}s. Shape: {df.shape}"
|
309 |
+
print(final_msg)
|
310 |
+
return df, True, final_msg
|
311 |
+
|
312 |
+
|
313 |
+
# ... (make_treemap_data, create_treemap functions remain unchanged) ...
|
314 |
def make_treemap_data(df, count_by, top_k=25, tag_filter=None, pipeline_filter=None, size_filter=None, skip_orgs=None):
|
315 |
+
if df is None or df.empty: return pd.DataFrame()
|
|
|
316 |
filtered_df = df.copy()
|
317 |
+
col_map = { "Audio & Speech": "is_audio_speech", "Music": "has_music", "Robotics": "has_robot",
|
318 |
+
"Biomedical": "is_biomed", "Time series": "has_series", "Sciences": "has_science",
|
319 |
+
"Video": "has_video", "Images": "has_image", "Text": "has_text"}
|
320 |
|
321 |
+
if 'has_robot' in filtered_df.columns:
|
322 |
+
initial_robot_count = filtered_df['has_robot'].sum()
|
323 |
+
# print(f"DIAGNOSTIC (make_treemap_data entry): Input df has {initial_robot_count} 'has_robot' models.") # Can be noisy
|
324 |
+
# else:
|
325 |
+
# print("DIAGNOSTIC (make_treemap_data entry): 'has_robot' column NOT in input df.")
|
326 |
+
|
327 |
+
if tag_filter and tag_filter in col_map:
|
328 |
+
target_col = col_map[tag_filter]
|
329 |
+
if target_col in filtered_df.columns:
|
330 |
+
# if tag_filter == "Robotics":
|
331 |
+
# count_before_robot_filter = filtered_df[target_col].sum()
|
332 |
+
# print(f"DIAGNOSTIC (make_treemap_data): Applying 'Robotics' filter. Models with '{target_col}'=True: {count_before_robot_filter}")
|
333 |
+
filtered_df = filtered_df[filtered_df[target_col]]
|
334 |
+
# if tag_filter == "Robotics":
|
335 |
+
# print(f"DIAGNOSTIC (make_treemap_data): After 'Robotics' filter ({target_col}), df rows: {len(filtered_df)}")
|
336 |
+
else:
|
337 |
+
print(f"Warning: Tag filter column '{col_map[tag_filter]}' not found in DataFrame.")
|
338 |
if pipeline_filter:
|
339 |
+
if "pipeline_tag" in filtered_df.columns:
|
340 |
+
filtered_df = filtered_df[filtered_df["pipeline_tag"] == pipeline_filter]
|
341 |
+
else:
|
342 |
+
print(f"Warning: 'pipeline_tag' column not found for filtering.")
|
343 |
+
if size_filter and size_filter != "None" and size_filter in MODEL_SIZE_RANGES.keys():
|
344 |
+
if 'size_category' in filtered_df.columns:
|
345 |
+
filtered_df = filtered_df[filtered_df['size_category'] == size_filter]
|
346 |
+
else:
|
347 |
+
print("Warning: 'size_category' column not found for filtering.")
|
|
|
|
|
|
|
|
|
348 |
if skip_orgs and len(skip_orgs) > 0:
|
349 |
+
if "organization" in filtered_df.columns:
|
350 |
+
filtered_df = filtered_df[~filtered_df["organization"].isin(skip_orgs)]
|
351 |
+
else:
|
352 |
+
print("Warning: 'organization' column not found for filtering.")
|
353 |
+
if filtered_df.empty: return pd.DataFrame()
|
354 |
+
# Ensure count_by column is numeric, coercing if necessary
|
355 |
+
if count_by not in filtered_df.columns or not pd.api.types.is_numeric_dtype(filtered_df[count_by]):
|
356 |
+
# print(f"Warning: Column '{count_by}' for treemap values is not numeric or missing. Coercing to numeric, filling NaNs with 0.")
|
357 |
+
filtered_df[count_by] = pd.to_numeric(filtered_df.get(count_by), errors="coerce").fillna(0.0)
|
358 |
+
|
359 |
+
org_totals = filtered_df.groupby("organization")[count_by].sum().nlargest(top_k, keep='first') # Default keep='first'
|
360 |
+
top_orgs_list = org_totals.index.tolist()
|
361 |
+
treemap_data = filtered_df[filtered_df["organization"].isin(top_orgs_list)][["id", "organization", count_by]].copy()
|
362 |
+
treemap_data["root"] = "models" # For treemap structure
|
363 |
+
treemap_data[count_by] = pd.to_numeric(treemap_data[count_by], errors="coerce").fillna(0.0) # Ensure numeric again after subsetting
|
|
|
|
|
|
|
|
|
|
|
|
|
364 |
return treemap_data
|
365 |
|
366 |
def create_treemap(treemap_data, count_by, title=None):
|
|
|
367 |
if treemap_data.empty:
|
368 |
+
fig = px.treemap(names=["No data matches filters"], parents=[""], values=[1]) # Placeholder for empty data
|
369 |
+
fig.update_layout(title="No data matches the selected filters", margin=dict(t=50, l=25, r=25, b=25))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
370 |
return fig
|
|
|
|
|
371 |
fig = px.treemap(
|
372 |
+
treemap_data, path=["root", "organization", "id"], values=count_by,
|
|
|
|
|
373 |
title=title or f"HuggingFace Models - {count_by.capitalize()} by Organization",
|
374 |
+
color_discrete_sequence=px.colors.qualitative.Plotly # Example color sequence
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
375 |
)
|
376 |
+
fig.update_layout(margin=dict(t=50, l=25, r=25, b=25))
|
377 |
+
fig.update_traces(textinfo="label+value+percent root", hovertemplate="<b>%{label}</b><br>%{value:,} " + count_by + "<br>%{percentRoot:.2%} of total<extra></extra>")
|
378 |
return fig
|
379 |
|
380 |
+
# --- Gradio UI and Controllers ---
|
381 |
+
with gr.Blocks(title="HuggingFace Model Explorer") as demo:
|
382 |
+
models_data_state = gr.State(pd.DataFrame())
|
383 |
+
loading_complete_state = gr.State(False) # To control button interactivity
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
|
|
|
|
|
|
|
|
|
385 |
with gr.Row():
|
386 |
+
gr.Markdown("# HuggingFace Models TreeMap Visualization")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
387 |
with gr.Row():
|
388 |
+
with gr.Column(scale=1): # Controls column
|
389 |
+
count_by_dropdown = gr.Dropdown(label="Metric", choices=[("Downloads (last 30 days)", "downloads"), ("Downloads (All Time)", "downloadsAllTime"), ("Likes", "likes")], value="downloads")
|
390 |
+
filter_choice_radio = gr.Radio(label="Filter Type", choices=["None", "Tag Filter", "Pipeline Filter"], value="None")
|
391 |
+
tag_filter_dropdown = gr.Dropdown(label="Select Tag", choices=TAG_FILTER_CHOICES, value=None, visible=False)
|
392 |
+
pipeline_filter_dropdown = gr.Dropdown(label="Select Pipeline Tag", choices=PIPELINE_TAGS, value=None, visible=False)
|
393 |
+
size_filter_dropdown = gr.Dropdown(label="Model Size Filter", choices=["None"] + list(MODEL_SIZE_RANGES.keys()), value="None")
|
394 |
+
top_k_slider = gr.Slider(label="Number of Top Organizations", minimum=5, maximum=50, value=25, step=5)
|
395 |
+
skip_orgs_textbox = gr.Textbox(label="Organizations to Skip (comma-separated)", value="TheBloke,MaziyarPanahi,unsloth,modularai,Gensyn,bartowski") # Common large orgs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
396 |
|
397 |
+
generate_plot_button = gr.Button(value="Generate Plot", variant="primary", interactive=False) # Starts disabled
|
398 |
+
refresh_data_button = gr.Button(value="Refresh Data from Hugging Face", variant="secondary")
|
|
|
|
|
|
|
|
|
399 |
|
400 |
+
with gr.Column(scale=3): # Plot and info column
|
401 |
+
plot_output = gr.Plot()
|
402 |
+
status_message_md = gr.Markdown("Initializing...") # For general status updates
|
403 |
+
data_info_md = gr.Markdown("") # For detailed data stats
|
404 |
+
|
405 |
+
# Enable generate button only after data is loaded
|
406 |
+
def _update_button_interactivity(is_loaded_flag):
|
407 |
+
return gr.update(interactive=is_loaded_flag)
|
408 |
+
loading_complete_state.change(fn=_update_button_interactivity, inputs=loading_complete_state, outputs=generate_plot_button)
|
409 |
+
|
410 |
+
# Show/hide tag/pipeline filters based on radio choice
|
411 |
+
def _toggle_filters_visibility(choice):
|
412 |
+
return gr.update(visible=choice == "Tag Filter"), gr.update(visible=choice == "Pipeline Filter")
|
413 |
+
filter_choice_radio.change(fn=_toggle_filters_visibility, inputs=filter_choice_radio, outputs=[tag_filter_dropdown, pipeline_filter_dropdown])
|
414 |
+
|
415 |
+
|
416 |
+
def ui_load_data_controller(force_refresh_ui_trigger=False, progress=gr.Progress(track_tqdm=True)): # Gradio progress tracker
|
417 |
+
print(f"ui_load_data_controller called with force_refresh_ui_trigger={force_refresh_ui_trigger}")
|
418 |
+
status_msg_ui = "Loading data..."
|
419 |
+
data_info_text = ""
|
420 |
+
current_df = pd.DataFrame()
|
421 |
+
load_success_flag = False
|
422 |
+
data_as_of_date_display = "N/A"
|
423 |
+
|
424 |
+
try:
|
425 |
+
# Pass gr.Progress.tqdm to load_models_data if it's a Gradio call
|
426 |
+
current_df, load_success_flag, status_msg_from_load = load_models_data(
|
427 |
+
force_refresh=force_refresh_ui_trigger, tqdm_cls=progress.tqdm if progress else tqdm
|
428 |
)
|
429 |
|
430 |
+
if load_success_flag:
|
431 |
+
if force_refresh_ui_trigger: # Data was just fetched by Gradio
|
432 |
+
data_as_of_date_display = pd.Timestamp.now(tz='UTC').strftime('%B %d, %Y, %H:%M:%S %Z')
|
433 |
+
# If loaded from pre-processed parquet, check for its timestamp column
|
434 |
+
elif 'data_download_timestamp' in current_df.columns and not current_df.empty and pd.notna(current_df['data_download_timestamp'].iloc[0]):
|
435 |
+
timestamp_from_parquet = pd.to_datetime(current_df['data_download_timestamp'].iloc[0])
|
436 |
+
if timestamp_from_parquet.tzinfo is None: # If no timezone, assume UTC from preprocessor
|
437 |
+
timestamp_from_parquet = timestamp_from_parquet.tz_localize('UTC')
|
438 |
+
data_as_of_date_display = timestamp_from_parquet.strftime('%B %d, %Y, %H:%M:%S %Z')
|
439 |
+
else: # Pre-processed data but no timestamp column or it's NaT
|
440 |
+
data_as_of_date_display = "Pre-processed (date unavailable)"
|
441 |
|
442 |
+
# Build data info string
|
443 |
+
size_dist_lines = []
|
444 |
+
if 'size_category' in current_df.columns:
|
445 |
+
for cat in MODEL_SIZE_RANGES.keys():
|
446 |
+
count = (current_df['size_category'] == cat).sum()
|
447 |
+
size_dist_lines.append(f" - {cat}: {count:,} models")
|
448 |
+
else: size_dist_lines.append(" - Size category information not available.")
|
449 |
+
size_dist = "\n".join(size_dist_lines)
|
450 |
|
451 |
+
data_info_text = (f"### Data Information\n"
|
452 |
+
f"- Overall Status: {status_msg_from_load}\n"
|
453 |
+
f"- Total models loaded: {len(current_df):,}\n"
|
454 |
+
f"- Data as of: {data_as_of_date_display}\n"
|
455 |
+
f"- Size categories:\n{size_dist}")
|
456 |
+
|
457 |
+
if not current_df.empty and 'has_robot' in current_df.columns:
|
458 |
+
robot_true_count = current_df['has_robot'].sum()
|
459 |
+
data_info_text += f"\n- **Models flagged 'has_robot'**: {robot_true_count}"
|
460 |
+
if 0 < robot_true_count <= 10:
|
461 |
+
sample_robot_ids = current_df[current_df['has_robot']]['id'].head(5).tolist()
|
462 |
+
data_info_text += f"\n - Sample 'has_robot' model IDs: `{', '.join(sample_robot_ids)}`"
|
463 |
+
elif not current_df.empty:
|
464 |
+
data_info_text += "\n- **Models flagged 'has_robot'**: 'has_robot' column not found."
|
465 |
+
|
466 |
+
status_msg_ui = "Data loaded successfully. Ready to generate plot."
|
467 |
+
else: # load_success_flag is False
|
468 |
+
data_info_text = f"### Data Load Failed\n- {status_msg_from_load}"
|
469 |
+
status_msg_ui = status_msg_from_load # Pass error message from load_models_data
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
470 |
|
471 |
+
except Exception as e:
|
472 |
+
status_msg_ui = f"An unexpected error occurred in ui_load_data_controller: {str(e)}"
|
473 |
+
data_info_text = f"### Critical Error\n- {status_msg_ui}"
|
474 |
+
print(f"Critical error in ui_load_data_controller: {e}")
|
475 |
+
load_success_flag = False # Ensure this is false on error
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
476 |
|
477 |
+
return current_df, load_success_flag, data_info_text, status_msg_ui
|
478 |
|
479 |
+
def ui_generate_plot_controller(metric_choice, filter_type, tag_choice, pipeline_choice,
|
480 |
+
size_choice, k_orgs, skip_orgs_input, df_current_models):
|
481 |
+
if df_current_models is None or df_current_models.empty:
|
482 |
+
empty_fig = create_treemap(pd.DataFrame(), metric_choice, "Error: Model Data Not Loaded")
|
483 |
+
error_msg = "Model data is not loaded or is empty. Please load or refresh data first."
|
484 |
+
gr.Warning(error_msg) # Display a Gradio warning
|
485 |
+
return empty_fig, error_msg
|
486 |
+
|
487 |
+
tag_to_use = tag_choice if filter_type == "Tag Filter" else None
|
488 |
+
pipeline_to_use = pipeline_choice if filter_type == "Pipeline Filter" else None
|
489 |
+
size_to_use = size_choice if size_choice != "None" else None # Handle "None" string
|
490 |
+
orgs_to_skip = [org.strip() for org in skip_orgs_input.split(',') if org.strip()] if skip_orgs_input else []
|
491 |
+
|
492 |
+
# if 'has_robot' in df_current_models.columns:
|
493 |
+
# robot_count_before_treemap = df_current_models['has_robot'].sum()
|
494 |
+
# print(f"DIAGNOSTIC (ui_generate_plot_controller): df_current_models entering make_treemap_data has {robot_count_before_treemap} 'has_robot' models.")
|
495 |
+
|
496 |
+
treemap_df = make_treemap_data(df_current_models, metric_choice, k_orgs, tag_to_use, pipeline_to_use, size_to_use, orgs_to_skip)
|
497 |
|
498 |
+
title_labels = {"downloads": "Downloads (last 30 days)", "downloadsAllTime": "Downloads (All Time)", "likes": "Likes"}
|
499 |
+
chart_title = f"HuggingFace Models - {title_labels.get(metric_choice, metric_choice)} by Organization"
|
500 |
+
plotly_fig = create_treemap(treemap_df, metric_choice, chart_title)
|
501 |
+
|
502 |
+
if treemap_df.empty:
|
503 |
+
plot_stats_md = "No data matches the selected filters. Try adjusting your filters."
|
504 |
+
else:
|
505 |
+
total_items_in_plot = len(treemap_df['id'].unique()) # Count unique models in plot
|
506 |
+
total_value_in_plot = treemap_df[metric_choice].sum() # Sum of metric in plot
|
507 |
+
plot_stats_md = (f"## Plot Statistics\n- **Models shown**: {total_items_in_plot:,}\n- **Total {metric_choice}**: {int(total_value_in_plot):,}")
|
508 |
+
|
509 |
+
return plotly_fig, plot_stats_md
|
510 |
+
|
511 |
+
# --- Event Handlers ---
|
512 |
+
# Initial data load on app start
|
|
|
|
|
513 |
demo.load(
|
514 |
+
fn=lambda progress=gr.Progress(track_tqdm=True): ui_load_data_controller(force_refresh_ui_trigger=False, progress=progress),
|
515 |
+
inputs=[], # No inputs for initial load
|
516 |
+
outputs=[models_data_state, loading_complete_state, data_info_md, status_message_md]
|
517 |
+
)
|
518 |
+
|
519 |
+
# Refresh data button
|
520 |
+
refresh_data_button.click(
|
521 |
+
fn=lambda progress=gr.Progress(track_tqdm=True): ui_load_data_controller(force_refresh_ui_trigger=True, progress=progress),
|
522 |
inputs=[],
|
523 |
+
outputs=[models_data_state, loading_complete_state, data_info_md, status_message_md]
|
524 |
)
|
525 |
|
526 |
+
# Generate plot button
|
527 |
generate_plot_button.click(
|
528 |
+
fn=ui_generate_plot_controller,
|
529 |
+
inputs=[count_by_dropdown, filter_choice_radio, tag_filter_dropdown, pipeline_filter_dropdown,
|
530 |
+
size_filter_dropdown, top_k_slider, skip_orgs_textbox, models_data_state],
|
531 |
+
outputs=[plot_output, status_message_md] # Update plot and status message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
532 |
)
|
533 |
|
|
|
534 |
if __name__ == "__main__":
|
535 |
+
if not os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
536 |
+
print(f"WARNING: Pre-processed data file '{PROCESSED_PARQUET_FILE_PATH}' not found.")
|
537 |
+
print("It is highly recommended to run the preprocessing script (preprocess.py) first.")
|
538 |
+
else:
|
539 |
+
print(f"Found pre-processed data file: '{PROCESSED_PARQUET_FILE_PATH}'.")
|
540 |
+
demo.launch()
|
541 |
+
# --- END OF FILE app.py ---
|
models.csv → models_processed.parquet
RENAMED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:998afad6c0c4c64f9e98efd8609d1cbab1dd2ac281b9c2e023878ad436c2fbde
|
3 |
+
size 96033487
|
preprocess.py
ADDED
@@ -0,0 +1,371 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# --- START OF FILE preprocess.py ---
|
2 |
+
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import json
|
6 |
+
import ast
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
import time
|
9 |
+
import os
|
10 |
+
import duckdb
|
11 |
+
import re # Import re for the manual regex check in debug
|
12 |
+
|
13 |
+
# --- Constants ---
|
14 |
+
PROCESSED_PARQUET_FILE_PATH = "models_processed.parquet"
|
15 |
+
HF_PARQUET_URL = 'https://huggingface.co/datasets/cfahlgren1/hub-stats/resolve/main/models.parquet'
|
16 |
+
|
17 |
+
MODEL_SIZE_RANGES = {
|
18 |
+
"Small (<1GB)": (0, 1),
|
19 |
+
"Medium (1-5GB)": (1, 5),
|
20 |
+
"Large (5-20GB)": (5, 20),
|
21 |
+
"X-Large (20-50GB)": (20, 50),
|
22 |
+
"XX-Large (>50GB)": (50, float('inf'))
|
23 |
+
}
|
24 |
+
|
25 |
+
# --- Debugging Constant ---
|
26 |
+
# <<<<<<< SET THE MODEL ID YOU WANT TO DEBUG HERE >>>>>>>
|
27 |
+
MODEL_ID_TO_DEBUG = "openvla/openvla-7b"
|
28 |
+
# Example: MODEL_ID_TO_DEBUG = "openai-community/gpt2"
|
29 |
+
# If you don't have a specific ID, the debug block will just report it's not found.
|
30 |
+
|
31 |
+
# --- Utility Functions (extract_model_file_size_gb, extract_org_from_id, process_tags_for_series, get_file_size_category - unchanged from previous correct version) ---
|
32 |
+
def extract_model_file_size_gb(safetensors_data):
|
33 |
+
try:
|
34 |
+
if pd.isna(safetensors_data): return 0.0
|
35 |
+
data_to_parse = safetensors_data
|
36 |
+
if isinstance(safetensors_data, str):
|
37 |
+
try:
|
38 |
+
if (safetensors_data.startswith('{') and safetensors_data.endswith('}')) or \
|
39 |
+
(safetensors_data.startswith('[') and safetensors_data.endswith(']')):
|
40 |
+
data_to_parse = ast.literal_eval(safetensors_data)
|
41 |
+
else: data_to_parse = json.loads(safetensors_data)
|
42 |
+
except Exception: return 0.0
|
43 |
+
if isinstance(data_to_parse, dict) and 'total' in data_to_parse:
|
44 |
+
total_bytes_val = data_to_parse['total']
|
45 |
+
try:
|
46 |
+
size_bytes = float(total_bytes_val)
|
47 |
+
return size_bytes / (1024 * 1024 * 1024)
|
48 |
+
except (ValueError, TypeError): return 0.0
|
49 |
+
return 0.0
|
50 |
+
except Exception: return 0.0
|
51 |
+
|
52 |
+
def extract_org_from_id(model_id):
|
53 |
+
if pd.isna(model_id): return "unaffiliated"
|
54 |
+
model_id_str = str(model_id)
|
55 |
+
return model_id_str.split("/")[0] if "/" in model_id_str else "unaffiliated"
|
56 |
+
|
57 |
+
def process_tags_for_series(series_of_tags_values):
|
58 |
+
processed_tags_accumulator = []
|
59 |
+
|
60 |
+
for i, tags_value_from_series in enumerate(tqdm(series_of_tags_values, desc="Standardizing Tags", leave=False, unit="row")):
|
61 |
+
temp_processed_list_for_row = []
|
62 |
+
current_value_for_error_msg = str(tags_value_from_series)[:200] # Truncate for long error messages
|
63 |
+
|
64 |
+
try:
|
65 |
+
# Order of checks is important!
|
66 |
+
# 1. Handle explicit Python lists first
|
67 |
+
if isinstance(tags_value_from_series, list):
|
68 |
+
current_tags_in_list = []
|
69 |
+
for idx_tag, tag_item in enumerate(tags_value_from_series):
|
70 |
+
try:
|
71 |
+
# Ensure item is not NaN before string conversion if it might be a float NaN in a list
|
72 |
+
if pd.isna(tag_item): continue
|
73 |
+
str_tag = str(tag_item)
|
74 |
+
stripped_tag = str_tag.strip()
|
75 |
+
if stripped_tag:
|
76 |
+
current_tags_in_list.append(stripped_tag)
|
77 |
+
except Exception as e_inner_list_proc:
|
78 |
+
print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a list for row {i}. Error: {e_inner_list_proc}. Original list: {current_value_for_error_msg}")
|
79 |
+
temp_processed_list_for_row = current_tags_in_list
|
80 |
+
|
81 |
+
# 2. Handle NumPy arrays
|
82 |
+
elif isinstance(tags_value_from_series, np.ndarray):
|
83 |
+
# Convert to list, then process elements, handling potential NaNs within the array
|
84 |
+
current_tags_in_list = []
|
85 |
+
for idx_tag, tag_item in enumerate(tags_value_from_series.tolist()): # .tolist() is crucial
|
86 |
+
try:
|
87 |
+
if pd.isna(tag_item): continue # Check for NaN after converting to Python type
|
88 |
+
str_tag = str(tag_item)
|
89 |
+
stripped_tag = str_tag.strip()
|
90 |
+
if stripped_tag:
|
91 |
+
current_tags_in_list.append(stripped_tag)
|
92 |
+
except Exception as e_inner_array_proc:
|
93 |
+
print(f"ERROR processing item '{tag_item}' (type: {type(tag_item)}) within a NumPy array for row {i}. Error: {e_inner_array_proc}. Original array: {current_value_for_error_msg}")
|
94 |
+
temp_processed_list_for_row = current_tags_in_list
|
95 |
+
|
96 |
+
# 3. Handle simple None or pd.NA after lists and arrays (which might contain pd.NA elements handled above)
|
97 |
+
elif tags_value_from_series is None or pd.isna(tags_value_from_series): # Now pd.isna is safe for scalars
|
98 |
+
temp_processed_list_for_row = []
|
99 |
+
|
100 |
+
# 4. Handle strings (could be JSON-like, list-like, or comma-separated)
|
101 |
+
elif isinstance(tags_value_from_series, str):
|
102 |
+
processed_str_tags = []
|
103 |
+
# Attempt ast.literal_eval for strings that look like lists/tuples
|
104 |
+
if (tags_value_from_series.startswith('[') and tags_value_from_series.endswith(']')) or \
|
105 |
+
(tags_value_from_series.startswith('(') and tags_value_from_series.endswith(')')):
|
106 |
+
try:
|
107 |
+
evaluated_tags = ast.literal_eval(tags_value_from_series)
|
108 |
+
if isinstance(evaluated_tags, (list, tuple)): # Check if eval result is a list/tuple
|
109 |
+
# Recursively process this evaluated list/tuple, as its elements could be complex
|
110 |
+
# For simplicity here, assume elements are simple strings after eval
|
111 |
+
current_eval_list = []
|
112 |
+
for tag_item in evaluated_tags:
|
113 |
+
if pd.isna(tag_item): continue
|
114 |
+
str_tag = str(tag_item).strip()
|
115 |
+
if str_tag: current_eval_list.append(str_tag)
|
116 |
+
processed_str_tags = current_eval_list
|
117 |
+
except (ValueError, SyntaxError):
|
118 |
+
pass # If ast.literal_eval fails, let it fall to JSON or comma split
|
119 |
+
|
120 |
+
# If ast.literal_eval didn't populate, try JSON
|
121 |
+
if not processed_str_tags:
|
122 |
+
try:
|
123 |
+
json_tags = json.loads(tags_value_from_series)
|
124 |
+
if isinstance(json_tags, list):
|
125 |
+
# Similar to above, assume elements are simple strings after JSON parsing
|
126 |
+
current_json_list = []
|
127 |
+
for tag_item in json_tags:
|
128 |
+
if pd.isna(tag_item): continue
|
129 |
+
str_tag = str(tag_item).strip()
|
130 |
+
if str_tag: current_json_list.append(str_tag)
|
131 |
+
processed_str_tags = current_json_list
|
132 |
+
except json.JSONDecodeError:
|
133 |
+
# If not a valid JSON list, fall back to comma splitting as the final string strategy
|
134 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()]
|
135 |
+
except Exception as e_json_other:
|
136 |
+
print(f"ERROR during JSON processing for string '{current_value_for_error_msg}' for row {i}. Error: {e_json_other}")
|
137 |
+
processed_str_tags = [tag.strip() for tag in tags_value_from_series.split(',') if tag.strip()] # Fallback
|
138 |
+
|
139 |
+
temp_processed_list_for_row = processed_str_tags
|
140 |
+
|
141 |
+
# 5. Fallback for other scalar types (e.g., int, float that are not NaN)
|
142 |
+
else:
|
143 |
+
# This path is for non-list, non-ndarray, non-None/NaN, non-string types.
|
144 |
+
# Or for NaNs that slipped through if they are not None or pd.NA (e.g. float('nan'))
|
145 |
+
if pd.isna(tags_value_from_series): # Catch any remaining NaNs like float('nan')
|
146 |
+
temp_processed_list_for_row = []
|
147 |
+
else:
|
148 |
+
str_val = str(tags_value_from_series).strip()
|
149 |
+
temp_processed_list_for_row = [str_val] if str_val else []
|
150 |
+
|
151 |
+
processed_tags_accumulator.append(temp_processed_list_for_row)
|
152 |
+
|
153 |
+
except Exception as e_outer_tag_proc:
|
154 |
+
print(f"CRITICAL UNHANDLED ERROR processing row {i}: value '{current_value_for_error_msg}' (type: {type(tags_value_from_series)}). Error: {e_outer_tag_proc}. Appending [].")
|
155 |
+
processed_tags_accumulator.append([])
|
156 |
+
|
157 |
+
return processed_tags_accumulator
|
158 |
+
|
159 |
+
def get_file_size_category(file_size_gb_val):
|
160 |
+
try:
|
161 |
+
numeric_file_size_gb = float(file_size_gb_val)
|
162 |
+
if pd.isna(numeric_file_size_gb): numeric_file_size_gb = 0.0
|
163 |
+
except (ValueError, TypeError): numeric_file_size_gb = 0.0
|
164 |
+
if 0 <= numeric_file_size_gb < 1: return "Small (<1GB)"
|
165 |
+
elif 1 <= numeric_file_size_gb < 5: return "Medium (1-5GB)"
|
166 |
+
elif 5 <= numeric_file_size_gb < 20: return "Large (5-20GB)"
|
167 |
+
elif 20 <= numeric_file_size_gb < 50: return "X-Large (20-50GB)"
|
168 |
+
elif numeric_file_size_gb >= 50: return "XX-Large (>50GB)"
|
169 |
+
else: return "Small (<1GB)"
|
170 |
+
|
171 |
+
|
172 |
+
def main_preprocessor():
|
173 |
+
print(f"Starting pre-processing script. Output: '{PROCESSED_PARQUET_FILE_PATH}'.")
|
174 |
+
overall_start_time = time.time()
|
175 |
+
|
176 |
+
print(f"Fetching fresh data from Hugging Face: {HF_PARQUET_URL}")
|
177 |
+
try:
|
178 |
+
fetch_start_time = time.time()
|
179 |
+
query = f"SELECT * FROM read_parquet('{HF_PARQUET_URL}')"
|
180 |
+
df_raw = duckdb.sql(query).df()
|
181 |
+
data_download_timestamp = pd.Timestamp.now(tz='UTC')
|
182 |
+
|
183 |
+
if df_raw is None or df_raw.empty: raise ValueError("Fetched data is empty or None.")
|
184 |
+
if 'id' not in df_raw.columns: raise ValueError("Fetched data must contain 'id' column.")
|
185 |
+
|
186 |
+
print(f"Fetched data in {time.time() - fetch_start_time:.2f}s. Rows: {len(df_raw)}. Downloaded at: {data_download_timestamp.strftime('%Y-%m-%d %H:%M:%S %Z')}")
|
187 |
+
except Exception as e_fetch:
|
188 |
+
print(f"ERROR: Could not fetch data from Hugging Face: {e_fetch}.")
|
189 |
+
return
|
190 |
+
|
191 |
+
df = pd.DataFrame()
|
192 |
+
print("Processing raw data...")
|
193 |
+
proc_start = time.time()
|
194 |
+
|
195 |
+
expected_cols_setup = {
|
196 |
+
'id': str, 'downloads': float, 'downloadsAllTime': float, 'likes': float,
|
197 |
+
'pipeline_tag': str, 'tags': object, 'safetensors': object
|
198 |
+
}
|
199 |
+
for col_name, target_dtype in expected_cols_setup.items():
|
200 |
+
if col_name in df_raw.columns:
|
201 |
+
df[col_name] = df_raw[col_name]
|
202 |
+
if target_dtype == float: df[col_name] = pd.to_numeric(df[col_name], errors='coerce').fillna(0.0)
|
203 |
+
elif target_dtype == str: df[col_name] = df[col_name].astype(str).fillna('')
|
204 |
+
else:
|
205 |
+
if col_name in ['downloads', 'downloadsAllTime', 'likes']: df[col_name] = 0.0
|
206 |
+
elif col_name == 'pipeline_tag': df[col_name] = ''
|
207 |
+
elif col_name == 'tags': df[col_name] = pd.Series([[] for _ in range(len(df_raw))]) # Initialize with empty lists
|
208 |
+
elif col_name == 'safetensors': df[col_name] = None # Initialize with None
|
209 |
+
elif col_name == 'id': print("CRITICAL ERROR: 'id' column missing."); return
|
210 |
+
|
211 |
+
output_filesize_col_name = 'params'
|
212 |
+
if output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name]):
|
213 |
+
print(f"Using pre-existing '{output_filesize_col_name}' column as file size in GB.")
|
214 |
+
df[output_filesize_col_name] = pd.to_numeric(df_raw[output_filesize_col_name], errors='coerce').fillna(0.0)
|
215 |
+
elif 'safetensors' in df.columns:
|
216 |
+
print(f"Calculating '{output_filesize_col_name}' (file size in GB) from 'safetensors' data...")
|
217 |
+
df[output_filesize_col_name] = df['safetensors'].apply(extract_model_file_size_gb)
|
218 |
+
df[output_filesize_col_name] = pd.to_numeric(df[output_filesize_col_name], errors='coerce').fillna(0.0)
|
219 |
+
else:
|
220 |
+
print(f"Cannot determine file size. Setting '{output_filesize_col_name}' to 0.0.")
|
221 |
+
df[output_filesize_col_name] = 0.0
|
222 |
+
|
223 |
+
df['data_download_timestamp'] = data_download_timestamp
|
224 |
+
print(f"Added 'data_download_timestamp' column.")
|
225 |
+
|
226 |
+
print("Categorizing models by file size...")
|
227 |
+
df['size_category'] = df[output_filesize_col_name].apply(get_file_size_category)
|
228 |
+
|
229 |
+
print("Standardizing 'tags' column...")
|
230 |
+
df['tags'] = process_tags_for_series(df['tags']) # This now uses tqdm internally
|
231 |
+
|
232 |
+
# --- START DEBUGGING BLOCK ---
|
233 |
+
# This block will execute before the main tag processing loop
|
234 |
+
if MODEL_ID_TO_DEBUG and MODEL_ID_TO_DEBUG in df['id'].values: # Check if ID exists
|
235 |
+
print(f"\n--- Pre-Loop Debugging for Model ID: {MODEL_ID_TO_DEBUG} ---")
|
236 |
+
|
237 |
+
# 1. Check the 'tags' column content after process_tags_for_series
|
238 |
+
model_specific_tags_list = df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'tags'].iloc[0]
|
239 |
+
print(f"1. Tags from df['tags'] (after process_tags_for_series): {model_specific_tags_list}")
|
240 |
+
print(f" Type of tags: {type(model_specific_tags_list)}")
|
241 |
+
if isinstance(model_specific_tags_list, list):
|
242 |
+
for i, tag_item in enumerate(model_specific_tags_list):
|
243 |
+
print(f" Tag item {i}: '{tag_item}' (type: {type(tag_item)}, len: {len(str(tag_item))})")
|
244 |
+
# Detailed check for 'robotics' specifically
|
245 |
+
if 'robotics' in str(tag_item).lower():
|
246 |
+
print(f" DEBUG: Found 'robotics' substring in '{tag_item}'")
|
247 |
+
print(f" - str(tag_item).lower().strip(): '{str(tag_item).lower().strip()}'")
|
248 |
+
print(f" - Is it exactly 'robotics'?: {str(tag_item).lower().strip() == 'robotics'}")
|
249 |
+
print(f" - Ordinals: {[ord(c) for c in str(tag_item)]}")
|
250 |
+
|
251 |
+
# 2. Simulate temp_tags_joined for this specific model
|
252 |
+
if isinstance(model_specific_tags_list, list):
|
253 |
+
simulated_temp_tags_joined = '~~~'.join(str(t).lower().strip() for t in model_specific_tags_list if pd.notna(t) and str(t).strip())
|
254 |
+
else:
|
255 |
+
simulated_temp_tags_joined = ''
|
256 |
+
print(f"2. Simulated 'temp_tags_joined' for this model: '{simulated_temp_tags_joined}'")
|
257 |
+
|
258 |
+
# 3. Simulate 'has_robot' check for this model
|
259 |
+
robot_keywords = ['robot', 'robotics']
|
260 |
+
robot_pattern = '|'.join(robot_keywords)
|
261 |
+
manual_robot_check = bool(re.search(robot_pattern, simulated_temp_tags_joined, flags=re.IGNORECASE))
|
262 |
+
print(f"3. Manual regex check for 'has_robot' ('{robot_pattern}' in '{simulated_temp_tags_joined}'): {manual_robot_check}")
|
263 |
+
print(f"--- End Pre-Loop Debugging for Model ID: {MODEL_ID_TO_DEBUG} ---\n")
|
264 |
+
elif MODEL_ID_TO_DEBUG:
|
265 |
+
print(f"DEBUG: Model ID '{MODEL_ID_TO_DEBUG}' not found in DataFrame for pre-loop debugging.")
|
266 |
+
# --- END DEBUGGING BLOCK ---
|
267 |
+
|
268 |
+
|
269 |
+
print("Vectorized creation of cached tag columns...")
|
270 |
+
tag_time = time.time()
|
271 |
+
# This is the original temp_tags_joined creation:
|
272 |
+
df['temp_tags_joined'] = df['tags'].apply(
|
273 |
+
lambda tl: '~~~'.join(str(t).lower().strip() for t in tl if pd.notna(t) and str(t).strip()) if isinstance(tl, list) else ''
|
274 |
+
)
|
275 |
+
|
276 |
+
tag_map = {
|
277 |
+
'has_audio': ['audio'], 'has_speech': ['speech'], 'has_music': ['music'],
|
278 |
+
'has_robot': ['robot', 'robotics','openvla','vla'],
|
279 |
+
'has_bio': ['bio'], 'has_med': ['medic', 'medical'],
|
280 |
+
'has_series': ['series', 'time-series', 'timeseries'],
|
281 |
+
'has_video': ['video'], 'has_image': ['image', 'vision'],
|
282 |
+
'has_text': ['text', 'nlp', 'llm']
|
283 |
+
}
|
284 |
+
for col, kws in tag_map.items():
|
285 |
+
pattern = '|'.join(kws)
|
286 |
+
df[col] = df['temp_tags_joined'].str.contains(pattern, na=False, case=False, regex=True)
|
287 |
+
|
288 |
+
df['has_science'] = (
|
289 |
+
df['temp_tags_joined'].str.contains('science', na=False, case=False, regex=True) &
|
290 |
+
~df['temp_tags_joined'].str.contains('bigscience', na=False, case=False, regex=True)
|
291 |
+
)
|
292 |
+
del df['temp_tags_joined'] # Clean up temporary column
|
293 |
+
df['is_audio_speech'] = (df['has_audio'] | df['has_speech'] |
|
294 |
+
df['pipeline_tag'].str.contains('audio|speech', case=False, na=False, regex=True))
|
295 |
+
df['is_biomed'] = df['has_bio'] | df['has_med']
|
296 |
+
print(f"Vectorized tag columns created in {time.time() - tag_time:.2f}s.")
|
297 |
+
|
298 |
+
# --- POST-LOOP DIAGNOSTIC for has_robot & a specific model ---
|
299 |
+
if 'has_robot' in df.columns:
|
300 |
+
print("\n--- 'has_robot' Diagnostics (Preprocessor - Post-Loop) ---")
|
301 |
+
print(df['has_robot'].value_counts(dropna=False))
|
302 |
+
|
303 |
+
if MODEL_ID_TO_DEBUG and MODEL_ID_TO_DEBUG in df['id'].values:
|
304 |
+
model_has_robot_val = df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'has_robot'].iloc[0]
|
305 |
+
print(f"Value of 'has_robot' for model '{MODEL_ID_TO_DEBUG}': {model_has_robot_val}")
|
306 |
+
if model_has_robot_val:
|
307 |
+
print(f" Original tags for '{MODEL_ID_TO_DEBUG}': {df.loc[df['id'] == MODEL_ID_TO_DEBUG, 'tags'].iloc[0]}")
|
308 |
+
|
309 |
+
if df['has_robot'].any():
|
310 |
+
print("Sample models flagged as 'has_robot':")
|
311 |
+
print(df[df['has_robot']][['id', 'tags', 'has_robot']].head(5))
|
312 |
+
else:
|
313 |
+
print("No models were flagged as 'has_robot' after processing.")
|
314 |
+
print("--------------------------------------------------------\n")
|
315 |
+
# --- END POST-LOOP DIAGNOSTIC ---
|
316 |
+
|
317 |
+
|
318 |
+
print("Adding organization column...")
|
319 |
+
df['organization'] = df['id'].apply(extract_org_from_id)
|
320 |
+
|
321 |
+
# Drop safetensors if params was calculated from it, and params didn't pre-exist as numeric
|
322 |
+
if 'safetensors' in df.columns and \
|
323 |
+
not (output_filesize_col_name in df_raw.columns and pd.api.types.is_numeric_dtype(df_raw[output_filesize_col_name])):
|
324 |
+
df = df.drop(columns=['safetensors'], errors='ignore')
|
325 |
+
|
326 |
+
final_expected_cols = [
|
327 |
+
'id', 'downloads', 'downloadsAllTime', 'likes', 'pipeline_tag', 'tags',
|
328 |
+
'params', 'size_category', 'organization',
|
329 |
+
'has_audio', 'has_speech', 'has_music', 'has_robot', 'has_bio', 'has_med',
|
330 |
+
'has_series', 'has_video', 'has_image', 'has_text', 'has_science',
|
331 |
+
'is_audio_speech', 'is_biomed',
|
332 |
+
'data_download_timestamp'
|
333 |
+
]
|
334 |
+
# Ensure all final columns exist, adding defaults if necessary
|
335 |
+
for col in final_expected_cols:
|
336 |
+
if col not in df.columns:
|
337 |
+
print(f"Warning: Final expected column '{col}' is missing! Defaulting appropriately.")
|
338 |
+
if col == 'params': df[col] = 0.0
|
339 |
+
elif col == 'size_category': df[col] = "Small (<1GB)" # Default size category
|
340 |
+
elif 'has_' in col or 'is_' in col : df[col] = False # Default boolean flags to False
|
341 |
+
elif col == 'data_download_timestamp': df[col] = pd.NaT # Default timestamp to NaT
|
342 |
+
|
343 |
+
print(f"Data processing completed in {time.time() - proc_start:.2f}s.")
|
344 |
+
try:
|
345 |
+
print(f"Saving processed data to: {PROCESSED_PARQUET_FILE_PATH}")
|
346 |
+
df_to_save = df[final_expected_cols].copy() # Ensure only expected columns are saved
|
347 |
+
df_to_save.to_parquet(PROCESSED_PARQUET_FILE_PATH, index=False, engine='pyarrow')
|
348 |
+
print(f"Successfully saved processed data.")
|
349 |
+
except Exception as e_save:
|
350 |
+
print(f"ERROR: Could not save processed data: {e_save}")
|
351 |
+
return
|
352 |
+
|
353 |
+
total_elapsed_script = time.time() - overall_start_time
|
354 |
+
print(f"Pre-processing finished. Total time: {total_elapsed_script:.2f}s. Final Parquet shape: {df_to_save.shape}")
|
355 |
+
|
356 |
+
if __name__ == "__main__":
|
357 |
+
if os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
358 |
+
print(f"Deleting existing '{PROCESSED_PARQUET_FILE_PATH}' to ensure fresh processing...")
|
359 |
+
try: os.remove(PROCESSED_PARQUET_FILE_PATH)
|
360 |
+
except OSError as e: print(f"Error deleting file: {e}. Please delete manually and rerun."); exit()
|
361 |
+
|
362 |
+
main_preprocessor()
|
363 |
+
|
364 |
+
if os.path.exists(PROCESSED_PARQUET_FILE_PATH):
|
365 |
+
print(f"\nTo verify, load parquet and check 'has_robot' and its 'tags':")
|
366 |
+
print(f"import pandas as pd; df_chk = pd.read_parquet('{PROCESSED_PARQUET_FILE_PATH}')")
|
367 |
+
print(f"print(df_chk['has_robot'].value_counts())")
|
368 |
+
if MODEL_ID_TO_DEBUG:
|
369 |
+
print(f"print(df_chk[df_chk['id'] == '{MODEL_ID_TO_DEBUG}'][['id', 'tags', 'has_robot']])")
|
370 |
+
else:
|
371 |
+
print(f"print(df_chk[df_chk['has_robot']][['id', 'tags', 'has_robot']].head())")
|