import gradio as gr import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer # Model configuration model_name = "ai4bharat/IndicBART" # Load tokenizer and model on CPU print("Loading IndicBART tokenizer...") tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case=False, use_fast=False, keep_accents=True) print("Loading IndicBART model on CPU...") model = AutoModelForSeq2SeqLM.from_pretrained( model_name, torch_dtype=torch.float32, # Use float32 for better CPU performance device_map="cpu" ) # Language mapping LANGUAGE_CODES = { "Assamese": "<2as>", "Bengali": "<2bn>", "English": "<2en>", "Gujarati": "<2gu>", "Hindi": "<2hi>", "Kannada": "<2kn>", "Malayalam": "<2ml>", "Marathi": "<2mr>", "Oriya": "<2or>", "Punjabi": "<2pa>", "Tamil": "<2ta>", "Telugu": "<2te>" } def generate_response(input_text, source_lang, target_lang, task_type, max_length): """Generate response using IndicBART on CPU""" # Get language codes src_code = LANGUAGE_CODES[source_lang] tgt_code = LANGUAGE_CODES[target_lang] # Format input based on task type if task_type == "Translation": formatted_input = f"{input_text} {src_code}" decoder_start_token = tgt_code elif task_type == "Text Completion": # For completion, use target language formatted_input = f"{input_text} {tgt_code}" decoder_start_token = tgt_code else: # Text Generation formatted_input = f"{input_text} {src_code}" decoder_start_token = tgt_code # Tokenize input (keep on CPU) inputs = tokenizer(formatted_input, return_tensors="pt", padding=True, truncation=True, max_length=512) # Get decoder start token id try: decoder_start_token_id = tokenizer._convert_token_to_id_with_added_voc(decoder_start_token) except: # Fallback if the method doesn't exist decoder_start_token_id = tokenizer.convert_tokens_to_ids(decoder_start_token) # Generate on CPU with torch.no_grad(): outputs = model.generate( **inputs, decoder_start_token_id=decoder_start_token_id, max_length=max_length, num_beams=2, # Reduced for faster CPU inference early_stopping=True, pad_token_id=tokenizer.pad_token_id, eos_token_id=tokenizer.eos_token_id, use_cache=True, do_sample=False # Deterministic for CPU ) # Decode output generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True, clean_up_tokenization_spaces=False) return generated_text # Create Gradio interface with gr.Blocks(title="IndicBART CPU Multilingual Assistant", theme=gr.themes.Soft()) as demo: gr.Markdown(""" # ЁЯЗоЁЯЗ│ IndicBART Multilingual Assistant (CPU Version) Experience IndicBART - trained on **11 Indian languages**! Perfect for translation, text completion, and multilingual generation. **Supported Languages**: Assamese, Bengali, Gujarati, Hindi, Kannada, Malayalam, Marathi, Oriya, Punjabi, Tamil, Telugu, English *Note: Running on CPU - responses may take longer than GPU version.* """) with gr.Row(): with gr.Column(scale=3): input_text = gr.Textbox( label="Input Text", placeholder="Enter text in any supported language...", lines=3 ) output_text = gr.Textbox( label="Generated Output", lines=5, interactive=False ) with gr.Row(): generate_btn = gr.Button("Generate", variant="primary", size="lg") clear_btn = gr.Button("Clear", variant="secondary") with gr.Column(scale=1): task_type = gr.Dropdown( choices=["Translation", "Text Completion", "Text Generation"], value="Translation", label="Task Type" ) source_lang = gr.Dropdown( choices=list(LANGUAGE_CODES.keys()), value="English", label="Source Language" ) target_lang = gr.Dropdown( choices=list(LANGUAGE_CODES.keys()), value="Hindi", label="Target Language" ) max_length = gr.Slider( minimum=20, maximum=200, # Reduced for faster CPU processing value=80, step=10, label="Max Length" ) # Examples gr.Markdown("### ЁЯТб Try these examples:") examples = [ ["Hello, how are you?", "English", "Hindi", "Translation", 80], ["рдореИрдВ рдПрдХ рдЫрд╛рддреНрд░ рд╣реВрдВ", "Hindi", "English", "Translation", 80], ["ржЖржорж┐ ржнрж╛ржд ржЦрж╛ржЗ", "Bengali", "English", "Translation", 80], ["рднрд╛рд░рдд рдПрдХ", "Hindi", "Hindi", "Text Completion", 100], ["The capital of India", "English", "English", "Text Completion", 80] ] gr.Examples( examples=examples, inputs=[input_text, source_lang, target_lang, task_type, max_length], outputs=output_text, fn=generate_response ) # Event handlers def clear_fields(): return "", "" # Connect buttons generate_btn.click( generate_response, inputs=[input_text, source_lang, target_lang, task_type, max_length], outputs=output_text ) clear_btn.click( clear_fields, outputs=[input_text, output_text] ) if __name__ == "__main__": demo.launch( share=True, ssr_mode=False, # Disable SSR mode to fix the 500 error server_name="0.0.0.0", server_port=7860, show_error=True )