Spaces:
Sleeping
Sleeping
Upload viz_utils.py
Browse files- viz_utils.py +20 -8
viz_utils.py
CHANGED
|
@@ -5,36 +5,48 @@ from collections import defaultdict
|
|
| 5 |
import random
|
| 6 |
|
| 7 |
def color_for_label(label):
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
return f"rgb({random.randint(50,200)}, {random.randint(50,200)}, {random.randint(50,200)})"
|
| 10 |
|
| 11 |
def generate_force_graph(sentences, labels):
|
| 12 |
nodes = []
|
| 13 |
links = []
|
| 14 |
-
label_map =
|
|
|
|
| 15 |
for i, (s, l) in enumerate(zip(sentences, labels)):
|
| 16 |
color = color_for_label(l)
|
| 17 |
-
nodes.append({"name": s, "symbolSize": 10, "category": int(l), "itemStyle": {"color": color}})
|
| 18 |
-
label_map
|
| 19 |
|
| 20 |
for group in label_map.values():
|
|
|
|
|
|
|
| 21 |
for i in group:
|
|
|
|
| 22 |
for j in group:
|
| 23 |
if i < j:
|
| 24 |
links.append({"source": sentences[i], "target": sentences[j]})
|
|
|
|
|
|
|
|
|
|
| 25 |
return {"type": "force", "nodes": nodes, "links": links}
|
| 26 |
|
| 27 |
def generate_bubble_chart(sentences, labels):
|
| 28 |
counts = defaultdict(int)
|
| 29 |
for l in labels:
|
| 30 |
counts[l] += 1
|
| 31 |
-
data = [{"name": f"簇{l}", "value": v, "itemStyle": {"color": color_for_label(l)}} for l, v in counts.items()]
|
| 32 |
return {"type": "bubble", "series": [{"type": "scatter", "data": data}]}
|
| 33 |
|
| 34 |
def generate_umap_plot(embeddings, labels):
|
| 35 |
-
reducer = umap.UMAP(n_components=2)
|
| 36 |
umap_emb = reducer.fit_transform(embeddings)
|
| 37 |
scaled = MinMaxScaler().fit_transform(umap_emb)
|
| 38 |
-
data = [{"x": float(x), "y": float(y), "label": int(l), "itemStyle": {"color": color_for_label(l)}}
|
| 39 |
-
for (x, y), l in zip(scaled, labels)]
|
| 40 |
return {"type": "scatter", "series": [{"data": data}]}
|
|
|
|
| 5 |
import random
|
| 6 |
|
| 7 |
def color_for_label(label):
|
| 8 |
+
try:
|
| 9 |
+
label_int = int(label)
|
| 10 |
+
except:
|
| 11 |
+
label_int = -1
|
| 12 |
+
if label_int < 0:
|
| 13 |
+
return "rgb(150,150,150)" # 噪声点(-1)用灰色
|
| 14 |
+
random.seed(label_int + 1000)
|
| 15 |
return f"rgb({random.randint(50,200)}, {random.randint(50,200)}, {random.randint(50,200)})"
|
| 16 |
|
| 17 |
def generate_force_graph(sentences, labels):
|
| 18 |
nodes = []
|
| 19 |
links = []
|
| 20 |
+
label_map = defaultdict(list)
|
| 21 |
+
|
| 22 |
for i, (s, l) in enumerate(zip(sentences, labels)):
|
| 23 |
color = color_for_label(l)
|
| 24 |
+
nodes.append({"name": s, "symbolSize": 10, "category": int(l) if l >=0 else 0, "itemStyle": {"color": color}})
|
| 25 |
+
label_map[l].append(i)
|
| 26 |
|
| 27 |
for group in label_map.values():
|
| 28 |
+
# 可选:限制边数,避免边太多
|
| 29 |
+
max_edges_per_node = 10
|
| 30 |
for i in group:
|
| 31 |
+
connected = 0
|
| 32 |
for j in group:
|
| 33 |
if i < j:
|
| 34 |
links.append({"source": sentences[i], "target": sentences[j]})
|
| 35 |
+
connected += 1
|
| 36 |
+
if connected >= max_edges_per_node:
|
| 37 |
+
break
|
| 38 |
return {"type": "force", "nodes": nodes, "links": links}
|
| 39 |
|
| 40 |
def generate_bubble_chart(sentences, labels):
|
| 41 |
counts = defaultdict(int)
|
| 42 |
for l in labels:
|
| 43 |
counts[l] += 1
|
| 44 |
+
data = [{"name": f"簇{l}" if l >=0 else "噪声", "value": v, "itemStyle": {"color": color_for_label(l)}} for l, v in counts.items()]
|
| 45 |
return {"type": "bubble", "series": [{"type": "scatter", "data": data}]}
|
| 46 |
|
| 47 |
def generate_umap_plot(embeddings, labels):
|
| 48 |
+
reducer = umap.UMAP(n_components=2, random_state=42)
|
| 49 |
umap_emb = reducer.fit_transform(embeddings)
|
| 50 |
scaled = MinMaxScaler().fit_transform(umap_emb)
|
| 51 |
+
data = [{"x": float(x), "y": float(y), "label": int(l), "itemStyle": {"color": color_for_label(l)}} for (x, y), l in zip(scaled, labels)]
|
|
|
|
| 52 |
return {"type": "scatter", "series": [{"data": data}]}
|