## using pre-trained swin transformer to train the model

In [1]:
import torch
from torchvision import datasets, transforms
from torch.utils.data import DataLoader

# Define transformation pipeline
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize to 224x224 for Swin Transformer
    transforms.ToTensor(),
])

# Load dataset
train_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/data/brain_tumor_dataset/train', transform=transform)
val_dataset = datasets.ImageFolder(root='/content/drive/MyDrive/data/brain_tumor_dataset/test', transform=transform)

# Create dataloaders
train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False)


In [2]:
from transformers import SwinForImageClassification

# Load the pre-trained Swin Transformer model with 4 output classes
model = SwinForImageClassification.from_pretrained(
    'microsoft/swin-tiny-patch4-window7-224',
    num_labels=2,  # Number of tumor types
    ignore_mismatched_sizes=True  # Ignore size mismatch for the classifier layer
)

config.json:   0%|          | 0.00/71.8k [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/113M [00:00<?, ?B/s]

Some weights of SwinForImageClassification were not initialized from the model checkpoint at microsoft/swin-tiny-patch4-window7-224 and are newly initialized because the shapes did not match:
- classifier.bias: found shape torch.Size([1000]) in the checkpoint and torch.Size([2]) in the model instantiated
- classifier.weight: found shape torch.Size([1000, 768]) in the checkpoint and torch.Size([2, 768]) in the model instantiated
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
# Freeze all layers except the final classification head
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the classification layer
for param in model.classifier.parameters():
    param.requires_grad = True


In [4]:
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn as nn

# Set up optimizer and loss function
optimizer = optim.Adam(model.parameters(), lr=1e-4)
criterion = nn.CrossEntropyLoss()

# Move model to GPU if available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model.to(device)

SwinForImageClassification(
  (swin): SwinModel(
    (embeddings): SwinEmbeddings(
      (patch_embeddings): SwinPatchEmbeddings(
        (projection): Conv2d(3, 96, kernel_size=(4, 4), stride=(4, 4))
      )
      (norm): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
      (dropout): Dropout(p=0.0, inplace=False)
    )
    (encoder): SwinEncoder(
      (layers): ModuleList(
        (0): SwinStage(
          (blocks): ModuleList(
            (0-1): 2 x SwinLayer(
              (layernorm_before): LayerNorm((96,), eps=1e-05, elementwise_affine=True)
              (attention): SwinAttention(
                (self): SwinSelfAttention(
                  (query): Linear(in_features=96, out_features=96, bias=True)
                  (key): Linear(in_features=96, out_features=96, bias=True)
                  (value): Linear(in_features=96, out_features=96, bias=True)
                  (dropout): Dropout(p=0.0, inplace=False)
                )
                (output): SwinSelfOutput(
  

In [5]:
# Training loop
for epoch in range(10):  # Train for 10 epochs (adjust as needed)
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for inputs, labels in train_loader:
        inputs, labels = inputs.to(device), labels.to(device)

        # Zero the parameter gradients
        optimizer.zero_grad()

        # Forward pass
        outputs = model(inputs).logits
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        loss.backward()
        optimizer.step()

        # Calculate accuracy
        _, predicted = torch.max(outputs, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()
        running_loss += loss.item()

    # Print training stats
    print(f'Epoch [{epoch+1}/10], Loss: {running_loss/len(train_loader)}, Accuracy: {100 * correct / total}%')

    # Validation
    model.eval()
    val_correct = 0
    val_total = 0
    with torch.no_grad():
        for inputs, labels in val_loader:
            inputs, labels = inputs.to(device), labels.to(device)
            outputs = model(inputs).logits
            _, predicted = torch.max(outputs, 1)
            val_total += labels.size(0)
            val_correct += (predicted == labels).sum().item()

    print(f'Validation Accuracy: {100 * val_correct / val_total}%')

Epoch [1/10], Loss: 0.6803878396749496, Accuracy: 58.13692480359147%
Validation Accuracy: 59.060402684563755%
Epoch [2/10], Loss: 0.6379124437059674, Accuracy: 72.61503928170595%
Validation Accuracy: 69.12751677852349%
Epoch [3/10], Loss: 0.6045689774411065, Accuracy: 77.32884399551067%
Validation Accuracy: 74.49664429530202%
Epoch [4/10], Loss: 0.5734582436936242, Accuracy: 79.12457912457913%
Validation Accuracy: 78.52348993288591%
Epoch [5/10], Loss: 0.5508207274334771, Accuracy: 80.13468013468014%
Validation Accuracy: 80.53691275167785%
Epoch [6/10], Loss: 0.5296014377049038, Accuracy: 80.69584736251403%
Validation Accuracy: 78.52348993288591%
Epoch [7/10], Loss: 0.5103116855025291, Accuracy: 82.37934904601572%
Validation Accuracy: 79.86577181208054%
Epoch [8/10], Loss: 0.48474655938999994, Accuracy: 83.72615039281706%
Validation Accuracy: 77.85234899328859%
Epoch [9/10], Loss: 0.48020742727177484, Accuracy: 83.16498316498317%
Validation Accuracy: 79.19463087248322%
Epoch [10/10], L

In [6]:
# Test the model
model.eval()
test_correct = 0
test_total = 0

with torch.no_grad():
    for inputs, labels in val_loader:
        inputs, labels = inputs.to(device), labels.to(device)
        outputs = model(inputs).logits
        _, predicted = torch.max(outputs, 1)
        test_total += labels.size(0)
        test_correct += (predicted == labels).sum().item()

print(f'Test Accuracy: {100 * test_correct / test_total}%')

Test Accuracy: 79.86577181208054%


In [7]:
torch.save(model.state_dict(), 'swin_brain_tumor_classifier.pth')

In [8]:
from PIL import Image

In [20]:
# Load the saved model
model.load_state_dict(torch.load('swin_brain_tumor_classifier.pth'))
model.eval()

# Make predictions on new data
img = Image.open('/content/drive/MyDrive/data/brain_tumor_dataset/test/healthy/0566.jpg')
img = transform(img).unsqueeze(0).to(device)

# Predict
output = model(img).logits
_, predicted = torch.max(output, 1)
print(f'Predicted class: {predicted.item()}')

  model.load_state_dict(torch.load('swin_brain_tumor_classifier.pth'))


Predicted class: 0


In [21]:
#healthy tumor
path = '/content/drive/MyDrive/data/brain_tumor_dataset/test/healthy'


In [22]:
import os

In [25]:
files = os.listdir(path)

for f in files:
  try:
    img = Image.open(os.path.join(path,f))
    img = transform(img).unsqueeze(0).to(device)
    output = model(img).logits
    _, predicted = torch.max(output, 1)
    print(f'predicted class: {predicted.item()} filename: {f} actual class: 0')
  except Exception as e:
    print(e)
    continue

predicted class: 0 filename: 0796.jpg actual class: 0
predicted class: 0 filename: 0676.jpg actual class: 0
predicted class: 1 filename: 0698.jpg actual class: 0
predicted class: 1 filename: 0601.jpg actual class: 0
predicted class: 0 filename: 0861.jpg actual class: 0
predicted class: 1 filename: 0615.jpg actual class: 0
predicted class: 0 filename: 0874.jpg actual class: 0
predicted class: 0 filename: 0820.jpg actual class: 0
predicted class: 0 filename: 0785.jpg actual class: 0
predicted class: 0 filename: 0792.jpg actual class: 0
predicted class: 0 filename: 0731.jpg actual class: 0
predicted class: 0 filename: 0762.jpg actual class: 0
predicted class: 1 filename: 0710.jpg actual class: 0
predicted class: 0 filename: 0858.jpg actual class: 0
predicted class: 0 filename: 0691.jpg actual class: 0
predicted class: 0 filename: 0791.jpg actual class: 0
predicted class: 1 filename: 0639.jpg actual class: 0
predicted class: 1 filename: 0596.jpg actual class: 0
predicted class: 1 filename:

In [39]:
#calculating the accuracy
def calculate_accuracy(model, img_path, img_files, actual_class):
  total_images = len(img_files)
  predicted_ones = 0
  for i in img_files:
    try:
      img = Image.open(os.path.join(img_path,i))
      img = transform(img).unsqueeze(0).to(device)
      output = model(img).logits
      _, predicted = torch.max(output, 1)
      if int(predicted.item()) == int(actual_class):
        predicted_ones += 1
    except Exception as e:
      continue
  accuracy_score = (predicted_ones/total_images)*100
  return accuracy_score

In [40]:
img_path = '/content/drive/MyDrive/data/brain_tumor_dataset/train/healthy'
img_files = os.listdir(img_path)
print("Accuracy score:",calculate_accuracy(model, img_path, img_files, 0))

Accuracy score: 62.02830188679245


In [41]:
img_path = '/content/drive/MyDrive/data/brain_tumor_dataset/train/tumor'
img_files = os.listdir(img_path)
print("Accuracy score:",calculate_accuracy(model, img_path, img_files, 1))

Accuracy score: 85.65310492505354


### For healthy class model accuracy score is 62%
### For tumor images model accuracy score is 85%