Hanrui / SpecForge /tests /test_layers /test_decoder.py
Lekr0's picture
Add files using upload-large-folder tool
7a60a87 verified
import os
import time
import unittest
import torch
import torch.multiprocessing as mp
from accelerate.utils import set_seed
from torch import nn
from transformers import PretrainedConfig
from yunchang import EXTRACT_FUNC_DICT
from specforge.core.eagle3_adapters import SdpaLikeAdapter, UspAdapter
from specforge.data.preprocessing import build_offline_eagle3_dataset
# Project-specific imports
from specforge.distributed import destroy_distributed, init_distributed
from specforge.modeling.draft.llama3_eagle import LlamaDecoderLayer
from specforge.utils import padding
from tests.utils import get_available_port
def get_model_config():
"""Create and return the model configuration."""
config_dict = {
"architectures": ["LlamaForCausalLMEagle3"],
"eagle_config": {
"eagle_aux_hidden_state_layer_ids": [1, 29, 57],
"use_aux_hidden_state": True,
},
"bos_token_id": 128000,
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 7168,
"initializer_range": 0.02,
"intermediate_size": 29568,
"max_position_embeddings": 32768,
"model_type": "llama",
"num_attention_heads": 32,
"num_key_value_heads": 8,
"num_hidden_layers": 1,
"pad_token_id": 0,
"rms_norm_eps": 1e-05,
"tie_word_embeddings": False,
"torch_dtype": "float16",
"transformers_version": "4.28.1",
"use_cache": True,
"rope_scaling": None,
"vocab_size": 129280,
"draft_vocab_size": 32000,
"pretraining_tp": 1,
}
return PretrainedConfig.from_dict(config_dict)
def setup_env(rank, world_size, port):
"""Set up distributed environment variables."""
os.environ["RANK"] = str(rank)
os.environ["WORLD_SIZE"] = str(world_size)
os.environ["MASTER_ADDR"] = "localhost"
os.environ["MASTER_PORT"] = str(port)
torch.cuda.set_device(rank)
def dbg(rank, msg):
print(f"[rank{rank}] {msg}", flush=True)
def wait_for_file(path, timeout_s=60, poll_s=0.1):
start = time.time()
while time.time() - start < timeout_s:
if os.path.exists(path):
return True
time.sleep(poll_s)
return False
def run_iterative_pass(
decoder_layer,
embed_tokens,
input_ids,
hidden_states,
attention_mask,
position_ids,
ttt_length,
):
"""
Core loop: execute the forward pass `ttt_length` times.
Used for both Golden (SDPA) and Distributed (USP) runs to ensure logic consistency.
"""
# Clone to avoid side effects on original tensors
curr_input_ids = input_ids.clone()
curr_hidden_states = hidden_states.clone()
# Init cache
cache_hidden = [[], []]
past_key_values = None
final_output = None
for idx in range(ttt_length):
is_last = idx == ttt_length - 1
# 1. Embed inputs
inputs_embeds = embed_tokens(curr_input_ids).to(curr_hidden_states.dtype)
# 2. Forward pass
output_hidden_states = decoder_layer(
input_emb=inputs_embeds,
hidden_states=curr_hidden_states,
cache_hidden=cache_hidden,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=False,
use_cache=False,
)
# Update states for next iteration
curr_hidden_states = output_hidden_states
final_output = output_hidden_states
# 3. Simulate TTT padding/shift
if not is_last:
curr_input_ids = padding(curr_input_ids, left=False)
return final_output
def run_test_case(rank, world_size, port):
"""Worker function executed in each process."""
setup_env(rank, world_size, port)
device = torch.device(f"cuda:{rank}")
set_seed(42)
dbg(rank, "env setup complete")
# --- Data & Config Preparation ---
config = get_model_config()
seq_len = 1560
batch_size = 1
ttt_length = 3
# Generate dummy data on GPU
data_input_ids = torch.randint(0, 10000, (batch_size, seq_len), device=device)
data_hidden_states = torch.randn(
batch_size, seq_len, config.hidden_size, device=device, dtype=torch.bfloat16
)
attention_mask = torch.tril(torch.ones(seq_len, seq_len, device=device)).view(
1, 1, seq_len, seq_len
)
position_ids = torch.arange(seq_len, device=device).unsqueeze(0)
# Shared embedding layer
embed_tokens = nn.Embedding(
config.vocab_size, config.hidden_size, config.pad_token_id
).to(device)
# --- Phase 1: Golden Run (FA) ---
# Init dist briefly for internal checks, even if running single-device logic
init_distributed(tp_size=1, sp_ulysses_size=1, sp_ring_size=1)
dbg(rank, "init_distributed (FA) done")
sdpa_decoder = (
LlamaDecoderLayer(config, attention_backend="fa").to(device).to(torch.bfloat16)
)
dbg(rank, "FA decoder created")
# Adapter smoke test for FA/SDPA-style path
dummy_model = type("Dummy", (), {})()
sdpa_adapter = SdpaLikeAdapter(dummy_model)
sdpa_target_p = torch.zeros((1, seq_len, 8), device=device, dtype=torch.float32)
sdpa_position_mask = torch.ones((1, seq_len, 1), device=device, dtype=torch.float32)
sdpa_state = sdpa_adapter.step_view(
idx=0,
ttt_length=ttt_length,
global_input_ids=data_input_ids,
attention_mask=attention_mask,
loss_mask=torch.ones((1, seq_len, 1), device=device, dtype=torch.float32),
position_ids=position_ids,
hidden_states=data_hidden_states,
target_p_padded=sdpa_target_p,
position_mask=sdpa_position_mask,
seq_length=seq_len,
)
assert sdpa_state.input_ids.shape[1] == seq_len
assert sdpa_state.hidden_states.shape[1] == seq_len
with torch.no_grad():
sdpa_output = run_iterative_pass(
decoder_layer=sdpa_decoder,
embed_tokens=embed_tokens,
input_ids=data_input_ids,
hidden_states=data_hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
ttt_length=ttt_length,
)
dbg(rank, "FA forward done")
# Save weights for alignment and cleanup SDPA model
state_dict = sdpa_decoder.state_dict()
del sdpa_decoder
destroy_distributed()
dbg(rank, "destroy_distributed (FA) done")
# --- Phase 2: Distributed Run (USP) ---
def subtest_usp(sp_ulysses_degree, sp_ring_degree):
"""Run USP with specific topology and compare against Golden."""
try:
init_distributed(
tp_size=1,
sp_ulysses_size=sp_ulysses_degree,
sp_ring_size=sp_ring_degree,
)
dbg(
rank,
f"init_distributed (USP U{sp_ulysses_degree} R{sp_ring_degree}) done",
)
# Dataset + adapter smoke test (USP path)
tmp_dir = "./tmp/usp_dataset_shared"
try:
if rank == 0:
os.makedirs(tmp_dir, exist_ok=True)
sample = {
"input_ids": data_input_ids[0].cpu(),
"loss_mask": torch.ones_like(data_input_ids[0].cpu()),
"hidden_state": data_hidden_states[0].cpu().unsqueeze(0),
"aux_hidden_state": data_hidden_states[0].cpu().unsqueeze(0),
}
torch.save(sample, os.path.join(tmp_dir, "data_0.ckpt"))
dbg(rank, "wrote sample ckpt")
ready_flag = os.path.join(tmp_dir, "ready.flag")
with open(ready_flag, "w", encoding="utf-8") as f:
f.write("ready\n")
if rank != 0:
ready_flag = os.path.join(tmp_dir, "ready.flag")
assert wait_for_file(
ready_flag, timeout_s=60
), "timeout waiting for ready flag"
dbg(rank, "dataset sync done")
assert os.path.exists(
os.path.join(tmp_dir, "data_0.ckpt")
), f"Expected sample not found at {tmp_dir}"
dbg(rank, "sample exists")
ds = build_offline_eagle3_dataset(
tmp_dir,
max_len=seq_len,
ttt_length=ttt_length,
use_usp_preprocess=True,
)
dbg(rank, "dataset built")
item = ds[0]
dbg(rank, "dataset item loaded")
assert "position_ids" in item
dummy_model = type("Dummy", (), {})()
adapter = UspAdapter(dummy_model)
local_seq_len = item["input_ids"].shape[1]
target_p_padded = torch.zeros(
(1, local_seq_len, 8), device=device, dtype=torch.float32
)
position_mask = torch.ones(
(1, local_seq_len, 1), device=device, dtype=torch.float32
)
state = adapter.step_view(
idx=0,
ttt_length=ttt_length,
global_input_ids=item["input_ids"].to(device),
attention_mask=item["attention_mask"].to(device),
loss_mask=item["loss_mask"].to(device).unsqueeze(-1),
position_ids=item["position_ids"].to(device),
hidden_states=item["hidden_state"].to(device),
target_p_padded=target_p_padded,
position_mask=position_mask,
seq_length=local_seq_len,
)
assert state.input_ids.shape[1] == local_seq_len - ttt_length
assert state.hidden_states.shape[1] == local_seq_len - ttt_length
dbg(rank, "adapter step_view ok")
finally:
if rank == 0:
done_flag = os.path.join(tmp_dir, "done.flag")
assert wait_for_file(
done_flag, timeout_s=60
), "timeout waiting for done flag"
try:
for root, _, files in os.walk(tmp_dir):
for name in files:
os.remove(os.path.join(root, name))
os.rmdir(tmp_dir)
except OSError:
pass
else:
done_flag = os.path.join(tmp_dir, "done.flag")
with open(done_flag, "w", encoding="utf-8") as f:
f.write("done\n")
# Init USP model and load golden weights
usp_decoder = (
LlamaDecoderLayer(config, attention_backend="usp")
.to(device)
.to(torch.bfloat16)
)
usp_decoder.load_state_dict(state_dict)
dbg(rank, "USP decoder loaded")
# Shard data (Split Input)
extract_func = EXTRACT_FUNC_DICT["basic"]
local_input_ids = (
extract_func(
data_input_ids,
rank,
world_size=world_size,
rd=sp_ring_degree,
ud=sp_ulysses_degree,
)
.detach()
.clone()
)
local_hidden_states = (
extract_func(
data_hidden_states,
rank,
world_size=world_size,
rd=sp_ring_degree,
ud=sp_ulysses_degree,
)
.detach()
.clone()
)
dbg(rank, "USP local inputs prepared")
total_degree = sp_ring_degree * sp_ulysses_degree
chunk_size = sdpa_output.shape[1] // total_degree
start_idx = (rank % total_degree) * chunk_size
local_len = local_input_ids.shape[1]
local_position_ids = (
torch.arange(start_idx, start_idx + local_len, device=device)
.unsqueeze(0)
.long()
)
local_attention_mask = torch.tril(
torch.ones(local_len, local_len, device=device)
).view(1, 1, local_len, local_len)
# Run USP forward
if sp_ring_degree > 1:
usp_attention_mask = local_attention_mask
usp_position_ids = local_position_ids
else:
usp_attention_mask = attention_mask
usp_position_ids = position_ids
with torch.no_grad():
usp_output = run_iterative_pass(
decoder_layer=usp_decoder,
embed_tokens=embed_tokens,
input_ids=local_input_ids,
hidden_states=local_hidden_states,
attention_mask=usp_attention_mask,
position_ids=usp_position_ids,
ttt_length=ttt_length,
)
dbg(rank, "USP forward done")
# Verify results
# Slice the golden output to match the current rank's chunk
end_idx = start_idx + chunk_size
golden_chunk = sdpa_output[:, start_idx:end_idx, :]
assert torch.allclose(usp_output, golden_chunk, rtol=2e-2, atol=2e-2), (
f"[Rank {rank}] USP (U{sp_ulysses_degree}R{sp_ring_degree}) mismatch!\n"
f"Max Diff: {(usp_output - golden_chunk).abs().max().item()}"
)
dbg(rank, "USP output verified")
finally:
destroy_distributed()
dbg(rank, "destroy_distributed (USP) done")
# Case 1: Hybrid (Ulysses=2, Ring=1)
subtest_usp(sp_ulysses_degree=2, sp_ring_degree=1)
# Case 2: Hybrid (Ulysses=1, Ring=2)
subtest_usp(sp_ulysses_degree=1, sp_ring_degree=2)
class TestTTTDistributed(unittest.TestCase):
def test_llama_usp_decoder(self):
world_size = 2
port = get_available_port()
mp.spawn(run_test_case, nprocs=world_size, args=(world_size, port))
if __name__ == "__main__":
unittest.main()