winamnd commited on
Commit
4d58fc5
·
verified ·
1 Parent(s): 39fa45a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +6 -5
app.py CHANGED
@@ -95,7 +95,7 @@ def generate_ocr(method, image):
95
 
96
  # Join extracted text into a single string
97
  text_output = " ".join(extracted_text).strip()
98
-
99
  # If no text detected, return early
100
  if len(text_output) == 0:
101
  return "No text detected!", "Cannot classify"
@@ -106,11 +106,12 @@ def generate_ocr(method, image):
106
  # Perform inference
107
  with torch.no_grad():
108
  outputs = model(**inputs)
109
- probs = F.softmax(outputs.logits, dim=1) # Convert logits to probabilities
110
- spam_prob = probs[0][1].item() # Probability of Spam
 
111
 
112
- # Adjust classification based on threshold
113
- label = "Spam" if spam_prob > 0.6 else "Not Spam"
114
 
115
  # Save results using external function
116
  save_results_to_repo(text_output, label)
 
95
 
96
  # Join extracted text into a single string
97
  text_output = " ".join(extracted_text).strip()
98
+
99
  # If no text detected, return early
100
  if len(text_output) == 0:
101
  return "No text detected!", "Cannot classify"
 
106
  # Perform inference
107
  with torch.no_grad():
108
  outputs = model(**inputs)
109
+ logits = outputs.logits # Get raw logits
110
+ probs = F.softmax(logits, dim=1) # Convert logits to probabilities
111
+ predicted_class = torch.argmax(probs, dim=1).item() # Get the predicted class (0 or 1)
112
 
113
+ # Map predicted class correctly
114
+ label = "Spam" if predicted_class == 1 else "Not Spam"
115
 
116
  # Save results using external function
117
  save_results_to_repo(text_output, label)