Spaces:
Runtime error
Runtime error
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from transformers import BertTokenizer, BertForSequenceClassification | |
import torch | |
import gradio as gr | |
app = FastAPI() | |
# Prediction labels | |
LABELS = [ | |
'Login Issue', | |
'Booking Issue', | |
'Delivery Issue', | |
'Laboratory Issue', | |
'Application Issue' | |
] | |
# CORS Configuration | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Load model and tokenizer | |
try: | |
model = BertForSequenceClassification.from_pretrained("./saved_model1") | |
tokenizer = BertTokenizer.from_pretrained("./saved_model1") | |
model.eval() | |
except Exception as e: | |
raise RuntimeError(f"Model loading failed: {str(e)}") | |
# Request Model | |
class PredictionRequest(BaseModel): | |
issue: str | |
# FastAPI Endpoint | |
async def predict(request: PredictionRequest): | |
try: | |
inputs = tokenizer( | |
request.issue, | |
return_tensors="pt", | |
truncation=True, | |
padding=True, | |
max_length=512 | |
) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probabilities = torch.softmax(outputs.logits, dim=1) | |
label_idx = torch.argmax(probabilities).item() | |
return { | |
"category": LABELS[label_idx], | |
"confidence": round(probabilities[0][label_idx].item(), 4) | |
} | |
except Exception as e: | |
raise HTTPException( | |
status_code=500, | |
detail=f"Prediction error: {str(e)}" | |
) | |
# Gradio Interface | |
def gradio_classifier(text): | |
inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.softmax(outputs.logits, dim=1) | |
pred_idx = torch.argmax(probs).item() | |
return { | |
"Prediction": LABELS[pred_idx], | |
"Confidence Score": float(probs[0][pred_idx].item()), | |
"All Probabilities": { | |
label: round(float(probs[0][i]), 4) | |
for i, label in enumerate(LABELS) | |
} | |
} | |
# Mount Gradio interface | |
gradio_app = gr.Interface( | |
fn=gradio_classifier, | |
inputs=gr.Textbox(lines=3, placeholder="Enter issue description...", label="Issue"), | |
outputs=[ | |
gr.Label(label="Predicted Category"), | |
gr.Number(label="Confidence Score"), | |
gr.JSON(label="Class Probabilities") | |
], | |
title="Issue Classifier", | |
description="BERT-based classification system for customer support issues" | |
) | |
app = gr.mount_gradio_app(app, gradio_app, path="/gradio") | |