|
import os |
|
import logging |
|
import uuid |
|
import json |
|
import pandas as pd |
|
from datetime import datetime |
|
from typing import List, Dict, Any, Optional |
|
from datasets import Dataset, load_dataset |
|
from huggingface_hub import HfApi, HfFolder, CommitOperationAdd |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
class ChatHistoryManager: |
|
""" |
|
Manages chat history persistence using Hugging Face Datasets. |
|
Supports both local storage and syncing to Hugging Face Hub. |
|
""" |
|
|
|
def __init__(self, dataset_name=None, local_dir="./data/chat_history"): |
|
""" |
|
Initialize the chat history manager. |
|
|
|
Args: |
|
dataset_name: Hugging Face dataset name (username/repo) |
|
local_dir: Local directory to store chat history |
|
""" |
|
self.dataset_name = dataset_name or os.getenv("HF_DATASET_NAME") |
|
self.local_dir = local_dir |
|
self.hf_api = HfApi() |
|
self.token = os.getenv("HF_API_KEY") |
|
|
|
|
|
os.makedirs(self.local_dir, exist_ok=True) |
|
|
|
|
|
self.local_file = os.path.join(self.local_dir, "chat_history.jsonl") |
|
|
|
|
|
if not os.path.exists(self.local_file): |
|
with open(self.local_file, "w") as f: |
|
f.write("") |
|
|
|
logger.info(f"Chat history manager initialized with local file: {self.local_file}") |
|
if self.dataset_name: |
|
logger.info(f"Will sync to HF dataset: {self.dataset_name}") |
|
|
|
def load_history(self) -> List[Dict[str, Any]]: |
|
"""Load chat history from local file or Hugging Face dataset.""" |
|
try: |
|
|
|
if os.path.exists(self.local_file) and os.path.getsize(self.local_file) > 0: |
|
with open(self.local_file, "r") as f: |
|
lines = f.readlines() |
|
history = [json.loads(line) for line in lines if line.strip()] |
|
logger.info(f"Loaded {len(history)} conversations from local file") |
|
return history |
|
|
|
|
|
if self.dataset_name and self.token: |
|
try: |
|
dataset = load_dataset(self.dataset_name, token=self.token) |
|
history = dataset["train"].to_pandas().to_dict("records") |
|
logger.info(f"Loaded {len(history)} conversations from Hugging Face") |
|
|
|
|
|
self._write_history_to_local(history) |
|
return history |
|
except Exception as e: |
|
logger.warning(f"Error loading from Hugging Face: {e}") |
|
|
|
|
|
return [] |
|
except Exception as e: |
|
logger.error(f"Error loading chat history: {e}") |
|
return [] |
|
|
|
def save_conversation(self, conversation: Dict[str, Any]) -> bool: |
|
""" |
|
Save a conversation to history. |
|
|
|
Args: |
|
conversation: Dict with keys: user_query, assistant_response, |
|
timestamp, sources (optional) |
|
|
|
Returns: |
|
bool: True if successful |
|
""" |
|
try: |
|
|
|
if "id" not in conversation: |
|
conversation["id"] = str(uuid.uuid4()) |
|
if "timestamp" not in conversation: |
|
conversation["timestamp"] = datetime.now().isoformat() |
|
|
|
|
|
with open(self.local_file, "a") as f: |
|
f.write(json.dumps(conversation) + "\n") |
|
|
|
logger.info(f"Saved conversation to local file: {conversation['id']}") |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error saving conversation: {e}") |
|
return False |
|
|
|
def sync_to_hub(self) -> bool: |
|
"""Sync local chat history to Hugging Face Hub.""" |
|
if not self.dataset_name or not self.token: |
|
logger.warning("Cannot sync to Hub: missing dataset name or token") |
|
return False |
|
|
|
try: |
|
|
|
history = self.load_history() |
|
if not history: |
|
logger.warning("No history to sync") |
|
return False |
|
|
|
|
|
ds = Dataset.from_pandas( |
|
pd.DataFrame(history) |
|
) |
|
|
|
|
|
ds.push_to_hub( |
|
self.dataset_name, |
|
token=self.token, |
|
private=True |
|
) |
|
|
|
logger.info(f"Successfully synced {len(history)} conversations to Hugging Face Hub") |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error syncing to Hub: {e}") |
|
return False |
|
|
|
def _write_history_to_local(self, history: List[Dict[str, Any]]) -> bool: |
|
"""Write history list to local file.""" |
|
try: |
|
with open(self.local_file, "w") as f: |
|
for conversation in history: |
|
f.write(json.dumps(conversation) + "\n") |
|
return True |
|
except Exception as e: |
|
logger.error(f"Error writing history to local file: {e}") |
|
return False |
|
|
|
def get_conversations_by_date(self, start_date=None, end_date=None) -> List[Dict[str, Any]]: |
|
"""Get conversations filtered by date range.""" |
|
history = self.load_history() |
|
|
|
if not start_date and not end_date: |
|
return history |
|
|
|
filtered = [] |
|
for conv in history: |
|
timestamp = conv.get("timestamp", "") |
|
if not timestamp: |
|
continue |
|
|
|
try: |
|
conv_date = datetime.fromisoformat(timestamp) |
|
|
|
if start_date and end_date: |
|
if start_date <= conv_date <= end_date: |
|
filtered.append(conv) |
|
elif start_date: |
|
if start_date <= conv_date: |
|
filtered.append(conv) |
|
elif end_date: |
|
if conv_date <= end_date: |
|
filtered.append(conv) |
|
except ValueError: |
|
continue |
|
|
|
return filtered |
|
|
|
def search_conversations(self, query: str) -> List[Dict[str, Any]]: |
|
"""Search conversations by keyword (simple text match).""" |
|
history = self.load_history() |
|
query = query.lower() |
|
|
|
results = [] |
|
for conv in history: |
|
user_query = conv.get("user_query", "").lower() |
|
assistant_response = conv.get("assistant_response", "").lower() |
|
|
|
if query in user_query or query in assistant_response: |
|
results.append(conv) |
|
|
|
return results |