LOGREG_TTV / app.py
ganeshkonapalli's picture
Update app.py
efb2161 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from typing import Optional
import pandas as pd
import joblib
import os
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.preprocessing import LabelEncoder
from sklearn.multioutput import MultiOutputClassifier
from sklearn.linear_model import LogisticRegression
# --- Configuration ---
LABEL_COLUMNS = [
"Red_Flag_Reason", "Maker_Action", "Escalation_Level",
"Risk_Category", "Risk_Drivers", "Investigation_Outcome"
]
TEXT_COLUMN = "Sanction_Context"
MODEL_DIR = "/tmp"
MODEL_PATH = os.path.join(MODEL_DIR, "logreg_model.pkl")
TFIDF_PATH = os.path.join(MODEL_DIR, "tfidf_vectorizer.pkl")
ENCODERS_PATH = os.path.join(MODEL_DIR, "label_encoders.pkl")
# --- FastAPI App ---
app = FastAPI()
# --- Input Schema ---
class TransactionData(BaseModel):
Transaction_Id: str
Hit_Seq: int
Hit_Id_List: str
Origin: str
Designation: str
Keywords: str
Name: str
SWIFT_Tag: str
Currency: str
Entity: str
Message: str
City: str
Country: str
State: str
Hit_Type: str
Record_Matching_String: str
WatchList_Match_String: str
Payment_Sender_Name: Optional[str] = ""
Payment_Reciever_Name: Optional[str] = ""
Swift_Message_Type: str
Text_Sanction_Data: str
Matched_Sanctioned_Entity: str
Is_Match: int
Red_Flag_Reason: str
Risk_Level: str
Risk_Score: float
Risk_Score_Description: str
CDD_Level: str
PEP_Status: str
Value_Date: str
Last_Review_Date: str
Next_Review_Date: str
Sanction_Description: str
Checker_Notes: str
Sanction_Context: str
Maker_Action: str
Customer_ID: int
Customer_Type: str
Industry: str
Transaction_Date_Time: str
Transaction_Type: str
Transaction_Channel: str
Originating_Bank: str
Beneficiary_Bank: str
Geographic_Origin: str
Geographic_Destination: str
Match_Score: float
Match_Type: str
Sanctions_List_Version: str
Screening_Date_Time: str
Risk_Category: str
Risk_Drivers: str
Alert_Status: str
Investigation_Outcome: str
Case_Owner_Analyst: str
Escalation_Level: str
Escalation_Date: str
Regulatory_Reporting_Flags: bool
Audit_Trail_Timestamp: str
Source_Of_Funds: str
Purpose_Of_Transaction: str
Beneficial_Owner: str
Sanctions_Exposure_History: bool
class PredictionRequest(BaseModel):
transaction_data: TransactionData
class DataPathInput(BaseModel):
data_path: str
@app.get("/")
def health_check():
return {"status": "healthy", "message": "logistic regression complience predictor API "}
@app.post("/train")
def train_model(input: DataPathInput):
try:
df = pd.read_csv(input.data_path)
df.dropna(subset=[TEXT_COLUMN] + LABEL_COLUMNS, inplace=True)
label_encoders = {}
for col in LABEL_COLUMNS:
le = LabelEncoder()
df[col] = le.fit_transform(df[col])
label_encoders[col] = le
tfidf = TfidfVectorizer(max_features=1000, ngram_range=(1, 2), stop_words="english")
X_vec = tfidf.fit_transform(df[TEXT_COLUMN])
y = df[LABEL_COLUMNS]
model = MultiOutputClassifier(LogisticRegression(max_iter=1000))
model.fit(X_vec, y)
joblib.dump(model, MODEL_PATH)
joblib.dump(tfidf, TFIDF_PATH)
joblib.dump(label_encoders, ENCODERS_PATH)
return {"status": "βœ… Logistic Regression model trained and saved."}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/validate")
def validate_model(input: DataPathInput):
try:
df = pd.read_csv(input.data_path)
required_columns = [TEXT_COLUMN] + LABEL_COLUMNS
missing = [col for col in required_columns if col not in df.columns]
if missing:
return {"status": "❌ Invalid input", "missing_columns": missing}
else:
return {"status": "βœ… Input is valid."}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Validation error: {str(e)}")
@app.post("/test")
def test_model(input: DataPathInput):
try:
df = pd.read_csv(input.data_path)
df = df.dropna(subset=[TEXT_COLUMN])
tfidf = joblib.load(TFIDF_PATH)
model = joblib.load(MODEL_PATH)
encoders = joblib.load(ENCODERS_PATH)
X_vec = tfidf.transform(df[TEXT_COLUMN])
preds = model.predict(X_vec)
decoded_preds = []
for pred in preds:
decoded = {
col: encoders[col].inverse_transform([label])[0]
for col, label in zip(LABEL_COLUMNS, pred)
}
decoded_preds.append(decoded)
return {"predictions": decoded_preds[:5]}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/predict")
def predict(request: PredictionRequest):
try:
input_data = pd.DataFrame([request.transaction_data.dict()])
text_input = " ".join([str(val) for val in input_data.iloc[0].values if pd.notna(val)])
tfidf = joblib.load(TFIDF_PATH)
model = joblib.load(MODEL_PATH)
encoders = joblib.load(ENCODERS_PATH)
X_vec = tfidf.transform([text_input])
pred = model.predict(X_vec)[0]
decoded = {
col: encoders[col].inverse_transform([p])[0]
for col, p in zip(LABEL_COLUMNS, pred)
}
return {"prediction": decoded}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.post("/ttv")
def train_test_validate(input: DataPathInput):
try:
train_model(input)
validate_result = validate_model(input)
test_result = test_model(input)
return {
"train": "βœ… Done",
"validate": validate_result,
"test": test_result,
}
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))