diff --git a/justfile b/justfile index e4dd071..e243e81 100644 --- a/justfile +++ b/justfile @@ -175,6 +175,9 @@ test-ffn-smollm2-135m: test-ffn-clm60m: python3 transactional_emulator/testbench/models/multi_model_ffn_test.py clm60m +test-clm60m-rtl-config rtl_root="../PLENA_RTL": + python3 transactional_emulator/testbench/models/clm60m_rtl_config_test.py --rtl-root {{rtl_root}} + test-decoder-multi-model: python3 transactional_emulator/testbench/models/multi_model_decoder_test.py diff --git a/pyproject.toml b/pyproject.toml index e2739fc..af6004e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -16,6 +16,7 @@ dependencies = [ "matplotlib", "ruff>=0.12", "pydantic>=2.0", + "tomlkit>=0.15.0", ] # Use PyTorch CUDA index only for torch packages diff --git a/tools/memory_mapping/memory_map.py b/tools/memory_mapping/memory_map.py index 4ac7782..84796b4 100644 --- a/tools/memory_mapping/memory_map.py +++ b/tools/memory_mapping/memory_map.py @@ -58,6 +58,27 @@ def hex_to_bytes(hex_str): return bytes.fromhex(hex_str) +def pack_values_to_bytes(values, data_width): + """Pack fixed-width integer values into bytes, least-significant element first.""" + data = 0 + bits_left = 0 + out = bytearray() + mask = (1 << data_width) - 1 + + for value in values: + data |= (int(value) & mask) << bits_left + bits_left += data_width + while bits_left >= 8: + out.append(data & 0xFF) + data >>= 8 + bits_left -= 8 + + if bits_left > 0: + out.append(data & 0xFF) + + return bytes(out) + + def map_data_to_fake_hbm_for_rtl_sim( blocks, element_width, block_width, bias, bias_width, directory, combined_blk_dim, append=True, hbm_row_width=64 ): @@ -120,7 +141,15 @@ def map_data_to_fake_hbm_for_rtl_sim( def map_mx_data_to_hbm_for_behave_sim( - blocks, element_width, block_width, bias, bias_width, directory, append=True, hbm_row_width=64 + blocks, + element_width, + block_width, + bias, + bias_width, + directory, + append=True, + hbm_row_width=64, + logical_row_elements=None, ): """ Maps the quantized blocks and bias to binary memory file for fake HBM memory. @@ -140,8 +169,15 @@ def map_mx_data_to_hbm_for_behave_sim( for row_idx, row in enumerate(blocks): _ = " ".join(f"0x{val:02X}" for val in row) - hbm_row_elem_num = hbm_row_width // (element_width) - hbm_row_bias_num = hbm_row_width // (bias_width) + hbm_row_bytes = (hbm_row_width + 7) // 8 + if logical_row_elements is None: + logical_row_elements = hbm_row_width // element_width + blocks_per_logical_row = (logical_row_elements + block_width - 1) // block_width + + scale_row_bits = (hbm_row_width * bias_width + (element_width * block_width) - 1) // ( + element_width * block_width + ) + scale_row_bytes = (scale_row_bits + 7) // 8 with open(output_file, mode) as f: # Track total bytes written @@ -152,23 +188,30 @@ def map_mx_data_to_hbm_for_behave_sim( # Process blocks row_buffer = bytearray() - for i, block in enumerate(blocks): - hex_str = map_block_to_value(block, element_width) - block_bytes = hex_to_bytes(hex_str) - row_buffer.extend(block_bytes) - - # Write when row is full - if len(row_buffer) >= hbm_row_elem_num: - f.write(row_buffer[:hbm_row_elem_num]) - total_bytes_written += hbm_row_elem_num - blocks_bytes_written += hbm_row_elem_num - row_buffer = bytearray() # Reset buffer after writing + blocks_in_row = 0 + for _i, block in enumerate(blocks): + row_buffer.extend(pack_values_to_bytes(block, element_width)) + blocks_in_row += 1 + + if blocks_in_row == blocks_per_logical_row: + if len(row_buffer) > hbm_row_bytes: + raise ValueError( + f"Packed element row ({len(row_buffer)} bytes) exceeds HBM row width " + f"({hbm_row_bytes} bytes)" + ) + row_padding = hbm_row_bytes - len(row_buffer) + row_buffer.extend(b"\x00" * row_padding) + f.write(row_buffer) + total_bytes_written += len(row_buffer) + blocks_bytes_written += len(row_buffer) + row_buffer = bytearray() + blocks_in_row = 0 # Flush any remaining block data blocks_row_padding = 0 if len(row_buffer) > 0: # Pad to row width - blocks_row_padding = hbm_row_elem_num - len(row_buffer) + blocks_row_padding = hbm_row_bytes - len(row_buffer) row_buffer.extend(b"\x00" * blocks_row_padding) f.write(row_buffer) total_bytes_written += len(row_buffer) @@ -177,17 +220,24 @@ def map_mx_data_to_hbm_for_behave_sim( # Process bias row_buffer = bytearray() - for i, b in enumerate(bias): - hex_str = map_scale_to_value(b, bias_width) - bias_bytes = hex_to_bytes(hex_str) - row_buffer.extend(bias_bytes) - - # Write when row is full - if len(row_buffer) >= hbm_row_bias_num: - f.write(row_buffer[:hbm_row_bias_num]) - total_bytes_written += hbm_row_bias_num - bias_bytes_written += hbm_row_bias_num + scales_in_row = 0 + for _i, b in enumerate(bias): + row_buffer.extend(pack_values_to_bytes([b], bias_width)) + scales_in_row += 1 + + if scales_in_row == blocks_per_logical_row: + if len(row_buffer) > scale_row_bytes: + raise ValueError( + f"Packed scale row ({len(row_buffer)} bytes) exceeds scale row width " + f"({scale_row_bytes} bytes)" + ) + row_padding = scale_row_bytes - len(row_buffer) + row_buffer.extend(b"\x00" * row_padding) + f.write(row_buffer) + total_bytes_written += len(row_buffer) + bias_bytes_written += len(row_buffer) row_buffer = bytearray() + scales_in_row = 0 # # For Little Endian Purpose # if len(row_buffer) > 0: @@ -201,7 +251,7 @@ def map_mx_data_to_hbm_for_behave_sim( bias_row_padding = 0 if len(row_buffer) > 0: # Calculate padding needed - bias_row_padding = hbm_row_bias_num - len(row_buffer) + bias_row_padding = scale_row_bytes - len(row_buffer) row_buffer.extend(b"\x00" * bias_row_padding) f.write(row_buffer) total_bytes_written += len(row_buffer) diff --git a/transactional_emulator/lib/quantize/src/dtype.rs b/transactional_emulator/lib/quantize/src/dtype.rs index 2ed0e21..f501551 100644 --- a/transactional_emulator/lib/quantize/src/dtype.rs +++ b/transactional_emulator/lib/quantize/src/dtype.rs @@ -107,21 +107,20 @@ impl FpType { } // Inf/NaN -> Inf/NaN _ if exponent == exponent_mask => (new_exponent_mask, 0), - // Normal number bias conversion - _ if self.exponent <= new_ty.exponent => { - (exponent + ((new_exponent_mask - exponent_mask) >> 1), 0) - } + // Normal number bias conversion. _ => { - // TODO: Needs to reimplment the underflow and overflow treatment. - let bias_diff = (exponent - new_exponent_mask) >> 1; - if exponent <= bias_diff { - // Underflow: saturate to zero (subnormal) + let src_bias = (exponent_mask >> 1) as i32; + let dst_bias = (new_exponent_mask >> 1) as i32; + let dst_exp = exponent as i32 - src_bias + dst_bias; + + if dst_exp <= 0 { + // Underflow: saturate to zero. (0, 0) - } else if exponent - bias_diff >= new_exponent_mask { - // Overflow: saturate to infinity + } else if dst_exp >= new_exponent_mask as i32 { + // Overflow: saturate to infinity. (new_exponent_mask, 0) } else { - (exponent - bias_diff, 0) + (dst_exp as u32, 0) } } }; @@ -177,10 +176,77 @@ impl FpType { Self::F32.cast(self, float.to_bits()) } + /// Convert f32 to finite minifloat bits, matching the Python hardware + /// quantizer used for non-IEEE PLENA FP formats. + pub fn bits_from_f32_no_specials(self, float: f32) -> u32 { + if float == 0.0 || float.is_nan() { + return 0; + } + + let sign_bit = if self.sign && float.is_sign_negative() { + 1 + } else { + 0 + }; + if !self.sign && float.is_sign_negative() { + return 0; + } + + let value = float.abs(); + let exponent_bias = (mask(self.exponent) >> 1) as i32; + let exponent_min = -exponent_bias; + let exponent_max = (1i32 << self.exponent) - 2 - exponent_bias; + let mut exponent = (value + 1e-9).log2().floor() as i32; + let overflow = exponent > exponent_max; + exponent = exponent.clamp(exponent_min, exponent_max); + + let shift = 1u32 << self.mantissa; + let shifted_mantissa_max = shift - 1; + let mantissa = value / 2.0f32.powi(exponent); + let shifted = if exponent == exponent_min { + (mantissa * shift as f32).round() + } else { + ((mantissa - 1.0) * shift as f32).round() + }; + let mut shifted_mantissa = shifted.clamp(0.0, shifted_mantissa_max as f32) as u32; + if overflow { + shifted_mantissa = shifted_mantissa_max; + } + + let exponent_bits = (exponent + exponent_bias) as u32; + (sign_bit << (self.exponent + self.mantissa)) + | (exponent_bits << self.mantissa) + | shifted_mantissa + } + /// Convert bits to f32. Only lower `bits()` bits are used. pub const fn convert_bits_to_f32(self, bits: u32) -> f32 { f32::from_bits(self.cast(Self::F32, bits)) } + + /// Convert minifloat bits to f32 without reserving Inf/NaN encodings. + /// + /// HBM MXFP uses the Python hardware pack/unpack path, which treats all + /// exponent patterns as finite. Plain/scalar FP still uses IEEE specials. + pub fn convert_bits_to_f32_no_specials(self, bits: u32) -> f32 { + let sign = if self.sign { + (bits >> (self.exponent + self.mantissa)) & 1 + } else { + 0 + }; + let exponent = (bits >> self.mantissa) & mask(self.exponent); + let mantissa_bits = bits & mask(self.mantissa); + let exponent_bias = (mask(self.exponent) >> 1) as i32; + let exponent_val = exponent as i32 - exponent_bias; + let mantissa_scale = (1u32 << self.mantissa) as f32; + let mantissa = if exponent == 0 { + mantissa_bits as f32 / mantissa_scale + } else { + 1.0 + mantissa_bits as f32 / mantissa_scale + }; + let value = 2.0f32.powi(exponent_val) * mantissa; + if sign == 1 { -value } else { value } + } } #[test] @@ -217,6 +283,26 @@ fn test_f16() { ); } +#[test] +fn test_e6m5_scalar_roundtrip() { + let ty = FpType { + sign: true, + exponent: 6, + mantissa: 5, + }; + + assert_eq!(ty.convert_bits_to_f32(ty.bits_from_f32(0.25)), 0.25); + assert_eq!(ty.convert_bits_to_f32(ty.bits_from_f32(1.0)), 1.0); + assert_eq!( + ty.convert_bits_to_f32(ty.bits_from_f32(1.0 / 16.0)), + 1.0 / 16.0 + ); + assert_eq!( + ty.convert_bits_to_f32(ty.bits_from_f32(f32::NEG_INFINITY)), + f32::NEG_INFINITY + ); +} + #[test] fn test_e4m3_subnormal() { // E4M3 format: 1 sign, 4 exp, 3 mantissa. Bias = 7. @@ -422,6 +508,13 @@ impl DataType { } } + pub fn bits_from_f32_no_specials(self, float: f32) -> u32 { + match self { + DataType::Fp(fp_type) => fp_type.bits_from_f32_no_specials(float), + DataType::Int(int_type) => int_type.bits_from_f32(float), + } + } + pub const fn convert_bits_to_f32(self, bits: u32) -> f32 { match self { DataType::Fp(fp_type) => fp_type.convert_bits_to_f32(bits), @@ -429,8 +522,23 @@ impl DataType { } } + pub fn convert_bits_to_f32_no_specials(self, bits: u32) -> f32 { + match self { + DataType::Fp(fp_type) => fp_type.convert_bits_to_f32_no_specials(bits), + DataType::Int(int_type) => int_type.convert_bits_to_f32(bits), + } + } + /// Convert bytes to vector of f32. pub fn convert_bytes_to_f32_vec(self, mut bytes: &[u8], out: &mut [f32]) { + self.convert_bytes_to_f32_vec_impl(&mut bytes, out, false); + } + + pub fn convert_bytes_to_f32_vec_no_specials(self, mut bytes: &[u8], out: &mut [f32]) { + self.convert_bytes_to_f32_vec_impl(&mut bytes, out, true); + } + + fn convert_bytes_to_f32_vec_impl(self, bytes: &mut &[u8], out: &mut [f32], no_specials: bool) { let bits = self.size_in_bits(); let mut data = 0; let mut bits_left = 0; @@ -438,35 +546,53 @@ impl DataType { while bits_left < bits { data |= (bytes[0] as u32) << bits_left; bits_left += 8; - bytes = &bytes[1..]; + *bytes = &bytes[1..]; } - *out = self.convert_bits_to_f32(data); + *out = if no_specials { + self.convert_bits_to_f32_no_specials(data) + } else { + self.convert_bits_to_f32(data) + }; bits_left -= bits; data >>= bits; } } - pub fn bytes_from_f32(self, input: &[f32], mut out: &mut [u8]) { + pub fn bytes_from_f32(self, input: &[f32], out: &mut [u8]) { + self.bytes_from_f32_impl(input, out, false); + } + + pub fn bytes_from_f32_no_specials(self, input: &[f32], out: &mut [u8]) { + self.bytes_from_f32_impl(input, out, true); + } + + fn bytes_from_f32_impl(self, input: &[f32], out: &mut [u8], no_specials: bool) { let bits = self.size_in_bits(); let mut data = 0; let mut bits_left = 0u8; + let mut out_idx = 0usize; for elem in input.iter().copied() { while bits_left >= 8 { - out[0] = data as u8; - out = &mut out[1..]; + out[out_idx] = data as u8; + out_idx += 1; data >>= 8; bits_left -= 8; } - data |= self.bits_from_f32(elem) << bits_left; + let elem_bits = if no_specials { + self.bits_from_f32_no_specials(elem) + } else { + self.bits_from_f32(elem) + }; + data |= elem_bits << bits_left; bits_left += bits; } while bits_left > 0 { - out[0] = data as u8; - out = &mut out[1..]; + out[out_idx] = data as u8; + out_idx += 1; data >>= 8; bits_left = bits_left.saturating_sub(8); } diff --git a/transactional_emulator/lib/quantize/src/tensor.rs b/transactional_emulator/lib/quantize/src/tensor.rs index 311840d..fc9f477 100644 --- a/transactional_emulator/lib/quantize/src/tensor.rs +++ b/transactional_emulator/lib/quantize/src/tensor.rs @@ -130,7 +130,6 @@ impl QuantTensor { let elem_ty = ty.element_type(); let mut vec = vec![0f32; len]; - elem_ty.convert_bytes_to_f32_vec(bytes, &mut vec); if let MxDataType::Mx { elem: _, @@ -138,9 +137,10 @@ impl QuantTensor { block, } = ty { + elem_ty.convert_bytes_to_f32_vec_no_specials(bytes, &mut vec); let mut scale_vec = vec![0f32; len / block as usize]; - scale.convert_bytes_to_f32_vec(&scale_bytes, &mut scale_vec); + scale.convert_bytes_to_f32_vec_no_specials(scale_bytes, &mut scale_vec); for (elem, scale) in vec .chunks_mut(block as usize) @@ -150,6 +150,8 @@ impl QuantTensor { *elem *= scale; } } + } else { + elem_ty.convert_bytes_to_f32_vec(bytes, &mut vec); } let tensor = tch::Tensor::from_slice(&vec); diff --git a/transactional_emulator/lib/vector_sram/src/lib.rs b/transactional_emulator/lib/vector_sram/src/lib.rs index 38ba4a2..17f444f 100644 --- a/transactional_emulator/lib/vector_sram/src/lib.rs +++ b/transactional_emulator/lib/vector_sram/src/lib.rs @@ -1,4 +1,4 @@ -use quantize::{DataType, MxDataType, QuantTensor}; +use quantize::{DataType, FpType, MxDataType, QuantTensor}; use tch::Tensor; use tokio::sync::oneshot::Receiver; use tokio::sync::Mutex; @@ -32,6 +32,10 @@ enum RowData { } impl VectorSram { + fn row_width_bytes(vlen: u32, fp_type: DataType) -> usize { + (vlen as usize * fp_type.size_in_bits() as usize).div_ceil(8) + } + /// Create a new Vector SRAM with given vector length, depth, and data types. /// /// # Arguments @@ -40,9 +44,7 @@ impl VectorSram { /// * `fp_type` - Floating point data type for FP operations /// * `int_size_bytes` - Size of integer in bytes (typically 4 for i32) pub fn new(vlen: u32, depth: usize, fp_type: DataType, int_size_bytes: usize) -> Self { - // Use FP type size for row width (can be changed if needed) - let element_size = fp_type.size_in_bits() as usize / 8; - let row_width = vlen as usize * element_size; + let row_width = Self::row_width_bytes(vlen, fp_type); let rows = (0..depth) .map(|_| Mutex::new(RowData::Ready(vec![0u8; row_width]))) @@ -80,9 +82,7 @@ impl VectorSram { /// Get the size of the SRAM in bytes pub fn size_in_bytes(&self) -> usize { - let element_size = self.fp_type.size_in_bits() as usize / 8; - let row_width = self.vlen as usize * element_size; - row_width * self.depth + Self::row_width_bytes(self.vlen, self.fp_type) * self.depth } /// Read a vector from the SRAM at the given address as FP (QuantTensor). @@ -263,9 +263,8 @@ impl VectorSram { /// /// This is used for preloading the SRAM with test data. pub async fn load_from_bytes(&self, bytes: &[u8]) { - let element_size = self.fp_type.size_in_bits() as usize / 8; - let bytes_per_element = element_size; - let total_elements = bytes.len() / bytes_per_element; + let bits_per_element = self.fp_type.size_in_bits() as usize; + let total_elements = (bytes.len() * 8) / bits_per_element; let num_rows = (total_elements + self.vlen as usize - 1) / self.vlen as usize; for row_idx in 0..num_rows.min(self.depth) { @@ -273,13 +272,24 @@ impl VectorSram { let end_element = (start_element + self.vlen as usize).min(total_elements); let elements_in_row = end_element - start_element; - let start_byte = start_element * bytes_per_element; - let end_byte = end_element * bytes_per_element; + let start_bit = start_element * bits_per_element; + let end_bit = end_element * bits_per_element; + assert!( + start_bit.is_multiple_of(8) && end_bit.is_multiple_of(8), + "Vector SRAM preload rows must be byte-aligned" + ); + let start_byte = start_bit / 8; + let end_byte = end_bit / 8; // Convert bytes to f32 values let mut vec = vec![0f32; elements_in_row]; - self.fp_type - .convert_bytes_to_f32_vec(&bytes[start_byte..end_byte], &mut vec); + if self.use_no_specials_fp() { + self.fp_type + .convert_bytes_to_f32_vec_no_specials(&bytes[start_byte..end_byte], &mut vec); + } else { + self.fp_type + .convert_bytes_to_f32_vec(&bytes[start_byte..end_byte], &mut vec); + } // Pad with zeros if needed if elements_in_row < self.vlen as usize { @@ -354,19 +364,28 @@ impl VectorSram { let total_bits = len * self.fp_type.size_in_bits() as usize; let bytes_needed = (total_bits + 7) / 8; let mut bytes = vec![0u8; bytes_needed]; - self.fp_type.bytes_from_f32(f32_slice, &mut bytes); + if self.use_no_specials_fp() { + self.fp_type + .bytes_from_f32_no_specials(f32_slice, &mut bytes); + } else { + self.fp_type.bytes_from_f32(f32_slice, &mut bytes); + } bytes } /// Convert bytes to QuantTensor (FP format) fn bytes_to_quant_tensor(&self, bytes: &[u8], expected_len: u32) -> QuantTensor { - let bytes_per_element = self.fp_type.size_in_bits() as usize / 8; - let num_elements = bytes.len() / bytes_per_element; + let bits_per_element = self.fp_type.size_in_bits() as usize; + let num_elements = (bytes.len() * 8) / bits_per_element; let actual_len = num_elements.min(expected_len as usize); let mut vec = vec![0f32; actual_len]; - self.fp_type - .convert_bytes_to_f32_vec(&bytes[..actual_len * bytes_per_element], &mut vec); + if self.use_no_specials_fp() { + self.fp_type + .convert_bytes_to_f32_vec_no_specials(bytes, &mut vec); + } else { + self.fp_type.convert_bytes_to_f32_vec(bytes, &mut vec); + } // Pad to expected_len if needed if actual_len < expected_len as usize { @@ -377,6 +396,13 @@ impl VectorSram { QuantTensor::quantize(tensor, MxDataType::Plain(self.fp_type)) } + fn use_no_specials_fp(&self) -> bool { + match self.fp_type { + DataType::Fp(fp_type) => !matches!(fp_type, FpType::F16 | FpType::BF16 | FpType::F32), + DataType::Int(_) => false, + } + } + /// Convert integer vector to bytes fn int_vec_to_bytes(&self, int_vec: &[i32], expected_len: u32) -> Vec { let mut bytes = Vec::with_capacity(expected_len as usize * self.int_size_bytes); diff --git a/transactional_emulator/src/load_config.rs b/transactional_emulator/src/load_config.rs index 2655792..f98f36b 100644 --- a/transactional_emulator/src/load_config.rs +++ b/transactional_emulator/src/load_config.rs @@ -396,6 +396,10 @@ pub static CONFIG: LazyLock = LazyLock::new(|| { // Configuration loading functions pub fn load_config() -> Result> { + if let Ok(config_path) = env::var("PLENA_SETTINGS_TOML") { + return load_config_from_file(&config_path); + } + let config_path = env::current_dir() .unwrap() .parent() diff --git a/transactional_emulator/src/main.rs b/transactional_emulator/src/main.rs index 6d70e54..a9631f4 100644 --- a/transactional_emulator/src/main.rs +++ b/transactional_emulator/src/main.rs @@ -14,9 +14,9 @@ use std::sync::LazyLock; use clap::Parser; use futures::StreamExt; use futures::stream::FuturesUnordered; -use half::{bf16, f16}; +use half::f16; use memory::{ErasedMemoryModel, MemoryModel}; -use quantize::{MxDataType, QuantTensor}; +use quantize::{DataType, MxDataType, QuantTensor}; use runtime::{Duration, Executor, Instant}; use tch::{IndexOp, Tensor}; use vector_sram::VectorSram; @@ -60,6 +60,7 @@ static MATRIX_WEIGHT_TYPE: LazyLock = LazyLock::new(|| matrix_weight static MATRIX_KV_TYPE: LazyLock = LazyLock::new(|| matrix_kv_type()); static VECTOR_ACTIVATION_TYPE: LazyLock = LazyLock::new(|| vector_activation_type()); static VECTOR_KV_TYPE: LazyLock = LazyLock::new(|| vector_kv_type()); +static SCALAR_FP_TYPE: LazyLock = LazyLock::new(|| scalar_fp_type()); static PREFETCH_M_AMOUNT: LazyLock = LazyLock::new(|| hbm_m_prefetch_amount()); static PREFETCH_V_AMOUNT: LazyLock = LazyLock::new(|| hbm_v_prefetch_amount()); static STORE_V_AMOUNT: LazyLock = LazyLock::new(|| hbm_v_writeback_amount()); @@ -899,7 +900,7 @@ impl VectorMachine { async fn exp(&mut self, vd: u32, vs1: u32, rmask: u8, mask: u32) { let a = self.vram.read(vs1).await; - // Clamp inputs to [-88, 88] to prevent bf16 overflow (exp(89) > bf16_max). + // Clamp inputs to [-88, 88] to prevent FP overflow in exp. // This matches what hardware exp units do (saturate instead of producing inf/NaN). let clamped = a.as_tensor().clamp(-88.0f64, 88.0f64); if rmask == 0 { @@ -948,16 +949,13 @@ impl VectorMachine { } } - async fn vector_transfer_fp(&mut self, vd: u32, f: &[bf16]) { + async fn vector_transfer_fp(&mut self, vd: u32, f: &[f32]) { assert_eq!( f.len(), self.vram.tile_size() as usize, "Input vector length must match tile_size" ); - // Convert bf16 slice to f32 vector - let f32_vec: Vec = f.iter().map(|x| f32::from(*x)).collect(); - // Create tensor from f32 vector - let tensor = tch::Tensor::from_slice(&f32_vec); + let tensor = tch::Tensor::from_slice(f); // Quantize the tensor according to vram data type let c = QuantTensor::quantize(tensor, self.vram.ty()); cycle!(*VLEN); @@ -1026,13 +1024,13 @@ struct Accelerator { hbm: Arc, reg_file: AcceeleratorRegFile, intsram: Vec, - fpsram: Vec, + fpsram: Vec, loop_stack: Vec, // Stack for nested loops } struct AcceeleratorRegFile { gp_reg: [u32; 16], - fp_reg: [bf16; 8], + fp_reg: [f32; 8], hbm_addr_reg: [u64; 16], scale: u32, stride: u32, @@ -1097,8 +1095,7 @@ impl Accelerator { assert!(element_bits.is_power_of_two()); let len_in_bits_per_load = element_bits as u32 * load_dim; - assert!(len_in_bits_per_load.is_multiple_of(8 * 64)); - let len_in_bytes_per_load = len_in_bits_per_load / 8; + let len_in_bytes_per_load = len_in_bits_per_load.div_ceil(8); // Calculate scale bytes per load iteration (for Mx types) let (scale_len_in_bytes_per_load, block) = if let MxDataType::Mx { @@ -1110,8 +1107,7 @@ impl Accelerator { let scale_bits = scale.size_in_bits(); assert!(scale_bits.is_power_of_two()); let scale_len_in_bits_per_load = scale_bits as u32 * (load_dim / block); - assert!(scale_len_in_bits_per_load.is_multiple_of(8)); - (scale_len_in_bits_per_load / 8, block as usize) + (scale_len_in_bits_per_load.div_ceil(8), block as usize) } else { (0, usize::MAX) }; @@ -1147,42 +1143,46 @@ impl Accelerator { as usize + block_idx as usize * scale_len_in_bytes_per_load as usize; - // Element chunks: - for i in 0..(len_in_bytes_per_load as usize + 63) / 64 { - let chunk_offset = byte_offset + i * 64; - let chunk_size = std::cmp::min(64, total_bytes - chunk_offset); - let addr = element_addr + (i * 64) as u64; - assert!(addr.is_multiple_of(64)); + // Element chunks may start at sub-64B HBM offsets for smaller RTL + // shapes. Fetch aligned chunks and copy only the requested bytes. + let mut element_bytes_read = 0usize; + while element_bytes_read < len_in_bytes_per_load as usize { + let current_addr = element_addr + element_bytes_read as u64; + let aligned_addr = (current_addr / 64) * 64; + let within_chunk_offset = (current_addr % 64) as usize; + let bytes_remaining = len_in_bytes_per_load as usize - element_bytes_read; + let chunk_size = std::cmp::min(64 - within_chunk_offset, bytes_remaining); + let chunk_offset = byte_offset + element_bytes_read; futures.push(Box::pin(async move { - let data = hbm_clone.read(addr).await; - ChunkType::Element(chunk_offset, data, chunk_size) + let data = hbm_clone.read(aligned_addr).await; + let mut selected = [0u8; 64]; + selected[..chunk_size].copy_from_slice( + &data[within_chunk_offset..within_chunk_offset + chunk_size], + ); + ChunkType::Element(chunk_offset, selected, chunk_size) })); + element_bytes_read += chunk_size; } // Scale chunks (if Mx type) - if scale_len_in_bytes_per_load > 0 { - // Always align to 64-byte chunk boundary for loading - // For scale_addr, we fetch the aligned 64-byte block, and mask/select out only what is needed - let aligned_scale_addr = (scale_addr / 64) * 64; - let within_chunk_offset = (scale_addr % 64) as usize; - let chunk_offset = scale_byte_offset; // where to write in scale_bytes - let chunk_size = std::cmp::min(64, total_scale_bytes - chunk_offset); + let mut scale_bytes_read = 0usize; + while scale_bytes_read < scale_len_in_bytes_per_load as usize { + let current_scale_addr = scale_addr + scale_bytes_read as u64; + let aligned_scale_addr = (current_scale_addr / 64) * 64; + let within_chunk_offset = (current_scale_addr % 64) as usize; + let bytes_remaining = + scale_len_in_bytes_per_load as usize - scale_bytes_read; + let chunk_size = std::cmp::min(64 - within_chunk_offset, bytes_remaining); + let chunk_offset = scale_byte_offset + scale_bytes_read; futures.push(Box::pin(async move { let data = hbm_clone.read(aligned_scale_addr).await; - // println!("aligned_scale_addr = {:?}", aligned_scale_addr); - // Copy out only the relevant bytes for this scale_addr - // scale_len_in_bytes_per_load says how many bytes to copy from within the chunk - let end_offset = std::cmp::min( - within_chunk_offset + scale_len_in_bytes_per_load as usize, - 64, - ); let mut selected = [0u8; 64]; - let len_to_copy = end_offset - within_chunk_offset; - selected[..len_to_copy] - .copy_from_slice(&data[within_chunk_offset..end_offset]); - // println!("selected scale = {:?}", selected); - ChunkType::Scale(chunk_offset, selected, len_to_copy) + selected[..chunk_size].copy_from_slice( + &data[within_chunk_offset..within_chunk_offset + chunk_size], + ); + ChunkType::Scale(chunk_offset, selected, chunk_size) })); + scale_bytes_read += chunk_size; } } } @@ -1212,10 +1212,7 @@ impl Accelerator { let bytes_start = (write_idx * write_amount) as usize * len_in_bytes_per_load as usize; - element_ty.convert_bytes_to_f32_vec( - &bytes[bytes_start..bytes_start + write_elements * (element_bits as usize / 8)], - &mut vec, - ); + let write_element_bytes = (write_elements * element_bits as usize).div_ceil(8); // Apply scaling if needed if let MxDataType::Mx { @@ -1224,13 +1221,17 @@ impl Accelerator { block, } = hbm_type { + element_ty.convert_bytes_to_f32_vec_no_specials( + &bytes[bytes_start..bytes_start + write_element_bytes], + &mut vec, + ); let nblocks = write_elements / block as usize; let scale_bytes_start = (write_idx * write_amount) as usize * scale_len_in_bytes_per_load as usize; let mut scale_vec = vec![0f32; nblocks]; - scale.convert_bytes_to_f32_vec( + scale.convert_bytes_to_f32_vec_no_specials( &scale_bytes[scale_bytes_start - ..scale_bytes_start + nblocks * (scale_bits as usize / 8)], + ..scale_bytes_start + (nblocks * scale_bits as usize).div_ceil(8)], &mut scale_vec, ); for (elem_block, scale_val) in vec @@ -1241,6 +1242,11 @@ impl Accelerator { *elem *= scale_val; } } + } else { + element_ty.convert_bytes_to_f32_vec( + &bytes[bytes_start..bytes_start + write_element_bytes], + &mut vec, + ); } let tensor = tch::Tensor::from_slice(&vec); @@ -1308,8 +1314,7 @@ impl Accelerator { assert!(element_bits.is_power_of_two()); let len_in_bits_per_store = element_bits as u32 * store_dim; - assert!(len_in_bits_per_store.is_multiple_of(8 * 64)); - let len_in_bytes_per_store = len_in_bits_per_store / 8; + let len_in_bytes_per_store = len_in_bits_per_store.div_ceil(8); // Calculate scale bytes per store iteration (for Mx types) let (scale_len_in_bytes_per_store, block) = if let MxDataType::Mx { @@ -1321,8 +1326,7 @@ impl Accelerator { let scale_bits = scale.size_in_bits(); assert!(scale_bits.is_power_of_two()); let scale_len_in_bits_per_store = scale_bits as u32 * (store_dim / block); - assert!(scale_len_in_bits_per_store.is_multiple_of(8)); - (scale_len_in_bits_per_store / 8, block as usize) + (scale_len_in_bits_per_store.div_ceil(8), block as usize) } else { (0, usize::MAX) }; @@ -1394,20 +1398,29 @@ impl Accelerator { let element_addr = index + (store_iter * stride) as u64; let scale_addr = scale_index + (store_iter as f32 * stride_scale) as u64; - // Write element bytes to HBM (64-byte aligned chunks) - for i in 0..(len_in_bytes_per_store as usize + 63) / 64 { - let chunk_offset = i * 64; - let chunk_size = std::cmp::min(64, len_in_bytes_per_store as usize - chunk_offset); - let addr = element_addr + (i * 64) as u64; - assert!(addr.is_multiple_of(64)); - - let mut chunk = [0u8; 64]; - if chunk_offset < element_bytes.len() { - let copy_len = std::cmp::min(chunk_size, element_bytes.len() - chunk_offset); - chunk[..copy_len] - .copy_from_slice(&element_bytes[chunk_offset..chunk_offset + copy_len]); + // Write element bytes to HBM, preserving unrelated bytes in the same + // aligned chunk when the logical row is smaller than 64B. + let mut element_bytes_written = 0usize; + let total_element_bytes = len_in_bytes_per_store as usize; + while element_bytes_written < total_element_bytes { + let current_addr = element_addr + element_bytes_written as u64; + let aligned_addr = (current_addr / 64) * 64; + let within_chunk_offset = (current_addr % 64) as usize; + let bytes_remaining = total_element_bytes - element_bytes_written; + let bytes_in_chunk = std::cmp::min(64 - within_chunk_offset, bytes_remaining); + let bytes_to_copy = + std::cmp::min(bytes_in_chunk, element_bytes.len() - element_bytes_written); + + let mut chunk = hbm_clone.read(aligned_addr).await; + if bytes_to_copy > 0 { + chunk[within_chunk_offset..within_chunk_offset + bytes_to_copy] + .copy_from_slice( + &element_bytes + [element_bytes_written..element_bytes_written + bytes_to_copy], + ); } - hbm_clone.write(addr, chunk).await; + hbm_clone.write(aligned_addr, chunk).await; + element_bytes_written += bytes_in_chunk; } // Write scale bytes to HBM (if Mx type) @@ -1855,7 +1868,7 @@ impl Accelerator { mask, ) .await; - self.reg_file.fp_reg[*rd as usize] = bf16::from_f32(result); + self.reg_file.fp_reg[*rd as usize] = quantize_scalar_fp(result); } op::Opcode::V_RED_MAX { rd, rs1, rmask } => { let mask = if *rmask == 0 { @@ -1872,7 +1885,7 @@ impl Accelerator { mask, ) .await; - self.reg_file.fp_reg[*rd as usize] = bf16::from_f32(result); + self.reg_file.fp_reg[*rd as usize] = quantize_scalar_fp(result); } // Write to fp0 is a no-op. @@ -1885,41 +1898,44 @@ impl Accelerator { | op::Opcode::S_SQRT_FP { rd: 0, .. } => {} op::Opcode::S_ADD_FP { rd, rs1, rs2 } => { - self.reg_file.fp_reg[*rd as usize] = - self.reg_file.fp_reg[*rs1 as usize] + self.reg_file.fp_reg[*rs2 as usize]; + self.reg_file.fp_reg[*rd as usize] = quantize_scalar_fp( + self.reg_file.fp_reg[*rs1 as usize] + self.reg_file.fp_reg[*rs2 as usize], + ); cycle!(*SCALAR_FP_BASIC_CYCLES); } op::Opcode::S_SUB_FP { rd, rs1, rs2 } => { - self.reg_file.fp_reg[*rd as usize] = - self.reg_file.fp_reg[*rs1 as usize] - self.reg_file.fp_reg[*rs2 as usize]; + self.reg_file.fp_reg[*rd as usize] = quantize_scalar_fp( + self.reg_file.fp_reg[*rs1 as usize] - self.reg_file.fp_reg[*rs2 as usize], + ); cycle!(*SCALAR_FP_BASIC_CYCLES); } op::Opcode::S_MAX_FP { rd, rs1, rs2 } => { - self.reg_file.fp_reg[*rd as usize] = bf16::max( + self.reg_file.fp_reg[*rd as usize] = quantize_scalar_fp(f32::max( self.reg_file.fp_reg[*rs1 as usize], self.reg_file.fp_reg[*rs2 as usize], - ); + )); cycle!(*SCALAR_FP_BASIC_CYCLES); } op::Opcode::S_MUL_FP { rd, rs1, rs2 } => { - self.reg_file.fp_reg[*rd as usize] = - self.reg_file.fp_reg[*rs1 as usize] * self.reg_file.fp_reg[*rs2 as usize]; + self.reg_file.fp_reg[*rd as usize] = quantize_scalar_fp( + self.reg_file.fp_reg[*rs1 as usize] * self.reg_file.fp_reg[*rs2 as usize], + ); cycle!(*SCALAR_FP_BASIC_CYCLES); } op::Opcode::S_EXP_FP { rd, rs1 } => { - let val: f32 = self.reg_file.fp_reg[*rs1 as usize].into(); + let val = self.reg_file.fp_reg[*rs1 as usize]; let clamped = val.clamp(-88.0, 88.0); - self.reg_file.fp_reg[*rd as usize] = bf16::from_f32(clamped.exp()); + self.reg_file.fp_reg[*rd as usize] = quantize_scalar_fp(clamped.exp()); cycle!(*SCALAR_FP_EXP_CYCLES); } op::Opcode::S_RECI_FP { rd, rs1 } => { self.reg_file.fp_reg[*rd as usize] = - bf16::ONE / self.reg_file.fp_reg[*rs1 as usize]; + quantize_scalar_fp(1.0 / self.reg_file.fp_reg[*rs1 as usize]); cycle!(*SCALAR_FP_RECI_CYCLES); } op::Opcode::S_SQRT_FP { rd, rs1 } => { self.reg_file.fp_reg[*rd as usize] = - bf16::from_f32(f32::from(self.reg_file.fp_reg[*rs1 as usize]).sqrt()); + quantize_scalar_fp(self.reg_file.fp_reg[*rs1 as usize].sqrt()); cycle!(*SCALAR_FP_SQRT_CYCLES); } op::Opcode::S_LD_FP { rd, rs1, imm } => { @@ -2307,6 +2323,10 @@ fn is_quiet() -> bool { QUIET_MODE.load(std::sync::atomic::Ordering::Relaxed) } +fn quantize_scalar_fp(value: f32) -> f32 { + SCALAR_FP_TYPE.convert_bits_to_f32(SCALAR_FP_TYPE.bits_from_f32(value)) +} + async fn start() { let opts = Opts::parse(); QUIET_MODE.store(opts.quiet, std::sync::atomic::Ordering::Relaxed); @@ -2367,7 +2387,7 @@ async fn start() { hbm: hbm.clone(), reg_file: AcceeleratorRegFile { gp_reg: [0; 16], - fp_reg: [bf16::ZERO; 8], + fp_reg: [0.0; 8], hbm_addr_reg: [0; 16], scale: 0, stride: 1, @@ -2378,7 +2398,7 @@ async fn start() { v_mask: 0, }, intsram: vec![0; 1024], - fpsram: vec![bf16::ZERO; 1024], + fpsram: vec![0.0; 1024], loop_stack: Vec::new(), }; @@ -2400,13 +2420,13 @@ async fn start() { // Load fpsram and intsram as raw bytes and map to the vector files. // - fpsram Preload let fpsram_data = std::fs::read(opts.fpsram).unwrap(); - let fp_vals: Vec = { + let fp_vals: Vec = { let n = fpsram_data.len() / std::mem::size_of::(); let f16_slice: &[f16] = unsafe { std::slice::from_raw_parts(fpsram_data.as_ptr() as *const f16, n) }; f16_slice .iter() - .map(|x| bf16::from_f32(f32::from(*x))) + .map(|x| quantize_scalar_fp(f32::from(*x))) .collect() }; @@ -2474,11 +2494,10 @@ async fn start() { // Dump FPSRAM let fpsram_dump_path = "fpsram_dump.bin"; - let fpsram_bytes: Vec = accelerator - .fpsram - .iter() - .flat_map(|f| f.to_le_bytes()) - .collect(); + let scalar_fp = *SCALAR_FP_TYPE; + let total_bits = accelerator.fpsram.len() * scalar_fp.size_in_bits() as usize; + let mut fpsram_bytes = vec![0u8; total_bits.div_ceil(8)]; + scalar_fp.bytes_from_f32(&accelerator.fpsram, &mut fpsram_bytes); let mut fpsram_file = std::fs::File::create(fpsram_dump_path).unwrap(); fpsram_file.write_all(&fpsram_bytes).unwrap(); if !is_quiet() { diff --git a/transactional_emulator/testbench/emulator_runner.py b/transactional_emulator/testbench/emulator_runner.py index a89abca..053be77 100644 --- a/transactional_emulator/testbench/emulator_runner.py +++ b/transactional_emulator/testbench/emulator_runner.py @@ -10,6 +10,7 @@ import sys from pathlib import Path +import tomlkit from transactional_emulator.tools.check_mem import compare_vram_with_golden, print_comparison_results from transactional_emulator.testbench.config_utils import update_plena_config @@ -113,12 +114,14 @@ def compare_emulator_output(build_dir: Path) -> tuple: with open(params_file) as f: params = json.load(f) + exp_width, man_width, bits_per_val = _current_vector_sram_fp_format() results = compare_vram_with_golden( vram_file, golden_file, - exp_width=8, - man_width=7, - num_bytes_per_val=2, + exp_width=exp_width, + man_width=man_width, + num_bytes_per_val=max(1, (bits_per_val + 7) // 8), + bits_per_val=bits_per_val, row_dim=params.get("row_dim", 64), start_row_idx=params["start_row_idx"], num_batches=params["num_batches"], @@ -131,6 +134,18 @@ def compare_emulator_output(build_dir: Path) -> tuple: return results, params +def _current_vector_sram_fp_format() -> tuple[int, int, int]: + """Return VECTOR_SRAM_TYPE as (exp, mant, total_bits) from the active TOML.""" + config_path = Path(os.environ.get("PLENA_SETTINGS_TOML", Path(__file__).parents[2] / "plena_settings.toml")) + with open(config_path) as f: + config = tomlkit.load(f) + data_type = config["BEHAVIOR"]["PRECISION"]["VECTOR_SRAM_TYPE"]["DATA_TYPE"] + exp_width = int(data_type["exponent"]) + man_width = int(data_type["mantissa"]) + sign_width = 1 if bool(data_type.get("sign", True)) else 0 + return exp_width, man_width, sign_width + exp_width + man_width + + def run_and_assert(build_dir: Path, op_name: str, mlen: int = 64, blen: int = 4) -> None: """ Sync HW config, run the Rust emulator, compare output, exit(1) on failure. @@ -142,7 +157,9 @@ def run_and_assert(build_dir: Path, op_name: str, mlen: int = 64, blen: int = 4) blen: Batch tile length — synced to plena_settings.toml before running. """ # VLEN must equal mlen so the emulator's row-address alignment check passes. - update_plena_config(vlen=mlen, mlen=mlen, blen=blen, verbose=False) + # When PLENA_SETTINGS_TOML is set, the caller owns the full generated config. + if "PLENA_SETTINGS_TOML" not in os.environ: + update_plena_config(vlen=mlen, mlen=mlen, blen=blen, verbose=False) print("\n--- Running Rust transactional emulator ---") run_emulator(build_dir) diff --git a/transactional_emulator/testbench/models/clm60m_rtl_config_test.py b/transactional_emulator/testbench/models/clm60m_rtl_config_test.py new file mode 100644 index 0000000..39f14b7 --- /dev/null +++ b/transactional_emulator/testbench/models/clm60m_rtl_config_test.py @@ -0,0 +1,72 @@ +"""Run a CLM-60M sliced decoder test using the local PLENA_RTL config.""" + +from __future__ import annotations + +import argparse +from pathlib import Path +import sys + + +REPO_ROOT = Path(__file__).resolve().parents[3] +for path in (REPO_ROOT, REPO_ROOT / "tools", REPO_ROOT / "PLENA_Compiler"): + path_str = str(path) + if path_str not in sys.path: + sys.path.insert(0, path_str) + + +def main() -> int: + from transactional_emulator.testbench.rtl_config import default_rtl_root, rtl_plena_settings + from transactional_emulator.testbench.sliced_layer_test_builder import build_and_run_sliced_decoder_layer_test + + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--rtl-root", type=Path, default=None, help="Path to local PLENA_RTL checkout") + parser.add_argument("--seq-len", type=int, default=None, help="Sequence length; defaults to RTL MLEN") + parser.add_argument("--hidden-size", type=int, default=None, help="Sliced hidden size; defaults to RTL MLEN") + parser.add_argument("--inter-dim", type=int, default=None, help="Sliced FFN intermediate; defaults to 2*MLEN") + parser.add_argument("--layer-idx", type=int, default=0) + parser.add_argument("--build-dir", type=Path, default=Path("/tmp/clm60m_rtl_config")) + args = parser.parse_args() + + rtl_root = (args.rtl_root or default_rtl_root(REPO_ROOT)).resolve() + plena_toml = REPO_ROOT / "plena_settings.toml" + + with rtl_plena_settings(plena_toml, rtl_root) as rtl: + mlen = rtl["MLEN"] + blen = rtl["BLEN"] + seq_len = args.seq_len or mlen + hidden_size = args.hidden_size or mlen + inter_dim = args.inter_dim or (2 * mlen) + + print("Using RTL configuration:") + print(f" rtl_root={rtl_root}") + print( + f" MLEN={rtl['MLEN']} VLEN={rtl['VLEN']} BLEN={rtl['BLEN']} " + f"HLEN={rtl['HLEN']} BROADCAST_AMOUNT={rtl['BROADCAST_AMOUNT']}" + ) + print( + f" M_FP=e{rtl['M_FP_EXP_WIDTH']}m{rtl['M_FP_MANT_WIDTH']} " + f"V_FP=e{rtl['V_FP_EXP_WIDTH']}m{rtl['V_FP_MANT_WIDTH']} " + f"S_FP=e{rtl['S_FP_EXP_WIDTH']}m{rtl['S_FP_MANT_WIDTH']}" + ) + print( + f" CLM-60M sliced decoder: seq_len={seq_len}, " + f"hidden_size={hidden_size}, inter_dim={inter_dim}" + ) + + build_and_run_sliced_decoder_layer_test( + model_id="AICrossSim/clm-60m", + asm_name=f"clm60m_rtl_m{mlen}_b{blen}", + build_dir=args.build_dir, + layer_idx=args.layer_idx, + seq_len=seq_len, + hidden_size=hidden_size, + inter_dim=inter_dim, + mlen=mlen, + blen=blen, + ) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/transactional_emulator/testbench/rtl_config.py b/transactional_emulator/testbench/rtl_config.py new file mode 100644 index 0000000..76adcaf --- /dev/null +++ b/transactional_emulator/testbench/rtl_config.py @@ -0,0 +1,167 @@ +"""Bridge a local PLENA_RTL configuration into the simulator TOML.""" + +from __future__ import annotations + +from collections.abc import Iterator +from contextlib import contextmanager +import os +from pathlib import Path +import re + +import tomlkit + + +_PARAM_RE = re.compile(r"\s*(?:localparam|parameter)\s+(?:[\w]+\s+)*(?P\w+)\s*=\s*(?P[^;]+);") + + +def _load_svh_ints(path: Path) -> dict[str, int]: + values: dict[str, int] = {} + for line in path.read_text().splitlines(): + line = line.split("//", 1)[0] + match = _PARAM_RE.match(line) + if not match: + continue + value = match.group("value").strip() + try: + values[match.group("name")] = int(value, 0) + except ValueError: + continue + return values + + +def _set_value(config, section: str, name: str, value: int) -> None: + if name in config["BEHAVIOR"][section]: + config["BEHAVIOR"][section][name]["value"] = value + + +def _set_plain_fp(node, *, sign: bool, exponent: int, mantissa: int) -> None: + node["format"] = "Plain" + node["DATA_TYPE"]["type"] = "Fp" + node["DATA_TYPE"]["sign"] = sign + node["DATA_TYPE"]["exponent"] = exponent + node["DATA_TYPE"]["mantissa"] = mantissa + + +def _set_mx_fp(node, *, block: int, elem_exp: int, elem_mant: int, scale_exp: int) -> None: + node["format"] = "Mx" + node["block"] = block + node["ELEM"]["type"] = "Fp" + node["ELEM"]["sign"] = True + node["ELEM"]["exponent"] = elem_exp + node["ELEM"]["mantissa"] = elem_mant + node["SCALE"]["type"] = "Fp" + node["SCALE"]["sign"] = False + node["SCALE"]["exponent"] = scale_exp + node["SCALE"]["mantissa"] = 0 + + +def _set_scalar_fp(node, *, sign: bool, exponent: int, mantissa: int) -> None: + node["type"] = "Fp" + node["sign"] = sign + node["exponent"] = exponent + node["mantissa"] = mantissa + + +def default_rtl_root(repo_root: Path) -> Path: + env_path = os.environ.get("PLENA_RTL_LOCAL") or os.environ.get("PLENA_RTL_ROOT") + if env_path: + return Path(env_path).expanduser().resolve() + return (repo_root.parent / "PLENA_RTL").resolve() + + +def apply_rtl_settings_to_toml(config, rtl_root: Path) -> dict[str, int]: + definitions = rtl_root / "src" / "definitions" + config_svh = definitions / "configuration.svh" + precision_svh = definitions / "precision.svh" + if not config_svh.exists() or not precision_svh.exists(): + raise FileNotFoundError(f"Missing RTL configuration files under {definitions}") + + rtl_config = _load_svh_ints(config_svh) + rtl_precision = _load_svh_ints(precision_svh) + + mlen = rtl_config["MLEN"] + vlen = rtl_config["VLEN"] + blen = rtl_config["BLEN"] + hlen = rtl_config.get("HLEN", config["BEHAVIOR"]["CONFIG"]["HLEN"]["value"]) + broadcast_amount = rtl_config.get("BROADCAST_AMOUNT", mlen // hlen) + + _set_value(config, "CONFIG", "MLEN", mlen) + _set_value(config, "CONFIG", "VLEN", vlen) + _set_value(config, "CONFIG", "BLEN", blen) + _set_value(config, "CONFIG", "HLEN", hlen) + _set_value(config, "CONFIG", "BROADCAST_AMOUNT", broadcast_amount) + _set_value(config, "CONFIG", "HBM_M_Prefetch_Amount", rtl_config.get("HBM_M_Prefetch_Amount", 16)) + _set_value(config, "CONFIG", "HBM_V_Prefetch_Amount", rtl_config.get("HBM_V_Prefetch_Amount", 16)) + _set_value(config, "CONFIG", "HBM_V_Writeback_Amount", rtl_config.get("HBM_V_Writeback_Amount", 4)) + _set_value(config, "CONFIG", "MATRIX_SRAM_SIZE", rtl_config.get("MATRIX_SRAM_DEPTH", 1024)) + _set_value(config, "CONFIG", "VECTOR_SRAM_SIZE", rtl_config.get("VECTOR_SRAM_DEPTH", 1024)) + + wt_exp = rtl_precision.get("WT_MX_EXP_WIDTH", 4) + wt_mant = rtl_precision.get("WT_MX_MANT_WIDTH", 3) + kv_exp = rtl_precision.get("KV_MX_EXP_WIDTH", wt_exp) + kv_mant = rtl_precision.get("KV_MX_MANT_WIDTH", wt_mant) + act_exp = rtl_precision.get("ACT_MXFP_EXP_WIDTH", 4) + act_mant = rtl_precision.get("ACT_MXFP_MANT_WIDTH", 3) + scale_exp = rtl_precision.get("MX_SCALE_WIDTH", 8) + block = rtl_precision.get("BLOCK_DIM", 4) + + raw_hbm_width = (wt_mant + wt_exp + 1) * mlen + hbm_width = 1 << ((raw_hbm_width * 2 - 1).bit_length()) + _set_value(config, "CONFIG", "HBM_WIDTH", rtl_config.get("HBM_WIDTH", hbm_width)) + + precision = config["BEHAVIOR"]["PRECISION"] + _set_plain_fp( + precision["MATRIX_SRAM_TYPE"], + sign=True, + exponent=rtl_precision.get("M_FP_EXP_WIDTH", 8), + mantissa=rtl_precision.get("M_FP_MANT_WIDTH", 7), + ) + _set_plain_fp( + precision["VECTOR_SRAM_TYPE"], + sign=True, + exponent=rtl_precision.get("V_FP_EXP_WIDTH", 8), + mantissa=rtl_precision.get("V_FP_MANT_WIDTH", 7), + ) + _set_mx_fp(precision["HBM_M_WEIGHT_TYPE"], block=block, elem_exp=wt_exp, elem_mant=wt_mant, scale_exp=scale_exp) + _set_mx_fp(precision["HBM_M_KV_TYPE"], block=block, elem_exp=kv_exp, elem_mant=kv_mant, scale_exp=scale_exp) + _set_mx_fp(precision["HBM_V_ACT_TYPE"], block=block, elem_exp=act_exp, elem_mant=act_mant, scale_exp=scale_exp) + _set_mx_fp(precision["HBM_V_KV_TYPE"], block=block, elem_exp=kv_exp, elem_mant=kv_mant, scale_exp=scale_exp) + _set_scalar_fp( + precision["SCALAR_FP"], + sign=True, + exponent=rtl_precision.get("S_FP_EXP_WIDTH", 8), + mantissa=rtl_precision.get("S_FP_MANT_WIDTH", 7), + ) + + return { + "MLEN": mlen, + "VLEN": vlen, + "BLEN": blen, + "HLEN": hlen, + "BROADCAST_AMOUNT": broadcast_amount, + "M_FP_EXP_WIDTH": rtl_precision.get("M_FP_EXP_WIDTH", 8), + "M_FP_MANT_WIDTH": rtl_precision.get("M_FP_MANT_WIDTH", 7), + "V_FP_EXP_WIDTH": rtl_precision.get("V_FP_EXP_WIDTH", 8), + "V_FP_MANT_WIDTH": rtl_precision.get("V_FP_MANT_WIDTH", 7), + "S_FP_EXP_WIDTH": rtl_precision.get("S_FP_EXP_WIDTH", 8), + "S_FP_MANT_WIDTH": rtl_precision.get("S_FP_MANT_WIDTH", 7), + } + + +@contextmanager +def rtl_plena_settings(plena_toml: Path, rtl_root: Path) -> Iterator[dict[str, int]]: + """Temporarily rewrite plena_settings.toml from RTL .svh files and restore it.""" + original = plena_toml.read_text() + previous_env = os.environ.get("PLENA_SETTINGS_TOML") + try: + config = tomlkit.loads(original) + summary = apply_rtl_settings_to_toml(config, rtl_root) + plena_toml.write_text(tomlkit.dumps(config)) + os.environ["PLENA_SETTINGS_TOML"] = str(plena_toml) + yield summary + finally: + plena_toml.write_text(original) + if previous_env is None: + os.environ.pop("PLENA_SETTINGS_TOML", None) + else: + os.environ["PLENA_SETTINGS_TOML"] = previous_env diff --git a/transactional_emulator/testbench/sim_env_utils.py b/transactional_emulator/testbench/sim_env_utils.py new file mode 100644 index 0000000..af2a77a --- /dev/null +++ b/transactional_emulator/testbench/sim_env_utils.py @@ -0,0 +1,228 @@ +"""Build simulator memory artifacts from testbench tensor files.""" + +from __future__ import annotations + +import logging +import os +import sys +from pathlib import Path + +import torch + + +REPO_ROOT = Path(__file__).resolve().parents[2] +TOOLS_PATH = str(REPO_ROOT / "tools") +if TOOLS_PATH not in sys.path: + sys.path.insert(0, TOOLS_PATH) + +from compiler.assembler.assembly_to_binary import AssemblyToBinary # noqa: E402 +from memory_mapping.memory_map import ( # noqa: E402 + map_mx_data_to_hbm_for_behave_sim, + map_normal_data_to_hbm_for_behave_sim, +) +from memory_mapping.rand_gen import RandomMxfpTensorGenerator # noqa: E402 +from utils.load_config import load_toml_config # noqa: E402 +from utils.logger import get_logger # noqa: E402 + + +logger = get_logger("testbench") +logger.setLevel(logging.DEBUG) + + +class MemoryDataManager: + """Collect MX and integer memory payloads for HBM setup.""" + + def __init__(self) -> None: + self.mx_entries = [] + self.int_entries = [] + + def add_mx_file(self, filename, blocks, bias, quant_config) -> None: + self.mx_entries.append( + { + "filename": filename, + "type": "mx", + "blocks": blocks, + "bias": bias, + "quant_config": quant_config, + } + ) + + def add_int_file(self, filename, data) -> None: + self.int_entries.append({"filename": filename, "type": "int", "data": data}) + + def get_all_entries(self): + return [*self.mx_entries, *self.int_entries] + + +def _mx_quant_config(precision_node, precision_settings): + return { + "exp_width": precision_node["ELEM"]["exponent"], + "man_width": precision_node["ELEM"]["mantissa"], + "exp_bias_width": precision_node["SCALE"]["exponent"], + "block_size": [1, precision_node["block"]], + "int_width": precision_settings["HBM_V_INT_TYPE"]["DATA_TYPE"]["width"], + "skip_first_dim": False, + } + + +def _precision_for_tensor(stem: str, precision_settings): + if stem == "V" or stem.startswith("V_"): + return precision_settings["HBM_M_KV_TYPE"] + if stem == "K" or stem.startswith(("K_", "W_")): + return precision_settings["HBM_M_WEIGHT_TYPE"] + return precision_settings["HBM_V_ACT_TYPE"] + + +def create_mem_for_sim( + data_size=256, + mode="behave_sim", + asm="attn", + data=None, + specified_data_order=None, + build_path=None, +): + plena_toml_path = str(REPO_ROOT / "plena_settings.toml") + config_settings = load_toml_config(plena_toml_path, "CONFIG") + precision_settings = load_toml_config(plena_toml_path, "PRECISION") + + if mode == "behave_sim": + target_dir = Path(build_path) if build_path is not None else REPO_ROOT / "transactional_emulator/testbench/build" + asm_file = target_dir / "generated_asm_code.asm" + else: + asm_file = REPO_ROOT / "test" / "Instr_Level_Benchmark" / f"{asm}.asm" + target_dir = asm_file.parent + + init_mem(asm_file.parent) + + data_config = { + "tensor_size": [1, data_size], + "block_size": [1, precision_settings["HBM_M_WEIGHT_TYPE"]["block"]], + } + quant_config = { + "exp_width": precision_settings["HBM_V_ACT_TYPE"]["ELEM"]["exponent"], + "man_width": precision_settings["HBM_V_ACT_TYPE"]["ELEM"]["mantissa"], + "exp_bias_width": precision_settings["HBM_V_ACT_TYPE"]["SCALE"]["exponent"], + "block_size": data_config["block_size"], + "int_width": precision_settings["HBM_V_INT_TYPE"]["DATA_TYPE"]["width"], + "skip_first_dim": False, + } + + memory_data_manager = MemoryDataManager() + if mode != "behave_sim": + raw_data = RandomMxfpTensorGenerator( + shape=tuple(data_config["tensor_size"]), + quant_config=quant_config, + config_settings=config_settings, + directory=asm_file.parent, + filename=Path(f"{asm}/fake_test_raw_data.pt"), + ) + raw_data.tensor_gen() + raw_tensor = raw_data.tensor_load() + blocks, bias = raw_data.quantize_tensor(raw_tensor) + memory_data_manager.add_mx_file("fake_test_raw_data.pt", blocks, bias, quant_config) + else: + if specified_data_order is not None: + pt_files = [target_dir / f"{name}.pt" for name in specified_data_order] + else: + pt_files = list(target_dir.glob("*.pt")) + list(target_dir.glob("*.pth")) + + for pt_file in pt_files: + if pt_file.stem == "int": + memory_data_manager.add_int_file(pt_file.name, torch.load(pt_file)) + continue + + file_quant_config = _mx_quant_config(_precision_for_tensor(pt_file.stem, precision_settings), precision_settings) + file_raw_data = RandomMxfpTensorGenerator( + shape=tuple(data_config["tensor_size"]), + quant_config=file_quant_config, + config_settings=config_settings, + directory=asm_file.parent, + filename=pt_file, + ) + file_tensor = file_raw_data.tensor_load() + blocks, bias = file_raw_data.quantize_tensor(file_tensor) + memory_data_manager.add_mx_file(pt_file.name, blocks, bias, file_quant_config) + + env_setup( + memory_data_manager, + asm_file.parent, + data_config, + quant_config, + hbm_row_width=config_settings["HBM_WIDTH"]["value"], + logical_row_elements=config_settings["MLEN"]["value"], + ) + + +def env_setup( + memory_data_manager, + build_path: Path, + data_config, + quant_config, + hbm_row_width=256, + logical_row_elements=None, +) -> None: + isa_file_path = REPO_ROOT / "PLENA_Compiler" / "doc" / "operation.svh" + config_file_path = REPO_ROOT / "PLENA_Compiler" / "doc" / "configuration.svh" + + assembler = AssemblyToBinary(str(isa_file_path), str(config_file_path)) + assembler.generate_binary(build_path / "generated_asm_code.asm", build_path / "generated_machine_code.mem") + + for entry in memory_data_manager.get_all_entries(): + if entry["type"] == "mx": + entry_quant_config = entry.get("quant_config", quant_config) + map_mx_data_to_hbm_for_behave_sim( + blocks=entry["blocks"], + element_width=entry_quant_config["exp_width"] + entry_quant_config["man_width"] + 1, + block_width=entry_quant_config["block_size"][1], + bias=entry["bias"], + bias_width=entry_quant_config["exp_bias_width"], + directory=build_path, + append=True, + hbm_row_width=hbm_row_width, + logical_row_elements=logical_row_elements, + ) + elif entry["type"] == "int": + map_normal_data_to_hbm_for_behave_sim( + data=entry["data"], + data_width=quant_config["int_width"], + directory=build_path, + append=True, + hbm_row_width=hbm_row_width, + ) + + +def init_mem(build_path: Path) -> None: + build_path.mkdir(parents=True, exist_ok=True) + + hbm_bin_file = build_path / "hbm_for_behave_sim.bin" + if hbm_bin_file.exists(): + hbm_bin_file.unlink() + + hbm_element_file = build_path / "hbm_ele.mem" + hbm_scale_file = build_path / "hbm_scale.mem" + hbm_file_for_behave_sim = build_path / "hbm_for_behave_sim.mem" + instr_file = build_path / "machine_code.mem" + + os.environ["HBM_ELEMENT_FILE"] = str(hbm_element_file) + os.environ["HBM_SCALE_FILE"] = str(hbm_scale_file) + os.environ["HBM_FOR_BEHAVE_SIM_FILE"] = str(hbm_file_for_behave_sim) + os.environ["INSTR_FILE"] = str(instr_file) + + hbm_write_element_m_file = build_path / "hbm_write_m_ele.mem" + hbm_write_element_v_file = build_path / "hbm_write_v_ele.mem" + hbm_write_scale_m_file = build_path / "hbm_write_m_scale.mem" + hbm_write_scale_v_file = build_path / "hbm_write_v_scale.mem" + vector_mem_result_file = build_path / "vector_result.mem" + + hbm_write_element_m_file.touch() + hbm_write_element_v_file.touch() + hbm_write_scale_m_file.touch() + hbm_write_scale_v_file.touch() + vector_mem_result_file.touch() + + os.environ["VECTOR_MEM_RESULT_FILE"] = str(vector_mem_result_file) + os.environ["FAKE_HBM_ELEMENT_WRITE_M_FILE"] = str(hbm_write_element_m_file) + os.environ["FAKE_HBM_ELEMENT_WRITE_V_FILE"] = str(hbm_write_element_v_file) + os.environ["FAKE_HBM_SCALE_WRITE_M_FILE"] = str(hbm_write_scale_m_file) + os.environ["FAKE_HBM_SCALE_WRITE_V_FILE"] = str(hbm_write_scale_v_file) + os.environ["ASM_FILE"] = str(instr_file) diff --git a/transactional_emulator/testbench/sliced_layer_test_builder.py b/transactional_emulator/testbench/sliced_layer_test_builder.py index 6a8cd01..0f3c424 100644 --- a/transactional_emulator/testbench/sliced_layer_test_builder.py +++ b/transactional_emulator/testbench/sliced_layer_test_builder.py @@ -12,11 +12,13 @@ """ import sys +import os from dataclasses import dataclass from pathlib import Path import json +import tomlkit import torch import torch.nn.functional as F @@ -26,7 +28,7 @@ from compiler.aten.plena import PlenaCompiler from transactional_emulator.tools.create_sim_env import create_sim_env -from compiler.sim_env_utils import create_mem_for_sim +from transactional_emulator.testbench.sim_env_utils import create_mem_for_sim from transactional_emulator.testbench.emulator_runner import run_and_assert @@ -154,22 +156,100 @@ def load_ffn_weights( # --------------------------------------------------------------------------- -# MXFP8 quantization +# Active precision helpers # --------------------------------------------------------------------------- -def quantize_to_mxfp(tensor: torch.Tensor) -> torch.Tensor: - """Quantize tensor to MXFP8 matching HBM hardware format; return dequantized result.""" + + +def _active_precision_settings(): + config_path = Path(os.environ.get("PLENA_SETTINGS_TOML", Path(__file__).parents[2] / "plena_settings.toml")) + with open(config_path) as f: + return tomlkit.load(f)["BEHAVIOR"]["PRECISION"] + + +def _quantize_plain_fp_no_specials(tensor: torch.Tensor, *, exponent: int, mantissa: int, sign: bool) -> torch.Tensor: + x = tensor.float() + out = torch.zeros_like(x) + finite = torch.isfinite(x) & (x != 0) + if not sign: + finite &= x > 0 + if not torch.any(finite): + return out + + values = x[finite].abs() + exp_bias = (1 << exponent) // 2 - 1 + exp_min = -exp_bias + exp_max = (1 << exponent) - 2 - exp_bias + raw_exp = torch.floor(torch.log2(values + 1e-9)) + overflow = raw_exp > exp_max + clamped_exp = torch.clamp(raw_exp, exp_min, exp_max) + + shift = 1 << mantissa + scaled = values / torch.pow(torch.tensor(2.0, device=x.device), clamped_exp) + subnormal = clamped_exp == exp_min + shifted = torch.where(subnormal, scaled * shift, (scaled - 1.0) * shift) + shifted = torch.round(shifted).clamp(0, shift - 1) + shifted = torch.where(overflow, torch.full_like(shifted, shift - 1), shifted) + + exp_bits = clamped_exp + exp_bias + decoded_exp = exp_bits - exp_bias + decoded_base = torch.where(exp_bits == 0, shifted / shift, 1.0 + shifted / shift) + decoded = decoded_base * torch.pow(torch.tensor(2.0, device=x.device), decoded_exp) + decoded = torch.where(x[finite] < 0, -decoded, decoded) + out[finite] = decoded + return out + + +def quantize_to_vector_fp(tensor: torch.Tensor, precision=None) -> torch.Tensor: + """Quantize through the active VECTOR_SRAM_TYPE plain FP format.""" + precision = precision or _active_precision_settings() + data_type = precision["VECTOR_SRAM_TYPE"]["DATA_TYPE"] + exponent = int(data_type["exponent"]) + mantissa = int(data_type["mantissa"]) + sign = bool(data_type.get("sign", True)) + if sign and exponent == 8 and mantissa == 7: + return tensor.float().to(torch.bfloat16).float() + if sign and exponent == 5 and mantissa == 10: + return tensor.float().to(torch.float16).float() + if sign and exponent == 8 and mantissa == 23: + return tensor.float() + return _quantize_plain_fp_no_specials(tensor, exponent=exponent, mantissa=mantissa, sign=sign) + + +def quantize_to_mxfp(tensor: torch.Tensor, precision_node=None) -> torch.Tensor: + """Quantize tensor to the configured HBM MXFP format; return dequantized result.""" + if precision_node is None: + width = 8 + exponent_width = 4 + exponent_bias_width = 8 + block_size = [1, 8] + else: + width = int(precision_node["ELEM"]["exponent"]) + int(precision_node["ELEM"]["mantissa"]) + 1 + exponent_width = int(precision_node["ELEM"]["exponent"]) + exponent_bias_width = int(precision_node["SCALE"]["exponent"]) + block_size = [1, int(precision_node["block"])] + orig_shape = tensor.shape tensor_2d = tensor.float().reshape(-1, tensor.shape[-1]) bm_x, _, _, _ = _mx_fp_quantize_hardware( tensor_2d, - width=8, - exponent_width=4, - exponent_bias_width=8, - block_size=[1, 8], + width=width, + exponent_width=exponent_width, + exponent_bias_width=exponent_bias_width, + block_size=block_size, ) return bm_x.reshape(orig_shape) +def _load_to_vector_fp(tensor: torch.Tensor, hbm_precision, vector_precision) -> torch.Tensor: + return quantize_to_vector_fp(quantize_to_mxfp(tensor, hbm_precision), vector_precision) + + +def _rms_norm_vector_ref(x: torch.Tensor, eps: float, precision) -> torch.Tensor: + x_q = quantize_to_vector_fp(x, precision) + rms = quantize_to_vector_fp(torch.rsqrt(x_q.float().pow(2).mean(-1, keepdim=True) + eps), precision) + return quantize_to_vector_fp(x_q * rms, precision) + + # --------------------------------------------------------------------------- # Hardware-accurate golden reference # --------------------------------------------------------------------------- @@ -676,12 +756,19 @@ def build_and_run_sliced_decoder_layer_test( cos, sin = _make_rope_tables(seq_len, head_dim, theta=rope_theta) - # Precompute Q_rot from bfloat16-approximated intermediate - # PLENA computes embedding_add + rms_norm in bfloat16; Q_rot must match - X_embed_bf16 = token_embeds.to(torch.bfloat16) + pos_weight.to(torch.bfloat16) - rms_bf16 = torch.rsqrt(X_embed_bf16.float().pow(2).mean(-1, keepdim=True) + eps).to(torch.bfloat16) - X_norm_bf16 = X_embed_bf16 * rms_bf16 - Q_rot = _rotate_half(X_norm_bf16.float()) + precision = _active_precision_settings() + hbm_act_precision = precision["HBM_V_ACT_TYPE"] + hbm_weight_precision = precision["HBM_M_WEIGHT_TYPE"] + hbm_kv_precision = precision["HBM_M_KV_TYPE"] + + # Precompute Q_rot from the same vector precision path used by the simulator. + X_embed_q = quantize_to_vector_fp( + _load_to_vector_fp(token_embeds, hbm_act_precision, precision) + + _load_to_vector_fp(pos_weight, hbm_act_precision, precision), + precision, + ) + X_norm_q = _rms_norm_vector_ref(X_embed_q, eps, precision) + Q_rot = _rotate_half(X_norm_q.float()) print(f"\ntoken_embeds: {token_embeds.shape}, range [{token_embeds.min():.3f}, {token_embeds.max():.3f}]") print(f"pos_weight: {pos_weight.shape}, range [{pos_weight.min():.3f}, {pos_weight.max():.3f}]") @@ -696,33 +783,25 @@ def build_and_run_sliced_decoder_layer_test( print(f"\nattn_scale: {scale:.6f}") # ----------------------------------------------------------- golden ref - # Apply MXFP8 quantization to all HBM-stored tensors (matching hardware storage). - # K/V from real weights can have values up to ±15 — coarse quantization at that - # scale causes large attention errors unless the golden accounts for it. - K_q = quantize_to_mxfp(K_mat) - V_q = quantize_to_mxfp(V_mat) - W_gate_q = quantize_to_mxfp(W_gate) - W_up_q = quantize_to_mxfp(W_up) - W_down_q = quantize_to_mxfp(W_down) - - print("\n--- CPU Golden Reference (MXFP8 quantized HBM tensors + BF16 intermediates) ---") - - X_gold = token_embeds.clone() - X_gold = X_gold + pos_weight # embedding_add - # Use bfloat16 rms_norm to match PLENA's quantised intermediate - X_gold_bf16 = X_gold.to(torch.bfloat16) - rms_gold = torch.rsqrt(X_gold_bf16.float().pow(2).mean(-1, keepdim=True) + eps).to(torch.bfloat16) - X_gold = (X_gold_bf16 * rms_gold).float() # rms_norm (bfloat16) - Q_rot_gold = _rotate_half(X_gold) # consistent Q_rot - X_gold = X_gold * cos + Q_rot_gold * sin # rope - X_gold = _flash_attn_ref(X_gold, K_q, V_q, scale) # flash_attn (MXFP8 K/V) - # FFN with MXFP8 weights + BF16 intermediates (matches hardware VRAM storage) - X_gold_attn = X_gold.to(torch.bfloat16) - up_out = torch.matmul(X_gold_attn.float(), W_up_q.float()).to(torch.bfloat16) - gate_out = torch.matmul(X_gold_attn.float(), W_gate_q.float()).to(torch.bfloat16) - silu_gate = (F.silu(up_out.float()) * gate_out.float()).to(torch.bfloat16) - X_gold = torch.matmul(silu_gate.float(), W_down_q.float()).to(torch.bfloat16).float() - X_gold = _rms_norm_ref(X_gold, eps) # final rms_norm + K_q = quantize_to_mxfp(K_mat, hbm_weight_precision) + V_q = quantize_to_mxfp(V_mat, hbm_kv_precision) + W_gate_q = quantize_to_mxfp(W_gate, hbm_weight_precision) + W_up_q = quantize_to_mxfp(W_up, hbm_weight_precision) + W_down_q = quantize_to_mxfp(W_down, hbm_weight_precision) + + print("\n--- CPU Golden Reference (active TOML HBM + vector precision) ---") + + X_gold = X_embed_q + Q_rot_gold = _load_to_vector_fp(Q_rot, hbm_act_precision, precision) + cos_q = _load_to_vector_fp(cos, hbm_act_precision, precision) + sin_q = _load_to_vector_fp(sin, hbm_act_precision, precision) + X_gold = quantize_to_vector_fp(X_gold * cos_q + Q_rot_gold * sin_q, precision) # rope + X_gold = quantize_to_vector_fp(_flash_attn_ref(X_gold, K_q, V_q, scale), precision) + up_out = quantize_to_vector_fp(torch.matmul(X_gold.float(), W_up_q.float()), precision) + gate_out = quantize_to_vector_fp(torch.matmul(X_gold.float(), W_gate_q.float()), precision) + silu_gate = quantize_to_vector_fp(F.silu(up_out.float()) * gate_out.float(), precision) + X_gold = quantize_to_vector_fp(torch.matmul(silu_gate.float(), W_down_q.float()), precision) + X_gold = _rms_norm_vector_ref(X_gold, eps, precision) golden_out = X_gold print(f" golden_out: {golden_out.shape}") diff --git a/transactional_emulator/tools/check_mem.py b/transactional_emulator/tools/check_mem.py index 34056cd..f7e90e4 100644 --- a/transactional_emulator/tools/check_mem.py +++ b/transactional_emulator/tools/check_mem.py @@ -39,7 +39,14 @@ def parse_golden_output(golden_file_path): def read_bin_file_as_array( - bin_file, exp_width, man_width, row_dim, num_bytes_per_val=2, start_row_idx=0, num_rows=None + bin_file, + exp_width, + man_width, + row_dim, + num_bytes_per_val=2, + start_row_idx=0, + num_rows=None, + bits_per_val=None, ): """ Read binary file and convert to numpy array (similar to view_bin_file_by_row but returns array). @@ -59,7 +66,8 @@ def read_bin_file_as_array( """ sign_width = 1 total_width = sign_width + exp_width + man_width - if total_width > num_bytes_per_val * 8: + storage_width = bits_per_val if bits_per_val is not None else num_bytes_per_val * 8 + if total_width > storage_width: raise ValueError("num_bytes_per_val is too small for given bit widths.") def raw_to_fp(bits_val): @@ -92,7 +100,7 @@ def raw_to_fp(bits_val): with open(bin_file, "rb") as f: data = f.read() - num_vals = len(data) // num_bytes_per_val + num_vals = (len(data) * 8) // storage_width total_rows = (num_vals + row_dim - 1) // row_dim # Calculate which rows to read @@ -106,11 +114,16 @@ def raw_to_fp(bits_val): if val_idx >= num_vals: # Reached end of data, pad with None or break break - chunk = data[val_idx * num_bytes_per_val : (val_idx + 1) * num_bytes_per_val] - if not chunk or len(chunk) < num_bytes_per_val: + bit_offset = val_idx * storage_width + byte_offset = bit_offset // 8 + bit_shift = bit_offset % 8 + bytes_needed = (bit_shift + storage_width + 7) // 8 + chunk = data[byte_offset : byte_offset + bytes_needed] + if not chunk or len(chunk) < bytes_needed: break # Use little-endian byte order to match Rust's byte packing - bits_val = int.from_bytes(chunk, byteorder="little") + raw = int.from_bytes(chunk, byteorder="little") + bits_val = (raw >> bit_shift) & ((1 << total_width) - 1) float_val = raw_to_fp(bits_val) values.append(float_val) @@ -204,6 +217,7 @@ def compare_vram_with_golden( exp_width=8, man_width=7, num_bytes_per_val=2, + bits_per_val=None, row_dim=64, start_row_idx=0, num_batches=4, @@ -255,7 +269,14 @@ def compare_vram_with_golden( # Read binary file (now properly handles row-based indexing) simulated_np = read_bin_file_as_array( - bin_file, exp_width, man_width, row_dim, num_bytes_per_val, start_row_idx, num_rows + bin_file, + exp_width, + man_width, + row_dim, + num_bytes_per_val, + start_row_idx, + num_rows, + bits_per_val=bits_per_val, ) # Apply slice mode: extract first slice_per_row elements from each row diff --git a/transactional_emulator/tools/create_sim_env.py b/transactional_emulator/tools/create_sim_env.py index 60aa869..8ce9965 100644 --- a/transactional_emulator/tools/create_sim_env.py +++ b/transactional_emulator/tools/create_sim_env.py @@ -23,6 +23,23 @@ def np_array_to_str_2f(arr): ) +def np_array_to_str_precise(arr): + arr = np.asarray(arr) + if arr.ndim == 1: + return "[" + " ".join([f"{v:.9g}" for v in arr]) + "]" + if arr.ndim == 2: + rows = [" " + " ".join([f"{v:.9g}" for v in row]) for row in arr] + return "[\n" + "\n".join(rows) + "\n]" + + import sys as _sys + + return np.array2string( + arr, + formatter={"float_kind": lambda x: f"{x:.9g}"}, + threshold=_sys.maxsize, + ) + + def create_sim_env( input_tensor, generated_code, @@ -76,7 +93,7 @@ def create_sim_env( f.write("\n\nOriginal Output:\n") # Convert BFloat16 to float32 before converting to numpy output_np = golden_result["original_output"].detach().cpu().float().numpy() - f.write(np_array_to_str_2f(output_np)) + f.write(np_array_to_str_precise(output_np)) if vram_preload is not None: # vram_preload: a flat tensor or numpy array of fp16 values representing diff --git a/uv.lock b/uv.lock index eb5864b..5d302c5 100644 --- a/uv.lock +++ b/uv.lock @@ -1,5 +1,5 @@ version = 1 -revision = 2 +revision = 3 requires-python = ">=3.12" [[package]] @@ -926,6 +926,7 @@ dependencies = [ { name = "pytest" }, { name = "ruff" }, { name = "toml" }, + { name = "tomlkit" }, { name = "torch" }, { name = "tqdm" }, { name = "transformers" }, @@ -942,6 +943,7 @@ requires-dist = [ { name = "pytest" }, { name = "ruff", specifier = ">=0.12" }, { name = "toml" }, + { name = "tomlkit", specifier = ">=0.15.0" }, { name = "torch", specifier = "==2.7.1+cu126", index = "https://download.pytorch.org/whl/cu126" }, { name = "tqdm" }, { name = "transformers" }, @@ -1356,6 +1358,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/44/6f/7120676b6d73228c96e17f1f794d8ab046fc910d781c8d151120c3f1569e/toml-0.10.2-py2.py3-none-any.whl", hash = "sha256:806143ae5bfb6a3c6e736a764057db0e6a0e05e338b5630894a5f779cabb4f9b", size = 16588, upload-time = "2020-11-01T01:40:20.672Z" }, ] +[[package]] +name = "tomlkit" +version = "0.15.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/51/db/03eaf4331631ef6b27d6e3c9b68c54dc6f0d63d87201fed600cc409307fd/tomlkit-0.15.0.tar.gz", hash = "sha256:7d1a9ecba3086638211b13814ea79c90dd54dd11993564376f3aa92271f5c7a3", size = 161875, upload-time = "2026-05-10T07:38:22.245Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/6a/43/8bd850ee71a191bf072e31302c73a66be413fecdd98fdcd111ecbcce13ca/tomlkit-0.15.0-py3-none-any.whl", hash = "sha256:4dbc8f0fc024412b57ced8757ac7461305126a648ff8c2c807fcb8e133a78738", size = 41328, upload-time = "2026-05-10T07:38:23.517Z" }, +] + [[package]] name = "torch" version = "2.7.1+cu126" @@ -1385,12 +1396,12 @@ dependencies = [ { name = "typing-extensions" }, ] wheels = [ - { url = "https://download.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:63bce0590bc540fc16139e2be0177847585182b8c5e68d7f9213789d1d96c978" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp312-cp312-win_amd64.whl", hash = "sha256:7d897b5ff67e778de4a2a05d4528377003105e29854fd73ecbe965287533f08b" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:a05c0001fd1d0ceae9cda8c8c1b8a16ed5def858fe996c9237a28016559dad52" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp313-cp313-win_amd64.whl", hash = "sha256:a38a903c9b55cea1217100e0851b25659765b6bb8cd75e6de6bbf0063a2cd51e" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:27d396231f33dc6103ba26ec6ec2ec5939d9850b599e32da711b038af272954e" }, - { url = "https://download.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp313-cp313t-win_amd64.whl", hash = "sha256:d4e68a1aeb2a6272d0234b7575089fc70757a93d24dccde8e962a3b18aef77d1" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:63bce0590bc540fc16139e2be0177847585182b8c5e68d7f9213789d1d96c978", upload-time = "2025-06-03T18:29:29Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp312-cp312-win_amd64.whl", hash = "sha256:7d897b5ff67e778de4a2a05d4528377003105e29854fd73ecbe965287533f08b", upload-time = "2025-06-03T18:29:36Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:a05c0001fd1d0ceae9cda8c8c1b8a16ed5def858fe996c9237a28016559dad52", upload-time = "2025-06-03T18:29:44Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp313-cp313-win_amd64.whl", hash = "sha256:a38a903c9b55cea1217100e0851b25659765b6bb8cd75e6de6bbf0063a2cd51e", upload-time = "2025-06-03T18:29:58Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:27d396231f33dc6103ba26ec6ec2ec5939d9850b599e32da711b038af272954e", upload-time = "2025-06-03T18:29:59Z" }, + { url = "https://download-r2.pytorch.org/whl/cu126/torch-2.7.1%2Bcu126-cp313-cp313t-win_amd64.whl", hash = "sha256:d4e68a1aeb2a6272d0234b7575089fc70757a93d24dccde8e962a3b18aef77d1", upload-time = "2025-06-03T18:30:01Z" }, ] [[package]]