File size: 3,163 Bytes
efb129d
1abd701
efb129d
 
 
 
 
d01b7c9
 
 
1abd701
efb129d
 
 
1abd701
efb129d
066cc0b
1abd701
d01b7c9
 
 
 
 
 
 
 
 
 
 
182358f
e3dfc76
efb129d
182358f
00bd88d
d01b7c9
00bd88d
 
 
d01b7c9
 
00bd88d
 
 
 
1abd701
d01b7c9
efb129d
182358f
00bd88d
1abd701
00bd88d
 
1abd701
efb129d
00bd88d
1abd701
6396265
 
 
 
 
 
 
 
182358f
6396265
 
efb129d
6396265
efb129d
 
 
6396265
 
 
 
 
182358f
 
efb129d
 
e3dfc76
efb129d
90c6104
182358f
1abd701
 
 
182358f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
# CodeSearch-ModernBERT-Owl Demo Space using CodeSearchNet Dataset
import gradio as gr
import torch
import random
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
from spaces import GPU
import re



# --- Load model ---
model = SentenceTransformer("Shuu12121/CodeSearch-ModernBERT-Owl")
model.eval()

# --- Load CodeSearchNet dataset (test split only) ---
dataset = load_dataset("code_x_glue_tc_nl_code_search_adv", trust_remote_code=True, split="test")

def remove_comments_from_code(code: str) -> str:
    # 複数行コメント(docstring含む)を除去
    code = re.sub(r'"""[\s\S]*?"""', '', code)
    code = re.sub(r"'''[\s\S]*?'''", '', code)

    # 単一行コメント(# 以降を除去)
    code = re.sub(r'#.*', '', code)

    return code


# --- Query & Candidate Generator ---
def get_query_and_candidates(seed: int = 8520):
    random.seed(seed)
    idx = random.randint(0, len(dataset) - 1)
    query = dataset[idx]
    correct_code = remove_comments_from_code(query["code"])  # 修正
    doc_str = query["docstring"]

    candidate_pool = [example for i, example in enumerate(dataset) if i != idx]
    negatives = random.sample(candidate_pool, k=99)
    candidates = [correct_code] + [remove_comments_from_code(neg["code"]) for neg in negatives]  # 修正
    random.shuffle(candidates)

    return doc_str, correct_code, candidates



@GPU
def code_search_demo(seed: int):
    doc_str, correct_code, candidates = get_query_and_candidates(seed)

    query_emb = model.encode(doc_str, convert_to_tensor=True)
    candidate_embeddings = model.encode(candidates, convert_to_tensor=True)

    cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
    results = sorted(zip(candidates, cos_scores), key=lambda x: x[1], reverse=True)

    top_k = 10
    correct_in_top_k = any(code.strip() == correct_code.strip() for code, _ in results[:top_k])
    mrr = 0.0
    for rank, (code, _) in enumerate(results, start=1):
        if code.strip() == correct_code.strip():
            mrr = 1.0 / rank
            break

    output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
    output += f"**✅ 正解は Top-{top_k} に含まれているか?**: {'🟢 Yes' if correct_in_top_k else '🔴 No'}\n\n"
    output += f"**📈 MRR@{top_k}**: {mrr:.4f}\n\n"
    output += "## 🏆 Top Matches:\n"

    medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
    for i, (code, score) in enumerate(results):
        label = medals[i] if i < len(medals) else f"#{i+1}"
        is_correct = "✅" if code.strip() == correct_code.strip() else ""
        output += f"\n**{label}** - Similarity: {score.item():.4f} {is_correct}\n\n```python\n{code.strip()[:1000]}\n```\n"

    return output


# --- Gradio UI ---
demo = gr.Interface(
    fn=code_search_demo,
    inputs=gr.Slider(0, 100000, value=8520, step=1, label="Random Seed"),
    outputs=gr.Markdown(label="Search Result"),
    title="🔎 CodeSearch-ModernBERT-Owl🦉 Demo",
    description="docstring から類似 Python 関数を検索(CodeXGlue + ModernBERT-Owl)"
)

if __name__ == "__main__":
    demo.launch()