# Copyright 2024 Databricks | |
# SPDX-License-Identifier: Apache-2.0 | |
import numpy as np | |
import pytest | |
import torch | |
from megablocks import ops | |
TOPOLOGY_TESTS = ( | |
(1024, 1536, 2), | |
(1024, 1536, 4), | |
(1024, 1536, 8), | |
(1024, 1536, 16), | |
(1024, 1536, 32), | |
(1024, 1536, 64), | |
(1024, 1536, 128), | |
(1024, 1536, 256), | |
(1024, 1536, 512), | |
(16384, 768, 2), | |
(16384, 768, 4), | |
(16384, 768, 8), | |
(16384, 768, 16), | |
(16384, 768, 32), | |
(16384, 768, 64), | |
(16384, 768, 128), | |
(16384, 768, 256), | |
(16384, 768, 512), | |
(16384, 768, 1024), | |
(8, 14336, 8), | |
) | |
def test_topology(sl: int, hs: int, ne: int): | |
# Create the data and indices. | |
blocking = 128 | |
assert hs % blocking == 0 | |
# Randomly assign tokens to experts. | |
top_expert = torch.randint(0, ne, (sl,)).cuda().int() | |
tokens_per_expert = ops.histogram(top_expert, ne) | |
padded_tokens_per_expert = ops.round_up(tokens_per_expert, blocking) | |
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) | |
# Dimensions for the output indices. | |
output_block_rows = int(padded_bins[-1]) // blocking | |
output_block_columns = hs // blocking | |
def topology( | |
padded_bins: torch.Tensor, | |
blocking: torch.Tensor, | |
rows: int, | |
columns: int, | |
): | |
padded_bins = padded_bins.cpu().numpy() | |
out = np.zeros([rows * columns]) | |
start = 0 | |
for i in range(padded_bins.shape[0]): | |
end = padded_bins[i] // blocking | |
while start < end: | |
for j in range(columns): | |
out[start * columns + j] = j + i * columns | |
start += 1 | |
return torch.from_numpy(out).cuda().short() | |
out = ops.topology( | |
padded_bins, | |
blocking, | |
output_block_rows, | |
output_block_columns, | |
) | |
expected_out = topology( | |
padded_bins, | |
blocking, | |
output_block_rows, | |
output_block_columns, | |
) | |
assert torch.all(torch.eq(out, expected_out)) | |