|
import gradio as gr |
|
import torch |
|
from phaseformer import PhaseFormerWrapper |
|
|
|
def run_phaseformer(mode, seq_len, batch_size, input_dim, t): |
|
model = PhaseFormerWrapper(mode=mode, input_dim=input_dim) |
|
x = torch.randn(seq_len, batch_size, input_dim) |
|
out = model(x, t) |
|
return f"Output shape: {tuple(out.shape)}" |
|
|
|
iface = gr.Interface( |
|
fn=run_phaseformer, |
|
inputs=[ |
|
gr.Radio(["mlp", "transformer"], label="Select Model"), |
|
gr.Slider(1, 128, value=10, label="Sequence Length"), |
|
gr.Slider(1, 64, value=2, label="Batch Size"), |
|
gr.Slider(1, 512, value=64, label="Input Dimension"), |
|
gr.Slider(0.0, 10.0, value=5.0, step=0.1, label="Time Step (t)") |
|
], |
|
outputs="text", |
|
title="🧠 Perceive PhaseFormer Demo", |
|
description=""" |
|
Choose model mode and input specs. Outputs the shape of the result. |
|
""" |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |
|
|