Manju080's picture
Initial deployment test_to_sql test1
0d8581e
import torch
from transformers import (
AutoTokenizer,
AutoModelForSeq2SeqLM,
Seq2SeqTrainingArguments,
Seq2SeqTrainer,
DataCollatorForSeq2Seq
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import os
# Model Configuration
MODEL_NAME = "Salesforce/codet5-base"
MAX_LENGTH = 128
TRAIN_BATCH_SIZE = 2
EVAL_BATCH_SIZE = 2
LEARNING_RATE = 1e-4
NUM_EPOCHS = 3
TRAIN_SIZE = 5000
VAL_SIZE = 500
CHECKPOINT_DIR = "./codet5-sql-finetuned"
def preprocess(example):
question = example["question"]
table_headers = ", ".join(example["table"]["header"])
sql_query = example["sql"]["human_readable"]
return {
"input_text": f"### Table columns:\n{table_headers}\n### Question:\n{question}\n### SQL:",
"target_text": sql_query
}
def main():
# Set up device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load and preprocess dataset
print("Loading dataset...")
try:
dataset = load_dataset("wikisql")
except Exception as e:
print(f"Error loading dataset: {str(e)}")
print("Trying with trust_remote_code=True...")
dataset = load_dataset("wikisql", trust_remote_code=True)
train_dataset = dataset["train"].select(range(TRAIN_SIZE))
val_dataset = dataset["validation"].select(range(VAL_SIZE))
print("Preprocessing datasets...")
processed_train = train_dataset.map(preprocess, remove_columns=train_dataset.column_names)
processed_val = val_dataset.map(preprocess, remove_columns=val_dataset.column_names)
# Load model and tokenizer
print("Loading model and tokenizer...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
# Add LoRA adapters
lora_config = LoraConfig(
r=8,
lora_alpha=16,
lora_dropout=0.1,
bias="none",
task_type=TaskType.SEQ_2_SEQ_LM,
target_modules=["q", "v", "k", "o", "wi", "wo"]
)
model = get_peft_model(model, lora_config)
def tokenize_function(examples):
inputs = tokenizer(
examples["input_text"],
padding="max_length",
truncation=True,
max_length=MAX_LENGTH,
return_tensors="pt"
)
targets = tokenizer(
examples["target_text"],
padding="max_length",
truncation=True,
max_length=MAX_LENGTH,
return_tensors="pt"
)
inputs["labels"] = targets["input_ids"]
return inputs
print("Tokenizing datasets...")
tokenized_train = processed_train.map(
tokenize_function,
remove_columns=processed_train.column_names,
batched=True
)
tokenized_val = processed_val.map(
tokenize_function,
remove_columns=processed_val.column_names,
batched=True
)
# Training arguments - simplified for stability
training_args = Seq2SeqTrainingArguments(
output_dir=CHECKPOINT_DIR,
per_device_train_batch_size=TRAIN_BATCH_SIZE,
per_device_eval_batch_size=EVAL_BATCH_SIZE,
num_train_epochs=NUM_EPOCHS,
learning_rate=LEARNING_RATE,
logging_dir=os.path.join(CHECKPOINT_DIR, "logs"),
logging_steps=10,
save_total_limit=2,
predict_with_generate=True,
no_cuda=True, # Force CPU training
fp16=False, # Disable mixed precision training since we're on CPU
report_to="none" # Disable wandb logging
)
# Data collator
data_collator = DataCollatorForSeq2Seq(
tokenizer,
model=model,
padding=True
)
# Initialize trainer
trainer = Seq2SeqTrainer(
model=model,
args=training_args,
train_dataset=tokenized_train,
eval_dataset=tokenized_val,
data_collator=data_collator,
)
try:
print("\nStarting training...")
print("You can stop training at any time by pressing Ctrl+C")
print("Training will automatically save checkpoints after each epoch")
# Check for existing checkpoints
last_checkpoint = None
if os.path.exists(CHECKPOINT_DIR):
checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')]
if checkpoints:
last_checkpoint = os.path.join(CHECKPOINT_DIR, sorted(checkpoints, key=lambda x: int(x.split('-')[1]))[-1])
print(f"\nFound checkpoint: {last_checkpoint}")
print("Training will resume from this checkpoint.")
# Start or resume training
trainer.train(resume_from_checkpoint=last_checkpoint)
# Save the final model
trainer.save_model("./final-model")
print("\nTraining completed successfully!")
print(f"Final model saved to: ./final-model")
except KeyboardInterrupt:
print("\nTraining interrupted by user!")
print("Progress is saved in the latest checkpoint.")
print("To resume, just run the script again.")
except Exception as e:
print(f"\nAn error occurred during training: {str(e)}")
if os.path.exists(CHECKPOINT_DIR):
error_checkpoint = os.path.join(CHECKPOINT_DIR, "checkpoint-error")
trainer.save_model(error_checkpoint)
print(f"Saved error checkpoint to: {error_checkpoint}")
if __name__ == "__main__":
main()