|
from fastapi import FastAPI, HTTPException, BackgroundTasks, UploadFile, File, Form |
|
from fastapi.responses import FileResponse |
|
from pydantic import BaseModel |
|
from typing import Optional, Dict, Any, List |
|
import uvicorn |
|
import logging |
|
import os |
|
import pandas as pd |
|
from datetime import datetime |
|
import shutil |
|
from pathlib import Path |
|
import numpy as np |
|
import sys |
|
import json |
|
import joblib |
|
|
|
|
|
from dataset_utils import ( |
|
load_and_preprocess_data, |
|
save_label_encoders, |
|
load_label_encoders |
|
) |
|
from config import ( |
|
TEXT_COLUMN, |
|
LABEL_COLUMNS, |
|
BATCH_SIZE, |
|
MODEL_SAVE_DIR |
|
) |
|
from models.tfidf_lgbm import TfidfLightGBM |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI(title="LGBM Compliance Predictor API") |
|
|
|
UPLOAD_DIR = Path("uploads") |
|
MODEL_SAVE_DIR = Path("saved_models") |
|
UPLOAD_DIR.mkdir(parents=True, exist_ok=True) |
|
MODEL_SAVE_DIR.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
TFIDF_PATH = os.path.join(str(MODEL_SAVE_DIR), "tfidf_vectorizer.pkl") |
|
MODEL_PATH = os.path.join(str(MODEL_SAVE_DIR), "lgbm_models.pkl") |
|
ENCODERS_PATH = os.path.join(os.path.dirname(__file__), "label_encoders.pkl") |
|
|
|
training_status = { |
|
"is_training": False, |
|
"current_epoch": 0, |
|
"total_epochs": 0, |
|
"current_loss": 0.0, |
|
"start_time": None, |
|
"end_time": None, |
|
"status": "idle", |
|
"metrics": None |
|
} |
|
|
|
class TrainingConfig(BaseModel): |
|
batch_size: int = 32 |
|
num_epochs: int = 1 |
|
random_state: int = 42 |
|
|
|
class TrainingResponse(BaseModel): |
|
message: str |
|
training_id: str |
|
status: str |
|
download_url: Optional[str] = None |
|
|
|
class ValidationResponse(BaseModel): |
|
message: str |
|
metrics: Dict[str, Any] |
|
predictions: List[Dict[str, Any]] |
|
|
|
class TransactionData(BaseModel): |
|
Transaction_Id: str |
|
Hit_Seq: int |
|
Hit_Id_List: str |
|
Origin: str |
|
Designation: str |
|
Keywords: str |
|
Name: str |
|
SWIFT_Tag: str |
|
Currency: str |
|
Entity: str |
|
Message: str |
|
City: str |
|
Country: str |
|
State: str |
|
Hit_Type: str |
|
Record_Matching_String: str |
|
WatchList_Match_String: str |
|
Payment_Sender_Name: Optional[str] = "" |
|
Payment_Reciever_Name: Optional[str] = "" |
|
Swift_Message_Type: str |
|
Text_Sanction_Data: str |
|
Matched_Sanctioned_Entity: str |
|
Is_Match: int |
|
Red_Flag_Reason: str |
|
Risk_Level: str |
|
Risk_Score: float |
|
Risk_Score_Description: str |
|
CDD_Level: str |
|
PEP_Status: str |
|
Value_Date: str |
|
Last_Review_Date: str |
|
Next_Review_Date: str |
|
Sanction_Description: str |
|
Checker_Notes: str |
|
Sanction_Context: str |
|
Maker_Action: str |
|
Customer_ID: int |
|
Customer_Type: str |
|
Industry: str |
|
Transaction_Date_Time: str |
|
Transaction_Type: str |
|
Transaction_Channel: str |
|
Originating_Bank: str |
|
Beneficiary_Bank: str |
|
Geographic_Origin: str |
|
Geographic_Destination: str |
|
Match_Score: float |
|
Match_Type: str |
|
Sanctions_List_Version: str |
|
Screening_Date_Time: str |
|
Risk_Category: str |
|
Risk_Drivers: str |
|
Alert_Status: str |
|
Investigation_Outcome: str |
|
Case_Owner_Analyst: str |
|
Escalation_Level: str |
|
Escalation_Date: str |
|
Regulatory_Reporting_Flags: bool |
|
Audit_Trail_Timestamp: str |
|
Source_Of_Funds: str |
|
Purpose_Of_Transaction: str |
|
Beneficial_Owner: str |
|
Sanctions_Exposure_History: bool |
|
|
|
|
|
class PredictionRequest(BaseModel): |
|
transaction_data: TransactionData |
|
model_name: str = "lgbm_models" |
|
|
|
class BatchPredictionResponse(BaseModel): |
|
message: str |
|
predictions: List[Dict[str, Any]] |
|
metrics: Optional[Dict[str, Any]] = None |
|
|
|
@app.get("/") |
|
async def root(): |
|
return {"message": "LGBM Compliance Predictor API"} |
|
|
|
@app.get("/v1/lgbm/health") |
|
async def health_check(): |
|
return {"status": "healthy"} |
|
|
|
@app.get("/v1/lgbm/training-status") |
|
async def get_training_status(): |
|
return training_status |
|
|
|
@app.post("/v1/lgbm/train", response_model=TrainingResponse) |
|
async def start_training( |
|
config: str = Form(...), |
|
background_tasks: BackgroundTasks = None, |
|
file: UploadFile = File(...) |
|
): |
|
if training_status["is_training"]: |
|
raise HTTPException(status_code=400, detail="Training is already in progress") |
|
if not file.filename.endswith('.csv'): |
|
raise HTTPException(status_code=400, detail="Only CSV files are allowed") |
|
try: |
|
config_dict = json.loads(config) |
|
training_config = TrainingConfig(**config_dict) |
|
except Exception as e: |
|
raise HTTPException(status_code=400, detail=f"Invalid config parameters: {str(e)}") |
|
file_path = UPLOAD_DIR / file.filename |
|
with file_path.open("wb") as buffer: |
|
shutil.copyfileobj(file.file, buffer) |
|
training_id = datetime.now().strftime("%Y%m%d_%H%M%S") |
|
training_status.update({ |
|
"is_training": True, |
|
"current_epoch": 0, |
|
"total_epochs": 1, |
|
"start_time": datetime.now().isoformat(), |
|
"status": "starting" |
|
}) |
|
background_tasks.add_task(train_model_task, training_config, str(file_path), training_id) |
|
download_url = f"/v1/lgbm/download-model/{training_id}" |
|
return TrainingResponse( |
|
message="Training started successfully", |
|
training_id=training_id, |
|
status="started", |
|
download_url=download_url |
|
) |
|
|
|
@app.post("/v1/lgbm/validate") |
|
async def validate_model( |
|
file: UploadFile = File(...), |
|
model_name: str = "lgbm_models" |
|
): |
|
if not file.filename.endswith('.csv'): |
|
raise HTTPException(status_code=400, detail="Only CSV files are allowed") |
|
try: |
|
file_path = UPLOAD_DIR / file.filename |
|
with file_path.open("wb") as buffer: |
|
shutil.copyfileobj(file.file, buffer) |
|
data_df, label_encoders = load_and_preprocess_data(str(file_path)) |
|
model_path = MODEL_SAVE_DIR / f"{model_name}.pkl" |
|
if not model_path.exists(): |
|
raise HTTPException(status_code=404, detail="LGBM model file not found") |
|
model = TfidfLightGBM(label_encoders) |
|
model.load_model(model_name) |
|
X = data_df[TEXT_COLUMN] |
|
y = data_df[LABEL_COLUMNS] |
|
|
|
if not isinstance(X, pd.Series) or not pd.api.types.is_string_dtype(X): |
|
raise HTTPException(status_code=400, detail=f"TEXT_COLUMN ('{TEXT_COLUMN}') must be a pandas Series of strings. Got type: {type(X)}, dtype: {getattr(X, 'dtype', None)}") |
|
reports, y_true_list, y_pred_list = model.evaluate(X, y) |
|
all_probs = model.predict_proba(X) |
|
predictions = [] |
|
for i, col in enumerate(LABEL_COLUMNS): |
|
label_encoder = label_encoders[col] |
|
true_labels_orig = label_encoder.inverse_transform(y_true_list[i]) |
|
pred_labels_orig = label_encoder.inverse_transform(y_pred_list[i]) |
|
for true, pred, probs in zip(true_labels_orig, pred_labels_orig, all_probs[i]): |
|
class_probs = {label: float(prob) for label, prob in zip(label_encoder.classes_, probs)} |
|
predictions.append({ |
|
"field": col, |
|
"true_label": true, |
|
"predicted_label": pred, |
|
"probabilities": class_probs |
|
}) |
|
return ValidationResponse( |
|
message="Validation completed successfully", |
|
metrics=reports, |
|
predictions=predictions |
|
) |
|
except Exception as e: |
|
logger.error(f"Validation failed: {str(e)}") |
|
raise HTTPException(status_code=500, detail=f"Validation failed: {str(e)}") |
|
finally: |
|
if os.path.exists(file_path): |
|
os.remove(file_path) |
|
|
|
@app.post("/v1/lgbm/predict") |
|
async def predict( |
|
request: Optional[PredictionRequest] = None, |
|
file: UploadFile = File(None), |
|
model_name: str = "lgbm_models" |
|
): |
|
try: |
|
|
|
tfidf = joblib.load(TFIDF_PATH) |
|
model = joblib.load(MODEL_PATH) |
|
encoders = joblib.load(ENCODERS_PATH) |
|
|
|
if file and file.filename: |
|
if not file.filename.endswith('.csv'): |
|
raise HTTPException(status_code=400, detail="Only CSV files are allowed") |
|
file_path = UPLOAD_DIR / file.filename |
|
with file_path.open("wb") as buffer: |
|
shutil.copyfileobj(file.file, buffer) |
|
try: |
|
data_df, _ = load_and_preprocess_data(str(file_path)) |
|
|
|
texts = data_df.apply(lambda row: " ".join([str(val) for val in row.values if pd.notna(val)]), axis=1) |
|
X_vec = tfidf.transform(texts) |
|
preds = model.predict(X_vec) |
|
predictions = [] |
|
for i, pred in enumerate(preds): |
|
decoded = { |
|
col: encoders[col].inverse_transform([label])[0] |
|
for col, label in zip(LABEL_COLUMNS, pred) |
|
} |
|
predictions.append({ |
|
"transaction_id": data_df.iloc[i].get('Transaction_Id', f"transaction_{i}"), |
|
"predictions": decoded |
|
}) |
|
return BatchPredictionResponse( |
|
message="Batch prediction completed successfully", |
|
predictions=predictions |
|
) |
|
finally: |
|
if os.path.exists(file_path): |
|
os.remove(file_path) |
|
|
|
elif request and request.transaction_data: |
|
input_data = pd.DataFrame([request.transaction_data.dict()]) |
|
text_input = " ".join([ |
|
str(val) for val in input_data.iloc[0].values if pd.notna(val) |
|
]) |
|
X_vec = tfidf.transform([text_input]) |
|
pred = model.predict(X_vec)[0] |
|
decoded = { |
|
col: encoders[col].inverse_transform([p])[0] |
|
for col, p in zip(LABEL_COLUMNS, pred) |
|
} |
|
return decoded |
|
else: |
|
raise HTTPException( |
|
status_code=400, |
|
detail="Either provide a transaction in the request body or upload a CSV file" |
|
) |
|
except Exception as e: |
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
@app.get("/v1/lgbm/download-model/{model_id}") |
|
async def download_model(model_id: str): |
|
model_path = MODEL_SAVE_DIR / f"{model_id}.pkl" |
|
if not model_path.exists(): |
|
raise HTTPException(status_code=404, detail="Model not found") |
|
return FileResponse( |
|
path=model_path, |
|
filename=f"lgbm_model_{model_id}.pkl", |
|
media_type="application/octet-stream" |
|
) |
|
|
|
async def train_model_task(config: TrainingConfig, file_path: str, training_id: str): |
|
try: |
|
data_df_original, label_encoders = load_and_preprocess_data(file_path) |
|
save_label_encoders(label_encoders) |
|
X = data_df_original[TEXT_COLUMN] |
|
y = data_df_original[LABEL_COLUMNS] |
|
model = TfidfLightGBM(label_encoders) |
|
model.train(X, y) |
|
model.save_model(training_id) |
|
training_status.update({ |
|
"is_training": False, |
|
"end_time": datetime.now().isoformat(), |
|
"status": "completed" |
|
}) |
|
except Exception as e: |
|
logger.error(f"Training failed: {str(e)}") |
|
training_status.update({ |
|
"is_training": False, |
|
"end_time": datetime.now().isoformat(), |
|
"status": "failed", |
|
"error": str(e) |
|
}) |
|
|
|
if __name__ == "__main__": |
|
port = int(os.environ.get("PORT", 7860)) |
|
uvicorn.run(app, host="0.0.0.0", port=port) |
|
|