Lekr0's picture
Add files using upload-large-folder tool
4024ed7 verified
import json
import logging
import os
import re
from contextlib import contextmanager
import torch
import torch.distributed as dist
from torch.distributed._tensor import DTensor, Shard, distribute_tensor
from transformers import AutoConfig, PretrainedConfig
logger = logging.getLogger(__name__)
@contextmanager
def rank_0_priority():
rank = dist.get_rank()
if rank == 0:
yield
dist.barrier()
else:
dist.barrier()
yield
@contextmanager
def default_torch_dtype(dtype: torch.dtype):
current_dtype = torch.get_default_dtype()
torch.set_default_dtype(dtype)
yield
torch.set_default_dtype(current_dtype)
@torch.no_grad()
def padding(tensor, left=True):
zeropadding = torch.zeros_like(tensor[:, -1:])
if left:
tensor = torch.cat((zeropadding, tensor[:, :-1]), dim=1)
else:
tensor = torch.cat((tensor[:, 1:], zeropadding), dim=1)
return tensor
def load_config_from_file(config_path: str):
with open(config_path, "r") as f:
config = json.load(f)
return PretrainedConfig.from_dict(config)
def print_with_rank(message):
if dist.is_available() and dist.is_initialized():
logger.info(f"rank {dist.get_rank()}: {message}")
else:
logger.info(f"non-distributed: {message}")
def print_args_with_dots(args):
if dist.get_rank() == 0:
args_dict = vars(args)
max_key_length = max(len(key) for key in args_dict.keys())
total_width = 50
print("\n -----------【args】-----------")
for key, value in args_dict.items():
key_str = f"{key:<{max_key_length}}"
value_str = str(value)
dot_count = total_width - len(key_str) - len(value_str)
dot_fill = "·" * dot_count
print(f"{key_str} {dot_fill} {value_str}")
def print_on_rank0(message):
if dist.get_rank() == 0:
logger.info(message)
def get_last_checkpoint(folder, prefix="epoch"):
content = os.listdir(folder)
_re_checkpoint = re.compile(r"^" + prefix + r"_(\d+)$")
checkpoints = [
path
for path in content
if _re_checkpoint.search(path) is not None
and os.path.isdir(os.path.join(folder, path))
]
if len(checkpoints) == 0:
return
return os.path.join(
folder,
max(checkpoints, key=lambda x: int(_re_checkpoint.search(x).groups()[0])),
)
def generate_draft_model_config(
target_model_path: str, template_config_path: str = None, cache_dir: str = None
):
"""
Auto-generate draft model config based on target model parameters aligned with template config
Args:
target_model_path (str): Path to the target model
template_config_path (str, optional): Template config file path, defaults to llama3-8B-eagle3.json
cache_dir (str, optional): Cache directory
Returns:
dict: Generated draft model config dictionary
"""
# Get target model config
target_config = AutoConfig.from_pretrained(target_model_path, cache_dir=cache_dir)
# If no template specified, use default llama3-8B-eagle3.json
if template_config_path is None:
# Use the script execution directory as base
import sys
script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
project_root = os.path.dirname(script_dir) # Go up one level from scripts/
template_config_path = os.path.join(
project_root, "configs", "llama3-8B-eagle3.json"
)
# Read template config
with open(template_config_path, "r") as f:
draft_config = json.load(f)
# Adjust architecture config based on target model type
if hasattr(target_config, "model_type"):
# Default to llama architecture
draft_config["model_type"] = "llama"
# Align key parameters
param_mappings = {
"vocab_size": "vocab_size",
"hidden_size": "hidden_size",
"num_attention_heads": "num_attention_heads",
"num_key_value_heads": "num_key_value_heads",
"intermediate_size": "intermediate_size",
"max_position_embeddings": "max_position_embeddings",
"rms_norm_eps": "rms_norm_eps",
"hidden_act": "hidden_act",
"bos_token_id": "bos_token_id",
"eos_token_id": "eos_token_id",
"torch_dtype": "torch_dtype",
}
# Copy parameters from target model to draft config
for target_param, draft_param in param_mappings.items():
if hasattr(target_config, target_param):
value = getattr(target_config, target_param)
# Special handling for torch_dtype to make it JSON serializable
if target_param == "torch_dtype" and isinstance(value, torch.dtype):
value = str(value).replace("torch.", "")
draft_config[draft_param] = value
# Special handling for some parameters
# Ensure num_hidden_layers is always 1 (EAGLE3 feature)
draft_config["num_hidden_layers"] = 1
# Keep some fixed draft model specific parameters
draft_config["tie_word_embeddings"] = False
draft_config["use_cache"] = True
# If template doesn't have draft_vocab_size, set default
if "draft_vocab_size" not in draft_config:
draft_config["draft_vocab_size"] = 32000 # Default value
return draft_config
def save_draft_model_config(config_dict: dict, output_path: str):
"""
Save draft model config to file
Args:
config_dict (dict): Config dictionary
output_path (str): Output file path
"""
os.makedirs(os.path.dirname(output_path), exist_ok=True)
with open(output_path, "w", encoding="utf-8") as f:
json.dump(config_dict, f, indent=2, ensure_ascii=False)
print(f"Draft model config saved to: {output_path}")
def create_draft_config_from_target(
target_model_path: str,
output_dir: str = None,
template_config_path: str = None,
cache_dir: str = None,
):
"""
Convenient function to create draft model config file from target model
Args:
target_model_path (str): Target model path
output_dir (str, optional): Output directory, defaults to configs folder in current directory
template_config_path (str, optional): Template config path
cache_dir (str, optional): Cache directory
Returns:
str: Generated config file path
"""
# Generate config
rank = dist.get_rank()
if rank == 0:
print_with_rank(
"No draft model config provided, auto-generating from target model..."
)
config_dict = generate_draft_model_config(
target_model_path, template_config_path, cache_dir
)
dist.barrier()
# Determine output path
if output_dir is None:
# Use the script execution directory as base
import sys
script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
project_root = os.path.dirname(script_dir) # Go up one level from scripts/
output_dir = os.path.join(project_root, "configs")
# Extract model name from model path
model_name = target_model_path.split("/")[-1].lower()
output_filename = f"{model_name}-eagle3-auto.json"
output_path = os.path.join(output_dir, output_filename)
# Save config
if rank == 0:
save_draft_model_config(config_dict, output_path)
print_with_rank(f"Auto-generated draft model config saved to: {output_path}")
dist.barrier()
return output_path
def get_full_optimizer_state(optimizer_state_dict: dict):
"""
Convert optimizer state dict with DTensor to full tensors for saving
Args:
optimizer_state_dict (dict): Optimizer state dict possibly containing DTensors
Returns:
dict: Optimizer state dict with full tensors
"""
full_optimizer_state_dict = {
k: v for k, v in optimizer_state_dict.items() if k != "state"
}
if "state" in optimizer_state_dict:
full_optimizer_state_dict["state"] = {
param_id: {
state_key: (
state_tensor.full_tensor()
if isinstance(state_tensor, torch.distributed.tensor.DTensor)
else state_tensor
)
for state_key, state_tensor in param_state.items()
}
for param_id, param_state in optimizer_state_dict["state"].items()
}
return full_optimizer_state_dict
def shard_optimizer_state_with_dtensor(bf16_optimizer, device_mesh):
"""
Shards the optimizer state tensors of a BF16Optimizer instance using DTensor.
Args:
bf16_optimizer (BF16Optimizer): An instance of BF16Optimizer, which contains
the actual optimizer (e.g., torch.optim.Adam) as its `.optimizer` attribute.
"""
optim = bf16_optimizer.optimizer
for group in optim.param_groups:
for p in group["params"]:
if not isinstance(p, DTensor):
continue
state = optim.state.get(p, None)
if state is None:
continue
mesh = device_mesh
placements = (Shard(dim=0),)
for k, v in list(state.items()):
if k == "step":
continue
if isinstance(v, DTensor):
continue
if not isinstance(v, torch.Tensor):
continue
state[k] = distribute_tensor(
v.to(p.device), device_mesh=mesh, placements=placements
)
def safe_conversations_generator(file_path):
"""
Generator that:
1. Extracts the 'conversations' field.
2. Preserves all original fields within each message.
3. [Key step] Converts all list/dict-type field values to strings to resolve mixed-type conflicts (e.g., for Arrow compatibility).
"""
with open(file_path, "r", encoding="utf-8") as f:
for i, line in enumerate(f):
line = line.strip()
if not line:
continue
try:
row = json.loads(line)
raw_convs = row.get("conversations", [])
# 1. Ensure 'conversations' is a list
if not isinstance(raw_convs, list):
# If it's None or some unexpected type, treat as empty or skip
if raw_convs is None:
raw_convs = []
else:
# Edge case: 'conversations' is a plain string or non-iterable—skip this line
logger.warning(
f"Line {i + 1}: 'conversations' is not a list. Please check!"
)
continue
cleaned_convs = []
for msg in raw_convs:
# 2. Ensure each item in the list is a dictionary
if not isinstance(msg, dict):
# Skip if an element is not a dict (e.g., malformed like ["user", "hi"])
continue
# 3. [Core logic] Iterate over all fields in the message (role, content, tools, etc.)
new_msg = {}
for k, v in msg.items():
# If the value is a list or dict, serialize it to a JSON string
# This ensures Arrow treats the column as string type instead of list/struct
if isinstance(v, (list, dict)):
new_msg[k] = json.dumps(v, ensure_ascii=False)
else:
# Keep primitive types (str, int, float, bool, None) unchanged
new_msg[k] = v
cleaned_convs.append(new_msg)
# Yield only the processed 'conversations'
yield {"conversations": cleaned_convs}
except Exception as e:
logger.warning(f"Skipping line {i + 1}: {e}")
continue