diff --git a/datasketches/Cargo.toml b/datasketches/Cargo.toml index eb299e8..3f48668 100644 --- a/datasketches/Cargo.toml +++ b/datasketches/Cargo.toml @@ -45,6 +45,7 @@ frequencies = [] hll = [] tdigest = [] theta = [] +tuple = ["theta"] [dev-dependencies] googletest = { workspace = true } diff --git a/datasketches/src/codec/decode.rs b/datasketches/src/codec/decode.rs index 37c1523..d2f2364 100644 --- a/datasketches/src/codec/decode.rs +++ b/datasketches/src/codec/decode.rs @@ -38,6 +38,16 @@ impl SketchSlice<'_> { self.slice.set_position(pos + n); } + /// Returns the not-yet-read portion of the underlying slice. + /// + /// Useful for handing the remaining bytes to a variable-length decoder that reports how many + /// bytes it consumed; pair it with [`advance`](Self::advance). + pub fn remaining(&self) -> &[u8] { + let buf = self.slice.get_ref(); + let pos = (self.slice.position() as usize).min(buf.len()); + &buf[pos..] + } + /// Reads exactly `buf.len()` bytes from the slice into `buf`. pub fn read_exact(&mut self, buf: &mut [u8]) -> io::Result<()> { self.slice.read_exact(buf) diff --git a/datasketches/src/codec/family.rs b/datasketches/src/codec/family.rs index cc1ce45..c6ab908 100644 --- a/datasketches/src/codec/family.rs +++ b/datasketches/src/codec/family.rs @@ -53,6 +53,15 @@ impl Family { max_pre_longs: 1, }; + /// Tuple Sketch for cardinality estimation with per-key summaries. + #[cfg(feature = "tuple")] + pub const TUPLE: Family = Family { + id: 9, + name: "TUPLE", + min_pre_longs: 1, + max_pre_longs: 3, + }; + /// The Frequency family of sketches. #[cfg(feature = "frequencies")] pub const FREQUENCY: Family = Family { diff --git a/datasketches/src/lib.rs b/datasketches/src/lib.rs index 65f0ddd..4068848 100644 --- a/datasketches/src/lib.rs +++ b/datasketches/src/lib.rs @@ -45,6 +45,8 @@ pub mod hll; pub mod tdigest; #[cfg(feature = "theta")] pub mod theta; +#[cfg(feature = "tuple")] +pub mod tuple; // common modules pub mod codec; diff --git a/datasketches/src/theta/hash_table.rs b/datasketches/src/theta/hash_table.rs index 1ee498d..ccdc8ee 100644 --- a/datasketches/src/theta/hash_table.rs +++ b/datasketches/src/theta/hash_table.rs @@ -24,12 +24,7 @@ use crate::theta::HASH_TABLE_REBUILD_THRESHOLD; use crate::theta::HASH_TABLE_RESIZE_THRESHOLD; use crate::theta::MAX_THETA; use crate::theta::MIN_LG_K; - -/// Stride hash bits (7 bits for stride calculation) -const STRIDE_HASH_BITS: u8 = 7; - -/// Stride mask -const STRIDE_MASK: u64 = (1 << STRIDE_HASH_BITS) - 1; +use crate::theta::STRIDE_MASK; /// Specific hash table for theta sketch /// @@ -391,7 +386,7 @@ impl ThetaHashTable { /// Compute initial lg_size for hash table based on target lg_size, minimum lg_size, and resize /// factor. Make sure `lg_target = lg_init + n * lg_resize_factor`, where `n` is an integer and /// `lg_init >= lg_min` -fn starting_sub_multiple(lg_target: u8, lg_min: u8, lg_resize_factor: u8) -> u8 { +pub(crate) fn starting_sub_multiple(lg_target: u8, lg_min: u8, lg_resize_factor: u8) -> u8 { if lg_target <= lg_min { lg_min } else if lg_resize_factor == 0 { @@ -402,7 +397,7 @@ fn starting_sub_multiple(lg_target: u8, lg_min: u8, lg_resize_factor: u8) -> u8 } /// Compute initial theta for hash table based on sampling probability. -fn starting_theta_from_sampling_probability(sampling_probability: f32) -> u64 { +pub(crate) fn starting_theta_from_sampling_probability(sampling_probability: f32) -> u64 { if sampling_probability < 1.0 { (MAX_THETA as f64 * sampling_probability as f64) as u64 } else { diff --git a/datasketches/src/theta/mod.rs b/datasketches/src/theta/mod.rs index 03b5e2a..f25d890 100644 --- a/datasketches/src/theta/mod.rs +++ b/datasketches/src/theta/mod.rs @@ -45,6 +45,12 @@ mod intersection; mod serialization; mod sketch; +// These helpers are re-exported only for the Tuple sketch, which reuses the Theta hash-table +// sizing. +#[cfg(feature = "tuple")] +pub(crate) use self::hash_table::starting_sub_multiple; +#[cfg(feature = "tuple")] +pub(crate) use self::hash_table::starting_theta_from_sampling_probability; pub use self::intersection::ThetaIntersection; pub use self::sketch::CompactThetaSketch; pub use self::sketch::ThetaSketch; @@ -52,14 +58,18 @@ pub use self::sketch::ThetaSketchBuilder; pub use self::sketch::ThetaSketchView; /// Maximum theta value (signed max for compatibility with Java) -const MAX_THETA: u64 = i64::MAX as u64; +pub(crate) const MAX_THETA: u64 = i64::MAX as u64; /// Minimum log2 of K -const MIN_LG_K: u8 = 5; +pub(crate) const MIN_LG_K: u8 = 5; /// Maximum log2 of K -const MAX_LG_K: u8 = 26; +pub(crate) const MAX_LG_K: u8 = 26; /// Default log2 of K -const DEFAULT_LG_K: u8 = 12; +pub(crate) const DEFAULT_LG_K: u8 = 12; /// Resize threshold (0.5 = 50% load factor) -const HASH_TABLE_RESIZE_THRESHOLD: f64 = 0.5; +pub(crate) const HASH_TABLE_RESIZE_THRESHOLD: f64 = 0.5; /// Rebuild threshold (15/16 = 93.75% load factor) -const HASH_TABLE_REBUILD_THRESHOLD: f64 = 15.0 / 16.0; +pub(crate) const HASH_TABLE_REBUILD_THRESHOLD: f64 = 15.0 / 16.0; +/// Stride hash bits (7 bits for stride calculation) +pub(crate) const STRIDE_HASH_BITS: u8 = 7; +/// Stride mask +pub(crate) const STRIDE_MASK: u64 = (1 << STRIDE_HASH_BITS) - 1; diff --git a/datasketches/src/tuple/a_not_b.rs b/datasketches/src/tuple/a_not_b.rs new file mode 100644 index 0000000..670b12c --- /dev/null +++ b/datasketches/src/tuple/a_not_b.rs @@ -0,0 +1,429 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tuple sketch set difference (`A and not B`). +//! +//! [`TupleAnotB`] computes the set difference of two Tuple sketches: the keys retained in `A` that +//! are not present in `B`. Surviving keys keep their summaries from `A` unchanged, so unlike the +//! union and intersection this operation needs no combine policy. + +use std::collections::HashSet; + +use crate::error::Error; +use crate::hash::DEFAULT_UPDATE_SEED; +use crate::hash::compute_seed_hash; +use crate::theta::MAX_THETA; +use crate::tuple::sketch::CompactTupleSketch; +use crate::tuple::sketch::TupleSketchView; + +/// Set difference operator (`A and not B`) for Tuple sketches. +/// +/// This is a stateless operator (other than the seed): each call to [`compute`](Self::compute) +/// takes two input sketches and returns a new [`CompactTupleSketch`]. Surviving keys carry their +/// summaries straight from `A`. +/// +/// # Examples +/// +/// ``` +/// # use datasketches::tuple::{TupleAnotB, UpdatableTupleSketch}; +/// let mut a = UpdatableTupleSketch::::builder().build(); +/// a.update("apple", 1); +/// a.update("banana", 1); +/// +/// let mut b = UpdatableTupleSketch::::builder().build(); +/// b.update("banana", 1); +/// +/// let a_not_b = TupleAnotB::new_with_default_seed(); +/// let result = a_not_b.compute(&a, &b).unwrap(); +/// assert_eq!(result.num_retained(), 1); // only "apple" survives +/// ``` +#[derive(Debug, Clone, Copy)] +pub struct TupleAnotB { + seed_hash: u16, +} + +impl TupleAnotB { + /// Creates a new set difference operator for the given `seed`. + pub fn new(seed: u64) -> Self { + Self { + seed_hash: compute_seed_hash(seed), + } + } + + /// Creates a new set difference operator with the default seed. + pub fn new_with_default_seed() -> Self { + Self::new(DEFAULT_UPDATE_SEED) + } + + /// Computes `a and not b`, returning an ordered compact sketch. + /// + /// # Errors + /// + /// See [`compute_with_ordered`](Self::compute_with_ordered). + pub fn compute(&self, a: &A, b: &B) -> Result, Error> + where + A: TupleSketchView, + B: TupleSketchView, + S: Clone, + { + self.compute_with_ordered(a, b, true) + } + + /// Computes `a and not b`. + /// + /// The result retains every key of `a` (below the combined theta) that is not present in `b`, + /// keeping the summaries from `a`. If `ordered` is true, the retained entries are sorted + /// ascending by hash. + /// + /// # Errors + /// + /// Returns an error if either non-trivial input has a seed hash that differs from this + /// operator's seed. + pub fn compute_with_ordered( + &self, + a: &A, + b: &B, + ordered: bool, + ) -> Result, Error> + where + A: TupleSketchView, + B: TupleSketchView, + S: Clone, + { + // If A is empty the result is an (empty) copy of A. As with the union and intersection, an + // empty input carries no keys, so its seed is not validated. + if a.is_empty() { + return Ok(Self::compact_from_view(a, ordered)); + } + + // A is non-empty, so its seed must be compatible. + if a.seed_hash() != self.seed_hash { + return Err(Error::invalid_argument(format!( + "A seed hash mismatch: expected {}, got {}", + self.seed_hash, + a.seed_hash() + ))); + } + + // An empty B subtracts nothing, so the result is simply a copy of A. This also covers the + // "A is non-empty but has no retained keys" state: B's seed and theta must not influence + // the result, so we return before touching them. + if b.is_empty() { + return Ok(Self::compact_from_view(a, ordered)); + } + + // B is non-empty, so its seed must be compatible. + if b.seed_hash() != self.seed_hash { + return Err(Error::invalid_argument(format!( + "B seed hash mismatch: expected {}, got {}", + self.seed_hash, + b.seed_hash() + ))); + } + + let theta = a.theta64().min(b.theta64()); + // A is non-empty here; the result only becomes empty if everything is subtracted in exact + // mode (handled below). + let mut is_empty = false; + + let entries: Vec<(u64, S)> = if b.num_retained() == 0 { + a.iter() + .filter(|(hash, _)| *hash < theta) + .map(|(hash, summary)| (hash, summary.clone())) + .collect() + } else { + let mut b_keys: HashSet = HashSet::with_capacity(b.num_retained()); + for (hash, _) in b.iter() { + if hash < theta { + b_keys.insert(hash); + } else if b.is_ordered() { + break; + } + } + + let mut entries = Vec::new(); + for (hash, summary) in a.iter() { + if hash < theta { + if !b_keys.contains(&hash) { + entries.push((hash, summary.clone())); + } + } else if a.is_ordered() { + break; + } + } + entries + }; + + if entries.is_empty() && theta == MAX_THETA { + is_empty = true; + } + + let out_ordered = ordered || a.is_ordered(); + let mut entries = entries; + if ordered && !a.is_ordered() && entries.len() > 1 { + entries.sort_unstable_by_key(|(hash, _)| *hash); + } + + Ok(CompactTupleSketch::from_parts( + entries, + theta, + self.seed_hash, + out_ordered, + is_empty, + )) + } + + /// Builds a compact sketch that is a copy of the view `a`. + fn compact_from_view(a: &V, ordered: bool) -> CompactTupleSketch + where + V: TupleSketchView, + S: Clone, + { + let mut entries: Vec<(u64, S)> = a + .iter() + .map(|(hash, summary)| (hash, summary.clone())) + .collect(); + let out_ordered = ordered || a.is_ordered(); + if ordered && !a.is_ordered() && entries.len() > 1 { + entries.sort_unstable_by_key(|(hash, _)| *hash); + } + CompactTupleSketch::from_parts( + entries, + a.theta64(), + a.seed_hash(), + out_ordered, + a.is_empty(), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::NumStdDev; + use crate::error::ErrorKind; + use crate::tuple::UpdatableTupleSketch; + + fn sorted_entries(sketch: &CompactTupleSketch) -> Vec<(u64, u64)> { + let mut entries: Vec<(u64, u64)> = sketch.iter().map(|(h, &s)| (h, s)).collect(); + entries.sort_unstable(); + entries + } + + #[test] + fn a_not_b_basic_difference() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..1000 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 500..1500 { + b.update(i, 1u64); + } + + let result = TupleAnotB::new_with_default_seed().compute(&a, &b).unwrap(); + // Keys 0..500 are only in A (exact mode). + assert_eq!(result.num_retained(), 500); + assert_eq!(result.estimate(), 500.0); + } + + #[test] + fn a_not_b_keeps_summaries_from_a() { + let mut a = UpdatableTupleSketch::::builder().build(); + a.update("only_a", 7u64); + a.update("shared", 7u64); + let mut b = UpdatableTupleSketch::::builder().build(); + b.update("shared", 99u64); + + let result = TupleAnotB::new_with_default_seed().compute(&a, &b).unwrap(); + assert_eq!(result.num_retained(), 1); + // The surviving key keeps A's summary; B's summary is never combined in. + assert_eq!(result.iter().next().unwrap().1, &7); + } + + #[test] + fn a_not_b_with_empty_b_returns_a() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..100 { + a.update(i, 3u64); + } + let b = UpdatableTupleSketch::::builder().build(); + + let result = TupleAnotB::new_with_default_seed().compute(&a, &b).unwrap(); + assert_eq!(result.num_retained(), 100); + assert!(result.iter().all(|(_, &s)| s == 3)); + } + + #[test] + fn a_not_b_with_empty_a_is_empty() { + let a = UpdatableTupleSketch::::builder().build(); + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 0..100 { + b.update(i, 1u64); + } + + let result = TupleAnotB::new_with_default_seed().compute(&a, &b).unwrap(); + assert!(result.is_empty()); + assert_eq!(result.num_retained(), 0); + assert_eq!(result.estimate(), 0.0); + } + + #[test] + fn a_not_b_with_superset_b_is_empty() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..500 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 0..1000 { + b.update(i, 1u64); + } + + let result = TupleAnotB::new_with_default_seed().compute(&a, &b).unwrap(); + assert_eq!(result.num_retained(), 0); + assert_eq!(result.estimate(), 0.0); + } + + #[test] + fn a_not_b_with_disjoint_b_returns_a() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..500 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 500..1000 { + b.update(i, 1u64); + } + + let result = TupleAnotB::new_with_default_seed().compute(&a, &b).unwrap(); + assert_eq!(result.num_retained(), 500); + } + + #[test] + fn a_not_b_accepts_updatable_and_compact_inputs() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..1000 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 500..1500 { + b.update(i, 1u64); + } + let b_compact = b.compact(true); + + // a (updatable) not b (compact) + let result = TupleAnotB::new_with_default_seed() + .compute(&a, &b_compact) + .unwrap(); + assert_eq!(result.num_retained(), 500); + + // a (compact) not b (compact) + let a_compact = a.compact(true); + let result2 = TupleAnotB::new_with_default_seed() + .compute(&a_compact, &b_compact) + .unwrap(); + assert_eq!(result2.num_retained(), 500); + } + + #[test] + fn a_not_b_result_is_ordered_when_requested() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..1000 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 500..1500 { + b.update(i, 1u64); + } + + let result = TupleAnotB::new_with_default_seed() + .compute_with_ordered(&a, &b, true) + .unwrap(); + assert!(result.is_ordered()); + let entries = sorted_entries(&result); + let iter_order: Vec = result.iter().map(|(h, _)| h).collect(); + let sorted_order: Vec = entries.iter().map(|(h, _)| *h).collect(); + assert_eq!(iter_order, sorted_order); + } + + #[test] + fn a_not_b_rejects_seed_mismatch() { + let mut a = UpdatableTupleSketch::::builder().seed(1).build(); + a.update(1, 1u64); + let mut b = UpdatableTupleSketch::::builder().seed(1).build(); + b.update(2, 1u64); + + let err = TupleAnotB::new(2).compute(&a, &b).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::InvalidArgument); + } + + #[test] + fn a_not_b_validates_a_seed_even_when_b_is_empty() { + // A is non-empty with a seed that does not match the operator; B is empty. The empty-B fast + // path must not bypass A's seed check. + let mut a = UpdatableTupleSketch::::builder().seed(1).build(); + a.update(1, 1u64); + let b = UpdatableTupleSketch::::builder().seed(1).build(); // empty + + let err = TupleAnotB::new(2).compute(&a, &b).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::InvalidArgument); + } + + #[test] + fn a_not_b_empty_b_returns_a_for_non_empty_zero_retained_a() { + // A is logically non-empty but retains no keys (the single update is screened out by the + // sampling theta). + let mut a = UpdatableTupleSketch::::builder() + .sampling_probability(0.001) + .build(); + a.update(1u64, 1u64); + assert!(!a.is_empty()); + assert_eq!(a.num_retained(), 0); + + // B is empty and built with a different seed. Since an empty B subtracts nothing, the + // result must be a copy of A: no seed error, and A's theta is preserved (not + // lowered by B). + let b = UpdatableTupleSketch::::builder().seed(999).build(); + + let result = TupleAnotB::new_with_default_seed().compute(&a, &b).unwrap(); + assert!(!result.is_empty()); + assert_eq!(result.num_retained(), 0); + assert_eq!(result.theta64(), a.theta64()); + } + + #[test] + fn a_not_b_in_estimation_mode_estimates_within_bounds() { + let mut a = UpdatableTupleSketch::::builder().lg_k(8).build(); + for i in 0..75000 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().lg_k(8).build(); + for i in 25000..75000 { + b.update(i, 1u64); + } + + let result = TupleAnotB::new_with_default_seed().compute(&a, &b).unwrap(); + assert!(result.is_estimation_mode()); + // True difference size is 25000 (keys 0..25000). + let lower = result.lower_bound(NumStdDev::Three); + let upper = result.upper_bound(NumStdDev::Three); + assert!( + lower <= 25000.0 && 25000.0 <= upper, + "expected 25000 in [{lower}, {upper}]" + ); + } +} diff --git a/datasketches/src/tuple/hash_table.rs b/datasketches/src/tuple/hash_table.rs new file mode 100644 index 0000000..ebc08e0 --- /dev/null +++ b/datasketches/src/tuple/hash_table.rs @@ -0,0 +1,780 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::hash::Hash; + +use crate::common::ResizeFactor; +use crate::hash::MurmurHash3X64128; +use crate::hash::compute_seed_hash; +use crate::theta::HASH_TABLE_REBUILD_THRESHOLD; +use crate::theta::HASH_TABLE_RESIZE_THRESHOLD; +use crate::theta::MAX_THETA; +use crate::theta::MIN_LG_K; +use crate::theta::STRIDE_MASK; +use crate::theta::starting_sub_multiple; +use crate::theta::starting_theta_from_sampling_probability; + +/// A retained entry: a hash key together with its associated summary. +#[derive(Debug)] +struct TupleEntry { + hash: u64, + summary: S, +} + +/// Specific hash table for tuple sketch. +/// +/// This is the Theta sketch hash table extended so that each retained key carries a user-defined +/// summary. It maintains an array with capacity up to 2^lg_max_size: +/// * Before it reaches the max capacity, it will extend the array based on resize_factor. +/// * After it reaches the capacity bigger than 2^lg_nom_size, every time the number of entries +/// exceeds the threshold, it will rebuild the table: only keep the min 2^lg_nom_size entries and +/// update the theta to the k-th smallest entry. +/// +/// Unlike the Theta hash table, when a key is inserted that already exists, the incoming update is +/// merged into the existing summary rather than discarded. +#[derive(Debug)] +pub(super) struct TupleHashTable { + lg_cur_size: u8, + lg_nom_size: u8, + lg_max_size: u8, + resize_factor: ResizeFactor, + sampling_probability: f32, + hash_seed: u64, + + // Logical emptiness of the source set. + // + // * `false` if any update has been attempted (even if screened by theta) + // * `true` if no updates have been attempted. + // + // This can be false even when `num_retained` is 0. + is_empty: bool, + + theta: u64, + + // Using `None` to represent zero value. + entries: Vec>>, + + // Number of retained non-zero hashes currently stored in `entries`. + num_retained: usize, +} + +impl TupleHashTable { + /// Create a new hash table + pub fn new( + lg_nom_size: u8, + resize_factor: ResizeFactor, + sampling_probability: f32, + hash_seed: u64, + ) -> Self { + let lg_max_size = lg_nom_size + 1; + let lg_cur_size = starting_sub_multiple(lg_max_size, MIN_LG_K, resize_factor.lg_value()); + Self::from_raw_parts( + lg_cur_size, + lg_nom_size, + resize_factor, + sampling_probability, + starting_theta_from_sampling_probability(sampling_probability), + hash_seed, + true, + ) + } + + /// Constructs a table from raw internal state. + /// + /// # Panics + /// + /// Panics if `lg_cur_size > lg_nom_size + 1`. (`lg_nom_size + 1 == lg_max_size`) + pub fn from_raw_parts( + lg_cur_size: u8, + lg_nom_size: u8, + resize_factor: ResizeFactor, + sampling_probability: f32, + theta: u64, + hash_seed: u64, + is_empty: bool, + ) -> Self { + let lg_max_size = lg_nom_size + 1; + assert!( + lg_cur_size <= lg_max_size, + "lg_cur_size must be <= lg_nom_size + 1, got lg_cur_size={lg_cur_size}, lg_nom_size={lg_nom_size}" + ); + let size = if lg_cur_size > 0 { 1 << lg_cur_size } else { 0 }; + let entries = std::iter::repeat_with(|| None).take(size).collect(); + Self { + lg_cur_size, + lg_nom_size, + lg_max_size, + resize_factor, + sampling_probability, + hash_seed, + is_empty, + theta, + entries, + num_retained: 0, + } + } + + /// Hash a value with the table seed and return the hash. + fn hash(&self, value: T) -> u64 { + let mut hasher = MurmurHash3X64128::with_seed(self.hash_seed); + value.hash(&mut hasher); + let (h1, _) = hasher.finish128(); + h1 >> 1 // To make it compatible with Java version + } + + /// Find an entry in the hash table. + /// + /// Returns the index of the entry if found, otherwise None. The entry may have been inserted or + /// empty. + fn find_in_curr_entries(&self, key: u64) -> Option { + Self::find_in_entries(&self.entries, key, self.lg_cur_size) + } + + /// Find index in a given entries. + /// + /// Returns the index of the entry if found, otherwise None. The entry may have been inserted or + /// empty. + fn find_in_entries(entries: &[Option>], key: u64, lg_size: u8) -> Option { + if entries.is_empty() { + return None; + } + + let size = entries.len(); + let mask = size - 1; + let stride = Self::get_stride(key, lg_size); + let mut index = (key as usize) & mask; + let loop_index = index; + + loop { + match &entries[index] { + None => return Some(index), + Some(entry) if entry.hash == key => return Some(index), + _ => {} + } + index = (index + stride) & mask; + if index == loop_index { + return None; + } + } + } + + /// Hashes a key and inserts or updates its summary via a single callback. + /// + /// See [`upsert`](Self::upsert) for the callback contract. Returns true if a new entry was + /// created, false if the key already existed or the hash was screened out by theta. + pub fn update(&mut self, key: T, f: F) -> bool + where + T: Hash, + F: FnOnce(Option<&mut S>) -> Option, + { + let hash = self.hash(key); + self.upsert(hash, f) + } + + /// Inserts or updates the summary slot for a pre-hashed key. + /// + /// The callback `f` is invoked with the current summary for `hash`: + /// * `Some(existing)` if the key is already retained. The callback should modify it in place; + /// its return value is ignored. + /// * `None` if the key is new. The callback returns `Some(summary)` to insert it, or `None` to + /// decline insertion. + /// + /// Using a single callback ensures any captured update value is consumed exactly once, so it + /// works for both the update sketch (folding an update value) and set operations (merging an + /// incoming summary) without requiring the value to be `Copy` or `Clone`. + /// + /// Returns true if a new entry was created, false otherwise (existing key, declined insertion, + /// or a hash screened out by theta). + pub fn upsert(&mut self, hash: u64, f: F) -> bool + where + F: FnOnce(Option<&mut S>) -> Option, + { + self.is_empty = false; + + if hash == 0 || hash >= self.theta { + return false; + } + + let Some(index) = self.find_in_curr_entries(hash) else { + unreachable!( + "Resize or rebuild should be called to make sure it always can find the entry." + ); + }; + + // Already exists: let the callback merge into the retained summary in place. + if let Some(entry) = self.entries[index].as_mut() { + f(Some(&mut entry.summary)); + return false; + } + + // New key: the callback may decline by returning None. + let Some(summary) = f(None) else { + return false; + }; + self.entries[index] = Some(TupleEntry { hash, summary }); + self.num_retained += 1; + + // Check if we need to resize or rebuild + let capacity = self.get_capacity(); + if self.num_retained > capacity { + if self.lg_cur_size <= self.lg_nom_size { + self.resize(); + } else { + self.rebuild(); + } + } + true + } + + /// Get capacity threshold + fn get_capacity(&self) -> usize { + let fraction = if self.lg_cur_size <= self.lg_nom_size { + HASH_TABLE_RESIZE_THRESHOLD + } else { + HASH_TABLE_REBUILD_THRESHOLD + }; + (fraction * self.entries.len() as f64) as usize + } + + /// Resize the hash table + fn resize(&mut self) { + let new_lg_size = std::cmp::min( + self.lg_cur_size + self.resize_factor.lg_value(), + self.lg_max_size, + ); + let new_size = 1 << new_lg_size; + + // Get new entries and rehash all entries + let mut new_entries: Vec>> = + std::iter::repeat_with(|| None).take(new_size).collect(); + for entry in std::mem::take(&mut self.entries).into_iter().flatten() { + let Some(idx) = Self::find_in_entries(&new_entries, entry.hash, new_lg_size) else { + unreachable!( + "find_in_entries should always return Some if the entry is not empty." + ); + }; + new_entries[idx] = Some(entry); + } + + self.entries = new_entries; + self.lg_cur_size = new_lg_size; + } + + /// Rebuild the hash table: + /// The number of entries will be reduced to the nominal size k. + fn rebuild(&mut self) { + let k = 1usize << self.lg_nom_size; + + // Select the k-th smallest entry as new theta and keep the lesser entries. + let mut retained: Vec> = std::mem::take(&mut self.entries) + .into_iter() + .flatten() + .collect(); + let kth_hash = { + let (_lesser, kth, _greater) = retained.select_nth_unstable_by_key(k, |e| e.hash); + kth.hash + }; + self.theta = kth_hash; + retained.truncate(k); + + // Rebuild the table with the lesser entries. + let size = 1 << self.lg_cur_size; + let mut new_entries: Vec>> = + std::iter::repeat_with(|| None).take(size).collect(); + let mut num_inserted = 0; + for entry in retained { + if let Some(idx) = Self::find_in_entries(&new_entries, entry.hash, self.lg_cur_size) { + new_entries[idx] = Some(entry); + num_inserted += 1; + } else { + unreachable!( + "find_in_entries should always return Some if the entry is not empty." + ); + } + } + + assert_eq!( + num_inserted, k, + "Number of inserted entries should be equal to k." + ); + self.num_retained = num_inserted; + self.entries = new_entries; + } + + /// Trim the table to nominal size k + pub fn trim(&mut self) { + if self.num_retained > (1 << self.lg_nom_size) { + self.rebuild(); + } + } + + /// Reset the table to empty state + pub fn reset(&mut self) { + let init_theta = starting_theta_from_sampling_probability(self.sampling_probability); + let init_lg_cur = starting_sub_multiple( + self.lg_nom_size + 1, + MIN_LG_K, + self.resize_factor.lg_value(), + ); + + // clear entries + let size = 1 << init_lg_cur; + self.entries.clear(); + self.entries.resize_with(size, || None); + self.num_retained = 0; + self.theta = init_theta; + self.is_empty = true; + self.lg_cur_size = init_lg_cur; + } + + /// Return number of retained entries + pub fn num_retained(&self) -> usize { + self.num_retained + } + + /// Get theta + pub fn theta(&self) -> u64 { + self.theta + } + + /// Check if emptiness of the source set + pub fn is_empty(&self) -> bool { + self.is_empty + } + + /// Get iterator over retained entries as `(hash, &summary)` pairs. + pub fn iter(&self) -> impl Iterator + '_ { + self.entries + .iter() + .filter_map(|slot| slot.as_ref().map(|entry| (entry.hash, &entry.summary))) + } + + /// Get log2 of nominal size + pub fn lg_nom_size(&self) -> u8 { + self.lg_nom_size + } + + /// Get the hash of the seed that was used to hash the input. + pub fn seed_hash(&self) -> u16 { + compute_seed_hash(self.hash_seed) + } + + /// Returns a reference to the summary stored for `hash`, or `None` if the hash is not retained. + pub fn get(&self, hash: u64) -> Option<&S> { + if hash == 0 { + return None; + } + let index = self.find_in_curr_entries(hash)?; + match &self.entries[index] { + Some(entry) if entry.hash == hash => Some(&entry.summary), + _ => None, + } + } + + /// Inserts a `(hash, summary)` pair, taking ownership of `summary`. + /// + /// Returns true if a new entry was created. Returns false (dropping `summary`) if the hash is + /// already retained or is screened out by theta. This is the summary-carrying analogue of the + /// Theta hash table's `try_insert_hash`. + pub fn try_insert(&mut self, hash: u64, summary: S) -> bool { + self.upsert(hash, |existing| match existing { + Some(_) => None, + None => Some(summary), + }) + } + + /// Set empty flag + pub fn set_empty(&mut self, is_empty: bool) { + self.is_empty = is_empty; + } + + /// Get the hash seed used by this table. + pub fn hash_seed(&self) -> u64 { + self.hash_seed + } + + /// Sets theta value. + pub fn set_theta(&mut self, theta: u64) { + assert!( + (1..=MAX_THETA).contains(&theta), + "theta must be in [1, {MAX_THETA}], got {theta}" + ); + self.theta = theta; + } + + /// Returns minimal lg_size where rebuild-capacity can hold `count`. + pub fn lg_size_from_count_for_rebuild(count: usize, load_factor: f64) -> u8 { + let log2 = |n: usize| { + if n == 0 { 0_u8 } else { n.ilog2() as u8 } + }; + let log2_n = log2(count); + log2_n + + (if count > (((1u128 << ((log2_n as u32) + 1)) as f64) * load_factor) as usize { + 2 + } else { + 1 + }) + } + + /// Get stride for hash table probing + fn get_stride(key: u64, lg_size: u8) -> usize { + (2 * ((key >> (lg_size)) & STRIDE_MASK) + 1) as usize + } + + /// Returns the estimated size of the heap allocations in bytes + pub fn estimated_size(&self) -> usize { + self.entries.capacity() * std::mem::size_of::>>() + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::hash::DEFAULT_UPDATE_SEED; + + /// Inserts a key with count-style summary semantics: a new key starts at 1, a repeated key + /// increments the retained count. Returns true if a new entry was created. + fn insert(table: &mut TupleHashTable, value: impl Hash) -> bool { + table.update(value, |existing| match existing { + Some(count) => { + *count += 1; + None + } + None => Some(1), + }) + } + + /// Collect retained `(hash, count)` pairs. + fn collect(table: &TupleHashTable) -> Vec<(u64, u64)> { + table.iter().map(|(hash, &count)| (hash, count)).collect() + } + + #[test] + fn test_new_hash_table() { + let table = TupleHashTable::::new(8, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + + assert_eq!( + table.lg_cur_size, + starting_sub_multiple(8 + 1, MIN_LG_K, ResizeFactor::X8.lg_value()) + ); + assert_eq!(table.theta, starting_theta_from_sampling_probability(1.0)); + assert_eq!(table.num_retained(), 0); + assert!(table.is_empty()); + assert_eq!(table.iter().count(), 0); + } + + #[test] + fn test_hash_and_theta_screen_behavior() { + let mut table = TupleHashTable::::new(8, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + + // With MAX_THETA, hashes are computed normally. + let hash1 = table.hash("test1"); + let hash2 = table.hash("test2"); + assert_ne!(hash1, 0); + assert_ne!(hash2, 0); + assert_ne!(hash1, hash2); + + // With low theta, update should be screened out. + table.theta = 1; + assert!(!insert(&mut table, "test3")); + } + + #[test] + fn test_insert() { + let mut table = TupleHashTable::::new(5, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + + assert!(insert(&mut table, "test_value")); + assert_eq!(table.num_retained(), 1); + assert!(!table.is_empty()); + + // Insert the same value again: not a new entry, but the summary is merged. + assert!(!insert(&mut table, "test_value")); + assert_eq!(table.num_retained(), 1); + assert_eq!(collect(&table), vec![(table.hash("test_value"), 2)]); + + // Force screening and verify insertion fails + table.theta = 0; + assert!(!insert(&mut table, "screened")); + assert_eq!(table.num_retained(), 1); + assert!(!table.is_empty()); + } + + #[test] + fn test_insert_multiple_values() { + let mut table = TupleHashTable::::new(8, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + + let mut inserted_count = 0; + for i in 0..10 { + if insert(&mut table, format!("value_{}", i)) { + inserted_count += 1; + } + } + + assert_eq!(table.num_retained(), inserted_count); + assert!(!table.is_empty()); + assert_eq!(table.iter().count(), inserted_count); + } + + #[test] + fn test_summary_is_merged_on_collision() { + let mut table = TupleHashTable::::new(8, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + + for _ in 0..5 { + insert(&mut table, "same_key"); + } + + assert_eq!(table.num_retained(), 1); + assert_eq!(collect(&table), vec![(table.hash("same_key"), 5)]); + } + + #[test] + fn test_resize() { + fn populate_values(table: &mut TupleHashTable, count: usize) -> usize { + let mut inserted = 0; + for i in 0..count { + if insert(table, format!("value_{}", i)) { + inserted += 1; + } + } + inserted + } + + { + let mut table = + TupleHashTable::::new(8, ResizeFactor::X2, 1.0, DEFAULT_UPDATE_SEED); + + assert_eq!(table.entries.len(), 32); + + // Insert enough values to trigger resize (50% threshold) + // Capacity = 32 * 0.5 = 16 + let inserted = populate_values(&mut table, 20); + + assert!(table.num_retained() > 0); + assert_eq!(table.num_retained(), inserted); + assert_eq!(table.entries.len(), 64); + } + + { + let mut table = + TupleHashTable::::new(8, ResizeFactor::X4, 1.0, DEFAULT_UPDATE_SEED); + + assert_eq!(table.entries.len(), 32); + + let inserted = populate_values(&mut table, 20); + + assert!(table.num_retained() > 0); + assert_eq!(table.num_retained(), inserted); + assert_eq!(table.entries.len(), 128); + } + } + + #[test] + fn test_rebuild() { + let mut table = TupleHashTable::::new(5, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + + assert_eq!(table.lg_cur_size, 6); + assert_eq!(table.entries.len(), 64); + assert_eq!(table.theta, MAX_THETA); + + // Insert many values to trigger rebuild + for i in 0..100 { + insert(&mut table, format!("value_{}", i)); + } + + let new_theta = table.theta(); + assert!( + new_theta < MAX_THETA, + "Theta should be reduced after rebuild" + ); + + // Continue to insert values to trigger rebuild again + for i in 100..200 { + insert(&mut table, format!("value_{}", i)); + } + + assert_eq!(table.lg_cur_size, 6); + assert!(table.entries.len() >= 64); + assert!(table.theta < new_theta); + } + + #[test] + fn test_trim() { + let mut table = TupleHashTable::::new(5, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + + for i in 0..100 { + insert(&mut table, format!("value_{}", i)); + } + + let before_trim = table.num_retained(); + assert!(before_trim > 32); + + table.trim(); + let after_trim = table.num_retained(); + assert!(after_trim <= 32); + assert!(table.theta() < MAX_THETA); + } + + #[test] + fn test_trim_when_not_needed() { + let mut table = TupleHashTable::::new(8, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + + for i in 0..10 { + insert(&mut table, format!("value_{}", i)); + } + + let before_trim = table.num_retained(); + let before_theta = table.theta(); + table.trim(); + let after_trim = table.num_retained(); + + assert_eq!(before_trim, after_trim); + assert_eq!(before_theta, table.theta()); + } + + #[test] + fn test_reset() { + let mut table = TupleHashTable::::new(8, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + let init_theta = table.theta(); + let init_lg_cur = table.lg_cur_size; + let init_entries = table.entries.len(); + + for i in 0..10 { + insert(&mut table, format!("value_{}", i)); + } + + assert!(!table.is_empty()); + assert!(table.num_retained() > 0); + + table.reset(); + + assert!(table.is_empty()); + assert_eq!(table.num_retained(), 0); + assert_eq!(table.theta(), init_theta); + assert_eq!(table.lg_cur_size, init_lg_cur); + assert_eq!(table.entries.len(), init_entries); + assert_eq!(table.iter().count(), 0); + } + + #[test] + fn test_table_with_sampling() { + let mut table = TupleHashTable::::new( + 8, + ResizeFactor::X8, + 0.5, // sampling_probability = 0.5 + DEFAULT_UPDATE_SEED, + ); + assert_eq!(table.theta(), (MAX_THETA as f64 * 0.5) as u64); + + for i in 0..10 { + insert(&mut table, format!("value_{}", i)); + } + + table.reset(); + + assert_eq!(table.theta(), (MAX_THETA as f64 * 0.5) as u64); + assert!(table.is_empty()); + } + + #[test] + fn test_iterator() { + let mut table = TupleHashTable::::new(8, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + + let mut inserted_hashes = vec![]; + for i in 0..10 { + let hash = table.hash(i); + if insert(&mut table, i) { + inserted_hashes.push(hash); + } + } + + let iter_hashes: Vec = table.iter().map(|(hash, _)| hash).collect(); + assert_eq!(iter_hashes.len(), table.num_retained()); + assert_eq!(iter_hashes.len(), inserted_hashes.len()); + + for hash in &inserted_hashes { + assert!(iter_hashes.contains(hash)); + } + + assert!(!iter_hashes.contains(&0)); + } + + #[test] + fn test_empty_table_operations() { + let mut table = TupleHashTable::::new(8, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + + assert!(table.is_empty()); + assert_eq!(table.num_retained(), 0); + assert_eq!(table.iter().count(), 0); + + // Trim on empty table should not panic + table.trim(); + assert!(table.is_empty()); + + // Reset on empty table should not panic + table.reset(); + assert!(table.is_empty()); + } + + #[test] + fn test_rebuild_preserves_entries_less_than_kth() { + let mut table = TupleHashTable::::new(5, ResizeFactor::X8, 1.0, DEFAULT_UPDATE_SEED); + let k = 1u64 << 5; // k = 32 + + // Insert many values to trigger rebuild + let mut i = 0; + let mut inserted_hashes = vec![]; + loop { + let hash = table.hash(i); + i += 1; + if insert(&mut table, i - 1) { + inserted_hashes.push(hash); + } + if table.num_retained() >= k as usize { + break; + } + } + + let rebuild_threshold = table.get_capacity(); + + loop { + let hash = table.hash(i); + i += 1; + if insert(&mut table, i - 1) { + inserted_hashes.push(hash); + } + if table.num_retained() >= rebuild_threshold { + break; + } + } + + // trigger rebuild + loop { + let hash = table.hash(i); + i += 1; + if insert(&mut table, i - 1) { + inserted_hashes.push(hash); + break; + } + } + + // assert all entries are less than kth + inserted_hashes.sort(); + let kth = inserted_hashes[k as usize]; + assert!(table.iter().all(|(hash, _)| hash < kth)); + assert_eq!(table.theta(), kth); + } +} diff --git a/datasketches/src/tuple/intersection.rs b/datasketches/src/tuple/intersection.rs new file mode 100644 index 0000000..70fa287 --- /dev/null +++ b/datasketches/src/tuple/intersection.rs @@ -0,0 +1,522 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tuple sketch intersection. +//! +//! [`TupleIntersection`] computes the intersection (set AND) of Tuple sketches. The hash-table +//! bookkeeping mirrors the [Theta intersection](crate::theta), with one Tuple-specific addition: +//! for each key retained in both the running result and the incoming sketch, the two summaries are +//! combined with a [`SummaryCombinePolicy`]. +//! +//! Unlike the union there is no default policy: how to combine the summaries of keys present in +//! both inputs is application-specific, so a policy must always be supplied. + +use crate::common::ResizeFactor; +use crate::error::Error; +use crate::hash::DEFAULT_UPDATE_SEED; +use crate::theta::HASH_TABLE_REBUILD_THRESHOLD; +use crate::theta::MAX_THETA; +use crate::tuple::hash_table::TupleHashTable; +use crate::tuple::policy::SummaryCombinePolicy; +use crate::tuple::sketch::CompactTupleSketch; +use crate::tuple::sketch::TupleSketchView; + +/// Stateful intersection operator for Tuple sketches. +/// +/// `S` is the summary type and `P` is the [`SummaryCombinePolicy`] applied to keys present in more +/// than one input. There is no default policy (see the module docs), so one must be supplied at +/// construction. +/// +/// Before the first [`update`](Self::update), the result is undefined; use +/// [`has_result`](Self::has_result) to check. +/// +/// # Examples +/// +/// ``` +/// use datasketches::tuple::SummaryCombinePolicy; +/// use datasketches::tuple::TupleIntersection; +/// use datasketches::tuple::UpdatableTupleSketch; +/// +/// // Sum the summaries of keys that appear in both inputs. +/// #[derive(Default)] +/// struct SumPolicy; +/// impl SummaryCombinePolicy for SumPolicy { +/// fn combine(&self, summary: &mut u64, other: &u64) { +/// *summary += *other; +/// } +/// } +/// +/// let mut a = UpdatableTupleSketch::::builder().build(); +/// a.update("shared", 3); +/// a.update("only_a", 1); +/// +/// let mut b = UpdatableTupleSketch::::builder().build(); +/// b.update("shared", 4); +/// b.update("only_b", 1); +/// +/// let mut intersection = TupleIntersection::::new_with_default_seed(SumPolicy); +/// intersection.update(&a).unwrap(); +/// intersection.update(&b).unwrap(); +/// +/// let result = intersection.result(); +/// assert_eq!(result.num_retained(), 1); // only "shared" +/// assert_eq!(result.iter().next().unwrap().1, &7); // 3 + 4 +/// ``` +#[derive(Debug)] +pub struct TupleIntersection { + is_valid: bool, + table: TupleHashTable, + policy: P, +} + +impl TupleIntersection { + /// Creates a new intersection operator for the given `seed` and combine `policy`. + pub fn new(seed: u64, policy: P) -> Self { + Self { + is_valid: false, + table: TupleHashTable::from_raw_parts( + 0, + 0, + ResizeFactor::X1, + 1.0, + MAX_THETA, + seed, + false, + ), + policy, + } + } + + /// Creates a new intersection operator with the default seed and the given combine `policy`. + pub fn new_with_default_seed(policy: P) -> Self { + Self::new(DEFAULT_UPDATE_SEED, policy) + } + + /// Updates the intersection with a given sketch. + /// + /// The intersection can be viewed as starting from the "universe" set, and every update reduces + /// the current set to the keys it shares with `sketch`. Summaries of shared keys are combined + /// via the policy. + /// + /// # Errors + /// + /// Returns an error if `sketch` (when non-empty) has a different seed hash, or if the input + /// appears corrupted (entry counts do not match what the sketch reports). + pub fn update(&mut self, sketch: &V) -> Result<(), Error> + where + V: TupleSketchView, + P: SummaryCombinePolicy, + S: Clone, + { + let new_default_table = |table: &TupleHashTable| { + TupleHashTable::from_raw_parts( + 0, + 0, + ResizeFactor::X1, + 1.0, + table.theta(), + table.hash_seed(), + table.is_empty(), + ) + }; + + if self.table.is_empty() { + return Ok(()); + } + + if !sketch.is_empty() && sketch.seed_hash() != self.table.seed_hash() { + return Err(Error::invalid_argument(format!( + "incompatible seed hash: expected {}, got {}", + self.table.seed_hash(), + sketch.seed_hash() + ))); + } + + if sketch.is_empty() { + self.table.set_empty(true); + } + + self.table.set_theta(if self.table.is_empty() { + MAX_THETA + } else { + self.table.theta().min(sketch.theta64()) + }); + + if self.is_valid && self.table.num_retained() == 0 { + return Ok(()); + } + + if sketch.num_retained() == 0 { + self.is_valid = true; + self.table = new_default_table(&self.table); + return Ok(()); + } + + // first update, copy the incoming sketch's entries (hash + summary) + if !self.is_valid { + self.is_valid = true; + let lg_size = TupleHashTable::::lg_size_from_count_for_rebuild( + sketch.num_retained(), + HASH_TABLE_REBUILD_THRESHOLD, + ); + let mut new_table = TupleHashTable::from_raw_parts( + lg_size, + lg_size - 1, + ResizeFactor::X1, + 1.0, + self.table.theta(), + self.table.hash_seed(), + self.table.is_empty(), + ); + for (hash, summary) in sketch.iter() { + if !new_table.try_insert(hash, summary.clone()) { + return Err(Error::invalid_argument( + "duplicate key, possibly corrupted input sketch", + )); + } + } + // Safety check. + if new_table.num_retained() != sketch.num_retained() { + return Err(Error::invalid_argument( + "num entries mismatch, possibly corrupted input sketch", + )); + } + self.table = new_table; + } else { + let max_matches = self.table.num_retained().min(sketch.num_retained()); + let mut matched_entries: Vec<(u64, S)> = Vec::with_capacity(max_matches); + let mut count = 0; + let policy = &self.policy; + for (hash, incoming) in sketch.iter() { + if hash < self.table.theta() { + if let Some(existing) = self.table.get(hash) { + if matched_entries.len() == max_matches { + return Err(Error::invalid_argument( + "max matches exceeded, possibly corrupted input sketch", + )); + } + let mut combined = existing.clone(); + policy.combine(&mut combined, incoming); + matched_entries.push((hash, combined)); + } + } else if sketch.is_ordered() { + break; // early stop for ordered sketches + } + count += 1; + } + // Safety check. + if count > sketch.num_retained() { + return Err(Error::invalid_argument( + "more keys than expected, possibly corrupted input sketch", + )); + } else if !sketch.is_ordered() && count < sketch.num_retained() { + return Err(Error::invalid_argument( + "fewer keys than expected, possibly corrupted input sketch", + )); + } + if matched_entries.is_empty() { + self.table = new_default_table(&self.table); + if self.table.theta() == MAX_THETA { + self.table.set_empty(true); + } + } else { + let lg_size = TupleHashTable::::lg_size_from_count_for_rebuild( + matched_entries.len(), + HASH_TABLE_REBUILD_THRESHOLD, + ); + let mut new_table = TupleHashTable::from_raw_parts( + lg_size, + lg_size - 1, + ResizeFactor::X1, + 1.0, + self.table.theta(), + self.table.hash_seed(), + self.table.is_empty(), + ); + for (hash, summary) in matched_entries { + if !new_table.try_insert(hash, summary) { + return Err(Error::invalid_argument( + "duplicate key, possibly corrupted input sketch", + )); + } + } + self.table = new_table; + } + } + Ok(()) + } + + /// Returns whether this operator has received at least one update. + pub fn has_result(&self) -> bool { + self.is_valid + } + + /// Returns the intersection result as a compact Tuple sketch (ordered). + /// + /// # Panics + /// + /// Panics if called before the first [`update`](Self::update). + pub fn result(&self) -> CompactTupleSketch + where + S: Clone, + { + self.result_with_ordered(true) + } + + /// Returns the intersection result as a compact Tuple sketch. + /// + /// # Panics + /// + /// Panics if called before the first [`update`](Self::update). + pub fn result_with_ordered(&self, ordered: bool) -> CompactTupleSketch + where + S: Clone, + { + assert!( + self.is_valid, + "TupleIntersection::result() called before first update()" + ); + let mut entries: Vec<(u64, S)> = self + .table + .iter() + .map(|(hash, summary)| (hash, summary.clone())) + .collect(); + if ordered { + entries.sort_unstable_by_key(|(hash, _)| *hash); + } + CompactTupleSketch::from_parts( + entries, + self.table.theta(), + self.table.seed_hash(), + ordered, + self.table.is_empty(), + ) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::tuple::UpdatableTupleSketch; + + #[derive(Debug, Default, Clone, Copy)] + struct SumPolicy; + + impl SummaryCombinePolicy for SumPolicy { + fn combine(&self, summary: &mut u64, other: &u64) { + *summary += *other; + } + } + + fn sorted_entries(sketch: &CompactTupleSketch) -> Vec<(u64, u64)> { + let mut entries: Vec<(u64, u64)> = sketch.iter().map(|(h, &s)| (h, s)).collect(); + entries.sort_unstable(); + entries + } + + #[test] + fn intersection_of_overlapping_sketches() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..1000 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 500..1500 { + b.update(i, 1u64); + } + + let mut intersection = TupleIntersection::::new_with_default_seed(SumPolicy); + intersection.update(&a).unwrap(); + intersection.update(&b).unwrap(); + + let result = intersection.result(); + // Keys 500..1000 are shared (exact mode), each summary is 1 + 1 = 2. + assert_eq!(result.num_retained(), 500); + assert_eq!(result.estimate(), 500.0); + assert!(result.iter().all(|(_, &s)| s == 2)); + } + + #[test] + fn intersection_combines_summaries_of_shared_keys() { + let mut a = UpdatableTupleSketch::::builder().build(); + a.update("shared", 3u64); + a.update("only_a", 100u64); + let mut b = UpdatableTupleSketch::::builder().build(); + b.update("shared", 4u64); + b.update("only_b", 200u64); + + let mut intersection = TupleIntersection::::new_with_default_seed(SumPolicy); + intersection.update(&a).unwrap(); + intersection.update(&b).unwrap(); + + let result = intersection.result(); + assert_eq!(sorted_entries(&result).len(), 1); + assert_eq!(result.iter().next().unwrap().1, &7); // 3 + 4 + } + + #[test] + fn intersection_is_order_independent() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..1000 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 500..1500 { + b.update(i, 1u64); + } + + let mut a_then_b = TupleIntersection::::new_with_default_seed(SumPolicy); + a_then_b.update(&a).unwrap(); + a_then_b.update(&b).unwrap(); + + let mut b_then_a = TupleIntersection::::new_with_default_seed(SumPolicy); + b_then_a.update(&b).unwrap(); + b_then_a.update(&a).unwrap(); + + assert_eq!( + sorted_entries(&a_then_b.result()), + sorted_entries(&b_then_a.result()) + ); + } + + #[test] + fn intersection_accepts_updatable_and_compact_inputs() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..1000 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 500..1500 { + b.update(i, 1u64); + } + let b_compact = b.compact(true); + + let mut intersection = TupleIntersection::::new_with_default_seed(SumPolicy); + intersection.update(&a).unwrap(); + intersection.update(&b_compact).unwrap(); + + assert_eq!(intersection.result().num_retained(), 500); + } + + #[test] + fn intersection_with_disjoint_sketches_is_empty() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..1000 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 1000..2000 { + b.update(i, 1u64); + } + + let mut intersection = TupleIntersection::::new_with_default_seed(SumPolicy); + intersection.update(&a).unwrap(); + intersection.update(&b).unwrap(); + + let result = intersection.result(); + assert_eq!(result.num_retained(), 0); + assert_eq!(result.estimate(), 0.0); + } + + #[test] + fn intersection_with_empty_input_is_empty() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..1000 { + a.update(i, 1u64); + } + let empty = UpdatableTupleSketch::::builder().build(); + + let mut intersection = TupleIntersection::::new_with_default_seed(SumPolicy); + intersection.update(&a).unwrap(); + intersection.update(&empty).unwrap(); + + let result = intersection.result(); + assert!(result.is_empty()); + assert_eq!(result.num_retained(), 0); + assert_eq!(result.estimate(), 0.0); + } + + #[test] + fn intersection_single_update_returns_input() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..100 { + a.update(i, 5u64); + } + + let mut intersection = TupleIntersection::::new_with_default_seed(SumPolicy); + intersection.update(&a).unwrap(); + + let result = intersection.result(); + assert_eq!(result.num_retained(), 100); + // A single update copies the input unchanged (summaries not combined with anything). + assert!(result.iter().all(|(_, &s)| s == 5)); + } + + #[test] + fn has_result_reflects_first_update() { + let mut intersection = TupleIntersection::::new_with_default_seed(SumPolicy); + assert!(!intersection.has_result()); + + let mut a = UpdatableTupleSketch::::builder().build(); + a.update(1, 1u64); + intersection.update(&a).unwrap(); + assert!(intersection.has_result()); + } + + #[test] + #[should_panic(expected = "before first update")] + fn result_before_update_panics() { + let intersection = TupleIntersection::::new_with_default_seed(SumPolicy); + let _ = intersection.result(); + } + + #[test] + fn intersection_rejects_seed_mismatch() { + let mut a = UpdatableTupleSketch::::builder().seed(1).build(); + a.update(1, 1u64); + + let mut intersection = TupleIntersection::::new(2, SumPolicy); + let err = intersection.update(&a).unwrap_err(); + assert_eq!(err.kind(), crate::error::ErrorKind::InvalidArgument); + } + + #[test] + fn intersection_in_estimation_mode_estimates_within_bounds() { + let mut a = UpdatableTupleSketch::::builder().lg_k(8).build(); + for i in 0..50000 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().lg_k(8).build(); + for i in 25000..75000 { + b.update(i, 1u64); + } + + let mut intersection = TupleIntersection::::new_with_default_seed(SumPolicy); + intersection.update(&a).unwrap(); + intersection.update(&b).unwrap(); + + let result = intersection.result(); + assert!(result.is_estimation_mode()); + // True intersection size is 25000 (keys 25000..50000). + let lower = result.lower_bound(crate::common::NumStdDev::Three); + let upper = result.upper_bound(crate::common::NumStdDev::Three); + assert!( + lower <= 25000.0 && 25000.0 <= upper, + "expected 25000 in [{lower}, {upper}]" + ); + } +} diff --git a/datasketches/src/tuple/mod.rs b/datasketches/src/tuple/mod.rs new file mode 100644 index 0000000..7470134 --- /dev/null +++ b/datasketches/src/tuple/mod.rs @@ -0,0 +1,60 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tuple sketch implementation. +//! +//! A Tuple sketch is an extension of the [Theta sketch](crate::theta): in addition to the retained +//! hash values it keeps a user-defined summary associated with every retained key. The hash table +//! mechanics (theta screening, resize, rebuild to nominal size k) mirror the Theta sketch, with the +//! added requirement that colliding keys merge their summaries. +//! +//! The behavior of a summary (how to create, update, and combine it) is supplied externally through +//! policy objects ([`SummaryUpdatePolicy`] and [`SummaryCombinePolicy`]) rather than being baked +//! into the summary type. +//! +//! # Usage +//! +//! ``` +//! # use datasketches::tuple::UpdatableTupleSketch; +//! let mut sketch = UpdatableTupleSketch::::builder().build(); +//! sketch.update("apple", 1); +//! assert!(sketch.estimate() >= 1.0); +//! ``` + +mod a_not_b; +mod hash_table; +mod intersection; +mod policy; +mod serde; +mod serialization; +mod sketch; +mod union; + +pub use self::a_not_b::TupleAnotB; +pub use self::intersection::TupleIntersection; +pub use self::policy::DefaultUnionPolicy; +pub use self::policy::DefaultUpdatePolicy; +pub use self::policy::SummaryCombinePolicy; +pub use self::policy::SummaryUpdatePolicy; +pub use self::serde::PrimitiveSummarySerde; +pub use self::serde::SummarySerde; +pub use self::sketch::CompactTupleSketch; +pub use self::sketch::TupleSketchView; +pub use self::sketch::UpdatableTupleSketch; +pub use self::sketch::UpdatableTupleSketchBuilder; +pub use self::union::TupleUnion; +pub use self::union::TupleUnionBuilder; diff --git a/datasketches/src/tuple/policy.rs b/datasketches/src/tuple/policy.rs new file mode 100644 index 0000000..b641b10 --- /dev/null +++ b/datasketches/src/tuple/policy.rs @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Policies describing how summaries are created, updated, and combined. +//! +//! A Tuple sketch keeps a user-defined summary `S` next to every retained key. The behavior of a +//! summary is supplied externally through policy objects rather than baked into the summary type +//! itself, so the same summary type (for example a plain `u64` or a `Vec`) can be driven by +//! different behaviors and can carry per-instance configuration (such as the number of values in an +//! array-of-doubles summary). +//! +//! This mirrors the policy approach used by the C++ implementation +//! (`default_tuple_update_policy`, `default_tuple_union_policy`) and the Java +//! `SummaryFactory` / `SummarySetOperations` interfaces. + +use std::ops::AddAssign; + +/// Defines how a summary is created and how update values are folded into it. +/// +/// This is used by the update tuple sketch. `S` is the stored summary type and `U` is the type of +/// the update value, which may be a borrowed type such as `&[f64]`. +/// +/// Corresponds to C++ `default_tuple_update_policy` and Java +/// `UpdatableSummary` together with `SummaryFactory`. +pub trait SummaryUpdatePolicy { + /// Creates a new summary for a key seen for the first time. + /// + /// The summary should be in its identity state; the first update value is folded in separately + /// via [`update`](Self::update). + fn create(&self) -> S; + + /// Folds an update value into an existing summary. + fn update(&self, summary: &mut S, value: U); +} + +/// Defines how two summaries that share the same key are combined. +/// +/// This is used by both union and intersection. Each operator is given its own policy instance, +/// because the two operations may combine summaries differently for the same summary type. +/// +/// Corresponds to C++ `default_tuple_union_policy` / the tuple intersection policy and Java +/// `SummarySetOperations`. +pub trait SummaryCombinePolicy { + /// Combines `other` into `summary` in place. + fn combine(&self, summary: &mut S, other: &S); +} + +/// Default update policy for summaries that are default-constructible and additive. +/// +/// This is the convenience policy used when no custom policy is supplied, equivalent to C++ +/// `default_tuple_update_policy` (which folds updates with `summary += update`). It is available +/// for any summary type `S` and update type `U` where `S: Default + AddAssign`. +#[derive(Debug, Default, Clone, Copy)] +pub struct DefaultUpdatePolicy; + +impl SummaryUpdatePolicy for DefaultUpdatePolicy +where + S: Default + AddAssign, +{ + fn create(&self) -> S { + S::default() + } + + fn update(&self, summary: &mut S, value: U) { + *summary += value; + } +} + +/// Default combine policy for additive summaries, used by the union when no custom policy is given. +/// +/// This is equivalent to C++ `default_tuple_union_policy`, which combines two summaries with +/// `summary += other`. It is available for any summary type `S` where `S: AddAssign<&S>`. +/// +/// There is intentionally no default combine policy for the intersection: how to combine summaries +/// of the keys present in both inputs is application-specific, so the intersection always requires +/// an explicit policy. +#[derive(Debug, Default, Clone, Copy)] +pub struct DefaultUnionPolicy; + +impl SummaryCombinePolicy for DefaultUnionPolicy +where + for<'a> S: AddAssign<&'a S>, +{ + fn combine(&self, summary: &mut S, other: &S) { + *summary += other; + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn default_update_policy_update_accumulates() { + let policy = DefaultUpdatePolicy; + let mut summary = 0u64; + policy.update(&mut summary, 3); + policy.update(&mut summary, 4); + assert_eq!(summary, 7); + } + + /// A non-trivial custom policy (keeps the maximum) to exercise the traits beyond the additive + /// default. + #[derive(Debug, Default, Clone, Copy)] + struct MaxPolicy; + + impl SummaryUpdatePolicy for MaxPolicy { + fn create(&self) -> u64 { + 0 + } + + fn update(&self, summary: &mut u64, value: u64) { + *summary = (*summary).max(value); + } + } + + impl SummaryCombinePolicy for MaxPolicy { + fn combine(&self, summary: &mut u64, other: &u64) { + self.update(summary, *other); + } + } + + #[test] + fn custom_update_policy_keeps_max() { + let policy = MaxPolicy; + let mut summary = policy.create(); + policy.update(&mut summary, 3); + policy.update(&mut summary, 7); + policy.update(&mut summary, 2); + assert_eq!(summary, 7); + } + + #[test] + fn custom_combine_policy_keeps_max() { + let policy = MaxPolicy; + let mut summary = 5u64; + policy.combine(&mut summary, &10); + policy.combine(&mut summary, &7); + assert_eq!(summary, 10); + } + + #[test] + fn default_union_policy_combines_additively() { + let policy = DefaultUnionPolicy; + let mut summary = 5u64; + policy.combine(&mut summary, &10); + policy.combine(&mut summary, &7); + assert_eq!(summary, 22); + } +} diff --git a/datasketches/src/tuple/serde.rs b/datasketches/src/tuple/serde.rs new file mode 100644 index 0000000..376701d --- /dev/null +++ b/datasketches/src/tuple/serde.rs @@ -0,0 +1,164 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Serialization of summaries. +//! +//! A Tuple sketch stores a user-defined summary next to every retained key. Because the summary +//! type is opaque to the sketch, (de)serialization of summaries is delegated to a [`SummarySerde`] +//! object, mirroring the C++ `SerDe` template parameter and the Java `SummarySerializer` / +//! `SummaryDeserializer` interfaces. + +use crate::error::Error; + +/// Serializes and deserializes a summary of type `S`. +/// +/// The encoding is entirely up to the implementation; the sketch only requires that +/// [`deserialize`](Self::deserialize) report how many bytes it consumed so it can advance to the +/// next entry. This supports both fixed-width summaries (such as a `u64` counter) and +/// variable-width summaries (such as an array of doubles whose length is encoded in the bytes). +pub trait SummarySerde { + /// Appends the serialized form of `summary` to `out`. + fn serialize(&self, summary: &S, out: &mut Vec); + + /// Reads one summary from the front of `bytes`, returning it together with the number of bytes + /// consumed. + /// + /// # Errors + /// + /// Returns an error if `bytes` is too short or otherwise malformed for this encoding. + fn deserialize(&self, bytes: &[u8]) -> Result<(S, usize), Error>; +} + +/// A [`SummarySerde`] for fixed-width little-endian primitive summaries. +/// +/// This covers the common case where the summary is a single integer or float (`u32`, `u64`, `i32`, +/// `i64`, `f32`, `f64`). The value is stored in little-endian byte order, matching the Java/C++ +/// primitive serializers. +/// +/// # Examples +/// +/// ``` +/// use datasketches::tuple::PrimitiveSummarySerde; +/// use datasketches::tuple::SummarySerde; +/// +/// let serde = PrimitiveSummarySerde; +/// let mut bytes = Vec::new(); +/// serde.serialize(&7u64, &mut bytes); +/// assert_eq!(bytes.len(), 8); +/// let (value, consumed) = serde.deserialize(&bytes).unwrap(); +/// assert_eq!((value, consumed), (7u64, 8)); +/// ``` +#[derive(Debug, Default, Clone, Copy)] +pub struct PrimitiveSummarySerde; + +impl SummarySerde for PrimitiveSummarySerde +where + S: private::LeBytes, +{ + fn serialize(&self, summary: &S, out: &mut Vec) { + summary.write_le(out); + } + + fn deserialize(&self, bytes: &[u8]) -> Result<(S, usize), Error> { + if bytes.len() < S::WIDTH { + return Err(Error::insufficient_data(format!( + "summary: expected {} bytes, got {}", + S::WIDTH, + bytes.len() + ))); + } + Ok((S::read_le(&bytes[..S::WIDTH]), S::WIDTH)) + } +} + +mod private { + /// Sealed helper describing fixed-width little-endian encoding of a primitive. + pub trait LeBytes: Copy { + /// Number of bytes in the little-endian encoding. + const WIDTH: usize; + /// Appends the little-endian bytes of `self` to `out`. + fn write_le(self, out: &mut Vec); + /// Reads the value from exactly `WIDTH` leading bytes of `bytes`. + fn read_le(bytes: &[u8]) -> Self; + } + + macro_rules! impl_le_bytes { + ($($t:ty),* $(,)?) => { + $( + impl LeBytes for $t { + const WIDTH: usize = std::mem::size_of::<$t>(); + + fn write_le(self, out: &mut Vec) { + out.extend_from_slice(&self.to_le_bytes()); + } + + fn read_le(bytes: &[u8]) -> Self { + let mut buf = [0u8; std::mem::size_of::<$t>()]; + buf.copy_from_slice(&bytes[..std::mem::size_of::<$t>()]); + <$t>::from_le_bytes(buf) + } + } + )* + }; + } + + impl_le_bytes!(u32, u64, i32, i64, f32, f64); +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn primitive_serde_round_trips_u64() { + let serde = PrimitiveSummarySerde; + let mut bytes = Vec::new(); + serde.serialize(&123456789u64, &mut bytes); + assert_eq!(bytes.len(), 8); + let (value, consumed): (u64, usize) = serde.deserialize(&bytes).unwrap(); + assert_eq!(value, 123456789); + assert_eq!(consumed, 8); + } + + #[test] + fn primitive_serde_round_trips_f64() { + let serde = PrimitiveSummarySerde; + let mut bytes = Vec::new(); + serde.serialize(&3.5f64, &mut bytes); + let (value, consumed): (f64, usize) = serde.deserialize(&bytes).unwrap(); + assert_eq!(value, 3.5); + assert_eq!(consumed, 8); + } + + #[test] + fn primitive_serde_consumes_only_its_width() { + let serde = PrimitiveSummarySerde; + let mut bytes = Vec::new(); + serde.serialize(&9u32, &mut bytes); + bytes.extend_from_slice(&[0xAA, 0xBB]); // trailing bytes belonging to the next entry + let (value, consumed): (u32, usize) = serde.deserialize(&bytes).unwrap(); + assert_eq!(value, 9); + assert_eq!(consumed, 4); + } + + #[test] + fn primitive_serde_rejects_short_input() { + let serde = PrimitiveSummarySerde; + let err = SummarySerde::::deserialize(&serde, &[0u8; 3]).unwrap_err(); + assert_eq!(err.kind(), crate::error::ErrorKind::InvalidData); + } +} diff --git a/datasketches/src/tuple/serialization.rs b/datasketches/src/tuple/serialization.rs new file mode 100644 index 0000000..9f431ee --- /dev/null +++ b/datasketches/src/tuple/serialization.rs @@ -0,0 +1,37 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Binary serialization format constants for Tuple sketches. +//! +//! The Tuple compact format reuses the uncompressed Theta layout (preamble, flags, theta) but uses +//! the Tuple family id, carries a sketch-type byte, and stores the user summary bytes after each +//! retained hash. See the C++/Java reference implementations for the on-disk format. + +/// Current serial version written by this implementation. +pub(super) const SERIAL_VERSION: u8 = 3; +/// Legacy serial version still accepted on read. +pub(super) const SERIAL_VERSION_LEGACY: u8 = 1; + +/// Current sketch-type byte written by this implementation. +pub(super) const SKETCH_TYPE: u8 = 1; +/// Legacy sketch-type byte still accepted on read. +pub(super) const SKETCH_TYPE_LEGACY: u8 = 5; + +pub(super) const FLAGS_IS_READ_ONLY: u8 = 1 << 1; +pub(super) const FLAGS_IS_EMPTY: u8 = 1 << 2; +pub(super) const FLAGS_IS_COMPACT: u8 = 1 << 3; +pub(super) const FLAGS_IS_ORDERED: u8 = 1 << 4; diff --git a/datasketches/src/tuple/sketch.rs b/datasketches/src/tuple/sketch.rs new file mode 100644 index 0000000..4e41125 --- /dev/null +++ b/datasketches/src/tuple/sketch.rs @@ -0,0 +1,1045 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tuple sketch types. +//! +//! This module provides [`UpdatableTupleSketch`] (mutable) and [`CompactTupleSketch`] (immutable), +//! the Tuple sketch analogues of the Theta sketch. Each retained key carries a user-defined summary +//! whose behavior is supplied by a [`SummaryUpdatePolicy`]. + +use std::hash::Hash; +use std::marker::PhantomData; + +use crate::codec::SketchBytes; +use crate::codec::SketchSlice; +use crate::codec::assert::ensure_preamble_longs_in_range; +use crate::codec::assert::insufficient_data; +use crate::codec::family::Family; +use crate::common::NumStdDev; +use crate::common::ResizeFactor; +use crate::common::binomial_bounds; +use crate::error::Error; +use crate::hash::DEFAULT_UPDATE_SEED; +use crate::hash::compute_seed_hash; +use crate::theta::DEFAULT_LG_K; +use crate::theta::MAX_LG_K; +use crate::theta::MAX_THETA; +use crate::theta::MIN_LG_K; +use crate::tuple::hash_table::TupleHashTable; +use crate::tuple::policy::DefaultUpdatePolicy; +use crate::tuple::policy::SummaryUpdatePolicy; +use crate::tuple::serde::SummarySerde; +use crate::tuple::serialization; + +mod private { + use super::*; + + // Sealed trait to prevent external implementations of TupleSketchView. + pub trait Sealed {} + + impl Sealed for UpdatableTupleSketch {} + impl Sealed for CompactTupleSketch {} +} + +/// Read-only view for Tuple sketches. +/// +/// This trait provides a unified input abstraction for APIs (such as union and intersection) that +/// can accept either a mutable [`UpdatableTupleSketch`] or an immutable [`CompactTupleSketch`]. +/// `S` is the summary type retained by the sketch. +pub trait TupleSketchView: private::Sealed { + /// Returns the 16-bit seed hash. + fn seed_hash(&self) -> u16; + + /// Returns theta as `u64`. + fn theta64(&self) -> u64; + + /// Returns true if this sketch is empty. + fn is_empty(&self) -> bool; + + /// Returns an iterator over retained entries as `(hash, &summary)` pairs. + fn iter<'a>(&'a self) -> impl Iterator + 'a + where + S: 'a; + + /// Returns the number of retained entries. + fn num_retained(&self) -> usize; + + /// Returns whether retained entries are ordered in ascending order by hash. + fn is_ordered(&self) -> bool { + false + } +} + +/// Mutable Tuple sketch for building from input data. +/// +/// `S` is the summary type retained alongside each key, and `P` is the [`SummaryUpdatePolicy`] that +/// defines how summaries are created and updated. For additive summaries the default +/// [`DefaultUpdatePolicy`] is used. +/// +/// # Examples +/// +/// ``` +/// # use datasketches::tuple::UpdatableTupleSketch; +/// let mut sketch = UpdatableTupleSketch::::builder().build(); +/// sketch.update("apple", 1); +/// sketch.update("apple", 1); +/// assert!(sketch.estimate() >= 1.0); +/// assert_eq!(sketch.num_retained(), 1); +/// ``` +#[derive(Debug)] +pub struct UpdatableTupleSketch { + table: TupleHashTable, + policy: P, +} + +impl UpdatableTupleSketch { + /// Creates a new builder using the default update policy. + /// + /// # Examples + /// + /// ``` + /// # use datasketches::tuple::UpdatableTupleSketch; + /// let sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + /// assert_eq!(sketch.lg_k(), 12); + /// ``` + pub fn builder() -> UpdatableTupleSketchBuilder { + UpdatableTupleSketchBuilder::default() + } +} + +impl UpdatableTupleSketch { + /// Updates the sketch with a key and an update value. + /// + /// If the key is new, the policy creates a summary and folds in `value`; if the key already + /// exists, `value` is folded into the retained summary. Updates screened out by theta do not + /// change any summary. + /// + /// # Examples + /// + /// ``` + /// # use datasketches::tuple::UpdatableTupleSketch; + /// let mut sketch = UpdatableTupleSketch::::builder().build(); + /// sketch.update(42, 5); + /// ``` + pub fn update(&mut self, key: impl Hash, value: U) + where + P: SummaryUpdatePolicy, + { + let policy = &self.policy; + self.table.update(key, |existing| match existing { + Some(summary) => { + policy.update(summary, value); + None + } + None => { + let mut summary = policy.create(); + policy.update(&mut summary, value); + Some(summary) + } + }); + } + + /// Returns the cardinality (distinct key count) estimate. + pub fn estimate(&self) -> f64 { + if self.is_empty() { + return 0.0; + } + let num_retained = self.table.num_retained() as f64; + let theta = self.table.theta() as f64 / MAX_THETA as f64; + num_retained / theta + } + + /// Returns theta as a fraction (0.0 to 1.0). + pub fn theta(&self) -> f64 { + self.table.theta() as f64 / MAX_THETA as f64 + } + + /// Returns theta as `u64`. + pub fn theta64(&self) -> u64 { + self.table.theta() + } + + /// Returns the 16-bit seed hash. + pub fn seed_hash(&self) -> u16 { + self.table.seed_hash() + } + + /// Returns true if the sketch is empty. + pub fn is_empty(&self) -> bool { + self.table.is_empty() + } + + /// Returns true if the sketch is in estimation mode. + pub fn is_estimation_mode(&self) -> bool { + self.table.theta() < MAX_THETA + } + + /// Returns the number of retained entries. + pub fn num_retained(&self) -> usize { + self.table.num_retained() + } + + /// Returns lg_k (log2 of the nominal size k). + pub fn lg_k(&self) -> u8 { + self.table.lg_nom_size() + } + + /// Trims the sketch to the nominal size k. + pub fn trim(&mut self) { + self.table.trim(); + } + + /// Resets the sketch to the empty state. + pub fn reset(&mut self) { + self.table.reset(); + } + + /// Returns an iterator over retained entries as `(hash, &summary)` pairs. + pub fn iter(&self) -> impl Iterator + '_ { + self.table.iter() + } + + /// Returns the approximate lower error bound given the number of standard deviations. + pub fn lower_bound(&self, num_std_dev: NumStdDev) -> f64 { + if !self.is_estimation_mode() { + return self.num_retained() as f64; + } + binomial_bounds::lower_bound(self.num_retained() as u64, self.theta(), num_std_dev) + .expect("theta should always be valid") + } + + /// Returns the approximate upper error bound given the number of standard deviations. + pub fn upper_bound(&self, num_std_dev: NumStdDev) -> f64 { + if !self.is_estimation_mode() { + return self.num_retained() as f64; + } + binomial_bounds::upper_bound( + self.num_retained() as u64, + self.theta(), + num_std_dev, + self.is_empty(), + ) + .expect("theta should always be valid") + } + + /// Returns the estimated size of the sketch in bytes. + pub fn estimated_size(&self) -> usize { + std::mem::size_of::() + self.table.estimated_size() + } +} + +impl UpdatableTupleSketch { + /// Returns this sketch in compact (immutable) form. + /// + /// If `ordered` is true, retained entries are sorted by hash in ascending order. + /// + /// # Examples + /// + /// ``` + /// # use datasketches::tuple::UpdatableTupleSketch; + /// let mut sketch = UpdatableTupleSketch::::builder().build(); + /// sketch.update("apple", 1); + /// let compact = sketch.compact(true); + /// assert_eq!(compact.num_retained(), 1); + /// ``` + pub fn compact(&self, ordered: bool) -> CompactTupleSketch { + let mut entries: Vec<(u64, S)> = self + .table + .iter() + .map(|(hash, summary)| (hash, summary.clone())) + .collect(); + + let empty = self.is_empty(); + // Match Theta's behavior for never-updated sketches initialized with p < 1.0. + let theta = if empty { MAX_THETA } else { self.table.theta() }; + let is_single = entries.len() == 1 && theta == MAX_THETA; + // Empty or single-item sketches are always ordered (Java/C++ compatibility). + let ordered = ordered || empty || is_single; + + if ordered && entries.len() > 1 { + entries.sort_unstable_by_key(|(hash, _)| *hash); + } + + CompactTupleSketch::from_parts(entries, theta, self.table.seed_hash(), ordered, empty) + } +} + +impl TupleSketchView for UpdatableTupleSketch { + fn seed_hash(&self) -> u16 { + UpdatableTupleSketch::seed_hash(self) + } + + fn theta64(&self) -> u64 { + UpdatableTupleSketch::theta64(self) + } + + fn is_empty(&self) -> bool { + UpdatableTupleSketch::is_empty(self) + } + + fn iter<'a>(&'a self) -> impl Iterator + 'a + where + S: 'a, + { + UpdatableTupleSketch::iter(self) + } + + fn num_retained(&self) -> usize { + UpdatableTupleSketch::num_retained(self) + } +} + +/// Compact (immutable) Tuple sketch. +/// +/// This is the serialization-friendly form: a compact array of `(hash, summary)` entries plus theta +/// and a 16-bit seed hash. It can be ordered (sorted ascending by hash) or unordered. +#[derive(Clone, Debug)] +pub struct CompactTupleSketch { + entries: Vec<(u64, S)>, + theta: u64, + seed_hash: u16, + ordered: bool, + empty: bool, +} + +impl CompactTupleSketch { + pub(super) fn from_parts( + entries: Vec<(u64, S)>, + theta: u64, + seed_hash: u16, + ordered: bool, + empty: bool, + ) -> Self { + Self { + entries, + theta, + seed_hash, + ordered, + empty, + } + } + + /// Returns the cardinality (distinct key count) estimate. + pub fn estimate(&self) -> f64 { + if self.is_empty() { + return 0.0; + } + let num_retained = self.num_retained() as f64; + if self.theta == MAX_THETA { + return num_retained; + } + let theta = self.theta as f64 / MAX_THETA as f64; + num_retained / theta + } + + /// Returns theta as a fraction (0.0 to 1.0). + pub fn theta(&self) -> f64 { + self.theta as f64 / MAX_THETA as f64 + } + + /// Returns theta as `u64`. + pub fn theta64(&self) -> u64 { + self.theta + } + + /// Returns true if the sketch is empty. + pub fn is_empty(&self) -> bool { + self.empty + } + + /// Returns true if the sketch is in estimation mode. + pub fn is_estimation_mode(&self) -> bool { + self.theta < MAX_THETA + } + + /// Returns the number of retained entries. + pub fn num_retained(&self) -> usize { + self.entries.len() + } + + /// Returns true if retained entries are ordered (sorted ascending by hash). + pub fn is_ordered(&self) -> bool { + self.ordered + } + + /// Returns the 16-bit seed hash. + pub fn seed_hash(&self) -> u16 { + self.seed_hash + } + + /// Returns an iterator over retained entries as `(hash, &summary)` pairs. + pub fn iter(&self) -> impl Iterator + '_ { + self.entries.iter().map(|(hash, summary)| (*hash, summary)) + } + + /// Returns the approximate lower error bound given the number of standard deviations. + pub fn lower_bound(&self, num_std_dev: NumStdDev) -> f64 { + if !self.is_estimation_mode() { + return self.num_retained() as f64; + } + binomial_bounds::lower_bound(self.num_retained() as u64, self.theta(), num_std_dev) + .expect("compact theta should always be valid") + } + + /// Returns the approximate upper error bound given the number of standard deviations. + pub fn upper_bound(&self, num_std_dev: NumStdDev) -> f64 { + if !self.is_estimation_mode() { + return self.num_retained() as f64; + } + binomial_bounds::upper_bound( + self.num_retained() as u64, + self.theta(), + num_std_dev, + self.is_empty(), + ) + .expect("compact theta should always be valid") + } + + /// Returns the estimated size of the sketch in bytes. + pub fn estimated_size(&self) -> usize { + std::mem::size_of::() + self.entries.capacity() * std::mem::size_of::<(u64, S)>() + } + + fn preamble_longs(&self) -> u8 { + if self.is_estimation_mode() { + 3 + } else if self.is_empty() || self.entries.len() == 1 { + 1 + } else { + 2 + } + } + + /// Serializes this sketch into the compact Tuple binary format. + /// + /// Each summary is encoded using `serde`. The layout matches the Java/C++ Tuple sketches, so + /// the output can be read by those implementations given a compatible summary serializer. + /// + /// # Examples + /// + /// ``` + /// # use datasketches::tuple::{PrimitiveSummarySerde, UpdatableTupleSketch}; + /// let mut sketch = UpdatableTupleSketch::::builder().build(); + /// sketch.update("apple", 1); + /// let bytes = sketch.compact(true).serialize(&PrimitiveSummarySerde); + /// assert!(!bytes.is_empty()); + /// ``` + pub fn serialize(&self, serde: &SD) -> Vec + where + SD: SummarySerde, + { + let pre_longs = self.preamble_longs(); + let mut bytes = + SketchBytes::with_capacity(8 * pre_longs as usize + self.entries.len() * 16); + + bytes.write_u8(pre_longs); + bytes.write_u8(serialization::SERIAL_VERSION); + bytes.write_u8(Family::TUPLE.id); + bytes.write_u8(serialization::SKETCH_TYPE); + bytes.write_u8(0); // unused + + let mut flags = serialization::FLAGS_IS_READ_ONLY | serialization::FLAGS_IS_COMPACT; + if self.is_empty() { + flags |= serialization::FLAGS_IS_EMPTY; + } + if self.is_ordered() { + flags |= serialization::FLAGS_IS_ORDERED; + } + bytes.write_u8(flags); + bytes.write_u16_le(self.seed_hash); + + if pre_longs > 1 { + bytes.write_u32_le(self.entries.len() as u32); + bytes.write_u32_le(0); // unused + } + if self.is_estimation_mode() { + bytes.write_u64_le(self.theta); + } + + let mut summary_buf = Vec::new(); + for (hash, summary) in self.entries.iter() { + bytes.write_u64_le(*hash); + summary_buf.clear(); + serde.serialize(summary, &mut summary_buf); + bytes.write(&summary_buf); + } + bytes.into_bytes() + } + + /// Deserializes a compact Tuple sketch using the default seed, decoding summaries with `serde`. + pub fn deserialize(bytes: &[u8], serde: &SD) -> Result + where + SD: SummarySerde, + { + Self::deserialize_with_seed(bytes, DEFAULT_UPDATE_SEED, serde) + } + + /// Deserializes a compact Tuple sketch using the provided expected `seed`, decoding summaries + /// with `serde`. + /// + /// # Errors + /// + /// Returns an error if the bytes are truncated, the family/serial version/sketch type are + /// unexpected, the seed hash does not match (for non-empty sketches), or an entry is corrupted. + pub fn deserialize_with_seed(bytes: &[u8], seed: u64, serde: &SD) -> Result + where + SD: SummarySerde, + { + let mut cursor = SketchSlice::new(bytes); + let pre_longs = cursor + .read_u8() + .map_err(insufficient_data("preamble_longs"))?; + let ser_ver = cursor + .read_u8() + .map_err(insufficient_data("serial_version"))?; + let family_id = cursor.read_u8().map_err(insufficient_data("family_id"))?; + let sketch_type = cursor.read_u8().map_err(insufficient_data("sketch_type"))?; + cursor.read_u8().map_err(insufficient_data(""))?; + let flags = cursor.read_u8().map_err(insufficient_data("flags"))?; + let seed_hash = cursor + .read_u16_le() + .map_err(insufficient_data("seed_hash"))?; + + Family::TUPLE.validate_id(family_id)?; + ensure_preamble_longs_in_range( + Family::TUPLE.min_pre_longs..=Family::TUPLE.max_pre_longs, + pre_longs, + )?; + if ser_ver != serialization::SERIAL_VERSION + && ser_ver != serialization::SERIAL_VERSION_LEGACY + { + return Err(Error::deserial(format!( + "unsupported serial version: expected {} or {}, got {ser_ver}", + serialization::SERIAL_VERSION, + serialization::SERIAL_VERSION_LEGACY, + ))); + } + if sketch_type != serialization::SKETCH_TYPE + && sketch_type != serialization::SKETCH_TYPE_LEGACY + { + return Err(Error::deserial(format!( + "unsupported sketch type: expected {} or {}, got {sketch_type}", + serialization::SKETCH_TYPE, + serialization::SKETCH_TYPE_LEGACY, + ))); + } + + let empty = (flags & serialization::FLAGS_IS_EMPTY) != 0; + let ordered = (flags & serialization::FLAGS_IS_ORDERED) != 0; + + if empty { + return Ok(Self::from_parts( + Vec::new(), + MAX_THETA, + seed_hash, + ordered, + true, + )); + } + + let expected_seed_hash = compute_seed_hash(seed); + if seed_hash != expected_seed_hash { + return Err(Error::deserial(format!( + "incompatible seed hash: expected {expected_seed_hash}, got {seed_hash}", + ))); + } + + let mut theta = MAX_THETA; + let num_entries = if pre_longs == 1 { + 1usize + } else { + let n = cursor + .read_u32_le() + .map_err(insufficient_data("num_entries"))? as usize; + cursor + .read_u32_le() + .map_err(insufficient_data(""))?; + if pre_longs > 2 { + theta = cursor.read_u64_le().map_err(insufficient_data("theta"))?; + } + n + }; + + let mut entries = Vec::with_capacity(num_entries); + for _ in 0..num_entries { + let hash = cursor + .read_u64_le() + .map_err(insufficient_data("entry_hash"))?; + if hash == 0 || hash >= theta { + return Err(Error::deserial("corrupted: invalid retained hash value")); + } + let (summary, consumed) = serde.deserialize(cursor.remaining())?; + cursor.advance(consumed as u64); + entries.push((hash, summary)); + } + + Ok(Self::from_parts(entries, theta, seed_hash, ordered, false)) + } +} + +impl TupleSketchView for CompactTupleSketch { + fn seed_hash(&self) -> u16 { + CompactTupleSketch::seed_hash(self) + } + + fn theta64(&self) -> u64 { + CompactTupleSketch::theta64(self) + } + + fn is_empty(&self) -> bool { + CompactTupleSketch::is_empty(self) + } + + fn iter<'a>(&'a self) -> impl Iterator + 'a + where + S: 'a, + { + CompactTupleSketch::iter(self) + } + + fn num_retained(&self) -> usize { + CompactTupleSketch::num_retained(self) + } + + fn is_ordered(&self) -> bool { + CompactTupleSketch::is_ordered(self) + } +} + +/// Builder for [`UpdatableTupleSketch`]. +/// +/// The summary type `S` is fixed when the builder is created (for example via +/// `UpdatableTupleSketch::::builder()`), and the policy type `P` defaults to +/// [`DefaultUpdatePolicy`]. Use [`policy`](Self::policy) to supply a custom policy. +#[derive(Debug)] +pub struct UpdatableTupleSketchBuilder { + lg_k: u8, + resize_factor: ResizeFactor, + sampling_probability: f32, + seed: u64, + policy: P, + _marker: PhantomData S>, +} + +impl Default for UpdatableTupleSketchBuilder { + fn default() -> Self { + Self { + lg_k: DEFAULT_LG_K, + resize_factor: ResizeFactor::X8, + sampling_probability: 1.0, + seed: DEFAULT_UPDATE_SEED, + policy: DefaultUpdatePolicy, + _marker: PhantomData, + } + } +} + +impl UpdatableTupleSketchBuilder { + /// Sets lg_k (log2 of the nominal size k). + /// + /// # Panics + /// + /// Panics if lg_k is not in range [5, 26]. + pub fn lg_k(mut self, lg_k: u8) -> Self { + assert!( + (MIN_LG_K..=MAX_LG_K).contains(&lg_k), + "lg_k must be in [{MIN_LG_K}, {MAX_LG_K}], got {lg_k}" + ); + self.lg_k = lg_k; + self + } + + /// Sets the resize factor. + pub fn resize_factor(mut self, factor: ResizeFactor) -> Self { + self.resize_factor = factor; + self + } + + /// Sets the sampling probability p. + /// + /// # Panics + /// + /// Panics if p is not in range `(0.0, 1.0]`. + pub fn sampling_probability(mut self, probability: f32) -> Self { + assert!( + (0.0..=1.0).contains(&probability) && probability > 0.0, + "sampling_probability must be in (0.0, 1.0], got {probability}" + ); + self.sampling_probability = probability; + self + } + + /// Sets the hash seed. + pub fn seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } + + /// Sets a custom update policy, changing the builder's policy type. + /// + /// # Examples + /// + /// ``` + /// use datasketches::tuple::SummaryUpdatePolicy; + /// use datasketches::tuple::UpdatableTupleSketch; + /// + /// #[derive(Default)] + /// struct MaxPolicy; + /// impl SummaryUpdatePolicy for MaxPolicy { + /// fn create(&self) -> u64 { + /// 0 + /// } + /// fn update(&self, summary: &mut u64, value: u64) { + /// *summary = (*summary).max(value); + /// } + /// } + /// + /// let mut sketch = UpdatableTupleSketch::::builder() + /// .policy(MaxPolicy) + /// .build(); + /// sketch.update("k", 3); + /// sketch.update("k", 7); + /// ``` + pub fn policy(self, policy: P2) -> UpdatableTupleSketchBuilder { + UpdatableTupleSketchBuilder { + lg_k: self.lg_k, + resize_factor: self.resize_factor, + sampling_probability: self.sampling_probability, + seed: self.seed, + policy, + _marker: PhantomData, + } + } + + /// Builds the [`UpdatableTupleSketch`]. + pub fn build(self) -> UpdatableTupleSketch { + let table = TupleHashTable::new( + self.lg_k, + self.resize_factor, + self.sampling_probability, + self.seed, + ); + UpdatableTupleSketch { + table, + policy: self.policy, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::error::ErrorKind; + use crate::tuple::policy::SummaryUpdatePolicy; + use crate::tuple::serde::PrimitiveSummarySerde; + + fn sorted_updatable_entries(sketch: &UpdatableTupleSketch) -> Vec<(u64, u64)> { + let mut entries: Vec<(u64, u64)> = sketch.iter().map(|(h, &s)| (h, s)).collect(); + entries.sort_unstable(); + entries + } + + fn sorted_compact_entries(sketch: &CompactTupleSketch) -> Vec<(u64, u64)> { + let mut entries: Vec<(u64, u64)> = sketch.iter().map(|(h, &s)| (h, s)).collect(); + entries.sort_unstable(); + entries + } + + fn assert_updatable_and_compact_equivalent( + updatable: &UpdatableTupleSketch, + compact: &CompactTupleSketch, + ) { + assert_eq!(updatable.is_empty(), compact.is_empty()); + assert_eq!(updatable.is_estimation_mode(), compact.is_estimation_mode()); + assert_eq!(updatable.num_retained(), compact.num_retained()); + assert_eq!(updatable.theta64(), compact.theta64()); + assert_eq!(updatable.seed_hash(), compact.seed_hash()); + assert_eq!( + sorted_updatable_entries(updatable), + sorted_compact_entries(compact) + ); + assert!((updatable.estimate() - compact.estimate()).abs() <= 1e-9); + } + + #[test] + fn exact_mode_updatable_and_compact_equivalent() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + for i in 0..2000 { + sketch.update(i, 1u64); + } + assert!(!sketch.is_estimation_mode()); + + for ordered in [false, true] { + let compact = sketch.compact(ordered); + assert_updatable_and_compact_equivalent(&sketch, &compact); + if compact.num_retained() > 1 { + assert_eq!(compact.is_ordered(), ordered); + } + } + } + + #[test] + fn estimation_mode_updatable_and_compact_equivalent() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(5).build(); + for i in 0..5000 { + sketch.update(i, 1u64); + } + assert!(sketch.is_estimation_mode()); + + for ordered in [false, true] { + let compact = sketch.compact(ordered); + assert_updatable_and_compact_equivalent(&sketch, &compact); + } + } + + #[test] + fn summaries_accumulate_with_default_policy() { + let mut sketch = UpdatableTupleSketch::::builder().build(); + for _ in 0..5 { + sketch.update("same_key", 2u64); + } + assert_eq!(sketch.num_retained(), 1); + let entries = sorted_updatable_entries(&sketch); + assert_eq!(entries.len(), 1); + assert_eq!(entries[0].1, 10); // 5 updates of 2 + + // Summaries survive the compaction. + let compact = sketch.compact(true); + assert_eq!(sorted_compact_entries(&compact)[0].1, 10); + } + + #[test] + fn empty_sketch_is_ordered_and_zero_estimate() { + let sketch = UpdatableTupleSketch::::builder().build(); + assert!(sketch.is_empty()); + assert_eq!(sketch.estimate(), 0.0); + + let compact = sketch.compact(false); + assert!(compact.is_empty()); + assert!(compact.is_ordered()); + assert_eq!(compact.estimate(), 0.0); + assert_eq!(compact.theta64(), MAX_THETA); + } + + #[test] + fn bounds_bracket_estimate_in_estimation_mode() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + for i in 0..10000 { + sketch.update(i, 1u64); + } + let estimate = sketch.estimate(); + let lower = sketch.lower_bound(NumStdDev::Two); + let upper = sketch.upper_bound(NumStdDev::Two); + assert!(lower <= estimate); + assert!(estimate <= upper); + } + + #[derive(Default)] + struct MaxPolicy; + + impl SummaryUpdatePolicy for MaxPolicy { + fn create(&self) -> u64 { + 0 + } + + fn update(&self, summary: &mut u64, value: u64) { + *summary = (*summary).max(value); + } + } + + #[test] + fn custom_policy_drives_summary_behavior() { + let mut sketch = UpdatableTupleSketch::::builder() + .policy(MaxPolicy) + .build(); + sketch.update("k", 3u64); + sketch.update("k", 7u64); + sketch.update("k", 2u64); + + assert_eq!(sketch.num_retained(), 1); + let entries = sorted_updatable_entries_generic(&sketch); + assert_eq!(entries[0].1, 7); + } + + fn sorted_updatable_entries_generic

( + sketch: &UpdatableTupleSketch, + ) -> Vec<(u64, u64)> { + let mut entries: Vec<(u64, u64)> = sketch.iter().map(|(h, &s)| (h, s)).collect(); + entries.sort_unstable(); + entries + } + + fn view_num_retained>(view: &V) -> usize { + view.num_retained() + } + + fn view_summary_sum>(view: &V) -> u64 { + view.iter().map(|(_, &summary)| summary).sum() + } + + fn view_is_ordered>(view: &V) -> bool { + view.is_ordered() + } + + #[test] + fn view_trait_accepts_both_sketch_types() { + let mut sketch = UpdatableTupleSketch::::builder().build(); + for i in 0..100 { + sketch.update(i, 2u64); + } + let compact = sketch.compact(true); + + // Both sketch types are accepted through the shared view trait. + assert_eq!(view_num_retained(&sketch), 100); + assert_eq!(view_num_retained(&compact), 100); + assert_eq!(view_summary_sum(&sketch), view_summary_sum(&compact)); + assert_eq!(view_summary_sum(&compact), 200); // 100 keys * 2 + + // Updatable is unordered by default; compact(true) reports ordered. + assert!(!view_is_ordered(&sketch)); + assert!(view_is_ordered(&compact)); + } + + fn assert_compact_round_trip(original: &CompactTupleSketch) { + let serde = PrimitiveSummarySerde; + let bytes = original.serialize(&serde); + let restored = CompactTupleSketch::::deserialize(&bytes, &serde).unwrap(); + assert_eq!(original.is_empty(), restored.is_empty()); + assert_eq!(original.is_ordered(), restored.is_ordered()); + assert_eq!(original.theta64(), restored.theta64()); + assert_eq!(original.seed_hash(), restored.seed_hash()); + assert_eq!(original.num_retained(), restored.num_retained()); + assert_eq!( + sorted_compact_entries(original), + sorted_compact_entries(&restored) + ); + } + + #[test] + fn serialize_round_trip_exact_mode() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + for i in 0..2000 { + sketch.update(i, 1u64); + } + assert!(!sketch.is_estimation_mode()); + assert_compact_round_trip(&sketch.compact(true)); + assert_compact_round_trip(&sketch.compact(false)); + } + + #[test] + fn serialize_round_trip_estimation_mode() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(5).build(); + for i in 0..5000 { + sketch.update(i, 3u64); + } + let compact = sketch.compact(true); + assert!(compact.is_estimation_mode()); + assert_compact_round_trip(&compact); + assert_compact_round_trip(&sketch.compact(false)); + } + + #[test] + fn serialize_round_trip_empty() { + let sketch = UpdatableTupleSketch::::builder().build(); + let compact = sketch.compact(true); + assert!(compact.is_empty()); + assert_compact_round_trip(&compact); + } + + #[test] + fn serialize_round_trip_single_entry() { + let mut sketch = UpdatableTupleSketch::::builder().build(); + sketch.update("only", 42u64); + let compact = sketch.compact(true); + assert_eq!(compact.num_retained(), 1); + + let serde = PrimitiveSummarySerde; + let bytes = compact.serialize(&serde); + // A single-entry exact sketch uses a 1-long preamble. + assert_eq!(bytes[0], 1); + + let restored = CompactTupleSketch::::deserialize(&bytes, &serde).unwrap(); + assert_eq!(restored.num_retained(), 1); + assert_eq!(restored.iter().next().unwrap().1, &42); + } + + #[test] + fn serialize_header_fields_match_tuple_format() { + let mut sketch = UpdatableTupleSketch::::builder().build(); + for i in 0..100 { + sketch.update(i, 1u64); + } + let bytes = sketch.compact(true).serialize(&PrimitiveSummarySerde); + assert_eq!(bytes[0], 2); // preamble longs (exact, multi-entry) + assert_eq!(bytes[1], 3); // serial version + assert_eq!(bytes[2], 9); // TUPLE family id + assert_eq!(bytes[3], 1); // sketch type + } + + #[test] + fn serialize_preserves_summaries() { + let mut sketch = UpdatableTupleSketch::::builder().build(); + for i in 0..50 { + sketch.update(i, 1u64); + sketch.update(i, 1u64); // each summary accumulates to 2 + } + let serde = PrimitiveSummarySerde; + let bytes = sketch.compact(true).serialize(&serde); + let restored = CompactTupleSketch::::deserialize(&bytes, &serde).unwrap(); + assert_eq!(restored.num_retained(), 50); + assert!(restored.iter().all(|(_, &s)| s == 2)); + } + + #[test] + fn deserialize_rejects_wrong_family() { + let mut sketch = UpdatableTupleSketch::::builder().build(); + for i in 0..10 { + sketch.update(i, 1u64); + } + let mut bytes = sketch.compact(true).serialize(&PrimitiveSummarySerde); + bytes[2] = 3; // pretend it is a THETA sketch + let err = + CompactTupleSketch::::deserialize(&bytes, &PrimitiveSummarySerde).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::InvalidData); + } + + #[test] + fn deserialize_rejects_seed_mismatch() { + let mut sketch = UpdatableTupleSketch::::builder().build(); + for i in 0..10 { + sketch.update(i, 1u64); + } + let bytes = sketch.compact(true).serialize(&PrimitiveSummarySerde); + let err = + CompactTupleSketch::::deserialize_with_seed(&bytes, 999, &PrimitiveSummarySerde) + .unwrap_err(); + assert_eq!(err.kind(), ErrorKind::InvalidData); + } + + #[test] + fn deserialize_rejects_truncated_summary() { + let mut sketch = UpdatableTupleSketch::::builder().build(); + for i in 0..100 { + sketch.update(i, 1u64); + } + let bytes = sketch.compact(true).serialize(&PrimitiveSummarySerde); + let truncated = &bytes[..bytes.len() - 4]; // cut the last summary in half + let err = + CompactTupleSketch::::deserialize(truncated, &PrimitiveSummarySerde).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::InvalidData); + } +} diff --git a/datasketches/src/tuple/union.rs b/datasketches/src/tuple/union.rs new file mode 100644 index 0000000..50acd85 --- /dev/null +++ b/datasketches/src/tuple/union.rs @@ -0,0 +1,495 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Tuple sketch union. +//! +//! [`TupleUnion`] computes the union (set OR) of any number of Tuple sketches. Like the Theta union +//! it keeps an internal "gadget" hash table plus a running `union_theta`, and updates it from each +//! input sketch. The only Tuple-specific behavior is that when an incoming key already exists in +//! the gadget, the two summaries are combined with a [`SummaryCombinePolicy`] instead of one being +//! dropped. + +use std::marker::PhantomData; + +use crate::common::ResizeFactor; +use crate::error::Error; +use crate::hash::DEFAULT_UPDATE_SEED; +use crate::theta::DEFAULT_LG_K; +use crate::theta::MAX_LG_K; +use crate::theta::MIN_LG_K; +use crate::tuple::hash_table::TupleHashTable; +use crate::tuple::policy::DefaultUnionPolicy; +use crate::tuple::policy::SummaryCombinePolicy; +use crate::tuple::sketch::CompactTupleSketch; +use crate::tuple::sketch::TupleSketchView; + +/// Union (set OR) of Tuple sketches. +/// +/// `S` is the summary type and `P` is the [`SummaryCombinePolicy`] applied when a key is present in +/// more than one input. For additive summaries the default [`DefaultUnionPolicy`] is used. +/// +/// # Examples +/// +/// ``` +/// # use datasketches::tuple::{TupleUnion, UpdatableTupleSketch}; +/// let mut a = UpdatableTupleSketch::::builder().build(); +/// a.update("apple", 1); +/// a.update("banana", 1); +/// +/// let mut b = UpdatableTupleSketch::::builder().build(); +/// b.update("banana", 1); +/// b.update("cherry", 1); +/// +/// let mut union = TupleUnion::::builder().build(); +/// union.update(&a).unwrap(); +/// union.update(&b).unwrap(); +/// +/// let result = union.result(true); +/// assert_eq!(result.num_retained(), 3); // apple, banana, cherry +/// ``` +#[derive(Debug)] +pub struct TupleUnion { + table: TupleHashTable, + union_theta: u64, + policy: P, +} + +impl TupleUnion { + /// Creates a new builder using the default union policy. + /// + /// # Examples + /// + /// ``` + /// # use datasketches::tuple::TupleUnion; + /// let union = TupleUnion::::builder().lg_k(12).build(); + /// ``` + pub fn builder() -> TupleUnionBuilder { + TupleUnionBuilder::default() + } +} + +impl TupleUnion { + /// Merges a sketch into the union. + /// + /// Accepts either an [`UpdatableTupleSketch`](crate::tuple::UpdatableTupleSketch) or a + /// [`CompactTupleSketch`] through the shared [`TupleSketchView`] trait. Keys present in both + /// the running union and `sketch` have their summaries combined via the union policy. + /// + /// # Errors + /// + /// Returns an error if `sketch` was built with a different seed than this union (its seed hash + /// does not match). + pub fn update(&mut self, sketch: &V) -> Result<(), Error> + where + V: TupleSketchView, + P: SummaryCombinePolicy, + S: Clone, + { + if sketch.is_empty() { + return Ok(()); + } + if sketch.seed_hash() != self.table.seed_hash() { + return Err(Error::invalid_argument(format!( + "incompatible seed hash: expected {}, got {}", + self.table.seed_hash(), + sketch.seed_hash() + ))); + } + + // Any non-empty input makes the union non-empty, even if every key is screened out by + // theta. + self.table.set_empty(false); + self.union_theta = self.union_theta.min(sketch.theta64()); + + let ordered = sketch.is_ordered(); + let policy = &self.policy; + for (hash, summary) in sketch.iter() { + // A key contributes only if it is below both the running union theta and the gadget + // theta (the gadget theta can drop while we insert). For an ordered input, + // the first key that fails this test means all remaining (larger) keys fail + // too, so we can stop early. + if hash >= self.union_theta || hash >= self.table.theta() { + if ordered { + break; + } + continue; + } + self.table.upsert(hash, |existing| match existing { + Some(existing_summary) => { + policy.combine(existing_summary, summary); + None + } + None => Some(summary.clone()), + }); + } + + // A rebuild during the inserts above may have lowered the gadget theta below union_theta. + self.union_theta = self.union_theta.min(self.table.theta()); + Ok(()) + } + + /// Returns the union as a [`CompactTupleSketch`]. + /// + /// If `ordered` is true, retained entries are sorted ascending by hash. + pub fn result(&self, ordered: bool) -> CompactTupleSketch + where + S: Clone, + { + let seed_hash = self.table.seed_hash(); + if self.table.is_empty() { + return CompactTupleSketch::from_parts( + Vec::new(), + self.union_theta, + seed_hash, + true, + true, + ); + } + + let mut theta = self.union_theta.min(self.table.theta()); + let mut entries: Vec<(u64, S)> = self + .table + .iter() + .filter(|(hash, _)| *hash < theta) + .map(|(hash, summary)| (hash, summary.clone())) + .collect(); + + // Trim down to the nominal size k, lowering theta to the k-th smallest hash if needed. + let nominal = 1usize << self.table.lg_nom_size(); + if entries.len() > nominal { + let (_, kth, _) = entries.select_nth_unstable_by_key(nominal, |(hash, _)| *hash); + theta = kth.0; + entries.truncate(nominal); + } + + if ordered { + entries.sort_unstable_by_key(|(hash, _)| *hash); + } + + CompactTupleSketch::from_parts(entries, theta, seed_hash, ordered, false) + } + + /// Resets the union to its initial empty state. + pub fn reset(&mut self) { + self.table.reset(); + self.union_theta = self.table.theta(); + } +} + +/// Builder for [`TupleUnion`]. +/// +/// The summary type `S` is fixed when the builder is created (for example via +/// `TupleUnion::::builder()`), and the policy type `P` defaults to [`DefaultUnionPolicy`]. Use +/// [`policy`](Self::policy) to supply a custom policy. +#[derive(Debug)] +pub struct TupleUnionBuilder { + lg_k: u8, + resize_factor: ResizeFactor, + sampling_probability: f32, + seed: u64, + policy: P, + _marker: PhantomData S>, +} + +impl Default for TupleUnionBuilder { + fn default() -> Self { + Self { + lg_k: DEFAULT_LG_K, + resize_factor: ResizeFactor::X8, + sampling_probability: 1.0, + seed: DEFAULT_UPDATE_SEED, + policy: DefaultUnionPolicy, + _marker: PhantomData, + } + } +} + +impl TupleUnionBuilder { + /// Sets lg_k (log2 of the nominal size k). + /// + /// # Panics + /// + /// Panics if lg_k is not in range [5, 26]. + pub fn lg_k(mut self, lg_k: u8) -> Self { + assert!( + (MIN_LG_K..=MAX_LG_K).contains(&lg_k), + "lg_k must be in [{MIN_LG_K}, {MAX_LG_K}], got {lg_k}" + ); + self.lg_k = lg_k; + self + } + + /// Sets the resize factor. + pub fn resize_factor(mut self, factor: ResizeFactor) -> Self { + self.resize_factor = factor; + self + } + + /// Sets the sampling probability p. + /// + /// # Panics + /// + /// Panics if p is not in range `(0.0, 1.0]`. + pub fn sampling_probability(mut self, probability: f32) -> Self { + assert!( + (0.0..=1.0).contains(&probability) && probability > 0.0, + "sampling_probability must be in (0.0, 1.0], got {probability}" + ); + self.sampling_probability = probability; + self + } + + /// Sets the hash seed. + pub fn seed(mut self, seed: u64) -> Self { + self.seed = seed; + self + } + + /// Sets a custom union policy, changing the builder's policy type. + pub fn policy(self, policy: P2) -> TupleUnionBuilder { + TupleUnionBuilder { + lg_k: self.lg_k, + resize_factor: self.resize_factor, + sampling_probability: self.sampling_probability, + seed: self.seed, + policy, + _marker: PhantomData, + } + } + + /// Builds the [`TupleUnion`]. + pub fn build(self) -> TupleUnion { + let table = TupleHashTable::new( + self.lg_k, + self.resize_factor, + self.sampling_probability, + self.seed, + ); + let union_theta = table.theta(); + TupleUnion { + table, + union_theta, + policy: self.policy, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::common::NumStdDev; + use crate::error::ErrorKind; + use crate::tuple::UpdatableTupleSketch; + use crate::tuple::policy::SummaryCombinePolicy; + + fn sorted_entries(sketch: &CompactTupleSketch) -> Vec<(u64, u64)> { + let mut entries: Vec<(u64, u64)> = sketch.iter().map(|(h, &s)| (h, s)).collect(); + entries.sort_unstable(); + entries + } + + #[test] + fn union_of_disjoint_sketches_sums_cardinality() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..1000 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 1000..2000 { + b.update(i, 1u64); + } + + let mut union = TupleUnion::::builder().build(); + union.update(&a).unwrap(); + union.update(&b).unwrap(); + + let result = union.result(true); + // 2000 distinct keys < k (4096), so the union is in exact mode. + assert!(!result.is_estimation_mode()); + assert_eq!(result.num_retained(), 2000); + assert_eq!(result.estimate(), 2000.0); + // Every summary stays at 1 because the inputs are disjoint. + assert!(result.iter().all(|(_, &s)| s == 1)); + } + + #[test] + fn union_combines_overlapping_summaries() { + let mut a = UpdatableTupleSketch::::builder().build(); + a.update("shared", 3u64); + a.update("only_a", 1u64); + let mut b = UpdatableTupleSketch::::builder().build(); + b.update("shared", 4u64); + b.update("only_b", 1u64); + + let mut union = TupleUnion::::builder().build(); + union.update(&a).unwrap(); + union.update(&b).unwrap(); + + let result = union.result(true); + assert_eq!(result.num_retained(), 3); + + // The shared key's summary is the default-policy sum (3 + 4 = 7); the rest are 1. + let summaries: Vec = sorted_entries(&result) + .into_iter() + .map(|(_, s)| s) + .collect(); + let mut sorted = summaries.clone(); + sorted.sort_unstable(); + assert_eq!(sorted, vec![1, 1, 7]); + } + + #[test] + fn union_result_is_order_independent() { + let mut a = UpdatableTupleSketch::::builder().build(); + a.update("shared", 3u64); + a.update("only_a", 5u64); + let mut b = UpdatableTupleSketch::::builder().build(); + b.update("shared", 4u64); + b.update("only_b", 6u64); + + let mut a_then_b = TupleUnion::::builder().build(); + a_then_b.update(&a).unwrap(); + a_then_b.update(&b).unwrap(); + + let mut b_then_a = TupleUnion::::builder().build(); + b_then_a.update(&b).unwrap(); + b_then_a.update(&a).unwrap(); + + assert_eq!( + sorted_entries(&a_then_b.result(true)), + sorted_entries(&b_then_a.result(true)) + ); + } + + #[test] + fn union_accepts_updatable_and_compact_inputs() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..500 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().build(); + for i in 250..750 { + b.update(i, 1u64); + } + let b_compact = b.compact(true); + + let mut union = TupleUnion::::builder().build(); + union.update(&a).unwrap(); // updatable input + union.update(&b_compact).unwrap(); // compact input + + let result = union.result(true); + assert_eq!(result.num_retained(), 750); // 0..750 distinct + } + + #[test] + fn union_of_empty_inputs_is_empty() { + let empty = UpdatableTupleSketch::::builder().build(); + + let mut union = TupleUnion::::builder().build(); + union.update(&empty).unwrap(); + + let result = union.result(true); + assert!(result.is_empty()); + assert!(result.is_ordered()); + assert_eq!(result.estimate(), 0.0); + assert_eq!(result.num_retained(), 0); + } + + #[test] + fn union_never_updated_is_empty() { + let union = TupleUnion::::builder().build(); + let result = union.result(true); + assert!(result.is_empty()); + assert_eq!(result.estimate(), 0.0); + } + + #[test] + fn union_rejects_seed_mismatch() { + let mut a = UpdatableTupleSketch::::builder().seed(1).build(); + a.update("k", 1u64); + + let mut union = TupleUnion::::builder().seed(2).build(); + let err = union.update(&a).unwrap_err(); + assert_eq!(err.kind(), ErrorKind::InvalidArgument); + } + + #[test] + fn union_in_estimation_mode_estimates_within_bounds() { + let mut a = UpdatableTupleSketch::::builder().lg_k(8).build(); + for i in 0..50000 { + a.update(i, 1u64); + } + let mut b = UpdatableTupleSketch::::builder().lg_k(8).build(); + for i in 25000..75000 { + b.update(i, 1u64); + } + + let mut union = TupleUnion::::builder().lg_k(8).build(); + union.update(&a).unwrap(); + union.update(&b).unwrap(); + + let result = union.result(true); + assert!(result.is_estimation_mode()); + // Union of 0..75000 distinct keys. + let lower = result.lower_bound(NumStdDev::Three); + let upper = result.upper_bound(NumStdDev::Three); + assert!( + lower <= 75000.0 && 75000.0 <= upper, + "expected 75000 in [{lower}, {upper}]" + ); + } + + #[derive(Debug, Default, Clone, Copy)] + struct MaxUnionPolicy; + + impl SummaryCombinePolicy for MaxUnionPolicy { + fn combine(&self, summary: &mut u64, other: &u64) { + *summary = (*summary).max(*other); + } + } + + #[test] + fn union_uses_custom_combine_policy() { + let mut a = UpdatableTupleSketch::::builder().build(); + a.update("shared", 3u64); + let mut b = UpdatableTupleSketch::::builder().build(); + b.update("shared", 9u64); + + let mut union = TupleUnion::::builder().policy(MaxUnionPolicy).build(); + union.update(&a).unwrap(); + union.update(&b).unwrap(); + + let result = union.result(true); + assert_eq!(result.num_retained(), 1); + assert_eq!(result.iter().next().unwrap().1, &9); // max(3, 9) + } + + #[test] + fn union_reset_clears_state() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..100 { + a.update(i, 1u64); + } + + let mut union = TupleUnion::::builder().build(); + union.update(&a).unwrap(); + assert!(!union.result(true).is_empty()); + + union.reset(); + assert!(union.result(true).is_empty()); + } +} diff --git a/datasketches/tests/tuple_intersection_test.rs b/datasketches/tests/tuple_intersection_test.rs new file mode 100644 index 0000000..6ce3e8b --- /dev/null +++ b/datasketches/tests/tuple_intersection_test.rs @@ -0,0 +1,340 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Behavioral tests for the Tuple intersection, mirroring `theta_intersection_test.rs`. +//! +//! Unlike Theta, a Tuple intersection requires an explicit [`SummaryCombinePolicy`] for keys that +//! appear in more than one input. These tests use a `u64` summary and a summing policy, so the +//! distinct-count behavior matches the Theta intersection. + +#![cfg(feature = "tuple")] + +use datasketches::tuple::CompactTupleSketch; +use datasketches::tuple::PrimitiveSummarySerde; +use datasketches::tuple::SummaryCombinePolicy; +use datasketches::tuple::TupleIntersection; +use datasketches::tuple::UpdatableTupleSketch; + +#[derive(Debug, Default, Clone, Copy)] +struct SumPolicy; + +impl SummaryCombinePolicy for SumPolicy { + fn combine(&self, summary: &mut u64, other: &u64) { + *summary += *other; + } +} + +fn sketch_with_range(start: u64, count: u64) -> UpdatableTupleSketch { + let mut sketch = UpdatableTupleSketch::::builder().build(); + for i in 0..count { + sketch.update(start + i, 1u64); + } + sketch +} + +#[test] +fn test_has_result_state_machine() { + let mut a = UpdatableTupleSketch::::builder().build(); + a.update("x", 1u64); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + assert!(!i.has_result()); + i.update(&a).unwrap(); + assert!(i.has_result()); + assert!(i.result().estimate() >= 1.0); +} + +#[test] +fn test_result_before_update_panics() { + let i = TupleIntersection::::new(123, SumPolicy); + let result = std::panic::catch_unwind(|| { + let _ = i.result(); + }); + assert!(result.is_err()); +} + +#[test] +fn test_update_accepts_compact_sketch() { + let mut a = UpdatableTupleSketch::::builder().build(); + a.update("x", 1u64); + a.update("y", 1u64); + + let mut b = UpdatableTupleSketch::::builder().build(); + b.update("y", 1u64); + b.update("z", 1u64); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&a.compact(true)).unwrap(); + i.update(&b).unwrap(); + + let r = i.result(); + assert!(r.estimate() == 1.0); + assert!(r.is_ordered()); + + let mut c = UpdatableTupleSketch::::builder().build(); + c.update("a", 1u64); + c.update("b", 1u64); + c.update("c", 1u64); + + i.update(&c.compact(false)).unwrap(); + + let r = i.result_with_ordered(false); + assert!(r.estimate() == 0.0); + assert!(!r.is_ordered()); +} + +#[test] +fn test_seed_mismatch_behaviour_for_empty_sketch() { + let empty_other_seed = UpdatableTupleSketch::::builder().seed(2).build(); + let mut i = TupleIntersection::::new(1, SumPolicy); + + i.update(&empty_other_seed).unwrap(); + assert!(i.has_result()); + let r = i.result(); + assert!(r.is_empty()); +} + +#[test] +fn test_seed_mismatch_behaviour() { + let mut one_other_seed = UpdatableTupleSketch::::builder().seed(2).build(); + one_other_seed.update("value", 1u64); + let mut i = TupleIntersection::::new(1, SumPolicy); + + assert!(i.update(&one_other_seed).is_err()); +} + +#[test] +fn test_terminal_empty_state_ignores_future_updates() { + let empty = UpdatableTupleSketch::::builder().build(); + + let mut non_empty = UpdatableTupleSketch::::builder().build(); + non_empty.update("x", 1u64); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&empty).unwrap(); + i.update(&non_empty).unwrap(); + + let r = i.result(); + assert!(r.is_empty()); +} + +#[test] +fn test_result_with_ordered_false_is_not_ordered() { + let mut a = UpdatableTupleSketch::::builder().build(); + for i in 0..64 { + a.update(i, 1u64); + } + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&a).unwrap(); + + let r = i.result_with_ordered(false); + assert!(!r.is_ordered()); +} + +#[test] +fn test_empty_update_twice() { + let empty = UpdatableTupleSketch::::builder().build(); + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + + i.update(&empty).unwrap(); + let r1 = i.result(); + assert_eq!(r1.num_retained(), 0); + assert!(r1.is_empty()); + assert!(!r1.is_estimation_mode()); + assert_eq!(r1.estimate(), 0.0); + + i.update(&empty).unwrap(); + let r2 = i.result(); + assert_eq!(r2.num_retained(), 0); + assert!(r2.is_empty()); + assert!(!r2.is_estimation_mode()); + assert_eq!(r2.estimate(), 0.0); +} + +#[test] +fn test_non_empty_no_retained_keys() { + let mut s = UpdatableTupleSketch::::builder() + .sampling_probability(0.001) + .build(); + s.update(1u64, 1u64); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&s).unwrap(); + let r1 = i.result(); + assert_eq!(r1.num_retained(), 0); + assert!(!r1.is_empty()); + assert!(r1.is_estimation_mode()); + assert!((r1.theta() - 0.001).abs() < 1e-10); + assert_eq!(r1.estimate(), 0.0); + + i.update(&s).unwrap(); + let r2 = i.result(); + assert_eq!(r2.num_retained(), 0); + assert!(!r2.is_empty()); + assert!(r2.is_estimation_mode()); + assert!((r2.theta() - 0.001).abs() < 1e-10); + assert_eq!(r2.estimate(), 0.0); +} + +#[test] +fn test_exact_half_overlap_unordered() { + let s1 = sketch_with_range(0, 1000); + let s2 = sketch_with_range(500, 1000); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&s1).unwrap(); + i.update(&s2).unwrap(); + let r = i.result(); + + assert!(!r.is_empty()); + assert!(!r.is_estimation_mode()); + assert_eq!(r.estimate(), 500.0); +} + +#[test] +fn test_exact_half_overlap_ordered() { + let s1 = sketch_with_range(0, 1000); + let s2 = sketch_with_range(500, 1000); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&s1.compact(true)).unwrap(); + i.update(&s2.compact(true)).unwrap(); + let r = i.result(); + + assert!(!r.is_empty()); + assert!(!r.is_estimation_mode()); + assert_eq!(r.estimate(), 500.0); +} + +#[test] +fn test_exact_disjoint_unordered() { + let s1 = sketch_with_range(0, 1000); + let s2 = sketch_with_range(1000, 1000); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&s1).unwrap(); + i.update(&s2).unwrap(); + let r = i.result(); + + assert!(r.is_empty()); + assert!(!r.is_estimation_mode()); + assert_eq!(r.estimate(), 0.0); +} + +#[test] +fn test_exact_disjoint_ordered() { + let s1 = sketch_with_range(0, 1000); + let s2 = sketch_with_range(1000, 1000); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&s1.compact(true)).unwrap(); + i.update(&s2.compact(true)).unwrap(); + let r = i.result(); + + assert!(r.is_empty()); + assert!(!r.is_estimation_mode()); + assert_eq!(r.estimate(), 0.0); +} + +#[test] +fn test_estimation_half_overlap_unordered() { + let s1 = sketch_with_range(0, 10000); + let s2 = sketch_with_range(5000, 10000); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&s1).unwrap(); + i.update(&s2).unwrap(); + let r = i.result(); + + assert!(!r.is_empty()); + assert!(r.is_estimation_mode()); + assert!((r.estimate() - 5000.0).abs() <= 5000.0 * 0.02); +} + +#[test] +fn test_estimation_half_overlap_ordered() { + let s1 = sketch_with_range(0, 10000); + let s2 = sketch_with_range(5000, 10000); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&s1.compact(true)).unwrap(); + i.update(&s2.compact(true)).unwrap(); + let r = i.result(); + + assert!(!r.is_empty()); + assert!(r.is_estimation_mode()); + assert!((r.estimate() - 5000.0).abs() <= 5000.0 * 0.02); +} + +#[test] +fn test_estimation_half_overlap_ordered_deserialized_compact() { + let serde = PrimitiveSummarySerde; + let s1 = sketch_with_range(0, 10000); + let s2 = sketch_with_range(5000, 10000); + let c1 = CompactTupleSketch::::deserialize(&s1.compact(true).serialize(&serde), &serde) + .unwrap(); + let c2 = CompactTupleSketch::::deserialize(&s2.compact(true).serialize(&serde), &serde) + .unwrap(); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&c1).unwrap(); + i.update(&c2).unwrap(); + let r = i.result(); + + assert!(!r.is_empty()); + assert!(r.is_estimation_mode()); + assert!((r.estimate() - 5000.0).abs() <= 5000.0 * 0.02); +} + +#[test] +fn test_estimation_disjoint_unordered() { + let s1 = sketch_with_range(0, 10000); + let s2 = sketch_with_range(10000, 10000); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&s1).unwrap(); + i.update(&s2).unwrap(); + let r = i.result(); + + assert!(!r.is_empty()); + assert!(r.is_estimation_mode()); + assert_eq!(r.estimate(), 0.0); +} + +#[test] +fn test_estimation_disjoint_ordered() { + let s1 = sketch_with_range(0, 10000); + let s2 = sketch_with_range(10000, 10000); + + let mut i = TupleIntersection::::new_with_default_seed(SumPolicy); + i.update(&s1.compact(true)).unwrap(); + i.update(&s2.compact(true)).unwrap(); + let r = i.result(); + + assert!(!r.is_empty()); + assert!(r.is_estimation_mode()); + assert_eq!(r.estimate(), 0.0); +} + +#[test] +fn test_seed_mismatch_non_empty_returns_error() { + let mut s = UpdatableTupleSketch::::builder().build(); + s.update(1u64, 1u64); + + let mut i = TupleIntersection::::new(123, SumPolicy); + assert!(i.update(&s).is_err()); +} diff --git a/datasketches/tests/tuple_serialization_test.rs b/datasketches/tests/tuple_serialization_test.rs new file mode 100644 index 0000000..086de77 --- /dev/null +++ b/datasketches/tests/tuple_serialization_test.rs @@ -0,0 +1,113 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Cross-language compatibility tests for Tuple sketch serialization. +//! +//! The fixtures are produced by the upstream Java and C++ generators (see +//! `tools/generate_serialization_test_data.py`): +//! +//! * Java: `TupleCrossLanguageTest.generateForCppIntegerSummary` writes `tuple_int_n{n}_java.sk` +//! using its `IntegerSummary`. +//! * C++: `tuple_sketch_serialize_for_java.cpp` writes `tuple_int_n{n}_cpp.sk` using an `int` +//! summary. +//! +//! Both build a tuple sketch with `update(i, i)` for `i` in `0..n`, so the summary is a 4-byte +//! little-endian signed integer — exactly what [`PrimitiveSummarySerde`] reads into an `i32`. The +//! `aod_*`/`aos_*` fixtures use Array-of-Doubles / Array-of-Strings summaries, which this crate +//! does not implement, so they are intentionally not covered here. + +#![cfg(feature = "tuple")] + +mod common; + +use std::fs; +use std::path::PathBuf; + +use common::serialization_test_data; +use datasketches::tuple::CompactTupleSketch; +use datasketches::tuple::PrimitiveSummarySerde; +use googletest::assert_that; +use googletest::prelude::near; + +fn test_sketch_file(path: PathBuf, expected_cardinality: usize) { + let expected = expected_cardinality as f64; + let serde = PrimitiveSummarySerde; + + let bytes = fs::read(&path).unwrap(); + let sketch1 = CompactTupleSketch::::deserialize(&bytes, &serde) + .unwrap_or_else(|err| panic!("Deserialization failed for {}: {}", path.display(), err)); + + assert_eq!( + sketch1.is_empty(), + expected_cardinality == 0, + "Unexpected is_empty for {}", + path.display() + ); + + let estimate1 = sketch1.estimate(); + assert_that!(estimate1, near(expected, expected * 0.03)); + + // Snapshots from Java/C++ are not required to match byte-for-byte output from this + // implementation. Verify our own serialization is stable across a round-trip instead. + let serialized_bytes = sketch1.serialize(&serde); + let sketch2 = + CompactTupleSketch::::deserialize(&serialized_bytes, &serde).unwrap_or_else(|err| { + panic!( + "Deserialization failed after round-trip for {}: {}", + path.display(), + err + ) + }); + + let serialized_bytes2 = sketch2.serialize(&serde); + assert_eq!( + serialized_bytes, + serialized_bytes2, + "Serialized bytes are unstable after round-trip for {}", + path.display() + ); + + let estimate2 = sketch2.estimate(); + assert_eq!( + estimate1, + estimate2, + "Estimates differ after round-trip for {}", + path.display() + ); +} + +#[test] +fn test_java_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10_000, 100_000, 1_000_000]; + + for n in test_cases { + let filename = format!("tuple_int_n{}_java.sk", n); + let path = serialization_test_data("java_generated_files", &filename); + test_sketch_file(path, n); + } +} + +#[test] +fn test_cpp_compatibility() { + let test_cases = [0, 1, 10, 100, 1000, 10_000, 100_000, 1_000_000]; + + for n in test_cases { + let filename = format!("tuple_int_n{}_cpp.sk", n); + let path = serialization_test_data("cpp_generated_files", &filename); + test_sketch_file(path, n); + } +} diff --git a/datasketches/tests/tuple_sketch_test.rs b/datasketches/tests/tuple_sketch_test.rs new file mode 100644 index 0000000..d68cdfd --- /dev/null +++ b/datasketches/tests/tuple_sketch_test.rs @@ -0,0 +1,324 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Behavioral tests for the Tuple sketch, mirroring `theta_sketch_test.rs`. +//! +//! Updates carry a `u64` summary combined with the default (additive) policy, so the distinct-count +//! behavior matches the Theta sketch while the summaries accumulate alongside each key. + +#![cfg(feature = "tuple")] + +use datasketches::common::NumStdDev; +use datasketches::hash_value; +use datasketches::tuple::UpdatableTupleSketch; + +#[test] +fn test_basic_update() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + assert!(sketch.is_empty()); + assert_eq!(sketch.estimate(), 0.0); + + sketch.update("value1", 1u64); + assert!(!sketch.is_empty()); + assert_eq!(sketch.estimate(), 1.0); + + sketch.update("value2", 1u64); + assert_eq!(sketch.estimate(), 2.0); +} + +#[test] +fn test_summary_accumulates_per_key() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + for _ in 0..5 { + sketch.update("same_key", 2u64); + } + assert_eq!(sketch.estimate(), 1.0); + assert_eq!(sketch.num_retained(), 1); + // The default policy folds each update into the retained summary: 5 * 2 == 10. + assert_eq!(sketch.iter().next().unwrap().1, &10); +} + +#[test] +fn test_update_various_types() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + + sketch.update("string", 1u64); + sketch.update(42i64, 1u64); + sketch.update(42u64, 1u64); + // where floating-point numbers have different representations + sketch.update(hash_value::canonical_float::from_f64(3.15), 1u64); + sketch.update(hash_value::canonical_float::from_f64(3.15), 1u64); + sketch.update(hash_value::canonical_float::from_f32(3.15), 1u64); + sketch.update(hash_value::canonical_float::from_f32(3.15), 1u64); + sketch.update([1u8, 2, 3], 1u64); + + assert!(!sketch.is_empty()); + assert_eq!(sketch.estimate(), 5.0); + + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + + sketch.update("string", 1u64); + sketch.update(42i64, 1u64); + sketch.update(42u64, 1u64); + // where floating-point numbers have the same representation + sketch.update(hash_value::canonical_float::from_f64(5.0), 1u64); + sketch.update(hash_value::canonical_float::from_f64(5.0), 1u64); + sketch.update(hash_value::canonical_float::from_f32(5.0), 1u64); + sketch.update(hash_value::canonical_float::from_f32(5.0), 1u64); + sketch.update([1u8, 2, 3], 1u64); + + assert!(!sketch.is_empty()); + assert_eq!(sketch.estimate(), 4.0); +} + +#[test] +fn test_duplicate_updates() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + + for _ in 0..100 { + sketch.update("same_value", 1u64); + } + + assert_eq!(sketch.estimate(), 1.0); +} + +#[test] +fn test_theta_reduction() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(5).build(); // Small k to trigger theta reduction + assert!(!sketch.is_estimation_mode()); + + // Insert many values to trigger theta reduction + for i in 0..1000 { + sketch.update(format!("value_{}", i), 1u64); + } + + assert!(sketch.is_estimation_mode()); + assert!(sketch.theta() < 1.0); +} + +#[test] +fn test_trim() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(5).build(); + + // Insert many values + for i in 0..1000 { + sketch.update(format!("value_{}", i), 1u64); + } + + let before_trim = sketch.num_retained(); + sketch.trim(); + let after_trim = sketch.num_retained(); + + // After trim, should have approximately k entries + assert!(after_trim <= before_trim); + assert_eq!(sketch.num_retained(), 32); +} + +#[test] +fn test_reset() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(5).build(); + + // Insert many values + for i in 0..1000 { + sketch.update(format!("value_{}", i), 1u64); + } + assert!(!sketch.is_empty()); + assert!(sketch.is_estimation_mode()); + assert!(sketch.num_retained() > 32); + assert!(sketch.theta() < 1.0); + + sketch.reset(); + assert!(sketch.is_empty()); + assert_eq!(sketch.estimate(), 0.0); + assert_eq!(sketch.theta(), 1.0); + assert_eq!(sketch.num_retained(), 0); + assert!(!sketch.is_estimation_mode()); + assert_eq!(sketch.lower_bound(NumStdDev::One), 0.0); + assert_eq!(sketch.upper_bound(NumStdDev::One), 0.0); +} + +#[test] +fn test_iterator() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + + sketch.update("value1", 1u64); + sketch.update("value2", 1u64); + sketch.update("value3", 1u64); + + let count: usize = sketch.iter().count(); + assert_eq!(count, sketch.num_retained()); +} + +#[test] +fn test_bounds_empty_sketch() { + let sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + assert!(sketch.is_empty()); + assert!(!sketch.is_estimation_mode()); + assert_eq!(sketch.theta(), 1.0); + assert_eq!(sketch.estimate(), 0.0); + assert_eq!(sketch.lower_bound(NumStdDev::One), 0.0); + assert_eq!(sketch.upper_bound(NumStdDev::One), 0.0); + assert_eq!(sketch.lower_bound(NumStdDev::Two), 0.0); + assert_eq!(sketch.upper_bound(NumStdDev::Two), 0.0); + assert_eq!(sketch.lower_bound(NumStdDev::Three), 0.0); + assert_eq!(sketch.upper_bound(NumStdDev::Three), 0.0); +} + +#[test] +fn test_bounds_exact_mode() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + for i in 0..2000 { + sketch.update(i, 1u64); + } + assert!(!sketch.is_empty()); + assert!(!sketch.is_estimation_mode()); + assert_eq!(sketch.theta(), 1.0); + assert_eq!(sketch.estimate(), 2000.0); + assert_eq!(sketch.lower_bound(NumStdDev::One), 2000.0); + assert_eq!(sketch.upper_bound(NumStdDev::One), 2000.0); +} + +#[test] +fn test_bounds_estimation_mode() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + let n = 10000; + for i in 0..n { + sketch.update(i, 1u64); + } + assert!(!sketch.is_empty()); + assert!(sketch.is_estimation_mode()); + assert!(sketch.theta() < 1.0); + + let estimate = sketch.estimate(); + let lower_bound_1 = sketch.lower_bound(NumStdDev::One); + let upper_bound_1 = sketch.upper_bound(NumStdDev::One); + let lower_bound_2 = sketch.lower_bound(NumStdDev::Two); + let upper_bound_2 = sketch.upper_bound(NumStdDev::Two); + let lower_bound_3 = sketch.lower_bound(NumStdDev::Three); + let upper_bound_3 = sketch.upper_bound(NumStdDev::Three); + + // Check estimate is within reasonable margin (2% to be safe) + assert!( + (estimate - n as f64).abs() < n as f64 * 0.02, + "estimate {} is not within 2% of {}", + estimate, + n + ); + + // Check bounds are in correct order + assert!(lower_bound_1 < estimate); + assert!(estimate < upper_bound_1); + assert!(lower_bound_2 < estimate); + assert!(estimate < upper_bound_2); + assert!(lower_bound_3 < estimate); + assert!(estimate < upper_bound_3); + + // Check that wider confidence intervals are indeed wider + assert!(lower_bound_3 < lower_bound_2); + assert!(lower_bound_2 < lower_bound_1); + assert!(upper_bound_1 < upper_bound_2); + assert!(upper_bound_2 < upper_bound_3); +} + +#[test] +fn test_bounds_with_sampling() { + let mut sketch = UpdatableTupleSketch::::builder() + .lg_k(12) + .sampling_probability(0.5) + .build(); + + for i in 0..1000 { + sketch.update(i, 1u64); + } + + assert!(!sketch.is_empty()); + assert!(sketch.is_estimation_mode()); + assert!(sketch.theta() < 1.0); + + let estimate = sketch.estimate(); + let lower_bound = sketch.lower_bound(NumStdDev::Two); + let upper_bound = sketch.upper_bound(NumStdDev::Two); + + assert!(lower_bound <= estimate); + assert!(estimate <= upper_bound); +} + +#[test] +fn test_bounds_all_num_std_devs() { + let mut sketch = UpdatableTupleSketch::::builder().lg_k(12).build(); + for i in 0..10000 { + sketch.update(i, 1u64); + } + + let lb1 = sketch.lower_bound(NumStdDev::One); + let lb2 = sketch.lower_bound(NumStdDev::Two); + let lb3 = sketch.lower_bound(NumStdDev::Three); + let ub1 = sketch.upper_bound(NumStdDev::One); + let ub2 = sketch.upper_bound(NumStdDev::Two); + let ub3 = sketch.upper_bound(NumStdDev::Three); + + // Verify the bounds are properly ordered + assert!(lb3 <= lb2); + assert!(lb2 <= lb1); + assert!(ub1 <= ub2); + assert!(ub2 <= ub3); +} + +#[test] +fn test_bounds_empty_estimation_mode() { + // Create a sketch with sampling probability < 1.0 to force estimation mode + let sketch = UpdatableTupleSketch::::builder() + .lg_k(12) + .sampling_probability(0.1) + .build(); + + // The sketch is empty but theta < 1.0, so it's in estimation mode. + // When empty, both bounds should return 0.0 (matching the Java/Theta behavior). + assert!(sketch.is_empty()); + assert!(sketch.is_estimation_mode()); + assert_eq!(sketch.estimate(), 0.0); + assert_eq!(sketch.lower_bound(NumStdDev::One), 0.0); + assert_eq!(sketch.upper_bound(NumStdDev::One), 0.0); +} + +#[test] +fn test_compact_preserves_logical_non_empty_after_screened_update() { + let screened_value = (0u64..) + .find(|candidate| { + let mut sketch = UpdatableTupleSketch::::builder() + .lg_k(12) + .sampling_probability(0.5) + .build(); + sketch.update(*candidate, 1u64); + !sketch.is_empty() && sketch.num_retained() == 0 + }) + .expect("failed to find a value screened out by the sampling theta"); + + let mut sketch = UpdatableTupleSketch::::builder() + .lg_k(12) + .sampling_probability(0.5) + .build(); + sketch.update(screened_value, 1u64); + + assert!(!sketch.is_empty()); + assert_eq!(sketch.num_retained(), 0); + + let compact = sketch.compact(false); + assert!(!compact.is_empty()); + assert_eq!(compact.num_retained(), 0); + assert_eq!(compact.theta64(), sketch.theta64()); +}