Shuu12121 commited on
Commit
6396265
·
verified ·
1 Parent(s): 066cc0b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -4
app.py CHANGED
@@ -25,22 +25,40 @@ def code_search_demo(seed: int):
25
  code_str, doc_str = get_random_query(seed)
26
  query_emb = model.encode(doc_str, convert_to_tensor=True)
27
 
28
- # ランダムに10件取得
29
  candidates = dataset.shuffle(seed=seed).select(range(10))
 
 
 
30
  candidate_codes = [c["code"] for c in candidates]
31
  candidate_embeddings = model.encode(candidate_codes, convert_to_tensor=True)
32
 
33
- # 類似度スコア算出
34
  cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
35
  results = sorted(zip(candidate_codes, cos_scores), key=lambda x: x[1], reverse=True)
36
 
37
- # 結果出力
 
 
 
 
 
 
 
 
 
38
  output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
 
 
39
  output += "## 🏆 Top Matches:\n"
 
40
  medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
41
  for i, (code, score) in enumerate(results):
42
  label = medals[i] if i < len(medals) else f"#{i+1}"
43
- output += f"\n**{label}** - Similarity: {score.item():.4f}\n\n```python\n{code.strip()[:1000]}\n```\n"
 
 
 
 
44
 
45
  return output
46
 
 
25
  code_str, doc_str = get_random_query(seed)
26
  query_emb = model.encode(doc_str, convert_to_tensor=True)
27
 
28
+ # ランダムに10件取得し、正解 index を含めるようにする(※現実には全件評価がおすすめ)
29
  candidates = dataset.shuffle(seed=seed).select(range(10))
30
+ correct_label = dataset[seed]["label"] # 正解 index(全体に対する)
31
+ correct_code = dataset[correct_label]["code"]
32
+
33
  candidate_codes = [c["code"] for c in candidates]
34
  candidate_embeddings = model.encode(candidate_codes, convert_to_tensor=True)
35
 
 
36
  cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
37
  results = sorted(zip(candidate_codes, cos_scores), key=lambda x: x[1], reverse=True)
38
 
39
+ # 正解コードが Top-K に含まれているかを確認
40
+ top_k = 10
41
+ correct_in_top_k = any(code.strip() == correct_code.strip() for code, _ in results[:top_k])
42
+ mrr = 0.0
43
+ for rank, (code, _) in enumerate(results, start=1):
44
+ if code.strip() == correct_code.strip():
45
+ mrr = 1.0 / rank
46
+ break
47
+
48
+ # 出力構築
49
  output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
50
+ output += f"**✅ 正解は Top-{top_k} に含まれているか?**: {'🟢 Yes' if correct_in_top_k else '🔴 No'}\n\n"
51
+ output += f"**📈 MRR@{top_k}**: {mrr:.4f}\n\n"
52
  output += "## 🏆 Top Matches:\n"
53
+
54
  medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
55
  for i, (code, score) in enumerate(results):
56
  label = medals[i] if i < len(medals) else f"#{i+1}"
57
+ is_correct = "✅" if code.strip() == correct_code.strip() else ""
58
+ output += f"\n**{label}** - Similarity: {score.item():.4f} {is_correct}\n\n```python\n{code.strip()[:1000]}\n```\n"
59
+
60
+ return output
61
+
62
 
63
  return output
64