ammu / app.py
vikramronavrsc's picture
Update app.py
f27ac44 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
from transformers import BertTokenizer, BertForSequenceClassification
import torch
import gradio as gr
app = FastAPI()
# Prediction labels
LABELS = [
'Login Issue',
'Booking Issue',
'Delivery Issue',
'Laboratory Issue',
'Application Issue'
]
# CORS Configuration
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Load model and tokenizer
try:
model = BertForSequenceClassification.from_pretrained("./saved_model1")
tokenizer = BertTokenizer.from_pretrained("./saved_model1")
model.eval()
except Exception as e:
raise RuntimeError(f"Model loading failed: {str(e)}")
# Request Model
class PredictionRequest(BaseModel):
issue: str
# FastAPI Endpoint
@app.post("/predict")
async def predict(request: PredictionRequest):
try:
inputs = tokenizer(
request.issue,
return_tensors="pt",
truncation=True,
padding=True,
max_length=512
)
with torch.no_grad():
outputs = model(**inputs)
probabilities = torch.softmax(outputs.logits, dim=1)
label_idx = torch.argmax(probabilities).item()
return {
"category": LABELS[label_idx],
"confidence": round(probabilities[0][label_idx].item(), 4)
}
except Exception as e:
raise HTTPException(
status_code=500,
detail=f"Prediction error: {str(e)}"
)
# Gradio Interface
def gradio_classifier(text):
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True)
with torch.no_grad():
outputs = model(**inputs)
probs = torch.softmax(outputs.logits, dim=1)
pred_idx = torch.argmax(probs).item()
return {
"Prediction": LABELS[pred_idx],
"Confidence Score": float(probs[0][pred_idx].item()),
"All Probabilities": {
label: round(float(probs[0][i]), 4)
for i, label in enumerate(LABELS)
}
}
# Mount Gradio interface
gradio_app = gr.Interface(
fn=gradio_classifier,
inputs=gr.Textbox(lines=3, placeholder="Enter issue description...", label="Issue"),
outputs=[
gr.Label(label="Predicted Category"),
gr.Number(label="Confidence Score"),
gr.JSON(label="Class Probabilities")
],
title="Issue Classifier",
description="BERT-based classification system for customer support issues"
)
app = gr.mount_gradio_app(app, gradio_app, path="/gradio")