Update app.py
Browse files
app.py
CHANGED
@@ -14,29 +14,32 @@ model.eval()
|
|
14 |
dataset = load_dataset("code_x_glue_tc_nl_code_search_adv", trust_remote_code=True, split="test")
|
15 |
|
16 |
# --- Query & Candidate Generator ---
|
17 |
-
def
|
18 |
random.seed(seed)
|
19 |
idx = random.randint(0, len(dataset) - 1)
|
20 |
-
|
21 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
@GPU
|
24 |
def code_search_demo(seed: int):
|
25 |
-
|
26 |
-
query_emb = model.encode(doc_str, convert_to_tensor=True)
|
27 |
|
28 |
-
|
29 |
-
|
30 |
-
correct_label = dataset[seed]["label"] # 正解 index(全体に対する)
|
31 |
-
correct_code = dataset[correct_label]["code"]
|
32 |
-
|
33 |
-
candidate_codes = [c["code"] for c in candidates]
|
34 |
-
candidate_embeddings = model.encode(candidate_codes, convert_to_tensor=True)
|
35 |
|
36 |
cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
|
37 |
-
results = sorted(zip(
|
38 |
|
39 |
-
# 正解コードが Top-K に含まれているかを確認
|
40 |
top_k = 10
|
41 |
correct_in_top_k = any(code.strip() == correct_code.strip() for code, _ in results[:top_k])
|
42 |
mrr = 0.0
|
@@ -45,7 +48,6 @@ def code_search_demo(seed: int):
|
|
45 |
mrr = 1.0 / rank
|
46 |
break
|
47 |
|
48 |
-
# 出力構築
|
49 |
output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
|
50 |
output += f"**✅ 正解は Top-{top_k} に含まれているか?**: {'🟢 Yes' if correct_in_top_k else '🔴 No'}\n\n"
|
51 |
output += f"**📈 MRR@{top_k}**: {mrr:.4f}\n\n"
|
@@ -60,8 +62,6 @@ def code_search_demo(seed: int):
|
|
60 |
return output
|
61 |
|
62 |
|
63 |
-
return output
|
64 |
-
|
65 |
# --- Gradio UI ---
|
66 |
demo = gr.Interface(
|
67 |
fn=code_search_demo,
|
|
|
14 |
dataset = load_dataset("code_x_glue_tc_nl_code_search_adv", trust_remote_code=True, split="test")
|
15 |
|
16 |
# --- Query & Candidate Generator ---
|
17 |
+
def get_query_and_candidates(seed: int = 42):
|
18 |
random.seed(seed)
|
19 |
idx = random.randint(0, len(dataset) - 1)
|
20 |
+
query = dataset[idx]
|
21 |
+
correct_code = query["code"]
|
22 |
+
doc_str = query["docstring"]
|
23 |
+
|
24 |
+
# 正例 + ランダム負例(正例を除く)
|
25 |
+
candidate_pool = [example for i, example in enumerate(dataset) if i != idx]
|
26 |
+
negatives = random.sample(candidate_pool, k=9) # 9件の負例
|
27 |
+
candidates = [correct_code] + [neg["code"] for neg in negatives]
|
28 |
+
random.shuffle(candidates)
|
29 |
+
|
30 |
+
return doc_str, correct_code, candidates
|
31 |
+
|
32 |
|
33 |
@GPU
|
34 |
def code_search_demo(seed: int):
|
35 |
+
doc_str, correct_code, candidates = get_query_and_candidates(seed)
|
|
|
36 |
|
37 |
+
query_emb = model.encode(doc_str, convert_to_tensor=True)
|
38 |
+
candidate_embeddings = model.encode(candidates, convert_to_tensor=True)
|
|
|
|
|
|
|
|
|
|
|
39 |
|
40 |
cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
|
41 |
+
results = sorted(zip(candidates, cos_scores), key=lambda x: x[1], reverse=True)
|
42 |
|
|
|
43 |
top_k = 10
|
44 |
correct_in_top_k = any(code.strip() == correct_code.strip() for code, _ in results[:top_k])
|
45 |
mrr = 0.0
|
|
|
48 |
mrr = 1.0 / rank
|
49 |
break
|
50 |
|
|
|
51 |
output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
|
52 |
output += f"**✅ 正解は Top-{top_k} に含まれているか?**: {'🟢 Yes' if correct_in_top_k else '🔴 No'}\n\n"
|
53 |
output += f"**📈 MRR@{top_k}**: {mrr:.4f}\n\n"
|
|
|
62 |
return output
|
63 |
|
64 |
|
|
|
|
|
65 |
# --- Gradio UI ---
|
66 |
demo = gr.Interface(
|
67 |
fn=code_search_demo,
|