|
import gradio as gr |
|
import torch |
|
import torch.nn.functional as F |
|
from PIL import Image |
|
from typing import Dict, List, Tuple |
|
from underthesea import word_tokenize |
|
from huggingface_hub import hf_hub_download |
|
import vlai_template |
|
|
|
|
|
try: |
|
import torchvision.transforms as transforms |
|
except RuntimeError: |
|
|
|
from torchvision import transforms |
|
|
|
|
|
from src.vivqax_model import ViVQAX_Model |
|
|
|
class ViVQAXPredictor: |
|
def __init__(self, checkpoint_path: str): |
|
""" |
|
Initialize the ViVQA-X predictor. |
|
|
|
Args: |
|
checkpoint_path: Path to the trained model checkpoint |
|
""" |
|
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False) |
|
self.config = checkpoint['config'] |
|
|
|
|
|
self.word2idx = checkpoint['word2idx'] |
|
self.idx2word = checkpoint['idx2word'] |
|
self.answer2idx = checkpoint['answer2idx'] |
|
self.idx2answer = checkpoint['idx2answer'] |
|
|
|
|
|
self.model = ViVQAX_Model( |
|
vocab_size=len(self.word2idx), |
|
embed_size=self.config['model']['embed_size'], |
|
hidden_size=self.config['model']['hidden_size'], |
|
num_layers=self.config['model']['num_layers'], |
|
num_answers=len(self.answer2idx), |
|
max_explanation_length=self.config['model']['max_explanation_length'], |
|
word2idx=self.word2idx |
|
).to(self.device) |
|
|
|
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
self.model.eval() |
|
|
|
|
|
self.transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
transforms.Normalize(mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225]) |
|
]) |
|
|
|
print(f"Model loaded successfully on {self.device}") |
|
print(f"Vocabulary size: {len(self.word2idx)}") |
|
print(f"Number of answers: {len(self.answer2idx)}") |
|
|
|
def tokenize_question(self, question: str, max_length: int = 20) -> torch.Tensor: |
|
"""Tokenize and encode a question.""" |
|
tokens = word_tokenize(question.lower()) |
|
token_ids = [self.word2idx.get(token, self.word2idx['<UNK>']) for token in tokens] |
|
|
|
|
|
if len(token_ids) > max_length: |
|
token_ids = token_ids[:max_length] |
|
else: |
|
token_ids += [self.word2idx['<PAD>']] * (max_length - len(token_ids)) |
|
|
|
return torch.LongTensor(token_ids).unsqueeze(0).to(self.device) |
|
|
|
def decode_explanation(self, explanation_ids: torch.Tensor) -> str: |
|
"""Decode explanation token IDs to text.""" |
|
ids = explanation_ids.squeeze().detach().cpu().tolist() |
|
if isinstance(ids, int): |
|
ids = [ids] |
|
words = [] |
|
for token_id in ids: |
|
word = self.idx2word.get(token_id, '<UNK>') |
|
if word == '<END>': |
|
break |
|
if word not in ['<PAD>', '<START>', '<UNK>']: |
|
words.append(word) |
|
return ' '.join(words) |
|
|
|
def predict(self, image: Image.Image, question: str) -> Tuple[str, str, float]: |
|
""" |
|
Make prediction for an image-question pair. |
|
|
|
Args: |
|
image: PIL Image |
|
question: Question text |
|
|
|
Returns: |
|
Tuple of (predicted_answer, explanation, confidence) |
|
""" |
|
try: |
|
with torch.no_grad(): |
|
|
|
if image.mode != 'RGB': |
|
image = image.convert('RGB') |
|
image_tensor = self.transform(image).unsqueeze(0).to(self.device) |
|
|
|
|
|
question_tensor = self.tokenize_question(question) |
|
|
|
|
|
answer_logits, explanation_ids = self.model.generate_explanation( |
|
image_tensor, question_tensor, beam_size=3 |
|
) |
|
|
|
|
|
answer_probs = F.softmax(answer_logits, dim=1) |
|
answer_idx = answer_logits.argmax(dim=1).item() |
|
confidence = answer_probs.max().item() |
|
|
|
predicted_answer = self.idx2answer.get(answer_idx, 'unknown') |
|
|
|
|
|
explanation = self.decode_explanation(explanation_ids[0]) |
|
|
|
return predicted_answer, explanation, confidence |
|
|
|
except Exception as e: |
|
return f"Error: {str(e)}", "Could not generate explanation", 0.0 |
|
|
|
|
|
checkpoint_path = hf_hub_download( |
|
repo_id="VLAI-AIVN/ViVQA-X_LSTM-Generative", |
|
filename="best_model.pth" |
|
) |
|
predictor = ViVQAXPredictor(checkpoint_path) |
|
|
|
def predict_vqa(image, question): |
|
"""Gradio prediction function.""" |
|
if image is None: |
|
return "Please upload an image", "No explanation available", "0.00" |
|
|
|
if not question or question.strip() == "": |
|
return "Please enter a question", "No explanation available", "0.00" |
|
|
|
|
|
answer, explanation, confidence = predictor.predict(image, question) |
|
|
|
return answer, explanation, f"{confidence:.2f}" |
|
|
|
force_light_theme_js = """ |
|
() => { |
|
const params = new URLSearchParams(window.location.search); |
|
if (!params.has('__theme')) { |
|
params.set('__theme', 'light'); |
|
window.location.search = params.toString(); |
|
} |
|
} |
|
""" |
|
|
|
|
|
with gr.Blocks(theme='gstaff/sketch', css=vlai_template.custom_css, js=force_light_theme_js) as demo: |
|
vlai_template.create_header() |
|
|
|
gr.Markdown( |
|
""" |
|
### An Automated Pipeline for Constructing a Vietnamese VQA-NLE Dataset |
|
This demo showcases the **LSTM-Generative** baseline model from our paper, trained on the **ViVQA-X** dataset. It answers questions about an image in Vietnamese and provides a natural language explanation for its answer. |
|
|
|
**How to use:** |
|
1. Upload an image. |
|
2. Enter a question about the image in Vietnamese. |
|
3. Click the "Submit" button to get the answer and explanation. |
|
""" |
|
) |
|
|
|
with gr.Row(equal_height=True, variant="panel"): |
|
with gr.Column(scale=3): |
|
image_input = gr.Image( |
|
type="pil", |
|
label="Upload Image", |
|
height=300 |
|
) |
|
question_input = gr.Textbox( |
|
label="Question (in Vietnamese)", |
|
placeholder="Enter your question about the image...", |
|
lines=2 |
|
) |
|
submit_btn = gr.Button("Submit / Gα»i π§", elem_classes="full-width-btn", variant="primary") |
|
|
|
with gr.Column(scale=4): |
|
answer_output = gr.Textbox( |
|
label="Answer", |
|
interactive=False |
|
) |
|
explanation_output = gr.Textbox( |
|
label="Explanation", |
|
lines=4, |
|
interactive=False |
|
) |
|
confidence_output = gr.Textbox( |
|
label="Confidence Score", |
|
interactive=False |
|
) |
|
|
|
|
|
gr.Examples( |
|
examples=[ |
|
["examples/example1.jpg", "ΔΓ’y lΓ loΓ i chΓ³ gΓ¬?"], |
|
["examples/example2.jpg", "ΔΓ’y lΓ loΓ i vαΊt gΓ¬?"], |
|
["examples/example3.jpg", "ΔΓ’y lΓ loΓ i vαΊt gΓ¬?"] |
|
], |
|
inputs=[image_input, question_input], |
|
label="Example Questions" |
|
) |
|
|
|
submit_btn.click( |
|
fn=predict_vqa, |
|
inputs=[image_input, question_input], |
|
outputs=[answer_output, explanation_output, confidence_output] |
|
) |
|
|
|
question_input.submit( |
|
fn=predict_vqa, |
|
inputs=[image_input, question_input], |
|
outputs=[answer_output, explanation_output, confidence_output] |
|
) |
|
gr.Markdown( |
|
""" |
|
### β Star Us on GitHub! |
|
If you find this project useful, please consider giving us a star on GitHub. Your support is greatly appreciated! |
|
<a href="https://github.com/duongtruongbinh/ViVQA-X" target="_blank">[duongtruongbinh/ViVQA-X]</a> |
|
|
|
### π Citation |
|
To use this dataset or model in your research, please cite our paper: |
|
```bibtex |
|
@misc{vivqax2025, |
|
author = {Duong, Truong-Binh and Tran, Hoang-Minh and Le-Nguyen, Binh-Nam and Duong, Dinh-Thang}, |
|
title = {An Automated Pipeline for Constructing a Vietnamese VQA-NLE Dataset}, |
|
howpublished = {Accepted for publication in the Proceedings of The International Conference on Intelligent Systems & Networks (ICISN 2025), Springer Lecture Notes in Networks and Systems (LNNS)}, |
|
year = {2025} |
|
} |
|
``` |
|
""".strip() |
|
) |
|
vlai_template.create_footer() |
|
|
|
if __name__ == "__main__": |
|
demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static", "examples"]) |