Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
d8d4f4f
Add store_act_asm template for VRAM to HBM activation storage
gaoziqian123 Feb 7, 2026
5740122
Merge remote-tracking branch 'origin/main' into feature/store-act-asm
gaoziqian123 Feb 15, 2026
9ec3a6e
sync local compiler changes
gaoziqian123 Feb 15, 2026
6898656
Add tilelang runtime compiler materials
gaoziqian123 Apr 26, 2026
5fbece5
update tilelang runtime compiler support
gaoziqian123 Apr 27, 2026
8d15546
add tilelang tvm compiler package
gaoziqian123 Apr 30, 2026
03299f3
sync compiler tree from PLENA_Simulator working copy
gaoziqian123 May 6, 2026
2936ae0
migrate frontend to all-graph-layer pipeline
gaoziqian123 May 9, 2026
2280c81
docs: refresh PIPELINE_ARCHITECTURE / MIGRATION_PLAN / AI_AGENT_GUIDE…
gaoziqian123 May 10, 2026
28d7602
unify VRAM/MRAM ≥2D buffer layout to BSHD; consolidate row_*_at addre…
gaoziqian123 May 11, 2026
0259583
remove legacy non-min kernel demos
gaoziqian123 May 11, 2026
9ead591
add SPMD_REWRITE.md: design for replacing 4 lane-fusion graph passes
gaoziqian123 May 11, 2026
da1ed20
SPMD step 1: classify_lane_use pass + unit tests
gaoziqian123 May 11, 2026
e7feb7d
mid_ir pipeline: drop graph layer, add cluster_dim, BTMV/MV/vram-vram…
gaoziqian123 May 12, 2026
4c246b4
rope_min: v↔fp transfer treats cluster phase as 0 + multi-lane wrap
gaoziqian123 May 12, 2026
5409a30
register_alloc: spill_borrow also filters pinned GPs
gaoziqian123 May 13, 2026
75b7b08
gemm region+dim_roles schema, pinned globals row-major-flat, row_stac…
gaoziqian123 May 14, 2026
17e5b50
SSB chain support: concat_min kernel, head-layout helpers, mid_ir + k…
gaoziqian123 May 16, 2026
32a9bd5
Merge branch 'main' into plena-compiler
gaoziqian123 May 17, 2026
4369210
Add loop register alloc/interchange/fusion passes, GQA flash attentio…
gaoziqian123 May 19, 2026
6e40b6d
v2 backend: PreIsaIR v2 → MIR → ISA with scope-recursive register all…
gaoziqian123 May 21, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
371 changes: 371 additions & 0 deletions SPMD_REWRITE.md

Large diffs are not rendered by default.

3 changes: 2 additions & 1 deletion asm_templates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,8 +30,8 @@
from .reset_reg_asm import reset_fpreg_asm, reset_reg_asm
from .rope_asm import rope_asm
from .silu_asm import silu_asm
from .gelu_asm import gelu_asm
from .store_act_asm import store_act_asm
from .gemv_asm import gemv_asm

__all__ = [
"batched_matmul_asm",
Expand All @@ -56,5 +56,6 @@
"rms_norm_asm",
"rope_asm",
"silu_asm",
"gelu_asm",
"store_act_asm",
]
1 change: 1 addition & 0 deletions asm_templates/preload_act.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def preload_act_asm(
inner_loop_register = alive_registers[4]

stride_len = vlen if stride_size is None else stride_size
scale_len = hidden_size * batch if scale_size is None else scale_size

# Set scale offset
generated_code += _load_large_int(a_actual_register, hidden_size * batch)
Expand Down
95 changes: 60 additions & 35 deletions assembler/assembly_to_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,30 +38,43 @@ def _convert_to_binary(self, instruction):
imm = instruction.imm
rmask = instruction.rmask
binary_instruction = 0

imm_mask = (1 << self.imm_width) - 1 if self.imm_width > 0 else 0
if imm_mask and isinstance(imm, int) and (imm < 0 or imm > imm_mask):
print(
f"[assembler] WARN: imm overflow on {instruction.opcode}: "
f"raw imm={imm} (0x{imm & 0xFFFFFFFFFFFFFFFF:X}), "
f"IMM_WIDTH={self.imm_width}, masking to 0x{imm & imm_mask:X}"
)
imm = imm & imm_mask
# print(f"Converting instruction: {instruction.opcode} with opcode={hex(opcode)}, rd={rd}, rs1={rs1}, rs2={rs2}, rstride={rstride}, funct1={funct1}, funct2={funct2}, imm={imm}")
ow = self.operands_width
opw = self.opcode_width

if instruction.opcode in [
"S_ADDI_INT",
"M_MM_WO",
"S_LD_FP",
"S_ST_FP",
"S_LD_INT",
"S_ST_INT",
"S_MAP_V_FP",
"V_RED_MAX",
"V_RECI_V",
"V_EXP_V",
]:
binary_instruction = (imm << (opw + 2 * ow)) + (rs1 << (opw + ow)) + (rd << opw) + opcode
if instruction.opcode in ["S_ADDI_INT", "S_SLLI_INT", "S_SRLI_INT", "M_MM_WO", "S_LD_FP", "S_ST_FP", "S_LD_INT", "S_ST_INT", "S_MAP_V_FP", "S_MAP_FP_V"]:
binary_instruction = (
(imm << (opw + 2 * ow)) +
(rs1 << (opw + ow)) +
(rd << opw) +
opcode
)
elif instruction.opcode in ["S_LUI_INT", "M_MV_WO", "M_BMM_WO", "M_BMV_WO"]:
binary_instruction = (imm << (opw + ow)) + (rd << opw) + opcode
elif instruction.opcode in ["S_MV_FP", "S_RECI_FP", "S_EXP_FP", "S_SQRT_FP", "V_EXP_V", "V_RED_SUM"]:
binary_instruction = (rs1 << (opw + ow)) + (rd << opw) + opcode
elif instruction.opcode in ["C_BREAK"]:
binary_instruction = opcode
elif instruction.opcode in ["C_SET_SCALE_REG", "C_SET_STRIDE_REG", "C_SET_V_MASK_REG", "C_LOOP_END"]:
binary_instruction = (rd << opw) + opcode
binary_instruction = (
(imm << (opw + ow)) +
(rd << opw) +
opcode
)
elif instruction.opcode in [ "S_MV_FP", "S_RECI_FP", "S_EXP_FP", "S_SQRT_FP"]:
binary_instruction = (
(rs1 << (opw + ow)) +
(rd << opw) +
opcode
)
elif instruction.opcode in [ "C_SET_SCALE_REG", "C_SET_STRIDE_REG", "C_SET_V_MASK_REG", "C_LOOP_END"]:
binary_instruction = (
(rd << opw) +
opcode
)
elif instruction.opcode in ["C_LOOP_START"]:
# C_LOOP_START rd, imm - uses 22-bit immediate like S_LUI_INT
binary_instruction = (imm << (opw + ow)) + (rd << opw) + opcode
Expand All @@ -74,17 +87,15 @@ def _convert_to_binary(self, instruction):
+ (rd << opw)
+ opcode
)
elif instruction.opcode in [
"V_ADD_VV",
"V_ADD_VF",
"V_MUL_VV",
"V_SUB_VV",
"V_MUL_VF",
"V_EXP_V",
"V_RECI_V",
"V_RED_SUM",
"V_RED_MAX",
]:
elif instruction.opcode in ["V_EXP_V", "V_RECI_V", "V_RED_SUM", "V_RED_MAX"]:
binary_instruction = (
(rmask << (opw + 3 * ow)) +
(0 << (opw + 2 * ow)) +
(rs1 << (opw + ow)) +
(rd << opw) +
opcode
)
elif instruction.opcode in ["V_ADD_VV", "V_ADD_VF", "V_MUL_VV", "V_SUB_VV", "V_MUL_VF"]:
binary_instruction = (
(rmask << (opw + 3 * ow)) + (rs2 << (opw + 2 * ow)) + (rs1 << (opw + ow)) + (rd << opw) + opcode
)
Expand Down Expand Up @@ -122,10 +133,17 @@ def _convert_to_binary(self, instruction):
return binary_instruction

def write_binary_to_file(self, binary_instructions, output_file: str):
with open(output_file, "w") as file:
for instruction in binary_instructions:
file.write(f"0x{instruction:08X}\n")

instr_mask = (1 << self.instruction_length) - 1 if self.instruction_length > 0 else 0xFFFFFFFF
with open(output_file, 'w') as file:
for idx, instruction in enumerate(binary_instructions):
if instruction & ~instr_mask:
print(
f"[assembler] WARN: instruction #{idx} overflows "
f"INSTRUCTION_LENGTH={self.instruction_length}: "
f"raw=0x{instruction:X}, truncating to 0x{instruction & instr_mask:08X}"
)
file.write(f"0x{instruction & instr_mask:08X}\n")

def generate_binary(self, asm_file: str, output_file: str):
"""
Generate binary instructions from the assembled instructions.
Expand All @@ -141,3 +159,10 @@ def generate_binary(self, asm_file: str, output_file: str):
return binary_instructions


# isa_file_path = '../../src/definitions/operation.svh'
# config_file_path = '../../src/definitions/configuration.svh'
# asm_file_path = f'../../test/{args.test_type}/{args.layer}.asm'
# print(f'Assembling {asm_file_path} to {args.layer}.mem')
# output_file_path = f'../../test/{args.test_type}/{args.layer}.mem'
# assembler = AssemblyToBinary(isa_file_path, config_file_path)
# assembler.generate_binary(asm_file_path, output_file_path)
20 changes: 14 additions & 6 deletions assembler/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,14 +182,22 @@ def parse_reg_or_int(operand):
imm = int(operand_1)
except ValueError:
imm = None
# If it looks like register, rs2; else, imm (overwrites imm if rs1 not present)
if operand_2.strip().startswith(("gp", "f", "a")):
rs2 = parse_reg_or_int(operand_2)
else:
# Some vector ops use a 3-operand form where the last field is
# rmask, not rs2/imm.
if opcode in {"V_EXP_V", "V_RECI_V", "V_RED_SUM", "V_RED_MAX"}:
try:
imm = int(operand_2)
rstride = int(operand_2)
except ValueError:
pass
rstride = None
else:
# If it looks like register, rs2; else, imm (overwrites imm if rs1 not present)
if operand_2.strip().startswith(('gp','f','a')):
rs2 = parse_reg_or_int(operand_2)
else:
try:
imm = int(operand_2)
except ValueError:
pass
elif len(operands) == 4:
operand_0, operand_1, operand_2, operand_3 = operands
rd = parse_reg_or_int(operand_0)
Expand Down
8 changes: 8 additions & 0 deletions doc/operation.svh
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ typedef enum logic [instruction_pkg::OPCODE_WIDTH - 1:0] {
S_LD_FP = 6'h1E,
S_ST_FP = 6'h1F,
S_MAP_V_FP = 6'h20,
S_MAP_FP_V = 6'h35,

// Scalar Operations (INT)
S_ADD_INT = 6'h21,
Expand All @@ -155,6 +156,13 @@ typedef enum logic [instruction_pkg::OPCODE_WIDTH - 1:0] {
S_LUI_INT = 6'h25,
S_LD_INT = 6'h26,
S_ST_INT = 6'h27,
// Logical shifts. Arithmetic-right (SRA) intentionally omitted: PLENA's
// integer domain is unsigned address/index arithmetic, no negative values
// ever flow through the GP pool, so sign-extension on shift is a no-op.
S_SLL_INT = 6'h36,
S_SLLI_INT = 6'h37,
S_SRL_INT = 6'h38,
S_SRLI_INT = 6'h39,

// Memory Operations
H_PREFETCH_M = 6'h28,
Expand Down
54 changes: 54 additions & 0 deletions doc/plena_isa_spec.md
Original file line number Diff line number Diff line change
Expand Up @@ -407,6 +407,60 @@ S_ADDI_INT gp2, gp1, 64 ; gp2 = 128 + 64 = 192

Load upper immediate value into the integer register.

#### S_SLL_INT

**Format:** `S_SLL_INT rd, rs1, rs2`

**Operation:** `gp_reg<rd> = gp_reg<rs1> << (gp_reg<rs2> & 0x1F)`

**Description:**

Logical shift left, shift amount taken from the lower 5 bits of `gp_reg<rs2>`.

#### S_SLLI_INT

**Format:** `S_SLLI_INT rd, rs1, imm`

**Operation:** `gp_reg<rd> = gp_reg<rs1> << (imm & 0x1F)`

**Description:**

Logical shift left by an immediate. The lower 5 bits of `imm` are used as the shift amount; upper bits are ignored.

**Example:**
```asm
S_ADDI_INT gp1, gp0, 5 ; gp1 = 5
S_SLLI_INT gp2, gp1, 3 ; gp2 = 5 << 3 = 40
```

#### S_SRL_INT

**Format:** `S_SRL_INT rd, rs1, rs2`

**Operation:** `gp_reg<rd> = gp_reg<rs1> >> (gp_reg<rs2> & 0x1F)` (logical, zero-fill)

**Description:**

Logical shift right, shift amount taken from the lower 5 bits of `gp_reg<rs2>`.

#### S_SRLI_INT

**Format:** `S_SRLI_INT rd, rs1, imm`

**Operation:** `gp_reg<rd> = gp_reg<rs1> >> (imm & 0x1F)` (logical, zero-fill)

**Description:**

Logical shift right by an immediate. The lower 5 bits of `imm` are used as the shift amount.

**Example:**
```asm
S_ADDI_INT gp1, gp0, 200 ; gp1 = 200
S_SRLI_INT gp2, gp1, 3 ; gp2 = 200 >> 3 = 25
```

**Note on omissions:** PLENA does **not** provide arithmetic right shift (SRA) -- the integer pool is treated as unsigned for address arithmetic. There is also no AND/OR/XOR; bit-level masking is not part of the kernel programming model.

#### S_LD_INT

**Format:** `S_LD_INT rd, rs1, imm`
Expand Down
Loading