Spaces:
Sleeping
Sleeping
import time | |
from fastapi import FastAPI | |
from pydantic import BaseModel | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, BitsAndBytesConfig, AutoTokenizer | |
from IndicTransToolkit.processor import IndicProcessor | |
import signal | |
import sys | |
import uvicorn | |
BATCH_SIZE = 4 | |
DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
quantization = None | |
def initialize_model_and_tokenizer(ckpt_dir, quantization): | |
if quantization == "4-bit": | |
qconfig = BitsAndBytesConfig( | |
load_in_4bit=True, | |
bnb_4bit_use_double_quant=True, | |
bnb_4bit_compute_dtype=torch.bfloat16, | |
) | |
elif quantization == "8-bit": | |
qconfig = BitsAndBytesConfig( | |
load_in_8bit=True, | |
bnb_8bit_use_double_quant=True, | |
bnb_8bit_compute_dtype=torch.bfloat16, | |
) | |
else: | |
qconfig = None | |
tokenizer = AutoTokenizer.from_pretrained(ckpt_dir, trust_remote_code=True) | |
model = AutoModelForSeq2SeqLM.from_pretrained( | |
ckpt_dir, | |
trust_remote_code=True, | |
low_cpu_mem_usage=True, | |
quantization_config=qconfig, | |
) | |
if qconfig is None: | |
model = model.to(DEVICE) | |
# Only use half precision if CUDA is available | |
if DEVICE == "cuda" and torch.cuda.is_available(): | |
model.half() | |
model.eval() | |
return tokenizer, model | |
def batch_translate(input_sentences, src_lang, tgt_lang, model, tokenizer, ip): | |
translations = [] | |
for i in range(0, len(input_sentences), BATCH_SIZE): | |
batch = input_sentences[i : i + BATCH_SIZE] | |
# Preprocess the batch and extract entity mappings | |
batch = ip.preprocess_batch(batch, src_lang=src_lang, tgt_lang=tgt_lang) | |
# Tokenize the batch and generate input encodings | |
inputs = tokenizer( | |
batch, | |
truncation=True, | |
padding="longest", | |
return_tensors="pt", | |
return_attention_mask=True, | |
).to(DEVICE) | |
# Generate translations using the model | |
with torch.no_grad(): | |
generated_tokens = model.generate( | |
**inputs, | |
use_cache=True, | |
min_length=0, | |
max_length=256, | |
num_beams=4, | |
num_return_sequences=1, | |
) | |
# Decode the generated tokens into text | |
generated_tokens = tokenizer.batch_decode( | |
generated_tokens, | |
skip_special_tokens=True, | |
clean_up_tokenization_spaces=True, | |
) | |
# Postprocess the translations, including entity replacement | |
translations += ip.postprocess_batch(generated_tokens, lang=tgt_lang) | |
del inputs | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
return translations | |
# en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-1B" # ai4bharat/indictrans2-en-indic-dist-200M | |
en_indic_ckpt_dir = "ai4bharat/indictrans2-en-indic-dist-200M" | |
en_indic_tokenizer, en_indic_model = initialize_model_and_tokenizer(en_indic_ckpt_dir, quantization) | |
indic_en_ckpt_dir = "ai4bharat/indictrans2-indic-en-dist-200M" | |
indic_en_tokenizer, indic_en_model = initialize_model_and_tokenizer(indic_en_ckpt_dir, quantization) | |
ip = IndicProcessor(inference=True) | |
app = FastAPI() | |
class Translate(BaseModel): | |
input_sentence : str | |
source_lan : str | |
target_lang: str | |
lang_list = [ | |
"eng_Latn", # Latin English | |
"ben_Beng", # Bengali | |
"pan_Guru", # Punjabi | |
"asm_Beng", # Assamese | |
"gom_Deva", # Konkani | |
"guj_Gujr", # Gujarati | |
"hin_Deva", # Hindi | |
"kan_Knda", # Kannada, | |
"mal_Mlym", # Malayalam | |
"ory_Orya", # Odia, | |
"tam_Taml", # Tamil, | |
"tel_Telu", # Telugu | |
] | |
# post method to translate | |
def translate(input : Translate):# -> dict[str, Any]: | |
# start time | |
start_time = time.time() | |
if input.source_lan not in lang_list or input.target_lang not in lang_list: | |
return { | |
"message" : "Not a valid dialect", | |
"translation": None | |
} | |
model = None | |
tokenizer = None | |
if input.target_lang == "eng_Latn": | |
model = indic_en_model | |
tokenizer = indic_en_tokenizer | |
else: | |
model = en_indic_model | |
tokenizer = en_indic_tokenizer | |
translation = batch_translate( | |
[input.input_sentence], # Note: batch_translate expects a list | |
src_lang=input.source_lan, | |
tgt_lang=input.target_lang, | |
model=model, | |
tokenizer=tokenizer, | |
ip=ip # Don't forget to pass the ip parameter | |
) | |
# Calculate processing time | |
end_time = time.time() | |
processing_time = round(end_time - start_time, 2) | |
return { | |
"message" : f"translation processed successfully in {processing_time} seconds", | |
"translation": translation[0] | |
} | |
def health_check(): | |
return { | |
"status": "healthy", | |
"gpu_available": torch.cuda.is_available(), | |
"gpu_count": torch.cuda.device_count() if torch.cuda.is_available() else 0 | |
} | |
# Signal handler for graceful shutdown | |
def handle_sigterm(signum, frame): | |
print("Received SIGTERM signal. Cleaning up models and exiting...") | |
# Delete models to free GPU memory | |
global en_indic_tokenizer, en_indic_model, indic_en_tokenizer, indic_en_model | |
del en_indic_tokenizer, en_indic_model | |
del indic_en_tokenizer, indic_en_model | |
if torch.cuda.is_available(): | |
torch.cuda.empty_cache() | |
sys.exit(0) | |
# Register the signal handler | |
signal.signal(signal.SIGTERM, handle_sigterm) | |
if __name__ == "__main__": | |
uvicorn.run(app, host="0.0.0.0", port=9000) |