Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from typing import Optional | |
import pandas as pd | |
import joblib | |
import os | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.preprocessing import LabelEncoder | |
from sklearn.multioutput import MultiOutputClassifier | |
from sklearn.linear_model import LogisticRegression | |
# --- Configuration --- | |
LABEL_COLUMNS = [ | |
"Red_Flag_Reason", "Maker_Action", "Escalation_Level", | |
"Risk_Category", "Risk_Drivers", "Investigation_Outcome" | |
] | |
TEXT_COLUMN = "Sanction_Context" | |
MODEL_DIR = "/tmp" | |
MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl") | |
TFIDF_PATH = os.path.join(MODEL_DIR, "tfidf_vectorizer.pkl") | |
ENCODERS_PATH = os.path.join(MODEL_DIR, "label_encoders.pkl") | |
# --- FastAPI App --- | |
app = FastAPI() | |
# --- Input Schema --- | |
class TransactionData(BaseModel): | |
Transaction_Id: str | |
Hit_Seq: int | |
Hit_Id_List: str | |
Origin: str | |
Designation: str | |
Keywords: str | |
Name: str | |
SWIFT_Tag: str | |
Currency: str | |
Entity: str | |
Message: str | |
City: str | |
Country: str | |
State: str | |
Hit_Type: str | |
Record_Matching_String: str | |
WatchList_Match_String: str | |
Payment_Sender_Name: Optional[str] = "" | |
Payment_Reciever_Name: Optional[str] = "" | |
Swift_Message_Type: str | |
Text_Sanction_Data: str | |
Matched_Sanctioned_Entity: str | |
Is_Match: int | |
Red_Flag_Reason: str | |
Risk_Level: str | |
Risk_Score: float | |
Risk_Score_Description: str | |
CDD_Level: str | |
PEP_Status: str | |
Value_Date: str | |
Last_Review_Date: str | |
Next_Review_Date: str | |
Sanction_Description: str | |
Checker_Notes: str | |
Sanction_Context: str | |
Maker_Action: str | |
Customer_ID: int | |
Customer_Type: str | |
Industry: str | |
Transaction_Date_Time: str | |
Transaction_Type: str | |
Transaction_Channel: str | |
Originating_Bank: str | |
Beneficiary_Bank: str | |
Geographic_Origin: str | |
Geographic_Destination: str | |
Match_Score: float | |
Match_Type: str | |
Sanctions_List_Version: str | |
Screening_Date_Time: str | |
Risk_Category: str | |
Risk_Drivers: str | |
Alert_Status: str | |
Investigation_Outcome: str | |
Case_Owner_Analyst: str | |
Escalation_Level: str | |
Escalation_Date: str | |
Regulatory_Reporting_Flags: bool | |
Audit_Trail_Timestamp: str | |
Source_Of_Funds: str | |
Purpose_Of_Transaction: str | |
Beneficial_Owner: str | |
Sanctions_Exposure_History: bool | |
class PredictionRequest(BaseModel): | |
transaction_data: TransactionData | |
class DataPathInput(BaseModel): | |
data_path: str | |
def health_check(): | |
return {"status": "healthy", "message": "logistic regression complience predictor API "} | |
def train_model(input: DataPathInput): | |
try: | |
df = pd.read_csv(input.data_path) | |
df.dropna(subset=[TEXT_COLUMN] + LABEL_COLUMNS, inplace=True) | |
label_encoders = {} | |
for col in LABEL_COLUMNS: | |
le = LabelEncoder() | |
df[col] = le.fit_transform(df[col]) | |
label_encoders[col] = le | |
tfidf = TfidfVectorizer(max_features=1000, ngram_range=(1, 2), stop_words="english") | |
X_vec = tfidf.fit_transform(df[TEXT_COLUMN]) | |
y = df[LABEL_COLUMNS] | |
model = MultiOutputClassifier(LogisticRegression(max_iter=1000)) | |
model.fit(X_vec, y) | |
joblib.dump(model, MODEL_PATH) | |
joblib.dump(tfidf, TFIDF_PATH) | |
joblib.dump(label_encoders, ENCODERS_PATH) | |
return {"status": "β Logistic Regression model trained and saved."} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def validate_model(input: DataPathInput): | |
try: | |
df = pd.read_csv(input.data_path) | |
required_columns = [TEXT_COLUMN] + LABEL_COLUMNS | |
missing = [col for col in required_columns if col not in df.columns] | |
if missing: | |
return {"status": "β Invalid input", "missing_columns": missing} | |
else: | |
return {"status": "β Input is valid."} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Validation error: {str(e)}") | |
def test_model(input: DataPathInput): | |
try: | |
df = pd.read_csv(input.data_path) | |
df = df.dropna(subset=[TEXT_COLUMN]) | |
tfidf = joblib.load(TFIDF_PATH) | |
model = joblib.load(MODEL_PATH) | |
encoders = joblib.load(ENCODERS_PATH) | |
X_vec = tfidf.transform(df[TEXT_COLUMN]) | |
preds = model.predict(X_vec) | |
decoded_preds = [] | |
for pred in preds: | |
decoded = { | |
col: encoders[col].inverse_transform([label])[0] | |
for col, label in zip(LABEL_COLUMNS, pred) | |
} | |
decoded_preds.append(decoded) | |
return {"predictions": decoded_preds[:5]} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def predict(request: PredictionRequest): | |
try: | |
input_data = pd.DataFrame([request.transaction_data.dict()]) | |
text_input = " ".join([str(val) for val in input_data.iloc[0].values if pd.notna(val)]) | |
tfidf = joblib.load(TFIDF_PATH) | |
model = joblib.load(MODEL_PATH) | |
encoders = joblib.load(ENCODERS_PATH) | |
X_vec = tfidf.transform([text_input]) | |
pred = model.predict(X_vec)[0] | |
decoded = { | |
col: encoders[col].inverse_transform([p])[0] | |
for col, p in zip(LABEL_COLUMNS, pred) | |
} | |
return {"prediction": decoded} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |
def train_test_validate(input: DataPathInput): | |
try: | |
train_model(input) | |
validate_result = validate_model(input) | |
test_result = test_model(input) | |
return { | |
"train": "β Done", | |
"validate": validate_result, | |
"test": test_result, | |
} | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=str(e)) | |