|
--- |
|
license: other |
|
license_name: codet5-salesforce |
|
license_link: LICENSE |
|
language: |
|
- en |
|
base_model: |
|
- Salesforce/codet5-small |
|
pipeline_tag: text-generation |
|
tags: |
|
- code |
|
- commit |
|
- gitdiff |
|
datasets: |
|
- seniruk/git-diff_to_commit_msg_large |
|
--- |
|
|
|
## Fintuned Salesforce/codet5-small base model using 1000000 rows of data with git commits of different types of random languages for 5 epochs |
|
## Took a total of 10 hours in a gpu of RTX 4060TI 16GB VRAM. |
|
## Use the below instructions for inference |
|
## Modules required- transformers,pytorch,CUDA |
|
``` |
|
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM |
|
|
|
Load the correct CodeT5 tokenizer and model |
|
model_name = "Salesforce/codet5-small" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
#Example Git diff input |
|
git_diff = """ |
|
diff --git a/example.py b/example.py |
|
index 3b18e12..b3f7e54 100644 |
|
--- a/example.py |
|
+++ b/example.py |
|
@@ -1,5 +1,6 @@ |
|
-def greet(): |
|
- print("Hello, world!") |
|
+def greet_user(name): |
|
+ print(f"Hello, {name}!") |
|
|
|
-def farewell(): |
|
- print("Goodbye!") |
|
+def farewell_user(name): |
|
+ print(f"Goodbye, {name}!") |
|
""" |
|
|
|
#keep the instruction unchanged, becus the model was trained on this static instruction |
|
instruction = "Generate a commit message based on the following Git diff:\n\n" |
|
|
|
task_input = instruction + git_diff |
|
|
|
# Tokenize the input |
|
inputs = tokenizer( |
|
task_input, |
|
max_length=512, # Truncate if necessary |
|
truncation=True, |
|
padding="max_length", |
|
return_tensors="pt" |
|
) |
|
|
|
# Generate commit message |
|
outputs = model.generate( |
|
inputs["input_ids"], |
|
max_length=50, |
|
num_beams=5, # Use beam search |
|
temperature=0.9, # Adds controlled randomness |
|
top_p=0.9, # Nucleus sampling |
|
early_stopping=True |
|
) |
|
|
|
# Decode the generated commit message |
|
commit_message = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
# Print the result |
|
print("Generated Commit Message:") |
|
print(commit_message) |
|
``` |