Spaces:
Runtime error
Runtime error
import torch | |
from transformers import ( | |
AutoTokenizer, | |
AutoModelForSeq2SeqLM, | |
Seq2SeqTrainingArguments, | |
Seq2SeqTrainer, | |
DataCollatorForSeq2Seq | |
) | |
from peft import LoraConfig, get_peft_model, TaskType | |
from datasets import load_dataset | |
import os | |
# Model Configuration | |
MODEL_NAME = "Salesforce/codet5-base" | |
MAX_LENGTH = 128 | |
TRAIN_BATCH_SIZE = 2 | |
EVAL_BATCH_SIZE = 2 | |
LEARNING_RATE = 1e-4 | |
NUM_EPOCHS = 3 | |
TRAIN_SIZE = 5000 | |
VAL_SIZE = 500 | |
CHECKPOINT_DIR = "./codet5-sql-finetuned" | |
def preprocess(example): | |
question = example["question"] | |
table_headers = ", ".join(example["table"]["header"]) | |
sql_query = example["sql"]["human_readable"] | |
return { | |
"input_text": f"### Table columns:\n{table_headers}\n### Question:\n{question}\n### SQL:", | |
"target_text": sql_query | |
} | |
def main(): | |
# Set up device | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
# Load and preprocess dataset | |
print("Loading dataset...") | |
try: | |
dataset = load_dataset("wikisql") | |
except Exception as e: | |
print(f"Error loading dataset: {str(e)}") | |
print("Trying with trust_remote_code=True...") | |
dataset = load_dataset("wikisql", trust_remote_code=True) | |
train_dataset = dataset["train"].select(range(TRAIN_SIZE)) | |
val_dataset = dataset["validation"].select(range(VAL_SIZE)) | |
print("Preprocessing datasets...") | |
processed_train = train_dataset.map(preprocess, remove_columns=train_dataset.column_names) | |
processed_val = val_dataset.map(preprocess, remove_columns=val_dataset.column_names) | |
# Load model and tokenizer | |
print("Loading model and tokenizer...") | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) | |
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) | |
# Add LoRA adapters | |
lora_config = LoraConfig( | |
r=8, | |
lora_alpha=16, | |
lora_dropout=0.1, | |
bias="none", | |
task_type=TaskType.SEQ_2_SEQ_LM, | |
target_modules=["q", "v", "k", "o", "wi", "wo"] | |
) | |
model = get_peft_model(model, lora_config) | |
def tokenize_function(examples): | |
inputs = tokenizer( | |
examples["input_text"], | |
padding="max_length", | |
truncation=True, | |
max_length=MAX_LENGTH, | |
return_tensors="pt" | |
) | |
targets = tokenizer( | |
examples["target_text"], | |
padding="max_length", | |
truncation=True, | |
max_length=MAX_LENGTH, | |
return_tensors="pt" | |
) | |
inputs["labels"] = targets["input_ids"] | |
return inputs | |
print("Tokenizing datasets...") | |
tokenized_train = processed_train.map( | |
tokenize_function, | |
remove_columns=processed_train.column_names, | |
batched=True | |
) | |
tokenized_val = processed_val.map( | |
tokenize_function, | |
remove_columns=processed_val.column_names, | |
batched=True | |
) | |
# Training arguments - simplified for stability | |
training_args = Seq2SeqTrainingArguments( | |
output_dir=CHECKPOINT_DIR, | |
per_device_train_batch_size=TRAIN_BATCH_SIZE, | |
per_device_eval_batch_size=EVAL_BATCH_SIZE, | |
num_train_epochs=NUM_EPOCHS, | |
learning_rate=LEARNING_RATE, | |
logging_dir=os.path.join(CHECKPOINT_DIR, "logs"), | |
logging_steps=10, | |
save_total_limit=2, | |
predict_with_generate=True, | |
no_cuda=True, # Force CPU training | |
fp16=False, # Disable mixed precision training since we're on CPU | |
report_to="none" # Disable wandb logging | |
) | |
# Data collator | |
data_collator = DataCollatorForSeq2Seq( | |
tokenizer, | |
model=model, | |
padding=True | |
) | |
# Initialize trainer | |
trainer = Seq2SeqTrainer( | |
model=model, | |
args=training_args, | |
train_dataset=tokenized_train, | |
eval_dataset=tokenized_val, | |
data_collator=data_collator, | |
) | |
try: | |
print("\nStarting training...") | |
print("You can stop training at any time by pressing Ctrl+C") | |
print("Training will automatically save checkpoints after each epoch") | |
# Check for existing checkpoints | |
last_checkpoint = None | |
if os.path.exists(CHECKPOINT_DIR): | |
checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')] | |
if checkpoints: | |
last_checkpoint = os.path.join(CHECKPOINT_DIR, sorted(checkpoints, key=lambda x: int(x.split('-')[1]))[-1]) | |
print(f"\nFound checkpoint: {last_checkpoint}") | |
print("Training will resume from this checkpoint.") | |
# Start or resume training | |
trainer.train(resume_from_checkpoint=last_checkpoint) | |
# Save the final model | |
trainer.save_model("./final-model") | |
print("\nTraining completed successfully!") | |
print(f"Final model saved to: ./final-model") | |
except KeyboardInterrupt: | |
print("\nTraining interrupted by user!") | |
print("Progress is saved in the latest checkpoint.") | |
print("To resume, just run the script again.") | |
except Exception as e: | |
print(f"\nAn error occurred during training: {str(e)}") | |
if os.path.exists(CHECKPOINT_DIR): | |
error_checkpoint = os.path.join(CHECKPOINT_DIR, "checkpoint-error") | |
trainer.save_model(error_checkpoint) | |
print(f"Saved error checkpoint to: {error_checkpoint}") | |
if __name__ == "__main__": | |
main() |