Spaces:
Sleeping
Sleeping
Update tasks/text.py
Browse files- tasks/text.py +38 -1
tasks/text.py
CHANGED
|
@@ -62,6 +62,43 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 62 |
tracker.start_task("inference")
|
| 63 |
|
| 64 |
#--------------------------------------------------------------------------------------------
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 65 |
|
| 66 |
|
| 67 |
#predictions = xgb.predict(embeddings)
|
|
@@ -80,7 +117,7 @@ async def evaluate_text(request: TextEvaluationRequest):
|
|
| 80 |
emissions_data = tracker.stop_task()
|
| 81 |
|
| 82 |
# Calculate accuracy
|
| 83 |
-
accuracy = accuracy_score(true_labels,
|
| 84 |
|
| 85 |
# Prepare results dictionary
|
| 86 |
results = {
|
|
|
|
| 62 |
tracker.start_task("inference")
|
| 63 |
|
| 64 |
#--------------------------------------------------------------------------------------------
|
| 65 |
+
# Load a pre-trained Sentence-BERT model
|
| 66 |
+
model = SentenceTransformer('sentence-transformers/all-MPNET-base-v2', device='cpu')
|
| 67 |
+
# Generate sentence embeddings
|
| 68 |
+
sentence_embeddings = model.encode(test_dataset["quote"])
|
| 69 |
+
|
| 70 |
+
#load the models
|
| 71 |
+
with open("xgb_bin.pkl","rb") as f:
|
| 72 |
+
xgb_bin = pickle.load(f)
|
| 73 |
+
|
| 74 |
+
with open("xgb_multi.pkl","rb") as f:
|
| 75 |
+
xgb_multi = pickle.load(f)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
X_train = test_dataset["quote"]
|
| 79 |
+
|
| 80 |
+
y_train = test_dataset["label"].copy()
|
| 81 |
+
|
| 82 |
+
#binary
|
| 83 |
+
y_train_binary = y_train.copy()
|
| 84 |
+
y_train_binary[y_train_binary != 0] = 1
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
#multi class
|
| 88 |
+
X_train_multi = X_train[y_train != 0]
|
| 89 |
+
|
| 90 |
+
y_train_multi = y_train[y_train != 0]
|
| 91 |
+
|
| 92 |
+
#predictions
|
| 93 |
+
y_pred_bin = xgb_bin.predict(X_train)
|
| 94 |
+
|
| 95 |
+
y_pred_multi = xgb_multi.predict(X_train_multi) + 1
|
| 96 |
+
|
| 97 |
+
y_pred_bin[y_pred_bin==1] = y_pred_multi
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
|
| 102 |
|
| 103 |
|
| 104 |
#predictions = xgb.predict(embeddings)
|
|
|
|
| 117 |
emissions_data = tracker.stop_task()
|
| 118 |
|
| 119 |
# Calculate accuracy
|
| 120 |
+
accuracy = accuracy_score(true_labels, y_true))
|
| 121 |
|
| 122 |
# Prepare results dictionary
|
| 123 |
results = {
|