Spaces:
Running
on
T4
Running
on
T4
import os, tempfile, time | |
import gradio as gr | |
from tool.testv3 import run_autotune_pipeline | |
# ---------- Core callback ---------- | |
def generate_kernel(text_input, n_iters, progress=gr.Progress()): | |
""" | |
text_input : string from textbox (NL description or base CUDA code) | |
file_input : gr.File upload object (or None) | |
Returns : (kernel_code_str, downloadable_file_path) | |
""" | |
progress((0, n_iters), desc="Initializing...") | |
# 1) Select input source | |
if not text_input.strip(): | |
return "⚠️ Please paste a description or baseline CUDA code.", "", None | |
td = tempfile.mkdtemp(prefix="auto_") | |
src_path = os.path.join(td, f"input_{int(time.time())}.txt") | |
with open(src_path, "w") as f: | |
f.write(text_input) | |
best_code = "" | |
for info in run_autotune_pipeline(src_path, n_iters): | |
# 1) update progress bar (if iteration known) | |
if info["event"] == "iteration_end" and info["iteration"] is not None: | |
# print(f"Iteration {info['iteration']} / {n_iters}: {info['message']}") | |
progress((info["iteration"], n_iters), desc=info["message"]) | |
# 3) kernel output only when we get new code | |
if info["code"]: | |
best_code = info["code"] | |
# last yield enables the download button | |
return best_code | |
# ---------- Gradio UI ---------- | |
with gr.Blocks(title="KernelPilot", theme=gr.themes.Soft(text_size="lg", font=[ | |
"system-ui", | |
"-apple-system", | |
"BlinkMacSystemFont", | |
"Segoe UI", | |
"Roboto", | |
"Helvetica Neue", | |
"Arial", | |
"Noto Sans", | |
"sans-serif" | |
])) as demo: | |
gr.Markdown( | |
"""# 🚀 KernelPilot | |
Enter a natural‑language description, | |
then click **Generate** to obtain the kernel function.""" | |
) | |
with gr.Row(): | |
txt_input = gr.Textbox( | |
label="📝 Input", | |
lines=10, | |
placeholder="Describe the kernel", | |
scale=3 | |
) | |
level = gr.Number( | |
label="Optimization Level", | |
minimum=1, | |
maximum=5, | |
value=2, | |
step=1, | |
scale=1 | |
) | |
gen_btn = gr.Button("⚡ Generate") | |
kernel_output = gr.Code( | |
label="🎯 Tuned CUDA Kernel", | |
language="cpp" | |
) | |
gen_btn.click( | |
fn=generate_kernel, | |
inputs=[txt_input, level], | |
outputs=[kernel_output], | |
queue=True, # keeps requests queued | |
show_progress=True, # show progress bar | |
show_progress_on=kernel_output # update log box with progress | |
) | |
if __name__ == "__main__": | |
demo.queue(default_concurrency_limit=1, max_size=50) | |
demo.launch(server_name="0.0.0.0", server_port=7860) |