from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field, validator from typing import Optional import pandas as pd import joblib app = FastAPI() # === Paths === TFIDF_PATH = "models/tfidf_vectorizer.pkl" MODEL_PATH = "models/logreg_model.pkl" ENCODER_PATH = "models/label_encoders.pkl" # === Load artifacts === tfidf_vectorizer = joblib.load(TFIDF_PATH) models = joblib.load(MODEL_PATH) label_encoders = joblib.load(ENCODER_PATH) # === Input Schema === class TransactionData(BaseModel): Transaction_Id: str Hit_Seq: int Hit_Id_List: str Origin: str Designation: str Keywords: str Name: str SWIFT_Tag: str Currency: str Entity: str Message: str City: str Country: str State: str Hit_Type: str Record_Matching_String: str WatchList_Match_String: str Payment_Sender_Name: Optional[str] = "" Payment_Reciever_Name: Optional[str] = "" Swift_Message_Type: str Text_Sanction_Data: str Matched_Sanctioned_Entity: str Is_Match: int Red_Flag_Reason: str Risk_Level: str Risk_Score: float Risk_Score_Description: str CDD_Level: str PEP_Status: str Value_Date: str Last_Review_Date: str Next_Review_Date: str Sanction_Description: str Checker_Notes: str Sanction_Context: str = Field(..., min_length=5) Maker_Action: str Customer_ID: int Customer_Type: str Industry: str Transaction_Date_Time: str Transaction_Type: str Transaction_Channel: str Originating_Bank: str Beneficiary_Bank: str Geographic_Origin: str Geographic_Destination: str Match_Score: float Match_Type: str Sanctions_List_Version: str Screening_Date_Time: str Risk_Category: str Risk_Drivers: str Alert_Status: str Investigation_Outcome: str Case_Owner_Analyst: str Escalation_Level: str Escalation_Date: str Regulatory_Reporting_Flags: bool Audit_Trail_Timestamp: str Source_Of_Funds: str Purpose_Of_Transaction: str Beneficial_Owner: str Sanctions_Exposure_History: bool @validator("Sanction_Context") def context_not_blank(cls, v): if not v.strip(): raise ValueError("Sanction_Context must not be blank.") return v class PredictionRequest(BaseModel): transaction_data: TransactionData @app.get("/") def root(): return {"status": "healthy", "message": "XGBoost TF-IDF API is running"} @app.post("/validate") def validate_input(request: PredictionRequest): try: _ = request.transaction_data return {"status": "success", "message": "Input schema is valid."} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) @app.post("/predict") def predict(request: PredictionRequest): try: input_data = pd.DataFrame([request.transaction_data.dict()]) text_input = f""" Transaction ID: {input_data['Transaction_Id'].iloc[0]} Origin: {input_data['Origin'].iloc[0]} Designation: {input_data['Designation'].iloc[0]} Keywords: {input_data['Keywords'].iloc[0]} Name: {input_data['Name'].iloc[0]} SWIFT Tag: {input_data['SWIFT_Tag'].iloc[0]} Currency: {input_data['Currency'].iloc[0]} Entity: {input_data['Entity'].iloc[0]} Message: {input_data['Message'].iloc[0]} City: {input_data['City'].iloc[0]} Country: {input_data['Country'].iloc[0]} State: {input_data['State'].iloc[0]} Hit Type: {input_data['Hit_Type'].iloc[0]} Record Matching String: {input_data['Record_Matching_String'].iloc[0]} WatchList Match String: {input_data['WatchList_Match_String'].iloc[0]} Payment Sender: {input_data['Payment_Sender_Name'].iloc[0]} Payment Receiver: {input_data['Payment_Reciever_Name'].iloc[0]} Swift Message Type: {input_data['Swift_Message_Type'].iloc[0]} Text Sanction Data: {input_data['Text_Sanction_Data'].iloc[0]} Matched Sanctioned Entity: {input_data['Matched_Sanctioned_Entity'].iloc[0]} Red Flag Reason: {input_data['Red_Flag_Reason'].iloc[0]} Risk Level: {input_data['Risk_Level'].iloc[0]} Risk Score: {input_data['Risk_Score'].iloc[0]} CDD Level: {input_data['CDD_Level'].iloc[0]} PEP Status: {input_data['PEP_Status'].iloc[0]} Sanction Description: {input_data['Sanction_Description'].iloc[0]} Checker Notes: {input_data['Checker_Notes'].iloc[0]} Sanction Context: {input_data['Sanction_Context'].iloc[0]} Maker Action: {input_data['Maker_Action'].iloc[0]} Customer Type: {input_data['Customer_Type'].iloc[0]} Industry: {input_data['Industry'].iloc[0]} Transaction Type: {input_data['Transaction_Type'].iloc[0]} Transaction Channel: {input_data['Transaction_Channel'].iloc[0]} Geographic Origin: {input_data['Geographic_Origin'].iloc[0]} Geographic Destination: {input_data['Geographic_Destination'].iloc[0]} Risk Category: {input_data['Risk_Category'].iloc[0]} Risk Drivers: {input_data['Risk_Drivers'].iloc[0]} Alert Status: {input_data['Alert_Status'].iloc[0]} Investigation Outcome: {input_data['Investigation_Outcome'].iloc[0]} Source of Funds: {input_data['Source_Of_Funds'].iloc[0]} Purpose of Transaction: {input_data['Purpose_Of_Transaction'].iloc[0]} Beneficial Owner: {input_data['Beneficial_Owner'].iloc[0]} """ X_tfidf = tfidf_vectorizer.transform([text_input]) response = {} for label, model in models.items(): proba = model.predict_proba(X_tfidf)[0] pred_idx = proba.argmax() decoded = label_encoders[label].inverse_transform([pred_idx])[0] response[label] = { "prediction": decoded, "probabilities": { label_encoders[label].classes_[i]: float(p) for i, p in enumerate(proba) } } return response except Exception as e: raise HTTPException(status_code=500, detail=str(e))