Transformers from Scratch: Complete Implementation
A comprehensive PyTorch implementation of the Transformer architecture from "Attention Is All You Need", featuring detailed mathematical foundations, educational content, and practical text classification applications.
Model Description
This repository contains a complete, from-scratch implementation of the Transformer architecture. The model demonstrates the core concepts behind modern NLP systems like BERT, GPT, and ChatGPT through a practical sentiment analysis task. This implementation serves as both a working model and an educational resource for understanding the revolutionary attention mechanism.
Architecture Details
- Model Type: Transformer Encoder for Text Classification
- Framework: PyTorch
- Task: Binary sentiment classification (positive/negative movie reviews)
- Model Dimension: 128
- Attention Heads: 8
- Layers: 4 Transformer blocks
- Feed-Forward Dimension: 256
- Total Parameters: ~200K
- Vocabulary Size: Dynamic (built from training data)
Key Components
- Multi-Head Attention: Core mechanism allowing parallel processing of sequences
- Positional Encoding: Sine/cosine embeddings to inject position information
- Transformer Blocks: Attention + feed-forward with residual connections
- Layer Normalization: Stabilizes training and improves convergence
- Classification Head: Global average pooling + linear layer for predictions
Mathematical Foundation
Scaled Dot-Product Attention
Attention(Q, K, V) = softmax(QK^T / โd_k)V
Multi-Head Attention
MultiHead(Q, K, V) = Concat(head_1, ..., head_h)W^O
head_i = Attention(QW_i^Q, KW_i^K, VW_i^V)
Positional Encoding
PE(pos, 2i) = sin(pos/10000^(2i/d_model))
PE(pos, 2i+1) = cos(pos/10000^(2i/d_model))
Training Details
- Dataset: Synthetic movie reviews (positive/negative sentiment)
- Optimizer: AdamW with weight decay (0.01)
- Learning Rate: 0.0001 with cosine annealing
- Batch Size: 16
- Max Sequence Length: 24 tokens
- Training Epochs: 30
- Hardware: Optimized for Apple M4 and CUDA GPUs
Model Performance
Metrics
- Test Accuracy: 85%+
- Training Time: ~10 minutes on Apple M4
- Model Size: 200K parameters
- Convergence: Stable training without overfitting
Capabilities
- โ Binary sentiment classification
- โ Attention weight visualization
- โ Fast inference on modern hardware
- โ Educational transparency
- โ Easily extensible architecture
Usage
Quick Start
import torch
import torch.nn as nn
import math
# Load the complete implementation (from notebook)
class TransformerClassifier(nn.Module):
def __init__(self, vocab_size, d_model, num_heads, num_layers, d_ff, max_len, num_classes):
super().__init__()
self.d_model = d_model
self.embedding = nn.Embedding(vocab_size, d_model)
self.pos_encoding = PositionalEncoding(d_model, max_len)
self.transformer_blocks = nn.ModuleList([
TransformerBlock(d_model, num_heads, d_ff)
for _ in range(num_layers)
])
self.norm = nn.LayerNorm(d_model)
self.classifier = nn.Linear(d_model, num_classes)
def forward(self, x):
# Embedding + positional encoding
x = self.embedding(x) * math.sqrt(self.d_model)
x = self.pos_encoding(x)
# Transformer blocks
for transformer in self.transformer_blocks:
x = transformer(x)
# Classification
x = self.norm(x)
x = x.mean(dim=1) # Global average pooling
return self.classifier(x)
# Load trained model
model = TransformerClassifier(
vocab_size=vocab_size,
d_model=128,
num_heads=8,
num_layers=4,
d_ff=256,
max_len=24,
num_classes=2
)
model.load_state_dict(torch.load('best_transformer_model.pth'))
model.eval()
# Example inference
def predict_sentiment(text, model, vocab_to_idx, max_length=24):
tokens = tokenize_text(text, vocab_to_idx, max_length)
with torch.no_grad():
output = model(tokens.unsqueeze(0))
prediction = torch.softmax(output, dim=1)
return "Positive" if prediction[0][1] > 0.5 else "Negative"
# Test the model
result = predict_sentiment("This movie was absolutely fantastic!", model, vocab_to_idx)
print(f"Sentiment: {result}")
Advanced Usage
# Visualize attention weights
def visualize_attention(model, text, vocab_to_idx):
# Extract attention weights from each layer
# Create heatmaps showing what the model focuses on
pass
# Fine-tune on new data
def fine_tune_model(model, new_data_loader, epochs=5):
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)
# Continue training on domain-specific data
pass
Visualizations and Analysis
- Training Curves: Loss and accuracy evolution over epochs
- Attention Heatmaps: Visualize what the model pays attention to
- Performance Metrics: Precision, recall, F1-score breakdowns
- Architecture Diagrams: Component-wise model visualization
- Error Analysis: Common failure cases and model limitations
Files and Outputs
Transformers.ipynb
: Complete implementation with educational contentbest_transformer_model.pth
: Trained model weightsm4_transformer_results.png
: Training curves and performance metrics- Architecture visualization and attention weight examples
Educational Value
This implementation is designed as a comprehensive learning resource featuring:
Mathematical Understanding
- Complete Derivations: From attention theory to implementation
- Step-by-Step Breakdown: Each component explained individually
- Visual Mathematics: Attention visualizations and formula explanations
- Practical Examples: Concrete numerical calculations
Implementation Insights
- Clean Code Architecture: Modular, readable, and well-documented
- Best Practices: Modern PyTorch patterns and techniques
- Performance Optimization: Efficient training and inference
- Debugging Techniques: How to monitor and improve training
Real-World Applications
- End-to-End Pipeline: From raw text to predictions
- Production Considerations: Model deployment and optimization
- Extension Examples: How to adapt for different tasks
- Transfer Learning: Building on pre-trained representations
Applications
This Transformer implementation can be adapted for:
Text Classification Tasks
- Sentiment Analysis: Movie reviews, product feedback, social media
- Topic Classification: News categorization, document organization
- Spam Detection: Email filtering, content moderation
- Intent Recognition: Chatbot understanding, voice assistants
Sequence Processing
- Named Entity Recognition: Extract people, places, organizations
- Part-of-Speech Tagging: Grammatical analysis
- Text Similarity: Document matching, plagiarism detection
- Feature Extraction: Dense representations for downstream tasks
Research and Development
- Architecture Experiments: Test new attention mechanisms
- Ablation Studies: Understand component contributions
- Scaling Experiments: Larger models and datasets
- Novel Applications: Domain-specific adaptations
Comparison with Other Architectures
Advantages over RNNs
- โ Parallel Processing: Much faster training and inference
- โ Long-Range Dependencies: Better handling of distant relationships
- โ Scalability: Efficient on modern hardware
- โ Interpretability: Attention weights provide insights
Advantages over CNNs
- โ Sequence Modeling: Natural fit for text and time series
- โ Variable Length: Handle sequences of any length
- โ Global Context: Attend to entire sequence simultaneously
- โ Position Awareness: Explicit positional information
Educational Benefits
- ๐ Foundation Understanding: Core concepts behind modern NLP
- ๐ Mathematical Clarity: Clean mathematical formulations
- ๐ Implementation Practice: Hands-on coding experience
- ๐ Research Preparation: Basis for advanced architectures
Citation
If you use this implementation in your research or projects, please cite:
@misc{transformers_from_scratch_2024,
title={Transformers from Scratch: Complete Implementation},
author={Gruhesh Kurra},
year={2024},
url={https://huggingface.co/karthik-2905/TransformersFromScratch}
}
Future Extensions
Planned improvements and research directions:
- ๐ Encoder-Decoder Architecture: Full sequence-to-sequence implementation
- ๐จ Pre-training Pipeline: Large-scale language model training
- ๐ Alternative Attention: Sparse, local, and linear attention variants
- ๐ผ๏ธ Vision Transformers: Adapt architecture for image tasks
- ๐ต Multimodal Transformers: Text, image, and audio processing
- ๐งฌ Scientific Applications: Protein sequences, molecular modeling
License
This project is licensed under the MIT License - see the LICENSE file for details.
Additional Resources
- GitHub Repository: TransformersFromScratch
- Original Paper: "Attention Is All You Need" by Vaswani et al.
- Educational Content: Complete mathematical derivations and examples
- Performance Benchmarks: Detailed analysis and comparisons
Model Card Authors
Gruhesh Kurra - Implementation, documentation, and educational content
Tags: transformers, attention, pytorch, nlp, text-classification, educational
Model Card Last Updated: December 2024