baoyin2024 commited on
Commit
0666a2d
·
verified ·
1 Parent(s): d1f434b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +170 -124
app.py CHANGED
@@ -1,156 +1,202 @@
1
- from flask import Flask, request, jsonify
2
- import os
 
3
  import io
4
- import whisperx
5
- import torchaudio
6
- import gc
7
- import tempfile
8
- import ffmpeg
9
- from datetime import datetime
10
  from threading import Semaphore
 
 
 
 
 
 
11
 
12
- app = Flask(__name__)
 
13
 
14
- # 从环境变量中读取 API_KEY
15
- api_key = os.environ.get("API_KEY")
16
- if not api_key:
17
- print("Error: API_KEY environment variable not set!")
18
 
19
- # 信号量,用于限制并发请求的数量
20
  MAX_CONCURRENT_REQUESTS = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
 
22
 
23
- # GPU device
24
- device = "cuda"
25
- compute_type = "float16"
26
 
27
  def validate_api_key(request):
28
- """
29
- 验证 API Key. 从 request header 读取 API Key,并与环境变量中的 API Key 进行比较。
 
 
 
30
 
31
- Args:
32
- request: Flask request 对象.
33
 
34
- Returns:
35
- True 如果 API Key 有效,否则 False.
36
- """
37
- api_key_header = request.headers.get("X-API-Key")
38
- api_key_query = request.args.get("api_key")
39
- api_key_form = request.form.get("api_key")
40
 
41
- api_key_env = os.environ.get("API_KEY")
42
 
43
- if not api_key_env:
44
- return False, "API_KEY environment variable not set"
 
 
 
 
 
 
45
 
46
- if api_key_header == api_key_env or api_key_query == api_key_env or api_key_form == api_key_env:
47
- return True, None
48
- else:
49
- return False, "Invalid API Key"
50
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- @app.route("/whisper_transcribe", methods=["POST"])
53
- def whisper_transcribe():
54
- is_valid, message = validate_api_key(request) # 验证 API Key
55
- if not is_valid:
56
- return jsonify({"error": message}), 401
57
 
58
- with request_semaphore:
59
- if 'file' not in request.files:
60
- return jsonify({'error': 'No file uploaded'}), 400
 
 
 
 
 
 
 
 
 
61
 
62
- file = request.files['file']
63
- if file.filename == '':
64
- return jsonify({'error': 'No file selected'}), 400
65
 
66
- filename = file.filename
67
- file_extension = filename.rsplit('.', 1)[1].lower()
68
- allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a', 'flac', 'aac', 'wma', 'opus', 'aiff', 'mp4', 'avi', 'mov',
69
- 'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'}
70
- if file_extension not in allowed_extensions:
71
- return jsonify({'error': f'Invalid file format. Supported: {", ".join(allowed_extensions)}'}), 400
 
 
72
 
73
- try:
74
- # Save the uploaded file to a temporary file
75
- with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_extension}') as temp_file:
76
- file.save(temp_file.name)
77
- temp_file_path = temp_file.name
78
-
79
- # Determine if the file is a video file
80
- video_extensions = {'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'}
81
- if file_extension in video_extensions:
82
- file_type = "video"
83
- try:
84
- # Extract audio from video using ffmpeg
85
- audio_file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
86
- ffmpeg.input(temp_file_path).output(audio_file_path, format='wav', acodec='pcm_s16le').run(quiet=True, overwrite_output=True)
87
- except Exception as e:
88
- return jsonify({'error': f'Failed to extract audio from video: {str(e)}'}), 500
89
-
90
- # Delete the temporary video file
91
- os.remove(temp_file_path)
92
- audio_file_path_final = audio_file_path
93
- else:
94
- file_type = "audio"
95
- audio_file_path_final = temp_file_path
96
-
97
- # Load the audio file
98
- try:
99
- audio, samplerate = torchaudio.load(audio_file_path_final)
100
- audio = audio.to(device)
101
- if audio.shape[0] > 1:
102
- audio = audio.mean(dim=0, keepdim=True)
103
- audio = audio.squeeze()
104
- if samplerate != 16000:
105
- audio = torchaudio.functional.resample(audio, samplerate, 16000)
106
- except Exception as e:
107
- return jsonify({'error': f'Failed to load audio file: {str(e)}'}), 500
108
-
109
- # Ensure the audio duration does not exceed 10 minutes
110
- max_duration = 10 * 60 # 10 minutes in seconds
111
- if audio.shape[-1] / 16000 > max_duration:
112
- return jsonify({'error': 'Audio duration exceeds the maximum allowed duration of 10 minutes'}), 400
113
-
114
- # Perform transcription
115
- try:
116
- wmodel, model_options = get_model()
117
 
118
- segments, info = wmodel.transcribe(audio, batch_size=model_options.get("batch_size", None))
119
- segments = list(segments) # Convert generator to list
 
120
 
121
- transcription = ""
122
- for segment in segments:
123
- transcription += segment.text
124
 
125
- except Exception as e:
126
- return jsonify({'error': f'Transcription failed: {str(e)}'}), 500
127
- finally:
128
- # Clean up temporary files
129
- os.remove(audio_file_path_final)
130
- gc.collect()
131
- torch.cuda.empty_cache()
132
 
133
- return jsonify({'transcription': transcription, 'file_type': file_type}), 200
 
 
 
134
 
135
- except Exception as e:
136
- return jsonify({'error': str(e)}), 500
 
137
 
138
- @app.route("/health", methods=["GET"])
139
- def health_check():
140
- return jsonify({"status": "healthy"}), 200
141
 
142
- @app.route("/status/busy", methods=["GET"])
143
- def status_busy():
144
- return jsonify({"busy": request_semaphore._value == 0}), 200
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
145
 
146
- def get_model():
147
- """Load model"""
148
- model_name = "guillaumekln/faster-whisper-large-v2"
149
- model_options = {"beam_size": 5}
150
- wmodel = whisperx.load_model(model_name, device, compute_type=compute_type)
151
 
152
- return wmodel, model_options
 
 
 
 
153
 
 
154
 
155
- if __name__ == "__main__":
156
- app.run(debug=True, port=int(os.environ.get("PORT", 7860)))
 
1
+ from flask import Flask, request, jsonify, Response
2
+ from faster_whisper import WhisperModel
3
+ import torch
4
  import io
5
+ import time
6
+ import datetime
 
 
 
 
7
  from threading import Semaphore
8
+ import os
9
+ from werkzeug.utils import secure_filename
10
+ import tempfile
11
+ from moviepy.editor import VideoFileClip
12
+ import logging
13
+ import torchaudio # Import torchaudio
14
 
15
+ # Configure logging
16
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
17
 
18
+ app = Flask(__name__)
 
 
 
19
 
20
+ # Configuration
21
  MAX_CONCURRENT_REQUESTS = 2
22
+ MAX_FILE_DURATION = 60 * 30
23
+ TEMPORARY_FOLDER = tempfile.gettempdir()
24
+ ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'ogg', 'm4a', 'flac', 'aac', 'wma', 'opus', 'aiff'}
25
+ ALLOWED_VIDEO_EXTENSIONS = {'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'}
26
+ ALLOWED_EXTENSIONS = ALLOWED_AUDIO_EXTENSIONS.union(ALLOWED_VIDEO_EXTENSIONS)
27
+
28
+ API_KEY = os.environ.get("API_KEY") # Load API key from environment
29
+ MODEL_NAME = os.environ.get("WHISPER_MODEL", "guillaumekln/faster-whisper-large-v2") # Configurable model
30
+
31
+ # Device check for faster-whisper
32
+ device = "cuda" if torch.cuda.is_available() else "cpu"
33
+ compute_type = "float16" if device == "cuda" else "int8"
34
+ logging.info(f"Using device: {device} with compute_type: {compute_type}")
35
+
36
+ # Faster Whisper setup
37
+ beamsize = 2
38
+ try:
39
+ wmodel = WhisperModel(
40
+ MODEL_NAME,
41
+ device=device,
42
+ compute_type=compute_type,
43
+ download_root="./model_cache"
44
+ )
45
+ logging.info(f"Model {MODEL_NAME} loaded successfully.")
46
+ except Exception as e:
47
+ logging.error(f"Failed to load model {MODEL_NAME}: {e}")
48
+ wmodel = None
49
+
50
+ # Concurrency control
51
  request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
52
+ active_requests = 0
53
 
 
 
 
54
 
55
  def validate_api_key(request):
56
+ api_key = request.headers.get('X-API-Key')
57
+ if api_key == API_KEY:
58
+ return True
59
+ else:
60
+ return False
61
 
 
 
62
 
63
+ def allowed_file(filename):
64
+ return '.' in filename and \
65
+ filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
 
 
 
66
 
 
67
 
68
+ def cleanup_temp_files(*file_paths):
69
+ for file_path in file_paths:
70
+ try:
71
+ if file_path and os.path.exists(file_path):
72
+ os.remove(file_path)
73
+ logging.info(f"Deleted temporary file: {file_path}")
74
+ except Exception as e:
75
+ logging.error(f"Error cleaning up temp file {file_path}: {str(e)}")
76
 
 
 
 
 
77
 
78
+ def extract_audio_from_video(video_path, output_audio_path):
79
+ try:
80
+ video = VideoFileClip(video_path)
81
+ if video.duration > MAX_FILE_DURATION:
82
+ video.close()
83
+ raise ValueError(f"Video duration exceeds {MAX_FILE_DURATION} seconds")
84
+ video.audio.write_audiofile(output_audio_path, codec='pcm_s16le') # Specify codec
85
+ video.close()
86
+ return output_audio_path
87
+ except Exception as e:
88
+ logging.exception("Error extracting audio from video")
89
+ raise Exception(f"Failed to extract audio from video: {str(e)}")
90
 
 
 
 
 
 
91
 
92
+ @app.route("/health", methods=["GET"])
93
+ def health_check():
94
+ return jsonify({
95
+ 'status': 'API is running',
96
+ 'timestamp': datetime.datetime.now().isoformat(),
97
+ 'device': device,
98
+ 'compute_type': compute_type,
99
+ 'active_requests': active_requests,
100
+ 'max_duration_supported': MAX_FILE_DURATION,
101
+ 'supported_formats': list(ALLOWED_EXTENSIONS),
102
+ 'model': MODEL_NAME
103
+ })
104
 
 
 
 
105
 
106
+ @app.route("/status/busy", methods=["GET"])
107
+ def server_busy():
108
+ is_busy = active_requests >= MAX_CONCURRENT_REQUESTS
109
+ return jsonify({
110
+ 'is_busy': is_busy,
111
+ 'active_requests': active_requests,
112
+ 'max_capacity': MAX_CONCURRENT_REQUESTS
113
+ })
114
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
115
 
116
+ @app.route("/whisper_transcribe", methods=["POST"])
117
+ def transcribe():
118
+ global active_requests
119
 
120
+ if not validate_api_key(request):
121
+ return jsonify({'error': 'Invalid API key'}), 401
 
122
 
123
+ if not request_semaphore.acquire(blocking=False):
124
+ return jsonify({'error': 'Server busy'}), 503
 
 
 
 
 
125
 
126
+ active_requests += 1
127
+ start_time = time.time()
128
+ temp_file_path = None
129
+ temp_audio_path = None
130
 
131
+ try:
132
+ if wmodel is None:
133
+ return jsonify({'error': 'Model failed to load. Check server logs.'}), 500
134
 
135
+ if 'file' not in request.files:
136
+ return jsonify({'error': 'No file provided'}), 400
 
137
 
138
+ file = request.files['file']
139
+ if not (file and allowed_file(file.filename)):
140
+ return jsonify({'error': f'Invalid file format. Supported: {", ".join(ALLOWED_EXTENSIONS)}'}), 400
141
+
142
+ # Save uploaded file to temporary location
143
+ temp_file_path = os.path.join(TEMPORARY_FOLDER, secure_filename(file.filename))
144
+ file.save(temp_file_path)
145
+
146
+ # Check if file is a video and extract audio if necessary
147
+ file_extension = file.filename.rsplit('.', 1)[1].lower()
148
+ is_video = file_extension in ALLOWED_VIDEO_EXTENSIONS
149
+
150
+ if is_video:
151
+ temp_audio_path = os.path.join(TEMPORARY_FOLDER, f"temp_audio_{int(time.time())}.wav")
152
+ extract_audio_from_video(temp_file_path, temp_audio_path)
153
+ transcription_file = temp_audio_path
154
+ else:
155
+ transcription_file = temp_file_path
156
+ # Check audio file duration directly
157
+ try:
158
+ info = torchaudio.info(transcription_file)
159
+ duration = info.num_frames / info.sample_rate
160
+ if duration > MAX_FILE_DURATION:
161
+ raise ValueError(f"Audio duration exceeds {MAX_FILE_DURATION} seconds")
162
+ except Exception as duration_err:
163
+ logging.exception(f"Error getting/checking audio duration for {transcription_file}")
164
+ return jsonify({'error': f'Error getting/checking audio duration: {str(duration_err)}'}), 400
165
+
166
+
167
+
168
+ # Transcribe the audio file
169
+ segments, _ = wmodel.transcribe(
170
+ transcription_file,
171
+ beam_size=beamsize,
172
+ vad_filter=True,
173
+ without_timestamps=True,
174
+ compression_ratio_threshold=2.4,
175
+ word_timestamps=False
176
+ )
177
+
178
+ full_text = " ".join(segment.text for segment in segments)
179
+ return jsonify({
180
+ 'transcription': full_text,
181
+ 'file_type': 'video' if is_video else 'audio'
182
+ }), 200
183
+
184
+ except Exception as e:
185
+ logging.exception("Exception during transcription process")
186
+ return jsonify({'error': str(e)}), 500
187
+
188
+ finally:
189
+ cleanup_temp_files(temp_file_path, temp_audio_path)
190
+ active_requests -= 1
191
+ request_semaphore.release()
192
+ print(f"Processed in {time.time() - start_time:.2f}s (Active: {active_requests})")
193
 
 
 
 
 
 
194
 
195
+ if __name__ == "__main__":
196
+ # Create temporary folder if it doesn't exist
197
+ if not os.path.exists(TEMPORARY_FOLDER):
198
+ os.makedirs(TEMPORARY_FOLDER)
199
+ logging.info(f"Created temporary folder: {TEMPORARY_FOLDER}")
200
 
201
+ app.run(host="0.0.0.0", port=7860, threaded=True)
202