vikramronavrsc commited on
Commit
58ed1c6
·
verified ·
1 Parent(s): c59fc20

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -38
app.py CHANGED
@@ -1,48 +1,35 @@
1
- import gradio as gr
2
- from transformers import AutoTokenizer, AutoModelForSequenceClassification
 
3
  import torch
 
 
 
 
 
 
4
 
5
- # Load model and tokenizer from local directory
6
- MODEL_PATH = "./saved_model1"
7
- tokenizer = AutoTokenizer.from_pretrained(MODEL_PATH)
8
- model = AutoModelForSequenceClassification.from_pretrained(MODEL_PATH)
9
 
10
- def predict(text):
11
- # Preprocess input
12
- inputs = tokenizer(
13
- text,
14
- padding=True,
15
- truncation=True,
16
- max_length=512,
17
- return_tensors="pt"
18
- )
19
-
20
- # Inference
21
  with torch.no_grad():
22
  outputs = model(**inputs)
23
-
24
- # Postprocess output (modify based on your task)
25
- probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
26
- predicted_class = torch.argmax(probabilities).item()
27
-
28
  return {
29
- "text": text,
30
- "predicted_class": predicted_class,
31
- "probabilities": probabilities.tolist()[0]
32
  }
33
 
34
- # Create Gradio interface
35
- demo = gr.Interface(
36
- fn=predict,
37
- inputs=gr.Textbox(label="Input Text", lines=3),
38
- outputs=[
39
- gr.Textbox(label="Processed Text"),
40
- gr.Number(label="Predicted Class"),
41
- gr.Label(label="Class Probabilities")
42
- ],
43
- title="BERT Model Deployment",
44
- examples=[["Sample text 1"], ["Another example text"]]
45
  )
46
 
47
- if __name__ == "__main__":
48
- demo.launch()
 
1
+ from fastapi import FastAPI, HTTPException
2
+ from pydantic import BaseModel
3
+ from transformers import BertTokenizer, BertForSequenceClassification
4
  import torch
5
+ import gradio as gr
6
+
7
+ app = FastAPI()
8
+
9
+ # Your existing prediction code remains unchanged
10
+ Label = ['Login Issue', 'Booking Issue', 'Delivery Issue', 'Laboratory Issue', 'Application Issue']
11
 
12
+ # ... [Keep your existing CORS and model loading code] ...
 
 
 
13
 
14
+ # Gradio Interface
15
+ def gradio_predict(issue_text):
16
+ inputs = tokenizer(issue_text, return_tensors="pt", truncation=True, padding=True)
 
 
 
 
 
 
 
 
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
+ predictions = torch.softmax(outputs.logits, dim=1)
20
+ label_idx = torch.argmax(predictions, dim=1).item()
 
 
 
21
  return {
22
+ "Predicted Category": Label[label_idx],
23
+ "Confidence": f"{predictions[0][label_idx].item():.4f}"
 
24
  }
25
 
26
+ gradio_app = gr.Interface(
27
+ fn=gradio_predict,
28
+ inputs=gr.Textbox(label="Enter Issue Description"),
29
+ outputs=gr.JSON(label="Prediction Results"),
30
+ title="Issue Classifier",
31
+ description="BERT-based classification demo"
 
 
 
 
 
32
  )
33
 
34
+ # Mount Gradio to FastAPI
35
+ app = gr.mount_gradio_app(app, gradio_app, path="/gradio")