Skip to content

Commit 2e637f3

Browse files
committed
use bool for found
1 parent c1a94c7 commit 2e637f3

2 files changed

Lines changed: 18 additions & 26 deletions

File tree

src/lib.nr

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -111,7 +111,7 @@ where
111111
* @brief determine whether `target` is present in `self.keys`
112112
* @details if `found == false`, `self.keys[found_index] < target < self.keys[found_index + 1]`
113113
**/
114-
unconstrained fn search_for_key(self, target: u32) -> (u32, u32) {
114+
unconstrained fn search_for_key(self, target: u32) -> (bool, u32) {
115115
let mut found = false;
116116
let mut found_index: u32 = 0;
117117
let mut previous_less_than_or_equal_to_target = false;
@@ -129,7 +129,7 @@ where
129129
}
130130
previous_less_than_or_equal_to_target = current_less_than_or_equal_to_target;
131131
}
132-
(found as u32, found_index)
132+
(found, found_index)
133133
}
134134

135135
/**
@@ -138,8 +138,6 @@ where
138138
**/
139139
fn get(self, idx: u32) -> T {
140140
let (found, found_index) = unsafe { self.search_for_key(idx) };
141-
// bool check. 0.25 gates cheaper than a raw `bool` type. need to fix at some point
142-
assert(found * found == found);
143141

144142
// OK! So we have the following cases to check
145143
// 1. if `found` then `self.keys[found_index] == idx`
@@ -150,13 +148,13 @@ where
150148
// combine the two into the following single statement:
151149
// `self.keys[found_index] + 1 - found <= idx <= self.keys[found_index + 1 - found] - 1 + found
152150
let lhs = self.keys[found_index];
153-
let rhs = self.keys[found_index + 1 - found];
154-
assert(lhs + 1 - found <= idx);
155-
assert(idx <= rhs + found - 1);
151+
let rhs = self.keys[found_index + 1 - found as u32];
152+
assert(lhs + 1 - found as u32 <= idx);
153+
assert(idx <= rhs + found as u32 - 1);
156154

157155
// self.keys[i] maps to self.values[i+1]
158156
// however...if we did not find a non-sparse entry, we want to return self.values[0] (the default value)
159-
let value_index = (found_index + 1) * found;
157+
let value_index = (found_index + 1) * found as u32;
160158
self.values[value_index]
161159
}
162160
}

src/mut_sparse_array.nr

Lines changed: 12 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ where
123123
r
124124
}
125125

126-
unconstrained fn search_for_key(self, target: u32) -> (u32, u32) {
126+
unconstrained fn search_for_key(self, target: u32) -> (bool, u32) {
127127
let mut found = false;
128128
let mut found_index = 0;
129129
let mut previous_less_than_or_equal_to_target = false;
@@ -146,19 +146,18 @@ where
146146
// }
147147
}
148148

149-
(found as u32, found_index)
149+
(found, found_index)
150150
}
151151

152-
unconstrained fn __check_if_can_insert(self, found: u32) {
152+
unconstrained fn __check_if_can_insert(self, found: bool) {
153153
assert(
154-
(found == 1) | (self.tail_ptr < N + 2),
154+
(found == true) | (self.tail_ptr < N + 2),
155155
"MutSparseArray::set exceeded maximum size of array",
156156
);
157157
}
158158

159159
fn set(&mut self, idx: u32, value: T) {
160160
let (found, found_index) = unsafe { self.search_for_key(idx) };
161-
assert(found * found == found);
162161

163162
// check can be unsafe because, if check fails, unsatisfiable constraints are created
164163
// due to an array overflow when accesing `self.linked_keys[self.tail_ptr]`
@@ -167,8 +166,6 @@ where
167166
let lhs_index = found_index;
168167
let rhs_index = self.linked_keys[found_index];
169168

170-
assert(found * found == found);
171-
172169
// OK! So we have the following cases to check
173170
// 1. if `found` then `self.keys[found_index] == idx`
174171
// 2. if `!found` then `self.keys[found_index] < idx < self.keys[found_index + 1]
@@ -180,17 +177,17 @@ where
180177
let lhs = self.keys[lhs_index];
181178
let rhs = self.keys[rhs_index];
182179

183-
assert(lhs + 1 - found <= idx);
184-
assert(idx <= rhs + found - 1);
180+
assert(lhs + 1 - found as u32 <= idx);
181+
assert(idx <= rhs + found as u32 - 1);
185182

186183
// lhs points to tail_ptr
187184
// tail_ptr points to rhs
188-
if (found == 0) {
185+
if (found == false) {
189186
self.keys[self.tail_ptr] = idx;
190187

191-
self.linked_keys[found_index] = self.tail_ptr * (1 - found) + found * rhs_index;
188+
self.linked_keys[found_index] = self.tail_ptr;
192189

193-
self.linked_keys[self.tail_ptr] = rhs_index * (1 - found);
190+
self.linked_keys[self.tail_ptr] = rhs_index;
194191
self.values[self.tail_ptr + 1] = value;
195192
self.tail_ptr += 1;
196193
} else {
@@ -200,13 +197,10 @@ where
200197

201198
fn get(self, idx: u32) -> T {
202199
let (found, found_index) = unsafe { self.search_for_key(idx) };
203-
assert(found * found == found);
204200

205201
let lhs_index = found_index;
206202
let rhs_index = self.linked_keys[found_index];
207203

208-
assert(found * found == found);
209-
210204
// OK! So we have the following cases to check
211205
// 1. if `found` then `self.keys[found_index] == idx`
212206
// 2. if `!found` then `self.keys[found_index] < idx < self.keys[found_index + 1]
@@ -217,9 +211,9 @@ where
217211
// `self.keys[found_index] + 1 - found <= idx <= self.keys[found_index + 1 - found] - 1 + found
218212
let lhs = self.keys[lhs_index];
219213
let rhs = self.keys[rhs_index];
220-
assert(lhs + 1 - found <= idx);
221-
assert(idx <= rhs + found - 1);
222-
let value_index = (lhs_index + 1) * found;
214+
assert(lhs + 1 - found as u32 <= idx);
215+
assert(idx <= rhs + found as u32 - 1);
216+
let value_index = (lhs_index + 1) * found as u32;
223217
self.values[value_index]
224218
}
225219
}

0 commit comments

Comments
 (0)