""" Calculation Handler Component This module provides mathematical calculation capabilities for the GAIA agent, handling numeric questions, expressions, and table-based calculations. """ import re import logging import math import operator from typing import Dict, Any, List, Optional, Union, Tuple import traceback import numpy as np from collections import defaultdict logger = logging.getLogger("gaia_agent.components.calculation_handler") class CalculationHandler: """ Handles mathematical calculations, expression parsing, and numeric operations. Provides capabilities for answering numerical questions. """ def __init__(self): """Initialize the calculation handler with supported operations.""" # Map operators to their functions self.binary_ops = { '+': operator.add, '-': operator.sub, '*': operator.mul, '/': operator.truediv, '**': operator.pow, '^': operator.pow, '%': operator.mod, '//': operator.floordiv } # Functions that can be called in expressions self.math_functions = { 'sqrt': math.sqrt, 'sin': math.sin, 'cos': math.cos, 'tan': math.tan, 'abs': abs, 'log': math.log, 'log10': math.log10, 'exp': math.exp, 'ceil': math.ceil, 'floor': math.floor, 'round': round, 'sum': sum, 'mean': lambda x: sum(x) / len(x) if x else 0, 'median': lambda x: sorted(x)[len(x) // 2] if x else 0, 'min': min, 'max': max } logger.info("CalculationHandler initialized") def extract_expression(self, text: str) -> Optional[str]: """ Extract mathematical expressions from text input. Args: text: Input text containing potential mathematical expressions Returns: Extracted mathematical expression or None if not found """ # Try to extract expressions in various formats patterns = [ r'calculate\s+([\d\+\-\*\/\(\)\^\.\s]+)', r'compute\s+([\d\+\-\*\/\(\)\^\.\s]+)', r'evaluate\s+([\d\+\-\*\/\(\)\^\.\s]+)', r'what is\s+([\d\+\-\*\/\(\)\^\.\s]+)', r'(\d+[\d\+\-\*\/\(\)\^\.\s]+\d+)' ] for pattern in patterns: match = re.search(pattern, text, re.IGNORECASE) if match: return match.group(1).strip() # Try to find equations equation_match = re.search(r'([\d\+\-\*\/\(\)\^\.\s]+\=[\d\+\-\*\/\(\)\^\.\s]+)', text) if equation_match: return equation_match.group(1).strip() return None def parse_expression(self, expression: str) -> float: """ Parse and evaluate a mathematical expression. Args: expression: Mathematical expression as a string Returns: Calculated result Raises: ValueError: If expression parsing fails """ try: # Improved and more secure parser # First, normalize and sanitize the expression expression = self._normalize_expression(expression) # Handle special functions for func_name, func in self.math_functions.items(): pattern = fr'{func_name}\(([^)]+)\)' for match in re.finditer(pattern, expression): args_str = match.group(1) # Recursively parse arguments if ',' in args_str: args = [self.parse_expression(arg.strip()) for arg in args_str.split(',')] result = func(args) else: arg = self.parse_expression(args_str.strip()) result = func(arg) expression = expression.replace(match.group(0), str(result)) # For security, parse and evaluate the expression manually # rather than using eval() directly return self._recursive_parse(expression) except Exception as e: logger.error(f"Error parsing expression '{expression}': {str(e)}") raise ValueError(f"Could not parse mathematical expression: {str(e)}") def _normalize_expression(self, expression: str) -> str: """ Normalize a mathematical expression by handling different formats and notations. Args: expression: The raw expression string Returns: Normalized expression string """ # Remove whitespace expression = expression.strip() # Replace ^ with ** for exponentiation expression = expression.replace('^', '**') # Convert × to * and ÷ to / expression = expression.replace('×', '*').replace('÷', '/') # Handle implied multiplication (e.g., 2(3+4) → 2*(3+4)) expression = re.sub(r'(\d+)(\()', r'\1*\2', expression) # Handle percentage expressions expression = re.sub(r'(\d+)%', r'(\1/100)', expression) # Replace common mathematical constants expression = expression.replace('pi', str(math.pi)) expression = expression.replace('e', str(math.e)) return expression def _recursive_parse(self, expression: str) -> float: """ Recursively parse and evaluate an expression using operator precedence. Args: expression: The normalized expression string Returns: Evaluated result Raises: ValueError: If parsing fails """ # Remove all whitespace expression = re.sub(r'\s', '', expression) # Handle parentheses first (highest precedence) paren_pattern = r'\(([^()]+)\)' while '(' in expression: match = re.search(paren_pattern, expression) if not match: raise ValueError(f"Mismatched parentheses in expression: {expression}") # Recursively evaluate the parenthesized sub-expression sub_expr = match.group(1) sub_result = self._recursive_parse(sub_expr) # Replace the entire parenthesized expression with its result expression = expression.replace(f"({sub_expr})", str(sub_result)) # Handle addition and subtraction (lowest precedence) terms = self._split_by_operators(expression, ['+', '-']) if len(terms) > 1: # Parse the first term result = self._recursive_parse(terms[0]) # Process each operator and subsequent term i = 1 while i < len(terms): op = terms[i] next_term = terms[i+1] # Perform the operation if op == '+': result += self._recursive_parse(next_term) elif op == '-': result -= self._recursive_parse(next_term) i += 2 return result # Handle multiplication and division (medium precedence) factors = self._split_by_operators(expression, ['*', '/', '**', '//']) if len(factors) > 1: # Parse the first factor result = self._recursive_parse(factors[0]) # Process each operator and subsequent factor i = 1 while i < len(factors): op = factors[i] next_factor = factors[i+1] # Perform the operation if op == '*': result *= self._recursive_parse(next_factor) elif op == '/': divisor = self._recursive_parse(next_factor) if divisor == 0: raise ValueError("Division by zero") result /= divisor elif op == '**': result = pow(result, self._recursive_parse(next_factor)) elif op == '//': divisor = self._recursive_parse(next_factor) if divisor == 0: raise ValueError("Division by zero") result //= divisor i += 2 return result # Base case: just a number try: return float(expression) except ValueError: # If it's not a simple number, check if it's a constant safe_constants = { 'pi': math.pi, 'e': math.e } if expression in safe_constants: return safe_constants[expression] raise ValueError(f"Cannot parse expression part: {expression}") def _split_by_operators(self, expression: str, operators: List[str]) -> List[str]: """ Split an expression by specified operators, preserving their positions. Args: expression: Expression string to split operators: List of operators to split by Returns: List alternating between terms and operators """ if not expression: return [] # Combine operators into a regex pattern, escaping special chars op_pattern = '|'.join(re.escape(op) for op in sorted(operators, key=len, reverse=True)) # Split the expression, keeping the operators parts = re.split(f'({op_pattern})', expression) # Filter out empty parts return [p for p in parts if p] def extract_numbers(self, text: str) -> List[float]: """ Extract all numbers from a text string. Args: text: Text to extract numbers from Returns: List of extracted numbers """ # Extract numbers (including decimals and negative numbers) number_pattern = r'-?\d+(?:\.\d+)?' return [float(match) for match in re.findall(number_pattern, text)] def check_commutative_property(self, operation: str, values: List[float]) -> bool: """ Check if the given operation is commutative for the provided values. Args: operation: Operation to check ('+', '*', etc.) values: List of numeric values to test Returns: True if commutative, False otherwise """ if len(values) < 2: return True if operation not in self.binary_ops: return False op_func = self.binary_ops[operation] # Test commutativity: a op b == b op a for i in range(len(values)): for j in range(i + 1, len(values)): a, b = values[i], values[j] if abs(op_func(a, b) - op_func(b, a)) > 1e-10: return False return True def create_frequency_table(self, data: List[Any]) -> Dict[Any, int]: """ Create a frequency table from a list of data. Args: data: List of values Returns: Dictionary mapping values to their frequencies """ freq_table = defaultdict(int) for item in data: freq_table[item] += 1 return dict(freq_table) def parse_table_data(self, table_text: str) -> Tuple[List[str], List[List[Any]]]: """ Parse tabular data from text representation. Args: table_text: Text containing table data Returns: Tuple of (column_headers, rows) """ lines = table_text.strip().split('\n') # Extract headers (first line) if '|' in lines[0]: # Markdown table format headers = [h.strip() for h in lines[0].split('|')] # Remove empty entries at start/end from the pipe chars headers = [h for h in headers if h] # Skip separator line if present start_idx = 1 if len(lines) > 1 and all(c == '-' or c == '|' or c == ' ' for c in lines[1]): start_idx = 2 # Extract rows rows = [] for i in range(start_idx, len(lines)): if '|' in lines[i]: row_values = [cell.strip() for cell in lines[i].split('|')] # Remove empty entries at start/end row_values = [cell for cell in row_values if cell != ''] # Convert numeric values converted_values = [] for val in row_values: try: # Try to convert to number if possible if '.' in val: converted_values.append(float(val)) else: converted_values.append(int(val)) except ValueError: converted_values.append(val) rows.append(converted_values) else: # CSV or space-delimited format delimiter = ',' if ',' in lines[0] else None headers = [h.strip() for h in lines[0].split(delimiter)] # Extract rows rows = [] for i in range(1, len(lines)): row_values = [cell.strip() for cell in lines[i].split(delimiter)] # Convert numeric values converted_values = [] for val in row_values: try: # Try to convert to number if possible if '.' in val: converted_values.append(float(val)) else: converted_values.append(int(val)) except ValueError: converted_values.append(val) rows.append(converted_values) return headers, rows def perform_set_operation(self, table_data: Tuple[List[str], List[List[Any]]], operation: str) -> Any: """ Perform set operations on table data. Args: table_data: Table data as (headers, rows) operation: Operation to perform (union, intersection, etc.) Returns: Result of the operation """ headers, rows = table_data # Extract columns as sets columns = {} for i, header in enumerate(headers): if i < len(rows[0]): # Ensure column index is valid column_data = [row[i] for row in rows if i < len(row)] columns[header] = set(column_data) if operation == "union": # Union of all sets result = set() for column_set in columns.values(): result = result.union(column_set) return result elif operation == "intersection": # Intersection of all sets sets = list(columns.values()) if not sets: return set() result = sets[0].copy() for s in sets[1:]: result = result.intersection(s) return result elif operation == "difference": # Difference between first set and all others sets = list(columns.values()) if not sets: return set() result = sets[0].copy() for s in sets[1:]: result = result.difference(s) return result elif operation == "symmetric_difference": # Symmetric difference (elements in either set but not both) sets = list(columns.values()) if not sets: return set() result = sets[0].copy() for s in sets[1:]: result = result.symmetric_difference(s) return result raise ValueError(f"Unsupported set operation: {operation}") def analyze_question(self, question: str) -> Dict[str, Any]: """ Analyze a question to determine if it requires calculation. Args: question: The question to analyze Returns: Dict containing analysis results, including: - requires_calculation: Whether question requires calculation - calculation_type: Type of calculation needed (expression, numeric, table) - expression: Extracted expression if found - answer: Calculated answer if possible - confidence: Confidence in the answer """ result = { "question": question, "requires_calculation": False, "calculation_type": None, "expression": None, "numbers": [], "answer": None, "confidence": 0.0 } # Check for mathematical expressions expression = self.extract_expression(question) if expression: result["requires_calculation"] = True result["calculation_type"] = "expression" result["expression"] = expression try: calculated_result = self.parse_expression(expression) formatted_result = f"{calculated_result:.4f}".rstrip('0').rstrip('.') if '.' in f"{calculated_result}" else f"{calculated_result}" result["answer"] = formatted_result result["confidence"] = 0.95 except ValueError as e: logger.warning(f"Failed to calculate expression: {str(e)}") result["answer"] = f"I couldn't calculate that expression: {str(e)}" result["confidence"] = 0.0 return result # Check for commutative property questions if "commutative" in question.lower(): result["requires_calculation"] = True result["calculation_type"] = "property_check" # Determine the operation being asked about if "addition" in question.lower() or "+" in question: operation = "+" elif "multiplication" in question.lower() or "*" in question or "×" in question: operation = "*" elif "subtraction" in question.lower() or "-" in question: operation = "-" elif "division" in question.lower() or "/" in question or "÷" in question: operation = "/" else: operation = None # Extract numbers if present numbers = self.extract_numbers(question) result["numbers"] = numbers if operation and numbers: is_commutative = self.check_commutative_property(operation, numbers) result["answer"] = "Yes" if is_commutative else "No" result["confidence"] = 0.9 result["explanation"] = f"Testing commutativity of {operation} with values {', '.join(str(n) for n in numbers)}: {'commutative' if is_commutative else 'not commutative'}" elif operation: # If operation is known but no specific numbers provided is_commutative = operation in ["+", "*"] # Only + and * are commutative result["answer"] = "Yes" if is_commutative else "No" result["confidence"] = 0.85 result["explanation"] = f"The {'addition' if operation == '+' else 'multiplication' if operation == '*' else 'subtraction' if operation == '-' else 'division'} operation is {'' if is_commutative else 'not '}commutative." return result # Check for numeric questions (e.g., sum, average, etc.) numeric_indicators = [ "sum", "add", "total", "average", "mean", "median", "minimum", "maximum", "min", "max", "count", "how many" ] if any(indicator in question.lower() for indicator in numeric_indicators): result["requires_calculation"] = True # Extract numbers if present numbers = self.extract_numbers(question) result["numbers"] = numbers if numbers: result["calculation_type"] = "numeric" if "sum" in question.lower() or "add" in question.lower() or "total" in question.lower(): result["answer"] = str(sum(numbers)) result["confidence"] = 0.9 elif "average" in question.lower() or "mean" in question.lower(): result["answer"] = str(sum(numbers) / len(numbers)) result["confidence"] = 0.9 elif "median" in question.lower(): sorted_nums = sorted(numbers) mid = len(sorted_nums) // 2 if len(sorted_nums) % 2 == 0: result["answer"] = str((sorted_nums[mid-1] + sorted_nums[mid]) / 2) else: result["answer"] = str(sorted_nums[mid]) result["confidence"] = 0.9 elif "min" in question.lower() or "minimum" in question.lower(): result["answer"] = str(min(numbers)) result["confidence"] = 0.9 elif "max" in question.lower() or "maximum" in question.lower(): result["answer"] = str(max(numbers)) result["confidence"] = 0.9 elif "count" in question.lower() or "how many" in question.lower(): result["answer"] = str(len(numbers)) result["confidence"] = 0.9 return result # If no calculation pattern detected return result def process_table_calculation(self, question: str, table_data: str) -> Dict[str, Any]: """ Process calculations on tabular data. Args: question: Question about the table table_data: String representation of the table Returns: Dict containing analysis results """ result = { "question": question, "requires_calculation": True, "calculation_type": "table", "answer": None, "confidence": 0.0, "explanation": None, "operations_performed": [] } try: parsed_table = self.parse_table_data(table_data) headers, rows = parsed_table # Extract numeric columns numeric_columns = {} for i, header in enumerate(headers): if i < len(rows[0]): # Ensure column index is valid column_data = [row[i] for row in rows if i < len(row)] if all(isinstance(val, (int, float)) for val in column_data): numeric_columns[header] = column_data # Determine what calculation to perform based on the question question_lower = question.lower() # Find which column is being asked about target_column = None for header in headers: if header.lower() in question_lower: target_column = header break if "sum" in question_lower or "total" in question_lower: if target_column and target_column in numeric_columns: result["answer"] = str(sum(numeric_columns[target_column])) result["confidence"] = 0.9 else: # Sum of all numeric values all_nums = [val for col in numeric_columns.values() for val in col] result["answer"] = str(sum(all_nums)) result["confidence"] = 0.7 elif "average" in question_lower or "mean" in question_lower: if target_column and target_column in numeric_columns: values = numeric_columns[target_column] avg_value = sum(values) / len(values) result["answer"] = f"{avg_value:.2f}" result["confidence"] = 0.9 result["explanation"] = f"Calculated the average of {len(values)} values in column '{target_column}'" result["operations_performed"].append({"operation": "average", "column": target_column, "result": avg_value}) else: # Handle case where no target column is found or it's not numeric pass elif "commutative" in question_lower: # Try to identify which operation to test operation = None if "add" in question_lower or "sum" in question_lower or "addition" in question_lower or "+" in question_lower: operation = "+" elif "multipl" in question_lower or "product" in question_lower or "*" in question_lower or "×" in question_lower: operation = "*" if operation and numeric_columns: # Identify columns to test commutativity on columns_to_test = [] # Check if specific columns are mentioned for header in numeric_columns.keys(): if header.lower() in question_lower: columns_to_test.append(header) # If no specific columns found, use all numeric columns if not columns_to_test and len(numeric_columns) >= 2: columns_to_test = list(numeric_columns.keys())[:2] # Use first two columns if len(columns_to_test) >= 2: col1, col2 = columns_to_test[0], columns_to_test[1] values1 = numeric_columns[col1] values2 = numeric_columns[col2] # Test commutativity all_commutative = True test_pairs = [] # Get operation function op_func = self.binary_ops.get(operation) # Only test on first 5 pairs for efficiency max_tests = min(5, len(values1), len(values2)) for i in range(max_tests): a, b = values1[i], values2[i] result1 = op_func(a, b) result2 = op_func(b, a) is_equal = abs(result1 - result2) < 1e-10 test_pairs.append({ "a": a, "b": b, "a_op_b": result1, "b_op_a": result2, "equal": is_equal }) if not is_equal: all_commutative = False result["answer"] = "Yes" if all_commutative else "No" result["confidence"] = 0.95 result["explanation"] = f"Tested commutativity of {operation} between columns '{col1}' and '{col2}'" result["operations_performed"].append({ "operation": "commutativity_check", "columns": [col1, col2], "test_operation": operation, "result": all_commutative, "test_pairs": test_pairs }) else: result["answer"] = "Cannot check commutativity without at least two numeric columns" result["confidence"] = 0.7 else: # Average of all numeric values all_nums = [val for col in numeric_columns.values() for val in col] result["answer"] = str(sum(all_nums) / len(all_nums)) result["confidence"] = 0.7 elif "maximum" in question_lower or "max" in question_lower: if target_column and target_column in numeric_columns: result["answer"] = str(max(numeric_columns[target_column])) result["confidence"] = 0.9 else: # Maximum of all numeric values all_nums = [val for col in numeric_columns.values() for val in col] result["answer"] = str(max(all_nums)) result["confidence"] = 0.7 elif "minimum" in question_lower or "min" in question_lower: if target_column and target_column in numeric_columns: result["answer"] = str(min(numeric_columns[target_column])) result["confidence"] = 0.9 else: # Minimum of all numeric values all_nums = [val for col in numeric_columns.values() for val in col] result["answer"] = str(min(all_nums)) result["confidence"] = 0.7 elif "count" in question_lower or "how many" in question_lower: if "rows" in question_lower: result["answer"] = str(len(rows)) result["confidence"] = 0.95 elif "columns" in question_lower: result["answer"] = str(len(headers)) result["confidence"] = 0.95 elif target_column: # Count values in the column column_idx = headers.index(target_column) column_values = [row[column_idx] for row in rows if column_idx < len(row)] result["answer"] = str(len(column_values)) result["confidence"] = 0.9 elif "set" in question_lower or "union" in question_lower or "intersection" in question_lower: # Determine set operation if "union" in question_lower: operation = "union" elif "intersection" in question_lower or "common" in question_lower: operation = "intersection" elif "difference" in question_lower: operation = "difference" elif "symmetric" in question_lower and "difference" in question_lower: operation = "symmetric_difference" else: operation = None if operation: set_result = self.perform_set_operation(parsed_table, operation) result["answer"] = str(set_result) result["confidence"] = 0.85 # If no specific calculation identified if result["answer"] is None: # Default to returning basic table statistics result["answer"] = (f"Table has {len(headers)} columns and {len(rows)} rows. " f"Columns: {', '.join(headers)}.") result["confidence"] = 0.5 except Exception as e: logger.error(f"Error processing table calculation: {str(e)}") logger.debug(traceback.format_exc()) result["answer"] = f"Could not process table calculation: {str(e)}" result["confidence"] = 0.0 return result