|
import logging |
|
import asyncio |
|
import functools |
|
from typing import Any, Callable, Dict, List, Optional |
|
import time |
|
import json |
|
from pathlib import Path |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
def async_timer(func: Callable) -> Callable: |
|
"""Decorator to time async function execution""" |
|
@functools.wraps(func) |
|
async def wrapper(*args, **kwargs): |
|
start_time = time.time() |
|
try: |
|
result = await func(*args, **kwargs) |
|
end_time = time.time() |
|
logger.debug(f"{func.__name__} completed in {end_time - start_time:.3f}s") |
|
return result |
|
except Exception as e: |
|
end_time = time.time() |
|
logger.error(f"{func.__name__} failed after {end_time - start_time:.3f}s: {str(e)}") |
|
raise |
|
return wrapper |
|
|
|
def retry_async(max_attempts: int = 3, delay: float = 1.0, backoff: float = 2.0): |
|
"""Decorator to retry async functions with exponential backoff""" |
|
def decorator(func: Callable) -> Callable: |
|
@functools.wraps(func) |
|
async def wrapper(*args, **kwargs): |
|
attempt = 1 |
|
current_delay = delay |
|
|
|
while attempt <= max_attempts: |
|
try: |
|
return await func(*args, **kwargs) |
|
except Exception as e: |
|
if attempt == max_attempts: |
|
logger.error(f"{func.__name__} failed after {max_attempts} attempts: {str(e)}") |
|
raise |
|
|
|
logger.warning(f"{func.__name__} attempt {attempt} failed: {str(e)}") |
|
logger.info(f"Retrying in {current_delay}s...") |
|
|
|
await asyncio.sleep(current_delay) |
|
attempt += 1 |
|
current_delay *= backoff |
|
|
|
return wrapper |
|
return decorator |
|
|
|
class MCPToolResponse: |
|
"""Standardized response format for MCP tools""" |
|
|
|
def __init__(self, success: bool, data: Any = None, error: str = None, |
|
metadata: Dict[str, Any] = None): |
|
self.success = success |
|
self.data = data |
|
self.error = error |
|
self.metadata = metadata or {} |
|
self.timestamp = time.time() |
|
|
|
def to_dict(self) -> Dict[str, Any]: |
|
"""Convert response to dictionary""" |
|
result = { |
|
"success": self.success, |
|
"timestamp": self.timestamp |
|
} |
|
|
|
if self.success: |
|
result["data"] = self.data |
|
else: |
|
result["error"] = self.error |
|
|
|
if self.metadata: |
|
result["metadata"] = self.metadata |
|
|
|
return result |
|
|
|
@classmethod |
|
def success_response(cls, data: Any, metadata: Dict[str, Any] = None): |
|
"""Create a success response""" |
|
return cls(success=True, data=data, metadata=metadata) |
|
|
|
@classmethod |
|
def error_response(cls, error: str, metadata: Dict[str, Any] = None): |
|
"""Create an error response""" |
|
return cls(success=False, error=error, metadata=metadata) |
|
|
|
def validate_required_params(params: Dict[str, Any], required: List[str]) -> Optional[str]: |
|
"""Validate that required parameters are present""" |
|
missing = [] |
|
for param in required: |
|
if param not in params or params[param] is None: |
|
missing.append(param) |
|
|
|
if missing: |
|
return f"Missing required parameters: {', '.join(missing)}" |
|
|
|
return None |
|
|
|
def sanitize_filename(filename: str) -> str: |
|
"""Sanitize filename for safe storage""" |
|
import re |
|
|
|
|
|
filename = re.sub(r'[<>:"/\\|?*]', '_', filename) |
|
|
|
|
|
filename = filename.strip('. ') |
|
|
|
|
|
if len(filename) > 255: |
|
name, ext = Path(filename).stem, Path(filename).suffix |
|
max_name_len = 255 - len(ext) |
|
filename = name[:max_name_len] + ext |
|
|
|
|
|
if not filename: |
|
filename = "unnamed_file" |
|
|
|
return filename |
|
|
|
def truncate_text(text: str, max_length: int, add_ellipsis: bool = True) -> str: |
|
"""Truncate text to specified length""" |
|
if len(text) <= max_length: |
|
return text |
|
|
|
if add_ellipsis and max_length > 3: |
|
return text[:max_length - 3] + "..." |
|
else: |
|
return text[:max_length] |
|
|
|
def extract_file_info(file_path: str) -> Dict[str, Any]: |
|
"""Extract information about a file""" |
|
try: |
|
path = Path(file_path) |
|
stat = path.stat() |
|
|
|
return { |
|
"filename": path.name, |
|
"extension": path.suffix.lower(), |
|
"size_bytes": stat.st_size, |
|
"size_mb": round(stat.st_size / (1024 * 1024), 2), |
|
"created_time": stat.st_ctime, |
|
"modified_time": stat.st_mtime, |
|
"exists": path.exists(), |
|
"is_file": path.is_file(), |
|
"is_dir": path.is_dir() |
|
} |
|
except Exception as e: |
|
return {"error": str(e)} |
|
|
|
async def batch_process(items: List[Any], processor: Callable, batch_size: int = 10, |
|
max_concurrent: int = 5) -> List[Any]: |
|
"""Process items in batches with concurrency control""" |
|
results = [] |
|
semaphore = asyncio.Semaphore(max_concurrent) |
|
|
|
async def process_item(item): |
|
async with semaphore: |
|
return await processor(item) |
|
|
|
|
|
for i in range(0, len(items), batch_size): |
|
batch = items[i:i + batch_size] |
|
batch_tasks = [process_item(item) for item in batch] |
|
batch_results = await asyncio.gather(*batch_tasks, return_exceptions=True) |
|
results.extend(batch_results) |
|
|
|
return results |
|
|
|
def format_file_size(size_bytes: int) -> str: |
|
"""Format file size in human-readable format""" |
|
for unit in ['B', 'KB', 'MB', 'GB', 'TB']: |
|
if size_bytes < 1024.0: |
|
return f"{size_bytes:.1f} {unit}" |
|
size_bytes /= 1024.0 |
|
return f"{size_bytes:.1f} PB" |
|
|
|
def calculate_reading_time(text: str, words_per_minute: int = 200) -> int: |
|
"""Calculate estimated reading time in minutes""" |
|
word_count = len(text.split()) |
|
return max(1, round(word_count / words_per_minute)) |
|
|
|
class ProgressTracker: |
|
"""Track progress of long-running operations""" |
|
|
|
def __init__(self, total_items: int, description: str = "Processing"): |
|
self.total_items = total_items |
|
self.completed_items = 0 |
|
self.description = description |
|
self.start_time = time.time() |
|
self.errors = [] |
|
|
|
def update(self, completed: int = 1, error: str = None): |
|
"""Update progress""" |
|
self.completed_items += completed |
|
if error: |
|
self.errors.append(error) |
|
|
|
def get_progress(self) -> Dict[str, Any]: |
|
"""Get current progress information""" |
|
elapsed_time = time.time() - self.start_time |
|
progress_percent = (self.completed_items / self.total_items) * 100 if self.total_items > 0 else 0 |
|
|
|
|
|
if self.completed_items > 0: |
|
avg_time_per_item = elapsed_time / self.completed_items |
|
remaining_items = self.total_items - self.completed_items |
|
estimated_remaining_time = avg_time_per_item * remaining_items |
|
else: |
|
estimated_remaining_time = 0 |
|
|
|
return { |
|
"description": self.description, |
|
"total_items": self.total_items, |
|
"completed_items": self.completed_items, |
|
"progress_percent": round(progress_percent, 1), |
|
"elapsed_time_seconds": round(elapsed_time, 1), |
|
"estimated_remaining_seconds": round(estimated_remaining_time, 1), |
|
"errors_count": len(self.errors), |
|
"errors": self.errors[-5:] if self.errors else [] |
|
} |
|
|
|
def is_complete(self) -> bool: |
|
"""Check if processing is complete""" |
|
return self.completed_items >= self.total_items |
|
|
|
def load_json_config(config_path: str, default_config: Dict[str, Any] = None) -> Dict[str, Any]: |
|
"""Load configuration from JSON file with fallback to defaults""" |
|
try: |
|
with open(config_path, 'r') as f: |
|
config = json.load(f) |
|
logger.info(f"Loaded configuration from {config_path}") |
|
return config |
|
except FileNotFoundError: |
|
logger.warning(f"Configuration file {config_path} not found, using defaults") |
|
return default_config or {} |
|
except json.JSONDecodeError as e: |
|
logger.error(f"Invalid JSON in configuration file {config_path}: {str(e)}") |
|
return default_config or {} |
|
|
|
def save_json_config(config: Dict[str, Any], config_path: str) -> bool: |
|
"""Save configuration to JSON file""" |
|
try: |
|
|
|
Path(config_path).parent.mkdir(parents=True, exist_ok=True) |
|
|
|
with open(config_path, 'w') as f: |
|
json.dump(config, f, indent=2) |
|
|
|
logger.info(f"Saved configuration to {config_path}") |
|
return True |
|
except Exception as e: |
|
logger.error(f"Failed to save configuration to {config_path}: {str(e)}") |
|
return False |
|
|
|
class RateLimiter: |
|
"""Simple rate limiter for API calls""" |
|
|
|
def __init__(self, max_calls: int, time_window: float): |
|
self.max_calls = max_calls |
|
self.time_window = time_window |
|
self.calls = [] |
|
|
|
async def acquire(self): |
|
"""Acquire permission to make a call""" |
|
now = time.time() |
|
|
|
|
|
self.calls = [call_time for call_time in self.calls if now - call_time < self.time_window] |
|
|
|
|
|
if len(self.calls) >= self.max_calls: |
|
|
|
oldest_call = min(self.calls) |
|
wait_time = self.time_window - (now - oldest_call) |
|
if wait_time > 0: |
|
await asyncio.sleep(wait_time) |
|
return await self.acquire() |
|
|
|
|
|
self.calls.append(now) |
|
|
|
def escape_markdown(text: str) -> str: |
|
"""Escape markdown special characters""" |
|
import re |
|
|
|
|
|
markdown_chars = r'([*_`\[\]()#+\-!\\])' |
|
return re.sub(markdown_chars, r'\\\1', text) |
|
|
|
def create_error_summary(errors: List[Exception]) -> str: |
|
"""Create a summary of multiple errors""" |
|
if not errors: |
|
return "No errors" |
|
|
|
error_counts = {} |
|
for error in errors: |
|
error_type = type(error).__name__ |
|
error_counts[error_type] = error_counts.get(error_type, 0) + 1 |
|
|
|
summary_parts = [] |
|
for error_type, count in error_counts.items(): |
|
if count == 1: |
|
summary_parts.append(f"1 {error_type}") |
|
else: |
|
summary_parts.append(f"{count} {error_type}s") |
|
|
|
return f"Encountered {len(errors)} total errors: " + ", ".join(summary_parts) |
|
|
|
async def safe_execute(func: Callable, *args, default_return=None, **kwargs): |
|
"""Safely execute a function and return default on error""" |
|
try: |
|
if asyncio.iscoroutinefunction(func): |
|
return await func(*args, **kwargs) |
|
else: |
|
return func(*args, **kwargs) |
|
except Exception as e: |
|
logger.error(f"Error executing {func.__name__}: {str(e)}") |
|
return default_return |
|
|
|
def get_content_preview(content: str, max_length: int = 200) -> str: |
|
"""Get a preview of content for display""" |
|
if not content: |
|
return "No content" |
|
|
|
|
|
content = ' '.join(content.split()) |
|
|
|
if len(content) <= max_length: |
|
return content |
|
|
|
|
|
preview = content[:max_length] |
|
last_sentence_end = max(preview.rfind('.'), preview.rfind('!'), preview.rfind('?')) |
|
|
|
if last_sentence_end > max_length * 0.7: |
|
return preview[:last_sentence_end + 1] |
|
else: |
|
|
|
last_space = preview.rfind(' ') |
|
if last_space > max_length * 0.7: |
|
return preview[:last_space] + "..." |
|
else: |
|
return preview + "..." |
|
|
|
class MemoryUsageTracker: |
|
"""Track memory usage of operations""" |
|
|
|
def __init__(self): |
|
self.start_memory = self._get_memory_usage() |
|
|
|
def _get_memory_usage(self) -> float: |
|
"""Get current memory usage in MB""" |
|
try: |
|
import psutil |
|
process = psutil.Process() |
|
return process.memory_info().rss / 1024 / 1024 |
|
except ImportError: |
|
return 0.0 |
|
|
|
def get_usage_delta(self) -> float: |
|
"""Get memory usage change since initialization""" |
|
current_memory = self._get_memory_usage() |
|
return current_memory - self.start_memory |
|
|
|
def log_usage(self, operation_name: str): |
|
"""Log current memory usage for an operation""" |
|
delta = self.get_usage_delta() |
|
logger.info(f"{operation_name} memory delta: {delta:.1f} MB") |