diff --git a/datasketches/tests/countmin_serialization_test.rs b/datasketches/tests/countmin_serialization_test.rs new file mode 100644 index 0000000..4492a55 --- /dev/null +++ b/datasketches/tests/countmin_serialization_test.rs @@ -0,0 +1,71 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#![cfg(feature = "countmin")] + +mod common; + +use std::fs; + +use common::serialization_test_data; +use datasketches::countmin::CountMinSketch; +use googletest::assert_that; +use googletest::prelude::contains_substring; + +// This test validates binary format compatibility (deserialize + byte round-trip) for +// C++ Count-Min snapshots. It intentionally does not assert estimate equivalence against +// original input keys because per-row hash seed derivation differs across implementations. +fn assert_cpp_snapshot( + filename: &str, + seed: u64, + expected_num_hashes: u8, + expected_num_buckets: u32, + expected_total_weight: u64, +) { + let path = serialization_test_data("cpp_generated_files", filename); + let bytes = fs::read(&path).unwrap(); + + let sketch = CountMinSketch::::deserialize_with_seed(&bytes, seed).unwrap(); + + assert_eq!(sketch.num_hashes(), expected_num_hashes); + assert_eq!(sketch.num_buckets(), expected_num_buckets); + assert_eq!(sketch.seed(), seed); + assert_eq!(sketch.total_weight(), expected_total_weight); + assert_eq!(sketch.is_empty(), expected_total_weight == 0); + + let roundtrip = sketch.serialize(); + assert_eq!(roundtrip, bytes, "round-trip bytes differ for {filename}"); +} + +#[test] +fn test_deserialize_cpp_empty_snapshot() { + assert_cpp_snapshot("count_min-empty.bin", 9001, 1, 5, 0); +} + +#[test] +fn test_deserialize_cpp_non_empty_snapshot() { + assert_cpp_snapshot("count_min-non-empty.bin", 9001, 3, 1024, 2850); +} + +#[test] +fn test_deserialize_cpp_snapshot_with_wrong_seed() { + let path = serialization_test_data("cpp_generated_files", "count_min-non-empty.bin"); + let bytes = fs::read(&path).unwrap(); + + let err = CountMinSketch::::deserialize_with_seed(&bytes, 9000).unwrap_err(); + assert_that!(err.message(), contains_substring("incompatible seed hash")); +} diff --git a/datasketches/tests/countmin_test.rs b/datasketches/tests/countmin_test.rs index a8681e4..fe5419b 100644 --- a/datasketches/tests/countmin_test.rs +++ b/datasketches/tests/countmin_test.rs @@ -18,6 +18,9 @@ #![cfg(feature = "countmin")] use datasketches::countmin::CountMinSketch; +use googletest::assert_that; +use googletest::prelude::ge; +use googletest::prelude::le; #[test] fn test_init_defaults() { @@ -43,7 +46,7 @@ fn test_parameter_suggestions() { let buckets = CountMinSketch::::suggest_num_buckets(0.1); let sketch = CountMinSketch::::new(3, buckets); - assert!(sketch.relative_error() <= 0.1); + assert_that!(sketch.relative_error(), le(0.1)); } #[test] @@ -56,8 +59,8 @@ fn test_update_and_bounds() { let estimate = sketch.estimate("x"); let upper = sketch.upper_bound("x"); let lower = sketch.lower_bound("x"); - assert!(lower <= estimate); - assert!(estimate <= upper); + assert_that!(estimate, ge(lower)); + assert_that!(estimate, le(upper)); } #[test] @@ -69,8 +72,8 @@ fn test_update_and_bounds_with_scaling() { let upper = sketch.upper_bound("x"); let lower = sketch.lower_bound("x"); assert_eq!(estimate, 10); - assert!(lower <= estimate); - assert!(estimate <= upper); + assert_that!(estimate, ge(lower)); + assert_that!(estimate, le(upper)); let eps = sketch.relative_error(); @@ -80,8 +83,8 @@ fn test_update_and_bounds_with_scaling() { let lower = sketch.lower_bound("x"); assert_eq!(sketch.total_weight(), 5); assert_eq!(estimate, 5); - assert!(lower <= estimate); - assert!(estimate <= upper); + assert_that!(estimate, ge(lower)); + assert_that!(estimate, le(upper)); assert_eq!( upper, estimate + (eps * sketch.total_weight() as f64) as u64 @@ -93,8 +96,8 @@ fn test_update_and_bounds_with_scaling() { let lower = sketch.lower_bound("x"); assert_eq!(sketch.total_weight(), 2); assert_eq!(estimate, 2); - assert!(lower <= estimate); - assert!(estimate <= upper); + assert_that!(estimate, ge(lower)); + assert_that!(estimate, le(upper)); assert_eq!( upper, estimate + (eps * sketch.total_weight() as f64) as u64 @@ -124,13 +127,13 @@ fn test_halve() { } for i in 0..1000usize { - assert!(sketch.estimate(i as u64) >= i as u64); + assert_that!(sketch.estimate(i as u64), ge(i as u64)); } sketch.halve(); for i in 0..1000usize { - assert!(sketch.estimate(i as u64) >= (i as u64) / 2); + assert_that!(sketch.estimate(i as u64), ge((i as u64) / 2)); } } @@ -147,7 +150,7 @@ fn test_decay() { } for i in 0..1000usize { - assert!(sketch.estimate(i as u64) >= i as u64); + assert_that!(sketch.estimate(i as u64), ge(i as u64)); } const FACTOR: f64 = 0.5; @@ -155,7 +158,7 @@ fn test_decay() { for i in 0..1000usize { let expected = ((i as f64) * FACTOR).floor() as u64; - assert!(sketch.estimate(i as u64) >= expected); + assert_that!(sketch.estimate(i as u64), ge(expected)); } } @@ -172,8 +175,8 @@ fn test_merge() { } left.merge(&right); assert_eq!(left.total_weight(), 18); - assert!(left.estimate("a") >= 14); - assert!(left.estimate("b") >= 4); + assert_that!(left.estimate("a"), ge(14)); + assert_that!(left.estimate("b"), ge(4)); } #[test] @@ -247,6 +250,6 @@ fn test_increment_multi_like_rust_count_min_sketch() { sketch.update(i % 100); } for key in 0..100u64 { - assert!(sketch.estimate(key) >= 9_000); + assert_that!(sketch.estimate(key), ge(9_000)); } } diff --git a/tools/generate_serialization_test_data.py b/tools/generate_serialization_test_data.py index c965f39..28314b1 100755 --- a/tools/generate_serialization_test_data.py +++ b/tools/generate_serialization_test_data.py @@ -134,6 +134,10 @@ def generate_cpp_files(workspace_dir, project_root): # 4. Clone repository repo_url = "https://github.com/apache/datasketches-cpp.git" branch = "master" + # Temporary e2e checkout for apache/datasketches-cpp#505. After that PR is + # merged, pin this to the merged master commit and remove the extra fetch. + commit = "af4436280bdab53e0063268e92ff29b3fdcb1b07" + fetch_ref = "refs/pull/505/head" run_command([ "git", "clone", "--depth", "1", @@ -142,6 +146,8 @@ def generate_cpp_files(workspace_dir, project_root): repo_url, str(temp_dir) ]) + run_command(["git", "fetch", "--depth", "1", "origin", fetch_ref], cwd=temp_dir) + run_command(["git", "checkout", "--detach", commit], cwd=temp_dir) # 5. Build and Run CMake build_dir = temp_dir / "build" @@ -166,15 +172,15 @@ def generate_cpp_files(workspace_dir, project_root): output_dir.mkdir(parents=True, exist_ok=True) files_copied = 0 - # Search recursively in build directory for *_cpp.sk - for file_path in build_dir.rglob("*_cpp.sk"): - # Avoid copying from CMakeFiles or other intermediate dirs if possible, but the pattern is specific enough - shutil.copy2(file_path, output_dir) - print(f"Copied: {file_path.name}") - files_copied += 1 + + for pattern in ("*_cpp.sk", "count_min-*.bin"): + for file_path in build_dir.rglob(pattern): + shutil.copy2(file_path, output_dir) + print(f"Copied: {file_path.name}") + files_copied += 1 if files_copied == 0: - print("Warning: No *_cpp.sk files were found to copy.") + print("Warning: No C++ serialization snapshots were found to copy.") else: print(f"Successfully copied {files_copied} files.")