import torch import gradio as gr from transformers import AutoConfig from models import CustomClassifier, CustomClassificationConfig MODEL_ID = "yhamidullah/custom-classifier-demo" config = CustomClassificationConfig.from_pretrained(MODEL_ID) model = CustomClassifier.from_pretrained(MODEL_ID) model.eval() def predict(input_csv: str): vec = [float(x) for x in input_csv.split(",")] if len(vec) != config.input_dim: return f"Error: Need {config.input_dim} floats" x = torch.tensor([vec]) with torch.no_grad(): logits = model(input_ids=x)["logits"] pred = logits.argmax(dim=-1).item() return f"Predicted class: {pred}" demo = gr.Interface( fn=predict, inputs=gr.Textbox(label="Input Vector (comma-separated)"), outputs="text", title="Custom Classifier Demo", ) demo.launch()