| import os |
| import time |
| import unittest |
|
|
| import torch |
| import torch.multiprocessing as mp |
| from accelerate.utils import set_seed |
| from torch import nn |
| from transformers import PretrainedConfig |
| from yunchang import EXTRACT_FUNC_DICT |
|
|
| from specforge.core.eagle3_adapters import SdpaLikeAdapter, UspAdapter |
| from specforge.data.preprocessing import build_offline_eagle3_dataset |
|
|
| |
| from specforge.distributed import destroy_distributed, init_distributed |
| from specforge.modeling.draft.llama3_eagle import LlamaDecoderLayer |
| from specforge.utils import padding |
| from tests.utils import get_available_port |
|
|
|
|
| def get_model_config(): |
| """Create and return the model configuration.""" |
| config_dict = { |
| "architectures": ["LlamaForCausalLMEagle3"], |
| "eagle_config": { |
| "eagle_aux_hidden_state_layer_ids": [1, 29, 57], |
| "use_aux_hidden_state": True, |
| }, |
| "bos_token_id": 128000, |
| "eos_token_id": 128001, |
| "hidden_act": "silu", |
| "hidden_size": 7168, |
| "initializer_range": 0.02, |
| "intermediate_size": 29568, |
| "max_position_embeddings": 32768, |
| "model_type": "llama", |
| "num_attention_heads": 32, |
| "num_key_value_heads": 8, |
| "num_hidden_layers": 1, |
| "pad_token_id": 0, |
| "rms_norm_eps": 1e-05, |
| "tie_word_embeddings": False, |
| "torch_dtype": "float16", |
| "transformers_version": "4.28.1", |
| "use_cache": True, |
| "rope_scaling": None, |
| "vocab_size": 129280, |
| "draft_vocab_size": 32000, |
| "pretraining_tp": 1, |
| } |
| return PretrainedConfig.from_dict(config_dict) |
|
|
|
|
| def setup_env(rank, world_size, port): |
| """Set up distributed environment variables.""" |
| os.environ["RANK"] = str(rank) |
| os.environ["WORLD_SIZE"] = str(world_size) |
| os.environ["MASTER_ADDR"] = "localhost" |
| os.environ["MASTER_PORT"] = str(port) |
| torch.cuda.set_device(rank) |
|
|
|
|
| def dbg(rank, msg): |
| print(f"[rank{rank}] {msg}", flush=True) |
|
|
|
|
| def wait_for_file(path, timeout_s=60, poll_s=0.1): |
| start = time.time() |
| while time.time() - start < timeout_s: |
| if os.path.exists(path): |
| return True |
| time.sleep(poll_s) |
| return False |
|
|
|
|
| def run_iterative_pass( |
| decoder_layer, |
| embed_tokens, |
| input_ids, |
| hidden_states, |
| attention_mask, |
| position_ids, |
| ttt_length, |
| ): |
| """ |
| Core loop: execute the forward pass `ttt_length` times. |
| Used for both Golden (SDPA) and Distributed (USP) runs to ensure logic consistency. |
| """ |
| |
| curr_input_ids = input_ids.clone() |
| curr_hidden_states = hidden_states.clone() |
|
|
| |
| cache_hidden = [[], []] |
| past_key_values = None |
| final_output = None |
|
|
| for idx in range(ttt_length): |
| is_last = idx == ttt_length - 1 |
|
|
| |
| inputs_embeds = embed_tokens(curr_input_ids).to(curr_hidden_states.dtype) |
|
|
| |
| output_hidden_states = decoder_layer( |
| input_emb=inputs_embeds, |
| hidden_states=curr_hidden_states, |
| cache_hidden=cache_hidden, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| past_key_values=past_key_values, |
| output_attentions=False, |
| use_cache=False, |
| ) |
|
|
| |
| curr_hidden_states = output_hidden_states |
| final_output = output_hidden_states |
|
|
| |
| if not is_last: |
| curr_input_ids = padding(curr_input_ids, left=False) |
|
|
| return final_output |
|
|
|
|
| def run_test_case(rank, world_size, port): |
| """Worker function executed in each process.""" |
| setup_env(rank, world_size, port) |
| device = torch.device(f"cuda:{rank}") |
| set_seed(42) |
| dbg(rank, "env setup complete") |
|
|
| |
| config = get_model_config() |
| seq_len = 1560 |
| batch_size = 1 |
| ttt_length = 3 |
|
|
| |
| data_input_ids = torch.randint(0, 10000, (batch_size, seq_len), device=device) |
| data_hidden_states = torch.randn( |
| batch_size, seq_len, config.hidden_size, device=device, dtype=torch.bfloat16 |
| ) |
| attention_mask = torch.tril(torch.ones(seq_len, seq_len, device=device)).view( |
| 1, 1, seq_len, seq_len |
| ) |
| position_ids = torch.arange(seq_len, device=device).unsqueeze(0) |
|
|
| |
| embed_tokens = nn.Embedding( |
| config.vocab_size, config.hidden_size, config.pad_token_id |
| ).to(device) |
|
|
| |
| |
| init_distributed(tp_size=1, sp_ulysses_size=1, sp_ring_size=1) |
| dbg(rank, "init_distributed (FA) done") |
|
|
| sdpa_decoder = ( |
| LlamaDecoderLayer(config, attention_backend="fa").to(device).to(torch.bfloat16) |
| ) |
| dbg(rank, "FA decoder created") |
| |
| dummy_model = type("Dummy", (), {})() |
| sdpa_adapter = SdpaLikeAdapter(dummy_model) |
| sdpa_target_p = torch.zeros((1, seq_len, 8), device=device, dtype=torch.float32) |
| sdpa_position_mask = torch.ones((1, seq_len, 1), device=device, dtype=torch.float32) |
| sdpa_state = sdpa_adapter.step_view( |
| idx=0, |
| ttt_length=ttt_length, |
| global_input_ids=data_input_ids, |
| attention_mask=attention_mask, |
| loss_mask=torch.ones((1, seq_len, 1), device=device, dtype=torch.float32), |
| position_ids=position_ids, |
| hidden_states=data_hidden_states, |
| target_p_padded=sdpa_target_p, |
| position_mask=sdpa_position_mask, |
| seq_length=seq_len, |
| ) |
| assert sdpa_state.input_ids.shape[1] == seq_len |
| assert sdpa_state.hidden_states.shape[1] == seq_len |
|
|
| with torch.no_grad(): |
| sdpa_output = run_iterative_pass( |
| decoder_layer=sdpa_decoder, |
| embed_tokens=embed_tokens, |
| input_ids=data_input_ids, |
| hidden_states=data_hidden_states, |
| attention_mask=attention_mask, |
| position_ids=position_ids, |
| ttt_length=ttt_length, |
| ) |
| dbg(rank, "FA forward done") |
|
|
| |
| state_dict = sdpa_decoder.state_dict() |
| del sdpa_decoder |
| destroy_distributed() |
| dbg(rank, "destroy_distributed (FA) done") |
|
|
| |
| def subtest_usp(sp_ulysses_degree, sp_ring_degree): |
| """Run USP with specific topology and compare against Golden.""" |
| try: |
| init_distributed( |
| tp_size=1, |
| sp_ulysses_size=sp_ulysses_degree, |
| sp_ring_size=sp_ring_degree, |
| ) |
| dbg( |
| rank, |
| f"init_distributed (USP U{sp_ulysses_degree} R{sp_ring_degree}) done", |
| ) |
| |
| tmp_dir = "./tmp/usp_dataset_shared" |
| try: |
| if rank == 0: |
| os.makedirs(tmp_dir, exist_ok=True) |
| sample = { |
| "input_ids": data_input_ids[0].cpu(), |
| "loss_mask": torch.ones_like(data_input_ids[0].cpu()), |
| "hidden_state": data_hidden_states[0].cpu().unsqueeze(0), |
| "aux_hidden_state": data_hidden_states[0].cpu().unsqueeze(0), |
| } |
| torch.save(sample, os.path.join(tmp_dir, "data_0.ckpt")) |
| dbg(rank, "wrote sample ckpt") |
| ready_flag = os.path.join(tmp_dir, "ready.flag") |
| with open(ready_flag, "w", encoding="utf-8") as f: |
| f.write("ready\n") |
| if rank != 0: |
| ready_flag = os.path.join(tmp_dir, "ready.flag") |
| assert wait_for_file( |
| ready_flag, timeout_s=60 |
| ), "timeout waiting for ready flag" |
| dbg(rank, "dataset sync done") |
| assert os.path.exists( |
| os.path.join(tmp_dir, "data_0.ckpt") |
| ), f"Expected sample not found at {tmp_dir}" |
| dbg(rank, "sample exists") |
|
|
| ds = build_offline_eagle3_dataset( |
| tmp_dir, |
| max_len=seq_len, |
| ttt_length=ttt_length, |
| use_usp_preprocess=True, |
| ) |
| dbg(rank, "dataset built") |
| item = ds[0] |
| dbg(rank, "dataset item loaded") |
| assert "position_ids" in item |
|
|
| dummy_model = type("Dummy", (), {})() |
| adapter = UspAdapter(dummy_model) |
| local_seq_len = item["input_ids"].shape[1] |
| target_p_padded = torch.zeros( |
| (1, local_seq_len, 8), device=device, dtype=torch.float32 |
| ) |
| position_mask = torch.ones( |
| (1, local_seq_len, 1), device=device, dtype=torch.float32 |
| ) |
| state = adapter.step_view( |
| idx=0, |
| ttt_length=ttt_length, |
| global_input_ids=item["input_ids"].to(device), |
| attention_mask=item["attention_mask"].to(device), |
| loss_mask=item["loss_mask"].to(device).unsqueeze(-1), |
| position_ids=item["position_ids"].to(device), |
| hidden_states=item["hidden_state"].to(device), |
| target_p_padded=target_p_padded, |
| position_mask=position_mask, |
| seq_length=local_seq_len, |
| ) |
| assert state.input_ids.shape[1] == local_seq_len - ttt_length |
| assert state.hidden_states.shape[1] == local_seq_len - ttt_length |
| dbg(rank, "adapter step_view ok") |
| finally: |
| if rank == 0: |
| done_flag = os.path.join(tmp_dir, "done.flag") |
| assert wait_for_file( |
| done_flag, timeout_s=60 |
| ), "timeout waiting for done flag" |
| try: |
| for root, _, files in os.walk(tmp_dir): |
| for name in files: |
| os.remove(os.path.join(root, name)) |
| os.rmdir(tmp_dir) |
| except OSError: |
| pass |
| else: |
| done_flag = os.path.join(tmp_dir, "done.flag") |
| with open(done_flag, "w", encoding="utf-8") as f: |
| f.write("done\n") |
|
|
| |
| usp_decoder = ( |
| LlamaDecoderLayer(config, attention_backend="usp") |
| .to(device) |
| .to(torch.bfloat16) |
| ) |
| usp_decoder.load_state_dict(state_dict) |
| dbg(rank, "USP decoder loaded") |
|
|
| |
| extract_func = EXTRACT_FUNC_DICT["basic"] |
|
|
| local_input_ids = ( |
| extract_func( |
| data_input_ids, |
| rank, |
| world_size=world_size, |
| rd=sp_ring_degree, |
| ud=sp_ulysses_degree, |
| ) |
| .detach() |
| .clone() |
| ) |
|
|
| local_hidden_states = ( |
| extract_func( |
| data_hidden_states, |
| rank, |
| world_size=world_size, |
| rd=sp_ring_degree, |
| ud=sp_ulysses_degree, |
| ) |
| .detach() |
| .clone() |
| ) |
| dbg(rank, "USP local inputs prepared") |
| total_degree = sp_ring_degree * sp_ulysses_degree |
| chunk_size = sdpa_output.shape[1] // total_degree |
| start_idx = (rank % total_degree) * chunk_size |
| local_len = local_input_ids.shape[1] |
| local_position_ids = ( |
| torch.arange(start_idx, start_idx + local_len, device=device) |
| .unsqueeze(0) |
| .long() |
| ) |
| local_attention_mask = torch.tril( |
| torch.ones(local_len, local_len, device=device) |
| ).view(1, 1, local_len, local_len) |
|
|
| |
| if sp_ring_degree > 1: |
| usp_attention_mask = local_attention_mask |
| usp_position_ids = local_position_ids |
| else: |
| usp_attention_mask = attention_mask |
| usp_position_ids = position_ids |
| with torch.no_grad(): |
| usp_output = run_iterative_pass( |
| decoder_layer=usp_decoder, |
| embed_tokens=embed_tokens, |
| input_ids=local_input_ids, |
| hidden_states=local_hidden_states, |
| attention_mask=usp_attention_mask, |
| position_ids=usp_position_ids, |
| ttt_length=ttt_length, |
| ) |
| dbg(rank, "USP forward done") |
|
|
| |
| |
| end_idx = start_idx + chunk_size |
|
|
| golden_chunk = sdpa_output[:, start_idx:end_idx, :] |
|
|
| assert torch.allclose(usp_output, golden_chunk, rtol=2e-2, atol=2e-2), ( |
| f"[Rank {rank}] USP (U{sp_ulysses_degree}R{sp_ring_degree}) mismatch!\n" |
| f"Max Diff: {(usp_output - golden_chunk).abs().max().item()}" |
| ) |
| dbg(rank, "USP output verified") |
|
|
| finally: |
| destroy_distributed() |
| dbg(rank, "destroy_distributed (USP) done") |
|
|
| |
| subtest_usp(sp_ulysses_degree=2, sp_ring_degree=1) |
|
|
| |
| subtest_usp(sp_ulysses_degree=1, sp_ring_degree=2) |
|
|
|
|
| class TestTTTDistributed(unittest.TestCase): |
| def test_llama_usp_decoder(self): |
| world_size = 2 |
| port = get_available_port() |
| mp.spawn(run_test_case, nprocs=world_size, args=(world_size, port)) |
|
|
|
|
| if __name__ == "__main__": |
| unittest.main() |
|
|