import torch from transformers import ( AutoTokenizer, AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer, DataCollatorForSeq2Seq ) from peft import LoraConfig, get_peft_model, TaskType from datasets import load_dataset import os # Model Configuration MODEL_NAME = "Salesforce/codet5-base" MAX_LENGTH = 128 TRAIN_BATCH_SIZE = 2 EVAL_BATCH_SIZE = 2 LEARNING_RATE = 1e-4 NUM_EPOCHS = 3 TRAIN_SIZE = 5000 VAL_SIZE = 500 CHECKPOINT_DIR = "./codet5-sql-finetuned" def preprocess(example): question = example["question"] table_headers = ", ".join(example["table"]["header"]) sql_query = example["sql"]["human_readable"] return { "input_text": f"### Table columns:\n{table_headers}\n### Question:\n{question}\n### SQL:", "target_text": sql_query } def main(): # Set up device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") print(f"Using device: {device}") # Load and preprocess dataset print("Loading dataset...") try: dataset = load_dataset("wikisql") except Exception as e: print(f"Error loading dataset: {str(e)}") print("Trying with trust_remote_code=True...") dataset = load_dataset("wikisql", trust_remote_code=True) train_dataset = dataset["train"].select(range(TRAIN_SIZE)) val_dataset = dataset["validation"].select(range(VAL_SIZE)) print("Preprocessing datasets...") processed_train = train_dataset.map(preprocess, remove_columns=train_dataset.column_names) processed_val = val_dataset.map(preprocess, remove_columns=val_dataset.column_names) # Load model and tokenizer print("Loading model and tokenizer...") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME) # Add LoRA adapters lora_config = LoraConfig( r=8, lora_alpha=16, lora_dropout=0.1, bias="none", task_type=TaskType.SEQ_2_SEQ_LM, target_modules=["q", "v", "k", "o", "wi", "wo"] ) model = get_peft_model(model, lora_config) def tokenize_function(examples): inputs = tokenizer( examples["input_text"], padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt" ) targets = tokenizer( examples["target_text"], padding="max_length", truncation=True, max_length=MAX_LENGTH, return_tensors="pt" ) inputs["labels"] = targets["input_ids"] return inputs print("Tokenizing datasets...") tokenized_train = processed_train.map( tokenize_function, remove_columns=processed_train.column_names, batched=True ) tokenized_val = processed_val.map( tokenize_function, remove_columns=processed_val.column_names, batched=True ) # Training arguments - simplified for stability training_args = Seq2SeqTrainingArguments( output_dir=CHECKPOINT_DIR, per_device_train_batch_size=TRAIN_BATCH_SIZE, per_device_eval_batch_size=EVAL_BATCH_SIZE, num_train_epochs=NUM_EPOCHS, learning_rate=LEARNING_RATE, logging_dir=os.path.join(CHECKPOINT_DIR, "logs"), logging_steps=10, save_total_limit=2, predict_with_generate=True, no_cuda=True, # Force CPU training fp16=False, # Disable mixed precision training since we're on CPU report_to="none" # Disable wandb logging ) # Data collator data_collator = DataCollatorForSeq2Seq( tokenizer, model=model, padding=True ) # Initialize trainer trainer = Seq2SeqTrainer( model=model, args=training_args, train_dataset=tokenized_train, eval_dataset=tokenized_val, data_collator=data_collator, ) try: print("\nStarting training...") print("You can stop training at any time by pressing Ctrl+C") print("Training will automatically save checkpoints after each epoch") # Check for existing checkpoints last_checkpoint = None if os.path.exists(CHECKPOINT_DIR): checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')] if checkpoints: last_checkpoint = os.path.join(CHECKPOINT_DIR, sorted(checkpoints, key=lambda x: int(x.split('-')[1]))[-1]) print(f"\nFound checkpoint: {last_checkpoint}") print("Training will resume from this checkpoint.") # Start or resume training trainer.train(resume_from_checkpoint=last_checkpoint) # Save the final model trainer.save_model("./final-model") print("\nTraining completed successfully!") print(f"Final model saved to: ./final-model") except KeyboardInterrupt: print("\nTraining interrupted by user!") print("Progress is saved in the latest checkpoint.") print("To resume, just run the script again.") except Exception as e: print(f"\nAn error occurred during training: {str(e)}") if os.path.exists(CHECKPOINT_DIR): error_checkpoint = os.path.join(CHECKPOINT_DIR, "checkpoint-error") trainer.save_model(error_checkpoint) print(f"Saved error checkpoint to: {error_checkpoint}") if __name__ == "__main__": main()