from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel from transformers import BertTokenizer, BertForSequenceClassification import torch import gradio as gr app = FastAPI() # Prediction labels LABELS = [ 'Login Issue', 'Booking Issue', 'Delivery Issue', 'Laboratory Issue', 'Application Issue' ] # CORS Configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) # Load model and tokenizer try: model = BertForSequenceClassification.from_pretrained("./saved_model1") tokenizer = BertTokenizer.from_pretrained("./saved_model1") model.eval() except Exception as e: raise RuntimeError(f"Model loading failed: {str(e)}") # Request Model class PredictionRequest(BaseModel): issue: str # FastAPI Endpoint @app.post("/predict") async def predict(request: PredictionRequest): try: inputs = tokenizer( request.issue, return_tensors="pt", truncation=True, padding=True, max_length=512 ) with torch.no_grad(): outputs = model(**inputs) probabilities = torch.softmax(outputs.logits, dim=1) label_idx = torch.argmax(probabilities).item() return { "category": LABELS[label_idx], "confidence": round(probabilities[0][label_idx].item(), 4) } except Exception as e: raise HTTPException( status_code=500, detail=f"Prediction error: {str(e)}" ) # Gradio Interface def gradio_classifier(text): inputs = tokenizer(text, return_tensors="pt", truncation=True, padding=True) with torch.no_grad(): outputs = model(**inputs) probs = torch.softmax(outputs.logits, dim=1) pred_idx = torch.argmax(probs).item() return { "Prediction": LABELS[pred_idx], "Confidence Score": float(probs[0][pred_idx].item()), "All Probabilities": { label: round(float(probs[0][i]), 4) for i, label in enumerate(LABELS) } } # Mount Gradio interface gradio_app = gr.Interface( fn=gradio_classifier, inputs=gr.Textbox(lines=3, placeholder="Enter issue description...", label="Issue"), outputs=[ gr.Label(label="Predicted Category"), gr.Number(label="Confidence Score"), gr.JSON(label="Class Probabilities") ], title="Issue Classifier", description="BERT-based classification system for customer support issues" ) app = gr.mount_gradio_app(app, gradio_app, path="/gradio")