File size: 5,557 Bytes
0d8581e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import torch
from transformers import (
    AutoTokenizer, 
    AutoModelForSeq2SeqLM,
    Seq2SeqTrainingArguments, 
    Seq2SeqTrainer,
    DataCollatorForSeq2Seq
)
from peft import LoraConfig, get_peft_model, TaskType
from datasets import load_dataset
import os

# Model Configuration
MODEL_NAME = "Salesforce/codet5-base"
MAX_LENGTH = 128
TRAIN_BATCH_SIZE = 2
EVAL_BATCH_SIZE = 2
LEARNING_RATE = 1e-4
NUM_EPOCHS = 3
TRAIN_SIZE = 5000
VAL_SIZE = 500
CHECKPOINT_DIR = "./codet5-sql-finetuned"

def preprocess(example):
    question = example["question"]
    table_headers = ", ".join(example["table"]["header"])
    sql_query = example["sql"]["human_readable"]
    
    return {
        "input_text": f"### Table columns:\n{table_headers}\n### Question:\n{question}\n### SQL:",
        "target_text": sql_query
    }

def main():
    # Set up device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    print(f"Using device: {device}")

    # Load and preprocess dataset
    print("Loading dataset...")
    try:
        dataset = load_dataset("wikisql")
    except Exception as e:
        print(f"Error loading dataset: {str(e)}")
        print("Trying with trust_remote_code=True...")
        dataset = load_dataset("wikisql", trust_remote_code=True)

    train_dataset = dataset["train"].select(range(TRAIN_SIZE))
    val_dataset = dataset["validation"].select(range(VAL_SIZE))
    
    print("Preprocessing datasets...")
    processed_train = train_dataset.map(preprocess, remove_columns=train_dataset.column_names)
    processed_val = val_dataset.map(preprocess, remove_columns=val_dataset.column_names)

    # Load model and tokenizer
    print("Loading model and tokenizer...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
    
    # Add LoRA adapters
    lora_config = LoraConfig(
        r=8,
        lora_alpha=16,
        lora_dropout=0.1,
        bias="none",
        task_type=TaskType.SEQ_2_SEQ_LM,
        target_modules=["q", "v", "k", "o", "wi", "wo"]
    )
    model = get_peft_model(model, lora_config)
    
    def tokenize_function(examples):
        inputs = tokenizer(
            examples["input_text"],
            padding="max_length",
            truncation=True,
            max_length=MAX_LENGTH,
            return_tensors="pt"
        )
        targets = tokenizer(
            examples["target_text"],
            padding="max_length",
            truncation=True,
            max_length=MAX_LENGTH,
            return_tensors="pt"
        )
        inputs["labels"] = targets["input_ids"]
        return inputs

    print("Tokenizing datasets...")
    tokenized_train = processed_train.map(
        tokenize_function,
        remove_columns=processed_train.column_names,
        batched=True
    )
    tokenized_val = processed_val.map(
        tokenize_function,
        remove_columns=processed_val.column_names,
        batched=True
    )

    # Training arguments - simplified for stability
    training_args = Seq2SeqTrainingArguments(
        output_dir=CHECKPOINT_DIR,
        per_device_train_batch_size=TRAIN_BATCH_SIZE,
        per_device_eval_batch_size=EVAL_BATCH_SIZE,
        num_train_epochs=NUM_EPOCHS,
        learning_rate=LEARNING_RATE,
        logging_dir=os.path.join(CHECKPOINT_DIR, "logs"),
        logging_steps=10,
        save_total_limit=2,
        predict_with_generate=True,
        no_cuda=True,  # Force CPU training
        fp16=False,    # Disable mixed precision training since we're on CPU
        report_to="none"  # Disable wandb logging
    )

    # Data collator
    data_collator = DataCollatorForSeq2Seq(
        tokenizer,
        model=model,
        padding=True
    )

    # Initialize trainer
    trainer = Seq2SeqTrainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_train,
        eval_dataset=tokenized_val,
        data_collator=data_collator,
    )

    try:
        print("\nStarting training...")
        print("You can stop training at any time by pressing Ctrl+C")
        print("Training will automatically save checkpoints after each epoch")
        
        # Check for existing checkpoints
        last_checkpoint = None
        if os.path.exists(CHECKPOINT_DIR):
            checkpoints = [d for d in os.listdir(CHECKPOINT_DIR) if d.startswith('checkpoint-')]
            if checkpoints:
                last_checkpoint = os.path.join(CHECKPOINT_DIR, sorted(checkpoints, key=lambda x: int(x.split('-')[1]))[-1])
                print(f"\nFound checkpoint: {last_checkpoint}")
                print("Training will resume from this checkpoint.")
        
        # Start or resume training
        trainer.train(resume_from_checkpoint=last_checkpoint)
        
        # Save the final model
        trainer.save_model("./final-model")
        print("\nTraining completed successfully!")
        print(f"Final model saved to: ./final-model")
            
    except KeyboardInterrupt:
        print("\nTraining interrupted by user!")
        print("Progress is saved in the latest checkpoint.")
        print("To resume, just run the script again.")
        
    except Exception as e:
        print(f"\nAn error occurred during training: {str(e)}")
        if os.path.exists(CHECKPOINT_DIR):
            error_checkpoint = os.path.join(CHECKPOINT_DIR, "checkpoint-error")
            trainer.save_model(error_checkpoint)
            print(f"Saved error checkpoint to: {error_checkpoint}")

if __name__ == "__main__":
    main()