Update app.py
Browse files
app.py
CHANGED
@@ -25,22 +25,40 @@ def code_search_demo(seed: int):
|
|
25 |
code_str, doc_str = get_random_query(seed)
|
26 |
query_emb = model.encode(doc_str, convert_to_tensor=True)
|
27 |
|
28 |
-
# ランダムに10
|
29 |
candidates = dataset.shuffle(seed=seed).select(range(10))
|
|
|
|
|
|
|
30 |
candidate_codes = [c["code"] for c in candidates]
|
31 |
candidate_embeddings = model.encode(candidate_codes, convert_to_tensor=True)
|
32 |
|
33 |
-
# 類似度スコア算出
|
34 |
cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
|
35 |
results = sorted(zip(candidate_codes, cos_scores), key=lambda x: x[1], reverse=True)
|
36 |
|
37 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
|
|
|
|
|
39 |
output += "## 🏆 Top Matches:\n"
|
|
|
40 |
medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
|
41 |
for i, (code, score) in enumerate(results):
|
42 |
label = medals[i] if i < len(medals) else f"#{i+1}"
|
43 |
-
|
|
|
|
|
|
|
|
|
44 |
|
45 |
return output
|
46 |
|
|
|
25 |
code_str, doc_str = get_random_query(seed)
|
26 |
query_emb = model.encode(doc_str, convert_to_tensor=True)
|
27 |
|
28 |
+
# ランダムに10件取得し、正解 index を含めるようにする(※現実には全件評価がおすすめ)
|
29 |
candidates = dataset.shuffle(seed=seed).select(range(10))
|
30 |
+
correct_label = dataset[seed]["label"] # 正解 index(全体に対する)
|
31 |
+
correct_code = dataset[correct_label]["code"]
|
32 |
+
|
33 |
candidate_codes = [c["code"] for c in candidates]
|
34 |
candidate_embeddings = model.encode(candidate_codes, convert_to_tensor=True)
|
35 |
|
|
|
36 |
cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
|
37 |
results = sorted(zip(candidate_codes, cos_scores), key=lambda x: x[1], reverse=True)
|
38 |
|
39 |
+
# 正解コードが Top-K に含まれているかを確認
|
40 |
+
top_k = 10
|
41 |
+
correct_in_top_k = any(code.strip() == correct_code.strip() for code, _ in results[:top_k])
|
42 |
+
mrr = 0.0
|
43 |
+
for rank, (code, _) in enumerate(results, start=1):
|
44 |
+
if code.strip() == correct_code.strip():
|
45 |
+
mrr = 1.0 / rank
|
46 |
+
break
|
47 |
+
|
48 |
+
# 出力構築
|
49 |
output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
|
50 |
+
output += f"**✅ 正解は Top-{top_k} に含まれているか?**: {'🟢 Yes' if correct_in_top_k else '🔴 No'}\n\n"
|
51 |
+
output += f"**📈 MRR@{top_k}**: {mrr:.4f}\n\n"
|
52 |
output += "## 🏆 Top Matches:\n"
|
53 |
+
|
54 |
medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
|
55 |
for i, (code, score) in enumerate(results):
|
56 |
label = medals[i] if i < len(medals) else f"#{i+1}"
|
57 |
+
is_correct = "✅" if code.strip() == correct_code.strip() else ""
|
58 |
+
output += f"\n**{label}** - Similarity: {score.item():.4f} {is_correct}\n\n```python\n{code.strip()[:1000]}\n```\n"
|
59 |
+
|
60 |
+
return output
|
61 |
+
|
62 |
|
63 |
return output
|
64 |
|