aarya-vsdk commited on
Commit
0bc408c
·
verified ·
1 Parent(s): 2f08b12

Quickstart updated

Browse files
Files changed (1) hide show
  1. README.md +22 -10
README.md CHANGED
@@ -139,10 +139,12 @@ class TurnDetector:
139
  self.max_length = 512
140
  print("✅ Model and tokenizer loaded successfully.")
141
 
142
- def predict(self, text: str) -> str:
143
  """
144
  Predicts if a given text utterance is the end of a turn.
145
- Returns "End of Turn" or "Not End of Turn".
 
 
146
  """
147
  # Tokenize the input text
148
  inputs = self.tokenizer(
@@ -160,12 +162,20 @@ class TurnDetector:
160
 
161
  # Run inference
162
  outputs = self.session.run(None, feed_dict)
163
- logits = outputs
164
-
165
- # Get the predicted class (0 or 1)
166
- prediction_index = np.argmax(logits, axis=1)
167
-
168
- return "End of Turn" if prediction_index == 1 else "Not End of Turn"
 
 
 
 
 
 
 
 
169
 
170
  # --- Example Usage ---
171
  if __name__ == "__main__":
@@ -177,8 +187,10 @@ if __name__ == "__main__":
177
  ]
178
 
179
  for sentence in sentences:
180
- result = detector.predict(sentence)
181
- print(f"'{sentence}' -> {result}")
 
 
182
 
183
  ```
184
 
 
139
  self.max_length = 512
140
  print("✅ Model and tokenizer loaded successfully.")
141
 
142
+ def predict(self, text: str) -> tuple:
143
  """
144
  Predicts if a given text utterance is the end of a turn.
145
+ Returns (predicted_label, confidence) where:
146
+ - predicted_label: 0 for "Not End of Turn", 1 for "End of Turn"
147
+ - confidence: confidence score between 0 and 1
148
  """
149
  # Tokenize the input text
150
  inputs = self.tokenizer(
 
162
 
163
  # Run inference
164
  outputs = self.session.run(None, feed_dict)
165
+ logits = outputs[0]
166
+
167
+ probabilities = self._softmax(logits[0])
168
+ predicted_label = np.argmax(probabilities)
169
+ print(f"Logits: {logits[0]}")
170
+ print(f"Probabilities:")
171
+ print(f" Not End of Turn: {probabilities[0]:.1%}")
172
+ print(f" End of Turn: {probabilities[1]:.1%}")
173
+ print(f"Predicted Label: {predicted_label}")
174
+ print(f"Confidence: {probabilities[predicted_label]:.1%}")
175
+ print(f"Confidence (class=1): {np.max(probabilities):.1%}")
176
+ confidence = float(np.max(probabilities))
177
+
178
+ return predicted_label, confidence
179
 
180
  # --- Example Usage ---
181
  if __name__ == "__main__":
 
187
  ]
188
 
189
  for sentence in sentences:
190
+ predicted_label, confidence = detector.predict(sentence)
191
+ result = "End of Turn" if predicted_label == 1 else "Not End of Turn"
192
+ print(f"'{sentence}' -> {result} (confidence: {confidence:.3f})")
193
+ print("-" * 50)
194
 
195
  ```
196