M1-3B-SFT / check_same.py
JunxiongWang's picture
Upload folder using huggingface_hub
61ed118 verified
import torch
import torch
# Import the exception type to catch
from torch._subclasses.fake_tensor import DataDependentOutputException
def load_checkpoint(path):
return torch.load(path, map_location=torch.device('cpu'))
def tensors_allclose(val_cur, val_ref):
try:
# Attempt to compare using torch.allclose directly.
if not torch.allclose(val_cur, val_ref):
return False
except torch._subclasses.fake_tensor.DataDependentOutputException:
# If we catch a fake tensor exception, convert tensors to real CPU tensors.
real_val_cur = val_cur.detach().cpu() if hasattr(val_cur, "detach") else val_cur
real_val_ref = val_ref.detach().cpu() if hasattr(val_ref, "detach") else val_ref
if not torch.allclose(real_val_cur, real_val_ref):
return False
return True
def compare_checkpoints(paths):
checkpoints = {path: load_checkpoint(path) for path in paths}
# Use the first checkpoint as the reference.
ref_path, ref_ckpt = list(checkpoints.items())[0]
all_same = True
for path, ckpt in checkpoints.items():
# Check if the types match.
if type(ckpt) != type(ref_ckpt):
print(f"Type mismatch: {path} is of type {type(ckpt)}, expected {type(ref_ckpt)}")
all_same = False
continue
# If checkpoint is a dictionary, compare keys and values.
if isinstance(ckpt, dict):
if set(ckpt.keys()) != set(ref_ckpt.keys()):
print(f"Key mismatch in {path}.")
all_same = False
continue
for key in ckpt:
val_ref = ref_ckpt[key]
val_cur = ckpt[key]
# If the value is a tensor, compare using our helper function.
if isinstance(val_ref, torch.Tensor) and isinstance(val_cur, torch.Tensor):
if not tensors_allclose(val_cur, val_ref):
print(f"Tensor values differ for key '{key}' in {path}.")
all_same = False
else:
if val_cur != val_ref:
print(f"Value for key '{key}' differs in {path}.")
all_same = False
else:
# If the checkpoints are not dictionaries, compare them directly.
if ckpt != ref_ckpt:
print(f"Checkpoint {path} differs from {ref_path}.")
all_same = False
if all_same:
print("All checkpoints are identical.")
else:
print("Not all checkpoints are identical.")
# Generate file paths for ranks 0 through 7.
paths = [f"grpo_cg_packfix_128_24576_1_1e-6_0_0_1_new_veRL/global_step_40/actor/model_world_size_8_rank_{i}.pt" for i in range(8)]
compare_checkpoints(paths)