|
""" |
|
Utility functions for the Gradio interface. |
|
""" |
|
|
|
import os |
|
import tempfile |
|
import uuid |
|
from pathlib import Path |
|
from typing import Any, Dict, List, Optional, Tuple, Union |
|
import time |
|
import gradio as gr |
|
|
|
|
|
def generate_session_id() -> str: |
|
"""Generate a unique session ID.""" |
|
return str(uuid.uuid4()) |
|
|
|
|
|
def save_uploaded_files(files: List[gr.File]) -> List[Tuple[str, str]]: |
|
""" |
|
Save uploaded files to temporary directory. |
|
|
|
Args: |
|
files: List of uploaded file objects |
|
|
|
Returns: |
|
List of (file_path, original_filename) tuples |
|
""" |
|
if not files: |
|
return [] |
|
|
|
saved_files = [] |
|
temp_dir = Path(tempfile.mkdtemp(prefix="rag_uploads_")) |
|
|
|
for file_obj in files: |
|
if file_obj is None: |
|
continue |
|
|
|
|
|
if hasattr(file_obj, 'orig_name'): |
|
original_name = file_obj.orig_name |
|
elif hasattr(file_obj, 'name'): |
|
original_name = Path(file_obj.name).name |
|
else: |
|
original_name = f"upload_{int(time.time())}" |
|
|
|
|
|
safe_name = generate_safe_filename(original_name) |
|
temp_path = temp_dir / safe_name |
|
|
|
|
|
if hasattr(file_obj, 'name') and os.path.exists(file_obj.name): |
|
|
|
import shutil |
|
shutil.copy2(file_obj.name, temp_path) |
|
else: |
|
|
|
with open(temp_path, 'wb') as f: |
|
if hasattr(file_obj, 'read'): |
|
f.write(file_obj.read()) |
|
elif hasattr(file_obj, 'content'): |
|
f.write(file_obj.content) |
|
|
|
saved_files.append((str(temp_path), original_name)) |
|
|
|
return saved_files |
|
|
|
|
|
def generate_safe_filename(filename: str) -> str: |
|
"""Generate a safe filename for saving.""" |
|
|
|
import re |
|
safe_name = re.sub(r'[^\w\-_\.]', '_', filename) |
|
|
|
|
|
if len(safe_name) > 100: |
|
name_part, ext = os.path.splitext(safe_name) |
|
safe_name = name_part[:90] + ext |
|
|
|
|
|
if not safe_name: |
|
safe_name = f"file_{int(time.time())}.txt" |
|
|
|
return safe_name |
|
|
|
|
|
def validate_file_types(files: List[gr.File], allowed_extensions: List[str]) -> Tuple[List[gr.File], List[str]]: |
|
""" |
|
Validate uploaded file types. |
|
|
|
Args: |
|
files: List of uploaded files |
|
allowed_extensions: List of allowed file extensions (e.g., ['.pdf', '.docx']) |
|
|
|
Returns: |
|
Tuple of (valid_files, error_messages) |
|
""" |
|
if not files: |
|
return [], ["No files provided"] |
|
|
|
valid_files = [] |
|
errors = [] |
|
|
|
for file_obj in files: |
|
if file_obj is None: |
|
continue |
|
|
|
|
|
if hasattr(file_obj, 'orig_name'): |
|
filename = file_obj.orig_name |
|
elif hasattr(file_obj, 'name'): |
|
filename = Path(file_obj.name).name |
|
else: |
|
filename = "unknown" |
|
|
|
|
|
file_ext = Path(filename).suffix.lower() |
|
|
|
if file_ext not in allowed_extensions: |
|
errors.append(f"File '{filename}' has unsupported type '{file_ext}'. Allowed types: {', '.join(allowed_extensions)}") |
|
else: |
|
valid_files.append(file_obj) |
|
|
|
return valid_files, errors |
|
|
|
|
|
def format_file_size(size_bytes: int) -> str: |
|
"""Format file size in human readable format.""" |
|
if size_bytes == 0: |
|
return "0 B" |
|
|
|
size_names = ["B", "KB", "MB", "GB"] |
|
i = 0 |
|
|
|
while size_bytes >= 1024.0 and i < len(size_names) - 1: |
|
size_bytes /= 1024.0 |
|
i += 1 |
|
|
|
return f"{size_bytes:.1f} {size_names[i]}" |
|
|
|
|
|
def cleanup_temp_files(file_paths: List[str]) -> None: |
|
"""Clean up temporary files.""" |
|
for file_path in file_paths: |
|
try: |
|
if os.path.exists(file_path): |
|
os.remove(file_path) |
|
except Exception as e: |
|
print(f"Warning: Could not remove temp file {file_path}: {e}") |
|
|
|
|
|
def extract_filename_from_path(file_path: str) -> str: |
|
"""Extract filename from file path.""" |
|
return Path(file_path).name |
|
|
|
|
|
def create_download_link(file_path: str, filename: str = None) -> str: |
|
"""Create a download link HTML.""" |
|
display_name = filename or Path(file_path).name |
|
|
|
return f""" |
|
<a href="file/{file_path}" download="{display_name}" |
|
style="color: #667eea; text-decoration: none; font-weight: 500;"> |
|
π₯ Download {display_name} |
|
</a> |
|
""" |
|
|
|
|
|
def sanitize_html(text: str) -> str: |
|
"""Sanitize text for safe HTML display.""" |
|
import html |
|
return html.escape(text) |
|
|
|
|
|
def truncate_text(text: str, max_length: int = 100, suffix: str = "...") -> str: |
|
"""Truncate text to specified length.""" |
|
if len(text) <= max_length: |
|
return text |
|
return text[:max_length - len(suffix)] + suffix |
|
|
|
|
|
def format_duration(seconds: float) -> str: |
|
"""Format duration in human readable format.""" |
|
if seconds < 1: |
|
return f"{seconds * 1000:.0f}ms" |
|
elif seconds < 60: |
|
return f"{seconds:.1f}s" |
|
elif seconds < 3600: |
|
minutes = int(seconds // 60) |
|
remaining_seconds = seconds % 60 |
|
return f"{minutes}m {remaining_seconds:.0f}s" |
|
else: |
|
hours = int(seconds // 3600) |
|
minutes = int((seconds % 3600) // 60) |
|
return f"{hours}h {minutes}m" |
|
|
|
|
|
def parse_search_filters(filter_text: str) -> Optional[Dict[str, Any]]: |
|
""" |
|
Parse search filter text into filter dictionary. |
|
|
|
Format: "key:value,key2:value2" or JSON string |
|
""" |
|
if not filter_text or not filter_text.strip(): |
|
return None |
|
|
|
filter_text = filter_text.strip() |
|
|
|
|
|
if filter_text.startswith('{') and filter_text.endswith('}'): |
|
try: |
|
import json |
|
return json.loads(filter_text) |
|
except json.JSONDecodeError: |
|
pass |
|
|
|
|
|
try: |
|
filters = {} |
|
pairs = filter_text.split(',') |
|
|
|
for pair in pairs: |
|
if ':' in pair: |
|
key, value = pair.split(':', 1) |
|
key = key.strip() |
|
value = value.strip() |
|
|
|
|
|
if value.lower() in ['true', 'false']: |
|
filters[key] = value.lower() == 'true' |
|
elif value.isdigit(): |
|
filters[key] = int(value) |
|
else: |
|
try: |
|
filters[key] = float(value) |
|
except ValueError: |
|
filters[key] = value |
|
|
|
return filters if filters else None |
|
|
|
except Exception: |
|
return None |
|
|
|
|
|
def create_breadcrumb(current_tab: str, sub_section: str = None) -> str: |
|
"""Create breadcrumb navigation.""" |
|
breadcrumb_parts = ["π RAG Assistant"] |
|
|
|
if current_tab: |
|
breadcrumb_parts.append(current_tab) |
|
|
|
if sub_section: |
|
breadcrumb_parts.append(sub_section) |
|
|
|
return " β ".join(breadcrumb_parts) |
|
|
|
|
|
def debounce_function(wait_time: float = 0.3): |
|
""" |
|
Decorator to debounce function calls. |
|
Useful for search inputs to avoid too many API calls. |
|
""" |
|
def decorator(func): |
|
def wrapper(*args, **kwargs): |
|
wrapper.calls = getattr(wrapper, 'calls', []) |
|
call_time = time.time() |
|
wrapper.calls.append(call_time) |
|
|
|
|
|
wrapper.calls = [t for t in wrapper.calls if call_time - t < wait_time] |
|
|
|
|
|
time.sleep(wait_time) |
|
if wrapper.calls and wrapper.calls[-1] == call_time: |
|
return func(*args, **kwargs) |
|
|
|
return wrapper |
|
return decorator |
|
|
|
|
|
def get_file_icon(filename: str) -> str: |
|
"""Get appropriate icon for file type.""" |
|
ext = Path(filename).suffix.lower() |
|
|
|
icons = { |
|
'.pdf': 'π', |
|
'.docx': 'π', |
|
'.doc': 'π', |
|
'.txt': 'π', |
|
'.md': 'π', |
|
'.py': 'π', |
|
'.js': 'π¨', |
|
'.html': 'π', |
|
'.css': 'π¨', |
|
'.json': 'π', |
|
'.xml': 'π', |
|
'.csv': 'π', |
|
'.xlsx': 'π', |
|
'.xls': 'π' |
|
} |
|
|
|
return icons.get(ext, 'π') |
|
|
|
|
|
def create_tooltip(text: str, tooltip_text: str) -> str: |
|
"""Create text with tooltip.""" |
|
return f""" |
|
<span title="{sanitize_html(tooltip_text)}" style="cursor: help; border-bottom: 1px dotted #6b7280;"> |
|
{sanitize_html(text)} |
|
</span> |
|
""" |
|
|
|
|
|
def format_timestamp(timestamp: float) -> str: |
|
"""Format timestamp to readable format.""" |
|
import datetime |
|
dt = datetime.datetime.fromtimestamp(timestamp) |
|
return dt.strftime("%Y-%m-%d %H:%M:%S") |
|
|
|
|
|
def calculate_reading_time(text: str, words_per_minute: int = 200) -> str: |
|
"""Calculate estimated reading time for text.""" |
|
word_count = len(text.split()) |
|
minutes = max(1, round(word_count / words_per_minute)) |
|
|
|
if minutes == 1: |
|
return "1 min read" |
|
elif minutes < 60: |
|
return f"{minutes} min read" |
|
else: |
|
hours = minutes // 60 |
|
remaining_minutes = minutes % 60 |
|
if remaining_minutes == 0: |
|
return f"{hours}h read" |
|
else: |
|
return f"{hours}h {remaining_minutes}m read" |
|
|
|
|
|
def highlight_search_terms(text: str, search_terms: List[str]) -> str: |
|
"""Highlight search terms in text.""" |
|
if not search_terms: |
|
return sanitize_html(text) |
|
|
|
import re |
|
|
|
|
|
text = sanitize_html(text) |
|
|
|
|
|
for term in search_terms: |
|
if not term.strip(): |
|
continue |
|
|
|
|
|
pattern = re.compile(re.escape(term), re.IGNORECASE) |
|
text = pattern.sub( |
|
lambda m: f'<mark style="background-color: #fef08a; padding: 0 2px;">{m.group(0)}</mark>', |
|
text |
|
) |
|
|
|
return text |