vikramronavrsc commited on
Commit
f27ac44
·
verified ·
1 Parent(s): 3abbb40

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +77 -14
app.py CHANGED
@@ -1,4 +1,5 @@
1
  from fastapi import FastAPI, HTTPException
 
2
  from pydantic import BaseModel
3
  from transformers import BertTokenizer, BertForSequenceClassification
4
  import torch
@@ -6,30 +7,92 @@ import gradio as gr
6
 
7
  app = FastAPI()
8
 
9
- # Your existing prediction code remains unchanged
10
- Label = ['Login Issue', 'Booking Issue', 'Delivery Issue', 'Laboratory Issue', 'Application Issue']
 
 
 
 
 
 
11
 
12
- # ... [Keep your existing CORS and model loading code] ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Gradio Interface
15
- def gradio_predict(issue_text):
16
- inputs = tokenizer(issue_text, return_tensors="pt", truncation=True, padding=True)
17
  with torch.no_grad():
18
  outputs = model(**inputs)
19
- predictions = torch.softmax(outputs.logits, dim=1)
20
- label_idx = torch.argmax(predictions, dim=1).item()
 
21
  return {
22
- "Predicted Category": Label[label_idx],
23
- "Confidence": f"{predictions[0][label_idx].item():.4f}"
 
 
 
 
24
  }
25
 
 
26
  gradio_app = gr.Interface(
27
- fn=gradio_predict,
28
- inputs=gr.Textbox(label="Enter Issue Description"),
29
- outputs=gr.JSON(label="Prediction Results"),
 
 
 
 
30
  title="Issue Classifier",
31
- description="BERT-based classification demo"
32
  )
33
 
34
- # Mount Gradio to FastAPI
35
  app = gr.mount_gradio_app(app, gradio_app, path="/gradio")
 
1
  from fastapi import FastAPI, HTTPException
2
+ from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  from transformers import BertTokenizer, BertForSequenceClassification
5
  import torch
 
7
 
8
  app = FastAPI()
9
 
10
+ # Prediction labels
11
+ LABELS = [
12
+ 'Login Issue',
13
+ 'Booking Issue',
14
+ 'Delivery Issue',
15
+ 'Laboratory Issue',
16
+ 'Application Issue'
17
+ ]
18
 
19
+ # CORS Configuration
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"],
23
+ allow_credentials=True,
24
+ allow_methods=["*"],
25
+ allow_headers=["*"],
26
+ )
27
+
28
+ # Load model and tokenizer
29
+ try:
30
+ model = BertForSequenceClassification.from_pretrained("./saved_model1")
31
+ tokenizer = BertTokenizer.from_pretrained("./saved_model1")
32
+ model.eval()
33
+ except Exception as e:
34
+ raise RuntimeError(f"Model loading failed: {str(e)}")
35
+
36
+ # Request Model
37
+ class PredictionRequest(BaseModel):
38
+ issue: str
39
+
40
+ # FastAPI Endpoint
41
+ @app.post("/predict")
42
+ async def predict(request: PredictionRequest):
43
+ try:
44
+ inputs = tokenizer(
45
+ request.issue,
46
+ return_tensors="pt",
47
+ truncation=True,
48
+ padding=True,
49
+ max_length=512
50
+ )
51
+
52
+ with torch.no_grad():
53
+ outputs = model(**inputs)
54
+ probabilities = torch.softmax(outputs.logits, dim=1)
55
+ label_idx = torch.argmax(probabilities).item()
56
+
57
+ return {
58
+ "category": LABELS[label_idx],
59
+ "confidence": round(probabilities[0][label_idx].item(), 4)
60
+ }
61
+
62
+ except Exception as e:
63
+ raise HTTPException(
64
+ status_code=500,
65
+ detail=f"Prediction error: {str(e)}"
66
+ )
67
 
68
  # Gradio Interface
69
+ def gradio_classifier(text):
70
+ inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
71
  with torch.no_grad():
72
  outputs = model(**inputs)
73
+ probs = torch.softmax(outputs.logits, dim=1)
74
+ pred_idx = torch.argmax(probs).item()
75
+
76
  return {
77
+ "Prediction": LABELS[pred_idx],
78
+ "Confidence Score": float(probs[0][pred_idx].item()),
79
+ "All Probabilities": {
80
+ label: round(float(probs[0][i]), 4)
81
+ for i, label in enumerate(LABELS)
82
+ }
83
  }
84
 
85
+ # Mount Gradio interface
86
  gradio_app = gr.Interface(
87
+ fn=gradio_classifier,
88
+ inputs=gr.Textbox(lines=3, placeholder="Enter issue description...", label="Issue"),
89
+ outputs=[
90
+ gr.Label(label="Predicted Category"),
91
+ gr.Number(label="Confidence Score"),
92
+ gr.JSON(label="Class Probabilities")
93
+ ],
94
  title="Issue Classifier",
95
+ description="BERT-based classification system for customer support issues"
96
  )
97
 
 
98
  app = gr.mount_gradio_app(app, gradio_app, path="/gradio")