yhamidullah commited on
Commit
f51abca
·
verified ·
1 Parent(s): f787795

create app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -0
app.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+ from transformers import AutoConfig
4
+ from your_model_code import CustomClassifier
5
+
6
+ MODEL_ID = "your-username/custom-classifier-demo"
7
+ config = AutoConfig.from_pretrained(MODEL_ID)
8
+ model = CustomClassifier.from_pretrained(MODEL_ID)
9
+ model.eval()
10
+
11
+ def predict(input_csv: str):
12
+ vec = [float(x) for x in input_csv.split(",")]
13
+ if len(vec) != config.input_dim:
14
+ return f"Error: Need {config.input_dim} floats"
15
+ x = torch.tensor([vec])
16
+ with torch.no_grad():
17
+ logits = model(input_ids=x)["logits"]
18
+ pred = logits.argmax(dim=-1).item()
19
+ return f"Predicted class: {pred}"
20
+
21
+ demo = gr.Interface(
22
+ fn=predict,
23
+ inputs=gr.Textbox(label="Input Vector (comma-separated)"),
24
+ outputs="text",
25
+ title="Custom Classifier Demo",
26
+ )
27
+ demo.launch()