File size: 6,397 Bytes
2fc6f4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import re
from typing import List, Dict, Optional, Callable
import numpy as np

def format_reward(completions, **kwargs):
    """Reward function that checks if the code is enclosed within <code> and </code> tags,
    and the final answer is enclosed within <answer> and </answer> tags."""
    pattern = r".*?<code>.*?</code>.*?<answer>.*?</answer>.*?"
    completion_contents = [completion[0]["content"] for completion in completions]
    matches = [re.search(pattern, content, re.DOTALL | re.MULTILINE) is not None for content in completion_contents]
    return [1.0 if match else 0.0 for match in matches]

def accuracy_reward(completions: list[list[dict[str, str]]], correct_answers: list[str], **kwargs) -> list[Optional[float]]:
    """Reward function that checks if the completion's answer matches the ground truth."""
    contents = [completion[0]["content"] for completion in completions]
    rewards = []
    
    for content, correct_answer in zip(contents, correct_answers):
        # Extract answer from the completion using regex
        answer_match = re.search(r'<answer>\\boxed{(.*?)}</answer>', content, re.DOTALL)
        if not answer_match:
            rewards.append(0.0)
            continue
            
        extracted_answer = answer_match.group(1).strip()
        
        # Check if the extracted answer matches the correct answer
        # You might need a more sophisticated comparison for mathematical expressions
        if extracted_answer == correct_answer:
            rewards.append(1.0)
        else:
            rewards.append(0.0)
            
    return rewards

def code_execution_reward(completions, **kwargs):
    """Reward function that checks if the code execution was successful."""
    completion_contents = [completion[0]["content"] for completion in completions]
    # Check for error patterns in interpreter output
    error_patterns = [
        r'<interpreter>.*?Error.*?</interpreter>',
        r'<interpreter>.*?Exception.*?</interpreter>',
        r'<interpreter>.*?Traceback.*?</interpreter>'
    ]
    
    rewards = []
    for content in completion_contents:
        # Find all code-interpreter pairs
        code_blocks = re.findall(r'<code>.*?</code>\s*<interpreter>(.*?)</interpreter>', content, re.DOTALL)
        if not code_blocks:
            rewards.append(0.0)
            continue
            
        # Check each interpreter output for errors
        error_count = 0
        for interpreter_output in code_blocks:
            has_error = any(re.search(pattern, interpreter_output, re.DOTALL) for pattern in error_patterns)
            if has_error:
                error_count += 1
                
        # Calculate success rate
        if len(code_blocks) == 0:
            rewards.append(0.0)
        else:
            success_rate = 1.0 - (error_count / len(code_blocks))
            rewards.append(success_rate)
            
    return rewards

def len_reward(completions, **kwargs):
    """Reward shorter completions to encourage efficiency."""
    completion_contents = [completion[0]["content"] for completion in completions]
    lengths = [len(content) for content in completion_contents]
    
    # If all completions have the same length, return neutral rewards
    if min(lengths) == max(lengths):
        return [0.0] * len(completions)
    
    # Normalize lengths to [0, 1] range and invert (shorter = higher reward)
    normalized_lengths = [(length - min(lengths)) / (max(lengths) - min(lengths)) for length in lengths]
    rewards = [1.0 - norm_length for norm_length in normalized_lengths]
    
    # Scale to a smaller range to make this a secondary consideration
    scaled_rewards = [0.2 * reward for reward in rewards]
    
    return scaled_rewards

def code_ratio_reward(completions, **kwargs):
    """Reward appropriate code-to-text ratio."""
    completion_contents = [completion[0]["content"] for completion in completions]
    rewards = []
    
    for content in completion_contents:
        # Extract all code blocks
        code_blocks = re.findall(r'<code>(.*?)</code>', content, re.DOTALL)
        total_code_length = sum(len(code) for code in code_blocks)
        total_length = len(content)
        
        if total_length == 0:
            rewards.append(0.0)
            continue
            
        code_ratio = total_code_length / total_length
        
        # Reward an optimal ratio range (e.g., 0.2 to 0.4)
        if 0.2 <= code_ratio <= 0.4:
            rewards.append(0.3)  # Full reward
        elif 0.1 <= code_ratio < 0.2 or 0.4 < code_ratio <= 0.5:
            rewards.append(0.2)  # Partial reward
        elif 0.05 <= code_ratio < 0.1 or 0.5 < code_ratio <= 0.6:
            rewards.append(0.1)  # Minimal reward
        else:
            rewards.append(0.0)  # No reward
            
    return rewards

def code_timing_reward(completions, **kwargs):
    """Reward for invoking code at appropriate points in the reasoning process."""
    completion_contents = [completion[0]["content"] for completion in completions]
    rewards = []
    
    for content in completion_contents:
        # Calculate relative position of first code block
        first_code_pos = content.find('<code>')
        if first_code_pos == -1:
            rewards.append(0.0)
            continue
            
        relative_pos = first_code_pos / len(content)
        
        # Reward early-to-middle code invocation (between 10% and 40% of the way through)
        if 0.1 <= relative_pos <= 0.4:
            rewards.append(0.3)
        elif 0.05 <= relative_pos < 0.1 or 0.4 < relative_pos <= 0.5:
            rewards.append(0.2)
        elif 0.0 <= relative_pos < 0.05 or 0.5 < relative_pos <= 0.7:
            rewards.append(0.1)
        else:
            rewards.append(0.0)
            
    return rewards

def get_reward_funcs(script_args) -> list[Callable]:
    """Create a registry of available reward functions and return those specified in script_args."""
    REWARD_FUNCS_REGISTRY = {
        "accuracy": accuracy_reward,
        "format": format_reward,
        "code_execution": code_execution_reward,
        "length": len_reward,
        "code_ratio": code_ratio_reward,
        "code_timing": code_timing_reward,
    }
    
    # Get the specified reward functions
    reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
    return reward_funcs