Spaces:
Runtime error
Runtime error
| import json | |
| import torch | |
| from PIL import Image | |
| from ruamel import yaml | |
| from model import albef_model_for_vqa | |
| from data.transforms import ALBEFTextTransform, testing_image_transform | |
| import gradio as gr | |
| data_dir = "./" | |
| config = yaml.load(open("./configs/vqa.yaml", "r"), Loader=yaml.Loader) | |
| model = albef_model_for_vqa(config) | |
| checkpoint_url = "https://download.pytorch.org/models/multimodal/albef/finetuned_vqa_checkpoint.pt" | |
| checkpoint = torch.hub.load_state_dict_from_url(checkpoint_url, map_location='cpu') | |
| model.load_state_dict(checkpoint) | |
| image_transform = testing_image_transform() | |
| question_transform = ALBEFTextTransform(add_end_token=False) | |
| answer_transform = ALBEFTextTransform(do_pre_process=False) | |
| vqa_data = json.load(open(data_dir + "vqa_data.json", "r")) | |
| answer_list = json.load(open(data_dir + "answer_list.json", "r")) | |
| examples = [[data['image'], data['question']] for data in vqa_data] | |
| title = 'VQA with ALBEF' | |
| description = 'VQA with [ALBEF](https://arxiv.org/abs/2107.07651), adapted from the [torchmultimodal example notebook](https://github.com/facebookresearch/multimodal/blob/main/examples/albef/vqa_with_albef.ipynb).' | |
| article = '''```bibtex | |
| @article{li2021align, | |
| title={Align before fuse: Vision and language representation learning with momentum distillation}, | |
| author={Li, Junnan and Selvaraju, Ramprasaath and Gotmare, Akhilesh and Joty, Shafiq and Xiong, Caiming and Hoi, Steven Chu Hong}, | |
| journal={Advances in neural information processing systems}, | |
| volume={34}, | |
| pages={9694--9705}, | |
| year={2021} | |
| } | |
| ```''' | |
| def infer(image, question): | |
| images = [image] | |
| image_input = [image_transform(image) for image in images] | |
| image_input = torch.stack(image_input, dim=0) | |
| question_input = question_transform([question]) | |
| question_atts = (question_input != 0).type(torch.long) | |
| answer_input = answer_transform(answer_list) | |
| answer_atts = (answer_input != 0).type(torch.long) | |
| answer_ids, _ = model( | |
| image_input, | |
| question_input, | |
| question_atts, | |
| answer_input, | |
| answer_atts, | |
| k=1, | |
| is_train=False, | |
| ) | |
| predicted_answer_id = answer_ids[0] | |
| predicted_answer = answer_list[predicted_answer_id] | |
| return predicted_answer | |
| demo = gr.Interface( | |
| fn=infer, | |
| inputs=[gr.Image(label='image', type='pil', image_mode='RGB'), gr.Text(label='question')], | |
| outputs=gr.Text(label='answer'), | |
| examples=examples, | |
| title=title, | |
| description=description, | |
| article=article | |
| ) | |
| demo.launch() |