|
from flask import Flask, request, jsonify, Response |
|
from faster_whisper import WhisperModel |
|
import torch |
|
import io |
|
import time |
|
import datetime |
|
from threading import Semaphore |
|
import os |
|
from werkzeug.utils import secure_filename |
|
import tempfile |
|
from moviepy.editor import VideoFileClip |
|
import logging |
|
import torchaudio |
|
import ffmpeg |
|
|
|
|
|
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
|
|
|
app = Flask(__name__) |
|
|
|
|
|
MAX_CONCURRENT_REQUESTS = 2 |
|
MAX_FILE_DURATION = 60 * 30 |
|
TEMPORARY_FOLDER = tempfile.gettempdir() |
|
ALLOWED_AUDIO_EXTENSIONS = {'mp3', 'wav', 'ogg', 'm4a', 'flac', 'aac', 'wma', 'opus', 'aiff'} |
|
ALLOWED_VIDEO_EXTENSIONS = {'mp4', 'avi', 'mov', 'mkv', 'webm', 'flv', 'wmv', 'mpeg', 'mpg', '3gp'} |
|
ALLOWED_EXTENSIONS = ALLOWED_AUDIO_EXTENSIONS.union(ALLOWED_VIDEO_EXTENSIONS) |
|
|
|
API_KEY = os.environ.get("API_KEY") |
|
MODEL_NAME = os.environ.get("WHISPER_MODEL", "guillaumekln/faster-whisper-large-v2") |
|
|
|
|
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
compute_type = "float16" if device == "cuda" else "int8" |
|
logging.info(f"使用设备: {device},计算类型: {compute_type}") |
|
|
|
|
|
beamsize = 2 |
|
try: |
|
wmodel = WhisperModel( |
|
MODEL_NAME, |
|
device=device, |
|
compute_type=compute_type, |
|
download_root="./model_cache" |
|
) |
|
logging.info(f"模型 {MODEL_NAME} 加载成功.") |
|
except Exception as e: |
|
logging.error(f"加载模型 {MODEL_NAME} 失败: {e}") |
|
wmodel = None |
|
|
|
|
|
request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS) |
|
active_requests = 0 |
|
|
|
|
|
def validate_api_key(request): |
|
api_key = request.headers.get('X-API-Key') |
|
if api_key == API_KEY: |
|
return True |
|
else: |
|
return False |
|
|
|
|
|
def allowed_file(filename): |
|
return '.' in filename and \ |
|
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS |
|
|
|
|
|
def cleanup_temp_files(*file_paths): |
|
for file_path in file_paths: |
|
try: |
|
if file_path and os.path.exists(file_path): |
|
os.remove(file_path) |
|
logging.info(f"删除临时文件: {file_path}") |
|
except Exception as e: |
|
logging.error(f"删除临时文件 {file_path} 出错: {str(e)}") |
|
|
|
|
|
def extract_audio_from_video(video_path, output_audio_path): |
|
try: |
|
|
|
ffmpeg.input(video_path).output(output_audio_path, acodec='pcm_s16le').run(capture_stdout=True, capture_stderr=True) |
|
|
|
|
|
|
|
|
|
video = VideoFileClip(video_path) |
|
if video.duration > MAX_FILE_DURATION: |
|
video.close() |
|
raise ValueError(f"视频时长超过 {MAX_FILE_DURATION} 秒") |
|
video.close() |
|
|
|
return output_audio_path |
|
except Exception as e: |
|
logging.exception("提取视频中的音频出错") |
|
raise Exception(f"提取视频中的音频出错: {str(e)}") |
|
|
|
|
|
@app.route("/health", methods=["GET"]) |
|
def health_check(): |
|
return jsonify({ |
|
'status': 'API 正在运行', |
|
'timestamp': datetime.datetime.now().isoformat(), |
|
'device': device, |
|
'compute_type': compute_type, |
|
'active_requests': active_requests, |
|
'max_duration_supported': MAX_FILE_DURATION, |
|
'supported_formats': list(ALLOWED_EXTENSIONS), |
|
'model': MODEL_NAME |
|
}) |
|
|
|
|
|
@app.route("/status/busy", methods=["GET"]) |
|
def server_busy(): |
|
is_busy = active_requests >= MAX_CONCURRENT_REQUESTS |
|
return jsonify({ |
|
'is_busy': is_busy, |
|
'active_requests': active_requests, |
|
'max_capacity': MAX_CONCURRENT_REQUESTS |
|
}) |
|
|
|
|
|
@app.route("/whisper_transcribe", methods=["POST"]) |
|
def transcribe(): |
|
global active_requests |
|
|
|
if not validate_api_key(request): |
|
return jsonify({'error': '无效的 API 密钥'}), 401 |
|
|
|
if not request_semaphore.acquire(blocking=False): |
|
return jsonify({'error': '服务器繁忙'}), 503 |
|
|
|
active_requests += 1 |
|
start_time = time.time() |
|
temp_file_path = None |
|
temp_audio_path = None |
|
|
|
try: |
|
if wmodel is None: |
|
return jsonify({'error': '模型加载失败。请检查服务器日志。'}), 500 |
|
|
|
if 'file' not in request.files: |
|
return jsonify({'error': '未提供文件'}), 400 |
|
|
|
file = request.files['file'] |
|
if not (file and allowed_file(file.filename)): |
|
return jsonify({'error': f'无效的文件格式。支持:{", ".join(ALLOWED_EXTENSIONS)}'}), 400 |
|
|
|
|
|
temp_file_path = os.path.join(TEMPORARY_FOLDER, secure_filename(file.filename)) |
|
file.save(temp_file_path) |
|
|
|
|
|
file_extension = file.filename.rsplit('.', 1)[1].lower() |
|
is_video = file_extension in ALLOWED_VIDEO_EXTENSIONS |
|
|
|
if is_video: |
|
temp_audio_path = os.path.join(TEMPORARY_FOLDER, f"temp_audio_{int(time.time())}.wav") |
|
extract_audio_from_video(temp_file_path, temp_audio_path) |
|
transcription_file = temp_audio_path |
|
else: |
|
transcription_file = temp_file_path |
|
|
|
|
|
try: |
|
|
|
waveform, sample_rate = torchaudio.load(transcription_file, format=file_extension) |
|
duration = waveform.size(1) / sample_rate |
|
if duration > MAX_FILE_DURATION: |
|
raise ValueError(f"音频时长超过 {MAX_FILE_DURATION} 秒") |
|
except Exception as load_err: |
|
logging.exception(f"使用 torchaudio.load 加载音频文件出错: {transcription_file}") |
|
try: |
|
|
|
torchaudio.set_audio_backend("soundfile") |
|
waveform, sample_rate = torchaudio.load(transcription_file) |
|
duration = waveform.size(1) / sample_rate |
|
if duration > MAX_FILE_DURATION: |
|
raise ValueError(f"音频时长超过 {MAX_FILE_DURATION} 秒") |
|
|
|
except Exception as soundfile_err: |
|
logging.exception(f"使用 soundfile 后端加载音频文件出错: {transcription_file}") |
|
return jsonify({'error': f'使用两个后端加载音频文件都出错: {str(soundfile_err)}'}), 400 |
|
|
|
finally: |
|
torchaudio.set_audio_backend("default") |
|
|
|
|
|
segments, _ = wmodel.transcribe( |
|
transcription_file, |
|
beam_size=beamsize, |
|
vad_filter=True, |
|
without_timestamps=True, |
|
compression_ratio_threshold=2.4, |
|
word_timestamps=False |
|
) |
|
|
|
full_text = " ".join(segment.text for segment in segments) |
|
return jsonify({ |
|
'transcription': full_text, |
|
'file_type': 'video' if is_video else 'audio' |
|
}), 200 |
|
|
|
except Exception as e: |
|
logging.exception("转录过程中发生异常") |
|
return jsonify({'error': str(e)}), 500 |
|
|
|
finally: |
|
cleanup_temp_files(temp_file_path, temp_audio_path) |
|
active_requests -= 1 |
|
request_semaphore.release() |
|
print(f"处理时间:{time.time() - start_time:.2f}s (活动请求:{active_requests})") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
if not os.path.exists(TEMPORARY_FOLDER): |
|
os.makedirs(TEMPORARY_FOLDER) |
|
logging.info(f"创建临时文件夹: {TEMPORARY_FOLDER}") |
|
|
|
app.run(host="0.0.0.0", port=7860, threaded=True) |
|
|
|
|
|
|