Skip to content

Commit f23cad4

Browse files
committed
feat: add frozen MemTable for nonblocking merge
1 parent 780ed1d commit f23cad4

2 files changed

Lines changed: 175 additions & 27 deletions

File tree

src/db.rs

Lines changed: 171 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@ use std::fmt::Debug;
88
use std::fs;
99
use std::iter::Peekable;
1010
use std::path::PathBuf;
11+
use std::sync::{Arc, mpsc};
12+
use std::thread;
1113
use uuid::Uuid;
1214
use xorf::{BinaryFuse8, Filter};
1315

@@ -21,15 +23,18 @@ struct MergedIterator<'a> {
2123

2224
struct HybridIterator<'a> {
2325
mem_iter: std::collections::hash_map::Iter<'a, String, Value>,
26+
frozen_iter: Option<std::collections::hash_map::Iter<'a, String, Value>>,
2427
disk_iter: MergedIterator<'a>,
2528
memtable: &'a HashMap<String, Value>,
29+
frozen_memtable: Option<&'a HashMap<String, Value>>,
2630
phase: ScanPhase,
2731
predicate: Option<Expression<'a>>,
2832
projections: Option<Vec<Expression<'a>>>,
2933
}
3034

3135
enum ScanPhase {
3236
MemTable,
37+
FrozenMemTable,
3338
Disk,
3439
}
3540

@@ -70,6 +75,50 @@ impl<'a> Iterator for HybridIterator<'a> {
7075
}
7176

7277
return Some(ExecutionResult::Value(id.clone(), val.clone()));
78+
} else {
79+
self.phase = ScanPhase::FrozenMemTable;
80+
}
81+
}
82+
ScanPhase::FrozenMemTable => {
83+
if let Some(iter) = &mut self.frozen_iter {
84+
if let Some((id, val)) = iter.next() {
85+
// Check if shadowed by active memtable
86+
if self.memtable.contains_key(id) {
87+
continue;
88+
}
89+
90+
use jsonb_schema::Value as JsonbValue;
91+
if matches!(val, JsonbValue::Null) {
92+
continue; // Tombstone
93+
}
94+
95+
if let Some(pred) = &self.predicate {
96+
if evaluate_expression(pred, val) != Value::Bool(true) {
97+
continue;
98+
}
99+
}
100+
101+
if let Some(projs) = &self.projections {
102+
let mut new_doc = BTreeMap::new();
103+
for expr in projs {
104+
let v = evaluate_expression(expr, val);
105+
let key = match expr {
106+
Expression::FieldReference(_, raw) => raw.to_string(),
107+
Expression::JsonPath(_, raw) => raw.to_string(),
108+
_ => "col".to_string(),
109+
};
110+
new_doc.insert(key, v);
111+
}
112+
return Some(ExecutionResult::Value(
113+
id.clone(),
114+
Value::Object(new_doc),
115+
));
116+
}
117+
118+
return Some(ExecutionResult::Value(id.clone(), val.clone()));
119+
} else {
120+
self.phase = ScanPhase::Disk;
121+
}
73122
} else {
74123
self.phase = ScanPhase::Disk;
75124
}
@@ -79,6 +128,11 @@ impl<'a> Iterator for HybridIterator<'a> {
79128
if self.memtable.contains_key(res.id()) {
80129
continue;
81130
}
131+
if let Some(frozen) = self.frozen_memtable {
132+
if frozen.contains_key(res.id()) {
133+
continue;
134+
}
135+
}
82136
return Some(res);
83137
} else {
84138
return None;
@@ -216,14 +270,16 @@ fn sanitize_filename(name: &str) -> String {
216270
result
217271
}
218272

219-
struct LoadedTable {
273+
pub(crate) struct LoadedTable {
220274
filter: BinaryFuse8,
221275
index: Vec<(String, u64)>,
222276
}
223277

224278
struct Collection {
225279
name: String,
226280
pub memtable: MemTable,
281+
pub frozen_memtable: Option<Arc<MemTable>>,
282+
flush_rx: Option<mpsc::Receiver<Result<LoadedTable, String>>>,
227283
dir: PathBuf,
228284
jstable_count: u64,
229285
logger: Box<dyn Log>,
@@ -271,6 +327,8 @@ impl Collection {
271327
Collection {
272328
name,
273329
memtable,
330+
frozen_memtable: None,
331+
flush_rx: None,
274332
dir,
275333
jstable_count,
276334
logger,
@@ -283,8 +341,15 @@ impl Collection {
283341

284342
#[tracing::instrument]
285343
fn insert(&mut self, doc: Value) -> String {
344+
// Check if any background flush finished
345+
self.check_flush_status(false);
346+
286347
if self.memtable.len() >= self.memtable_threshold {
287-
self.flush();
348+
// If we still have a frozen memtable, we must wait for it to clear
349+
if self.frozen_memtable.is_some() {
350+
self.check_flush_status(true);
351+
}
352+
self.trigger_flush();
288353
}
289354
let id = Uuid::now_v7().to_string();
290355
self.logger
@@ -316,30 +381,82 @@ impl Collection {
316381
self.memtable.update(id, doc);
317382
}
318383

319-
fn flush(&mut self) {
320-
let jstable_path = self.dir.join(format!("jstable-{}", self.jstable_count));
321-
let memtable = std::mem::take(&mut self.memtable);
322-
memtable
323-
.flush(
324-
jstable_path.to_str().unwrap(),
325-
self.name.clone(),
326-
self.index_threshold,
327-
)
328-
.unwrap();
329-
330-
// Load the new filter and index
331-
let path_str = jstable_path.to_str().unwrap();
332-
let filter = jstable::read_filter(path_str).unwrap();
333-
let index = jstable::read_index(path_str).unwrap();
384+
fn check_flush_status(&mut self, wait: bool) {
385+
if let Some(rx) = &self.flush_rx {
386+
let res = if wait {
387+
rx.recv().map_err(|e| e.to_string())
388+
} else {
389+
rx.try_recv().map_err(|e| e.to_string())
390+
};
391+
392+
if let Ok(result) = res {
393+
match result {
394+
Ok(table) => {
395+
self.tables.push(table);
396+
self.jstable_count += 1;
397+
if self.jstable_count >= self.jstable_threshold {
398+
self.compact();
399+
}
400+
}
401+
Err(e) => eprintln!("Background flush failed: {}", e),
402+
}
403+
self.frozen_memtable = None;
404+
self.flush_rx = None;
405+
} else if wait {
406+
// If we waited and got error (e.g. channel closed), clear state
407+
self.frozen_memtable = None;
408+
self.flush_rx = None;
409+
}
410+
}
411+
}
334412

335-
self.tables.push(LoadedTable { filter, index });
413+
pub fn wait_for_ongoing_flush(&mut self) {
414+
self.check_flush_status(true);
415+
}
336416

337-
self.jstable_count += 1;
417+
fn trigger_flush(&mut self) {
338418
self.logger.rotate().unwrap();
339419

340-
if self.jstable_count >= self.jstable_threshold {
341-
self.compact();
342-
}
420+
let jstable_path = self.dir.join(format!("jstable-{}", self.jstable_count));
421+
let index_threshold = self.index_threshold;
422+
let collection_name = self.name.clone();
423+
424+
let memtable = std::mem::take(&mut self.memtable);
425+
let frozen = Arc::new(memtable);
426+
self.frozen_memtable = Some(frozen.clone());
427+
428+
let (tx, rx) = mpsc::channel();
429+
self.flush_rx = Some(rx);
430+
431+
thread::spawn(move || {
432+
match frozen.flush(
433+
jstable_path.to_str().unwrap(),
434+
collection_name,
435+
index_threshold,
436+
) {
437+
Ok(_) => {
438+
let path_str = jstable_path.to_str().unwrap();
439+
let filter = match jstable::read_filter(path_str) {
440+
Ok(f) => f,
441+
Err(e) => {
442+
let _ = tx.send(Err(e.to_string()));
443+
return;
444+
}
445+
};
446+
let index = match jstable::read_index(path_str) {
447+
Ok(i) => i,
448+
Err(e) => {
449+
let _ = tx.send(Err(e.to_string()));
450+
return;
451+
}
452+
};
453+
let _ = tx.send(Ok(LoadedTable { filter, index }));
454+
}
455+
Err(e) => {
456+
let _ = tx.send(Err(e.to_string()));
457+
}
458+
}
459+
});
343460
}
344461

345462
fn compact(&mut self) {
@@ -402,8 +519,10 @@ impl Collection {
402519

403520
HybridIterator {
404521
mem_iter: self.memtable.documents.iter(),
522+
frozen_iter: self.frozen_memtable.as_ref().map(|m| m.documents.iter()),
405523
disk_iter,
406524
memtable: &self.memtable.documents,
525+
frozen_memtable: self.frozen_memtable.as_ref().map(|m| &m.documents),
407526
phase: ScanPhase::MemTable,
408527
predicate,
409528
projections,
@@ -420,7 +539,18 @@ impl Collection {
420539
return Some(doc.clone());
421540
}
422541

423-
// 2. Check JSTables (Newer to Older)
542+
// 2. Check Frozen MemTable
543+
if let Some(frozen) = &self.frozen_memtable {
544+
if let Some(doc) = frozen.documents.get(id) {
545+
use jsonb_schema::Value as JsonbValue;
546+
if matches!(doc, JsonbValue::Null) {
547+
return None; // Tombstone
548+
}
549+
return Some(doc.clone());
550+
}
551+
}
552+
553+
// 3. Check JSTables (Newer to Older)
424554
let hash = {
425555
use std::hash::{Hash, Hasher};
426556
let mut hasher = std::collections::hash_map::DefaultHasher::new();
@@ -633,6 +763,11 @@ impl DB {
633763
pub fn get(&self, collection: &str, id: &str) -> Result<Option<Value>, String> {
634764
self.get_collection(collection).map(|c| c.get(id))
635765
}
766+
767+
pub fn wait_for_flush(&mut self, collection: &str) -> Result<(), String> {
768+
self.get_collection_mut(collection)
769+
.map(|c| c.wait_for_ongoing_flush())
770+
}
636771
}
637772

638773
#[cfg(test)]
@@ -668,6 +803,9 @@ mod tests {
668803

669804
db.insert("test", serde_to_jsonb(json!({"a": MEMTABLE_THRESHOLD})))
670805
.unwrap();
806+
807+
db.wait_for_flush("test").unwrap();
808+
671809
let col = db.collections.get("test").unwrap();
672810
assert_eq!(col.memtable.len(), 1);
673811
assert_eq!(col.jstable_count, 1);
@@ -782,11 +920,15 @@ mod tests {
782920
.unwrap();
783921
}
784922

923+
db.wait_for_flush("test").unwrap();
924+
785925
let col = db.collections.get("test").unwrap();
786926
assert_eq!(col.jstable_count, JSTABLE_THRESHOLD - 1);
787927
db.insert("test", serde_to_jsonb(json!({ "a": 999 })))
788928
.unwrap();
789929

930+
db.wait_for_flush("test").unwrap();
931+
790932
let col = db.collections.get("test").unwrap();
791933
assert_eq!(col.jstable_count, 1);
792934
}
@@ -813,6 +955,8 @@ mod tests {
813955
db.insert("test", serde_to_jsonb(json!({ "trigger_1": 1 })))
814956
.unwrap();
815957

958+
db.wait_for_flush("test").unwrap();
959+
816960
let col = db.collections.get("test").unwrap();
817961
assert_eq!(col.jstable_count, 1);
818962

@@ -825,6 +969,8 @@ mod tests {
825969
db.insert("test", serde_to_jsonb(json!({ "trigger_2": 1 })))
826970
.unwrap();
827971

972+
db.wait_for_flush("test").unwrap();
973+
828974
let col = db.collections.get("test").unwrap();
829975
assert_eq!(col.jstable_count, 2);
830976

@@ -837,6 +983,8 @@ mod tests {
837983
.unwrap();
838984
}
839985

986+
db.wait_for_flush("test").unwrap();
987+
840988
let col = db.collections.get("test").unwrap();
841989
assert_eq!(col.jstable_count, 1);
842990

src/storage.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ impl MemTable {
3131
}
3232

3333
pub fn flush(
34-
self,
34+
&self,
3535
path: &str,
3636
collection: String,
3737
index_threshold: u64,
@@ -44,11 +44,11 @@ impl MemTable {
4444
// Sort documents by ID for JSTable
4545
let sorted_docs: BTreeMap<String, StoredValue> = self
4646
.documents
47-
.into_iter()
48-
.map(|(k, v)| (k, StoredValue::Static(v)))
47+
.iter()
48+
.map(|(k, v)| (k.clone(), StoredValue::Static(v.clone())))
4949
.collect();
5050

51-
let jstable = JSTable::new(timestamp, collection, self.schema, sorted_docs);
51+
let jstable = JSTable::new(timestamp, collection, self.schema.clone(), sorted_docs);
5252
jstable.write(path, index_threshold)
5353
}
5454

0 commit comments

Comments
 (0)