import os import json import joblib import requests import pandas as pd from typing import List from sklearn.model_selection import train_test_split from sklearn.feature_extraction.text import TfidfVectorizer from sklearn.multioutput import MultiOutputClassifier from sklearn.pipeline import Pipeline from sklearn.preprocessing import LabelEncoder from sklearn.linear_model import LogisticRegression from pydantic import BaseModel, ValidationError import argparse # --- CONFIG --- DATA_PATH = "data.csv" TEXT_COLUMN = "Sanction_Context" LABEL_COLUMNS = [ "Red_Flag_Reason", "Maker_Action", "Escalation_Level", "Risk_Category", "Risk_Drivers", "Investigation_Outcome" ] MODEL_SAVE_DIR = "models" LABEL_ENCODERS_PATH = os.path.join(MODEL_SAVE_DIR, "label_encoders.pkl") TFIDF_MAX_FEATURES = 1000 NGRAM_RANGE = (1, 2) USE_STOPWORDS = True RANDOM_STATE = 42 TEST_SIZE = 0.2 API_URL = "https://your-hf-api-url.hf.space/predict" # Replace with actual URL os.makedirs(MODEL_SAVE_DIR, exist_ok=True) # --- Pydantic schema --- class TransactionData(BaseModel): Transaction_Id: str Hit_Seq: int Hit_Id_List: str Origin: str Designation: str Keywords: str Name: str SWIFT_Tag: str Currency: str Entity: str Message: str City: str Country: str State: str Hit_Type: str Record_Matching_String: str WatchList_Match_String: str Payment_Sender_Name: str Payment_Reciever_Name: str Swift_Message_Type: str Text_Sanction_Data: str Matched_Sanctioned_Entity: str Is_Match: int Red_Flag_Reason: str Risk_Level: str Risk_Score: float Risk_Score_Description: str CDD_Level: str PEP_Status: str Value_Date: str Last_Review_Date: str Next_Review_Date: str Sanction_Description: str Checker_Notes: str Sanction_Context: str Maker_Action: str Customer_ID: int Customer_Type: str Industry: str Transaction_Date_Time: str Transaction_Type: str Transaction_Channel: str Originating_Bank: str Beneficiary_Bank: str Geographic_Origin: str Geographic_Destination: str Match_Score: float Match_Type: str Sanctions_List_Version: str Screening_Date_Time: str Risk_Category: str Risk_Drivers: str Alert_Status: str Investigation_Outcome: str Case_Owner_Analyst: str Escalation_Level: str Escalation_Date: str Regulatory_Reporting_Flags: bool Audit_Trail_Timestamp: str Source_Of_Funds: str Purpose_Of_Transaction: str Beneficial_Owner: str Sanctions_Exposure_History: bool # --- Train function --- def train_pipeline(): print("📥 Loading dataset...") df = pd.read_csv(DATA_PATH) df.dropna(subset=[TEXT_COLUMN] + LABEL_COLUMNS, inplace=True) label_encoders = {} for col in LABEL_COLUMNS: le = LabelEncoder() df[col] = le.fit_transform(df[col]) label_encoders[col] = le X = df[TEXT_COLUMN] Y = df[LABEL_COLUMNS] print("✂️ Splitting train/test...") X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=TEST_SIZE, random_state=RANDOM_STATE) print("🔧 Building pipeline with Logistic Regression...") stop_words = "english" if USE_STOPWORDS else None pipeline = Pipeline([ ('tfidf', TfidfVectorizer(max_features=TFIDF_MAX_FEATURES, ngram_range=NGRAM_RANGE, stop_words=stop_words)), ('clf', MultiOutputClassifier(LogisticRegression(random_state=RANDOM_STATE, max_iter=1000))) ]) print("🏋️ Training...") pipeline.fit(X_train, y_train) model_path = os.path.join(MODEL_SAVE_DIR, "logreg_model.pkl") print(f"💾 Saving model to {model_path}") joblib.dump(pipeline, model_path) print(f"💾 Saving label encoders to {LABEL_ENCODERS_PATH}") joblib.dump(label_encoders, LABEL_ENCODERS_PATH) tfidf_path = os.path.join(MODEL_SAVE_DIR, "tfidf_vectorizer.pkl") joblib.dump(pipeline.named_steps["tfidf"], tfidf_path) print("✅ Training complete.") # --- Input Validator --- def validate_sample_input(sample_input): try: validated = TransactionData(**sample_input) print("✅ Input is valid.") except ValidationError as e: print("❌ Validation error:") print(e.json(indent=2)) # --- API Test --- def test_api(sample_payload): headers = {"Content-Type": "application/json"} print(f"🚀 Posting to {API_URL}") response = requests.post(API_URL, headers=headers, data=json.dumps(sample_payload)) print("📥 Status Code:", response.status_code) try: print("📤 Response:", json.dumps(response.json(), indent=2)) except Exception as e: print("❌ Failed to parse response:", str(e)) # --- Sample Payload (unchanged) --- sample_payload = { "transaction_data": { "Transaction_Id": "TXN12345", "Hit_Seq": 1, "Hit_Id_List": "HIT789", "Origin": "India", "Designation": "Manager", "Keywords": "fraud", "Name": "John Doe", "SWIFT_Tag": "TAG001", "Currency": "INR", "Entity": "ABC Ltd", "Message": "Payment for services", "City": "Hyderabad", "Country": "India", "State": "Telangana", "Hit_Type": "Individual", "Record_Matching_String": "John Doe", "WatchList_Match_String": "Doe, John", "Payment_Sender_Name": "John Doe", "Payment_Reciever_Name": "Jane Smith", "Swift_Message_Type": "MT103", "Text_Sanction_Data": "Suspicious transfer to offshore account", "Matched_Sanctioned_Entity": "John Doe", "Is_Match": 1, "Red_Flag_Reason": "High value transaction", "Risk_Level": "High", "Risk_Score": 87.5, "Risk_Score_Description": "Very High", "CDD_Level": "Enhanced", "PEP_Status": "Yes", "Value_Date": "2023-01-01", "Last_Review_Date": "2023-06-01", "Next_Review_Date": "2024-06-01", "Sanction_Description": "OFAC List", "Checker_Notes": "Urgent check required", "Sanction_Context": "Payment matched with OFAC entry", "Maker_Action": "Escalate", "Customer_ID": 1001, "Customer_Type": "Corporate", "Industry": "Finance", "Transaction_Date_Time": "2023-12-15T10:00:00", "Transaction_Type": "Credit", "Transaction_Channel": "Online", "Originating_Bank": "ABC Bank", "Beneficiary_Bank": "XYZ Bank", "Geographic_Origin": "India", "Geographic_Destination": "USA", "Match_Score": 96.2, "Match_Type": "Exact", "Sanctions_List_Version": "2023-V5", "Screening_Date_Time": "2023-12-15T09:55:00", "Risk_Category": "Sanctions", "Risk_Drivers": "PEP, High Value", "Alert_Status": "Open", "Investigation_Outcome": "Pending", "Case_Owner_Analyst": "analyst1", "Escalation_Level": "L2", "Escalation_Date": "2023-12-16", "Regulatory_Reporting_Flags": True, "Audit_Trail_Timestamp": "2023-12-15T10:05:00", "Source_Of_Funds": "Corporate Account", "Purpose_Of_Transaction": "Service Payment", "Beneficial_Owner": "John Doe", "Sanctions_Exposure_History": False } } # --- Main Entry --- if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--train", action="store_true", help="Train the model") parser.add_argument("--validate", action="store_true", help="Validate sample input") parser.add_argument("--test", action="store_true", help="Test prediction API") args = parser.parse_args() if args.train: train_pipeline() if args.validate: validate_sample_input(sample_payload["transaction_data"]) if args.test: test_api(sample_payload)