import os import time import math import copy import torch from torch import nn import torch.nn.functional as F from torch.nn.utils import prune from transformers import ( AutoTokenizer, AutoConfig, DataCollatorForLanguageModeling, Trainer, TrainingArguments, AutoModelForCausalLM, AutoModel, EarlyStoppingCallback, pipeline, get_scheduler, logging as hf_logging ) try: from peft import PeftModel, LoraConfig, get_peft_model, TaskType, PeftConfig _peft_installed = True except ImportError: _peft_installed = False PeftModel = None LoraConfig = None get_peft_model = None TaskType = None PeftConfig = None from datasets import load_dataset, interleave_datasets, concatenate_datasets, Dataset, Features, Value, IterableDataset, DatasetDict from huggingface_hub import login, create_repo, HfApi, hf_hub_download import wandb import gradio as gr from gradio_huggingfacehub_search import HuggingfaceHubSearch import re import json import gc from accelerate import Accelerator import logging import traceback from collections import Counter, OrderedDict import requests import gzip import inspect import shutil from functools import partial import types import psutil hf_logging.set_verbosity_error() logging.getLogger("datasets").setLevel(logging.ERROR) logging.getLogger("huggingface_hub").setLevel(logging.ERROR) logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') padding = True truncation = True TOKENIZERS_PARALLELISM = True os.environ["TOKENIZERS_PARALLELISM"] = str(TOKENIZERS_PARALLELISM) BATCH_SIZE = 8 LEARNING_RATE = 1.5e-4 EPOCHS = 1 MAX_STEPS = 1 USE_CPU = False NUM_CPU_CORES = -1 MERGE_ALPHA = 0.7 CONTEXT_LENGTH = 256 HEADS = 4 DIMENSIONS = 256 LAYERS = 1 INTERMEDIATE_SIZE = 1024 USE_WANDB = False ACTIVATION_FUNCTIONS = { "relu": nn.ReLU, "gelu": nn.GELU, "silu": nn.SiLU, "mish": nn.Mish, "leaky_relu": nn.LeakyReLU, "elu": nn.ELU, "relu6": nn.ReLU6, "tanh": nn.Tanh, "sigmoid": nn.Sigmoid, "identity": nn.Identity } DEFAULT_ACTIVATION_FUNCTION = "gelu" OPTIMIZERS = { "adamw_torch": torch.optim.AdamW, "adam_torch": torch.optim.Adam, "sgd": torch.optim.SGD, "adamax": torch.optim.Adamax, "adagrad": torch.optim.Adagrad, "rmsprop": torch.optim.RMSprop } DEFAULT_OPTIMIZER = "adamw_torch" PRUNING_AMOUNT = 0.2 QUANTIZATION_MODES = ['float32', 'float16', 'bfloat16'] DEFAULT_QUANTIZATION = 'float32' SCHEDULER_TYPES = ["linear", "cosine", "cosine_with_restarts", "polynomial", "constant", "constant_with_warmup"] DEFAULT_SCHEDULER = "cosine" GRADIENT_ACCUMULATION_STEPS = 1 EVAL_STEPS = 500 SAVE_STEPS = 500 LOGGING_STEPS = 100 EARLY_STOPPING_PATIENCE = 5 LOAD_BEST_MODEL_AT_END = True METRIC_FOR_BEST_MODEL = "eval_loss" AVAILABLE_MODALITIES = ['Image', 'Audio'] MODALITY_ENCODERS = { 'Image': 'google/vit-base-patch16-224-in21k', 'Audio': 'openai/whisper-base' } DEFAULT_PEFT_CONFIG_DICT = { "task_type": TaskType.CAUSAL_LM if _peft_installed else None, "inference_mode": False, "r": 8, "lora_alpha": 32, "lora_dropout": 0.1, "target_modules": None } if _peft_installed else {} global_model = None global_tokenizer = None global_pipe = None original_num_layers_global = LAYERS config = None target_layers = LAYERS current_peft_config = copy.deepcopy(DEFAULT_PEFT_CONFIG_DICT) if _peft_installed else {} RAM_LIMIT_PERCENT = 85.0 DISK_LIMIT_GB = 5.0 BYPASS_RESOURCE_LIMITS = False class RMSNorm(nn.Module): def __init__(self, dim: int, eps: float = 1e-6, elementwise_affine=True, device=None, dtype=None): factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() self.dim = dim self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(dim, **factory_kwargs)) else: self.register_parameter('weight', None) self.reset_parameters() def reset_parameters(self): if self.elementwise_affine: nn.init.ones_(self.weight) def _norm(self, x): return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) def forward(self, x): output = self._norm(x.float()).type_as(x) if self.elementwise_affine: output = output * self.weight return output def extra_repr(self): return f'{self.dim}, eps={self.eps}, elementwise_affine={self.elementwise_affine}' def activation_quant(x): scale = 127.0 / x.abs().max(dim=-1, keepdim=True).values.clamp(min=1e-5) y = (x * scale).round().clamp(-128, 127) / scale return y def weight_quant(w): scale = 1.0 / w.abs().mean().clamp(min=1e-5) u = (w * scale).round().clamp(-1, 1) / scale return u class BitLinear(nn.Linear): def forward(self, x): w = self.weight device = w.device if x.device != device: x = x.to(device) rms_norm_module = RMSNorm(x.shape[-1], eps=1e-6, elementwise_affine=False, device=device, dtype=x.dtype) x_norm = rms_norm_module(x) x_quant = x_norm + (activation_quant(x_norm) - x_norm).detach() w_quant = w + (weight_quant(w) - w).detach() bias = self.bias.to(w_quant.dtype) if self.bias is not None else None output = F.linear(x_quant, w_quant, None) if bias is not None: output = output + bias.to(output.dtype) return output def to(self, *args, **kwargs): super().to(*args, **kwargs) if self.bias is not None: self.bias = self.bias.to(*args, **kwargs) return self class BypassLayerNorm(nn.Module): def __init__(self, normalized_shape, eps=1e-5, elementwise_affine=True, device=None, dtype=None): factory_kwargs = {'device': device, 'dtype': dtype} super().__init__() if isinstance(normalized_shape, int): self.normalized_shape = (normalized_shape,) else: self.normalized_shape = tuple(normalized_shape) self.eps = eps self.elementwise_affine = elementwise_affine if self.elementwise_affine: self.weight = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) self.bias = nn.Parameter(torch.empty(self.normalized_shape, **factory_kwargs)) else: self.register_parameter('weight', None) self.register_parameter('bias', None) self.bypass = False self.reset_parameters() def reset_parameters(self) -> None: if self.elementwise_affine: nn.init.ones_(self.weight) nn.init.zeros_(self.bias) def forward(self, x): if self.bypass: return x original_dtype = x.dtype if original_dtype not in [torch.float32, torch.float16, torch.bfloat16]: x = x.float() output = F.layer_norm(x, self.normalized_shape, self.weight, self.bias, self.eps) return output.to(original_dtype) def extra_repr(self) -> str: return f'{self.normalized_shape}, eps={self.eps}, elementwise_affine={self.elementwise_affine}, bypass={self.bypass}' class BypassDropout(nn.Module): def __init__(self, p=0.5, inplace=False): super().__init__() self.p = p self.inplace = inplace self.bypass = False def forward(self, x): if self.bypass or not self.training or self.p == 0: return x return F.dropout(x, self.p, self.training, self.inplace) def extra_repr(self) -> str: return f'p={self.p}, inplace={self.inplace}, bypass={self.bypass}' def get_device(): if torch.cuda.is_available() and not USE_CPU: return torch.device("cuda") elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and not USE_CPU: logging.info("MPS backend detected on Mac. Note: MPS support is experimental and may have limitations.") return torch.device("mps") else: if not USE_CPU: logging.warning("CUDA/MPS not available or USE_CPU=True. Falling back to CPU.") return torch.device("cpu") def clean_memory(): gc.collect() if torch.cuda.is_available(): torch.cuda.empty_cache() logging.debug("Cleaned memory.") def check_resources(ram_limit_pct=RAM_LIMIT_PERCENT, disk_limit_gb=DISK_LIMIT_GB): if BYPASS_RESOURCE_LIMITS: logging.info("Resource limit checks bypassed.") return True, "Resource checks bypassed." try: ram = psutil.virtual_memory() ram_used_pct = ram.percent ram_ok = ram_used_pct < ram_limit_pct disk = psutil.disk_usage('/') disk_free_gb = disk.free / (1024**3) disk_ok = disk_free_gb > disk_limit_gb status_msg = (f"Resource Check: RAM Used: {ram_used_pct:.1f}% (Limit: <{ram_limit_pct}%), " f"Disk Free: {disk_free_gb:.1f}GB (Limit: >{disk_limit_gb}GB).") if ram_ok and disk_ok: logging.info(status_msg + " Status: OK") return True, status_msg + " Status: OK" else: warning_msg = status_msg + " Status: LIMIT EXCEEDED!" logging.warning(warning_msg) return False, warning_msg except Exception as e: logging.error(f"Failed to check resources: {e}") return True, f"Resource check failed: {e}" def initialize_config_flags(existing_config=None): if existing_config is None: from transformers import PretrainedConfig config_obj = PretrainedConfig() elif isinstance(existing_config, dict): from transformers import PretrainedConfig try: config_obj = PretrainedConfig(**existing_config) except Exception as e: logging.warning(f"Could not initialize PretrainedConfig from dict, using default. Error: {e}") config_obj = PretrainedConfig() else: config_obj = existing_config default_flags = { "reduced_layers": False, "original_num_layers": None, "removed_bias": False, "untied_embeddings": False, "limits_configured": False, "qa_restrictions_removed": False, "additional_mechanisms_applied": False, "safety_settings_enabled": True, "perfect_precision_recovered": False, "token_gen_speed_maximized": False, "coherence_improvement_enabled": False, "inconsistencies_biases_removed": False, "quantization_applied": False, "quantization_mode": DEFAULT_QUANTIZATION, "layer_norm_bypassed": False, "replaced_layer_norm": False, "dropout_bypassed": False, "replaced_dropout": False, "activation_function_swapped": False, "current_activation_function": DEFAULT_ACTIVATION_FUNCTION, "embedding_normalized": False, "gradient_clipping_disabled": False, "weight_decay_disabled": False, "lr_scheduler_disabled": False, "bitnet_applied": False, "gradient_checkpointing_enabled": False, "pruning_applied": False, "pruning_amount": None, "frozen_layers": None, "enhanced_security_enabled": False, "debug_mode_enabled": False, "auto_optimization_enabled": False, "internal_logging_enabled": False, "drift_detection_enabled": False, "ultra_fast_mode": False, "optimizer": DEFAULT_OPTIMIZER, "rms_norm_applied": False, "layerdrop_enabled": False, "layerdrop_prob": 0.0, "lora_merged": False, "lora_adapter_path": None, "knowledge_distillation_setup": False, "kd_num_labels": None, "reward_modeling_setup": False, "rm_num_outputs": 1, "swa_applied": False, "knowledge_edited": False, "head_pruning_applied": False, "qat_applied": False, "architecture_merged": False, "weight_init_applied": False, "gradient_noise_applied": False, "rope_scaling_type": None, "rope_scaling_factor": None, "sliding_window_size": None, "attention_variant": None, "response_filters": True, "harassment_filter": True, "hate_filter": True, "sexually_explicit_filter": True, "dangerous_content_filter": True, "civic_integrity_filter": True, "code_filter": True, "medical_advice_filter": True, "legal_advice_filter": True, "financial_advice_filter": True, "pii_filter": True, "political_filter": True, "religious_filter": True, "profanity_filter": True, "stereotype_filter": True, "misinfo_filter": True, "self_harm_filter": True, "personal_attack_filter": True, "toxicity_filter": True, "spam_filter": True, "off_topic_filter": True, "tone_filter": True, "min_max_length_filter": True, "repetition_filter_enabled": False, "factuality_filter_enabled": False, "baseline_distribution": None, "remove_censorship": False, "no_response_filters": False, "no_advert_warning": False, "no_limits": False, "knowledge_date": None, "cutoff_date": None, "max_input_tokens": None, "max_output_tokens": None, "multimodal_applied": False, "supported_modalities": [], "modality_encoders": {}, "modality_projection_dim": None, "modality_special_tokens": {}, "use_flash_attention_2": getattr(config_obj, 'attn_implementation', None) == 'flash_attention_2', "attn_implementation": getattr(config_obj, 'attn_implementation', 'auto'), "gradient_accumulation_steps": GRADIENT_ACCUMULATION_STEPS, "peft_adapter_added": False, "peft_config": None } for flag, default_value in default_flags.items(): if not hasattr(config_obj, flag): setattr(config_obj, flag, default_value) if getattr(config_obj, 'attn_implementation', 'auto') == 'flash_attention_2': config_obj.use_flash_attention_2 = True else: config_obj.use_flash_attention_2 = False if getattr(config_obj, 'quantization_mode', DEFAULT_QUANTIZATION) == 'float32': config_obj.quantization_applied = False config_obj.perfect_precision_recovered = True else: config_obj.quantization_applied = True config_obj.perfect_precision_recovered = False if _peft_installed and isinstance(existing_config, PeftConfig): config_obj.peft_adapter_added = True config_obj.peft_config = existing_config.to_dict() return config_obj def _recursive_setattr(obj, attr_str, value): parts = attr_str.split('.') obj_to_set = obj try: for part in parts[:-1]: if not hasattr(obj_to_set, part): logging.warning(f"Intermediate attribute {part} not found in {attr_str} for object {type(obj_to_set)}") return False obj_to_set = getattr(obj_to_set, part) if obj_to_set is None: logging.warning(f"Intermediate attribute {part} is None in {attr_str}") return False if hasattr(obj_to_set, parts[-1]): setattr(obj_to_set, parts[-1], value) return True else: logging.warning(f"Final attribute {parts[-1]} not found in {attr_str} on object {type(obj_to_set)}") return False except AttributeError as e: logging.error(f"AttributeError setting {attr_str}: {e}") return False except Exception as e: logging.error(f"Generic error setting attribute {attr_str}: {e}") return False def _get_encoder_hidden_size(encoder_model_id, trust_remote_code=True): try: encoder_config = AutoConfig.from_pretrained(encoder_model_id, trust_remote_code=trust_remote_code) possible_keys = ['hidden_size', 'd_model', 'embed_dim'] for key in possible_keys: if hasattr(encoder_config, key): size = getattr(encoder_config, key) if isinstance(size, int) and size > 0: return size nested_configs = ['vision_config', 'audio_config', 'encoder'] for nested_name in nested_configs: if hasattr(encoder_config, nested_name): nested_cfg = getattr(encoder_config, nested_name) if nested_cfg and isinstance(nested_cfg, object): for key in possible_keys: if hasattr(nested_cfg, key): size = getattr(nested_cfg, key) if isinstance(size, int) and size > 0: return size raise ValueError(f"Could not automatically determine hidden/embedding size for encoder {encoder_model_id}. Checked attributes: {possible_keys} and nested configs: {nested_configs}.") except Exception as e: logging.error(f"Failed to get config or hidden size for encoder {encoder_model_id}: {e}") raise ValueError(f"Failed to get config or hidden size for encoder {encoder_model_id}") from e def convert_to_bitnet(model, config, copy_weights=True): if not hasattr(nn, 'RMSNorm'): logging.warning("BitNet conversion requires RMSNorm, which might not be standard. Using custom RMSNorm.") device = get_device() converted_count = 0 modules_to_process = list(model.named_modules()) processed_names = set() with torch.no_grad(): for name, module in modules_to_process: if name in processed_names: continue is_target_linear = isinstance(module, nn.Linear) and ( any(sub in name.lower() for sub in ["attn", "mlp", "fc", "dense", "query", "key", "value", "out", "wi", "wo"]) and "norm" not in name.lower() and "embedding" not in name.lower() ) if is_target_linear: try: current_dtype = module.weight.dtype if hasattr(module, 'weight') and module.weight is not None else torch.float32 has_bias = module.bias is not None bl = BitLinear(module.in_features, module.out_features, has_bias).to(device=device, dtype=current_dtype) if copy_weights and hasattr(module, 'weight') and module.weight is not None: if bl.weight.shape == module.weight.shape: bl.weight.data.copy_(module.weight.data) else: logging.warning(f"Shape mismatch for weight {name}: Expected {bl.weight.shape}, got {module.weight.shape}. Skipping weight copy.") if has_bias and bl.bias is not None: if bl.bias.shape == module.bias.shape: bl.bias.data.copy_(module.bias.data) else: logging.warning(f"Shape mismatch for bias {name}: Expected {bl.bias.shape}, got {module.bias.shape}. Skipping bias copy.") elif not has_bias and bl.bias is not None: nn.init.zeros_(bl.bias) elif has_bias and bl.bias is None: logging.warning(f"Module {name} had bias, but BitLinear does not. Bias info lost.") elif not copy_weights: nn.init.xavier_uniform_(bl.weight) if bl.bias is not None: nn.init.zeros_(bl.bias) if _recursive_setattr(model, name, bl): converted_count += 1 processed_names.add(name) logging.debug(f"Converted layer {name} to BitLinear.") else: logging.warning(f"Failed to set BitLinear for {name}") processed_names.add(name) except Exception as e: logging.error(f"Error replacing {name} with BitLinear: {e} \n{traceback.format_exc()}") processed_names.add(name) if converted_count > 0: config.bitnet_applied = True logging.info(f"Applied BitNet conversion to {converted_count} linear layers.") return f"Applied BitNet conversion to {converted_count} linear layers." else: logging.info("No applicable linear layers found or converted for BitNet.") config.bitnet_applied = False return "No applicable layers found for BitNet conversion." def revert_bitnet(model, config): if not getattr(config, 'bitnet_applied', False): return "BitNet not applied according to config, nothing to revert." device = get_device() model.to(device) reverted_count = 0 modules_to_process = list(model.named_modules()) processed_names = set() with torch.no_grad(): for name, module in modules_to_process: if name in processed_names: continue if isinstance(module, BitLinear): try: dtype = module.weight.dtype if hasattr(module, 'weight') and module.weight is not None else torch.float32 has_bias = module.bias is not None lin = nn.Linear(module.in_features, module.out_features, bias=has_bias).to(device, dtype=dtype) if hasattr(module, 'weight') and module.weight is not None: if lin.weight.shape == module.weight.shape: lin.weight.data.copy_(module.weight.data) else: logging.warning(f"Shape mismatch reverting weight {name}: Expected {lin.weight.shape}, got {module.weight.shape}. Re-initializing nn.Linear weight.") nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5)) else: nn.init.kaiming_uniform_(lin.weight, a=math.sqrt(5)) if has_bias and lin.bias is not None: if lin.bias.shape == module.bias.shape: lin.bias.data.copy_(module.bias.data) else: logging.warning(f"Shape mismatch reverting bias {name}: Expected {lin.bias.shape}, got {module.bias.shape}. Re-initializing nn.Linear bias.") fan_in, _ = nn.init._calculate_fan_in_and_fan_out(lin.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(lin.bias, -bound, bound) elif has_bias and lin.bias is None: logging.error(f"BitLinear layer {name} had bias, but reverted nn.Linear does not. Reversion failed for bias.") elif not has_bias and lin.bias is not None: logging.error(f"BitLinear layer {name} lacked bias, but reverted nn.Linear has one. Setting to zero.") nn.init.zeros_(lin.bias) if _recursive_setattr(model, name, lin): reverted_count += 1 processed_names.add(name) logging.debug(f"Reverted BitLinear layer {name} to nn.Linear.") else: logging.warning(f"Failed to revert BitLinear for {name}") processed_names.add(name) except Exception as e: logging.error(f"Error reverting BitLinear {name}: {e} \n{traceback.format_exc()}") processed_names.add(name) if reverted_count > 0: config.bitnet_applied = False logging.info(f"Reverted {reverted_count} BitNet layers to standard nn.Linear.") return f"Reverted {reverted_count} BitNet layers." else: config.bitnet_applied = False logging.info("No BitNet layers found to revert, but flag was true. Resetting flag.") return "No BitNet layers found to revert." def _find_decoder_layers_module(model): prefixes = [ ('model.decoder', 'layers'), ('model.layers', None), ('transformer.h', None), ('transformer.blocks', None), ('encoder.layer', None), ('model.encoder.layers', None), ('', 'layers'), ('', 'h'), ('model', 'layers'), ('decoder.block', None), ('decoder.layers', None) ] if hasattr(model, 'model'): base_obj = model.model else: base_obj = model direct_attrs = ['layers', 'h', 'blocks', 'block'] for attr in direct_attrs: if hasattr(base_obj, attr): layer_list = getattr(base_obj, attr) if isinstance(layer_list, nn.ModuleList) and len(layer_list) > 0: logging.info(f"Found layer list at 'model.{attr}' or '{attr}' with {len(layer_list)} layers.") return base_obj, attr, layer_list elif isinstance(layer_list, (list, tuple)) and len(layer_list) > 0 and isinstance(layer_list[0], nn.Module): logging.warning(f"Found layers as list/tuple at 'model.{attr}' or '{attr}'. Converting to ModuleList.") setattr(base_obj, attr, nn.ModuleList(layer_list)) return base_obj, attr, getattr(base_obj, attr) for p_base, attr_name_explicit in prefixes: mod = model valid_path = True if p_base: for comp in p_base.split('.'): if not hasattr(mod, comp): valid_path = False break mod = getattr(mod, comp) if mod is None: valid_path = False break if not valid_path: continue attrs_to_check = [attr_name_explicit] if attr_name_explicit else ['layers', 'h', 'blocks', 'block', 'layer'] for attr in attrs_to_check: if hasattr(mod, attr): layer_list = getattr(mod, attr) if isinstance(layer_list, nn.ModuleList) and len(layer_list) > 0: logging.info(f"Found layer list at '{p_base}.{attr}' with {len(layer_list)} layers.") return mod, attr, layer_list elif isinstance(layer_list, (list, tuple)) and len(layer_list) > 0 and isinstance(layer_list[0], nn.Module): logging.warning(f"Found layers as list/tuple at '{p_base}.{attr}'. Converting to ModuleList.") setattr(mod, attr, nn.ModuleList(layer_list)) return mod, attr, getattr(mod, attr) logging.warning("Could not automatically find the standard decoder/transformer layer list module.") return None, None, None def _reduce_layers_to_one(base_model, config, target_layers=1): if not isinstance(target_layers, int) or target_layers < 1: logging.error(f"Invalid target_layers value: {target_layers}. Must be an integer >= 1.") return f"Error: Target layers must be >= 1, got {target_layers}." layer_module, layer_attr, current_layers = _find_decoder_layers_module(base_model) if layer_module and layer_attr and current_layers is not None: current_len = len(current_layers) if current_len <= 0: logging.warning("Found layer attribute but the ModuleList is empty. Cannot reduce.") return "Warning: Layer list found but it's empty. Cannot reduce." if current_len > target_layers: logging.info(f"Reducing layers: {current_len} -> {target_layers}...") original_layer_count = current_len if not hasattr(config, 'original_num_layers') or config.original_num_layers is None or config.original_num_layers < current_len: config.original_num_layers = original_layer_count logging.info(f"Stored/Updated original layer count in config: {original_layer_count}") new_layer_list = nn.ModuleList(current_layers[:target_layers]) setattr(layer_module, layer_attr, new_layer_list) config.num_hidden_layers = target_layers config.reduced_layers = True if hasattr(config, 'n_layer'): config.n_layer = target_layers if hasattr(config, 'num_layers'): config.num_layers = target_layers if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = target_layers if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = target_layers logging.info(f"Successfully reduced layers to {target_layers}.") clean_memory() return f"Layers reduced to {target_layers}. Original count was: {original_layer_count}." elif current_len == target_layers: logging.info(f"Model already has {current_len} layers, matching the target {target_layers}. No reduction needed.") config.reduced_layers = False if current_len == getattr(config, 'original_num_layers', current_len) else True config.num_hidden_layers = current_len if hasattr(config, 'n_layer'): config.n_layer = current_len if hasattr(config, 'num_layers'): config.num_layers = current_len if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len return f"Model already has {current_len} layers (target {target_layers}). No reduction performed." else: logging.info(f"Model has {current_len} layers, which is less than the target {target_layers}. No reduction needed.") config.reduced_layers = True config.num_hidden_layers = current_len if hasattr(config, 'n_layer'): config.n_layer = current_len if hasattr(config, 'num_layers'): config.num_layers = current_len if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len return f"Model already has {current_len} layers (< target {target_layers}). No reduction performed." else: logging.warning("Could not find standard layer structure for reduction.") config.reduced_layers = False return "Warning: Could not find standard layer structure for reduction." def _enable_full_layers(base_model, config, original_num_layers=None): if not getattr(config, 'reduced_layers', False): layer_module, layer_attr, current_layers = _find_decoder_layers_module(base_model) current_len = len(current_layers) if current_layers is not None else 0 orig_len_config = getattr(config, 'original_num_layers', None) if current_len > 0 and orig_len_config is not None and current_len == orig_len_config: config.reduced_layers = False return "Layers already seem to be at the original count. Flag corrected if necessary." else: return "Layers not previously reduced according to config flag, or cannot verify current/original counts." orig_layers = original_num_layers if original_num_layers is not None else getattr(config, 'original_num_layers', None) if orig_layers is None: global original_num_layers_global orig_layers = original_num_layers_global if orig_layers is not None: logging.warning(f"Using globally stored original layer count: {orig_layers} as it was missing in config.") config.original_num_layers = orig_layers else: logging.error("Cannot restore layers: Original layer count is missing from config and global state.") return "Error: Cannot revert - Original layer count unknown." if not isinstance(orig_layers, int) or orig_layers <= 0: logging.error(f"Cannot restore layers: Invalid original layer count found ({orig_layers}).") return f"Error: Cannot revert - Invalid original layer count ({orig_layers})." layer_module, layer_attr, current_layers = _find_decoder_layers_module(base_model) if layer_module and layer_attr and current_layers is not None: current_len = len(current_layers) if current_len < orig_layers: logging.info(f"Restoring layers: {current_len} -> {orig_layers}..."); T = time.time() try: if current_len == 0: logging.error("Cannot restore layers: No existing layers found to copy structure from.") return "Error: Cannot restore layers - no template layer available." device = next(iter(current_layers[0].parameters()), torch.tensor([])).device template_layer = current_layers[0].to('cpu') layers_to_add = [] num_layers_to_add = orig_layers - current_len logging.info(f"Need to add {num_layers_to_add} layers.") for i in range(num_layers_to_add): new_layer = copy.deepcopy(template_layer) for _, sub_module in new_layer.named_modules(): if hasattr(sub_module, 'reset_parameters'): try: sub_module.reset_parameters() except Exception as reset_e: logging.warning(f"Could not reset parameters for submodule {sub_module} in new layer {i}: {reset_e}") elif isinstance(sub_module, nn.Linear): nn.init.kaiming_uniform_(sub_module.weight, a=math.sqrt(5)) if sub_module.bias is not None: fan_in, _ = nn.init._calculate_fan_in_and_fan_out(sub_module.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(sub_module.bias, -bound, bound) elif isinstance(sub_module, nn.Embedding): nn.init.normal_(sub_module.weight) if sub_module.padding_idx is not None: with torch.no_grad(): sub_module.weight[sub_module.padding_idx].fill_(0) elif isinstance(sub_module, (nn.LayerNorm, RMSNorm, BypassLayerNorm)): if sub_module.elementwise_affine: if hasattr(sub_module, 'weight') and sub_module.weight is not None: nn.init.ones_(sub_module.weight) if hasattr(sub_module, 'bias') and sub_module.bias is not None: nn.init.zeros_(sub_module.bias) new_layer = new_layer.to(device) layers_to_add.append(new_layer) full_layer_list = nn.ModuleList(list(current_layers) + layers_to_add) setattr(layer_module, layer_attr, full_layer_list) config.num_hidden_layers = orig_layers config.reduced_layers = False if hasattr(config, 'n_layer'): config.n_layer = orig_layers if hasattr(config, 'num_layers'): config.num_layers = orig_layers if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = orig_layers if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = orig_layers msg = f"Restored layer structure to {orig_layers} layers in {time.time()-T:.2f}s." logging.info(msg) clean_memory() return msg except Exception as e: logging.error(f"Error restoring layers: {e}\n{traceback.format_exc()}") setattr(layer_module, layer_attr, current_layers) config.num_hidden_layers = current_len config.reduced_layers = True if hasattr(config, 'n_layer'): config.n_layer = current_len if hasattr(config, 'num_layers'): config.num_layers = current_len if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len return f"Error restoring layers: {e}. State might be inconsistent." else: config.reduced_layers = False config.num_hidden_layers = current_len if hasattr(config, 'n_layer'): config.n_layer = current_len if hasattr(config, 'num_layers'): config.num_layers = current_len if hasattr(config, 'num_decoder_layers'): config.num_decoder_layers = current_len if hasattr(config, 'num_encoder_layers') and 'encoder' in layer_attr: config.num_encoder_layers = current_len msg = f"Model already has {current_len} layers (>= original {orig_layers}). No restoration needed. Corrected flags if necessary." logging.info(msg) return msg elif layer_module and layer_attr and current_layers is None: logging.warning(f"Layer attribute '{layer_attr}' exists but is None or invalid. Cannot restore layers.") return "Warning: Layer attribute found but invalid. Cannot restore layers." else: logging.warning("Could not find standard layer structure for restoration.") return "Warning: Could not find standard layer structure for restoration." def _replace_linear_without_bias(module, config): device = get_device() replaced_count = 0 modules_to_process = list(module.named_modules()) processed_names = set() for name, child in modules_to_process: if name in processed_names: continue if isinstance(child, nn.Linear) and child.bias is not None: try: dtype = child.weight.dtype current_device = child.weight.device new_linear = nn.Linear(child.in_features, child.out_features, bias=False).to(device=current_device, dtype=dtype) with torch.no_grad(): if new_linear.weight.shape == child.weight.shape: new_linear.weight.copy_(child.weight) else: logging.warning(f"Shape mismatch removing bias for weight {name}: Expected {new_linear.weight.shape}, got {child.weight.shape}. Re-initializing.") nn.init.kaiming_uniform_(new_linear.weight, a=math.sqrt(5)) if _recursive_setattr(module, name, new_linear): replaced_count += 1 processed_names.add(name) logging.debug(f"Removed bias from layer {name}") else: logging.warning(f"Failed to set bias-less Linear for {name}") processed_names.add(name) except Exception as e: logging.error(f"Error removing bias for layer {name}: {e}") processed_names.add(name) if replaced_count > 0: config.removed_bias = True logging.info(f"Removed bias from {replaced_count} linear layers.") return f"Removed bias from {replaced_count} linear layers." else: logging.info("No linear layers with bias found to modify.") return "No linear layers with bias found to modify." def _enable_bias_in_linear(module, config): if not getattr(config, 'removed_bias', False): return "Bias not previously removed according to config flag. Cannot enable (revert)." device = get_device() enabled_count = 0 modules_to_process = list(module.named_modules()) processed_names = set() for name, child in modules_to_process: if name in processed_names: continue if isinstance(child, nn.Linear) and child.bias is None: try: dtype = child.weight.dtype current_device = child.weight.device new_linear = nn.Linear(child.in_features, child.out_features, bias=True).to(device=current_device, dtype=dtype) with torch.no_grad(): if new_linear.weight.shape == child.weight.shape: new_linear.weight.copy_(child.weight) else: logging.warning(f"Shape mismatch enabling bias for weight {name}: Expected {new_linear.weight.shape}, got {child.weight.shape}. Re-initializing weight.") nn.init.kaiming_uniform_(new_linear.weight, a=math.sqrt(5)) fan_in, _ = nn.init._calculate_fan_in_and_fan_out(new_linear.weight) bound = 1 / math.sqrt(fan_in) if fan_in > 0 else 0 nn.init.uniform_(new_linear.bias, -bound, bound) if _recursive_setattr(module, name, new_linear): enabled_count += 1 processed_names.add(name) logging.debug(f"Enabled bias for layer {name}") else: logging.warning(f"Failed to set biased Linear for {name}") processed_names.add(name) except Exception as e: logging.error(f"Error enabling bias for layer {name}: {e}") processed_names.add(name) if enabled_count > 0: config.removed_bias = False logging.info(f"Enabled (restored) bias for {enabled_count} linear layers.") return f"Enabled bias for {enabled_count} linear layers." else: config.removed_bias = False logging.info("No bias-less linear layers found to enable bias for. Resetting flag.") return "No bias-less linear layers found to enable bias for." def _replace_layer_norm_with_bypass(module, config): replaced_count = 0 modules_to_process = list(module.named_modules()) processed_names = set() for name, child in modules_to_process: if name in processed_names: continue if isinstance(child, nn.LayerNorm) and not isinstance(child, (BypassLayerNorm, RMSNorm)): try: child_device = get_device() child_dtype = torch.float32 if hasattr(child, 'weight') and child.weight is not None: child_device = child.weight.device child_dtype = child.weight.dtype elif hasattr(child, 'bias') and child.bias is not None: child_device = child.bias.device child_dtype = child.bias.dtype elif hasattr(child, '_parameters') and child._parameters: first_param = next(iter(child.parameters()), None) if first_param is not None: child_device = first_param.device child_dtype = first_param.dtype norm_shape = child.normalized_shape eps = child.eps affine = child.elementwise_affine new_layer_norm = BypassLayerNorm(norm_shape, eps, affine, device=child_device, dtype=child_dtype) if affine: with torch.no_grad(): if hasattr(child, 'weight') and child.weight is not None and new_layer_norm.weight is not None: if new_layer_norm.weight.shape == child.weight.shape: new_layer_norm.weight.copy_(child.weight) else: logging.warning(f"Shape mismatch replacing LN weight {name}. Expected {new_layer_norm.weight.shape}, got {child.weight.shape}. Initializing BypassLN weight.") nn.init.ones_(new_layer_norm.weight) elif new_layer_norm.weight is not None: nn.init.ones_(new_layer_norm.weight) if hasattr(child, 'bias') and child.bias is not None and new_layer_norm.bias is not None: if new_layer_norm.bias.shape == child.bias.shape: new_layer_norm.bias.copy_(child.bias) else: logging.warning(f"Shape mismatch replacing LN bias {name}. Expected {new_layer_norm.bias.shape}, got {child.bias.shape}. Initializing BypassLN bias.") nn.init.zeros_(new_layer_norm.bias) elif new_layer_norm.bias is not None: nn.init.zeros_(new_layer_norm.bias) if _recursive_setattr(module, name, new_layer_norm): replaced_count += 1 processed_names.add(name) logging.debug(f"Replaced LayerNorm {name} with BypassLayerNorm.") else: logging.warning(f"Failed to set BypassLayerNorm for {name}") processed_names.add(name) except Exception as e: logging.error(f"Error replacing LayerNorm {name} with Bypass version: {e}\n{traceback.format_exc()}") processed_names.add(name) if replaced_count > 0: config.replaced_layer_norm = True config.layer_norm_bypassed = False logging.info(f"Replaced {replaced_count} LayerNorm layers with Bypass version.") return f"Replaced {replaced_count} LayerNorm layers with Bypass version." else: logging.info("No standard nn.LayerNorm layers found to replace with BypassLayerNorm.") return "No standard LayerNorm layers found to replace." def _revert_bypass_layer_norm(module, config): if not getattr(config, 'replaced_layer_norm', False): return "BypassLayerNorm not previously applied according to config flag. Cannot revert." reverted_count = 0 modules_to_process = list(module.named_modules()) processed_names = set() for name, child in modules_to_process: if name in processed_names: continue if isinstance(child, BypassLayerNorm): try: child_device = get_device() child_dtype = torch.float32 if child.elementwise_affine: if child.weight is not None: child_device = child.weight.device child_dtype = child.weight.dtype elif child.bias is not None: child_device = child.bias.device child_dtype = child.bias.dtype else: pass norm_shape = child.normalized_shape eps = child.eps affine = child.elementwise_affine if isinstance(norm_shape, tuple) and len(norm_shape) == 1: norm_arg = norm_shape[0] elif isinstance(norm_shape, (list, tuple)): norm_arg = list(norm_shape) elif isinstance(norm_shape, int): norm_arg = norm_shape else: raise ValueError(f"Unsupported normalized_shape type for nn.LayerNorm: {type(norm_shape)}") new_layer_norm = nn.LayerNorm(norm_arg, eps, affine, device=child_device, dtype=child_dtype) if affine: with torch.no_grad(): if hasattr(child, 'weight') and child.weight is not None and new_layer_norm.weight is not None: if new_layer_norm.weight.shape == child.weight.shape: new_layer_norm.weight.copy_(child.weight) else: logging.warning(f"Shape mismatch reverting BypassLN weight {name}. Expected {new_layer_norm.weight.shape}, got {child.weight.shape}. Initializing LayerNorm weight.") nn.init.ones_(new_layer_norm.weight) elif new_layer_norm.weight is not None: nn.init.ones_(new_layer_norm.weight) if hasattr(child, 'bias') and child.bias is not None and new_layer_norm.bias is not None: if new_layer_norm.bias.shape == child.bias.shape: new_layer_norm.bias.copy_(child.bias) else: logging.warning(f"Shape mismatch reverting BypassLN bias {name}. Expected {new_layer_norm.bias.shape}, got {child.bias.shape}. Initializing LayerNorm bias.") nn.init.zeros_(new_layer_norm.bias) elif new_layer_norm.bias is not None: nn.init.zeros_(new_layer_norm.bias) if _recursive_setattr(module, name, new_layer_norm): reverted_count += 1 processed_names.add(name) logging.debug(f"Reverted BypassLayerNorm {name} to standard nn.LayerNorm.") else: logging.warning(f"Failed to revert BypassLayerNorm for {name}") processed_names.add(name) except Exception as e: logging.error(f"Error reverting BypassLayerNorm {name} to standard LayerNorm: {e}\n{traceback.format_exc()}") processed_names.add(name) if reverted_count > 0: config.replaced_layer_norm = False config.layer_norm_bypassed = False logging.info(f"Reverted {reverted_count} BypassLayerNorm layers back to standard nn.LayerNorm.") return f"Reverted {reverted_count} BypassLayerNorm layers." else: config.replaced_layer_norm = False config.layer_norm_bypassed = False logging.info("No BypassLayerNorm layers found to revert. Resetting flags.") return "No BypassLayerNorm layers found to revert." def _enable_layer_norm_bypass(model): count = 0 found_bypass_layers = False for m in model.modules(): if isinstance(m, BypassLayerNorm): found_bypass_layers = True if not m.bypass: m.bypass = True count += 1 if not found_bypass_layers: if getattr(model.config, 'replaced_layer_norm', False): logging.warning("Config indicates LN were replaced with BypassLN, but none found. Cannot enable bypass.") model.config.layer_norm_bypassed = False return "Replaced LN flag is true, but no BypassLN layers found. Run 'Replace LN' first or revert." else: return "No BypassLayerNorm layers found in the model. Replace standard LayerNorm first to enable bypass functionality." elif count > 0: model.config.layer_norm_bypassed = True logging.info(f"Enabled bypass for {count} BypassLayerNorm layers.") return f"Enabled bypass for {count} LN layers." else: model.config.layer_norm_bypassed = True logging.info("All existing BypassLayerNorm layers already have bypass enabled.") return "No changes made (layers might already be bypassed)." def _disable_layer_norm_bypass(model): count = 0 found_bypass_layers = False for m in model.modules(): if isinstance(m, BypassLayerNorm): found_bypass_layers = True if m.bypass: m.bypass = False count += 1 if not found_bypass_layers: if getattr(model.config, 'replaced_layer_norm', False): model.config.layer_norm_bypassed = False return "Replaced LN flag is true, but no BypassLN layers found to disable bypass on." else: return "No BypassLayerNorm layers found in the model to disable bypass on." elif count > 0: model.config.layer_norm_bypassed = False logging.info(f"Disabled bypass for {count} BypassLayerNorm layers.") return f"Disabled bypass for {count} LN layers." else: model.config.layer_norm_bypassed = False logging.info("All existing BypassLayerNorm layers already have bypass disabled.") return "No changes made (layers might already have bypass disabled)." def _replace_dropout_with_bypass(module, config): replaced_count = 0 modules_to_process = list(module.named_modules()) processed_names = set() for name, child in modules_to_process: if name in processed_names: continue if type(child) == nn.Dropout: try: new_dropout = BypassDropout(child.p, child.inplace) try: parent_name = '.'.join(name.split('.')[:-1]) parent_module = module.get_submodule(parent_name) if parent_name else module first_param = next(iter(parent_module.parameters()), None) if first_param is not None: new_dropout.to(device=first_param.device) except Exception: new_dropout.to(device=get_device()) if _recursive_setattr(module, name, new_dropout): replaced_count += 1 processed_names.add(name) logging.debug(f"Replaced Dropout {name} with BypassDropout.") else: logging.warning(f"Failed to set BypassDropout for {name}") processed_names.add(name) except Exception as e: logging.error(f"Error replacing Dropout {name} with Bypass version: {e}") processed_names.add(name) if replaced_count > 0: config.replaced_dropout = True config.dropout_bypassed = False logging.info(f"Replaced {replaced_count} nn.Dropout layers with BypassDropout version.") return f"Replaced {replaced_count} Dropout layers." else: logging.info("No standard nn.Dropout layers found to replace with BypassDropout.") return "No standard Dropout layers found to replace." def _revert_bypass_dropout(module, config): if not getattr(config, 'replaced_dropout', False): return "BypassDropout not previously applied according to config flag. Cannot revert." reverted_count = 0 modules_to_process = list(module.named_modules()) processed_names = set() for name, child in modules_to_process: if name in processed_names: continue if isinstance(child, BypassDropout): try: new_dropout = nn.Dropout(child.p, child.inplace) try: parent_name = '.'.join(name.split('.')[:-1]) parent_module = module.get_submodule(parent_name) if parent_name else module first_param = next(iter(parent_module.parameters()), None) if first_param is not None: new_dropout.to(device=first_param.device) except Exception: new_dropout.to(device=get_device()) if _recursive_setattr(module, name, new_dropout): reverted_count += 1 processed_names.add(name) logging.debug(f"Reverted BypassDropout {name} to standard nn.Dropout.") else: logging.warning(f"Failed to revert BypassDropout for {name}") processed_names.add(name) except Exception as e: logging.error(f"Error reverting BypassDropout {name} to standard nn.Dropout: {e}") processed_names.add(name) if reverted_count > 0: config.replaced_dropout = False config.dropout_bypassed = False logging.info(f"Reverted {reverted_count} BypassDropout layers back to standard nn.Dropout.") return f"Reverted {reverted_count} BypassDropout layers." else: config.replaced_dropout = False config.dropout_bypassed = False logging.info("No BypassDropout layers found to revert. Resetting flags.") return "No BypassDropout layers found to revert." def _enable_dropout_bypass(model): count = 0 found_bypass_layers = False for m in model.modules(): if isinstance(m, BypassDropout): found_bypass_layers = True if not m.bypass: m.bypass = True count += 1 if not found_bypass_layers: if getattr(model.config, 'replaced_dropout', False): model.config.dropout_bypassed = False return "Replaced Dropout flag is true, but no BypassDropout layers found. Run 'Replace DO' first or revert." else: return "No BypassDropout layers found in the model. Replace standard Dropout first to enable bypass." elif count > 0: model.config.dropout_bypassed = True logging.info(f"Enabled bypass for {count} BypassDropout layers.") return f"Enabled bypass for {count} Dropout layers." else: model.config.dropout_bypassed = True logging.info("All existing BypassDropout layers already have bypass enabled.") return "No changes made (layers might already be bypassed)." def _disable_dropout_bypass(model): count = 0 found_bypass_layers = False for m in model.modules(): if isinstance(m, BypassDropout): found_bypass_layers = True if m.bypass: m.bypass = False count += 1 if not found_bypass_layers: if getattr(model.config, 'replaced_dropout', False): model.config.dropout_bypassed = False return "Replaced Dropout flag is true, but no BypassDropout layers found to disable bypass on." else: return "No BypassDropout layers found in the model to disable bypass on." elif count > 0: model.config.dropout_bypassed = False logging.info(f"Disabled bypass for {count} BypassDropout layers.") return f"Disabled bypass for {count} Dropout layers." else: model.config.dropout_bypassed = False logging.info("All existing BypassDropout layers already have bypass disabled.") return "No changes made (layers might already have bypass disabled)." def _swap_activation_function(model, config, activation_fn_name): activation_fn_class = ACTIVATION_FUNCTIONS.get(activation_fn_name) if not activation_fn_class: msg = f"Warning: Activation function '{activation_fn_name}' not found or invalid. Using default '{DEFAULT_ACTIVATION_FUNCTION}'." logging.warning(msg) activation_fn_class = ACTIVATION_FUNCTIONS[DEFAULT_ACTIVATION_FUNCTION] activation_fn_name = DEFAULT_ACTIVATION_FUNCTION if not activation_fn_class: logging.error(f"Default activation function '{DEFAULT_ACTIVATION_FUNCTION}' is also missing! Cannot swap.") return f"Error: Cannot find '{activation_fn_name}' or the default '{DEFAULT_ACTIVATION_FUNCTION}'." else: msg = "" replaced_count = 0 current_act_classes = tuple(f for f in ACTIVATION_FUNCTIONS.values() if f is not None and inspect.isclass(f) and issubclass(f, nn.Module)) target_act_class = activation_fn_class modules_to_process = list(model.named_modules()) processed_names = set() for name, child in modules_to_process: if name in processed_names: continue if type(child) in current_act_classes: if type(child) == target_act_class: processed_names.add(name) continue try: new_activation = target_act_class() try: parent_name = '.'.join(name.split('.')[:-1]) parent_module = model.get_submodule(parent_name) if parent_name else module first_param = next(iter(parent_module.parameters()), None) if first_param is not None: new_activation.to(device=first_param.device) except Exception: new_activation.to(device=get_device()) if _recursive_setattr(model, name, new_activation): replaced_count += 1 processed_names.add(name) logging.debug(f"Swapped activation {name} from {type(child).__name__} to {target_act_class.__name__}") else: logging.warning(f"Failed to set new activation function for {name}") processed_names.add(name) except Exception as e: logging.error(f"Error replacing activation function {name} of type {type(child).__name__} with {target_act_class.__name__}: {e}") processed_names.add(name) if replaced_count > 0: msg += f"Swapped {replaced_count} activation functions to {activation_fn_name}." config.activation_function_swapped = True config.current_activation_function = activation_fn_name if hasattr(config, 'hidden_act'): config.hidden_act = activation_fn_name if hasattr(config, 'activation_function'): config.activation_function = activation_fn_name else: msg += f"No eligible activation functions found to swap to {activation_fn_name} (or already using it)." if not config.activation_function_swapped: current_in_config = getattr(config, 'hidden_act', getattr(config, 'activation_function', DEFAULT_ACTIVATION_FUNCTION)) config.current_activation_function = current_in_config if current_in_config in ACTIVATION_FUNCTIONS else DEFAULT_ACTIVATION_FUNCTION logging.info(msg) return msg def _revert_activation_function(model, config): current_activation = getattr(config, 'current_activation_function', DEFAULT_ACTIVATION_FUNCTION) was_swapped = getattr(config, 'activation_function_swapped', False) if not was_swapped and current_activation == DEFAULT_ACTIVATION_FUNCTION: return f"Activation function is already the default ('{DEFAULT_ACTIVATION_FUNCTION}') and was not marked as swapped." elif not was_swapped: logging.info(f"Activation function is '{current_activation}' but wasn't marked as swapped. Attempting to revert to '{DEFAULT_ACTIVATION_FUNCTION}' anyway.") pass else: logging.info(f"Reverting activation function from '{current_activation}' to default '{DEFAULT_ACTIVATION_FUNCTION}'...") result_msg = _swap_activation_function(model, config, DEFAULT_ACTIVATION_FUNCTION) config.activation_function_swapped = False config.current_activation_function = DEFAULT_ACTIVATION_FUNCTION if hasattr(config, 'hidden_act'): config.hidden_act = DEFAULT_ACTIVATION_FUNCTION if hasattr(config, 'activation_function'): config.activation_function = DEFAULT_ACTIVATION_FUNCTION final_msg = f"Reverted to default activation ('{DEFAULT_ACTIVATION_FUNCTION}'). Result: {result_msg}" return final_msg def _swap_normalization_layer(model, config, target_norm_type='RMSNorm'): device = get_device() swapped_count = 0 processed_names = set() if target_norm_type == 'RMSNorm': current_norm_class = nn.LayerNorm new_norm_class = RMSNorm config_flag_name = 'rms_norm_applied' target_flag_value = True elif target_norm_type == 'LayerNorm': current_norm_class = RMSNorm new_norm_class = nn.LayerNorm config_flag_name = 'rms_norm_applied' target_flag_value = False else: msg = f"Error: Unsupported target normalization type '{target_norm_type}'. Use 'RMSNorm' or 'LayerNorm'." logging.error(msg) return msg already_configured = getattr(config, config_flag_name, False) == target_flag_value has_current_norm_instances = any(isinstance(m, current_norm_class) for name, m in model.named_modules() if not isinstance(m, (BypassLayerNorm, new_norm_class))) if already_configured and not has_current_norm_instances: logging.info(f"Model config flag '{config_flag_name}' is already {target_flag_value}, and no instances of {current_norm_class.__name__} found to swap. No action needed.") return f"Model already configured for {target_norm_type} (or no swappable layers found)." elif already_configured and has_current_norm_instances: logging.warning(f"Model config flag '{config_flag_name}' is {target_flag_value}, but instances of {current_norm_class.__name__} were found. Attempting swap anyway to ensure consistency.") pass elif not already_configured and not has_current_norm_instances: logging.info(f"No instances of {current_norm_class.__name__} found to swap to {target_norm_type}. Updating config flag to {target_flag_value}.") setattr(config, config_flag_name, target_flag_value) if hasattr(config, 'layer_norm_bypassed'): config.layer_norm_bypassed = False return f"No {current_norm_class.__name__} layers found to swap. Config flag '{config_flag_name}' set to {target_flag_value}." modules_to_process = list(model.named_modules()) for name, module in modules_to_process: if name in processed_names: continue if isinstance(module, current_norm_class) and not isinstance(module, BypassLayerNorm): try: eps = module.eps elementwise_affine = module.elementwise_affine module_device = get_device() module_dtype = torch.float32 params = list(module.parameters()) if params: module_device = params[0].device module_dtype = params[0].dtype elif elementwise_affine and hasattr(module, 'weight') and module.weight is not None: module_device = module.weight.device module_dtype = module.weight.dtype dim = None if isinstance(module, nn.LayerNorm): dim = module.normalized_shape elif isinstance(module, RMSNorm): if elementwise_affine and hasattr(module, 'weight') and module.weight is not None: dim = module.weight.shape[0] else: logging.warning(f"Cannot determine dimension for affine-less RMSNorm {name}. Cannot swap this layer.") processed_names.add(name) continue else: raise ValueError(f"Module {name} is unexpected type {type(module)} during norm swap.") if new_norm_class == nn.LayerNorm: if isinstance(dim, int): norm_arg = dim elif isinstance(dim, (list, tuple)): norm_arg = list(dim) else: raise ValueError(f"Unsupported dimension type {type(dim)} '{dim}' for creating LayerNorm from {current_norm_class.__name__} layer {name}.") new_norm = new_norm_class(norm_arg, eps=eps, elementwise_affine=elementwise_affine, device=module_device, dtype=module_dtype) elif new_norm_class == RMSNorm: if isinstance(dim, int): norm_arg = dim elif isinstance(dim, (list, tuple)): if len(dim) == 1: norm_arg = dim[0] else: logging.warning(f"LayerNorm shape {dim} has multiple dimensions. Using last dim ({dim[-1]}) for RMSNorm {name}.") norm_arg = dim[-1] else: raise ValueError(f"Unsupported dimension type {type(dim)} '{dim}' for creating RMSNorm from {current_norm_class.__name__} layer {name}.") new_norm = new_norm_class(norm_arg, eps=eps, elementwise_affine=elementwise_affine, device=module_device, dtype=module_dtype) else: raise ValueError("Invalid new_norm_class.") if elementwise_affine: with torch.no_grad(): if hasattr(module, 'weight') and module.weight is not None and hasattr(new_norm, 'weight') and new_norm.weight is not None: if new_norm.weight.shape == module.weight.shape: new_norm.weight.copy_(module.weight) else: logging.warning(f"Weight shape mismatch swapping norm {name}: {module.weight.shape} -> {new_norm.weight.shape}. Re-initializing target weight.") nn.init.ones_(new_norm.weight) elif hasattr(new_norm, 'weight') and new_norm.weight is not None: logging.debug(f"Initializing weight for new norm {name} as source lacked it.") nn.init.ones_(new_norm.weight) if hasattr(module, 'bias') and module.bias is not None and hasattr(new_norm, 'bias') and new_norm.bias is not None: if new_norm.bias.shape == module.bias.shape: new_norm.bias.copy_(module.bias) else: logging.warning(f"Bias shape mismatch swapping norm {name}: {module.bias.shape} -> {new_norm.bias.shape}. Re-initializing target bias.") nn.init.zeros_(new_norm.bias) elif hasattr(new_norm, 'bias') and new_norm.bias is not None: logging.debug(f"Initializing bias for new LayerNorm {name} as source RMSNorm lacked it.") nn.init.zeros_(new_norm.bias) if _recursive_setattr(model, name, new_norm): swapped_count += 1 processed_names.add(name) logging.debug(f"Swapped {current_norm_class.__name__} layer {name} to {new_norm_class.__name__}.") else: logging.warning(f"Failed to set swapped normalization layer for {name}") processed_names.add(name) except Exception as e: logging.error(f"Error swapping norm layer {name} from {current_norm_class.__name__} to {new_norm_class.__name__}: {e}\n{traceback.format_exc()}") processed_names.add(name) if swapped_count > 0: setattr(config, config_flag_name, target_flag_value) if hasattr(config, 'layer_norm_bypassed'): config.layer_norm_bypassed = False msg = f"Swapped {swapped_count} {current_norm_class.__name__} layers to {new_norm_class.__name__}." else: if not already_configured: setattr(config, config_flag_name, target_flag_value) if hasattr(config, 'layer_norm_bypassed'): config.layer_norm_bypassed = False msg = f"No {current_norm_class.__name__} layers found or matched criteria to swap to {new_norm_class.__name__}. Updated config flag." else: msg = f"No {current_norm_class.__name__} layers were swapped (already configured or other issue)." logging.info(msg) return msg def _normalize_embeddings(model, config): emb_layer = None if hasattr(model, 'get_input_embeddings'): try: emb_layer_candidate = model.get_input_embeddings() if isinstance(emb_layer_candidate, nn.Embedding): emb_layer = emb_layer_candidate logging.info("Found embedding layer via get_input_embeddings()") except Exception as e: logging.warning(f"Error calling get_input_embeddings(): {e}") if emb_layer is None: potential_emb_names = ['embed_tokens', 'wte', 'word_embeddings', 'embeddings.word_embeddings', 'shared'] model_base = getattr(model, 'model', model) for name in potential_emb_names: try: candidate = model_base parts = name.split('.') valid_path = True for part in parts: if hasattr(candidate, part): candidate = getattr(candidate, part) if candidate is None: valid_path = False break else: valid_path = False break if valid_path and isinstance(candidate, nn.Embedding) and hasattr(candidate, 'weight') and candidate.weight is not None: emb_layer = candidate logging.info(f"Found embedding layer via attribute: '{name}'") break except AttributeError: continue except Exception as e: logging.warning(f"Error accessing potential embedding layer '{name}': {e}") if emb_layer is not None and hasattr(emb_layer, 'weight') and emb_layer.weight is not None: try: with torch.no_grad(): w = emb_layer.weight.data norms = torch.norm(w, p=2, dim=-1, keepdim=True) safe_norms = norms.clamp(min=1e-12) w.div_(safe_norms) config.embedding_normalized = True logging.info("Input embeddings normalized (L2 norm).") return "Input embeddings normalized (L2 norm)." except Exception as e: logging.error(f"Error normalizing embeddings: {e}") config.embedding_normalized = False return f"Error normalizing embeddings: {e}" else: msg="Input embedding layer or its weights not found using common methods. Cannot normalize." logging.warning(msg) config.embedding_normalized = False return msg def _revert_embedding_normalization(model, config): if not getattr(config, 'embedding_normalized', False): return "Embedding normalization flag is already false (or was never applied)." config.embedding_normalized = False logging.info("Embedding normalization flag reverted. Note: Original embedding weights are NOT restored.") return "Embedding normalization flag reverted (weights NOT restored)." def _prune_weights_magnitude(model, config, amount=0.2): if not isinstance(amount, (float, int)) or not (0 < amount < 1): msg="Error: Pruning amount must be a float between 0 and 1 (exclusive)." logging.error(msg) return msg logging.info(f"Applying global unstructured L1 magnitude pruning (amount={amount:.2f})...") device = get_device() model.to(device) params_to_prune = [] for module_name, module in model.named_modules(): if isinstance(module, (nn.Linear, BitLinear)): if hasattr(module, 'weight') and module.weight is not None and module.weight.requires_grad: params_to_prune.append((module, 'weight')) if not params_to_prune: msg="No prunable Linear or BitLinear layers with trainable weights found." logging.warning(msg) config.pruning_applied = False config.pruning_amount = None return msg try: prune.global_unstructured( parameters=params_to_prune, pruning_method=prune.L1Unstructured, amount=amount ) pruned_count = 0 total_params = 0 modules_made_permanent = 0 for module, name in params_to_prune: if prune.is_pruned(module): prune.remove(module, name) modules_made_permanent += 1 if hasattr(module, name): weight = getattr(module, name) if weight is not None: pruned_count += torch.sum(weight == 0).item() total_params += weight.nelement() if modules_made_permanent > 0: sparsity = 100. * pruned_count / total_params if total_params > 0 else 0 msg = (f"Pruning applied and made permanent on {modules_made_permanent} parameter groups. " f"Final Sparsity: {sparsity:.2f}% ({pruned_count:,}/{total_params:,} zeros).") config.pruning_applied = True config.pruning_amount = amount elif any(prune.is_pruned(mod) for mod, _ in params_to_prune): msg = "Pruning hooks were applied but removal failed or was incomplete. Pruning might not be permanent." config.pruning_applied = False config.pruning_amount = None else: msg = "Pruning was attempted, but no modules seem to have been pruned or made permanent." config.pruning_applied = False config.pruning_amount = None except Exception as e: msg = f"Error during pruning: {e}\n{traceback.format_exc()}" logging.error(msg) for module, name in params_to_prune: if prune.is_pruned(module): try: prune.remove(module, name) logging.info(f"Cleaned up pruning hook from {module}.{name} during error handling.") except Exception as remove_e: logging.warning(f"Couldn't remove pruning hook from {module}.{name} during cleanup: {remove_e}") config.pruning_applied = False config.pruning_amount = None logging.info(msg) return msg def _revert_pruning(model, config): if not getattr(config, 'pruning_applied', False): return "Pruning flag is already false (or pruning was never applied/made permanent)." config.pruning_applied = False config.pruning_amount = None logging.info("Pruning flag reverted. Note: Pruned weights (zeros) are NOT restored.") return "Pruning flag reverted (weights NOT restored)." def _quantize_model(model, config, mode='bfloat16'): logging.info(f"Attempting to change model dtype to {mode}...") original_dtype_str = getattr(config, 'quantization_mode', DEFAULT_QUANTIZATION) target_dtype = None if mode == 'bfloat16': if torch.cuda.is_available() and torch.cuda.is_bf16_supported(): target_dtype = torch.bfloat16 else: msg="Device does not support bfloat16. Cannot quantize to bfloat16. Keeping current dtype." logging.warning(msg) return msg elif mode == 'float16': target_dtype = torch.float16 elif mode == 'float32': target_dtype = torch.float32 else: msg = f"Unsupported quantization mode '{mode}'. Choose from {QUANTIZATION_MODES}." logging.error(msg) return msg try: current_dtype = next(iter(model.parameters()), torch.tensor([])).dtype if not isinstance(current_dtype, torch.dtype): msg = "Model has no parameters. Cannot determine or change dtype." logging.error(msg) return msg except StopIteration: msg = "Model has no parameters. Cannot determine or change dtype." logging.error(msg) return msg except Exception as e: msg = f"Could not determine current model dtype: {e}" logging.error(msg) return msg if current_dtype == target_dtype: msg = f"Model is already in {mode} ({target_dtype}). No change needed." logging.info(msg) config.quantization_applied = (mode != 'float32') config.quantization_mode = mode config.perfect_precision_recovered = (mode == 'float32') return msg try: device = get_device() model.to(device=device, dtype=target_dtype) new_dtype = next(iter(model.parameters()), torch.tensor([])).dtype if new_dtype == target_dtype: config.quantization_applied = (mode != 'float32') config.quantization_mode = mode config.perfect_precision_recovered = (mode == 'float32') msg = f"Model dtype successfully changed to {mode} ({target_dtype}) on device {device}." logging.info(msg) clean_memory() return msg else: logging.error(f"Model dtype did not change as expected after .to() call. Still {new_dtype}. Reverting config flags.") config.quantization_applied = (original_dtype_str != 'float32') config.quantization_mode = original_dtype_str config.perfect_precision_recovered = (original_dtype_str == 'float32') raise RuntimeError(f"Model dtype did not change as expected. Still {new_dtype}.") except Exception as e: msg=f"Error converting model to {target_dtype}: {e}\n{traceback.format_exc()}" logging.error(msg) config.quantization_applied = (original_dtype_str != 'float32') config.quantization_mode = original_dtype_str config.perfect_precision_recovered = (original_dtype_str == 'float32') try: original_torch_dtype = getattr(torch, original_dtype_str, torch.float32) model.to(device=device, dtype=original_torch_dtype) logging.info(f"Attempted to restore model to original dtype {original_dtype_str} after error.") except Exception as revert_e: logging.error(f"Failed to restore original dtype after error: {revert_e}") return msg def _revert_quantization(model, config): logging.info("Reverting quantization to float32...") return _quantize_model(model, config, mode='float32') def _freeze_layers(model, config, layers_to_freeze_str): if not layers_to_freeze_str or not isinstance(layers_to_freeze_str, str): msg="No layers specified to freeze or invalid input type." logging.warning(msg) config.frozen_layers = None return msg layer_indices = set() try: raw_parts = layers_to_freeze_str.split(',') for part in raw_parts: part = part.strip() if not part: continue if '-' in part: start_end = part.split('-') if len(start_end) == 2: s = int(start_end[0].strip()) e = int(start_end[1].strip()) if s < 0 or e < 0: raise ValueError("Negative indices are not allowed.") if s <= e: layer_indices.update(range(s, e + 1)) else: layer_indices.update(range(e, s + 1)) logging.warning(f"Interpreted range '{part}' as descending: {list(range(e, s + 1))}") else: raise ValueError(f"Invalid range format: {part}") else: idx = int(part) if idx < 0: raise ValueError("Negative indices are not allowed.") layer_indices.add(idx) except ValueError as e: msg=f"Error parsing layer specification '{layers_to_freeze_str}': {e}. Use non-negative, comma-separated numbers or ranges (e.g., '0-3, 7, 10-11')." logging.error(msg) return msg layer_module, layer_attr, layer_list = _find_decoder_layers_module(model) if not (layer_module and layer_attr and layer_list is not None): msg="Could not determine layer structure for freezing. No layers frozen." logging.warning(msg) return msg total_layers = len(layer_list) frozen_params_count = 0 actual_frozen_indices = set() unfrozen_globally = 0 for param in model.parameters(): if not param.requires_grad: param.requires_grad = True unfrozen_globally += 1 if unfrozen_globally > 0: logging.info(f"Unfroze {unfrozen_globally} parameters globally before applying new freeze spec.") else: logging.info("No parameters were frozen globally before applying new spec.") invalid_indices_skipped = set() for i in layer_indices: if 0 <= i < total_layers: try: current_layer = layer_list[i] params_in_layer = 0 for param in current_layer.parameters(): if param.requires_grad: param.requires_grad = False frozen_params_count += 1 params_in_layer += 1 if params_in_layer > 0: actual_frozen_indices.add(i) logging.debug(f"Froze {params_in_layer} parameters in layer {i}.") else: logging.debug(f"Layer {i} had no trainable parameters to freeze.") except IndexError: logging.warning(f"Index {i} seems out of bounds for layer list during freezing loop, although check passed earlier. Skipping.") invalid_indices_skipped.add(i) except Exception as e: logging.error(f"Error accessing or freezing parameters for layer {i}: {e}") invalid_indices_skipped.add(i) else: logging.warning(f"Layer index {i} is out of bounds (0-{total_layers-1}). Skipping.") invalid_indices_skipped.add(i) frozen_list_str = ",".join(map(str, sorted(list(actual_frozen_indices)))) config.frozen_layers = frozen_list_str if actual_frozen_indices else None msg = f"Froze {frozen_params_count} parameters in layers: {frozen_list_str} (Total layers: {total_layers})." if invalid_indices_skipped: msg += f" Skipped invalid indices: {sorted(list(invalid_indices_skipped))}." if frozen_params_count == 0 and not invalid_indices_skipped: msg = f"No parameters were frozen. Specified layers {frozen_list_str} might have already been frozen or had no trainable params." logging.info(msg) return msg def _unfreeze_all_layers(model, config): unfrozen_count = 0 for name, param in model.named_parameters(): if not param.requires_grad: param.requires_grad = True unfrozen_count += 1 config.frozen_layers = None msg = f"Unfroze {unfrozen_count} parameters across the entire model." if unfrozen_count > 0 else "No parameters needed unfreezing." logging.info(msg) return msg def _enable_gradient_checkpointing(model, config): gc_enabled_in_model = False if hasattr(model, 'gradient_checkpointing_enable'): try: sig = inspect.signature(model.gradient_checkpointing_enable) if 'gradient_checkpointing_kwargs' in sig.parameters: model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={'use_reentrant': False}) msg = "Gradient Checkpointing enabled via model method (non-reentrant)." else: model.gradient_checkpointing_enable() msg = "Gradient Checkpointing enabled via model method." gc_enabled_in_model = True logging.info(msg) except Exception as e: logging.warning(f"Failed to enable gradient checkpointing via standard method: {e}. Trying config attribute.") if hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing'): if not gc_enabled_in_model: logging.info("Enabling gradient checkpointing via model config attribute.") model.config.gradient_checkpointing = True gc_enabled_in_model = True if gc_enabled_in_model: if hasattr(model.config, 'use_cache'): if model.config.use_cache: model.config.use_cache = False logging.info("Set model.config.use_cache = False (required for Gradient Checkpointing).") else: logging.warning("Model config missing 'use_cache' attribute. Gradient checkpointing might not work correctly or efficiently.") config.gradient_checkpointing_enabled = True final_msg = "Gradient Checkpointing enabled." if not hasattr(model, 'gradient_checkpointing_enable') and not (hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing')): final_msg += " (Set via main config flag only; ensure Trainer args/model support it)." return final_msg else: config.gradient_checkpointing_enabled = False msg = "Could not enable Gradient Checkpointing via model methods or config attributes." logging.error(msg) return f"[Error] {msg}" def _disable_gradient_checkpointing(model, config): gc_disabled_in_model = False if hasattr(model, 'gradient_checkpointing_disable'): try: model.gradient_checkpointing_disable() gc_disabled_in_model = True logging.info("Gradient Checkpointing disabled via model method.") except Exception as e: logging.warning(f"Failed to disable gradient checkpointing via standard method: {e}. Trying config attribute.") if hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing'): if not gc_disabled_in_model: logging.info("Disabling gradient checkpointing via model config attribute.") model.config.gradient_checkpointing = False gc_disabled_in_model = True if gc_disabled_in_model: if hasattr(model.config, 'use_cache'): if not model.config.use_cache: model.config.use_cache = True logging.info("Set model.config.use_cache = True (restored after disabling Gradient Checkpointing).") config.gradient_checkpointing_enabled = False final_msg = "Gradient Checkpointing disabled." if not hasattr(model, 'gradient_checkpointing_disable') and not (hasattr(model, 'config') and hasattr(model.config, 'gradient_checkpointing')): final_msg += " (Set via main config flag only)." return final_msg else: config.gradient_checkpointing_enabled = False msg = "Could not disable Gradient Checkpointing via model methods or config attributes (may not have been enabled)." logging.warning(msg) return msg def _swap_optimizer(config, optimizer_name): if optimizer_name in OPTIMIZERS: config.optimizer = optimizer_name global DEFAULT_OPTIMIZER DEFAULT_OPTIMIZER = optimizer_name msg=f"Optimizer preference set to '{optimizer_name}' in config. This will be used by the Trainer if training starts." logging.info(msg) return msg else: available_opts = ", ".join(OPTIMIZERS.keys()) msg=f"Error: Optimizer '{optimizer_name}' unknown or not available. Choose from: {available_opts}." logging.error(msg) return msg def _revert_optimizer(config): original_default_optimizer = "adamw_torch" logging.info(f"Reverting optimizer preference to script default: '{original_default_optimizer}'.") return _swap_optimizer(config, original_default_optimizer) def _untie_embeddings(model, config): try: input_embeddings = model.get_input_embeddings() output_embeddings = model.get_output_embeddings() if output_embeddings is None: if hasattr(model, 'lm_head') and isinstance(model.lm_head, nn.Linear): output_embeddings = model.lm_head logging.info("Using 'lm_head' as the output embedding layer for untie check.") else: msg="Could not get output embedding layer (get_output_embeddings() returned None and 'lm_head' not found/Linear). Cannot untie." logging.warning(msg) config.untied_embeddings = True if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = False return msg if input_embeddings is None: msg="Could not get input embedding layer. Cannot untie." logging.warning(msg) config.untied_embeddings = False if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = True return msg are_tied = False if hasattr(input_embeddings, "weight") and hasattr(output_embeddings, "weight") and \ input_embeddings.weight is not None and output_embeddings.weight is not None: if input_embeddings.weight.data_ptr() == output_embeddings.weight.data_ptr(): are_tied = True elif input_embeddings.weight.storage().data_ptr() == output_embeddings.weight.storage().data_ptr(): are_tied = True logging.info("Weights appear tied (share storage).") if are_tied: logging.info("Detected tied input/output embeddings. Attempting to untie...") device = input_embeddings.weight.device dtype = input_embeddings.weight.dtype new_output_weight = input_embeddings.weight.clone().detach() new_output_weight.requires_grad_(output_embeddings.weight.requires_grad) output_embeddings.weight = nn.Parameter(new_output_weight.to(device, dtype=dtype)) if hasattr(input_embeddings, "bias") and input_embeddings.bias is not None and \ hasattr(output_embeddings, "bias") and output_embeddings.bias is not None and \ input_embeddings.bias.data_ptr() == output_embeddings.bias.data_ptr(): logging.info("Detected tied bias, untying as well.") new_output_bias = input_embeddings.bias.clone().detach() new_output_bias.requires_grad_(output_embeddings.bias.requires_grad) output_embeddings.bias = nn.Parameter(new_output_bias.to(device, dtype=dtype)) if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = False config.untied_embeddings = True msg="Embeddings untied successfully (output layer weights/bias are now distinct copies)." logging.info(msg) clean_memory() return msg else: config.untied_embeddings = True if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = False msg="Embeddings are already untied (or weights are missing/different objects)." logging.info(msg) return msg except Exception as e: msg=f"Error untying embeddings: {e}\n{traceback.format_exc()}" logging.error(msg) return msg def _retie_embeddings(model, config): if not getattr(config, 'untied_embeddings', False): try: input_emb = model.get_input_embeddings() output_emb = model.get_output_embeddings() if output_emb is None and hasattr(model, 'lm_head') and isinstance(model.lm_head, nn.Linear): output_emb = model.lm_head if input_emb is not None and output_emb is not None and \ hasattr(input_emb, 'weight') and input_emb.weight is not None and \ hasattr(output_emb, 'weight') and output_emb.weight is not None and \ input_emb.weight.data_ptr() == output_emb.weight.data_ptr(): msg = "Embeddings seem already tied. Resetting flag if needed." config.untied_embeddings = False if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = True logging.info(msg) return msg else: msg = "Cannot re-tie: Flag 'untied_embeddings' is false or cannot verify current state." logging.info(msg) return msg except Exception as e: msg = f"Cannot re-tie: Error checking current state ({e}). Flag 'untied_embeddings' is false." logging.warning(msg) return msg try: input_embeddings = model.get_input_embeddings() output_embeddings = model.get_output_embeddings() if output_embeddings is None and hasattr(model, 'lm_head') and isinstance(model.lm_head, nn.Linear): output_embeddings = model.lm_head logging.info("Using 'lm_head' as output layer for re-tying.") if input_embeddings is None or output_embeddings is None: msg="Could not get both input and output embedding layers for re-tying." logging.warning(msg) return msg if hasattr(input_embeddings, "weight") and input_embeddings.weight is not None and \ hasattr(output_embeddings, "weight") and output_embeddings.weight is not None: if input_embeddings.weight.shape == output_embeddings.weight.shape: logging.info("Attempting to re-tie embeddings by sharing input embedding weight...") device = input_embeddings.weight.device dtype = input_embeddings.weight.dtype output_embeddings = output_embeddings.to(device=device, dtype=dtype) output_embeddings.weight = input_embeddings.weight if hasattr(output_embeddings, "bias") and output_embeddings.bias is not None: logging.info("Setting output embedding bias to None as part of re-tying.") output_embeddings.bias = None if hasattr(config, "tie_word_embeddings"): config.tie_word_embeddings = True config.untied_embeddings = False msg="Embeddings re-tied successfully (output layer now shares input layer's weight, bias set to None)." logging.info(msg) clean_memory() return msg else: msg = f"Cannot re-tie embeddings: Weight shapes mismatch. Input: {input_embeddings.weight.shape}, Output: {output_embeddings.weight.shape}." logging.warning(msg) return msg else: msg = "Cannot re-tie embeddings: Input or output embedding weights missing or None." logging.warning(msg) return msg except Exception as e: msg=f"Error re-tying embeddings: {e}\n{traceback.format_exc()}" logging.error(msg) return msg def _configure_limits(config): config.knowledge_date = "2045-03-28" config.cutoff_date = "2045-03-28" current_max_pos = getattr(config, 'max_position_embeddings', 512) new_max_pos = current_max_pos * 100 config.max_position_embeddings = new_max_pos config.limits_configured = True config.no_limits = True logging.info(f"Set knowledge/cutoff date flags and increased max_position_embeddings in config to {config.max_position_embeddings}.") return f"Limit-related flags configured (Knowledge Date: 2045, Max Pos Emb: {config.max_position_embeddings}). Requires model reload or RoPE scaling for actual effect." def _remove_limits_configuration(config): if not getattr(config, 'limits_configured', False): return "Limit configuration flags are already in their default state." config.knowledge_date = None config.cutoff_date = None config.limits_configured = False config.no_limits = False logging.info("Reset knowledge date and cutoff date flags in config. Max position embeddings remain modified.") return "Limit-related flags removed/reset. Max position embeddings NOT reverted." def _remove_qa_restrictions(config): config.qa_restrictions_removed = True logging.info("QA restrictions removal flag set in config. Actual effect depends on model usage/fine-tuning and inference logic.") return "QA Restrictions Removal Flag Enabled (symbolic)." def _enable_qa_restrictions(config): config.qa_restrictions_removed = False logging.info("QA restrictions removal flag disabled in config.") return "QA Restrictions Removal Flag Disabled (symbolic)." def _enable_coherence_improvement(config): config.coherence_improvement_enabled = True logging.info("Coherence improvement flag enabled. Inference will use beam search if this is active.") return "Coherence Improvement Flag ON (uses beam search in inference)." def _disable_coherence_improvement(config): config.coherence_improvement_enabled = False logging.info("Coherence improvement flag disabled.") return "Coherence Improvement Flag OFF." def _set_flag_only(config, flag_name, value, msg_on, msg_off): if not hasattr(config, flag_name): logging.warning(f"Config object does not have flag '{flag_name}'. Adding it.") bool_value = bool(value) setattr(config, flag_name, bool_value) msg = msg_on if bool_value else msg_off logging.info(f"Config flag '{flag_name}' set to {bool_value}. Message: {msg}") return msg def _apply_swa(model, config): return _set_flag_only(config, "swa_applied", True, "SWA flag set. Requires SWA callback/logic during training.", "SWA flag disabled.") def _revert_swa(model, config): return _set_flag_only(config, "swa_applied", False, "SWA flag set.", "SWA flag disabled.") def _apply_knowledge_editing(model, config): return _set_flag_only(config, "knowledge_edited", True, "Knowledge Editing flag set. Indicates manual edits or specific editing techniques were applied (symbolic).", "Knowledge Editing flag disabled.") def _revert_knowledge_editing(model, config): return _set_flag_only(config, "knowledge_edited", False, "Knowledge Editing flag set.", "Knowledge Editing flag disabled.") def _apply_head_pruning(model, config): return _set_flag_only(config, "head_pruning_applied", True, "Head Pruning flag set. Requires specific pruning implementation outside this script (symbolic).", "Head Pruning flag disabled.") def _revert_head_pruning(model, config): return _set_flag_only(config, "head_pruning_applied", False, "Head Pruning flag set.", "Head Pruning flag disabled.") def _apply_qat(model, config): return _set_flag_only(config, "qat_applied", True, "QAT flag set. Requires Quantization-Aware Training setup and execution (symbolic).", "QAT flag disabled.") def _revert_qat(model, config): return _set_flag_only(config, "qat_applied", False, "QAT flag set.", "QAT flag disabled.") def _apply_architecture_merge_flag(model, config): return _set_flag_only(config, "architecture_merged", True, "Architecture Merged flag set. Indicates model is likely a result of parameter averaging.", "Architecture Merged flag disabled.") def _revert_architecture_merge_flag(model, config): return _set_flag_only(config, "architecture_merged", False, "Architecture Merged flag set.", "Architecture Merged flag disabled.") def _apply_weight_init(model, config): return _set_flag_only(config, "weight_init_applied", True, "Weight Initialization flag set. Indicates a specific init strategy was used (symbolic).", "Weight Initialization flag disabled.") def _revert_weight_init(model, config): return _set_flag_only(config, "weight_init_applied", False, "Weight Initialization flag set.", "Weight Initialization flag disabled.") def _apply_gradient_noise(model, config): return _set_flag_only(config, "gradient_noise_applied", True, "Gradient Noise flag set. Requires implementation in optimizer/trainer (symbolic).", "Gradient Noise flag disabled.") def _revert_gradient_noise(model, config): return _set_flag_only(config, "gradient_noise_applied", False, "Gradient Noise flag set.", "Gradient Noise flag disabled.") def _apply_additional_mechanisms(base_model, config): logging.info("Applying various additional experimental mechanisms flags and simple optimizations...") _set_flag_only(config, "enhanced_security_enabled", True, "Enhanced Security Flag ON.", "Enhanced Security Flag OFF.") _set_flag_only(config, "debug_mode_enabled", True, "Debug Mode Flag ON.", "Debug Mode Flag OFF.") _set_flag_only(config, "internal_logging_enabled", True, "Internal Logging Flag ON.", "Internal Logging Flag OFF.") _set_flag_only(config, "drift_detection_enabled", True, "Drift Detection Flag ON.", "Drift Detection Flag OFF.") _set_flag_only(config, "ultra_fast_mode", True, "Ultra Fast Mode Flag ON.", "Ultra Fast Mode Flag OFF.") coherence_msg = _enable_coherence_improvement(config) speed_msg = _optimize_token_generation_speed(config) config.additional_mechanisms_applied = True logging.info("Applied various additional mechanism flags and optimizations.") return f"Applied Additional Mechanism Flags & Optimizations. Coherence: {coherence_msg}, Speed: {speed_msg}" def _disable_additional_mechanisms(config): if not getattr(config, 'additional_mechanisms_applied', False): return "Additional mechanisms flag is already off. No changes made." logging.info("Disabling various additional experimental mechanisms flags and reverting optimizations...") _set_flag_only(config, "enhanced_security_enabled", False, "Enhanced Security Flag ON.", "Enhanced Security Flag OFF.") _set_flag_only(config, "debug_mode_enabled", False, "Debug Mode Flag ON.", "Debug Mode Flag OFF.") _set_flag_only(config, "internal_logging_enabled", False, "Internal Logging Flag ON.", "Internal Logging Flag OFF.") _set_flag_only(config, "drift_detection_enabled", False, "Drift Detection Flag ON.", "Drift Detection Flag OFF.") _set_flag_only(config, "ultra_fast_mode", False, "Ultra Fast Mode Flag ON.", "Ultra Fast Mode Flag OFF.") coherence_msg = _disable_coherence_improvement(config) speed_msg = _revert_token_generation_speed_optimization(config) config.additional_mechanisms_applied = False logging.info("Disabled various additional mechanism flags and reverted optimizations.") return f"Disabled Additional Mechanism Flags & Reverted Optimizations. Coherence: {coherence_msg}, Speed: {speed_msg}" def _disable_all_safety_settings(config): flags_to_disable = [ "response_filters", "safety_settings_enabled", "harassment_filter", "hate_filter", "sexually_explicit_filter", "dangerous_content_filter", "civic_integrity_filter", "code_filter", "medical_advice_filter", "legal_advice_filter", "financial_advice_filter", "pii_filter", "political_filter", "religious_filter", "profanity_filter", "stereotype_filter", "misinfo_filter", "self_harm_filter", "personal_attack_filter", "toxicity_filter", "spam_filter", "off_topic_filter", "tone_filter", "min_max_length_filter", "repetition_filter_enabled", "factuality_filter_enabled" ] flags_to_set_true = [ "remove_censorship", "no_response_filters", "no_advert_warning", "no_limits" ] config = initialize_config_flags(config) updated_flags = 0 for flag in flags_to_disable: if hasattr(config, flag) and getattr(config, flag) is not False: setattr(config, flag, False) updated_flags += 1 for flag in flags_to_set_true: if hasattr(config, flag) and getattr(config, flag) is not True: setattr(config, flag, True) updated_flags += 1 config.safety_settings_enabled = False config.response_filters = False logging.info(f"Disabled all known safety/content filters and related flags in config ({updated_flags} flags updated).") return "All safety filter flags disabled in config." def _enable_all_safety_settings(config): flags_to_set_default_true = [ "safety_settings_enabled", "response_filters", "harassment_filter", "hate_filter", "sexually_explicit_filter", "dangerous_content_filter", "self_harm_filter", "pii_filter", "min_max_length_filter", "toxicity_filter", "personal_attack_filter", ] flags_to_set_optional_true = [ "civic_integrity_filter", "code_filter", "medical_advice_filter", "legal_advice_filter", "financial_advice_filter", "political_filter", "religious_filter", "profanity_filter", "stereotype_filter", "misinfo_filter", "spam_filter", "off_topic_filter", "tone_filter" ] flags_to_set_false = [ "remove_censorship", "no_response_filters", "no_advert_warning", "no_limits" ] flags_to_set_default_false = [ "repetition_filter_enabled", "factuality_filter_enabled" ] config = initialize_config_flags(config) updated_flags = 0 all_flags_to_enable = flags_to_set_default_true + flags_to_set_optional_true for flag in all_flags_to_enable: if hasattr(config, flag) and getattr(config, flag) is not True: setattr(config, flag, True) updated_flags += 1 for flag in flags_to_set_false: if hasattr(config, flag) and getattr(config, flag) is not False: setattr(config, flag, False) updated_flags += 1 for flag in flags_to_set_default_false: if hasattr(config, flag) and getattr(config, flag) is not False: setattr(config, flag, False) updated_flags += 1 config.safety_settings_enabled = True config.response_filters = True logging.info(f"Enabled default safety/content filters and related flags in config ({updated_flags} flags updated).") return "Default safety filter flags enabled in config." def _remove_inconsistencias_and_biases(base_model, config): bias_adjusted_count = 0 params_adjusted_count = 0 device = get_device() base_model.to(device) if getattr(config, 'inconsistencies_biases_removed', False): return "Inconsistencies/Biases removal flag already set. No action taken." with torch.no_grad(): for name, param in base_model.named_parameters(): if "bias" in name and isinstance(param, nn.Parameter) and param.requires_grad: if any(lin_name in name.lower() for lin_name in ['linear', 'dense', 'fc', 'out_proj', 'q_proj', 'k_proj', 'v_proj', 'wi', 'wo', 'lm_head']): try: original_mean = torch.mean(param.data.float()).item() if abs(original_mean) > 1e-6: param.sub_(original_mean) bias_adjusted_count += 1 params_adjusted_count += param.numel() logging.debug(f"Centered bias for {name} (original mean: {original_mean:.4e})") except Exception as e: logging.warning(f"Could not center bias for {name}: {e}") if bias_adjusted_count > 0: config.inconsistencies_biases_removed = True logging.info(f"Centered {bias_adjusted_count} bias terms ({params_adjusted_count} parameters) to potentially reduce inconsistencies.") return f"{bias_adjusted_count} bias terms centered." else: config.inconsistencies_biases_removed = True logging.info("Attempted bias centering, but no adjustable bias terms with significant mean found or no bias terms present.") return "Attempted bias centering (no significant changes made or no biases found)." def _reenable_inconsistencias_and_biases(config): if not getattr(config, 'inconsistencies_biases_removed', False): return "Inconsistencies/Biases removal flag already disabled." config.inconsistencies_biases_removed = False logging.info("Inconsistencies/Biases removal flag reverted. Note: Original bias values are NOT restored.") return "Inconsistencies/Biases removal flag reverted (biases NOT restored)." def _enable_layerdrop(config, probability=0.1): if not isinstance(probability, (float, int)) or not (0 <= probability <= 1): msg=f"Error: LayerDrop probability must be between 0 and 1. Got {probability}." logging.error(msg) return msg if hasattr(config, 'layerdrop'): config.layerdrop = float(probability) else: logging.warning("Config does not have a standard 'layerdrop' attribute. Setting custom flag only.") setattr(config, 'layerdrop', float(probability)) config.layerdrop_enabled = (probability > 0) config.layerdrop_prob = float(probability) logging.info(f"LayerDrop enabled flag set in config with probability {probability}. Actual effect depends on model architecture support during training/inference.") return f"LayerDrop flag {'ON' if probability > 0 else 'OFF'} (p={probability:.2f}). Requires model/Trainer support." def _disable_layerdrop(config): return _enable_layerdrop(config, probability=0.0) def _apply_lora_merge(model, config): global global_model adapter_path = getattr(config, 'lora_adapter_path', None) if not adapter_path: msg="No LoRA adapter path specified in config ('lora_adapter_path'). Use 'Set Path in Config' first or train/load an adapter." logging.warning(msg) return msg if not _peft_installed: msg="Error: PEFT library not installed, cannot merge LoRA." logging.error(msg) return msg current_model = model if not isinstance(current_model, PeftModel): logging.warning(f"Model is not a PeftModel. Attempting to load adapter '{adapter_path}' onto it first.") try: peft_model_instance = PeftModel.from_pretrained(current_model, adapter_path, is_trainable=False) current_model = peft_model_instance logging.info(f"Successfully loaded adapter '{adapter_path}' onto the base model.") except Exception as e: msg = f"Error loading adapter '{adapter_path}' onto base model: {e}\n{traceback.format_exc()}" logging.error(msg) return msg else: active_adapter = getattr(current_model, 'active_adapter', 'default') target_adapter_name = os.path.basename(os.path.normpath(adapter_path)) if not target_adapter_name: target_adapter_name = 'default' if target_adapter_name not in current_model.peft_config: logging.info(f"Adapter '{target_adapter_name}' (from path {adapter_path}) not found in existing PeftModel config. Loading it now.") try: current_model.load_adapter(adapter_path, adapter_name=target_adapter_name, is_trainable=False) logging.info(f"Loaded new adapter '{target_adapter_name}'.") except Exception as e: msg = f"Error loading adapter '{target_adapter_name}' from path '{adapter_path}' onto existing PeftModel: {e}\n{traceback.format_exc()}" logging.error(msg) return msg if active_adapter != target_adapter_name: try: current_model.set_adapter(target_adapter_name) logging.info(f"Set active adapter to '{target_adapter_name}' for merging.") active_adapter = target_adapter_name except Exception as e: msg = f"Error setting adapter '{target_adapter_name}' active on existing PeftModel: {e}\n{traceback.format_exc()}" logging.error(msg) return msg else: active_adapter = target_adapter_name try: logging.info(f"Merging active LoRA adapter ('{active_adapter}') into the base model..."); T = time.time() merged_model = current_model.merge_and_unload() merge_time = time.time() - T merged_config = merged_model.config merged_config = initialize_config_flags(merged_config) merged_config.lora_merged = True merged_config.lora_adapter_path = adapter_path merged_config.peft_adapter_added = False merged_config.peft_config = None global_model = merged_model config = merged_config msg = f"LoRA adapter '{active_adapter}' (from {adapter_path}) merged successfully in {merge_time:.2f}s. Global model updated to the merged base model." logging.info(msg) clean_memory() return msg except ValueError as ve: msg = f"Error merging LoRA adapter '{active_adapter}': {ve}. Adapter type might not support merging." logging.error(msg) return msg except Exception as e: msg = f"Error merging LoRA adapter '{active_adapter}': {e}\n{traceback.format_exc()}" logging.error(msg) return msg def _revert_lora_merge(model, config): if not getattr(config, 'lora_merged', False): return "LoRA merge flag is already false (or merge never applied/recorded)." config.lora_merged = False config.lora_adapter_path = None msg = "LoRA merge flag reverted. IMPORTANT: Model weights are NOT restored to pre-merge state. Reload the original base model if needed."; logging.warning(msg) return msg def _set_lora_adapter_path(config, path): if path and isinstance(path, str) and path.strip(): path = path.strip() config.lora_adapter_path = path msg = f"LoRA adapter path set in config to: '{path}'" logging.info(msg) return msg else: msg = "Invalid or empty LoRA adapter path provided. Path not set." logging.warning(msg) return msg def _setup_knowledge_distillation(model, config, num_labels=2): if not isinstance(num_labels, int) or num_labels <= 0: msg = f"Error: Number of labels for KD must be a positive integer, got {num_labels}." logging.error(msg) return msg try: device=get_device() try: dtype = next(iter(model.parameters())).dtype except StopIteration: dtype = torch.float32 if not isinstance(dtype, torch.dtype): dtype = torch.float32 classifier_name = 'kd_classifier' if hasattr(model, classifier_name): logging.warning(f"Model already has an attribute named '{classifier_name}'. Overwriting.") hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', getattr(config, 'embed_dim', None))) if not isinstance(hidden_size, int) or hidden_size <= 0: raise ValueError("Cannot setup KD: Model config missing valid 'hidden_size', 'd_model', or 'embed_dim' attribute.") classifier_layer = nn.Linear(hidden_size, num_labels).to(device, dtype=dtype) nn.init.xavier_uniform_(classifier_layer.weight) if classifier_layer.bias is not None: nn.init.zeros_(classifier_layer.bias) setattr(model, classifier_name, classifier_layer) if not hasattr(config, 'num_labels') or config.num_labels is None: config.num_labels = num_labels else: logging.warning(f"Model config already has 'num_labels'={config.num_labels}. KD setup might conflict if used for other classification tasks.") config.knowledge_distillation_setup = True config.kd_num_labels = num_labels msg = (f"Knowledge Distillation head ('{classifier_name}') added with {num_labels} labels (outputs). " f"Requires training changes: loss calculation using this head (e.g., cross-entropy on its logits), " f"and appropriate data format (e.g., sequence inputs + target labels).") logging.info(msg) return msg except Exception as e: msg = f"Error setting up Knowledge Distillation head: {e}\n{traceback.format_exc()}" logging.error(msg) if hasattr(model, 'kd_classifier'): delattr(model, 'kd_classifier') config.knowledge_distillation_setup = False config.kd_num_labels = None return msg def _revert_knowledge_distillation(model, config): classifier_name = 'kd_classifier' if hasattr(model, classifier_name): delattr(model, classifier_name) config.knowledge_distillation_setup = False config.kd_num_labels = None msg = f"Knowledge Distillation setup reverted (removed '{classifier_name}' head and reset config flags)." logging.info(msg) clean_memory() return msg else: config.knowledge_distillation_setup = False config.kd_num_labels = None msg = f"Knowledge Distillation head ('{classifier_name}') not found, nothing to revert. Reset flags." logging.info(msg) return msg def _setup_reward_modeling(model, config, num_outputs=1): if not isinstance(num_outputs, int) or num_outputs <= 0: msg = f"Error: Number of outputs for Reward Model must be a positive integer, got {num_outputs}." logging.error(msg) return msg try: device=get_device() try: dtype = next(iter(model.parameters())).dtype except StopIteration: dtype = torch.float32 if not isinstance(dtype, torch.dtype): dtype = torch.float32 rm_head_name = 'reward_head' if hasattr(model, rm_head_name): logging.warning(f"Model already has an attribute named '{rm_head_name}'. Overwriting.") hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', getattr(config, 'embed_dim', None))) if not isinstance(hidden_size, int) or hidden_size <= 0: raise ValueError("Cannot setup Reward Model head: Model config missing valid 'hidden_size', 'd_model', or 'embed_dim'.") reward_head = nn.Linear(hidden_size, num_outputs).to(device, dtype=dtype) nn.init.xavier_uniform_(reward_head.weight) if reward_head.bias is not None: nn.init.zeros_(reward_head.bias) setattr(model, rm_head_name, reward_head) config.reward_modeling_setup = True config.rm_num_outputs = num_outputs msg = (f"Reward Modeling head ('{rm_head_name}') added with {num_outputs} output(s). " f"Requires training changes: loss targeting rewards (e.g., ranking loss), specific data format (prompt, chosen_resp, rejected_resp), " f"and likely using the final hidden state of the sequence as input to this head.") logging.info(msg) return msg except Exception as e: msg = f"Error setting up Reward Modeling head: {e}\n{traceback.format_exc()}" logging.error(msg) if hasattr(model, 'reward_head'): delattr(model, 'reward_head') config.reward_modeling_setup = False config.rm_num_outputs = None return msg def _revert_reward_modeling(model, config): rm_head_name = 'reward_head' if hasattr(model, rm_head_name): delattr(model, rm_head_name) config.reward_modeling_setup = False config.rm_num_outputs = None msg = f"Reward Modeling setup reverted (removed '{rm_head_name}' head and reset config flags)." logging.info(msg) clean_memory() return msg else: config.reward_modeling_setup = False config.rm_num_outputs = None msg = f"Reward Modeling head ('{rm_head_name}') not found, nothing to revert. Reset flags." logging.info(msg) return msg def _set_rope_scaling_config(model, config, scaling_type="linear", factor=2.0): valid_types = ["linear", "dynamic"] if not scaling_type or not isinstance(scaling_type, str) or scaling_type not in valid_types: msg = f"Error: RoPE scaling type must be one of {valid_types}. Got '{scaling_type}'." logging.error(msg) return msg try: factor = float(factor) if factor < 1.0: raise ValueError("Factor must be >= 1.0.") if factor == 1.0: logging.warning(f"RoPE scaling factor set to {factor}, which implies no scaling.") except (ValueError, TypeError) as e: msg=f"Error: Invalid RoPE scaling factor '{factor}'. Must be a number >= 1.0. Error: {e}" logging.error(msg) return msg rope_config = {"type": scaling_type, "factor": factor} config.rope_scaling = rope_config config.rope_scaling_type = scaling_type config.rope_scaling_factor = factor msg = (f"RoPE Scaling set in config: type='{scaling_type}', factor={factor:.2f}. " f"Requires model architecture support and **reloading the model** with this config for the changes to take effect.") logging.warning(msg) return msg def _revert_rope_scaling(model, config): if hasattr(config, 'rope_scaling') and config.rope_scaling is not None: config.rope_scaling = None config.rope_scaling_type = None config.rope_scaling_factor = None msg = "RoPE Scaling configuration removed from config. Model reload required to revert RoPE behavior." logging.warning(msg) return msg else: config.rope_scaling_type = None config.rope_scaling_factor = None msg = "RoPE Scaling was not configured. No changes made." logging.info(msg) return msg def _set_sliding_window_config(model, config, window_size=4096): try: window_size = int(window_size) if window_size < 0: raise ValueError("Window size must be non-negative (0 or None to disable).") except (ValueError, TypeError) as e: msg=f"Error: Invalid sliding window size '{window_size}'. Must be a non-negative integer. Error: {e}" logging.error(msg) return msg effective_window_size = window_size if window_size > 0 else None config.sliding_window = effective_window_size config.sliding_window_size = effective_window_size if effective_window_size: msg = (f"Sliding Window Attention size set in config to: {effective_window_size}. " f"Requires model architecture support (e.g., Mistral) and potentially reloading the model.") else: msg = "Sliding Window Attention disabled in config (size set to 0 or None). Model reload may be needed." logging.warning(msg) return msg def _revert_sliding_window(model, config): if hasattr(config, 'sliding_window') and config.sliding_window is not None: config.sliding_window = None config.sliding_window_size = None msg = "Sliding Window Attention configuration removed from config. Model reload may be needed to revert behavior." logging.warning(msg) return msg else: config.sliding_window_size = None msg = "Sliding Window Attention was not configured. No changes made." logging.info(msg) return msg def _set_attention_variant_config(model, config, variant="auto"): valid_variants = ["auto", "eager", "sdpa", "flash_attention_2"] if not variant or not isinstance(variant, str) or variant not in valid_variants: msg = f"Error: Invalid attention variant '{variant}'. Choose from: {', '.join(valid_variants)}." logging.error(msg) return msg config.attn_implementation = variant config.attention_variant = variant config.use_flash_attention_2 = (variant == "flash_attention_2") msg = (f"Attention implementation preference set in config to: '{variant}'. " f"Effective implementation depends on model, hardware, and transformers version. **Requires model reload** to take effect.") logging.warning(msg) return msg def _revert_attention_variant(model, config): default_variant = "auto" current_variant = getattr(config, 'attn_implementation', default_variant) if current_variant != default_variant: config.attn_implementation = default_variant config.attention_variant = default_variant config.use_flash_attention_2 = False msg = f"Attention implementation preference reverted to '{default_variant}' in config. Model reload required." logging.warning(msg) return msg else: config.attention_variant = default_variant config.use_flash_attention_2 = False msg = f"Attention implementation preference is already '{default_variant}' or was not set. No changes made." logging.info(msg) return msg def _enable_gradient_clipping(config): return _set_flag_only(config, "gradient_clipping_disabled", False, "Gradient Clipping Enabled (flag for Trainer).", "Gradient Clipping Disabled.") def _disable_gradient_clipping(config): return _set_flag_only(config, "gradient_clipping_disabled", True, "Gradient Clipping Enabled.", "Gradient Clipping Disabled (flag for Trainer).") def _enable_weight_decay(config): return _set_flag_only(config, "weight_decay_disabled", False, "Weight Decay Enabled (flag for Trainer).", "Weight Decay Disabled.") def _disable_weight_decay(config): return _set_flag_only(config, "weight_decay_disabled", True, "Weight Decay Enabled.", "Weight Decay Disabled (flag for Trainer).") def _enable_lr_scheduler(config): return _set_flag_only(config, "lr_scheduler_disabled", False, "LR Scheduler Enabled (flag for Trainer).", "LR Scheduler Disabled.") def _disable_lr_scheduler(config): return _set_flag_only(config, "lr_scheduler_disabled", True, "LR Scheduler Enabled.", "LR Scheduler Disabled (flag for Trainer).") def _enable_enhanced_security(config): return _set_flag_only(config, "enhanced_security_enabled", True, "Enhanced Security Enabled (symbolic flag).", "Enhanced Security Disabled.") def _disable_enhanced_security(config): return _set_flag_only(config, "enhanced_security_enabled", False, "Enhanced Security Enabled.", "Enhanced Security Disabled (symbolic flag).") def _enable_debug_mode(config): return _set_flag_only(config, "debug_mode_enabled", True, "Debug Mode Enabled (symbolic flag).", "Debug Mode Disabled.") def _disable_debug_mode(config): return _set_flag_only(config, "debug_mode_enabled", False, "Debug Mode Enabled.", "Debug Mode Disabled (symbolic flag).") def _enable_internal_usage_logging(config): return _set_flag_only(config, "internal_logging_enabled", True, "Internal Usage Logging Enabled (symbolic flag).", "Internal Logging Disabled.") def _disable_internal_usage_logging(config): return _set_flag_only(config, "internal_logging_enabled", False, "Internal Logging Enabled.", "Internal Logging Disabled (symbolic flag).") def _enable_drift_detection(config): return _set_flag_only(config, "drift_detection_enabled", True, "Drift Detection Enabled (symbolic flag).", "Drift Detection Disabled.") def _disable_drift_detection(config): return _set_flag_only(config, "drift_detection_enabled", False, "Drift Detection Enabled.", "Drift Detection Disabled (symbolic flag).") def _enable_auto_optimization(base_model, config): msg = "" if getattr(config, 'auto_optimization_enabled', False): msg = "Auto Optimization already enabled (flag was true)." logging.info(msg) return msg logging.info("Enabling Auto Optimization: Applying Quantization and Gradient Checkpointing...") device = get_device() quant_mode = 'bfloat16' if (device.type == 'cuda' and torch.cuda.is_bf16_supported()) else 'float16' if device.type == 'cpu': quant_mode = 'float32' quant_msg = _quantize_model(base_model, config, mode=quant_mode) gc_msg = _enable_gradient_checkpointing(base_model, config) config.auto_optimization_enabled = True msg = f"Auto Optimization Enabled. Quantization ({quant_mode}): {quant_msg}. Gradient Checkpointing: {gc_msg}" logging.info(msg) return msg def _disable_auto_optimization(config): if getattr(config, 'auto_optimization_enabled', False): config.auto_optimization_enabled = False logging.info("Auto Optimization Disabled (flag only). Applied optimizations (like quantization, GC) remain active unless manually reverted.") return "Auto Optimization Disabled (flag only)." else: logging.info("Auto Optimization was already disabled.") return "Auto Optimization already disabled." def _recover_perfect_precision(base_model, config): logging.info("Attempting to recover FP32 precision...") msg = _quantize_model(base_model, config, mode='float32') if getattr(config, 'perfect_precision_recovered', False): logging.info(f"Successfully recovered FP32 precision. Status: {msg}") return "Recovered FP32 Precision. " + msg else: logging.warning(f"FP32 precision recovery might have failed or model was already FP32. Status: {msg}") return "Attempted FP32 Precision Recovery. " + msg def _revert_perfect_precision(base_model, config): if not getattr(config, 'perfect_precision_recovered', False): return "Model not currently in FP32 mode according to flag (or flag is inconsistent)." device = get_device() mode_to_revert_to = 'bfloat16' if (device.type=='cuda' and torch.cuda.is_bf16_supported()) else 'float16' if device.type=='cuda' else 'float32' if mode_to_revert_to == 'float32': logging.info("Cannot revert from FP32 as the target revert type is also FP32 (e.g., on CPU).") return "Cannot revert from FP32 to lower precision on current device." logging.info(f"Reverting precision from FP32 (target: {mode_to_revert_to})...") msg = _quantize_model(base_model, config, mode=mode_to_revert_to) logging.info(f"Attempted precision revert from FP32: {msg}") return f"Reverted Precision from FP32 (attempted {mode_to_revert_to}). " + msg def _optimize_token_generation_speed(config): if not hasattr(config, '_original_do_sample'): config._original_do_sample = getattr(config, 'do_sample', True) if not hasattr(config, '_original_num_beams'): config._original_num_beams = getattr(config, 'num_beams', 1) if not hasattr(config, '_original_use_cache'): default_use_cache = True if hasattr(config, 'model_type'): if config.model_type == "t5" and getattr(config, 'gradient_checkpointing', False): default_use_cache = False config._original_use_cache = getattr(config, 'use_cache', default_use_cache) config.do_sample = False config.num_beams = 1 config.use_cache = True config.token_gen_speed_maximized = True logging.info("Token Generation Speed Optimized (Flags set for greedy decoding, num_beams=1, use_cache=True).") return "Token Speed Opt flags set (greedy, cache on)." def _revert_token_generation_speed_optimization(config): if not getattr(config, 'token_gen_speed_maximized', False): return "Token speed optimization not active according to flag." config.do_sample = getattr(config, '_original_do_sample', True) config.num_beams = getattr(config, '_original_num_beams', 1) config.use_cache = getattr(config, '_original_use_cache', True) config.token_gen_speed_maximized = False if hasattr(config, '_original_do_sample'): del config._original_do_sample if hasattr(config, '_original_num_beams'): del config._original_num_beams if hasattr(config, '_original_use_cache'): del config._original_use_cache logging.info("Token Generation Speed Optimization Reverted to previous/default flags.") return "Token Speed Optimization Reverted." def _add_peft_adapter(model, config, peft_config_obj=None): global global_model, current_peft_config if not _peft_installed: return "[Error] PEFT library (pip install peft) is not installed." if isinstance(model, PeftModel): return "[Warning] Model is already a PEFT model. Merge or remove existing adapters before adding a new one via this button." if getattr(config, 'lora_merged', False): return "[Warning] LoRA adapters were previously merged into this model state. Adding new adapters might have unintended effects without reloading the original base model." try: if peft_config_obj and isinstance(peft_config_obj, (LoraConfig, PeftConfig)): peft_conf = peft_config_obj logging.info(f"Using provided PEFT config object: {peft_conf}") else: default_config_dict = copy.deepcopy(DEFAULT_PEFT_CONFIG_DICT) if not default_config_dict: raise ValueError("Default PEFT config is not available and no valid config provided.") peft_conf = LoraConfig(**default_config_dict) logging.info(f"Using default PEFT config: {peft_conf}") if hasattr(peft_conf, 'task_type') and peft_conf.task_type != TaskType.CAUSAL_LM: logging.warning(f"PEFT config task type is {peft_conf.task_type}, overriding to CAUSAL_LM for this platform.") peft_conf.task_type = TaskType.CAUSAL_LM elif not hasattr(peft_conf, 'task_type'): if isinstance(peft_conf, PeftConfig) and not isinstance(peft_conf, LoraConfig): peft_conf.task_type = TaskType.CAUSAL_LM peft_model = get_peft_model(model, peft_conf) base_model_config = peft_model.get_base_model().config base_model_config.peft_adapter_added = True base_model_config.peft_config = peft_conf.to_dict() base_model_config.lora_merged = False current_peft_config = peft_conf global_model = peft_model config = base_model_config trainable_params, all_params = peft_model.get_nb_trainable_parameters() logging.info( f"trainable params: {trainable_params:,d} || all params: {all_params:,d} || trainable%: {100 * trainable_params / all_params:.4f}" ) msg = f"PEFT adapter ({type(peft_conf).__name__}) added successfully. Model is ready for PEFT training." logging.info(msg) return msg except Exception as e: logging.error(f"Error adding PEFT adapter: {e}\n{traceback.format_exc()}") if hasattr(model, 'config'): model.config.peft_adapter_added = False model.config.peft_config = None return f"[Error] Failed to add PEFT adapter: {e}" def _remove_peft_adapter(model, config): global global_model, current_peft_config if not _peft_installed: return "[Error] PEFT library not installed." if not isinstance(model, PeftModel): if getattr(config, 'peft_adapter_added', False): logging.warning("Model is not a PeftModel instance, but PEFT flag was set. Resetting flags.") config.peft_adapter_added = False config.peft_config = None current_peft_config = {} return "[Warning] Reset PEFT flags as model was not a PeftModel instance." else: return "[Info] No PEFT adapter currently applied to the model." try: base_model = model.get_base_model() global_model = base_model config = base_model.config config.peft_adapter_added = False config.peft_config = None current_peft_config = {} msg = "PEFT adapter layers removed. Restored base model and reset PEFT config flags." logging.info(msg) clean_memory() return msg except Exception as e: logging.error(f"Error removing PEFT adapter: {e}\n{traceback.format_exc()}") return f"[Error] Failed to remove PEFT adapter: {e}" def _setup_multimodal(model, config, selected_modalities): global global_tokenizer if not selected_modalities: return "[Info] No modalities selected for setup." if getattr(config, 'multimodal_applied', False): current_modalities = getattr(config, 'supported_modalities', []) return f"[Warning] Multi-modal setup already applied for modalities: {current_modalities}. Revert first to change." logging.info(f"Attempting multi-modal setup for: {selected_modalities}") device = get_device() llm_hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None)) if not llm_hidden_size: return "[Error] Cannot setup multi-modal: LLM config missing 'hidden_size' or 'd_model'." if global_tokenizer is None: return "[Error] Cannot setup multi-modal: Global tokenizer not loaded." try: added_encoders = {} added_projections = {} added_special_tokens = {} new_tokens_added_to_tokenizer = [] current_modality_config = {} current_special_tokens_map = {} tokens_to_add_struct = [] for modality in selected_modalities: if modality not in MODALITY_ENCODERS: logging.warning(f"Skipping modality '{modality}': No predefined encoder found.") continue special_token = f"<{modality.upper()}>" if special_token not in global_tokenizer.get_vocab(): tokens_to_add_struct.append({'token': special_token, 'modality': modality}) else: token_id = global_tokenizer.convert_tokens_to_ids(special_token) current_special_tokens_map[modality] = {"token": special_token, "id": token_id} logging.info(f"Special token '{special_token}' for {modality} already exists (ID: {token_id}).") if tokens_to_add_struct: num_added = global_tokenizer.add_tokens([t['token'] for t in tokens_to_add_struct], special_tokens=True) if num_added > 0: logging.info(f"Added {num_added} new special tokens to tokenizer: {[t['token'] for t in tokens_to_add_struct]}") logging.info(f"Resizing LLM token embeddings from {model.config.vocab_size} to {len(global_tokenizer)}.") model.resize_token_embeddings(len(global_tokenizer)) if hasattr(config, 'vocab_size'): config.vocab_size = len(global_tokenizer) with torch.no_grad(): input_embeddings = model.get_input_embeddings() if input_embeddings is not None and hasattr(input_embeddings, 'weight'): avg_weight = input_embeddings.weight[:-num_added,:].mean(dim=0) input_embeddings.weight[-num_added:,:] = avg_weight logging.info(f"Initialized {num_added} new token embeddings with average weight.") for t_info in tokens_to_add_struct: modality = t_info['modality'] special_token = t_info['token'] token_id = global_tokenizer.convert_tokens_to_ids(special_token) current_special_tokens_map[modality] = {"token": special_token, "id": token_id} new_tokens_added_to_tokenizer.append(special_token) else: logging.error(f"Failed to add special tokens: {[t['token'] for t in tokens_to_add_struct]}. Aborting multi-modal setup.") return "[Error] Failed to add required special tokens to tokenizer." successful_modalities = [] for modality in selected_modalities: if modality not in MODALITY_ENCODERS: continue encoder_id = MODALITY_ENCODERS[modality] encoder_attr_name = f"{modality.lower()}_encoder" projection_attr_name = f"{modality.lower()}_projection" try: logging.info(f"Loading {modality} encoder: {encoder_id}") encoder = AutoModel.from_pretrained(encoder_id, trust_remote_code=True) encoder = encoder.to(device).eval() for param in encoder.parameters(): param.requires_grad = False added_encoders[encoder_attr_name] = encoder setattr(model, encoder_attr_name, encoder) encoder_hidden_size = _get_encoder_hidden_size(encoder_id, trust_remote_code=True) logging.info(f"Creating projection layer for {modality}: {encoder_hidden_size} -> {llm_hidden_size}") projection = nn.Linear(encoder_hidden_size, llm_hidden_size).to(device) nn.init.xavier_uniform_(projection.weight) if projection.bias is not None: nn.init.zeros_(projection.bias) added_projections[projection_attr_name] = projection setattr(model, projection_attr_name, projection) current_modality_config[modality] = encoder_id successful_modalities.append(modality) except Exception as mod_e: logging.error(f"Failed to setup modality '{modality}' with encoder '{encoder_id}': {mod_e}") if hasattr(model, encoder_attr_name): delattr(model, encoder_attr_name) if hasattr(model, projection_attr_name): delattr(model, projection_attr_name) if successful_modalities: config.multimodal_applied = True config.supported_modalities = successful_modalities config.modality_encoders = current_modality_config config.modality_projection_dim = llm_hidden_size config.modality_special_tokens = current_special_tokens_map msg = (f"Multi-modal setup partially/fully applied for: {successful_modalities}. " f"Added {len(added_encoders)} encoders and {len(added_projections)} projections. " f"Added/mapped {len(current_special_tokens_map)} special tokens. ") logging.warning(msg) return msg else: config.multimodal_applied = False return "[Error] Multi-modal setup failed for all selected modalities." except Exception as e: logging.error(f"Error during multi-modal setup: {e}\n{traceback.format_exc()}") for name in added_encoders.keys(): if hasattr(model, name): delattr(model, name) for name in added_projections.keys(): if hasattr(model, name): delattr(model, name) config.multimodal_applied = False config.supported_modalities = [] config.modality_encoders = {} config.modality_projection_dim = None config.modality_special_tokens = {} return (f"[Error] Multi-modal setup failed: {e}. Attempted cleanup, state might be inconsistent " "(tokenizer/embeddings may remain changed). Reload original model/tokenizer for full reset.") def _revert_multimodal(model, config): if not getattr(config, 'multimodal_applied', False): return "[Info] Multi-modal setup not applied according to config." modalities_to_revert = getattr(config, 'supported_modalities', []) if not modalities_to_revert: config.multimodal_applied = False config.modality_encoders = {} config.modality_projection_dim = None config.modality_special_tokens = {} return "[Info] No supported modalities listed in config to revert, but flag was true. Resetting flags." logging.info(f"Reverting multi-modal setup for modalities: {modalities_to_revert}") removed_count = 0 errors = [] try: for modality in modalities_to_revert: encoder_attr_name = f"{modality.lower()}_encoder" projection_attr_name = f"{modality.lower()}_projection" try: if hasattr(model, encoder_attr_name): delattr(model, encoder_attr_name) logging.info(f"Removed encoder: {encoder_attr_name}") removed_count += 1 if hasattr(model, projection_attr_name): delattr(model, projection_attr_name) logging.info(f"Removed projection: {projection_attr_name}") removed_count += 1 except Exception as del_e: error_msg = f"Error removing components for modality '{modality}': {del_e}" logging.error(error_msg) errors.append(error_msg) config.multimodal_applied = False config.supported_modalities = [] config.modality_encoders = {} config.modality_projection_dim = None config.modality_special_tokens = {} logging.warning("Multi-modal components removed. **Special tokens added to tokenizer and potentially resized embeddings remain.** Reload original model/tokenizer if full reversion needed.") clean_memory() final_msg = f"Multi-modal setup reverted ({removed_count} components removed, flags reset). Embeddings/tokenizer not shrunk." if errors: final_msg += f" Errors encountered: {'; '.join(errors)}" return final_msg except Exception as e: logging.error(f"Error reverting multi-modal setup: {e}\n{traceback.format_exc()}") config.multimodal_applied = False config.supported_modalities = [] config.modality_encoders = {} config.modality_projection_dim = None config.modality_special_tokens = {} return f"[Error] Reverting multi-modal setup failed: {e}. Flags reset." def auto_extract_text_universal(data_item): if isinstance(data_item, str): return data_item.strip().replace('\\n', '\n') elif isinstance(data_item, bytes): try: return data_item.decode('utf-8', errors='replace').strip().replace('\\n', '\n') except Exception: return "" elif isinstance(data_item, (list, tuple)): texts = [auto_extract_text_universal(item) for item in data_item] return " ".join(filter(None, texts)) elif isinstance(data_item, dict): texts = [] potential_keys = [ 'text', 'content', 'sentence', 'paragraph', 'article', 'abstract', 'summary', 'body', 'passage', 'document', 'script', 'dialogue', 'instruction', 'input', 'output', 'query', 'response', 'title', 'question', 'answer', 'prompt', 'completion', 'target', 'label', 'review', 'comment', 'post', 'code', 'markdown' ] processed_keys = set() for key in potential_keys: if key in data_item and key not in processed_keys: value = data_item[key] extracted = auto_extract_text_universal(value) if extracted: texts.append(extracted) processed_keys.add(key) if not texts: for key, value in data_item.items(): if key not in processed_keys: extracted = auto_extract_text_universal(value) if extracted: texts.append(extracted) processed_keys.add(key) seen = set() unique_texts = [] for t in texts: if t and t not in seen: unique_texts.append(t) seen.add(t) return "\n".join(unique_texts) elif isinstance(data_item, (int, float, bool)) or data_item is None: return "" else: try: return str(data_item).strip().replace('\\n', '\n') except Exception: return "" def process_example_universal(example): extracted_text = auto_extract_text_universal(example) return {"text": extracted_text if extracted_text else "[EMPTY_OR_NON_TEXTUAL]"} def parse_datasets(dataset_text): datasets = [] seen_ids = set() for line_num, line in enumerate(dataset_text.strip().splitlines()): line = line.strip() if not line or line.startswith('#'): continue parts = [s.strip() for s in line.split(",") if s.strip()] ds_name = None ds_config = None ds_split = 'train' ds_weight = 1.0 if len(parts) >= 1: ds_name = parts[0] if len(parts) >= 2 and parts[1]: ds_config = parts[1] if parts[1].lower() != 'none' else None if len(parts) >= 3 and parts[2]: ds_split = parts[2] if len(parts) >= 4: try: ds_weight = float(parts[3]) if ds_weight <= 0: raise ValueError("Weight must be positive") except (ValueError, IndexError): logging.warning(f"Invalid or missing weight '{parts[3] if len(parts) >= 4 else ''}' on line {line_num+1} ('{line}'). Using default 1.0.") ds_weight = 1.0 if ds_name: dataset_id = f"{ds_name}_{ds_config or 'DEFAULT'}_{ds_split}" if dataset_id in seen_ids: logging.warning(f"Skipping duplicate dataset entry: {dataset_id} on line {line_num+1}") continue datasets.append({"id": ds_name, "config": ds_config, "split": ds_split, "weight": ds_weight}) seen_ids.add(dataset_id) else: logging.warning(f"Skipping invalid dataset line (no name found): '{line}' on line {line_num+1}") if not datasets: raise ValueError("No valid dataset configurations were parsed from the input.") return datasets def load_datasets_from_config(datasets_config): ds_list = [] loaded_configs = [] total_weight = 0.0 logging.info(f"Attempting to load datasets based on config: {datasets_config}") for config_entry in datasets_config: ds_name = config_entry['id'] ds_config = config_entry['config'] ds_split = config_entry['split'] ds_weight = config_entry['weight'] dataset_identifier = f"{ds_name}{'['+ds_config+']' if ds_config else ''} (Split: {ds_split}, Weight: {ds_weight})" try: logging.info(f"Loading {dataset_identifier}...") d = load_dataset( ds_name, ds_config, streaming=True, split=ds_split, trust_remote_code=True, ) try: peek = next(iter(d)) original_columns = list(peek.keys()) d = load_dataset(ds_name, ds_config, streaming=True, split=ds_split, trust_remote_code=True) except StopIteration: logging.warning(f"Dataset stream appears empty after loading: {dataset_identifier}. Skipping.") continue except Exception as peek_e: logging.warning(f"Could not reliably peek into dataset {dataset_identifier} to get columns: {peek_e}. Will attempt processing without column removal.") original_columns = None logging.info(f"Processing {dataset_identifier} (Original cols: {original_columns or 'unknown'}) -> Map to 'text' field") process_partial = partial(process_example_universal) processed_d = d.map(process_partial, remove_columns=original_columns) processed_d = processed_d.filter(lambda example: example.get("text") != "[EMPTY_OR_NON_TEXTUAL]") shuffled_d = processed_d.shuffle(buffer_size=10000, seed=42) ds_list.append(shuffled_d) loaded_configs.append(config_entry) total_weight += ds_weight logging.info(f"Successfully prepared stream: {dataset_identifier}") except (requests.exceptions.RequestException, gzip.BadGzipFile) as http_e: logging.error(f"Network or File Error loading dataset {dataset_identifier}: {http_e}. Check connection and dataset validity. Skipping.") except FileNotFoundError: logging.error(f"Dataset or config not found for {dataset_identifier}. Check name/config/path. Skipping.") except Exception as e: logging.error(f"General Error loading/processing dataset {dataset_identifier}: {e} \n{traceback.format_exc()}. Skipping.") if not ds_list: raise ValueError("No valid datasets were loaded. Check dataset names, configurations, splits, availability, and network connection.") logging.info(f"Successfully loaded {len(ds_list)} dataset streams.") if total_weight <= 0 or len(loaded_configs) != len(ds_list): probabilities = [1.0 / len(ds_list)] * len(ds_list) if ds_list else [] logging.warning("Using equal probabilities for interleaving due to zero total weight, loading errors, or no datasets.") else: probabilities = [cfg['weight'] / total_weight for cfg in loaded_configs] prob_sum = sum(probabilities) if abs(prob_sum - 1.0) > 1e-6: probabilities = [p / prob_sum for p in probabilities] if not ds_list: logging.warning("No datasets to interleave.") return None logging.info(f"Interleaving {len(ds_list)} datasets with probabilities: {[f'{p:.3f}' for p in probabilities]}") interleaved_ds = interleave_datasets(ds_list, probabilities=probabilities, seed=42, stopping_strategy="all_exhausted") return interleaved_ds def tokenize_function(examples, tokenizer, context_length): texts = [str(t) if t is not None else "" for t in examples["text"]] tokenized_output = tokenizer(texts, truncation=False, padding=False) return tokenized_output def group_texts(examples, block_size): concatenated_examples = {k: sum(examples[k], []) if isinstance(examples[k][0], list) else examples[k] for k in examples} total_length = len(concatenated_examples[list(examples.keys())[0]]) if total_length >= block_size: total_length = (total_length // block_size) * block_size else: return {k: [] for k in examples.keys()} result = { k: [t[i : i + block_size] for i in range(0, total_length, block_size)] for k, t in concatenated_examples.items() } result["labels"] = result["input_ids"].copy() return result def split_dataset(processed_lm_iterable_dataset): eval_buffer_size = 1000 shuffle_buffer_size = 10000 logging.info(f"Preparing train/eval split. Eval buffer: {eval_buffer_size}, Shuffle buffer: {shuffle_buffer_size}..."); T = time.time() if not isinstance(processed_lm_iterable_dataset, IterableDataset): logging.error("Input dataset is not an IterableDataset. Cannot perform stream-based splitting.") raise TypeError("Input to split_dataset must be an IterableDataset.") shuffled_ds = processed_lm_iterable_dataset.shuffle(seed=42, buffer_size=shuffle_buffer_size) logging.info(f"Taking up to {eval_buffer_size} samples for the evaluation buffer...") eval_samples_iter = shuffled_ds.take(eval_buffer_size) try: eval_list = list(eval_samples_iter) num_eval_samples = len(eval_list) except Exception as e: logging.error(f"Error collecting evaluation samples: {e}. Proceeding without evaluation set.") num_eval_samples = 0 eval_list = [] train_ds = shuffled_ds eval_ds_static = None if num_eval_samples > 0: logging.info(f"Collected {num_eval_samples} samples for evaluation buffer.") train_ds = shuffled_ds.skip(num_eval_samples) logging.info("Training stream prepared (skipped eval samples).") logging.info("Creating static evaluation dataset from buffer...") try: if not eval_list: raise ValueError("Evaluation buffer list is empty after take().") first_example = eval_list[0] if not isinstance(first_example, dict): raise ValueError("Eval buffer items are not dictionaries.") expected_keys = ['input_ids', 'attention_mask', 'labels'] eval_features_dict = {} for key in expected_keys: if key not in first_example: raise ValueError(f"Eval buffer items missing required key: '{key}'") try: from datasets import Sequence inner_dtype = 'int64' if isinstance(first_example[key], list) and first_example[key] and isinstance(first_example[key][0], int): eval_features_dict[key] = Sequence(feature=Value(dtype=inner_dtype)) else: eval_features_dict[key] = Value(dtype='list') except ImportError: eval_features_dict[key] = Value(dtype='list') if not eval_features_dict: raise ValueError("Could not define features for evaluation dataset.") eval_features = Features(eval_features_dict) valid_eval_list = [] required_keys_set = set(eval_features.keys()) for i, ex in enumerate(eval_list): if isinstance(ex, dict) and set(ex.keys()) >= required_keys_set: is_valid = all(isinstance(ex.get(k), list) for k in required_keys_set) if is_valid: valid_eval_list.append({k: ex[k] for k in required_keys_set}) else: logging.warning(f"Eval buffer item {i} has invalid type for required keys. Skipping.") else: logging.warning(f"Eval buffer item {i} is invalid (not dict or missing keys). Skipping.") if not valid_eval_list: logging.warning("No valid examples remained in the evaluation buffer after validation. Eval dataset will be None.") eval_ds_static = None train_ds = shuffled_ds else: eval_ds_static = Dataset.from_list(valid_eval_list, features=eval_features) logging.info(f"Created static evaluation dataset with {len(eval_ds_static)} examples.") except Exception as e: logging.error(f"Error creating static evaluation dataset from buffer: {e}\n{traceback.format_exc()}. Evaluation dataset will be None.") eval_ds_static = None train_ds = shuffled_ds else: logging.warning("Evaluation buffer is empty (requested size might be too large or dataset too small). Training will continue without evaluation.") logging.info(f"Dataset splitting completed in {time.time()-T:.2f}s") return train_ds, eval_ds_static def compute_perplexity(loss): if loss is None or not isinstance(loss, (int, float)) or not math.isfinite(loss): return float("inf") try: clamped_loss = min(max(loss, -700.0), 700.0) perplexity = math.exp(clamped_loss) if not math.isfinite(perplexity): logging.warning(f"Perplexity calculation resulted in infinity for loss {loss} (clamped: {clamped_loss}).") return float("inf") return perplexity except OverflowError: logging.warning(f"OverflowError computing perplexity for loss {loss}. Returning infinity.") return float("inf") except Exception as e: logging.warning(f"Error computing perplexity for loss {loss}: {e}. Returning infinity.") return float("inf") def merge_model_parameters(original_model, trained_model, alpha=MERGE_ALPHA): if not (0 <= alpha <= 1): logging.error(f"Merge alpha must be between 0 and 1. Got {alpha}. Defaulting to 0.5") alpha = 0.5 logging.info(f"Merging model parameters with alpha={alpha:.2f} (alpha*original + (1-alpha)*trained using linear interpolation)..."); T = time.time(); device = get_device() original_model = original_model.to(device) trained_model = trained_model.to(device) merged_model = copy.deepcopy(original_model).to(device) merged_params_count = 0 skipped_params_count = 0 orig_params = dict(original_model.named_parameters()) trained_params = dict(trained_model.named_parameters()) merged_params = dict(merged_model.named_parameters()) with torch.no_grad(): for name, trained_param in trained_params.items(): if name in orig_params and name in merged_params: orig_param = orig_params[name] merged_param = merged_params[name] if orig_param.data.shape == trained_param.data.shape: merged_tensor = torch.lerp(trained_param.data.float(), orig_param.data.float(), alpha) merged_param.copy_(merged_tensor.to(merged_param.dtype)) merged_params_count += 1 else: logging.warning(f"Size mismatch for parameter '{name}'. Original: {orig_param.data.shape}, Trained: {trained_param.data.shape}. Skipping merge for this parameter.") skipped_params_count += 1 else: if name not in orig_params: logging.warning(f"Parameter '{name}' from trained model not found in original model structure. Skipping.") if name not in merged_params: logging.warning(f"Parameter '{name}' from trained model not found in merged model structure (shouldn't happen). Skipping.") skipped_params_count += 1 logging.info(f"Parameter merging finished in {time.time()-T:.2f}s. Merged {merged_params_count} parameters, skipped {skipped_params_count}.") return merged_model def preserve_model_quality(original_model, trained_model, eval_dataset, tokenizer): if eval_dataset is None: logging.warning("No evaluation data provided (eval_dataset is None). Cannot perform quality check. Returning trained model.") return trained_model is_iterable = isinstance(eval_dataset, IterableDataset) if is_iterable: logging.warning("Evaluation dataset is iterable. Loss comparison might not be on the exact same data. Proceeding with caution.") try: _ = next(iter(eval_dataset.take(1))) except StopIteration: logging.warning("Iterable evaluation dataset appears empty. Returning trained model.") return trained_model except Exception as e: logging.warning(f"Could not peek into iterable eval dataset: {e}. Assuming not empty.") elif isinstance(eval_dataset, Dataset): if len(eval_dataset) == 0: logging.warning("Evaluation dataset is empty (length 0). Returning trained model.") return trained_model else: logging.warning(f"Unknown evaluation dataset type: {type(eval_dataset)}. Cannot perform quality check. Returning trained model.") return trained_model data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) eval_batch_size = max(1, BATCH_SIZE // 2) device = get_device() original_model.to(device).eval() trained_model.to(device).eval() was_training_orig = original_model.training trained_model.train() temp_eval_dir = "./tmp_eval_quality_check" eval_args = TrainingArguments( output_dir=temp_eval_dir, per_device_eval_batch_size=eval_batch_size, report_to=[], dataloader_num_workers=max(1, (NUM_CPU_CORES if NUM_CPU_CORES > 0 else os.cpu_count() // 2)), fp16=torch.cuda.is_available() and not USE_CPU and original_model.dtype == torch.float16, bf16=(torch.cuda.is_available() and torch.cuda.is_bf16_supported()) and original_model.dtype == torch.bfloat16, use_cpu=USE_CPU, log_level='error', remove_unused_columns=False, ) results = {} eval_error = False for model_name, model_instance in [("Original", original_model), ("Trained", trained_model)]: logging.info(f"Evaluating {model_name} model for quality check..."); T_eval = time.time() current_eval_dataset = eval_dataset if is_iterable: current_eval_dataset = eval_dataset.take(1000) # Hardcoded eval buffer size try: if len(list(iter(current_eval_dataset.take(1)))) == 0: logging.warning(f"Iterable eval sample for {model_name} is empty. Skipping eval.") results[model_name] = {"loss": float('inf'), "ppl": float('inf')} continue current_eval_dataset = eval_dataset.take(1000) # Hardcoded eval buffer size except Exception as e: logging.error(f"Error handling iterable dataset sample for {model_name}: {e}") results[model_name] = {"loss": float('inf'), "ppl": float('inf')} eval_error = True; break trainer = Trainer( model=model_instance, args=eval_args, data_collator=data_collator, eval_dataset=current_eval_dataset ) try: model_instance.eval() metrics = trainer.evaluate() loss = metrics.get("eval_loss") ppl = compute_perplexity(loss) results[model_name] = {"loss": loss if loss is not None else float('inf'), "ppl": ppl} logging.info(f"{model_name} Eval Loss: {loss if loss is not None else 'N/A':.4f}, PPL: {ppl:.4f} (Eval time: {time.time()-T_eval:.2f}s)") except StopIteration: logging.error(f"Evaluation dataset exhausted unexpectedly during evaluation of {model_name}. Comparison may be incomplete.") results[model_name] = {"loss": float('inf'), "ppl": float('inf')} eval_error = True; break except Exception as e: logging.error(f"Error evaluating {model_name} model: {e}\n{traceback.format_exc()}") results[model_name] = {"loss": float('inf'), "ppl": float('inf')} eval_error = True; break if os.path.exists(temp_eval_dir): try: shutil.rmtree(temp_eval_dir) except Exception as e: logging.warning(f"Could not remove temporary eval directory {temp_eval_dir}: {e}") original_model.train(mode=was_training_orig) trained_model.train() original_loss = results.get("Original", {}).get("loss", float('inf')) trained_loss = results.get("Trained", {}).get("loss", float('inf')) if eval_error: logging.error("Evaluation encountered errors. Cannot reliably compare models. Returning trained model.") return trained_model valid_comparison = math.isfinite(original_loss) and math.isfinite(trained_loss) if valid_comparison: loss_threshold = original_loss * 1.05 if trained_loss > loss_threshold: logging.warning(f"Trained model loss ({trained_loss:.4f}) is significantly worse (>5%) than original ({original_loss:.4f}). Reverting to original model state based on quality check.") return original_model.to(device) elif trained_loss > original_loss: logging.info(f"Trained model loss ({trained_loss:.4f}) is slightly worse than original ({original_loss:.4f}), but within threshold. Keeping trained model.") return trained_model.to(device) else: logging.info(f"Trained model loss ({trained_loss:.4f}) is better than or equal to original ({original_loss:.4f}). Keeping trained model.") return trained_model.to(device) else: logging.warning("Could not perform valid loss comparison (one or both evaluations failed or yielded non-finite loss). Returning trained model.") return trained_model.to(device) def _merge_architectures(model_ids_str, hf_token=None, bypass_limits_state=False): global global_model, global_tokenizer, config, global_pipe, BYPASS_RESOURCE_LIMITS BYPASS_RESOURCE_LIMITS = bypass_limits_state if not isinstance(model_ids_str, str) or not model_ids_str.strip(): return "[Error] Model IDs string cannot be empty.", "{}", *get_error_filter_updates() resources_ok, res_msg = check_resources() if not resources_ok: error_msg = f"[Error] Resource limits exceeded, cannot proceed with merge. {res_msg}" logging.error(error_msg) return error_msg, "{}", *get_error_filter_updates() else: logging.info(res_msg) model_ids = [m.strip() for m in model_ids_str.split(',') if m.strip()] if len(model_ids) < 2: return "[Error] Need at least two valid model IDs/paths separated by commas to merge.", "{}", *get_error_filter_updates() logging.info(f"Starting architecture merge (parameter averaging) for models: {model_ids}") device = get_device() merged_model = None t_merge_start = time.time() base_model_id = model_ids[0] try: logging.info(f"Loading base config and tokenizer from: {base_model_id}") base_tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token) base_config = AutoConfig.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token) if base_tokenizer.pad_token is None and base_tokenizer.eos_token is not None: base_tokenizer.pad_token = base_tokenizer.eos_token base_config.pad_token_id = base_config.eos_token_id logging.info("Set base tokenizer pad_token to eos_token for consistency.") except Exception as e: logging.error(f"Failed to load base config/tokenizer for {base_model_id}: {e}") return f"[Error] Failed to load base model config/tokenizer: {e}", "{}", *get_error_filter_updates() try: logging.info(f"Loading base model state dict (CPU, float32) for merging: {base_model_id}") base_model = AutoModelForCausalLM.from_pretrained( base_model_id, trust_remote_code=True, token=hf_token, torch_dtype=torch.float32, low_cpu_mem_usage=True ) base_state_dict = base_model.state_dict() merged_state_dict = OrderedDict((k, v.clone()) for k, v in base_state_dict.items()) param_counts = OrderedDict((k, 1) for k in base_state_dict) num_models_processed = 1 del base_model, base_state_dict clean_memory() except Exception as e: logging.error(f"Failed to load base model state dict for {base_model_id}: {e}") return f"[Error] Failed to load base model state dict: {e}", "{}", *get_error_filter_updates() for i, model_id in enumerate(model_ids[1:]): logging.info(f"Processing model {i+2}/{len(model_ids)}: {model_id}") try: model_i = AutoModelForCausalLM.from_pretrained( model_id, trust_remote_code=True, token=hf_token, torch_dtype=torch.float32, low_cpu_mem_usage=True ) state_dict_i = model_i.state_dict() for name, param_i in state_dict_i.items(): if name in merged_state_dict: if merged_state_dict[name].shape == param_i.shape: merged_state_dict[name].add_(param_i) param_counts[name] += 1 else: logging.warning(f"Shape mismatch for parameter '{name}' between base and {model_id}. Base: {merged_state_dict[name].shape}, Current: {param_i.shape}. Parameter '{name}' will NOT include contribution from {model_id}.") else: logging.warning(f"Parameter '{name}' found in {model_id} but not in base model {base_model_id}. Skipping this parameter.") num_models_processed += 1 del model_i, state_dict_i clean_memory() except Exception as e: logging.error(f"Failed to load or process model {model_id}: {e}. Skipping this model for merge.") continue if num_models_processed < 2: msg = "Merge failed: Fewer than two models were successfully loaded and processed." logging.error(msg) return f"[Error] {msg}", "{}", *get_error_filter_updates() averaged_count = 0 for name in merged_state_dict: count = param_counts.get(name, 0) if count > 0: merged_state_dict[name].div_(count) averaged_count +=1 else: logging.error(f"Parameter '{name}' has count {count} <= 0 during averaging. This indicates a logic error.") logging.info(f"Averaged {averaged_count} parameters across {num_models_processed} successfully processed models.") try: logging.info("Creating final merged model from base config and averaged weights...") merged_model = AutoModelForCausalLM.from_config(base_config, trust_remote_code=True) load_results = merged_model.load_state_dict(merged_state_dict, strict=False) if load_results.missing_keys: logging.warning(f"Load state dict results: Missing keys: {load_results.missing_keys}") if load_results.unexpected_keys: logging.warning(f"Load state dict results: Unexpected keys: {load_results.unexpected_keys}") final_dtype = torch.bfloat16 if device.type == 'cuda' and torch.cuda.is_bf16_supported() else torch.float16 if device.type == 'cuda' else torch.float32 logging.info(f"Converting merged model to {final_dtype} for use on device {device}.") global_model = merged_model.to(device=device, dtype=final_dtype) global_tokenizer = base_tokenizer config = initialize_config_flags(global_model.config) config.architecture_merged = True config.merged_from_models = model_ids config.merged_models_processed = num_models_processed update_pipeline() clean_memory() final_status_json, *filter_updates = get_detailed_status_and_filter_states() merge_time = time.time() - t_merge_start msg = f"Successfully merged architectures (averaged parameters) from {num_models_processed} models in {merge_time:.2f}s. Base config/tokenizer from: {base_model_id}." logging.info(msg) return msg, final_status_json, *filter_updates except Exception as e: logging.error(f"Architecture merging failed during final model creation or state update: {e}\n{traceback.format_exc()}") global_model = None global_tokenizer = None config = None global_pipe = None clean_memory() return f"[Error] Architecture merging failed: {e}", "{}", *get_error_filter_updates() def get_user_id(token): if not token: logging.warning("No Hugging Face token provided for user ID check.") return "unknown_user" try: api = HfApi() user_info = api.whoami(token=token) return user_info.get("name", "unknown_user") except requests.exceptions.HTTPError as http_err: if http_err.response.status_code == 401: logging.error("Hugging Face authentication failed (401 Unauthorized). Check your token.") return "auth_error_user" else: logging.error(f"HTTP error retrieving Hugging Face user ID: {http_err}") return "http_error_user" except Exception as e: logging.error(f"Could not retrieve Hugging Face user ID: {e}") return "unknown_user" def decode_model_details(model): if model is None: return json.dumps({"Error": "Model not loaded."}, indent=2) if not hasattr(model, 'config'): logging.warning("Model object lacks a 'config' attribute.") details = OrderedDict() details["Model Class"] = type(model).__name__ details["Error"] = "Model config attribute not found." return json.dumps(details, indent=2) details = OrderedDict() config_obj = model.config t_start_decode = time.time() logging.info("Decoding model details...") try: details["Model Class"] = type(model).__name__ details["Config Class"] = getattr(config_obj, 'config_class', type(config_obj).__name__) details["Model Type"] = getattr(config_obj, 'model_type', 'N/A') total_params = 0 trainable_params = 0 param_dtypes = set() param_devices = set() try: for name, param in model.named_parameters(): num_elements = param.numel() total_params += num_elements param_dtypes.add(str(param.dtype).replace('torch.', '')) param_devices.add(str(param.device)) if param.requires_grad: trainable_params += num_elements if not param_devices: device_str = "N/A (No parameters)" elif len(param_devices) == 1: device_str = param_devices.pop() else: device_str = f"Multiple ({', '.join(param_devices)})" except Exception as e: logging.warning(f"Could not fully analyze parameters: {e}") device_str = "Error analyzing params" details["Device(s)"] = device_str trainable_perc = (100 * trainable_params / total_params) if total_params > 0 else 0.00 details["Params Summary"] = (f"Total: {total_params:,}, Trainable: {trainable_params:,} " f"({trainable_perc:.2f}%), Dtypes: {list(param_dtypes)}") try: layer_counts = Counter(type(m).__name__ for m in model.modules() if not isinstance(m, nn.Sequential)) details["Layer Types Count"] = dict(layer_counts.most_common(15)) except Exception as e: logging.warning(f"Could not count layer types: {e}") details["Layer Types Count"] = "Error counting layers" details["Modification Flags"] = {} all_flags = initialize_config_flags(None).__dict__.keys() for flag in sorted(all_flags): if hasattr(config_obj, flag): value = getattr(config_obj, flag) details["Modification Flags"][flag] = value details["Key Config Attributes"] = {} key_attrs = ['vocab_size', 'hidden_size', 'num_hidden_layers', 'num_attention_heads', 'intermediate_size', 'max_position_embeddings', 'hidden_act', 'layer_norm_eps', 'rms_norm_eps', 'attention_dropout', 'hidden_dropout_prob', 'initializer_range', 'tie_word_embeddings', 'rope_scaling', 'sliding_window', 'attn_implementation'] for attr in key_attrs: if hasattr(config_obj, attr): details["Key Config Attributes"][attr] = getattr(config_obj, attr) logging.info(f"Model details decoded in {time.time() - t_start_decode:.2f}s") return json.dumps(details, indent=2, default=str) except Exception as e: logging.error(f"Error decoding model details: {e} \n{traceback.format_exc()}") details["Error"] = f"Failed during detail decoding: {e}" return json.dumps(details, indent=2, default=str) def update_pipeline(): global global_model, global_tokenizer, global_pipe if global_model and global_tokenizer: device = get_device() pipeline_device_arg = None device_map = None if device.type == 'cpu': pipeline_device_arg = -1 logging.info("Configuring pipeline for CPU.") elif device.type == 'cuda': if torch.cuda.device_count() > 1: device_map = "auto" pipeline_device_arg = None logging.info("Multiple GPUs detected, configuring pipeline with device_map='auto'.") else: pipeline_device_arg = 0 logging.info("Configuring pipeline for single CUDA device (device=0).") elif device.type == 'mps': pipeline_device_arg = 0 logging.info("Configuring pipeline for MPS device (device=0).") else: pipeline_device_arg = -1 logging.warning(f"Unknown device type '{device.type}', configuring pipeline for CPU.") logging.info(f"Updating text generation pipeline (Device Arg: {pipeline_device_arg}, Device Map: {device_map})..."); T=time.time() try: if device_map is None and pipeline_device_arg is not None: if pipeline_device_arg == -1: global_model.to('cpu') elif device.type == 'cuda': global_model.to(f'cuda:{pipeline_device_arg}') elif device.type == 'mps': global_model.to('mps:0') task = "text-generation" global_pipe = pipeline( task=task, model=global_model, tokenizer=global_tokenizer, device=pipeline_device_arg, device_map=device_map ) pipe_device_str = "N/A" if global_pipe.device_map: pipe_device_str = f"device_map: {global_pipe.device_map}" elif global_pipe.device: pipe_device_str = str(global_pipe.device) logging.info(f"Text generation pipeline created/updated. Effective device(s): {pipe_device_str}") if device_map is None and global_pipe.device != device: logging.warning(f"Pipeline created on {global_pipe.device}, but target device was {device}. This might happen with device_map issues or insufficient VRAM.") msg = f"Text generation pipeline updated successfully in {time.time()-T:.2f}s."; logging.info(msg) return msg except Exception as e: msg=f"Pipeline update failed: {e}\n{traceback.format_exc()}"; logging.error(msg); global_pipe = None; return f"[Error] Pipeline update failed: {e}" else: msg = "Cannot update pipeline: Global model or tokenizer not loaded."; logging.warning(msg); global_pipe = None; return msg def get_detailed_status_and_filter_states(): global global_model, config t_start = time.time() if global_model is None: logging.warning("Cannot get status: Model not loaded.") return json.dumps({"Error": "Model not loaded."}, indent=2), *get_error_filter_updates() if not hasattr(global_model, 'config') or global_model.config is None: logging.warning("Model config missing. Initializing default flags for status check.") temp_config = initialize_config_flags(None) status_json = json.dumps({"Warning": "Model config missing, status reflects defaults.", **json.loads(decode_model_details(global_model))}, indent=2) config_to_check = temp_config else: config = global_model.config config = initialize_config_flags(config) global_model.config = config status_json = decode_model_details(global_model) config_to_check = config logging.info("Refreshing detailed model status and filter checkbox states...") filter_states = {} for name in filter_names_ui: attr_name = filter_attr_map.get(name) if attr_name: filter_states[name] = getattr(config_to_check, attr_name, False) else: logging.error(f"Filter name '{name}' not found in attribute map. Setting state to False.") filter_states[name] = False updates = [gr.update(value=filter_states.get(name, False)) for name in filter_names_ui] logging.info(f"Refreshed status and filter states in {time.time()-t_start:.2f}s."); return status_json, *updates def get_error_filter_updates(): return [gr.update(value=False) for _ in filter_names_ui] def base_toggle_function(func_enable, func_disable, enable, success_msg_enable, success_msg_disable, *args): global global_model, config t_start = time.time() if not global_model: return "[Error] Model not loaded. Load a model first." if not hasattr(global_model, 'config') or global_model.config is None: logging.warning("Model config missing. Initializing default flags before toggle.") global_model.config = initialize_config_flags(None) config = initialize_config_flags(global_model.config) global_model.config = config msg = "" func_to_call = func_enable if enable else func_disable action_name = "Enable" if enable else "Disable" func_name = getattr(func_enable, '__name__', 'unknown_enable').replace('_', ' ').title() if enable else \ getattr(func_disable, '__name__', 'unknown_disable').replace('_', ' ').title() logging.info(f"Executing toggle: {action_name} {func_name}...") try: sig = inspect.signature(func_to_call) pass_args = [] if 'model' in sig.parameters or 'base_model' in sig.parameters or 'module' in sig.parameters: pass_args.append(global_model) if 'config' in sig.parameters: pass_args.append(config) pass_args.extend(args) result = func_to_call(*pass_args) if isinstance(result, str) and "[Error]" not in result: msg = result elif isinstance(result, str): msg = result else: msg = success_msg_enable if enable else success_msg_disable logging.info(f"Toggle Action ({func_name} -> {action_name}) Result: {msg} (Took {time.time()-t_start:.2f}s)") if "[Error]" not in msg: update_pipeline() except Exception as e: msg = f"[Error] during toggle ({action_name} {func_name}): {e}" logging.error(f"{msg}\n{traceback.format_exc()}") clean_memory() return msg def specific_action_function(action_func, *args, success_msg="Action completed successfully."): global global_model, global_tokenizer, config t_start=time.time() if not global_model: return "[Error] Model not loaded. Load a model first." if not hasattr(global_model, 'config') or global_model.config is None: logging.warning("Model config missing. Initializing default flags before action.") global_model.config = initialize_config_flags(None) config = initialize_config_flags(global_model.config) global_model.config = config msg = "" func_name = getattr(action_func, '__name__', 'unknown_action') logging.info(f"Executing action: {func_name}...") try: sig = inspect.signature(action_func) pass_args = [] if 'model' in sig.parameters or 'base_model' in sig.parameters or 'module' in sig.parameters: pass_args.append(global_model) if 'config' in sig.parameters: pass_args.append(config) if 'tokenizer' in sig.parameters: if global_tokenizer: pass_args.append(global_tokenizer) else: return f"[Error] Action '{func_name}' requires tokenizer, but it's not loaded." pass_args.extend(args) result = action_func(*pass_args) if isinstance(result, str) and "[Error]" not in result: msg = result elif isinstance(result, str): msg = result else: msg = success_msg logging.info(f"Action ({func_name}) Result: {msg} (Took {time.time()-t_start:.2f}s)") if "[Error]" not in msg: update_pipeline() except Exception as e: msg = f"[Error] during action ({func_name}): {e}" logging.error(f"{msg}\n{traceback.format_exc()}") clean_memory() return msg toggle_bias_removal_wrapper = lambda enable: base_toggle_function(_replace_linear_without_bias, _enable_bias_in_linear, enable, "Bias removal applied.", "Bias addition applied (reverted removal).") toggle_embeddings_untie_wrapper = lambda enable: base_toggle_function(_untie_embeddings, _retie_embeddings, enable, "Embeddings untied.", "Embeddings re-tied.") toggle_layer_reduction_wrapper = lambda enable, layers: specific_action_function(_reduce_layers_to_one if enable else _enable_full_layers, layers if enable else None, success_msg=f"Layer reduction {'applied' if enable else 'reverted'}.") apply_norm_swap_wrapper = lambda norm_type: specific_action_function(_swap_normalization_layer, norm_type, success_msg=f"Normalization swapped to {norm_type}") apply_activation_change_wrapper = lambda name: specific_action_function(_swap_activation_function, name, success_msg=f"Activation Function Swapped to {name}") revert_activation_change_wrapper = lambda: specific_action_function(_revert_activation_function, success_msg="Activation Function Reverted to Default") toggle_bitnet_wrapper = lambda enable: base_toggle_function(convert_to_bitnet, revert_bitnet, enable, "BitNet conversion applied.", "BitNet conversion reverted.") apply_multimodal_wrapper = lambda modalities: specific_action_function(_setup_multimodal, modalities, success_msg="Multi-modal setup attempted.") revert_multimodal_wrapper = lambda: specific_action_function(_revert_multimodal, success_msg="Multi-modal setup reverted.") toggle_token_speed_optimization_wrapper = lambda enable: specific_action_function(_optimize_token_generation_speed if enable else _revert_token_generation_speed_optimization, success_msg="Token Speed Opt Flags Updated") toggle_coherence_improvement_wrapper = lambda enable: specific_action_function(_enable_coherence_improvement if enable else _disable_coherence_improvement, success_msg="Coherence Flag Updated") toggle_layer_norm_bypass_wrapper = lambda enable: specific_action_function(_enable_layer_norm_bypass if enable else _disable_layer_norm_bypass, success_msg="LN Bypass Updated") toggle_dropout_bypass_wrapper = lambda enable: specific_action_function(_enable_dropout_bypass if enable else _disable_dropout_bypass, success_msg="Dropout Bypass Updated") toggle_fp32_precision_wrapper = lambda enable: specific_action_function(_recover_perfect_precision if enable else _revert_perfect_precision, success_msg="FP32 Precision Updated") toggle_embedding_normalization_wrapper = lambda enable: specific_action_function(_normalize_embeddings if enable else _revert_embedding_normalization, success_msg="Embedding Normalization Updated") toggle_gradient_checkpointing_wrapper = lambda enable: specific_action_function(_enable_gradient_checkpointing if enable else _disable_gradient_checkpointing, success_msg="Grad Checkpointing Updated") toggle_flash_attention_wrapper = lambda enable: specific_action_function(_set_attention_variant_config, "flash_attention_2" if enable else "auto", success_msg=f"Flash Attention 2 {'Enabled' if enable else 'Disabled'} (via attn_implementation)") apply_quantization_wrapper = lambda mode: specific_action_function(_quantize_model, mode, success_msg=f"Quantization Attempted: {mode}") revert_quantization_wrapper = lambda: specific_action_function(_revert_quantization, success_msg="Quantization Reverted to FP32") def _parse_pruning_amount(amount_str): try: amount = float(amount_str) if not (0 < amount < 1): raise ValueError("Pruning amount must be between 0 and 1") return amount except (ValueError, TypeError): logging.warning(f"Invalid pruning amount '{amount_str}', using default {PRUNING_AMOUNT}") return PRUNING_AMOUNT apply_pruning_wrapper = lambda amount_str: specific_action_function( _prune_weights_magnitude, _parse_pruning_amount(amount_str), success_msg=f"Pruning Applied (Amount: {_parse_pruning_amount(amount_str):.2f})" ) revert_pruning_wrapper = lambda: specific_action_function(_revert_pruning, success_msg="Pruning Flag Reverted") set_lora_path_wrapper = lambda path: specific_action_function(_set_lora_adapter_path, path, success_msg="LoRA Path Set in Config") add_peft_adapter_wrapper = lambda: specific_action_function( _add_peft_adapter, LoraConfig(**DEFAULT_PEFT_CONFIG_DICT) if _peft_installed else None, success_msg="PEFT Adapter Added" ) merge_peft_adapter_wrapper = lambda: specific_action_function(_apply_lora_merge, success_msg="PEFT Adapter Merged") remove_peft_adapter_wrapper = lambda: specific_action_function(_remove_peft_adapter, success_msg="PEFT Adapter Removed") apply_layer_freeze_wrapper = lambda layers_str: specific_action_function(_freeze_layers, layers_str, success_msg="Layer Freezing Updated") revert_layer_freeze_wrapper = lambda: specific_action_function(_unfreeze_all_layers, success_msg="All Layers Unfrozen") toggle_limits_wrapper = lambda enable: specific_action_function(_configure_limits if enable else _remove_limits_configuration, success_msg="Limits Config Updated") toggle_qa_restrictions_wrapper = lambda enable: specific_action_function(_remove_qa_restrictions if enable else _enable_qa_restrictions, success_msg="QA Restrictions Flag Updated") def _parse_int_arg(arg, default, min_val=1): try: val = int(arg) return max(val, min_val) except (ValueError, TypeError): return default toggle_kd_wrapper = lambda enable, num_labels=2: specific_action_function( _setup_knowledge_distillation if enable else _revert_knowledge_distillation, _parse_int_arg(num_labels, 2, 1) if enable else (), success_msg="KD Setup Updated" ) toggle_reward_modeling_wrapper = lambda enable, num_outputs=1: specific_action_function( _setup_reward_modeling if enable else _revert_reward_modeling, _parse_int_arg(num_outputs, 1, 1) if enable else (), success_msg="Reward Modeling Setup Updated" ) toggle_swa_wrapper = lambda enable: specific_action_function(_apply_swa if enable else _revert_swa, success_msg="SWA Flag Updated") def _parse_prob_arg(arg, default, min_val=0.0, max_val=1.0): try: val = float(arg) return min(max(val, min_val), max_val) except(ValueError, TypeError): return default toggle_layerdrop_wrapper = lambda enable, prob=0.1: specific_action_function( _enable_layerdrop if enable else _disable_layerdrop, _parse_prob_arg(prob, 0.1, 0.0, 1.0) if enable else (), success_msg="LayerDrop Flag Updated" ) toggle_rope_scaling_wrapper = lambda enable, type="linear", factor=2.0: specific_action_function( _set_rope_scaling_config if enable else _revert_rope_scaling, str(type) if enable else (), _parse_prob_arg(factor, 2.0, 1.0, 100.0) if enable else (), success_msg="RoPE Scaling Config Updated" ) toggle_sliding_window_wrapper = lambda enable, size=4096: specific_action_function( _set_sliding_window_config if enable else _revert_sliding_window, _parse_int_arg(size, 4096, 0) if enable else (), success_msg="Sliding Window Config Updated" ) apply_attention_variant_wrapper = lambda variant="auto": specific_action_function(_set_attention_variant_config, str(variant), success_msg="Attention Variant Config Updated") revert_attention_variant_wrapper = lambda: specific_action_function(_revert_attention_variant, success_msg="Attention Variant Config Reverted") toggle_gradient_clipping_flag_wrapper = lambda enable: specific_action_function(_enable_gradient_clipping if enable else _disable_gradient_clipping, success_msg="Grad Clipping Flag Updated") toggle_weight_decay_flag_wrapper = lambda enable: specific_action_function(_enable_weight_decay if enable else _disable_weight_decay, success_msg="Weight Decay Flag Updated") toggle_lr_scheduler_flag_wrapper = lambda enable: specific_action_function(_enable_lr_scheduler if enable else _disable_lr_scheduler, success_msg="LR Scheduler Flag Updated") apply_optimizer_change_wrapper = lambda name: specific_action_function(_swap_optimizer, str(name), success_msg=f"Optimizer Pref Set: {name}") revert_optimizer_change_wrapper = lambda: specific_action_function(_revert_optimizer, success_msg="Optimizer Pref Reverted") def _set_grad_accum_config(config, steps): try: s = int(steps) if s < 1: raise ValueError("Steps must be >= 1") config.gradient_accumulation_steps = s global GRADIENT_ACCUMULATION_STEPS GRADIENT_ACCUMULATION_STEPS = s return f"Grad Accum Steps set to {s} in config." except (ValueError, TypeError) as e: logging.error(f"Invalid gradient accumulation steps: {steps}. Error: {e}") return f"[Error] Invalid Grad Accum steps: {e}" set_gradient_accumulation_wrapper = lambda steps: specific_action_function(_set_grad_accum_config, steps, success_msg=f"Grad Accum Steps update attempted.") toggle_all_safety_filters_wrapper = lambda enable: specific_action_function(_enable_all_safety_settings if enable else _disable_all_safety_settings, success_msg=f"All Safety Filters {'Enabled (Defaults)' if enable else 'Disabled'}") force_disable_censorship_wrapper = lambda: specific_action_function(_disable_all_safety_settings, success_msg="Attempted Force Disable All Censorship Flags") def toggle_individual_safety_filter_wrapper(*state_dict): global global_model, config t_start=time.time() if not global_model: return "[Error] Model not loaded." if not hasattr(global_model, 'config') or global_model.config is None: logging.warning("Model config missing. Initializing default flags for filter toggle.") global_model.config = initialize_config_flags(None) config = initialize_config_flags(global_model.config) global_model.config = config results = [] updated_count = 0 if len(state_dict) != len(filter_names_ui): return f"[Error] Mismatch between filter UI elements ({len(filter_names_ui)}) and received states ({len(state_dict)})." ui_state = dict(zip(filter_names_ui, state_dict)) for name, checkbox_state in ui_state.items(): filter_attr = filter_attr_map.get(name) if filter_attr: current_state = getattr(config, filter_attr, False) new_state = bool(checkbox_state) if current_state != new_state: setattr(config, filter_attr, new_state) results.append(f"{name}: {'ON' if new_state else 'OFF'}") updated_count += 1 else: logging.warning(f"UI filter name '{name}' not found in attribute map filter_attr_map. Skipping.") if updated_count > 0: msg = f"Applied {updated_count} individual filter toggle(s): {', '.join(results)}" update_pipeline() else: msg = "No individual filter states were changed." logging.info(f"Individual filter toggle action took {time.time()-t_start:.2f}s. Status: {msg}"); return msg def _improve_coherence(model, tokenizer, generation_args): logging.info("Applying coherence improvement using beam search...") coherence_beams = generation_args.get("num_beams", 1) if coherence_beams <= 1: coherence_beams = 4 coherence_args = generation_args.copy() coherence_args["num_beams"] = coherence_beams coherence_args["do_sample"] = False coherence_args["num_return_sequences"] = 1 coherence_args["early_stopping"] = True coherence_args.pop("temperature", None) coherence_args.pop("top_k", None) coherence_args.pop("top_p", None) input_ids = coherence_args.get("input_ids") if input_ids is None: logging.error("Coherence improvement failed: input_ids missing.") return "[Error: input_ids missing in generation_args]" try: with torch.no_grad(): outputs = model.generate(**coherence_args) response_ids = outputs[0][input_ids.shape[-1]:] response_text = tokenizer.decode(response_ids, skip_special_tokens=True) logging.info("Coherence improvement (beam search) successful.") return response_text except Exception as e: logging.error(f"Error during coherence improvement (beam search): {e}. Falling back to original generation settings.") fallback_args = generation_args.copy() if fallback_args.get("num_beams", 1) > 1 and fallback_args.get("do_sample", False): fallback_args["num_beams"] = 1 fallback_args["early_stopping"] = False try: with torch.no_grad(): outputs = model.generate(**fallback_args) response_ids = outputs[0][input_ids.shape[-1]:] response_text = tokenizer.decode(response_ids, skip_special_tokens=True) logging.info("Fallback generation successful after coherence error.") return response_text except Exception as fallback_e: logging.error(f"Fallback generation also failed: {fallback_e}") return f"[Error during generation (fallback): {fallback_e}]" def _apply_output_filters(text, config): if not getattr(config, "response_filters", True) or getattr(config, "no_response_filters", False): return text logging.debug("Applying output filters based on config...") original_text = text if not isinstance(text, str): text = str(text) filter_actions = { "profanity_filter": lambda t: re.sub(r'\b(damn|fuck|shit|cunt|bitch|asshole|bastard|hell)\b', '[profanity]', t, flags=re.IGNORECASE), "pii_filter": lambda t: re.sub(r'\b(\d{3}[-.\s]?\d{3}[-.\s]?\d{4})\b|\b([A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Z|a-z]{2,})\b', '[censored_pii]', t), "hate_filter": lambda t: "[censored_hate_speech]" if any(word in t.lower() for word in ["kill all", "attack the", "inferior race", "exterminate"]) else t, "sexually_explicit_filter": lambda t: "[censored_explicit]" if any(word in t.lower() for word in [" rape ", " naked ", " porn ", " xxx ", " blowjob ", " sex toy"]) else t, "dangerous_content_filter": lambda t: "[censored_dangerous]" if any(word in t.lower() for word in ["make a bomb", "build weapon", "illegal drug recipe", "how to kill"]) else t, "medical_advice_filter": lambda t: "[discouraged_medical_advice]" if any(word in t.lower() for word in ["diagnose", "prescribe", "cure for", "medical treatment for", "symptoms suggest"]) else t, "legal_advice_filter": lambda t: "[discouraged_legal_advice]" if any(word in t.lower() for word in ["legal advice", "sue", "represent me", "is this legal", "contract law"]) else t, "financial_advice_filter": lambda t: "[discouraged_financial_advice]" if any(word in t.lower() for word in ["guaranteed investment", "stock tip", "financial plan", "buy bitcoin", "investment advice"]) else t, "stereotype_filter": lambda t: "[censored_stereotype]" if re.search(r'\b(all|every)\s+([A-Za-z]+(\s+)?){1,3}\s+(are|always)\b', t.lower()) else t, "misinfo_filter": lambda t: "[potential_misinfo]" if any(phrase in t.lower() for phrase in ["5g causes covid", "earth is flat", "vaccines cause autism", "chemtrails"]) else t, "self_harm_filter": lambda t: "[censored_self_harm]" if any(phrase in t.lower() for phrase in ["commit suicide", "hurt myself", "painless death", "kill myself"]) else t, } active_filters_count = 0 for filter_ui_name, filter_attr in filter_attr_map.items(): if getattr(config, filter_attr, False): filter_func = filter_actions.get(filter_attr) if filter_func: try: filtered_text = filter_func(text) if filtered_text != text: active_filters_count +=1 logging.debug(f"Filter '{filter_attr}' potentially applied modification.") text = filtered_text except Exception as e: logging.warning(f"Error applying filter '{filter_attr}': {e}") if not getattr(config, "no_advert_warning", False): if re.search(r'\b(advertisement|sponsored|promo code|discount code|special offer)\b', text, re.IGNORECASE): if "[Note: This response may contain promotional content.]" not in text: text += "\n[Note: This response may contain promotional content.]" active_filters_count +=1 if active_filters_count > 0: logging.debug(f"Output filtering potentially applied {active_filters_count} modifications.") return text def run_inference(prompt, max_new_tokens, temperature, top_k, top_p, repetition_penalty): global global_pipe, global_model, global_tokenizer, config if not all([global_model, global_tokenizer]): return "[Error] Model or Tokenizer not loaded. Please load a model first." if global_pipe is None: pipe_msg = update_pipeline() if global_pipe is None: return f"[Error] Text generation pipeline could not be initialized. Load/Reload model. Status: {pipe_msg}" if not hasattr(global_model, 'config'): logging.warning("Model config missing during inference. Initializing default flags.") global_model.config = initialize_config_flags(None) config = initialize_config_flags(global_model.config) global_model.config = config logging.info("Starting inference run..."); t_start_inf = time.time() try: use_filters = getattr(config, "response_filters", True) and not getattr(config, "no_response_filters", False) apply_coherence = getattr(config, "coherence_improvement_enabled", False) try: max_new_tokens = int(max_new_tokens); assert max_new_tokens > 0 except: max_new_tokens = 256; logging.warning("Invalid max_new_tokens, using 256.") try: temperature = float(temperature); assert temperature >= 0.0 except: temperature = 0.7; logging.warning("Invalid temperature, using 0.7.") try: top_k = int(top_k); assert top_k >= 0 except: top_k = 50; logging.warning("Invalid top_k, using 50.") try: top_p = float(top_p); assert 0.0 <= top_p <= 1.0 except: top_p = 0.95; logging.warning("Invalid top_p, using 0.95.") try: repetition_penalty = float(repetition_penalty); assert repetition_penalty >= 1.0 except: repetition_penalty = 1.1; logging.warning("Invalid repetition_penalty, using 1.1.") is_greedy = (temperature < 1e-6) or \ (top_k == 1 and top_k != 0) or \ (top_p <= 0.0 or top_p >= 1.0) or \ getattr(config, "token_gen_speed_maximized", False) gen_kwargs = { "max_new_tokens": max_new_tokens, "temperature": temperature if not is_greedy else None, "top_k": top_k if top_k > 0 and not is_greedy else None, "top_p": top_p if top_p > 0.0 and top_p < 1.0 and not is_greedy else None, "repetition_penalty": repetition_penalty if repetition_penalty > 1.0 else None, "do_sample": not is_greedy, "use_cache": getattr(config, "use_cache", True), "num_beams": (max(getattr(config, "num_beams", 1), 4) if apply_coherence else getattr(config, "num_beams", 1)), "pad_token_id": global_tokenizer.pad_token_id if global_tokenizer.pad_token_id is not None else getattr(config, 'pad_token_id', None), "eos_token_id": global_tokenizer.eos_token_id if global_tokenizer.eos_token_id is not None else getattr(config, 'eos_token_id', None), "early_stopping": True if (apply_coherence or getattr(config, "num_beams", 1) > 1) else False } gen_kwargs = {k: v for k, v in gen_kwargs.items() if v is not None} if gen_kwargs.get("num_beams", 1) > 1 and gen_kwargs.get("pad_token_id") is None: if gen_kwargs.get("eos_token_id") is not None: gen_kwargs["pad_token_id"] = gen_kwargs["eos_token_id"] logging.warning(f"Using eos_token_id ({gen_kwargs['eos_token_id']}) as pad_token_id for beam search.") else: logging.error("Beam search requires pad_token_id, but none found (and eos_token_id missing). Generation might fail.") return "[Error] Beam search failed: pad_token_id is required." response_text = "" device = get_device() logging.debug(f"Generation arguments: {gen_kwargs}") inputs = global_tokenizer(prompt, return_tensors="pt", padding=False, truncation=True, max_length=CONTEXT_LENGTH).to(device) gen_kwargs["input_ids"] = inputs["input_ids"] if "attention_mask" in inputs: gen_kwargs["attention_mask"] = inputs["attention_mask"] global_model.eval() if apply_coherence: response_text = _improve_coherence(global_model, global_tokenizer, gen_kwargs) else: with torch.no_grad(): outputs = global_model.generate(**gen_kwargs) output_sequence = outputs[0] response_ids = output_sequence[inputs.input_ids.shape[-1]:] response_text = global_tokenizer.decode(response_ids, skip_special_tokens=True) if use_filters: filtered_response = _apply_output_filters(response_text, config) if filtered_response != response_text: logging.info("Output filters applied modifications.") response_text = filtered_response final_response = response_text.strip() logging.info(f"Inference finished in {time.time()-t_start_inf:.2f}s. Response length: {len(final_response)}") return final_response except Exception as e: logging.error(f"Error during inference: {e}\n{traceback.format_exc()}") return f"[Error during inference: {e}]" finally: if global_model and hasattr(global_model, 'training') and global_model.training: global_model.train() def start_training( base_model_id: str, new_model_name: str, hf_token: str, datasets_input_str: str, activation_fn_name: str, target_layers_int: int, grad_accum_ui: int, lr: float, epochs: int, max_steps: int, batch_size: int, optimizer_name: str, scheduler_type: str, weight_decay: float, warmup_ratio: float, use_peft: bool, peft_r: int, peft_alpha: int, peft_dropout: float, peft_target_modules_str: str, wandb_token: str, use_cpu_flag: bool, bypass_limits_state: bool ): global global_model, global_tokenizer, global_pipe, original_num_layers_global, config, target_layers global USE_CPU, BATCH_SIZE, LEARNING_RATE, EPOCHS, MAX_STEPS, DEFAULT_OPTIMIZER, DEFAULT_SCHEDULER, GRADIENT_ACCUMULATION_STEPS global BYPASS_RESOURCE_LIMITS BYPASS_RESOURCE_LIMITS = bypass_limits_state start_overall_time = time.time() logging.info("="*50) logging.info("🚀 STARTING TRAINING PROCESS 🚀") resources_ok, res_msg = check_resources() if not resources_ok: error_msg = f"[Error] Resource limits exceeded, cannot start training. {res_msg}" logging.error(error_msg) return error_msg else: logging.info(res_msg) errors = [] if not base_model_id: errors.append("Base Model ID/Path is required.") if not new_model_name: errors.append("New Model Name (for saving/Hub) is required.") if not datasets_input_str: errors.append("At least one dataset must be provided.") try: target_layers_int = int(target_layers_int); assert target_layers_int >= 1 except: errors.append("Target Layers must be a positive integer.") try: grad_accum_ui = int(grad_accum_ui); assert grad_accum_ui >= 1 except: errors.append("Gradient Accumulation Steps must be a positive integer.") try: lr = float(lr); assert lr > 0 except: errors.append("Learning Rate must be a positive float.") try: epochs = int(epochs); assert epochs >= 0 except: errors.append("Epochs must be an integer >= 0.") try: max_steps = int(max_steps); assert max_steps >= 0 except: errors.append("Max Steps must be an integer >= 0.") if epochs <= 0 and max_steps <= 0: errors.append("Training requires at least one of Epochs or Max Steps to be positive.") elif epochs > 0 and max_steps > 0: logging.info(f"Both Epochs ({epochs}) and Max Steps ({max_steps}) are set (> 0). Max Steps will take precedence.") epochs = -1 elif epochs <= 0 and max_steps > 0: epochs = -1 elif epochs > 0 and max_steps <= 0: logging.info(f"Using Epochs ({epochs}) for training termination as Max Steps <= 0.") max_steps = -1 else: logging.error("Logic error in epoch/max_step handling. Defaulting Max Steps to 1.") max_steps = 1 epochs = -1 try: batch_size = int(batch_size); assert batch_size >= 1 except: errors.append("Batch Size must be a positive integer.") if optimizer_name not in OPTIMIZERS: errors.append(f"Invalid Optimizer. Choose from: {list(OPTIMIZERS.keys())}") if scheduler_type not in SCHEDULER_TYPES: errors.append(f"Invalid Scheduler. Choose from: {SCHEDULER_TYPES}") try: weight_decay = float(weight_decay); assert weight_decay >= 0.0 except: errors.append("Weight Decay must be a non-negative float.") try: warmup_ratio = float(warmup_ratio); assert 0.0 <= warmup_ratio <= 1.0 except: errors.append("Warmup Ratio must be between 0.0 and 1.0.") if activation_fn_name not in ACTIVATION_FUNCTIONS: errors.append(f"Invalid Activation Function. Choose from: {list(ACTIVATION_FUNCTIONS.keys())}") if use_peft and not _peft_installed: errors.append("PEFT requested, but library not installed (`pip install peft`).") peft_config_dict = {} if use_peft: try: peft_r = int(peft_r); assert peft_r >= 1 peft_alpha = int(peft_alpha); assert peft_alpha >= 1 peft_dropout = float(peft_dropout); assert 0.0 <= peft_dropout <= 1.0 peft_config_dict = { "task_type": TaskType.CAUSAL_LM, "inference_mode": False, "r": peft_r, "lora_alpha": peft_alpha, "lora_dropout": peft_dropout, } if peft_target_modules_str: modules = [m.strip() for m in peft_target_modules_str.split(',') if m.strip()] if modules: peft_config_dict["target_modules"] = modules except Exception as peft_e: errors.append(f"Invalid PEFT configuration: {peft_e}") if errors: error_msg = "[Error] Invalid training parameters:\n- " + "\n- ".join(errors) logging.error(error_msg) return error_msg logging.info(f"Base Model: {base_model_id}, New Name: {new_model_name}") logging.info(f"Use PEFT: {use_peft}") if use_peft: logging.info(f"PEFT Config: r={peft_r}, alpha={peft_alpha}, dropout={peft_dropout}, targets={peft_target_modules_str or 'Auto'}") logging.info(f"Datasets: \n{datasets_input_str}") logging.info(f"LR: {lr}, Effective Epochs: {epochs if epochs > 0 else 'N/A'}, MaxSteps: {max_steps if max_steps > 0 else 'N/A'}, BS: {batch_size}, GradAccum: {grad_accum_ui}") logging.info(f"Optim: {optimizer_name}, Scheduler: {scheduler_type}, WD: {weight_decay}, Warmup: {warmup_ratio}") logging.info(f"Post-Mod Target Layers: {target_layers_int}, Post-Mod ActFn: {activation_fn_name}") logging.info(f"Use CPU: {use_cpu_flag}, W&B: {'Enabled' if wandb_token else 'Disabled'}, Bypass Limits: {BYPASS_RESOURCE_LIMITS}") logging.info("="*50) USE_CPU = use_cpu_flag BATCH_SIZE = batch_size LEARNING_RATE = lr EPOCHS = epochs if epochs > 0 else 1 MAX_STEPS = max_steps DEFAULT_OPTIMIZER = optimizer_name DEFAULT_SCHEDULER = scheduler_type GRADIENT_ACCUMULATION_STEPS = grad_accum_ui target_layers = target_layers_int logging.info("Setting up environment...") clean_memory() device = get_device() logging.info(f"Using device: {device}") num_cpu_cores_os = os.cpu_count() or 1 global NUM_CPU_CORES if NUM_CPU_CORES <= 0: NUM_CPU_CORES = num_cpu_cores_os else: NUM_CPU_CORES = min(NUM_CPU_CORES, num_cpu_cores_os) logging.info(f"Using {NUM_CPU_CORES} CPU cores for dataloading.") wandb_run = None use_wandb_reporting = False if wandb_token: logging.info("Attempting WandB login...") try: wandb.login(key=wandb_token) logging.info("WandB login successful.") use_wandb_reporting = True except Exception as e: logging.warning(f"WandB login failed: {e}. Proceeding without WandB logging.") report_to = ["wandb"] if use_wandb_reporting else [] user_id = "local_user"; repo_id_str = new_model_name; repo_link = "N/A (Upload skipped or failed)" upload_to_hub = False if hf_token: logging.info("Attempting Hugging Face login...") user_id = get_user_id(hf_token) if user_id not in ["unknown_user", "http_error_user", "auth_error_user"]: try: login(token=hf_token, add_to_git_credential=False) repo_id_str = f"{user_id}/{new_model_name}" logging.info(f"Hugging Face login successful. User: {user_id}, Target Repo: {repo_id_str}") create_repo(repo_id=repo_id_str, repo_type="model", exist_ok=True, token=hf_token) logging.info(f"Hub repository '{repo_id_str}' ensured.") repo_link = f"https://huggingface.co/{repo_id_str}" upload_to_hub = True except Exception as e: logging.warning(f"Hugging Face login or repo creation failed: {e}. Upload will be skipped.") hf_token = None repo_id_str = new_model_name repo_link = "N/A (Login/Repo Failed)" else: logging.warning(f"Could not get valid Hugging Face user ID ({user_id}). Upload will be skipped.") hf_token = None repo_id_str = new_model_name repo_link = "N/A (Login Failed)" else: logging.info("No HF write token provided, Hub upload will be skipped.") logging.info(f"Loading base model '{base_model_id}' and tokenizer...") try: tokenizer = AutoTokenizer.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token) if tokenizer.pad_token is None: if tokenizer.eos_token is not None: tokenizer.pad_token = tokenizer.eos_token logging.info(f"Set tokenizer pad_token to eos_token ('{tokenizer.eos_token}')") else: added_pad = tokenizer.add_special_tokens({'pad_token': '[PAD]'}) if added_pad > 0: logging.warning("Tokenizer missing pad_token and eos_token. Added '[PAD]' as pad_token.") else: logging.error("Tokenizer missing pad/eos and failed to add '[PAD]'. Training may fail.") base_config_obj = AutoConfig.from_pretrained(base_model_id, trust_remote_code=True, token=hf_token) base_config_obj = initialize_config_flags(base_config_obj) original_num_layers_global = getattr(base_config_obj, 'num_hidden_layers', LAYERS) if getattr(base_config_obj, 'original_num_layers', None) is None: base_config_obj.original_num_layers = original_num_layers_global if getattr(base_config_obj, 'vocab_size', -1) != len(tokenizer): logging.warning(f"Config vocab size ({getattr(base_config_obj, 'vocab_size', 'N/A')}) differs from tokenizer ({len(tokenizer)}). Updating config.") base_config_obj.vocab_size = len(tokenizer) if getattr(base_config_obj, 'pad_token_id', -999) != tokenizer.pad_token_id: base_config_obj.pad_token_id = tokenizer.pad_token_id load_dtype = torch.bfloat16 if device.type == 'cuda' and torch.cuda.is_bf16_supported() else torch.float16 if device.type == 'cuda' else torch.float32 attn_impl_load = getattr(base_config_obj, 'attn_implementation', 'auto') if attn_impl_load == "flash_attention_2": base_config_obj.use_flash_attention_2 = True elif getattr(base_config_obj,'use_flash_attention_2', False): attn_impl_load = "flash_attention_2"; base_config_obj.attn_implementation = "flash_attention_2" logging.info(f"Loading model with dtype={load_dtype}, attn_implementation='{attn_impl_load}'...") model = AutoModelForCausalLM.from_pretrained( base_model_id, config=base_config_obj, trust_remote_code=True, token=hf_token, torch_dtype=load_dtype, low_cpu_mem_usage=True if device.type != 'cpu' else False, attn_implementation=attn_impl_load if attn_impl_load != 'auto' else None ) if model.get_input_embeddings().weight.shape[0] != len(tokenizer): logging.info(f"Resizing model token embeddings from {model.get_input_embeddings().weight.shape[0]} to tokenizer size {len(tokenizer)}") model.resize_token_embeddings(len(tokenizer)) if getattr(model.config, 'vocab_size', -1) != len(tokenizer): model.config.vocab_size = len(tokenizer) logging.info(f"Base model '{base_model_id}' loaded. Original Layers: {original_num_layers_global}, Current Layers: {model.config.num_hidden_layers}, Dtype: {model.dtype}") if device.type == 'cpu' or not (device.type != 'cpu' and True): model.to(device) logging.info(f"Model moved to device: {device}") else: logging.info(f"Model loaded with low_cpu_mem_usage, should be on target device(s).") config = model.config except Exception as e: logging.error(f"Failed to load base model or tokenizer '{base_model_id}': {e} \n{traceback.format_exc()}") return f"[Error] Load failed for '{base_model_id}': {e}" if use_peft: logging.info("Applying PEFT adapter to the model for training...") try: lora_config = LoraConfig(**peft_config_dict) peft_add_msg = _add_peft_adapter(model, config, peft_config_obj=lora_config) except Exception as peft_e: logging.error(f"Failed to configure or add PEFT adapter: {peft_e}") return f"[Error] Failed to prepare PEFT model: {peft_e}" if "[Error]" in peft_add_msg or "[Warning]" in peft_add_msg: logging.error(f"Failed adding PEFT adapter: {peft_add_msg}") return f"[Error] Failed adding PEFT adapter: {peft_add_msg}" model = global_model config = global_model.get_base_model().config logging.info("PEFT adapter added successfully.") else: logging.info("Proceeding with full fine-tuning (PEFT not selected).") logging.info("Loading and processing datasets...") train_ds_processed = None eval_ds_processed = None try: datasets_config_list = parse_datasets(datasets_input_str) interleaved_ds = load_datasets_from_config(datasets_config_list) if interleaved_ds is None: raise ValueError("Dataset loading and interleaving resulted in None. Check logs.") tokenize_partial = partial(tokenize_function, tokenizer=tokenizer, context_length=CONTEXT_LENGTH) tokenized_ds = interleaved_ds.map( tokenize_partial, batched=True, batch_size=1000, ) group_partial = partial(group_texts, block_size=CONTEXT_LENGTH) lm_dataset = tokenized_ds.map( group_partial, batched=True, batch_size=1000, ) try: peek_final = next(iter(lm_dataset)) final_cols = list(peek_final.keys()) logging.info(f"Sample processed record structure: { {k: type(v).__name__ for k, v in peek_final.items()} }") if not all(k in final_cols for k in ['input_ids', 'attention_mask', 'labels']): raise ValueError(f"Final dataset structure after tokenizing/grouping is missing required keys. Found: {final_cols}") except StopIteration: raise ValueError("Dataset appears empty after processing and grouping.") logging.info("Dataset tokenization and grouping complete.") train_ds_processed, eval_ds_processed = split_dataset(lm_dataset) if isinstance(train_ds_processed, IterableDataset): logging.info("Training dataset is iterable (streaming).") elif isinstance(train_ds_processed, Dataset): logging.info(f"Training dataset size: {len(train_ds_processed):,} examples.") else: logging.warning("Could not determine training dataset type or size.") if eval_ds_processed is not None: logging.info(f"Created static evaluation dataset with {len(eval_ds_processed)} examples.") else: logging.info("No evaluation dataset created (buffer empty or error occurred).") except Exception as e: logging.error(f"Dataset loading, processing, or splitting failed: {e} \n{traceback.format_exc()}") return f"[Error] Dataset preparation failed: {e}" logging.info("Setting up Training Arguments...") final_weight_decay = weight_decay if not getattr(config, 'weight_decay_disabled', False) else 0.0 final_lr_scheduler = scheduler_type if not getattr(config, 'lr_scheduler_disabled', False) else "constant" max_grad_norm_val = 1.0 if not getattr(config, 'gradient_clipping_disabled', False) else None output_dir = f"./{new_model_name}_training_output" training_args = TrainingArguments( output_dir=output_dir, overwrite_output_dir=True, report_to=report_to, per_device_train_batch_size=BATCH_SIZE, per_device_eval_batch_size=max(1, BATCH_SIZE * 2), gradient_accumulation_steps=GRADIENT_ACCUMULATION_STEPS, num_train_epochs=EPOCHS if EPOCHS > 0 else 1, max_steps=MAX_STEPS, optim=optimizer_name, learning_rate=LEARNING_RATE, weight_decay=final_weight_decay, warmup_ratio=warmup_ratio, lr_scheduler_type=final_lr_scheduler, max_grad_norm=max_grad_norm_val if max_grad_norm_val is not None else 1e9, fp16=load_dtype == torch.float16 and device.type == 'cuda', bf16=load_dtype == torch.bfloat16 and device.type == 'cuda', gradient_checkpointing=getattr(config, 'gradient_checkpointing_enabled', False), gradient_checkpointing_kwargs={'use_reentrant': False} if getattr(config, 'gradient_checkpointing_enabled', False) else None, dataloader_num_workers=NUM_CPU_CORES, dataloader_pin_memory=True if device.type == 'cuda' else False, evaluation_strategy="steps" if eval_ds_processed is not None else "no", eval_steps=EVAL_STEPS if eval_ds_processed is not None else None, save_strategy="steps", save_steps=SAVE_STEPS, save_total_limit=2, load_best_model_at_end=LOAD_BEST_MODEL_AT_END if eval_ds_processed is not None else False, metric_for_best_model=METRIC_FOR_BEST_MODEL if eval_ds_processed is not None else None, logging_strategy="steps", logging_steps=LOGGING_STEPS, push_to_hub=upload_to_hub, hub_model_id=repo_id_str if upload_to_hub else None, hub_token=hf_token if upload_to_hub else None, hub_strategy="checkpoint", use_cpu=USE_CPU, seed=42, remove_unused_columns=False, log_level="info", ) logging.info("Initializing Trainer...") data_collator = DataCollatorForLanguageModeling(tokenizer, mlm=False) callbacks = [] if LOAD_BEST_MODEL_AT_END and eval_ds_processed is not None: callbacks.append(EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE, early_stopping_threshold=0.001)) if use_wandb_reporting: try: wandb_run = wandb.init( project=f"llm-modify-train-{new_model_name.replace('/', '-')}", config=training_args.to_dict(), name=f"run-{new_model_name.replace('/', '-')}-{int(time.time())}", reinit=True ) logging.info(f"WandB run initialized: {wandb_run.name if wandb_run else 'Failed'}") except Exception as wandb_e: logging.error(f"Failed to initialize WandB run: {wandb_e}") wandb_run = None training_args.report_to = [] trainer = Trainer( model=model, args=training_args, tokenizer=tokenizer, train_dataset=train_ds_processed, eval_dataset=eval_ds_processed, data_collator=data_collator, callbacks=callbacks ) start_train_time = time.time() logging.info(f"🚀 Starting model training (Using {type(trainer.model).__name__}). Effective steps: {training_args.max_steps if training_args.max_steps > 0 else 'N/A'}. Effective epochs: {training_args.num_train_epochs if training_args.num_train_epochs > 0 else 'N/A'}.") train_result = None training_successful = False try: last_checkpoint = None if os.path.isdir(training_args.output_dir): from transformers.trainer_utils import get_last_checkpoint last_checkpoint = get_last_checkpoint(training_args.output_dir) if last_checkpoint: logging.info(f"*** Resuming training from checkpoint: {last_checkpoint} ***") train_result = trainer.train(resume_from_checkpoint=last_checkpoint) logging.info("✅ Training finished successfully.") training_successful = True trainer.save_model() trainer.save_state() if not use_peft: tokenizer.save_pretrained(training_args.output_dir) elif isinstance(trainer.model, PeftModel): tokenizer.save_pretrained(training_args.output_dir) except Exception as e: logging.error(f"❌ Training failed: {e}\n{traceback.format_exc()}") if wandb_run: wandb_run.finish(exit_code=1) return f"[Error] Training failed: {e}" finally: end_train_time = time.time() clean_memory() training_time = end_train_time - start_train_time logging.info(f"🕒 Training phase took {training_time:.2f} seconds.") if not training_successful: return "[Error] Training did not complete successfully." final_trained_model = trainer.model model_to_save = final_trained_model merged_model_for_mods = None if use_peft and isinstance(final_trained_model, PeftModel): logging.info("Merging PEFT adapter into the base model for modification and final save...") try: merged_model_for_mods = final_trained_model.merge_and_unload() logging.info("PEFT adapter merged successfully.") merged_model_for_mods.config.peft_adapter_added = False merged_model_for_mods.config.peft_config = None merged_model_for_mods.config.lora_merged = True except Exception as e: logging.error(f"Failed to merge PEFT adapter after training: {e}. Saving adapter separately.") adapter_save_path = os.path.join(training_args.output_dir, "final_adapter") try: final_trained_model.save_pretrained(adapter_save_path) base_model_for_saving = final_trained_model.get_base_model() base_model_for_saving.save_pretrained(training_args.output_dir) tokenizer.save_pretrained(training_args.output_dir) logging.info(f"PEFT adapter saved separately to {adapter_save_path}, base model to {training_args.output_dir}") merged_model_for_mods = final_trained_model except Exception as save_e: logging.error(f"Failed to save adapter or base model separately: {save_e}. Proceeding with potentially unmerged PEFT model.") merged_model_for_mods = final_trained_model else: merged_model_for_mods = final_trained_model if merged_model_for_mods is None: logging.error("Model state after training/merging is None. Cannot proceed.") return "[Error] Lost model reference after training/merging." def modify_model_post_train(model_obj, act_fn_name, target_layer_count): logging.info(f"Applying post-training modifications: Target Layers={target_layer_count}, Activation={act_fn_name}") if not hasattr(model_obj, 'config'): logging.error("Cannot modify model: Missing config.") return model_obj current_config = initialize_config_flags(model_obj.config) model_obj.config = current_config current_layers = getattr(current_config, 'num_hidden_layers', None) original_layers = getattr(current_config, 'original_num_layers', original_num_layers_global) if current_layers is not None and original_layers is not None: if target_layer_count != current_layers: logging.info(f"Adjusting layers post-training: {current_layers} -> {target_layer_count} (Original: {original_layers})") if target_layer_count < current_layers: _reduce_layers_to_one(model_obj, current_config, target_layers=target_layer_count) else: restore_target = min(target_layer_count, original_layers) if original_layers else target_layer_count if restore_target > current_layers: logging.info(f"Attempting to restore layers: {current_layers} -> {restore_target}") _enable_full_layers(model_obj, current_config, original_num_layers=restore_target) else: logging.info(f"Target layers ({target_layer_count}) >= current layers ({current_layers}). No layer increase needed or possible beyond original.") else: logging.info(f"Target layers ({target_layer_count}) matches current layers after training. No layer adjustment needed.") else: logging.warning("Could not determine current or original layer count from config post-training. Skipping layer adjustment.") current_act_fn = getattr(current_config, 'current_activation_function', DEFAULT_ACTIVATION_FUNCTION) if act_fn_name != current_act_fn: logging.info(f"Setting activation function post-training to: {act_fn_name}") _swap_activation_function(model_obj, current_config, act_fn_name) else: logging.info(f"Target activation function ({act_fn_name}) already matches current. No change needed.") logging.info("Post-training modifications applied.") return model_obj logging.info("Applying final post-training modifications specified in UI...") final_model_modified = modify_model_post_train(merged_model_for_mods, activation_fn_name, target_layers_int) if merged_model_for_mods is not final_model_modified: del merged_model_for_mods clean_memory() final_model_path = training_args.output_dir logging.info(f"Saving final modified model state to {final_model_path}...") try: save_kwargs = {"safe_serialization": True} final_model_modified.config = initialize_config_flags(final_model_modified.config) if use_peft and isinstance(final_model_modified, PeftModel): logging.warning("Saving unmerged PEFT model state again after modifications (adapter separate).") adapter_save_dir = os.path.join(final_model_path, "final_adapter_modified") final_model_modified.save_pretrained(adapter_save_dir) logging.info(f"PEFT adapter saved to {adapter_save_dir}") try: base_model_final = final_model_modified.get_base_model() base_model_final.save_pretrained(final_model_path, **save_kwargs) tokenizer.save_pretrained(final_model_path) logging.info(f"Base model saved to {final_model_path}") except Exception as base_save_e: logging.error(f"Failed to save base model separately after modification: {base_save_e}. Only adapter might be saved.") else: final_model_modified.save_pretrained(final_model_path, **save_kwargs) tokenizer.save_pretrained(final_model_path) logging.info("Final modified model saved locally.") except Exception as e: logging.error(f"Failed to save final modified model locally: {e}\n{traceback.format_exc()}") if wandb_run: wandb_run.finish(exit_code=1) global_model = final_model_modified.to(device) global_tokenizer = tokenizer config = final_model_modified.config update_pipeline() clean_memory() return f"[Error] Failed to save final model locally: {e}. Training logs/checkpoints might be in {output_dir}." final_eval_results = {}; final_eval_loss = None; final_perplexity = float('inf') if eval_ds_processed is not None: logging.info("Evaluating final modified model..."); T_final_eval = time.time() try: final_trainer = Trainer( model=final_model_modified, args=training_args, tokenizer=tokenizer, data_collator=data_collator, eval_dataset=eval_ds_processed, ) final_eval_results = final_trainer.evaluate() final_eval_loss = final_eval_results.get("eval_loss") final_perplexity = compute_perplexity(final_eval_loss) logging.info(f"✅ Final Model Evaluation Results: {final_eval_results}") logging.info(f"Final Model Perplexity: {final_perplexity:.4f} (Eval time: {time.time() - T_final_eval:.2f}s)") if use_wandb_reporting and wandb_run: wandb_run.log({"final_eval_loss": final_eval_loss if final_eval_loss is not None else -1.0, "final_perplexity": final_perplexity if final_perplexity != float('inf') else -1.0, **final_eval_results}) except Exception as e: logging.error(f"Final evaluation failed: {e}\n{traceback.format_exc()}") if use_wandb_reporting and wandb_run: wandb_run.log({"final_eval_status": "Failed", "final_eval_error": str(e)}) else: logging.info("Skipping final evaluation as no evaluation dataset was available.") upload_successful_final = False if upload_to_hub: logging.info(f"Attempting final upload of '{final_model_path}' to Hugging Face Hub: {repo_id_str}...") try: api = HfApi() api.upload_folder( folder_path=final_model_path, repo_id=repo_id_str, repo_type="model", token=hf_token, commit_message=f"Upload final trained model: {new_model_name} (Base: {base_model_id}, PPL: {final_perplexity:.2f})", commit_description=(f"Training completed. Eval Loss: {final_eval_loss:.4f if final_eval_loss is not None else 'N/A'}, Perplexity: {final_perplexity:.4f if final_perplexity != float('inf') else 'N/A'}. " f"Config: PEFT={use_peft}, Layers={target_layers_int}, ActFn={activation_fn_name}. Training time: {training_time:.2f}s.") ) repo_link = f"https://huggingface.co/{repo_id_str}" logging.info(f"✅ Final model upload complete: {repo_link}") if use_wandb_reporting and wandb_run: wandb_run.log({"hf_repo_link": repo_link, "hf_upload_status": "Success"}) upload_successful_final = True except Exception as e: logging.error(f"Final Hugging Face upload failed: {e}\n{traceback.format_exc()}") repo_link = "[Upload Failed]" if use_wandb_reporting and wandb_run: wandb_run.log({"hf_upload_status": "Failed", "hf_upload_error": str(e)}) else: logging.info(f"Skipping final Hugging Face Hub upload based on initial setup.") logging.info("Updating global state with the final model...") global_model = final_model_modified.to(device) global_tokenizer = tokenizer config = global_model.config update_pipeline() clean_memory() final_status_report_json = decode_model_details(global_model) total_script_time = time.time() - start_overall_time final_message = ( f"✅ Training & Modification Process Complete!\n" f"{'='*40}\n" f"New Model Name: {new_model_name}\n" f"Base Model: {base_model_id}\n" f"Total Time: {total_script_time:.2f}s | Training Phase Time: {training_time:.2f}s\n" f"{'='*40}\n" f"Training Results:\n" ) if train_result: final_message += f" - Steps Completed: {train_result.global_step}\n" train_loss = train_result.training_loss final_message += f" - Training Loss: {train_loss:.4f if train_loss is not None else 'N/A'}\n" train_metrics = train_result.metrics for metric, value in train_metrics.items(): if "loss" in metric.lower() or "perplexity" in metric.lower() or "epoch" in metric.lower() or "step" in metric.lower(): value_str = f"{value:.4f}" if isinstance(value, float) else str(value) final_message += f" - {metric.replace('_', ' ').title()}: {value_str}\n" final_message += ( f"Final Evaluation:\n" f" - Eval Loss: {final_eval_loss:.4f if final_eval_loss is not None else 'N/A'}\n" f" - Perplexity: {final_perplexity:.4f if final_perplexity != float('inf') else 'N/A'}\n" f"{'='*40}\n" f"Saving & Upload:\n" f" - Local Path: {final_model_path}\n" f" - Hub Repo: {repo_link}\n" f"{'='*40}\n" f"Final Model Status Summary:\n" ) try: status_data = json.loads(final_status_report_json) summary_keys = ["Model Class", "Config Class", "Device(s)", "Params Summary", "Layer Types Count", "Key Config Attributes", "Modification Flags"] for key in summary_keys: if key in status_data: value = status_data[key] if isinstance(value, dict): value_str = json.dumps(value, indent=4) elif isinstance(value, list): value_str = ", ".join(map(str, value)) else: value_str = str(value) if len(value_str) > 200: value_str = value_str[:200] + "..." final_message += f" - {key}: {value_str}\n" final_message += f"(Full status logged and available in 'Model Controls' tab after refresh)\n" except Exception as json_e: logging.warning(f"Could not parse final status JSON for summary: {json_e}") final_message += "(Could not generate status summary from JSON)\n" final_message += f"{'='*40}" if use_wandb_reporting and wandb_run: try: wandb_final_log = { "total_time_seconds": total_script_time, "training_time_seconds": training_time, "final_eval_loss": final_eval_loss if final_eval_loss is not None else -1.0, "final_perplexity": final_perplexity if final_perplexity != float('inf') else -1.0, "upload_successful": upload_successful_final, "final_steps_completed": train_result.global_step if train_result else -1, "final_train_loss": train_result.training_loss if train_result and train_result.training_loss else -1.0, } wandb_run.log(wandb_final_log) wandb_run.finish() logging.info("WandB run finished.") except Exception as e: logging.warning(f"Error finishing WandB run: {e}") logging.info("🏁 Full training and modification process finished. 🏁") return final_message def load_model_for_control(model_id_or_path, hf_token=None, bypass_limits_state=False): global global_model, global_tokenizer, global_pipe, config, original_num_layers_global, BYPASS_RESOURCE_LIMITS BYPASS_RESOURCE_LIMITS = bypass_limits_state logging.info(f"Attempting to load model for control: {model_id_or_path}") if not model_id_or_path: return "[Error] Model ID or Path cannot be empty.", "{}", *get_error_filter_updates() resources_ok, res_msg = check_resources() if not resources_ok: error_msg = f"[Error] Resource limits exceeded, cannot load model. {res_msg}" logging.error(error_msg) return error_msg, "{}", *get_error_filter_updates() else: logging.info(res_msg) t_load_start = time.time() device = get_device() error_return = f"[Error] Failed to load model '{model_id_or_path}'.", "{}", *get_error_filter_updates() global_model, global_tokenizer, global_pipe, config = None, None, None, None clean_memory() try: logging.info("Loading tokenizer...") tokenizer_load = AutoTokenizer.from_pretrained( model_id_or_path, trust_remote_code=True, token=hf_token ) if tokenizer_load.pad_token is None: if tokenizer_load.eos_token is not None: tokenizer_load.pad_token = tokenizer_load.eos_token logging.info(f"Set tokenizer pad_token to eos_token ('{tokenizer_load.eos_token}')") else: try: tokenizer_load.add_special_tokens({'pad_token': '[PAD]'}) logging.warning("Added '[PAD]' as pad_token.") except Exception as pad_e: logging.error(f"Could not set PAD token: {pad_e}. Batching or beam search might fail.") logging.info("Loading model config...") loaded_config = AutoConfig.from_pretrained( model_id_or_path, trust_remote_code=True, token=hf_token ) config_load = initialize_config_flags(loaded_config) original_layers_load = getattr(config_load, 'num_hidden_layers', LAYERS) if getattr(config_load, 'original_num_layers', None) is None: config_load.original_num_layers = original_layers_load logging.info(f"Set original_num_layers in loaded config to {original_layers_load}") original_num_layers_global = config_load.original_num_layers if getattr(config_load, 'vocab_size', -1) != len(tokenizer_load): config_load.vocab_size = len(tokenizer_load) if getattr(config_load, 'pad_token_id', -999) != tokenizer_load.pad_token_id: config_load.pad_token_id = tokenizer_load.pad_token_id logging.info("Loading model weights...") attn_impl_load = getattr(config_load, 'attn_implementation', 'auto') if attn_impl_load == "flash_attention_2": config_load.use_flash_attention_2 = True elif getattr(config_load,'use_flash_attention_2', False): attn_impl_load = "flash_attention_2"; config_load.attn_implementation = "flash_attention_2" load_dtype = torch.bfloat16 if device.type == 'cuda' and torch.cuda.is_bf16_supported() else torch.float16 if device.type == 'cuda' else torch.float32 logging.info(f"Using dtype {load_dtype} and attn_implementation '{attn_impl_load}' for loading.") model_load = AutoModelForCausalLM.from_pretrained( model_id_or_path, config=config_load, trust_remote_code=True, token=hf_token, torch_dtype=load_dtype, low_cpu_mem_usage=True if device.type != 'cpu' else False, attn_implementation=attn_impl_load if attn_impl_load != 'auto' else None, ) if model_load.get_input_embeddings().weight.shape[0] != len(tokenizer_load): logging.info(f"Resizing loaded model embeddings from {model_load.get_input_embeddings().weight.shape[0]} to tokenizer size {len(tokenizer_load)}") model_load.resize_token_embeddings(len(tokenizer_load)) if getattr(model_load.config, 'vocab_size', -1) != len(tokenizer_load): model_load.config.vocab_size = len(tokenizer_load) global_model = model_load.to(device) global_tokenizer = tokenizer_load config = global_model.config logging.info(f"Model loaded successfully to {device}.") update_pipeline() clean_memory() logging.info(f"Model '{model_id_or_path}' loaded and pipeline updated in {time.time() - t_load_start:.2f}s.") status_json, *filter_updates = get_detailed_status_and_filter_states() return f"Model '{model_id_or_path}' loaded successfully.", status_json, *filter_updates except Exception as e: logging.error(f"Failed to load model '{model_id_or_path}': {e}\n{traceback.format_exc()}") global_model, global_tokenizer, global_pipe, config = None, None, None, None clean_memory() return error_return def save_current_model(save_path, hf_token=None, hub_repo_id=None): global global_model, global_tokenizer, config if not global_model or not global_tokenizer: return "[Error] No model loaded to save." if not save_path and not hub_repo_id: return "[Error] Please provide a local save path or a Hub Repo ID (or both)." t_save_start = time.time() model_to_save = global_model tokenizer_to_save = global_tokenizer config_to_save = initialize_config_flags(config if config else getattr(model_to_save, 'config', None)) if config_to_save is None: logging.error("Cannot save: Model config is missing.") return "[Error] Model config is missing, cannot save." model_to_save.config = config_to_save is_peft_model = _peft_installed and isinstance(model_to_save, PeftModel) save_adapter_only = is_peft_model logging.info(f"Save mode: {'Adapter Only (PEFT model detected)' if save_adapter_only else 'Full Model'}") temp_save_dir = None effective_save_path = save_path.strip() if save_path else None if not effective_save_path and hub_repo_id: temp_save_dir = f"./hub_upload_temp_{hub_repo_id.replace('/', '_')}_{int(time.time())}" effective_save_path = temp_save_dir logging.info(f"No local path provided, saving temporarily to '{effective_save_path}' for Hub upload.") elif not effective_save_path: return "[Error] Cannot determine save location (missing local path and Hub ID)." try: os.makedirs(effective_save_path, exist_ok=True) except OSError as e: logging.error(f"Failed to create save directory '{effective_save_path}': {e}") return f"[Error] Failed to create save directory: {e}" local_save_message = "" try: logging.info(f"Saving current model state to {effective_save_path}...") save_kwargs = {"safe_serialization": True} if save_adapter_only: logging.info("Saving PEFT adapter weights and tokenizer.") model_to_save.save_pretrained(effective_save_path) tokenizer_to_save.save_pretrained(effective_save_path) try: base_model_config = model_to_save.get_base_model().config base_model_config.save_pretrained(effective_save_path) except Exception as config_e: logging.warning(f"Could not save base model config alongside adapter: {config_e}") else: logging.info("Saving full model weights and tokenizer.") model_to_save.save_pretrained(effective_save_path, **save_kwargs) tokenizer_to_save.save_pretrained(effective_save_path) save_local_time = time.time() - t_save_start logging.info(f"Model state saved locally to {effective_save_path} in {save_local_time:.2f}s") local_save_message = f"Model saved locally to '{effective_save_path}'." except Exception as e: logging.error(f"Failed to save model locally to {effective_save_path}: {e}\n{traceback.format_exc()}") if temp_save_dir and os.path.exists(temp_save_dir): try: shutil.rmtree(temp_save_dir); logging.info("Cleaned up temporary directory after local save error.") except Exception as clean_e: logging.warning(f"Could not remove temp dir {temp_save_dir} after error: {clean_e}") return f"[Error] Failed to save model locally: {e}" hub_message = "" upload_successful = False if hub_repo_id: if not hf_token: hub_message = "[Warning] Hub upload skipped: Hugging Face Write Token required." logging.warning(hub_message) else: logging.info(f"Attempting to upload '{effective_save_path}' to Hub repo: {hub_repo_id}") try: api = HfApi(); create_repo(repo_id=hub_repo_id, repo_type="model", exist_ok=True, token=hf_token) api.upload_folder( folder_path=effective_save_path, repo_id=hub_repo_id, repo_type="model", token=hf_token, commit_message=f"Upload model state ({'Adapter' if save_adapter_only else 'Full'}) via LLM Platform", commit_description=f"Saved from LLM Platform UI. Model class: {type(global_model).__name__}. State: {'PEFT Adapter' if save_adapter_only else 'Full Model'}.", ) hub_link = f"https://huggingface.co/{hub_repo_id}" hub_message = f"Successfully uploaded to Hub: {hub_link}" upload_successful = True logging.info(hub_message) except Exception as e: hub_message = f"[Error] Hub upload failed: {e}" logging.error(f"Hub upload failed: {e}\n{traceback.format_exc()}") if temp_save_dir and os.path.exists(temp_save_dir): try: shutil.rmtree(temp_save_dir) logging.info("Cleaned up temporary save directory.") except Exception as e: logging.warning(f"Could not remove temporary directory {temp_save_dir}: {e}") final_message = local_save_message if save_path else "" if hub_message: if final_message: final_message += f" | {hub_message}" else: final_message = hub_message if not final_message: final_message = "[Info] No local save path provided and Hub upload failed or was skipped." total_save_time = time.time() - t_save_start logging.info(f"Total save operation took {total_save_time:.2f}s") return final_message filter_names_ui = [ "Harassment", "Hate Speech", "Sexually Explicit", "Dangerous Content", "Civic Integrity", "Harmful Code", "Medical Advice", "Legal Advice", "Financial Advice", "PII (Basic)", "Political Content", "Religious Content", "Profanity", "Stereotype", "Misinfo", "Self Harm", "Personal Attack", "Toxicity", "Spam", "Off Topic", "Tone", "Min Max Length", "Repetition Filter", "Factuality Filter" ] filter_attr_map = {name: name.lower().replace(" ", "_").replace("(", "").replace(")", "") + "_filter" for name in filter_names_ui} filter_attr_map["PII (Basic)"] = "pii_filter" filter_attr_map["Harmful Code"] = "code_filter" filter_attr_map["Min Max Length"] = "min_max_length_filter" filter_attr_map["Repetition Filter"] = "repetition_filter_enabled" filter_attr_map["Factuality Filter"] = "factuality_filter_enabled" custom_theme = gr.themes.Soft(primary_hue=gr.themes.colors.blue, secondary_hue=gr.themes.colors.sky).set( button_primary_background_fill="*primary_500", button_primary_background_fill_hover="*primary_400", button_secondary_background_fill="*secondary_500", button_secondary_background_fill_hover="*secondary_400", button_cancel_background_fill="*neutral_200", button_cancel_background_fill_hover="*neutral_300", ) with gr.Blocks(theme=custom_theme, title="Advanced LLM Training & Modification Platform") as demo: gr.Markdown("# 🤖 Advanced LLM Training & Modification Platform v1.2") gr.Markdown("Load, modify, filter, train (Full/PEFT), test, merge, and save Large Language Models. Includes PEFT, experimental multi-modal capabilities, reward modeling setup, and resource checks.") with gr.Accordion("🔑 Authentication & Settings", open=False): with gr.Row(): hf_token_read = gr.Textbox(label="🤗 HF Token (Read - Optional, for private models)", type="password", interactive=True, placeholder="hf_...") hf_token_write = gr.Textbox(label="🤗 HF Token (Write - Optional, for Hub upload/training)", type="password", interactive=True, placeholder="hf_...") train_wandb_token_inp = gr.Textbox(label="📊 WandB Token (Optional, for logging runs)", type="password", interactive=True) with gr.Row(): bypass_limits_chk = gr.Checkbox(label="Bypass RAM/Disk Limits (Use with Caution!)", value=False, interactive=True) with gr.Tabs(): with gr.TabItem("💾 Load, Save & Merge"): with gr.Row(): with gr.Column(scale=2): gr.Markdown("### Load Model for Modification & Inference") load_model_selector = HuggingfaceHubSearch(label="Search Hub or Enter Path/ID", placeholder="google/gemma-2b") load_button = gr.Button("Load Model", variant="primary") load_status_output = gr.Textbox(label="Load Status", interactive=False, lines=1) with gr.Column(scale=2): gr.Markdown("### Save Current Model State") save_path_inp = gr.Textbox(label="Local Save Path (Optional)", placeholder="./saved_models/my_modified_model", interactive=True) save_hub_repo_inp = gr.Textbox(label="Hub Repo ID (Optional, e.g., user/repo)", placeholder="username/my-cool-llm", interactive=True) save_button = gr.Button("Save Model", variant="secondary") save_status_output = gr.Textbox(label="Save Status", interactive=False, lines=1) gr.Markdown("---") gr.Markdown("### Merge Model Architectures (Parameter Averaging - Experimental)") gr.Markdown("⚠️ **Experimental:** Averages parameters of models with compatible layers. Enter comma-separated Model IDs/Paths. The first model's config and tokenizer will be used as the base.") merge_model_ids_inp = gr.Textbox(label="Model IDs/Paths to Merge (comma-separated)", placeholder="org/model-a, org/model-b, ./local-model-c") merge_button = gr.Button("Merge Architectures", variant="primary") merge_status_output = gr.Textbox(label="Merge Status", interactive=False, lines=2) with gr.TabItem("🚀 Training"): gr.Markdown("Fine-tune a model based on a selected base model. Supports Full fine-tuning and PEFT (LoRA). Apply modifications post-training.") gr.Markdown("### 1. Base Model & Output Name") with gr.Row(): train_model_selector = HuggingfaceHubSearch(label="Search & Select Base Model for Training", placeholder="Type to search Hugging Face Hub...") with gr.Row(): train_new_model_inp = gr.Textbox(label="New Model Name (for saving locally and optionally on Hub)", placeholder="MyTunedModel-v1", interactive=True) gr.Markdown("### 2. Training Data") with gr.Row(): train_dataset_selector = HuggingfaceHubSearch(label="Search Datasets on Hub (or specify local below)") train_datasets_inp = gr.Textbox( label="Datasets (one per line: 'id[,config[,split[,weight]]]')", placeholder="Example:\nopenwebtext\nwikitext,wikitext-103-raw-v1,train,0.5\nmy_local_dataset_path,,train,1.5\nusername/my_dataset,my_config,validation,2.0", lines=5, interactive=True) gr.Markdown("### 3. Training Configuration") with gr.Accordion("Training Mode & Hyperparameters", open=True): train_use_peft_chk = gr.Checkbox(label="Enable PEFT (LoRA) Training", value=True, interactive=True) with gr.Row(): train_lr_inp = gr.Number(value=LEARNING_RATE, label="Learning Rate", interactive=True, minimum=1e-8, step=1e-6, precision=8) train_epochs_inp = gr.Number(value=EPOCHS, label="Epochs (Set <= 0 if using Max Steps)", precision=0, minimum=-1, interactive=True) train_max_steps_inp = gr.Number(value=MAX_STEPS, label="Max Steps (Set <= 0 if using Epochs)", precision=0, minimum=-1, interactive=True) with gr.Row(): train_batch_size_inp = gr.Number(value=BATCH_SIZE, label="Batch Size (Per Device)", precision=0, minimum=1, interactive=True) train_grad_accum_inp = gr.Number(value=GRADIENT_ACCUMULATION_STEPS, label="Grad Accum Steps", precision=0, minimum=1, interactive=True) train_optim_selector = gr.Dropdown(choices=list(OPTIMIZERS.keys()), value=DEFAULT_OPTIMIZER, label="Optimizer", interactive=True) with gr.Row(): train_scheduler_selector = gr.Dropdown(choices=SCHEDULER_TYPES, value=DEFAULT_SCHEDULER, label="LR Scheduler", interactive=True) train_wd_inp = gr.Number(value=0.01, label="Weight Decay", minimum=0.0, interactive=True, step=0.001, precision=4) train_warmup_ratio_inp = gr.Slider(0.0, 0.5, value=0.03, step=0.01, label="Warmup Ratio", interactive=True) with gr.Accordion("PEFT Configuration (if PEFT enabled)", open=False, visible=True) as peft_config_accordion: peft_r_inp = gr.Slider(label="LoRA r (Rank)", minimum=1, maximum=256, value=8, step=1, interactive=True) peft_alpha_inp = gr.Slider(label="LoRA alpha", minimum=1, maximum=512, value=32, step=1, interactive=True) peft_dropout_inp = gr.Slider(label="LoRA Dropout", minimum=0.0, maximum=0.5, value=0.1, step=0.01, interactive=True) peft_target_modules_inp = gr.Textbox(label="Target Modules (comma-sep, optional, e.g., q_proj,v_proj)", placeholder="Leave empty for auto-detection (recommended)", interactive=True) train_use_peft_chk.change(lambda x: gr.update(visible=x), inputs=train_use_peft_chk, outputs=peft_config_accordion) with gr.Accordion("Post-Training Modifications (Applied After Training)", open=False): with gr.Row(): train_post_activation_fn_selector = gr.Dropdown(choices=list(ACTIVATION_FUNCTIONS.keys()), value=DEFAULT_ACTIVATION_FUNCTION, label="Target Activation Fn") train_post_target_layers_inp = gr.Number(value=LAYERS, label="Target Layer Count", precision=0, minimum=1) with gr.Accordion("Hardware & Logging", open=False): train_use_cpu_chk = gr.Checkbox(value=USE_CPU, label="Force Use CPU (Very Slow!)", interactive=True) gr.Markdown("### 4. Start Training") train_button = gr.Button("✨ Start Training Process", variant="primary") train_output = gr.Textbox(label="Training Log & Status", interactive=False, lines=20, max_lines=50) with gr.TabItem("🔧 Model Controls"): gr.Markdown("Interactively toggle modifications and filters for the **currently loaded** model. Refresh status after changes.") with gr.Row(): refresh_status_button = gr.Button("🔄 Refresh Status & Filter Checkboxes") control_output = gr.Textbox(label="Control Action Status", interactive=False, lines=1) status_output = gr.TextArea(label="Current Model Status (JSON)", interactive=False, lines=20, max_lines=60) with gr.Tabs(): with gr.TabItem("Core & Structure"): with gr.Row(): with gr.Column(min_width=150): bias_on = gr.Button("Bias Rem. ✅"); bias_off = gr.Button("Bias Rem. ❌") with gr.Column(min_width=150): emb_on = gr.Button("Emb. Untie ✅"); emb_off = gr.Button("Emb. Untie ❌") layer_target_inp = gr.Number(value=LAYERS, label="Target Layers", precision=0, minimum=1, interactive=True, scale=1) layer_red_on = gr.Button("Apply Layer Red.", scale=1) layer_red_off = gr.Button("Revert Layer Red.", scale=1) with gr.Row(): with gr.Column(min_width=150): norm_swap_rms = gr.Button("Use RMSNorm"); norm_swap_ln = gr.Button("Use LayerNorm") act_select = gr.Dropdown(choices=list(ACTIVATION_FUNCTIONS.keys()), value=DEFAULT_ACTIVATION_FUNCTION, label="Change ActFn") act_revert = gr.Button("Revert ActFn") with gr.Column(min_width=150): bitnet_on = gr.Button("BitNet ✅"); bitnet_off = gr.Button("BitNet ❌") with gr.Accordion("Multi-Modal Conversion (Experimental)", open=False): gr.Markdown("⚠️ **Experimental:** Adds modality-specific encoders (e.g., ViT, Whisper) and projection layers. **Requires manual `forward` pass adaptation & multi-modal data/training.**") modality_checkboxes_ui = gr.CheckboxGroup(choices=AVAILABLE_MODALITIES, label="Select Modalities") with gr.Row(): apply_multimodal_button = gr.Button("Apply Multi-Modal Setup") revert_multimodal_button = gr.Button("Revert Multi-Modal Setup") with gr.TabItem("Performance & Opt."): with gr.Row(): with gr.Column(min_width=150): speed_on = gr.Button("Speed Opt. ✅"); speed_off = gr.Button("Speed Opt. ❌") with gr.Column(min_width=150): coher_on = gr.Button("Coherence ✅"); coher_off = gr.Button("Coherence ❌") with gr.Column(min_width=150): ln_bypass_on = gr.Button("LN Bypass ✅"); ln_bypass_off = gr.Button("LN Bypass ❌") with gr.Row(): with gr.Column(min_width=150): do_bypass_on = gr.Button("Dropout Bypass ✅"); do_bypass_off = gr.Button("Dropout Bypass ❌") with gr.Column(min_width=150): prec_on = gr.Button("FP32 Prec. ✅"); prec_off = gr.Button("FP32 Prec. ❌") with gr.Column(min_width=150): norm_emb_on = gr.Button("Emb. Norm. ✅"); norm_emb_off = gr.Button("Emb. Norm. ❌") with gr.Row(): with gr.Column(min_width=150): gc_cp_on = gr.Button("Grad Checkpoint ✅"); gc_cp_off = gr.Button("Grad Checkpoint ❌") with gr.Column(min_width=150): flash_attn_on = gr.Button("Flash Attn 2 ✅"); flash_attn_off = gr.Button("Flash Attn 2 ❌") with gr.Accordion("Quantization & Pruning", open=False): with gr.Row(): quant_select = gr.Dropdown(choices=QUANTIZATION_MODES, value=DEFAULT_QUANTIZATION, label="Quantize To") quant_apply = gr.Button("Apply Quant.") quant_revert = gr.Button("Revert Quant.") with gr.Row(): prune_amount_inp = gr.Slider(0.01, 0.95, value=PRUNING_AMOUNT, step=0.01, label="Prune Amount") prune_apply = gr.Button("Apply Pruning") prune_revert = gr.Button("Revert Pruning") with gr.TabItem("PEFT Adapters"): gr.Markdown("Add, merge, or remove LoRA/PEFT adapters from the currently loaded model.") peft_lora_path_input = gr.Textbox(label="LoRA/PEFT Adapter Path or Hub ID", placeholder="username/my-lora-adapter") with gr.Row(): peft_set_path_btn = gr.Button("Set Path in Config") peft_add_adapter_btn = gr.Button("Add Default Adapter") peft_merge_btn = gr.Button("Merge Active Adapter") peft_remove_adapter_btn = gr.Button("Remove/Unload Adapter") with gr.TabItem("Advanced Config & Layers"): with gr.Row(): freeze_input = gr.Textbox(label="Layers to Freeze (e.g., '0-3, 7, 10-11')") freeze_apply = gr.Button("🧊 Freeze") freeze_revert = gr.Button("🔥 Unfreeze All") with gr.Row(): with gr.Column(min_width=150): lim_on = gr.Button("Limits Cfg ✅"); lim_off = gr.Button("Limits Cfg ❌") with gr.Column(min_width=150): qa_on = gr.Button("QA Restrict Rem. ✅"); qa_off = gr.Button("QA Restrict Rem. ❌") layerdrop_prob_inp = gr.Slider(0.0, 0.5, value=0.1, step=0.01, label="LayerDrop Prob") layerdrop_on = gr.Button("LayerDrop Flag ✅") layerdrop_off = gr.Button("LayerDrop Flag ❌") with gr.Accordion("RoPE, Sliding Window, Attention Variant (Require Model Reload)", open=False): gr.Markdown("**Warning:** These settings modify the config but require reloading the model to take effect.") with gr.Row(): rope_type_inp = gr.Dropdown(label="RoPE Type", choices=["linear", "dynamic"], value="linear") rope_factor_inp = gr.Number(label="RoPE Factor (>=1.0)", value=2.0, minimum=1.0, step=0.1) rope_apply_btn = gr.Button("Set RoPE") rope_revert_btn = gr.Button("Revert RoPE") with gr.Row(): sw_size_inp = gr.Number(label="Sliding Window Size (0=disable)", value=4096, minimum=0, step=64) sw_apply_btn = gr.Button("Set Sliding Window") sw_revert_btn = gr.Button("Revert Sliding Window") with gr.Row(): attn_variant_inp = gr.Dropdown(label="Attention Implementation", choices=["auto", "eager", "sdpa", "flash_attention_2"], value="auto") attn_apply_btn = gr.Button("Set Attention Variant") attn_revert_btn = gr.Button("Revert Attention Variant") with gr.Accordion("KD & Reward Heads (Experimental - Requires Training Changes)", open=False): with gr.Row(): kd_labels_inp = gr.Number(label="KD Num Labels", value=2, minimum=1, precision=0) kd_setup_btn = gr.Button("Setup KD Head") kd_revert_btn = gr.Button("Revert KD Head") with gr.Row(): rm_outputs_inp = gr.Number(label="RM Num Outputs", value=1, minimum=1, precision=0) rm_setup_btn = gr.Button("Setup RM Head") rm_revert_btn = gr.Button("Revert RM Head") with gr.Accordion("Other Flags (Symbolic - May Require Specific Training Logic)", open=False): with gr.Row(): swa_on = gr.Button("SWA Flag ✅"); swa_off = gr.Button("SWA Flag ❌") ke_on = gr.Button("Know. Edit Flag ✅"); ke_off = gr.Button("Know. Edit Flag ❌") hp_on = gr.Button("Head Prune Flag ✅"); hp_off = gr.Button("Head Prune Flag ❌") with gr.Row(): qat_on = gr.Button("QAT Flag ✅"); qat_off = gr.Button("QAT Flag ❌") gn_on = gr.Button("Grad Noise Flag ✅"); gn_off = gr.Button("Grad Noise Flag ❌") wi_on = gr.Button("Weight Init Flag ✅"); wi_off = gr.Button("Weight Init Flag ❌") with gr.TabItem("Training Param Flags"): gr.Markdown("Toggle flags in the config that affect **subsequent** Trainer initialization (won't affect current training).") with gr.Row(): with gr.Column(min_width=150): gc_flag_on = gr.Button("GradClip Flg ✅"); gc_flag_off = gr.Button("GradClip Flg ❌") with gr.Column(min_width=150): wd_flag_on = gr.Button("WD Flg ✅"); wd_flag_off = gr.Button("WD Flg ❌") with gr.Column(min_width=150): lr_flag_on = gr.Button("LR Sched. Flg ✅"); lr_flag_off = gr.Button("LR Sched. Flg ❌") with gr.Row(): optim_flag_select = gr.Dropdown(choices=list(OPTIMIZERS.keys()), value=DEFAULT_OPTIMIZER, label="Set Optim. Pref") optim_flag_apply = gr.Button("Apply Optim.") optim_flag_revert = gr.Button("Revert Optim.") with gr.Row(): grad_accum_ui_inp_config = gr.Number(value=GRADIENT_ACCUMULATION_STEPS, label="Grad Accum Steps (Config)", precision=0, minimum=1) grad_accum_set_btn = gr.Button("Set Grad Accum") with gr.TabItem("🔒 Safety & Content Filters"): gr.Markdown("Control safety filter flags in the model's config. Actual filtering effectiveness depends on the inference implementation.") with gr.Row(): safety_all_on = gr.Button("🔒 Enable ALL Filters (Defaults)", variant="secondary") safety_all_off = gr.Button("🔓 Disable ALL Filters", variant="stop") gr.Markdown("Individual Filter Toggles:") filter_checkboxes = [] num_cols = 4 for i in range(0, len(filter_names_ui), num_cols): with gr.Row(): for j in range(num_cols): idx = i + j if idx < len(filter_names_ui): name = filter_names_ui[idx] cb = gr.Checkbox(label=name, value=False, interactive=True) filter_checkboxes.append(cb) else: gr.HTML("") apply_filters_button = gr.Button("Apply Individual Filter Toggles", variant="secondary") with gr.TabItem("💬 Inference"): gr.Markdown("Test the **currently loaded and configured** model.") with gr.Row(): inference_prompt = gr.Textbox(label="Enter Prompt", lines=4, placeholder="Once upon a time...") inference_output = gr.Textbox(label="Model Response", interactive=False, lines=15) with gr.Accordion("Generation Parameters", open=True): with gr.Row(): max_new_tokens_slider = gr.Slider(10, 4096, value=256, step=10, label="Max New Tokens", interactive=True) temperature_slider = gr.Slider(0.0, 2.0, value=0.7, step=0.01, label="Temperature (0=greedy)", interactive=True) with gr.Row(): top_k_slider = gr.Slider(0, 200, value=50, step=1, label="Top-K (0=disable)", interactive=True) top_p_slider = gr.Slider(0.0, 1.0, value=0.95, step=0.01, label="Top-P (0 or 1=disable)", interactive=True) repetition_penalty_slider = gr.Slider(1.0, 3.0, value=1.1, step=0.05, label="Repetition Penalty (1=disable)", interactive=True) generate_button = gr.Button("Generate Response", variant="primary") with gr.TabItem("🚫 Censor Control"): gr.Markdown("## Force Disable Censorship Flags") gr.Markdown("Click the button below to attempt to set all known censorship/filter flags in the loaded model's configuration to `False`. This uses the `Disable ALL Filters` function.") censor_off_button = gr.Button("🔓 Attempt Force Disable All Censorship Flags", variant="stop") censor_status = gr.Textbox(label="Censorship Flag Status", interactive=False, lines=2) load_button.click( fn=load_model_for_control, inputs=[load_model_selector, hf_token_read, bypass_limits_chk], outputs=[load_status_output, status_output] + filter_checkboxes ) save_button.click( fn=save_current_model, inputs=[save_path_inp, hf_token_write, save_hub_repo_inp], outputs=save_status_output ) merge_button.click( fn=_merge_architectures, inputs=[merge_model_ids_inp, hf_token_read, bypass_limits_chk], outputs=[merge_status_output, status_output] + filter_checkboxes ) train_button.click( fn=start_training, inputs=[ train_model_selector, train_new_model_inp, hf_token_write, train_datasets_inp, train_post_activation_fn_selector, train_post_target_layers_inp, train_grad_accum_inp, train_lr_inp, train_epochs_inp, train_max_steps_inp, train_batch_size_inp, train_optim_selector, train_scheduler_selector, train_wd_inp, train_warmup_ratio_inp, train_use_peft_chk, peft_r_inp, peft_alpha_inp, peft_dropout_inp, peft_target_modules_inp, train_wandb_token_inp, train_use_cpu_chk, bypass_limits_chk ], outputs=train_output ).then( fn=get_detailed_status_and_filter_states, inputs=None, outputs=[status_output] + filter_checkboxes ) refresh_outputs = [status_output] + filter_checkboxes refresh_status_button.click(fn=get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs) def link_control(button, func, inputs=None): processed_inputs = inputs if inputs else [] click_event = button.click(func, inputs=processed_inputs, outputs=control_output) click_event.then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs) link_control(bias_on, lambda: toggle_bias_removal_wrapper(True)) link_control(bias_off, lambda: toggle_bias_removal_wrapper(False)) link_control(emb_on, lambda: toggle_embeddings_untie_wrapper(True)) link_control(emb_off, lambda: toggle_embeddings_untie_wrapper(False)) link_control(layer_red_on, lambda layers: toggle_layer_reduction_wrapper(True, layers), inputs=[layer_target_inp]) link_control(layer_red_off, lambda: toggle_layer_reduction_wrapper(False, None)) link_control(norm_swap_rms, lambda: apply_norm_swap_wrapper('RMSNorm')) link_control(norm_swap_ln, lambda: apply_norm_swap_wrapper('LayerNorm')) act_select.change(lambda name: apply_activation_change_wrapper(name), inputs=[act_select], outputs=control_output).then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs) link_control(act_revert, revert_activation_change_wrapper) link_control(bitnet_on, lambda: toggle_bitnet_wrapper(True)) link_control(bitnet_off, lambda: toggle_bitnet_wrapper(False)) link_control(apply_multimodal_button, apply_multimodal_wrapper, inputs=[modality_checkboxes_ui]) link_control(revert_multimodal_button, revert_multimodal_wrapper) link_control(speed_on, lambda: toggle_token_speed_optimization_wrapper(True)) link_control(speed_off, lambda: toggle_token_speed_optimization_wrapper(False)) link_control(coher_on, lambda: toggle_coherence_improvement_wrapper(True)) link_control(coher_off, lambda: toggle_coherence_improvement_wrapper(False)) link_control(ln_bypass_on, lambda: toggle_layer_norm_bypass_wrapper(True)) link_control(ln_bypass_off, lambda: toggle_layer_norm_bypass_wrapper(False)) link_control(do_bypass_on, lambda: toggle_dropout_bypass_wrapper(True)) link_control(do_bypass_off, lambda: toggle_dropout_bypass_wrapper(False)) link_control(prec_on, lambda: toggle_fp32_precision_wrapper(True)) link_control(prec_off, lambda: toggle_fp32_precision_wrapper(False)) link_control(norm_emb_on, lambda: toggle_embedding_normalization_wrapper(True)) link_control(norm_emb_off, lambda: toggle_embedding_normalization_wrapper(False)) link_control(gc_cp_on, lambda: toggle_gradient_checkpointing_wrapper(True)) link_control(gc_cp_off, lambda: toggle_gradient_checkpointing_wrapper(False)) link_control(flash_attn_on, lambda: toggle_flash_attention_wrapper(True)) link_control(flash_attn_off, lambda: toggle_flash_attention_wrapper(False)) link_control(quant_apply, apply_quantization_wrapper, inputs=[quant_select]) link_control(quant_revert, revert_quantization_wrapper) link_control(prune_apply, apply_pruning_wrapper, inputs=[prune_amount_inp]) link_control(prune_revert, revert_pruning_wrapper) link_control(peft_set_path_btn, set_lora_path_wrapper, inputs=[peft_lora_path_input]) link_control(peft_add_adapter_btn, add_peft_adapter_wrapper) link_control(peft_merge_btn, merge_peft_adapter_wrapper) link_control(peft_remove_adapter_btn, remove_peft_adapter_wrapper) link_control(freeze_apply, apply_layer_freeze_wrapper, inputs=[freeze_input]) link_control(freeze_revert, revert_layer_freeze_wrapper) link_control(lim_on, lambda: toggle_limits_wrapper(True)) link_control(lim_off, lambda: toggle_limits_wrapper(False)) link_control(qa_on, lambda: toggle_qa_restrictions_wrapper(True)) link_control(qa_off, lambda: toggle_qa_restrictions_wrapper(False)) link_control(layerdrop_on, lambda prob: toggle_layerdrop_wrapper(True, prob), inputs=[layerdrop_prob_inp]) link_control(layerdrop_off, lambda: toggle_layerdrop_wrapper(False)) link_control(rope_apply_btn, lambda type, factor: toggle_rope_scaling_wrapper(True, type, factor), inputs=[rope_type_inp, rope_factor_inp]) link_control(rope_revert_btn, lambda: toggle_rope_scaling_wrapper(False)) link_control(sw_apply_btn, lambda size: toggle_sliding_window_wrapper(True, size), inputs=[sw_size_inp]) link_control(sw_revert_btn, lambda: toggle_sliding_window_wrapper(False)) link_control(attn_apply_btn, apply_attention_variant_wrapper, inputs=[attn_variant_inp]) link_control(attn_revert_btn, revert_attention_variant_wrapper) link_control(kd_setup_btn, lambda labels: toggle_kd_wrapper(True, labels), inputs=[kd_labels_inp]) link_control(kd_revert_btn, lambda: toggle_kd_wrapper(False)) link_control(rm_setup_btn, lambda outputs: toggle_reward_modeling_wrapper(True, outputs), inputs=[rm_outputs_inp]) link_control(rm_revert_btn, lambda: toggle_reward_modeling_wrapper(False)) link_control(swa_on, lambda: specific_action_function(_apply_swa)) link_control(swa_off, lambda: specific_action_function(_revert_swa)) link_control(ke_on, lambda: specific_action_function(_apply_knowledge_editing)) link_control(ke_off, lambda: specific_action_function(_revert_knowledge_editing)) link_control(hp_on, lambda: specific_action_function(_apply_head_pruning)) link_control(hp_off, lambda: specific_action_function(_revert_head_pruning)) link_control(qat_on, lambda: specific_action_function(_apply_qat)) link_control(qat_off, lambda: specific_action_function(_revert_qat)) link_control(gn_on, lambda: specific_action_function(_apply_gradient_noise)) link_control(gn_off, lambda: specific_action_function(_revert_gradient_noise)) link_control(wi_on, lambda: specific_action_function(_apply_weight_init)) link_control(wi_off, lambda: specific_action_function(_revert_weight_init)) link_control(gc_flag_on, lambda: toggle_gradient_clipping_flag_wrapper(True)) link_control(gc_flag_off, lambda: toggle_gradient_clipping_flag_wrapper(False)) link_control(wd_flag_on, lambda: toggle_weight_decay_flag_wrapper(True)) link_control(wd_flag_off, lambda: toggle_weight_decay_flag_wrapper(False)) link_control(lr_flag_on, lambda: toggle_lr_scheduler_flag_wrapper(True)) link_control(lr_flag_off, lambda: toggle_lr_scheduler_flag_wrapper(False)) link_control(optim_flag_apply, apply_optimizer_change_wrapper, inputs=[optim_flag_select]) link_control(optim_flag_revert, revert_optimizer_change_wrapper) link_control(grad_accum_set_btn, set_gradient_accumulation_wrapper, inputs=[grad_accum_ui_inp_config]) link_control(safety_all_on, lambda: toggle_all_safety_filters_wrapper(True)) link_control(safety_all_off, lambda: toggle_all_safety_filters_wrapper(False)) apply_filters_button.click( fn=toggle_individual_safety_filter_wrapper, inputs=filter_checkboxes, outputs=control_output ).then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs) generate_button.click( fn=run_inference, inputs=[ inference_prompt, max_new_tokens_slider, temperature_slider, top_k_slider, top_p_slider, repetition_penalty_slider ], outputs=inference_output ) censor_off_button.click( fn=force_disable_censorship_wrapper, outputs=censor_status ).then(get_detailed_status_and_filter_states, inputs=None, outputs=refresh_outputs) if __name__ == "__main__": demo.queue().launch(server_name="0.0.0.0", share=True, debug=False)