Spaces:
Running
Running
""" | |
Utility Functions | |
Helper functions for file handling, validation, progress tracking, | |
and system management for the knowledge distillation application. | |
""" | |
import os | |
import logging | |
import asyncio | |
import hashlib | |
import mimetypes | |
import shutil | |
import psutil | |
import time | |
from typing import Dict, Any, List, Optional, Union | |
from pathlib import Path | |
import json | |
import tempfile | |
from datetime import datetime, timedelta | |
import torch | |
import numpy as np | |
from fastapi import UploadFile | |
# Configure logging | |
def setup_logging(level: str = "INFO", log_file: Optional[str] = None) -> None: | |
""" | |
Setup application logging | |
Args: | |
level: Logging level (DEBUG, INFO, WARNING, ERROR) | |
log_file: Optional log file path | |
""" | |
log_level = getattr(logging, level.upper(), logging.INFO) | |
# Configure logging format | |
formatter = logging.Formatter( | |
'%(asctime)s - %(name)s - %(levelname)s - %(message)s' | |
) | |
# Setup handlers | |
handlers = [] | |
# Console handler (always available) | |
console_handler = logging.StreamHandler() | |
console_handler.setFormatter(formatter) | |
handlers.append(console_handler) | |
# File handler (only if writable) | |
try: | |
# Create logs directory if it doesn't exist and is writable | |
logs_dir = Path("logs") | |
logs_dir.mkdir(exist_ok=True) | |
if log_file is None: | |
log_file = f"logs/app_{datetime.now().strftime('%Y%m%d')}.log" | |
# Test if we can write to the log file | |
test_file = Path(log_file) | |
test_file.touch() | |
file_handler = logging.FileHandler(log_file) | |
file_handler.setFormatter(formatter) | |
handlers.append(file_handler) | |
except (PermissionError, OSError): | |
# If we can't write to file, just use console logging | |
print(f"Warning: Cannot write to log file, using console logging only") | |
# Configure root logger | |
logging.basicConfig( | |
level=log_level, | |
handlers=handlers, | |
force=True | |
) | |
logger = logging.getLogger(__name__) | |
logger.info(f"Logging initialized with level: {level}") | |
def validate_file(file: UploadFile) -> Dict[str, Any]: | |
""" | |
Validate uploaded file for security and format compliance | |
Args: | |
file: FastAPI UploadFile object | |
Returns: | |
Validation result dictionary | |
""" | |
try: | |
# File size limits (in bytes) | |
MAX_FILE_SIZE = 5 * 1024 * 1024 * 1024 # 5GB | |
MIN_FILE_SIZE = 1024 # 1KB | |
# Allowed file extensions | |
ALLOWED_EXTENSIONS = { | |
'.pt', '.pth', '.bin', '.safetensors', | |
'.onnx', '.h5', '.pkl', '.joblib' | |
} | |
# Allowed MIME types | |
ALLOWED_MIME_TYPES = { | |
'application/octet-stream', | |
'application/x-pytorch', | |
'application/x-pickle', | |
'application/x-hdf5' | |
} | |
# Check file size | |
if hasattr(file, 'size') and file.size: | |
if file.size > MAX_FILE_SIZE: | |
return { | |
'valid': False, | |
'error': f'File too large. Maximum size: {MAX_FILE_SIZE // (1024**3)}GB' | |
} | |
if file.size < MIN_FILE_SIZE: | |
return { | |
'valid': False, | |
'error': f'File too small. Minimum size: {MIN_FILE_SIZE} bytes' | |
} | |
# Check file extension | |
file_extension = Path(file.filename).suffix.lower() | |
if file_extension not in ALLOWED_EXTENSIONS: | |
return { | |
'valid': False, | |
'error': f'Invalid file extension. Allowed: {", ".join(ALLOWED_EXTENSIONS)}' | |
} | |
# Check MIME type | |
mime_type, _ = mimetypes.guess_type(file.filename) | |
if mime_type and mime_type not in ALLOWED_MIME_TYPES: | |
# Allow octet-stream as fallback for binary files | |
if mime_type != 'application/octet-stream': | |
logging.warning(f"Unexpected MIME type: {mime_type} for {file.filename}") | |
# Check filename for security | |
if not _is_safe_filename(file.filename): | |
return { | |
'valid': False, | |
'error': 'Invalid filename. Contains unsafe characters.' | |
} | |
return { | |
'valid': True, | |
'extension': file_extension, | |
'mime_type': mime_type, | |
'size': getattr(file, 'size', None) | |
} | |
except Exception as e: | |
return { | |
'valid': False, | |
'error': f'Validation error: {str(e)}' | |
} | |
def _is_safe_filename(filename: str) -> bool: | |
"""Check if filename is safe (no path traversal, etc.)""" | |
if not filename: | |
return False | |
# Check for path traversal attempts | |
if '..' in filename or '/' in filename or '\\' in filename: | |
return False | |
# Check for null bytes | |
if '\x00' in filename: | |
return False | |
# Check for control characters | |
if any(ord(c) < 32 for c in filename): | |
return False | |
return True | |
def get_system_info() -> Dict[str, Any]: | |
""" | |
Get system information for monitoring and debugging | |
Returns: | |
System information dictionary | |
""" | |
try: | |
# CPU information | |
cpu_info = { | |
'count': psutil.cpu_count(), | |
'usage_percent': psutil.cpu_percent(interval=1), | |
'frequency': psutil.cpu_freq()._asdict() if psutil.cpu_freq() else None | |
} | |
# Memory information | |
memory = psutil.virtual_memory() | |
memory_info = { | |
'total_gb': round(memory.total / (1024**3), 2), | |
'available_gb': round(memory.available / (1024**3), 2), | |
'used_gb': round(memory.used / (1024**3), 2), | |
'percent': memory.percent | |
} | |
# Disk information | |
disk = psutil.disk_usage('/') | |
disk_info = { | |
'total_gb': round(disk.total / (1024**3), 2), | |
'free_gb': round(disk.free / (1024**3), 2), | |
'used_gb': round(disk.used / (1024**3), 2), | |
'percent': round((disk.used / disk.total) * 100, 2) | |
} | |
# GPU information | |
gpu_info = {} | |
if torch.cuda.is_available(): | |
gpu_info = { | |
'available': True, | |
'count': torch.cuda.device_count(), | |
'current_device': torch.cuda.current_device(), | |
'device_name': torch.cuda.get_device_name(), | |
'memory_allocated_gb': round(torch.cuda.memory_allocated() / (1024**3), 2), | |
'memory_reserved_gb': round(torch.cuda.memory_reserved() / (1024**3), 2) | |
} | |
else: | |
gpu_info = {'available': False} | |
return { | |
'cpu': cpu_info, | |
'memory': memory_info, | |
'disk': disk_info, | |
'gpu': gpu_info, | |
'python_version': f"{psutil.sys.version_info.major}.{psutil.sys.version_info.minor}.{psutil.sys.version_info.micro}", | |
'platform': psutil.os.name | |
} | |
except Exception as e: | |
logging.error(f"Error getting system info: {e}") | |
return {'error': str(e)} | |
def cleanup_temp_files(max_age_hours: int = 24) -> Dict[str, Any]: | |
""" | |
Clean up temporary files older than specified age | |
Args: | |
max_age_hours: Maximum age of files to keep (in hours) | |
Returns: | |
Cleanup statistics | |
""" | |
try: | |
cleanup_stats = { | |
'files_removed': 0, | |
'bytes_freed': 0, | |
'directories_cleaned': [] | |
} | |
cutoff_time = time.time() - (max_age_hours * 3600) | |
# Directories to clean | |
temp_dirs = ['temp', 'uploads'] | |
for dir_name in temp_dirs: | |
dir_path = Path(dir_name) | |
if not dir_path.exists(): | |
continue | |
files_removed = 0 | |
bytes_freed = 0 | |
for file_path in dir_path.rglob('*'): | |
if file_path.is_file(): | |
try: | |
# Check file age | |
if file_path.stat().st_mtime < cutoff_time: | |
file_size = file_path.stat().st_size | |
file_path.unlink() | |
files_removed += 1 | |
bytes_freed += file_size | |
except Exception as e: | |
logging.warning(f"Error removing file {file_path}: {e}") | |
if files_removed > 0: | |
cleanup_stats['directories_cleaned'].append({ | |
'directory': str(dir_path), | |
'files_removed': files_removed, | |
'bytes_freed': bytes_freed | |
}) | |
cleanup_stats['files_removed'] += files_removed | |
cleanup_stats['bytes_freed'] += bytes_freed | |
logging.info(f"Cleanup completed: {cleanup_stats['files_removed']} files removed, " | |
f"{cleanup_stats['bytes_freed'] / (1024**2):.2f} MB freed") | |
return cleanup_stats | |
except Exception as e: | |
logging.error(f"Error during cleanup: {e}") | |
return {'error': str(e)} | |
def calculate_file_hash(file_path: Union[str, Path], algorithm: str = 'sha256') -> str: | |
""" | |
Calculate hash of a file | |
Args: | |
file_path: Path to the file | |
algorithm: Hash algorithm (md5, sha1, sha256, etc.) | |
Returns: | |
Hexadecimal hash string | |
""" | |
try: | |
hash_obj = hashlib.new(algorithm) | |
with open(file_path, 'rb') as f: | |
for chunk in iter(lambda: f.read(8192), b""): | |
hash_obj.update(chunk) | |
return hash_obj.hexdigest() | |
except Exception as e: | |
logging.error(f"Error calculating hash for {file_path}: {e}") | |
raise | |
def format_bytes(bytes_value: int) -> str: | |
""" | |
Format bytes into human-readable string | |
Args: | |
bytes_value: Number of bytes | |
Returns: | |
Formatted string (e.g., "1.5 GB") | |
""" | |
for unit in ['B', 'KB', 'MB', 'GB', 'TB']: | |
if bytes_value < 1024.0: | |
return f"{bytes_value:.1f} {unit}" | |
bytes_value /= 1024.0 | |
return f"{bytes_value:.1f} PB" | |
def format_duration(seconds: float) -> str: | |
""" | |
Format duration in seconds to human-readable string | |
Args: | |
seconds: Duration in seconds | |
Returns: | |
Formatted string (e.g., "2h 30m 15s") | |
""" | |
if seconds < 60: | |
return f"{seconds:.1f}s" | |
elif seconds < 3600: | |
minutes = int(seconds // 60) | |
secs = int(seconds % 60) | |
return f"{minutes}m {secs}s" | |
else: | |
hours = int(seconds // 3600) | |
minutes = int((seconds % 3600) // 60) | |
secs = int(seconds % 60) | |
return f"{hours}h {minutes}m {secs}s" | |
def create_progress_tracker(): | |
""" | |
Create a progress tracking utility | |
Returns: | |
Progress tracker instance | |
""" | |
class ProgressTracker: | |
def __init__(self): | |
self.start_time = time.time() | |
self.last_update = self.start_time | |
self.steps_completed = 0 | |
self.total_steps = 0 | |
def update(self, current_step: int, total_steps: int, message: str = ""): | |
self.steps_completed = current_step | |
self.total_steps = total_steps | |
self.last_update = time.time() | |
# Calculate progress metrics | |
progress = current_step / total_steps if total_steps > 0 else 0 | |
elapsed = self.last_update - self.start_time | |
if progress > 0: | |
eta = (elapsed / progress) * (1 - progress) | |
eta_str = format_duration(eta) | |
else: | |
eta_str = "Unknown" | |
return { | |
'progress': progress, | |
'current_step': current_step, | |
'total_steps': total_steps, | |
'elapsed': format_duration(elapsed), | |
'eta': eta_str, | |
'message': message | |
} | |
return ProgressTracker() | |
def safe_json_load(file_path: Union[str, Path]) -> Optional[Dict[str, Any]]: | |
""" | |
Safely load JSON file with error handling | |
Args: | |
file_path: Path to JSON file | |
Returns: | |
Loaded JSON data or None if error | |
""" | |
try: | |
with open(file_path, 'r', encoding='utf-8') as f: | |
return json.load(f) | |
except Exception as e: | |
logging.warning(f"Error loading JSON from {file_path}: {e}") | |
return None | |
def safe_json_save(data: Dict[str, Any], file_path: Union[str, Path]) -> bool: | |
""" | |
Safely save data to JSON file | |
Args: | |
data: Data to save | |
file_path: Path to save file | |
Returns: | |
True if successful, False otherwise | |
""" | |
try: | |
# Ensure directory exists | |
Path(file_path).parent.mkdir(parents=True, exist_ok=True) | |
with open(file_path, 'w', encoding='utf-8') as f: | |
json.dump(data, f, indent=2, ensure_ascii=False) | |
return True | |
except Exception as e: | |
logging.error(f"Error saving JSON to {file_path}: {e}") | |
return False | |
def get_available_memory() -> float: | |
""" | |
Get available system memory in GB | |
Returns: | |
Available memory in GB | |
""" | |
try: | |
memory = psutil.virtual_memory() | |
return memory.available / (1024**3) | |
except Exception: | |
return 0.0 | |
def check_disk_space(path: str = ".", min_gb: float = 1.0) -> bool: | |
""" | |
Check if there's enough disk space | |
Args: | |
path: Path to check | |
min_gb: Minimum required space in GB | |
Returns: | |
True if enough space available | |
""" | |
try: | |
disk = psutil.disk_usage(path) | |
free_gb = disk.free / (1024**3) | |
return free_gb >= min_gb | |
except Exception: | |
return False | |