gpaasch commited on
Commit
54fa492
·
1 Parent(s): 00bcf43

audio processing pipeliine

Browse files
Files changed (1) hide show
  1. src/app.py +56 -21
src/app.py CHANGED
@@ -15,7 +15,9 @@ import io
15
  import base64
16
  import numpy as np
17
  from transformers.pipelines import pipeline # Changed from transformers import pipeline
18
- from transformers import WhisperProcessor
 
 
19
 
20
  # Model options mapped to their requirements
21
  MODEL_OPTIONS = {
@@ -40,25 +42,24 @@ MODEL_OPTIONS = {
40
  }
41
 
42
  # Initialize Whisper with proper configuration
 
 
 
 
 
43
  transcriber = pipeline(
44
  "automatic-speech-recognition",
45
  model="openai/whisper-base.en",
46
  chunk_length_s=30,
47
  stride_length_s=5,
48
- return_timestamps=True,
49
- device="cpu", # Explicitly set to CPU since we're seeing GPU warnings
50
  torch_dtype=torch.float32,
 
51
  generate_kwargs={
52
- "task": "transcribe",
53
- "language": "en",
54
- "use_cache": True,
55
- "return_timestamps": True
56
  }
57
  )
58
 
59
- # Create processor for proper attention mask
60
- processor = WhisperProcessor.from_pretrained("openai/whisper-base.en")
61
-
62
  def get_system_specs() -> Dict[str, float]:
63
  """Get system specifications."""
64
  # Get RAM
@@ -207,12 +208,23 @@ def process_speech(audio_data, history):
207
  audio_array = audio_array.astype(np.float32)
208
  audio_array /= np.max(np.abs(audio_array))
209
 
 
 
 
 
 
 
 
 
210
  # Transcribe with error handling
211
  try:
212
- result = transcriber(
213
- {"sampling_rate": sample_rate, "raw": audio_array},
214
- batch_size=8
215
- )
 
 
 
216
 
217
  # Handle different result types
218
  if isinstance(result, dict) and "text" in result:
@@ -422,6 +434,34 @@ with gr.Blocks(
422
  queue=True # Enable queuing for better stream handling
423
  )
424
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
425
  # Update transcription handler
426
  def update_live_transcription(audio):
427
  """Real-time transcription updates."""
@@ -430,14 +470,9 @@ with gr.Blocks(
430
 
431
  try:
432
  sample_rate, audio_array = audio
433
- if audio_array.ndim > 1:
434
- audio_array = audio_array.mean(axis=1)
435
- audio_array = audio_array.astype(np.float32)
436
- audio_array /= np.max(np.abs(audio_array))
437
 
438
- result = transcriber(
439
- {"sampling_rate": sample_rate, "raw": audio_array}
440
- )
441
 
442
  # Handle different result types
443
  if isinstance(result, dict):
 
15
  import base64
16
  import numpy as np
17
  from transformers.pipelines import pipeline # Changed from transformers import pipeline
18
+ from transformers import WhisperFeatureExtractor, WhisperTokenizer, WhisperProcessor
19
+ import torchaudio
20
+ import torchaudio.transforms as T
21
 
22
  # Model options mapped to their requirements
23
  MODEL_OPTIONS = {
 
42
  }
43
 
44
  # Initialize Whisper with proper configuration
45
+ # Create components separately
46
+ feature_extractor = WhisperFeatureExtractor.from_pretrained("openai/whisper-base.en")
47
+ tokenizer = WhisperTokenizer.from_pretrained("openai/whisper-base.en")
48
+ processor = WhisperProcessor(feature_extractor, tokenizer)
49
+
50
  transcriber = pipeline(
51
  "automatic-speech-recognition",
52
  model="openai/whisper-base.en",
53
  chunk_length_s=30,
54
  stride_length_s=5,
55
+ device="cpu",
 
56
  torch_dtype=torch.float32,
57
+ # Remove feature_extractor and tokenizer parameters as they're included in the model
58
  generate_kwargs={
59
+ "use_cache": True
 
 
 
60
  }
61
  )
62
 
 
 
 
63
  def get_system_specs() -> Dict[str, float]:
64
  """Get system specifications."""
65
  # Get RAM
 
208
  audio_array = audio_array.astype(np.float32)
209
  audio_array /= np.max(np.abs(audio_array))
210
 
211
+ # Ensure correct sampling rate
212
+ if sample_rate != 16000:
213
+ resampler = T.Resample(sample_rate, 16000)
214
+ audio_tensor = torch.FloatTensor(audio_array)
215
+ audio_tensor = resampler(audio_tensor)
216
+ audio_array = audio_tensor.numpy()
217
+ sample_rate = 16000
218
+
219
  # Transcribe with error handling
220
  try:
221
+ # Format dictionary correctly with required keys
222
+ inputs = {
223
+ "raw": audio_array,
224
+ "sampling_rate": sample_rate
225
+ }
226
+
227
+ result = transcriber(inputs)
228
 
229
  # Handle different result types
230
  if isinstance(result, dict) and "text" in result:
 
434
  queue=True # Enable queuing for better stream handling
435
  )
436
 
437
+ def process_audio(audio_array, sample_rate):
438
+ """Pre-process audio for Whisper."""
439
+ if audio_array.ndim > 1:
440
+ audio_array = audio_array.mean(axis=1)
441
+
442
+ # Convert to tensor for resampling
443
+ audio_tensor = torch.FloatTensor(audio_array)
444
+
445
+ # Resample to 16kHz if needed
446
+ if sample_rate != 16000:
447
+ resampler = T.Resample(sample_rate, 16000)
448
+ audio_tensor = resampler(audio_tensor)
449
+
450
+ # Normalize
451
+ audio_tensor = audio_tensor / torch.max(torch.abs(audio_tensor))
452
+
453
+ # Use feature extractor with correct sampling rate
454
+ features = feature_extractor(
455
+ audio_tensor.numpy(),
456
+ sampling_rate=16000, # Always use 16kHz
457
+ return_tensors="pt"
458
+ )
459
+
460
+ return {
461
+ "input_features": features.input_features,
462
+ "sampling_rate": 16000 # Return resampled rate
463
+ }
464
+
465
  # Update transcription handler
466
  def update_live_transcription(audio):
467
  """Real-time transcription updates."""
 
470
 
471
  try:
472
  sample_rate, audio_array = audio
473
+ input_features = process_audio(audio_array, sample_rate)
 
 
 
474
 
475
+ result = transcriber(input_features)
 
 
476
 
477
  # Handle different result types
478
  if isinstance(result, dict):