diff --git a/package.json b/package.json index 0d7aaacd..0da80649 100644 --- a/package.json +++ b/package.json @@ -62,6 +62,7 @@ "spark-rs": "file:rust/spark-rs/pkg" }, "dependencies": { + "@bokuweb/zstd-wasm": "^0.0.27", "fflate": "^0.8.2" }, "peerDependencies": { diff --git a/rust/Cargo.lock b/rust/Cargo.lock index 7b832572..d97ef639 100644 --- a/rust/Cargo.lock +++ b/rust/Cargo.lock @@ -937,6 +937,12 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" +[[package]] +name = "ruzstd" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fad02996bfc73da3e301efe90b1837be9ed8f4a462b6ed410aa35d00381de89f" + [[package]] name = "ryu" version = "1.0.20" @@ -1054,6 +1060,7 @@ dependencies = [ "miniz_oxide", "ordered-float", "rand_pcg", + "ruzstd", "serde", "serde_json", "smallvec", diff --git a/rust/Cargo.toml b/rust/Cargo.toml index 5b394c7a..68073167 100644 --- a/rust/Cargo.toml +++ b/rust/Cargo.toml @@ -24,6 +24,7 @@ itertools = "0.14.0" js-sys = "0.3.77" miniz_oxide = "0.8.9" ordered-float = "5.1.0" +ruzstd = { version = "0.7.3", default-features = false, features = ["std"] } serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.145" serde-wasm-bindgen = "0.6.5" diff --git a/rust/spark-lib/Cargo.toml b/rust/spark-lib/Cargo.toml index be19442b..991866b7 100644 --- a/rust/spark-lib/Cargo.toml +++ b/rust/spark-lib/Cargo.toml @@ -14,6 +14,7 @@ glam.workspace = true half.workspace = true ordered-float.workspace = true miniz_oxide.workspace = true +ruzstd.workspace = true serde.workspace = true smallvec.workspace = true itertools.workspace = true diff --git a/rust/spark-lib/src/decoder.rs b/rust/spark-lib/src/decoder.rs index 74c4af30..8525fd5b 100644 --- a/rust/spark-lib/src/decoder.rs +++ b/rust/spark-lib/src/decoder.rs @@ -474,6 +474,10 @@ impl ChunkReceiver for MultiDecoder { if (magic & 0x00ffffff) == PLY_MAGIC { return self.init_file_type(SplatFileType::PLY); } + if magic == SPZ_MAGIC { + // NGSP magic at file start — SPZ v4 (ZSTD multi-stream, not gzip-wrapped) + return self.init_file_type(SplatFileType::SPZ); + } if (magic & 0x00ffffff) == GZIP_MAGIC { // Gzipped file, unpack beginning to check magic number if self.buffer_gz.is_none() { diff --git a/rust/spark-lib/src/spz.rs b/rust/spark-lib/src/spz.rs index 1e60d943..15268056 100644 --- a/rust/spark-lib/src/spz.rs +++ b/rust/spark-lib/src/spz.rs @@ -4,6 +4,7 @@ use miniz_oxide::inflate::core::inflate_flags::{ TINFL_FLAG_USING_NON_WRAPPING_OUTPUT_BUF, }; use miniz_oxide::inflate::TINFLStatus; +use std::io::Read; use crate::decoder::{ChunkReceiver, SetSplatEncoding, SplatGetter, SplatInit, SplatReceiver}; use miniz_oxide::deflate::compress_to_vec; @@ -11,19 +12,65 @@ use miniz_oxide::deflate::compress_to_vec; pub const SPZ_MAGIC: u32 = 0x5053474e; // "NGSP" const SH_C0: f32 = 0.28209479177387814; const MAX_SPLAT_CHUNK: usize = 65536; +const NGSP_HEADER_SIZE: usize = 32; +const TOC_ENTRY_SIZE: usize = 16; // [u64 compressedSize LE][u64 uncompressedSize LE] + +// Header flag bits (byte 14 of the SPZ header). +const FLAG_HAS_EXTENSIONS: u8 = 0x02; #[derive(Debug, Clone, Copy, PartialEq, Eq)] enum SpzDecoderStage { Centers, Alphas, Rgb, Scales, Quats, Sh, Extension, ChildCounts, ChildStarts, Done } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum SpzFormat { + Unknown, // not yet detected (need at least 4 bytes) + Gzip, // legacy v1-v3: header + payload all inside a gzip stream + Ngsp, // v4: 32-byte NGSP header + TOC + ZSTD-compressed attribute streams +} + +/// Parsed v4 NGSP header; cached between calls to `try_decode_v4` so we don't +/// re-parse the 32-byte preamble on every push. +#[derive(Debug, Clone)] +struct V4HeaderInfo { + version: u32, + num_splats: usize, + sh_degree: usize, + fractional_bits: u8, + flags: u8, + num_streams: usize, + toc_byte_offset: usize, + toc_end: usize, // toc_byte_offset + num_streams * TOC_ENTRY_SIZE +} + +/// State machine for the streaming v4 decode. Each variant carries the parsed +/// outputs from the previous stage so we never reparse on subsequent pushes +/// while waiting for more bytes. +enum V4Stage { + NeedHeader, + NeedToc(V4HeaderInfo), + NeedStreams { + header: V4HeaderInfo, + compressed_offsets: Vec<(usize, usize)>, // (offset, size) per stream + total_size: usize, // total file size required (toc_end + sum of stream sizes) + }, + Done, +} + pub struct SpzDecoder { splats: T, + format: SpzFormat, + // Gzip path state (v1-v3) decompressor: DecompressorOxide, compressed: Vec, decompressed: Vec, - buffer: Vec, - state: Option, gzip_header_done: bool, out_pos: usize, + // V4 path state — accumulate the entire file before processing + raw: Vec, + v4_stage: V4Stage, + // Shared: decompressed payload bytes feeding the section state machine + buffer: Vec, + state: Option, done: bool, } @@ -31,6 +78,7 @@ impl SpzDecoder { pub fn new(splats: T) -> Self { Self { splats, + format: SpzFormat::Unknown, decompressor: DecompressorOxide::new(), compressed: Vec::new(), decompressed: vec![0u8; 128 * 1024], @@ -39,6 +87,8 @@ impl SpzDecoder { gzip_header_done: false, out_pos: 0, done: false, + raw: Vec::new(), + v4_stage: V4Stage::NeedHeader, } } @@ -61,24 +111,26 @@ impl SpzDecoder { return Ok(()); } - let magic = read_u32_le(&self.buffer[0..4]); - if magic != SPZ_MAGIC { - return Err(anyhow::anyhow!("Invalid SPZ magic: 0x{:08x}", magic)); - } - - let version = read_u32_le(&self.buffer[4..8]); - if version < 1 || version > 3 { - return Err(anyhow::anyhow!("Unsupported SPZ version: {}", version)); + let h = parse_common_header(&self.buffer)?; + if !(1..=3).contains(&h.version) { + return Err(anyhow::anyhow!("Unsupported legacy SPZ version: {}", h.version)); } - - let num_splats = read_u32_le(&self.buffer[8..12]) as usize; - let sh_degree = self.buffer[12] as usize; - let fractional_bits = self.buffer[13]; - let flags = self.buffer[14]; let _reserved = self.buffer[15]; self.buffer.drain(..16); - let state = SpzDecoderState::new(version as u32, num_splats, sh_degree, fractional_bits, flags)?; + self.init_state(h.version, h.num_splats, h.sh_degree, h.fractional_bits, h.flags)?; + Ok(()) + } + + fn init_state( + &mut self, + version: u32, + num_splats: usize, + sh_degree: usize, + fractional_bits: u8, + flags: u8, + ) -> anyhow::Result<()> { + let state = SpzDecoderState::new(version, num_splats, sh_degree, fractional_bits, flags)?; self.state = Some(state); self.splats.init_splats(&SplatInit { @@ -97,6 +149,90 @@ impl SpzDecoder { Ok(()) } + /// Drive the v4 decode state machine forward. Called every time bytes arrive + /// via `push()` (and once more in `finish()`). Each invocation advances + /// through as many stages as the currently-buffered bytes allow, then + /// returns. Parsed header / TOC outputs are carried in `self.v4_stage` so + /// no work is repeated across calls while waiting for the next size + /// threshold to be reached. + fn try_decode_v4(&mut self) -> anyhow::Result<()> { + loop { + // Take ownership of the current stage so we can match on it without + // holding a borrow on `self`. Each arm puts an updated stage back. + let stage = std::mem::replace(&mut self.v4_stage, V4Stage::Done); + match stage { + V4Stage::Done => { + // Either already finished or an error left us in a terminal + // state; restore Done and exit. `self.done` distinguishes + // success from a half-finished error path. + self.v4_stage = V4Stage::Done; + return Ok(()); + } + V4Stage::NeedHeader => { + if self.raw.len() < NGSP_HEADER_SIZE { + self.v4_stage = V4Stage::NeedHeader; + return Ok(()); + } + let header = parse_v4_header(&self.raw)?; + self.v4_stage = V4Stage::NeedToc(header); + // fall through to next iteration to attempt TOC parse + } + V4Stage::NeedToc(header) => { + if self.raw.len() < header.toc_end { + self.v4_stage = V4Stage::NeedToc(header); + return Ok(()); + } + let (compressed_offsets, total_size) = walk_v4_toc(&self.raw, &header)?; + self.v4_stage = V4Stage::NeedStreams { + header, + compressed_offsets, + total_size, + }; + // fall through to next iteration to attempt stream decompression + } + V4Stage::NeedStreams { + header, + compressed_offsets, + total_size, + } => { + if self.raw.len() < total_size { + self.v4_stage = V4Stage::NeedStreams { + header, + compressed_offsets, + total_size, + }; + return Ok(()); + } + // All bytes present; ZSTD-decompress every stream into + // self.buffer, then run the existing section state machine. + self.buffer.clear(); + for (offset, size) in &compressed_offsets { + let compressed = &self.raw[*offset..*offset + *size]; + let mut decoder = ruzstd::StreamingDecoder::new(compressed) + .map_err(|e| anyhow::anyhow!("v4 ZSTD init failed: {}", e))?; + decoder + .read_to_end(&mut self.buffer) + .map_err(|e| anyhow::anyhow!("v4 ZSTD decompress failed: {}", e))?; + } + self.init_state( + header.version, + header.num_splats, + header.sh_degree, + header.fractional_bits, + header.flags, + )?; + self.poll_sections()?; + // Mark the one-shot v4 decode as complete; finish() validates + // against this same flag for both the streaming gzip path and + // the v4 path. + self.v4_stage = V4Stage::Done; + self.done = true; + return Ok(()); + } + } + } + } + fn poll_sections(&mut self) -> anyhow::Result<()> { let Some(state) = self.state.as_mut() else { unreachable!(); @@ -225,7 +361,7 @@ impl SpzDecoder { } } SpzDecoderStage::Quats => { - let bytes_per_item = if state.version == 3 { 4 } else { 3 }; + let bytes_per_item = if state.version >= 3 { 4 } else { 3 }; let avail_items = self.buffer.len() / bytes_per_item; let remaining = state.num_splats - state.next_splat; if (avail_items < remaining) && (avail_items < MAX_SPLAT_CHUNK) { @@ -236,8 +372,8 @@ impl SpzDecoder { if state.output.len() < chunk * 4 { state.output.resize(chunk * 4, 0.0); } - if state.version == 3 { - // Version 3 uses "smallest three" compression for quaternions (4 bytes per splat) + if state.version >= 3 { + // Version 3 and v4 use "smallest three" compression for quaternions (4 bytes per splat) for i in 0..chunk { let base = i * 4; let comp = (self.buffer[base] as u32) @@ -481,6 +617,112 @@ impl SpzDecoder { } } +/// Fields shared by the v1–v3 (gzip) and v4 (NGSP) SPZ headers — the first 15 +/// bytes of either header have an identical layout, even though byte 15 onward +/// diverges (`_reserved` for legacy, `num_streams` + `toc_byte_offset` + 12 +/// reserved bytes for v4). Caller must guarantee `buf.len() >= 15`. +struct CommonHeaderFields { + version: u32, + num_splats: usize, + sh_degree: usize, + fractional_bits: u8, + flags: u8, +} + +/// Validate the SPZ magic and parse the shared first 15 bytes. Version range +/// validation is left to the caller — v1–v3 and v4 have different acceptable +/// ranges. +fn parse_common_header(buf: &[u8]) -> anyhow::Result { + debug_assert!(buf.len() >= 15); + let magic = read_u32_le(&buf[0..4]); + if magic != SPZ_MAGIC { + return Err(anyhow::anyhow!("Invalid SPZ magic: 0x{:08x}", magic)); + } + Ok(CommonHeaderFields { + version: read_u32_le(&buf[4..8]), + num_splats: read_u32_le(&buf[8..12]) as usize, + sh_degree: buf[12] as usize, + fractional_bits: buf[13], + flags: buf[14], + }) +} + +/// Parse the 32-byte NGSP header at the start of a v4 file. Caller must +/// guarantee `raw.len() >= NGSP_HEADER_SIZE`. +fn parse_v4_header(raw: &[u8]) -> anyhow::Result { + debug_assert!(raw.len() >= NGSP_HEADER_SIZE); + let h = parse_common_header(raw)?; + if h.version != 4 { + return Err(anyhow::anyhow!("Unsupported NGSP version: {}", h.version)); + } + // Extensions are signalled in the flag byte but this decoder does not + // parse extension data. Mirror the reference impl's behaviour: warn the + // user that some packing-affecting metadata may have been skipped, then + // continue decoding the rest of the file as normal. + if h.flags & FLAG_HAS_EXTENSIONS != 0 { + eprintln!( + "[SPZ WARNING] parse_v4_header: extensions were skipped at load time — \ + unpacked data may be incorrect due to unknown packing behavior" + ); + } + let num_streams = raw[15] as usize; + let toc_byte_offset = read_u32_le(&raw[16..20]) as usize; + // bytes 20..32 reserved + + if toc_byte_offset < NGSP_HEADER_SIZE { + return Err(anyhow::anyhow!( + "Invalid v4 tocByteOffset: {} < {}", + toc_byte_offset, + NGSP_HEADER_SIZE + )); + } + let toc_size = num_streams + .checked_mul(TOC_ENTRY_SIZE) + .ok_or_else(|| anyhow::anyhow!("v4 TOC size overflow"))?; + let toc_end = toc_byte_offset + .checked_add(toc_size) + .ok_or_else(|| anyhow::anyhow!("v4 TOC end overflow"))?; + + Ok(V4HeaderInfo { + version: h.version, + num_splats: h.num_splats, + sh_degree: h.sh_degree, + fractional_bits: h.fractional_bits, + flags: h.flags, + num_streams, + toc_byte_offset, + toc_end, + }) +} + +/// Walk the v4 TOC to compute the (offset, size) of every compressed stream +/// and the total file size required. Caller must guarantee +/// `raw.len() >= header.toc_end`. +fn walk_v4_toc( + raw: &[u8], + header: &V4HeaderInfo, +) -> anyhow::Result<(Vec<(usize, usize)>, usize)> { + debug_assert!(raw.len() >= header.toc_end); + let mut compressed_offsets: Vec<(usize, usize)> = Vec::with_capacity(header.num_streams); + let mut data_cursor = header.toc_end; + for i in 0..header.num_streams { + let e = header.toc_byte_offset + i * TOC_ENTRY_SIZE; + let cs_lo = read_u32_le(&raw[e..e + 4]) as u64; + let cs_hi = read_u32_le(&raw[e + 4..e + 8]) as u64; + let _us_lo = read_u32_le(&raw[e + 8..e + 12]) as u64; + let _us_hi = read_u32_le(&raw[e + 12..e + 16]) as u64; + let compressed_size = (cs_lo | (cs_hi << 32)) as usize; + if cs_hi != 0 || compressed_size > usize::MAX / 2 { + return Err(anyhow::anyhow!("v4 stream too large")); + } + compressed_offsets.push((data_cursor, compressed_size)); + data_cursor = data_cursor + .checked_add(compressed_size) + .ok_or_else(|| anyhow::anyhow!("v4 stream offset overflow"))?; + } + Ok((compressed_offsets, data_cursor)) +} + fn parse_gzip_header(buffer: &mut Vec) -> anyhow::Result { if buffer.len() < 10 { return Ok(false); @@ -551,14 +793,67 @@ fn parse_gzip_header(buffer: &mut Vec) -> anyhow::Result { impl ChunkReceiver for SpzDecoder { fn push(&mut self, bytes: &[u8]) -> anyhow::Result<()> { - self.compressed.extend_from_slice(bytes); - self.poll_decompress()?; - Ok(()) + // Phase 1: get the incoming bytes into the format-appropriate buffer. + // On the very first push (or first few, if chunks arrive < 4 bytes at a + // time) the format is still Unknown — we accumulate into `raw` as a + // scratch buffer, detect the format from the first 4 bytes, and (for + // gzip) move what we've collected so far into `compressed`. + if self.format == SpzFormat::Unknown { + self.raw.extend_from_slice(bytes); + if self.raw.len() < 4 { + return Ok(()); + } + let magic = read_u32_le(&self.raw[0..4]); + if magic == SPZ_MAGIC { + self.format = SpzFormat::Ngsp; + // `raw` is already the right destination buffer for v4. + } else if (magic & 0x00ffffff) == 0x00088b1f { + self.format = SpzFormat::Gzip; + // Move the detection scratch into the gzip input buffer. + let buffered = std::mem::take(&mut self.raw); + self.compressed.extend_from_slice(&buffered); + } else { + return Err(anyhow::anyhow!( + "Unrecognized SPZ format: leading bytes 0x{:08x}", magic + )); + } + } else { + // Steady state: append to whichever buffer the detected format uses. + match self.format { + SpzFormat::Gzip => self.compressed.extend_from_slice(bytes), + SpzFormat::Ngsp => self.raw.extend_from_slice(bytes), + SpzFormat::Unknown => unreachable!(), + } + } + + // Phase 2: advance the decoder. Single dispatch point for both formats. + match self.format { + SpzFormat::Gzip => self.poll_decompress(), + SpzFormat::Ngsp => self.try_decode_v4(), + SpzFormat::Unknown => unreachable!(), + } } fn finish(&mut self) -> anyhow::Result<()> { - self.poll_decompress()?; - if !self.done { return Err(anyhow::anyhow!("Truncated gzip stream")); } + match self.format { + SpzFormat::Gzip => { + self.poll_decompress()?; + if !self.done { + return Err(anyhow::anyhow!("Truncated gzip stream")); + } + } + SpzFormat::Ngsp => { + // No new bytes arrive between the last push() and finish(); the v4 + // state machine is already as advanced as the buffered data permits. + // A non-Done state here means the file was truncated. + if !self.done { + return Err(anyhow::anyhow!("Truncated SPZ v4 stream")); + } + } + SpzFormat::Unknown => { + return Err(anyhow::anyhow!("Empty SPZ stream")); + } + } if let Some(state) = &self.state { if state.stage != SpzDecoderStage::Done && !(state.sh_degree == 0 && state.stage == SpzDecoderStage::Sh) { return Err(anyhow::anyhow!("Incomplete SPZ stream: stage = {:?}, sh_degree = {}", state.stage, state.sh_degree)); diff --git a/src/SplatLoader.ts b/src/SplatLoader.ts index 5f2484f4..c8e6a208 100644 --- a/src/SplatLoader.ts +++ b/src/SplatLoader.ts @@ -426,6 +426,10 @@ export function getSplatFileType( // Unknown PKZip file type return undefined; } + if (magic === 0x5053474e) { + // NGSP magic at file start — SPZ v4 (ZSTD multi-stream, not gzip-wrapped) + return SplatFileType.SPZ; + } if (magic === 0x30444152) { return SplatFileType.RAD; } diff --git a/src/spz.ts b/src/spz.ts index 75d3d832..d5df69f3 100644 --- a/src/spz.ts +++ b/src/spz.ts @@ -5,8 +5,23 @@ import { getSplatFileType, getSplatFileTypeFromPath, } from "./SplatLoader"; +import { + compress as zstdCompress, + decompress as zstdDecompress, + init as zstdInit, +} from "@bokuweb/zstd-wasm"; import { GunzipReader, fromHalf, normalize } from "./utils"; +// Lazy, idempotent initialization of the ZSTD WASM module. The first call +// fetches/instantiates the WASM blob; subsequent calls return the cached promise. +let zstdInitPromise: Promise | null = null; +function ensureZstdInit(): Promise { + if (!zstdInitPromise) { + zstdInitPromise = zstdInit(); + } + return zstdInitPromise; +} + import { decodeAntiSplat } from "./antisplat"; import { SplatFileType } from "./defines"; import { decodeKsplat } from "./ksplat"; @@ -16,7 +31,10 @@ import { PlyReader } from "./ply"; export class SpzReader { fileBytes: Uint8Array; - reader: GunzipReader; + // null for v4 (ZSTD), set for v1-v3 (gzip) + reader: GunzipReader | null = null; + // Pre-decompressed attribute streams for v4: [positions, alphas, colors, scales, rotations, sh?] + v4Streams: Uint8Array[] | null = null; version = -1; numSplats = 0; @@ -32,9 +50,19 @@ export class SpzReader { constructor({ fileBytes }: { fileBytes: Uint8Array | ArrayBuffer }) { this.fileBytes = fileBytes instanceof ArrayBuffer ? new Uint8Array(fileBytes) : fileBytes; - this.reader = new GunzipReader({ - fileBytes: this.fileBytes as Uint8Array, - }); + // V4 files start with NGSP magic directly; v1-v3 are gzip-compressed. + const b = this.fileBytes; + const isV4 = + b.length >= 4 && + b[0] === 0x4e && + b[1] === 0x47 && + b[2] === 0x53 && + b[3] === 0x50; + if (!isV4) { + this.reader = new GunzipReader({ + fileBytes: this.fileBytes as Uint8Array, + }); + } } async parseHeader() { @@ -42,26 +70,87 @@ export class SpzReader { throw new Error("SPZ file header already parsed"); } - const header = new DataView((await this.reader.read(16)).buffer); - if (header.getUint32(0, true) !== 0x5053474e) { - throw new Error("Invalid SPZ file"); - } - this.version = header.getUint32(4, true); - if (this.version < 1 || this.version > 3) { - throw new Error(`Unsupported SPZ version: ${this.version}`); + if (this.reader === null) { + // V4: 32-byte NGSP header, attributes in separate ZSTD-compressed streams. + if (this.fileBytes.length < 32) { + throw new Error("SPZ v4 file too short"); + } + const view = new DataView( + this.fileBytes.buffer, + this.fileBytes.byteOffset, + this.fileBytes.byteLength, + ); + this.version = view.getUint32(4, true); + if (this.version !== 4) { + throw new Error(`Unsupported SPZ version: ${this.version}`); + } + this.numSplats = view.getUint32(8, true); + this.shDegree = view.getUint8(12); + this.fractionalBits = view.getUint8(13); + this.flags = view.getUint8(14); + this.flagAntiAlias = (this.flags & 0x01) !== 0; + this.flagLod = (this.flags & 0x80) !== 0; + this.reserved = 0; + const numStreams = view.getUint8(15); + const tocByteOffset = view.getUint32(16, true); + await ensureZstdInit(); + this.v4Streams = this._loadV4Streams(numStreams, tocByteOffset, view); + } else { + // V1-V3: 16-byte NGSP header inside gzip stream. + const header = new DataView((await this.reader.read(16)).buffer); + if (header.getUint32(0, true) !== 0x5053474e) { + throw new Error("Invalid SPZ file"); + } + this.version = header.getUint32(4, true); + if (this.version < 1 || this.version > 3) { + throw new Error(`Unsupported SPZ version: ${this.version}`); + } + this.numSplats = header.getUint32(8, true); + this.shDegree = header.getUint8(12); + this.fractionalBits = header.getUint8(13); + this.flags = header.getUint8(14); + this.flagAntiAlias = (this.flags & 0x01) !== 0; + this.flagLod = (this.flags & 0x80) !== 0; + this.reserved = header.getUint8(15); } - this.numSplats = header.getUint32(8, true); - this.shDegree = header.getUint8(12); - this.fractionalBits = header.getUint8(13); - this.flags = header.getUint8(14); - this.flagAntiAlias = (this.flags & 0x01) !== 0; - this.flagLod = (this.flags & 0x80) !== 0; - this.reserved = header.getUint8(15); this.headerParsed = true; this.parsed = false; } + private _loadV4Streams( + numStreams: number, + tocByteOffset: number, + view: DataView, + ): Uint8Array[] { + // TOC layout: numStreams × 16 bytes, each entry = [compressedSize u64 LE][uncompressedSize u64 LE]. + // Compressed streams follow immediately after the TOC in this order: + // positions, alphas, colors, scales, rotations, SH (zero-size streams skipped) + const tocEntrySize = 16; + const tocEnd = tocByteOffset + numStreams * tocEntrySize; + if (tocEnd > this.fileBytes.byteLength) { + throw new Error("SPZ v4: TOC extends beyond file end"); + } + const streams: Uint8Array[] = []; + let dataOffset = tocEnd; + for (let i = 0; i < numStreams; i++) { + const e = tocByteOffset + i * tocEntrySize; + const compressedSizeLo = view.getUint32(e, true); + const compressedSizeHi = view.getUint32(e + 4, true); + if (compressedSizeHi !== 0) { + throw new Error("SPZ v4: stream size exceeds 4GB"); + } + const compressedSize = compressedSizeLo; + const compressed = this.fileBytes.subarray( + dataOffset, + dataOffset + compressedSize, + ); + streams.push(zstdDecompress(compressed)); + dataOffset += compressedSize; + } + return streams; + } + async parseSplats( centerCallback?: (index: number, x: number, y: number, z: number) => void, alphaCallback?: (index: number, alpha: number) => void, @@ -101,9 +190,19 @@ export class SpzReader { } this.parsed = true; + // Unified attribute reader: v4 returns pre-decompressed streams in order; + // v1-v3 reads sequentially from the gzip stream. + let streamIdx = 0; + const read = + this.v4Streams !== null + ? async (_n: number): Promise => + this.v4Streams![streamIdx++] + : async (n: number): Promise => + await this.reader!.read(n); + if (this.version === 1) { // float16 centers - const centerBytes = await this.reader.read(this.numSplats * 3 * 2); + const centerBytes = await read(this.numSplats * 3 * 2); const centerUint16 = new Uint16Array(centerBytes.buffer); for (let i = 0; i < this.numSplats; i++) { const i3 = i * 3; @@ -112,10 +211,10 @@ export class SpzReader { const z = fromHalf(centerUint16[i3 + 2]); centerCallback?.(i, x, y, z); } - } else if (this.version === 2 || this.version === 3) { - // 24-bit fixed-point centers + } else { + // 24-bit fixed-point centers (v2/v3/v4) const fixed = 1 << this.fractionalBits; - const centerBytes = await this.reader.read(this.numSplats * 3 * 3); + const centerBytes = await read(this.numSplats * 3 * 3); for (let i = 0; i < this.numSplats; i++) { const i9 = i * 9; const x = @@ -138,18 +237,16 @@ export class SpzReader { fixed; centerCallback?.(i, x, y, z); } - } else { - throw new Error("Unreachable"); } { - const bytes = await this.reader.read(this.numSplats); + const bytes = await read(this.numSplats); for (let i = 0; i < this.numSplats; i++) { alphaCallback?.(i, bytes[i] / 255); } } { - const rgbBytes = await this.reader.read(this.numSplats * 3); + const rgbBytes = await read(this.numSplats * 3); const scale = SH_C0 / 0.15; for (let i = 0; i < this.numSplats; i++) { const i3 = i * 3; @@ -160,7 +257,7 @@ export class SpzReader { } } { - const scalesBytes = await this.reader.read(this.numSplats * 3); + const scalesBytes = await read(this.numSplats * 3); for (let i = 0; i < this.numSplats; i++) { const i3 = i * 3; const scaleX = Math.exp(scalesBytes[i3] / 16 - 10); @@ -169,60 +266,37 @@ export class SpzReader { scalesCallback?.(i, scaleX, scaleY, scaleZ); } } - if (this.version === 3) { - // Version 3 uses a trick called "smallest three" to compress the rotation quaternions - // achieving better precision. "Optimizing orientation" section at https://gafferongames.com/post/snapshot_compression/ A quaternion length must be 1: x^2+y^2+z^2+w^2 = 1 - // We can drop one component and reconstruct it with the identity above. - // Largest component is dropped for best numerical precision. - // Quaternion stored in 32 bits - // 10 bits singed integer for each of the 3 components + 2 bits indicating the index of dropped component. - // vs 8 bits for each component uncompressed (spz version < 3) - // Max Value after extracting largest component v is another component v - // (v,v,0,0) - // v^2 + v^2 = 1 - // v = 1 / sqrt(2); - const maxValue = 1 / Math.sqrt(2); // 0.7071 - const quatBytes = await this.reader.read(this.numSplats * 4); + if (this.version >= 3) { + // Smallest-three quaternion encoding (v3 and v4): drop the largest component and + // store the three smallest at 9-bit precision + 1-bit sign, plus 2-bit index of + // the dropped component, all packed into 32 bits. + const maxValue = 1 / Math.sqrt(2); // max magnitude of any non-largest component + const quatBytes = await read(this.numSplats * 4); for (let i = 0; i < this.numSplats; i++) { - const i3 = i * 4; + const i4 = i * 4; const quaternion = [0, 0, 0, 0]; - const values = [ - quatBytes[i3], - quatBytes[i3 + 1], - quatBytes[i3 + 2], - quatBytes[i3 + 3], - ]; - // all values are packed in 32 bits (10 per each of 3 components + 2 bits of index of larged value) const combinedValues = - values[0] + (values[1] << 8) + (values[2] << 16) + (values[3] << 24); - // each component value is 9 bits + sign (1 bit) + quatBytes[i4] + + (quatBytes[i4 + 1] << 8) + + (quatBytes[i4 + 2] << 16) + + (quatBytes[i4 + 3] << 24); const valueMask = (1 << 9) - 1; - // extract index of the largest element. 2 top bits. const largestIndex = combinedValues >>> 30; let remainingValues = combinedValues; let sumSquares = 0; - for (let i = 3; i >= 0; --i) { - if (i !== largestIndex) { - // extract current value and sign. + for (let j = 3; j >= 0; --j) { + if (j !== largestIndex) { const value = remainingValues & valueMask; const sign = (remainingValues >>> 9) & 0x1; - // each value is represented as 10 bits. Shift to next one. remainingValues = remainingValues >>> 10; - // convert to range [0,1] and then to [0, 0.7071] - quaternion[i] = maxValue * (value / valueMask); - // apply sign. - quaternion[i] = sign === 0 ? quaternion[i] : -quaternion[i]; - // accumulate the sum of squares - sumSquares += quaternion[i] * quaternion[i]; + quaternion[j] = maxValue * (value / valueMask); + quaternion[j] = sign === 0 ? quaternion[j] : -quaternion[j]; + sumSquares += quaternion[j] * quaternion[j]; } } - // quartenion length must be 1 (x^2+y^2+z^2+w^2 = 1) - // so can reconstruct largest component from the other 3. - // w = sqrt(1 - x^2 - y^2 - z^2); - const square = 1 - sumSquares; - quaternion[largestIndex] = Math.sqrt(Math.max(square, 0)); + quaternion[largestIndex] = Math.sqrt(Math.max(1 - sumSquares, 0)); quatCallback?.( i, @@ -233,7 +307,8 @@ export class SpzReader { ); } } else { - const quatBytes = await this.reader.read(this.numSplats * 3); + // First-three quaternion encoding (v1/v2): store x/y/z as uint8, reconstruct w. + const quatBytes = await read(this.numSplats * 3); for (let i = 0; i < this.numSplats; i++) { const i3 = i * 3; const quatX = quatBytes[i3] / 127.5 - 1; @@ -250,7 +325,7 @@ export class SpzReader { const sh1 = new Float32Array(3 * 3); const sh2 = this.shDegree >= 2 ? new Float32Array(5 * 3) : undefined; const sh3 = this.shDegree >= 3 ? new Float32Array(7 * 3) : undefined; - const shBytes = await this.reader.read( + const shBytes = await read( this.numSplats * SH_DEGREE_TO_VECS[this.shDegree] * 3, ); @@ -275,7 +350,8 @@ export class SpzReader { shCallback?.(i, sh1, sh2, sh3); } } - if (this.flagLod) { + // LOD extension is only present in gzip-based (v1-v3) files. + if (this.flagLod && this.reader !== null) { let bytes = await this.reader.read(this.numSplats * 2); for (let i = 0; i < this.numSplats; i++) { const i2 = i * 2; @@ -301,12 +377,21 @@ const SH_DEGREE_TO_VECS: Record = { 1: 3, 2: 8, 3: 15 }; const SH_C0 = 0.28209479177387814; export const SPZ_MAGIC = 0x5053474e; // NGSP = Niantic gaussian splat -export const SPZ_VERSION = 3; +export const SPZ_VERSION = 4; export const FLAG_ANTIALIASED = 0x1; +const NGSP_HEADER_SIZE = 32; +const TOC_ENTRY_SIZE = 16; // [compressedSize u64 LE][uncompressedSize u64 LE] +const ZSTD_COMPRESSION_LEVEL = 12; +// SPZ v4 writer: each attribute lives in its own Uint8Array buffer; finalize() ZSTD-compresses +// each one and assembles the [header | TOC | streams] file layout. export class SpzWriter { - buffer: ArrayBuffer; - view: DataView; + positions: Uint8Array; // 9 bytes per splat (24-bit signed fixed-point x,y,z) + alphas: Uint8Array; // 1 byte per splat + colors: Uint8Array; // 3 bytes per splat + scales: Uint8Array; // 3 bytes per splat (log-encoded) + rotations: Uint8Array; // 4 bytes per splat (smallest-three quaternion) + sh: Uint8Array; // SH_DEGREE_TO_VECS[shDegree] * 3 bytes per splat (length 0 if shDegree==0) numSplats: number; shDegree: number; fractionalBits: number; @@ -325,69 +410,48 @@ export class SpzWriter { fractionalBits?: number; flagAntiAlias?: boolean; }) { - const splatSize = - 9 + // Position - 1 + // Opacity - 3 + // Scale - 3 + // DC-rgb - 4 + // Rotation - (shDegree >= 1 ? 9 : 0) + - (shDegree >= 2 ? 15 : 0) + - (shDegree >= 3 ? 21 : 0); - const bufferSize = 16 + numSplats * splatSize; - this.buffer = new ArrayBuffer(bufferSize); - this.view = new DataView(this.buffer); - - this.view.setUint32(0, SPZ_MAGIC, true); // NGSP - this.view.setUint32(4, SPZ_VERSION, true); - this.view.setUint32(8, numSplats, true); - this.view.setUint8(12, shDegree); - this.view.setUint8(13, fractionalBits); - this.view.setUint8(14, flagAntiAlias ? FLAG_ANTIALIASED : 0); - this.view.setUint8(15, 0); // Reserved - this.numSplats = numSplats; this.shDegree = shDegree; this.fractionalBits = fractionalBits; this.fraction = 1 << fractionalBits; this.flagAntiAlias = flagAntiAlias; + + this.positions = new Uint8Array(numSplats * 9); + this.alphas = new Uint8Array(numSplats); + this.colors = new Uint8Array(numSplats * 3); + this.scales = new Uint8Array(numSplats * 3); + this.rotations = new Uint8Array(numSplats * 4); + const shVecs = SH_DEGREE_TO_VECS[shDegree] || 0; + this.sh = new Uint8Array(numSplats * shVecs * 3); } setCenter(index: number, x: number, y: number, z: number) { - // Divide by this.fraction and round to nearest integer, - // then write as 3-bytes per x then y then z. + // Divide by this.fraction, round to nearest integer, write as 3 bytes per axis. const xRounded = Math.round(x * this.fraction); const xInt = Math.max(-0x7fffff, Math.min(0x7fffff, xRounded)); const yRounded = Math.round(y * this.fraction); const yInt = Math.max(-0x7fffff, Math.min(0x7fffff, yRounded)); const zRounded = Math.round(z * this.fraction); const zInt = Math.max(-0x7fffff, Math.min(0x7fffff, zRounded)); - const clipped = xRounded !== xInt || yRounded !== yInt || zRounded !== zInt; - if (clipped) { + if (xRounded !== xInt || yRounded !== yInt || zRounded !== zInt) { this.clippedCount += 1; - // if (this.clippedCount < 10) { - // // Write x y z also in hex - // console.log(`Clipped ${index}: ${x}, ${y}, ${z} (0x${x.toString(16)}, 0x${y.toString(16)}, 0x${z.toString(16)}) -> ${xRounded}, ${yRounded}, ${zRounded} (0x${xRounded.toString(16)}, 0x${yRounded.toString(16)}, 0x${zRounded.toString(16)}) -> ${xInt}, ${yInt}, ${zInt} (0x${xInt.toString(16)}, 0x${yInt.toString(16)}, 0x${zInt.toString(16)})`); - // } } - const i9 = index * 9; - const base = 16 + i9; - this.view.setUint8(base, xInt & 0xff); - this.view.setUint8(base + 1, (xInt >> 8) & 0xff); - this.view.setUint8(base + 2, (xInt >> 16) & 0xff); - this.view.setUint8(base + 3, yInt & 0xff); - this.view.setUint8(base + 4, (yInt >> 8) & 0xff); - this.view.setUint8(base + 5, (yInt >> 16) & 0xff); - this.view.setUint8(base + 6, zInt & 0xff); - this.view.setUint8(base + 7, (zInt >> 8) & 0xff); - this.view.setUint8(base + 8, (zInt >> 16) & 0xff); + const base = index * 9; + this.positions[base] = xInt & 0xff; + this.positions[base + 1] = (xInt >> 8) & 0xff; + this.positions[base + 2] = (xInt >> 16) & 0xff; + this.positions[base + 3] = yInt & 0xff; + this.positions[base + 4] = (yInt >> 8) & 0xff; + this.positions[base + 5] = (yInt >> 16) & 0xff; + this.positions[base + 6] = zInt & 0xff; + this.positions[base + 7] = (zInt >> 8) & 0xff; + this.positions[base + 8] = (zInt >> 16) & 0xff; } setAlpha(index: number, alpha: number) { - const base = 16 + this.numSplats * 9 + index; - this.view.setUint8( - base, - Math.max(0, Math.min(255, Math.round(alpha * 255))), + this.alphas[index] = Math.max( + 0, + Math.min(255, Math.round(alpha * 255)), ); } @@ -397,25 +461,25 @@ export class SpzWriter { } setRgb(index: number, r: number, g: number, b: number) { - const base = 16 + this.numSplats * 10 + index * 3; - this.view.setUint8(base, SpzWriter.scaleRgb(r)); - this.view.setUint8(base + 1, SpzWriter.scaleRgb(g)); - this.view.setUint8(base + 2, SpzWriter.scaleRgb(b)); + const base = index * 3; + this.colors[base] = SpzWriter.scaleRgb(r); + this.colors[base + 1] = SpzWriter.scaleRgb(g); + this.colors[base + 2] = SpzWriter.scaleRgb(b); } setScale(index: number, scaleX: number, scaleY: number, scaleZ: number) { - const base = 16 + this.numSplats * 13 + index * 3; - this.view.setUint8( - base, - Math.max(0, Math.min(255, Math.round((Math.log(scaleX) + 10) * 16))), + const base = index * 3; + this.scales[base] = Math.max( + 0, + Math.min(255, Math.round((Math.log(scaleX) + 10) * 16)), ); - this.view.setUint8( - base + 1, - Math.max(0, Math.min(255, Math.round((Math.log(scaleY) + 10) * 16))), + this.scales[base + 1] = Math.max( + 0, + Math.min(255, Math.round((Math.log(scaleY) + 10) * 16)), ); - this.view.setUint8( - base + 2, - Math.max(0, Math.min(255, Math.round((Math.log(scaleZ) + 10) * 16))), + this.scales[base + 2] = Math.max( + 0, + Math.min(255, Math.round((Math.log(scaleZ) + 10) * 16)), ); } @@ -423,23 +487,21 @@ export class SpzWriter { index: number, ...q: [number, number, number, number] // x, y, z, w ) { - const base = 16 + this.numSplats * 16 + index * 4; - + const base = index * 4; const quat = normalize(q); - // Find largest component + // Smallest-three encoding: drop the largest component and reconstruct from |q|=1. let iLargest = 0; for (let i = 1; i < 4; ++i) { if (Math.abs(quat[i]) > Math.abs(quat[iLargest])) { iLargest = i; } } - - // Since -quat represents the same rotation as quat, transform the quaternion so the largest element - // is positive. This avoids having to send its sign bit. + // -q represents the same rotation as q; flip so the largest element is positive + // and we can avoid sending its sign bit. const negate = quat[iLargest] < 0 ? 1 : 0; - // Do compression using sign bit and 9-bit precision per element. + // Pack: [2-bit iLargest][3 × (1-bit sign + 9-bit magnitude)] = 32 bits total. let comp = iLargest; for (let i = 0; i < 4; ++i) { if (i !== iLargest) { @@ -451,10 +513,10 @@ export class SpzWriter { } } - this.view.setUint8(base, comp & 0xff); - this.view.setUint8(base + 1, (comp >> 8) & 0xff); - this.view.setUint8(base + 2, (comp >> 16) & 0xff); - this.view.setUint8(base + 3, (comp >>> 24) & 0xff); + this.rotations[base] = comp & 0xff; + this.rotations[base + 1] = (comp >> 8) & 0xff; + this.rotations[base + 2] = (comp >> 16) & 0xff; + this.rotations[base + 3] = (comp >>> 24) & 0xff; } static quantizeSh(sh: number, bits: number) { @@ -472,43 +534,86 @@ export class SpzWriter { sh3?: Float32Array, ) { const shVecs = SH_DEGREE_TO_VECS[this.shDegree] || 0; - const base1 = 16 + this.numSplats * 20 + index * shVecs * 3; + const base1 = index * shVecs * 3; for (let j = 0; j < 9; ++j) { - this.view.setUint8(base1 + j, SpzWriter.quantizeSh(sh1[j], 5)); + this.sh[base1 + j] = SpzWriter.quantizeSh(sh1[j], 5); } if (sh2) { const base2 = base1 + 9; for (let j = 0; j < 15; ++j) { - this.view.setUint8(base2 + j, SpzWriter.quantizeSh(sh2[j], 4)); + this.sh[base2 + j] = SpzWriter.quantizeSh(sh2[j], 4); } if (sh3) { const base3 = base2 + 15; for (let j = 0; j < 21; ++j) { - this.view.setUint8(base3 + j, SpzWriter.quantizeSh(sh3[j], 4)); + this.sh[base3 + j] = SpzWriter.quantizeSh(sh3[j], 4); } } } } async finalize(): Promise { - const input = new Uint8Array(this.buffer); - const stream = new ReadableStream({ - async start(controller) { - controller.enqueue(input); - controller.close(); - }, - }); - const compressed = stream.pipeThrough(new CompressionStream("gzip")); - const response = new Response(compressed); - const buffer = await response.arrayBuffer(); + await ensureZstdInit(); + // Stream order matches the C++ reference encoder: positions, alphas, colors, + // scales, rotations, sh. Zero-size streams are skipped. + const rawStreams: Uint8Array[] = [ + this.positions, + this.alphas, + this.colors, + this.scales, + this.rotations, + ]; + if (this.sh.length > 0) { + rawStreams.push(this.sh); + } + + const compressed = rawStreams.map((s) => + zstdCompress(s, ZSTD_COMPRESSION_LEVEL), + ); + + const numStreams = rawStreams.length; + const tocByteOffset = NGSP_HEADER_SIZE; + const tocSize = numStreams * TOC_ENTRY_SIZE; + let totalCompressed = 0; + for (const c of compressed) totalCompressed += c.length; + const totalSize = tocByteOffset + tocSize + totalCompressed; + + const out = new Uint8Array(totalSize); + const view = new DataView(out.buffer); + + // 32-byte NGSP header + view.setUint32(0, SPZ_MAGIC, true); + view.setUint32(4, SPZ_VERSION, true); // 4 + view.setUint32(8, this.numSplats, true); + view.setUint8(12, this.shDegree); + view.setUint8(13, this.fractionalBits); + view.setUint8(14, this.flagAntiAlias ? FLAG_ANTIALIASED : 0); + view.setUint8(15, numStreams); + view.setUint32(16, tocByteOffset, true); + // bytes 20-31: reserved (already zero-initialized) + + // TOC: numStreams × 16 bytes, each [compressedSize u64 LE][uncompressedSize u64 LE] + for (let i = 0; i < numStreams; i++) { + const e = tocByteOffset + i * TOC_ENTRY_SIZE; + view.setUint32(e, compressed[i].length, true); + view.setUint32(e + 4, 0, true); // hi 32 bits of compressedSize + view.setUint32(e + 8, rawStreams[i].length, true); + view.setUint32(e + 12, 0, true); // hi 32 bits of uncompressedSize + } + + // Concatenated compressed streams + let dataOffset = tocByteOffset + tocSize; + for (const c of compressed) { + out.set(c, dataOffset); + dataOffset += c.length; + } + + let totalRaw = 0; + for (const s of rawStreams) totalRaw += s.length; console.log( - "Compressed", - input.length, - "bytes to", - buffer.byteLength, - "bytes", + `SPZ v4: ${this.numSplats} splats, ${totalRaw} bytes raw -> ${totalSize} bytes (header+TOC+ZSTD)`, ); - return new Uint8Array(buffer); + return out; } }