hamza2923's picture
Update app.py
7323fd3 verified
raw
history blame
6.36 kB
from flask import Flask, request, jsonify
from faster_whisper import WhisperModel
import torch
import io
import time
import datetime
from threading import Semaphore
import os
from werkzeug.utils import secure_filename
import tempfile
app = Flask(__name__)
# Configuration
MAX_CONCURRENT_REQUESTS = 2 # Adjust based on your server capacity
MAX_AUDIO_DURATION = 60 * 30 # 30 minutes maximum audio duration (adjust as needed)
TEMPORARY_FOLDER = tempfile.gettempdir()
ALLOWED_EXTENSIONS = {'mp3', 'wav', 'ogg', 'm4a', 'flac'}
# Device check for faster-whisper
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
print(f"Using device: {device} with compute_type: {compute_type}")
# Faster Whisper setup with optimized parameters for long audio
beamsize = 5 # Slightly larger beam size can help with long-form accuracy
wmodel = WhisperModel(
"guillaumekln/faster-whisper-small",
device=device,
compute_type=compute_type,
download_root="./model_cache" # Cache model to avoid re-downloading
)
# Concurrency control
request_semaphore = Semaphore(MAX_CONCURRENT_REQUESTS)
active_requests = 0
def allowed_file(filename):
return '.' in filename and \
filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
def cleanup_temp_files(file_path):
"""Ensure temporary files are deleted after processing"""
try:
if os.path.exists(file_path):
os.remove(file_path)
except Exception as e:
print(f"Error cleaning up temp file {file_path}: {str(e)}")
@app.route("/health", methods=["GET"])
def health_check():
"""Endpoint to check if API is running"""
return jsonify({
'status': 'API is running',
'timestamp': datetime.datetime.now().isoformat(),
'device': device,
'compute_type': compute_type,
'active_requests': active_requests,
'max_duration_supported': MAX_AUDIO_DURATION
})
@app.route("/status/busy", methods=["GET"])
def server_busy():
"""Endpoint to check if server is busy"""
is_busy = active_requests >= MAX_CONCURRENT_REQUESTS
return jsonify({
'is_busy': is_busy,
'active_requests': active_requests,
'max_capacity': MAX_CONCURRENT_REQUESTS
})
@app.route("/whisper_transcribe", methods=["POST"])
def whisper_transcribe():
global active_requests
if not request_semaphore.acquire(blocking=False):
return jsonify({
'status': 'Server busy',
'message': f'Currently processing {active_requests} requests',
'suggestion': 'Please try again shortly'
}), 503
active_requests += 1
print(f"Starting transcription (Active requests: {active_requests})")
temp_file_path = None
try:
if 'audio' not in request.files:
return jsonify({'error': 'No file provided'}), 400
audio_file = request.files['audio']
if not (audio_file and allowed_file(audio_file.filename)):
return jsonify({'error': 'Invalid file format'}), 400
# Save to temporary file for large audio processing
temp_file_path = os.path.join(TEMPORARY_FOLDER, secure_filename(audio_file.filename))
audio_file.save(temp_file_path)
# Get processing parameters from request
language = request.form.get('language', None)
task = request.form.get('task', 'transcribe') # 'transcribe' or 'translate'
vad_filter = request.form.get('vad_filter', 'true').lower() == 'true'
word_timestamps = request.form.get('word_timestamps', 'false').lower() == 'true'
try:
start_time = time.time()
# Process in chunks with VAD for long audio
segments, info = wmodel.transcribe(
temp_file_path,
beam_size=beamsize,
language=language,
task=task,
vad_filter=vad_filter,
word_timestamps=word_timestamps,
chunk_length=30 # Process in 30-second chunks
)
# Stream results as they become available
results = []
for segment in segments:
if time.time() - start_time > MAX_AUDIO_DURATION:
raise TimeoutError(f"Transcription exceeded maximum allowed duration of {MAX_AUDIO_DURATION} seconds")
result = {
'text': segment.text,
'start': segment.start,
'end': segment.end
}
if word_timestamps and segment.words:
result['words'] = [{
'word': word.word,
'start': word.start,
'end': word.end,
'probability': word.probability
} for word in segment.words]
results.append(result)
processing_time = time.time() - start_time
print(f"Transcription completed in {processing_time:.2f} seconds")
return jsonify({
'segments': results,
'summary': {
'processing_time': processing_time,
'language': info.language,
'language_probability': info.language_probability,
'duration': sum(seg.end - seg.start for seg in results if hasattr(seg, 'end'))
}
})
except TimeoutError as te:
print(f"Transcription timeout: {str(te)}")
return jsonify({'error': str(te)}), 504
except Exception as e:
print(f"Transcription error: {str(e)}")
return jsonify({'error': 'Transcription failed', 'details': str(e)}), 500
finally:
if temp_file_path:
cleanup_temp_files(temp_file_path)
active_requests -= 1
request_semaphore.release()
print(f"Request completed (Active requests: {active_requests})")
if __name__ == "__main__":
# Create temporary folder if it doesn't exist
if not os.path.exists(TEMPORARY_FOLDER):
os.makedirs(TEMPORARY_FOLDER)
app.run(host="0.0.0.0", port=7860, threaded=True)