Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions crates/backend/air/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,3 +6,6 @@ edition.workspace = true
[dependencies]
field = { path = "../field", package = "mt-field" }
poly = { path = "../poly", package = "mt-poly" }

[dev-dependencies]
koala-bear = { path = "../koala-bear", package = "mt-koala-bear" }
270 changes: 255 additions & 15 deletions crates/backend/air/src/symbolic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use core::iter::{Product, Sum};
use core::marker::PhantomData;
use core::ops::{Add, AddAssign, Mul, MulAssign, Neg, Sub, SubAssign};
use std::cell::RefCell;
use std::sync::atomic::{AtomicU32, Ordering};

use field::{Algebra, Field, InjectiveMonomial, PrimeCharacteristicRing};

Expand Down Expand Up @@ -73,37 +74,149 @@ pub struct SymbolicNode<F: Copy> {
pub rhs: SymbolicExpression<F>, // dummy (ZERO) for Neg
}

// We use an arena as a trick to allow SymbolicExpression to be Copy
// (ugly trick but fine in practice since SymbolicExpression is only used once at the start of the program)
/// Opaque handle into the thread-local symbolic arena.
///
/// Handles are scoped to a specific arena (thread) and generation (clear cycle).
/// Using a handle from a different thread or after the arena has been cleared will
/// produce a deterministic error instead of undefined behaviour.
#[derive(Copy, Clone, Debug, PartialEq, Eq, Hash)]
pub struct SymbolicNodeRef<F> {
arena_id: u32,
generation: u32,
offset: u32,
_phantom: PhantomData<fn() -> F>,
}

#[derive(Debug, Copy, Clone, PartialEq, Eq)]
pub enum SymbolicNodeAccessError {
WrongArena,
StaleGeneration,
OutOfBounds,
}

impl core::fmt::Display for SymbolicNodeAccessError {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Self::WrongArena => {
write!(f, "symbolic node handle belongs to a different thread's arena")
}
Self::StaleGeneration => {
write!(f, "symbolic node handle is stale (arena was cleared)")
}
Self::OutOfBounds => write!(f, "symbolic node handle offset is out of bounds"),
}
}
}

impl std::error::Error for SymbolicNodeAccessError {}

#[derive(Debug)]
struct ArenaState {
arena_id: u32,
generation: u32,
bytes: Vec<u8>,
}

impl ArenaState {
fn new() -> Self {
Self {
arena_id: next_arena_id(),
generation: 0,
bytes: Vec::new(),
}
}
}

static NEXT_ARENA_ID: AtomicU32 = AtomicU32::new(1);

fn next_arena_id_after(id: u32) -> Option<u32> {
id.checked_add(1)
}

fn next_arena_id() -> u32 {
NEXT_ARENA_ID
.fetch_update(Ordering::Relaxed, Ordering::Relaxed, next_arena_id_after)
.expect("symbolic arena id overflow")
}

fn checked_arena_allocation_range(offset: usize, node_size: usize) -> (u32, usize) {
let end = offset
.checked_add(node_size)
.expect("symbolic arena allocation overflow");
let offset_u32 = u32::try_from(offset).expect("symbolic arena exceeded u32::MAX bytes");
u32::try_from(end).expect("symbolic arena exceeded u32::MAX bytes");
(offset_u32, end)
}

// We use an arena as a trick to allow SymbolicExpression to be Copy.
// Handles carry arena_id + generation so that stale or cross-thread use
// is caught deterministically instead of reading garbage bytes.
thread_local! {
static ARENA: RefCell<Vec<u8>> = const { RefCell::new(Vec::new()) };
static ARENA: RefCell<ArenaState> = RefCell::new(ArenaState::new());
}

fn alloc_node<F: Field>(node: SymbolicNode<F>) -> u32 {
fn clear_arena() {
ARENA.with(|arena| {
let mut bytes = arena.borrow_mut();
let mut state = arena.borrow_mut();
state.generation = state
.generation
.checked_add(1)
.expect("symbolic arena generation overflow");
state.bytes.clear();
});
}

fn alloc_node<F: Field>(node: SymbolicNode<F>) -> SymbolicNodeRef<F> {
ARENA.with(|arena| {
let mut state = arena.borrow_mut();
let node_size = std::mem::size_of::<SymbolicNode<F>>();
let idx = bytes.len();
bytes.resize(idx + node_size, 0);
let offset = state.bytes.len();
let (offset_u32, end) = checked_arena_allocation_range(offset, node_size);
state.bytes.resize(end, 0);
// SAFETY: We just resized the buffer to `end` bytes, so `offset..end` is valid.
unsafe {
std::ptr::write_unaligned(bytes.as_mut_ptr().add(idx) as *mut SymbolicNode<F>, node);
std::ptr::write_unaligned(state.bytes.as_mut_ptr().add(offset).cast::<SymbolicNode<F>>(), node);
}
SymbolicNodeRef {
arena_id: state.arena_id,
generation: state.generation,
offset: offset_u32,
_phantom: PhantomData,
}
idx as u32
})
}

pub fn get_node<F: Field>(idx: u32) -> SymbolicNode<F> {
pub fn try_get_node<F: Field>(handle: SymbolicNodeRef<F>) -> Result<SymbolicNode<F>, SymbolicNodeAccessError> {
ARENA.with(|arena| {
let bytes = arena.borrow();
unsafe { std::ptr::read_unaligned(bytes.as_ptr().add(idx as usize) as *const SymbolicNode<F>) }
let state = arena.borrow();
if state.arena_id != handle.arena_id {
return Err(SymbolicNodeAccessError::WrongArena);
}
if state.generation != handle.generation {
return Err(SymbolicNodeAccessError::StaleGeneration);
}
let offset = handle.offset as usize;
let node_size = std::mem::size_of::<SymbolicNode<F>>();
let end = offset
.checked_add(node_size)
.ok_or(SymbolicNodeAccessError::OutOfBounds)?;
if end > state.bytes.len() {
return Err(SymbolicNodeAccessError::OutOfBounds);
}
// SAFETY: We verified that `offset..end` is within the arena buffer.
Ok(unsafe { std::ptr::read_unaligned(state.bytes.as_ptr().add(offset).cast::<SymbolicNode<F>>()) })
})
}

pub fn get_node<F: Field>(handle: SymbolicNodeRef<F>) -> SymbolicNode<F> {
try_get_node(handle).expect("invalid or stale symbolic node handle")
}

#[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum SymbolicExpression<F: Copy> {
Variable(SymbolicVariable<F>),
Constant(F),
Operation(u32), // index into thread-local arena
Operation(SymbolicNodeRef<F>),
}

impl<F: Field> Default for SymbolicExpression<F> {
Expand Down Expand Up @@ -321,8 +434,7 @@ pub fn get_symbolic_constraints_and_bus_data_values<F: Field, A: Air>(air: &A) -
where
A::ExtraData: Default,
{
// Clear the arena before building constraints
ARENA.with(|arena| arena.borrow_mut().clear());
clear_arena();

let mut builder = SymbolicAirBuilder::<F>::new(air.n_columns(), air.n_shift_columns());
air.eval(&mut builder, &Default::default());
Expand All @@ -332,3 +444,131 @@ where
builder.bus_data_values.unwrap(),
)
}

#[cfg(test)]
mod tests {
use super::*;
use koala_bear::KoalaBear;

type F = KoalaBear;

const _: () = {
const fn assert_copy<T: Copy>() {}
assert_copy::<SymbolicExpression<F>>();
assert_copy::<SymbolicNodeRef<F>>();
};

#[test]
fn roundtrip_alloc_get() {
clear_arena();
let a = SymbolicExpression::<F>::Constant(F::ONE);
let b = SymbolicExpression::<F>::Constant(F::TWO);
let handle = alloc_node(SymbolicNode {
op: SymbolicOperation::Add,
lhs: a,
rhs: b,
});
let node = get_node::<F>(handle);
assert_eq!(node.op, SymbolicOperation::Add);
assert_eq!(node.lhs, a);
assert_eq!(node.rhs, b);
}

#[test]
fn stale_handle_rejected_after_clear() {
clear_arena();
let handle = alloc_node(SymbolicNode {
op: SymbolicOperation::Mul,
lhs: SymbolicExpression::<F>::ONE,
rhs: SymbolicExpression::<F>::TWO,
});
assert!(try_get_node::<F>(handle).is_ok());
clear_arena();
assert!(matches!(
try_get_node::<F>(handle),
Err(SymbolicNodeAccessError::StaleGeneration)
));
}

#[test]
fn old_handle_cannot_read_new_generation_bytes() {
clear_arena();
let old_handle = alloc_node(SymbolicNode {
op: SymbolicOperation::Add,
lhs: SymbolicExpression::<F>::ONE,
rhs: SymbolicExpression::<F>::TWO,
});
clear_arena();
let _new_handle = alloc_node(SymbolicNode {
op: SymbolicOperation::Sub,
lhs: SymbolicExpression::<F>::ZERO,
rhs: SymbolicExpression::<F>::ONE,
});
assert!(matches!(
try_get_node::<F>(old_handle),
Err(SymbolicNodeAccessError::StaleGeneration)
));
}

#[test]
fn wrong_thread_handle_rejected() {
clear_arena();
let handle = alloc_node(SymbolicNode {
op: SymbolicOperation::Neg,
lhs: SymbolicExpression::<F>::ONE,
rhs: SymbolicExpression::<F>::ZERO,
});
let result = std::thread::spawn(move || try_get_node::<F>(handle)).join().unwrap();
assert!(matches!(result, Err(SymbolicNodeAccessError::WrongArena)));
}

#[test]
fn out_of_bounds_handle_rejected() {
clear_arena();
let bogus = SymbolicNodeRef::<F> {
arena_id: ARENA.with(|a| a.borrow().arena_id),
generation: ARENA.with(|a| a.borrow().generation),
offset: 999_999,
_phantom: PhantomData,
};
assert!(matches!(
try_get_node::<F>(bogus),
Err(SymbolicNodeAccessError::OutOfBounds)
));
}

#[test]
fn offset_truncation_detected() {
assert!(std::panic::catch_unwind(|| checked_arena_allocation_range(u32::MAX as usize, 1)).is_err());
}

#[test]
fn arena_id_overflow_detected() {
assert!(next_arena_id_after(u32::MAX).is_none());
}

#[test]
fn arithmetic_produces_valid_handles() {
clear_arena();
let var = SymbolicExpression::<F>::Variable(SymbolicVariable::new(0));
let c = SymbolicExpression::<F>::Constant(F::TWO);
let sum = var + c;
if let SymbolicExpression::Operation(handle) = sum {
let node = get_node::<F>(handle);
assert_eq!(node.op, SymbolicOperation::Add);
assert_eq!(node.lhs, var);
assert_eq!(node.rhs, c);
} else {
panic!("expected Operation variant from variable + constant");
}

let neg = -var;
if let SymbolicExpression::Operation(handle) = neg {
let node = get_node::<F>(handle);
assert_eq!(node.op, SymbolicOperation::Neg);
assert_eq!(node.lhs, var);
} else {
panic!("expected Operation variant from neg(variable)");
}
}
}
Loading