Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions justfile
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,9 @@ test-ffn-smollm2-135m:
test-ffn-clm60m:
python3 transactional_emulator/testbench/models/multi_model_ffn_test.py clm60m

test-clm60m-rtl-config rtl_root="../PLENA_RTL":
python3 transactional_emulator/testbench/models/clm60m_rtl_config_test.py --rtl-root {{rtl_root}}

test-decoder-multi-model:
python3 transactional_emulator/testbench/models/multi_model_decoder_test.py

Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ dependencies = [
"matplotlib",
"ruff>=0.12",
"pydantic>=2.0",
"tomlkit>=0.15.0",
]

# Use PyTorch CUDA index only for torch packages
Expand Down
102 changes: 76 additions & 26 deletions tools/memory_mapping/memory_map.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,27 @@ def hex_to_bytes(hex_str):
return bytes.fromhex(hex_str)


def pack_values_to_bytes(values, data_width):
"""Pack fixed-width integer values into bytes, least-significant element first."""
data = 0
bits_left = 0
out = bytearray()
mask = (1 << data_width) - 1

for value in values:
data |= (int(value) & mask) << bits_left
bits_left += data_width
while bits_left >= 8:
out.append(data & 0xFF)
data >>= 8
bits_left -= 8

if bits_left > 0:
out.append(data & 0xFF)

return bytes(out)


def map_data_to_fake_hbm_for_rtl_sim(
blocks, element_width, block_width, bias, bias_width, directory, combined_blk_dim, append=True, hbm_row_width=64
):
Expand Down Expand Up @@ -120,7 +141,15 @@ def map_data_to_fake_hbm_for_rtl_sim(


def map_mx_data_to_hbm_for_behave_sim(
blocks, element_width, block_width, bias, bias_width, directory, append=True, hbm_row_width=64
blocks,
element_width,
block_width,
bias,
bias_width,
directory,
append=True,
hbm_row_width=64,
logical_row_elements=None,
):
"""
Maps the quantized blocks and bias to binary memory file for fake HBM memory.
Expand All @@ -140,8 +169,15 @@ def map_mx_data_to_hbm_for_behave_sim(
for row_idx, row in enumerate(blocks):
_ = " ".join(f"0x{val:02X}" for val in row)

hbm_row_elem_num = hbm_row_width // (element_width)
hbm_row_bias_num = hbm_row_width // (bias_width)
hbm_row_bytes = (hbm_row_width + 7) // 8
if logical_row_elements is None:
logical_row_elements = hbm_row_width // element_width
blocks_per_logical_row = (logical_row_elements + block_width - 1) // block_width

scale_row_bits = (hbm_row_width * bias_width + (element_width * block_width) - 1) // (
element_width * block_width
)
scale_row_bytes = (scale_row_bits + 7) // 8

with open(output_file, mode) as f:
# Track total bytes written
Expand All @@ -152,23 +188,30 @@ def map_mx_data_to_hbm_for_behave_sim(
# Process blocks
row_buffer = bytearray()

for i, block in enumerate(blocks):
hex_str = map_block_to_value(block, element_width)
block_bytes = hex_to_bytes(hex_str)
row_buffer.extend(block_bytes)

# Write when row is full
if len(row_buffer) >= hbm_row_elem_num:
f.write(row_buffer[:hbm_row_elem_num])
total_bytes_written += hbm_row_elem_num
blocks_bytes_written += hbm_row_elem_num
row_buffer = bytearray() # Reset buffer after writing
blocks_in_row = 0
for _i, block in enumerate(blocks):
row_buffer.extend(pack_values_to_bytes(block, element_width))
blocks_in_row += 1

if blocks_in_row == blocks_per_logical_row:
if len(row_buffer) > hbm_row_bytes:
raise ValueError(
f"Packed element row ({len(row_buffer)} bytes) exceeds HBM row width "
f"({hbm_row_bytes} bytes)"
)
row_padding = hbm_row_bytes - len(row_buffer)
row_buffer.extend(b"\x00" * row_padding)
f.write(row_buffer)
total_bytes_written += len(row_buffer)
blocks_bytes_written += len(row_buffer)
row_buffer = bytearray()
blocks_in_row = 0

# Flush any remaining block data
blocks_row_padding = 0
if len(row_buffer) > 0:
# Pad to row width
blocks_row_padding = hbm_row_elem_num - len(row_buffer)
blocks_row_padding = hbm_row_bytes - len(row_buffer)
row_buffer.extend(b"\x00" * blocks_row_padding)
f.write(row_buffer)
total_bytes_written += len(row_buffer)
Expand All @@ -177,17 +220,24 @@ def map_mx_data_to_hbm_for_behave_sim(
# Process bias
row_buffer = bytearray()

for i, b in enumerate(bias):
hex_str = map_scale_to_value(b, bias_width)
bias_bytes = hex_to_bytes(hex_str)
row_buffer.extend(bias_bytes)

# Write when row is full
if len(row_buffer) >= hbm_row_bias_num:
f.write(row_buffer[:hbm_row_bias_num])
total_bytes_written += hbm_row_bias_num
bias_bytes_written += hbm_row_bias_num
scales_in_row = 0
for _i, b in enumerate(bias):
row_buffer.extend(pack_values_to_bytes([b], bias_width))
scales_in_row += 1

if scales_in_row == blocks_per_logical_row:
if len(row_buffer) > scale_row_bytes:
raise ValueError(
f"Packed scale row ({len(row_buffer)} bytes) exceeds scale row width "
f"({scale_row_bytes} bytes)"
)
row_padding = scale_row_bytes - len(row_buffer)
row_buffer.extend(b"\x00" * row_padding)
f.write(row_buffer)
total_bytes_written += len(row_buffer)
bias_bytes_written += len(row_buffer)
row_buffer = bytearray()
scales_in_row = 0

# # For Little Endian Purpose
# if len(row_buffer) > 0:
Expand All @@ -201,7 +251,7 @@ def map_mx_data_to_hbm_for_behave_sim(
bias_row_padding = 0
if len(row_buffer) > 0:
# Calculate padding needed
bias_row_padding = hbm_row_bias_num - len(row_buffer)
bias_row_padding = scale_row_bytes - len(row_buffer)
row_buffer.extend(b"\x00" * bias_row_padding)
f.write(row_buffer)
total_bytes_written += len(row_buffer)
Expand Down
Loading
Loading