Instructions to use mangsense/codet5-java-vulnerability-lora with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- PEFT
How to use mangsense/codet5-java-vulnerability-lora with PEFT:
from peft import PeftModel from transformers import AutoModelForSequenceClassification base_model = AutoModelForSequenceClassification.from_pretrained("Salesforce/codet5-small") model = PeftModel.from_pretrained(base_model, "mangsense/codet5-java-vulnerability-lora") - Transformers
How to use mangsense/codet5-java-vulnerability-lora with Transformers:
# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("mangsense/codet5-java-vulnerability-lora", dtype="auto") - Notebooks
- Google Colab
- Kaggle
| import os | |
| import torch | |
| from transformers import AutoTokenizer, T5ForSequenceClassification | |
| from typing import Dict, List, Any | |
| class EndpointHandler: | |
| """ | |
| HuggingFace Inference Endpoint Handler for Java Vulnerability Detection | |
| CodeT5 ๊ธฐ๋ฐ ๋ถ๋ฅ ๋ชจ๋ธ (LoRA fine-tuned) | |
| """ | |
| def __init__(self, path="."): | |
| """ | |
| ๋ชจ๋ธ๊ณผ ํ ํฌ๋์ด์ ๋ฅผ ์ด๊ธฐํํฉ๋๋ค. | |
| Args: | |
| path (str): ๋ชจ๋ธ์ด ์ ์ฅ๋ ๊ฒฝ๋ก (HuggingFace Hub์์ ์๋์ผ๋ก ์ค์ ๋จ) | |
| """ | |
| print(f"๐ Loading Java Vulnerability Detection Model from {path}") | |
| # ๋๋ฐ์ด์ค ์ค์ | |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"๐ Device: {self.device}") | |
| # ํ ํฌ๋์ด์ ๋ก๋ | |
| self.tokenizer = AutoTokenizer.from_pretrained(path) | |
| # T5ForSequenceClassification ๋ชจ๋ธ ๋ก๋ | |
| self.model = T5ForSequenceClassification.from_pretrained( | |
| path, | |
| torch_dtype=torch.float16 if self.device == "cuda" else torch.float32 | |
| ) | |
| # ๋ชจ๋ธ์ ํ๊ฐ ๋ชจ๋๋ก ์ค์ ํ๊ณ ๋๋ฐ์ด์ค๋ก ์ด๋ | |
| self.model.to(self.device) | |
| self.model.eval() | |
| print("โ Model loaded successfully!") | |
| def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
| """ | |
| ๋ฉ์ธ ์ถ๋ก ๋ฉ์๋ (HuggingFace Inference API๊ฐ ํธ์ถ) | |
| Args: | |
| data (dict): ์ ๋ ฅ ๋ฐ์ดํฐ | |
| - "inputs" (str): Java ์ฝ๋ ๋๋ | |
| - "code" (str): Java ์ฝ๋ | |
| Returns: | |
| list: ์์ธก ๊ฒฐ๊ณผ ๋ฆฌ์คํธ | |
| """ | |
| # 1. ์ ์ฒ๋ฆฌ | |
| inputs = self.preprocess(data) | |
| # 2. ์ถ๋ก | |
| outputs = self.inference(inputs) | |
| # 3. ํ์ฒ๋ฆฌ | |
| result = self.postprocess(outputs) | |
| return result | |
| def preprocess(self, request: Dict[str, Any]) -> Dict[str, torch.Tensor]: | |
| """ | |
| ์ ๋ ฅ ๋ฐ์ดํฐ๋ฅผ ์ ์ฒ๋ฆฌํฉ๋๋ค. | |
| Args: | |
| request (dict): API ์์ฒญ ๋ฐ์ดํฐ | |
| Returns: | |
| dict: ํ ํฌ๋์ด์ฆ๋ ์ ๋ ฅ ํ ์ | |
| """ | |
| # ์ ๋ ฅ ํ ์คํธ ์ถ์ถ | |
| if isinstance(request, dict): | |
| # "inputs" ๋๋ "code" ํค์์ Java ์ฝ๋ ์ถ์ถ | |
| code = request.get("inputs") or request.get("code") | |
| elif isinstance(request, list) and len(request) > 0: | |
| code = request[0].get("inputs") or request[0].get("code") | |
| elif isinstance(request, str): | |
| code = request | |
| else: | |
| raise ValueError( | |
| "Invalid request format. Expected {'inputs': 'Java code here'} " | |
| "or {'code': 'Java code here'}" | |
| ) | |
| if not code: | |
| raise ValueError("No code provided in request") | |
| # ํ๋กฌํํธ ํ ํ๋ฆฟ ์ ์ฉ | |
| input_text = f"Is this Java code vulnerable?:\n{code}" | |
| # ํ ํฌ๋์ด์ง | |
| inputs = self.tokenizer( | |
| input_text, | |
| max_length=512, | |
| truncation=True, | |
| padding="max_length", | |
| return_tensors="pt" | |
| ) | |
| # ๋๋ฐ์ด์ค๋ก ์ด๋ | |
| inputs = {k: v.to(self.device) for k, v in inputs.items()} | |
| return inputs | |
| def inference(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor: | |
| """ | |
| ๋ชจ๋ธ ์ถ๋ก ์ ์ํํฉ๋๋ค. | |
| Args: | |
| inputs (dict): ์ ์ฒ๋ฆฌ๋ ์ ๋ ฅ ํ ์ | |
| Returns: | |
| torch.Tensor: ๋ชจ๋ธ ์ถ๋ ฅ ๋ก์ง | |
| """ | |
| with torch.no_grad(): | |
| outputs = self.model(**inputs) | |
| logits = outputs.logits | |
| return logits | |
| def postprocess(self, logits: torch.Tensor) -> List[Dict[str, Any]]: | |
| """ | |
| ๋ชจ๋ธ ์ถ๋ ฅ์ ์ฌ๋์ด ์ฝ์ ์ ์๋ ํํ๋ก ๋ณํํฉ๋๋ค. | |
| Args: | |
| logits (torch.Tensor): ๋ชจ๋ธ ์ถ๋ ฅ ๋ก์ง | |
| Returns: | |
| list: ์์ธก ๊ฒฐ๊ณผ ๋ฆฌ์คํธ | |
| """ | |
| # ๋ก์ง ์ฒ๋ฆฌ (๋จ์ผ ์ถ๋ ฅ vs ๋ค์ค ํด๋์ค) | |
| if logits.shape[-1] == 1: | |
| # Binary classification with single output | |
| prob = torch.sigmoid(logits).item() | |
| predicted_class = 1 if prob > 0.5 else 0 | |
| confidence = prob if predicted_class == 1 else (1 - prob) | |
| probabilities = { | |
| "LABEL_0": 1 - prob, | |
| "LABEL_1": prob | |
| } | |
| else: | |
| # Multi-class classification | |
| probs = torch.softmax(logits, dim=1)[0] | |
| predicted_class = torch.argmax(logits, dim=1).item() | |
| confidence = probs[predicted_class].item() | |
| probabilities = { | |
| f"LABEL_{i}": probs[i].item() | |
| for i in range(len(probs)) | |
| } | |
| # ๋ ์ด๋ธ ๋งคํ | |
| label_map = { | |
| 0: "safe", | |
| 1: "vulnerable" | |
| } | |
| # ๊ฒฐ๊ณผ ํฌ๋งทํ | |
| result = { | |
| "label": label_map.get(predicted_class, f"LABEL_{predicted_class}"), | |
| "score": confidence, | |
| "probabilities": probabilities, | |
| "details": { | |
| "is_vulnerable": predicted_class == 1, | |
| "confidence_percentage": f"{confidence * 100:.2f}%", | |
| "safe_probability": probabilities.get("LABEL_0", 0), | |
| "vulnerable_probability": probabilities.get("LABEL_1", 0) | |
| } | |
| } | |
| return [result] | |
| # ๋ก์ปฌ ํ ์คํธ์ฉ ์ฝ๋ | |
| if __name__ == "__main__": | |
| # ๋ก์ปฌ์์ ํ ์คํธํ ๋ ์ฌ์ฉ | |
| handler = EndpointHandler(path=".") | |
| # ํ ์คํธ ์ผ์ด์ค | |
| test_code = """ | |
| import java.sql.*; | |
| public class SQLInjectionVulnerable { | |
| public void getUser(String userInput) { | |
| String query = "SELECT * FROM users WHERE username = '" + userInput + "'"; | |
| Statement statement = connection.createStatement(); | |
| ResultSet resultSet = statement.executeQuery(query); | |
| } | |
| } | |
| """ | |
| # ์ถ๋ก ์คํ | |
| request = {"inputs": test_code} | |
| result = handler(request) | |
| print("\n๐ Test Result:") | |
| print(result) |