File size: 3,163 Bytes
efb129d 1abd701 efb129d d01b7c9 1abd701 efb129d 1abd701 efb129d 066cc0b 1abd701 d01b7c9 182358f e3dfc76 efb129d 182358f 00bd88d d01b7c9 00bd88d d01b7c9 00bd88d 1abd701 d01b7c9 efb129d 182358f 00bd88d 1abd701 00bd88d 1abd701 efb129d 00bd88d 1abd701 6396265 182358f 6396265 efb129d 6396265 efb129d 6396265 182358f efb129d e3dfc76 efb129d 90c6104 182358f 1abd701 182358f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
# CodeSearch-ModernBERT-Owl Demo Space using CodeSearchNet Dataset
import gradio as gr
import torch
import random
from sentence_transformers import SentenceTransformer, util
from datasets import load_dataset
from spaces import GPU
import re
# --- Load model ---
model = SentenceTransformer("Shuu12121/CodeSearch-ModernBERT-Owl")
model.eval()
# --- Load CodeSearchNet dataset (test split only) ---
dataset = load_dataset("code_x_glue_tc_nl_code_search_adv", trust_remote_code=True, split="test")
def remove_comments_from_code(code: str) -> str:
# 複数行コメント(docstring含む)を除去
code = re.sub(r'"""[\s\S]*?"""', '', code)
code = re.sub(r"'''[\s\S]*?'''", '', code)
# 単一行コメント(# 以降を除去)
code = re.sub(r'#.*', '', code)
return code
# --- Query & Candidate Generator ---
def get_query_and_candidates(seed: int = 8520):
random.seed(seed)
idx = random.randint(0, len(dataset) - 1)
query = dataset[idx]
correct_code = remove_comments_from_code(query["code"]) # 修正
doc_str = query["docstring"]
candidate_pool = [example for i, example in enumerate(dataset) if i != idx]
negatives = random.sample(candidate_pool, k=99)
candidates = [correct_code] + [remove_comments_from_code(neg["code"]) for neg in negatives] # 修正
random.shuffle(candidates)
return doc_str, correct_code, candidates
@GPU
def code_search_demo(seed: int):
doc_str, correct_code, candidates = get_query_and_candidates(seed)
query_emb = model.encode(doc_str, convert_to_tensor=True)
candidate_embeddings = model.encode(candidates, convert_to_tensor=True)
cos_scores = util.cos_sim(query_emb, candidate_embeddings)[0]
results = sorted(zip(candidates, cos_scores), key=lambda x: x[1], reverse=True)
top_k = 10
correct_in_top_k = any(code.strip() == correct_code.strip() for code, _ in results[:top_k])
mrr = 0.0
for rank, (code, _) in enumerate(results, start=1):
if code.strip() == correct_code.strip():
mrr = 1.0 / rank
break
output = f"### 🔍 Query Docstring\n\n{doc_str}\n\n"
output += f"**✅ 正解は Top-{top_k} に含まれているか?**: {'🟢 Yes' if correct_in_top_k else '🔴 No'}\n\n"
output += f"**📈 MRR@{top_k}**: {mrr:.4f}\n\n"
output += "## 🏆 Top Matches:\n"
medals = ["🥇", "🥈", "🥉"] + [f"#{i+1}" for i in range(3, len(results))]
for i, (code, score) in enumerate(results):
label = medals[i] if i < len(medals) else f"#{i+1}"
is_correct = "✅" if code.strip() == correct_code.strip() else ""
output += f"\n**{label}** - Similarity: {score.item():.4f} {is_correct}\n\n```python\n{code.strip()[:1000]}\n```\n"
return output
# --- Gradio UI ---
demo = gr.Interface(
fn=code_search_demo,
inputs=gr.Slider(0, 100000, value=8520, step=1, label="Random Seed"),
outputs=gr.Markdown(label="Search Result"),
title="🔎 CodeSearch-ModernBERT-Owl🦉 Demo",
description="docstring から類似 Python 関数を検索(CodeXGlue + ModernBERT-Owl)"
)
if __name__ == "__main__":
demo.launch()
|