| """
|
| Read the top of perf_takehome.py for more introduction.
|
|
|
| This file is separate mostly for ease of copying it to freeze the machine and
|
| reference kernel for testing.
|
| """
|
|
|
| from copy import copy
|
| from dataclasses import dataclass
|
| from enum import Enum
|
| from typing import Any, Literal
|
| import random
|
|
|
| Engine = Literal["alu", "load", "store", "flow"]
|
| Instruction = dict[Engine, list[tuple]]
|
|
|
|
|
| class CoreState(Enum):
|
| RUNNING = 1
|
| PAUSED = 2
|
| STOPPED = 3
|
|
|
|
|
| @dataclass
|
| class Core:
|
| id: int
|
| scratch: list[int]
|
| trace_buf: list[int]
|
| pc: int = 0
|
| state: CoreState = CoreState.RUNNING
|
|
|
|
|
| @dataclass
|
| class DebugInfo:
|
| """
|
| We give you some debug info but it's up to you to use it in Machine if you
|
| want to. You're also welcome to add more.
|
| """
|
|
|
|
|
| scratch_map: dict[int, (str, int)]
|
|
|
|
|
| def cdiv(a, b):
|
| return (a + b - 1) // b
|
|
|
|
|
| SLOT_LIMITS = {
|
| "alu": 12,
|
| "valu": 6,
|
| "load": 2,
|
| "store": 2,
|
| "flow": 1,
|
| "debug": 64,
|
| }
|
|
|
| VLEN = 8
|
|
|
| N_CORES = 1
|
| SCRATCH_SIZE = 1536
|
| BASE_ADDR_TID = 100000
|
|
|
|
|
| class Machine:
|
| """
|
| Simulator for a custom VLIW SIMD architecture.
|
|
|
| VLIW (Very Large Instruction Word): Cores are composed of different
|
| "engines" each of which can execute multiple "slots" per cycle in parallel.
|
| How many slots each engine can execute per cycle is limited by SLOT_LIMITS.
|
| Effects of instructions don't take effect until the end of cycle. Each
|
| cycle, all engines execute all of their filled slots for that instruction.
|
| Effects like writes to memory take place after all the inputs are read.
|
|
|
| SIMD: There are instructions for acting on vectors of VLEN elements in a
|
| single slot. You can use vload and vstore to load multiple contiguous
|
| elements but not non-contiguous elements. Use vbroadcast to broadcast a
|
| scalar to a vector and then operate on vectors with valu instructions.
|
|
|
| The memory and scratch space are composed of 32-bit words. The solution is
|
| plucked out of the memory at the end of the program. You can think of the
|
| scratch space as serving the purpose of registers, constant memory, and a
|
| manually-managed cache.
|
|
|
| Here's an example of what an instruction might look like:
|
|
|
| {"valu": [("*", 4, 0, 0), ("+", 8, 4, 0)], "load": [("load", 16, 17)]}
|
|
|
| In general every number in an instruction is a scratch address except for
|
| const and jump, and except for store and some flow instructions the first
|
| operand is the destination.
|
|
|
| This comment is not meant to be full ISA documentation though, for the rest
|
| you should look through the simulator code.
|
| """
|
|
|
| def __init__(
|
| self,
|
| mem_dump: list[int],
|
| program: list[Instruction],
|
| debug_info: DebugInfo,
|
| n_cores: int = 1,
|
| scratch_size: int = SCRATCH_SIZE,
|
| trace: bool = False,
|
| value_trace: dict[Any, int] = {},
|
| ):
|
| self.cores = [
|
| Core(id=i, scratch=[0] * scratch_size, trace_buf=[]) for i in range(n_cores)
|
| ]
|
| self.mem = copy(mem_dump)
|
| self.program = program
|
| self.debug_info = debug_info
|
| self.value_trace = value_trace
|
| self.prints = False
|
| self.cycle = 0
|
| self.enable_pause = True
|
| self.enable_debug = True
|
| if trace:
|
| self.setup_trace()
|
| else:
|
| self.trace = None
|
|
|
| def rewrite_instr(self, instr):
|
| """
|
| Rewrite an instruction to use scratch addresses instead of names
|
| """
|
| res = {}
|
| for name, slots in instr.items():
|
| res[name] = []
|
| for slot in slots:
|
| res[name].append(self.rewrite_slot(slot))
|
| return res
|
|
|
| def print_step(self, instr, core):
|
|
|
|
|
| print(self.scratch_map(core))
|
| print(core.pc, instr, self.rewrite_instr(instr))
|
|
|
| def scratch_map(self, core):
|
| res = {}
|
| for addr, (name, length) in self.debug_info.scratch_map.items():
|
| res[name] = core.scratch[addr : addr + length]
|
| return res
|
|
|
| def rewrite_slot(self, slot):
|
| return tuple(
|
| self.debug_info.scratch_map.get(s, (None, None))[0] or s for s in slot
|
| )
|
|
|
| def setup_trace(self):
|
| """
|
| The simulator generates traces in Chrome's Trace Event Format for
|
| visualization in Perfetto (or chrome://tracing if you prefer it). See
|
| the bottom of the file for info about how to use this.
|
|
|
| See the format docs in case you want to add more info to the trace:
|
| https://docs.google.com/document/d/1CvAClvFfyA5R-PhYUmn5OOQtYMH4h6I0nSsKchNAySU/preview
|
| """
|
| self.trace = open("trace.json", "w")
|
| self.trace.write("[")
|
| tid_counter = 0
|
| self.tids = {}
|
| for ci, core in enumerate(self.cores):
|
| self.trace.write(
|
| f'{{"name": "process_name", "ph": "M", "pid": {ci}, "tid": 0, "args": {{"name":"Core {ci}"}}}},\n'
|
| )
|
| for name, limit in SLOT_LIMITS.items():
|
| if name == "debug":
|
| continue
|
| for i in range(limit):
|
| tid_counter += 1
|
| self.trace.write(
|
| f'{{"name": "thread_name", "ph": "M", "pid": {ci}, "tid": {tid_counter}, "args": {{"name":"{name}-{i}"}}}},\n'
|
| )
|
| self.tids[(ci, name, i)] = tid_counter
|
|
|
|
|
| for ci, core in enumerate(self.cores):
|
| for name, limit in SLOT_LIMITS.items():
|
| if name == "debug":
|
| continue
|
| for i in range(limit):
|
| tid = self.tids[(ci, name, i)]
|
| self.trace.write(
|
| f'{{"name": "init", "cat": "op", "ph": "X", "pid": {ci}, "tid": {tid}, "ts": 0, "dur": 0}},\n'
|
| )
|
| for ci, core in enumerate(self.cores):
|
| self.trace.write(
|
| f'{{"name": "process_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": 0, "args": {{"name":"Core {ci} Scratch"}}}},\n'
|
| )
|
| for addr, (name, length) in self.debug_info.scratch_map.items():
|
| self.trace.write(
|
| f'{{"name": "thread_name", "ph": "M", "pid": {len(self.cores) + ci}, "tid": {BASE_ADDR_TID + addr}, "args": {{"name":"{name}-{length}"}}}},\n'
|
| )
|
|
|
| def run(self):
|
| for core in self.cores:
|
| if core.state == CoreState.PAUSED:
|
| core.state = CoreState.RUNNING
|
| while any(c.state == CoreState.RUNNING for c in self.cores):
|
| has_non_debug = False
|
| for core in self.cores:
|
| if core.state != CoreState.RUNNING:
|
| continue
|
| if core.pc >= len(self.program):
|
| core.state = CoreState.STOPPED
|
| continue
|
| instr = self.program[core.pc]
|
| if self.prints:
|
| self.print_step(instr, core)
|
| core.pc += 1
|
| self.step(instr, core)
|
| if any(name != "debug" for name in instr.keys()):
|
| has_non_debug = True
|
| if has_non_debug:
|
| self.cycle += 1
|
|
|
| def alu(self, core, op, dest, a1, a2):
|
| a1 = core.scratch[a1]
|
| a2 = core.scratch[a2]
|
| match op:
|
| case "+":
|
| res = a1 + a2
|
| case "-":
|
| res = a1 - a2
|
| case "*":
|
| res = a1 * a2
|
| case "//":
|
| res = a1 // a2
|
| case "cdiv":
|
| res = cdiv(a1, a2)
|
| case "^":
|
| res = a1 ^ a2
|
| case "&":
|
| res = a1 & a2
|
| case "|":
|
| res = a1 | a2
|
| case "<<":
|
| res = a1 << a2
|
| case ">>":
|
| res = a1 >> a2
|
| case "%":
|
| res = a1 % a2
|
| case "<":
|
| res = int(a1 < a2)
|
| case "==":
|
| res = int(a1 == a2)
|
| case _:
|
| raise NotImplementedError(f"Unknown alu op {op}")
|
| res = res % (2**32)
|
| self.scratch_write[dest] = res
|
|
|
| def valu(self, core, *slot):
|
| match slot:
|
| case ("vbroadcast", dest, src):
|
| for i in range(VLEN):
|
| self.scratch_write[dest + i] = core.scratch[src]
|
| case ("multiply_add", dest, a, b, c):
|
| for i in range(VLEN):
|
| mul = (core.scratch[a + i] * core.scratch[b + i]) % (2**32)
|
| self.scratch_write[dest + i] = (mul + core.scratch[c + i]) % (2**32)
|
| case (op, dest, a1, a2):
|
| for i in range(VLEN):
|
| self.alu(core, op, dest + i, a1 + i, a2 + i)
|
| case _:
|
| raise NotImplementedError(f"Unknown valu op {slot}")
|
|
|
| def load(self, core, *slot):
|
| match slot:
|
| case ("load", dest, addr):
|
|
|
| self.scratch_write[dest] = self.mem[core.scratch[addr]]
|
| case ("load_offset", dest, addr, offset):
|
|
|
| self.scratch_write[dest + offset] = self.mem[
|
| core.scratch[addr + offset]
|
| ]
|
| case ("vload", dest, addr):
|
| addr = core.scratch[addr]
|
| for vi in range(VLEN):
|
| self.scratch_write[dest + vi] = self.mem[addr + vi]
|
| case ("const", dest, val):
|
| self.scratch_write[dest] = (val) % (2**32)
|
| case _:
|
| raise NotImplementedError(f"Unknown load op {slot}")
|
|
|
| def store(self, core, *slot):
|
| match slot:
|
| case ("store", addr, src):
|
| addr = core.scratch[addr]
|
| self.mem_write[addr] = core.scratch[src]
|
| case ("vstore", addr, src):
|
| addr = core.scratch[addr]
|
| for vi in range(VLEN):
|
| self.mem_write[addr + vi] = core.scratch[src + vi]
|
| case _:
|
| raise NotImplementedError(f"Unknown store op {slot}")
|
|
|
| def flow(self, core, *slot):
|
| match slot:
|
| case ("select", dest, cond, a, b):
|
| self.scratch_write[dest] = (
|
| core.scratch[a] if core.scratch[cond] != 0 else core.scratch[b]
|
| )
|
| case ("add_imm", dest, a, imm):
|
| self.scratch_write[dest] = (core.scratch[a] + imm) % (2**32)
|
| case ("vselect", dest, cond, a, b):
|
| for vi in range(VLEN):
|
| self.scratch_write[dest + vi] = (
|
| core.scratch[a + vi]
|
| if core.scratch[cond + vi] != 0
|
| else core.scratch[b + vi]
|
| )
|
| case ("halt",):
|
| core.state = CoreState.STOPPED
|
| case ("pause",):
|
| if self.enable_pause:
|
| core.state = CoreState.PAUSED
|
| case ("trace_write", val):
|
| core.trace_buf.append(core.scratch[val])
|
| case ("cond_jump", cond, addr):
|
| if core.scratch[cond] != 0:
|
| core.pc = addr
|
| case ("cond_jump_rel", cond, offset):
|
| if core.scratch[cond] != 0:
|
| core.pc += offset
|
| case ("jump", addr):
|
| core.pc = addr
|
| case ("jump_indirect", addr):
|
| core.pc = core.scratch[addr]
|
| case ("coreid", dest):
|
| self.scratch_write[dest] = core.id
|
| case _:
|
| raise NotImplementedError(f"Unknown flow op {slot}")
|
|
|
| def trace_post_step(self, instr, core):
|
|
|
| for addr, (name, length) in self.debug_info.scratch_map.items():
|
| if any((addr + vi) in self.scratch_write for vi in range(length)):
|
| val = str(core.scratch[addr : addr + length])
|
| val = val.replace("[", "").replace("]", "")
|
| self.trace.write(
|
| f'{{"name": "{val}", "cat": "op", "ph": "X", "pid": {len(self.cores) + core.id}, "tid": {BASE_ADDR_TID + addr}, "ts": {self.cycle}, "dur": 1 }},\n'
|
| )
|
|
|
| def trace_slot(self, core, slot, name, i):
|
| self.trace.write(
|
| f'{{"name": "{slot[0]}", "cat": "op", "ph": "X", "pid": {core.id}, "tid": {self.tids[(core.id, name, i)]}, "ts": {self.cycle}, "dur": 1, "args":{{"slot": "{str(slot)}", "named": "{str(self.rewrite_slot(slot))}" }} }},\n'
|
| )
|
|
|
| def step(self, instr: Instruction, core):
|
| """
|
| Execute all the slots in each engine for a single instruction bundle
|
| """
|
| ENGINE_FNS = {
|
| "alu": self.alu,
|
| "valu": self.valu,
|
| "load": self.load,
|
| "store": self.store,
|
| "flow": self.flow,
|
| }
|
| self.scratch_write = {}
|
| self.mem_write = {}
|
| for name, slots in instr.items():
|
| if name == "debug":
|
| if not self.enable_debug:
|
| continue
|
| for slot in slots:
|
| if slot[0] == "compare":
|
| loc, key = slot[1], slot[2]
|
| ref = self.value_trace[key]
|
| res = core.scratch[loc]
|
| assert res == ref, f"{res} != {ref} for {key} at pc={core.pc}"
|
| elif slot[0] == "vcompare":
|
| loc, keys = slot[1], slot[2]
|
| ref = [self.value_trace[key] for key in keys]
|
| res = core.scratch[loc : loc + VLEN]
|
| assert res == ref, (
|
| f"{res} != {ref} for {keys} at pc={core.pc} loc={loc}"
|
| )
|
| continue
|
| assert len(slots) <= SLOT_LIMITS[name]
|
| for i, slot in enumerate(slots):
|
| if self.trace is not None:
|
| self.trace_slot(core, slot, name, i)
|
| ENGINE_FNS[name](core, *slot)
|
| for addr, val in self.scratch_write.items():
|
| core.scratch[addr] = val
|
| for addr, val in self.mem_write.items():
|
| self.mem[addr] = val
|
|
|
| if self.trace:
|
| self.trace_post_step(instr, core)
|
|
|
| del self.scratch_write
|
| del self.mem_write
|
|
|
| def __del__(self):
|
| if self.trace is not None:
|
| self.trace.write("]")
|
| self.trace.close()
|
|
|
|
|
| @dataclass
|
| class Tree:
|
| """
|
| An implicit perfect balanced binary tree with values on the nodes.
|
| """
|
|
|
| height: int
|
| values: list[int]
|
|
|
| @staticmethod
|
| def generate(height: int):
|
| n_nodes = 2 ** (height + 1) - 1
|
| values = [random.randint(0, 2**30 - 1) for _ in range(n_nodes)]
|
| return Tree(height, values)
|
|
|
|
|
| @dataclass
|
| class Input:
|
| """
|
| A batch of inputs, indices to nodes (starting as 0) and initial input
|
| values. We then iterate these for a specified number of rounds.
|
| """
|
|
|
| indices: list[int]
|
| values: list[int]
|
| rounds: int
|
|
|
| @staticmethod
|
| def generate(forest: Tree, batch_size: int, rounds: int):
|
| indices = [0 for _ in range(batch_size)]
|
| values = [random.randint(0, 2**30 - 1) for _ in range(batch_size)]
|
| return Input(indices, values, rounds)
|
|
|
|
|
| HASH_STAGES = [
|
| ("+", 0x7ED55D16, "+", "<<", 12),
|
| ("^", 0xC761C23C, "^", ">>", 19),
|
| ("+", 0x165667B1, "+", "<<", 5),
|
| ("+", 0xD3A2646C, "^", "<<", 9),
|
| ("+", 0xFD7046C5, "+", "<<", 3),
|
| ("^", 0xB55A4F09, "^", ">>", 16),
|
| ]
|
|
|
|
|
| def myhash(a: int) -> int:
|
| """A simple 32-bit hash function"""
|
| fns = {
|
| "+": lambda x, y: x + y,
|
| "^": lambda x, y: x ^ y,
|
| "<<": lambda x, y: x << y,
|
| ">>": lambda x, y: x >> y,
|
| }
|
|
|
| def r(x):
|
| return x % (2**32)
|
|
|
| for op1, val1, op2, op3, val3 in HASH_STAGES:
|
| a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
|
|
|
| return a
|
|
|
|
|
| def reference_kernel(t: Tree, inp: Input):
|
| """
|
| Reference implementation of the kernel.
|
|
|
| A parallel tree traversal where at each node we set
|
| cur_inp_val = myhash(cur_inp_val ^ node_val)
|
| and then choose the left branch if cur_inp_val is even.
|
| If we reach the bottom of the tree we wrap around to the top.
|
| """
|
| for h in range(inp.rounds):
|
| for i in range(len(inp.indices)):
|
| idx = inp.indices[i]
|
| val = inp.values[i]
|
| val = myhash(val ^ t.values[idx])
|
| idx = 2 * idx + (1 if val % 2 == 0 else 2)
|
| idx = 0 if idx >= len(t.values) else idx
|
| inp.values[i] = val
|
| inp.indices[i] = idx
|
|
|
|
|
| def build_mem_image(t: Tree, inp: Input) -> list[int]:
|
| """
|
| Build a flat memory image of the problem.
|
| """
|
| header = 7
|
| extra_room = len(t.values) + len(inp.indices) * 2 + VLEN * 2 + 32
|
| mem = [0] * (
|
| header + len(t.values) + len(inp.indices) + len(inp.values) + extra_room
|
| )
|
| forest_values_p = header
|
| inp_indices_p = forest_values_p + len(t.values)
|
| inp_values_p = inp_indices_p + len(inp.values)
|
| extra_room = inp_values_p + len(inp.values)
|
|
|
| mem[0] = inp.rounds
|
| mem[1] = len(t.values)
|
| mem[2] = len(inp.indices)
|
| mem[3] = t.height
|
| mem[4] = forest_values_p
|
| mem[5] = inp_indices_p
|
| mem[6] = inp_values_p
|
| mem[7] = extra_room
|
|
|
| mem[header:inp_indices_p] = t.values
|
| mem[inp_indices_p:inp_values_p] = inp.indices
|
| mem[inp_values_p:] = inp.values
|
| return mem
|
|
|
|
|
| def myhash_traced(a: int, trace: dict[Any, int], round: int, batch_i: int) -> int:
|
| """A simple 32-bit hash function"""
|
| fns = {
|
| "+": lambda x, y: x + y,
|
| "^": lambda x, y: x ^ y,
|
| "<<": lambda x, y: x << y,
|
| ">>": lambda x, y: x >> y,
|
| }
|
|
|
| def r(x):
|
| return x % (2**32)
|
|
|
| for i, (op1, val1, op2, op3, val3) in enumerate(HASH_STAGES):
|
| a = r(fns[op2](r(fns[op1](a, val1)), r(fns[op3](a, val3))))
|
| trace[(round, batch_i, "hash_stage", i)] = a
|
|
|
| return a
|
|
|
|
|
| def reference_kernel2(mem: list[int], trace: dict[Any, int] = {}):
|
| """
|
| Reference implementation of the kernel on a flat memory.
|
| """
|
|
|
| rounds = mem[0]
|
| n_nodes = mem[1]
|
| batch_size = mem[2]
|
| forest_height = mem[3]
|
|
|
| forest_values_p = mem[4]
|
| inp_indices_p = mem[5]
|
| inp_values_p = mem[6]
|
| yield mem
|
| for h in range(rounds):
|
| for i in range(batch_size):
|
| idx = mem[inp_indices_p + i]
|
| trace[(h, i, "idx")] = idx
|
| val = mem[inp_values_p + i]
|
| trace[(h, i, "val")] = val
|
| node_val = mem[forest_values_p + idx]
|
| trace[(h, i, "node_val")] = node_val
|
| val = myhash_traced(val ^ node_val, trace, h, i)
|
| trace[(h, i, "hashed_val")] = val
|
| idx = 2 * idx + (1 if val % 2 == 0 else 2)
|
| trace[(h, i, "next_idx")] = idx
|
| idx = 0 if idx >= n_nodes else idx
|
| trace[(h, i, "wrapped_idx")] = idx
|
| mem[inp_values_p + i] = val
|
| mem[inp_indices_p + i] = idx
|
|
|
|
|
|
|
| yield mem
|
|
|