File size: 2,832 Bytes
61ed118
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import torch

import torch

# Import the exception type to catch
from torch._subclasses.fake_tensor import DataDependentOutputException

def load_checkpoint(path):
    return torch.load(path, map_location=torch.device('cpu'))

def tensors_allclose(val_cur, val_ref):
    try:
        # Attempt to compare using torch.allclose directly.
        if not torch.allclose(val_cur, val_ref):
            return False
    except torch._subclasses.fake_tensor.DataDependentOutputException:
        # If we catch a fake tensor exception, convert tensors to real CPU tensors.
        real_val_cur = val_cur.detach().cpu() if hasattr(val_cur, "detach") else val_cur
        real_val_ref = val_ref.detach().cpu() if hasattr(val_ref, "detach") else val_ref
        if not torch.allclose(real_val_cur, real_val_ref):
            return False
    return True

def compare_checkpoints(paths):
    checkpoints = {path: load_checkpoint(path) for path in paths}
    
    # Use the first checkpoint as the reference.
    ref_path, ref_ckpt = list(checkpoints.items())[0]
    all_same = True

    for path, ckpt in checkpoints.items():
        # Check if the types match.
        if type(ckpt) != type(ref_ckpt):
            print(f"Type mismatch: {path} is of type {type(ckpt)}, expected {type(ref_ckpt)}")
            all_same = False
            continue
        
        # If checkpoint is a dictionary, compare keys and values.
        if isinstance(ckpt, dict):
            if set(ckpt.keys()) != set(ref_ckpt.keys()):
                print(f"Key mismatch in {path}.")
                all_same = False
                continue
            for key in ckpt:
                val_ref = ref_ckpt[key]
                val_cur = ckpt[key]
                
                # If the value is a tensor, compare using our helper function.
                if isinstance(val_ref, torch.Tensor) and isinstance(val_cur, torch.Tensor):
                    if not tensors_allclose(val_cur, val_ref):
                        print(f"Tensor values differ for key '{key}' in {path}.")
                        all_same = False
                else:
                    if val_cur != val_ref:
                        print(f"Value for key '{key}' differs in {path}.")
                        all_same = False
        else:
            # If the checkpoints are not dictionaries, compare them directly.
            if ckpt != ref_ckpt:
                print(f"Checkpoint {path} differs from {ref_path}.")
                all_same = False

    if all_same:
        print("All checkpoints are identical.")
    else:
        print("Not all checkpoints are identical.")

# Generate file paths for ranks 0 through 7.
paths = [f"grpo_cg_packfix_128_24576_1_1e-6_0_0_1_new_veRL/global_step_40/actor/model_world_size_8_rank_{i}.pt" for i in range(8)]

compare_checkpoints(paths)