import os import torch import torch.nn as nn from torchvision import models, transforms from PIL import Image import json import sys # Print debug information print("Handler module loaded") print(f"Python version: {sys.version}") print(f"PyTorch version: {torch.__version__}") print(f"Directory contents: {os.listdir('.')}") if os.path.exists('/repository'): print(f"Repository directory contents: {os.listdir('/repository')}") # For debugging class ViTForImageClassification: @staticmethod def from_pretrained(model_dir): # This is a fake method to catch erroneous imports print(f"ERROR: ViTForImageClassification.from_pretrained was called with {model_dir}") raise ValueError("ViTForImageClassification is not the correct model for this application") class EndpointHandler: def __init__(self, model_dir): """ Initialize the model for AI image detection """ print(f"Initializing EndpointHandler with model_dir: {model_dir}") # Set device self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {self.device}") # Define transforms first (in case model loading fails) self.transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) ]) # Class names self.classes = ["Real Image", "AI-Generated Image"] # Load model try: self.model = self._load_model(model_dir) print("Model loaded successfully") except Exception as e: print(f"Error loading model: {e}") # Create a dummy model as fallback print("Creating a dummy model as fallback") self.model = models.efficientnet_v2_s(pretrained=True) self.model.classifier[-1] = nn.Linear( self.model.classifier[-1].in_features, 2 ) self.model.eval() def _load_model(self, model_dir): print(f"Loading model from directory: {model_dir}") print(f"Directory contents: {os.listdir(model_dir)}") # Create model architecture model = models.efficientnet_v2_s(weights=None) # Recreate classifier exactly as in training model.classifier = nn.Sequential( nn.Linear(model.classifier[1].in_features, 1024), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(1024, 512), nn.ReLU(), nn.Dropout(p=0.3), nn.Linear(512, 2) ) # Try to find model file in multiple possible locations model_found = False possible_paths = [ os.path.join(model_dir, "best_model_improved.pth"), os.path.join(model_dir, "pytorch_model.bin"), "best_model_improved.pth", "/repository/best_model_improved.pth" ] for model_path in possible_paths: print(f"Trying model path: {model_path}") if os.path.exists(model_path): print(f"Found model at: {model_path}") model.load_state_dict(torch.load(model_path, map_location=self.device)) model_found = True break if not model_found: # Check if we need to copy the model file if os.path.exists('best_model_improved.pth') and not os.path.exists(os.path.join(model_dir, 'best_model_improved.pth')): import shutil print(f"Copying model file to {model_dir}") shutil.copy('best_model_improved.pth', os.path.join(model_dir, 'best_model_improved.pth')) model.load_state_dict(torch.load(os.path.join(model_dir, 'best_model_improved.pth'), map_location=self.device)) model_found = True if not model_found: raise FileNotFoundError(f"Model file not found in any of these locations: {possible_paths}") model.to(self.device) model.eval() return model def __call__(self, data): """ Run prediction on the input data """ try: print(f"Received prediction request with data type: {type(data)}") # Parse request data if isinstance(data, dict) and "inputs" in data: # API format input_data = data["inputs"] print(f"Extracted input data from API format, type: {type(input_data)}") else: # Direct image input_data = data # Process image if isinstance(input_data, str): # Base64 string print("Processing base64 string image") import base64 from io import BytesIO # Decode base64 image if ',' in input_data: input_data = input_data.split(",", 1)[1] image_bytes = base64.b64decode(input_data) image = Image.open(BytesIO(image_bytes)).convert("RGB") elif hasattr(input_data, "read"): # File-like object print("Processing file-like object image") image = Image.open(input_data).convert("RGB") elif isinstance(input_data, Image.Image): print("Processing PIL Image") image = input_data else: print(f"Unsupported input type: {type(input_data)}") return {"error": f"Unsupported input type: {type(input_data)}"} # Preprocess image image_tensor = self.transform(image).unsqueeze(0).to(self.device) # Make prediction with torch.no_grad(): outputs = self.model(image_tensor) probabilities = torch.nn.functional.softmax(outputs, dim=1)[0] prediction = torch.argmax(probabilities).item() # Format results real_prob = probabilities[0].item() * 100 ai_prob = probabilities[1].item() * 100 # 修改这里: 返回符合 API 要求的格式 (Array) # 而不是返回原来的字典格式 return [ { "label": "Real Image", "score": float(real_prob) }, { "label": "AI-Generated Image", "score": float(ai_prob) } ] except Exception as e: import traceback print(f"Error during prediction: {e}") traceback.print_exc() return {"error": str(e), "traceback": traceback.format_exc()}