Lahiru-LK's picture
Upload 3 files
e05aa19 verified
# app.py - CORRECTED VERSION
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import torch
from transformers import RobertaTokenizer, RobertaModel
import uvicorn
from contextlib import asynccontextmanager
# Global variables
model = None
tokenizer = None
device = None
# THIS MUST MATCH YOUR TRAINING CODE EXACTLY
class CodeBERTClassifier(torch.nn.Module):
def __init__(self, num_labels=4, dropout=0.3, hidden_size=256):
super(CodeBERTClassifier, self).__init__()
# Load pre-trained CodeBERT
self.codebert = RobertaModel.from_pretrained('microsoft/codebert-base')
# Dropout
self.dropout = torch.nn.Dropout(dropout)
# Multi-layer feedforward network - MUST MATCH TRAINING
self.classifier = torch.nn.Sequential(
torch.nn.Linear(768, hidden_size), # 768 -> 256
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(hidden_size, hidden_size // 2), # 256 -> 128
torch.nn.ReLU(),
torch.nn.Dropout(dropout),
torch.nn.Linear(hidden_size // 2, num_labels) # 128 -> 4
)
def forward(self, input_ids, attention_mask):
# Get CodeBERT embeddings
outputs = self.codebert(
input_ids=input_ids,
attention_mask=attention_mask
)
# CRITICAL: Use [CLS] token from last_hidden_state (matching training)
pooled_output = outputs.last_hidden_state[:, 0, :]
# Apply dropout
pooled_output = self.dropout(pooled_output)
# Classification
logits = self.classifier(pooled_output)
return logits
# Use lifespan instead of deprecated on_event
@asynccontextmanager
async def lifespan(app: FastAPI):
# Startup
global model, tokenizer, device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Load your trained model with EXACT same architecture
model = CodeBERTClassifier(num_labels=4, dropout=0.3, hidden_size=256)
model.load_state_dict(torch.load('best_codebert_model.pt', map_location=device))
model.to(device)
model.eval()
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base')
print(f"Model loaded successfully on {device}")
yield
# Shutdown
print("Shutting down...")
app = FastAPI(title="Vulnerability Detection API", lifespan=lifespan)
class CodeRequest(BaseModel):
code: str
max_length: int = 512
class VulnerabilityResponse(BaseModel):
vulnerability_type: str
confidence: float
is_vulnerable: bool
label: str
@app.post("/detect", response_model=VulnerabilityResponse)
async def detect_vulnerability(request: CodeRequest):
try:
# Tokenize
encoding = tokenizer(
request.code,
padding='max_length',
truncation=True,
max_length=request.max_length,
return_tensors='pt'
)
input_ids = encoding['input_ids'].to(device)
attention_mask = encoding['attention_mask'].to(device)
# Predict
with torch.no_grad():
logits = model(input_ids, attention_mask)
probs = torch.softmax(logits, dim=1)
confidence, predicted = torch.max(probs, 1)
# Label mapping - VERIFY THIS MATCHES YOUR TRAINING
label_map = {0: 's0', 1: 'v0', 2: 's1', 3: 'v1'}
vuln_type_map = {
's0': 'SQL Injection',
'v0': 'Certificate Validation',
's1': 'SQL Injection',
'v1': 'Certificate Validation'
}
label = label_map[predicted.item()]
is_vulnerable = label in ['s0', 'v0']
return VulnerabilityResponse(
vulnerability_type=vuln_type_map[label],
confidence=float(confidence.item()),
is_vulnerable=is_vulnerable,
label=label
)
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
@app.get("/health")
async def health_check():
return {"status": "healthy", "model_loaded": model is not None}
@app.get("/")
async def root():
return {
"message": "Vulnerability Detection API",
"endpoints": ["/detect", "/health", "/docs"]
}
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=8000)