File size: 9,027 Bytes
cfa2a65
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
import subprocess
from typing import List

from transformers import TrainerCallback
from transformers.trainer_callback import TrainerControl, TrainerState
from transformers.training_args import TrainingArguments

class CurriculumLearningCallback(TrainerCallback):
    def __init__(self):
        self.current_stage = "format_stage"
        self.stages = {
            "format_stage": {
                "reward_weights": {"format": 1.0, "accuracy": 0.0, "code_execution": 0.0, 
                                  "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
                "beta": 0.1, # Higher KL - stay close to base model format
                "steps": 1000
            },
            "code_execution_stage": {
                "reward_weights": {"format": 0.3, "accuracy": 0.0, "code_execution": 0.7, 
                                  "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
                "beta": 0.05, # Medium KL
                "steps": 2000
            },
            "accuracy_stage": {
                "reward_weights": {"format": 0.2, "accuracy": 0.8, "code_execution": 0.0, 
                                  "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
                "beta": 0.01, # Very low KL - allow exploration
                "steps": 3000
            },
            "refinement_stage": {
                "reward_weights": {"format": 0.1, "accuracy": 0.6, "code_execution": 0.1, 
                                  "length": 0.1, "code_ratio": 0.05, "code_timing": 0.05},
                "beta": 0.03, # Medium-low KL - stabilize learning
                "steps": 5000
            }
        }

        self.total_steps = sum(stage_config["steps"] for stage_config in self.stages.values())
        self.stage_transitions = self._calculate_stage_transitions()
    
    def _calculate_stage_transitions(self):
        """Calculate at which step each stage transition occurs."""
        transitions = {}
        current_step = 0
        for stage, config in self.stages.items():
            current_step += config["steps"]
            transitions[stage] = current_step
        return transitions
    
    def on_step_end(self, args, state, control, **kwargs):
        """Update reward weights based on current training stage."""
        trainer = kwargs.get('trainer')
        if trainer is None:
            return
        
        # Check if it's time to transition to the next stage
        current_step = state.global_step
        
        # Determine current stage
        previous_stage = self.current_stage
        for stage, transition_step in self.stage_transitions.items():
            if current_step <= transition_step:
                self.current_stage = stage
                break
        
        # If stage changed, update weights and log the transition
        if previous_stage != self.current_stage:
            print(f"Transitioning from {previous_stage} to {self.current_stage} at step {current_step}")
        
        # Apply weights for current stage
        stage_weights = self.stages[self.current_stage]["reward_weights"]
        
        # Update trainer's reward weights
        # This assumes the trainer has a reward_weights attribute
        for i, func_name in enumerate(trainer.reward_func_names):
            if func_name in stage_weights:
                trainer.reward_weights[i] = stage_weights[func_name]

                

class CurriculumLearningCallback(TrainerCallback):
    """A callback to implement curriculum learning stages during training."""
    def __init__(self, debug=False):
        self.debug = debug
        self.current_stage = "format_stage"
        self.stages = {
                    "format_stage": {
                        "reward_weights": {"format": 1.0, "accuracy": 0.0, "code_execution": 0.0, 
                                        "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
                        "beta": 0.1, # Higher KL - stay close to base model format
                        "steps": 1000
                    },
                    "code_execution_stage": {
                        "reward_weights": {"format": 0.3, "accuracy": 0.0, "code_execution": 0.7, 
                                        "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
                        "beta": 0.05, # Medium KL
                        "steps": 2000
                    },
                    "accuracy_stage": {
                        "reward_weights": {"format": 0.2, "accuracy": 0.8, "code_execution": 0.0, 
                                        "length": 0.0, "code_ratio": 0.0, "code_timing": 0.0},
                        "beta": 0.01, # Very low KL - allow exploration
                        "steps": 3000
                    },
                    "refinement_stage": {
                        "reward_weights": {"format": 0.1, "accuracy": 0.6, "code_execution": 0.1, 
                                        "length": 0.1, "code_ratio": 0.05, "code_timing": 0.05},
                        "beta": 0.03, # Medium-low KL - stabilize learning
                        "steps": 5000
                    }
                }
        self.total_steps = sum(stage_config["steps"] for stage_config in self.stages.values())
        self.stage_transitions = self._calculate_stage_transitions()
        
        print(f"Curriculum learning initialized with {len(self.stages)} stages:")
        for stage, end_step in self.stage_transitions.items():
            print(f"  {stage}: ends at step {end_step}")
    
    def _calculate_stage_transitions(self):
        """Calculate at which step each stage transition occurs."""
        transitions = {}
        current_step = 0
        for stage, config in self.stages.items():
            current_step += config["steps"]
            transitions[stage] = current_step
        return transitions
    
    def on_train_begin(self, args, state, control, **kwargs):
        """Initialize reward weights and beta at the start of training."""
        trainer = kwargs.get('trainer')
        if trainer is None:
            return
        
        # Set initial weights and beta from first stage
        first_stage = list(self.stages.keys())[0]
        stage_config = self.stages[first_stage]
        
        # Update reward weights
        if hasattr(trainer, "reward_weights") and hasattr(trainer, "reward_func_names"):
            for i, func_name in enumerate(trainer.reward_func_names):
                if func_name in stage_config["reward_weights"]:
                    trainer.reward_weights[i] = stage_config["reward_weights"][func_name]
                    if self.debug:
                        print(f"Setting initial weight for {func_name}: {trainer.reward_weights[i]}")
        else:
            print("Warning: Trainer doesn't have reward_weights or reward_func_names attributes")
        
        # Update beta (KL coefficient)
        if hasattr(trainer, "beta"):
            trainer.beta = stage_config.get("beta", 0.1)
            if self.debug:
                print(f"Setting initial beta: {trainer.beta}")
        else:
            print("Warning: Trainer doesn't have a beta attribute")
    
    def on_step_end(self, args, state, control, **kwargs):
        """Update reward weights and beta based on current training stage."""
        trainer = kwargs.get('trainer')
        if trainer is None:
            return
        
        # Check if it's time to transition to the next stage
        current_step = state.global_step
        
        # Determine current stage
        previous_stage = self.current_stage
        for stage, transition_step in sorted(self.stage_transitions.items()):
            if current_step <= transition_step:
                self.current_stage = stage
                break
        
        # If stage changed, update weights and log the transition
        if previous_stage != self.current_stage:
            print(f"Transitioning from {previous_stage} to {self.current_stage} at step {current_step}")
            
        # Get config for current stage
        stage_config = self.stages[self.current_stage]
        
        # Update reward weights
        if hasattr(trainer, "reward_weights") and hasattr(trainer, "reward_func_names"):
            for i, func_name in enumerate(trainer.reward_func_names):
                if func_name in stage_config["reward_weights"]:
                    new_weight = stage_config["reward_weights"][func_name]
                    if trainer.reward_weights[i] != new_weight:
                        trainer.reward_weights[i] = new_weight
                        if self.debug:
                            print(f"Updated weight for {func_name}: {new_weight}")
        
        # Update beta (KL coefficient)
        if hasattr(trainer, "beta"):
            new_beta = stage_config.get("beta", 0.1)
            if trainer.beta != new_beta:
                trainer.beta = new_beta
                if self.debug:
                    print(f"Updated beta: {new_beta}")