import gradio as gr import torch import torch.nn.functional as F from PIL import Image from typing import Dict, List, Tuple from underthesea import word_tokenize from huggingface_hub import hf_hub_download import vlai_template # must import # Import transforms separately to avoid compatibility issues try: import torchvision.transforms as transforms except RuntimeError: # Fallback for compatibility issues from torchvision import transforms # Import model components from src.vivqax_model import ViVQAX_Model class ViVQAXPredictor: def __init__(self, checkpoint_path: str): """ Initialize the ViVQA-X predictor. Args: checkpoint_path: Path to the trained model checkpoint """ self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # Load checkpoint checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) self.config = checkpoint['config'] # Load vocabularies self.word2idx = checkpoint['word2idx'] self.idx2word = checkpoint['idx2word'] self.answer2idx = checkpoint['answer2idx'] self.idx2answer = checkpoint['idx2answer'] # Initialize model self.model = ViVQAX_Model( vocab_size=len(self.word2idx), embed_size=self.config['model']['embed_size'], hidden_size=self.config['model']['hidden_size'], num_layers=self.config['model']['num_layers'], num_answers=len(self.answer2idx), max_explanation_length=self.config['model']['max_explanation_length'], word2idx=self.word2idx ).to(self.device) # Load model weights self.model.load_state_dict(checkpoint['model_state_dict']) self.model.eval() # Image preprocessing self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) ]) print(f"Model loaded successfully on {self.device}") print(f"Vocabulary size: {len(self.word2idx)}") print(f"Number of answers: {len(self.answer2idx)}") def tokenize_question(self, question: str, max_length: int = 20) -> torch.Tensor: """Tokenize and encode a question.""" tokens = word_tokenize(question.lower()) token_ids = [self.word2idx.get(token, self.word2idx['']) for token in tokens] # Pad or truncate if len(token_ids) > max_length: token_ids = token_ids[:max_length] else: token_ids += [self.word2idx['']] * (max_length - len(token_ids)) return torch.LongTensor(token_ids).unsqueeze(0).to(self.device) def decode_explanation(self, explanation_ids: torch.Tensor) -> str: """Decode explanation token IDs to text.""" ids = explanation_ids.squeeze().detach().cpu().tolist() if isinstance(ids, int): ids = [ids] words = [] for token_id in ids: word = self.idx2word.get(token_id, '') if word == '': break if word not in ['', '', '']: words.append(word) return ' '.join(words) def predict(self, image: Image.Image, question: str) -> Tuple[str, str, float]: """ Make prediction for an image-question pair. Args: image: PIL Image question: Question text Returns: Tuple of (predicted_answer, explanation, confidence) """ try: with torch.no_grad(): # Preprocess image if image.mode != 'RGB': image = image.convert('RGB') image_tensor = self.transform(image).unsqueeze(0).to(self.device) # Preprocess question question_tensor = self.tokenize_question(question) # Generate prediction answer_logits, explanation_ids = self.model.generate_explanation( image_tensor, question_tensor, beam_size=3 ) # Decode answer answer_probs = F.softmax(answer_logits, dim=1) answer_idx = answer_logits.argmax(dim=1).item() confidence = answer_probs.max().item() predicted_answer = self.idx2answer.get(answer_idx, 'unknown') # Decode explanation explanation = self.decode_explanation(explanation_ids[0]) return predicted_answer, explanation, confidence except Exception as e: return f"Error: {str(e)}", "Could not generate explanation", 0.0 # Initialize the predictor checkpoint_path = hf_hub_download( repo_id="VLAI-AIVN/ViVQA-X_LSTM-Generative", filename="best_model.pth" ) predictor = ViVQAXPredictor(checkpoint_path) def predict_vqa(image, question): """Gradio prediction function.""" if image is None: return "Please upload an image", "No explanation available", "0.00" if not question or question.strip() == "": return "Please enter a question", "No explanation available", "0.00" # Make prediction answer, explanation, confidence = predictor.predict(image, question) return answer, explanation, f"{confidence:.2f}" force_light_theme_js = """ () => { const params = new URLSearchParams(window.location.search); if (!params.has('__theme')) { params.set('__theme', 'light'); window.location.search = params.toString(); } } """ # ──────────────────────────── Main ───────────────────────── with gr.Blocks(theme='gstaff/sketch', css=vlai_template.custom_css, js=force_light_theme_js) as demo: vlai_template.create_header() # don't change gr.Markdown( """ ### An Automated Pipeline for Constructing a Vietnamese VQA-NLE Dataset This demo showcases the **LSTM-Generative** baseline model from our paper, trained on the **ViVQA-X** dataset. It answers questions about an image in Vietnamese and provides a natural language explanation for its answer. **How to use:** 1. Upload an image. 2. Enter a question about the image in Vietnamese. 3. Click the "Submit" button to get the answer and explanation. """ ) with gr.Row(equal_height=True, variant="panel"): with gr.Column(scale=3): image_input = gr.Image( type="pil", label="Upload Image", height=300 ) question_input = gr.Textbox( label="Question (in Vietnamese)", placeholder="Enter your question about the image...", lines=2 ) submit_btn = gr.Button("Submit / Gửi 🔧", elem_classes="full-width-btn", variant="primary") with gr.Column(scale=4): answer_output = gr.Textbox( label="Answer", interactive=False ) explanation_output = gr.Textbox( label="Explanation", lines=4, interactive=False ) confidence_output = gr.Textbox( label="Confidence Score", interactive=False ) # Examples gr.Examples( examples=[ ["examples/example1.jpg", "Đây là loài chó gì?"], ["examples/example2.jpg", "Đây là loài vật gì?"], ["examples/example3.jpg", "Đây là loài vật gì?"] ], inputs=[image_input, question_input], label="Example Questions" ) submit_btn.click( fn=predict_vqa, inputs=[image_input, question_input], outputs=[answer_output, explanation_output, confidence_output] ) question_input.submit( fn=predict_vqa, inputs=[image_input, question_input], outputs=[answer_output, explanation_output, confidence_output] ) gr.Markdown( """ ### ⭐ Star Us on GitHub! If you find this project useful, please consider giving us a star on GitHub. Your support is greatly appreciated! [duongtruongbinh/ViVQA-X] ### 📜 Citation To use this dataset or model in your research, please cite our paper: ```bibtex @misc{vivqax2025, author = {Duong, Truong-Binh and Tran, Hoang-Minh and Le-Nguyen, Binh-Nam and Duong, Dinh-Thang}, title = {An Automated Pipeline for Constructing a Vietnamese VQA-NLE Dataset}, howpublished = {Accepted for publication in the Proceedings of The International Conference on Intelligent Systems & Networks (ICISN 2025), Springer Lecture Notes in Networks and Systems (LNNS)}, year = {2025} } ``` """.strip() ) vlai_template.create_footer() # don't change if __name__ == "__main__": demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static", "examples"]) # don't change