File size: 9,622 Bytes
571201f b9d4197 571201f b2013c5 571201f b2013c5 571201f d52511f 571201f b2013c5 571201f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 |
import gradio as gr
import torch
import torch.nn.functional as F
from PIL import Image
from typing import Dict, List, Tuple
from underthesea import word_tokenize
from huggingface_hub import hf_hub_download
import vlai_template # must import
# Import transforms separately to avoid compatibility issues
try:
import torchvision.transforms as transforms
except RuntimeError:
# Fallback for compatibility issues
from torchvision import transforms
# Import model components
from src.vivqax_model import ViVQAX_Model
class ViVQAXPredictor:
def __init__(self, checkpoint_path: str):
"""
Initialize the ViVQA-X predictor.
Args:
checkpoint_path: Path to the trained model checkpoint
"""
self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load checkpoint
checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)
self.config = checkpoint['config']
# Load vocabularies
self.word2idx = checkpoint['word2idx']
self.idx2word = checkpoint['idx2word']
self.answer2idx = checkpoint['answer2idx']
self.idx2answer = checkpoint['idx2answer']
# Initialize model
self.model = ViVQAX_Model(
vocab_size=len(self.word2idx),
embed_size=self.config['model']['embed_size'],
hidden_size=self.config['model']['hidden_size'],
num_layers=self.config['model']['num_layers'],
num_answers=len(self.answer2idx),
max_explanation_length=self.config['model']['max_explanation_length'],
word2idx=self.word2idx
).to(self.device)
# Load model weights
self.model.load_state_dict(checkpoint['model_state_dict'])
self.model.eval()
# Image preprocessing
self.transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
transforms.Normalize(mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])
print(f"Model loaded successfully on {self.device}")
print(f"Vocabulary size: {len(self.word2idx)}")
print(f"Number of answers: {len(self.answer2idx)}")
def tokenize_question(self, question: str, max_length: int = 20) -> torch.Tensor:
"""Tokenize and encode a question."""
tokens = word_tokenize(question.lower())
token_ids = [self.word2idx.get(token, self.word2idx['<UNK>']) for token in tokens]
# Pad or truncate
if len(token_ids) > max_length:
token_ids = token_ids[:max_length]
else:
token_ids += [self.word2idx['<PAD>']] * (max_length - len(token_ids))
return torch.LongTensor(token_ids).unsqueeze(0).to(self.device)
def decode_explanation(self, explanation_ids: torch.Tensor) -> str:
"""Decode explanation token IDs to text."""
ids = explanation_ids.squeeze().detach().cpu().tolist()
if isinstance(ids, int):
ids = [ids]
words = []
for token_id in ids:
word = self.idx2word.get(token_id, '<UNK>')
if word == '<END>':
break
if word not in ['<PAD>', '<START>', '<UNK>']:
words.append(word)
return ' '.join(words)
def predict(self, image: Image.Image, question: str) -> Tuple[str, str, float]:
"""
Make prediction for an image-question pair.
Args:
image: PIL Image
question: Question text
Returns:
Tuple of (predicted_answer, explanation, confidence)
"""
try:
with torch.no_grad():
# Preprocess image
if image.mode != 'RGB':
image = image.convert('RGB')
image_tensor = self.transform(image).unsqueeze(0).to(self.device)
# Preprocess question
question_tensor = self.tokenize_question(question)
# Generate prediction
answer_logits, explanation_ids = self.model.generate_explanation(
image_tensor, question_tensor, beam_size=3
)
# Decode answer
answer_probs = F.softmax(answer_logits, dim=1)
answer_idx = answer_logits.argmax(dim=1).item()
confidence = answer_probs.max().item()
predicted_answer = self.idx2answer.get(answer_idx, 'unknown')
# Decode explanation
explanation = self.decode_explanation(explanation_ids[0])
return predicted_answer, explanation, confidence
except Exception as e:
return f"Error: {str(e)}", "Could not generate explanation", 0.0
# Initialize the predictor
checkpoint_path = hf_hub_download(
repo_id="VLAI-AIVN/ViVQA-X_LSTM-Generative",
filename="best_model.pth"
)
predictor = ViVQAXPredictor(checkpoint_path)
def predict_vqa(image, question):
"""Gradio prediction function."""
if image is None:
return "Please upload an image", "No explanation available", "0.00"
if not question or question.strip() == "":
return "Please enter a question", "No explanation available", "0.00"
# Make prediction
answer, explanation, confidence = predictor.predict(image, question)
return answer, explanation, f"{confidence:.2f}"
force_light_theme_js = """
() => {
const params = new URLSearchParams(window.location.search);
if (!params.has('__theme')) {
params.set('__theme', 'light');
window.location.search = params.toString();
}
}
"""
# ββββββββββββββββββββββββββββ Main βββββββββββββββββββββββββ
with gr.Blocks(theme='gstaff/sketch', css=vlai_template.custom_css, js=force_light_theme_js) as demo:
vlai_template.create_header() # don't change
gr.Markdown(
"""
### An Automated Pipeline for Constructing a Vietnamese VQA-NLE Dataset
This demo showcases the **LSTM-Generative** baseline model from our paper, trained on the **ViVQA-X** dataset. It answers questions about an image in Vietnamese and provides a natural language explanation for its answer.
**How to use:**
1. Upload an image.
2. Enter a question about the image in Vietnamese.
3. Click the "Submit" button to get the answer and explanation.
"""
)
with gr.Row(equal_height=True, variant="panel"):
with gr.Column(scale=3):
image_input = gr.Image(
type="pil",
label="Upload Image",
height=300
)
question_input = gr.Textbox(
label="Question (in Vietnamese)",
placeholder="Enter your question about the image...",
lines=2
)
submit_btn = gr.Button("Submit / Gα»i π§", elem_classes="full-width-btn", variant="primary")
with gr.Column(scale=4):
answer_output = gr.Textbox(
label="Answer",
interactive=False
)
explanation_output = gr.Textbox(
label="Explanation",
lines=4,
interactive=False
)
confidence_output = gr.Textbox(
label="Confidence Score",
interactive=False
)
# Examples
gr.Examples(
examples=[
["examples/example1.jpg", "ΔΓ’y lΓ loΓ i chΓ³ gΓ¬?"],
["examples/example2.jpg", "ΔΓ’y lΓ loΓ i vαΊt gΓ¬?"],
["examples/example3.jpg", "ΔΓ’y lΓ loΓ i vαΊt gΓ¬?"]
],
inputs=[image_input, question_input],
label="Example Questions"
)
submit_btn.click(
fn=predict_vqa,
inputs=[image_input, question_input],
outputs=[answer_output, explanation_output, confidence_output]
)
question_input.submit(
fn=predict_vqa,
inputs=[image_input, question_input],
outputs=[answer_output, explanation_output, confidence_output]
)
gr.Markdown(
"""
### β Star Us on GitHub!
If you find this project useful, please consider giving us a star on GitHub. Your support is greatly appreciated!
<a href="https://github.com/duongtruongbinh/ViVQA-X" target="_blank">[duongtruongbinh/ViVQA-X]</a>
### π Citation
To use this dataset or model in your research, please cite our paper:
```bibtex
@misc{vivqax2025,
author = {Duong, Truong-Binh and Tran, Hoang-Minh and Le-Nguyen, Binh-Nam and Duong, Dinh-Thang},
title = {An Automated Pipeline for Constructing a Vietnamese VQA-NLE Dataset},
howpublished = {Accepted for publication in the Proceedings of The International Conference on Intelligent Systems & Networks (ICISN 2025), Springer Lecture Notes in Networks and Systems (LNNS)},
year = {2025}
}
```
""".strip()
)
vlai_template.create_footer() # don't change
if __name__ == "__main__":
demo.launch(allowed_paths=["static/aivn_logo.png", "static/vlai_logo.png", "static", "examples"]) # don't change |