Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion assembler/assembly_to_binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,11 @@ def _convert_to_binary(self, instruction):
binary_instruction = 0
ow = self.operands_width
opw = self.opcode_width
vector_ops_with_rmask = {"V_ADD_VV", "V_ADD_VF", "V_MUL_VV", "V_SUB_VV", "V_MUL_VF", "V_EXP_V", "V_RECI_V", "V_RED_SUM", "V_RED_MAX"}

if instruction.opcode in vector_ops_with_rmask and rmask is None:
# Treat omitted rmask deterministically as "mask disabled" instead of crashing on None << ...
rmask = 0

if instruction.opcode in [
"S_ADDI_INT",
Expand Down Expand Up @@ -140,4 +145,3 @@ def generate_binary(self, asm_file: str, output_file: str):
self.write_binary_to_file(binary_instructions, output_file)
return binary_instructions


9 changes: 9 additions & 0 deletions assembler/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,9 @@ def parse_asm_file(file_path: str) -> list[Instruction]:
"""
instructions = []

vector_masked_unary_or_reduction_ops = {"V_EXP_V", "V_RECI_V", "V_RED_SUM", "V_RED_MAX"}
vector_masked_binary_ops = {"V_ADD_VV", "V_ADD_VF", "V_MUL_VV", "V_SUB_VV", "V_MUL_VF"}

with open(file_path) as file:
for line in file:
# Remove comments and strip whitespace
Expand Down Expand Up @@ -190,6 +193,12 @@ def parse_reg_or_int(operand):
imm = int(operand_2)
except ValueError:
pass
if opcode in vector_masked_unary_or_reduction_ops:
# Keep rmask/rstride aligned for 3-operand masked unary/reduction forms.
rstride = imm
elif opcode in vector_masked_binary_ops:
# Allow 3-operand vector ALU forms by defaulting omitted rmask to 0.
rstride = 0
elif len(operands) == 4:
operand_0, operand_1, operand_2, operand_3 = operands
rd = parse_reg_or_int(operand_0)
Expand Down
28 changes: 28 additions & 0 deletions assembler/tests/test_vector_rmask_handling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import unittest

from assembler.assembly_to_binary import AssemblyToBinary
from assembler.parser import Instruction, parse_asm_file


class TestVectorRmaskHandling(unittest.TestCase):
def setUp(self):
self.asm = AssemblyToBinary("doc/operation.svh", "doc/configuration.svh")

def test_parser_sets_default_rmask_for_three_operand_vector_binary(self):
asm_path = "/tmp/plena_test_vector_binary_missing_rmask.asm"
with open(asm_path, "w") as f:
f.write("V_ADD_VV gp1, gp2, gp3\n")

parsed = parse_asm_file(asm_path)
self.assertEqual(len(parsed), 1)
self.assertEqual(parsed[0].rmask, 0)

def test_encoder_defaults_missing_rmask_to_zero(self):
explicit_mask = Instruction("V_ADD_VV", 1, 2, 3, 0, None, None, None)
missing_mask = Instruction("V_ADD_VV", 1, 2, 3, None, None, None, None)

self.assertEqual(self.asm._convert_to_binary(missing_mask), self.asm._convert_to_binary(explicit_mask))


if __name__ == "__main__":
unittest.main()
Loading