Spaces:
Sleeping
Sleeping
| """ | |
| Encoders for MandelMem system. | |
| Converts content to vectors and complex coordinates. | |
| """ | |
| import torch | |
| import torch.nn as nn | |
| import numpy as np | |
| from typing import Union, Dict, Any, Tuple | |
| from dataclasses import dataclass | |
| class EncodingResult: | |
| """Result of encoding operation.""" | |
| vector: torch.Tensor | |
| complex_coord: complex | |
| metadata: Dict[str, Any] | |
| class ContentEncoder(nn.Module): | |
| """Encodes text/image/event content to vector representation.""" | |
| def __init__(self, embedding_dim: int = 768, vocab_size: int = 50000): | |
| super().__init__() | |
| self.embedding_dim = embedding_dim | |
| self.vocab_size = vocab_size | |
| # Simple text encoder (can be replaced with transformer) | |
| self.embedding = nn.Embedding(vocab_size, embedding_dim) | |
| self.position_encoding = nn.Parameter(torch.randn(512, embedding_dim)) | |
| self.transformer = nn.TransformerEncoder( | |
| nn.TransformerEncoderLayer( | |
| d_model=embedding_dim, | |
| nhead=8, | |
| dim_feedforward=2048, | |
| dropout=0.1, | |
| batch_first=True | |
| ), | |
| num_layers=6 | |
| ) | |
| self.pooler = nn.Linear(embedding_dim, embedding_dim) | |
| def tokenize(self, text: str) -> torch.Tensor: | |
| """Simple tokenization (replace with proper tokenizer).""" | |
| # Convert to character-level tokens for simplicity | |
| tokens = [ord(c) % self.vocab_size for c in text[:512]] | |
| tokens = tokens + [0] * (512 - len(tokens)) # Pad | |
| return torch.tensor(tokens, dtype=torch.long) | |
| def forward(self, content: Union[str, torch.Tensor]) -> torch.Tensor: | |
| """Encode content to vector.""" | |
| if isinstance(content, str): | |
| tokens = self.tokenize(content).unsqueeze(0) | |
| else: | |
| tokens = content | |
| # Add position encoding | |
| seq_len = tokens.size(1) | |
| pos_enc = self.position_encoding[:seq_len].unsqueeze(0) | |
| # Embed and encode | |
| embedded = self.embedding(tokens) + pos_enc | |
| encoded = self.transformer(embedded) | |
| # Pool to single vector | |
| pooled = torch.mean(encoded, dim=1) | |
| return torch.tanh(self.pooler(pooled)) | |
| class AddressEncoder(nn.Module): | |
| """Encodes content/metadata to complex coordinate address.""" | |
| def __init__(self, input_dim: int = 768, hidden_dim: int = 256, meta_dim: int = 6): | |
| super().__init__() | |
| self.input_dim = input_dim | |
| self.meta_dim = meta_dim | |
| # Two-head MLP for real and imaginary parts | |
| # Input can be just vector or vector + metadata | |
| self.shared = nn.Sequential( | |
| nn.Linear(input_dim + meta_dim, hidden_dim), | |
| nn.ReLU(), | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.ReLU() | |
| ) | |
| self.real_head = nn.Linear(hidden_dim, 1) | |
| self.imag_head = nn.Linear(hidden_dim, 1) | |
| def forward(self, vector: torch.Tensor, meta: Dict[str, Any] = None) -> complex: | |
| """Convert vector to complex coordinate.""" | |
| # Always add metadata features (use defaults if none provided) | |
| meta_features = self._encode_metadata(meta or {}) | |
| # Concatenate vector and metadata | |
| combined_input = torch.cat([vector, meta_features], dim=-1) | |
| shared_repr = self.shared(combined_input) | |
| real_part = torch.tanh(self.real_head(shared_repr)) * 2.0 # Scale to [-2, 2] | |
| imag_part = torch.tanh(self.imag_head(shared_repr)) * 2.0 | |
| return complex(real_part.item(), imag_part.item()) | |
| def _encode_metadata(self, meta: Dict[str, Any]) -> torch.Tensor: | |
| """Encode metadata to features.""" | |
| features = [] | |
| # Importance score | |
| features.append(float(meta.get('importance', 0.5))) | |
| # Source type (one-hot) | |
| source_types = ['user', 'system', 'external', 'generated'] | |
| source = meta.get('source', 'user') | |
| source_vec = [1.0 if s == source else 0.0 for s in source_types] | |
| features.extend(source_vec) | |
| # PII flag | |
| features.append(1.0 if meta.get('pii', False) else 0.0) | |
| # Ensure we have exactly meta_dim features | |
| while len(features) < self.meta_dim: | |
| features.append(0.0) | |
| features = features[:self.meta_dim] | |
| return torch.tensor(features, dtype=torch.float32).unsqueeze(0) | |
| class TimeEncoder(nn.Module): | |
| """Encodes temporal features for time-aware memory.""" | |
| def __init__(self, time_dim: int = 64): | |
| super().__init__() | |
| self.time_dim = time_dim | |
| # Sinusoidal position encoding for time | |
| self.time_encoder = nn.Sequential( | |
| nn.Linear(1, time_dim), | |
| nn.ReLU(), | |
| nn.Linear(time_dim, time_dim) | |
| ) | |
| def forward(self, timestamp: float) -> torch.Tensor: | |
| """Encode timestamp to temporal features.""" | |
| # Normalize timestamp (assuming Unix timestamp) | |
| normalized_time = torch.tensor([timestamp / 1e9], dtype=torch.float32) | |
| return self.time_encoder(normalized_time) | |
| def get_recency_weight(self, timestamp: float, current_time: float, | |
| half_life: float = 86400.0) -> float: | |
| """Calculate recency weight with exponential decay.""" | |
| age = current_time - timestamp | |
| return np.exp(-age / half_life) | |
| class MultiModalEncoder(nn.Module): | |
| """Combined encoder for content, address, and time.""" | |
| def __init__(self, embedding_dim: int = 768): | |
| super().__init__() | |
| self.content_encoder = ContentEncoder(embedding_dim) | |
| self.address_encoder = AddressEncoder(embedding_dim) | |
| self.time_encoder = TimeEncoder() | |
| def encode(self, content: str, meta: Dict[str, Any] = None, | |
| timestamp: float = None) -> EncodingResult: | |
| """Full encoding pipeline.""" | |
| # Encode content | |
| vector = self.content_encoder(content) | |
| # Add temporal features if timestamp provided | |
| if timestamp is not None: | |
| time_features = self.time_encoder(timestamp) | |
| # Concatenate or add time features (simplified) | |
| vector = vector + torch.mean(time_features).item() | |
| # Generate complex address | |
| complex_coord = self.address_encoder(vector, meta) | |
| return EncodingResult( | |
| vector=vector.squeeze(0), | |
| complex_coord=complex_coord, | |
| metadata=meta or {} | |
| ) | |