Shuu12121 commited on
Commit
00bd88d
·
verified ·
1 Parent(s): 6396265

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -17
app.py CHANGED
@@ -14,29 +14,32 @@ model.eval()
14
  dataset = load_dataset("code_x_glue_tc_nl_code_search_adv", trust_remote_code=True, split="test")
15
 
16
  # --- Query & Candidate Generator ---
17
- def get_random_query(seed: int = 42):
18
  random.seed(seed)
19
  idx = random.randint(0, len(dataset) - 1)
20
- sample = dataset[idx]
21
- return sample["code"], sample["docstring"]
 
 
 
 
 
 
 
 
 
 
22
 
23
  @GPU
24
  def code_search_demo(seed: int):
25
- code_str, doc_str = get_random_query(seed)
26
- query_emb = model.encode(doc_str, convert_to_tensor=True)
27
 
28
- # ランダムに10件取得し、正解 index を含めるようにする(※現実には全件評価がおすすめ)
29
- candidates = dataset.shuffle(seed=seed).select(range(10))
30
- correct_label = dataset[seed]["label"] # 正解 index(全体に対する)
31
- correct_code = dataset[correct_label]["code"]
32
-
33
- candidate_codes = [c["code"] for c in candidates]
34
- candidate_embeddings = model.encode(candidate_codes, convert_to_tensor=True)
35
 
36
  cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
37
- results = sorted(zip(candidate_codes, cos_scores), key=lambda x: x[1], reverse=True)
38
 
39
- # 正解コードが Top-K に含まれているかを確認
40
  top_k = 10
41
  correct_in_top_k = any(code.strip() == correct_code.strip() for code, _ in results[:top_k])
42
  mrr = 0.0
@@ -45,7 +48,6 @@ def code_search_demo(seed: int):
45
  mrr = 1.0 / rank
46
  break
47
 
48
- # 出力構築
49
  output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
50
  output += f"**✅ 正解は Top-{top_k} に含まれているか?**: {'🟢 Yes' if correct_in_top_k else '🔴 No'}\n\n"
51
  output += f"**📈 MRR@{top_k}**: {mrr:.4f}\n\n"
@@ -60,8 +62,6 @@ def code_search_demo(seed: int):
60
  return output
61
 
62
 
63
- return output
64
-
65
  # --- Gradio UI ---
66
  demo = gr.Interface(
67
  fn=code_search_demo,
 
14
  dataset = load_dataset("code_x_glue_tc_nl_code_search_adv", trust_remote_code=True, split="test")
15
 
16
  # --- Query & Candidate Generator ---
17
+ def get_query_and_candidates(seed: int = 42):
18
  random.seed(seed)
19
  idx = random.randint(0, len(dataset) - 1)
20
+ query = dataset[idx]
21
+ correct_code = query["code"]
22
+ doc_str = query["docstring"]
23
+
24
+ # 正例 + ランダム負例(正例を除く)
25
+ candidate_pool = [example for i, example in enumerate(dataset) if i != idx]
26
+ negatives = random.sample(candidate_pool, k=9) # 9件の負例
27
+ candidates = [correct_code] + [neg["code"] for neg in negatives]
28
+ random.shuffle(candidates)
29
+
30
+ return doc_str, correct_code, candidates
31
+
32
 
33
  @GPU
34
  def code_search_demo(seed: int):
35
+ doc_str, correct_code, candidates = get_query_and_candidates(seed)
 
36
 
37
+ query_emb = model.encode(doc_str, convert_to_tensor=True)
38
+ candidate_embeddings = model.encode(candidates, convert_to_tensor=True)
 
 
 
 
 
39
 
40
  cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
41
+ results = sorted(zip(candidates, cos_scores), key=lambda x: x[1], reverse=True)
42
 
 
43
  top_k = 10
44
  correct_in_top_k = any(code.strip() == correct_code.strip() for code, _ in results[:top_k])
45
  mrr = 0.0
 
48
  mrr = 1.0 / rank
49
  break
50
 
 
51
  output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
52
  output += f"**✅ 正解は Top-{top_k} に含まれているか?**: {'🟢 Yes' if correct_in_top_k else '🔴 No'}\n\n"
53
  output += f"**📈 MRR@{top_k}**: {mrr:.4f}\n\n"
 
62
  return output
63
 
64
 
 
 
65
  # --- Gradio UI ---
66
  demo = gr.Interface(
67
  fn=code_search_demo,