from PIL import Image import torch from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration from sentence_transformers import SentenceTransformer import faiss import numpy as np import json import spacy # Load models and resources clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32") clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32") text_encoder = SentenceTransformer('all-MiniLM-L6-v2') tokenizer = T5Tokenizer.from_pretrained("t5-small") generator = T5ForConditionalGeneration.from_pretrained("t5-small") nlp = spacy.load("en_core_web_sm") # Load FAISS index and captions faiss_index = faiss.read_index("./faiss_index.idx") with open("./captions.json", "r", encoding="utf-8") as f: captions = json.load(f) def extract_image_features(image): """ Extract image features using CLIP model. Input: PIL Image or image path (str). Output: Normalized image embedding (numpy array). """ try: # Handle both PIL Image and file path if isinstance(image, str): image = Image.open(image).convert("RGB") else: image = image.convert("RGB") inputs = clip_processor(images=image, return_tensors="pt") with torch.no_grad(): features = clip_model.get_image_features(**inputs) features = torch.nn.functional.normalize(features, p=2, dim=-1) return features.squeeze(0).cpu().numpy().astype("float32") except Exception as e: print(f"Error extracting features: {e}") return None def retrieve_similar_captions(image_embedding, k=5): """ Retrieve k most similar captions using FAISS index. Input: Image embedding (numpy array). Output: List of captions. """ if image_embedding.ndim == 1: image_embedding = image_embedding.reshape(1, -1) D, I = faiss_index.search(image_embedding, k) return [captions[i] for i in I[0]] def extract_location_names(texts): """ Extract location names from captions using spaCy. Input: List of captions. Output: List of unique location names. """ names = [] for text in texts: doc = nlp(text) for ent in doc.ents: if ent.label_ in ["GPE", "LOC", "FAC"]: names.append(ent.text) return list(set(names)) def generate_caption_from_retrieved(retrieved_captions): """ Generate a caption from retrieved captions using T5. Input: List of retrieved captions. Output: Generated caption (str). """ locations = extract_location_names(retrieved_captions) location_hint = f"The place might be: {', '.join(locations)}. " if locations else "" prompt = location_hint + " ".join(retrieved_captions) + " Generate a caption with the landmark name:" inputs = tokenizer(prompt, return_tensors="pt", truncation=True) outputs = generator.generate( input_ids=inputs.input_ids, attention_mask=inputs.attention_mask, max_length=300, num_beams=5, early_stopping=True ) return tokenizer.decode(outputs[0], skip_special_tokens=True) def generate_rag_caption(image): """ Generate a RAG-based caption for an image. Input: PIL Image or image path (str). Output: Caption (str). """ embedding = extract_image_features(image) if embedding is None: return "Failed to process image." retrieved = retrieve_similar_captions(embedding, k=5) if not retrieved: return "No similar captions found." return generate_caption_from_retrieved(retrieved) def predict(image): """ API-compatible function for inference. Input: PIL Image or image file path. Output: Dictionary with caption. """ caption = generate_rag_caption(image) return {"caption": caption}