feat: save and display abuse prediction results with timestamped filenames
Browse files- train_abuse_model.py +27 -9
- utils.py +33 -0
train_abuse_model.py
CHANGED
@@ -6,6 +6,7 @@ import io
|
|
6 |
import os
|
7 |
import time
|
8 |
import gradio as gr # β
required for progress bar
|
|
|
9 |
from pathlib import Path
|
10 |
import queue
|
11 |
|
@@ -43,11 +44,26 @@ from utils import (
|
|
43 |
AbuseDataset
|
44 |
)
|
45 |
|
|
|
|
|
46 |
|
47 |
PERSIST_DIR = Path("/home/user/app")
|
48 |
MODEL_DIR = PERSIST_DIR / "saved_model"
|
49 |
LOG_FILE = PERSIST_DIR / "training.log"
|
50 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
51 |
# configure logging
|
52 |
log_buffer = io.StringIO()
|
53 |
logging.basicConfig(
|
@@ -100,20 +116,19 @@ def evaluate_model_with_thresholds(trainer, test_dataset):
|
|
100 |
|
101 |
logger.info("\nπ Final Evaluation Report (multi-class per label):\n")
|
102 |
yield "\nπ Final Evaluation Report (multi-class per label):\n "
|
103 |
-
|
104 |
-
true_str,
|
105 |
-
final_pred_str,
|
106 |
-
labels=["no", "plausibly", "yes"],
|
107 |
-
digits=3,
|
108 |
-
zero_division=0
|
109 |
-
))
|
110 |
-
yield classification_report(
|
111 |
true_str,
|
112 |
final_pred_str,
|
113 |
labels=["no", "plausibly", "yes"],
|
114 |
digits=3,
|
115 |
zero_division=0
|
116 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
def load_saved_model_and_tokenizer():
|
118 |
tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_DIR)
|
119 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR).to(device)
|
@@ -296,13 +311,16 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
|
|
296 |
progress(1.0)
|
297 |
yield "β
Progress: 100%\n"
|
298 |
|
|
|
|
|
|
|
299 |
# Save the model and tokenizer
|
300 |
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
301 |
model.save_pretrained(MODEL_DIR)
|
302 |
tokenizer.save_pretrained(MODEL_DIR)
|
303 |
|
304 |
logger.info(" Training completed and model saved.")
|
305 |
-
yield "π Training complete! Model saved.\n"
|
306 |
|
307 |
except Exception as e:
|
308 |
logger.exception( f"β Training failed: {e}")
|
|
|
6 |
import os
|
7 |
import time
|
8 |
import gradio as gr # β
required for progress bar
|
9 |
+
from datetime import datetime
|
10 |
from pathlib import Path
|
11 |
import queue
|
12 |
|
|
|
44 |
AbuseDataset
|
45 |
)
|
46 |
|
47 |
+
# Create evaluation results directory if it doesn't exist
|
48 |
+
Path("/home/user/app/results_eval").mkdir(parents=True, exist_ok=True)
|
49 |
|
50 |
PERSIST_DIR = Path("/home/user/app")
|
51 |
MODEL_DIR = PERSIST_DIR / "saved_model"
|
52 |
LOG_FILE = PERSIST_DIR / "training.log"
|
53 |
|
54 |
+
|
55 |
+
# Save and print evaluation results
|
56 |
+
def save_and_yield_eval(report: str):
|
57 |
+
# Generate versioned filename using timestamp
|
58 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
59 |
+
|
60 |
+
eval_filename = f"eval_report_{timestamp}.txt"
|
61 |
+
eval_filepath = Path("/home/user/app/results_eval") / eval_filename
|
62 |
+
|
63 |
+
with open(eval_filepath, "w") as f:
|
64 |
+
f.write(report)
|
65 |
+
yield f"π Evaluation saved to: {eval_filepath.name}"
|
66 |
+
yield report
|
67 |
# configure logging
|
68 |
log_buffer = io.StringIO()
|
69 |
logging.basicConfig(
|
|
|
116 |
|
117 |
logger.info("\nπ Final Evaluation Report (multi-class per label):\n")
|
118 |
yield "\nπ Final Evaluation Report (multi-class per label):\n "
|
119 |
+
report = classification_report(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
true_str,
|
121 |
final_pred_str,
|
122 |
labels=["no", "plausibly", "yes"],
|
123 |
digits=3,
|
124 |
zero_division=0
|
125 |
)
|
126 |
+
logger.info(report)
|
127 |
+
yield from save_and_yield_eval(report)
|
128 |
+
|
129 |
+
# Save to file
|
130 |
+
with open("/home/user/app/results_eval/eval_report.txt", "w") as f:
|
131 |
+
f.write(report)
|
132 |
def load_saved_model_and_tokenizer():
|
133 |
tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_DIR)
|
134 |
model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR).to(device)
|
|
|
311 |
progress(1.0)
|
312 |
yield "β
Progress: 100%\n"
|
313 |
|
314 |
+
PERSIST_DIR = Path("/home/user/app")
|
315 |
+
MODEL_DIR = PERSIST_DIR / "saved_model"
|
316 |
+
|
317 |
# Save the model and tokenizer
|
318 |
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
319 |
model.save_pretrained(MODEL_DIR)
|
320 |
tokenizer.save_pretrained(MODEL_DIR)
|
321 |
|
322 |
logger.info(" Training completed and model saved.")
|
323 |
+
yield f"π Training complete! Model saved on {MODEL_DIR.resolve()}.\n"
|
324 |
|
325 |
except Exception as e:
|
326 |
logger.exception( f"β Training failed: {e}")
|
utils.py
CHANGED
@@ -2,6 +2,39 @@ import numpy as np
|
|
2 |
from sklearn.metrics import precision_recall_fscore_support
|
3 |
import torch
|
4 |
from torch.utils.data import Dataset
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
|
6 |
# Custom Dataset class
|
7 |
class AbuseDataset(Dataset):
|
|
|
2 |
from sklearn.metrics import precision_recall_fscore_support
|
3 |
import torch
|
4 |
from torch.utils.data import Dataset
|
5 |
+
from datetime import datetime
|
6 |
+
from pathlib import Path
|
7 |
+
import logging
|
8 |
+
|
9 |
+
def save_and_return_prediction(enriched_input: str, predicted_labels: list):
|
10 |
+
Path("/home/user/app/results_pred").mkdir(parents=True, exist_ok=True)
|
11 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
12 |
+
pred_filename = f"prediction_{timestamp}.txt"
|
13 |
+
pred_filepath = Path("/home/user/app/results_pred") / pred_filename
|
14 |
+
|
15 |
+
with open(pred_filepath, "w") as f:
|
16 |
+
f.write("===== Enriched Input =====\n")
|
17 |
+
f.write(enriched_input + "\n\n")
|
18 |
+
f.write("===== Predicted Labels =====\n")
|
19 |
+
f.write(", ".join(predicted_labels))
|
20 |
+
|
21 |
+
return str(pred_filepath.name)
|
22 |
+
|
23 |
+
# Save and print evaluation results
|
24 |
+
def save_and_yield_eval(report: str):
|
25 |
+
# Create evaluation results directories if they don't exist
|
26 |
+
Path("/home/user/app/results_eval").mkdir(parents=True, exist_ok=True)
|
27 |
+
|
28 |
+
# Generate versioned filename using timestamp
|
29 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
|
30 |
+
|
31 |
+
eval_filename = f"eval_report_{timestamp}.txt"
|
32 |
+
eval_filepath = Path("/home/user/app/results_eval") / eval_filename
|
33 |
+
|
34 |
+
with open(eval_filepath, "w") as f:
|
35 |
+
f.write(report)
|
36 |
+
yield f"π Evaluation saved to: {eval_filepath.name}"
|
37 |
+
yield report
|
38 |
|
39 |
# Custom Dataset class
|
40 |
class AbuseDataset(Dataset):
|