baoyin2024 commited on
Commit
f5dc719
·
verified ·
1 Parent(s): 9702cbe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -135
app.py CHANGED
@@ -1,153 +1,156 @@
1
- from flask import Flask, request, jsonify, Response
2
- from faster_whisper import WhisperModel
3
- import torch
4
- import io
5
- import time
6
- import datetime
7
- from threading import Semaphore
8
  import os
9
- from werkzeug.utils import secure_filename
 
 
 
10
  import tempfile
11
- from moviepy.editor import VideoFileClip # Added for video processing
 
 
12
 
13
  app = Flask(__name__)
14
 
15
- # Configuration
16
- MAX_CONCURRENT_REQUESTS = 2 # Adjust based on server capacity
17
- MAX_FILE_DURATION = 60 * 30 # 30 minutes maximum duration (adjust as needed)
18
- TEMPORARY_FOLDER = tempfile.gettempdir()
19
- ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'ogg', 'm4a', 'flac', 'aac', 'wma', 'opus', 'aiff'}
20
- ALLOWED_VIDEO_EXTENSIONS = {'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'}
21
- ALLOWED_EXTENSIONS = ALLOWED_AUDIO_EXTENSIONS.union(ALLOWED_VIDEO_EXTENSIONS)
22
-
23
- # Device check for faster-whisper
24
- device = "cuda" if torch.cuda.is_available() else "cpu"
25
- compute_type = "float16" if device == "cuda" else "int8"
26
- print(f"Using device: {device} with compute_type: {compute_type}")
27
-
28
- # Faster Whisper setup with optimized parameters for long audio
29
- beamsize = 2
30
- wmodel = WhisperModel(
31
- "guillaumekln/faster-whisper-small",
32
- device=device,
33
- compute_type=compute_type,
34
- download_root="./model_cache"
35
- )
36
-
37
- # Concurrency control
38
  request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
39
- active_requests = 0
40
 
41
- def allowed_file(filename):
42
- return '.' in filename and \
43
- filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- def cleanup_temp_files(*file_paths):
46
- """Ensure temporary files are deleted after processing"""
47
- for file_path in file_paths:
48
  try:
49
- if file_path and os.path.exists(file_path):
50
- os.remove(file_path)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
  except Exception as e:
52
- print(f"Error cleaning up temp file {file_path}: {str(e)}")
53
-
54
- def extract_audio_from_video(video_path, output_audio_path):
55
- """Extract audio from a video file and save it as a temporary audio file"""
56
- try:
57
- video = VideoFileClip(video_path)
58
- if video.duration > MAX_FILE_DURATION:
59
- video.close()
60
- raise ValueError(f"Video duration exceeds {MAX_FILE_DURATION} seconds")
61
- video.audio.write_audiofile(output_audio_path)
62
- video.close()
63
- return output_audio_path
64
- except Exception as e:
65
- raise Exception(f"Failed to extract audio from video: {str(e)}")
66
 
67
  @app.route("/health", methods=["GET"])
68
  def health_check():
69
- """Endpoint to check if API is running"""
70
- return jsonify({
71
- 'status': 'API is running',
72
- 'timestamp': datetime.datetime.now().isoformat(),
73
- 'device': device,
74
- 'compute_type': compute_type,
75
- 'active_requests': active_requests,
76
- 'max_duration_supported': MAX_FILE_DURATION,
77
- 'supported_formats': list(ALLOWED_EXTENSIONS)
78
- })
79
 
80
  @app.route("/status/busy", methods=["GET"])
81
- def server_busy():
82
- """Endpoint to check if server is busy"""
83
- is_busy = active_requests >= MAX_CONCURRENT_REQUESTS
84
- return jsonify({
85
- 'is_busy': is_busy,
86
- 'active_requests': active_requests,
87
- 'max_capacity': MAX_CONCURRENT_REQUESTS
88
- })
 
 
89
 
90
- @app.route("/whisper_transcribe", methods=["POST"])
91
- def transcribe():
92
- global active_requests
93
-
94
- if not request_semaphore.acquire(blocking=False):
95
- return jsonify({'error': 'Server busy'}), 503
96
-
97
- active_requests += 1
98
- start_time = time.time()
99
- temp_file_path = None
100
- temp_audio_path = None
101
-
102
- try:
103
- if 'file' not in request.files:
104
- return jsonify({'error': 'No file provided'}), 400
105
-
106
- file = request.files['file']
107
- if not (file and allowed_file(file.filename)):
108
- return jsonify({'error': f'Invalid file format. Supported: {", ".join(ALLOWED_EXTENSIONS)}'}), 400
109
-
110
- # Save uploaded file to temporary location
111
- temp_file_path = os.path.join(TEMPORARY_FOLDER, secure_filename(file.filename))
112
- file.save(temp_file_path)
113
-
114
- # Check if file is a video and extract audio if necessary
115
- file_extension = file.filename.rsplit('.', 1)[1].lower()
116
- if file_extension in ALLOWED_VIDEO_EXTENSIONS:
117
- temp_audio_path = os.path.join(TEMPORARY_FOLDER, f"temp_audio_{int(time.time())}.wav")
118
- extract_audio_from_video(temp_file_path, temp_audio_path)
119
- transcription_file = temp_audio_path
120
- else:
121
- transcription_file = temp_file_path
122
-
123
- # Transcribe the audio file
124
- segments, _ = wmodel.transcribe(
125
- transcription_file,
126
- beam_size=beamsize,
127
- vad_filter=True,
128
- without_timestamps=True,
129
- compression_ratio_threshold=2.4,
130
- word_timestamps=False
131
- )
132
-
133
- full_text = " ".join(segment.text for segment in segments)
134
- return jsonify({
135
- 'transcription': full_text,
136
- 'file_type': 'video' if file_extension in ALLOWED_VIDEO_EXTENSIONS else 'audio'
137
- }), 200
138
-
139
- except Exception as e:
140
- return jsonify({'error': str(e)}), 500
141
-
142
- finally:
143
- cleanup_temp_files(temp_file_path, temp_audio_path)
144
- active_requests -= 1
145
- request_semaphore.release()
146
- print(f"Processed in {time.time()-start_time:.2f}s (Active: {active_requests})")
147
 
148
  if __name__ == "__main__":
149
- # Create temporary folder if it doesn't exist
150
- if not os.path.exists(TEMPORARY_FOLDER):
151
- os.makedirs(TEMPORARY_FOLDER)
152
-
153
- app.run(host="0.0.0.0", port=7860, threaded=True)
 
1
+ from flask import Flask, request, jsonify
 
 
 
 
 
 
2
  import os
3
+ import io
4
+ import whisperx
5
+ import torchaudio
6
+ import gc
7
  import tempfile
8
+ import ffmpeg
9
+ from datetime import datetime
10
+ from threading import Semaphore
11
 
12
  app = Flask(__name__)
13
 
14
+ # 从环境变量中读取 API_KEY
15
+ api_key = os.environ.get("API_KEY")
16
+ if not api_key:
17
+ print("Error: API_KEY environment variable not set!")
18
+
19
+ # 信号量,用于限制并发请求的数量
20
+ MAX_CONCURRENT_REQUESTS = 2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
 
22
 
23
+ # GPU device
24
+ device = "cuda"
25
+ compute_type = "float16"
26
+
27
+ def validate_api_key(request):
28
+ """
29
+ 验证 API Key. 从 request header 读取 API Key,并与环境变量中的 API Key 进行比较。
30
+
31
+ Args:
32
+ request: Flask request 对象.
33
+
34
+ Returns:
35
+ True 如果 API Key 有效,否则 False.
36
+ """
37
+ api_key_header = request.headers.get("X-API-Key")
38
+ api_key_query = request.args.get("api_key")
39
+ api_key_form = request.form.get("api_key")
40
+
41
+ api_key_env = os.environ.get("API_KEY")
42
+
43
+ if not api_key_env:
44
+ return False, "API_KEY environment variable not set"
45
+
46
+ if api_key_header == api_key_env or api_key_query == api_key_env or api_key_form == api_key_env:
47
+ return True, None
48
+ else:
49
+ return False, "Invalid API Key"
50
+
51
+
52
+ @app.route("/whisper_transcribe", methods=["POST"])
53
+ def whisper_transcribe():
54
+ is_valid, message = validate_api_key(request) # 验证 API Key
55
+ if not is_valid:
56
+ return jsonify({"error": message}), 401
57
+
58
+ with request_semaphore:
59
+ if 'file' not in request.files:
60
+ return jsonify({'error': 'No file uploaded'}), 400
61
+
62
+ file = request.files['file']
63
+ if file.filename == '':
64
+ return jsonify({'error': 'No file selected'}), 400
65
+
66
+ filename = file.filename
67
+ file_extension = filename.rsplit('.', 1)[1].lower()
68
+ allowed_extensions = {'mp3', 'wav', 'ogg', 'm4a', 'flac', 'aac', 'wma', 'opus', 'aiff', 'mp4', 'avi', 'mov',
69
+ 'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'}
70
+ if file_extension not in allowed_extensions:
71
+ return jsonify({'error': f'Invalid file format. Supported: {", ".join(allowed_extensions)}'}), 400
72
 
 
 
 
73
  try:
74
+ # Save the uploaded file to a temporary file
75
+ with tempfile.NamedTemporaryFile(delete=False, suffix=f'.{file_extension}') as temp_file:
76
+ file.save(temp_file.name)
77
+ temp_file_path = temp_file.name
78
+
79
+ # Determine if the file is a video file
80
+ video_extensions = {'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'}
81
+ if file_extension in video_extensions:
82
+ file_type = "video"
83
+ try:
84
+ # Extract audio from video using ffmpeg
85
+ audio_file_path = tempfile.NamedTemporaryFile(delete=False, suffix=".wav").name
86
+ ffmpeg.input(temp_file_path).output(audio_file_path, format='wav', acodec='pcm_s16le').run(quiet=True, overwrite_output=True)
87
+ except Exception as e:
88
+ return jsonify({'error': f'Failed to extract audio from video: {str(e)}'}), 500
89
+
90
+ # Delete the temporary video file
91
+ os.remove(temp_file_path)
92
+ audio_file_path_final = audio_file_path
93
+ else:
94
+ file_type = "audio"
95
+ audio_file_path_final = temp_file_path
96
+
97
+ # Load the audio file
98
+ try:
99
+ audio, samplerate = torchaudio.load(audio_file_path_final)
100
+ audio = audio.to(device)
101
+ if audio.shape[0] > 1:
102
+ audio = audio.mean(dim=0, keepdim=True)
103
+ audio = audio.squeeze()
104
+ if samplerate != 16000:
105
+ audio = torchaudio.functional.resample(audio, samplerate, 16000)
106
+ except Exception as e:
107
+ return jsonify({'error': f'Failed to load audio file: {str(e)}'}), 500
108
+
109
+ # Ensure the audio duration does not exceed 10 minutes
110
+ max_duration = 10 * 60 # 10 minutes in seconds
111
+ if audio.shape[-1] / 16000 > max_duration:
112
+ return jsonify({'error': 'Audio duration exceeds the maximum allowed duration of 10 minutes'}), 400
113
+
114
+ # Perform transcription
115
+ try:
116
+ wmodel, model_options = get_model()
117
+
118
+ segments, info = wmodel.transcribe(audio, batch_size=model_options.get("batch_size", None))
119
+ segments = list(segments) # Convert generator to list
120
+
121
+ transcription = ""
122
+ for segment in segments:
123
+ transcription += segment.text
124
+
125
+ except Exception as e:
126
+ return jsonify({'error': f'Transcription failed: {str(e)}'}), 500
127
+ finally:
128
+ # Clean up temporary files
129
+ os.remove(audio_file_path_final)
130
+ gc.collect()
131
+ torch.cuda.empty_cache()
132
+
133
+ return jsonify({'transcription': transcription, 'file_type': file_type}), 200
134
+
135
  except Exception as e:
136
+ return jsonify({'error': str(e)}), 500
 
 
 
 
 
 
 
 
 
 
 
 
 
137
 
138
  @app.route("/health", methods=["GET"])
139
  def health_check():
140
+ return jsonify({"status": "healthy"}), 200
 
 
 
 
 
 
 
 
 
141
 
142
  @app.route("/status/busy", methods=["GET"])
143
+ def status_busy():
144
+ return jsonify({"busy": request_semaphore._value == 0}), 200
145
+
146
+ def get_model():
147
+ """Load model"""
148
+ model_name = "guillaumekln/faster-whisper-small"
149
+ model_options = {"beam_size": 5}
150
+ wmodel = whisperx.load_model(model_name, device, compute_type=compute_type)
151
+
152
+ return wmodel, model_options
153
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
154
 
155
  if __name__ == "__main__":
156
+ app.run(debug=True, port=int(os.environ.get("PORT", 7860)))