|
import argparse |
|
import gc |
|
import os |
|
import warnings |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from whisperx.alignment import align, load_align_model |
|
from whisperx.asr import load_model |
|
from whisperx.audio import load_audio |
|
from whisperx.diarize import DiarizationPipeline, assign_word_speakers |
|
from whisperx.types import AlignedTranscriptionResult, TranscriptionResult |
|
from whisperx.utils import LANGUAGES, TO_LANGUAGE_CODE, get_writer |
|
|
|
|
|
def transcribe_task(args: dict, parser: argparse.ArgumentParser): |
|
"""Transcription task to be called from CLI. |
|
|
|
Args: |
|
args: Dictionary of command-line arguments. |
|
parser: argparse.ArgumentParser object. |
|
""" |
|
|
|
|
|
model_name: str = args.pop("model") |
|
batch_size: int = args.pop("batch_size") |
|
model_dir: str = args.pop("model_dir") |
|
model_cache_only: bool = args.pop("model_cache_only") |
|
output_dir: str = args.pop("output_dir") |
|
output_format: str = args.pop("output_format") |
|
device: str = args.pop("device") |
|
device_index: int = args.pop("device_index") |
|
compute_type: str = args.pop("compute_type") |
|
verbose: bool = args.pop("verbose") |
|
|
|
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
align_model: str = args.pop("align_model") |
|
interpolate_method: str = args.pop("interpolate_method") |
|
no_align: bool = args.pop("no_align") |
|
task: str = args.pop("task") |
|
if task == "translate": |
|
|
|
no_align = True |
|
|
|
return_char_alignments: bool = args.pop("return_char_alignments") |
|
|
|
hf_token: str = args.pop("hf_token") |
|
vad_method: str = args.pop("vad_method") |
|
vad_onset: float = args.pop("vad_onset") |
|
vad_offset: float = args.pop("vad_offset") |
|
|
|
chunk_size: int = args.pop("chunk_size") |
|
|
|
diarize: bool = args.pop("diarize") |
|
min_speakers: int = args.pop("min_speakers") |
|
max_speakers: int = args.pop("max_speakers") |
|
diarize_model_name: str = args.pop("diarize_model") |
|
print_progress: bool = args.pop("print_progress") |
|
return_speaker_embeddings: bool = args.pop("speaker_embeddings") |
|
|
|
if return_speaker_embeddings and not diarize: |
|
warnings.warn("--speaker_embeddings has no effect without --diarize") |
|
|
|
if args["language"] is not None: |
|
args["language"] = args["language"].lower() |
|
if args["language"] not in LANGUAGES: |
|
if args["language"] in TO_LANGUAGE_CODE: |
|
args["language"] = TO_LANGUAGE_CODE[args["language"]] |
|
else: |
|
raise ValueError(f"Unsupported language: {args['language']}") |
|
|
|
if model_name.endswith(".en") and args["language"] != "en": |
|
if args["language"] is not None: |
|
warnings.warn( |
|
f"{model_name} is an English-only model but received '{args['language']}'; using English instead." |
|
) |
|
args["language"] = "en" |
|
align_language = ( |
|
args["language"] if args["language"] is not None else "en" |
|
) |
|
|
|
temperature = args.pop("temperature") |
|
if (increment := args.pop("temperature_increment_on_fallback")) is not None: |
|
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment)) |
|
else: |
|
temperature = [temperature] |
|
|
|
faster_whisper_threads = 4 |
|
if (threads := args.pop("threads")) > 0: |
|
torch.set_num_threads(threads) |
|
faster_whisper_threads = threads |
|
|
|
asr_options = { |
|
"beam_size": args.pop("beam_size"), |
|
"patience": args.pop("patience"), |
|
"length_penalty": args.pop("length_penalty"), |
|
"temperatures": temperature, |
|
"compression_ratio_threshold": args.pop("compression_ratio_threshold"), |
|
"log_prob_threshold": args.pop("logprob_threshold"), |
|
"no_speech_threshold": args.pop("no_speech_threshold"), |
|
"condition_on_previous_text": False, |
|
"initial_prompt": args.pop("initial_prompt"), |
|
"suppress_tokens": [int(x) for x in args.pop("suppress_tokens").split(",")], |
|
"suppress_numerals": args.pop("suppress_numerals"), |
|
} |
|
|
|
writer = get_writer(output_format, output_dir) |
|
word_options = ["highlight_words", "max_line_count", "max_line_width"] |
|
if no_align: |
|
for option in word_options: |
|
if args[option]: |
|
parser.error(f"--{option} not possible with --no_align") |
|
if args["max_line_count"] and not args["max_line_width"]: |
|
warnings.warn("--max_line_count has no effect without --max_line_width") |
|
writer_args = {arg: args.pop(arg) for arg in word_options} |
|
|
|
|
|
results = [] |
|
tmp_results = [] |
|
|
|
model = load_model( |
|
model_name, |
|
device=device, |
|
device_index=device_index, |
|
download_root=model_dir, |
|
compute_type=compute_type, |
|
language=args["language"], |
|
asr_options=asr_options, |
|
vad_method=vad_method, |
|
vad_options={ |
|
"chunk_size": chunk_size, |
|
"vad_onset": vad_onset, |
|
"vad_offset": vad_offset, |
|
}, |
|
task=task, |
|
local_files_only=model_cache_only, |
|
threads=faster_whisper_threads, |
|
) |
|
|
|
for audio_path in args.pop("audio"): |
|
audio = load_audio(audio_path) |
|
|
|
print(">>Performing transcription...") |
|
result: TranscriptionResult = model.transcribe( |
|
audio, |
|
batch_size=batch_size, |
|
chunk_size=chunk_size, |
|
print_progress=print_progress, |
|
verbose=verbose, |
|
) |
|
results.append((result, audio_path)) |
|
|
|
|
|
del model |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if not no_align: |
|
tmp_results = results |
|
results = [] |
|
align_model, align_metadata = load_align_model( |
|
align_language, device, model_name=align_model |
|
) |
|
for result, audio_path in tmp_results: |
|
|
|
if len(tmp_results) > 1: |
|
input_audio = audio_path |
|
else: |
|
|
|
input_audio = audio |
|
|
|
if align_model is not None and len(result["segments"]) > 0: |
|
if result.get("language", "en") != align_metadata["language"]: |
|
|
|
print( |
|
f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language..." |
|
) |
|
align_model, align_metadata = load_align_model( |
|
result["language"], device |
|
) |
|
print(">>Performing alignment...") |
|
result: AlignedTranscriptionResult = align( |
|
result["segments"], |
|
align_model, |
|
align_metadata, |
|
input_audio, |
|
device, |
|
interpolate_method=interpolate_method, |
|
return_char_alignments=return_char_alignments, |
|
print_progress=print_progress, |
|
) |
|
|
|
results.append((result, audio_path)) |
|
|
|
|
|
del align_model |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if diarize: |
|
if hf_token is None: |
|
print( |
|
"Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model..." |
|
) |
|
tmp_results = results |
|
print(">>Performing diarization...") |
|
print(">>Using model:", diarize_model_name) |
|
results = [] |
|
diarize_model = DiarizationPipeline(model_name=diarize_model_name, use_auth_token=hf_token, device=device) |
|
for result, input_audio_path in tmp_results: |
|
diarize_result = diarize_model( |
|
input_audio_path, |
|
min_speakers=min_speakers, |
|
max_speakers=max_speakers, |
|
return_embeddings=return_speaker_embeddings |
|
) |
|
|
|
if return_speaker_embeddings: |
|
diarize_segments, speaker_embeddings = diarize_result |
|
else: |
|
diarize_segments = diarize_result |
|
speaker_embeddings = None |
|
|
|
result = assign_word_speakers(diarize_segments, result, speaker_embeddings) |
|
results.append((result, input_audio_path)) |
|
|
|
for result, audio_path in results: |
|
result["language"] = align_language |
|
writer(result, audio_path, writer_args) |
|
|