import os, tempfile, time import gradio as gr from tool.testv3 import run_autotune_pipeline # ---------- Core callback ---------- def generate_kernel(text_input, n_iters, progress=gr.Progress()): """ text_input : string from textbox (NL description or base CUDA code) file_input : gr.File upload object (or None) Returns : (kernel_code_str, downloadable_file_path) """ progress((0, n_iters), desc="Initializing...") # 1) Select input source if not text_input.strip(): return "⚠️ Please paste a description or baseline CUDA code.", "", None td = tempfile.mkdtemp(prefix="auto_") src_path = os.path.join(td, f"input_{int(time.time())}.txt") with open(src_path, "w") as f: f.write(text_input) best_code = "" for info in run_autotune_pipeline(src_path, n_iters): # 1) update progress bar (if iteration known) if info["event"] == "iteration_end" and info["iteration"] is not None: # print(f"Iteration {info['iteration']} / {n_iters}: {info['message']}") progress((info["iteration"], n_iters), desc=info["message"]) # 3) kernel output only when we get new code if info["code"]: best_code = info["code"] # last yield enables the download button return best_code # ---------- Gradio UI ---------- with gr.Blocks(title="KernelPilot", theme=gr.themes.Soft(text_size="lg", font=[ "system-ui", "-apple-system", "BlinkMacSystemFont", "Segoe UI", "Roboto", "Helvetica Neue", "Arial", "Noto Sans", "sans-serif" ])) as demo: gr.Markdown( """# 🚀 KernelPilot Enter a natural‑language description, then click **Generate** to obtain the kernel function.""" ) with gr.Row(): txt_input = gr.Textbox( label="📝 Input", lines=10, placeholder="Describe the kernel", scale=3 ) level = gr.Number( label="Optimization Level", minimum=1, maximum=5, value=2, step=1, scale=1 ) gen_btn = gr.Button("⚡ Generate") kernel_output = gr.Code( label="🎯 Tuned CUDA Kernel", language="cpp" ) gen_btn.click( fn=generate_kernel, inputs=[txt_input, level], outputs=[kernel_output], queue=True, # keeps requests queued show_progress=True, # show progress bar show_progress_on=kernel_output # update log box with progress ) if __name__ == "__main__": demo.queue(default_concurrency_limit=1, max_size=50) demo.launch(server_name="0.0.0.0", server_port=7860)