Spaces:
Sleeping
Sleeping
import os | |
import json | |
import joblib | |
import requests | |
import pandas as pd | |
from typing import List | |
from sklearn.model_selection import train_test_split | |
from sklearn.feature_extraction.text import TfidfVectorizer | |
from sklearn.multioutput import MultiOutputClassifier | |
from sklearn.pipeline import Pipeline | |
from sklearn.preprocessing import LabelEncoder | |
from sklearn.linear_model import LogisticRegression | |
from pydantic import BaseModel, ValidationError | |
import argparse | |
# --- CONFIG --- | |
DATA_PATH = "data.csv" | |
TEXT_COLUMN = "Sanction_Context" | |
LABEL_COLUMNS = [ | |
"Red_Flag_Reason", "Maker_Action", "Escalation_Level", | |
"Risk_Category", "Risk_Drivers", "Investigation_Outcome" | |
] | |
MODEL_SAVE_DIR = "models" | |
LABEL_ENCODERS_PATH = os.path.join(MODEL_SAVE_DIR, "label_encoders.pkl") | |
TFIDF_MAX_FEATURES = 1000 | |
NGRAM_RANGE = (1, 2) | |
USE_STOPWORDS = True | |
RANDOM_STATE = 42 | |
TEST_SIZE = 0.2 | |
API_URL = "https://your-hf-api-url.hf.space/predict" # Replace with actual URL | |
os.makedirs(MODEL_SAVE_DIR, exist_ok=True) | |
# --- Pydantic schema --- | |
class TransactionData(BaseModel): | |
Transaction_Id: str | |
Hit_Seq: int | |
Hit_Id_List: str | |
Origin: str | |
Designation: str | |
Keywords: str | |
Name: str | |
SWIFT_Tag: str | |
Currency: str | |
Entity: str | |
Message: str | |
City: str | |
Country: str | |
State: str | |
Hit_Type: str | |
Record_Matching_String: str | |
WatchList_Match_String: str | |
Payment_Sender_Name: str | |
Payment_Reciever_Name: str | |
Swift_Message_Type: str | |
Text_Sanction_Data: str | |
Matched_Sanctioned_Entity: str | |
Is_Match: int | |
Red_Flag_Reason: str | |
Risk_Level: str | |
Risk_Score: float | |
Risk_Score_Description: str | |
CDD_Level: str | |
PEP_Status: str | |
Value_Date: str | |
Last_Review_Date: str | |
Next_Review_Date: str | |
Sanction_Description: str | |
Checker_Notes: str | |
Sanction_Context: str | |
Maker_Action: str | |
Customer_ID: int | |
Customer_Type: str | |
Industry: str | |
Transaction_Date_Time: str | |
Transaction_Type: str | |
Transaction_Channel: str | |
Originating_Bank: str | |
Beneficiary_Bank: str | |
Geographic_Origin: str | |
Geographic_Destination: str | |
Match_Score: float | |
Match_Type: str | |
Sanctions_List_Version: str | |
Screening_Date_Time: str | |
Risk_Category: str | |
Risk_Drivers: str | |
Alert_Status: str | |
Investigation_Outcome: str | |
Case_Owner_Analyst: str | |
Escalation_Level: str | |
Escalation_Date: str | |
Regulatory_Reporting_Flags: bool | |
Audit_Trail_Timestamp: str | |
Source_Of_Funds: str | |
Purpose_Of_Transaction: str | |
Beneficial_Owner: str | |
Sanctions_Exposure_History: bool | |
# --- Train function --- | |
def train_pipeline(): | |
print("π₯ Loading dataset...") | |
df = pd.read_csv(DATA_PATH) | |
df.dropna(subset=[TEXT_COLUMN] + LABEL_COLUMNS, inplace=True) | |
label_encoders = {} | |
for col in LABEL_COLUMNS: | |
le = LabelEncoder() | |
df[col] = le.fit_transform(df[col]) | |
label_encoders[col] = le | |
X = df[TEXT_COLUMN] | |
Y = df[LABEL_COLUMNS] | |
print("βοΈ Splitting train/test...") | |
X_train, X_test, y_train, y_test = train_test_split(X, Y, test_size=TEST_SIZE, random_state=RANDOM_STATE) | |
print("π§ Building pipeline with Logistic Regression...") | |
stop_words = "english" if USE_STOPWORDS else None | |
pipeline = Pipeline([ | |
('tfidf', TfidfVectorizer(max_features=TFIDF_MAX_FEATURES, ngram_range=NGRAM_RANGE, stop_words=stop_words)), | |
('clf', MultiOutputClassifier(LogisticRegression(random_state=RANDOM_STATE, max_iter=1000))) | |
]) | |
print("ποΈ Training...") | |
pipeline.fit(X_train, y_train) | |
model_path = os.path.join(MODEL_SAVE_DIR, "logreg_model.pkl") | |
print(f"πΎ Saving model to {model_path}") | |
joblib.dump(pipeline, model_path) | |
print(f"πΎ Saving label encoders to {LABEL_ENCODERS_PATH}") | |
joblib.dump(label_encoders, LABEL_ENCODERS_PATH) | |
tfidf_path = os.path.join(MODEL_SAVE_DIR, "tfidf_vectorizer.pkl") | |
joblib.dump(pipeline.named_steps["tfidf"], tfidf_path) | |
print("β Training complete.") | |
# --- Input Validator --- | |
def validate_sample_input(sample_input): | |
try: | |
validated = TransactionData(**sample_input) | |
print("β Input is valid.") | |
except ValidationError as e: | |
print("β Validation error:") | |
print(e.json(indent=2)) | |
# --- API Test --- | |
def test_api(sample_payload): | |
headers = {"Content-Type": "application/json"} | |
print(f"π Posting to {API_URL}") | |
response = requests.post(API_URL, headers=headers, data=json.dumps(sample_payload)) | |
print("π₯ Status Code:", response.status_code) | |
try: | |
print("π€ Response:", json.dumps(response.json(), indent=2)) | |
except Exception as e: | |
print("β Failed to parse response:", str(e)) | |
# --- Sample Payload (unchanged) --- | |
sample_payload = { | |
"transaction_data": { | |
"Transaction_Id": "TXN12345", | |
"Hit_Seq": 1, | |
"Hit_Id_List": "HIT789", | |
"Origin": "India", | |
"Designation": "Manager", | |
"Keywords": "fraud", | |
"Name": "John Doe", | |
"SWIFT_Tag": "TAG001", | |
"Currency": "INR", | |
"Entity": "ABC Ltd", | |
"Message": "Payment for services", | |
"City": "Hyderabad", | |
"Country": "India", | |
"State": "Telangana", | |
"Hit_Type": "Individual", | |
"Record_Matching_String": "John Doe", | |
"WatchList_Match_String": "Doe, John", | |
"Payment_Sender_Name": "John Doe", | |
"Payment_Reciever_Name": "Jane Smith", | |
"Swift_Message_Type": "MT103", | |
"Text_Sanction_Data": "Suspicious transfer to offshore account", | |
"Matched_Sanctioned_Entity": "John Doe", | |
"Is_Match": 1, | |
"Red_Flag_Reason": "High value transaction", | |
"Risk_Level": "High", | |
"Risk_Score": 87.5, | |
"Risk_Score_Description": "Very High", | |
"CDD_Level": "Enhanced", | |
"PEP_Status": "Yes", | |
"Value_Date": "2023-01-01", | |
"Last_Review_Date": "2023-06-01", | |
"Next_Review_Date": "2024-06-01", | |
"Sanction_Description": "OFAC List", | |
"Checker_Notes": "Urgent check required", | |
"Sanction_Context": "Payment matched with OFAC entry", | |
"Maker_Action": "Escalate", | |
"Customer_ID": 1001, | |
"Customer_Type": "Corporate", | |
"Industry": "Finance", | |
"Transaction_Date_Time": "2023-12-15T10:00:00", | |
"Transaction_Type": "Credit", | |
"Transaction_Channel": "Online", | |
"Originating_Bank": "ABC Bank", | |
"Beneficiary_Bank": "XYZ Bank", | |
"Geographic_Origin": "India", | |
"Geographic_Destination": "USA", | |
"Match_Score": 96.2, | |
"Match_Type": "Exact", | |
"Sanctions_List_Version": "2023-V5", | |
"Screening_Date_Time": "2023-12-15T09:55:00", | |
"Risk_Category": "Sanctions", | |
"Risk_Drivers": "PEP, High Value", | |
"Alert_Status": "Open", | |
"Investigation_Outcome": "Pending", | |
"Case_Owner_Analyst": "analyst1", | |
"Escalation_Level": "L2", | |
"Escalation_Date": "2023-12-16", | |
"Regulatory_Reporting_Flags": True, | |
"Audit_Trail_Timestamp": "2023-12-15T10:05:00", | |
"Source_Of_Funds": "Corporate Account", | |
"Purpose_Of_Transaction": "Service Payment", | |
"Beneficial_Owner": "John Doe", | |
"Sanctions_Exposure_History": False | |
} | |
} | |
# --- Main Entry --- | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--train", action="store_true", help="Train the model") | |
parser.add_argument("--validate", action="store_true", help="Validate sample input") | |
parser.add_argument("--test", action="store_true", help="Test prediction API") | |
args = parser.parse_args() | |
if args.train: | |
train_pipeline() | |
if args.validate: | |
validate_sample_input(sample_payload["transaction_data"]) | |
if args.test: | |
test_api(sample_payload) | |