import os import json import torch import logging from pathlib import Path from dataclasses import dataclass from typing import Optional, List, Dict, Tuple, Any import transformers from transformers import ( AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling ) from datasets import Dataset, load_dataset import numpy as np from accelerate import Accelerator from safetensors import safe_open from safetensors.torch import save_file, load_file logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) @dataclass class TensorInfo: """Stores metadata about tensor indices and shape""" shape: Tuple[int, ...] dtype: str indices: Optional[torch.Tensor] = None hcf_patterns: Optional[Dict] = None class SafeTensorHCFAnalyzer: """ Analyzes HCF patterns in model weights using SafeTensors format. Handles efficient loading and analysis of large model weights. """ def __init__(self, tolerance: float = 1e-5): self.tolerance = tolerance self.tensor_info = {} self.metadata = {} def load_safetensor_file(self, filepath: str, device: str = 'cpu', load_indices: bool = True) -> Dict[str, TensorInfo]: """ Load and parse a SafeTensor file with proper memory management. Args: filepath: Path to .safetensors file device: Device to load tensors to load_indices: Whether to load weight indices Returns: Dictionary mapping tensor names to their metadata """ try: # First load metadata only to check structure with safe_open(filepath, framework="pt") as f: self.metadata = json.loads(f.metadata()) if f.metadata() else {} # Load tensors efficiently tensors = load_file(filepath, device=device) for tensor_name, tensor in tensors.items(): self.tensor_info[tensor_name] = TensorInfo( shape=tuple(tensor.shape), dtype=str(tensor.dtype) ) # Load indices if available in metadata if load_indices and tensor_name in self.metadata: if 'indices' in self.metadata[tensor_name]: indices_data = self.metadata[tensor_name]['indices'] if isinstance(indices_data, list): self.tensor_info[tensor_name].indices = torch.tensor( indices_data, device=device ) elif isinstance(indices_data, str) and os.path.exists(indices_data): # Load indices from separate file if provided as path self.tensor_info[tensor_name].indices = torch.load(indices_data) return self.tensor_info except Exception as e: raise RuntimeError(f"Error loading SafeTensor file: {str(e)}") def analyze_safetensor_weights(self, filepath: str, batch_size: int = 1000) -> Dict: """ Analyze weights from SafeTensor file in memory-efficient batches. Args: filepath: Path to .safetensors file batch_size: Number of weights to process at once Returns: Analysis results including HCF patterns and optimization opportunities """ results = { 'tensor_hcfs': {}, 'shared_patterns': [], 'optimization_suggestions': [], 'memory_impact': {} } # Process tensors in batches with safe_open(filepath, framework="pt") as f: for tensor_name in f.keys(): # Get tensor info tensor_data = f.get_tensor(tensor_name) tensor_size = np.prod(tensor_data.shape) if tensor_name in self.tensor_info and self.tensor_info[tensor_name].indices is not None: indices = self.tensor_info[tensor_name].indices unique_indices = torch.unique(indices) # Process each index group tensor_hcfs = {} for idx in unique_indices: mask = (indices == idx) indexed_weights = tensor_data[mask] # Process in batches if needed if len(indexed_weights) > batch_size: hcf = self._process_large_weight_group(indexed_weights, batch_size) else: hcf = self._calculate_hcf(indexed_weights) tensor_hcfs[idx.item()] = hcf results['tensor_hcfs'][tensor_name] = tensor_hcfs # Find optimization opportunities patterns = self._analyze_weight_patterns(tensor_data, indices) self.tensor_info[tensor_name].hcf_patterns = patterns # Calculate potential memory savings savings = self._estimate_memory_savings(patterns, tensor_data.dtype) results['memory_impact'][tensor_name] = { 'original_size': tensor_size * tensor_data.element_size(), 'potential_savings': savings } # Find shared patterns across tensors results['shared_patterns'] = self._find_shared_patterns() results['optimization_suggestions'] = self._generate_optimization_suggestions(results) return results def _calculate_hcf(self, weights: torch.Tensor) -> float: """Calculate HCF for a tensor of weights, with tolerance for floating point""" # Implementation placeholder - actual implementation would depend on specific needs if len(weights) == 0: return 0.0 return 1.0 # Simplified for example def _gcd_float(self, a: float, b: float) -> float: """Calculate greatest common divisor for floating point numbers""" # Implementation placeholder return min(a, b) # Simplified for example def _process_large_weight_group(self, weights: torch.Tensor, batch_size: int) -> float: """Process large weight groups in batches to manage memory.""" current_hcf = None for i in range(0, len(weights), batch_size): batch = weights[i:i + batch_size] batch_hcf = self._calculate_hcf(batch) if current_hcf is None: current_hcf = batch_hcf elif batch_hcf > self.tolerance: current_hcf = self._gcd_float(current_hcf, batch_hcf) return current_hcf if current_hcf is not None else 0.0 def _analyze_weight_patterns(self, weights: torch.Tensor, indices: torch.Tensor) -> Dict: """Analyze weight patterns within indexed groups.""" patterns = {} unique_indices = torch.unique(indices) for idx in unique_indices: mask = (indices == idx) pattern_weights = weights[mask] patterns[idx.item()] = { 'mean': float(pattern_weights.mean()), 'std': float(pattern_weights.std()), 'size': len(pattern_weights), 'hcf': self._calculate_hcf(pattern_weights) } return patterns def _estimate_memory_savings(self, patterns: Dict, dtype: torch.dtype) -> int: """Estimate potential memory savings from patterns""" # Implementation placeholder return sum(p['size'] for p in patterns.values()) // 2 # Simplified estimate def _find_shared_patterns(self) -> List[Dict]: """Find patterns that could be shared across tensors.""" shared_patterns = [] pattern_groups = {} for tensor_name, info in self.tensor_info.items(): if info.hcf_patterns: for idx, pattern in info.hcf_patterns.items(): # Create pattern signature signature = f"{pattern['mean']:.4f}_{pattern['std']:.4f}" if signature not in pattern_groups: pattern_groups[signature] = [] pattern_groups[signature].append({ 'tensor': tensor_name, 'index': idx, 'pattern': pattern }) # Find groups with similar patterns for signature, group in pattern_groups.items(): if len(group) > 1: shared_patterns.append({ 'signature': signature, 'occurrences': group, 'potential_savings': sum(p['pattern']['size'] for p in group[1:]) }) return shared_patterns def _generate_optimization_suggestions(self, results: Dict) -> List[Dict]: """Generate optimization suggestions based on analysis""" # Implementation placeholder suggestions = [] for tensor_name, impact in results['memory_impact'].items(): if impact['potential_savings'] > 1000000: # If savings > 1MB suggestions.append({ 'tensor': tensor_name, 'suggestion': 'Consider weight quantization', 'impact': f"Save {impact['potential_savings'] / 1024 / 1024:.2f}MB" }) return suggestions @dataclass class TrainingStatistics: """Statistics collected during HCF-aware training""" memory_savings: int = 0 quantization_error: float = 0.0 convergence_rate: float = 0.0 epoch: int = 0 batch_count: int = 0 def update(self, batch_stats: Dict[str, Any]): """Update statistics with batch results""" self.memory_savings += batch_stats.get('memory_savings', 0) self.quantization_error = batch_stats.get('quantization_error', self.quantization_error) self.convergence_rate = batch_stats.get('convergence_rate', self.convergence_rate) self.batch_count += 1 class HCFTrainingOptimizer(torch.optim.Adam): """ Optimizer with HCF-awareness for more efficient training """ def __init__(self, params, lr=0.001, betas=(0.9, 0.999), eps=1e-8, weight_decay=0, weight_quantization=True, maintain_patterns=True): super().__init__(params, lr, betas, eps, weight_decay) self.weight_quantization = weight_quantization self.maintain_patterns = maintain_patterns self.analyzer = SafeTensorHCFAnalyzer() self.stats = {'memory_savings': 0, 'quantization_error': 0.0} def step(self, closure=None): """Perform optimization step with HCF awareness""" # Run standard optimization step loss = super().step(closure) # Apply HCF optimizations if enabled if self.weight_quantization: self._apply_weight_quantization() if self.maintain_patterns: self._maintain_weight_patterns() return loss def _apply_weight_quantization(self): """Apply dynamic weight quantization using HCF patterns""" savings = 0 total_error = 0.0 for group in self.param_groups: for p in group['params']: if p.grad is None or not p.requires_grad: continue # Apply weight quantization logic based on HCF analysis # This is a simplified placeholder - real implementation would be more complex if p.dim() > 1: # Only apply to matrices/tensors # Find suitable quantization factor factor = torch.max(torch.abs(p.data)) / 127 # 8-bit quantization example # Quantize weights quantized = torch.round(p.data / factor) * factor # Calculate error and savings error = torch.mean((p.data - quantized)**2).item() savings += p.numel() * (p.element_size() - 1) # Assuming 8-bit savings # Apply quantized weights p.data.copy_(quantized) total_error += error # Update statistics self.stats['memory_savings'] = savings self.stats['quantization_error'] = total_error def _maintain_weight_patterns(self): """Maintain efficient weight patterns identified by HCF analysis""" # Placeholder for pattern maintenance logic # Real implementation would analyze weight matrices and enforce patterns pass def get_stats(self): """Get current optimization statistics""" return self.stats class HCFAwareTrainer: """ Trainer that incorporates HCF analysis for better training efficiency """ def __init__(self, model, optimizer): self.model = model self.optimizer = optimizer self.analyzer = SafeTensorHCFAnalyzer() def train_epoch(self, train_loader, criterion, epoch): """Train one epoch with HCF awareness""" self.model.train() stats = TrainingStatistics(epoch=epoch) for batch_idx, batch in enumerate(train_loader): # Get data inputs, targets = self._prepare_batch(batch) # Forward pass self.optimizer.zero_grad() outputs = self.model(inputs) loss = criterion(outputs, targets) # Backward pass loss.backward() # Optimize with HCF awareness self.optimizer.step() # Get batch statistics batch_stats = self.optimizer.get_stats() stats.update(batch_stats) # Log progress if batch_idx % 50 == 0: logger.info(f"Epoch {epoch} | Batch {batch_idx}/{len(train_loader)} | " f"Memory Savings: {stats.memory_savings/1024/1024:.2f}MB | " f"Quantization Error: {stats.quantization_error:.6f}") # End of epoch analysis self._analyze_model_weights() return stats def _prepare_batch(self, batch): """Prepare batch data for training""" # Implementation depends on dataset structure if isinstance(batch, dict): inputs = batch.get('input_ids') targets = batch.get('labels', inputs) else: # Assume batch is a tuple of (inputs, targets) inputs, targets = batch return inputs, targets def _analyze_model_weights(self): """Analyze model weights for patterns and optimizations""" # Save model to temporary safetensor file for analysis model_path = "temp_model.safetensors" tensors = {name: param for name, param in self.model.named_parameters()} save_file(tensors, model_path) # Analyze weights results = self.analyzer.analyze_safetensor_weights(model_path) # Log findings logger.info(f"Weight Analysis: Found {len(results['shared_patterns'])} shared patterns") logger.info(f"Potential memory savings: " f"{sum(i['potential_savings'] for i in results['memory_impact'].values())/1024/1024:.2f}MB") # Clean up if os.path.exists(model_path): os.remove(model_path) @dataclass class ModelConfig: name: str model_id: str tokenizer_id: str CONFIGS = { "7b": ModelConfig( name="7b", model_id="scrapegoat/ScrapeGoat-Music-Stage1", tokenizer_id="scrapegoat/ScrapeGoat-Music-Stage1" ), "1b": ModelConfig( name="1b", model_id="scrapegoat/ScrapeGoat-Music-Stage2", tokenizer_id="scrapegoat/ScrapeGoat-Music-Stage2" ) } class MusicFineTuner: def __init__( self, model_size: str, dataset_path: str, output_dir: str, device: str = "auto", batch_size: int = 4, gradient_accumulation_steps: int = 4, learning_rate: float = 1e-5, num_epochs: int = 3, use_hcf: bool = True ): self.config = CONFIGS[model_size] self.dataset_path = Path(dataset_path) self.output_dir = Path(output_dir) self.device = self._setup_device(device) self.use_hcf = use_hcf self.training_args = TrainingArguments( output_dir=str(self.output_dir), per_device_train_batch_size=batch_size, gradient_accumulation_steps=gradient_accumulation_steps, learning_rate=learning_rate, num_train_epochs=num_epochs, logging_steps=100, save_steps=1000, evaluation_strategy="steps", eval_steps=500, save_total_limit=3, load_best_model_at_end=True, gradient_checkpointing=True, fp16=torch.cuda.is_available(), optim="adamw_torch" ) def _setup_device(self, device: str) -> str: if device == "auto": if torch.cuda.is_available(): return "cuda" elif torch.backends.mps.is_available(): return "mps" else: return "cpu" return device def _load_model_and_tokenizer(self): logger.info(f"Loading model {self.config.model_id}") # Determine dtype based on device dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 model = AutoModelForCausalLM.from_pretrained( self.config.model_id, torch_dtype=dtype, device_map="auto" if self.device == "cuda" else None, attn_implementation="flash_attention_2" if self.device == "cuda" else "eager" ) tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_id) return model, tokenizer def _prepare_dataset(self, tokenizer): logger.info("Preparing dataset") with open(self.dataset_path / "metadata" / "dataset_info.json") as f: metadata = json.load(f) def generate_text(item): return f"Genre: {item['genre']}\nDuration: {item['duration']:.2f}s\nTitle: {item['title']}\nArtist: {item['artist']}\n" texts = [generate_text(item) for item in metadata["files"]] dataset = Dataset.from_dict({"text": texts}) def tokenize(examples): return tokenizer( examples["text"], truncation=True, padding="max_length", max_length=512, return_tensors="pt" ) tokenized_dataset = dataset.map( tokenize, batched=True, remove_columns=dataset.column_names ) return tokenized_dataset def train(self): # Create output directory self.output_dir.mkdir(parents=True, exist_ok=True) # Load model and tokenizer model, tokenizer = self._load_model_and_tokenizer() # Prepare dataset dataset = self._prepare_dataset(tokenizer) # Split dataset dataset = dataset.train_test_split(test_size=0.1) if self.use_hcf: logger.info("Using HCF-aware training") # Create custom HCF optimizer optimizer = HCFTrainingOptimizer( model.parameters(), lr=self.training_args.learning_rate, weight_quantization=True, maintain_patterns=True ) # Create HCF trainer hcf_trainer = HCFAwareTrainer(model, optimizer) # Create custom training loop train_loader = torch.utils.data.DataLoader( dataset["train"], batch_size=self.training_args.per_device_train_batch_size, shuffle=True ) # Training loop with HCF awareness criterion = torch.nn.CrossEntropyLoss() for epoch in range(int(self.training_args.num_train_epochs)): stats = hcf_trainer.train_epoch(train_loader, criterion, epoch) # Log training metrics logger.info(f"Epoch {epoch} completed") logger.info(f"Memory Savings: {stats.memory_savings/1024/1024:.2f}MB") logger.info(f"Quantization Error: {stats.quantization_error:.6f}") logger.info(f"Convergence Rate: {stats.convergence_rate:.4f}") # Save checkpoint self._save_hcf_checkpoint(model, tokenizer, epoch) else: # Use standard HuggingFace Trainer logger.info("Using standard training") trainer = Trainer( model=model, args=self.training_args, train_dataset=dataset["train"], eval_dataset=dataset["test"], data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), ) # Train logger.info("Starting training") trainer.train() # Save final model logger.info("Saving model") model.save_pretrained(str(self.output_dir / "final_model")) tokenizer.save_pretrained(str(self.output_dir / "final_model")) def _save_hcf_checkpoint(self, model, tokenizer, epoch): """Save checkpoint with HCF metadata""" checkpoint_dir = self.output_dir / f"checkpoint-{epoch}" checkpoint_dir.mkdir(exist_ok=True) # Save model and tokenizer model.save_pretrained(str(checkpoint_dir)) tokenizer.save_pretrained(str(checkpoint_dir)) # Analyze and save HCF metadata analyzer = SafeTensorHCFAnalyzer() # Save tensors to analyze model_path = str(checkpoint_dir / "model.safetensors") if os.path.exists(model_path): results = analyzer.analyze_safetensor_weights(model_path) # Save analysis results with open(checkpoint_dir / "hcf_analysis.json", "w") as f: json.dump(results, f, indent=2) logger.info(f"Saved checkpoint at {checkpoint_dir}") if __name__ == "__main__": import argparse parser = argparse.ArgumentParser() parser.add_argument("--model_size", type=str, choices=["1b", "7b"], required=True) parser.add_argument("--dataset_path", type=str, required=True) parser.add_argument("--output_dir", type=str, required=True) parser.add_argument("--device", type=str, default="auto") parser.add_argument("--batch_size", type=int, default=4) parser.add_argument("--gradient_accumulation_steps", type=int, default=4) parser.add_argument("--learning_rate", type=float, default=1e-5) parser.add_argument("--num_epochs", type=int, default=3) parser.add_argument("--use_hcf", action="store_true", help="Enable HCF-aware training") args = parser.parse_args() fine_tuner = MusicFineTuner( model_size=args.model_size, dataset_path=args.dataset_path, output_dir=args.output_dir, device=args.device, batch_size=args.batch_size, gradient_accumulation_steps=args.gradient_accumulation_steps, learning_rate=args.learning_rate, num_epochs=args.num_epochs, use_hcf=args.use_hcf ) fine_tuner.train()