""" State-of-the-Art Vision Transformer for Defect Detection Implements modern ViT architecture with advanced features """ import torch import torch.nn as nn import torch.nn.functional as F from torchvision import models import timm from typing import Optional, Tuple import math class PatchEmbedding(nn.Module): """Convert image into patches and embed them""" def __init__(self, img_size=224, patch_size=16, in_channels=3, embed_dim=768): super().__init__() self.img_size = img_size self.patch_size = patch_size self.n_patches = (img_size // patch_size) ** 2 self.projection = nn.Conv2d( in_channels, embed_dim, kernel_size=patch_size, stride=patch_size ) def forward(self, x): x = self.projection(x) # (B, E, H/P, W/P) x = x.flatten(2) # (B, E, N) x = x.transpose(1, 2) # (B, N, E) return x class MultiHeadAttention(nn.Module): """Multi-Head Self Attention with improvements""" def __init__(self, embed_dim, num_heads, dropout=0.1): super().__init__() self.embed_dim = embed_dim self.num_heads = num_heads self.head_dim = embed_dim // num_heads self.qkv = nn.Linear(embed_dim, embed_dim * 3) self.attn_dropout = nn.Dropout(dropout) self.projection = nn.Linear(embed_dim, embed_dim) self.proj_dropout = nn.Dropout(dropout) def forward(self, x): B, N, E = x.shape qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim) qkv = qkv.permute(2, 0, 3, 1, 4) q, k, v = qkv[0], qkv[1], qkv[2] attn = (q @ k.transpose(-2, -1)) * (self.head_dim ** -0.5) attn = F.softmax(attn, dim=-1) attn = self.attn_dropout(attn) x = (attn @ v).transpose(1, 2).reshape(B, N, E) x = self.projection(x) x = self.proj_dropout(x) return x class TransformerBlock(nn.Module): """Transformer block with LayerNorm and residual connections""" def __init__(self, embed_dim, num_heads, mlp_ratio=4, dropout=0.1): super().__init__() self.norm1 = nn.LayerNorm(embed_dim) self.attn = MultiHeadAttention(embed_dim, num_heads, dropout) self.norm2 = nn.LayerNorm(embed_dim) mlp_hidden = int(embed_dim * mlp_ratio) self.mlp = nn.Sequential( nn.Linear(embed_dim, mlp_hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(mlp_hidden, embed_dim), nn.Dropout(dropout) ) def forward(self, x): x = x + self.attn(self.norm1(x)) x = x + self.mlp(self.norm2(x)) return x class VisionTransformer(nn.Module): """ Vision Transformer for Defect Detection Supports multiple configurations from small to large """ def __init__( self, img_size=224, patch_size=16, in_channels=3, num_classes=2, embed_dim=768, depth=12, num_heads=12, mlp_ratio=4, dropout=0.1, use_pretrained=True ): super().__init__() self.patch_embed = PatchEmbedding(img_size, patch_size, in_channels, embed_dim) self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) self.pos_embed = nn.Parameter(torch.zeros(1, 1 + self.patch_embed.n_patches, embed_dim)) self.pos_dropout = nn.Dropout(dropout) self.blocks = nn.ModuleList([ TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout) for _ in range(depth) ]) self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, num_classes) # Initialize weights nn.init.trunc_normal_(self.pos_embed, std=0.02) nn.init.trunc_normal_(self.cls_token, std=0.02) def forward(self, x): B = x.shape[0] x = self.patch_embed(x) cls_tokens = self.cls_token.expand(B, -1, -1) x = torch.cat((cls_tokens, x), dim=1) x = x + self.pos_embed x = self.pos_dropout(x) for block in self.blocks: x = block(x) x = self.norm(x) cls_token_final = x[:, 0] x = self.head(cls_token_final) return x class HybridViT(nn.Module): """ Hybrid Vision Transformer combining CNN backbone with ViT Better for smaller datasets """ def __init__(self, num_classes=2, pretrained=True): super().__init__() # Use ResNet50 as feature extractor resnet = models.resnet50(pretrained=pretrained) self.features = nn.Sequential(*list(resnet.children())[:-2]) # ViT on top of CNN features self.embed_dim = 768 self.projection = nn.Conv2d(2048, self.embed_dim, kernel_size=1) # Transformer blocks self.blocks = nn.ModuleList([ TransformerBlock(self.embed_dim, num_heads=12, mlp_ratio=4) for _ in range(6) ]) self.norm = nn.LayerNorm(self.embed_dim) self.head = nn.Linear(self.embed_dim, num_classes) def forward(self, x): # CNN feature extraction x = self.features(x) x = self.projection(x) # Flatten spatial dimensions B, C, H, W = x.shape x = x.flatten(2).transpose(1, 2) # Transformer blocks for block in self.blocks: x = block(x) x = self.norm(x) x = x.mean(dim=1) # Global average pooling x = self.head(x) return x class EfficientViT(nn.Module): """ Efficient Vision Transformer using timm library Provides state-of-the-art pretrained models """ def __init__(self, model_name='vit_base_patch16_224', num_classes=2, pretrained=True): super().__init__() self.model = timm.create_model(model_name, pretrained=pretrained, num_classes=num_classes) def forward(self, x): return self.model(x) class DualAttentionViT(nn.Module): """ Implements Dual Attention Vision Transformer (DaViT) inspired architecture Combines spatial and channel attention """ def __init__(self, img_size=224, patch_size=16, num_classes=2, embed_dim=768): super().__init__() self.patch_embed = PatchEmbedding(img_size, patch_size, 3, embed_dim) # Spatial attention branch self.spatial_blocks = nn.ModuleList([ TransformerBlock(embed_dim, num_heads=12) for _ in range(6) ]) # Channel attention branch self.channel_attention = nn.Sequential( nn.AdaptiveAvgPool1d(1), nn.Linear(embed_dim, embed_dim // 16), nn.ReLU(), nn.Linear(embed_dim // 16, embed_dim), nn.Sigmoid() ) self.fusion = nn.Linear(embed_dim * 2, embed_dim) self.norm = nn.LayerNorm(embed_dim) self.head = nn.Linear(embed_dim, num_classes) def forward(self, x): x = self.patch_embed(x) # Spatial attention path spatial_feat = x for block in self.spatial_blocks: spatial_feat = block(spatial_feat) # Channel attention B, N, C = x.shape channel_weights = self.channel_attention(x.transpose(1, 2)).transpose(1, 2) channel_feat = x * channel_weights # Fusion combined = torch.cat([spatial_feat, channel_feat], dim=-1) x = self.fusion(combined) x = self.norm(x) x = x.mean(dim=1) # Global pooling x = self.head(x) return x def get_model(model_type='efficient_vit', num_classes=2, pretrained=True): """ Factory function to get different Vision Transformer variants Args: model_type: One of ['vit', 'hybrid_vit', 'efficient_vit', 'dual_attention_vit'] num_classes: Number of output classes pretrained: Whether to use pretrained weights Returns: Model instance """ if model_type == 'vit': return VisionTransformer(num_classes=num_classes, use_pretrained=pretrained) elif model_type == 'hybrid_vit': return HybridViT(num_classes=num_classes, pretrained=pretrained) elif model_type == 'efficient_vit': # Using state-of-the-art pretrained models from timm # Options: vit_large_patch16_224, vit_huge_patch14_224, deit_base_patch16_224 return EfficientViT(model_name='vit_large_patch16_224', num_classes=num_classes, pretrained=pretrained) elif model_type == 'dual_attention_vit': return DualAttentionViT(num_classes=num_classes) else: raise ValueError(f"Unknown model type: {model_type}") # Model configurations for different use cases MODEL_CONFIGS = { 'small': { 'embed_dim': 384, 'depth': 6, 'num_heads': 6, 'description': 'Small ViT for quick training' }, 'base': { 'embed_dim': 768, 'depth': 12, 'num_heads': 12, 'description': 'Base ViT - good balance' }, 'large': { 'embed_dim': 1024, 'depth': 24, 'num_heads': 16, 'description': 'Large ViT for best accuracy' } } if __name__ == "__main__": # Test models print("Testing Vision Transformer models...") # Test input x = torch.randn(2, 3, 224, 224) # Test different models models_to_test = ['vit', 'hybrid_vit', 'efficient_vit', 'dual_attention_vit'] for model_type in models_to_test: print(f"\nTesting {model_type}...") model = get_model(model_type, num_classes=2, pretrained=False) output = model(x) print(f" Input shape: {x.shape}") print(f" Output shape: {output.shape}") print(f" Parameters: {sum(p.numel() for p in model.parameters()):,}")