Spaces:
Sleeping
Sleeping
import torch | |
import os | |
from PIL import Image | |
from transformers import AutoModelForImageClassification, SiglipImageProcessor | |
import gradio as gr | |
# Alternative OCR using transformers | |
def setup_alternative_ocr(): | |
"""Setup alternative OCR using transformers models""" | |
try: | |
from transformers import TrOCRProcessor, VisionEncoderDecoderModel | |
print("Setting up TrOCR for text extraction...") | |
ocr_processor = TrOCRProcessor.from_pretrained("microsoft/trocr-base-printed") | |
ocr_model = VisionEncoderDecoderModel.from_pretrained("microsoft/trocr-base-printed") | |
print("β TrOCR model loaded successfully!") | |
return ocr_processor, ocr_model, True | |
except Exception as e: | |
print(f"β οΈ Could not load TrOCR: {e}") | |
return None, None, False | |
# Try to setup OCR | |
OCR_PROCESSOR, OCR_MODEL, OCR_AVAILABLE = setup_alternative_ocr() | |
# Model path | |
MODEL_PATH = "./model" | |
try: | |
print(f"=== Loading model from: {MODEL_PATH} ===") | |
print(f"Available files: {os.listdir(MODEL_PATH)}") | |
# Load the model | |
print("Loading model...") | |
model = AutoModelForImageClassification.from_pretrained(MODEL_PATH, local_files_only=True) | |
print("β Model loaded successfully!") | |
# Load image processor | |
print("Loading image processor...") | |
try: | |
processor = SiglipImageProcessor.from_pretrained(MODEL_PATH, local_files_only=True) | |
print("β Image processor loaded from local files!") | |
except Exception as e: | |
print(f"β οΈ Could not load local processor: {e}") | |
print("Loading image processor from base SigLIP model...") | |
processor = SiglipImageProcessor.from_pretrained("google/siglip-base-patch16-224") | |
print("β Image processor loaded from base model!") | |
# Get labels | |
if hasattr(model.config, 'id2label') and model.config.id2label: | |
labels = model.config.id2label | |
print(f"β Found {len(labels)} labels in model config") | |
else: | |
num_labels = model.config.num_labels if hasattr(model.config, 'num_labels') else 2 | |
labels = {i: f"class_{i}" for i in range(num_labels)} | |
print(f"β Created {len(labels)} generic labels") | |
print("π Model setup complete!") | |
except Exception as e: | |
print(f"β Error loading model: {e}") | |
print(f"Files in model directory: {os.listdir(MODEL_PATH)}") | |
raise | |
def extract_text_alternative(image): | |
"""Extract text using TrOCR model""" | |
if not OCR_AVAILABLE: | |
return "OCR not available" | |
try: | |
# Convert to RGB if needed | |
if image.mode != 'RGB': | |
image = image.convert('RGB') | |
# Process with TrOCR | |
pixel_values = OCR_PROCESSOR(image, return_tensors="pt").pixel_values | |
generated_ids = OCR_MODEL.generate(pixel_values) | |
generated_text = OCR_PROCESSOR.batch_decode(generated_ids, skip_special_tokens=True)[0] | |
return generated_text | |
except Exception as e: | |
return f"OCR error: {str(e)}" | |
def classify_meme(image: Image.Image): | |
""" | |
Classify meme and extract text | |
""" | |
try: | |
# Extract text using alternative OCR | |
if OCR_AVAILABLE: | |
extracted_text = extract_text_alternative(image) | |
else: | |
extracted_text = "OCR not available in this environment" | |
# Process image for classification | |
inputs = processor(images=image, return_tensors="pt") | |
# Run inference | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
# Get predictions | |
predictions = {} | |
for i in range(len(labels)): | |
label = labels.get(i, f"class_{i}") | |
predictions[label] = float(probs[0][i]) | |
# Sort predictions by confidence | |
sorted_predictions = dict(sorted(predictions.items(), key=lambda x: x[1], reverse=True)) | |
# Debug prints | |
print("=== Classification Results ===") | |
print(f"Extracted Text: '{extracted_text.strip()}'") | |
print("Top 3 Predictions:") | |
for i, (label, prob) in enumerate(list(sorted_predictions.items())[:3]): | |
print(f" {i+1}. {label}: {prob:.4f}") | |
return sorted_predictions, extracted_text.strip() | |
except Exception as e: | |
error_msg = f"Error processing image: {str(e)}" | |
print(f"β {error_msg}") | |
return {"Error": 1.0}, error_msg | |
# Create Gradio interface | |
demo = gr.Interface( | |
fn=classify_meme, | |
inputs=gr.Image(type="pil", label="Upload Meme Image"), | |
outputs=[ | |
gr.Label(num_top_classes=5, label="Meme Classification"), | |
gr.Textbox(label="Extracted Text", lines=3) | |
], | |
title="π Meme Classifier" + (" with TrOCR" if OCR_AVAILABLE else ""), | |
description=f""" | |
Upload a meme image to **classify** its content using your trained SigLIP2_77 model. | |
{'β **Text extraction** available via TrOCR (Microsoft Transformer OCR)' if OCR_AVAILABLE else 'β οΈ **Text extraction** not available'} | |
Your model will predict the category/sentiment of the uploaded meme. | |
""", | |
examples=None, | |
allow_flagging="never" | |
) | |
if __name__ == "__main__": | |
print("π Starting Gradio interface...") | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=7860, | |
share=False | |
) |