import gradio as gr import torch from torchvision import transforms from PIL import Image from transformers import SwinForImageClassification, AutoImageProcessor import os device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # Load image processor processor = AutoImageProcessor.from_pretrained("microsoft/swin-tiny-patch4-window7-224") # Load model model = SwinForImageClassification.from_pretrained( "microsoft/swin-tiny-patch4-window7-224", num_labels=2, ignore_mismatched_sizes=True ) model.load_state_dict(torch.load("model/oral_cancer_swin_new.pth", map_location=device)) model.to(device) model.eval() labels = ["Cancer", "Non-Cancer"] def predict(image): inputs = processor(images=image, return_tensors="pt").to(device) with torch.no_grad(): outputs = model(**inputs) pred = torch.argmax(outputs.logits, dim=1).item() return labels[pred] demo = gr.Interface( fn=predict, inputs=gr.Image(type="pil"), outputs="text", title="Oral Cancer Detection", description="Upload a tongue image to detect whether it shows signs of Cancer or not." ) demo.launch()