import gradio as gr import os import torch import torchaudio from transformers import Wav2Vec2Processor, Wav2Vec2ForCTC from deepmultilingualpunctuation import PunctuationModel MODEL = "xTorch8/fine-tuned-mms" TOKEN = os.getenv("TOKEN") model = Wav2Vec2ForCTC.from_pretrained(MODEL, token = TOKEN) processor = Wav2Vec2Processor.from_pretrained(MODEL, token = TOKEN) torchaudio.set_audio_backend("soundfile") language_model = PunctuationModel() def transcription(audio_stream, is_video = False): try: if isinstance(audio_stream, tuple): audio_stream = audio_stream[0] if is_video: waveform, sample_rate = torchaudio.load(audio_stream, format = "wav") else: waveform, sample_rate = torchaudio.load(audio_stream) target_sample_rate = 16000 if sample_rate != target_sample_rate: transform = torchaudio.transforms.Resample(orig_freq = sample_rate, new_freq = target_sample_rate) waveform = transform(waveform) input_values = processor(waveform.squeeze().numpy(), return_tensors = "pt", sampling_rate = target_sample_rate).input_values with torch.no_grad(): logits = model(input_values).logits predicted_ids = torch.argmax(logits, dim = -1) transcription = processor.batch_decode(predicted_ids)[0] transcription = language_model.restore_punctuation(transcription) return transcription except Exception as e: return e demo = gr.Interface( fn = transcription, inputs = [ gr.Audio(label = "Upload Audio/Video", type="filepath"), gr.Checkbox(label = "Is this a video file?") ], outputs = gr.Textbox(label = "Transcription Output"), title = "MMS Audio/Video Transcription", allow_flagging = "never" ) if __name__ == "__main__": demo.launch() # Trigger rebuild