duongtruongbinh's picture
Update requirements.txt
b9d4197
import gradio as gr
import torch
import torch.nn.functional as F
from PIL import Image
from typing import Dict, List, Tuple
from underthesea import word_tokenize
from huggingface_hub import hf_hub_download
import vlai_template # must import
# Import transforms separately to avoid compatibility issues
try:
import torchvision.transforms as transforms
except RuntimeError:
# Fallback for compatibility issues
from torchvision import transforms
# Import model components
from src.vivqax_model import ViVQAX_Model
class ViVQAXPredictor:
def __init__(self, checkpoint_path: str):
"""
Initialize the ViVQA-X predictor.
Args:
checkpoint_path: Path to the trained model checkpoint
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
self.config = checkpoint['config']
# Load vocabularies
self.word2idx = checkpoint['word2idx']
self.idx2word = checkpoint['idx2word']
self.answer2idx = checkpoint['answer2idx']
self.idx2answer = checkpoint['idx2answer']
# Initialize model
self.model = ViVQAX_Model(
vocab_size=len(self.word2idx),
embed_size=self.config['model']['embed_size'],
hidden_size=self.config['model']['hidden_size'],
num_layers=self.config['model']['num_layers'],
num_answers=len(self.answer2idx),
max_explanation_length=self.config['model']['max_explanation_length'],
word2idx=self.word2idx
).to(self.device)
# Load model weights
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
# Image preprocessing
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
print(f"Model loaded successfully on {self.device}")
print(f"Vocabulary size: {len(self.word2idx)}")
print(f"Number of answers: {len(self.answer2idx)}")
def tokenize_question(self, question: str, max_length: int = 20) -> torch.Tensor:
"""Tokenize and encode a question."""
tokens = word_tokenize(question.lower())
token_ids = [self.word2idx.get(token, self.word2idx['<UNK>']) for token in tokens]
# Pad or truncate
if len(token_ids) > max_length:
token_ids = token_ids[:max_length]
else:
token_ids += [self.word2idx['<PAD>']] * (max_length - len(token_ids))
return torch.LongTensor(token_ids).unsqueeze(0).to(self.device)
def decode_explanation(self, explanation_ids: torch.Tensor) -> str:
"""Decode explanation token IDs to text."""
ids = explanation_ids.squeeze().detach().cpu().tolist()
if isinstance(ids, int):
ids = [ids]
words = []
for token_id in ids:
word = self.idx2word.get(token_id, '<UNK>')
if word == '<END>':
break
if word not in ['<PAD>', '<START>', '<UNK>']:
words.append(word)
return ' '.join(words)
def predict(self, image: Image.Image, question: str) -> Tuple[str, str, float]:
"""
Make prediction for an image-question pair.
Args:
image: PIL Image
question: Question text
Returns:
Tuple of (predicted_answer, explanation, confidence)
"""
try:
with torch.no_grad():
# Preprocess image
if image.mode != 'RGB':
image = image.convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Preprocess question
question_tensor = self.tokenize_question(question)
# Generate prediction
answer_logits, explanation_ids = self.model.generate_explanation(
image_tensor, question_tensor, beam_size=3
)
# Decode answer
answer_probs = F.softmax(answer_logits, dim=1)
answer_idx = answer_logits.argmax(dim=1).item()
confidence = answer_probs.max().item()
predicted_answer = self.idx2answer.get(answer_idx, 'unknown')
# Decode explanation
explanation = self.decode_explanation(explanation_ids[0])
return predicted_answer, explanation, confidence
except Exception as e:
return f"Error: {str(e)}", "Could not generate explanation", 0.0
# Initialize the predictor
checkpoint_path = hf_hub_download(
repo_id="VLAI-AIVN/ViVQA-X_LSTM-Generative",
filename="best_model.pth"
)
predictor = ViVQAXPredictor(checkpoint_path)
def predict_vqa(image, question):
"""Gradio prediction function."""
if image is None:
return "Please upload an image", "No explanation available", "0.00"
if not question or question.strip() == "":
return "Please enter a question", "No explanation available", "0.00"
# Make prediction
answer, explanation, confidence = predictor.predict(image, question)
return answer, explanation, f"{confidence:.2f}"
force_light_theme_js = """
() => {
const params = new URLSearchParams(window.location.search);
if (!params.has('__theme')) {
params.set('__theme', 'light');
window.location.search = params.toString();
}
}
"""
# ──────────────────────────── Main ─────────────────────────
with gr.Blocks(theme='gstaff/sketch', css=vlai_template.custom_css, js=force_light_theme_js) as demo:
vlai_template.create_header() # don't change
gr.Markdown(
"""
### An Automated Pipeline for Constructing a Vietnamese VQA-NLE Dataset
This demo showcases the **LSTM-Generative** baseline model from our paper, trained on the **ViVQA-X** dataset. It answers questions about an image in Vietnamese and provides a natural language explanation for its answer.
**How to use:**
1. Upload an image.
2. Enter a question about the image in Vietnamese.
3. Click the "Submit" button to get the answer and explanation.
"""
)
with gr.Row(equal_height=True, variant="panel"):
with gr.Column(scale=3):
image_input = gr.Image(
type="pil",
label="Upload Image",
height=300
)
question_input = gr.Textbox(
label="Question (in Vietnamese)",
placeholder="Enter your question about the image...",
lines=2
)
submit_btn = gr.Button("Submit / Gα»­i πŸ”§", elem_classes="full-width-btn", variant="primary")
with gr.Column(scale=4):
answer_output = gr.Textbox(
label="Answer",
interactive=False
)
explanation_output = gr.Textbox(
label="Explanation",
lines=4,
interactive=False
)
confidence_output = gr.Textbox(
label="Confidence Score",
interactive=False
)
# Examples
gr.Examples(
examples=[
["examples/example1.jpg", "ĐÒy là loài chó gì?"],
["examples/example2.jpg", "ĐÒy là loài vật gì?"],
["examples/example3.jpg", "ĐÒy là loài vật gì?"]
],
inputs=[image_input, question_input],
label="Example Questions"
)
submit_btn.click(
fn=predict_vqa,
inputs=[image_input, question_input],
outputs=[answer_output, explanation_output, confidence_output]
)
question_input.submit(
fn=predict_vqa,
inputs=[image_input, question_input],
outputs=[answer_output, explanation_output, confidence_output]
)
gr.Markdown(
"""
### ⭐ Star Us on GitHub!
If you find this project useful, please consider giving us a star on GitHub. Your support is greatly appreciated!
<a href="https://github.com/duongtruongbinh/ViVQA-X" target="_blank">[duongtruongbinh/ViVQA-X]</a>
### πŸ“œ Citation
To use this dataset or model in your research, please cite our paper:
```bibtex
@misc{vivqax2025,
author = {Duong, Truong-Binh and Tran, Hoang-Minh and Le-Nguyen, Binh-Nam and Duong, Dinh-Thang},
title = {An Automated Pipeline for Constructing a Vietnamese VQA-NLE Dataset},
howpublished = {Accepted for publication in the Proceedings of The International Conference on Intelligent Systems & Networks (ICISN 2025), Springer Lecture Notes in Networks and Systems (LNNS)},
year = {2025}
}
```
""".strip()
)
vlai_template.create_footer() # don't change
if __name__ == "__main__":
demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static", "examples"]) # don't change