|
|
|
from PIL import Image |
|
import torch |
|
from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration |
|
from sentence_transformers import SentenceTransformer |
|
import faiss |
|
import numpy as np |
|
import json |
|
import spacy |
|
|
|
|
|
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") |
|
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") |
|
text_encoder = SentenceTransformer('all-MiniLM-L6-v2') |
|
tokenizer = T5Tokenizer.from_pretrained("t5-small") |
|
generator = T5ForConditionalGeneration.from_pretrained("t5-small") |
|
nlp = spacy.load("en_core_web_sm") |
|
|
|
|
|
faiss_index = faiss.read_index("./faiss_index.idx") |
|
with open("./captions.json", "r", encoding="utf-8") as f: |
|
captions = json.load(f) |
|
|
|
def extract_image_features(image): |
|
""" |
|
Extract image features using CLIP model. |
|
Input: PIL Image or image path (str). |
|
Output: Normalized image embedding (numpy array). |
|
""" |
|
try: |
|
|
|
if isinstance(image, str): |
|
image = Image.open(image).convert("RGB") |
|
else: |
|
image = image.convert("RGB") |
|
inputs = clip_processor(images=image, return_tensors="pt") |
|
with torch.no_grad(): |
|
features = clip_model.get_image_features(**inputs) |
|
features = torch.nn.functional.normalize(features, p=2, dim=-1) |
|
return features.squeeze(0).cpu().numpy().astype("float32") |
|
except Exception as e: |
|
print(f"Error extracting features: {e}") |
|
return None |
|
|
|
def retrieve_similar_captions(image_embedding, k=5): |
|
""" |
|
Retrieve k most similar captions using FAISS index. |
|
Input: Image embedding (numpy array). |
|
Output: List of captions. |
|
""" |
|
if image_embedding.ndim == 1: |
|
image_embedding = image_embedding.reshape(1, -1) |
|
D, I = faiss_index.search(image_embedding, k) |
|
return [captions[i] for i in I[0]] |
|
|
|
def extract_location_names(texts): |
|
""" |
|
Extract location names from captions using spaCy. |
|
Input: List of captions. |
|
Output: List of unique location names. |
|
""" |
|
names = [] |
|
for text in texts: |
|
doc = nlp(text) |
|
for ent in doc.ents: |
|
if ent.label_ in ["GPE", "LOC", "FAC"]: |
|
names.append(ent.text) |
|
return list(set(names)) |
|
|
|
def generate_caption_from_retrieved(retrieved_captions): |
|
""" |
|
Generate a caption from retrieved captions using T5. |
|
Input: List of retrieved captions. |
|
Output: Generated caption (str). |
|
""" |
|
locations = extract_location_names(retrieved_captions) |
|
location_hint = f"The place might be: {', '.join(locations)}. " if locations else "" |
|
prompt = location_hint + " ".join(retrieved_captions) + " Generate a caption with the landmark name:" |
|
inputs = tokenizer(prompt, return_tensors="pt", truncation=True) |
|
outputs = generator.generate( |
|
input_ids=inputs.input_ids, |
|
attention_mask=inputs.attention_mask, |
|
max_length=300, |
|
num_beams=5, |
|
early_stopping=True |
|
) |
|
return tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
def generate_rag_caption(image): |
|
""" |
|
Generate a RAG-based caption for an image. |
|
Input: PIL Image or image path (str). |
|
Output: Caption (str). |
|
""" |
|
embedding = extract_image_features(image) |
|
if embedding is None: |
|
return "Failed to process image." |
|
retrieved = retrieve_similar_captions(embedding, k=5) |
|
if not retrieved: |
|
return "No similar captions found." |
|
return generate_caption_from_retrieved(retrieved) |
|
|
|
def predict(image): |
|
""" |
|
API-compatible function for inference. |
|
Input: PIL Image or image file path. |
|
Output: Dictionary with caption. |
|
""" |
|
caption = generate_rag_caption(image) |
|
return {"caption": caption} |
|
|