import re from typing import List, Dict, Optional, Callable import numpy as np def format_reward(completions, **kwargs): """Reward function that checks if the code is enclosed within and tags, and the final answer is enclosed within and tags.""" pattern = r".*?.*?.*?.*?.*?" completion_contents = [completion[0]["content"] for completion in completions] matches = [re.search(pattern, content, re.DOTALL | re.MULTILINE) is not None for content in completion_contents] return [1.0 if match else 0.0 for match in matches] def accuracy_reward(completions: list[list[dict[str, str]]], correct_answers: list[str], **kwargs) -> list[Optional[float]]: """Reward function that checks if the completion's answer matches the ground truth.""" contents = [completion[0]["content"] for completion in completions] rewards = [] for content, correct_answer in zip(contents, correct_answers): # Extract answer from the completion using regex answer_match = re.search(r'\\boxed{(.*?)}', content, re.DOTALL) if not answer_match: rewards.append(0.0) continue extracted_answer = answer_match.group(1).strip() # Check if the extracted answer matches the correct answer # You might need a more sophisticated comparison for mathematical expressions if extracted_answer == correct_answer: rewards.append(1.0) else: rewards.append(0.0) return rewards def code_execution_reward(completions, **kwargs): """Reward function that checks if the code execution was successful.""" completion_contents = [completion[0]["content"] for completion in completions] # Check for error patterns in interpreter output error_patterns = [ r'.*?Error.*?', r'.*?Exception.*?', r'.*?Traceback.*?' ] rewards = [] for content in completion_contents: # Find all code-interpreter pairs code_blocks = re.findall(r'.*?\s*(.*?)', content, re.DOTALL) if not code_blocks: rewards.append(0.0) continue # Check each interpreter output for errors error_count = 0 for interpreter_output in code_blocks: has_error = any(re.search(pattern, interpreter_output, re.DOTALL) for pattern in error_patterns) if has_error: error_count += 1 # Calculate success rate if len(code_blocks) == 0: rewards.append(0.0) else: success_rate = 1.0 - (error_count / len(code_blocks)) rewards.append(success_rate) return rewards def len_reward(completions, **kwargs): """Reward shorter completions to encourage efficiency.""" completion_contents = [completion[0]["content"] for completion in completions] lengths = [len(content) for content in completion_contents] # If all completions have the same length, return neutral rewards if min(lengths) == max(lengths): return [0.0] * len(completions) # Normalize lengths to [0, 1] range and invert (shorter = higher reward) normalized_lengths = [(length - min(lengths)) / (max(lengths) - min(lengths)) for length in lengths] rewards = [1.0 - norm_length for norm_length in normalized_lengths] # Scale to a smaller range to make this a secondary consideration scaled_rewards = [0.2 * reward for reward in rewards] return scaled_rewards def code_ratio_reward(completions, **kwargs): """Reward appropriate code-to-text ratio.""" completion_contents = [completion[0]["content"] for completion in completions] rewards = [] for content in completion_contents: # Extract all code blocks code_blocks = re.findall(r'(.*?)', content, re.DOTALL) total_code_length = sum(len(code) for code in code_blocks) total_length = len(content) if total_length == 0: rewards.append(0.0) continue code_ratio = total_code_length / total_length # Reward an optimal ratio range (e.g., 0.2 to 0.4) if 0.2 <= code_ratio <= 0.4: rewards.append(0.3) # Full reward elif 0.1 <= code_ratio < 0.2 or 0.4 < code_ratio <= 0.5: rewards.append(0.2) # Partial reward elif 0.05 <= code_ratio < 0.1 or 0.5 < code_ratio <= 0.6: rewards.append(0.1) # Minimal reward else: rewards.append(0.0) # No reward return rewards def code_timing_reward(completions, **kwargs): """Reward for invoking code at appropriate points in the reasoning process.""" completion_contents = [completion[0]["content"] for completion in completions] rewards = [] for content in completion_contents: # Calculate relative position of first code block first_code_pos = content.find('') if first_code_pos == -1: rewards.append(0.0) continue relative_pos = first_code_pos / len(content) # Reward early-to-middle code invocation (between 10% and 40% of the way through) if 0.1 <= relative_pos <= 0.4: rewards.append(0.3) elif 0.05 <= relative_pos < 0.1 or 0.4 < relative_pos <= 0.5: rewards.append(0.2) elif 0.0 <= relative_pos < 0.05 or 0.5 < relative_pos <= 0.7: rewards.append(0.1) else: rewards.append(0.0) return rewards def get_reward_funcs(script_args) -> list[Callable]: """Create a registry of available reward functions and return those specified in script_args.""" REWARD_FUNCS_REGISTRY = { "accuracy": accuracy_reward, "format": format_reward, "code_execution": code_execution_reward, "length": len_reward, "code_ratio": code_ratio_reward, "code_timing": code_timing_reward, } # Get the specified reward functions reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs] return reward_funcs