|
import torch |
|
|
|
import torch |
|
|
|
|
|
from torch._subclasses.fake_tensor import DataDependentOutputException |
|
|
|
def load_checkpoint(path): |
|
return torch.load(path, map_location=torch.device('cpu')) |
|
|
|
def tensors_allclose(val_cur, val_ref): |
|
try: |
|
|
|
if not torch.allclose(val_cur, val_ref): |
|
return False |
|
except torch._subclasses.fake_tensor.DataDependentOutputException: |
|
|
|
real_val_cur = val_cur.detach().cpu() if hasattr(val_cur, "detach") else val_cur |
|
real_val_ref = val_ref.detach().cpu() if hasattr(val_ref, "detach") else val_ref |
|
if not torch.allclose(real_val_cur, real_val_ref): |
|
return False |
|
return True |
|
|
|
def compare_checkpoints(paths): |
|
checkpoints = {path: load_checkpoint(path) for path in paths} |
|
|
|
|
|
ref_path, ref_ckpt = list(checkpoints.items())[0] |
|
all_same = True |
|
|
|
for path, ckpt in checkpoints.items(): |
|
|
|
if type(ckpt) != type(ref_ckpt): |
|
print(f"Type mismatch: {path} is of type {type(ckpt)}, expected {type(ref_ckpt)}") |
|
all_same = False |
|
continue |
|
|
|
|
|
if isinstance(ckpt, dict): |
|
if set(ckpt.keys()) != set(ref_ckpt.keys()): |
|
print(f"Key mismatch in {path}.") |
|
all_same = False |
|
continue |
|
for key in ckpt: |
|
val_ref = ref_ckpt[key] |
|
val_cur = ckpt[key] |
|
|
|
|
|
if isinstance(val_ref, torch.Tensor) and isinstance(val_cur, torch.Tensor): |
|
if not tensors_allclose(val_cur, val_ref): |
|
print(f"Tensor values differ for key '{key}' in {path}.") |
|
all_same = False |
|
else: |
|
if val_cur != val_ref: |
|
print(f"Value for key '{key}' differs in {path}.") |
|
all_same = False |
|
else: |
|
|
|
if ckpt != ref_ckpt: |
|
print(f"Checkpoint {path} differs from {ref_path}.") |
|
all_same = False |
|
|
|
if all_same: |
|
print("All checkpoints are identical.") |
|
else: |
|
print("Not all checkpoints are identical.") |
|
|
|
|
|
paths = [f"grpo_cg_packfix_128_24576_1_1e-6_0_0_1_new_veRL/global_step_40/actor/model_world_size_8_rank_{i}.pt" for i in range(8)] |
|
|
|
compare_checkpoints(paths) |
|
|
|
|