Spaces:
Sleeping
Sleeping
#!/usr/bin/env python3 | |
""" | |
ResearchMate Settings | |
Centralized configuration management for ResearchMate | |
""" | |
import os | |
import json | |
import logging | |
from pathlib import Path | |
from typing import Dict, Any, Optional, List | |
from dataclasses import dataclass, asdict, field | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
class ServerConfig: | |
"""Server configuration settings""" | |
host: str = "0.0.0.0" | |
port: int = 7860 # HF Spaces default | |
debug: bool = False | |
reload: bool = False | |
workers: int = 1 | |
log_level: str = "info" | |
class DatabaseConfig: | |
"""Database configuration settings""" | |
chroma_persist_dir: str = "/tmp/researchmate/chroma_persist" | |
collection_name: str = "research_documents" | |
similarity_threshold: float = 0.7 | |
max_results: int = 10 | |
embedding_model: str = "all-MiniLM-L6-v2" | |
class AIModelConfig: | |
"""AI model configuration settings""" | |
model_name: str = "llama-3.3-70b-versatile" | |
temperature: float = 0.7 | |
max_tokens: int = 4096 | |
top_p: float = 0.9 | |
frequency_penalty: float = 0.0 | |
presence_penalty: float = 0.0 | |
timeout: int = 30 | |
class UploadConfig: | |
"""File upload configuration settings""" | |
max_file_size: int = 50 * 1024 * 1024 # 50MB | |
allowed_extensions: List[str] = field(default_factory=lambda: [".pdf", ".txt", ".md", ".docx", ".doc"]) | |
upload_directory: str = "/tmp/researchmate/uploads" | |
temp_directory: str = "/tmp/researchmate/tmp" | |
class SearchConfig: | |
"""Search configuration settings""" | |
max_results: int = 10 | |
similarity_threshold: float = 0.7 | |
enable_reranking: bool = True | |
chunk_size: int = 1000 | |
chunk_overlap: int = 200 | |
class SecurityConfig: | |
"""Security configuration settings""" | |
cors_origins: List[str] = field(default_factory=lambda: ["*"]) | |
cors_methods: List[str] = field(default_factory=lambda: ["*"]) | |
cors_headers: List[str] = field(default_factory=lambda: ["*"]) | |
rate_limit_enabled: bool = True | |
rate_limit_requests: int = 100 | |
rate_limit_period: int = 60 # seconds | |
class LoggingConfig: | |
"""Logging configuration settings""" | |
level: str = "INFO" | |
format: str = "%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
file_enabled: bool = True | |
file_path: str = "/tmp/researchmate/logs/app.log" | |
max_file_size: int = 10 * 1024 * 1024 # 10MB | |
backup_count: int = 5 | |
console_enabled: bool = True | |
class Settings: | |
"""Main settings class for ResearchMate""" | |
def __init__(self, config_file: Optional[str] = None): | |
self.config_file = config_file or self._get_default_config_file() | |
self.project_root = Path(__file__).parent.parent | |
# Initialize configuration objects with HF Spaces-safe defaults | |
self.server = ServerConfig() | |
self.database = DatabaseConfig() | |
self.ai_model = AIModelConfig() | |
self.upload = UploadConfig() | |
self.search = SearchConfig() | |
self.security = SecurityConfig() | |
self.logging = LoggingConfig() | |
# Load configuration | |
self._load_config() | |
self._validate_config() | |
def _get_default_config_file(self) -> str: | |
"""Get default configuration file path""" | |
# Always use writable config directory for HF Spaces | |
config_dir = os.environ.get('CONFIG_DIR', '/tmp/researchmate/config') | |
return str(Path(config_dir) / "settings.json") | |
def _load_config(self): | |
"""Load configuration from file and environment variables""" | |
# Load from file if exists | |
config_path = Path(self.config_file) | |
if config_path.exists(): | |
try: | |
with open(config_path, 'r') as f: | |
config_data = json.load(f) | |
self._apply_config_data(config_data) | |
except Exception as e: | |
logging.warning(f"Failed to load config file: {e}") | |
# Override with environment variables | |
self._load_from_env() | |
def _apply_config_data(self, config_data: Dict[str, Any]): | |
"""Apply configuration data to settings objects""" | |
for section, data in config_data.items(): | |
if hasattr(self, section): | |
section_obj = getattr(self, section) | |
for key, value in data.items(): | |
if hasattr(section_obj, key): | |
setattr(section_obj, key, value) | |
def _load_from_env(self): | |
"""Load configuration from environment variables""" | |
# Server configuration | |
self.server.host = os.getenv("HOST", self.server.host) | |
self.server.port = int(os.getenv("PORT", self.server.port)) | |
self.server.debug = os.getenv("DEBUG", "false").lower() == "true" | |
self.server.reload = os.getenv("RELOAD", "false").lower() == "true" | |
self.server.workers = int(os.getenv("WORKERS", self.server.workers)) | |
self.server.log_level = os.getenv("LOG_LEVEL", self.server.log_level) | |
# Database configuration - ALWAYS use writable tmp paths | |
self.database.chroma_persist_dir = os.getenv("CHROMA_DIR", "/tmp/researchmate/chroma_persist") | |
self.database.collection_name = os.getenv("COLLECTION_NAME", self.database.collection_name) | |
self.database.similarity_threshold = float(os.getenv("SIMILARITY_THRESHOLD", self.database.similarity_threshold)) | |
self.database.max_results = int(os.getenv("MAX_RESULTS", self.database.max_results)) | |
# AI model configuration | |
self.ai_model.model_name = os.getenv("MODEL_NAME", self.ai_model.model_name) | |
self.ai_model.temperature = float(os.getenv("TEMPERATURE", self.ai_model.temperature)) | |
self.ai_model.max_tokens = int(os.getenv("MAX_TOKENS", self.ai_model.max_tokens)) | |
self.ai_model.timeout = int(os.getenv("MODEL_TIMEOUT", self.ai_model.timeout)) | |
# Upload configuration - ALWAYS use writable tmp paths | |
self.upload.max_file_size = int(os.getenv("MAX_FILE_SIZE", self.upload.max_file_size)) | |
self.upload.upload_directory = os.getenv("UPLOADS_DIR", "/tmp/researchmate/uploads") | |
self.upload.temp_directory = os.getenv("TEMP_DIR", "/tmp/researchmate/tmp") | |
# Logging configuration - ALWAYS use writable tmp paths | |
self.logging.level = os.getenv("LOG_LEVEL", self.logging.level) | |
self.logging.file_path = os.getenv("LOG_FILE", "/tmp/researchmate/logs/app.log") | |
# Ensure no hardcoded /data paths slip through | |
self._sanitize_paths() | |
def _sanitize_paths(self): | |
"""Ensure no paths point to non-writable locations""" | |
# List of paths that should be writable | |
writable_paths = [ | |
('database.chroma_persist_dir', '/tmp/researchmate/chroma_persist'), | |
('upload.upload_directory', '/tmp/researchmate/uploads'), | |
('upload.temp_directory', '/tmp/researchmate/tmp'), | |
('logging.file_path', '/tmp/researchmate/logs/app.log'), | |
] | |
for path_attr, fallback in writable_paths: | |
obj, attr = path_attr.split('.') | |
current_path = getattr(getattr(self, obj), attr) | |
# Check if path is in a potentially non-writable location | |
if (current_path.startswith('/data') or | |
current_path.startswith('./data') or | |
current_path.startswith('/app/data') or | |
not current_path.startswith('/tmp/')): | |
print(f"β Warning: Changing {path_attr} from {current_path} to {fallback}") | |
setattr(getattr(self, obj), attr, fallback) | |
def _validate_config(self): | |
"""Validate configuration settings""" | |
# Validate required environment variables | |
required_env_vars = ["GROQ_API_KEY"] | |
missing_vars = [var for var in required_env_vars if not os.getenv(var)] | |
if missing_vars: | |
print(f"β Warning: Missing environment variables: {', '.join(missing_vars)}") | |
print("Some features may not work without these variables") | |
# Validate server configuration | |
if not (1 <= self.server.port <= 65535): | |
print(f"β Warning: Invalid port {self.server.port}, using 7860") | |
self.server.port = 7860 | |
# Validate AI model configuration | |
if not (0.0 <= self.ai_model.temperature <= 2.0): | |
print(f"β Warning: Invalid temperature {self.ai_model.temperature}, using 0.7") | |
self.ai_model.temperature = 0.7 | |
if not (1 <= self.ai_model.max_tokens <= 32768): | |
print(f"β Warning: Invalid max_tokens {self.ai_model.max_tokens}, using 4096") | |
self.ai_model.max_tokens = 4096 | |
# Validate database configuration | |
if not (0.0 <= self.database.similarity_threshold <= 1.0): | |
print(f"β Warning: Invalid similarity_threshold {self.database.similarity_threshold}, using 0.7") | |
self.database.similarity_threshold = 0.7 | |
# Create directories if they don't exist | |
self._create_directories() | |
def _create_directories(self): | |
"""Create necessary directories""" | |
directories = [ | |
self.database.chroma_persist_dir, | |
self.upload.upload_directory, | |
self.upload.temp_directory, | |
Path(self.logging.file_path).parent, | |
Path(self.config_file).parent | |
] | |
for directory in directories: | |
try: | |
path = Path(directory) | |
path.mkdir(parents=True, exist_ok=True) | |
# Ensure write permissions | |
path.chmod(0o777) | |
print(f"β Created/verified directory: {directory}") | |
except Exception as e: | |
print(f"β Warning: Could not create directory {directory}: {e}") | |
# Continue without raising error | |
def save_config(self): | |
"""Save current configuration to file""" | |
config_data = { | |
"server": asdict(self.server), | |
"database": asdict(self.database), | |
"ai_model": asdict(self.ai_model), | |
"upload": asdict(self.upload), | |
"search": asdict(self.search), | |
"security": asdict(self.security), | |
"logging": asdict(self.logging) | |
} | |
config_path = Path(self.config_file) | |
try: | |
config_path.parent.mkdir(parents=True, exist_ok=True) | |
with open(config_path, 'w') as f: | |
json.dump(config_data, f, indent=2) | |
print(f"β Configuration saved to: {config_path}") | |
except Exception as e: | |
print(f"β Warning: Could not save config file: {e}") | |
# Don't raise the error for config saving | |
def get_groq_api_key(self) -> str: | |
"""Get Groq API key from environment""" | |
api_key = os.getenv("GROQ_API_KEY") | |
if not api_key: | |
print("β Warning: GROQ_API_KEY environment variable is not set") | |
return "dummy_key" # Return dummy key to prevent crashes | |
return api_key | |
def get_database_url(self) -> str: | |
"""Get database connection URL""" | |
return f"sqlite:///{self.database.chroma_persist_dir}/chroma.db" | |
def get_static_url(self) -> str: | |
"""Get static files URL""" | |
return "/static" | |
def get_templates_dir(self) -> str: | |
"""Get templates directory""" | |
return str(self.project_root / "src" / "templates") | |
def get_static_dir(self) -> str: | |
"""Get static files directory""" | |
return str(self.project_root / "src" / "static") | |
def get_upload_dir(self) -> str: | |
"""Get upload directory""" | |
return self.upload.upload_directory | |
def is_development(self) -> bool: | |
"""Check if running in development mode""" | |
return os.getenv("ENVIRONMENT", "production").lower() == "development" | |
def is_production(self) -> bool: | |
"""Check if running in production mode""" | |
return not self.is_development() | |
def __str__(self) -> str: | |
"""String representation of settings""" | |
return f"ResearchMate Settings (Config: {self.config_file})" | |
def __repr__(self) -> str: | |
"""Detailed representation of settings""" | |
return f"Settings(config_file='{self.config_file}')" | |
# Global settings instance | |
settings = Settings() | |
# Convenience functions | |
def get_settings() -> Settings: | |
"""Get the global settings instance""" | |
return settings | |
def reload_settings(): | |
"""Reload settings from configuration file""" | |
global settings | |
settings = Settings(settings.config_file) | |
def create_default_config(): | |
"""Create a default configuration file""" | |
default_settings = Settings() | |
default_settings.save_config() | |
return default_settings.config_file | |
if __name__ == "__main__": | |
# Test the settings | |
print("ResearchMate Settings Test") | |
print("=" * 40) | |
try: | |
settings = get_settings() | |
print(f"β Settings loaded successfully") | |
print(f"Config file: {settings.config_file}") | |
print(f"Server: {settings.server.host}:{settings.server.port}") | |
print(f"AI Model: {settings.ai_model.model_name}") | |
print(f"Database: {settings.database.chroma_persist_dir}") | |
print(f"Upload dir: {settings.get_upload_dir()}") | |
print(f"Groq API Key: {'Set' if settings.get_groq_api_key() else 'Not set'}") | |
print(f"Environment: {'Development' if settings.is_development() else 'Production'}") | |
# Save configuration | |
settings.save_config() | |
except Exception as e: | |
print(f"β Error: {e}") | |
import traceback | |
traceback.print_exc() |