Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -108,10 +108,19 @@ def generate_ocr(method, image):
|
|
108 |
outputs = model(**inputs)
|
109 |
logits = outputs.logits # Get raw logits
|
110 |
probs = F.softmax(logits, dim=1) # Convert logits to probabilities
|
111 |
-
predicted_class = torch.argmax(probs, dim=1).item() # Get the predicted class (0 or 1)
|
112 |
|
113 |
-
#
|
114 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
115 |
|
116 |
# Save results using external function
|
117 |
save_results_to_repo(text_output, label)
|
|
|
108 |
outputs = model(**inputs)
|
109 |
logits = outputs.logits # Get raw logits
|
110 |
probs = F.softmax(logits, dim=1) # Convert logits to probabilities
|
|
|
111 |
|
112 |
+
# Print probabilities for debugging
|
113 |
+
print(f"Probabilities: {probs}")
|
114 |
+
|
115 |
+
# Extract probability values
|
116 |
+
not_spam_prob = probs[0, 0].item()
|
117 |
+
spam_prob = probs[0, 1].item()
|
118 |
+
|
119 |
+
# Ensure classification is based on correct threshold
|
120 |
+
if spam_prob > not_spam_prob:
|
121 |
+
label = "Spam"
|
122 |
+
else:
|
123 |
+
label = "Not Spam"
|
124 |
|
125 |
# Save results using external function
|
126 |
save_results_to_repo(text_output, label)
|