gpaasch commited on
Commit
ec82b9a
·
1 Parent(s): 54fa492

attention mask is not set

Browse files
Files changed (1) hide show
  1. src/app.py +83 -29
src/app.py CHANGED
@@ -41,12 +41,11 @@ MODEL_OPTIONS = {
41
  }
42
  }
43
 
44
- # Initialize Whisper with proper configuration
45
- # Create components separately
46
  feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en")
47
  tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
48
- processor = WhisperProcessor(feature_extractor, tokenizer)
49
 
 
50
  transcriber = pipeline(
51
  "automatic-speech-recognition",
52
  model="openai/whisper-base.en",
@@ -54,9 +53,48 @@ transcriber = pipeline(
54
  stride_length_s=5,
55
  device="cpu",
56
  torch_dtype=torch.float32,
57
- # Remove feature_extractor and tokenizer parameters as they're included in the model
58
  generate_kwargs={
59
- "use_cache": True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
60
  }
61
  )
62
 
@@ -298,14 +336,23 @@ with gr.Blocks(
298
  and patients understand potential diagnoses based on described symptoms.
299
 
300
  ### How it works:
301
- 1. Click the microphone button and describe your symptoms
302
  2. The AI will analyze your description and suggest possible diagnoses
303
  3. Answer follow-up questions to refine the diagnosis
304
  """)
305
 
306
  with gr.Row():
307
  with gr.Column(scale=2):
308
- # Moved microphone row above chatbot
 
 
 
 
 
 
 
 
 
309
  with gr.Row():
310
  microphone = gr.Audio(
311
  sources=["microphone"],
@@ -371,7 +418,6 @@ with gr.Blocks(
371
  return history
372
 
373
  try:
374
- # Process audio stream
375
  if isinstance(audio_path, tuple) and len(audio_path) == 2:
376
  sample_rate, audio_array = audio_path
377
 
@@ -381,26 +427,33 @@ with gr.Blocks(
381
  audio_array = audio_array.astype(np.float32)
382
  audio_array /= np.max(np.abs(audio_array))
383
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
384
  # Get transcription from Whisper
385
- result = transcriber(
386
- {"sampling_rate": sample_rate, "raw": audio_array},
387
- batch_size=8,
388
- return_timestamps=True
389
- )
390
 
391
- # Handle different result types
392
  transcript = ""
393
  if isinstance(result, dict):
394
- transcript = result.get("text", "")
395
  elif isinstance(result, str):
396
- transcript = result
397
- elif isinstance(result, (list, tuple)) and len(result) > 0:
398
- transcript = str(result[0])
399
- else:
400
- print(f"Unexpected transcriber result type: {type(result)}")
401
- return history
402
-
403
- transcript = transcript.strip()
404
  if not transcript:
405
  return history
406
 
@@ -470,19 +523,20 @@ with gr.Blocks(
470
 
471
  try:
472
  sample_rate, audio_array = audio
473
- input_features = process_audio(audio_array, sample_rate)
474
 
475
- result = transcriber(input_features)
 
 
 
 
476
 
477
- # Handle different result types
478
  if isinstance(result, dict):
479
  return result.get("text", "").strip()
480
  elif isinstance(result, str):
481
  return result.strip()
482
- elif isinstance(result, (list, tuple)) and len(result) > 0:
483
- return str(result[0]).strip()
484
  return ""
485
-
486
  except Exception as e:
487
  print(f"Transcription error: {str(e)}")
488
  return ""
 
41
  }
42
  }
43
 
44
+ # Initialize Whisper components
 
45
  feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en")
46
  tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
 
47
 
48
+ # Configure transcription pipeline with only necessary components
49
  transcriber = pipeline(
50
  "automatic-speech-recognition",
51
  model="openai/whisper-base.en",
 
53
  stride_length_s=5,
54
  device="cpu",
55
  torch_dtype=torch.float32,
 
56
  generate_kwargs={
57
+ "use_cache": True,
58
+ "return_timestamps": True
59
+ }
60
+ )
61
+
62
+ # Audio preprocessing function
63
+ def prepare_audio_features(audio_array, sample_rate):
64
+ """Prepare audio features with proper format."""
65
+ # Convert stereo to mono
66
+ if audio_array.ndim > 1:
67
+ audio_array = audio_array.mean(axis=1)
68
+ audio_array = audio_array.astype(np.float32)
69
+
70
+ # Normalize audio
71
+ audio_array /= np.max(np.abs(audio_array))
72
+
73
+ # Resample to 16kHz if needed
74
+ if sample_rate != 16000:
75
+ resampler = T.Resample(orig_freq=sample_rate, new_freq=16000)
76
+ audio_tensor = torch.FloatTensor(audio_array)
77
+ audio_tensor = resampler(audio_tensor)
78
+ audio_array = audio_tensor.numpy()
79
+
80
+ # Return proper dictionary format for pipeline
81
+ return {
82
+ "raw": audio_array,
83
+ "sampling_rate": 16000
84
+ }
85
+
86
+ # Update transcriber configuration
87
+ transcriber = pipeline(
88
+ "automatic-speech-recognition",
89
+ model="openai/whisper-base.en",
90
+ chunk_length_s=30,
91
+ stride_length_s=5,
92
+ device="cpu",
93
+ torch_dtype=torch.float32,
94
+ feature_extractor=feature_extractor,
95
+ generate_kwargs={
96
+ "use_cache": True,
97
+ "return_timestamps": True
98
  }
99
  )
100
 
 
336
  and patients understand potential diagnoses based on described symptoms.
337
 
338
  ### How it works:
339
+ 1. Either click the record button and describe your symptoms or type them into the textbox
340
  2. The AI will analyze your description and suggest possible diagnoses
341
  3. Answer follow-up questions to refine the diagnosis
342
  """)
343
 
344
  with gr.Row():
345
  with gr.Column(scale=2):
346
+ # Add text input above microphone
347
+ with gr.Row():
348
+ text_input = gr.Textbox(
349
+ label="Type your symptoms",
350
+ placeholder="Or type your symptoms here...",
351
+ lines=3
352
+ )
353
+ submit_btn = gr.Button("Submit", variant="primary")
354
+
355
+ # Existing microphone row
356
  with gr.Row():
357
  microphone = gr.Audio(
358
  sources=["microphone"],
 
418
  return history
419
 
420
  try:
 
421
  if isinstance(audio_path, tuple) and len(audio_path) == 2:
422
  sample_rate, audio_array = audio_path
423
 
 
427
  audio_array = audio_array.astype(np.float32)
428
  audio_array /= np.max(np.abs(audio_array))
429
 
430
+ # Ensure correct sampling rate
431
+ if sample_rate != 16000:
432
+ resampler = T.Resample(
433
+ orig_freq=sample_rate,
434
+ new_freq=16000
435
+ )
436
+ audio_tensor = torch.FloatTensor(audio_array)
437
+ audio_tensor = resampler(audio_tensor)
438
+ audio_array = audio_tensor.numpy()
439
+ sample_rate = 16000
440
+
441
+ # Format input dictionary exactly as required
442
+ transcriber_input = {
443
+ "raw": audio_array,
444
+ "sampling_rate": sample_rate
445
+ }
446
+
447
  # Get transcription from Whisper
448
+ result = transcriber(transcriber_input)
 
 
 
 
449
 
450
+ # Extract text from result
451
  transcript = ""
452
  if isinstance(result, dict):
453
+ transcript = result.get("text", "").strip()
454
  elif isinstance(result, str):
455
+ transcript = result.strip()
456
+
 
 
 
 
 
 
457
  if not transcript:
458
  return history
459
 
 
523
 
524
  try:
525
  sample_rate, audio_array = audio
 
526
 
527
+ # Process audio and get proper format
528
+ inputs = prepare_audio_features(audio_array, sample_rate)
529
+
530
+ # Pass to transcriber
531
+ result = transcriber(inputs)
532
 
533
+ # Extract text from result
534
  if isinstance(result, dict):
535
  return result.get("text", "").strip()
536
  elif isinstance(result, str):
537
  return result.strip()
 
 
538
  return ""
539
+
540
  except Exception as e:
541
  print(f"Transcription error: {str(e)}")
542
  return ""