diff --git a/docs/docs/beta-release-notes.mdx b/docs/docs/beta-release-notes.mdx index 00dc38d..d7e5d32 100644 --- a/docs/docs/beta-release-notes.mdx +++ b/docs/docs/beta-release-notes.mdx @@ -15,13 +15,18 @@ This page tracks **API and naming changes** since the GeoBrix project started. A ## What's new in v0.3.0 -Released 2026-05-19. Per-version highlights; full migration tables are in the per-component sections below. +Released 2026-05-26. Per-version highlights; full migration tables are in the per-component sections below. - **`rst_clip` CRS axis-order fix (all-black clips).** GDAL 3+ defaults EPSG-imported `SpatialReference`s to authority-compliant axis order (lat/lon for EPSG:4326), which silently swapped axes against JTS/Databricks WKT/WKB cutlines so the clip missed the raster entirely. The reprojection now clones the source/destination `SpatialReference`s and forces `OAMS_TRADITIONAL_GIS_ORDER` before the OGR transform; caller-owned `SpatialReference`s are not mutated. - **EWKT / EWKB support for `rst_clip`.** `JTS.fromWKT` / `JTS.fromWKB` auto-detect EWKT/EWKB; new `JTS.toEWKT` / `JTS.toEWKB` helpers emit SRID-preserving forms. `rst_clip` reprojects the cutline when its SRID differs from the raster CRS, and falls back to the raster's CRS (Mosaic-compatible) when the SRID is `0` / unresolvable. - **`rst_transform` rejects invalid SRIDs.** `targetSrid <= 0` and unresolvable EPSG codes now surface a clear error via tile metadata `error_message` instead of returning a raster with an uninitialized CRS. - **`/vsimem/` path-handling hardening.** `rst_memsize` / `rst_unlink` / GDAL writer in-memory byte fetch now use `startsWith("/vsimem/")` (not `contains`) and null-check `GetMemFileBuffer`, so datasets whose description embeds the substring (e.g. NetCDF subdataset selectors) aren't mis-routed through the in-memory branch. - **`tile.raster` bytes are always self-contained (no VRT payloads).** Three RasterX operations — `MergeRasters` (`gbx_rst_merge`, `gbx_rst_merge_agg`), `MergeBands` (`gbx_rst_frombands`), and `PixelCombineRasters` (`gbx_rst_derivedband`, `gbx_rst_derivedband_agg`, `gbx_rst_combineavg`, `gbx_rst_combineavg_agg`) — used to return tiles whose `metadata("driver")` claimed `VRT` even though the on-disk file was a materialized GTiff. That mis-tag propagated through `RasterDriver.writeToBytes` (which keys both the tempfile extension AND the `-of` flag in the inner `gdal_translate` call off `metadata.driver`), causing the serialized `tile.raster` payload to be VRT XML referencing a `/vsimem/` tempfile only reachable on the producing executor. Single-node testing passed by accident; multi-executor clusters hit `file not found` when the VRT was opened elsewhere. Fix: `GDALTranslate.executeTranslate` now records the **output** dataset's driver in its returned metadata (not the input's), and `RasterDriver.writeToBytes` defensively coerces VRT to GTiff on serialization + sniffs the result to refuse shipping VRT bytes. Regression coverage in [`RST_NoVrtPayloadTest`](https://github.com/databrickslabs/geobrix/blob/main/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_NoVrtPayloadTest.scala). +- **`PixelCombineRasters` pixel function now actually fires (`combineavg` / `derivedband` were silently returning one of the inputs).** `gbx_rst_combineavg`, `gbx_rst_combineavg_agg`, `gbx_rst_derivedband`, and `gbx_rst_derivedband_agg` build a multi-source VRT, inject a `Python` band, and re-open it for `gdal_translate`. The previous implementation re-opened the VRT **before** mutating the XML file, so the in-memory `Dataset` handle never saw the pixel function; `gdal.Translate` then fell back to a default multi-source mosaic (last-source-wins per pixel). On co-extensive inputs (e.g. a monthly EO time-series), the output silently equaled one of the inputs — non-deterministic per partition in a distributed setting, producing visible tile-of-different-years patchwork on multi-executor clusters. Fix: `PixelCombineRasters.combine` now injects the pixel function **before** the VRT is re-opened, and pre-creates the per-JVM `NodeFilePathUtil.rootPath` staging dir itself (previously only `ClipToGeom` did, so `combineavg` would `file not found` if it was the first op to hit a fresh JVM). Regression coverage: `RST_AggregationsTest` "CombineAvg actually averages pixel values" (two constant rasters 50 + 100 → output 75). +- **Friendly error on `ARRAY`-function misuse.** Calling `gbx_rst_combineavg`, `gbx_rst_merge`, `gbx_rst_frombands`, or `gbx_rst_mapalgebra` on a single tile column (instead of an `ARRAY` like `collect_list(tile)`) used to surface as a raw `ClassCastException: StructType cannot be cast to ArrayType` from inside Catalyst analysis — untraceable from a notebook. The four expressions now route through `RST_ExpressionUtil.arrayOfTileRasterType`, which raises a clean `IllegalArgumentException` naming the function, the actual type received, and (where applicable) the aggregator companion the user likely wanted, e.g. `gbx_rst_combineavg expects ARRAY (e.g. collect_list(tile) or array(t1, t2, ...)), but received STRUCT<...>. To aggregate the column across rows, use gbx_rst_combineavg_agg(tile).` +- **Docs: `GDAL_VRT_ENABLE_PYTHON` for custom GDAL code paths.** Built-in `combineavg` / `derivedband` calls auto-enable VRT Python via the in-process `GDALManager.withVrtPython` bracket — no cluster config needed. The new [RasterX § VRT Python pixel functions](./packages/rasterx#vrt-python-pixel-functions) section documents how to enable the same evaluation in your own GDAL calls (Python `gdal.SetConfigOption`, cluster `spark.executorEnv`, or the JVM `withVrtPython` helper) and points to the `TRUSTED_MODULES` variant for less-trusted VRT sources. A cross-reference is added in [Security § 6](./security#6-vrt-python-pixel-functions-off-by-default-by-design) explaining why GeoBrix ships the option `NO` by default. +- **`gbx_rst_derivedband` / `gbx_rst_derivedband_agg` numerical-correctness regression coverage.** These functions share the `PixelCombineRasters` code path with `combineavg`, so they were silently no-opping in the same way (returning one of the inputs unchanged on co-extensive stacks). The ordering fix above repairs both call sites, but the existing tests only checked that the result wasn't null — they would have passed either way. This release adds explicit pixel-value assertions: `RST_AggregationsTest` covers the in-process `RST_DerivedBand` path with a doubling pyfunc and a 3-input numpy-mean pyfunc, and `RST_AggEvalTest` covers the Spark-aggregation `rst_derivedband_agg` path end-to-end (three constant-Byte tiles 10/20/30 with a "mean × 2" pyfunc must yield 40 across the result tile). Two previously-passing tests used `def myfunc(x): return x * 2` — an invalid VRT pixel-function signature — and were updated to the canonical `(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize, buf_radius, gt, **kwargs)` shape; they only "passed" before because the pyfunc never actually ran. +- **`gbx_rst_combineavg` / `gbx_rst_combineavg_agg` math corrected (NoData, valid zeros, rounding).** With the pixel function now firing (previous bullet), several latent bugs in the average kernel surface and are fixed in this release. The pyfunc used to sum every source value blindly — including each band's NoData sentinel (e.g. 255 on Byte EO products) — and counted only strictly-positive cells in the divisor (`np.sum(stacked > 0, axis=0)`), which (a) inflated the numerator with NoData and (b) wrongly excluded valid `0` measurements from the divisor. It also used `np.divide(..., casting='unsafe')`, which **truncates** rather than rounds when casting back to an integer output dtype (Byte / UInt16), producing systematic underbias on integer EO stacks. Now the kernel reads each source band's declared NoData (via `BandAccessors.getNoDataValue`, baked into the pyfunc source as a literal list at VRT-write time), masks NoData cells out of both sum and divisor, includes valid `0`s, uses float64 internally, and rounds-to-nearest-even before the unsafe cast when the output dtype is integer. The bogus `np.clip(out_ar, stacked.min(), stacked.max(), ...)` (the bounds were contaminated by NoData sentinels) is removed. When at least one input declares NoData, that value is also stamped on the output band so downstream `GetNoDataValue` reports all-NoData pixels. Regression coverage in `RST_AggregationsTest`: "excludes declared NoData from both sum and divisor", "counts valid 0 cells in the divisor", "rounds (not truncates) when casting to integer output". - **Scalar args without `f.lit(...)`.** Python wrappers auto-wrap `bool` / `int` / `float` / `bytes`; Scala adds typed overloads. SQL was already natively-typed. String literals still wrap in `f.lit(...)` per pyspark's column-ref convention. Details and migration examples in [Scalar values vs `lit(...)` wrapping](#scalar-values-vs-lit-wrapping). - **Example notebooks — EO Series, xView, and enablement diagrams.** New end-to-end walkthroughs under `docs/examples/` covering EO time-series, xView object-detection rasters, and RasterX architecture diagrams. - **Supply-chain hardening (lockdown).** Jobs pinned to the Databricks-hardened runner group (org-level allowlist, ephemeral VMs, constrained secret access); every Maven dependency, transitive dep, plugin, and plugin dependency is PGP-verified against `.maven-keys.list` before any compile or test execution; pip and Maven routed through JFrog with OIDC; init script + pinned package versions vetted; new [Security](./security.mdx) page in the docs. diff --git a/docs/docs/packages/rasterx.mdx b/docs/docs/packages/rasterx.mdx index ce125e5..492b236 100644 --- a/docs/docs/packages/rasterx.mdx +++ b/docs/docs/packages/rasterx.mdx @@ -112,6 +112,46 @@ Every RasterX function returns a tile whose `raster` field is a **self-contained Functions that internally build via an intermediate VRT — `gbx_rst_merge`, `gbx_rst_merge_agg`, `gbx_rst_frombands`, `gbx_rst_combineavg`, `gbx_rst_combineavg_agg`, `gbx_rst_derivedband`, `gbx_rst_derivedband_agg` — materialize the result to GTiff before returning, so downstream stages on different executors see real raster bytes. Inspect a tile's payload format from `tile.metadata.driver`; for any of the functions above, it will read `GTiff` (not `VRT`). See [Beta Release Notes](../beta-release-notes#whats-new-in-v030) for the v0.3.0 correctness fix that introduced this invariant. +## VRT Python pixel functions + +`gbx_rst_combineavg`, `gbx_rst_combineavg_agg`, `gbx_rst_derivedband`, and `gbx_rst_derivedband_agg` evaluate a Python expression on each pixel via GDAL's [VRT Python pixel-function API](https://gdal.org/en/stable/drivers/raster/vrt.html#using-derived-bands-with-pixel-functions-in-python). That API is gated behind the GDAL config option `GDAL_VRT_ENABLE_PYTHON`, which **GeoBrix sets to `NO` at executor startup** (see [Security § Restrict GDAL drivers](../security#6-vrt-python-pixel-functions-off-by-default-by-design)). When you call one of the four functions above, GeoBrix flips the option to `YES` for the duration of that call only — via the internal `GDALManager.withVrtPython` bracket — and restores `NO` immediately on return. You don't need to set anything on the cluster or in your notebook to use the built-in functions. + +### When you need to enable it yourself + +If you're invoking the GDAL Python bindings (`from osgeo import gdal`) **directly** — outside the built-in RasterX functions — and you read a VRT that declares a `Python` band, you'll get an empty/null read unless you enable the option in the same process. Pick one of: + +**Python — programmatic, scoped to your read.** Recommended in all cases. Mirrors what GeoBrix does internally, works for both driver-side `pyspark.sql` calls and inside `mapPartitions` / `mapInPandas` UDFs that load VRT-with-pyfunc via `osgeo.gdal`, and survives interleaving with GeoBrix built-in calls (each GeoBrix call resets the option to `NO` on exit, so re-set it on every read): + +```python +from osgeo import gdal + +gdal.SetConfigOption("GDAL_VRT_ENABLE_PYTHON", "YES") +try: + ds = gdal.Open("/path/to/your/vrt-with-pixel-function.vrt") + arr = ds.GetRasterBand(1).ReadAsArray() + ds = None +finally: + gdal.SetConfigOption("GDAL_VRT_ENABLE_PYTHON", "NO") +``` + +**Cluster env var — for Python-worker processes only.** Setting `spark.executorEnv.GDAL_VRT_ENABLE_PYTHON YES` on the cluster works for Python UDF workers (a separate process from the JVM, where GDAL initializes from env vars). It does **not** help JVM-side reads — GeoBrix calls `gdal.SetConfigOption("GDAL_VRT_ENABLE_PYTHON", "NO")` at executor JVM startup, and `SetConfigOption` takes precedence over the env var. Prefer the programmatic form above unless you have a strong reason to globally enable. + +**Scala / JVM code.** If you're writing custom Spark expressions that consume Python-pixel VRTs, wrap the read/translate in the same helper GeoBrix uses internally — it refcounts the option so concurrent tasks on the same executor JVM compose safely: + +```scala +import com.databricks.labs.gbx.rasterx.gdal.GDALManager + +val result = GDALManager.withVrtPython { + val ds = org.gdal.gdal.gdal.Open(vrtPath) + // ... GDAL reads / translates here see the Python pixel function ... + ds +} +``` + +### Trusted-modules variant + +GDAL also accepts `GDAL_VRT_ENABLE_PYTHON=TRUSTED_MODULES` plus a `GDAL_VRT_PYTHON_TRUSTED_MODULES` allowlist if you want pixel-function code restricted to specific Python module prefixes. GeoBrix uses the plain `YES` form because the pixel-function source is constructed in-process from trusted (geobrix-generated) strings, never from user-supplied VRT XML on disk. If your custom code path reads VRTs whose `` originates from less-trusted sources, switch to the `TRUSTED_MODULES` form and allowlist only what you intend to load. + ## Usage Examples ### Python/PySpark diff --git a/docs/docs/security.mdx b/docs/docs/security.mdx index a80eb68..572daba 100644 --- a/docs/docs/security.mdx +++ b/docs/docs/security.mdx @@ -206,6 +206,24 @@ publishing details. See [SECURITY.md](https://github.com/databrickslabs/geobrix/blob/main/SECURITY.md) for what to include in the report. +### 6. VRT Python pixel functions: off by default by design + +GDAL's [VRT Python pixel function API](https://gdal.org/en/stable/drivers/raster/vrt.html#using-derived-bands-with-pixel-functions-in-python) +lets a `` element in a VRT XML file execute arbitrary +Python in-process at band-read time. GeoBrix sets `GDAL_VRT_ENABLE_PYTHON=NO` +at executor startup and only flips it to `YES` for the duration of an +individual `combineavg` / `derivedband` call (via the internal +`GDALManager.withVrtPython` bracket). The four built-in functions inject +pyfunc source generated by GeoBrix itself, never by user input. + +If your own code consumes Python-pixel VRTs from less-trusted sources +(e.g. you pull VRT XML from object storage that other principals can +write to), either keep the option `NO` and pre-translate to GTiff, or +switch to `GDAL_VRT_ENABLE_PYTHON=TRUSTED_MODULES` with a narrow +`GDAL_VRT_PYTHON_TRUSTED_MODULES` allowlist. See +[RasterX § VRT Python pixel functions](./packages/rasterx#vrt-python-pixel-functions) +for the full how-to. + ## Next steps - [Installation Guide](./installation) — apply the init script as part of diff --git a/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_CombineAvg.scala b/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_CombineAvg.scala index 26c5167..323b0ef 100644 --- a/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_CombineAvg.scala +++ b/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_CombineAvg.scala @@ -18,7 +18,9 @@ case class RST_CombineAvg( ) extends InvokedExpression { /** Raster DataType from the tile array element struct. */ - private def rasterType = tileExpr.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields(1).dataType + private def rasterType = RST_ExpressionUtil.arrayOfTileRasterType( + RST_CombineAvg.name, tileExpr, aggHint = Some("gbx_rst_combineavg_agg") + ) override def children: Seq[Expression] = Seq(tileExpr, ExpressionConfigExpr()) override def dataType: DataType = RST_ExpressionUtil.tileDataType(rasterType) override def nullable: Boolean = true diff --git a/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_MapAlgebra.scala b/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_MapAlgebra.scala index 2bce2ef..f6a0387 100644 --- a/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_MapAlgebra.scala +++ b/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_MapAlgebra.scala @@ -23,7 +23,9 @@ case class RST_MapAlgebra( jsonSpecExpr: Expression ) extends InvokedExpression { - private def rasterType = tileExpr.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields(1).dataType + private def rasterType = RST_ExpressionUtil.arrayOfTileRasterType( + RST_MapAlgebra.name, tileExpr, aggHint = None + ) override def children: Seq[Expression] = Seq(tileExpr, jsonSpecExpr, ExpressionConfigExpr()) override def dataType: DataType = RST_ExpressionUtil.tileDataType(rasterType) override def nullable: Boolean = true diff --git a/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_Merge.scala b/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_Merge.scala index 4b09408..1c453d4 100644 --- a/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_Merge.scala +++ b/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/RST_Merge.scala @@ -18,7 +18,9 @@ case class RST_Merge( ) extends InvokedExpression { /** Raster DataType from the tile array element struct. */ - private def rasterType = tileExpr.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields(1).dataType + private def rasterType = RST_ExpressionUtil.arrayOfTileRasterType( + RST_Merge.name, tileExpr, aggHint = Some("gbx_rst_merge_agg") + ) override def children: Seq[Expression] = Seq(tileExpr, ExpressionConfigExpr()) override def dataType: DataType = RST_ExpressionUtil.tileDataType(rasterType) override def nullable: Boolean = true diff --git a/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/constructor/RST_FromBands.scala b/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/constructor/RST_FromBands.scala index 39b6824..fb2fda0 100644 --- a/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/constructor/RST_FromBands.scala +++ b/src/main/scala/com/databricks/labs/gbx/rasterx/expressions/constructor/RST_FromBands.scala @@ -18,7 +18,9 @@ case class RST_FromBands( ) extends InvokedExpression { /** Raster DataType from the bands array element struct. */ - private def rasterType = bandsExpr.dataType.asInstanceOf[ArrayType].elementType.asInstanceOf[StructType].fields(1).dataType + private def rasterType = RST_ExpressionUtil.arrayOfTileRasterType( + RST_FromBands.name, bandsExpr, aggHint = None + ) override def children: Seq[Expression] = Seq(bandsExpr, ExpressionConfigExpr()) override def dataType: DataType = RST_ExpressionUtil.tileDataType(rasterType) override def nullable: Boolean = true diff --git a/src/main/scala/com/databricks/labs/gbx/rasterx/operations/CombineAVG.scala b/src/main/scala/com/databricks/labs/gbx/rasterx/operations/CombineAVG.scala index e06ea95..d30d73b 100644 --- a/src/main/scala/com/databricks/labs/gbx/rasterx/operations/CombineAVG.scala +++ b/src/main/scala/com/databricks/labs/gbx/rasterx/operations/CombineAVG.scala @@ -2,25 +2,75 @@ package com.databricks.labs.gbx.rasterx.operations import org.gdal.gdal.Dataset -/** Pixel-wise average of input rasters via VRT Python pixel function; output type double. Returns (Dataset, metadata). Caller must release. */ +/** + * Pixel-wise mean across N input rasters via a VRT Python pixel function. + * + * Each input contributes one band; the embedded Python function reads each + * source's declared NoData (via `BandAccessors.getNoDataValue`, baked into + * the pyfunc source as a literal list at VRT-write time) and excludes those + * cells from BOTH the sum and the divisor. Cells with valid value `0` count + * toward the mean (the previous `>0` mask wrongly excluded them). + * + * Output band keeps the VRT's dtype (typically the input dtype). When that + * dtype is integer, the mean is rounded to nearest int before the unsafe + * cast — bare truncation produced systematic underbias on Byte / UInt16 EO + * stacks. When all inputs at a pixel are NoData, the output cell carries the + * first declared input NoData (or 0 if no input declared one), and that + * NoData value is also stamped on the output band so downstream consumers + * can detect all-NoData pixels with `GetNoDataValue`. + */ object CombineAVG { - /** Average = sum/div per pixel (div = count of non-zero); delegates to PixelCombineRasters. */ + /** Average per-pixel, excluding per-source NoData; preserves first-source NoData on output band. */ def compute(rasters: Array[Dataset], options: Map[String, String]): (Dataset, Map[String, String]) = { + val sourceNoData: Array[Option[Double]] = rasters.map { ds => + val v = BandAccessors.getNoDataValue(ds.GetRasterBand(1)) + if (v.isNaN) None else Some(v) + } + val nodataListLiteral = sourceNoData + .map(_.map(d => f"$d%s").getOrElse("None")) + .mkString("[", ", ", "]") + val fallback: Double = sourceNoData.collectFirst { case Some(v) => v }.getOrElse(0.0) + val fallbackLiteral = f"$fallback%s" - val pythonFunc = """ - |import numpy as np - |import sys - | - |def average(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize, buf_radius, gt, **kwargs): - | stacked_array = np.array(in_ar) - | pixel_sum = np.sum(stacked_array, axis=0) - | div = np.sum(stacked_array > 0, axis=0) - | div = np.where(div==0, 1, div) - | np.divide(pixel_sum, div, out=out_ar, casting='unsafe') - | np.clip(out_ar, stacked_array.min(), stacked_array.max(), out=out_ar) - |""".stripMargin - PixelCombineRasters.combine(rasters, options, pythonFunc, "average") + // Pixel function code is interpolated into VRT XML at PixelCombineRasters- + // build time. Keep it self-contained — only depends on numpy + the two + // injected literals (NODATA list, FALLBACK scalar). + val pythonFunc = + s""" + |import numpy as np + | + |NODATA = $nodataListLiteral + |FALLBACK = $fallbackLiteral + | + |def average(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize, buf_radius, gt, **kwargs): + | stacked = np.asarray(in_ar, dtype=np.float64) + | valid = np.ones(stacked.shape, dtype=bool) + | for i, nd in enumerate(NODATA): + | if nd is not None: + | valid[i] = stacked[i] != nd + | sums = np.where(valid, stacked, 0.0).sum(axis=0) + | counts = valid.sum(axis=0) + | means = np.where(counts > 0, sums / np.maximum(counts, 1), FALLBACK) + | if np.issubdtype(out_ar.dtype, np.integer): + | np.copyto(out_ar, np.rint(means), casting='unsafe') + | else: + | np.copyto(out_ar, means, casting='unsafe') + |""".stripMargin + + val (resultDs, resultMeta) = PixelCombineRasters.combine(rasters, options, pythonFunc, "average") + + // Stamp the chosen NoData onto the output band so callers can detect + // all-NoData pixels. Only do this when at least one input declared + // NoData — otherwise we'd invent a sentinel that wasn't present. + if (sourceNoData.exists(_.isDefined) && resultDs != null) { + scala.util.Try { + resultDs.GetRasterBand(1).SetNoDataValue(fallback) + resultDs.FlushCache() + } + } + + (resultDs, resultMeta) } } diff --git a/src/main/scala/com/databricks/labs/gbx/rasterx/operations/PixelCombineRasters.scala b/src/main/scala/com/databricks/labs/gbx/rasterx/operations/PixelCombineRasters.scala index a7ed271..c79a2c3 100644 --- a/src/main/scala/com/databricks/labs/gbx/rasterx/operations/PixelCombineRasters.scala +++ b/src/main/scala/com/databricks/labs/gbx/rasterx/operations/PixelCombineRasters.scala @@ -22,6 +22,13 @@ object PixelCombineRasters { val uuid = java.util.UUID.randomUUID().toString.replace("-", "_") val outShortName = dss.head.GetDriver().getShortName val extension = GDAL.getExtension(outShortName) + // Ensure the per-JVM staging dir exists. gdal.BuildVRT silently + // produces an unwritten Dataset if the parent dir is missing, and + // the subsequent RasterDriver.read(vrtPath) then throws + // "No such file or directory". Only ClipToGeom previously created + // this dir, leaving combineavg / derivedband broken if they were + // the first op to hit a fresh JVM. + Files.createDirectories(NodeFilePathUtil.rootPath) val vrtPath = s"${NodeFilePathUtil.rootPath}/combine_rasters_vrt_$uuid.vrt" val rasterPath = s"/vsimem/combine_rasters_$uuid.$extension" @@ -33,12 +40,19 @@ object PixelCombineRasters { ) vrtRaster._1.delete() + // Inject the pixel function BEFORE re-opening the VRT. gdal.Open + // parses VRT XML into an in-memory band structure at Open time, so + // any mutation of the on-disk file performed after Open is invisible + // to the Dataset handle passed to gdal.Translate — Translate then + // runs a default multi-source mosaic (last-source-wins per pixel) + // instead of evaluating the pixel function, silently returning one + // of the inputs. + addPixelFunction(vrtPath, pythonFunc, pythonFuncName) + // GDAL evaluates Python during VRT read-back and translate. val result = GDALManager.withVrtPython { val vrtRefreshed = RasterDriver.read(vrtPath, vrtRaster._2) - addPixelFunction(vrtPath, pythonFunc, pythonFuncName) - GDALTranslate.executeTranslate( rasterPath, vrtRefreshed, diff --git a/src/main/scala/com/databricks/labs/gbx/rasterx/util/RST_ExpressionUtil.scala b/src/main/scala/com/databricks/labs/gbx/rasterx/util/RST_ExpressionUtil.scala index ca24da1..ad8b1fe 100644 --- a/src/main/scala/com/databricks/labs/gbx/rasterx/util/RST_ExpressionUtil.scala +++ b/src/main/scala/com/databricks/labs/gbx/rasterx/util/RST_ExpressionUtil.scala @@ -20,6 +20,48 @@ object RST_ExpressionUtil { /** DataType of the raster field (second field) of the tile struct for the given tile expression. */ def rasterType(tileExpr: Expression): DataType = tileExpr.dataType.asInstanceOf[StructType].fields(1).dataType + /** + * Raster DataType inside an `ARRAY` expression, with a friendly + * IllegalArgumentException when the caller actually passed a single tile. + * + * Used by the non-aggregating array-of-tiles functions + * (`gbx_rst_combineavg`, `gbx_rst_merge`, `gbx_rst_frombands`, + * `gbx_rst_mapalgebra`). Without this guard, callers who write + * `gbx_rst_combineavg(tile)` instead of `gbx_rst_combineavg(collect_list(tile))` + * or the aggregator variant get a raw `ClassCastException: StructType + * cannot be cast to ArrayType` from inside Spark's CheckAnalysis, + * which is hostile and untraceable from a notebook. + * + * `funcName` is the SQL-facing name surfaced in the error. + * `aggHint` is an optional pointer to the aggregator companion + * (e.g. "gbx_rst_combineavg_agg") for functions where the typical + * mistake is reaching for the non-agg form when an aggregate across + * rows was wanted. + * + * Note: Spark 4.0's `AnalysisException` no longer exposes a + * `(String)` constructor (only the error-class form), so the error + * is raised as `IllegalArgumentException` — still surfaces during + * Catalyst analysis with the full message, and avoids depending on + * Spark-internal error-class catalogs. + */ + def arrayOfTileRasterType( + funcName: String, + tileExpr: Expression, + aggHint: Option[String] = None + ): DataType = tileExpr.dataType match { + case ArrayType(StructType(fields), _) if fields.length >= 2 => + fields(1).dataType + case other => + val aggSuggestion = aggHint + .map(name => s" To aggregate the column across rows, use $name(tile).") + .getOrElse("") + throw new IllegalArgumentException( + s"$funcName expects ARRAY (e.g. collect_list(tile) " + + s"or array(t1, t2, ...)), but received ${other.simpleString}." + + aggSuggestion + ) + } + /** StructType for a tile with the given tile expression's raster type (cellid, raster, metadata). */ def tileDataType(tileExpr: Expression): DataType = { val rasterDataType = rasterType(tileExpr) diff --git a/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_AggEvalTest.scala b/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_AggEvalTest.scala index 30fedb1..bfd95b3 100644 --- a/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_AggEvalTest.scala +++ b/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_AggEvalTest.scala @@ -1,15 +1,34 @@ package com.databricks.labs.gbx.rasterx.expressions import com.databricks.labs.gbx.rasterx.functions +import com.databricks.labs.gbx.rasterx.gdal.RasterDriver import com.databricks.labs.gbx.udfs.st_buffer import org.apache.spark.sql.DataFrame import org.apache.spark.sql.catalyst.plans.PlanTest import org.apache.spark.sql.functions._ import org.apache.spark.sql.test.SilentSparkSession +import org.gdal.gdal.gdal +import org.gdal.gdalconst.gdalconstConstants import org.scalatest.matchers.should.Matchers._ class RST_AggEvalTest extends PlanTest with SilentSparkSession { + /** + * Valid GDAL VRT Python pixel-function signature: (in_ar, out_ar, xoff, + * yoff, xsize, ysize, raster_xsize, raster_ysize, buf_radius, gt, + * **kwargs). The previous test used `def myfunc(x): return x*2`, which + * GDAL silently rejected; the surrounding `noException` only checked that + * nothing threw, so the malformed pyfunc went unnoticed for as long as + * the underlying PixelCombineRasters bug prevented any pyfunc from + * actually firing. + */ + private val doublePyFunc = + """ + |import numpy as np + |def myfunc(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize, buf_radius, gt, **kwargs): + | out_ar[:] = np.mean(np.asarray(in_ar, dtype=np.float64), axis=0) * 2 + |""".stripMargin + test("RST_AggEvalTest should evaluate expressions on raster columns") { val sc = spark import com.databricks.labs.gbx.rasterx.functions._ @@ -18,8 +37,6 @@ class RST_AggEvalTest extends PlanTest with SilentSparkSession { val tifPath = this.getClass.getResource("/modis/").toString - val pyfunc = "def myfunc(x):\n return x * 2" - def runQuery(df: DataFrame): Unit = { df .withColumn("bbox", rst_boundingbox(col("raster"))) @@ -28,7 +45,7 @@ class RST_AggEvalTest extends PlanTest with SilentSparkSession { .groupBy(lit(1)) .agg( rst_combineavg_agg(col("raster")), - rst_derivedband_agg(col("raster"), pyfunc, "myfunc"), + rst_derivedband_agg(col("raster"), doublePyFunc, "myfunc"), rst_merge_agg(col("raster")) ) .collect() @@ -52,4 +69,102 @@ class RST_AggEvalTest extends PlanTest with SilentSparkSession { } + test("rst_combineavg on a single tile column raises a friendly error pointing at the _agg form") { + // Regression for the user-reported notebook error: + // .selectExpr("gbx_rst_combineavg(tile) AS tile") + // The non-agg form expects ARRAY; passing a single tile struct + // previously produced a raw ClassCastException from inside Catalyst + // analysis. After RST_ExpressionUtil.arrayOfTileRasterType is in + // place we should see an IllegalArgumentException with a message that + // names the function, the actual type received, and the aggregator + // companion that the user likely wanted. + val sc = spark + import com.databricks.labs.gbx.rasterx.functions._ + import sc.implicits._ + functions.register(spark) + + val tifPath = this.getClass.getResource("/modis/").toString + val df = Seq( + s"$tifPath/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF" + ).toDF("path").withColumn("tile", rst_fromfile(col("path"), lit("GTiff"))) + + val thrown = intercept[Throwable] { + df.selectExpr("gbx_rst_combineavg(tile) AS tile").collect() + } + // Spark wraps analysis-time IllegalArgumentException so the actual + // class can be either IllegalArgumentException or one of Spark's + // catalyst-wrapper types — what matters is that our diagnostic + // message survives in the chain. + val joined = LazyList + .iterate(Option(thrown))(_.flatMap(t => Option(t.getCause))) + .takeWhile(_.isDefined) + .flatMap(_.map(_.getMessage).filter(_ != null)) + .mkString(" || ") + joined should include ("gbx_rst_combineavg expects ARRAY") + joined should include ("gbx_rst_combineavg_agg") + } + + test("rst_derivedband_agg actually transforms pixel values (parity with combineavg_agg fix)") { + // End-to-end Spark aggregation: three constant Byte tiles (10, 20, 30) + // averaged then doubled by the pyfunc should yield 40 everywhere + // (mean(10,20,30)=20, *2 = 40). Before the PixelCombineRasters + // ordering fix this returned one of the inputs unchanged through the + // aggregator path, so the output would be 10 / 20 / 30 — not 40. + val sc = spark + import com.databricks.labs.gbx.rasterx.functions._ + import sc.implicits._ + functions.register(spark) + + val tmpDir = java.nio.file.Files.createTempDirectory("gbx_derivedband_agg_").toFile + + val w = 8; val h = 8 + // Inline byte raster writer; can't reuse RST_AggregationsTest's helper + // from a separate suite without lifting it to a shared location and + // this is the only place outside that suite that needs it. + def writeByteConst(p: String, v: Int): Unit = { + val drv = gdal.GetDriverByName("GTiff") + val ds = drv.Create(p, w, h, 1, gdalconstConstants.GDT_Byte, Array[String]("COMPRESS=DEFLATE")) + ds.SetGeoTransform(Array[Double](149.0, 0.01, 0.0, -35.0, 0.0, -0.01)) + val sr = new org.gdal.osr.SpatialReference() + sr.ImportFromEPSG(4326) + ds.SetProjection(sr.ExportToWkt()) + val band = ds.GetRasterBand(1) + band.WriteRaster(0, 0, w, h, Array.fill[Byte](w * h)(v.toByte)) + band.FlushCache() + ds.FlushCache() + ds.delete() + } + val paths = Seq(10, 20, 30).map { v => + val p = s"${tmpDir.getAbsolutePath}/const_$v.tif" + writeByteConst(p, v) + p + } + + try { + val agg = paths.toDF("path") + .withColumn("tile", rst_fromfile(col("path"), lit("GTiff"))) + .groupBy(lit(1).alias("g")) + .agg(rst_derivedband_agg(col("tile"), doublePyFunc, "myfunc").alias("out")) + .select(col("out.raster").alias("raster")) + + val bytes = agg.collect().head.getAs[Array[Byte]]("raster") + // Decode in-memory GTiff bytes; verify uniform 40. + val mem = s"/vsimem/derivedband_agg_check_${java.util.UUID.randomUUID()}.tif" + gdal.FileFromMemBuffer(mem, bytes) + val ds = gdal.Open(mem) + try { + val buf = Array.ofDim[Double](ds.GetRasterXSize * ds.GetRasterYSize) + ds.GetRasterBand(1).ReadRaster(0, 0, ds.GetRasterXSize, ds.GetRasterYSize, gdalconstConstants.GDT_Float64, buf) + buf.min shouldBe 40.0 +- 0.5 + buf.max shouldBe 40.0 +- 0.5 + } finally { + RasterDriver.releaseDataset(ds) + gdal.Unlink(mem) + } + } finally { + tmpDir.listFiles().foreach(_.delete()) + tmpDir.delete() + } + } + } diff --git a/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_AggregationsTest.scala b/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_AggregationsTest.scala index 381fb99..526f37b 100644 --- a/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_AggregationsTest.scala +++ b/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_AggregationsTest.scala @@ -1,7 +1,10 @@ package com.databricks.labs.gbx.rasterx.expressions import com.databricks.labs.gbx.rasterx.gdal.{GDALManager, RasterDriver} +import com.databricks.labs.gbx.rasterx.operations.BandAccessors import org.gdal.gdal.{Dataset, gdal} +import org.gdal.gdalconst.gdalconstConstants +import org.gdal.osr.SpatialReference import org.scalatest.BeforeAndAfterAll import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers._ @@ -94,6 +97,144 @@ class RST_AggregationsTest extends AnyFunSuite with BeforeAndAfterAll { RasterDriver.releaseDataset(resultDs) } + test("CombineAvg actually averages pixel values (regression: pixel function must fire)") { + // Synthetic control with known answer. Two constant Byte rasters + // (50 and 100); the mean must be 75 everywhere. Before the + // PixelCombineRasters ordering fix, gdal.Open parsed the VRT XML + // BEFORE addPixelFunction wrote PixelFunctionLanguage into the file, + // so the in-memory Dataset never saw the Python function and + // gdal.Translate fell back to a default multi-source mosaic + // (last-source-wins per pixel). On co-extensive inputs the output + // then equaled the last input — 50 or 100, never 75 — silently and + // without any error. This test would fail (output 100 instead of + // 75); after the fix it passes. + val tmpDir = Files.createTempDirectory("gbx_combineavg_").toFile + val const50 = makeConstantByteRaster(s"${tmpDir.getAbsolutePath}/const_50.tif", 50) + val const100 = makeConstantByteRaster(s"${tmpDir.getAbsolutePath}/const_100.tif", 100) + try { + val (_, resultDs, _) = RST_CombineAvg.execute( + Seq((1L, const50, Map.empty), (1L, const100, Map.empty)) + ) + try { + val w = resultDs.GetRasterXSize + val h = resultDs.GetRasterYSize + val band = resultDs.GetRasterBand(1) + val buf = Array.ofDim[Double](w * h) + band.ReadRaster(0, 0, w, h, gdalconstConstants.GDT_Float64, buf) + buf.min shouldBe 75.0 +- 0.5 + buf.max shouldBe 75.0 +- 0.5 + (buf.sum / buf.length) shouldBe 75.0 +- 0.5 + } finally RasterDriver.releaseDataset(resultDs) + } finally { + RasterDriver.releaseDataset(const50) + RasterDriver.releaseDataset(const100) + tmpDir.listFiles().foreach(_.delete()) + tmpDir.delete() + } + } + + test("CombineAvg excludes declared NoData from both sum and divisor") { + // Two 4x4 Byte rasters, NoData=255 declared on input A: + // A: half of cells are 100, half are 255 (NoData) + // B: all cells are 50 + // Expected output per cell: + // - A=100 cells: mean(100, 50) = 75 + // - A=255 cells: A excluded, mean = 50 + // Before fix #2 the pyfunc summed 255 in (sum=305, div=2) → 152. + val tmpDir = Files.createTempDirectory("gbx_combineavg_nodata_").toFile + val w = 4; val h = 4 + val aVals = Array[Byte]( + 100, 100, 100, 100, + 100, 100, 100, 100, + -1, -1, -1, -1, // 0xFF = 255 + -1, -1, -1, -1 + ) + val bVals = Array.fill[Byte](w * h)(50) + val a = makeByteRaster(s"${tmpDir.getAbsolutePath}/a.tif", aVals, w, h, nodata = Some(255)) + val b = makeByteRaster(s"${tmpDir.getAbsolutePath}/b.tif", bVals, w, h, nodata = None) + try { + val (_, resultDs, _) = RST_CombineAvg.execute( + Seq((1L, a, Map.empty), (1L, b, Map.empty)) + ) + try { + val band = resultDs.GetRasterBand(1) + val buf = Array.ofDim[Double](w * h) + band.ReadRaster(0, 0, w, h, gdalconstConstants.GDT_Float64, buf) + // Top half: both inputs valid → 75 + buf.slice(0, 8).forall(v => math.abs(v - 75.0) <= 0.5) shouldBe true + // Bottom half: A is NoData, only B contributes → 50 + buf.slice(8, 16).forall(v => math.abs(v - 50.0) <= 0.5) shouldBe true + // Output NoData should be stamped (sourced from input A's 255). + val outNd = BandAccessors.getNoDataValue(band) + outNd shouldBe 255.0 +- 0.5 + } finally RasterDriver.releaseDataset(resultDs) + } finally { + RasterDriver.releaseDataset(a) + RasterDriver.releaseDataset(b) + tmpDir.listFiles().foreach(_.delete()) + tmpDir.delete() + } + } + + test("CombineAvg counts valid 0 cells in the divisor") { + // Two 4x4 Byte rasters, NO NoData declared on either: + // A: all cells = 0 (a valid measurement of zero, not NoData) + // B: all cells = 100 + // Expected output: 50 everywhere — both contribute to divisor. + // Before fix #2 the pyfunc divisor was `np.sum(stacked > 0, axis=0)` + // which counted only B → output 100 everywhere. + val tmpDir = Files.createTempDirectory("gbx_combineavg_zero_").toFile + val w = 4; val h = 4 + val a = makeByteRaster(s"${tmpDir.getAbsolutePath}/a.tif", Array.fill[Byte](w * h)(0), w, h, nodata = None) + val b = makeByteRaster(s"${tmpDir.getAbsolutePath}/b.tif", Array.fill[Byte](w * h)(100), w, h, nodata = None) + try { + val (_, resultDs, _) = RST_CombineAvg.execute( + Seq((1L, a, Map.empty), (1L, b, Map.empty)) + ) + try { + val w2 = resultDs.GetRasterXSize + val h2 = resultDs.GetRasterYSize + val buf = Array.ofDim[Double](w2 * h2) + resultDs.GetRasterBand(1).ReadRaster(0, 0, w2, h2, gdalconstConstants.GDT_Float64, buf) + buf.min shouldBe 50.0 +- 0.5 + buf.max shouldBe 50.0 +- 0.5 + } finally RasterDriver.releaseDataset(resultDs) + } finally { + RasterDriver.releaseDataset(a) + RasterDriver.releaseDataset(b) + tmpDir.listFiles().foreach(_.delete()) + tmpDir.delete() + } + } + + test("CombineAvg rounds (not truncates) when casting to integer output") { + // mean(50, 99) = 74.5; with truncation that becomes 74, with rounding 74 + // (banker's rounding → nearest even). Pick mean(51, 100) = 75.5 → 76. + // Before fix #2: `np.divide(..., casting='unsafe')` truncated → 75. + val tmpDir = Files.createTempDirectory("gbx_combineavg_round_").toFile + val w = 4; val h = 4 + val a = makeByteRaster(s"${tmpDir.getAbsolutePath}/a.tif", Array.fill[Byte](w * h)(51), w, h, nodata = None) + val b = makeByteRaster(s"${tmpDir.getAbsolutePath}/b.tif", Array.fill[Byte](w * h)(100), w, h, nodata = None) + try { + val (_, resultDs, _) = RST_CombineAvg.execute( + Seq((1L, a, Map.empty), (1L, b, Map.empty)) + ) + try { + val w2 = resultDs.GetRasterXSize + val h2 = resultDs.GetRasterYSize + val buf = Array.ofDim[Double](w2 * h2) + resultDs.GetRasterBand(1).ReadRaster(0, 0, w2, h2, gdalconstConstants.GDT_Float64, buf) + buf.min shouldBe 76.0 +- 0.5 + buf.max shouldBe 76.0 +- 0.5 + } finally RasterDriver.releaseDataset(resultDs) + } finally { + RasterDriver.releaseDataset(a) + RasterDriver.releaseDataset(b) + tmpDir.listFiles().foreach(_.delete()) + tmpDir.delete() + } + } + test("CombineAvg should generate operation metadata") { val inputMetadata = Map( "TEST_KEY" -> "TEST_VALUE", @@ -115,6 +256,68 @@ class RST_AggregationsTest extends AnyFunSuite with BeforeAndAfterAll { // RST_DerivedBand Tests (3 tests) // ==================================================================== + test("DerivedBand actually transforms pixel values (regression: pixel function must fire)") { + // Synthetic doubling — confirms the user-supplied pyfunc fires through + // the same PixelCombineRasters path as CombineAvg. Before the fix #1 + // ordering bug was repaired, the pyfunc never executed and gdal.Translate + // returned a multi-source mosaic of the inputs, so a single-input + // doubling pyfunc would have silently returned the input unchanged. + val tmpDir = Files.createTempDirectory("gbx_derivedband_double_").toFile + val input = makeConstantByteRaster(s"${tmpDir.getAbsolutePath}/in.tif", 50) + val pyfunc = """ + |import numpy as np + |def doubleit(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize, buf_radius, gt, **kwargs): + | out_ar[:] = np.asarray(in_ar[0], dtype=np.float64) * 2 + |""".stripMargin + try { + val (resultDs, _) = RST_DerivedBand.execute(Seq(input), Map.empty, pyfunc, "doubleit") + try { + val w = resultDs.GetRasterXSize + val h = resultDs.GetRasterYSize + val buf = Array.ofDim[Double](w * h) + resultDs.GetRasterBand(1).ReadRaster(0, 0, w, h, gdalconstConstants.GDT_Float64, buf) + buf.min shouldBe 100.0 +- 0.5 + buf.max shouldBe 100.0 +- 0.5 + } finally RasterDriver.releaseDataset(resultDs) + } finally { + RasterDriver.releaseDataset(input) + tmpDir.listFiles().foreach(_.delete()) + tmpDir.delete() + } + } + + test("DerivedBand averages multi-input pixel values") { + // Three co-extensive inputs 50 / 100 / 150 with a numpy-mean pyfunc; + // expected output 100 everywhere. Equivalent to CombineAvg on the + // simple case but exercising the user-pyfunc code path. + val tmpDir = Files.createTempDirectory("gbx_derivedband_avg_").toFile + val a = makeConstantByteRaster(s"${tmpDir.getAbsolutePath}/a.tif", 50) + val b = makeConstantByteRaster(s"${tmpDir.getAbsolutePath}/b.tif", 100) + val c = makeConstantByteRaster(s"${tmpDir.getAbsolutePath}/c.tif", 150) + val pyfunc = """ + |import numpy as np + |def meanof(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize, buf_radius, gt, **kwargs): + | out_ar[:] = np.mean(np.asarray(in_ar, dtype=np.float64), axis=0) + |""".stripMargin + try { + val (resultDs, _) = RST_DerivedBand.execute(Seq(a, b, c), Map.empty, pyfunc, "meanof") + try { + val w = resultDs.GetRasterXSize + val h = resultDs.GetRasterYSize + val buf = Array.ofDim[Double](w * h) + resultDs.GetRasterBand(1).ReadRaster(0, 0, w, h, gdalconstConstants.GDT_Float64, buf) + buf.min shouldBe 100.0 +- 0.5 + buf.max shouldBe 100.0 +- 0.5 + } finally RasterDriver.releaseDataset(resultDs) + } finally { + RasterDriver.releaseDataset(a) + RasterDriver.releaseDataset(b) + RasterDriver.releaseDataset(c) + tmpDir.listFiles().foreach(_.delete()) + tmpDir.delete() + } + } + test("DerivedBand should apply simple Python averaging function") { val pyfunc = """ import numpy as np @@ -255,4 +458,37 @@ def identity(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize derivedDs should not be null RasterDriver.releaseDataset(derivedDs) } + + private def makeConstantByteRaster(path: String, value: Int, w: Int = 32, h: Int = 32): Dataset = + makeByteRaster(path, Array.fill[Byte](w * h)(value.toByte), w, h, nodata = None) + + /** + * Create a small GTiff backed by `values` (length w*h, row-major). Used by + * the CombineAvg / DerivedBand numerical-correctness tests that need + * deterministic synthetic inputs with a known expected mean. EPSG:4326 + + * a fixed geotransform so gdalbuildvrt aligns the rasters. If `nodata` is + * set, it's stamped on the band so the CombineAVG pyfunc can mask those + * cells out of both sum and divisor. + * + * Note: closes the writer Dataset and re-opens it for reading. Without + * this, downstream gdalbuildvrt sometimes sees an unflushed file and + * pulls in zeros for the pixel values. Also uses the byte[] overload of + * WriteRaster — the int[]/GDT_Byte combo silently writes nothing. + */ + private def makeByteRaster(path: String, values: Array[Byte], w: Int, h: Int, nodata: Option[Int]): Dataset = { + require(values.length == w * h, s"expected ${w * h} values, got ${values.length}") + val drv = gdal.GetDriverByName("GTiff") + val writer = drv.Create(path, w, h, 1, gdalconstConstants.GDT_Byte, Array[String]("COMPRESS=DEFLATE")) + writer.SetGeoTransform(Array[Double](149.0, 0.01, 0.0, -35.0, 0.0, -0.01)) + val sr = new SpatialReference() + sr.ImportFromEPSG(4326) + writer.SetProjection(sr.ExportToWkt()) + val band = writer.GetRasterBand(1) + nodata.foreach(nd => band.SetNoDataValue(nd.toDouble)) + band.WriteRaster(0, 0, w, h, values) + band.FlushCache() + writer.FlushCache() + writer.delete() + gdal.Open(path) + } } diff --git a/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_ExpressionExecuteTest.scala b/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_ExpressionExecuteTest.scala index 661a10d..87b28fd 100644 --- a/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_ExpressionExecuteTest.scala +++ b/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_ExpressionExecuteTest.scala @@ -71,7 +71,18 @@ class RST_ExpressionExecuteTest extends AnyFunSuite with BeforeAndAfterAll { } test("RST_DerivedBand should compute derived band from raster") { - val pyfunc = "def compute(pixel):\n return pixel[0] * 2" + // Valid GDAL VRT Python pixel-function signature. The earlier + // `def compute(pixel): return pixel[0]*2` shape silently never + // executed (PixelCombineRasters opened the VRT before injecting + // the pixel function), so the malformed signature went unnoticed. + // After the ordering fix the pixel function actually fires, so + // the signature has to be correct. + val pyfunc = + """ + |import numpy as np + |def compute(in_ar, out_ar, xoff, yoff, xsize, ysize, raster_xsize, raster_ysize, buf_radius, gt, **kwargs): + | out_ar[:] = np.array(in_ar[0]) * 2 + |""".stripMargin val (derivedDs, _) = RST_DerivedBand.execute(Seq(ds, ds), Map.empty, pyfunc, "compute") derivedDs != null shouldBe true derivedDs.getRasterCount == ds.getRasterCount shouldBe true diff --git a/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_NoVrtPayloadTest.scala b/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_NoVrtPayloadTest.scala index ce998ae..3525be8 100644 --- a/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_NoVrtPayloadTest.scala +++ b/src/test/scala/com/databricks/labs/gbx/rasterx/expressions/RST_NoVrtPayloadTest.scala @@ -102,25 +102,19 @@ class RST_NoVrtPayloadTest extends PlanTest with SilentSparkSession { test("rst_combineavg_agg returns self-contained GTiff bytes (no VRT payload)") { val sc = spark import com.databricks.labs.gbx.rasterx.functions._ - import com.databricks.labs.gbx.udfs.st_buffer import sc.implicits._ functions.register(spark) val tifPath = this.getClass.getResource("/modis/").toString - // rst_clip preamble mirrors RST_AggEvalTest — exercises the same - // path-setup that PixelCombineRasters relies on (the VRT staging dir - // under NodeFilePathUtil.rootPath gets touched during clip). Without - // it, combineavg's gdalbuildvrt → read-back-VRT step can't find its - // own freshly-written VRT in the test environment. + // No rst_clip warmup needed: PixelCombineRasters.combine now + // pre-creates NodeFilePathUtil.rootPath itself, so combineavg_agg + // is safe as the first op in a fresh JVM. val df = Seq( s"$tifPath/MCD43A4.A2018185.h10v07.006.2018194033728_B01.TIF", s"$tifPath/MCD43A4.A2018185.h10v07.006.2018194033728_B02.TIF", s"$tifPath/MCD43A4.A2018185.h10v07.006.2018194033728_B03.TIF" ).toDF("path") .withColumn("tile", rst_fromfile(col("path"), lit("GTiff"))) - .withColumn("bbox", rst_boundingbox(col("tile"))) - .withColumn("clipper", st_buffer(col("bbox"), lit(-500000.0))) - .withColumn("tile", rst_clip(col("tile"), col("clipper"), lit(true))) .groupBy(lit(1).alias("g")) .agg(rst_combineavg_agg(col("tile")).alias("avg")) .select(