hamza2923 commited on
Commit
7323fd3
·
verified ·
1 Parent(s): 4ac3ca3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +92 -35
app.py CHANGED
@@ -5,32 +5,47 @@ import io
5
  import time
6
  import datetime
7
  from threading import Semaphore
 
 
 
8
 
9
  app = Flask(__name__)
10
 
 
 
 
 
 
 
11
  # Device check for faster-whisper
12
  device = "cuda" if torch.cuda.is_available() else "cpu"
13
  compute_type = "float16" if device == "cuda" else "int8"
14
  print(f"Using device: {device} with compute_type: {compute_type}")
15
 
16
- # Faster Whisper setup
17
- beamsize = 2
18
- wmodel = WhisperModel("guillaumekln/faster-whisper-small", device=device, compute_type=compute_type)
 
 
 
 
 
19
 
20
  # Concurrency control
21
- MAX_CONCURRENT_REQUESTS = 2 # Adjust based on your server capacity
22
  request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
23
  active_requests = 0
24
 
25
- # Warm up the model (important for CUDA)
26
- print("Warming up the model...")
27
- try:
28
- dummy_audio = io.BytesIO(b'') # Empty audio for warmup
29
- segments, info = wmodel.transcribe(dummy_audio, beam_size=beamsize)
30
- _ = [segment.text for segment in segments] # Force execution
31
- print("Model warmup complete")
32
- except Exception as e:
33
- print(f"Model warmup failed: {str(e)}")
 
 
34
 
35
  @app.route("/health", methods=["GET"])
36
  def health_check():
@@ -40,7 +55,8 @@ def health_check():
40
  'timestamp': datetime.datetime.now().isoformat(),
41
  'device': device,
42
  'compute_type': compute_type,
43
- 'active_requests': active_requests
 
44
  })
45
 
46
  @app.route("/status/busy", methods=["GET"])
@@ -57,7 +73,6 @@ def server_busy():
57
  def whisper_transcribe():
58
  global active_requests
59
 
60
- # Check if server is at capacity
61
  if not request_semaphore.acquire(blocking=False):
62
  return jsonify({
63
  'status': 'Server busy',
@@ -68,50 +83,92 @@ def whisper_transcribe():
68
  active_requests += 1
69
  print(f"Starting transcription (Active requests: {active_requests})")
70
 
 
 
71
  try:
72
  if 'audio' not in request.files:
73
  return jsonify({'error': 'No file provided'}), 400
74
 
75
  audio_file = request.files['audio']
76
- allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a'}
77
- if not (audio_file and audio_file.filename.lower().split('.')[-1] in allowed_extensions):
78
  return jsonify({'error': 'Invalid file format'}), 400
79
 
80
- audio_bytes = audio_file.read()
81
- audio_file = io.BytesIO(audio_bytes)
82
-
 
 
 
 
 
 
 
83
  try:
84
- # Timeout handling (60 seconds max processing time)
85
  start_time = time.time()
86
- segments, info = wmodel.transcribe(audio_file, beam_size=beamsize)
87
 
88
- text = ''
 
 
 
 
 
 
 
 
 
 
 
 
89
  for segment in segments:
90
- if time.time() - start_time > 60: # Timeout after 60 seconds
91
- raise TimeoutError("Transcription took too long")
92
- text += segment.text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94
  processing_time = time.time() - start_time
95
  print(f"Transcription completed in {processing_time:.2f} seconds")
96
 
97
  return jsonify({
98
- 'transcription': text,
99
- 'processing_time': processing_time,
100
- 'language': info.language,
101
- 'language_probability': info.language_probability
 
 
 
102
  })
103
 
104
- except TimeoutError:
105
- print("Transcription timeout")
106
- return jsonify({'error': 'Transcription timeout'}), 504
107
  except Exception as e:
108
  print(f"Transcription error: {str(e)}")
109
- return jsonify({'error': 'Transcription failed'}), 500
110
 
111
  finally:
 
 
112
  active_requests -= 1
113
  request_semaphore.release()
114
  print(f"Request completed (Active requests: {active_requests})")
115
 
116
  if __name__ == "__main__":
117
- app.run(host="0.0.0.0", debug=True, port=7860, threaded=True)
 
 
 
 
 
5
  import time
6
  import datetime
7
  from threading import Semaphore
8
+ import os
9
+ from werkzeug.utils import secure_filename
10
+ import tempfile
11
 
12
  app = Flask(__name__)
13
 
14
+ # Configuration
15
+ MAX_CONCURRENT_REQUESTS = 2 # Adjust based on your server capacity
16
+ MAX_AUDIO_DURATION = 60 * 30 # 30 minutes maximum audio duration (adjust as needed)
17
+ TEMPORARY_FOLDER = tempfile.gettempdir()
18
+ ALLOWED_EXTENSIONS = {'mp3', 'wav', 'ogg', 'm4a', 'flac'}
19
+
20
  # Device check for faster-whisper
21
  device = "cuda" if torch.cuda.is_available() else "cpu"
22
  compute_type = "float16" if device == "cuda" else "int8"
23
  print(f"Using device: {device} with compute_type: {compute_type}")
24
 
25
+ # Faster Whisper setup with optimized parameters for long audio
26
+ beamsize = 5 # Slightly larger beam size can help with long-form accuracy
27
+ wmodel = WhisperModel(
28
+ "guillaumekln/faster-whisper-small",
29
+ device=device,
30
+ compute_type=compute_type,
31
+ download_root="./model_cache" # Cache model to avoid re-downloading
32
+ )
33
 
34
  # Concurrency control
 
35
  request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
36
  active_requests = 0
37
 
38
+ def allowed_file(filename):
39
+ return '.' in filename and \
40
+ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
41
+
42
+ def cleanup_temp_files(file_path):
43
+ """Ensure temporary files are deleted after processing"""
44
+ try:
45
+ if os.path.exists(file_path):
46
+ os.remove(file_path)
47
+ except Exception as e:
48
+ print(f"Error cleaning up temp file {file_path}: {str(e)}")
49
 
50
  @app.route("/health", methods=["GET"])
51
  def health_check():
 
55
  'timestamp': datetime.datetime.now().isoformat(),
56
  'device': device,
57
  'compute_type': compute_type,
58
+ 'active_requests': active_requests,
59
+ 'max_duration_supported': MAX_AUDIO_DURATION
60
  })
61
 
62
  @app.route("/status/busy", methods=["GET"])
 
73
  def whisper_transcribe():
74
  global active_requests
75
 
 
76
  if not request_semaphore.acquire(blocking=False):
77
  return jsonify({
78
  'status': 'Server busy',
 
83
  active_requests += 1
84
  print(f"Starting transcription (Active requests: {active_requests})")
85
 
86
+ temp_file_path = None
87
+
88
  try:
89
  if 'audio' not in request.files:
90
  return jsonify({'error': 'No file provided'}), 400
91
 
92
  audio_file = request.files['audio']
93
+ if not (audio_file and allowed_file(audio_file.filename)):
 
94
  return jsonify({'error': 'Invalid file format'}), 400
95
 
96
+ # Save to temporary file for large audio processing
97
+ temp_file_path = os.path.join(TEMPORARY_FOLDER, secure_filename(audio_file.filename))
98
+ audio_file.save(temp_file_path)
99
+
100
+ # Get processing parameters from request
101
+ language = request.form.get('language', None)
102
+ task = request.form.get('task', 'transcribe') # 'transcribe' or 'translate'
103
+ vad_filter = request.form.get('vad_filter', 'true').lower() == 'true'
104
+ word_timestamps = request.form.get('word_timestamps', 'false').lower() == 'true'
105
+
106
  try:
 
107
  start_time = time.time()
 
108
 
109
+ # Process in chunks with VAD for long audio
110
+ segments, info = wmodel.transcribe(
111
+ temp_file_path,
112
+ beam_size=beamsize,
113
+ language=language,
114
+ task=task,
115
+ vad_filter=vad_filter,
116
+ word_timestamps=word_timestamps,
117
+ chunk_length=30 # Process in 30-second chunks
118
+ )
119
+
120
+ # Stream results as they become available
121
+ results = []
122
  for segment in segments:
123
+ if time.time() - start_time > MAX_AUDIO_DURATION:
124
+ raise TimeoutError(f"Transcription exceeded maximum allowed duration of {MAX_AUDIO_DURATION} seconds")
125
+
126
+ result = {
127
+ 'text': segment.text,
128
+ 'start': segment.start,
129
+ 'end': segment.end
130
+ }
131
+
132
+ if word_timestamps and segment.words:
133
+ result['words'] = [{
134
+ 'word': word.word,
135
+ 'start': word.start,
136
+ 'end': word.end,
137
+ 'probability': word.probability
138
+ } for word in segment.words]
139
+
140
+ results.append(result)
141
 
142
  processing_time = time.time() - start_time
143
  print(f"Transcription completed in {processing_time:.2f} seconds")
144
 
145
  return jsonify({
146
+ 'segments': results,
147
+ 'summary': {
148
+ 'processing_time': processing_time,
149
+ 'language': info.language,
150
+ 'language_probability': info.language_probability,
151
+ 'duration': sum(seg.end - seg.start for seg in results if hasattr(seg, 'end'))
152
+ }
153
  })
154
 
155
+ except TimeoutError as te:
156
+ print(f"Transcription timeout: {str(te)}")
157
+ return jsonify({'error': str(te)}), 504
158
  except Exception as e:
159
  print(f"Transcription error: {str(e)}")
160
+ return jsonify({'error': 'Transcription failed', 'details': str(e)}), 500
161
 
162
  finally:
163
+ if temp_file_path:
164
+ cleanup_temp_files(temp_file_path)
165
  active_requests -= 1
166
  request_semaphore.release()
167
  print(f"Request completed (Active requests: {active_requests})")
168
 
169
  if __name__ == "__main__":
170
+ # Create temporary folder if it doesn't exist
171
+ if not os.path.exists(TEMPORARY_FOLDER):
172
+ os.makedirs(TEMPORARY_FOLDER)
173
+
174
+ app.run(host="0.0.0.0", port=7860, threaded=True)