|
|
|
|
|
from fastapi import FastAPI, HTTPException |
|
|
from pydantic import BaseModel |
|
|
import torch |
|
|
from transformers import RobertaTokenizer, RobertaModel |
|
|
import uvicorn |
|
|
from contextlib import asynccontextmanager |
|
|
|
|
|
|
|
|
model = None |
|
|
tokenizer = None |
|
|
device = None |
|
|
|
|
|
|
|
|
class CodeBERTClassifier(torch.nn.Module): |
|
|
def __init__(self, num_labels=4, dropout=0.3, hidden_size=256): |
|
|
super(CodeBERTClassifier, self).__init__() |
|
|
|
|
|
|
|
|
self.codebert = RobertaModel.from_pretrained('microsoft/codebert-base') |
|
|
|
|
|
|
|
|
self.dropout = torch.nn.Dropout(dropout) |
|
|
|
|
|
|
|
|
self.classifier = torch.nn.Sequential( |
|
|
torch.nn.Linear(768, hidden_size), |
|
|
torch.nn.ReLU(), |
|
|
torch.nn.Dropout(dropout), |
|
|
torch.nn.Linear(hidden_size, hidden_size // 2), |
|
|
torch.nn.ReLU(), |
|
|
torch.nn.Dropout(dropout), |
|
|
torch.nn.Linear(hidden_size // 2, num_labels) |
|
|
) |
|
|
|
|
|
def forward(self, input_ids, attention_mask): |
|
|
|
|
|
outputs = self.codebert( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask |
|
|
) |
|
|
|
|
|
|
|
|
pooled_output = outputs.last_hidden_state[:, 0, :] |
|
|
|
|
|
|
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
@asynccontextmanager |
|
|
async def lifespan(app: FastAPI): |
|
|
|
|
|
global model, tokenizer, device |
|
|
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
|
|
|
|
|
|
model = CodeBERTClassifier(num_labels=4, dropout=0.3, hidden_size=256) |
|
|
model.load_state_dict(torch.load('best_codebert_model.pt', map_location=device)) |
|
|
model.to(device) |
|
|
model.eval() |
|
|
|
|
|
tokenizer = RobertaTokenizer.from_pretrained('microsoft/codebert-base') |
|
|
print(f"Model loaded successfully on {device}") |
|
|
|
|
|
yield |
|
|
|
|
|
|
|
|
print("Shutting down...") |
|
|
|
|
|
app = FastAPI(title="Vulnerability Detection API", lifespan=lifespan) |
|
|
|
|
|
class CodeRequest(BaseModel): |
|
|
code: str |
|
|
max_length: int = 512 |
|
|
|
|
|
class VulnerabilityResponse(BaseModel): |
|
|
vulnerability_type: str |
|
|
confidence: float |
|
|
is_vulnerable: bool |
|
|
label: str |
|
|
|
|
|
@app.post("/detect", response_model=VulnerabilityResponse) |
|
|
async def detect_vulnerability(request: CodeRequest): |
|
|
try: |
|
|
|
|
|
encoding = tokenizer( |
|
|
request.code, |
|
|
padding='max_length', |
|
|
truncation=True, |
|
|
max_length=request.max_length, |
|
|
return_tensors='pt' |
|
|
) |
|
|
|
|
|
input_ids = encoding['input_ids'].to(device) |
|
|
attention_mask = encoding['attention_mask'].to(device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
logits = model(input_ids, attention_mask) |
|
|
probs = torch.softmax(logits, dim=1) |
|
|
confidence, predicted = torch.max(probs, 1) |
|
|
|
|
|
|
|
|
label_map = {0: 's0', 1: 'v0', 2: 's1', 3: 'v1'} |
|
|
vuln_type_map = { |
|
|
's0': 'SQL Injection', |
|
|
'v0': 'Certificate Validation', |
|
|
's1': 'SQL Injection', |
|
|
'v1': 'Certificate Validation' |
|
|
} |
|
|
|
|
|
label = label_map[predicted.item()] |
|
|
is_vulnerable = label in ['s0', 'v0'] |
|
|
|
|
|
return VulnerabilityResponse( |
|
|
vulnerability_type=vuln_type_map[label], |
|
|
confidence=float(confidence.item()), |
|
|
is_vulnerable=is_vulnerable, |
|
|
label=label |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
raise HTTPException(status_code=500, detail=str(e)) |
|
|
|
|
|
@app.get("/health") |
|
|
async def health_check(): |
|
|
return {"status": "healthy", "model_loaded": model is not None} |
|
|
|
|
|
@app.get("/") |
|
|
async def root(): |
|
|
return { |
|
|
"message": "Vulnerability Detection API", |
|
|
"endpoints": ["/detect", "/health", "/docs"] |
|
|
} |
|
|
|
|
|
if __name__ == "__main__": |
|
|
uvicorn.run(app, host="0.0.0.0", port=8000) |