|
import os |
|
import json |
|
import torch |
|
import logging |
|
from pathlib import Path |
|
from dataclasses import dataclass |
|
from typing import Optional, List, Dict, Tuple, Any |
|
import transformers |
|
from transformers import ( |
|
AutoModelForCausalLM, |
|
AutoTokenizer, |
|
TrainingArguments, |
|
Trainer, |
|
DataCollatorForLanguageModeling |
|
) |
|
from datasets import Dataset, load_dataset |
|
import numpy as np |
|
from accelerate import Accelerator |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file, load_file |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
@dataclass |
|
class TensorInfo: |
|
"""Stores metadata about tensor indices and shape""" |
|
shape: Tuple[int, ...] |
|
dtype: str |
|
indices: Optional[torch.Tensor] = None |
|
hcf_patterns: Optional[Dict] = None |
|
|
|
class SafeTensorHCFAnalyzer: |
|
""" |
|
Analyzes HCF patterns in model weights using SafeTensors format. |
|
Handles efficient loading and analysis of large model weights. |
|
""" |
|
|
|
def __init__(self, tolerance: float = 1e-5): |
|
self.tolerance = tolerance |
|
self.tensor_info = {} |
|
self.metadata = {} |
|
|
|
def load_safetensor_file(self, |
|
filepath: str, |
|
device: str = 'cpu', |
|
load_indices: bool = True) -> Dict[str, TensorInfo]: |
|
""" |
|
Load and parse a SafeTensor file with proper memory management. |
|
|
|
Args: |
|
filepath: Path to .safetensors file |
|
device: Device to load tensors to |
|
load_indices: Whether to load weight indices |
|
|
|
Returns: |
|
Dictionary mapping tensor names to their metadata |
|
""" |
|
try: |
|
|
|
with safe_open(filepath, framework="pt") as f: |
|
self.metadata = json.loads(f.metadata()) if f.metadata() else {} |
|
|
|
|
|
tensors = load_file(filepath, device=device) |
|
|
|
for tensor_name, tensor in tensors.items(): |
|
self.tensor_info[tensor_name] = TensorInfo( |
|
shape=tuple(tensor.shape), |
|
dtype=str(tensor.dtype) |
|
) |
|
|
|
|
|
if load_indices and tensor_name in self.metadata: |
|
if 'indices' in self.metadata[tensor_name]: |
|
indices_data = self.metadata[tensor_name]['indices'] |
|
if isinstance(indices_data, list): |
|
self.tensor_info[tensor_name].indices = torch.tensor( |
|
indices_data, device=device |
|
) |
|
elif isinstance(indices_data, str) and os.path.exists(indices_data): |
|
|
|
self.tensor_info[tensor_name].indices = torch.load(indices_data) |
|
|
|
return self.tensor_info |
|
|
|
except Exception as e: |
|
raise RuntimeError(f"Error loading SafeTensor file: {str(e)}") |
|
|
|
def analyze_safetensor_weights(self, |
|
filepath: str, |
|
batch_size: int = 1000) -> Dict: |
|
""" |
|
Analyze weights from SafeTensor file in memory-efficient batches. |
|
|
|
Args: |
|
filepath: Path to .safetensors file |
|
batch_size: Number of weights to process at once |
|
|
|
Returns: |
|
Analysis results including HCF patterns and optimization opportunities |
|
""" |
|
results = { |
|
'tensor_hcfs': {}, |
|
'shared_patterns': [], |
|
'optimization_suggestions': [], |
|
'memory_impact': {} |
|
} |
|
|
|
|
|
with safe_open(filepath, framework="pt") as f: |
|
for tensor_name in f.keys(): |
|
|
|
tensor_data = f.get_tensor(tensor_name) |
|
tensor_size = np.prod(tensor_data.shape) |
|
|
|
if tensor_name in self.tensor_info and self.tensor_info[tensor_name].indices is not None: |
|
indices = self.tensor_info[tensor_name].indices |
|
unique_indices = torch.unique(indices) |
|
|
|
|
|
tensor_hcfs = {} |
|
for idx in unique_indices: |
|
mask = (indices == idx) |
|
indexed_weights = tensor_data[mask] |
|
|
|
|
|
if len(indexed_weights) > batch_size: |
|
hcf = self._process_large_weight_group(indexed_weights, batch_size) |
|
else: |
|
hcf = self._calculate_hcf(indexed_weights) |
|
|
|
tensor_hcfs[idx.item()] = hcf |
|
|
|
results['tensor_hcfs'][tensor_name] = tensor_hcfs |
|
|
|
|
|
patterns = self._analyze_weight_patterns(tensor_data, indices) |
|
self.tensor_info[tensor_name].hcf_patterns = patterns |
|
|
|
|
|
savings = self._estimate_memory_savings(patterns, tensor_data.dtype) |
|
results['memory_impact'][tensor_name] = { |
|
'original_size': tensor_size * tensor_data.element_size(), |
|
'potential_savings': savings |
|
} |
|
|
|
|
|
results['shared_patterns'] = self._find_shared_patterns() |
|
results['optimization_suggestions'] = self._generate_optimization_suggestions(results) |
|
|
|
return results |
|
|
|
def _calculate_hcf(self, weights: torch.Tensor) -> float: |
|
"""Calculate HCF for a tensor of weights, with tolerance for floating point""" |
|
|
|
if len(weights) == 0: |
|
return 0.0 |
|
return 1.0 |
|
|
|
def _gcd_float(self, a: float, b: float) -> float: |
|
"""Calculate greatest common divisor for floating point numbers""" |
|
|
|
return min(a, b) |
|
|
|
def _process_large_weight_group(self, |
|
weights: torch.Tensor, |
|
batch_size: int) -> float: |
|
"""Process large weight groups in batches to manage memory.""" |
|
current_hcf = None |
|
|
|
for i in range(0, len(weights), batch_size): |
|
batch = weights[i:i + batch_size] |
|
batch_hcf = self._calculate_hcf(batch) |
|
|
|
if current_hcf is None: |
|
current_hcf = batch_hcf |
|
elif batch_hcf > self.tolerance: |
|
current_hcf = self._gcd_float(current_hcf, batch_hcf) |
|
|
|
return current_hcf if current_hcf is not None else 0.0 |
|
|
|
def _analyze_weight_patterns(self, |
|
weights: torch.Tensor, |
|
indices: torch.Tensor) -> Dict: |
|
"""Analyze weight patterns within indexed groups.""" |
|
patterns = {} |
|
unique_indices = torch.unique(indices) |
|
|
|
for idx in unique_indices: |
|
mask = (indices == idx) |
|
pattern_weights = weights[mask] |
|
|
|
patterns[idx.item()] = { |
|
'mean': float(pattern_weights.mean()), |
|
'std': float(pattern_weights.std()), |
|
'size': len(pattern_weights), |
|
'hcf': self._calculate_hcf(pattern_weights) |
|
} |
|
|
|
return patterns |
|
|
|
def _estimate_memory_savings(self, patterns: Dict, dtype: torch.dtype) -> int: |
|
"""Estimate potential memory savings from patterns""" |
|
|
|
return sum(p['size'] for p in patterns.values()) // 2 |
|
|
|
def _find_shared_patterns(self) -> List[Dict]: |
|
"""Find patterns that could be shared across tensors.""" |
|
shared_patterns = [] |
|
pattern_groups = {} |
|
|
|
for tensor_name, info in self.tensor_info.items(): |
|
if info.hcf_patterns: |
|
for idx, pattern in info.hcf_patterns.items(): |
|
|
|
signature = f"{pattern['mean']:.4f}_{pattern['std']:.4f}" |
|
|
|
if signature not in pattern_groups: |
|
pattern_groups[signature] = [] |
|
pattern_groups[signature].append({ |
|
'tensor': tensor_name, |
|
'index': idx, |
|
'pattern': pattern |
|
}) |
|
|
|
|
|
for signature, group in pattern_groups.items(): |
|
if len(group) > 1: |
|
shared_patterns.append({ |
|
'signature': signature, |
|
'occurrences': group, |
|
'potential_savings': sum(p['pattern']['size'] for p in group[1:]) |
|
}) |
|
|
|
return shared_patterns |
|
|
|
def _generate_optimization_suggestions(self, results: Dict) -> List[Dict]: |
|
"""Generate optimization suggestions based on analysis""" |
|
|
|
suggestions = [] |
|
for tensor_name, impact in results['memory_impact'].items(): |
|
if impact['potential_savings'] > 1000000: |
|
suggestions.append({ |
|
'tensor': tensor_name, |
|
'suggestion': 'Consider weight quantization', |
|
'impact': f"Save {impact['potential_savings'] / 1024 / 1024:.2f}MB" |
|
}) |
|
return suggestions |
|
|
|
@dataclass |
|
class TrainingStatistics: |
|
"""Statistics collected during HCF-aware training""" |
|
memory_savings: int = 0 |
|
quantization_error: float = 0.0 |
|
convergence_rate: float = 0.0 |
|
epoch: int = 0 |
|
batch_count: int = 0 |
|
|
|
def update(self, batch_stats: Dict[str, Any]): |
|
"""Update statistics with batch results""" |
|
self.memory_savings += batch_stats.get('memory_savings', 0) |
|
self.quantization_error = batch_stats.get('quantization_error', self.quantization_error) |
|
self.convergence_rate = batch_stats.get('convergence_rate', self.convergence_rate) |
|
self.batch_count += 1 |
|
|
|
class HCFTrainingOptimizer(torch.optim.Adam): |
|
""" |
|
Optimizer with HCF-awareness for more efficient training |
|
""" |
|
def __init__(self, |
|
params, |
|
lr=0.001, |
|
betas=(0.9, 0.999), |
|
eps=1e-8, |
|
weight_decay=0, |
|
weight_quantization=True, |
|
maintain_patterns=True): |
|
super().__init__(params, lr, betas, eps, weight_decay) |
|
self.weight_quantization = weight_quantization |
|
self.maintain_patterns = maintain_patterns |
|
self.analyzer = SafeTensorHCFAnalyzer() |
|
self.stats = {'memory_savings': 0, 'quantization_error': 0.0} |
|
|
|
def step(self, closure=None): |
|
"""Perform optimization step with HCF awareness""" |
|
|
|
loss = super().step(closure) |
|
|
|
|
|
if self.weight_quantization: |
|
self._apply_weight_quantization() |
|
|
|
if self.maintain_patterns: |
|
self._maintain_weight_patterns() |
|
|
|
return loss |
|
|
|
def _apply_weight_quantization(self): |
|
"""Apply dynamic weight quantization using HCF patterns""" |
|
savings = 0 |
|
total_error = 0.0 |
|
|
|
for group in self.param_groups: |
|
for p in group['params']: |
|
if p.grad is None or not p.requires_grad: |
|
continue |
|
|
|
|
|
|
|
if p.dim() > 1: |
|
|
|
factor = torch.max(torch.abs(p.data)) / 127 |
|
|
|
|
|
quantized = torch.round(p.data / factor) * factor |
|
|
|
|
|
error = torch.mean((p.data - quantized)**2).item() |
|
savings += p.numel() * (p.element_size() - 1) |
|
|
|
|
|
p.data.copy_(quantized) |
|
|
|
total_error += error |
|
|
|
|
|
self.stats['memory_savings'] = savings |
|
self.stats['quantization_error'] = total_error |
|
|
|
def _maintain_weight_patterns(self): |
|
"""Maintain efficient weight patterns identified by HCF analysis""" |
|
|
|
|
|
pass |
|
|
|
def get_stats(self): |
|
"""Get current optimization statistics""" |
|
return self.stats |
|
|
|
class HCFAwareTrainer: |
|
""" |
|
Trainer that incorporates HCF analysis for better training efficiency |
|
""" |
|
def __init__(self, model, optimizer): |
|
self.model = model |
|
self.optimizer = optimizer |
|
self.analyzer = SafeTensorHCFAnalyzer() |
|
|
|
def train_epoch(self, train_loader, criterion, epoch): |
|
"""Train one epoch with HCF awareness""" |
|
self.model.train() |
|
stats = TrainingStatistics(epoch=epoch) |
|
|
|
for batch_idx, batch in enumerate(train_loader): |
|
|
|
inputs, targets = self._prepare_batch(batch) |
|
|
|
|
|
self.optimizer.zero_grad() |
|
outputs = self.model(inputs) |
|
loss = criterion(outputs, targets) |
|
|
|
|
|
loss.backward() |
|
|
|
|
|
self.optimizer.step() |
|
|
|
|
|
batch_stats = self.optimizer.get_stats() |
|
stats.update(batch_stats) |
|
|
|
|
|
if batch_idx % 50 == 0: |
|
logger.info(f"Epoch {epoch} | Batch {batch_idx}/{len(train_loader)} | " |
|
f"Memory Savings: {stats.memory_savings/1024/1024:.2f}MB | " |
|
f"Quantization Error: {stats.quantization_error:.6f}") |
|
|
|
|
|
self._analyze_model_weights() |
|
|
|
return stats |
|
|
|
def _prepare_batch(self, batch): |
|
"""Prepare batch data for training""" |
|
|
|
if isinstance(batch, dict): |
|
inputs = batch.get('input_ids') |
|
targets = batch.get('labels', inputs) |
|
else: |
|
|
|
inputs, targets = batch |
|
|
|
return inputs, targets |
|
|
|
def _analyze_model_weights(self): |
|
"""Analyze model weights for patterns and optimizations""" |
|
|
|
model_path = "temp_model.safetensors" |
|
tensors = {name: param for name, param in self.model.named_parameters()} |
|
save_file(tensors, model_path) |
|
|
|
|
|
results = self.analyzer.analyze_safetensor_weights(model_path) |
|
|
|
|
|
logger.info(f"Weight Analysis: Found {len(results['shared_patterns'])} shared patterns") |
|
logger.info(f"Potential memory savings: " |
|
f"{sum(i['potential_savings'] for i in results['memory_impact'].values())/1024/1024:.2f}MB") |
|
|
|
|
|
if os.path.exists(model_path): |
|
os.remove(model_path) |
|
|
|
@dataclass |
|
class ModelConfig: |
|
name: str |
|
model_id: str |
|
tokenizer_id: str |
|
|
|
CONFIGS = { |
|
"7b": ModelConfig( |
|
name="7b", |
|
model_id="scrapegoat/ScrapeGoat-Music-Stage1", |
|
tokenizer_id="scrapegoat/ScrapeGoat-Music-Stage1" |
|
), |
|
"1b": ModelConfig( |
|
name="1b", |
|
model_id="scrapegoat/ScrapeGoat-Music-Stage2", |
|
tokenizer_id="scrapegoat/ScrapeGoat-Music-Stage2" |
|
) |
|
} |
|
|
|
class MusicFineTuner: |
|
def __init__( |
|
self, |
|
model_size: str, |
|
dataset_path: str, |
|
output_dir: str, |
|
device: str = "auto", |
|
batch_size: int = 4, |
|
gradient_accumulation_steps: int = 4, |
|
learning_rate: float = 1e-5, |
|
num_epochs: int = 3, |
|
use_hcf: bool = True |
|
): |
|
self.config = CONFIGS[model_size] |
|
self.dataset_path = Path(dataset_path) |
|
self.output_dir = Path(output_dir) |
|
self.device = self._setup_device(device) |
|
self.use_hcf = use_hcf |
|
self.training_args = TrainingArguments( |
|
output_dir=str(self.output_dir), |
|
per_device_train_batch_size=batch_size, |
|
gradient_accumulation_steps=gradient_accumulation_steps, |
|
learning_rate=learning_rate, |
|
num_train_epochs=num_epochs, |
|
logging_steps=100, |
|
save_steps=1000, |
|
evaluation_strategy="steps", |
|
eval_steps=500, |
|
save_total_limit=3, |
|
load_best_model_at_end=True, |
|
gradient_checkpointing=True, |
|
fp16=torch.cuda.is_available(), |
|
optim="adamw_torch" |
|
) |
|
|
|
def _setup_device(self, device: str) -> str: |
|
if device == "auto": |
|
if torch.cuda.is_available(): |
|
return "cuda" |
|
elif torch.backends.mps.is_available(): |
|
return "mps" |
|
else: |
|
return "cpu" |
|
return device |
|
|
|
def _load_model_and_tokenizer(self): |
|
logger.info(f"Loading model {self.config.model_id}") |
|
|
|
|
|
dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 |
|
|
|
model = AutoModelForCausalLM.from_pretrained( |
|
self.config.model_id, |
|
torch_dtype=dtype, |
|
device_map="auto" if self.device == "cuda" else None, |
|
attn_implementation="flash_attention_2" if self.device == "cuda" else "eager" |
|
) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(self.config.tokenizer_id) |
|
return model, tokenizer |
|
|
|
def _prepare_dataset(self, tokenizer): |
|
logger.info("Preparing dataset") |
|
|
|
with open(self.dataset_path / "metadata" / "dataset_info.json") as f: |
|
metadata = json.load(f) |
|
|
|
def generate_text(item): |
|
return f"Genre: {item['genre']}\nDuration: {item['duration']:.2f}s\nTitle: {item['title']}\nArtist: {item['artist']}\n" |
|
|
|
texts = [generate_text(item) for item in metadata["files"]] |
|
dataset = Dataset.from_dict({"text": texts}) |
|
|
|
def tokenize(examples): |
|
return tokenizer( |
|
examples["text"], |
|
truncation=True, |
|
padding="max_length", |
|
max_length=512, |
|
return_tensors="pt" |
|
) |
|
|
|
tokenized_dataset = dataset.map( |
|
tokenize, |
|
batched=True, |
|
remove_columns=dataset.column_names |
|
) |
|
|
|
return tokenized_dataset |
|
|
|
def train(self): |
|
|
|
self.output_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
|
|
model, tokenizer = self._load_model_and_tokenizer() |
|
|
|
|
|
dataset = self._prepare_dataset(tokenizer) |
|
|
|
|
|
dataset = dataset.train_test_split(test_size=0.1) |
|
|
|
if self.use_hcf: |
|
logger.info("Using HCF-aware training") |
|
|
|
optimizer = HCFTrainingOptimizer( |
|
model.parameters(), |
|
lr=self.training_args.learning_rate, |
|
weight_quantization=True, |
|
maintain_patterns=True |
|
) |
|
|
|
|
|
hcf_trainer = HCFAwareTrainer(model, optimizer) |
|
|
|
|
|
train_loader = torch.utils.data.DataLoader( |
|
dataset["train"], |
|
batch_size=self.training_args.per_device_train_batch_size, |
|
shuffle=True |
|
) |
|
|
|
|
|
criterion = torch.nn.CrossEntropyLoss() |
|
for epoch in range(int(self.training_args.num_train_epochs)): |
|
stats = hcf_trainer.train_epoch(train_loader, criterion, epoch) |
|
|
|
|
|
logger.info(f"Epoch {epoch} completed") |
|
logger.info(f"Memory Savings: {stats.memory_savings/1024/1024:.2f}MB") |
|
logger.info(f"Quantization Error: {stats.quantization_error:.6f}") |
|
logger.info(f"Convergence Rate: {stats.convergence_rate:.4f}") |
|
|
|
|
|
self._save_hcf_checkpoint(model, tokenizer, epoch) |
|
else: |
|
|
|
logger.info("Using standard training") |
|
trainer = Trainer( |
|
model=model, |
|
args=self.training_args, |
|
train_dataset=dataset["train"], |
|
eval_dataset=dataset["test"], |
|
data_collator=DataCollatorForLanguageModeling(tokenizer, mlm=False), |
|
) |
|
|
|
|
|
logger.info("Starting training") |
|
trainer.train() |
|
|
|
|
|
logger.info("Saving model") |
|
model.save_pretrained(str(self.output_dir / "final_model")) |
|
tokenizer.save_pretrained(str(self.output_dir / "final_model")) |
|
|
|
def _save_hcf_checkpoint(self, model, tokenizer, epoch): |
|
"""Save checkpoint with HCF metadata""" |
|
checkpoint_dir = self.output_dir / f"checkpoint-{epoch}" |
|
checkpoint_dir.mkdir(exist_ok=True) |
|
|
|
|
|
model.save_pretrained(str(checkpoint_dir)) |
|
tokenizer.save_pretrained(str(checkpoint_dir)) |
|
|
|
|
|
analyzer = SafeTensorHCFAnalyzer() |
|
|
|
|
|
model_path = str(checkpoint_dir / "model.safetensors") |
|
if os.path.exists(model_path): |
|
results = analyzer.analyze_safetensor_weights(model_path) |
|
|
|
|
|
with open(checkpoint_dir / "hcf_analysis.json", "w") as f: |
|
json.dump(results, f, indent=2) |
|
|
|
logger.info(f"Saved checkpoint at {checkpoint_dir}") |
|
|
|
if __name__ == "__main__": |
|
import argparse |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--model_size", type=str, choices=["1b", "7b"], required=True) |
|
parser.add_argument("--dataset_path", type=str, required=True) |
|
parser.add_argument("--output_dir", type=str, required=True) |
|
parser.add_argument("--device", type=str, default="auto") |
|
parser.add_argument("--batch_size", type=int, default=4) |
|
parser.add_argument("--gradient_accumulation_steps", type=int, default=4) |
|
parser.add_argument("--learning_rate", type=float, default=1e-5) |
|
parser.add_argument("--num_epochs", type=int, default=3) |
|
parser.add_argument("--use_hcf", action="store_true", help="Enable HCF-aware training") |
|
args = parser.parse_args() |
|
|
|
fine_tuner = MusicFineTuner( |
|
model_size=args.model_size, |
|
dataset_path=args.dataset_path, |
|
output_dir=args.output_dir, |
|
device=args.device, |
|
batch_size=args.batch_size, |
|
gradient_accumulation_steps=args.gradient_accumulation_steps, |
|
learning_rate=args.learning_rate, |
|
num_epochs=args.num_epochs, |
|
use_hcf=args.use_hcf |
|
) |
|
fine_tuner.train() |