#!/usr/bin/env python3 import gradio as gr import torch import torchaudio import numpy as np from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor import logging # Constants and Configuration SAMPLE_RATE = 16000 CHUNK_SECONDS = 30 # Split audio into 30-second chunks CHUNK_SAMPLES = SAMPLE_RATE * CHUNK_SECONDS MODEL_NAME = "openpecha/general_stt_base_model" title = "# Tibetan Speech-to-Text" description = """ This application transcribes Tibetan audio files using: - Wav2Vec2 model fine-tuned on Garchen Rinpoche's teachings - 30-second fixed chunking for long audio processing """ # Initialize model def init_model(): # Load Wav2Vec2 model model = Wav2Vec2ForCTC.from_pretrained(MODEL_NAME) processor = Wav2Vec2Processor.from_pretrained(MODEL_NAME) model.eval() return model, processor # Initialize model globally model, processor = init_model() def process_audio(audio_path: str): if audio_path is None or audio_path == "": return "Please upload an audio file first" logging.info(f"Processing audio file: {audio_path}") try: # Load and resample audio to 16kHz mono wav, sr = torchaudio.load(audio_path) if sr != SAMPLE_RATE: wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(wav) wav = wav.mean(dim=0) # convert to mono # Split audio into 30-second chunks audio_length = wav.shape[0] transcriptions = [] for start_sample in range(0, audio_length, CHUNK_SAMPLES): end_sample = min(start_sample + CHUNK_SAMPLES, audio_length) # Extract chunk chunk = wav[start_sample:end_sample] # Skip processing if chunk is too short (less than 0.5 seconds) if chunk.shape[0] < 0.5 * SAMPLE_RATE: continue # Process chunk through model inputs = processor(chunk, sampling_rate=SAMPLE_RATE, return_tensors="pt", padding=True) with torch.no_grad(): logits = model(**inputs).logits predicted_ids = torch.argmax(logits, dim=-1) transcription = processor.decode(predicted_ids[0]) # Skip empty transcriptions if transcription.strip(): transcriptions.append(transcription) if not transcriptions: return "No speech detected or recognized" # Join all transcriptions all_text = " ".join(transcriptions) return all_text except Exception as e: logging.error(f"Error processing audio: {str(e)}") return f"Error processing audio: {str(e)}" demo = gr.Blocks() with demo: gr.Markdown(title) with gr.Row(): audio_input = gr.Audio( sources=["upload"], type="filepath", label="Upload audio file", ) process_button = gr.Button("Transcribe Audio") with gr.Row(): text_output = gr.Textbox( label="Transcription", placeholder="Transcribed text will appear here...", lines=8 ) process_button.click( process_audio, inputs=[audio_input], outputs=[text_output], ) gr.Markdown(description) if __name__ == "__main__": formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" logging.basicConfig(format=formatter, level=logging.INFO) demo.launch(share=True)