import torch import os from PIL import Image from transformers import AutoModelForImageClassification, SiglipImageProcessor import gradio as gr # Alternative OCR using transformers def setup_alternative_ocr(): """Setup alternative OCR using transformers models""" try: from transformers import TrOCRProcessor, VisionEncoderDecoderModel print("Setting up TrOCR for text extraction...") ocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") ocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed") print("✅ TrOCR model loaded successfully!") return ocr_processor, ocr_model, True except Exception as e: print(f"⚠️ Could not load TrOCR: {e}") return None, None, False # Try to setup OCR OCR_PROCESSOR, OCR_MODEL, OCR_AVAILABLE = setup_alternative_ocr() # Model path MODEL_PATH = "./model" try: print(f"=== Loading model from: {MODEL_PATH} ===") print(f"Available files: {os.listdir(MODEL_PATH)}") # Load the model print("Loading model...") model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True) print("✅ Model loaded successfully!") # Load image processor print("Loading image processor...") try: processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True) print("✅ Image processor loaded from local files!") except Exception as e: print(f"⚠️ Could not load local processor: {e}") print("Loading image processor from base SigLIP model...") processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224") print("✅ Image processor loaded from base model!") # Get labels if hasattr(model.config, 'id2label') and model.config.id2label: labels = model.config.id2label print(f"✅ Found {len(labels)} labels in model config") else: num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2 labels = {i: f"class_{i}" for i in range(num_labels)} print(f"✅ Created {len(labels)} generic labels") print("🎉 Model setup complete!") except Exception as e: print(f"❌ Error loading model: {e}") print(f"Files in model directory: {os.listdir(MODEL_PATH)}") raise def extract_text_alternative(image): """Extract text using TrOCR model""" if not OCR_AVAILABLE: return "OCR not available" try: # Convert to RGB if needed if image.mode != 'RGB': image = image.convert('RGB') # Process with TrOCR pixel_values = OCR_PROCESSOR(image, return_tensors="pt").pixel_values generated_ids = OCR_MODEL.generate(pixel_values) generated_text = OCR_PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0] return generated_text except Exception as e: return f"OCR error: {str(e)}" def classify_meme(image: Image.Image): """ Classify meme and extract text """ try: # Extract text using alternative OCR if OCR_AVAILABLE: extracted_text = extract_text_alternative(image) else: extracted_text = "OCR not available in this environment" # Process image for classification inputs = processor(images=image, return_tensors="pt") # Run inference with torch.no_grad(): outputs = model(**inputs) probs = torch.nn.functional.softmax(outputs.logits, dim=-1) # Get predictions predictions = {} for i in range(len(labels)): label = labels.get(i, f"class_{i}") predictions[label] = float(probs[0][i]) # Sort predictions by confidence sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True)) # Debug prints print("=== Classification Results ===") print(f"Extracted Text: '{extracted_text.strip()}'") print("Top 3 Predictions:") for i, (label, prob) in enumerate(list(sorted_predictions.items())[:3]): print(f" {i+1}. {label}: {prob:.4f}") return sorted_predictions, extracted_text.strip() except Exception as e: error_msg = f"Error processing image: {str(e)}" print(f"❌ {error_msg}") return {"Error": 1.0}, error_msg # Create Gradio interface demo = gr.Interface( fn=classify_meme, inputs=gr.Image(type="pil", label="Upload Meme Image"), outputs=[ gr.Label(num_top_classes=5, label="Meme Classification"), gr.Textbox(label="Extracted Text", lines=3) ], title="🎭 Meme Classifier" + (" with TrOCR" if OCR_AVAILABLE else ""), description=f""" Upload a meme image to **classify** its content using your trained SigLIP2_77 model. {'✅ **Text extraction** available via TrOCR (Microsoft Transformer OCR)' if OCR_AVAILABLE else '⚠️ **Text extraction** not available'} Your model will predict the category/sentiment of the uploaded meme. """, examples=None, allow_flagging="never" ) if __name__ == "__main__": print("🚀 Starting Gradio interface...") demo.launch( server_name="0.0.0.0", server_port=7860, share=False )