Spaces:
Running
Running
""" | |
Database manager for the AI Knowledge Distillation Platform | |
""" | |
import sqlite3 | |
import logging | |
from pathlib import Path | |
from typing import Dict, Any, List, Optional | |
from datetime import datetime | |
logger = logging.getLogger(__name__) | |
class DatabaseManager: | |
""" | |
Centralized database manager for all platform data | |
""" | |
def __init__(self, db_dir: str = "database"): | |
""" | |
Initialize database manager | |
Args: | |
db_dir: Directory for database files | |
""" | |
self.db_dir = Path(db_dir) | |
self.db_dir.mkdir(parents=True, exist_ok=True) | |
# Database file paths | |
self.tokens_db = self.db_dir / "tokens.db" | |
self.training_db = self.db_dir / "training_sessions.db" | |
self.performance_db = self.db_dir / "performance_metrics.db" | |
self.medical_db = self.db_dir / "medical_datasets.db" | |
# Initialize all databases | |
self._init_all_databases() | |
logger.info("Database Manager initialized") | |
def _init_all_databases(self): | |
"""Initialize all database schemas""" | |
self._init_tokens_database() | |
self._init_training_database() | |
self._init_performance_database() | |
self._init_medical_database() | |
def _init_tokens_database(self): | |
"""Initialize tokens database""" | |
with sqlite3.connect(self.tokens_db) as conn: | |
conn.execute(''' | |
CREATE TABLE IF NOT EXISTS tokens ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
name TEXT UNIQUE NOT NULL, | |
token_type TEXT NOT NULL, | |
encrypted_token TEXT NOT NULL, | |
is_default BOOLEAN DEFAULT FALSE, | |
description TEXT, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
last_used TIMESTAMP, | |
usage_count INTEGER DEFAULT 0, | |
is_active BOOLEAN DEFAULT TRUE | |
) | |
''') | |
conn.execute(''' | |
CREATE TABLE IF NOT EXISTS token_usage_log ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
token_name TEXT NOT NULL, | |
operation TEXT NOT NULL, | |
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
success BOOLEAN, | |
error_message TEXT | |
) | |
''') | |
conn.commit() | |
def _init_training_database(self): | |
"""Initialize training sessions database""" | |
with sqlite3.connect(self.training_db) as conn: | |
conn.execute(''' | |
CREATE TABLE IF NOT EXISTS training_sessions ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
session_id TEXT UNIQUE NOT NULL, | |
teacher_model TEXT NOT NULL, | |
student_model TEXT NOT NULL, | |
dataset_name TEXT, | |
training_type TEXT NOT NULL, | |
status TEXT DEFAULT 'initialized', | |
progress REAL DEFAULT 0.0, | |
current_step INTEGER DEFAULT 0, | |
total_steps INTEGER, | |
current_loss REAL, | |
best_loss REAL, | |
learning_rate REAL, | |
batch_size INTEGER, | |
temperature REAL, | |
alpha REAL, | |
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
started_at TIMESTAMP, | |
completed_at TIMESTAMP, | |
error_message TEXT, | |
config_json TEXT | |
) | |
''') | |
conn.execute(''' | |
CREATE TABLE IF NOT EXISTS training_logs ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
session_id TEXT NOT NULL, | |
step INTEGER NOT NULL, | |
loss REAL, | |
learning_rate REAL, | |
memory_usage_mb REAL, | |
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
additional_metrics TEXT | |
) | |
''') | |
conn.commit() | |
def _init_performance_database(self): | |
"""Initialize performance metrics database""" | |
with sqlite3.connect(self.performance_db) as conn: | |
conn.execute(''' | |
CREATE TABLE IF NOT EXISTS system_metrics ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
cpu_usage_percent REAL, | |
memory_usage_mb REAL, | |
memory_usage_percent REAL, | |
available_memory_gb REAL, | |
disk_usage_percent REAL, | |
temperature_celsius REAL | |
) | |
''') | |
conn.execute(''' | |
CREATE TABLE IF NOT EXISTS model_performance ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
model_name TEXT NOT NULL, | |
operation TEXT NOT NULL, | |
duration_seconds REAL, | |
memory_peak_mb REAL, | |
throughput_samples_per_second REAL, | |
accuracy REAL, | |
timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, | |
additional_metrics TEXT | |
) | |
''') | |
conn.commit() | |
def _init_medical_database(self): | |
"""Initialize medical datasets database""" | |
with sqlite3.connect(self.medical_db) as conn: | |
conn.execute(''' | |
CREATE TABLE IF NOT EXISTS medical_datasets ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
dataset_name TEXT UNIQUE NOT NULL, | |
repo_id TEXT NOT NULL, | |
description TEXT, | |
size_gb REAL, | |
num_samples INTEGER, | |
modalities TEXT, | |
specialties TEXT, | |
languages TEXT, | |
last_accessed TIMESTAMP, | |
access_count INTEGER DEFAULT 0, | |
is_cached BOOLEAN DEFAULT FALSE, | |
cache_path TEXT, | |
metadata_json TEXT | |
) | |
''') | |
conn.execute(''' | |
CREATE TABLE IF NOT EXISTS dicom_files ( | |
id INTEGER PRIMARY KEY AUTOINCREMENT, | |
file_path TEXT UNIQUE NOT NULL, | |
patient_id TEXT, | |
study_date TEXT, | |
modality TEXT, | |
file_size_mb REAL, | |
processed BOOLEAN DEFAULT FALSE, | |
processed_at TIMESTAMP, | |
metadata_json TEXT | |
) | |
''') | |
conn.commit() | |
def get_connection(self, db_name: str) -> sqlite3.Connection: | |
"""Get database connection""" | |
db_map = { | |
'tokens': self.tokens_db, | |
'training': self.training_db, | |
'performance': self.performance_db, | |
'medical': self.medical_db | |
} | |
if db_name not in db_map: | |
raise ValueError(f"Unknown database: {db_name}") | |
return sqlite3.connect(db_map[db_name]) | |
def execute_query(self, db_name: str, query: str, params: tuple = ()) -> List[tuple]: | |
"""Execute query and return results""" | |
with self.get_connection(db_name) as conn: | |
cursor = conn.execute(query, params) | |
return cursor.fetchall() | |
def execute_update(self, db_name: str, query: str, params: tuple = ()) -> int: | |
"""Execute update query and return affected rows""" | |
with self.get_connection(db_name) as conn: | |
cursor = conn.execute(query, params) | |
conn.commit() | |
return cursor.rowcount | |
def backup_databases(self, backup_dir: str = "backups") -> Dict[str, str]: | |
"""Create backup of all databases""" | |
backup_path = Path(backup_dir) | |
backup_path.mkdir(parents=True, exist_ok=True) | |
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
backup_files = {} | |
db_files = { | |
'tokens': self.tokens_db, | |
'training': self.training_db, | |
'performance': self.performance_db, | |
'medical': self.medical_db | |
} | |
for db_name, db_file in db_files.items(): | |
if db_file.exists(): | |
backup_file = backup_path / f"{db_name}_{timestamp}.db" | |
# Copy database file | |
import shutil | |
shutil.copy2(db_file, backup_file) | |
backup_files[db_name] = str(backup_file) | |
logger.info(f"Backed up {db_name} database to {backup_file}") | |
return backup_files | |
def get_database_stats(self) -> Dict[str, Any]: | |
"""Get statistics about all databases""" | |
stats = {} | |
db_files = { | |
'tokens': self.tokens_db, | |
'training': self.training_db, | |
'performance': self.performance_db, | |
'medical': self.medical_db | |
} | |
for db_name, db_file in db_files.items(): | |
if db_file.exists(): | |
file_size_mb = db_file.stat().st_size / (1024**2) | |
# Get table counts | |
try: | |
with self.get_connection(db_name) as conn: | |
cursor = conn.execute( | |
"SELECT name FROM sqlite_master WHERE type='table'" | |
) | |
tables = [row[0] for row in cursor.fetchall()] | |
table_counts = {} | |
for table in tables: | |
cursor = conn.execute(f"SELECT COUNT(*) FROM {table}") | |
count = cursor.fetchone()[0] | |
table_counts[table] = count | |
stats[db_name] = { | |
'file_size_mb': file_size_mb, | |
'tables': table_counts, | |
'total_records': sum(table_counts.values()) | |
} | |
except Exception as e: | |
stats[db_name] = { | |
'file_size_mb': file_size_mb, | |
'error': str(e) | |
} | |
else: | |
stats[db_name] = { | |
'file_size_mb': 0, | |
'status': 'not_created' | |
} | |
return stats | |
def cleanup_old_data(self, days_to_keep: int = 30) -> Dict[str, int]: | |
"""Cleanup old data from databases""" | |
cutoff_date = datetime.now().timestamp() - (days_to_keep * 24 * 3600) | |
cleanup_stats = {} | |
try: | |
# Cleanup old performance metrics | |
with self.get_connection('performance') as conn: | |
cursor = conn.execute( | |
"DELETE FROM system_metrics WHERE timestamp < ?", | |
(cutoff_date,) | |
) | |
cleanup_stats['system_metrics'] = cursor.rowcount | |
conn.commit() | |
# Cleanup old training logs | |
with self.get_connection('training') as conn: | |
cursor = conn.execute( | |
"DELETE FROM training_logs WHERE timestamp < ?", | |
(cutoff_date,) | |
) | |
cleanup_stats['training_logs'] = cursor.rowcount | |
conn.commit() | |
# Cleanup old token usage logs | |
with self.get_connection('tokens') as conn: | |
cursor = conn.execute( | |
"DELETE FROM token_usage_log WHERE timestamp < ?", | |
(cutoff_date,) | |
) | |
cleanup_stats['token_usage_log'] = cursor.rowcount | |
conn.commit() | |
logger.info(f"Cleaned up old data: {cleanup_stats}") | |
except Exception as e: | |
logger.error(f"Error cleaning up old data: {e}") | |
cleanup_stats['error'] = str(e) | |
return cleanup_stats | |