LOGREG_TTV / train_test_validate.py
ganeshkonapalli's picture
Create train_test_validate.py
c778365 verified
import os
import json
import joblib
import requests
import pandas as pd
from typing import List
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.multioutput import MultiOutputClassifier
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import LabelEncoder
from sklearn.linear_model import LogisticRegression
from pydantic import BaseModel, ValidationError
import argparse
# --- CONFIG ---
DATA_PATH = "data.csv"
TEXT_COLUMN = "Sanction_Context"
LABEL_COLUMNS = [
"Red_Flag_Reason", "Maker_Action", "Escalation_Level",
"Risk_Category", "Risk_Drivers", "Investigation_Outcome"
]
MODEL_SAVE_DIR = "models"
LABEL_ENCODERS_PATH = os.path.join(MODEL_SAVE_DIR, "label_encoders.pkl")
TFIDF_MAX_FEATURES = 1000
NGRAM_RANGE = (1, 2)
USE_STOPWORDS = True
RANDOM_STATE = 42
TEST_SIZE = 0.2
API_URL = "https://your-hf-api-url.hf.space/predict" # Replace with actual URL
os.makedirs(MODEL_SAVE_DIR, exist_ok=True)
# --- Pydantic schema ---
class TransactionData(BaseModel):
Transaction_Id: str
Hit_Seq: int
Hit_Id_List: str
Origin: str
Designation: str
Keywords: str
Name: str
SWIFT_Tag: str
Currency: str
Entity: str
Message: str
City: str
Country: str
State: str
Hit_Type: str
Record_Matching_String: str
WatchList_Match_String: str
Payment_Sender_Name: str
Payment_Reciever_Name: str
Swift_Message_Type: str
Text_Sanction_Data: str
Matched_Sanctioned_Entity: str
Is_Match: int
Red_Flag_Reason: str
Risk_Level: str
Risk_Score: float
Risk_Score_Description: str
CDD_Level: str
PEP_Status: str
Value_Date: str
Last_Review_Date: str
Next_Review_Date: str
Sanction_Description: str
Checker_Notes: str
Sanction_Context: str
Maker_Action: str
Customer_ID: int
Customer_Type: str
Industry: str
Transaction_Date_Time: str
Transaction_Type: str
Transaction_Channel: str
Originating_Bank: str
Beneficiary_Bank: str
Geographic_Origin: str
Geographic_Destination: str
Match_Score: float
Match_Type: str
Sanctions_List_Version: str
Screening_Date_Time: str
Risk_Category: str
Risk_Drivers: str
Alert_Status: str
Investigation_Outcome: str
Case_Owner_Analyst: str
Escalation_Level: str
Escalation_Date: str
Regulatory_Reporting_Flags: bool
Audit_Trail_Timestamp: str
Source_Of_Funds: str
Purpose_Of_Transaction: str
Beneficial_Owner: str
Sanctions_Exposure_History: bool
# --- Train function ---
def train_pipeline():
print("πŸ“₯ Loading dataset...")
df = pd.read_csv(DATA_PATH)
df.dropna(subset=[TEXT_COLUMN] + LABEL_COLUMNS, inplace=True)
label_encoders = {}
for col in LABEL_COLUMNS:
le = LabelEncoder()
df[col] = le.fit_transform(df[col])
label_encoders[col] = le
X = df[TEXT_COLUMN]
Y = df[LABEL_COLUMNS]
print("βœ‚οΈ Splitting train/test...")
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=TEST_SIZE, random_state=RANDOM_STATE)
print("πŸ”§ Building pipeline with Logistic Regression...")
stop_words = "english" if USE_STOPWORDS else None
pipeline = Pipeline([
('tfidf', TfidfVectorizer(max_features=TFIDF_MAX_FEATURES, ngram_range=NGRAM_RANGE, stop_words=stop_words)),
('clf', MultiOutputClassifier(LogisticRegression(random_state=RANDOM_STATE, max_iter=1000)))
])
print("πŸ‹οΈ Training...")
pipeline.fit(X_train, y_train)
model_path = os.path.join(MODEL_SAVE_DIR, "logreg_model.pkl")
print(f"πŸ’Ύ Saving model to {model_path}")
joblib.dump(pipeline, model_path)
print(f"πŸ’Ύ Saving label encoders to {LABEL_ENCODERS_PATH}")
joblib.dump(label_encoders, LABEL_ENCODERS_PATH)
tfidf_path = os.path.join(MODEL_SAVE_DIR, "tfidf_vectorizer.pkl")
joblib.dump(pipeline.named_steps["tfidf"], tfidf_path)
print("βœ… Training complete.")
# --- Input Validator ---
def validate_sample_input(sample_input):
try:
validated = TransactionData(**sample_input)
print("βœ… Input is valid.")
except ValidationError as e:
print("❌ Validation error:")
print(e.json(indent=2))
# --- API Test ---
def test_api(sample_payload):
headers = {"Content-Type": "application/json"}
print(f"πŸš€ Posting to {API_URL}")
response = requests.post(API_URL, headers=headers, data=json.dumps(sample_payload))
print("πŸ“₯ Status Code:", response.status_code)
try:
print("πŸ“€ Response:", json.dumps(response.json(), indent=2))
except Exception as e:
print("❌ Failed to parse response:", str(e))
# --- Sample Payload (unchanged) ---
sample_payload = {
"transaction_data": {
"Transaction_Id": "TXN12345",
"Hit_Seq": 1,
"Hit_Id_List": "HIT789",
"Origin": "India",
"Designation": "Manager",
"Keywords": "fraud",
"Name": "John Doe",
"SWIFT_Tag": "TAG001",
"Currency": "INR",
"Entity": "ABC Ltd",
"Message": "Payment for services",
"City": "Hyderabad",
"Country": "India",
"State": "Telangana",
"Hit_Type": "Individual",
"Record_Matching_String": "John Doe",
"WatchList_Match_String": "Doe, John",
"Payment_Sender_Name": "John Doe",
"Payment_Reciever_Name": "Jane Smith",
"Swift_Message_Type": "MT103",
"Text_Sanction_Data": "Suspicious transfer to offshore account",
"Matched_Sanctioned_Entity": "John Doe",
"Is_Match": 1,
"Red_Flag_Reason": "High value transaction",
"Risk_Level": "High",
"Risk_Score": 87.5,
"Risk_Score_Description": "Very High",
"CDD_Level": "Enhanced",
"PEP_Status": "Yes",
"Value_Date": "2023-01-01",
"Last_Review_Date": "2023-06-01",
"Next_Review_Date": "2024-06-01",
"Sanction_Description": "OFAC List",
"Checker_Notes": "Urgent check required",
"Sanction_Context": "Payment matched with OFAC entry",
"Maker_Action": "Escalate",
"Customer_ID": 1001,
"Customer_Type": "Corporate",
"Industry": "Finance",
"Transaction_Date_Time": "2023-12-15T10:00:00",
"Transaction_Type": "Credit",
"Transaction_Channel": "Online",
"Originating_Bank": "ABC Bank",
"Beneficiary_Bank": "XYZ Bank",
"Geographic_Origin": "India",
"Geographic_Destination": "USA",
"Match_Score": 96.2,
"Match_Type": "Exact",
"Sanctions_List_Version": "2023-V5",
"Screening_Date_Time": "2023-12-15T09:55:00",
"Risk_Category": "Sanctions",
"Risk_Drivers": "PEP, High Value",
"Alert_Status": "Open",
"Investigation_Outcome": "Pending",
"Case_Owner_Analyst": "analyst1",
"Escalation_Level": "L2",
"Escalation_Date": "2023-12-16",
"Regulatory_Reporting_Flags": True,
"Audit_Trail_Timestamp": "2023-12-15T10:05:00",
"Source_Of_Funds": "Corporate Account",
"Purpose_Of_Transaction": "Service Payment",
"Beneficial_Owner": "John Doe",
"Sanctions_Exposure_History": False
}
}
# --- Main Entry ---
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--train", action="store_true", help="Train the model")
parser.add_argument("--validate", action="store_true", help="Validate sample input")
parser.add_argument("--test", action="store_true", help="Test prediction API")
args = parser.parse_args()
if args.train:
train_pipeline()
if args.validate:
validate_sample_input(sample_payload["transaction_data"])
if args.test:
test_api(sample_payload)