rshakked commited on
Commit
980da81
Β·
1 Parent(s): 1d998d0

feat: save and display abuse prediction results with timestamped filenames

Browse files
Files changed (2) hide show
  1. train_abuse_model.py +27 -9
  2. utils.py +33 -0
train_abuse_model.py CHANGED
@@ -6,6 +6,7 @@ import io
6
  import os
7
  import time
8
  import gradio as gr # βœ… required for progress bar
 
9
  from pathlib import Path
10
  import queue
11
 
@@ -43,11 +44,26 @@ from utils import (
43
  AbuseDataset
44
  )
45
 
 
 
46
 
47
  PERSIST_DIR = Path("/home/user/app")
48
  MODEL_DIR = PERSIST_DIR / "saved_model"
49
  LOG_FILE = PERSIST_DIR / "training.log"
50
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  # configure logging
52
  log_buffer = io.StringIO()
53
  logging.basicConfig(
@@ -100,20 +116,19 @@ def evaluate_model_with_thresholds(trainer, test_dataset):
100
 
101
  logger.info("\nπŸ“Š Final Evaluation Report (multi-class per label):\n")
102
  yield "\nπŸ“Š Final Evaluation Report (multi-class per label):\n "
103
- logger.info(classification_report(
104
- true_str,
105
- final_pred_str,
106
- labels=["no", "plausibly", "yes"],
107
- digits=3,
108
- zero_division=0
109
- ))
110
- yield classification_report(
111
  true_str,
112
  final_pred_str,
113
  labels=["no", "plausibly", "yes"],
114
  digits=3,
115
  zero_division=0
116
  )
 
 
 
 
 
 
117
  def load_saved_model_and_tokenizer():
118
  tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_DIR)
119
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR).to(device)
@@ -296,13 +311,16 @@ def run_training(progress=gr.Progress(track_tqdm=True)):
296
  progress(1.0)
297
  yield "βœ… Progress: 100%\n"
298
 
 
 
 
299
  # Save the model and tokenizer
300
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
301
  model.save_pretrained(MODEL_DIR)
302
  tokenizer.save_pretrained(MODEL_DIR)
303
 
304
  logger.info(" Training completed and model saved.")
305
- yield "πŸŽ‰ Training complete! Model saved.\n"
306
 
307
  except Exception as e:
308
  logger.exception( f"❌ Training failed: {e}")
 
6
  import os
7
  import time
8
  import gradio as gr # βœ… required for progress bar
9
+ from datetime import datetime
10
  from pathlib import Path
11
  import queue
12
 
 
44
  AbuseDataset
45
  )
46
 
47
+ # Create evaluation results directory if it doesn't exist
48
+ Path("/home/user/app/results_eval").mkdir(parents=True, exist_ok=True)
49
 
50
  PERSIST_DIR = Path("/home/user/app")
51
  MODEL_DIR = PERSIST_DIR / "saved_model"
52
  LOG_FILE = PERSIST_DIR / "training.log"
53
 
54
+
55
+ # Save and print evaluation results
56
+ def save_and_yield_eval(report: str):
57
+ # Generate versioned filename using timestamp
58
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
59
+
60
+ eval_filename = f"eval_report_{timestamp}.txt"
61
+ eval_filepath = Path("/home/user/app/results_eval") / eval_filename
62
+
63
+ with open(eval_filepath, "w") as f:
64
+ f.write(report)
65
+ yield f"πŸ“„ Evaluation saved to: {eval_filepath.name}"
66
+ yield report
67
  # configure logging
68
  log_buffer = io.StringIO()
69
  logging.basicConfig(
 
116
 
117
  logger.info("\nπŸ“Š Final Evaluation Report (multi-class per label):\n")
118
  yield "\nπŸ“Š Final Evaluation Report (multi-class per label):\n "
119
+ report = classification_report(
 
 
 
 
 
 
 
120
  true_str,
121
  final_pred_str,
122
  labels=["no", "plausibly", "yes"],
123
  digits=3,
124
  zero_division=0
125
  )
126
+ logger.info(report)
127
+ yield from save_and_yield_eval(report)
128
+
129
+ # Save to file
130
+ with open("/home/user/app/results_eval/eval_report.txt", "w") as f:
131
+ f.write(report)
132
  def load_saved_model_and_tokenizer():
133
  tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_DIR)
134
  model = AutoModelForSequenceClassification.from_pretrained(MODEL_DIR).to(device)
 
311
  progress(1.0)
312
  yield "βœ… Progress: 100%\n"
313
 
314
+ PERSIST_DIR = Path("/home/user/app")
315
+ MODEL_DIR = PERSIST_DIR / "saved_model"
316
+
317
  # Save the model and tokenizer
318
  MODEL_DIR.mkdir(parents=True, exist_ok=True)
319
  model.save_pretrained(MODEL_DIR)
320
  tokenizer.save_pretrained(MODEL_DIR)
321
 
322
  logger.info(" Training completed and model saved.")
323
+ yield f"πŸŽ‰ Training complete! Model saved on {MODEL_DIR.resolve()}.\n"
324
 
325
  except Exception as e:
326
  logger.exception( f"❌ Training failed: {e}")
utils.py CHANGED
@@ -2,6 +2,39 @@ import numpy as np
2
  from sklearn.metrics import precision_recall_fscore_support
3
  import torch
4
  from torch.utils.data import Dataset
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  # Custom Dataset class
7
  class AbuseDataset(Dataset):
 
2
  from sklearn.metrics import precision_recall_fscore_support
3
  import torch
4
  from torch.utils.data import Dataset
5
+ from datetime import datetime
6
+ from pathlib import Path
7
+ import logging
8
+
9
+ def save_and_return_prediction(enriched_input: str, predicted_labels: list):
10
+ Path("/home/user/app/results_pred").mkdir(parents=True, exist_ok=True)
11
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
12
+ pred_filename = f"prediction_{timestamp}.txt"
13
+ pred_filepath = Path("/home/user/app/results_pred") / pred_filename
14
+
15
+ with open(pred_filepath, "w") as f:
16
+ f.write("===== Enriched Input =====\n")
17
+ f.write(enriched_input + "\n\n")
18
+ f.write("===== Predicted Labels =====\n")
19
+ f.write(", ".join(predicted_labels))
20
+
21
+ return str(pred_filepath.name)
22
+
23
+ # Save and print evaluation results
24
+ def save_and_yield_eval(report: str):
25
+ # Create evaluation results directories if they don't exist
26
+ Path("/home/user/app/results_eval").mkdir(parents=True, exist_ok=True)
27
+
28
+ # Generate versioned filename using timestamp
29
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
30
+
31
+ eval_filename = f"eval_report_{timestamp}.txt"
32
+ eval_filepath = Path("/home/user/app/results_eval") / eval_filename
33
+
34
+ with open(eval_filepath, "w") as f:
35
+ f.write(report)
36
+ yield f"πŸ“„ Evaluation saved to: {eval_filepath.name}"
37
+ yield report
38
 
39
  # Custom Dataset class
40
  class AbuseDataset(Dataset):