train-modle / src /utils.py
fokan's picture
Initial clean commit: Multi-Modal Knowledge Distillation Platform
ab4e093
"""
Utility Functions
Helper functions for file handling, validation, progress tracking,
and system management for the knowledge distillation application.
"""
import os
import logging
import asyncio
import hashlib
import mimetypes
import shutil
import psutil
import time
from typing import Dict, Any, List, Optional, Union
from pathlib import Path
import json
import tempfile
from datetime import datetime, timedelta
import torch
import numpy as np
from fastapi import UploadFile
# Configure logging
def setup_logging(level: str = "INFO", log_file: Optional[str] = None) -> None:
"""
Setup application logging
Args:
level: Logging level (DEBUG, INFO, WARNING, ERROR)
log_file: Optional log file path
"""
log_level = getattr(logging, level.upper(), logging.INFO)
# Configure logging format
formatter = logging.Formatter(
'%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
# Setup handlers
handlers = []
# Console handler (always available)
console_handler = logging.StreamHandler()
console_handler.setFormatter(formatter)
handlers.append(console_handler)
# File handler (only if writable)
try:
# Create logs directory if it doesn't exist and is writable
logs_dir = Path("logs")
logs_dir.mkdir(exist_ok=True)
if log_file is None:
log_file = f"logs/app_{datetime.now().strftime('%Y%m%d')}.log"
# Test if we can write to the log file
test_file = Path(log_file)
test_file.touch()
file_handler = logging.FileHandler(log_file)
file_handler.setFormatter(formatter)
handlers.append(file_handler)
except (PermissionError, OSError):
# If we can't write to file, just use console logging
print(f"Warning: Cannot write to log file, using console logging only")
# Configure root logger
logging.basicConfig(
level=log_level,
handlers=handlers,
force=True
)
logger = logging.getLogger(__name__)
logger.info(f"Logging initialized with level: {level}")
def validate_file(file: UploadFile) -> Dict[str, Any]:
"""
Validate uploaded file for security and format compliance
Args:
file: FastAPI UploadFile object
Returns:
Validation result dictionary
"""
try:
# File size limits (in bytes)
MAX_FILE_SIZE = 5 * 1024 * 1024 * 1024 # 5GB
MIN_FILE_SIZE = 1024 # 1KB
# Allowed file extensions
ALLOWED_EXTENSIONS = {
'.pt', '.pth', '.bin', '.safetensors',
'.onnx', '.h5', '.pkl', '.joblib'
}
# Allowed MIME types
ALLOWED_MIME_TYPES = {
'application/octet-stream',
'application/x-pytorch',
'application/x-pickle',
'application/x-hdf5'
}
# Check file size
if hasattr(file, 'size') and file.size:
if file.size > MAX_FILE_SIZE:
return {
'valid': False,
'error': f'File too large. Maximum size: {MAX_FILE_SIZE // (1024**3)}GB'
}
if file.size < MIN_FILE_SIZE:
return {
'valid': False,
'error': f'File too small. Minimum size: {MIN_FILE_SIZE} bytes'
}
# Check file extension
file_extension = Path(file.filename).suffix.lower()
if file_extension not in ALLOWED_EXTENSIONS:
return {
'valid': False,
'error': f'Invalid file extension. Allowed: {", ".join(ALLOWED_EXTENSIONS)}'
}
# Check MIME type
mime_type, _ = mimetypes.guess_type(file.filename)
if mime_type and mime_type not in ALLOWED_MIME_TYPES:
# Allow octet-stream as fallback for binary files
if mime_type != 'application/octet-stream':
logging.warning(f"Unexpected MIME type: {mime_type} for {file.filename}")
# Check filename for security
if not _is_safe_filename(file.filename):
return {
'valid': False,
'error': 'Invalid filename. Contains unsafe characters.'
}
return {
'valid': True,
'extension': file_extension,
'mime_type': mime_type,
'size': getattr(file, 'size', None)
}
except Exception as e:
return {
'valid': False,
'error': f'Validation error: {str(e)}'
}
def _is_safe_filename(filename: str) -> bool:
"""Check if filename is safe (no path traversal, etc.)"""
if not filename:
return False
# Check for path traversal attempts
if '..' in filename or '/' in filename or '\\' in filename:
return False
# Check for null bytes
if '\x00' in filename:
return False
# Check for control characters
if any(ord(c) < 32 for c in filename):
return False
return True
def get_system_info() -> Dict[str, Any]:
"""
Get system information for monitoring and debugging
Returns:
System information dictionary
"""
try:
# CPU information
cpu_info = {
'count': psutil.cpu_count(),
'usage_percent': psutil.cpu_percent(interval=1),
'frequency': psutil.cpu_freq()._asdict() if psutil.cpu_freq() else None
}
# Memory information
memory = psutil.virtual_memory()
memory_info = {
'total_gb': round(memory.total / (1024**3), 2),
'available_gb': round(memory.available / (1024**3), 2),
'used_gb': round(memory.used / (1024**3), 2),
'percent': memory.percent
}
# Disk information
disk = psutil.disk_usage('/')
disk_info = {
'total_gb': round(disk.total / (1024**3), 2),
'free_gb': round(disk.free / (1024**3), 2),
'used_gb': round(disk.used / (1024**3), 2),
'percent': round((disk.used / disk.total) * 100, 2)
}
# GPU information
gpu_info = {}
if torch.cuda.is_available():
gpu_info = {
'available': True,
'count': torch.cuda.device_count(),
'current_device': torch.cuda.current_device(),
'device_name': torch.cuda.get_device_name(),
'memory_allocated_gb': round(torch.cuda.memory_allocated() / (1024**3), 2),
'memory_reserved_gb': round(torch.cuda.memory_reserved() / (1024**3), 2)
}
else:
gpu_info = {'available': False}
return {
'cpu': cpu_info,
'memory': memory_info,
'disk': disk_info,
'gpu': gpu_info,
'python_version': f"{psutil.sys.version_info.major}.{psutil.sys.version_info.minor}.{psutil.sys.version_info.micro}",
'platform': psutil.os.name
}
except Exception as e:
logging.error(f"Error getting system info: {e}")
return {'error': str(e)}
def cleanup_temp_files(max_age_hours: int = 24) -> Dict[str, Any]:
"""
Clean up temporary files older than specified age
Args:
max_age_hours: Maximum age of files to keep (in hours)
Returns:
Cleanup statistics
"""
try:
cleanup_stats = {
'files_removed': 0,
'bytes_freed': 0,
'directories_cleaned': []
}
cutoff_time = time.time() - (max_age_hours * 3600)
# Directories to clean
temp_dirs = ['temp', 'uploads']
for dir_name in temp_dirs:
dir_path = Path(dir_name)
if not dir_path.exists():
continue
files_removed = 0
bytes_freed = 0
for file_path in dir_path.rglob('*'):
if file_path.is_file():
try:
# Check file age
if file_path.stat().st_mtime < cutoff_time:
file_size = file_path.stat().st_size
file_path.unlink()
files_removed += 1
bytes_freed += file_size
except Exception as e:
logging.warning(f"Error removing file {file_path}: {e}")
if files_removed > 0:
cleanup_stats['directories_cleaned'].append({
'directory': str(dir_path),
'files_removed': files_removed,
'bytes_freed': bytes_freed
})
cleanup_stats['files_removed'] += files_removed
cleanup_stats['bytes_freed'] += bytes_freed
logging.info(f"Cleanup completed: {cleanup_stats['files_removed']} files removed, "
f"{cleanup_stats['bytes_freed'] / (1024**2):.2f} MB freed")
return cleanup_stats
except Exception as e:
logging.error(f"Error during cleanup: {e}")
return {'error': str(e)}
def calculate_file_hash(file_path: Union[str, Path], algorithm: str = 'sha256') -> str:
"""
Calculate hash of a file
Args:
file_path: Path to the file
algorithm: Hash algorithm (md5, sha1, sha256, etc.)
Returns:
Hexadecimal hash string
"""
try:
hash_obj = hashlib.new(algorithm)
with open(file_path, 'rb') as f:
for chunk in iter(lambda: f.read(8192), b""):
hash_obj.update(chunk)
return hash_obj.hexdigest()
except Exception as e:
logging.error(f"Error calculating hash for {file_path}: {e}")
raise
def format_bytes(bytes_value: int) -> str:
"""
Format bytes into human-readable string
Args:
bytes_value: Number of bytes
Returns:
Formatted string (e.g., "1.5 GB")
"""
for unit in ['B', 'KB', 'MB', 'GB', 'TB']:
if bytes_value < 1024.0:
return f"{bytes_value:.1f} {unit}"
bytes_value /= 1024.0
return f"{bytes_value:.1f} PB"
def format_duration(seconds: float) -> str:
"""
Format duration in seconds to human-readable string
Args:
seconds: Duration in seconds
Returns:
Formatted string (e.g., "2h 30m 15s")
"""
if seconds < 60:
return f"{seconds:.1f}s"
elif seconds < 3600:
minutes = int(seconds // 60)
secs = int(seconds % 60)
return f"{minutes}m {secs}s"
else:
hours = int(seconds // 3600)
minutes = int((seconds % 3600) // 60)
secs = int(seconds % 60)
return f"{hours}h {minutes}m {secs}s"
def create_progress_tracker():
"""
Create a progress tracking utility
Returns:
Progress tracker instance
"""
class ProgressTracker:
def __init__(self):
self.start_time = time.time()
self.last_update = self.start_time
self.steps_completed = 0
self.total_steps = 0
def update(self, current_step: int, total_steps: int, message: str = ""):
self.steps_completed = current_step
self.total_steps = total_steps
self.last_update = time.time()
# Calculate progress metrics
progress = current_step / total_steps if total_steps > 0 else 0
elapsed = self.last_update - self.start_time
if progress > 0:
eta = (elapsed / progress) * (1 - progress)
eta_str = format_duration(eta)
else:
eta_str = "Unknown"
return {
'progress': progress,
'current_step': current_step,
'total_steps': total_steps,
'elapsed': format_duration(elapsed),
'eta': eta_str,
'message': message
}
return ProgressTracker()
def safe_json_load(file_path: Union[str, Path]) -> Optional[Dict[str, Any]]:
"""
Safely load JSON file with error handling
Args:
file_path: Path to JSON file
Returns:
Loaded JSON data or None if error
"""
try:
with open(file_path, 'r', encoding='utf-8') as f:
return json.load(f)
except Exception as e:
logging.warning(f"Error loading JSON from {file_path}: {e}")
return None
def safe_json_save(data: Dict[str, Any], file_path: Union[str, Path]) -> bool:
"""
Safely save data to JSON file
Args:
data: Data to save
file_path: Path to save file
Returns:
True if successful, False otherwise
"""
try:
# Ensure directory exists
Path(file_path).parent.mkdir(parents=True, exist_ok=True)
with open(file_path, 'w', encoding='utf-8') as f:
json.dump(data, f, indent=2, ensure_ascii=False)
return True
except Exception as e:
logging.error(f"Error saving JSON to {file_path}: {e}")
return False
def get_available_memory() -> float:
"""
Get available system memory in GB
Returns:
Available memory in GB
"""
try:
memory = psutil.virtual_memory()
return memory.available / (1024**3)
except Exception:
return 0.0
def check_disk_space(path: str = ".", min_gb: float = 1.0) -> bool:
"""
Check if there's enough disk space
Args:
path: Path to check
min_gb: Minimum required space in GB
Returns:
True if enough space available
"""
try:
disk = psutil.disk_usage(path)
free_gb = disk.free / (1024**3)
return free_gb >= min_gb
except Exception:
return False