Spaces:
Running
Running
attention mask is not set
Browse files- src/app.py +83 -29
src/app.py
CHANGED
@@ -41,12 +41,11 @@ MODEL_OPTIONS = {
|
|
41 |
}
|
42 |
}
|
43 |
|
44 |
-
# Initialize Whisper
|
45 |
-
# Create components separately
|
46 |
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en")
|
47 |
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
|
48 |
-
processor = WhisperProcessor(feature_extractor, tokenizer)
|
49 |
|
|
|
50 |
transcriber = pipeline(
|
51 |
"automatic-speech-recognition",
|
52 |
model="openai/whisper-base.en",
|
@@ -54,9 +53,48 @@ transcriber = pipeline(
|
|
54 |
stride_length_s=5,
|
55 |
device="cpu",
|
56 |
torch_dtype=torch.float32,
|
57 |
-
# Remove feature_extractor and tokenizer parameters as they're included in the model
|
58 |
generate_kwargs={
|
59 |
-
"use_cache": True
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
60 |
}
|
61 |
)
|
62 |
|
@@ -298,14 +336,23 @@ with gr.Blocks(
|
|
298 |
and patients understand potential diagnoses based on described symptoms.
|
299 |
|
300 |
### How it works:
|
301 |
-
1.
|
302 |
2. The AI will analyze your description and suggest possible diagnoses
|
303 |
3. Answer follow-up questions to refine the diagnosis
|
304 |
""")
|
305 |
|
306 |
with gr.Row():
|
307 |
with gr.Column(scale=2):
|
308 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
309 |
with gr.Row():
|
310 |
microphone = gr.Audio(
|
311 |
sources=["microphone"],
|
@@ -371,7 +418,6 @@ with gr.Blocks(
|
|
371 |
return history
|
372 |
|
373 |
try:
|
374 |
-
# Process audio stream
|
375 |
if isinstance(audio_path, tuple) and len(audio_path) == 2:
|
376 |
sample_rate, audio_array = audio_path
|
377 |
|
@@ -381,26 +427,33 @@ with gr.Blocks(
|
|
381 |
audio_array = audio_array.astype(np.float32)
|
382 |
audio_array /= np.max(np.abs(audio_array))
|
383 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
# Get transcription from Whisper
|
385 |
-
result = transcriber(
|
386 |
-
{"sampling_rate": sample_rate, "raw": audio_array},
|
387 |
-
batch_size=8,
|
388 |
-
return_timestamps=True
|
389 |
-
)
|
390 |
|
391 |
-
#
|
392 |
transcript = ""
|
393 |
if isinstance(result, dict):
|
394 |
-
transcript = result.get("text", "")
|
395 |
elif isinstance(result, str):
|
396 |
-
transcript = result
|
397 |
-
|
398 |
-
transcript = str(result[0])
|
399 |
-
else:
|
400 |
-
print(f"Unexpected transcriber result type: {type(result)}")
|
401 |
-
return history
|
402 |
-
|
403 |
-
transcript = transcript.strip()
|
404 |
if not transcript:
|
405 |
return history
|
406 |
|
@@ -470,19 +523,20 @@ with gr.Blocks(
|
|
470 |
|
471 |
try:
|
472 |
sample_rate, audio_array = audio
|
473 |
-
input_features = process_audio(audio_array, sample_rate)
|
474 |
|
475 |
-
|
|
|
|
|
|
|
|
|
476 |
|
477 |
-
#
|
478 |
if isinstance(result, dict):
|
479 |
return result.get("text", "").strip()
|
480 |
elif isinstance(result, str):
|
481 |
return result.strip()
|
482 |
-
elif isinstance(result, (list, tuple)) and len(result) > 0:
|
483 |
-
return str(result[0]).strip()
|
484 |
return ""
|
485 |
-
|
486 |
except Exception as e:
|
487 |
print(f"Transcription error: {str(e)}")
|
488 |
return ""
|
|
|
41 |
}
|
42 |
}
|
43 |
|
44 |
+
# Initialize Whisper components
|
|
|
45 |
feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en")
|
46 |
tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
|
|
|
47 |
|
48 |
+
# Configure transcription pipeline with only necessary components
|
49 |
transcriber = pipeline(
|
50 |
"automatic-speech-recognition",
|
51 |
model="openai/whisper-base.en",
|
|
|
53 |
stride_length_s=5,
|
54 |
device="cpu",
|
55 |
torch_dtype=torch.float32,
|
|
|
56 |
generate_kwargs={
|
57 |
+
"use_cache": True,
|
58 |
+
"return_timestamps": True
|
59 |
+
}
|
60 |
+
)
|
61 |
+
|
62 |
+
# Audio preprocessing function
|
63 |
+
def prepare_audio_features(audio_array, sample_rate):
|
64 |
+
"""Prepare audio features with proper format."""
|
65 |
+
# Convert stereo to mono
|
66 |
+
if audio_array.ndim > 1:
|
67 |
+
audio_array = audio_array.mean(axis=1)
|
68 |
+
audio_array = audio_array.astype(np.float32)
|
69 |
+
|
70 |
+
# Normalize audio
|
71 |
+
audio_array /= np.max(np.abs(audio_array))
|
72 |
+
|
73 |
+
# Resample to 16kHz if needed
|
74 |
+
if sample_rate != 16000:
|
75 |
+
resampler = T.Resample(orig_freq=sample_rate, new_freq=16000)
|
76 |
+
audio_tensor = torch.FloatTensor(audio_array)
|
77 |
+
audio_tensor = resampler(audio_tensor)
|
78 |
+
audio_array = audio_tensor.numpy()
|
79 |
+
|
80 |
+
# Return proper dictionary format for pipeline
|
81 |
+
return {
|
82 |
+
"raw": audio_array,
|
83 |
+
"sampling_rate": 16000
|
84 |
+
}
|
85 |
+
|
86 |
+
# Update transcriber configuration
|
87 |
+
transcriber = pipeline(
|
88 |
+
"automatic-speech-recognition",
|
89 |
+
model="openai/whisper-base.en",
|
90 |
+
chunk_length_s=30,
|
91 |
+
stride_length_s=5,
|
92 |
+
device="cpu",
|
93 |
+
torch_dtype=torch.float32,
|
94 |
+
feature_extractor=feature_extractor,
|
95 |
+
generate_kwargs={
|
96 |
+
"use_cache": True,
|
97 |
+
"return_timestamps": True
|
98 |
}
|
99 |
)
|
100 |
|
|
|
336 |
and patients understand potential diagnoses based on described symptoms.
|
337 |
|
338 |
### How it works:
|
339 |
+
1. Either click the record button and describe your symptoms or type them into the textbox
|
340 |
2. The AI will analyze your description and suggest possible diagnoses
|
341 |
3. Answer follow-up questions to refine the diagnosis
|
342 |
""")
|
343 |
|
344 |
with gr.Row():
|
345 |
with gr.Column(scale=2):
|
346 |
+
# Add text input above microphone
|
347 |
+
with gr.Row():
|
348 |
+
text_input = gr.Textbox(
|
349 |
+
label="Type your symptoms",
|
350 |
+
placeholder="Or type your symptoms here...",
|
351 |
+
lines=3
|
352 |
+
)
|
353 |
+
submit_btn = gr.Button("Submit", variant="primary")
|
354 |
+
|
355 |
+
# Existing microphone row
|
356 |
with gr.Row():
|
357 |
microphone = gr.Audio(
|
358 |
sources=["microphone"],
|
|
|
418 |
return history
|
419 |
|
420 |
try:
|
|
|
421 |
if isinstance(audio_path, tuple) and len(audio_path) == 2:
|
422 |
sample_rate, audio_array = audio_path
|
423 |
|
|
|
427 |
audio_array = audio_array.astype(np.float32)
|
428 |
audio_array /= np.max(np.abs(audio_array))
|
429 |
|
430 |
+
# Ensure correct sampling rate
|
431 |
+
if sample_rate != 16000:
|
432 |
+
resampler = T.Resample(
|
433 |
+
orig_freq=sample_rate,
|
434 |
+
new_freq=16000
|
435 |
+
)
|
436 |
+
audio_tensor = torch.FloatTensor(audio_array)
|
437 |
+
audio_tensor = resampler(audio_tensor)
|
438 |
+
audio_array = audio_tensor.numpy()
|
439 |
+
sample_rate = 16000
|
440 |
+
|
441 |
+
# Format input dictionary exactly as required
|
442 |
+
transcriber_input = {
|
443 |
+
"raw": audio_array,
|
444 |
+
"sampling_rate": sample_rate
|
445 |
+
}
|
446 |
+
|
447 |
# Get transcription from Whisper
|
448 |
+
result = transcriber(transcriber_input)
|
|
|
|
|
|
|
|
|
449 |
|
450 |
+
# Extract text from result
|
451 |
transcript = ""
|
452 |
if isinstance(result, dict):
|
453 |
+
transcript = result.get("text", "").strip()
|
454 |
elif isinstance(result, str):
|
455 |
+
transcript = result.strip()
|
456 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
457 |
if not transcript:
|
458 |
return history
|
459 |
|
|
|
523 |
|
524 |
try:
|
525 |
sample_rate, audio_array = audio
|
|
|
526 |
|
527 |
+
# Process audio and get proper format
|
528 |
+
inputs = prepare_audio_features(audio_array, sample_rate)
|
529 |
+
|
530 |
+
# Pass to transcriber
|
531 |
+
result = transcriber(inputs)
|
532 |
|
533 |
+
# Extract text from result
|
534 |
if isinstance(result, dict):
|
535 |
return result.get("text", "").strip()
|
536 |
elif isinstance(result, str):
|
537 |
return result.strip()
|
|
|
|
|
538 |
return ""
|
539 |
+
|
540 |
except Exception as e:
|
541 |
print(f"Transcription error: {str(e)}")
|
542 |
return ""
|