| | import os |
| | import sys |
| | import argparse |
| | import numpy as np |
| | import tensorflow as tf |
| | from sklearn.model_selection import train_test_split |
| |
|
| | sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) |
| |
|
| | from src.model import create_malconv_model |
| | from src.utils import ( |
| | configure_gpu_memory, |
| | plot_training_history, |
| | evaluate_model, |
| | get_file_paths_and_labels, |
| | data_generator, |
| | read_binary_file |
| | ) |
| |
|
| | def train_malconv(data_source, |
| | epochs=10, |
| | batch_size=256, |
| | max_length=2_000_000, |
| | validation_split=0.2, |
| | save_path="models/malconv_model.h5"): |
| | """ |
| | MalConv ๋ชจ๋ธ ํ๋ จ (๋ฐ์ดํฐ ์ ๋๋ ์ดํฐ ์ฌ์ฉ) |
| | |
| | Args: |
| | data_source: (malware_dir, benign_dir) ํํ |
| | epochs: ํ๋ จ ์ํฌํฌ ์ |
| | batch_size: ๋ฐฐ์น ํฌ๊ธฐ |
| | max_length: ์ต๋ ์
๋ ฅ ๊ธธ์ด (2MB) |
| | validation_split: ๊ฒ์ฆ ๋ฐ์ดํฐ ๋น์จ |
| | save_path: ๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก |
| | """ |
| | |
| | print("=" * 60) |
| | print("MalConv ๋ชจ๋ธ ํ๋ จ ์์ (๋ฐ์ดํฐ ์ ๋๋ ์ดํฐ ๋ชจ๋)") |
| | print("=" * 60) |
| | |
| | |
| | configure_gpu_memory() |
| | |
| | |
| | if isinstance(data_source, tuple) and len(data_source) == 2: |
| | malware_dir, benign_dir = data_source |
| | filepaths, labels = get_file_paths_and_labels(malware_dir, benign_dir) |
| | else: |
| | raise ValueError("data_source๋ (malware_dir, benign_dir) ํํ์ด์ด์ผ ํฉ๋๋ค.") |
| |
|
| | |
| | filepaths_train, filepaths_val, labels_train, labels_val = train_test_split( |
| | filepaths, labels, test_size=validation_split, random_state=42, stratify=labels |
| | ) |
| | |
| | print(f"์ด ๋ฐ์ดํฐ: {len(filepaths)}") |
| | print(f"ํ๋ จ ๋ฐ์ดํฐ: {len(filepaths_train)}, ๊ฒ์ฆ ๋ฐ์ดํฐ: {len(filepaths_val)}") |
| |
|
| | |
| | train_gen = data_generator(filepaths_train, labels_train, batch_size, max_length) |
| | val_gen = data_generator(filepaths_val, labels_val, batch_size, max_length, shuffle=False) |
| |
|
| | |
| | print("MalConv ๋ชจ๋ธ ์์ฑ ์ค...") |
| | model = create_malconv_model(max_length) |
| | |
| | |
| | dummy_input = np.zeros((1, max_length), dtype=np.uint8) |
| | _ = model(dummy_input) |
| | |
| | print("\n=== ๋ชจ๋ธ ์ํคํ
์ฒ ===") |
| | model.summary() |
| | print(f"์ด ํ๋ผ๋ฏธํฐ ์: {model.count_params():,}") |
| | |
| | |
| | callbacks = [ |
| | tf.keras.callbacks.EarlyStopping( |
| | monitor='val_loss', |
| | patience=5, |
| | restore_best_weights=True, |
| | verbose=1 |
| | ), |
| | tf.keras.callbacks.ModelCheckpoint( |
| | save_path, |
| | monitor='val_auc', |
| | save_best_only=True, |
| | verbose=1, |
| | mode='max' |
| | ) |
| | ] |
| | |
| | |
| | print(f"\n=== ํ๋ จ ์์ ===") |
| | print(f"๋ฐฐ์น ํฌ๊ธฐ: {batch_size}") |
| | print(f"์ํฌํฌ: {epochs}") |
| | |
| | history = model.fit( |
| | train_gen, |
| | steps_per_epoch=len(filepaths_train) // batch_size, |
| | epochs=epochs, |
| | validation_data=val_gen, |
| | validation_steps=len(filepaths_val) // batch_size, |
| | callbacks=callbacks, |
| | verbose=1 |
| | ) |
| | |
| | |
| | print("\n=== ์ต์ข
ํ๊ฐ ===") |
| | num_eval_samples = min(len(filepaths_val), 1024) |
| | X_eval = np.array([read_binary_file(fp, max_length) for fp in filepaths_val[:num_eval_samples]]) |
| | y_eval = np.array(labels_val[:num_eval_samples]) |
| | |
| | if X_eval.size > 0: |
| | results = evaluate_model(model, X_eval, y_eval, batch_size=batch_size//2) |
| | else: |
| | print("ํ๊ฐํ ๋ฐ์ดํฐ๊ฐ ์์ต๋๋ค.") |
| | results = {} |
| |
|
| | |
| | plot_training_history(history) |
| | |
| | print(f"\n๋ชจ๋ธ์ด ์ ์ฅ๋์์ต๋๋ค: {save_path}") |
| | |
| | return model, history, results |
| |
|
| | def main(): |
| | parser = argparse.ArgumentParser(description='MalConv ๋ชจ๋ธ ํ๋ จ') |
| | |
| | |
| | parser.add_argument('--malware_dir', required=True, help='์
์ฑ์ฝ๋ ๋๋ ํ ๋ฆฌ') |
| | parser.add_argument('--benign_dir', required=True, help='์ ์ํ์ผ ๋๋ ํ ๋ฆฌ') |
| | |
| | |
| | parser.add_argument('--epochs', type=int, default=20, help='์ํฌํฌ ์') |
| | parser.add_argument('--batch_size', type=int, default=64, help='๋ฐฐ์น ํฌ๊ธฐ') |
| | parser.add_argument('--max_length', type=int, default=2_000_000, help='์ต๋ ์
๋ ฅ ๊ธธ์ด') |
| | parser.add_argument('--save_path', default='models/malconv_model.h5', help='๋ชจ๋ธ ์ ์ฅ ๊ฒฝ๋ก') |
| | |
| | args = parser.parse_args() |
| | |
| | data_source = (args.malware_dir, args.benign_dir) |
| | |
| | |
| | os.makedirs(os.path.dirname(args.save_path), exist_ok=True) |
| | |
| | |
| | train_malconv( |
| | data_source=data_source, |
| | epochs=args.epochs, |
| | batch_size=args.batch_size, |
| | max_length=args.max_length, |
| | save_path=args.save_path |
| | ) |
| |
|
| | if __name__ == "__main__": |
| | main() |