RAGExplo / inference.py
bsoupy's picture
Upload 5 files
1fc786c verified
from PIL import Image
import torch
from transformers import CLIPProcessor, CLIPModel, T5Tokenizer, T5ForConditionalGeneration
from sentence_transformers import SentenceTransformer
import faiss
import numpy as np
import json
import spacy
# Load models and resources
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
text_encoder = SentenceTransformer('all-MiniLM-L6-v2')
tokenizer = T5Tokenizer.from_pretrained("t5-small")
generator = T5ForConditionalGeneration.from_pretrained("t5-small")
nlp = spacy.load("en_core_web_sm")
# Load FAISS index and captions
faiss_index = faiss.read_index("./faiss_index.idx")
with open("./captions.json", "r", encoding="utf-8") as f:
captions = json.load(f)
def extract_image_features(image):
"""
Extract image features using CLIP model.
Input: PIL Image or image path (str).
Output: Normalized image embedding (numpy array).
"""
try:
# Handle both PIL Image and file path
if isinstance(image, str):
image = Image.open(image).convert("RGB")
else:
image = image.convert("RGB")
inputs = clip_processor(images=image, return_tensors="pt")
with torch.no_grad():
features = clip_model.get_image_features(**inputs)
features = torch.nn.functional.normalize(features, p=2, dim=-1)
return features.squeeze(0).cpu().numpy().astype("float32")
except Exception as e:
print(f"Error extracting features: {e}")
return None
def retrieve_similar_captions(image_embedding, k=5):
"""
Retrieve k most similar captions using FAISS index.
Input: Image embedding (numpy array).
Output: List of captions.
"""
if image_embedding.ndim == 1:
image_embedding = image_embedding.reshape(1, -1)
D, I = faiss_index.search(image_embedding, k)
return [captions[i] for i in I[0]]
def extract_location_names(texts):
"""
Extract location names from captions using spaCy.
Input: List of captions.
Output: List of unique location names.
"""
names = []
for text in texts:
doc = nlp(text)
for ent in doc.ents:
if ent.label_ in ["GPE", "LOC", "FAC"]:
names.append(ent.text)
return list(set(names))
def generate_caption_from_retrieved(retrieved_captions):
"""
Generate a caption from retrieved captions using T5.
Input: List of retrieved captions.
Output: Generated caption (str).
"""
locations = extract_location_names(retrieved_captions)
location_hint = f"The place might be: {', '.join(locations)}. " if locations else ""
prompt = location_hint + " ".join(retrieved_captions) + " Generate a caption with the landmark name:"
inputs = tokenizer(prompt, return_tensors="pt", truncation=True)
outputs = generator.generate(
input_ids=inputs.input_ids,
attention_mask=inputs.attention_mask,
max_length=300,
num_beams=5,
early_stopping=True
)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
def generate_rag_caption(image):
"""
Generate a RAG-based caption for an image.
Input: PIL Image or image path (str).
Output: Caption (str).
"""
embedding = extract_image_features(image)
if embedding is None:
return "Failed to process image."
retrieved = retrieve_similar_captions(embedding, k=5)
if not retrieved:
return "No similar captions found."
return generate_caption_from_retrieved(retrieved)
def predict(image):
"""
API-compatible function for inference.
Input: PIL Image or image file path.
Output: Dictionary with caption.
"""
caption = generate_rag_caption(image)
return {"caption": caption}