zhujiace's picture
op=2
50764c0
import os, tempfile, time
import gradio as gr
from tool.testv3 import run_autotune_pipeline
# ---------- Core callback ----------
def generate_kernel(text_input, n_iters, progress=gr.Progress()):
"""
text_input : string from textbox (NL description or base CUDA code)
file_input : gr.File upload object (or None)
Returns : (kernel_code_str, downloadable_file_path)
"""
progress((0, n_iters), desc="Initializing...")
# 1) Select input source
if not text_input.strip():
return "⚠️ Please paste a description or baseline CUDA code.", "", None
td = tempfile.mkdtemp(prefix="auto_")
src_path = os.path.join(td, f"input_{int(time.time())}.txt")
with open(src_path, "w") as f:
f.write(text_input)
best_code = ""
for info in run_autotune_pipeline(src_path, n_iters):
# 1) update progress bar (if iteration known)
if info["event"] == "iteration_end" and info["iteration"] is not None:
# print(f"Iteration {info['iteration']} / {n_iters}: {info['message']}")
progress((info["iteration"], n_iters), desc=info["message"])
# 3) kernel output only when we get new code
if info["code"]:
best_code = info["code"]
# last yield enables the download button
return best_code
# ---------- Gradio UI ----------
with gr.Blocks(title="KernelPilot", theme=gr.themes.Soft(text_size="lg", font=[
"system-ui",
"-apple-system",
"BlinkMacSystemFont",
"Segoe UI",
"Roboto",
"Helvetica Neue",
"Arial",
"Noto Sans",
"sans-serif"
])) as demo:
gr.Markdown(
"""# 🚀 KernelPilot
Enter a natural‑language description,
then click **Generate** to obtain the kernel function."""
)
with gr.Row():
txt_input = gr.Textbox(
label="📝 Input",
lines=10,
placeholder="Describe the kernel",
scale=3
)
level = gr.Number(
label="Optimization Level",
minimum=1,
maximum=5,
value=2,
step=1,
scale=1
)
gen_btn = gr.Button("⚡ Generate")
kernel_output = gr.Code(
label="🎯 Tuned CUDA Kernel",
language="cpp"
)
gen_btn.click(
fn=generate_kernel,
inputs=[txt_input, level],
outputs=[kernel_output],
queue=True, # keeps requests queued
show_progress=True, # show progress bar
show_progress_on=kernel_output # update log box with progress
)
if __name__ == "__main__":
demo.queue(default_concurrency_limit=1, max_size=50)
demo.launch(server_name="0.0.0.0", server_port=7860)