import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer from IndicTransToolkit.processor import IndicProcessor src_lang, tgt_lang = "kas_Arab", "eng_Latn" model_name = "ai4bharat/indictrans2-indic-en-dist-200M" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForSeq2SeqLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.float16, ) ip = IndicProcessor(inference=True) def tranlate_ks(text: str): batch = ip.preprocess_batch( [text], src_lang=src_lang, tgt_lang=tgt_lang, ) inputs = tokenizer( batch, truncation=True, padding="longest", return_tensors="pt", ) with torch.no_grad(): generated_tokens = model.generate( **inputs, use_cache=True, min_length=0, max_length=256, num_beams=5, num_return_sequences=1, ) generated_tokens = tokenizer.batch_decode( generated_tokens, skip_special_tokens=True, clean_up_tokenization_spaces=True, ) translations = ip.postprocess_batch(generated_tokens, lang=tgt_lang) print('[TRANSLATION] Done . . .') return translations[0]