diff --git a/LICENSE-binary b/LICENSE-binary index e04f78a80070b..a578bc76478c7 100644 --- a/LICENSE-binary +++ b/LICENSE-binary @@ -242,7 +242,7 @@ io.dropwizard.metrics:metrics-json io.dropwizard.metrics:metrics-jvm io.fabric8:kubernetes-client io.fabric8:kubernetes-client-api -io.fabric8:kubernetes-httpclient-okhttp +io.fabric8:kubernetes-httpclient-vertx io.fabric8:kubernetes-model-admissionregistration io.fabric8:kubernetes-model-apiextensions io.fabric8:kubernetes-model-apps @@ -305,6 +305,7 @@ joda-time:joda-time net.sf.opencsv:opencsv net.sf.supercsv:super-csv net.sf.jpam:jpam +org.apache.arrow:arrow-compression org.apache.arrow:arrow-format org.apache.arrow:arrow-memory-core org.apache.arrow:arrow-memory-netty @@ -349,7 +350,7 @@ org.apache.logging.log4j:log4j-1.2-api org.apache.logging.log4j:log4j-api org.apache.logging.log4j:log4j-core org.apache.logging.log4j:log4j-layout-template-json -org.apache.logging.log4j:log4j-slf4j-impl +org.apache.logging.log4j:log4j-slf4j2-impl org.apache.orc:orc-core org.apache.orc:orc-format org.apache.orc:orc-mapreduce @@ -510,6 +511,7 @@ javax.transaction:transaction-api Common Development and Distribution License (CDDL) 1.1 ------------------------------------------------------ javax.transaction:jta http://www.oracle.com/technetwork/java/index.html +javax.servlet:javax.servlet-api https://oss.oracle.com/licenses/CDDL+GPL-1.1 javax.xml.bind:jaxb-api https://github.com/javaee/jaxb-v2 diff --git a/R/pkg/DESCRIPTION b/R/pkg/DESCRIPTION index 78283891dea81..8259cae059502 100644 --- a/R/pkg/DESCRIPTION +++ b/R/pkg/DESCRIPTION @@ -1,6 +1,6 @@ Package: SparkR Type: Package -Version: 4.2.0.1-4.3.0-1 +Version: 4.2.0.1-4.3.0-2 Title: R Front End for 'Apache Spark' Description: Provides an R Front end for 'Apache Spark' . Authors@R: diff --git a/assembly/pom.xml b/assembly/pom.xml index aba28e2cf858d..277d7e0a6dc86 100644 --- a/assembly/pom.xml +++ b/assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../pom.xml diff --git a/common/kvstore/pom.xml b/common/kvstore/pom.xml index 6a936b31f28a0..0410ed159af2b 100644 --- a/common/kvstore/pom.xml +++ b/common/kvstore/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java index d80e002ddb06e..03bd2a3f12485 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/LevelDBIterator.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.lang.ref.Cleaner; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; @@ -289,17 +290,10 @@ private boolean isEndMarker(byte[] key) { key[key.length - 1] == LevelDBTypeInfo.END_MARKER[0]); } + @VisibleForTesting static int compare(byte[] a, byte[] b) { - int diff = 0; - int minLen = Math.min(a.length, b.length); - for (int i = 0; i < minLen; i++) { - diff += (a[i] - b[i]); - if (diff != 0) { - return diff; - } - } - - return a.length - b.length; + // Unsigned bytewise comparison, matching LevelDB's key ordering. + return Arrays.compareUnsigned(a, b); } static class ResourceCleaner implements Runnable { diff --git a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java index d37a4bd7b0b2d..a77f399a49c8b 100644 --- a/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java +++ b/common/kvstore/src/main/java/org/apache/spark/util/kvstore/RocksDBIterator.java @@ -278,17 +278,10 @@ private boolean isEndMarker(byte[] key) { key[key.length - 1] == RocksDBTypeInfo.END_MARKER[0]); } + @VisibleForTesting static int compare(byte[] a, byte[] b) { - int diff = 0; - int minLen = Math.min(a.length, b.length); - for (int i = 0; i < minLen; i++) { - diff += (a[i] - b[i]); - if (diff != 0) { - return diff; - } - } - - return a.length - b.length; + // Unsigned bytewise comparison, matching RocksDB's key ordering. + return Arrays.compareUnsigned(a, b); } static class ResourceCleaner implements Runnable { diff --git a/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorCompareSuite.java b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorCompareSuite.java new file mode 100644 index 0000000000000..b399d250bff25 --- /dev/null +++ b/common/kvstore/src/test/java/org/apache/spark/util/kvstore/DBIteratorCompareSuite.java @@ -0,0 +1,72 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.kvstore; + +import org.junit.jupiter.api.Test; +import static org.junit.jupiter.api.Assertions.*; + +/** + * Unit tests for the static byte-array comparator that LevelDBIterator and RocksDBIterator + * each define (the two backends are kept independent, so the method is duplicated). The + * comparator is pure and needs no database, so these tests run on every platform. + */ +public class DBIteratorCompareSuite { + + @FunctionalInterface + private interface ByteArrayComparator { + int compare(byte[] a, byte[] b); + } + + @Test + public void testLevelDBIteratorCompare() { + checkUnsignedByteOrdering(LevelDBIterator::compare); + } + + @Test + public void testRocksDBIteratorCompare() { + checkUnsignedByteOrdering(RocksDBIterator::compare); + } + + private static void checkUnsignedByteOrdering(ByteArrayComparator cmp) { + // Equal arrays compare equal. + assertEquals(0, cmp.compare(new byte[] { 1, 2, 3 }, new byte[] { 1, 2, 3 })); + + // Empty arrays compare equal, and an empty array sorts before any non-empty one. + assertEquals(0, cmp.compare(new byte[] {}, new byte[] {})); + assertTrue(cmp.compare(new byte[] {}, new byte[] { 1 }) < 0); + assertTrue(cmp.compare(new byte[] { 1 }, new byte[] {}) > 0); + + // A prefix sorts before the longer array that extends it. + assertTrue(cmp.compare(new byte[] { 1, 2 }, new byte[] { 1, 2, 3 }) < 0); + assertTrue(cmp.compare(new byte[] { 1, 2, 3 }, new byte[] { 1, 2 }) > 0); + + // Bytes must be ordered as unsigned, matching the underlying key ordering: 0x80 (128) is + // greater than 0x7f (127). A signed comparison would wrongly treat 0x80 as -128. + byte[] highBit = new byte[] { (byte) 0x80 }; + byte[] lowBit = new byte[] { 0x7f }; + assertTrue(cmp.compare(highBit, lowBit) > 0); + assertTrue(cmp.compare(lowBit, highBit) < 0); + + // 0xff (255) is the largest byte value when unsigned, so it sorts after 0x00. + assertTrue(cmp.compare(new byte[] { (byte) 0xff }, new byte[] { 0x00 }) > 0); + + // The first differing byte decides the order, regardless of later bytes. + assertTrue(cmp.compare(new byte[] { 0x01, (byte) 0xff }, new byte[] { 0x02, 0x00 }) < 0); + } + +} diff --git a/common/network-common/pom.xml b/common/network-common/pom.xml index 72dc7bef3b5f3..dc542de0a4ef1 100644 --- a/common/network-common/pom.xml +++ b/common/network-common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java index 8bab808ad6864..c4e94847ea0ab 100644 --- a/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java +++ b/common/network-common/src/main/java/org/apache/spark/network/protocol/Encoders.java @@ -20,6 +20,7 @@ import java.io.IOException; import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; +import java.util.Objects; import io.netty.buffer.ByteBuf; import org.roaringbitmap.RoaringBitmap; @@ -41,6 +42,7 @@ public static void encode(ByteBuf buf, String s) { public static String decode(ByteBuf buf) { int length = buf.readInt(); + Objects.checkFromIndexSize(0, length, buf.readableBytes()); byte[] bytes = new byte[length]; buf.readBytes(bytes); return new String(bytes, StandardCharsets.UTF_8); @@ -105,6 +107,7 @@ public static void encode(ByteBuf buf, byte[] arr) { public static byte[] decode(ByteBuf buf) { int length = buf.readInt(); + Objects.checkFromIndexSize(0, length, buf.readableBytes()); byte[] bytes = new byte[length]; buf.readBytes(bytes); return bytes; @@ -130,6 +133,7 @@ public static void encode(ByteBuf buf, String[] strings) { public static String[] decode(ByteBuf buf) { int numStrings = buf.readInt(); + Objects.checkFromIndexSize(0, numStrings, buf.readableBytes() / 4); String[] strings = new String[numStrings]; for (int i = 0; i < strings.length; i ++) { strings[i] = Strings.decode(buf); @@ -153,6 +157,7 @@ public static void encode(ByteBuf buf, int[] ints) { public static int[] decode(ByteBuf buf) { int numInts = buf.readInt(); + Objects.checkFromIndexSize(0, numInts, buf.readableBytes() / 4); int[] ints = new int[numInts]; for (int i = 0; i < ints.length; i ++) { ints[i] = buf.readInt(); @@ -176,6 +181,7 @@ public static void encode(ByteBuf buf, long[] longs) { public static long[] decode(ByteBuf buf) { int numLongs = buf.readInt(); + Objects.checkFromIndexSize(0, numLongs, buf.readableBytes() / 8); long[] longs = new long[numLongs]; for (int i = 0; i < longs.length; i ++) { longs[i] = buf.readLong(); @@ -207,6 +213,9 @@ public static void encode(ByteBuf buf, RoaringBitmap[] bitmaps) { public static RoaringBitmap[] decode(ByteBuf buf) { int numBitmaps = buf.readInt(); + // The divisor 8 is the minimum on-wire size of one element, since an empty RoaringBitmap + // serializes to 8 bytes (a 4-byte cookie followed by a 4-byte size). + Objects.checkFromIndexSize(0, numBitmaps, buf.readableBytes() / 8); RoaringBitmap[] bitmaps = new RoaringBitmap[numBitmaps]; for (int i = 0; i < bitmaps.length; i ++) { bitmaps[i] = Bitmaps.decode(buf); diff --git a/common/network-common/src/test/java/org/apache/spark/network/protocol/EncodersSuite.java b/common/network-common/src/test/java/org/apache/spark/network/protocol/EncodersSuite.java index 127835d29bc01..91c95a3c76e09 100644 --- a/common/network-common/src/test/java/org/apache/spark/network/protocol/EncodersSuite.java +++ b/common/network-common/src/test/java/org/apache/spark/network/protocol/EncodersSuite.java @@ -48,6 +48,28 @@ public void testRoaringBitmapEncodeShouldFailWhenBufferIsSmall() { () -> Encoders.Bitmaps.encode(buf, bitmap)); } + @Test + public void testStringsEncodeDecode() { + String s = "spark"; + ByteBuf buf = Unpooled.buffer(Encoders.Strings.encodedLength(s)); + Encoders.Strings.encode(buf, s); + assertEquals(s, Encoders.Strings.decode(buf)); + } + + @Test + public void testStringsDecodeShouldFailWhenLengthIsNegative() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(-1); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.Strings.decode(buf)); + } + + @Test + public void testStringsDecodeShouldFailWhenLengthExceedsReadableBytes() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(Integer.MAX_VALUE); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.Strings.decode(buf)); + } + @Test public void testBitmapArraysEncodeDecode() { RoaringBitmap[] bitmaps = new RoaringBitmap[] { @@ -66,4 +88,106 @@ public void testBitmapArraysEncodeDecode() { RoaringBitmap[] decodedBitmaps = Encoders.BitmapArrays.decode(buf); assertArrayEquals(bitmaps, decodedBitmaps); } + + @Test + public void testBitmapArraysDecodeShouldFailWhenLengthIsNegative() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(-1); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.BitmapArrays.decode(buf)); + } + + @Test + public void testBitmapArraysDecodeShouldFailWhenLengthExceedsReadableBytes() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(Integer.MAX_VALUE); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.BitmapArrays.decode(buf)); + } + + @Test + public void testByteArraysEncodeDecode() { + byte[] arr = new byte[] { 1, 2, 3, 4, 5 }; + ByteBuf buf = Unpooled.buffer(Encoders.ByteArrays.encodedLength(arr)); + Encoders.ByteArrays.encode(buf, arr); + assertArrayEquals(arr, Encoders.ByteArrays.decode(buf)); + } + + @Test + public void testByteArraysDecodeShouldFailWhenLengthIsNegative() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(-1); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.ByteArrays.decode(buf)); + } + + @Test + public void testByteArraysDecodeShouldFailWhenLengthExceedsReadableBytes() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(Integer.MAX_VALUE); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.ByteArrays.decode(buf)); + } + + @Test + public void testIntArraysEncodeDecode() { + int[] arr = new int[] { 1, 2, 3, 4, 5 }; + ByteBuf buf = Unpooled.buffer(Encoders.IntArrays.encodedLength(arr)); + Encoders.IntArrays.encode(buf, arr); + assertArrayEquals(arr, Encoders.IntArrays.decode(buf)); + } + + @Test + public void testIntArraysDecodeShouldFailWhenLengthIsNegative() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(-1); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.IntArrays.decode(buf)); + } + + @Test + public void testIntArraysDecodeShouldFailWhenLengthExceedsReadableBytes() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(Integer.MAX_VALUE); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.IntArrays.decode(buf)); + } + + @Test + public void testLongArraysEncodeDecode() { + long[] arr = new long[] { 1L, 2L, 3L, 4L, 5L }; + ByteBuf buf = Unpooled.buffer(Encoders.LongArrays.encodedLength(arr)); + Encoders.LongArrays.encode(buf, arr); + assertArrayEquals(arr, Encoders.LongArrays.decode(buf)); + } + + @Test + public void testLongArraysDecodeShouldFailWhenLengthIsNegative() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(-1); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.LongArrays.decode(buf)); + } + + @Test + public void testLongArraysDecodeShouldFailWhenLengthExceedsReadableBytes() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(Integer.MAX_VALUE); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.LongArrays.decode(buf)); + } + + @Test + public void testStringArraysEncodeDecode() { + String[] arr = new String[] { "spark", "", "rocks" }; + ByteBuf buf = Unpooled.buffer(Encoders.StringArrays.encodedLength(arr)); + Encoders.StringArrays.encode(buf, arr); + assertArrayEquals(arr, Encoders.StringArrays.decode(buf)); + } + + @Test + public void testStringArraysDecodeShouldFailWhenLengthIsNegative() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(-1); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.StringArrays.decode(buf)); + } + + @Test + public void testStringArraysDecodeShouldFailWhenLengthExceedsReadableBytes() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(Integer.MAX_VALUE); + assertThrows(IndexOutOfBoundsException.class, () -> Encoders.StringArrays.decode(buf)); + } } diff --git a/common/network-shuffle/pom.xml b/common/network-shuffle/pom.xml index 9b5a916587056..758077d064979 100644 --- a/common/network-shuffle/pom.xml +++ b/common/network-shuffle/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/common/network-yarn/pom.xml b/common/network-yarn/pom.xml index cf85e7577a759..863392e43329b 100644 --- a/common/network-yarn/pom.xml +++ b/common/network-yarn/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java index b5718946252e1..ab23585170b4a 100644 --- a/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java +++ b/common/network-yarn/src/main/java/org/apache/spark/network/yarn/YarnShuffleService.java @@ -433,7 +433,8 @@ public void initializeApplication(ApplicationInitializationContext context) { MDC.of(LogKeys.APP_ID, appId)); } } catch (IOException ioe) { - logger.warn("Unable to parse application data for service: " + payload); + logger.warn("Unable to parse the application data for application {}", + MDC.of(LogKeys.APP_ID, appId)); metaInfo = null; } if (isAuthenticationEnabled()) { diff --git a/common/sketch/pom.xml b/common/sketch/pom.xml index d5d9a986d664d..b22d89af9e99f 100644 --- a/common/sketch/pom.xml +++ b/common/sketch/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/common/tags/pom.xml b/common/tags/pom.xml index b49d6baa14607..7ba63816225a9 100644 --- a/common/tags/pom.xml +++ b/common/tags/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/common/unsafe/pom.xml b/common/unsafe/pom.xml index 55bb994fa9b15..68784dce25a44 100644 --- a/common/unsafe/pom.xml +++ b/common/unsafe/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java index 2b9457c58560f..c9fee02125fd8 100644 --- a/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationAwareUTF8String.java @@ -1519,7 +1519,7 @@ public static UTF8String trimRight( if (charIndex == src.length()) { return srcString; } - if (lastNonSpacePosition == srcString.numChars()) { + if (lastNonSpacePosition == src.length()) { return UTF8String.fromString(src.substring(0, charIndex)); } return UTF8String.fromString( diff --git a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java index 96b103ae33881..ac41812bc6d35 100644 --- a/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java +++ b/common/unsafe/src/main/java/org/apache/spark/unsafe/types/UTF8String.java @@ -724,7 +724,8 @@ public int getChar(int charIndex) { /** * Returns the code point starting from the byte at position `byteIndex`. - * If byte index is invalid, throws exception. + * If byte index is invalid, throws exception. If the sequence is truncated (the leader byte + * declares more bytes than remain), the missing continuation bytes are treated as 0. */ public int codePointFrom(int byteIndex) { if (byteIndex < 0 || byteIndex >= numBytes) { @@ -736,18 +737,28 @@ public int codePointFrom(int byteIndex) { case 1 -> b & 0x7F; case 2 -> - ((b & 0x1F) << 6) | (getByte(byteIndex + 1) & 0x3F); + ((b & 0x1F) << 6) | continuationByte(byteIndex + 1); case 3 -> - ((b & 0x0F) << 12) | ((getByte(byteIndex + 1) & 0x3F) << 6) | - (getByte(byteIndex + 2) & 0x3F); + ((b & 0x0F) << 12) | (continuationByte(byteIndex + 1) << 6) | + continuationByte(byteIndex + 2); case 4 -> - ((b & 0x07) << 18) | ((getByte(byteIndex + 1) & 0x3F) << 12) | - ((getByte(byteIndex + 2) & 0x3F) << 6) | (getByte(byteIndex + 3) & 0x3F); + ((b & 0x07) << 18) | (continuationByte(byteIndex + 1) << 12) | + (continuationByte(byteIndex + 2) << 6) | continuationByte(byteIndex + 3); default -> throw new IllegalStateException("Error in UTF-8 code point"); }; } + /** + * Returns the low 6 bits of the UTF-8 continuation byte at `byteIndex`, or 0 when `byteIndex` + * is past the end of the string. The bounds check stops a truncated trailing multi-byte + * sequence (a leader byte whose declared width exceeds the bytes that remain) from reading + * past the end of the backing memory. + */ + private int continuationByte(int byteIndex) { + return byteIndex < numBytes ? getByte(byteIndex) & 0x3F : 0; + } + public boolean matchAt(final UTF8String s, int pos) { if (s.numBytes + pos > numBytes || pos < 0) { return false; @@ -944,7 +955,10 @@ public int findInSet(UTF8String match) { * @return a new UTF8String in the position of [start, end] of current UTF8String bytes. */ public UTF8String copyUTF8String(int start, int end) { - int len = end - start + 1; + // Clamp to the bytes that actually remain so an out-of-range `end` (for example, derived + // from a truncated trailing multi-byte sequence) can't copy past the end of the backing + // memory. + int len = Math.min(end - start + 1, numBytes - start); byte[] newBytes = new byte[len]; copyMemory(base, offset + start, newBytes, BYTE_ARRAY_OFFSET, len); return UTF8String.fromBytes(newBytes); @@ -1137,7 +1151,9 @@ public UTF8String trimRight(UTF8String trimString) { stringCharPos[numChars - 1], stringCharPos[numChars - 1] + stringCharLen[numChars - 1] - 1); if (trimString.find(searchChar, 0) >= 0) { - trimEnd -= stringCharLen[numChars - 1]; + // Advance by the bytes the character actually occupies. A truncated trailing leader is + // shorter than the width its leader byte declares, so use the (clamped) search char. + trimEnd -= searchChar.numBytes; } else { break; } @@ -1160,7 +1176,7 @@ public UTF8String reverse() { int i = 0; // position in byte while (i < numBytes) { - int len = Math.min(numBytesForFirstByte(getByte(i)), numBytes); + int len = Math.min(numBytesForFirstByte(getByte(i)), numBytes - i); int targetOffset = Math.max(result.length - i - len, 0); copyMemory(this.base, this.offset + i, result, BYTE_ARRAY_OFFSET + targetOffset, len); diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java index 1db163c1c822d..6372d7e4663c1 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/CollationSupportSuite.java @@ -3647,6 +3647,32 @@ public void testStringTrimRight() throws SparkException { assertStringTrimRight(UTF8_LCASE, "𝔸", "a", "𝔸"); assertStringTrimRight(UNICODE, "𝔸", "a", "𝔸"); assertStringTrimRight(UNICODE_CI, "𝔸", "a", ""); + // RTRIM-modifier collations ignore trailing spaces while matching the trim characters, then + // re-append them. The behaviour must agree across the UTF8_BINARY (binaryTrimRight), + // UTF8_LCASE (lowercaseTrimRight), and ICU (trimRight) paths. The supplementary-character + // cases below (trailing-space count == supplementary code-point count) regressed on the ICU + // path before SPARK-57506, which compared a Java-char index against a code-point count. + assertStringTrimRight("UTF8_BINARY_RTRIM", "x ", "x", " "); + assertStringTrimRight("UTF8_LCASE_RTRIM", "x ", "x", " "); + assertStringTrimRight("UNICODE_RTRIM", "x ", "x", " "); + assertStringTrimRight("UTF8_BINARY_RTRIM", " ", "x", " "); + assertStringTrimRight("UTF8_LCASE_RTRIM", " ", "x", " "); + assertStringTrimRight("UNICODE_RTRIM", " ", "x", " "); + assertStringTrimRight("UTF8_BINARY_RTRIM", "𝔸 ", "𝔸", " "); + assertStringTrimRight("UTF8_LCASE_RTRIM", "𝔸 ", "𝔸", " "); + assertStringTrimRight("UNICODE_RTRIM", "𝔸 ", "𝔸", " "); + assertStringTrimRight("UTF8_BINARY_RTRIM", "𝔸 ", "𝔸", " "); + assertStringTrimRight("UTF8_LCASE_RTRIM", "𝔸 ", "𝔸", " "); + assertStringTrimRight("UNICODE_RTRIM", "𝔸 ", "𝔸", " "); + assertStringTrimRight("UTF8_BINARY_RTRIM", "𝔸𝔸 ", "𝔸", " "); + assertStringTrimRight("UTF8_LCASE_RTRIM", "𝔸𝔸 ", "𝔸", " "); + assertStringTrimRight("UNICODE_RTRIM", "𝔸𝔸 ", "𝔸", " "); + // Case-folding interacts with space preservation per path: only UTF8_LCASE folds B to b, so + // only it trims the trailing 'B' and re-appends the space; binary and (case-sensitive) ICU + // leave the input unchanged. This exercises the lcase space-preservation branch on its own. + assertStringTrimRight("UTF8_BINARY_RTRIM", "xB ", "b", "xB "); + assertStringTrimRight("UTF8_LCASE_RTRIM", "xB ", "b", "x "); + assertStringTrimRight("UNICODE_RTRIM", "xB ", "b", "xB "); } /** diff --git a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java index 26b96155377e8..0374a1672d22b 100644 --- a/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java +++ b/common/unsafe/src/test/java/org/apache/spark/unsafe/types/UTF8StringSuite.java @@ -325,6 +325,30 @@ public void reverse() { assertEquals(EMPTY_UTF8, EMPTY_UTF8.reverse()); assertEquals(fromString("者行孙"), fromString("孙行者").reverse()); assertEquals(fromString("者行孙 olleh"), fromString("hello 孙行者").reverse()); + // Malformed UTF-8: a truncated trailing multi-byte sequence must be reversed as orphan + // bytes without reading past the end of the string. The backing arrays carry an extra + // trailing byte so a regression that over-reads would produce a deterministically wrong + // result rather than reading uninitialized memory. + // 'A' followed by an incomplete 2-byte leader (0xCE). + byte[] truncated2 = new byte[]{0x41, (byte) 0xCE, 0x42}; + assertEquals( + fromBytes(new byte[]{(byte) 0xCE, 0x41}), + fromBytes(truncated2, 0, 2).reverse()); + // 'A' followed by an incomplete 3-byte leader (0xE4 0xB8). + byte[] truncated3 = new byte[]{0x41, (byte) 0xE4, (byte) 0xB8, 0x42}; + assertEquals( + fromBytes(new byte[]{(byte) 0xE4, (byte) 0xB8, 0x41}), + fromBytes(truncated3, 0, 3).reverse()); + // 'A' followed by an incomplete 4-byte leader (0xF0 0x90). + byte[] truncated4 = new byte[]{0x41, (byte) 0xF0, (byte) 0x90, 0x42}; + assertEquals( + fromBytes(new byte[]{(byte) 0xF0, (byte) 0x90, 0x41}), + fromBytes(truncated4, 0, 3).reverse()); + // A complete 3-byte character (U+4E16) followed by an incomplete 2-byte leader (0xCE). + byte[] truncatedMid = new byte[]{(byte) 0xE4, (byte) 0xB8, (byte) 0x96, (byte) 0xCE, 0x42}; + assertEquals( + fromBytes(new byte[]{(byte) 0xCE, (byte) 0xE4, (byte) 0xB8, (byte) 0x96}), + fromBytes(truncatedMid, 0, 4).reverse()); } @Test @@ -1204,6 +1228,61 @@ public void testCodePointFrom() { assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(-1)); assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(str.length())); assertThrows(IndexOutOfBoundsException.class, () -> s.codePointFrom(str.length() + 1)); + + // Truncated trailing multi-byte sequence: the leader declares more bytes than remain. + // codePointFrom should decode only the bytes present (missing continuation bytes count as + // 0) and not read past the end. Each backing array has extra trailing bytes, so an + // over-read regression would show up in the result. + // 2-byte leader 0xCE with no continuation byte present. + assertEquals((0xCE & 0x1F) << 6, + fromBytes(new byte[] {(byte) 0xCE, 0x42}, 0, 1).codePointFrom(0)); + // 3-byte leader 0xE4 with no continuation bytes present. + assertEquals((0xE4 & 0x0F) << 12, + fromBytes(new byte[] {(byte) 0xE4, 0x42, 0x43}, 0, 1).codePointFrom(0)); + // 3-byte leader 0xE4 0xB8 with the final continuation byte missing. + assertEquals(((0xE4 & 0x0F) << 12) | ((0xB8 & 0x3F) << 6), + fromBytes(new byte[] {(byte) 0xE4, (byte) 0xB8, 0x42}, 0, 2).codePointFrom(0)); + // 4-byte leader 0xF1 with no continuation bytes present. + assertEquals((0xF1 & 0x07) << 18, + fromBytes(new byte[] {(byte) 0xF1, 0x42, 0x43, 0x44}, 0, 1).codePointFrom(0)); + // 4-byte leader 0xF1 with two continuation bytes present and only the last one missing, + // so just the final read crosses the end. + assertEquals(((0xF1 & 0x07) << 18) | ((0x9F & 0x3F) << 12) | ((0x8F & 0x3F) << 6), + fromBytes(new byte[] {(byte) 0xF1, (byte) 0x9F, (byte) 0x8F, 0x42}, 0, 3).codePointFrom(0)); + } + + @Test + public void copyUTF8StringClampsToRemainingBytes() { + // Here `end` runs one byte past the string, as it would for a truncated trailing sequence. + // copyUTF8String should clamp to the available bytes; the extra backing byte would show up + // in the result if it over-read. + byte[] backing = new byte[] {0x41, 0x42, 0x43}; + UTF8String s = fromBytes(backing, 0, 2); // views "AB" + // `end` (2) is one past the last valid byte index (1); only the two real bytes are copied. + assertEquals(fromString("AB"), s.copyUTF8String(0, 2)); + // Same with a non-zero start, so the clamp uses `numBytes - start`, not `numBytes`. + assertEquals(fromString("B"), s.copyUTF8String(1, 2)); + // In-bounds copies are unaffected. + assertEquals(fromString("AB"), s.copyUTF8String(0, 1)); + assertEquals(fromString("B"), s.copyUTF8String(1, 1)); + } + + @Test + public void trimTruncatedTrailingSequence() { + // trimLeft/trimRight build the search character through copyUTF8String, so an over-read would + // make it longer than the bytes that remain. The backing arrays carry an extra trailing byte + // to make any over-read deterministic. + // A lone truncated 2-byte leader (0xC2): the clamped search char is just the leader, which + // matches the 1-byte trim set and is trimmed away. + UTF8String lone = fromBytes(new byte[] {(byte) 0xC2, 0x42}, 0, 1); + UTF8String trim2 = fromBytes(new byte[] {(byte) 0xC2}); + assertEquals(EMPTY_UTF8, lone.trimLeft(trim2)); + assertEquals(EMPTY_UTF8, lone.trimRight(trim2)); + // 'A' followed by a truncated 3-byte leader (0xE4). Trimming the leader from the right must + // keep 'A': the trailing character occupies only one real byte, so only that byte is removed. + UTF8String prefixed = fromBytes(new byte[] {0x41, (byte) 0xE4, 0x42}, 0, 2); + UTF8String trim3 = fromBytes(new byte[] {(byte) 0xE4}); + assertEquals(fromString("A"), prefixed.trimRight(trim3)); } @Test diff --git a/common/utils-java/pom.xml b/common/utils-java/pom.xml index 433bffd7e405e..d536bea3d44ac 100644 --- a/common/utils-java/pom.xml +++ b/common/utils-java/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/common/utils/pom.xml b/common/utils/pom.xml index 296c30a6d25f7..f26599a5fbe43 100644 --- a/common/utils/pom.xml +++ b/common/utils/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/common/utils/src/main/resources/error/error-conditions.json b/common/utils/src/main/resources/error/error-conditions.json index 926019df1e74f..f86dbf6d61ec3 100644 --- a/common/utils/src/main/resources/error/error-conditions.json +++ b/common/utils/src/main/resources/error/error-conditions.json @@ -2931,12 +2931,6 @@ "" ] }, - "COLUMN_ID_MISMATCH" : { - "message" : [ - "Column IDs have changed:", - "" - ] - }, "METADATA_COLUMNS_MISMATCH" : { "message" : [ "Metadata columns have changed:", diff --git a/common/variant/pom.xml b/common/variant/pom.xml index 2ddd78eb7f17d..45f8889f5a3d1 100644 --- a/common/variant/pom.xml +++ b/common/variant/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java index ac93246991c0e..24f8f0fafc404 100644 --- a/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java +++ b/common/variant/src/main/java/org/apache/spark/types/variant/VariantUtil.java @@ -26,6 +26,7 @@ import java.math.BigInteger; import java.nio.ByteBuffer; import java.nio.ByteOrder; +import java.nio.charset.StandardCharsets; import java.util.Arrays; import java.util.UUID; @@ -507,7 +508,10 @@ public static String getString(byte[] value, int pos) { length = readUnsigned(value, pos + 1, U32_SIZE); } checkIndex(start + length - 1, value.length); - return new String(value, start, length); + // The string content is UTF-8 encoded (it is written by `VariantBuilder.appendString`). + // Decode with UTF-8 explicitly rather than relying on the JVM default charset, which is + // platform-dependent on Java 17. + return new String(value, start, length, StandardCharsets.UTF_8); } throw unexpectedType(Type.STRING); } @@ -693,6 +697,8 @@ public static String getMetadataKey(byte[] metadata, int id) { int nextOffset = readUnsigned(metadata, 1 + (id + 2) * offsetSize, offsetSize); if (offset > nextOffset) throw malformedVariant(); checkIndex(stringStart + nextOffset - 1, metadata.length); - return new String(metadata, stringStart + offset, nextOffset - offset); + // Dictionary keys are UTF-8 encoded (see `VariantBuilder.addKey`). Decode with UTF-8 + // explicitly rather than relying on the platform-dependent JVM default charset. + return new String(metadata, stringStart + offset, nextOffset - offset, StandardCharsets.UTF_8); } } diff --git a/common/variant/src/test/scala/org/apache/spark/types/variant/VariantUtf8DecodeSuite.scala b/common/variant/src/test/scala/org/apache/spark/types/variant/VariantUtf8DecodeSuite.scala new file mode 100644 index 0000000000000..e196209c82fa7 --- /dev/null +++ b/common/variant/src/test/scala/org/apache/spark/types/variant/VariantUtf8DecodeSuite.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.types.variant + +import java.nio.charset.StandardCharsets +import java.nio.file.Paths + +import org.scalatest.funsuite.AnyFunSuite // scalastyle:ignore funsuite + +class VariantUtf8DecodeSuite extends AnyFunSuite { // scalastyle:ignore funsuite + + test("SPARK-57599: keys and string values decode as UTF-8 under a non-UTF-8 default charset") { + // The reader must decode object keys and string values as UTF-8 regardless of the JVM default + // charset. That can only be exercised by changing the default charset, which is fixed at JVM + // startup and pinned to UTF-8 in the test JVM, so fork a child JVM with a non-UTF-8 + // -Dfile.encoding and round-trip a variant with non-ASCII content there (see + // VariantUtf8DecodeChecker). With the pre-fix default-charset decode the characters are + // corrupted and the child exits non-zero. + val javaExe = Paths.get(sys.props("java.home"), "bin", "java").toString + val command = Seq( + javaExe, + "-Dfile.encoding=ISO-8859-1", + "-cp", sys.props("java.class.path"), + VariantUtf8DecodeChecker.getClass.getName.stripSuffix("$")) + val process = new ProcessBuilder(command: _*).redirectErrorStream(true).start() + val output = new String(process.getInputStream.readAllBytes(), StandardCharsets.UTF_8).trim + val exitCode = process.waitFor() + if (output.contains("RESULT=INCONCLUSIVE")) { + // The child JVM ended up with a UTF-8 default charset, so it cannot exercise the fix. + assume(false, s"Inconclusive; child output: $output") + } + assert(exitCode === 0, s"child JVM reported a decode failure; output: $output") + } +} + +/** + * Entry point run in a child JVM by `VariantUtf8DecodeSuite` to verify that the Variant reader + * decodes object keys and string values as UTF-8 independently of the JVM default charset. + * + * This has to run as a separate `main`: the only way to actually exercise the bug is to start a JVM + * with a non-UTF-8 `-Dfile.encoding`, because the default charset is fixed at JVM startup and is + * pinned to UTF-8 in the test JVM. With the pre-fix default-charset decode, the non-ASCII content + * round-tripped here is corrupted and the process exits non-zero. + * + * All output is ASCII (characters are reported as code points) so it is readable regardless of the + * child's default charset. The process prints `RESULT=OK` and exits 0 on success, `RESULT=FAIL` and + * exits 1 if any value was corrupted, or `RESULT=INCONCLUSIVE` and exits 0 if the child's default + * charset turned out to be UTF-8 (in which case the fixed and buggy code cannot be told apart). + */ +private[variant] object VariantUtf8DecodeChecker { + + // This object is a standalone process whose stdout is its result protocol, so println is used + // deliberately as the communication channel with the parent test. + // scalastyle:off println + + // scalastyle:off nonascii + // (key, value) pairs covering 2-byte, 3-byte, and LONG_STR (> 63 UTF-8 bytes) UTF-8 content. + private val cases: Seq[(String, String)] = Seq( + "café" -> "résumé", + "你好" -> "世界", + "é" -> ("你" * 22)) // value is 66 UTF-8 bytes, over MAX_SHORT_STR_SIZE -> LONG_STR encoding + // scalastyle:on nonascii + + def main(args: Array[String]): Unit = { + // Read the startup `file.encoding`, which determines the JVM default charset that the buggy + // `new String(bytes)` decode used. (Querying the default charset directly is banned by + // scalastyle.) + val fileEncoding = sys.props.getOrElse("file.encoding", "") + if (fileEncoding.equalsIgnoreCase("UTF-8")) { + // The fixed and buggy code are indistinguishable when the default charset is already UTF-8. + println(s"RESULT=INCONCLUSIVE fileEncoding=$fileEncoding") + System.exit(0) + } + var failures = 0 + cases.foreach { case (key, value) => + val json = "{" + jsonString(key) + ":" + jsonString(value) + "}" + val field = VariantBuilder.parseJson(json, false).getFieldAtIndex(0) + if (field.key != key) { + println(s"KEY_MISMATCH expected=${codePoints(key)} actual=${codePoints(field.key)}") + failures += 1 + } + val decoded = field.value.getString + if (decoded != value) { + println(s"VALUE_MISMATCH expected=${codePoints(value)} actual=${codePoints(decoded)}") + failures += 1 + } + } + if (failures == 0) { + println(s"RESULT=OK fileEncoding=$fileEncoding") + System.exit(0) + } else { + println(s"RESULT=FAIL failures=$failures fileEncoding=$fileEncoding") + System.exit(1) + } + } + + private def jsonString(s: String): String = "\"" + s + "\"" + + private def codePoints(s: String): String = + s.codePoints().toArray.map(c => "U+%04X".format(c)).mkString("[", ",", "]") + + // scalastyle:on println +} diff --git a/connector/avro/pom.xml b/connector/avro/pom.xml index b4f5cd72f551d..30a4c876cfcda 100644 --- a/connector/avro/pom.xml +++ b/connector/avro/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala index 56b107c14f57f..376d12a8a4923 100644 --- a/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala +++ b/connector/avro/src/test/scala/org/apache/spark/sql/avro/AvroSuite.scala @@ -3191,6 +3191,181 @@ abstract class AvroSuite } } + test("TIME type read/write with Avro format") { + withTempPath { dir => + // Test boundary values and NULL handling + val df = spark.sql(""" + SELECT + TIME'00:00:00.123456' as midnight, + TIME'12:34:56.789012' as noon, + TIME'23:59:59.999999' as max_time, + CAST(NULL AS TIME) as null_time + """) + + df.write.format("avro").save(dir.toString) + val readDf = spark.read.format("avro").load(dir.toString) + + checkAnswer(readDf, df) + + // Verify schema - all should be default TimeType(6) + readDf.schema.fields.foreach { field => + assert(field.dataType == TimeType(), s"Field ${field.name} should be TimeType") + } + + // Verify boundary values + val row = readDf.collect()(0) + assert(row.getAs[java.time.LocalTime]("midnight") == + java.time.LocalTime.of(0, 0, 0, 123456000)) + assert(row.getAs[java.time.LocalTime]("noon") == + java.time.LocalTime.of(12, 34, 56, 789012000)) + assert(row.getAs[java.time.LocalTime]("max_time") == + java.time.LocalTime.of(23, 59, 59, 999999000)) + assert(row.get(3) == null, "NULL time should be preserved") + } + } + + test("TIME type in nested structures in Avro") { + withTempPath { dir => + // Test TIME type in arrays and structs with different precisions + val df = spark.sql(""" + SELECT + named_struct('start', CAST(TIME'09:00:00.123' AS TIME(3)), + 'end', CAST(TIME'17:30:45.654321' AS TIME(6))) as schedule, + array(TIME'08:15:30.111222', TIME'12:45:15.333444', TIME'16:20:50.555666') as checkpoints + """) + + df.write.format("avro").save(dir.toString) + val readDf = spark.read.format("avro").load(dir.toString) + + checkAnswer(readDf, df) + } + } + + test("TIME type precision metadata is preserved in Avro") { + withTempPath { dir => + // Test all TIME precisions (0-6) with multiple columns + val df = spark.sql(""" + SELECT + id, + CAST(TIME '12:34:56' AS TIME(0)) as time_p0, + CAST(TIME '12:34:56.1' AS TIME(1)) as time_p1, + CAST(TIME '12:34:56.12' AS TIME(2)) as time_p2, + CAST(TIME '12:34:56.123' AS TIME(3)) as time_p3, + CAST(TIME '12:34:56.1234' AS TIME(4)) as time_p4, + CAST(TIME '12:34:56.12345' AS TIME(5)) as time_p5, + CAST(TIME '12:34:56.123456' AS TIME(6)) as time_p6, + description + FROM VALUES + (1, 'Morning'), + (2, 'Evening') + AS t(id, description) + """) + + // Verify original schema has all precisions + (0 to 6).foreach { p => + assert(df.schema(s"time_p$p").dataType == TimeType(p)) + } + + // Write to Avro and read back + df.write.format("avro").save(dir.toString) + val readDf = spark.read.format("avro").load(dir.toString) + + // Verify ALL precisions are preserved after round-trip + (0 to 6).foreach { p => + assert(readDf.schema(s"time_p$p").dataType == TimeType(p), + s"Precision $p should be preserved") + } + + // Verify data integrity + checkAnswer(readDf, df) + } + } + + test("SPARK-57581: TIME is written as unit-correct time-micros for external readers") { + // Expected microseconds-since-midnight for TIME'12:34:56.123456' truncated to each precision. + val baseSeconds = (12 * 3600 + 34 * 60 + 56).toLong + val expectedMicros = Map( + 0 -> (baseSeconds * 1000000L + 0L), + 1 -> (baseSeconds * 1000000L + 100000L), + 2 -> (baseSeconds * 1000000L + 120000L), + 3 -> (baseSeconds * 1000000L + 123000L), + 4 -> (baseSeconds * 1000000L + 123400L), + 5 -> (baseSeconds * 1000000L + 123450L), + 6 -> (baseSeconds * 1000000L + 123456L)) + // Valid micros-of-day range; values mislabeled as micros but holding nanos would exceed this. + val microsPerDay = 24L * 3600L * 1000000L + + (0 to 6).foreach { p => + withTempPath { dir => + spark.sql(s"SELECT CAST(TIME'12:34:56.123456' AS TIME($p)) as t") + .write.format("avro").save(dir.toString) + + val avroFile = dir.listFiles() + .filter(f => f.isFile && f.getName.endsWith("avro")) + .head + val reader = new DataFileReader[GenericRecord]( + avroFile, new GenericDatumReader[GenericRecord]()) + try { + // The Avro field must be annotated with the time-micros logical type. + val fieldSchema = reader.getSchema.getField("t").schema() + val timeSchema = if (fieldSchema.getType == Type.UNION) { + fieldSchema.getTypes.asScala.find(_.getType == Type.LONG).get + } else { + fieldSchema + } + assert(timeSchema.getLogicalType.getName == "time-micros", + s"precision $p should be written as time-micros") + + assert(reader.hasNext) + val record = reader.next() + val stored = record.get("t").asInstanceOf[Long] + assert(stored == expectedMicros(p), + s"precision $p should store micros-of-day ${expectedMicros(p)}, but was $stored") + assert(stored >= 0 && stored < microsPerDay, + s"precision $p stored value $stored is outside the valid micros-of-day range") + } finally { + reader.close() + } + } + } + } + + test("SPARK-57581: TIME read from a plain time-micros Avro file (no catalyst prop)") { + withTempDir { dir => + // Build an Avro file the way an external tool (Hive/Trino/fastavro) would: a `time-micros` + // long with no `spark.sql.catalyst.type` property. Spark must read it back as TIME, + // converting the stored microseconds-since-midnight to its internal nanoseconds and + // defaulting to the micros precision TIME(6). This pins the deserializer's micros -> nanos + // conversion independently of the write path. + val micros = (12L * 3600 + 34 * 60 + 56) * 1000000L + 123456L + val avroSchema = new Schema.Parser().parse( + """ + |{ + | "type": "record", + | "name": "top", + | "fields": [ + | {"name": "t", "type": {"type": "long", "logicalType": "time-micros"}} + | ] + |} + """.stripMargin) + val avroFile = new File(dir, "external.avro") + val datumWriter = new GenericDatumWriter[GenericRecord](avroSchema) + val dataFileWriter = new DataFileWriter[GenericRecord](datumWriter) + dataFileWriter.create(avroSchema, avroFile) + try { + val record = new GenericData.Record(avroSchema) + record.put("t", micros) + dataFileWriter.append(record) + } finally { + dataFileWriter.close() + } + + val readDf = spark.read.format("avro").load(dir.toString) + assert(readDf.schema("t").dataType == TimeType(TimeType.MICROS_PRECISION)) + checkAnswer(readDf, Row(java.time.LocalTime.of(12, 34, 56, 123456000))) + } + } + } class AvroV1Suite extends AvroSuite { @@ -3392,96 +3567,6 @@ class AvroV2Suite extends AvroSuite with ExplainSuiteHelper { } } - test("TIME type read/write with Avro format") { - withTempPath { dir => - // Test boundary values and NULL handling - val df = spark.sql(""" - SELECT - TIME'00:00:00.123456' as midnight, - TIME'12:34:56.789012' as noon, - TIME'23:59:59.999999' as max_time, - CAST(NULL AS TIME) as null_time - """) - - df.write.format("avro").save(dir.toString) - val readDf = spark.read.format("avro").load(dir.toString) - - checkAnswer(readDf, df) - - // Verify schema - all should be default TimeType(6) - readDf.schema.fields.foreach { field => - assert(field.dataType == TimeType(), s"Field ${field.name} should be TimeType") - } - - // Verify boundary values - val row = readDf.collect()(0) - assert(row.getAs[java.time.LocalTime]("midnight") == - java.time.LocalTime.of(0, 0, 0, 123456000)) - assert(row.getAs[java.time.LocalTime]("noon") == - java.time.LocalTime.of(12, 34, 56, 789012000)) - assert(row.getAs[java.time.LocalTime]("max_time") == - java.time.LocalTime.of(23, 59, 59, 999999000)) - assert(row.get(3) == null, "NULL time should be preserved") - } - } - - test("TIME type in nested structures in Avro") { - withTempPath { dir => - // Test TIME type in arrays and structs with different precisions - val df = spark.sql(""" - SELECT - named_struct('start', CAST(TIME'09:00:00.123' AS TIME(3)), - 'end', CAST(TIME'17:30:45.654321' AS TIME(6))) as schedule, - array(TIME'08:15:30.111222', TIME'12:45:15.333444', TIME'16:20:50.555666') as checkpoints - """) - - df.write.format("avro").save(dir.toString) - val readDf = spark.read.format("avro").load(dir.toString) - - checkAnswer(readDf, df) - } - } - - test("TIME type precision metadata is preserved in Avro") { - withTempPath { dir => - // Test all TIME precisions (0-6) with multiple columns - val df = spark.sql(""" - SELECT - id, - CAST(TIME '12:34:56' AS TIME(0)) as time_p0, - CAST(TIME '12:34:56.1' AS TIME(1)) as time_p1, - CAST(TIME '12:34:56.12' AS TIME(2)) as time_p2, - CAST(TIME '12:34:56.123' AS TIME(3)) as time_p3, - CAST(TIME '12:34:56.1234' AS TIME(4)) as time_p4, - CAST(TIME '12:34:56.12345' AS TIME(5)) as time_p5, - CAST(TIME '12:34:56.123456' AS TIME(6)) as time_p6, - description - FROM VALUES - (1, 'Morning'), - (2, 'Evening') - AS t(id, description) - """) - - // Verify original schema has all precisions - (0 to 6).foreach { p => - assert(df.schema(s"time_p$p").dataType == TimeType(p)) - } - - // Write to Avro and read back - df.write.format("avro").save(dir.toString) - val readDf = spark.read.format("avro").load(dir.toString) - - // Verify ALL precisions are preserved after round-trip - (0 to 6).foreach { p => - assert(readDf.schema(s"time_p$p").dataType == TimeType(p), - s"Precision $p should be preserved") - } - - // Verify data integrity - checkAnswer(readDf, df) - } - } - test("SPARK-56457: Avro V2 formatName matches V1 FileFormat.toString") { val v2Provider = DataSource.lookupDataSourceV2("avro", spark.sessionState.conf) assert(v2Provider.isDefined) diff --git a/connector/docker-integration-tests/pom.xml b/connector/docker-integration-tests/pom.xml index 66022cab7c77b..d468c758f5873 100644 --- a/connector/docker-integration-tests/pom.xml +++ b/connector/docker-integration-tests/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleDatabaseOnDocker.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleDatabaseOnDocker.scala index baed9f5c7a5e0..73a8ad4b7d9da 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleDatabaseOnDocker.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/OracleDatabaseOnDocker.scala @@ -23,7 +23,7 @@ class OracleDatabaseOnDocker extends DatabaseOnDocker with Logging { // sarutak/oracle-free is a custom fork of gvenzl/oracle-free which allows to set timeout for // password initialization. See SPARK-54076 for details. lazy override val imageName = - sys.env.getOrElse("ORACLE_DOCKER_IMAGE_NAME", "sarutak/oracle-free:23.26.1-slim") + sys.env.getOrElse("ORACLE_DOCKER_IMAGE_NAME", "sarutak/oracle-free:23.26.2-slim") val oracle_password = "Th1s1sThe0racle#Pass" override val env = Map( "ORACLE_PWD" -> oracle_password, // oracle images uses this diff --git a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala index c71f9ae7688f3..594819689e6f2 100644 --- a/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala +++ b/connector/docker-integration-tests/src/test/scala/org/apache/spark/sql/jdbc/v2/OracleIntegrationSuite.scala @@ -340,5 +340,13 @@ class OracleIntegrationSuite extends DockerJDBCIntegrationV2Suite with V2JDBCTes assert(rows10.length === 2) assert(rows10(0).getString(0) === "amy") assert(rows10(1).getString(0) === "alex") + + // SPARK-57364: verify MONTH truncation pushes down correctly (not as 'IW') + val df11 = sql(s"SELECT name FROM $tbl WHERE trunc(date1, 'MONTH') = date'2022-05-01'") + checkFilterPushed(df11) + val rows11 = df11.collect() + assert(rows11.length === 2) + assert(rows11(0).getString(0) === "amy") + assert(rows11(1).getString(0) === "alex") } } diff --git a/connector/kafka-0-10-assembly/pom.xml b/connector/kafka-0-10-assembly/pom.xml index b86c94f3e35af..3c0358c2b704a 100644 --- a/connector/kafka-0-10-assembly/pom.xml +++ b/connector/kafka-0-10-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/connector/kafka-0-10-sql/pom.xml b/connector/kafka-0-10-sql/pom.xml index 4980e94c45776..649deaebce221 100644 --- a/connector/kafka-0-10-sql/pom.xml +++ b/connector/kafka-0-10-sql/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala index 828891f0b4983..17323165b451f 100644 --- a/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala +++ b/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchStream.scala @@ -357,7 +357,7 @@ private[kafka010] class KafkaMicroBatchStream( } } } else { - Some(latestPartitionOffsets) + Option(latestPartitionOffsets) } KafkaMicroBatchStream.metrics(latestConsumedOffset, reCalculatedLatestPartitionOffsets) diff --git a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala index e6a794c22e1f2..274be623da6fa 100644 --- a/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala +++ b/connector/kafka-0-10-sql/src/test/scala/org/apache/spark/sql/kafka010/KafkaMicroBatchSourceSuite.scala @@ -1883,6 +1883,32 @@ abstract class KafkaMicroBatchV2SourceSuite extends KafkaMicroBatchSourceSuiteBa // test null latestAvailablePartitionOffsets assert(KafkaMicroBatchStream.metrics(Optional.ofNullable(offset), None).isEmpty) } + + test("SPARK-57438: metrics should not NPE when latestPartitionOffsets is null") { + // Construct a KafkaMicroBatchStream instance without calling latestOffset(), + // so latestPartitionOffsets remains null (its default uninitialized value). + // Calling metrics() on this instance exercises the real non-RTM code path. + val topic = newTopic() + val tp = new TopicPartition(topic, 0) + + SparkSession.setActiveSession(spark) + withTempDir { dir => + val provider = new KafkaSourceProvider() + val options = Map( + "kafka.bootstrap.servers" -> testUtils.brokerAddress, + "subscribe" -> topic + ) + val dsOptions = new CaseInsensitiveStringMap(options.asJava) + val table = provider.getTable(dsOptions) + val stream = table.newScanBuilder(dsOptions).build().toMicroBatchStream(dir.getAbsolutePath) + .asInstanceOf[KafkaMicroBatchStream] + + // latestPartitionOffsets is still null - metrics() must not NPE + val offset = KafkaSourceOffset(Map(tp -> 0L)) + val result = stream.metrics(Optional.of(offset)) + assert(result.isEmpty) + } + } } class KafkaMicroBatchV1SourceWithAdminSuite extends KafkaMicroBatchV1SourceSuite { diff --git a/connector/kafka-0-10-token-provider/pom.xml b/connector/kafka-0-10-token-provider/pom.xml index 5c471db25becb..7fdc489b24cc3 100644 --- a/connector/kafka-0-10-token-provider/pom.xml +++ b/connector/kafka-0-10-token-provider/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/connector/kafka-0-10/pom.xml b/connector/kafka-0-10/pom.xml index a7b5b06a6ff58..ce0feb19a0d1d 100644 --- a/connector/kafka-0-10/pom.xml +++ b/connector/kafka-0-10/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/connector/kinesis-asl-assembly/pom.xml b/connector/kinesis-asl-assembly/pom.xml index c73a0015c416e..e90db9c8e3521 100644 --- a/connector/kinesis-asl-assembly/pom.xml +++ b/connector/kinesis-asl-assembly/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/connector/kinesis-asl/pom.xml b/connector/kinesis-asl/pom.xml index c24bd4886e770..27efdf4c9eeea 100644 --- a/connector/kinesis-asl/pom.xml +++ b/connector/kinesis-asl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/connector/profiler/pom.xml b/connector/profiler/pom.xml index 93572d6d671d3..9652d89802cc8 100644 --- a/connector/profiler/pom.xml +++ b/connector/profiler/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/connector/protobuf/pom.xml b/connector/protobuf/pom.xml index e9521f9418c1f..e90ecb2927b22 100644 --- a/connector/protobuf/pom.xml +++ b/connector/protobuf/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/connector/spark-ganglia-lgpl/pom.xml b/connector/spark-ganglia-lgpl/pom.xml index 7b18a97cbd9de..cedd15d9c7aa7 100644 --- a/connector/spark-ganglia-lgpl/pom.xml +++ b/connector/spark-ganglia-lgpl/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/core/pom.xml b/core/pom.xml index 6b228a86f3535..99ad16676c9cb 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../pom.xml diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 95778771d4239..e95c4bd8c6222 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -135,14 +135,6 @@ private[deploy] class Master( private var restServer: Option[StandaloneRestServer] = None private var restServerBoundPort: Option[Int] = None - { - val authKey = SecurityManager.SPARK_AUTH_SECRET_CONF - require(conf.getOption(authKey).isEmpty || !restServerEnabled, - s"The RestSubmissionServer does not support authentication via ${authKey}. Either turn " + - "off the RestSubmissionServer with spark.master.rest.enabled=false, or do not use " + - "authentication.") - } - override def onStart(): Unit = { logInfo(log"Starting Spark master at ${MDC(LogKeys.MASTER_URL, masterUrl)}") logInfo(log"Running Spark version" + diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index 66036e7a5e5ce..00fe3d260cabc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -72,10 +72,11 @@ class MasterWebUI( if (decommissionEnabled) { attachHandler(createServletHandler("/workers/kill", new HttpServlet { override def doPost(req: HttpServletRequest, resp: HttpServletResponse): Unit = { + if (!master.securityMgr.checkModifyPermissions(req.getRemoteUser)) return val hostnames: Seq[String] = Option(req.getParameterValues("host")) .getOrElse(Array[String]()).toImmutableArraySeq if (!isDecommissioningRequestAllowed(req)) { - resp.sendError(HttpServletResponse.SC_METHOD_NOT_ALLOWED) + resp.sendError(HttpServletResponse.SC_FORBIDDEN) } else { val removedWorkers = masterEndpointRef.askSync[Integer]( DecommissionWorkersOnHosts(hostnames)) diff --git a/core/src/main/scala/org/apache/spark/internal/config/UI.scala b/core/src/main/scala/org/apache/spark/internal/config/UI.scala index 20c99794cd215..2e83c1f64f2ec 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/UI.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/UI.scala @@ -136,6 +136,16 @@ private[spark] object UI { .stringConf .createOptional + val UI_CONTENT_SECURITY_POLICY_ENABLED = + ConfigBuilder("spark.ui.contentSecurityPolicy.enabled") + .doc("Whether to set the HTTP Content-Security-Policy (CSP) response header for the " + + "Spark UI. When enabled, CSP restricts the sources from which the browser is allowed " + + "to load resources, providing defense-in-depth against cross-site scripting (XSS).") + .version("4.2.0") + .withBindingPolicy(ConfigBindingPolicy.NOT_APPLICABLE) + .booleanConf + .createWithDefault(false) + val UI_REQUEST_HEADER_SIZE = ConfigBuilder("spark.ui.requestHeaderSize") .doc("Value for HTTP request header size in bytes.") .version("2.2.3") diff --git a/core/src/main/scala/org/apache/spark/internal/config/package.scala b/core/src/main/scala/org/apache/spark/internal/config/package.scala index 1bbd733909bd8..9a43bcfdd3a2f 100644 --- a/core/src/main/scala/org/apache/spark/internal/config/package.scala +++ b/core/src/main/scala/org/apache/spark/internal/config/package.scala @@ -2987,15 +2987,6 @@ package object config { "The value only can be one or more of 'stdout, stderr'.") .createWithDefault(Seq("stdout", "stderr")) - private[spark] val YARN_AM_LIMIT_ACTIVE_PROCESSOR_COUNT_ENABLED = - ConfigBuilder("spark.yarn.am.limitActiveProcessorCount.enabled") - .doc("Whether to add -XX:ActiveProcessorCount= to the YARN " + - "Application Master JVM options in client mode. In cluster mode, use " + - "`spark.driver.limitActiveProcessorCount.enabled` instead.") - .version("4.2.0") - .booleanConf - .createWithDefault(false) - private[spark] val DRIVER_LIMIT_ACTIVE_PROCESSOR_COUNT_ENABLED = ConfigBuilder("spark.driver.limitActiveProcessorCount.enabled") .doc("Whether to add -XX:ActiveProcessorCount= to the driver JVM " + diff --git a/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala index a7e0bb683dd06..d06f2162d2f6c 100644 --- a/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala +++ b/core/src/main/scala/org/apache/spark/ui/HttpSecurityFilter.scala @@ -51,10 +51,12 @@ private class HttpSecurityFilter( val cspNonce = CspNonce.generate() try { - hres.setHeader("Content-Security-Policy", - s"default-src 'self'; script-src 'self' 'nonce-$cspNonce'; " + - s"style-src 'self' 'unsafe-inline'; img-src 'self' data:; " + - s"object-src 'none'; base-uri 'self';") + if (conf.get(UI_CONTENT_SECURITY_POLICY_ENABLED)) { + hres.setHeader("Content-Security-Policy", + s"default-src 'self'; script-src 'self' 'nonce-$cspNonce'; " + + s"style-src 'self' 'unsafe-inline'; img-src 'self' data:; " + + s"object-src 'none'; base-uri 'self';") + } val requestUser = hreq.getRemoteUser() diff --git a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java index 6c1e84f6157cc..055eaf6ccae61 100644 --- a/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java +++ b/core/src/test/java/org/apache/spark/launcher/SparkLauncherSuite.java @@ -152,8 +152,10 @@ private void inProcessLauncherTestImpl() throws Exception { // SPARK-23020: see doc for InProcessTestApp.LOCK for a description of the race. Here // we wait until we know that the connection between the app and the launcher has been // established before allowing the app to finish. + // Use a generous timeout because, under heavy CI load, establishing the connection + // between the in-process app and the launcher can take longer than a few seconds. final SparkAppHandle _handle = handle; - eventually(Duration.ofSeconds(5), Duration.ofMillis(10), () -> { + eventually(Duration.ofSeconds(30), Duration.ofMillis(100), () -> { assertNotEquals(SparkAppHandle.State.UNKNOWN, _handle.getState()); }); diff --git a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala index e64bc724cfba0..c138d1142a474 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/MasterSuite.scala @@ -249,4 +249,14 @@ class MasterSuite extends MasterSuiteBase { eventLogCodec = None) assert(master.invokePrivate(_createApplication(desc, null)).id === "spark-45756") } + + test("SPARK-57451: Allows REST server and spark.authenticate.secret to be enabled together") { + val conf = new SparkConf() + .set(MASTER_REST_SERVER_ENABLED, true) + .set(AUTH_SECRET, "secret") + // Creating a Master must not fail when both the REST server and an auth secret are set. + noException should be thrownBy { + makeMaster(conf) + } + } } diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUIAclSuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUIAclSuite.scala new file mode 100644 index 0000000000000..b3e773f075c7d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUIAclSuite.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.master.ui + +import jakarta.servlet.{Filter, FilterChain, ServletRequest, ServletResponse} +import jakarta.servlet.http.{HttpServletRequest, HttpServletRequestWrapper} +import org.mockito.Mockito.{mock, never, verify, when} + +import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} +import org.apache.spark.deploy.DeployMessages.DecommissionWorkersOnHosts +import org.apache.spark.deploy.master.Master +import org.apache.spark.internal.config.DECOMMISSION_ENABLED +import org.apache.spark.internal.config.UI._ +import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} +import org.apache.spark.util.Utils + +/** + * Tests the modify ACL enforcement of the Master Web UI's `/workers/kill` endpoint. + * A [[FakeAuthFilter]] injects the remote user so the modify permission can be exercised. + */ +class MasterWebUIAclSuite extends SparkFunSuite { + import MasterWebUISuite._ + + val conf = new SparkConf() + .set(DECOMMISSION_ENABLED, true) + .set(UI_FILTERS, Seq(classOf[FakeAuthFilter].getName)) + .set(ACLS_ENABLE, true) + .set(UI_VIEW_ACLS, Seq("*")) + .set(MODIFY_ACLS, Seq("alice")) + val securityMgr = new SecurityManager(conf) + val rpcEnv = mock(classOf[RpcEnv]) + val master = mock(classOf[Master]) + val masterEndpointRef = mock(classOf[RpcEndpointRef]) + when(master.securityMgr).thenReturn(securityMgr) + when(master.conf).thenReturn(conf) + when(master.rpcEnv).thenReturn(rpcEnv) + when(master.self).thenReturn(masterEndpointRef) + val masterWebUI = new MasterWebUI(master, 0) + + override def beforeAll(): Unit = { + super.beforeAll() + masterWebUI.bind() + } + + override def afterAll(): Unit = { + try { + masterWebUI.stop() + } finally { + super.afterAll() + } + } + + private def killWorkers(hostnames: Seq[String], user: String): Unit = { + val url = s"http://${Utils.localHostNameForURI()}:${masterWebUI.boundPort}/workers/kill/" + val body = convPostDataToString(hostnames.map(("host", _))) + val headers = Seq(FakeAuthFilter.FAKE_HTTP_USER -> user) + val conn = sendHttpRequest(url, "POST", body, headers) + // The master is mocked here, so cannot assert on the response code. + conn.getResponseCode + } + + test("Allow the worker kill request with the modify permission") { + val hostnames = Seq("allowed") + killWorkers(hostnames, "alice") + verify(masterEndpointRef).askSync[Integer](DecommissionWorkersOnHosts(hostnames)) + } + + test("Reject the worker kill request without the modify permission") { + val hostnames = Seq("denied") + killWorkers(hostnames, "nobody") + verify(masterEndpointRef, never()).askSync[Integer](DecommissionWorkersOnHosts(hostnames)) + } +} + +/** Test filter that sets the remote user from the request's HTTP_USER header. */ +class FakeAuthFilter extends Filter { + override def doFilter(req: ServletRequest, res: ServletResponse, chain: FilterChain): Unit = { + val hreq = req.asInstanceOf[HttpServletRequest] + val wrapped = new HttpServletRequestWrapper(hreq) { + override def getRemoteUser(): String = hreq.getHeader(FakeAuthFilter.FAKE_HTTP_USER) + } + chain.doFilter(wrapped, res) + } +} + +object FakeAuthFilter { + val FAKE_HTTP_USER = "HTTP_USER" +} diff --git a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala index 5e75d1c424eab..332e7bf9b5e04 100644 --- a/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala +++ b/core/src/test/scala/org/apache/spark/deploy/master/ui/MasterWebUISuite.scala @@ -24,6 +24,7 @@ import java.util.Date import scala.collection.mutable.HashMap +import jakarta.servlet.http.HttpServletResponse.SC_FORBIDDEN import org.mockito.Mockito.{mock, times, verify, when} import org.apache.spark.{SecurityManager, SparkConf, SparkFunSuite} @@ -31,6 +32,7 @@ import org.apache.spark.deploy.DeployMessages.{DecommissionWorkersOnHosts, KillD import org.apache.spark.deploy.DeployTestUtils._ import org.apache.spark.deploy.master._ import org.apache.spark.internal.config.DECOMMISSION_ENABLED +import org.apache.spark.internal.config.UI.MASTER_UI_DECOMMISSION_ALLOW_MODE import org.apache.spark.rpc.{RpcEndpointRef, RpcEnv} import org.apache.spark.util.Utils @@ -106,6 +108,26 @@ class MasterWebUISuite extends SparkFunSuite { test("Kill multiple hosts") { testKillWorkers(Seq("noSuchHost", "LocalHost")) } + + test("SPARK-57509: /workers/kill responds with 403 Forbidden when the request is not allowed") { + val denyConf = new SparkConf() + .set(DECOMMISSION_ENABLED, true) + .set(MASTER_UI_DECOMMISSION_ALLOW_MODE.key, "DENY") + val denyMaster = mock(classOf[Master]) + when(denyMaster.securityMgr).thenReturn(new SecurityManager(denyConf)) + when(denyMaster.conf).thenReturn(denyConf) + when(denyMaster.rpcEnv).thenReturn(rpcEnv) + when(denyMaster.self).thenReturn(masterEndpointRef) + val denyWebUI = new MasterWebUI(denyMaster, 0) + try { + denyWebUI.bind() + val url = s"http://${Utils.localHostNameForURI()}:${denyWebUI.boundPort}/workers/kill/" + val body = convPostDataToString(Seq(("host", Utils.localHostNameForURI()))) + assert(sendHttpRequest(url, "POST", body).getResponseCode === SC_FORBIDDEN) + } finally { + denyWebUI.stop() + } + } } object MasterWebUISuite { @@ -124,9 +146,11 @@ object MasterWebUISuite { private[ui] def sendHttpRequest( url: String, method: String, - body: String = ""): HttpURLConnection = { + body: String = "", + headers: Seq[(String, String)] = Nil): HttpURLConnection = { val conn = new URI(url).toURL.openConnection().asInstanceOf[HttpURLConnection] conn.setRequestMethod(method) + headers.foreach { case (k, v) => conn.setRequestProperty(k, v) } if (body.nonEmpty) { conn.setDoOutput(true) conn.setRequestProperty("Content-Type", "application/x-www-form-urlencoded") diff --git a/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala b/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala index 9965751c42cce..776f2e4496cc8 100644 --- a/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/HttpSecurityFilterSuite.scala @@ -119,6 +119,7 @@ class HttpSecurityFilterSuite extends SparkFunSuite { .set(UI_X_XSS_PROTECTION, "xssProtection") .set(UI_X_CONTENT_TYPE_OPTIONS, true) .set(UI_STRICT_TRANSPORT_SECURITY, "tsec") + .set(UI_CONTENT_SECURITY_POLICY_ENABLED, true) val secMgr = new SecurityManager(conf) val req = mockRequest() val res = mock(classOf[HttpServletResponse]) @@ -149,6 +150,19 @@ class HttpSecurityFilterSuite extends SparkFunSuite { } } + test("Content-Security-Policy header is not set by default") { + val conf = new SparkConf(false) + val secMgr = new SecurityManager(conf) + val req = mockRequest() + val res = mock(classOf[HttpServletResponse]) + val chain = mock(classOf[FilterChain]) + + val filter = new HttpSecurityFilter(conf, secMgr) + filter.doFilter(req, res, chain) + + verify(res, times(0)).setHeader(meq("Content-Security-Policy"), any()) + } + test("doAs impersonation") { val conf = new SparkConf(false) .set(ACLS_ENABLE, true) @@ -186,7 +200,7 @@ class HttpSecurityFilterSuite extends SparkFunSuite { } test("CSP nonce is available during chain.doFilter and cleared after") { - val conf = new SparkConf(false) + val conf = new SparkConf(false).set(UI_CONTENT_SECURITY_POLICY_ENABLED, true) val secMgr = new SecurityManager(conf) val req = mockRequest() val res = mock(classOf[HttpServletResponse]) diff --git a/docs/_config.yml b/docs/_config.yml index 5109d8d338a78..d80b90433e356 100644 --- a/docs/_config.yml +++ b/docs/_config.yml @@ -19,8 +19,8 @@ include: # These allow the documentation to be updated with newer releases # of Spark, Scala. -SPARK_VERSION: 4.2.0.1-4.3.0-1 -SPARK_VERSION_SHORT: 4.2.0.1-4.3.0-1 +SPARK_VERSION: 4.2.0.1-4.3.0-2 +SPARK_VERSION_SHORT: 4.2.0.1-4.3.0-2 SCALA_BINARY_VERSION: "2.13" SCALA_VERSION: "2.13.18" SPARK_ISSUE_TRACKER_URL: https://issues.apache.org/jira/browse/SPARK @@ -37,7 +37,7 @@ DOCSEARCH_SCRIPT: | inputSelector: '#docsearch-input', enhancedSearchInput: true, algoliaOptions: { - 'facetFilters': ["version:4.2.0.1-4.3.0-1"] + 'facetFilters': ["version:4.2.0.1-4.3.0-2"] }, debug: false // Set debug to true if you want to inspect the dropdown }); diff --git a/docs/security.md b/docs/security.md index 484017df84fb8..b3743d9a1b0a6 100644 --- a/docs/security.md +++ b/docs/security.md @@ -803,6 +803,16 @@ Security. 2.3.0 + + spark.ui.contentSecurityPolicy.enabled + false + + When enabled, the Content-Security-Policy (CSP) HTTP response header is set for the Spark UI, + restricting the sources from which the browser is allowed to load resources as a + defense-in-depth measure against cross-site scripting (XSS). + + 4.2.0 + diff --git a/docs/sql-migration-guide.md b/docs/sql-migration-guide.md index 620e3800ff010..fd4e85f136dac 100644 --- a/docs/sql-migration-guide.md +++ b/docs/sql-migration-guide.md @@ -30,6 +30,11 @@ license: | - Since Spark 4.2, the virtual `system` catalog hosts the new `system.builtin` and `system.session` namespaces. `system.builtin` exposes built-in functions and functions injected through `SparkSessionExtensions`; `system.session` exposes temporary views, temporary functions, and session variables created in the current session. As a result, 2-part references like `builtin.func()` and `session.func()` now follow a mini-path that tries the system namespace first and the current catalog second, so a persistent schema named `builtin` or `session` is no longer reached by `builtin.func()` / `session.func()` when the system namespace contains an object of the same name. To restore the previous behavior (current catalog first), set `spark.sql.legacy.persistentCatalogFirst` to `true`. Persistent schemas with these names are still allowed but should be reached with an explicit catalog prefix (for example, `spark_catalog.session.x`). See [Reserved system names](sql-ref-identifier.html#reserved-system-names). - Since Spark 4.2, `CREATE TEMPORARY VIEW`, `CREATE TEMPORARY FUNCTION`, and the corresponding `DROP` statements accept the `session` and `system.session` qualifiers on the object name (in addition to the previously supported unqualified form); for example, `CREATE TEMPORARY VIEW system.session.v AS ...` and `DROP TEMPORARY FUNCTION session.f` are now valid. Any other qualifier on a temporary object is rejected with `INVALID_TEMP_OBJ_QUALIFIER`. - Since Spark 4.2, the SQL standard `PATH` feature is available: the `SET PATH` statement, the `current_path()` function, path-based resolution of unqualified routines, tables, views, and session variables, and the configurations `spark.sql.path.enabled` (default `false`) and `spark.sql.defaultPath`. The feature is opt-in; when `spark.sql.path.enabled` is `false`, unqualified resolution falls back to a fixed default path and `SET PATH` is rejected with `UNSUPPORTED_FEATURE.SET_PATH_WHEN_DISABLED`. See [SET PATH](sql-ref-syntax-aux-conf-mgmt-set-path.html) and [Name Resolution](sql-ref-name-resolution.html). +- Since Spark 4.2, duplicate Common Table Expression (CTE) names within a single `WITH` clause are detected case-insensitively at parse time, so names that differ only in case (for example, `WITH a AS (...), A AS (...)`) are rejected with `DUPLICATED_CTE_NAMES`. This check is always case-insensitive and does not depend on `spark.sql.caseSensitive`. Rename the conflicting CTEs. +- Since Spark 4.2, `NATURAL JOIN` honors the `spark.sql.caseSensitive` configuration when determining the common columns to join on. This can change which columns are used as join keys, and therefore the query result, when column names differ only in case. +- Since Spark 4.2, when a SQL UDF has a parameter whose name matches a parameterless built-in function (`current_user`, `current_date`, `current_time`, `current_timestamp`, `user`, `session_user`, `grouping__id`), a bare reference to that name in the function body resolves to the built-in function instead of the parameter, matching the documented name resolution rules. Rename the parameter to avoid the collision, or set `spark.sql.legacy.allowUdfParameterToShadowParameterlessFunction` to `true` to restore the previous behavior. +- Since Spark 4.2, `SET CATALOG ` resolves a bare (unquoted) name as a session variable first, using the variable's value as the catalog name when such a variable exists, and otherwise treats the name as a literal catalog name. Use a string literal (`SET CATALOG 'name'`) to force it to be interpreted literally. +- Since Spark 4.2, when an error occurs while collecting observed metrics, `Observation.get` raises the underlying exception (for example, `SparkRuntimeException` in Scala or `PySparkException` in Python) instead of silently returning an empty result. Add error handling if your code relied on receiving an empty result on failure. ## Upgrading from Spark SQL 4.0 to 4.1 diff --git a/docs/streaming/ss-migration-guide.md b/docs/streaming/ss-migration-guide.md index a0c0a397edeae..210f752d9d17d 100644 --- a/docs/streaming/ss-migration-guide.md +++ b/docs/streaming/ss-migration-guide.md @@ -23,6 +23,10 @@ Note that this migration guide describes the items specific to Structured Stream Many items of SQL migration can be applied when migrating Structured Streaming to higher versions. Please refer [Migration Guide: SQL, Datasets and DataFrame](../sql-migration-guide.html). +## Upgrading from Structured Streaming 4.1 to 4.2 + +- Since Spark 4.2, restarting a streaming query from a checkpoint whose metadata file is missing while the offset or commit logs contain data fails with `STREAMING_CHECKPOINT_MISSING_METADATA_FILE`, instead of silently generating a new query ID (which can duplicate data in exactly-once sinks). Restore the metadata file or use a new checkpoint location. To restore the previous behavior, set `spark.sql.streaming.checkpoint.verifyMetadataExists.enabled` to `false`. (See [SPARK-55058](https://issues.apache.org/jira/browse/SPARK-55058) for more details.) + ## Upgrading from Structured Streaming 4.0 to 4.1 - Since Spark 4.1, AQE is supported for stateless workloads, and it could affect the behavior of the query after upgrade (especially since AQE is turned on by default). In general, it helps to achieve better performance including resolution of skewed partition, but you can turn off AQE via changing `spark.sql.adaptive.streaming.stateless.enabled` to `false` to restore the behavior if you see regression. diff --git a/examples/pom.xml b/examples/pom.xml index 7edcc47c2a2ff..1b6e02ccb8833 100644 --- a/examples/pom.xml +++ b/examples/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../pom.xml diff --git a/graphx/pom.xml b/graphx/pom.xml index 0d7048731b297..7af8aaf18f670 100644 --- a/graphx/pom.xml +++ b/graphx/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../pom.xml diff --git a/hadoop-cloud/pom.xml b/hadoop-cloud/pom.xml index a308783f8a120..ee45f47430df0 100644 --- a/hadoop-cloud/pom.xml +++ b/hadoop-cloud/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../pom.xml diff --git a/launcher/pom.xml b/launcher/pom.xml index a3443e4478391..6913b1c1020a7 100644 --- a/launcher/pom.xml +++ b/launcher/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../pom.xml diff --git a/mllib-local/pom.xml b/mllib-local/pom.xml index f4302f99265be..3d00a19769f9d 100644 --- a/mllib-local/pom.xml +++ b/mllib-local/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../pom.xml diff --git a/mllib/pom.xml b/mllib/pom.xml index e2f4aff537abe..2d3e4ea5c782b 100644 --- a/mllib/pom.xml +++ b/mllib/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../pom.xml diff --git a/pom.xml b/pom.xml index 0c0f250073b8a..c8840312cba86 100644 --- a/pom.xml +++ b/pom.xml @@ -26,7 +26,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 pom Spark Project Parent POM https://spark.apache.org/ diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 5d3dfb94c36cf..900a6dcba82f6 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -1643,6 +1643,7 @@ object Unidoc { .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/kafka010"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/types/variant"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/ui/flamegraph"))) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/udf/worker"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/collection"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/io"))) .map(_.filterNot(_.getCanonicalPath.contains("org/apache/spark/util/kvstore"))) @@ -1674,6 +1675,59 @@ object Unidoc { .map(_.filterNot(_.data.getCanonicalPath.contains("connect-shims"))) } + // genjavadoc emits top-level package-private Scala types (`private` or `private[x]`, e.g. + // `private[spark] trait Foo` or a bare `private class Bar`) as *public* Java stubs even with + // `-P:genjavadoc:strictVisibility=true`, because a top-level package-private type compiles to a + // JVM-public symbol, and the Javadoc `-public` option cannot drop a stub that really is public. + // ScalaDoc honors the qualifier and hides such types, so we drop their stubs here to keep the + // published Java API doc aligned with the Scala one. A stub `/target/java//.java` + // is dropped iff EVERY top-level Scala declaration of `` in that package is `private` or + // `private[...]`; a public class with a private companion object (e.g. `SparkConf`) is kept, since + // the class itself is public. The private regex tolerates other modifiers around the access + // qualifier (e.g. `final private[x] class`). + private val publicTopTypeRe = + """(?m)^(?:@\w+(?:\([^\n)]*\))?\s+)*(?:(?:final|sealed|abstract|implicit|case)\s+)*(?:class|trait|object)\s+(\w+)""".r + private val privateTopTypeRe = + """(?m)^(?:@\w+(?:\([^\n)]*\))?\s+)*(?:(?:final|sealed|abstract|implicit|case)\s+)*private(?:\[[^\]]+\])?\s+(?:(?:final|sealed|abstract|implicit|case)\s+)*(?:class|trait|object)\s+(\w+)""".r + + private def dropPackagePrivateJavaStubs(sources: Seq[Seq[File]]): Seq[Seq[File]] = { + val cache = scala.collection.mutable.Map.empty[String, (Set[String], Set[String])] + def scanPkg(dir: String): (Set[String], Set[String]) = cache.getOrElseUpdate(dir, { + val d = new File(dir) + val scalaFiles = + if (d.isDirectory) d.listFiles.filter(_.getName.endsWith(".scala")) else Array.empty[File] + val text = scalaFiles + .map(f => new String(java.nio.file.Files.readAllBytes(f.toPath), "UTF-8")) + .mkString("\n") + (publicTopTypeRe.findAllMatchIn(text).map(_.group(1)).toSet, + privateTopTypeRe.findAllMatchIn(text).map(_.group(1)).toSet) + }) + val marker = "/target/java/" + sources.map(_.filterNot { f => + val path = f.getCanonicalPath.replace('\\', '/') + val idx = path.indexOf(marker) + if (idx < 0 || !path.endsWith(".java")) { + false + } else { + val rel = path.substring(idx + marker.length) // /.java + val slash = rel.lastIndexOf('/') + if (slash < 0) { + false + } else { + val moduleRoot = path.substring(0, idx) + val pkgPath = rel.substring(0, slash) + val name = f.getName.stripSuffix(".java") + val (pub, priv) = Seq("src/main/scala", "src/main/scala-2.13") + .map(d => scanPkg(s"$moduleRoot/$d/$pkgPath")) + .foldLeft((Set.empty[String], Set.empty[String])) { + case ((p, q), (a, b)) => (p ++ a, q ++ b) + } + priv.contains(name) && !pub.contains(name) + } + } + }) + } + val unidocSourceBase = settingKey[String]("Base URL of source links in Scaladoc.") lazy val settings = BaseUnidocPlugin.projectSettings ++ @@ -1698,8 +1752,9 @@ object Unidoc { // Skip class names containing $ and some internal packages in Javadocs (JavaUnidoc / unidoc / unidocAllSources) := { - ignoreUndocumentedPackages((JavaUnidoc / unidoc / unidocAllSources).value) - .map(_.filterNot(_.getCanonicalPath.contains("org/apache/hadoop"))) + dropPackagePrivateJavaStubs( + ignoreUndocumentedPackages((JavaUnidoc / unidoc / unidocAllSources).value) + .map(_.filterNot(_.getCanonicalPath.contains("org/apache/hadoop")))) }, (JavaUnidoc / unidoc / javacOptions) := { diff --git a/python/docs/source/migration_guide/pyspark_upgrade.rst b/python/docs/source/migration_guide/pyspark_upgrade.rst index dba2a6266939a..9fcf83f4cb9ee 100644 --- a/python/docs/source/migration_guide/pyspark_upgrade.rst +++ b/python/docs/source/migration_guide/pyspark_upgrade.rst @@ -26,6 +26,12 @@ Upgrading from PySpark 4.1 to 4.2 * In Spark 4.2, columnar data exchange between PySpark and the JVM uses Apache Arrow by default. The configuration ``spark.sql.execution.arrow.pyspark.enabled`` now defaults to true. To restore the legacy (non-Arrow) row-based data exchange, set ``spark.sql.execution.arrow.pyspark.enabled`` to ``false``. * In Spark 4.2, regular Python UDFs are Arrow-optimized by default. The configuration ``spark.sql.execution.pythonUDF.arrow.enabled`` now defaults to true. To restore the legacy behavior for Python UDF execution, set ``spark.sql.execution.pythonUDF.arrow.enabled`` to ``false``. * In Spark 4.2, regular Python UDTFs are Arrow-optimized by default. The configuration ``spark.sql.execution.pythonUDTF.arrow.enabled`` now defaults to true. To restore the legacy behavior for Python UDTF execution, set ``spark.sql.execution.pythonUDTF.arrow.enabled`` to ``false``. +* In Spark 4.2, PyPy is no longer officially supported. Run PySpark on CPython instead. +* In Spark 4.2, ``SparkSession.createDataFrame`` from a NumPy ``ndarray`` requires PyArrow (instead of pandas) and converts the array directly to Arrow rather than through pandas. Install PyArrow; if you previously ran with Arrow disabled and relied on NumPy-dtype-based schema inference, review the inferred schema, as it now follows Arrow's type mapping. +* In Spark 4.2, when a pandas UDF receives a nullable integer column whose batch contains nulls, the column is delivered as a pandas nullable integer extension dtype (``Int8``/``Int16``/``Int32``/``Int64``) instead of ``float64``. Update UDF code that assumed ``float64`` input for nullable integer columns. +* In Spark 4.2, ``DataFrame.drop`` and ``Series.drop`` in pandas API on Spark raise a ``KeyError`` when any of the specified labels is missing, instead of only when all of them are missing, matching pandas. Make sure all labels exist before dropping, filter to existing labels, or pass ``errors="ignore"``. +* In Spark 4.2, a Python Data Source whose returned Arrow data has column types that do not match its declared schema fails with ``DATA_SOURCE_RETURN_SCHEMA_MISMATCH`` (column count and name mismatches already raised this error in earlier versions). Make the data source return data whose types match its declared schema. +* In Spark 4.2, a ``SimpleDataSourceStreamReader`` whose ``read()`` returns a non-empty batch without advancing the end offset past the start offset fails with ``SIMPLE_STREAM_READER_OFFSET_DID_NOT_ADVANCE``, instead of reprocessing the same batch and growing the prefetch cache without bound. Ensure the returned end offset advances past the last record. Upgrading from PySpark 4.0 to 4.1 --------------------------------- diff --git a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py index 04c6b84e02f34..1975cc1b899b6 100644 --- a/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py +++ b/python/pyspark/pandas/tests/plot/test_series_plot_matplotlib.py @@ -321,7 +321,7 @@ def check_box_plot(pser, psser, *args, **kwargs): check_box_plot(p, k, showfliers=True) check_box_plot(p, k, sym="") check_box_plot(p, k, sym=".", color="r") - check_box_plot(p, k, use_index=False, labels=["Test"]) + check_box_plot(p, k, use_index=False) check_box_plot(p, k, usermedians=[2.0]) check_box_plot(p, k, conf_intervals=[(1.0, 3.0)]) diff --git a/python/pyspark/pipelines/add_pipeline_analysis_context.py b/python/pyspark/pipelines/add_pipeline_analysis_context.py index 6ecabdf43b072..6d0bd4dd7308c 100644 --- a/python/pyspark/pipelines/add_pipeline_analysis_context.py +++ b/python/pyspark/pipelines/add_pipeline_analysis_context.py @@ -45,4 +45,7 @@ def add_pipeline_analysis_context( extension_id = client.add_threadlocal_user_context_extension(extension) yield finally: - client.remove_user_context_extension(extension_id) + # extension_id stays None if registering the extension above failed; skip cleanup in that + # case so we don't call remove_user_context_extension(None) and mask the original error. + if extension_id is not None: + client.remove_user_context_extension(extension_id) diff --git a/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py index 44e60f4597c69..f99e7d56c3d51 100644 --- a/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py +++ b/python/pyspark/pipelines/tests/test_add_pipeline_analysis_context.py @@ -15,6 +15,7 @@ # limitations under the License. # import unittest +from unittest import mock from pyspark.testing.connectutils import ( ReusedConnectTestCase, @@ -89,6 +90,43 @@ def test_nested_add_pipeline_analysis_context(self): thread_local_extensions_after_2 = self.spark.client.thread_local.user_context_extensions self.assertEqual(len(thread_local_extensions_after_2), 0) + def test_setup_failure_does_not_mask_original_error(self): + # If any setup step fails before the extension is registered, extension_id stays None and + # the finally block must skip remove_user_context_extension(None) - which would raise + # AttributeError and mask the original error. Cover each step that can fail. + import pyspark.sql.connect.proto as pb2 + from google.protobuf import any_pb2 + + failing_any = mock.MagicMock() + failing_any.Pack.side_effect = ValueError("boom") + + failure_points = { + "context construction": mock.patch.object( + pb2, "PipelineAnalysisContext", side_effect=ValueError("boom") + ), + "extension packing": mock.patch.object(any_pb2, "Any", return_value=failing_any), + "extension registration": mock.patch.object( + self.spark.client, + "add_threadlocal_user_context_extension", + side_effect=ValueError("boom"), + ), + } + for step, patcher in failure_points.items(): + with self.subTest(failing_step=step): + with patcher: + with self.assertRaises(ValueError): + with add_pipeline_analysis_context( + self.spark, "test_dataflow_graph_id", None + ): + pass + # A failed setup should not leave an extension registered. Read lazily, since the + # attribute is only created on first successful registration, which never happens + # here. + thread_local_extensions = getattr( + self.spark.client.thread_local, "user_context_extensions", [] + ) + self.assertEqual(len(thread_local_extensions), 0) + if __name__ == "__main__": from pyspark.testing import main diff --git a/python/pyspark/version.py b/python/pyspark/version.py index 42e21847f472b..e2fcb088630ea 100644 --- a/python/pyspark/version.py +++ b/python/pyspark/version.py @@ -16,4 +16,4 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__: str = "4.2.0.1+4.3.0.1" +__version__: str = "4.2.0.1+4.3.0.2" diff --git a/repl/pom.xml b/repl/pom.xml index 9558939a61a6d..0752520fe5b35 100644 --- a/repl/pom.xml +++ b/repl/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../pom.xml diff --git a/resource-managers/kubernetes/core/pom.xml b/resource-managers/kubernetes/core/pom.xml index 459c0818e52d1..0efdbc1719af2 100644 --- a/resource-managers/kubernetes/core/pom.xml +++ b/resource-managers/kubernetes/core/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../../pom.xml diff --git a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile index 4af2359237d36..09e1b9f1ecd27 100644 --- a/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile +++ b/resource-managers/kubernetes/docker/src/main/dockerfiles/spark/Dockerfile @@ -34,13 +34,6 @@ ARG spark_uid=185 # docker build -t spark:latest -f kubernetes/dockerfiles/spark/Dockerfile . RUN set -ex && \ - if [ -f /etc/apt/sources.list.d/ubuntu.sources ]; then \ - printf 'Types: deb\nURIs: https://mirrors.edge.kernel.org/ubuntu\nSuites: noble noble-updates noble-security\nComponents: main restricted universe multiverse\nSigned-By: /usr/share/keyrings/ubuntu-archive-keyring.gpg\n' > /etc/apt/sources.list.d/mirror.sources; \ - elif [ -f /etc/apt/sources.list ]; then \ - grep -q 'archive.ubuntu.com' /etc/apt/sources.list 2>/dev/null && \ - sed 's|archive.ubuntu.com|mirrors.edge.kernel.org|g;s|security.ubuntu.com|mirrors.edge.kernel.org|g' \ - /etc/apt/sources.list > /etc/apt/sources.list.d/mirror.list || true; \ - fi && \ apt-get update && \ ln -s /lib /lib64 && \ apt install -y --no-install-recommends bash tini libc6 libpam-modules krb5-user libnss3 procps net-tools logrotate libssl-dev && \ diff --git a/resource-managers/kubernetes/integration-tests/pom.xml b/resource-managers/kubernetes/integration-tests/pom.xml index ee806afeaba7f..c5c0994fe465c 100644 --- a/resource-managers/kubernetes/integration-tests/pom.xml +++ b/resource-managers/kubernetes/integration-tests/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../../pom.xml diff --git a/resource-managers/yarn/pom.xml b/resource-managers/yarn/pom.xml index aecf6a1ceb626..57aed854e5ebb 100644 --- a/resource-managers/yarn/pom.xml +++ b/resource-managers/yarn/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config/package.scala b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config/package.scala index 8b63d117e7fa6..59106684b191c 100644 --- a/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config/package.scala +++ b/resource-managers/yarn/src/main/scala/org/apache/spark/deploy/yarn/config/package.scala @@ -301,6 +301,15 @@ package object config extends Logging { .intConf .createWithDefault(1) + private[spark] val YARN_AM_LIMIT_ACTIVE_PROCESSOR_COUNT_ENABLED = + ConfigBuilder("spark.yarn.am.limitActiveProcessorCount.enabled") + .doc("Whether to add -XX:ActiveProcessorCount= to the YARN " + + "Application Master JVM options in client mode. In cluster mode, use " + + "`spark.driver.limitActiveProcessorCount.enabled` instead.") + .version("4.2.0") + .booleanConf + .createWithDefault(false) + private[spark] val AM_DEFAULT_JAVA_OPTIONS = ConfigBuilder("spark.yarn.am.defaultJavaOptions") .doc("Default Java options for the client-mode AM to prepend to " + "`spark.yarn.am.extraJavaOptions`. This is intended to be set by administrators.") diff --git a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala index 1f755ba5efee5..69c515d6a3812 100644 --- a/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala +++ b/resource-managers/yarn/src/test/scala/org/apache/spark/deploy/yarn/BaseYarnClusterSuite.scala @@ -103,6 +103,13 @@ abstract class BaseYarnClusterSuite extends SparkFunSuite with Matchers { yarnConf.set("yarn.scheduler.capacity.root.default.acl_submit_applications", "*") yarnConf.set("yarn.scheduler.capacity.root.default.acl_administer_queue", "*") yarnConf.setInt("yarn.scheduler.capacity.node-locality-delay", -1) + // `maximum-am-resource-percent` defaults to 0.1, which caps the queue's total AM resource + // usage to 10% of its capacity. On memory-constrained CI runners this becomes ~1GB, smaller + // than the AM/driver memory these tests request (1-2GB), so applications get stuck in the + // ACCEPTED state (never activated) and the suite times out waiting for a final state. Let + // AMs use the whole queue in tests so applications are always activated. + yarnConf.setFloat("yarn.scheduler.capacity.maximum-am-resource-percent", 1.0f) + yarnConf.setFloat("yarn.scheduler.capacity.root.default.maximum-am-resource-percent", 1.0f) // Support both IPv4 and IPv6 yarnConf.set("yarn.resourcemanager.hostname", Utils.localHostNameForURI()) diff --git a/sql/api/pom.xml b/sql/api/pom.xml index 9c8a28c179f7a..db2ec1077def5 100644 --- a/sql/api/pom.xml +++ b/sql/api/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala index 3a0e4d45f937c..89859c0326b29 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/Row.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/Row.scala @@ -32,12 +32,12 @@ import org.json4s.jackson.JsonMethods.{compact, pretty, render} import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.annotation.{Stable, Unstable} import org.apache.spark.sql.catalyst.expressions.GenericRow +import org.apache.spark.sql.catalyst.types.ops.TypeApiOps import org.apache.spark.sql.catalyst.util.{DateFormatter, SparkDateTimeUtils, TimeFormatter, TimestampFormatter, UDTUtils} import org.apache.spark.sql.errors.DataTypeErrors import org.apache.spark.sql.errors.DataTypeErrors.{toSQLType, toSQLValue} import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.ops.TypeApiOps import org.apache.spark.unsafe.types.CalendarInterval import org.apache.spark.util.ArrayImplicits._ diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala index bad673672188c..2dfdce7c10e6a 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/encoders/RowEncoder.scala @@ -22,10 +22,10 @@ import scala.reflect.classTag import org.apache.spark.sql.{AnalysisException, Row} import org.apache.spark.sql.catalyst.encoders.AgnosticEncoders.{BinaryEncoder, BoxedBooleanEncoder, BoxedByteEncoder, BoxedDoubleEncoder, BoxedFloatEncoder, BoxedIntEncoder, BoxedLongEncoder, BoxedShortEncoder, CalendarIntervalEncoder, CharEncoder, DateEncoder, DayTimeIntervalEncoder, EncoderField, GeographyEncoder, GeometryEncoder, InstantEncoder, IterableEncoder, JavaDecimalEncoder, LocalDateEncoder, LocalDateTimeEncoder, LocalTimeEncoder, MapEncoder, NullEncoder, RowEncoder => AgnosticRowEncoder, StringEncoder, TimestampEncoder, UDTEncoder, VarcharEncoder, VariantEncoder, YearMonthIntervalEncoder} +import org.apache.spark.sql.catalyst.types.ops.TypeApiOps import org.apache.spark.sql.errors.DataTypeErrorsBase import org.apache.spark.sql.internal.SqlApiConf import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.ops.TypeApiOps import org.apache.spark.util.ArrayImplicits._ /** diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeApiOps.scala similarity index 98% rename from sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeApiOps.scala index dd8f0398aba9c..fbe942c65eb17 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TimeTypeApiOps.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeApiOps.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.types.ops +package org.apache.spark.sql.catalyst.types.ops import java.time.LocalTime diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TypeApiOps.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TypeApiOps.scala similarity index 99% rename from sql/api/src/main/scala/org/apache/spark/sql/types/ops/TypeApiOps.scala rename to sql/api/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TypeApiOps.scala index fff5b8b6a022e..ac3347efc1abc 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/ops/TypeApiOps.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TypeApiOps.scala @@ -15,7 +15,7 @@ * limitations under the License. */ -package org.apache.spark.sql.types.ops +package org.apache.spark.sql.catalyst.types.ops import org.apache.arrow.vector.types.pojo.ArrowType diff --git a/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/FieldMetadataUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/FieldMetadataUtils.scala new file mode 100644 index 0000000000000..c27aef033d5b1 --- /dev/null +++ b/sql/api/src/main/scala/org/apache/spark/sql/catalyst/util/FieldMetadataUtils.scala @@ -0,0 +1,23 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.catalyst.util + +object FieldMetadataUtils { + // Metadata key for the field ID used to track column identity across schema evolution + val FIELD_ID_METADATA_KEY = "__FIELD_ID" +} diff --git a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala index eb3d30051880a..5bb12cf80e45c 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/types/StructField.scala @@ -26,6 +26,7 @@ import org.json4s.JsonDSL._ import org.apache.spark.SparkException import org.apache.spark.annotation.Stable import org.apache.spark.sql.catalyst.util.{CollationFactory, QuotingUtils, StringConcat} +import org.apache.spark.sql.catalyst.util.FieldMetadataUtils.FIELD_ID_METADATA_KEY import org.apache.spark.sql.catalyst.util.ResolveDefaultColumnsUtils.{CURRENT_DEFAULT_COLUMN_METADATA_KEY, EXISTS_DEFAULT_COLUMN_METADATA_KEY} import org.apache.spark.util.SparkSchemaUtils @@ -243,6 +244,43 @@ case class StructField( metadata.contains(EXISTS_DEFAULT_COLUMN_METADATA_KEY) } + /** + * Updates the field with an ID for column identity tracking. + */ + def withId(id: String): StructField = { + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .putString(FIELD_ID_METADATA_KEY, id) + .build() + copy(metadata = newMetadata) + } + + /** + * Returns the ID of this field, if set. + */ + def id: Option[String] = { + if (metadata.contains(FIELD_ID_METADATA_KEY)) { + Some(metadata.getString(FIELD_ID_METADATA_KEY)) + } else { + None + } + } + + /** + * Returns a copy of this field with the field ID removed, or this field if no ID is set. + */ + def clearId(): StructField = { + if (metadata.contains(FIELD_ID_METADATA_KEY)) { + val newMetadata = new MetadataBuilder() + .withMetadata(metadata) + .remove(FIELD_ID_METADATA_KEY) + .build() + copy(metadata = newMetadata) + } else { + this + } + } + private def getDDLDefault = getDefault() .orElse(getCurrentDefaultValue()) .map(" DEFAULT " + _) diff --git a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala index 1c1024fc0152e..f695c079ade40 100644 --- a/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala +++ b/sql/api/src/main/scala/org/apache/spark/sql/util/ArrowUtils.scala @@ -27,9 +27,9 @@ import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, Interval import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} import org.apache.spark.SparkException +import org.apache.spark.sql.catalyst.types.ops.TypeApiOps import org.apache.spark.sql.errors.ExecutionErrors import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.ops.TypeApiOps import org.apache.spark.util.ArrayImplicits._ private[sql] object ArrowUtils { diff --git a/sql/catalyst/pom.xml b/sql/catalyst/pom.xml index 64fcb1fff6847..0126f4d0f2086 100644 --- a/sql/catalyst/pom.xml +++ b/sql/catalyst/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java index 537c2edd11285..f150a1fbe9ffb 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Column.java @@ -18,12 +18,17 @@ package org.apache.spark.sql.connector.catalog; import java.util.Map; +import java.util.Objects; +import java.util.stream.Stream; import javax.annotation.Nullable; +import org.apache.spark.SparkIllegalArgumentException; import org.apache.spark.annotation.Evolving; import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.internal.connector.ColumnImpl; import org.apache.spark.sql.types.DataType; +import org.apache.spark.sql.types.StructField; +import org.apache.spark.sql.util.SchemaUtils; /** * An interface representing a column of a {@link Table}. It defines basic properties of a column, @@ -40,11 +45,11 @@ public interface Column { static Column create(String name, DataType dataType) { - return create(name, dataType, true); + return builderFor(name, dataType).build(); } static Column create(String name, DataType dataType, boolean nullable) { - return create(name, dataType, nullable, null, null); + return builderFor(name, dataType).nullable(nullable).build(); } static Column create( @@ -53,16 +58,11 @@ static Column create( boolean nullable, String comment, String metadataInJSON) { - return new ColumnImpl( - name, - dataType, - nullable, - comment, - /* defaultValue = */ null, - /* generationExpression = */ null, - /* identityColumnSpec = */ null, - metadataInJSON, - /* id = */ null); + return builderFor(name, dataType) + .nullable(nullable) + .comment(comment) + .metadata(metadataInJSON) + .build(); } static Column create( @@ -72,16 +72,12 @@ static Column create( String comment, ColumnDefaultValue defaultValue, String metadataInJSON) { - return new ColumnImpl( - name, - dataType, - nullable, - comment, - defaultValue, - /* generationExpression = */ null, - /* identityColumnSpec = */ null, - metadataInJSON, - /* id = */ null); + return builderFor(name, dataType) + .nullable(nullable) + .comment(comment) + .defaultValue(defaultValue) + .metadata(metadataInJSON) + .build(); } static Column create( @@ -91,35 +87,57 @@ static Column create( String comment, String generationExpression, String metadataInJSON) { - return new ColumnImpl( - name, - dataType, - nullable, - comment, - /* defaultValue = */ null, - generationExpression, - /* identityColumnSpec = */ null, - metadataInJSON, - /* id = */ null); + return builderFor(name, dataType) + .nullable(nullable) + .comment(comment) + .generationExpression(generationExpression) + .metadata(metadataInJSON) + .build(); } static Column create( - String name, - DataType dataType, - boolean nullable, - String comment, - IdentityColumnSpec identityColumnSpec, - String metadataInJSON) { - return new ColumnImpl( - name, - dataType, - nullable, - comment, - /* defaultValue = */ null, - /* generationExpression = */ null, - identityColumnSpec, - metadataInJSON, - /* id = */ null); + String name, + DataType dataType, + boolean nullable, + String comment, + IdentityColumnSpec identityColumnSpec, + String metadataInJSON) { + return builderFor(name, dataType) + .nullable(nullable) + .comment(comment) + .identityColumnSpec(identityColumnSpec) + .metadata(metadataInJSON) + .build(); + } + + /** + * Creates a builder for a new column with the given name and data type. + * + * @param name the name of the column + * @param dataType the data type of the column + * @return a new builder + * @since 4.2.0 + */ + static Builder builderFor(String name, DataType dataType) { + return new Builder(name, dataType); + } + + /** + * Creates a builder with pre-populated info from an existing column. + * + * @param column the source column + * @return a new builder seeded with the column's current state + * @since 4.2.0 + */ + static Builder builderFrom(Column column) { + return new Builder(column.name(), column.dataType()) + .nullable(column.nullable()) + .comment(column.comment()) + .defaultValue(column.defaultValue()) + .generationExpression(column.generationExpression()) + .identityColumnSpec(column.identityColumnSpec()) + .metadata(column.metadataInJSON()) + .id(column.id()); } /** @@ -193,12 +211,99 @@ static Column create( * others. *

* This API covers top-level columns only. Nested struct fields, array elements, and map - * keys/values do not have separate IDs. Connectors that track nested field IDs can encode - * them into the returned top-level Column ID string to detect nested changes, since Spark - * only compares string equality. + * keys/values carry their own IDs in struct field metadata. Spark validates both top-level and + * nested struct field IDs as part of schema compatibility checks (array elements and map/key + * values' validation is not supported yet). See {@link StructField#id()}. */ @Nullable default String id() { return null; } + + /** + * A builder for {@link Column}. + * + * @since 4.2.0 + */ + class Builder { + private final String name; + private DataType dataType; + private boolean nullable = true; + private String comment = null; + private ColumnDefaultValue defaultValue = null; + private String genExpr = null; + private IdentityColumnSpec identityColumnSpec = null; + private String metadataInJSON = null; + private String id = null; + + private Builder(String name, DataType dataType) { + this.name = Objects.requireNonNull(name, "name must not be null"); + this.dataType = Objects.requireNonNull(dataType, "dataType must not be null"); + } + + public Builder nullable(boolean nullable) { + this.nullable = nullable; + return this; + } + + public Builder comment(String comment) { + this.comment = comment; + return this; + } + + public Builder defaultValue(ColumnDefaultValue defaultValue) { + this.defaultValue = defaultValue; + return this; + } + + public Builder generationExpression(String sql) { + this.genExpr = sql; + return this; + } + + public Builder identityColumnSpec(IdentityColumnSpec identityColumnSpec) { + this.identityColumnSpec = identityColumnSpec; + return this; + } + + public Builder metadata(String metadataInJSON) { + this.metadataInJSON = metadataInJSON; + return this; + } + + public Builder id(String id) { + this.id = id; + return this; + } + + public Builder clearIds() { + this.id = null; + this.dataType = SchemaUtils.clearFieldIds(dataType); + return this; + } + + public Column build() { + validateState(); + return new ColumnImpl( + name, dataType, nullable, comment, defaultValue, + genExpr, identityColumnSpec, metadataInJSON, id); + } + + private void validateState() { + if (hasConflictingDefinitions()) { + throw new SparkIllegalArgumentException( + "INTERNAL_ERROR", + Map.of("message", + "Column '" + name + "' cannot have more than one definition of: " + + "default value, generation expression, identity column spec")); + } + } + + private boolean hasConflictingDefinitions() { + long definitionCount = Stream.of(defaultValue, genExpr, identityColumnSpec) + .filter(Objects::nonNull) + .count(); + return definitionCount > 1; + } + } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataTable.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingTable.java similarity index 67% rename from sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataTable.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingTable.java index 1d1acfde80f9d..2e8d63eb16d94 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/MetadataTable.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/DelegatingTable.java @@ -27,40 +27,35 @@ import org.apache.spark.sql.connector.expressions.Transform; /** - * A concrete {@code Table} implementation that contains only table metadata, deferring - * read/write to Spark. It represents a general Spark data source table or a Spark view; - * Spark resolves the table provider into a data source (for tables) or expands the view text - * (for views) at read time. + * A concrete {@link Table} that adapts a {@link TableInfo} -- it contains only table metadata and + * defers read/write to Spark, which resolves the table provider into a data source at read time. *

- * Catalogs build the metadata via {@link TableInfo.Builder} (for data-source tables) or - * {@link ViewInfo.Builder} (for views). A {@code MetadataTable} wrapping a - * {@link TableInfo} can be returned from {@link TableCatalog#loadTable(Identifier)} for a - * data-source table; a {@code MetadataTable} wrapping a {@link ViewInfo} can be returned - * from {@link TableViewCatalog#loadTableOrView(Identifier)} as the single-RPC perf opt-in - * for a view. - * Downstream consumers distinguish the two by checking - * {@code getTableInfo() instanceof ViewInfo}. + * Catalogs build the metadata via {@link TableInfo.Builder} and return a {@code DelegatingTable} + * from {@link TableCatalog#loadTable(Identifier)} (or {@link RelationCatalog#loadRelation} for a + * data-source table) when they want Spark to handle the underlying source. A catalog that has its + * own {@link Table} object returns that instead. Views are never represented as a + * {@code DelegatingTable}: a view is a {@link View}, which is itself a {@link Relation}. * * @since 4.2.0 */ @Evolving -public class MetadataTable implements Table { +public class DelegatingTable implements Table { private final TableInfo info; private final String name; /** - * @param info metadata for the table or view. Pass a {@link ViewInfo} for a view. + * @param info the table metadata to delegate to. * @param name human-readable name for this table, returned by {@link #name()} and surfaced * in places that read it (e.g. {@code BatchScan} plan-tree labels and * partition-management error messages). {@code DESCRIBE TABLE EXTENDED} does * not read this field; it emits the resolved identifier as structured * {@code Catalog} / {@code Namespace} / {@code Table} rows. Catalogs returning - * a {@code MetadataTable} from {@link TableCatalog#loadTable} or - * {@link TableViewCatalog#loadTableOrView} should typically pass + * a {@code DelegatingTable} from {@link TableCatalog#loadTable} or + * {@link RelationCatalog#loadRelation} should typically pass * {@code ident.toString()}, matching the quoted multi-part form used elsewhere * for v2 identifiers. */ - public MetadataTable(TableInfo info, String name) { + public DelegatingTable(TableInfo info, String name) { this.info = Objects.requireNonNull(info, "info should not be null"); this.name = Objects.requireNonNull(name, "name should not be null"); } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Relation.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Relation.java new file mode 100644 index 0000000000000..747a22c489c89 --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Relation.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.catalog; + +import java.util.Map; + +import org.apache.spark.annotation.Evolving; + +/** + * A relation in a catalog: either a {@link Table} or a {@link View}. This is the common type + * returned by {@link RelationCatalog#loadRelation} so a catalog that exposes both kinds can + * answer a single read in one round trip; callers discriminate with {@code instanceof Table} / + * {@code instanceof View}. + *

+ * The two kinds are deliberately asymmetric, mirroring how Spark treats them: a {@link Table} is + * an object Spark reads from and writes to, while a {@link View} carries only metadata (Spark + * expands its query text at read time and never builds a view object). Modeling both as siblings + * of {@code Relation} -- rather than smuggling a view through the {@code Table} surface -- keeps + * table-only concepts (partitioning, constraints, scans, writes) off the view side. + *

+ * In practice the only two kinds are {@link Table} and {@link View}. {@code Relation} is left + * un-sealed because {@link Table} is itself an open interface (so a closed hierarchy would add + * little) and because a sealed Java interface trips Scala's pattern-match analysis. + * + * @since 4.2.0 + */ +@Evolving +public interface Relation { + + /** + * The columns of this relation. + */ + Column[] columns(); + + /** + * The string map of properties of this relation. + */ + Map properties(); +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/RelationBuilder.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/RelationBuilder.java new file mode 100644 index 0000000000000..e000ec7c3f5ac --- /dev/null +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/RelationBuilder.java @@ -0,0 +1,85 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.spark.sql.connector.catalog; + +import java.util.HashMap; +import java.util.Map; + +import org.apache.spark.sql.types.StructType; + +/** + * Shared builder state for {@link TableInfo.Builder} and {@link View.Builder} -- the fields and + * convenience setters common to building a table's metadata and a view: the columns and the + * string properties bag, plus the reserved-property setters that apply to both. Setters return + * {@code B} so subclass builders chain through their own type without a covariant override on each + * inherited setter. + *

+ * This is an internal implementation detail: tables and views do not share a value type (a + * {@link TableInfo} describes a table Spark will realize as a {@link Table}, while a {@link View} + * is itself a {@link Relation}), only this builder logic. + */ +abstract class RelationBuilder> { + protected Column[] columns = new Column[0]; + protected Map properties = new HashMap<>(); + + protected abstract B self(); + + public B withColumns(Column[] columns) { + this.columns = columns; + return self(); + } + + public B withSchema(StructType schema) { + this.columns = CatalogV2Util.structTypeToV2Columns(schema, true /* keep IDs */); + return self(); + } + + /** + * Replaces the current properties map with a defensive copy of the given map. Any reserved + * keys set earlier via convenience setters (e.g. {@link #withComment}) are discarded -- + * call those setters after this method, not before. + */ + public B withProperties(Map properties) { + this.properties = new HashMap<>(properties); + return self(); + } + + // Convenience setters below write reserved keys into the current `properties` map. Pair + // each with a preceding `withProperties(...)` call if you want to start from a user map; + // calling `withProperties` after a convenience setter discards the value the convenience + // setter wrote. + + public B withComment(String comment) { + properties.put(TableCatalog.PROP_COMMENT, comment); + return self(); + } + + public B withCollation(String collation) { + properties.put(TableCatalog.PROP_COLLATION, collation); + return self(); + } + + public B withOwner(String owner) { + properties.put(TableCatalog.PROP_OWNER, owner); + return self(); + } + + public B withTableType(String tableType) { + properties.put(TableCatalog.PROP_TABLE_TYPE, tableType); + return self(); + } +} diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableViewCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/RelationCatalog.java similarity index 73% rename from sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableViewCatalog.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/RelationCatalog.java index 45ec41d680d8b..e088e5d3a82c4 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableViewCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/RelationCatalog.java @@ -27,15 +27,15 @@ * Catalog API for connectors that expose both tables and views in a single shared identifier * namespace. *

- * Connectors that expose both tables and views must implement {@code TableViewCatalog}; + * Connectors that expose both tables and views must implement {@code RelationCatalog}; * implementing {@link TableCatalog} and {@link ViewCatalog} directly without - * {@code TableViewCatalog} is rejected at catalog initialization. Connectors that expose only + * {@code RelationCatalog} is rejected at catalog initialization. Connectors that expose only * tables implement just {@link TableCatalog}; connectors that expose only views implement just * {@link ViewCatalog}; this interface is not relevant to them. * *

Two principles

* - * A {@code TableViewCatalog} follows two rules that, taken together, define every cross-cutting + * A {@code RelationCatalog} follows two rules that, taken together, define every cross-cutting * subtlety: *
    *
  1. Orthogonal interfaces. Every {@link TableCatalog} method behaves as if views did @@ -102,14 +102,14 @@ *

    Single-RPC perf entry points

    * * The orthogonal {@link TableCatalog} and {@link ViewCatalog} answer two cross-cutting - * questions in two round trips each. {@code TableViewCatalog} adds dedicated methods so a + * questions in two round trips each. {@code RelationCatalog} adds dedicated methods so a * catalog can answer both in one round trip: *
      - *
    • {@link #loadTableOrView(Identifier)} -- the resolver's per-identifier read path. Returns - * a regular {@link Table} for a table, or a {@link MetadataTable} wrapping a - * {@link ViewInfo} for a view. Saves the {@code loadTable} -> {@code loadView} fallback - * on a cold cache.
    • - *
    • {@link #listTableAndViewSummaries(String[])} -- a unified listing of tables and views + *
    • {@link #loadRelation(Identifier)} -- the resolver's per-identifier read path. Returns a + * {@link Table} for a table or a {@link View} for a view; callers discriminate via + * {@code instanceof}. Saves the {@code loadTable} -> {@code loadView} fallback on a cold + * cache.
    • + *
    • {@link #listRelationSummaries(String[])} -- a unified listing of tables and views * with the kind preserved on each {@link TableSummary}. Default impl performs both * {@link TableCatalog#listTableSummaries} and {@link ViewCatalog#listViews}; override to * fetch in one round trip.
    • @@ -118,22 +118,21 @@ * @since 4.2.0 */ @Evolving -public interface TableViewCatalog extends TableCatalog, ViewCatalog { +public interface RelationCatalog extends TableCatalog, ViewCatalog { /** - * Load metadata for an identifier that may resolve to either a table or a view. + * Load the relation for an identifier that may resolve to either a table or a view. *

      - * For a table, returns the table's {@link Table}. For a view, returns a - * {@link MetadataTable} wrapping a {@link ViewInfo}; callers discriminate via - * {@code getTableInfo() instanceof ViewInfo}. This lets the resolver answer in a single RPC - * instead of falling back from {@link TableCatalog#loadTable} to {@link ViewCatalog#loadView}. + * Returns a {@link Table} for a table or a {@link View} for a view; callers discriminate via + * {@code instanceof Table} / {@code instanceof View}. This lets the resolver answer in a single + * RPC instead of falling back from {@link TableCatalog#loadTable} to + * {@link ViewCatalog#loadView}. * * @param ident the identifier - * @return a {@link Table} for tables, or a {@link MetadataTable} wrapping a - * {@link ViewInfo} for views + * @return a {@link Table} for tables, or a {@link View} for views * @throws NoSuchTableException if neither a table nor a view exists at {@code ident} */ - Table loadTableOrView(Identifier ident) throws NoSuchTableException; + Relation loadRelation(Identifier ident) throws NoSuchTableException; /** * List the tables and views in a namespace, returned as {@link TableSummary} entries with @@ -149,7 +148,7 @@ public interface TableViewCatalog extends TableCatalog, ViewCatalog { * @throws NoSuchTableException if a table listed by the underlying enumeration disappears * before its summary can be assembled (default impl only) */ - default TableSummary[] listTableAndViewSummaries(String[] namespace) + default TableSummary[] listRelationSummaries(String[] namespace) throws NoSuchNamespaceException, NoSuchTableException { TableSummary[] tableSummaries = listTableSummaries(namespace); Identifier[] viewIdentifiers = listViews(namespace); @@ -167,52 +166,47 @@ default TableSummary[] listTableAndViewSummaries(String[] namespace) /** * {@inheritDoc} *

      - * The default implementation derives from {@link #loadTableOrView}: a {@link MetadataTable} - * wrapping a {@link ViewInfo} is rejected as not-a-table; anything else is returned. Override - * only if a tables-only path is materially cheaper than the unified one. + * The default implementation derives from {@link #loadRelation}: a {@link View} is rejected as + * not-a-table; a {@link Table} is returned. Override only if a tables-only path is materially + * cheaper than the unified one. */ @Override default Table loadTable(Identifier ident) throws NoSuchTableException { - Table t = loadTableOrView(ident); - if (t instanceof MetadataTable mot && mot.getTableInfo() instanceof ViewInfo) { - throw new NoSuchTableException(ident); + if (loadRelation(ident) instanceof Table t) { + return t; } - return t; + throw new NoSuchTableException(ident); } /** * {@inheritDoc} *

      - * The default implementation derives from {@link #loadTableOrView}: a {@link MetadataTable} - * wrapping a {@link ViewInfo} is unwrapped and returned; anything else (table or absent) is - * surfaced as {@link NoSuchViewException}. Override only if a views-only path is materially - * cheaper than the unified one. + * The default implementation derives from {@link #loadRelation}: a {@link View} is returned; + * anything else (table or absent) is surfaced as {@link NoSuchViewException}. Override only if a + * views-only path is materially cheaper than the unified one. */ @Override - default ViewInfo loadView(Identifier ident) throws NoSuchViewException { - Table t; + default View loadView(Identifier ident) throws NoSuchViewException { try { - t = loadTableOrView(ident); + if (loadRelation(ident) instanceof View v) { + return v; + } } catch (NoSuchTableException e) { throw new NoSuchViewException(ident); } - if (t instanceof MetadataTable mot && mot.getTableInfo() instanceof ViewInfo vi) { - return vi; - } throw new NoSuchViewException(ident); } /** * {@inheritDoc} *

      - * The default implementation derives from {@link #loadTableOrView}: returns {@code true} only if - * the entry exists and is not a view. Override only if a cheaper existence-check path exists. + * The default implementation derives from {@link #loadRelation}: returns {@code true} only if + * the entry exists and is a table. Override only if a cheaper existence-check path exists. */ @Override default boolean tableExists(Identifier ident) { try { - Table t = loadTableOrView(ident); - return !(t instanceof MetadataTable mot && mot.getTableInfo() instanceof ViewInfo); + return loadRelation(ident) instanceof Table; } catch (NoSuchTableException e) { return false; } @@ -221,14 +215,13 @@ default boolean tableExists(Identifier ident) { /** * {@inheritDoc} *

      - * The default implementation derives from {@link #loadTableOrView}: returns {@code true} only if + * The default implementation derives from {@link #loadRelation}: returns {@code true} only if * the entry exists and is a view. Override only if a cheaper existence-check path exists. */ @Override default boolean viewExists(Identifier ident) { try { - Table t = loadTableOrView(ident); - return t instanceof MetadataTable mot && mot.getTableInfo() instanceof ViewInfo; + return loadRelation(ident) instanceof View; } catch (NoSuchTableException e) { return false; } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java index a298520760bc0..ec27bcf6c82e2 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/Table.java @@ -38,11 +38,13 @@ * The default implementation of {@link #partitioning()} returns an empty array of partitions, and * the default implementation of {@link #properties()} returns an empty map. These should be * overridden by implementations that support partitioning and table properties. + *

      + * A {@code Table} is one kind of {@link Relation}; the other is {@link View}. * * @since 3.0.0 */ @Evolving -public interface Table { +public interface Table extends Relation { /** * A name to identify this table. Implementations should provide a meaningful name, like the @@ -76,7 +78,7 @@ default StructType schema() { * empty array can be returned here. */ default Column[] columns() { - return CatalogV2Util.structTypeToV2Columns(schema()); + return CatalogV2Util.structTypeToV2Columns(schema(), true /* keep IDs */); } /** diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java index f64c34ee0e071..23e9499932c16 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableCatalog.java @@ -35,7 +35,7 @@ * Catalog API for connectors that expose tables. *

      * Connectors that expose only tables implement this interface. Connectors that expose - * both tables and views must implement {@link TableViewCatalog} (which extends both this + * both tables and views must implement {@link RelationCatalog} (which extends both this * interface and {@link ViewCatalog} and adds the cross-cutting contract for the combined * case); the methods on this interface remain table-only -- they do not interact with views. *

      diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableInfo.java index 89709c9f1c2f0..1a3261e3085af 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/TableInfo.java @@ -16,7 +16,6 @@ */ package org.apache.spark.sql.connector.catalog; -import java.util.HashMap; import java.util.Map; import java.util.Objects; @@ -24,6 +23,13 @@ import org.apache.spark.sql.connector.expressions.Transform; import org.apache.spark.sql.types.StructType; +/** + * Metadata describing a data-source table: its columns, properties, partitioning and constraints. + * Spark realizes a {@code TableInfo} into a {@link Table} via {@link DelegatingTable}; a catalog + * that has its own {@link Table} object returns that instead. Views are described by the sibling + * {@link View}, which -- unlike a table -- is itself a {@link Relation} because Spark never builds + * a view object. + */ public class TableInfo { private final Column[] columns; @@ -34,7 +40,7 @@ public class TableInfo { /** * Constructor for TableInfo used by the builder. */ - protected TableInfo(BaseBuilder builder) { + protected TableInfo(Builder builder) { this.columns = builder.columns; this.properties = builder.properties; this.partitions = builder.partitions; @@ -59,96 +65,37 @@ public Transform[] partitions() { public Constraint[] constraints() { return constraints; } - public static class Builder extends BaseBuilder { - @Override - protected Builder self() { return this; } - - @Override - public TableInfo build() { - Objects.requireNonNull(columns, "columns should not be null"); - return new TableInfo(this); - } - } - - /** - * Shared builder state for {@link TableInfo} and its subclasses. Setters return {@code B} so - * subclass builders (e.g. {@link ViewInfo.Builder}) chain through their own type without - * a covariant override on each inherited setter. - */ - protected abstract static class BaseBuilder> { - protected Column[] columns = new Column[0]; - protected Map properties = new HashMap<>(); + public static class Builder extends RelationBuilder { protected Transform[] partitions = new Transform[0]; protected Constraint[] constraints = new Constraint[0]; - protected abstract B self(); - - public B withColumns(Column[] columns) { - this.columns = columns; - return self(); - } - - public B withSchema(StructType schema) { - this.columns = CatalogV2Util.structTypeToV2Columns(schema); - return self(); - } - - /** - * Replaces the current properties map with a defensive copy of the given map. Any reserved - * keys set earlier via convenience setters (e.g. {@link #withProvider}) are discarded -- - * call those setters after this method, not before. - */ - public B withProperties(Map properties) { - this.properties = new HashMap<>(properties); - return self(); - } + @Override + protected Builder self() { return this; } - public B withPartitions(Transform[] partitions) { + public Builder withPartitions(Transform[] partitions) { this.partitions = partitions; - return self(); + return this; } - public B withConstraints(Constraint[] constraints) { + public Builder withConstraints(Constraint[] constraints) { this.constraints = constraints; - return self(); + return this; } - // Convenience setters below write reserved keys into the current `properties` map. Pair - // each with a preceding `withProperties(...)` call if you want to start from a user map; - // calling `withProperties` after a convenience setter discards the value the convenience - // setter wrote. - /** Writes {@link TableCatalog#PROP_PROVIDER} into the current properties map. */ - public B withProvider(String provider) { + public Builder withProvider(String provider) { properties.put(TableCatalog.PROP_PROVIDER, provider); - return self(); + return this; } - public B withLocation(String location) { + public Builder withLocation(String location) { properties.put(TableCatalog.PROP_LOCATION, location); - return self(); - } - - public B withComment(String comment) { - properties.put(TableCatalog.PROP_COMMENT, comment); - return self(); + return this; } - public B withCollation(String collation) { - properties.put(TableCatalog.PROP_COLLATION, collation); - return self(); - } - - public B withOwner(String owner) { - properties.put(TableCatalog.PROP_OWNER, owner); - return self(); - } - - public B withTableType(String tableType) { - properties.put(TableCatalog.PROP_TABLE_TYPE, tableType); - return self(); + public TableInfo build() { + Objects.requireNonNull(columns, "columns should not be null"); + return new TableInfo(this); } - - public abstract TableInfo build(); } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ViewInfo.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/View.java similarity index 79% rename from sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ViewInfo.java rename to sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/View.java index 0f46e915a9be2..d4adef9a98c85 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ViewInfo.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/View.java @@ -22,26 +22,28 @@ import java.util.Objects; import org.apache.spark.annotation.Evolving; +import org.apache.spark.sql.types.StructType; /** - * View metadata DTO -- the typed payload returned by {@link ViewCatalog#loadView} and accepted - * by {@link ViewCatalog#createView} / {@link ViewCatalog#replaceView}. Carries the + * A view in a catalog -- the typed payload returned by {@link ViewCatalog#loadView} and accepted + * by {@link ViewCatalog#createView} / {@link ViewCatalog#replaceView}. A {@code View} carries the * view-specific fields that cannot be represented as string table properties: the query text, * captured creation-time resolution context, captured SQL configs, schema-binding mode, and - * query output column names. Schema and user TBLPROPERTIES are inherited from {@link TableInfo} - * via the typed builder. + * query output column names. Columns and user TBLPROPERTIES are set via the typed builder. *

      - * {@code ViewInfo} extends {@link TableInfo} so that a {@link TableViewCatalog} can opt into the - * single-RPC perf path by returning a {@link MetadataTable} wrapping a {@code ViewInfo} - * from {@link TableViewCatalog#loadTableOrView} for a view identifier. Pure {@link ViewCatalog} - * implementations never see {@code TableInfo}; the typed setters on {@link Builder} cover - * everything they need to construct a {@code ViewInfo}. + * Unlike a {@link Table}, a {@code View} is itself a {@link Relation} rather than something Spark + * realizes into a {@code Table}: Spark expands the view's query text at read time and never builds + * a view object. A {@link RelationCatalog} returns a {@code View} directly from + * {@link RelationCatalog#loadRelation} for a view identifier, so it never has to smuggle a view + * through the {@code Table} surface. * * @since 4.2.0 */ @Evolving -public class ViewInfo extends TableInfo { +public class View implements Relation { + private final Column[] columns; + private final Map properties; private final String queryText; private final String currentCatalog; private final String[] currentNamespace; @@ -50,8 +52,9 @@ public class ViewInfo extends TableInfo { private final String[] queryColumnNames; private final DependencyList viewDependencies; - protected ViewInfo(Builder builder) { - super(builder); + protected View(Builder builder) { + this.columns = builder.columns; + this.properties = builder.properties; this.queryText = Objects.requireNonNull(builder.queryText, "queryText should not be null"); this.currentCatalog = builder.currentCatalog; this.currentNamespace = builder.currentNamespace; @@ -59,10 +62,24 @@ protected ViewInfo(Builder builder) { this.schemaMode = builder.schemaMode; this.queryColumnNames = builder.queryColumnNames; this.viewDependencies = builder.viewDependencies; - // Default PROP_TABLE_TYPE = VIEW so `properties()` reflects the typed ViewInfo - // classification. Callers can refine to a more specific view kind (for example, - // METRIC_VIEW) by calling BaseBuilder.withTableType(...) on the builder before build(). - properties().putIfAbsent(TableCatalog.PROP_TABLE_TYPE, TableSummary.VIEW_TABLE_TYPE); + // Default PROP_TABLE_TYPE = VIEW so `properties()` reflects the typed View classification. + // Callers can refine to a more specific view kind (for example, METRIC_VIEW) by calling + // RelationBuilder.withTableType(...) on the builder before build(). + properties.putIfAbsent(TableCatalog.PROP_TABLE_TYPE, TableSummary.VIEW_TABLE_TYPE); + } + + @Override + public Column[] columns() { + return columns; + } + + public StructType schema() { + return CatalogV2Util.v2ColumnsToStructType(columns); + } + + @Override + public Map properties() { + return properties; } /** The SQL text of the view. */ @@ -111,7 +128,7 @@ protected ViewInfo(Builder builder) { */ public DependencyList viewDependencies() { return viewDependencies; } - public static class Builder extends BaseBuilder { + public static class Builder extends RelationBuilder { private String queryText; private String currentCatalog; private String[] currentNamespace = new String[0]; @@ -163,10 +180,9 @@ public Builder withViewDependencies(DependencyList viewDependencies) { return this; } - @Override - public ViewInfo build() { + public View build() { Objects.requireNonNull(columns, "columns should not be null"); - return new ViewInfo(this); + return new View(this); } } } diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ViewCatalog.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ViewCatalog.java index 0e74b22079bfa..12782ad7f1314 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ViewCatalog.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/ViewCatalog.java @@ -25,7 +25,7 @@ * Catalog API for connectors that expose views. *

      * Connectors that expose only views implement this interface. Connectors that expose - * both tables and views must implement {@link TableViewCatalog} (which extends both this + * both tables and views must implement {@link RelationCatalog} (which extends both this * interface and {@link TableCatalog} and adds the cross-cutting contract for the combined * case); the methods on this interface remain view-only -- they do not interact with tables. *

      @@ -53,7 +53,7 @@ public interface ViewCatalog extends CatalogPlugin { * @return the view metadata * @throws NoSuchViewException if the view does not exist */ - ViewInfo loadView(Identifier ident) throws NoSuchViewException; + View loadView(Identifier ident) throws NoSuchViewException; /** * Test whether a view exists. @@ -93,7 +93,7 @@ default void invalidateView(Identifier ident) { * @throws ViewAlreadyExistsException if a view already exists at {@code ident} * @throws NoSuchNamespaceException if the identifier's namespace does not exist (optional) */ - ViewInfo createView(Identifier ident, ViewInfo info) + View createView(Identifier ident, View info) throws ViewAlreadyExistsException, NoSuchNamespaceException; /** @@ -108,7 +108,7 @@ ViewInfo createView(Identifier ident, ViewInfo info) * @return the metadata of the replaced view; may equal {@code info} * @throws NoSuchViewException if no view exists at {@code ident} */ - ViewInfo replaceView(Identifier ident, ViewInfo info) throws NoSuchViewException; + View replaceView(Identifier ident, View info) throws NoSuchViewException; /** * Create a view if one does not exist at {@code ident}, or atomically replace it if one does. @@ -126,10 +126,10 @@ ViewInfo createView(Identifier ident, ViewInfo info) * concurrent {@code CREATE VIEW} won the race in the * default impl's gap between {@link #replaceView} and * the fallback {@link #createView}, or, in a - * {@link TableViewCatalog}, a table sits at {@code ident} + * {@link RelationCatalog}, a table sits at {@code ident} * @throws NoSuchNamespaceException if the identifier's namespace does not exist (optional) */ - default ViewInfo createOrReplaceView(Identifier ident, ViewInfo info) + default View createOrReplaceView(Identifier ident, View info) throws ViewAlreadyExistsException, NoSuchNamespaceException { try { return replaceView(ident, info); @@ -152,12 +152,12 @@ default ViewInfo createOrReplaceView(Identifier ident, ViewInfo info) * If the catalog supports tables and contains a table at the new identifier, this must throw * {@link ViewAlreadyExistsException}. If the source identifier resolves to a table rather than * a view, this must throw {@link NoSuchViewException}. The cross-type contract for catalogs - * that expose both tables and views lives on {@link TableViewCatalog}. + * that expose both tables and views lives on {@link RelationCatalog}. * * @param oldIdent the view identifier of the existing view to rename * @param newIdent the new view identifier * @throws NoSuchViewException if no view exists at {@code oldIdent} - * @throws ViewAlreadyExistsException if a view (or, in a {@link TableViewCatalog}, a table) + * @throws ViewAlreadyExistsException if a view (or, in a {@link RelationCatalog}, a table) * already exists at {@code newIdent} */ void renameView(Identifier oldIdent, Identifier newIdent) diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java index ae005d946694a..5addd4b09842d 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/connector/catalog/constraints/Check.java @@ -27,12 +27,13 @@ /** * A CHECK constraint. *

      - * A CHECK constraint defines a condition each row in a table must satisfy. Connectors can define - * such constraints either in SQL (Spark SQL dialect) or using a {@link Predicate predicate} if the - * condition can be expressed using a supported expression. A CHECK constraint can reference one or - * more columns. Such constraint is considered violated if its condition evaluates to {@code FALSE}, - * but not {@code NULL}. The search condition must be deterministic and cannot contain subqueries - * and certain functions like aggregates or UDFs. + * A CHECK constraint defines a condition each row in a table must satisfy. The condition is always + * represented as a SQL string (Spark SQL dialect), accessible via {@link #predicateSql()}, and is + * additionally exposed as a {@link Predicate predicate} via {@link #predicate()} whenever it can be + * expressed using supported expressions. A CHECK constraint can reference one or more columns. Such + * constraint is considered violated if its condition evaluates to {@code FALSE}, but not + * {@code NULL}. The search condition must be deterministic and cannot contain subqueries and + * certain functions like aggregates or UDFs. *

      * Spark supports enforced and not enforced CHECK constraints, allowing connectors to control * whether data modifications that violate the constraint must fail. Each constraint is either @@ -63,13 +64,19 @@ private Check( /** * Returns the SQL representation of the search condition (Spark SQL dialect). + *

      + * This is the canonical representation of the condition and is always present (never + * {@code null}). The optional {@link #predicate()} provides a structured form when the condition + * can be expressed using supported {@link Predicate} expressions. */ public String predicateSql() { return predicateSql; } /** - * Returns the search condition. + * Returns the search condition as a {@link Predicate}, or {@code null} if the condition cannot be + * expressed using supported predicate expressions. Use {@link #predicateSql()} for the canonical + * SQL representation, which is always present. */ public Predicate predicate() { return predicate; @@ -77,7 +84,7 @@ public Predicate predicate() { @Override protected String definition() { - return String.format("CHECK (%s)", predicateSql != null ? predicateSql : predicate); + return String.format("CHECK (%s)", predicateSql); } @Override @@ -123,10 +130,10 @@ public Builder predicate(Predicate predicate) { } public Check build() { - if (predicateSql == null && predicate == null) { + if (predicateSql == null) { throw new SparkIllegalArgumentException( "INTERNAL_ERROR", - Map.of("message", "Predicate SQL and expression can't be both null in CHECK")); + Map.of("message", "Predicate SQL can't be null in CHECK")); } return new Check(name(), predicateSql, predicate, enforced(), validationStatus(), rely()); } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index 651fd06c898a4..1e3a150348efe 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -51,7 +51,9 @@ import org.apache.spark.sql.catalyst.trees.TreePattern._ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{toPrettySQL, trimTempResolvedColumn, CharVarcharUtils} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ -import org.apache.spark.sql.connector.catalog._ +// `View` is aliased to `V2View` to avoid clashing with the logical-plan `View` imported via +// `org.apache.spark.sql.catalyst.plans.logical._`. +import org.apache.spark.sql.connector.catalog.{View => V2View, _} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.connector.catalog.TableChange.{After, ColumnPosition} import org.apache.spark.sql.connector.catalog.functions.UnboundFunction @@ -1186,9 +1188,9 @@ class Analyzer( * so surfacing a downstream "view not found" would hide the real reason. * * Lookup order against a non-session catalog: - * 1. If the catalog is a [[TableViewCatalog]], [[TableViewCatalog.loadTableOrView]] is called - * once. A returned [[MetadataTable]] wrapping a [[ViewInfo]] is interpreted as a - * view; other results are tables. + * 1. If the catalog is a [[RelationCatalog]], [[RelationCatalog.loadRelation]] is called + * once. A returned [[org.apache.spark.sql.connector.catalog.View]] is interpreted as a + * view; a [[Table]] is a table. * 2. Otherwise, [[TableCatalog.loadTable]] is tried (when implemented), then * [[ViewCatalog.loadView]] as the fallback view-resolution path (when implemented). */ @@ -1205,17 +1207,18 @@ class Analyzer( throw QueryCompilationErrors.missingCatalogViewsAbilityError(catalog) } catalog match { - case mc: TableViewCatalog => - // Single-RPC perf path: loadTableOrView returns a Table for a table or a - // MetadataTable wrapping a ViewInfo for a view. NoSuchTable means - // neither exists. + case mc: RelationCatalog => + // Single-RPC perf path: loadRelation returns a Table for a table or a View + // for a view. NoSuchTable means neither exists. try { - Some(mc.loadTableOrView(ident) match { - case t: MetadataTable if t.getTableInfo.isInstanceOf[ViewInfo] => - ResolvedPersistentView( - catalog, ident, t.getTableInfo.asInstanceOf[ViewInfo]) - case table => + Some(mc.loadRelation(ident) match { + case v: V2View => + ResolvedPersistentView(catalog, ident, v) + case table: Table => ResolvedTable.create(catalog.asTableCatalog, ident, table) + case other => throw SparkException.internalError( + s"Catalog ${catalog.name} returned an unexpected relation type for " + + s"$ident: ${other.getClass.getName}. Expected a Table or a View.") }) } catch { case _: NoSuchTableException => None @@ -1234,7 +1237,7 @@ class Analyzer( val v1Ident = v1Table.catalogTable.identifier val v2Ident = Identifier.of(v1Ident.database.toArray, v1Ident.identifier) ResolvedPersistentView( - catalog, v2Ident, new V1ViewInfo(v1Table.catalogTable)) + catalog, v2Ident, new V1View(v1Table.catalogTable)) case table => ResolvedTable.create(catalog.asTableCatalog, ident, table) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala index 55a7ad10790ea..22bed3fbe7699 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/RelationResolution.scala @@ -34,16 +34,17 @@ import org.apache.spark.sql.connector.catalog.{ CatalogPlugin, CatalogV2Util, ChangelogContext, + DelegatingTable, Identifier, LookupCatalog, - MetadataTable, + Relation, + RelationCatalog, Table, TableCatalog, - TableViewCatalog, V1Table, V2TableWithV1Fallback, - ViewCatalog, - ViewInfo + View, + ViewCatalog } import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.{DataTypeErrorsBase, QueryCompilationErrors} @@ -243,8 +244,8 @@ class RelationResolution( .orElse { val writePrivileges = u.options.get(UnresolvedRelation.REQUIRED_WRITE_PRIVILEGES) val finalOptions = u.clearWritePrivileges.options - // For a `TableViewCatalog` with no time-travel / write privileges, the single-RPC - // `loadTableOrView` answers both "is there a table?" and "is there a view?" in one + // For a `RelationCatalog` with no time-travel / write privileges, the single-RPC + // `loadRelation` answers both "is there a table?" and "is there a view?" in one // call. Time-travel and write privileges apply to tables only, so for those the // lookup falls through to the table-only `loadTable` path below; views are not // reachable via the v2 fallback in those cases. @@ -252,10 +253,10 @@ class RelationResolution( // Skip the table-side lookup entirely for view-only catalogs (no `TableCatalog` // mixin): `CatalogV2Util.loadTable` would call `asTableCatalog` and throw // MISSING_CATALOG_ABILITY.TABLES, masking the legitimate view-resolution path. - val tableOrView: Option[Table] = catalog match { - case mc: TableViewCatalog if finalTimeTravelSpec.isEmpty && writePrivileges == null => + val relation: Option[Relation] = catalog match { + case mc: RelationCatalog if finalTimeTravelSpec.isEmpty && writePrivileges == null => try { - Some(mc.loadTableOrView(ident)) + Some(mc.loadRelation(ident)) } catch { case _: NoSuchTableException => None } @@ -280,7 +281,7 @@ class RelationResolution( catalog match { case vc: ViewCatalog => try { - Some(new MetadataTable(vc.loadView(ident), ident.toString)) + Some(vc.loadView(ident)) } catch { case _: NoSuchViewException => None } @@ -291,12 +292,9 @@ class RelationResolution( } } } - // `table` is `tableOrView` filtered to tables only -- used for cache lookup since + // `table` is `relation` filtered to tables only -- used for cache lookup since // we don't share-cache views. - val table: Option[Table] = tableOrView.filter { - case t: MetadataTable if t.getTableInfo.isInstanceOf[ViewInfo] => false - case _ => true - } + val table: Option[Table] = relation.collect { case t: Table => t } val sharedRelationCacheMatch = for { t <- table @@ -314,7 +312,7 @@ class RelationResolution( val loaded = createRelation( catalog, ident, - tableOrView, + relation, finalOptions, u.isStreaming, finalTimeTravelSpec) @@ -373,7 +371,7 @@ class RelationResolution( private def createRelation( catalog: CatalogPlugin, ident: Identifier, - table: Option[Table], + relation: Option[Relation], options: CaseInsensitiveStringMap, isStreaming: Boolean, timeTravelSpec: Option[TimeTravelSpec]): Option[LogicalPlan] = { @@ -393,7 +391,12 @@ class RelationResolution( } } - table.map { + relation.map { + // A view is interpreted via v1: project it to a `CatalogTable` and run the v1 scan path, + // which expands the view text. + case v: View => + createDataSourceV1Scan(V1Table.toCatalogTable(catalog, ident, v)) + // To utilize this code path to execute V1 commands, e.g. INSERT, // either it must be session catalog, or tracksPartitionsInCatalog // must be false so it does not require use catalog to manage partitions. @@ -405,13 +408,13 @@ class RelationResolution( || !v1Table.catalogTable.tracksPartitionsInCatalog => createDataSourceV1Scan(v1Table.v1Table) - // MetadataTable is a sentinel meaning "interpret via v1", so unlike the V1Table + // DelegatingTable is a sentinel meaning "interpret via v1", so unlike the V1Table // case above we apply no session-catalog / tracksPartitionsInCatalog guard -- any catalog - // returning MetadataTable has opted into v1 read semantics. - case t: MetadataTable => + // returning DelegatingTable has opted into v1 read semantics. + case t: DelegatingTable => createDataSourceV1Scan(V1Table.toCatalogTable(catalog, ident, t)) - case table => + case table: Table => if (isStreaming) { assert(timeTravelSpec.isEmpty, "time travel is not allowed in streaming") val v1Fallback = table match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala index 223e7012af6b6..7bad6d149c602 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/V2TableReference.scala @@ -151,15 +151,6 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { // Make sure the table was not dropped and recreated. ref.info.tableId.foreach(V2TableUtil.validateTableId(ref.name, _, table)) - // Detect columns that were dropped and re-added with the same name but a different - // column ID. This catches replacements that preserve the schema but change identity. - val colIdErrors = V2TableUtil.validateColumnIds( - table = table, - originalCapturedCols = ref.info.columns) - if (colIdErrors.nonEmpty) { - throw QueryCompilationErrors.columnIdMismatchAfterAnalysis(ref.name, colIdErrors) - } - // Do not allow schema evolution to pre-analysed dataframes that are later used in // transactional writes. This is because the entire plans was built based on the original schema // and any schema change would make the plan structurally invalid. This is inline with the @@ -167,12 +158,17 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { val dataErrors = V2TableUtil.validateCapturedColumns( table = table, originCols = ref.info.columns, - mode = PROHIBIT_CHANGES) + mode = PROHIBIT_CHANGES, + checkIds = true) if (dataErrors.nonEmpty) { - throw QueryCompilationErrors.columnsMissingOrAddedAfterAnalysis(ref.name, dataErrors) + throw QueryCompilationErrors.columnsChangedAfterAnalysis(ref.name, dataErrors) } - val metaErrors = V2TableUtil.validateCapturedMetadataColumns(table, ref.info.metadataColumns) + val metaErrors = V2TableUtil.validateCapturedMetadataColumns( + table, + ref.info.metadataColumns, + mode = PROHIBIT_CHANGES, + checkIds = true) if (metaErrors.nonEmpty) { throw QueryCompilationErrors.metadataColumnsChangedAfterAnalysis(ref.name, metaErrors) } @@ -187,7 +183,8 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { val dataErrors = V2TableUtil.validateCapturedColumns( table, ref.info.columns, - mode = ALLOW_NEW_TOP_LEVEL_FIELDS) + mode = ALLOW_NEW_TOP_LEVEL_FIELDS, + checkIds = false) if (dataErrors.nonEmpty) { throw QueryCompilationErrors.columnsChangedAfterViewWithPlanCreation( ctx.viewName, @@ -195,7 +192,11 @@ private[sql] object V2TableReferenceUtils extends SQLConfHelper { dataErrors) } - val metaErrors = V2TableUtil.validateCapturedMetadataColumns(table, ref.info.metadataColumns) + val metaErrors = V2TableUtil.validateCapturedMetadataColumns( + table, + ref.info.metadataColumns, + mode = PROHIBIT_CHANGES, // metadata columns are projected on demand + checkIds = false) if (metaErrors.nonEmpty) { throw QueryCompilationErrors.metadataColumnsChangedAfterViewWithPlanCreation( ctx.viewName, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala index a8f5f0688890c..43370748e2b3c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/v2ResolutionPlans.scala @@ -229,19 +229,19 @@ case class ResolvedProcedure( /** * A plan containing a resolved persistent view. * - * `info` is the typed v2 [[org.apache.spark.sql.connector.catalog.ViewInfo]] payload for the + * `info` is the typed v2 [[org.apache.spark.sql.connector.catalog.View]] payload for the * view. Session-catalog (v1) views are surfaced through the same channel via - * [[org.apache.spark.sql.connector.catalog.V1ViewInfo]], which extends `ViewInfo` and wraps + * [[org.apache.spark.sql.connector.catalog.V1View]], which extends `View` and wraps * the original [[CatalogTable]] -- mirroring the way * [[org.apache.spark.sql.connector.catalog.V1Table]] exposes a v1 `CatalogTable` through the * v2 [[org.apache.spark.sql.connector.catalog.Table]] surface for `ResolvedTable`. v1-only * paths (e.g. `DescribeTableCommand`, `ShowCreateTableCommand`) recover the original - * `CatalogTable` by pattern-matching `info` against `V1ViewInfo`. + * `CatalogTable` by pattern-matching `info` against `V1View`. */ case class ResolvedPersistentView( catalog: CatalogPlugin, identifier: Identifier, - info: org.apache.spark.sql.connector.catalog.ViewInfo) + info: org.apache.spark.sql.connector.catalog.View) extends LeafNodeWithoutStats { // Surface the view's schema as `output` so `ResolveReferences` can resolve column references // against it (e.g. `DescribeColumn(ResolvedPersistentView, UnresolvedAttribute, ...)`). The @@ -252,7 +252,7 @@ case class ResolvedPersistentView( toAttributes(CharVarcharUtils.replaceCharVarcharWithStringInSchema(info.schema)) // Render `info` in plan-tree output as the qualified view name. The default case-class - // `toString` would format `info` via `Object.toString`, which produces `V1ViewInfo@` + // `toString` would format `info` via `Object.toString`, which produces `V1View@` // for the v1 leg and a similarly opaque hash for the v2 leg -- non-deterministic and useless // in EXPLAIN / golden file output. Replace it with the multi-part `catalog.namespace.name` // form so EXPLAIN, plan-tree dumps, and `SQLQueryTestSuite` golden files remain stable. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala index e05a9bfcc66da..4eda28f2ff22b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/SessionCatalog.scala @@ -1125,7 +1125,7 @@ class SessionCatalog( // so the SubqueryAlias qualifier reflects the real catalog + multi-part namespace. // Fall back to the historical 3-part form for v1 session-catalog tables -- we intentionally // always include `SESSION_CATALOG_NAME` here and ignore - // `LEGACY_NON_IDENTIFIER_OUTPUT_CATALOG_NAME` to preserve pre-v2-MetadataTable behavior. + // `LEGACY_NON_IDENTIFIER_OUTPUT_CATALOG_NAME` to preserve pre-v2-DelegatingTable behavior. val multiParts = metadata.multipartIdentifier.getOrElse { val qualifiedIdent = qualifyIdentifier(metadata.identifier) Seq(CatalogManager.SESSION_CATALOG_NAME, qualifiedIdent.database.get, qualifiedIdent.table) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 6dda153985e56..247d4124dae4a 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -447,7 +447,7 @@ case class CatalogTable( ignoredProperties: Map[String, String] = Map.empty, viewOriginalText: Option[String] = None, // Multi-part identifier [catalog, namespace..., name] for tables synthesized from a v2 - // `MetadataTable` whose namespace has more than one part -- the v1 `identifier: + // `DelegatingTable` whose namespace has more than one part -- the v1 `identifier: // TableIdentifier` (single-string database) cannot carry that losslessly. `None` for // v1-native tables; callers should use `fullIdent` which falls back to `identifier.nameParts`. multipartIdentifier: Option[Seq[String]] = None) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala index 04052dafb61ae..d0c7231d63604 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/ToStringBase.scala @@ -22,12 +22,12 @@ import java.time.ZoneOffset import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen._ import org.apache.spark.sql.catalyst.expressions.codegen.Block._ +import org.apache.spark.sql.catalyst.types.ops.TypeApiOps import org.apache.spark.sql.catalyst.util.{ArrayData, CharVarcharCodegenUtils, DateFormatter, FractionTimeFormatter, IntervalStringStyles, IntervalUtils, MapData, TimestampFormatter} import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.BinaryOutputStyle import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.ops.TypeApiOps import org.apache.spark.unsafe.UTF8StringBuilder import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String} import org.apache.spark.util.ArrayImplicits._ diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala index 7a4f04bf04f7a..30e00ac68004b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/aggregate/Mode.scala @@ -20,17 +20,30 @@ package org.apache.spark.sql.catalyst.expressions.aggregate import org.apache.spark.SparkIllegalArgumentException import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedWithinGroup} -import org.apache.spark.sql.catalyst.expressions.{Ascending, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder} +import org.apache.spark.sql.catalyst.expressions.{Ascending, BoundReference, Descending, Expression, ExpressionDescription, ImplicitCastInputTypes, SortOrder, UnsafeProjection} import org.apache.spark.sql.catalyst.expressions.Cast.toSQLExpr +import org.apache.spark.sql.catalyst.optimizer.NormalizeFloatingNumbers import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.catalyst.types.PhysicalDataType import org.apache.spark.sql.catalyst.util.{ArrayData, CollationFactory, GenericArrayData, MapData, UnsafeRowUtils} import org.apache.spark.sql.errors.DataTypeErrors.toSQLType import org.apache.spark.sql.errors.QueryCompilationErrors -import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, MapType, StringType, StructField, StructType} +import org.apache.spark.sql.types.{AbstractDataType, AnyDataType, ArrayType, BooleanType, DataType, DoubleType, FloatType, MapType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.collection.OpenHashMap +private[aggregate] object ModeKeyNormalizer { + def forType(dataType: DataType): Any => Any = dataType match { + case DoubleType => NormalizeFloatingNumbers.DOUBLE_NORMALIZER + case FloatType => NormalizeFloatingNumbers.FLOAT_NORMALIZER + case dt if NormalizeFloatingNumbers.needNormalize(dt) => + val ref = BoundReference(0, dt, nullable = true) + val proj = UnsafeProjection.create(NormalizeFloatingNumbers.normalize(ref)) + (value: Any) => InternalRow.copyValue(proj(InternalRow(value)).get(0, dt)) + case _ => (value: Any) => InternalRow.copyValue(value) + } +} + case class Mode( child: Expression, mutableAggBufferOffset: Int = 0, @@ -54,13 +67,16 @@ case class Mode( override def prettyName: String = "mode" + @transient private lazy val keyNormalizer: Any => Any = + ModeKeyNormalizer.forType(child.dataType) + override def update( buffer: OpenHashMap[AnyRef, Long], input: InternalRow): OpenHashMap[AnyRef, Long] = { val key = child.eval(input) if (key != null) { - buffer.changeValue(InternalRow.copyValue(key).asInstanceOf[AnyRef], 1L, _ + 1L) + buffer.changeValue(keyNormalizer(key).asInstanceOf[AnyRef], 1L, _ + 1L) } buffer } @@ -299,13 +315,16 @@ case class PandasMode( override def prettyName: String = "pandas_mode" + @transient private lazy val keyNormalizer: Any => Any = + ModeKeyNormalizer.forType(child.dataType) + override def update( buffer: OpenHashMap[AnyRef, Long], input: InternalRow): OpenHashMap[AnyRef, Long] = { val key = child.eval(input) if (key != null) { - buffer.changeValue(InternalRow.copyValue(key).asInstanceOf[AnyRef], 1L, _ + 1L) + buffer.changeValue(keyNormalizer(key).asInstanceOf[AnyRef], 1L, _ + 1L) } else if (!ignoreNA) { buffer.changeValue(null, 1L, _ + 1L) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala index 7d5fd9fe57913..1d25788fdb6c2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NormalizeFloatingNumbers.scala @@ -28,36 +28,40 @@ import org.apache.spark.sql.types._ import org.apache.spark.util.ArrayImplicits._ /** - * We need to take care of special floating numbers (NaN and -0.0) in several places: - * 1. When compare values, different NaNs should be treated as same, `-0.0` and `0.0` should be - * treated as same. - * 2. In aggregate grouping keys, different NaNs should belong to the same group, `-0.0` and `0.0` - * should belong to the same group. - * 3. In join keys, different NaNs should be treated as same, `-0.0` and `0.0` should be - * treated as same. - * 4. In window partition keys, different NaNs should belong to the same partition, `-0.0` - * and `0.0` should belong to the same partition. - * 5. In hash-based array set operations, different NaNs should be treated as same, `-0.0` - * and `0.0` should be treated as same. + * Certain pairs of floating point numbers require special handling: + * 1. 0.0 / -0.0 + * 2. NaN / NaN + * That's because we want to treat each of these pairs of numbers as equal, even though they have + * different binary representations. (Note that IEEE 754 allows multiple distinct bit patterns for + * NaN.) * - * Case 1 is fine, as we handle NaN and `-0.0` well during comparison. For complex types, we - * recursively compare the fields/elements, so it's also fine. + * There are multiple ways we compare values that require careful handling of the above floating + * point pairs: + * 1. Directly, via `==` or via methods of `java.lang.{type}` + * 2. Via raw bytes in instances of [[org.apache.spark.sql.catalyst.expressions.UnsafeRow]] + * 3. Via hash sets * - * Case 2, 3 and 4 are problematic, as Spark SQL turns grouping/join/window partition keys into - * binary `UnsafeRow` and compare the binary data directly. Different NaNs have different binary - * representation, and the same thing happens for `-0.0` and `0.0`. + * This special handling is required in several places where we compare values via one of the above + * methods: + * 1. When comparing values (direct) + * 2. When grouping keys for aggregates (`UnsafeRow`) + * 3. When joining on keys (`UnsafeRow`) + * 4. When partitioning keys for windows (`UnsafeRow`) + * 5. When executing array set operations (hash sets) * - * Case 5 is problematic for a similar reason: hash-based array set operations compare elements by - * their binary representation via hash sets. + * Case 1 is handled in [[org.apache.spark.sql.catalyst.util.SQLOrderingUtil]] and in the + * `genEqual` method of [[org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext]]. + * Case 2 is handled during planning in the `Aggregation` and `StatefulAggregationStrategy` objects + * of [[org.apache.spark.sql.execution.SparkStrategies]]. + * Cases 3-5 are handled by this optimizer rule. * * This rule runs in two places: - * 1. Early in `FinishAnalysis` (right after `ReplaceExpressions` and before `EvalInlineTables`) - * so that array set-like operations are wrapped before optimizer rules that pre-evaluate - * expressions (e.g. `ConstantFolding`, `ConvertToLocalRelation`, `EvalInlineTables`). - * - * 2. As a late batch at the end of the optimizer, because rules like subquery rewrite and - * join reorder can create new joins or join conditions after `FinishAnalysis` that still - * need their keys to be normalized. + * 1. Early in `FinishAnalysis` (right after `ReplaceExpressions` and before `EvalInlineTables`) + * so that array set-like operations are wrapped before optimizer rules that pre-evaluate + * expressions (e.g. `ConstantFolding`, `ConvertToLocalRelation`, `EvalInlineTables`). + * 2. As a late batch at the end of the optimizer, because rules like subquery rewrite and + * join reorder can create new joins or join conditions after `FinishAnalysis` that still + * need their keys to be normalized. * * Ideally we should do the normalization in the physical operators that compare the * binary `UnsafeRow` directly. We don't need this normalization if the Spark SQL execution engine @@ -92,7 +96,8 @@ object NormalizeFloatingNumbers extends Rule[LogicalPlan] { // TODO: ideally Aggregate should also be handled here, but its grouping expressions are // mixed in its aggregate expressions. It's unreliable to change the grouping expressions - // here. For now we normalize grouping expressions in `AggUtils` during planning. + // here. For now we normalize grouping expressions during planning. See Case 2 in the + // Scaladoc just above. } .transformAllExpressionsWithPruning(_.containsAnyPattern( ARRAY_DISTINCT, ARRAY_UNION, ARRAY_INTERSECT, ARRAY_EXCEPT, ARRAYS_OVERLAP)) { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala index 3bef1658c803b..b814087ea792d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala @@ -1547,9 +1547,16 @@ object CollapseProject extends Rule[LogicalPlan] with AliasHelper { /** * Check if the given expression is cheap that we can inline it. + * + * This is consumed both by logical-stage callers (which only ever see `Attribute`) and by the + * `FilterExec` whole-stage-codegen CSE gate, which runs on predicates already bound for codegen + * and so sees `BoundReference` instead. The `BoundReference` branch therefore only fires on the + * codegen path -- logical plans never carry `BoundReference` -- and leaves the logical callers + * unaffected. */ def isCheap(e: Expression): Boolean = e match { - case _: Attribute | _: OuterReference => true + // `BoundReference` is the codegen-bound form of an `Attribute`; a slot read, equally cheap. + case _: Attribute | _: OuterReference | _: BoundReference => true case _ if e.foldable => true // PythonUDF is handled by the rule ExtractPythonUDFs case _: PythonUDF => diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala index ee8e61b457fa8..974c5191ad996 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoin.scala @@ -92,11 +92,26 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] { // cross-product today -- should not silently bypass that choice. val join = Join(taggedLeft, right, joinType, None, JoinHint.NONE) - val (aggInput, rankingForAgg) = if (!rankingExpression.deterministic) { - val rankingAlias = Alias(rankingExpression, "__ranking__")() + // A LEFT OUTER join widens the right-side columns to nullable. The synthesized Aggregate + // (and the optional `__ranking__` Project) below sit directly on top of this join, so every + // reference to a right-side column must carry that widened nullability. Otherwise the + // rewritten plan would declare a right column non-nullable while its child -- the join -- + // produces it as nullable, which plan-integrity validation flags as a nullability + // regression. INNER joins do not widen the right side, so this is a no-op there. + val rightAttrs = joinType match { + case LeftOuter => right.output.map(_.withNullability(true)) + case _ => right.output + } + val rightNullabilityMap = AttributeMap(right.output.zip(rightAttrs)) + val rankingInJoin = rankingExpression.transform { + case a: Attribute => rightNullabilityMap.getOrElse(a, a) + } + + val (aggInput, rankingForAgg) = if (!rankingInJoin.deterministic) { + val rankingAlias = Alias(rankingInJoin, "__ranking__")() Project(join.output :+ rankingAlias, join) -> rankingAlias.toAttribute } else { - join -> rankingExpression + join -> rankingInJoin } // 4. Aggregate grouped by `__qid`: @@ -104,7 +119,7 @@ object RewriteNearestByJoin extends Rule[LogicalPlan] { // - max_by/min_by(struct(right.*), ranking, k) as `_matches`. // The ranking expression references left and right columns directly; no outer // reference is needed because both sides are present in the joined input. - val rightStruct = CreateStruct(right.output) + val rightStruct = CreateStruct(rightAttrs) // reverse = true -> MIN_BY (smallest ranking value first, for DISTANCE) // reverse = false -> MAX_BY (largest ranking value first, for SIMILARITY) val reverse = direction match { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala index 74c36f4099d9b..ddde0c3b92e69 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/ColumnDefinition.scala @@ -23,6 +23,7 @@ import org.apache.spark.sql.catalyst.expressions.{AnalysisAwareExpression, Expre import org.apache.spark.sql.catalyst.parser.ParserInterface import org.apache.spark.sql.catalyst.trees.TreePattern.{ANALYSIS_AWARE_EXPRESSION, TreePattern} import org.apache.spark.sql.catalyst.util.{CharVarcharUtils, GeneratedColumn, IdentityColumn, V2ExpressionBuilder} +import org.apache.spark.sql.catalyst.util.FieldMetadataUtils.FIELD_ID_METADATA_KEY import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.validateDefaultValueExpr import org.apache.spark.sql.catalyst.util.ResolveDefaultColumnsUtils.{CURRENT_DEFAULT_COLUMN_METADATA_KEY, EXISTS_DEFAULT_COLUMN_METADATA_KEY} import org.apache.spark.sql.connector.catalog.{Column => V2Column, ColumnDefaultValue, DefaultValue, IdentityColumnSpec} @@ -31,6 +32,7 @@ import org.apache.spark.sql.connector.expressions.LiteralValue import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.internal.connector.ColumnImpl import org.apache.spark.sql.types.{DataType, Metadata, MetadataBuilder, StructField} +import org.apache.spark.sql.util.SchemaUtils /** * User-specified column definition for CREATE/REPLACE TABLE commands. This is an expression so that @@ -39,6 +41,9 @@ import org.apache.spark.sql.types.{DataType, Metadata, MetadataBuilder, StructFi * For CREATE/REPLACE TABLE commands, columns are created from scratch, so we store the * user-specified default value as both the current default and exists default, in methods * `toV1Column` and `toV2Column`. + * + * Note that ColumnDefinition is meant to be used in DDL statements like CREATE or REPLACE. + * That's why it does not have a notion of column IDs as they must be assigned by connectors. */ case class ColumnDefinition( name: String, @@ -69,7 +74,8 @@ case class ColumnDefinition( defaultValue.map(_.toV2(statement, name)).orNull, generationExpression.orNull, identityColumnSpec.orNull, - if (metadata == Metadata.empty) null else metadata.json) + if (metadata == Metadata.empty) null else metadata.json, + id = null /* must be assigned by connectors */) } def toV1Column: StructField = { @@ -128,6 +134,7 @@ object ColumnDefinition { def fromV1Column(col: StructField, parser: ParserInterface): ColumnDefinition = { val metadataBuilder = new MetadataBuilder().withMetadata(col.metadata) metadataBuilder.remove("comment") + metadataBuilder.remove(FIELD_ID_METADATA_KEY) metadataBuilder.remove(CURRENT_DEFAULT_COLUMN_METADATA_KEY) metadataBuilder.remove(EXISTS_DEFAULT_COLUMN_METADATA_KEY) metadataBuilder.remove(GeneratedColumn.GENERATION_EXPRESSION_METADATA_KEY) @@ -158,7 +165,7 @@ object ColumnDefinition { } ColumnDefinition( col.name, - col.dataType, + SchemaUtils.clearFieldIds(col.dataType), col.nullable, col.getComment(), defaultValue, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala index 0cf152079c520..a69581ddd201d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/types/ops/TimeTypeOps.scala @@ -28,7 +28,6 @@ import org.apache.spark.sql.catalyst.types.{PhysicalDataType, PhysicalLongType} import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.execution.arrow.{ArrowFieldWriter, TimeWriter} import org.apache.spark.sql.types.{ObjectType, TimeType} -import org.apache.spark.sql.types.ops.TimeTypeApiOps /** * Server-side (catalyst) operations for TimeType. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala index 2e7e88633bfa0..5dc3962821a03 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/util/package.scala @@ -23,6 +23,7 @@ import java.nio.charset.StandardCharsets.UTF_8 import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.analysis.TempResolvedColumn import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.util.FieldMetadataUtils.FIELD_ID_METADATA_KEY import org.apache.spark.sql.connector.catalog.MetadataColumn import org.apache.spark.sql.types.{MetadataBuilder, NumericType, StringType, StructType} import org.apache.spark.unsafe.types.UTF8String @@ -241,6 +242,7 @@ package object util extends Logging { AUTO_GENERATED_ALIAS, METADATA_COL_ATTR_KEY, QUALIFIED_ACCESS_ONLY, + FIELD_ID_METADATA_KEY, FileSourceMetadataAttribute.FILE_SOURCE_METADATA_COL_ATTR_KEY, FileSourceConstantMetadataStructField.FILE_SOURCE_CONSTANT_METADATA_COL_ATTR_KEY, FileSourceGeneratedMetadataStructField.FILE_SOURCE_GENERATED_METADATA_COL_ATTR_KEY, diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala index e42d5f3a84457..aaf8310933f87 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/CatalogV2Util.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql.catalyst.catalog.ClusterBySpec import org.apache.spark.sql.catalyst.expressions.{Expression, Literal, V2ExpressionUtils} import org.apache.spark.sql.catalyst.plans.logical.{SerdeInfo, TableSpec} import org.apache.spark.sql.catalyst.util.{GeneratedColumn, IdentityColumn} +import org.apache.spark.sql.catalyst.util.FieldMetadataUtils.FIELD_ID_METADATA_KEY import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns._ import org.apache.spark.sql.connector.catalog.TableChange._ import org.apache.spark.sql.connector.catalog.constraints.Constraint @@ -38,7 +39,7 @@ import org.apache.spark.sql.connector.expressions.{ClusterByTransform, LiteralVa import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{ArrayType, MapType, Metadata, MetadataBuilder, StructField, StructType} -import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.sql.util.{CaseInsensitiveStringMap, SchemaUtils} import org.apache.spark.util.ArrayImplicits._ private[sql] object CatalogV2Util { @@ -78,6 +79,20 @@ private[sql] object CatalogV2Util { SupportsNamespaces.PROP_LOCATION, SupportsNamespaces.PROP_OWNER) + private val COMMON_COL_METADATA_KEYS = Seq("comment", FIELD_ID_METADATA_KEY) + + private val DEFAULT_COL_METADATA_KEYS = Seq( + CURRENT_DEFAULT_COLUMN_METADATA_KEY, + EXISTS_DEFAULT_COLUMN_METADATA_KEY) + + private val GENERATED_COL_METADATA_KEYS = Seq( + GeneratedColumn.GENERATION_EXPRESSION_METADATA_KEY) + + private val IDENTITY_COL_METADATA_KEYS = Seq( + IdentityColumn.IDENTITY_INFO_START, + IdentityColumn.IDENTITY_INFO_STEP, + IdentityColumn.IDENTITY_INFO_ALLOW_EXPLICIT_INSERT) + /** * Apply properties changes to a map and return the result. */ @@ -534,14 +549,14 @@ private[sql] object CatalogV2Util { } /** - * Construct a [[ViewInfo.Builder]] seeded from an existing view's metadata. Used by ALTER + * Construct a [[View.Builder]] seeded from an existing view's metadata. Used by ALTER * VIEW execs (SET / UNSET TBLPROPERTIES, ALTER VIEW ... WITH SCHEMA BINDING) -- override * the one field that changes, then `build` to produce the replacement payload for * [[ViewCatalog#replaceView]]. Every other field flows through unchanged so a metadata-only * mutation does not perturb the view body. */ - def viewInfoBuilderFrom(existing: ViewInfo): ViewInfo.Builder = { - val builder = new ViewInfo.Builder() + def viewInfoBuilderFrom(existing: View): View.Builder = { + val builder = new View.Builder() builder .withSchema(existing.schema) .withProperties(existing.properties) @@ -636,9 +651,9 @@ private[sql] object CatalogV2Util { } /** - * Converts DS v2 columns to StructType, which encodes column comment and default value to - * StructField metadata. This is mainly used to define the schema of v2 scan, w.r.t. the columns - * of the v2 table. + * Converts DS v2 columns to StructType, which encodes column comment, default value, and + * column ID to StructField metadata. This is mainly used to define the schema of v2 scan, + * w.r.t. the columns of the v2 table. */ def v2ColumnsToStructType(columns: Seq[Column]): StructType = { StructType(columns.map(v2ColumnToStructField)) @@ -653,6 +668,9 @@ private[sql] object CatalogV2Util { Option(col.defaultValue()).foreach { default => f = encodeDefaultValue(default, f) } + Option(col.id()).foreach { id => + f = f.withId(id) + } f } @@ -699,15 +717,21 @@ private[sql] object CatalogV2Util { /** * Converts a StructType to DS v2 columns, which decodes the StructField metadata to v2 column - * comment and default value or generation expression. This is mainly used to generate DS v2 - * columns from table schema in DDL commands, so that Spark can pass DS v2 columns to DS v2 - * createTable and related APIs. + * comment, default value or generation expression, and column ID. This is mainly used to + * generate DS v2 columns from table schema in DDL commands, so that Spark can pass DS v2 + * columns to DS v2 createTable and related APIs. */ - def structTypeToV2Columns(schema: StructType): Array[Column] = { - schema.fields.map(structFieldToV2Column) + def structTypeToV2Columns( + schema: StructType, + keepIds: Boolean = true): Array[Column] = { + schema.fields.map(structFieldToV2Column(_, keepIds)) } - private def structFieldToV2Column(f: StructField): Column = { + def clearIds(columns: Array[Column]): Array[Column] = { + columns.map(col => Column.builderFrom(col).clearIds().build()) + } + + private def structFieldToV2Column(f: StructField, keepIds: Boolean): Column = { def metadataAsJson(metadata: Metadata): String = { if (metadata == Metadata.empty) { null @@ -721,6 +745,10 @@ private[sql] object CatalogV2Util { }.build() } + val id = if (keepIds) f.id.orNull else null + val dataType = if (keepIds) f.dataType else SchemaUtils.clearFieldIds(f.dataType) + val comment = f.getComment().orNull + val isDefaultColumn = f.getCurrentDefaultValue().isDefined val isGeneratedColumn = GeneratedColumn.isGeneratedColumn(f) val isIdentityColumn = IdentityColumn.isIdentityColumn(f) @@ -736,32 +764,44 @@ private[sql] object CatalogV2Util { assert(e.resolved && e.foldable, "The existence default value must be a simple SQL string that is resolved and " + "foldable, but got: " + f.getExistenceDefaultValue().get) - LiteralValue(e.eval(), f.dataType) + LiteralValue(e.eval(), dataType) } else { null } val defaultValue = new ColumnDefaultValue(f.getCurrentDefaultValue().get, existsDefault) - val cleanedMetadata = metadataWithKeysRemoved( - Seq("comment", CURRENT_DEFAULT_COLUMN_METADATA_KEY, EXISTS_DEFAULT_COLUMN_METADATA_KEY)) - Column.create(f.name, f.dataType, f.nullable, f.getComment().orNull, defaultValue, - metadataAsJson(cleanedMetadata)) + val removedMetaKeys = COMMON_COL_METADATA_KEYS ++ DEFAULT_COL_METADATA_KEYS + Column.builderFor(f.name, dataType) + .nullable(f.nullable) + .comment(comment) + .defaultValue(defaultValue) + .metadata(metadataAsJson(metadataWithKeysRemoved(removedMetaKeys))) + .id(id) + .build() } else if (isGeneratedColumn) { - val cleanedMetadata = metadataWithKeysRemoved( - Seq("comment", GeneratedColumn.GENERATION_EXPRESSION_METADATA_KEY)) - Column.create(f.name, f.dataType, f.nullable, f.getComment().orNull, - GeneratedColumn.getGenerationExpression(f).get, metadataAsJson(cleanedMetadata)) + val removedMetaKeys = COMMON_COL_METADATA_KEYS ++ GENERATED_COL_METADATA_KEYS + Column.builderFor(f.name, dataType) + .nullable(f.nullable) + .comment(comment) + .generationExpression(GeneratedColumn.getGenerationExpression(f).get) + .metadata(metadataAsJson(metadataWithKeysRemoved(removedMetaKeys))) + .id(id) + .build() } else if (isIdentityColumn) { - val cleanedMetadata = metadataWithKeysRemoved( - Seq("comment", - IdentityColumn.IDENTITY_INFO_START, - IdentityColumn.IDENTITY_INFO_STEP, - IdentityColumn.IDENTITY_INFO_ALLOW_EXPLICIT_INSERT)) - Column.create(f.name, f.dataType, f.nullable, f.getComment().orNull, - IdentityColumn.getIdentityInfo(f).get, metadataAsJson(cleanedMetadata)) + val removedMetaKeys = COMMON_COL_METADATA_KEYS ++ IDENTITY_COL_METADATA_KEYS + Column.builderFor(f.name, dataType) + .nullable(f.nullable) + .comment(comment) + .identityColumnSpec(IdentityColumn.getIdentityInfo(f).get) + .metadata(metadataAsJson(metadataWithKeysRemoved(removedMetaKeys))) + .id(id) + .build() } else { - val cleanedMetadata = metadataWithKeysRemoved(Seq("comment")) - Column.create(f.name, f.dataType, f.nullable, f.getComment().orNull, - metadataAsJson(cleanedMetadata)) + Column.builderFor(f.name, dataType) + .nullable(f.nullable) + .comment(comment) + .metadata(metadataAsJson(metadataWithKeysRemoved(COMMON_COL_METADATA_KEYS))) + .id(id) + .build() } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/Catalogs.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/Catalogs.scala index c40d5ab679190..03addeb170697 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/Catalogs.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/Catalogs.scala @@ -64,7 +64,7 @@ private[sql] object Catalogs { } val plugin = pluginClass.getDeclaredConstructor().newInstance().asInstanceOf[CatalogPlugin] plugin.initialize(name, catalogOptions(name, conf)) - validateTableViewCatalog(name, plugin) + validateRelationCatalog(name, plugin) plugin } catch { case e: ClassNotFoundException => @@ -110,17 +110,17 @@ private[sql] object Catalogs { /** * Reject catalogs that implement both [[TableCatalog]] and [[ViewCatalog]] without - * extending [[TableViewCatalog]]. The combined case has cross-cutting rules (single namespace, - * cross-type collision rejection, perf opt-ins) that live on [[TableViewCatalog]]; implementing + * extending [[RelationCatalog]]. The combined case has cross-cutting rules (single namespace, + * cross-type collision rejection, perf opt-ins) that live on [[RelationCatalog]]; implementing * the two interfaces directly would skip that contract. */ - private def validateTableViewCatalog(name: String, plugin: CatalogPlugin): Unit = { + private def validateRelationCatalog(name: String, plugin: CatalogPlugin): Unit = { if (plugin.isInstanceOf[TableCatalog] && plugin.isInstanceOf[ViewCatalog] && - !plugin.isInstanceOf[TableViewCatalog]) { + !plugin.isInstanceOf[RelationCatalog]) { throw new IllegalArgumentException( s"Catalog '$name' (${plugin.getClass.getName}) implements both TableCatalog and " + s"ViewCatalog directly. Catalogs that expose both tables and views must implement " + - s"TableViewCatalog instead, which centralizes the cross-cutting rules (shared " + + s"RelationCatalog instead, which centralizes the cross-cutting rules (shared " + s"identifier namespace, cross-type collision rejection, single-RPC perf entry " + s"points).") } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala index 8a47cac8e7962..f4b629e372cda 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1Table.scala @@ -113,10 +113,8 @@ private[sql] object V1Table { def toCatalogTable( catalog: CatalogPlugin, ident: Identifier, - t: MetadataTable): CatalogTable = t.getTableInfo match { - case viewInfo: ViewInfo => toCatalogTable(catalog, ident, viewInfo) - case tableInfo => toCatalogTable(catalog, ident, tableInfo) - } + t: DelegatingTable): CatalogTable = + toCatalogTable(catalog, ident, t.getTableInfo) private def toCatalogTable( catalog: CatalogPlugin, @@ -127,7 +125,7 @@ private[sql] object V1Table { // v1 mapping (e.g. TableSummary.FOREIGN_TABLE_TYPE). v1 only has EXTERNAL/MANAGED, so // anything other than the explicit MANAGED mapping falls back to EXTERNAL for the v1 // representation -- the same default v1 uses when the value is missing. VIEW is reached - // only through the ViewInfo branch above. + // only through the View branch above. val tableType = props.get(TableCatalog.PROP_TABLE_TYPE) match { case Some(TableSummary.MANAGED_TABLE_TYPE) => CatalogTableType.MANAGED case _ => CatalogTableType.EXTERNAL @@ -169,7 +167,7 @@ private[sql] object V1Table { def toCatalogTable( catalog: CatalogPlugin, ident: Identifier, - info: ViewInfo): CatalogTable = { + info: View): CatalogTable = { val props = info.properties.asScala.toMap val userProps = props -- CatalogV2Util.TABLE_RESERVED_PROPERTIES // Serde/OPTION properties only apply to data-source tables; views' user properties are a @@ -196,7 +194,7 @@ private[sql] object V1Table { val schemaModeProps = Option(info.schemaMode) .map(m => Map(CatalogTable.VIEW_SCHEMA_MODE -> m)) .getOrElse(Map.empty) - // ViewInfo always represents a view-like table, but PROP_TABLE_TYPE may further refine the + // View always represents a view-like table, but PROP_TABLE_TYPE may further refine the // kind (e.g. METRIC_VIEW). Default to plain VIEW when no refinement is supplied. val tableType = props.get(TableCatalog.PROP_TABLE_TYPE) match { case Some(TableSummary.METRIC_VIEW_TABLE_TYPE) => CatalogTableType.METRIC_VIEW diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1ViewInfo.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1View.scala similarity index 72% rename from sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1ViewInfo.scala rename to sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1View.scala index e18fe52385a1c..7a19d04dd4e4c 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1ViewInfo.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V1View.scala @@ -22,42 +22,42 @@ import org.apache.spark.sql.catalyst.catalog.CatalogTable /** * A v1 [[CatalogTable]] (representing a session-catalog view) exposed through the v2 - * [[ViewInfo]] surface, mirroring the way [[V1Table]] exposes a v1 table CatalogTable through + * [[View]] surface, mirroring the way [[V1Table]] exposes a v1 table CatalogTable through * the v2 [[Table]] surface. Holds the original [[CatalogTable]] in [[v1Table]] for v1-only * paths that need the full v1 metadata representation (e.g. `DescribeTableCommand`, * `ShowCreateTableCommand`, anything that calls `CatalogTable#toLinkedHashMap`). * - * Note on `properties()`: the inherited [[ViewInfo#properties]] bag is built from the entire + * Note on `properties()`: the inherited [[View#properties]] bag is built from the entire * `v1Table.properties` map, which intermixes user TBLPROPERTIES with v1-internal storage keys * (`view.sqlConfig.*`, `view.catalogAndNamespace.*`, `view.query.out.*`, `view.schemaMode`). * v2 view inspection / SET execs (`ShowV2ViewPropertiesExec`, `AlterV2ViewSetPropertiesExec`, - * etc.) never see a `V1ViewInfo` -- `ResolveSessionCatalog` rewrites session-catalog views to + * etc.) never see a `V1View` -- `ResolveSessionCatalog` rewrites session-catalog views to * v1 commands first -- so the bag stays internal to v1-only paths. Consumers that do receive - * a `V1ViewInfo` should prefer the typed accessors ([[ViewInfo#sqlConfigs]], - * [[ViewInfo#currentNamespace]], [[ViewInfo#currentCatalog]], [[ViewInfo#queryColumnNames]], - * [[ViewInfo#schemaMode]]) for the v1-internal fields rather than scraping `properties()` for + * a `V1View` should prefer the typed accessors ([[View#sqlConfigs]], + * [[View#currentNamespace]], [[View#currentCatalog]], [[View#queryColumnNames]], + * [[View#schemaMode]]) for the v1-internal fields rather than scraping `properties()` for * them. */ -private[sql] class V1ViewInfo(val v1Table: CatalogTable) - extends ViewInfo(V1ViewInfo.builderFrom(v1Table)) +private[sql] class V1View(val v1Table: CatalogTable) + extends View(V1View.builderFrom(v1Table)) -private[sql] object V1ViewInfo { +private[sql] object V1View { /** - * Convert a v1 [[CatalogTable]] view into a [[ViewInfo.Builder]] with the same fields. - * Used as the {@code super(builder)} argument when constructing a [[V1ViewInfo]]. + * Convert a v1 [[CatalogTable]] view into a [[View.Builder]] with the same fields. + * Used as the {@code super(builder)} argument when constructing a [[V1View]]. */ - private def builderFrom(v1Table: CatalogTable): ViewInfo.Builder = { - val builder = new ViewInfo.Builder() + private def builderFrom(v1Table: CatalogTable): View.Builder = { + val builder = new View.Builder() builder.withSchema(v1Table.schema) builder.withProperties(v1Table.properties.asJava) // v1 stores collation / comment in typed `CatalogTable` fields rather than in `properties`, - // but consumers reading off [[ViewInfo]] (`ApplyDefaultCollation.fetchDefaultCollation`, + // but consumers reading off [[View]] (`ApplyDefaultCollation.fetchDefaultCollation`, // `ShowCreateV2ViewExec`, etc.) expect them under `PROP_COLLATION` / `PROP_COMMENT`. Bridge // them through the typed setters so the v2 surface sees the same view metadata regardless // of which catalog produced it. v1Table.collation.foreach(builder.withCollation) v1Table.comment.foreach(builder.withComment) - // ViewInfo requires a non-null queryText; v1 views always have one, but defend against + // View requires a non-null queryText; v1 views always have one, but defend against // an old/corrupt CatalogTable with `viewText = None` by falling back to an empty string. builder.withQueryText(v1Table.viewText.getOrElse("")) val cn = v1Table.viewCatalogAndNamespace diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala index 348d7e96e7d46..ec98d2c0fbe99 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/connector/catalog/V2TableUtil.scala @@ -19,8 +19,6 @@ package org.apache.spark.sql.connector.catalog import java.util.Locale -import scala.collection.mutable - import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.analysis.Resolver import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, MetadataColumnHelper} @@ -29,7 +27,7 @@ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.datasources.v2.DataSourceV2Relation import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.sql.util.SchemaValidationMode -import org.apache.spark.sql.util.SchemaValidationMode.PROHIBIT_CHANGES +import org.apache.spark.sql.util.SchemaValidationMode.{ALLOW_NEW_FIELDS, PROHIBIT_CHANGES} import org.apache.spark.util.ArrayImplicits._ private[sql] object V2TableUtil extends SQLConfHelper { @@ -49,14 +47,16 @@ private[sql] object V2TableUtil extends SQLConfHelper { def validateCapturedColumns( table: Table, relation: DataSourceV2Relation, - mode: SchemaValidationMode): Seq[String] = { - validateCapturedColumns(table, relation.table.columns.toImmutableArraySeq, mode) + mode: SchemaValidationMode, + checkIds: Boolean): Seq[String] = { + validateCapturedColumns(table, relation.table.columns.toImmutableArraySeq, mode, checkIds) } /** * Validates that captured data columns match the current table schema. * * Checks for: + * - Column ID changes (top-level and nested field IDs) * - Column type or nullability changes * - Removed columns (missing from the current table schema) * - Added columns (new in the current table schema) @@ -64,15 +64,17 @@ private[sql] object V2TableUtil extends SQLConfHelper { * @param table the current table metadata * @param originCols the originally captured columns * @param mode validation mode that defines what changes are acceptable + * @param checkIds whether to check field IDs * @return validation errors, or empty sequence if valid */ def validateCapturedColumns( table: Table, originCols: Seq[Column], - mode: SchemaValidationMode = PROHIBIT_CHANGES): Seq[String] = { + mode: SchemaValidationMode, + checkIds: Boolean): Seq[String] = { val originSchema = CatalogV2Util.v2ColumnsToStructType(originCols) val schema = CatalogV2Util.v2ColumnsToStructType(table.columns) - SchemaUtils.validateSchemaCompatibility(originSchema, schema, resolver, mode) + SchemaUtils.validateSchemaCompatibility(originSchema, schema, resolver, mode, checkIds) } /** @@ -86,8 +88,9 @@ private[sql] object V2TableUtil extends SQLConfHelper { def validateCapturedMetadataColumns( table: Table, relation: DataSourceV2Relation, - mode: SchemaValidationMode): Seq[String] = { - validateCapturedMetadataColumns(table, extractMetadataColumns(relation), mode) + mode: SchemaValidationMode, + checkIds: Boolean): Seq[String] = { + validateCapturedMetadataColumns(table, extractMetadataColumns(relation), mode, checkIds) } /** @@ -105,83 +108,29 @@ private[sql] object V2TableUtil extends SQLConfHelper { * Validates that captured metadata columns are consistent with the current table metadata. * * Checks for: + * - Column ID changes (top-level and nested field IDs) * - Metadata column type or nullability changes * - Removed metadata columns (missing from current table) * * @param table the current table metadata * @param originMetaCols the originally captured metadata columns * @param mode validation mode that defines what changes are acceptable + * @param checkIds whether to check IDs * @return validation errors, or empty sequence if valid */ def validateCapturedMetadataColumns( table: Table, originMetaCols: Seq[MetadataColumn], - mode: SchemaValidationMode = PROHIBIT_CHANGES): Seq[String] = { + mode: SchemaValidationMode, + checkIds: Boolean): Seq[String] = { + require( + mode == PROHIBIT_CHANGES || mode == ALLOW_NEW_FIELDS, + s"Unsupported schema validation mode for metadata columns: $mode") val originMetaColNames = originMetaCols.map(_.name) val originMetaSchema = CatalogV2Util.toStructType(originMetaCols) val metaCols = filter(originMetaColNames, metadataColumns(table)) val metaSchema = CatalogV2Util.toStructType(metaCols) - SchemaUtils.validateSchemaCompatibility(originMetaSchema, metaSchema, resolver, mode) - } - - /** - * Validates that column IDs have not changed for columns that still exist in the table. - * - * Only validates columns where the original and current column both have non-null IDs. - * If the connector does not support column IDs (returns null), the validation is skipped. - * - * @param table the current table metadata - * @param relation the relation with captured columns - * @return validation errors, or empty sequence if valid - */ - def validateColumnIds( - table: Table, - relation: DataSourceV2Relation): Seq[String] = { - validateColumnIds( - table = table, - originalCapturedCols = relation.table.columns.toImmutableArraySeq) - } - - /** - * Validates that column IDs have not changed for columns that still exist in the table. - * - * Only validates columns where the original and current column both have non-null IDs. - * If the connector does not support column IDs (returns null), the validation is skipped. - * - * ID transition handling: - * - null to null: skipped (no ID to validate) - * - null to ID: skipped (connector enabled ID tracking after analysis) - * - ID to null: skipped (connector disabled ID tracking) - * - ID to ID (same): no error - * - ID to ID (different): error, same column name was replaced - * - * @param table the current table metadata - * @param originalCapturedCols the originally captured columns - * @return validation errors, or empty sequence if valid - */ - def validateColumnIds( - table: Table, - originalCapturedCols: Seq[Column]): Seq[String] = { - val currentColsByNormalizedName = table.columns.toImmutableArraySeq - .map(currentCol => normalize(currentCol.name()) -> currentCol).toMap - val errors = new mutable.ArrayBuffer[String]() - for (originalCol <- originalCapturedCols) { - if (originalCol.id() != null) { - currentColsByNormalizedName.get(normalize(originalCol.name())) match { - case Some(currentCol) - if currentCol.id() != null && currentCol.id() != originalCol.id() => - errors += s"`${originalCol.name()}` column ID has changed from " + - s"${originalCol.id()} to ${currentCol.id()}" - case _ => - // 1. Column exists in the original schema but not in the current table. - // 2. Column IDs have not changed. - // 3. The current column's ID is null (connector disabled ID tracking). - // Note that dropped columns are handled separately by - // [[columnsMissingOrAddedAfterAnalysis]]. - } - } - } - errors.toSeq + SchemaUtils.validateSchemaCompatibility(originMetaSchema, metaSchema, resolver, mode, checkIds) } private def filter(colNames: Seq[String], cols: Seq[MetadataColumn]): Seq[MetadataColumn] = { diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala index 7b4b9334f02da..885ff428fb1d2 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/errors/QueryCompilationErrors.scala @@ -2228,17 +2228,7 @@ private[sql] object QueryCompilationErrors extends QueryErrorsBase with Compilat "currentTableId" -> currentTableId)) } - def columnIdMismatchAfterAnalysis( - tableName: String, - errors: Seq[String]): Throwable = { - new AnalysisException( - errorClass = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", - messageParameters = Map( - "tableName" -> toSQLId(tableName), - "errors" -> errors.mkString("- ", "\n- ", ""))) - } - - def columnsMissingOrAddedAfterAnalysis( + def columnsChangedAfterAnalysis( tableName: String, errors: Seq[String]): Throwable = { new AnalysisException( diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ColumnImpl.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ColumnImpl.scala index f97f90b7eb590..5e500ddd5aeed 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ColumnImpl.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/internal/connector/ColumnImpl.scala @@ -17,8 +17,6 @@ package org.apache.spark.sql.internal.connector -import java.util.Objects - import org.apache.spark.sql.connector.catalog.{Column, ColumnDefaultValue, IdentityColumnSpec} import org.apache.spark.sql.types.DataType @@ -32,36 +30,4 @@ case class ColumnImpl( generationExpression: String, identityColumnSpec: IdentityColumnSpec, metadataInJSON: String, - override val id: String = null) extends Column { - - // [[id]] is excluded from [[equals]] and [[hashCode]] because IDs only live on [[Column]], - // not on [[StructField]] metadata. Any code path that round-trips through [[StructType]] - // (e.g. [[CatalogV2Util.v2ColumnsToStructType]] followed by [[structTypeToV2Columns]]) - // drops the ID, producing a [[Column]] with id=null for the same logical column. Including - // [[id]] in equality would cause spurious mismatches across these round-trips. - // Column ID validation is performed separately by [[V2TableUtil.validateColumnIds]]. - override def equals(other: Any): Boolean = other match { - case that: ColumnImpl => - name == that.name && - dataType == that.dataType && - nullable == that.nullable && - comment == that.comment && - defaultValue == that.defaultValue && - generationExpression == that.generationExpression && - identityColumnSpec == that.identityColumnSpec && - metadataInJSON == that.metadataInJSON - case _ => false - } - - override def hashCode(): Int = { - Objects.hash( - name, - dataType, - Boolean.box(nullable), - comment, - defaultValue, - generationExpression, - identityColumnSpec, - metadataInJSON) - } -} + override val id: String = null) extends Column diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/metricview/serde/MetricViewCanonical.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/metricview/serde/MetricViewCanonical.scala index 1b4718ebd385e..3fa7ea5395974 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/metricview/serde/MetricViewCanonical.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/metricview/serde/MetricViewCanonical.scala @@ -179,7 +179,7 @@ private[sql] case class MetricView( * [[Constants.MAXIMUM_PROPERTY_SIZE]] characters, so these are descriptive values * for catalog browsers / lineage tooling -- not round-trippable representations * of the source. Consumers that need the full SQL or filter expression for - * re-execution should read [[ViewInfo#queryText]] (the YAML body) and re-parse it + * re-execution should read [[View#queryText]] (the YAML body) and re-parse it * rather than reconstruct the query from these properties; for any source whose * SQL exceeds the size limit, this property would silently return a truncated * string. diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala index 58ababa04739f..61c6b1397b9d3 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/util/SchemaUtils.scala @@ -394,19 +394,21 @@ private[spark] object SchemaUtils { } /** - * Validates schema compatibility by recursively checking type and nullability changes. + * Validates schema compatibility by recursively checking ID, type and nullability changes. * * @param schema the schema to validate against * @param otherSchema the other schema to check for compatibility * @param resolver the resolver that controls whether the validation is case sensitive * @param mode the validation mode that controls what changes are allowed + * @param checkFieldIds whether to validate top-level and nested field IDs * @return sequence of error messages describing incompatibilities, empty if fully compatible */ def validateSchemaCompatibility( schema: StructType, otherSchema: StructType, resolver: Resolver, - mode: SchemaValidationMode): Seq[String] = { + mode: SchemaValidationMode, + checkFieldIds: Boolean): Seq[String] = { checkSchemaColumnNameDuplication(schema, resolver) checkSchemaColumnNameDuplication(otherSchema, resolver) val errors = mutable.ArrayBuffer[String]() @@ -418,6 +420,7 @@ private[spark] object SchemaUtils { colPath = Seq.empty, resolver, mode, + checkFieldIds, errors) errors.toSeq } @@ -430,6 +433,7 @@ private[spark] object SchemaUtils { colPath: Seq[String], resolver: Resolver, mode: SchemaValidationMode, + checkFieldIds: Boolean, errors: mutable.ArrayBuffer[String]): Unit = { if (nullable && !otherNullable) { errors += s"${colPath.fullyQuoted} is no longer nullable" @@ -445,14 +449,21 @@ private[spark] object SchemaUtils { fieldsByName.foreach { case (normalizedName, field) => otherFieldsByName.get(normalizedName) match { case Some(otherField) => + val nameParts = colPath :+ field.name + if (checkFieldIds) { + for (id <- field.id; otherId <- otherField.id if id != otherId) { + errors += s"${nameParts.fullyQuoted} field ID has changed from $id to $otherId" + } + } validateTypeCompatibility( field.dataType, otherField.dataType, field.nullable, otherField.nullable, - colPath :+ field.name, + nameParts, resolver, mode, + checkFieldIds, errors) case None => errors += s"${formatField(colPath, field)} has been removed" @@ -476,6 +487,7 @@ private[spark] object SchemaUtils { colPath :+ "element", resolver, mode, + checkFieldIds, errors) case (MapType(keyType, valueType, valueContainsNull), @@ -488,6 +500,7 @@ private[spark] object SchemaUtils { colPath :+ "key", resolver, mode, + checkFieldIds, errors) validateTypeCompatibility( valueType, @@ -497,6 +510,7 @@ private[spark] object SchemaUtils { colPath :+ "value", resolver, mode, + checkFieldIds, errors) case _ if dataType != otherDataType => @@ -522,6 +536,38 @@ private[spark] object SchemaUtils { fields.map(field => field.name.toLowerCase(Locale.ROOT) -> field).toMap } } + + /** + * Recursively clears field IDs from a data type. + */ + def clearFieldIds(dataType: DataType): DataType = dataType match { + case s: StructType => + StructType(s.fields.map { field => + val fieldWithoutId = field.clearId() + val newDataType = clearFieldIds(field.dataType) + if (newDataType ne field.dataType) { + fieldWithoutId.copy(dataType = newDataType) + } else { + fieldWithoutId + } + }) + + case a: ArrayType => + val newElementType = clearFieldIds(a.elementType) + if (newElementType ne a.elementType) a.copy(elementType = newElementType) else a + + case m: MapType => + val newKeyType = clearFieldIds(m.keyType) + val newValueType = clearFieldIds(m.valueType) + if ((newKeyType ne m.keyType) || (newValueType ne m.valueType)) { + m.copy(keyType = newKeyType, valueType = newValueType) + } else { + m + } + + case other => + other + } } private[spark] sealed trait SchemaValidationMode diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala index 4a2b23fe059ba..4a85906dfc2b8 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/DateExpressionsSuite.scala @@ -862,8 +862,19 @@ class DateExpressionsSuite extends SparkFunSuite with ExpressionEvalHelper { val minTimestamp = Literal.create(Long.MinValue, TimestampType) Seq("YEAR", "QUARTER", "MONTH", "WEEK", "DAY", "HOUR", "MINUTE", "SECOND", "MILLISECOND").foreach { fmt => - checkExceptionInExpression[ArithmeticException]( - TruncTimestamp(Literal.create(fmt, StringType), minTimestamp), "") + // The overflow surfaces as a raw `ArithmeticException` from `Math.*Exact`, which + // `checkEvaluation` wraps in a scalatest `TestFailedException`. Unwrap it via `getCause` + // and assert the type. The exception message is not part of the API contract: JDK 25 may + // throw it with a `null` message in codegen mode (JIT "hot throw", see SPARK-55714 / + // SPARK-55755 / JDK-8367990), while the interpreter reports "long overflow", so tolerate + // a null message. Mirrors the `TimestampAddInterval` overflow check in this suite. + val e = intercept[Exception] { + checkEvaluation( + TruncTimestamp(Literal.create(fmt, StringType), minTimestamp), + null) + }.getCause + assert(e.isInstanceOf[ArithmeticException]) + assert(e.getMessage == null || e.getMessage.contains("overflow")) } } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala index 729b58394d4bc..90d6ad8065876 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/RewriteNearestByJoinSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.catalyst.optimizer import org.apache.spark.sql.catalyst.dsl.expressions._ import org.apache.spark.sql.catalyst.dsl.plans._ -import org.apache.spark.sql.catalyst.expressions.{Alias, AttributeReference, CreateStruct, Inline, Literal, Rand, Uuid} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeMap, AttributeReference, CreateStruct, Inline, Literal, Rand, Uuid} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, First, MaxMinByK} import org.apache.spark.sql.catalyst.plans.{Inner, JoinType, LeftOuter, NearestByDistance, NearestBySimilarity, PlanTest} import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, Generate, Join, JoinHint, LocalRelation, NearestByJoin, Project} @@ -46,9 +46,19 @@ class RewriteNearestByJoinSuite extends PlanTest { val taggedLeft = Project(left.output :+ qidAlias, left) val join = Join(taggedLeft, right, joinType, None, JoinHint.NONE) - val rightStruct = CreateStruct(right.output) + // Mirror the rewrite: a LEFT OUTER join widens right-side columns to nullable, so the + // struct and ranking that sit on top of the join must reference them with that nullability. + val rightAttrs = joinType match { + case LeftOuter => right.output.map(_.withNullability(true)) + case _ => right.output + } + val rightNullabilityMap = AttributeMap(right.output.zip(rightAttrs)) + val rankingInJoin = ranking.transform { + case a: Attribute => rightNullabilityMap.getOrElse(a, a) + } + val rightStruct = CreateStruct(rightAttrs) val topKAgg = MaxMinByK( - rightStruct, ranking, Literal(numResults), reverse = reverse) + rightStruct, rankingInJoin, Literal(numResults), reverse = reverse) .toAggregateExpression() val matchesAlias = Alias(topKAgg, "__nearest_matches__")() val firstLeftAggs = left.output.map { attr => @@ -145,6 +155,65 @@ class RewriteNearestByJoinSuite extends PlanTest { comparePlans(normalizeUuidSeed(rewritten), expected, checkAnalysis = false) } + test("SPARK-56395: LEFT OUTER rewrite keeps right-side nullability consistent with its child") { + // A LEFT OUTER NEAREST BY widens the synthetic join's right-side columns to nullable. Every + // operator built on top of that join that references those columns (the `_matches` struct, + // the ranking) must carry the widened nullability -- otherwise the rewritten plan declares a + // column non-nullable while its child produces it as nullable, an internal inconsistency that + // no framework check catches (`LogicalPlanIntegrity` compares types `asNullable` and schemas + // `equalsIgnoreNullability`), so this assertion is the only guard. INNER does not widen the + // right side, so it stays a no-op. + // + // The right-side columns are declared non-nullable here: that is what makes LEFT OUTER's + // widening observable (with nullable columns the widening is a no-op and the bug is hidden). + // The ranking is exercised both deterministic (the reference lands directly in the Aggregate) + // and nondeterministic (the rule pre-materializes it into a `__ranking__` Project above the + // Join), so the widening is checked wherever the reference ends up. + val left = LocalRelation($"a".int, $"b".int) + val right = LocalRelation( + AttributeReference("x", IntegerType, nullable = false)(), + AttributeReference("y", IntegerType, nullable = false)()) + val rankings = Seq( + "deterministic" -> (left.output(0) + right.output(0)), + "nondeterministic" -> (Rand(Literal(0L)) + right.output(0))) + for ((joinType, rightNullable) <- Seq(Inner -> false, LeftOuter -> true); + (label, ranking) <- rankings) { + val query = NearestByJoin( + left, right, joinType, approx = true, numResults = 1, + rankingExpression = ranking, + direction = NearestBySimilarity) + + val rewritten = RewriteNearestByJoin(query.analyze) + val join = rewritten.collect { case j: Join => j }.head + + // Sanity-check the fixture: the synthetic join widens its right-side output to nullable + // iff it is LEFT OUTER. (`join.right` is the right relation as it appears in the rewritten + // plan, so its ExprIds line up with the join's output.) + val rightExprIds = join.right.output.map(_.exprId).toSet + val joinRightOutput = join.output.filter(a => rightExprIds.contains(a.exprId)) + assert(joinRightOutput.nonEmpty) + assert(joinRightOutput.forall(_.nullable == rightNullable)) + + // Whole-plan integrity: at every operator, an attribute reference whose ExprId is produced + // by one of that operator's children must agree with the child on nullability -- this is + // exactly what the fix corrects for LEFT OUTER. Walking the whole plan (rather than just + // the Aggregate) also covers the `__ranking__` Project that the nondeterministic path + // inserts above the Join, where the widened ranking reference lands. + rewritten.foreach { node => + val childNullability = + node.children.flatMap(_.output).map(a => a.exprId -> a.nullable).toMap + node.expressions.foreach(_.foreach { + case ref: AttributeReference if childNullability.contains(ref.exprId) => + assert(ref.nullable == childNullability(ref.exprId), + s"$joinType/$label: ${ref.name}#${ref.exprId.id} declared " + + s"nullable=${ref.nullable} but its child produces " + + s"nullable=${childNullability(ref.exprId)}") + case _ => + }) + } + } + } + test("synthetic Join uses the user's joinType") { // Locks in that the rewrite's synthetic Join carries the user's `joinType` // (Inner or LeftOuter). diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala index ad9c2655023fc..8c59b2ef45988 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/CatalogSuite.scala @@ -151,7 +151,7 @@ class CatalogSuite extends SparkFunSuite { val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) assert(parsed == Seq("test", "`", ".", "test_table")) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) assert(table.properties.asScala == Map()) assert(catalog.tableExists(testIdent)) @@ -175,7 +175,7 @@ class CatalogSuite extends SparkFunSuite { val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) assert(parsed == Seq("test", "`", ".", "test_table")) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) assert(table.properties.asScala == Map()) assert(partCatalog.tableExists(testIdent)) @@ -198,7 +198,7 @@ class CatalogSuite extends SparkFunSuite { val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) assert(parsed == Seq("test", "`", ".", "test_table")) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) assert(table.properties == properties) assert(catalog.tableExists(testIdent)) @@ -220,7 +220,7 @@ class CatalogSuite extends SparkFunSuite { val parsed = CatalystSqlParser.parseMultipartIdentifier(table.name) assert(parsed == Seq("test", "`", ".", "test_table")) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) assert(table.constraints === constraints) assert(table.properties.asScala == Map()) @@ -419,11 +419,12 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) val updated = catalog.alterTable(testIdent, TableChange.addColumn(Array("ts"), TimestampType)) - assert(updated.columns === columns :+ Column.create("ts", TimestampType)) + assert(CatalogV2Util.clearIds(updated.columns) === + columns :+ Column.create("ts", TimestampType)) } test("alterTable: add required column") { @@ -436,12 +437,13 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) val updated = catalog.alterTable(testIdent, TableChange.addColumn(Array("ts"), TimestampType, false)) - assert(updated.columns === columns :+ Column.create("ts", TimestampType, false)) + assert(CatalogV2Util.clearIds(updated.columns) === + columns :+ Column.create("ts", TimestampType, false)) } test("alterTable: add column with comment") { @@ -454,13 +456,13 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) val updated = catalog.alterTable(testIdent, TableChange.addColumn(Array("ts"), TimestampType, false, "comment text")) val tsColumn = Column.create("ts", TimestampType, false, "comment text", null) - assert(updated.columns === (columns :+ tsColumn)) + assert(CatalogV2Util.clearIds(updated.columns) === (columns :+ tsColumn)) } test("alterTable: add nested column") { @@ -476,14 +478,14 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === tableColumns) + assert(CatalogV2Util.clearIds(table.columns) === tableColumns) val updated = catalog.alterTable(testIdent, TableChange.addColumn(Array("point", "z"), DoubleType)) val expectedColumns = columns :+ Column.create("point", pointStruct.add("z", DoubleType)) - assert(updated.columns === expectedColumns) + assert(CatalogV2Util.clearIds(updated.columns) === expectedColumns) } test("alterTable: add column to primitive field fails") { @@ -496,7 +498,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) checkError( exception = intercept[SparkIllegalArgumentException] { @@ -506,7 +508,7 @@ class CatalogSuite extends SparkFunSuite { parameters = Map("name" -> "data")) // the table has not changed - assert(catalog.loadTable(testIdent).columns === columns) + assert(CatalogV2Util.clearIds(catalog.loadTable(testIdent).columns) === columns) } test("alterTable: add field to missing column fails") { @@ -519,7 +521,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) checkError( exception = intercept[SparkIllegalArgumentException] { @@ -540,12 +542,12 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) val updated = catalog.alterTable(testIdent, TableChange.updateColumnType(Array("id"), LongType)) val expectedColumns = Array(Column.create("id", LongType), Column.create("data", StringType)) - assert(updated.columns sameElements expectedColumns) + assert(CatalogV2Util.clearIds(updated.columns) sameElements expectedColumns) } test("alterTable: update column nullability") { @@ -561,14 +563,14 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === originalColumns) + assert(CatalogV2Util.clearIds(table.columns) === originalColumns) val updated = catalog.alterTable(testIdent, TableChange.updateColumnNullability(Array("id"), true)) val expectedColumns = Array( Column.create("id", IntegerType, true), Column.create("data", StringType)) - assert(updated.columns sameElements expectedColumns) + assert(CatalogV2Util.clearIds(updated.columns) sameElements expectedColumns) } test("alterTable: update missing column fails") { @@ -581,7 +583,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) checkError( exception = intercept[SparkIllegalArgumentException] { @@ -602,7 +604,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) val updated = catalog.alterTable(testIdent, TableChange.updateColumnComment(Array("id"), "comment text")) @@ -611,7 +613,7 @@ class CatalogSuite extends SparkFunSuite { Column.create("id", IntegerType, true, "comment text", null), Column.create("data", StringType) ) - assert(updated.columns sameElements expectedColumns) + assert(CatalogV2Util.clearIds(updated.columns) sameElements expectedColumns) } test("alterTable: replace comment") { @@ -624,7 +626,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) catalog.alterTable(testIdent, TableChange.updateColumnComment(Array("id"), "comment text")) @@ -635,7 +637,7 @@ class CatalogSuite extends SparkFunSuite { val updated = catalog.alterTable(testIdent, TableChange.updateColumnComment(Array("id"), "replacement comment")) - assert(updated.columns sameElements expectedColumns) + assert(CatalogV2Util.clearIds(updated.columns) sameElements expectedColumns) } test("alterTable: add comment to missing column fails") { @@ -648,7 +650,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) checkError( exception = intercept[SparkIllegalArgumentException] { @@ -669,13 +671,13 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) val updated = catalog.alterTable(testIdent, TableChange.renameColumn(Array("id"), "some_id")) val expectedColumns = Array( Column.create("some_id", IntegerType), Column.create("data", StringType)) - assert(updated.columns sameElements expectedColumns) + assert(CatalogV2Util.clearIds(updated.columns) sameElements expectedColumns) } test("alterTable: rename nested column") { @@ -691,7 +693,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === tableColumns) + assert(CatalogV2Util.clearIds(table.columns) === tableColumns) val updated = catalog.alterTable(testIdent, TableChange.renameColumn(Array("point", "x"), "first")) @@ -699,7 +701,7 @@ class CatalogSuite extends SparkFunSuite { val newPointStruct = new StructType().add("first", DoubleType).add("y", DoubleType) val expectedColumns = columns :+ Column.create("point", newPointStruct) - assert(updated.columns === expectedColumns) + assert(CatalogV2Util.clearIds(updated.columns) === expectedColumns) } test("alterTable: rename struct column") { @@ -715,7 +717,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === tableColumns) + assert(CatalogV2Util.clearIds(table.columns) === tableColumns) val updated = catalog.alterTable(testIdent, TableChange.renameColumn(Array("point"), "p")) @@ -723,7 +725,7 @@ class CatalogSuite extends SparkFunSuite { val newPointStruct = new StructType().add("x", DoubleType).add("y", DoubleType) val expectedColumns = columns :+ Column.create("p", newPointStruct) - assert(updated.columns === expectedColumns) + assert(CatalogV2Util.clearIds(updated.columns) === expectedColumns) } test("alterTable: rename missing column fails") { @@ -736,7 +738,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) checkError( exception = intercept[SparkIllegalArgumentException] { @@ -760,7 +762,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === tableColumns) + assert(CatalogV2Util.clearIds(table.columns) === tableColumns) val updated = catalog.alterTable(testIdent, TableChange.renameColumn(Array("point", "x"), "first"), @@ -769,7 +771,7 @@ class CatalogSuite extends SparkFunSuite { val newPointStruct = new StructType().add("first", DoubleType).add("second", DoubleType) val expectedColumns = columns :+ Column.create("point", newPointStruct) - assert(updated.columns() === expectedColumns) + assert(CatalogV2Util.clearIds(updated.columns()) === expectedColumns) } test("alterTable: delete top-level column") { @@ -782,13 +784,13 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) val updated = catalog.alterTable(testIdent, TableChange.deleteColumn(Array("id"), false)) val expectedColumns = Array(Column.create("data", StringType)) - assert(updated.columns sameElements expectedColumns) + assert(CatalogV2Util.clearIds(updated.columns) sameElements expectedColumns) } test("alterTable: delete nested column") { @@ -804,7 +806,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === tableColumns) + assert(CatalogV2Util.clearIds(table.columns) === tableColumns) val updated = catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "y"), false)) @@ -812,7 +814,7 @@ class CatalogSuite extends SparkFunSuite { val newPointStruct = new StructType().add("x", DoubleType) val expectedColumns = columns :+ Column.create("point", newPointStruct) - assert(updated.columns === expectedColumns) + assert(CatalogV2Util.clearIds(updated.columns) === expectedColumns) } test("alterTable: delete missing column fails") { @@ -825,7 +827,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) checkError( exception = intercept[SparkIllegalArgumentException] { @@ -836,7 +838,7 @@ class CatalogSuite extends SparkFunSuite { // with if exists it should pass catalog.alterTable(testIdent, TableChange.deleteColumn(Array("missing_col"), true)) - assert(table.columns === columns) + assert(CatalogV2Util.clearIds(table.columns) === columns) } test("alterTable: delete missing nested column fails") { @@ -852,7 +854,7 @@ class CatalogSuite extends SparkFunSuite { .build() val table = catalog.createTable(testIdent, tableInfo) - assert(table.columns === tableColumns) + assert(CatalogV2Util.clearIds(table.columns) === tableColumns) checkError( exception = intercept[SparkIllegalArgumentException] { @@ -863,7 +865,7 @@ class CatalogSuite extends SparkFunSuite { // with if exists it should pass catalog.alterTable(testIdent, TableChange.deleteColumn(Array("point", "z"), true)) - assert(table.columns === tableColumns) + assert(CatalogV2Util.clearIds(table.columns) === tableColumns) } test("alterTable: table does not exist") { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ColumnSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ColumnSuite.scala new file mode 100644 index 0000000000000..8296e708d026a --- /dev/null +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ColumnSuite.scala @@ -0,0 +1,236 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connector.catalog + +import org.apache.spark.SparkFunSuite +import org.apache.spark.SparkIllegalArgumentException +import org.apache.spark.sql.catalyst.parser.CatalystSqlParser +import org.apache.spark.sql.catalyst.plans.logical.ColumnDefinition +import org.apache.spark.sql.catalyst.util.FieldMetadataUtils.FIELD_ID_METADATA_KEY +import org.apache.spark.sql.connector.expressions.LiteralValue +import org.apache.spark.sql.types.{IntegerType, MetadataBuilder, StringType, StructField, StructType} + +class ColumnSuite extends SparkFunSuite { + + private val intLiteral = LiteralValue(42, IntegerType) + private val defaultValue = new ColumnDefaultValue("42", intLiteral) + private val identitySpec = new IdentityColumnSpec(1L, 1L, false) + + // --------------------------------------------------------------------------- + // Column.create factory overloads + // --------------------------------------------------------------------------- + + test("create(name, type) defaults: nullable, no comment, no generation expr, no metadata") { + val col = Column.create("c", IntegerType) + assert(col.name() == "c") + assert(col.dataType() == IntegerType) + assert(col.nullable()) + assert(col.comment() == null) + assert(col.defaultValue() == null) + assert(col.generationExpression() == null) + assert(col.identityColumnSpec() == null) + assert(col.metadataInJSON() == null) + assert(col.id() == null) + } + + test("create(name, type, nullable) controls nullable flag") { + assert(Column.create("c", IntegerType, false).nullable() == false) + assert(Column.create("c", IntegerType, true).nullable() == true) + } + + test("create(name, type, nullable, comment, metadataInJSON)") { + val col = Column.create("c", StringType, false, "a comment", """{"key":"val"}""") + assert(col.name() == "c") + assert(col.dataType() == StringType) + assert(col.nullable() == false) + assert(col.comment() == "a comment") + assert(col.metadataInJSON() == """{"key":"val"}""") + assert(col.defaultValue() == null) + assert(col.generationExpression() == null) + assert(col.identityColumnSpec() == null) + } + + test("create(name, type, nullable, comment, defaultValue, metadataInJSON)") { + val col = Column.create("c", IntegerType, true, "doc", defaultValue, null) + assert(col.defaultValue() == defaultValue) + assert(col.generationExpression() == null) + assert(col.identityColumnSpec() == null) + } + + test("create(name, type, nullable, comment, generationExpression, metadataInJSON)") { + val col = Column.create("c", IntegerType, true, null, "a + 1", null) + assert(col.generationExpression() == "a + 1") + assert(col.defaultValue() == null) + assert(col.identityColumnSpec() == null) + } + + test("create(name, type, nullable, comment, identityColumnSpec, metadataInJSON)") { + val col = Column.create("c", IntegerType, false, null, identitySpec, null) + assert(col.identityColumnSpec() == identitySpec) + assert(col.defaultValue() == null) + assert(col.generationExpression() == null) + } + + // --------------------------------------------------------------------------- + // Column.builderFor / Builder + // --------------------------------------------------------------------------- + + test("builder defaults: nullable=true, everything else null") { + val col = Column.builderFor("c", IntegerType).build() + assert(col.name() == "c") + assert(col.dataType() == IntegerType) + assert(col.nullable()) + assert(col.comment() == null) + assert(col.defaultValue() == null) + assert(col.generationExpression() == null) + assert(col.identityColumnSpec() == null) + assert(col.metadataInJSON() == null) + assert(col.id() == null) + } + + test("builder sets all fields") { + val col = Column.builderFor("c", IntegerType) + .nullable(false) + .comment("doc") + .defaultValue(defaultValue) + .metadata("""{"k":"v"}""") + .id("abc-123") + .build() + assert(col.name() == "c") + assert(col.dataType() == IntegerType) + assert(col.nullable() == false) + assert(col.comment() == "doc") + assert(col.defaultValue() == defaultValue) + assert(col.metadataInJSON() == """{"k":"v"}""") + assert(col.id() == "abc-123") + } + + test("builder with generationExpression") { + val col = Column.builderFor("c", IntegerType) + .generationExpression("a * 2") + .build() + assert(col.generationExpression() == "a * 2") + assert(col.defaultValue() == null) + assert(col.identityColumnSpec() == null) + } + + test("builder with identityColumnSpec") { + val col = Column.builderFor("c", IntegerType) + .identityColumnSpec(identitySpec) + .build() + assert(col.identityColumnSpec() == identitySpec) + assert(col.defaultValue() == null) + assert(col.generationExpression() == null) + } + + // --------------------------------------------------------------------------- + // Builder invariants: conflicting definitions are rejected + // --------------------------------------------------------------------------- + + test("builder rejects defaultValue + generationExpression") { + val ex = intercept[SparkIllegalArgumentException] { + Column.builderFor("c", IntegerType) + .defaultValue(defaultValue) + .generationExpression("a + 1") + .build() + } + assert(ex.getMessage.contains("cannot have more than one definition")) + } + + test("builder rejects generationExpression + identityColumnSpec") { + val ex = intercept[SparkIllegalArgumentException] { + Column.builderFor("c", IntegerType) + .generationExpression("a + 1") + .identityColumnSpec(identitySpec) + .build() + } + assert(ex.getMessage.contains("cannot have more than one definition")) + } + + test("builder rejects defaultValue + identityColumnSpec") { + val ex = intercept[SparkIllegalArgumentException] { + Column.builderFor("c", IntegerType) + .defaultValue(defaultValue) + .identityColumnSpec(identitySpec) + .build() + } + assert(ex.getMessage.contains("cannot have more than one definition")) + } + + test("builder rejects all three definitions set simultaneously") { + val ex = intercept[SparkIllegalArgumentException] { + Column.builderFor("c", IntegerType) + .defaultValue(defaultValue) + .generationExpression("a + 1") + .identityColumnSpec(identitySpec) + .build() + } + assert(ex.getMessage.contains("cannot have more than one definition")) + } + + test("builder error message names the column") { + val ex = intercept[SparkIllegalArgumentException] { + Column.builderFor("my_column", IntegerType) + .defaultValue(defaultValue) + .generationExpression("x") + .build() + } + assert(ex.getMessage.contains("my_column")) + } + + // --------------------------------------------------------------------------- + // newBuilder rejects null name / type + // --------------------------------------------------------------------------- + + test("newBuilder rejects null name") { + intercept[NullPointerException] { + Column.builderFor(null, IntegerType) + } + } + + test("newBuilder rejects null dataType") { + intercept[NullPointerException] { + Column.builderFor("c", null) + } + } + + // --------------------------------------------------------------------------- + // ColumnDefinition.fromV1Column - metadata cleaning + // --------------------------------------------------------------------------- + + test("fromV1Column strips FIELD_ID_METADATA_KEY from metadata") { + val metadata = new MetadataBuilder() + .putString(FIELD_ID_METADATA_KEY, "42") + .putString("custom_key", "custom_value") + .build() + val field = StructField("col", IntegerType, nullable = true, metadata) + val colDef = ColumnDefinition.fromV1Column(field, CatalystSqlParser) + assert(!colDef.metadata.contains(FIELD_ID_METADATA_KEY)) + assert(colDef.metadata.contains("custom_key")) + } + + test("fromV1Column strips nested field IDs from struct dataType") { + val nestedType = StructType(Array( + StructField("x", IntegerType).withId("source-id"), + StructField("y", IntegerType).withId("source-id"))) + val field = StructField("col", nestedType) + val colDef = ColumnDefinition.fromV1Column(field, CatalystSqlParser) + val resultType = colDef.dataType.asInstanceOf[StructType] + resultType.fields.foreach { f => assert(f.id.isEmpty) } + } +} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ComposedColumnIdTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ComposedColumnIdTableCatalog.scala deleted file mode 100644 index 64488a76db7f3..0000000000000 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ComposedColumnIdTableCatalog.scala +++ /dev/null @@ -1,290 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.sql.connector.catalog - -import java.util.Locale -import java.util.concurrent.ConcurrentHashMap - -import scala.collection.mutable - -import org.apache.spark.sql.internal.connector.ColumnImpl -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} - -/** - * An [[InMemoryTableCatalog]] that tracks IDs at every nesting level - * (struct fields, array elements, map keys/values) and encodes the full - * subtree into each top-level [[Column.id]] string. - * - * This demonstrates how a connector that wants to detect nested changes - * can encode nested IDs into the top-level [[Column.id]] string. - * Any nested change (drop+re-add a struct field, etc.) produces a - * different encoded top-level string, so [[V2TableUtil.validateColumnIds]] - * detects it without Spark needing to traverse below the top level. - * - * Nested positions are keyed by ordinal path (`Seq[Int]`), not by field - * name. This matches Delta/Iceberg semantics where rename preserves the - * column ID: a renamed field stays at the same ordinal position, so the - * composed string is unchanged and schema validation catches the rename - * via the differing [[StructType]]. - * - * Example: for a column `person STRUCT` with - * root ID 5 and nested field IDs position 0 (name) = 10, - * position 1 (age) = 11, the composed [[Column.id]] string is - * `"5[0:10,1:11]"`. If `age` is dropped and re-added, the new age gets - * ID 12, producing `"5[0:10,1:12]"`. Spark sees different strings and - * fires `COLUMN_ID_MISMATCH`. - */ -class ComposedColumnIdTableCatalog extends InMemoryTableCatalog { - - // Per-table nested ID maps. - // Structure: tableIdentifier -> (columnName -> nestedFieldIdMap) - // where nestedFieldIdMap maps an ordinal path to its assigned ID. - // - // For column `person STRUCT>`: - // "person" -> { - // Seq(0) -> 10, // name - // Seq(1) -> 11, // addr - // Seq(1, 0) -> 12 // addr.city - // } - private val nestedIdMaps = - new ConcurrentHashMap[Identifier, mutable.Map[String, mutable.Map[Seq[Int], Long]]]() - - // Bare (uncomposed) root IDs, tracked separately to avoid double-encoding. - // Structure: tableIdentifier -> (columnName -> bareRootIdString) - private val rootIds = - new ConcurrentHashMap[Identifier, mutable.Map[String, String]]() - - override def createTable( - ident: Identifier, info: TableInfo): Table = { - val table = super.createTable(ident, info).asInstanceOf[InMemoryTable] - val allColumnNestedIds = mutable.Map[String, mutable.Map[Seq[Int], Long]]() - val allRootIds = mutable.Map[String, String]() - - val composedColumns: Array[Column] = table.columns().map { column => - val nestedFieldIds = mutable.Map[Seq[Int], Long]() - assignNestedIds(column.dataType(), parentPath = Seq.empty, nestedFieldIds) - val columnName = column.name().toLowerCase(Locale.ROOT) - allColumnNestedIds(columnName) = nestedFieldIds - allRootIds(columnName) = column.id() - val composedId = encodeComposedId(column.id(), nestedFieldIds) - column.asInstanceOf[ColumnImpl].copy(id = composedId): Column - } - - nestedIdMaps.put(ident, allColumnNestedIds) - rootIds.put(ident, allRootIds) - - val composedTable = new InMemoryTable( - table.name, - composedColumns, - table.partitioning, - table.properties, - table.constraints, - id = table.id) - composedTable.alterTableWithData(table.data, table.schema) - composedTable.setVersionAndValidatedVersionFrom(table) - tables.put(ident, composedTable) - composedTable - } - - override def alterTable(ident: Identifier, changes: TableChange*): Table = { - val oldTable = loadTable(ident).asInstanceOf[InMemoryTable] - val oldColumnNestedIds = Option(nestedIdMaps.get(ident)) - .getOrElse(mutable.Map[String, mutable.Map[Seq[Int], Long]]()) - val oldRootIds = Option(rootIds.get(ident)) - .getOrElse(mutable.Map[String, String]()) - - val alteredTable = super.alterTable(ident, changes: _*).asInstanceOf[InMemoryTable] - - val allColumnNestedIds = mutable.Map[String, mutable.Map[Seq[Int], Long]]() - val allRootIds = mutable.Map[String, String]() - val composedColumns: Array[Column] = alteredTable.columns().map { newColumn => - val columnName = newColumn.name().toLowerCase(Locale.ROOT) - val oldNestedFieldIds = - oldColumnNestedIds.getOrElse(columnName, mutable.Map[Seq[Int], Long]()) - - // Find the old column to compare data types for merging nested IDs - val oldColumnOpt = oldTable.columns() - .find(oldCol => oldCol.name().toLowerCase(Locale.ROOT) == columnName) - - val newNestedFieldIds = oldColumnOpt match { - case Some(oldColumn) => - // Column existed before: preserve IDs for positions that still exist, - // assign fresh IDs for new positions (e.g. a re-added nested field) - mergeNestedIds(oldNestedFieldIds, oldColumn.dataType(), newColumn.dataType()) - case None => - // Brand new column: assign fresh IDs to all nested positions - val freshIds = mutable.Map[Seq[Int], Long]() - assignNestedIds(newColumn.dataType(), parentPath = Seq.empty, freshIds) - freshIds - } - - allColumnNestedIds(columnName) = newNestedFieldIds - - // super.alterTable preserves IDs by name, so newColumn.id() is - // the previously composed string (e.g. "5[0:10,1:11]"). Passing - // that to encodeComposedId would produce "5[0:10,1:11][0:10,1:12]" - // instead of "5[0:10,1:12]". Use the original root ID (e.g. "5") - // from rootIds instead; fall back to newColumn.id() only for - // genuinely new columns whose ID is a fresh numeric string. - val rootId = oldRootIds.getOrElse(columnName, newColumn.id()) - allRootIds(columnName) = rootId - val composedId = encodeComposedId(rootId, newNestedFieldIds) - newColumn.asInstanceOf[ColumnImpl].copy(id = composedId): Column - } - - nestedIdMaps.put(ident, allColumnNestedIds) - rootIds.put(ident, allRootIds) - - val composedTable = new InMemoryTable( - alteredTable.name, - composedColumns, - alteredTable.partitioning, - alteredTable.properties, - alteredTable.constraints, - id = alteredTable.id) - composedTable.alterTableWithData(alteredTable.data, alteredTable.schema) - composedTable.setVersionAndValidatedVersionFrom(alteredTable) - tables.put(ident, composedTable) - composedTable - } - - /** - * Recursively assigns fresh IDs to every nested position in a data type: - * struct fields, array elements, map keys, and map values. - * - * Each position is identified by an ordinal path from the column root: - * - * `STRUCT>` produces: - * - Seq(0) -> id1 (name, position 0) - * - Seq(1) -> id2 (addr, position 1) - * - Seq(1, 0) -> id3 (addr.city, position 0 within addr) - * - * `ARRAY>` produces: - * - Seq(0) -> id1 (element, position 0) - * - Seq(0, 0) -> id2 (element.x, position 0 within element) - * - * `MAP>` produces: - * - Seq(0) -> id1 (key, position 0) - * - Seq(1) -> id2 (value, position 1) - * - Seq(1, 0) -> id3 (value.v, position 0 within value) - */ - private def assignNestedIds( - dataType: DataType, - parentPath: Seq[Int], - nestedFieldIds: mutable.Map[Seq[Int], Long]): Unit = { - dataType match { - case structType: StructType => - structType.fields.zipWithIndex.foreach { case (field, idx) => - val fieldPath = parentPath :+ idx - nestedFieldIds(fieldPath) = InMemoryBaseTable.nextColumnId() - assignNestedIds(field.dataType, fieldPath, nestedFieldIds) - } - case ArrayType(elementType, _) => - val elementPath = parentPath :+ 0 - nestedFieldIds(elementPath) = InMemoryBaseTable.nextColumnId() - assignNestedIds(elementType, elementPath, nestedFieldIds) - case MapType(keyType, valueType, _) => - val keyPath = parentPath :+ 0 - nestedFieldIds(keyPath) = InMemoryBaseTable.nextColumnId() - assignNestedIds(keyType, keyPath, nestedFieldIds) - val valuePath = parentPath :+ 1 - nestedFieldIds(valuePath) = InMemoryBaseTable.nextColumnId() - assignNestedIds(valueType, valuePath, nestedFieldIds) - case _ => // primitive types have no nested structure - } - } - - /** - * Merges nested IDs from old to new: preserves IDs for ordinal positions - * that exist in both old and new types, assigns fresh IDs for new positions. - * - * For example, if the old type is `STRUCT` with - * IDs {Seq(0)->10, Seq(1)->11}, and the new type is - * `STRUCT` after drop+re-add of `age`, then `age` - * gets a fresh ID 12 because its position was removed and re-added, while - * `name` keeps ID 10. - */ - private def mergeNestedIds( - oldFieldIds: mutable.Map[Seq[Int], Long], - oldType: DataType, - newType: DataType): mutable.Map[Seq[Int], Long] = { - val mergedFieldIds = mutable.Map[Seq[Int], Long]() - walkAndMerge(newType, parentPath = Seq.empty, mergedFieldIds, oldFieldIds) - mergedFieldIds - } - - /** - * Walks the new data type and for each nested position, either preserves - * the old ID (if the ordinal path existed before) or assigns a fresh one. - */ - private def walkAndMerge( - dataType: DataType, - parentPath: Seq[Int], - mergedFieldIds: mutable.Map[Seq[Int], Long], - oldFieldIds: mutable.Map[Seq[Int], Long]): Unit = { - dataType match { - case structType: StructType => - structType.fields.zipWithIndex.foreach { case (field, idx) => - val fieldPath = parentPath :+ idx - mergedFieldIds(fieldPath) = - oldFieldIds.getOrElse(fieldPath, InMemoryBaseTable.nextColumnId()) - walkAndMerge(field.dataType, fieldPath, mergedFieldIds, oldFieldIds) - } - case ArrayType(elementType, _) => - val elementPath = parentPath :+ 0 - mergedFieldIds(elementPath) = - oldFieldIds.getOrElse(elementPath, InMemoryBaseTable.nextColumnId()) - walkAndMerge(elementType, elementPath, mergedFieldIds, oldFieldIds) - case MapType(keyType, valueType, _) => - val keyPath = parentPath :+ 0 - mergedFieldIds(keyPath) = - oldFieldIds.getOrElse(keyPath, InMemoryBaseTable.nextColumnId()) - walkAndMerge(keyType, keyPath, mergedFieldIds, oldFieldIds) - val valuePath = parentPath :+ 1 - mergedFieldIds(valuePath) = - oldFieldIds.getOrElse(valuePath, InMemoryBaseTable.nextColumnId()) - walkAndMerge(valueType, valuePath, mergedFieldIds, oldFieldIds) - case _ => - } - } - - /** - * Encodes a root ID and its nested field IDs into a single deterministic string. - * Format: `rootId[path1:id1,path2:id2,...]` with paths sorted - * lexicographically by their dot-joined ordinal representation. - * - * Example: column `person STRUCT` with root ID "5" - * and nested field IDs {Seq(0)->10, Seq(1)->11} encodes as: - * `"5[0:10,1:11]"` - * - * If the column has no nested fields (e.g. `INT`), returns just the root ID. - */ - private def encodeComposedId( - rootId: String, - nestedFieldIds: mutable.Map[Seq[Int], Long]): String = { - if (nestedFieldIds.isEmpty) { - rootId - } else { - val sortedEntries = nestedFieldIds.toSeq.sortBy(_._1.mkString(".")) - val encoded = sortedEntries.map { case (fieldPath, fieldId) => - s"${fieldPath.mkString(".")}:$fieldId" - }.mkString(",") - s"$rootId[$encoded]" - } - } -} diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala index d63e3095a2ef4..2902bef2cda01 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/ConstraintSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.connector.catalog -import org.apache.spark.SparkFunSuite +import org.apache.spark.{SparkFunSuite, SparkIllegalArgumentException} import org.apache.spark.sql.connector.catalog.constraints.Constraint import org.apache.spark.sql.connector.catalog.constraints.Constraint.ValidationStatus import org.apache.spark.sql.connector.expressions.{Expression, FieldReference, LiteralValue, NamedReference} @@ -37,12 +37,13 @@ class ConstraintSuite extends SparkFunSuite { assert(con1.validationStatus() == ValidationStatus.VALID) val con2 = Constraint.check("con2") - .predicate( - new Predicate( - "=", - Array[Expression]( - FieldReference(Seq("a", "b.c", "d")), - LiteralValue(1, IntegerType)))) + .predicateSql("a.`b.c`.d = 1") + .predicate( + new Predicate( + "=", + Array[Expression]( + FieldReference(Seq("a", "b.c", "d")), + LiteralValue(1, IntegerType)))) .enforced(false) .validationStatus(ValidationStatus.VALID) .rely(true) @@ -70,6 +71,22 @@ class ConstraintSuite extends SparkFunSuite { assert(con4.validationStatus() == ValidationStatus.UNVALIDATED) } + test("CHECK constraint requires predicateSql") { + // predicateSql is the canonical representation of a CHECK condition and must always be present, + // even when a structured predicate is provided. + val noCondition = Constraint.check("con1") + val predicateOnly = Constraint.check("con2").predicate( + new Predicate( + "=", + Array[Expression](FieldReference(Seq("a")), LiteralValue(1, IntegerType)))) + Seq(noCondition, predicateOnly).foreach { builder => + checkError( + exception = intercept[SparkIllegalArgumentException](builder.build()), + condition = "INTERNAL_ERROR", + parameters = Map("message" -> "Predicate SQL can't be null in CHECK")) + } + } + test("UNIQUE constraint toDDL") { val con1 = Constraint.unique( "con1", diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala index 8bfcfc020fa12..78775d80cfa6d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryBaseTable.scala @@ -20,7 +20,6 @@ package org.apache.spark.sql.connector.catalog import java.time.{Instant, ZoneId} import java.time.temporal.ChronoUnit import java.util -import java.util.Locale import java.util.Objects import java.util.OptionalLong import java.util.concurrent.atomic.AtomicLong @@ -75,15 +74,9 @@ abstract class InMemoryBaseTable( // Stores the table version validated during the last `ALTER TABLE ... ADD CONSTRAINT` operation. private var validatedTableVersion: String = null - // Assign column IDs to columns that do not have one. - // This simulates connectors that support column identity tracking. - private var tableColumns: Array[Column] = initialColumns.map { c => - if (c.id() == null) { - c.asInstanceOf[ColumnImpl].copy(id = InMemoryBaseTable.nextColumnId().toString) - } else { - c - } - } + // Assign column IDs to columns that do not have one, including nested struct fields within + // arrays and maps. This simulates connectors that support column identity tracking. + private var tableColumns: Array[Column] = InMemoryBaseTable.assignMissingIds(initialColumns) override def columns(): Array[Column] = tableColumns @@ -780,10 +773,8 @@ abstract class InMemoryBaseTable( val mergedSchema = mergeSchema( oldType = CatalogV2Util.v2ColumnsToStructType(columns()), newType = newSchema) - val newColumns = CatalogV2Util.structTypeToV2Columns(mergedSchema) tableColumns = InMemoryBaseTable.assignMissingIds( - oldColumns = columns(), - newColumns = newColumns) + CatalogV2Util.structTypeToV2Columns(mergedSchema)) writer } @@ -916,32 +907,64 @@ abstract class InMemoryBaseTable( object InMemoryBaseTable { private val columnIdGlobalCounter = new AtomicLong(0) def nextColumnId(): Long = columnIdGlobalCounter.incrementAndGet() + def nextColumnIdString(): String = nextColumnId().toString - private def normalize(name: String): String = name.toLowerCase(Locale.ROOT) + // SQL conf key that enables column ID assignment + val ASSIGN_COLUMN_IDS = "spark.sql.test.inMemoryTable.assignColumnIds" /** - * Preserves column IDs from `oldColumns` when the column name matches, - * and assigns new IDs to columns that do not already have one. + * Assigns fresh IDs to any top-level column or nested struct field that does not already + * have one. Recurses into struct fields within ArrayType and MapType so that every field + * at every depth gets an ID. * - * IDs are preserved across type changes, keeping the same column ID through type - * widening and nested field additions. [[TypeChangeResetsColIdTableCatalog]] overrides - * this behavior for testing scenarios where type changes should produce a new ID. + * Existing IDs are preserved: Column -> StructType -> Column round-trip encodes them in + * StructField metadata (see StructField.FIELD_ID_METADATA_KEY), so only genuinely new fields + * arrive here without an ID. */ - def assignMissingIds( - oldColumns: Array[Column], - newColumns: Array[Column]): Array[Column] = { - newColumns.map { newCol => - oldColumns.find(c => normalize(c.name()) == normalize(newCol.name())) match { - case Some(oldCol) if oldCol.id() != null => - newCol.asInstanceOf[ColumnImpl].copy(id = oldCol.id()) - case _ if newCol.id() == null => - newCol.asInstanceOf[ColumnImpl].copy(id = nextColumnId().toString) - case _ => - newCol + def assignMissingIds(columns: Array[Column]): Array[Column] = { + if (!SQLConf.get.getConfString(ASSIGN_COLUMN_IDS, "false").toBoolean) return columns + columns.map { col => + val impl = col.asInstanceOf[ColumnImpl] + val colWithId = if (col.id == null) impl.copy(id = nextColumnIdString()) else impl + val updatedType = assignFieldIds(colWithId.dataType) + if (updatedType ne colWithId.dataType) { + colWithId.copy(dataType = updatedType) + } else { + colWithId } } } + private def assignFieldIds(dataType: DataType): DataType = dataType match { + case s: StructType => + val newFields = s.fields.map { field => + val fieldWithId = if (field.id.isEmpty) field.withId(nextColumnIdString()) else field + val updatedType = assignFieldIds(fieldWithId.dataType) + if (updatedType ne fieldWithId.dataType) { + fieldWithId.copy(dataType = updatedType) + } else { + fieldWithId + } + } + if (newFields.zip(s.fields).forall { case (n, e) => n eq e }) s else StructType(newFields) + + case a: ArrayType => + val updatedElement = assignFieldIds(a.elementType) + if (updatedElement ne a.elementType) a.copy(elementType = updatedElement) else a + + case m: MapType => + val updatedKeyType = assignFieldIds(m.keyType) + val updatedValueType = assignFieldIds(m.valueType) + if ((updatedKeyType ne m.keyType) || (updatedValueType ne m.valueType)) { + m.copy(keyType = updatedKeyType, valueType = updatedValueType) } + else { + m + } + + case other => + other + } + val SIMULATE_FAILED_WRITE_OPTION = "spark.sql.test.simulateFailedWrite" def extractValue( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableViewCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRelationCatalog.scala similarity index 83% rename from sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableViewCatalog.scala rename to sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRelationCatalog.scala index a3506938dea7c..d6f526b30ce09 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableViewCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRelationCatalog.scala @@ -26,46 +26,46 @@ import org.apache.spark.sql.catalyst.analysis.{NamespaceAlreadyExistsException, import org.apache.spark.sql.util.CaseInsensitiveStringMap /** - * An in-memory [[TableViewCatalog]] for tests. Tables and views share a single keyspace per - * the [[TableViewCatalog]] contract; the stored value's runtime type ([[TableInfo]] vs - * [[ViewInfo]]) is the kind discriminator. Also implements [[SupportsNamespaces]] with a + * An in-memory [[RelationCatalog]] for tests. Tables and views share a single keyspace per + * the [[RelationCatalog]] contract; the stored [[Relation]]'s runtime type ([[Table]] vs + * [[View]]) is the kind discriminator. Tables are stored as a [[DelegatingTable]] wrapping the + * [[TableInfo]] passed to `createTable`. Also implements [[SupportsNamespaces]] with a * minimal namespace store, so analyzer rules that read namespace metadata (e.g. * `ApplyDefaultCollation` consulting `loadNamespaceMetadata` for `PROP_COLLATION`) work * uniformly with the v1 session catalog. Suitable for any test suite that wants to exercise * v2 view DDL or inspection commands against a non-session catalog. */ -class InMemoryTableViewCatalog extends TableViewCatalog with SupportsNamespaces { +class InMemoryRelationCatalog extends RelationCatalog with SupportsNamespaces { private val store = - new ConcurrentHashMap[(Seq[String], String), TableInfo]() + new ConcurrentHashMap[(Seq[String], String), Relation]() private val namespaces = new ConcurrentHashMap[Seq[String], util.Map[String, String]]() - override def loadTableOrView(ident: Identifier): Table = { + override def loadRelation(ident: Identifier): Relation = { val key = (ident.namespace().toSeq, ident.name()) - Option(store.get(key)) - .map(new MetadataTable(_, ident.toString)) - .getOrElse(throw new NoSuchTableException(ident)) + Option(store.get(key)).getOrElse(throw new NoSuchTableException(ident)) } // ----- TableCatalog ----------------------------------------------------------------- override def createTable(ident: Identifier, info: TableInfo): Table = { val key = (ident.namespace().toSeq, ident.name()) - if (store.putIfAbsent(key, info) != null) { + val table = new DelegatingTable(info, ident.toString) + if (store.putIfAbsent(key, table) != null) { throw new TableAlreadyExistsException(ident) } - new MetadataTable(info, ident.toString) + table } override def alterTable(ident: Identifier, changes: TableChange*): Table = { - throw new UnsupportedOperationException("alterTable not supported on InMemoryTableViewCatalog") + throw new UnsupportedOperationException("alterTable not supported on InMemoryRelationCatalog") } override def dropTable(ident: Identifier): Boolean = { val key = (ident.namespace().toSeq, ident.name()) val existing = store.get(key) - if (existing == null || existing.isInstanceOf[ViewInfo]) return false + if (existing == null || existing.isInstanceOf[View]) return false store.remove(key) != null } @@ -73,7 +73,7 @@ class InMemoryTableViewCatalog extends TableViewCatalog with SupportsNamespaces val oldKey = (oldIdent.namespace().toSeq, oldIdent.name()) val newKey = (newIdent.namespace().toSeq, newIdent.name()) val existing = store.get(oldKey) - if (existing == null || existing.isInstanceOf[ViewInfo]) { + if (existing == null || existing.isInstanceOf[View]) { throw new NoSuchTableException(oldIdent) } if (store.putIfAbsent(newKey, existing) != null) { @@ -86,7 +86,7 @@ class InMemoryTableViewCatalog extends TableViewCatalog with SupportsNamespaces val target = namespace.toSeq val ids = new java.util.ArrayList[Identifier]() store.forEach { (key, info) => - if (key._1 == target && !info.isInstanceOf[ViewInfo]) { + if (key._1 == target && !info.isInstanceOf[View]) { ids.add(Identifier.of(key._1.toArray, key._2)) } } @@ -99,14 +99,14 @@ class InMemoryTableViewCatalog extends TableViewCatalog with SupportsNamespaces val target = namespace.toSeq val ids = new java.util.ArrayList[Identifier]() store.forEach { (key, info) => - if (key._1 == target && info.isInstanceOf[ViewInfo]) { + if (key._1 == target && info.isInstanceOf[View]) { ids.add(Identifier.of(key._1.toArray, key._2)) } } ids.toArray(new Array[Identifier](0)) } - override def createView(ident: Identifier, info: ViewInfo): ViewInfo = { + override def createView(ident: Identifier, info: View): View = { val key = (ident.namespace().toSeq, ident.name()) if (store.putIfAbsent(key, info) != null) { throw new ViewAlreadyExistsException(ident) @@ -114,10 +114,10 @@ class InMemoryTableViewCatalog extends TableViewCatalog with SupportsNamespaces info } - override def replaceView(ident: Identifier, info: ViewInfo): ViewInfo = { + override def replaceView(ident: Identifier, info: View): View = { val key = (ident.namespace().toSeq, ident.name()) val existing = store.get(key) - if (existing == null || !existing.isInstanceOf[ViewInfo]) { + if (existing == null || !existing.isInstanceOf[View]) { throw new NoSuchViewException(ident) } store.put(key, info) @@ -127,7 +127,7 @@ class InMemoryTableViewCatalog extends TableViewCatalog with SupportsNamespaces override def dropView(ident: Identifier): Boolean = { val key = (ident.namespace().toSeq, ident.name()) val existing = store.get(key) - if (existing == null || !existing.isInstanceOf[ViewInfo]) return false + if (existing == null || !existing.isInstanceOf[View]) return false store.remove(key) != null } @@ -135,7 +135,7 @@ class InMemoryTableViewCatalog extends TableViewCatalog with SupportsNamespaces val oldKey = (oldIdent.namespace().toSeq, oldIdent.name()) val newKey = (newIdent.namespace().toSeq, newIdent.name()) val existing = store.get(oldKey) - if (existing == null || !existing.isInstanceOf[ViewInfo]) { + if (existing == null || !existing.isInstanceOf[View]) { throw new NoSuchViewException(oldIdent) } if (store.putIfAbsent(newKey, existing) != null) { @@ -219,16 +219,16 @@ class InMemoryTableViewCatalog extends TableViewCatalog with SupportsNamespaces // Test-only accessors -------------------------------------------------------------- /** Returns the stored entry (table or view) for the identifier, or throws if missing. */ - def getStoredInfo(namespace: Array[String], name: String): TableInfo = { + def getStoredInfo(namespace: Array[String], name: String): Relation = { Option(store.get((namespace.toSeq, name))).getOrElse { throw new NoSuchTableException(Identifier.of(namespace, name)) } } - /** Returns the stored ViewInfo, or throws if the entry is missing or is not a view. */ - def getStoredView(namespace: Array[String], name: String): ViewInfo = { + /** Returns the stored View, or throws if the entry is missing or is not a view. */ + def getStoredView(namespace: Array[String], name: String): View = { getStoredInfo(namespace, name) match { - case v: ViewInfo => v + case v: View => v case _ => throw new IllegalStateException( s"stored entry at ${namespace.mkString(".")}.$name is not a view") } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala index cdc59ff637c0b..e420e84c76d36 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryRowLevelOperationTableCatalog.scala @@ -92,8 +92,7 @@ class InMemoryRowLevelOperationTableCatalog } val columnsWithIds = InMemoryBaseTable.assignMissingIds( - oldColumns = table.columns(), - newColumns = CatalogV2Util.structTypeToV2Columns(schema)) + CatalogV2Util.structTypeToV2Columns(schema)) val newTable = InMemoryRowLevelOperationTable.withColumns( name = table.name, diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala index c9a6c4acfa014..bb137ba4830df 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/InMemoryTableCatalog.scala @@ -207,8 +207,7 @@ class BasicInMemoryTableCatalog extends TableCatalog { table.increaseVersion() val currentVersion = table.version() val columnsWithIds = InMemoryBaseTable.assignMissingIds( - oldColumns = table.columns(), - newColumns = CatalogV2Util.structTypeToV2Columns(schema)) + CatalogV2Util.structTypeToV2Columns(schema)) val newTable = table match { case _: InMemoryTable => new InMemoryTable( diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullColumnIdInMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullColumnIdInMemoryTableCatalog.scala index c26ce263c1f8b..fead2e6fba79d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullColumnIdInMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullColumnIdInMemoryTableCatalog.scala @@ -30,9 +30,9 @@ import org.apache.spark.sql.internal.connector.ColumnImpl * override [[columns]] to strip IDs. Data is copied from the table * created by the parent [[InMemoryTableCatalog]]. * - * When column IDs are null, [[V2TableUtil.validateColumnIds]] + * When field IDs are null, field ID validation in [[org.apache.spark.sql.util.SchemaUtils]] * skips validation entirely, meaning drop/re-add of a column is NOT - * detected via column IDs. + * detected via field IDs. */ class NullColumnIdInMemoryTableCatalog extends InMemoryTableCatalog { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullTableIdAndNullColumnIdInMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullTableIdAndNullColumnIdInMemoryTableCatalog.scala index df7964f63b855..c1d45e42fc467 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullTableIdAndNullColumnIdInMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullTableIdAndNullColumnIdInMemoryTableCatalog.scala @@ -27,7 +27,7 @@ import org.apache.spark.sql.internal.connector.ColumnImpl * connectors that support neither table nor column identity tracking. * * When both IDs are null, neither the table identity check in [[V2TableRefreshUtil]] - * nor [[V2TableUtil.validateColumnIds]] fires, so drop/recreate of a table or + * nor the column schema check fires, so drop/recreate of a table or * drop/re-add of a column goes undetected. */ class NullTableIdAndNullColumnIdInMemoryTableCatalog extends InMemoryTableCatalog { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullTableIdInMemoryTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullTableIdInMemoryTableCatalog.scala index 391eb619535f5..ae079fac17141 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullTableIdInMemoryTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/NullTableIdInMemoryTableCatalog.scala @@ -29,7 +29,7 @@ package org.apache.spark.sql.connector.catalog * This is to test the scenario where connectors do not implement * table IDs but do implement column IDs. In this scenario, column * IDs assigned by [[InMemoryBaseTable]] still differ after recreate, - * so [[V2TableUtil.validateColumnIds]] catches the schema change. + * so field ID validation in [[org.apache.spark.sql.util.SchemaUtils]] catches the schema change. */ class NullTableIdInMemoryTableCatalog extends InMemoryTableCatalog { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TypeChangeResetsColIdTableCatalog.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TypeChangeResetsColIdTableCatalog.scala index d68f2e62b1365..83effcef2b266 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TypeChangeResetsColIdTableCatalog.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/TypeChangeResetsColIdTableCatalog.scala @@ -23,9 +23,8 @@ import org.apache.spark.sql.internal.connector.ColumnImpl /** * An [[InMemoryTableCatalog]] that assigns fresh column IDs when the - * column's data type changes. This is the inverse of the default - * [[InMemoryBaseTable.assignMissingIds]] behavior, which preserves IDs - * across type changes. + * column's data type changes, overriding the default behavior where type + * changes preserve the existing column ID. * * Use this catalog for tests that need a type change to produce a new * column ID (e.g., verifying that adding a nested field to a container diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2TableUtilSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2TableUtilSuite.scala index c02c517ff546b..278903299f056 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2TableUtilSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/V2TableUtilSuite.scala @@ -188,7 +188,11 @@ class V2TableUtilSuite extends SparkFunSuite { col("address", currentStructType, nullable = true)) val table = TestTableWithMetadataSupport("test", currentCols) - val errors = V2TableUtil.validateCapturedColumns(table, originCols.toSeq) + val errors = V2TableUtil.validateCapturedColumns( + table, + originCols.toSeq, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.size == 1) assert(errors.head.contains("`address`.`city` type has changed from STRING to INT")) } @@ -202,7 +206,11 @@ class V2TableUtilSuite extends SparkFunSuite { metaCol("index", IntegerType, nullable = false)) val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols) - val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + val errors = V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.isEmpty, "No changes should produce no errors") } @@ -213,7 +221,11 @@ class V2TableUtilSuite extends SparkFunSuite { metaCol("index", StringType, nullable = false)) // changed to StringType val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols) - val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + val errors = V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.size == 1) assert(errors.head == "`index` type has changed from INT to STRING") } @@ -225,7 +237,11 @@ class V2TableUtilSuite extends SparkFunSuite { metaCol("index", IntegerType, nullable = true)) // now nullable val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols) - val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + val errors = V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.size == 1) assert(errors.head == "`index` is nullable now") } @@ -235,7 +251,11 @@ class V2TableUtilSuite extends SparkFunSuite { val currentMetaCols = Array(metaCol("index", IntegerType, nullable = false)) // now NOT NULL val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols) - val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + val errors = V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.size == 1) assert(errors.head == "`index` is no longer nullable") } @@ -245,7 +265,11 @@ class V2TableUtilSuite extends SparkFunSuite { val currentMetaCols = Array.empty[MetadataColumn] // no metadata columns val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols) - val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + val errors = V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.size == 1) assert(errors.head == "`index` INT has been removed") } @@ -255,7 +279,11 @@ class V2TableUtilSuite extends SparkFunSuite { val table = TestTable("test", Array(col("id", LongType, nullable = true))) val originMetaCols = Seq(metaCol("index", IntegerType, nullable = false)) - val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + val errors = V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.size == 1) assert(errors.head == "`index` INT NOT NULL has been removed") } @@ -268,7 +296,11 @@ class V2TableUtilSuite extends SparkFunSuite { metaCol("_partition", IntegerType, nullable = false)) // type changed from StringType val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols) - val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + val errors = V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.size == 2) assert(errors.exists(e => e.contains("_partition") && e.contains("type has changed"))) assert(errors.exists(e => e.contains("index") && e.contains("removed"))) @@ -279,7 +311,11 @@ class V2TableUtilSuite extends SparkFunSuite { val currentMetaCols = Array(metaCol("INDEX", IntegerType, nullable = true)) // uppercase val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols) - val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + val errors = V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.isEmpty, "Case insensitive comparison should match") } @@ -294,7 +330,11 @@ class V2TableUtilSuite extends SparkFunSuite { val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols) val e = intercept[AnalysisException] { - V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) } assert(e.message.contains("Choose another name or rename the existing column")) } @@ -304,7 +344,11 @@ class V2TableUtilSuite extends SparkFunSuite { val currentMetaCols = Array.empty[MetadataColumn] val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols) - val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + val errors = V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.isEmpty, "No metadata columns should produce no errors") } @@ -316,7 +360,11 @@ class V2TableUtilSuite extends SparkFunSuite { val currentMetaCols = Array(metaCol("_partition", structType, nullable = false)) val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols) - val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + val errors = V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.isEmpty) } @@ -331,7 +379,11 @@ class V2TableUtilSuite extends SparkFunSuite { val currentMetaCols = Array(metaCol("_partition", currentStructType, nullable = false)) val table = TestTableWithMetadataSupport("test", Array.empty, currentMetaCols) - val errors = V2TableUtil.validateCapturedMetadataColumns(table, originMetaCols) + val errors = V2TableUtil.validateCapturedMetadataColumns( + table, + originMetaCols, + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.size == 1) assert(errors.head.contains("`_partition`.`bucket` type has changed from INT to STRING")) } @@ -364,7 +416,8 @@ class V2TableUtilSuite extends SparkFunSuite { val errors = V2TableUtil.validateCapturedMetadataColumns( currentTable, relation, - mode = PROHIBIT_CHANGES) + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.size == 1) assert(errors.head.contains("`_partition` type has changed")) } @@ -389,7 +442,8 @@ class V2TableUtilSuite extends SparkFunSuite { val errors = V2TableUtil.validateCapturedMetadataColumns( currentTable, relation, - mode = PROHIBIT_CHANGES) + mode = PROHIBIT_CHANGES, + checkIds = true) assert(errors.isEmpty) } @@ -550,6 +604,152 @@ class V2TableUtilSuite extends SparkFunSuite { assert(errors.head == "`person`.`attrs`.`value` type has changed from INT to BIGINT") } + // --------------------------------------------------------------------------- + // Field ID change error messages + // --------------------------------------------------------------------------- + + test("validateCapturedColumns - top-level column ID changed") { + val originCols = Array( + colWithId("id", LongType, nullable = false, id = "1"), + colWithId("name", StringType, nullable = true, id = "2")) + val currentCols = Array( + colWithId("id", LongType, nullable = false, id = "99"), // ID changed + colWithId("name", StringType, nullable = true, id = "2")) + val table = TestTableWithMetadataSupport("test", currentCols) + + val errors = validateCapturedColumns(table, originCols) + assert(errors.size == 1) + assert(errors.head == "`id` field ID has changed from 1 to 99") + } + + test("validateCapturedColumns - nested struct field ID changed") { + val originStruct = StructType(Seq( + StructField("name", StringType).withId("10"), + StructField("age", IntegerType).withId("11"))) + val originCols = Array(col("person", originStruct, nullable = true)) + + val currentStruct = StructType(Seq( + StructField("name", StringType).withId("10"), + StructField("age", IntegerType).withId("99"))) // age ID changed + val currentCols = Array(col("person", currentStruct, nullable = true)) + val table = TestTableWithMetadataSupport("test", currentCols) + + val errors = validateCapturedColumns(table, originCols) + assert(errors.size == 1) + assert(errors.head == "`person`.`age` field ID has changed from 11 to 99") + } + + test("validateCapturedColumns - doubly-nested struct field ID changed") { + val originInner = StructType(Seq(StructField("age", IntegerType).withId("11"))) + val originOuter = StructType(Seq(StructField("info", originInner).withId("10"))) + val originCols = Array(col("person", originOuter, nullable = true)) + + val currentInner = StructType(Seq(StructField("age", IntegerType).withId("99"))) // changed + val currentOuter = StructType(Seq(StructField("info", currentInner).withId("10"))) + val currentCols = Array(col("person", currentOuter, nullable = true)) + val table = TestTableWithMetadataSupport("test", currentCols) + + val errors = validateCapturedColumns(table, originCols) + assert(errors.size == 1) + assert(errors.head == "`person`.`info`.`age` field ID has changed from 11 to 99") + } + + test("validateCapturedColumns - multiple nested struct field IDs changed") { + val originStruct = StructType(Seq( + StructField("name", StringType).withId("10"), + StructField("age", IntegerType).withId("11"), + StructField("score", DoubleType).withId("12"))) + val originCols = Array(col("person", originStruct, nullable = true)) + + val currentStruct = StructType(Seq( + StructField("name", StringType).withId("10"), + StructField("age", IntegerType).withId("98"), // changed + StructField("score", DoubleType).withId("99"))) // changed + val currentCols = Array(col("person", currentStruct, nullable = true)) + val table = TestTableWithMetadataSupport("test", currentCols) + + val errors = validateCapturedColumns(table, originCols) + assert(errors.size == 2) + assert(errors.exists(_ == "`person`.`age` field ID has changed from 11 to 98")) + assert(errors.exists(_ == "`person`.`score` field ID has changed from 12 to 99")) + } + + test("validateCapturedColumns - nested field ID changed in array element struct") { + val originElem = StructType(Seq( + StructField("name", StringType).withId("10"), + StructField("price", IntegerType).withId("11"))) + val originCols = Array(col("items", ArrayType(originElem), nullable = true)) + + val currentElem = StructType(Seq( + StructField("name", StringType).withId("10"), + StructField("price", IntegerType).withId("99"))) // price ID changed + val currentCols = Array(col("items", ArrayType(currentElem), nullable = true)) + val table = TestTableWithMetadataSupport("test", currentCols) + + val errors = validateCapturedColumns(table, originCols) + assert(errors.size == 1) + assert(errors.head == "`items`.`element`.`price` field ID has changed from 11 to 99") + } + + test("validateCapturedColumns - nested field ID changed in map value struct") { + val originValue = StructType(Seq( + StructField("count", IntegerType).withId("10"), + StructField("label", StringType).withId("11"))) + val originCols = Array(col("props", MapType(StringType, originValue), nullable = true)) + + val currentValue = StructType(Seq( + StructField("count", IntegerType).withId("10"), + StructField("label", StringType).withId("99"))) // label ID changed + val currentCols = Array(col("props", MapType(StringType, currentValue), nullable = true)) + val table = TestTableWithMetadataSupport("test", currentCols) + + val errors = validateCapturedColumns(table, originCols) + assert(errors.size == 1) + assert(errors.head == "`props`.`value`.`label` field ID has changed from 11 to 99") + } + + test("validateCapturedColumns - field ID unchanged produces no error") { + val struct = StructType(Seq( + StructField("name", StringType).withId("10"), + StructField("age", IntegerType).withId("11"))) + val cols = Array(col("person", struct, nullable = true)) + val table = TestTableWithMetadataSupport("test", cols) + + val errors = validateCapturedColumns(table, cols) + assert(errors.isEmpty) + } + + test("validateCapturedColumns - reordered nested fields produce no error") { + val originStruct = StructType(Seq( + StructField("name", StringType).withId("10"), + StructField("age", IntegerType).withId("11"))) + val currentStruct = StructType(Seq( + StructField("age", IntegerType).withId("11"), // reordered + StructField("name", StringType).withId("10"))) + val originCols = Array(col("person", originStruct, nullable = true)) + val table = TestTableWithMetadataSupport("test", + Array(col("person", currentStruct, nullable = true))) + + val errors = validateCapturedColumns(table, originCols) + assert(errors.isEmpty, "reordering nested fields should not produce errors") + } + + test("validateCapturedColumns - field ID check disabled") { + val originStruct = StructType(Seq(StructField("age", IntegerType).withId("11"))) + val originCols = Array(col("person", originStruct, nullable = true)) + + val currentStruct = StructType(Seq(StructField("age", IntegerType).withId("99"))) + val currentCols = Array(col("person", currentStruct, nullable = true)) + val table = TestTableWithMetadataSupport("test", currentCols) + + val errors = V2TableUtil.validateCapturedColumns( + table, + originCols.toImmutableArraySeq, + mode = PROHIBIT_CHANGES, + checkIds = false) + assert(errors.isEmpty, "Disabled field ID check should not report ID changes") + } + test("validateCapturedColumns - ALLOW_NEW_TOP_LEVEL_FIELDS allows top-level additions") { val originCols = Array( col("id", LongType, nullable = false), @@ -604,7 +804,8 @@ class V2TableUtilSuite extends SparkFunSuite { val errors = V2TableUtil.validateCapturedColumns( table, originCols.toImmutableArraySeq, - mode = ALLOW_NEW_TOP_LEVEL_FIELDS) + mode = ALLOW_NEW_TOP_LEVEL_FIELDS, + checkIds = true) assert(errors.size == 1) assert(errors.head.contains("`items`.`element`.`price` INT has been added")) } @@ -630,23 +831,6 @@ class V2TableUtilSuite extends SparkFunSuite { assert(errors.head.contains("`metadata`.`value`.`timestamp` BIGINT has been added")) } - test("validateColumnIds - multiple errors") { - val originalCols = Seq( - colWithId("salary", IntegerType, nullable = true, id = "id-1"), - colWithId("bonus", IntegerType, nullable = true, id = "id-2")) - val currentCols = Array( - colWithId("salary", IntegerType, nullable = true, id = "id-100"), - colWithId("bonus", IntegerType, nullable = true, id = "id-200")) - val table = TestTableWithMetadataSupport("test", currentCols) - - val errors = V2TableUtil.validateColumnIds( - table = table, - originalCapturedCols = originalCols) - assert(errors == Seq( - "`salary` column ID has changed from id-1 to id-100", - "`bonus` column ID has changed from id-2 to id-200")) - } - // simple table without metadata column support private case class TestTable( override val name: String, @@ -688,7 +872,11 @@ class V2TableUtilSuite extends SparkFunSuite { table: Table, originCols: Array[Column], mode: SchemaValidationMode = PROHIBIT_CHANGES): Seq[String] = { - V2TableUtil.validateCapturedColumns(table, originCols.toImmutableArraySeq, mode) + V2TableUtil.validateCapturedColumns( + table, + originCols.toImmutableArraySeq, + mode, + checkIds = true) } private def col(name: String, dataType: DataType, nullable: Boolean): Column = { diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala index 203aed450a5f7..0b912c449132d 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/connector/catalog/txns.scala @@ -119,13 +119,10 @@ class TxnTable( // The starting version should be the delegate version. setVersion(delegate.version()) - // Preserve column IDs from the delegate so that column ID validation can correctly detect - // drop-and-re-add scenarios (different IDs) and pass when columns are unchanged (same IDs). - // Uses assignMissingIds to keep the delegate's IDs for existing columns while assigning - // fresh IDs for any new columns added by schema evolution. - updateColumns(InMemoryBaseTable.assignMissingIds( - oldColumns = delegate.columns(), - newColumns = columns())) + // Column IDs for existing columns are preserved through the StructType round-trip via + // metadata encoding. assignMissingIds assigns fresh IDs to any new columns added by + // schema evolution. + updateColumns(InMemoryBaseTable.assignMissingIds(columns())) alterTableWithData(delegate.data, schema) diff --git a/sql/connect/client/jdbc/pom.xml b/sql/connect/client/jdbc/pom.xml index 20c921d661ec4..952e6c44aaf81 100644 --- a/sql/connect/client/jdbc/pom.xml +++ b/sql/connect/client/jdbc/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../../../pom.xml diff --git a/sql/connect/client/jvm/pom.xml b/sql/connect/client/jvm/pom.xml index ba2c314d2799a..3581129b8d58e 100644 --- a/sql/connect/client/jvm/pom.xml +++ b/sql/connect/client/jvm/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../../../pom.xml diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala index 22feaff1c77f1..ceb561f695961 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/SparkConnectClientSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.connect.client import java.nio.charset.StandardCharsets.UTF_8 import java.util.{Base64, UUID} -import java.util.concurrent.TimeUnit +import java.util.concurrent.{Executor, TimeUnit} import scala.collection.mutable import scala.jdk.CollectionConverters._ import com.google.protobuf.{Any => PAny, StringValue} -import io.grpc.{CallOptions, Channel, ClientCall, ClientInterceptor, Metadata, MethodDescriptor, Server, ServerCall, ServerCallHandler, ServerInterceptor, Status, StatusRuntimeException} +import io.grpc.{CallCredentials, CallOptions, Channel, ClientCall, ClientInterceptor, Metadata, MethodDescriptor, Server, ServerCall, ServerCallHandler, ServerInterceptor, Status, StatusRuntimeException} import io.grpc.netty.NettyServerBuilder import io.grpc.stub.StreamObserver import org.scalatest.concurrent.Eventually @@ -824,6 +824,28 @@ class SparkConnectClientSuite extends ConnectFunSuite { assert(!headerInterceptor.headers.exists(_.containsKey(key))) } } + + test("SPARK-57336: access token is sent in the standard Authorization header") { + val token = "test-token-12345" + val creds = new SparkConnectClient.AccessTokenCallCredentials(token) + + var captured: Option[Metadata] = None + var failure: Option[Status] = None + val applier = new CallCredentials.MetadataApplier { + override def apply(headers: Metadata): Unit = captured = Some(headers) + override def fail(status: Status): Unit = failure = Some(status) + } + val sameThreadExecutor = new Executor { + override def execute(command: Runnable): Unit = command.run() + } + + creds.applyRequestMetadata(null, sameThreadExecutor, applier) + + val authorizationKey = Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER) + assert(failure.isEmpty, s"unexpected failure: ${failure.orNull}") + assert(captured.isDefined) + assert(captured.get.get(authorizationKey) === s"Bearer $token") + } } class DummySparkConnectService() extends SparkConnectServiceGrpc.SparkConnectServiceImplBase { diff --git a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala index ac83300ac20ac..9f171cda2ea23 100644 --- a/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala +++ b/sql/connect/client/jvm/src/test/scala/org/apache/spark/sql/connect/client/arrow/ArrowEncoderSuite.scala @@ -16,15 +16,20 @@ */ package org.apache.spark.sql.connect.client.arrow +import java.io.File import java.math.BigInteger +import java.net.URLClassLoader import java.time.{Duration, Period, ZoneOffset} import java.time.temporal.ChronoUnit import java.util import java.util.{Collections, Objects} +import java.util.concurrent.{ConcurrentLinkedQueue, CyclicBarrier} import scala.beans.BeanProperty import scala.collection.mutable +import scala.jdk.CollectionConverters._ import scala.reflect.classTag +import scala.reflect.runtime.{universe => ru} import org.apache.arrow.memory.{BufferAllocator, RootAllocator} import org.apache.arrow.vector.VarBinaryVector @@ -1160,6 +1165,76 @@ class ArrowEncoderSuite extends ConnectFunSuite { } } } + + // SPARK-57371: ArrowDeserializers resolves Scala collection companions and Enumeration modules + // via runtime reflection, which is not thread-safe (scala/bug#6240): a concurrent + // `mirror.classSymbol(cls).companion/.module.asModule` can observe the symbol as `NoSymbol` and + // throw `ScalaReflectionException: is not a module`. ArrowDeserializers serializes the + // reflection through a single monitor. The race only manifests while a mirror's symbol table is + // cold, so each repetition below builds a fresh mirror over a classloader parented at the + // platform loader (so `scala.*` is reloaded cold) and drives the real synchronized method from + // several threads released at once; without the lock it races red. + + private val collectionCompanionClassNames = Seq( + "scala.collection.immutable.List", + "scala.collection.immutable.Vector", + "scala.collection.immutable.Set", + "scala.collection.immutable.Map", + "scala.collection.mutable.ArrayBuffer", + "scala.collection.mutable.HashMap") + + /** A fresh classloader parented at the platform loader, so `scala.*` is reloaded cold. */ + private def newColdLoader(): URLClassLoader = { + val urls = System + .getProperty("java.class.path") + .split(File.pathSeparator) + .filter(_.nonEmpty) + .map(p => new File(p).toURI.toURL) + new URLClassLoader(urls, ClassLoader.getPlatformClassLoader) + } + + // Drive `resolve` against a fresh cold mirror from 8 threads, 50 times; fail on any error/hang. + private def hammerReflection( + names: Seq[String], + resolve: (ru.Mirror, Class[_]) => Any): Unit = { + val errors = new ConcurrentLinkedQueue[Throwable]() + for (_ <- 0 until 50) { + val loader = newColdLoader() + val mirror = ru.runtimeMirror(loader) + val classes = names.map(loader.loadClass) + val barrier = new CyclicBarrier(8) + val threads = (0 until 8).map { _ => + new Thread(() => { + barrier.await() // release all threads simultaneously onto the cold mirror + classes.foreach { cls => + try resolve(mirror, cls) + catch { case e: Throwable => errors.add(e) } + } + }) + } + threads.foreach(_.start()) + threads.foreach { t => + t.join(60000) + assert(!t.isAlive, "thread did not finish within 60s (possible deadlock)") + } + } + assert( + errors.isEmpty, + s"reflection raced under concurrent access (${errors.size} error(s)): " + + errors.asScala.map(e => s"${e.getClass.getName}: ${e.getMessage}").toSet.mkString("; ")) + } + + test("SPARK-57371: resolveCompanion is thread-safe under concurrent cold-mirror access") { + hammerReflection( + collectionCompanionClassNames, + (m, c) => ArrowDeserializers.resolveCompanionFromMirror(m, c)) + } + + test("SPARK-57371: resolveEnum is thread-safe under concurrent cold-mirror access") { + hammerReflection( + Seq(FooEnum.getClass.getName), + (m, c) => ArrowDeserializers.resolveEnumFromMirror(m, c)) + } } // TODO fix actual Null fields, e.g.: nullable: Null diff --git a/sql/connect/common/pom.xml b/sql/connect/common/pom.xml index 870a1aed08643..7cd10764d8f47 100644 --- a/sql/connect/common/pom.xml +++ b/sql/connect/common/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../../pom.xml diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala index d9b9ba35b5e6c..de5e5e7744f24 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/SparkConnectClient.scala @@ -705,7 +705,7 @@ object SparkConnectClient { private val DEFAULT_USER_AGENT: String = "_SPARK_CONNECT_SCALA" private val AUTH_TOKEN_META_DATA_KEY: Metadata.Key[String] = - Metadata.Key.of("Authentication", Metadata.ASCII_STRING_MARSHALLER) + Metadata.Key.of("Authorization", Metadata.ASCII_STRING_MARSHALLER) // for internal tests private[sql] def apply(channel: ManagedChannel): SparkConnectClient = { diff --git a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala index f2786c61d1b59..4f3eab64e45a0 100644 --- a/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala +++ b/sql/connect/common/src/main/scala/org/apache/spark/sql/connect/client/arrow/ArrowDeserializer.scala @@ -151,9 +151,9 @@ object ArrowDeserializers { } } case (ScalaEnumEncoder(parent, _), v: FieldVector) => - val mirror = scala.reflect.runtime.currentMirror - val module = mirror.classSymbol(parent).module.asModule - val enumeration = mirror.reflectModule(module).instance.asInstanceOf[Enumeration] + // Scala runtime reflection is not thread-safe (scala/bug#6240). Synchronize to + // prevent races that surface as "... is not a module" under concurrent access. + val enumeration = resolveEnum(parent) new LeafFieldDeserializer[Enumeration#Value](encoder, v, timeZoneId) { override def value(i: Int): Enumeration#Value = { enumeration.withName(reader.getString(i)) @@ -445,11 +445,49 @@ object ArrowDeserializers { /** * Resolve the companion object for a scala class. In our particular case the class we pass in * is a Scala collection. We use the companion to create a builder for that collection. + * + * Scala runtime reflection is not thread-safe (scala/bug#6240): concurrent calls to + * `classSymbol(...).companion` can race, leaving the companion as `NoSymbol` so that + * `.asModule` throws `ScalaReflectionException: is not a module`. We serialize the + * reflection through this object's monitor (see [[resolveCompanionFromMirror]]) to prevent it. + */ + private[arrow] def resolveCompanion[T](tag: ClassTag[_]): T = + resolveCompanionFromMirror(scala.reflect.runtime.currentMirror, tag.runtimeClass) + .asInstanceOf[T] + + /** + * Synchronized reflection to resolve a companion object. The mirror is passed in rather than + * read from `currentMirror` so that the concurrency regression test can drive this exact + * (synchronized) method against a deliberately cold mirror, where the race would otherwise + * surface. Production always passes `currentMirror`. + */ + private[arrow] def resolveCompanionFromMirror( + mirror: scala.reflect.runtime.universe.Mirror, + cls: Class[_]): Any = synchronized { + val module = mirror.classSymbol(cls).companion.asModule + mirror.reflectModule(module).instance + } + + /** + * Resolve a Scala Enumeration parent class to its module instance. Reads `currentMirror` and + * delegates to [[resolveEnumFromMirror]], mirroring the [[resolveCompanion]] / + * [[resolveCompanionFromMirror]] split. + */ + private def resolveEnum(parent: Class[_]): Enumeration = + resolveEnumFromMirror(scala.reflect.runtime.currentMirror, parent).asInstanceOf[Enumeration] + + /** + * Synchronized reflection to resolve a Scala Enumeration's module instance. As with + * [[resolveCompanionFromMirror]], the mirror is passed in rather than read from `currentMirror` + * so the concurrency regression test can drive this exact (synchronized) method against a + * deliberately cold mirror. Synchronized on the same monitor as [[resolveCompanionFromMirror]] + * for the same thread-safety reasons. Production always passes `currentMirror`. */ - private[arrow] def resolveCompanion[T](tag: ClassTag[_]): T = { - val mirror = scala.reflect.runtime.currentMirror - val module = mirror.classSymbol(tag.runtimeClass).companion.asModule - mirror.reflectModule(module).instance.asInstanceOf[T] + private[arrow] def resolveEnumFromMirror( + mirror: scala.reflect.runtime.universe.Mirror, + parent: Class[_]): Any = synchronized { + val module = mirror.classSymbol(parent).module.asModule + mirror.reflectModule(module).instance } /** diff --git a/sql/connect/server/pom.xml b/sql/connect/server/pom.xml index f70eb91b0291c..0ec9a8312fba6 100644 --- a/sql/connect/server/pom.xml +++ b/sql/connect/server/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../../pom.xml diff --git a/sql/connect/server/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin b/sql/connect/server/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin new file mode 100644 index 0000000000000..5091829841fed --- /dev/null +++ b/sql/connect/server/src/main/resources/META-INF/services/org.apache.spark.status.AppHistoryServerPlugin @@ -0,0 +1,18 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +org.apache.spark.sql.connect.ui.SparkConnectServerHistoryServerPlugin diff --git a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala index 98dccc6c9a6c8..01a775f897442 100644 --- a/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala +++ b/sql/connect/server/src/main/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListener.scala @@ -22,7 +22,7 @@ import java.util.concurrent.{ConcurrentHashMap, ConcurrentMap} import scala.collection.mutable import scala.collection.mutable.ArrayBuffer -import org.apache.spark.{SparkConf, SparkContext, SparkEnv} +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.internal.Logging import org.apache.spark.internal.LogKeys.{OP_ID, SESSION_ID} import org.apache.spark.internal.config.Status.LIVE_ENTITY_UPDATE_PERIOD @@ -45,9 +45,7 @@ private[connect] class SparkConnectServerListener( new ConcurrentHashMap[String, LiveExecutionData] private val (retainedStatements: Int, retainedSessions: Int) = { - ( - SparkEnv.get.conf.get(CONNECT_UI_STATEMENT_LIMIT), - SparkEnv.get.conf.get(CONNECT_UI_SESSION_LIMIT)) + (sparkConf.get(CONNECT_UI_STATEMENT_LIMIT), sparkConf.get(CONNECT_UI_SESSION_LIMIT)) } // How often to update live entities. -1 means "never update" when replaying applications, diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListenerSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListenerSuite.scala index c9c110dd1e626..e9cd7c692373b 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListenerSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/ui/SparkConnectServerListenerSuite.scala @@ -210,6 +210,30 @@ class SparkConnectServerListenerSuite listener.onOtherEvent(SparkListenerConnectOperationClosed(unknownJob, "operationId", 0)) } + test("SPARK-57601: listener can be created without an active SparkEnv (history server)") { + // The History Server replays event logs without a SparkContext/SparkEnv, so + // SparkEnv.get returns null there. The listener must read its config from the + // SparkConf passed to its constructor rather than from SparkEnv.get.conf. + val previousEnv = SparkEnv.get + try { + SparkEnv.set(null) + val sparkConf = new SparkConf() + .set(ASYNC_TRACKING_ENABLED, false) + .set(LIVE_ENTITY_UPDATE_PERIOD, 0L) + .set(CONNECT_UI_SESSION_LIMIT, 1) + .set(CONNECT_UI_STATEMENT_LIMIT, 10) + val store = new ElementTrackingStore(new InMemoryStore, sparkConf) + try { + val listener = new SparkConnectServerListener(store, sparkConf, live = false) + assert(listener.noLiveData()) + } finally { + store.close(false) + } + } finally { + SparkEnv.set(previousEnv) + } + } + private def createProperties: Properties = { val properties = new Properties() properties.setProperty(SparkContext.SPARK_JOB_TAGS, jobTag) @@ -221,7 +245,6 @@ class SparkConnectServerListenerSuite sparkConf .set(ASYNC_TRACKING_ENABLED, false) .set(LIVE_ENTITY_UPDATE_PERIOD, 0L) - SparkEnv.get.conf .set(CONNECT_UI_SESSION_LIMIT, 1) .set(CONNECT_UI_STATEMENT_LIMIT, 10) kvstore = new ElementTrackingStore(new InMemoryStore, sparkConf) diff --git a/sql/connect/shims/pom.xml b/sql/connect/shims/pom.xml index 698b8129940a8..fdc41d850ded3 100644 --- a/sql/connect/shims/pom.xml +++ b/sql/connect/shims/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../../pom.xml diff --git a/sql/core/pom.xml b/sql/core/pom.xml index 9bc4824b603a3..30b85c79a521d 100644 --- a/sql/core/pom.xml +++ b/sql/core/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala index f61e6da8583e1..ba574c091dae9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroDeserializer.scala @@ -204,9 +204,11 @@ private[sql] class AvroDeserializer( } case (LONG, _: TimeType) => avroType.getLogicalType match { + // The time-micros logical type stores microseconds-since-midnight, while TimeType + // is represented internally as nanoseconds-since-midnight, so convert micros to nanos. case _: LogicalTypes.TimeMicros => (updater, ordinal, value) => val micros = value.asInstanceOf[Long] - updater.setLong(ordinal, micros) + updater.setLong(ordinal, DateTimeUtils.microsToNanos(micros)) case other => throw new IncompatibleSchemaException(errorPrefix + s"Avro logical type $other cannot be converted to SQL type ${TimeType().sql}.") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala index 533aa6ee09afb..aacf8dc9f347c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/avro/AvroSerializer.scala @@ -192,7 +192,11 @@ private[sql] class AvroSerializer( } case (_: TimeType, LONG) => avroType.getLogicalType match { - case _: LogicalTypes.TimeMicros => (getter, ordinal) => getter.getLong(ordinal) + // TimeType is stored internally as nanoseconds-since-midnight. The time-micros + // logical type stores microseconds-since-midnight, so convert nanos to micros + // to keep the on-disk value unit-correct for external Avro readers. + case _: LogicalTypes.TimeMicros => (getter, ordinal) => + DateTimeUtils.nanosToMicros(getter.getLong(ordinal)) case other => throw new IncompatibleSchemaException(errorPrefix + s"SQL type ${TimeType().sql} cannot be converted to Avro logical type $other") } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala index cfd52707bbc2c..c8f59b67e43d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/catalyst/analysis/ResolveSessionCatalog.scala @@ -816,7 +816,7 @@ class ResolveSessionCatalog(val catalogManager: CatalogManager) object ResolvedViewIdentifier { // Only matches session-catalog persistent views. Non-session-catalog persistent views - // (produced for `MetadataTable`) fall through and are picked up by dedicated v2 strategy + // (produced for `DelegatingTable`) fall through and are picked up by dedicated v2 strategy // cases in `DataSourceV2Strategy` -- AlterViewAs, SET/UNSET TBLPROPERTIES, ALTER VIEW ... // WITH SCHEMA, RENAME TO, SHOW CREATE TABLE, SHOW TBLPROPERTIES, SHOW COLUMNS, DESCRIBE // [COLUMN] all dispatch to v2 view execs that consume `ResolvedPersistentView.info` diff --git a/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala b/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala index 5aec32c572dae..0ac868c297d77 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/classic/columnNodeSupport.scala @@ -22,7 +22,7 @@ import org.apache.spark.SparkException import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.{analysis, expressions, CatalystTypeConverters} import org.apache.spark.sql.catalyst.analysis.{MultiAlias, UnresolvedAlias} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Generator, NamedExpression, Unevaluable} +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Generator, NamedExpression, Unevaluable, WindowFunction} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction} import org.apache.spark.sql.catalyst.parser.{ParserInterface, ParserUtils} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan @@ -186,7 +186,12 @@ private[sql] trait ColumnNodeToExpressionConverter extends (ColumnNode => Expres case ColumnNodeExpression(node) => apply(node) } transformed match { - case f: AggregateFunction => f.toAggregateExpression() + // A window function (e.g. an AggregateWindowFunction) is also an AggregateFunction, but + // it must not be wrapped in an AggregateExpression: it is used directly as the child of + // a WindowExpression. Wrapping it would later fail analysis with + // WINDOW_FUNCTION_WITHOUT_OVER_CLAUSE. + case f: AggregateFunction if !f.isInstanceOf[WindowFunction] => + f.toAggregateExpression() case _ => transformed } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala index 6714510874351..9fb5b960dbf8a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/HiveResult.scala @@ -23,6 +23,7 @@ import java.time._ import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.SQLConfHelper import org.apache.spark.sql.catalyst.expressions.ToStringBase +import org.apache.spark.sql.catalyst.types.ops.TypeApiOps import org.apache.spark.sql.catalyst.util.{DateFormatter, DateTimeUtils, FractionTimeFormatter, STUtils, TimeFormatter, TimestampFormatter} import org.apache.spark.sql.catalyst.util.IntervalStringStyles.HIVE_STYLE import org.apache.spark.sql.catalyst.util.IntervalUtils.{durationToMicros, periodToMonths, toDayTimeIntervalString, toYearMonthIntervalString} @@ -31,7 +32,6 @@ import org.apache.spark.sql.execution.datasources.v2.{DescribeTableExec, ShowTab import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.internal.SQLConf.BinaryOutputStyle import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.ops.TypeApiOps import org.apache.spark.unsafe.types.{CalendarInterval, VariantVal} import org.apache.spark.util.ArrayImplicits._ diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 88c74ab7adc41..8c96f1ff95790 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -31,6 +31,7 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences import org.apache.spark.sql.catalyst.expressions.codegen._ +import org.apache.spark.sql.catalyst.optimizer.CollapseProject import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.joins.{ShuffledHashJoinExec, SortMergeJoinExec} import org.apache.spark.sql.execution.metric.{SQLMetric, SQLMetrics} @@ -316,12 +317,26 @@ case class FilterExec(condition: Expression, child: SparkPlan) // (e.g. decoding a decimal column for rows a cheaper earlier predicate would reject), so we // fall back to `generatePredicateCode`. // + // A *cheap* common subexpression does not count either. Caching a cheap load saves nothing: + // the non-CSE path already loads each column lazily into a variable on demand, so taking the + // CSE path for it would only add the eager prologue that decodes every referenced column up + // front. Note bare columns never reach this point: `EquivalentExpressions` skips + // `LeafExpression`s (which includes `BoundReference`/`Attribute`), and + // `splitConjunctivePredicates` feeds each conjunct to a separate `addExprTree` call, so a + // column repeated across conjuncts (e.g. the `c >= lo` / `c <= hi` that `c BETWEEN lo AND hi` + // lowers to) is never recorded as a common subexpression. The cheap-but-recorded case is a + // shared *non-leaf* such as a struct field access -- `s.x > 5 AND s.x < 100` shares + // `GetStructField(s, x)` -- which is just a slot read. Require a non-cheap common + // subexpression (per `CollapseProject.isCheap`) so such filters keep the lazy, + // short-circuiting path and only genuine repeated computation takes the CSE path. + // // `subexpressionElimination.filterExec.enabled` additionally gates this path so it can be // turned off independently of subexpression elimination elsewhere. val (prologueCode, predicateCode) = if (conf.subexpressionEliminationEnabled && conf.subexpressionEliminationFilterExecEnabled && otherPreds.nonEmpty && - otherPredsEquivalentExpressions.getCommonSubexpressions.nonEmpty) { + otherPredsEquivalentExpressions.getCommonSubexpressions + .exists(!CollapseProject.isCheap(_))) { // Pre-evaluate input variables before CSE analysis: CSE clears // ctx.currentVars[i].code as a side effect; without this pre-evaluation, Janino // fails when otherPreds reference the same input columns that CSE already diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala index 9c012dbd58e12..f79742907779f 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryRelation.scala @@ -36,10 +36,11 @@ import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanExec import org.apache.spark.sql.execution.vectorized.{OffHeapColumnVector, OnHeapColumnVector, WritableColumnVector} import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.types._ +import org.apache.spark.sql.util.PartitionKeyedAccumulator import org.apache.spark.sql.vectorized.{ColumnarBatch, ColumnVector} import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{LongAccumulator, Utils} import org.apache.spark.util.ArrayImplicits._ +import org.apache.spark.util.Utils /** * The default implementation of CachedBatch. @@ -261,9 +262,20 @@ case class CachedRDDBuilder( @transient @volatile private var _cachedColumnBuffers: RDD[CachedBatch] = null @transient @volatile private var _cachedColumnBuffersAreLoaded: Boolean = false - val sizeInBytesStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator - val rowCountStats: LongAccumulator = cachedPlan.session.sparkContext.longAccumulator - private val materializedPartitions = cachedPlan.session.sparkContext.longAccumulator + // The cache's materialization bookkeeping: a partition-keyed accumulator storing + // (rowCount, sizeInBytes) per partition. AQE creates a separate cache scan stage per reference to + // the same cache and each submits its own build job, so the same partition can be computed by + // several concurrent jobs (and speculative tasks); Spark has no global cross-executor "compute + // this partition once" barrier (only a per-executor write lock). Keying by partition id + // (last-write-wins) means those duplicate completions cannot mark the cache loaded before every + // partition has been computed -- which otherwise let AQE read rowCount 0 on a non-empty cache and + // propagate an empty relation, silently dropping rows -- and also yields exact, de-duplicated row + // count / size. + private val partitionStats: PartitionKeyedAccumulator[(Long, Long)] = { + val acc = new PartitionKeyedAccumulator[(Long, Long)] + cachedPlan.session.sparkContext.register(acc) + acc + } val cachedName = tableName.map(n => s"In-memory table $n") .getOrElse(Utils.abbreviate(cachedPlan.toString, 1024)) @@ -284,6 +296,11 @@ case class CachedRDDBuilder( if (_cachedColumnBuffers != null) { _cachedColumnBuffers.unpersist(blocking) _cachedColumnBuffers = null + // The buffers no longer back a live RDD. Reset the one-way "loaded" latch and the keyed + // bookkeeping so a rebuild on this builder does not inherit a stale "loaded" state or stale + // statistics. Safe to reset in place: every read of the accumulator is under this monitor. + _cachedColumnBuffersAreLoaded = false + partitionStats.reset() } } @@ -296,9 +313,11 @@ case class CachedRDDBuilder( // We must make sure the statistics of `sizeInBytes` and `rowCount` are accurate if // `isCachedRDDLoaded` return true. Otherwise, AQE would do a wrong optimization, // e.g., convert a non-empty plan to empty local relation if `rowCount` is 0. - // Because the statistics is based on accumulator, here we use an extra accumulator to - // track if all partitions are materialized. - val rddLoaded = _cachedColumnBuffers.partitions.length == materializedPartitions.value + // Count DISTINCT materialized partitions (the keyed accumulator's key set), so the cache is + // only reported loaded once every partition has been computed -- sound even if a partition is + // computed more than once by concurrent or speculative tasks. + val numMaterialized = partitionStats.accumulatedNumPartitions + val rddLoaded = _cachedColumnBuffers.partitions.length.toLong == numMaterialized if (rddLoaded) { _cachedColumnBuffersAreLoaded = rddLoaded } @@ -306,6 +325,21 @@ case class CachedRDDBuilder( } } + // Reported row count / size for the cache's statistics: exact and de-duplicated, folded over the + // distinct materialized partitions. Synchronized so a fold never races a concurrent `clearCache` + // reset. + private[sql] def materializedRowCount: Long = synchronized { + partitionStats.foldValues(0L)((sum, v) => sum + v._1) + } + + private[sql] def materializedSizeInBytes: Long = synchronized { + partitionStats.foldValues(0L)((sum, v) => sum + v._2) + } + + // The id of the accumulator backing this cache's materialization bookkeeping. Exposed only so + // `CachedTableSuite`'s accumulator-cleanup test can verify it is cleared after uncache + GC. + private[sql] def materializationAccumulatorId: Long = partitionStats.id + private def buildBuffers(): RDD[CachedBatch] = { val cb = try { if (supportsColumnarInput) { @@ -330,18 +364,29 @@ case class CachedRDDBuilder( session.sharedState.cacheManager.recacheByPlan(session, logicalPlan) throw e } + // Records one successful partition materialization: this partition's (rows, bytes) keyed by its + // id. Bound to a local so the task closure below captures only the accumulator, not the + // enclosing CachedRDDBuilder (whose cachedPlan is not serializable). + val accumulator = partitionStats val cached = cb.mapPartitionsInternal { it => - TaskContext.get().addTaskCompletionListener[Unit] { context => + val taskContext = TaskContext.get() + val partitionId = taskContext.partitionId() + // This task computes exactly one partition. Tally its totals so the completion listener + // records them once, keyed by partition id (covering empty-output partitions, which produce + // no batches). + var localRows = 0L + var localBytes = 0L + taskContext.addTaskCompletionListener[Unit] { context => if (!context.isFailed() && !context.isInterrupted()) { - materializedPartitions.add(1L) + accumulator.add((partitionId, (localRows, localBytes))) } } new Iterator[CachedBatch] { override def hasNext: Boolean = it.hasNext override def next(): CachedBatch = { val batch = it.next() - sizeInBytesStats.add(batch.sizeInBytes) - rowCountStats.add(batch.numRows) + localBytes += batch.sizeInBytes + localRows += batch.numRows batch } } @@ -460,8 +505,8 @@ case class InMemoryRelation( statsOfPlanToCache } else { statsOfPlanToCache.copy( - sizeInBytes = cacheBuilder.sizeInBytesStats.value.longValue, - rowCount = Some(cacheBuilder.rowCountStats.value.longValue) + sizeInBytes = cacheBuilder.materializedSizeInBytes, + rowCount = Some(cacheBuilder.materializedRowCount) ) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala index 64317a04547a1..8746fdbc92499 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/DescribeRelationJsonCommand.scala @@ -33,7 +33,7 @@ import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeReference} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, DateTimeUtils, TimestampFormatter} import org.apache.spark.sql.classic.ClassicConversions.castToImpl -import org.apache.spark.sql.connector.catalog.{V1Table, V1ViewInfo} +import org.apache.spark.sql.connector.catalog.{V1Table, V1View} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits._ import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.types._ @@ -73,11 +73,11 @@ case class DescribeRelationJsonCommand( } // Resolve `v.info` to a `CatalogTable` so the JSON renderer below can read v1-shaped // fields uniformly. Session-catalog views carry the original `CatalogTable` inside - // `V1ViewInfo`; non-session v2 views carry a plain `ViewInfo` and are projected to a + // `V1View`; non-session v2 views carry a plain `View` and are projected to a // `CatalogTable` via `V1Table.toCatalogTable`, the same conversion the // `CreateTableLike` strategy case in `DataSourceV2Strategy` uses. val metadata = v.info match { - case v1Info: V1ViewInfo => v1Info.v1Table + case v1Info: V1View => v1Info.v1Table case info => V1Table.toCatalogTable(v.catalog, v.identifier, info) } describeIdentifier(v.identifier.toQualifiedNameParts(v.catalog), jsonMap) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala index 9a1c9bfe1f59c..a1f36e2280df1 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/command/tables.scala @@ -39,7 +39,7 @@ import org.apache.spark.sql.catalyst.types.DataTypeUtils import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeeded, CaseInsensitiveMap, CharVarcharUtils, DateTimeUtils, ResolveDefaultColumns} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumns.CURRENT_DEFAULT_COLUMN_METADATA_KEY import org.apache.spark.sql.classic.ClassicConversions.castToImpl -import org.apache.spark.sql.connector.catalog.{TableCatalog, V1Table, V1ViewInfo} +import org.apache.spark.sql.connector.catalog.{TableCatalog, V1Table, V1View} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.TableIdentifierHelper import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors} import org.apache.spark.sql.execution.CommandExecutionMode @@ -591,9 +591,9 @@ object ResolvedChildHelper { val catalog = sparkSession.sessionState.catalog child match { case ResolvedTempView(_, metadata) => metadata - // v1 inspection commands always see a v1 (`V1ViewInfo`) view here -- the v2 strategy + // v1 inspection commands always see a v1 (`V1View`) view here -- the v2 strategy // handles non-session views before this method is reached. - case ResolvedPersistentView(_, _, info: V1ViewInfo) => info.v1Table + case ResolvedPersistentView(_, _, info: V1View) => info.v1Table case ResolvedTable(_, _, t: V1Table, _) => t.v1Table case _ if (catalog.isTempView(table)) => catalog.getTempViewOrPermanentTableMetadata(table) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala index 59c103577e136..c75c8f046214b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/FileFormatWriter.scala @@ -148,52 +148,66 @@ object FileFormatWriter extends Logging { writerBucketSpec.map(_.bucketIdExpression) ++ sortColumns val writeFilesOpt = V1WritesUtils.getWriteFilesOpt(plan) - // SPARK-40588: when planned writing is disabled and AQE is enabled, - // plan contains an AdaptiveSparkPlanExec, which does not know - // its final plan's ordering, so we have to materialize that plan first - // it is fine to use plan further down as the final plan is cached in that plan - def materializeAdaptiveSparkPlan(plan: SparkPlan): SparkPlan = plan match { - case a: AdaptiveSparkPlanExec => a.finalPhysicalPlan - case p: SparkPlan => p.withNewChildren(p.children.map(materializeAdaptiveSparkPlan)) - } + // SPARK-56919: setupJob must run before materializeAdaptiveSparkPlan, which can throw. + // Otherwise INSERT OVERWRITE permanently loses the table path if AQE fails. + // setupJob is outside the try below because it only initializes the job; the try/catch + // calls abortJob on any later failure (e.g. materialize throwing), which cleans up the + // staging dir (_temporary / .spark-staging-*). + committer.setupJob(job) - // the sort order doesn't matter - val actualOrdering = writeFilesOpt.map(_.child) - .getOrElse(materializeAdaptiveSparkPlan(plan)) - .outputOrdering - val orderingMatched = V1WritesUtils.isOrderingMatched(requiredOrdering, actualOrdering) - - SQLExecution.checkSQLExecutionId(sparkSession) - - // propagate the description UUID into the jobs, so that committers - // get an ID guaranteed to be unique. - job.getConfiguration.set("spark.sql.sources.writeJobUUID", description.uuid) - - // When `PLANNED_WRITE_ENABLED` is true, the optimizer rule V1Writes will add logical sort - // operator based on the required ordering of the V1 write command. So the output - // ordering of the physical plan should always match the required ordering. Here - // we set the variable to verify this behavior in tests. - // There are two cases where FileFormatWriter still needs to add physical sort: - // 1) When the planned write config is disabled. - // 2) When the concurrent writers are enabled (in this case the required ordering of a - // V1 write command will be empty). - if (Utils.isTesting) outputOrderingMatched = orderingMatched - - if (writeFilesOpt.isDefined) { - // build `WriteFilesSpec` for `WriteFiles` - val concurrentOutputWriterSpecFunc = (plan: SparkPlan) => { - val sortPlan = createSortPlan(plan, requiredOrdering, outputSpec) - createConcurrentOutputWriterSpec(sparkSession, sortPlan, sortColumns) + try { + // SPARK-40588: when planned writing is disabled and AQE is enabled, + // plan contains an AdaptiveSparkPlanExec, which does not know + // its final plan's ordering, so we have to materialize that plan first + // it is fine to use plan further down as the final plan is cached in that plan + def materializeAdaptiveSparkPlan(plan: SparkPlan): SparkPlan = plan match { + case a: AdaptiveSparkPlanExec => a.finalPhysicalPlan + case p: SparkPlan => p.withNewChildren(p.children.map(materializeAdaptiveSparkPlan)) } - val writeSpec = WriteFilesSpec( - description = description, - committer = committer, - concurrentOutputWriterSpecFunc = concurrentOutputWriterSpecFunc - ) - executeWrite(sparkSession, plan, writeSpec, job) - } else { - executeWrite(sparkSession, plan, job, description, committer, outputSpec, - requiredOrdering, partitionColumns, sortColumns, orderingMatched) + + // the sort order doesn't matter + val actualOrdering = writeFilesOpt.map(_.child) + .getOrElse(materializeAdaptiveSparkPlan(plan)) + .outputOrdering + val orderingMatched = V1WritesUtils.isOrderingMatched(requiredOrdering, actualOrdering) + + SQLExecution.checkSQLExecutionId(sparkSession) + + // propagate the description UUID into the jobs, so that committers + // get an ID guaranteed to be unique. + job.getConfiguration.set("spark.sql.sources.writeJobUUID", description.uuid) + + // When `PLANNED_WRITE_ENABLED` is true, the optimizer rule V1Writes will add logical sort + // operator based on the required ordering of the V1 write command. So the output + // ordering of the physical plan should always match the required ordering. Here + // we set the variable to verify this behavior in tests. + // There are two cases where FileFormatWriter still needs to add physical sort: + // 1) When the planned write config is disabled. + // 2) When the concurrent writers are enabled (in this case the required ordering of a + // V1 write command will be empty). + if (Utils.isTesting) outputOrderingMatched = orderingMatched + + if (writeFilesOpt.isDefined) { + // build `WriteFilesSpec` for `WriteFiles` + val concurrentOutputWriterSpecFunc = (plan: SparkPlan) => { + val sortPlan = createSortPlan(plan, requiredOrdering, outputSpec) + createConcurrentOutputWriterSpec(sparkSession, sortPlan, sortColumns) + } + val writeSpec = WriteFilesSpec( + description = description, + committer = committer, + concurrentOutputWriterSpecFunc = concurrentOutputWriterSpecFunc + ) + executeWrite(sparkSession, plan, writeSpec, job) + } else { + executeWrite(sparkSession, plan, job, description, committer, outputSpec, + requiredOrdering, partitionColumns, sortColumns, orderingMatched) + } + } catch { + case cause: Throwable => + logError(log"Aborting job ${MDC(WRITE_JOB_UUID, description.uuid)}.", cause) + committer.abortJob(job) + throw cause } } // scalastyle:on argcount @@ -267,30 +281,21 @@ object FileFormatWriter extends Logging { job: Job, description: WriteJobDescription, committer: FileCommitProtocol)(f: => Array[WriteTaskResult]): Set[String] = { - // This call shouldn't be put into the `try` block below because it only initializes and - // prepares the job, any exception thrown from here shouldn't cause abortJob() to be called. - committer.setupJob(job) - try { - val ret = f - val commitMsgs = ret.map(_.commitMsg) - - logInfo(log"Start to commit write Job ${MDC(LogKeys.UUID, description.uuid)}.") - val (_, duration) = Utils - .timeTakenMs { committer.commitJob(job, commitMsgs.toImmutableArraySeq) } - logInfo(log"Write Job ${MDC(LogKeys.UUID, description.uuid)} committed. " + - log"Elapsed time: ${MDC(LogKeys.ELAPSED_TIME, duration)} ms.") - - processStats( - description.statsTrackers, ret.map(_.summary.stats).toImmutableArraySeq, duration) - logInfo(log"Finished processing stats for write job ${MDC(LogKeys.UUID, description.uuid)}.") - - // return a set of all the partition paths that were updated during this job - ret.map(_.summary.updatedPartitions).reduceOption(_ ++ _).getOrElse(Set.empty) - } catch { case cause: Throwable => - logError(log"Aborting job ${MDC(WRITE_JOB_UUID, description.uuid)}.", cause) - committer.abortJob(job) - throw cause - } + val ret = f + val commitMsgs = ret.map(_.commitMsg) + + logInfo(log"Start to commit write Job ${MDC(LogKeys.UUID, description.uuid)}.") + val (_, duration) = Utils + .timeTakenMs { committer.commitJob(job, commitMsgs.toImmutableArraySeq) } + logInfo(log"Write Job ${MDC(LogKeys.UUID, description.uuid)} committed. " + + log"Elapsed time: ${MDC(LogKeys.ELAPSED_TIME, duration)} ms.") + + processStats( + description.statsTrackers, ret.map(_.summary.stats).toImmutableArraySeq, duration) + logInfo(log"Finished processing stats for write job ${MDC(LogKeys.UUID, description.uuid)}.") + + // return a set of all the partition paths that were updated during this job + ret.map(_.summary.updatedPartitions).reduceOption(_ ++ _).getOrElse(Set.empty) } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala index 3c85b6e65dee7..7188c2b8b2e8e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/jdbc/JDBCOptions.scala @@ -20,6 +20,8 @@ package org.apache.spark.sql.execution.datasources.jdbc import java.sql.{Connection, DriverManager} import java.util.{Locale, Properties} +import scala.util.matching.Regex + import org.apache.commons.io.FilenameUtils import org.apache.spark.SparkFiles @@ -268,7 +270,7 @@ class JDBCOptions( case _ => false } - def getRedactUrl(): String = Utils.redact(SQLConf.get.stringRedactionPattern, url) + def getRedactUrl(): String = JDBCOptions.redactUrl(url, SQLConf.get.stringRedactionPattern) } class JdbcOptionsInWrite( @@ -302,6 +304,37 @@ object JDBCOptions { name } + /** + * Redacts a JDBC URL so it is safe to surface in logs and error messages. + * + * A JDBC URL has the form `jdbc::`, where `` is a registered + * driver name (`mysql`, `oracle`, `postgresql`, ...) and `` is entirely driver-specific. + * Credentials can appear anywhere in `` and in arbitrary syntaxes -- userinfo in a + * `//user:pwd@host` authority, Oracle Thin's `user/pwd@host`, `?`/`;` connection properties, etc. + * Rather than enumerate every driver's syntax (and inevitably miss one and leak), we keep only + * the `jdbc::` prefix -- which is credential-free by construction -- and redact + * everything after it. The driver type stays visible for debugging; nothing else does. + * + * This redaction is unconditional, unlike the optional, user-configured + * `spark.sql.redaction.string.regex` (which is unset by default and would leave the URL in the + * clear). The configured `regex` is still applied on top, preserving existing behavior. + */ + def redactUrl(url: String, regex: Option[Regex]): String = { + if (url == null || url.isEmpty) { + url + } else { + // The second colon terminates the subprotocol: "jdbc" ':' "" ':' "". + val subprotocolEnd = url.indexOf(':', url.indexOf(':') + 1) + val redacted = if (subprotocolEnd < 0) { + // No subname delimiter -- the URL is malformed, so don't trust any of it. + Utils.REDACTION_REPLACEMENT_TEXT + } else { + url.substring(0, subprotocolEnd + 1) + Utils.REDACTION_REPLACEMENT_TEXT + } + Utils.redact(regex, redacted) + } + } + val JDBC_URL = newOption("url") val JDBC_TABLE_NAME = newOption("dbtable") val JDBC_QUERY_STRING = newOption("query") diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterV2ViewExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterV2ViewExec.scala index 2309cb31b5ebe..3afed35d894bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterV2ViewExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/AlterV2ViewExec.scala @@ -23,12 +23,12 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, SchemaEvolution, ViewSchemaMode} import org.apache.spark.sql.catalyst.catalog.CatalogTable import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog, ViewCatalog, ViewInfo} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog, View, ViewCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.{IdentifierHelper, MultipartIdentifierHelper} import org.apache.spark.sql.execution.command.CommandUtils /** - * Shared bits for the v2 ALTER VIEW ... AS exec. The replacement [[ViewInfo]] is constructed by + * Shared bits for the v2 ALTER VIEW ... AS exec. The replacement [[View]] is constructed by * [[V2ViewPreparation.buildViewInfo]]; the existing view's payload is provided at analysis time * via the `existingView` field so we can preserve user-set TBLPROPERTIES, comment, collation, * owner, and schema binding mode without re-loading at runtime. @@ -40,7 +40,7 @@ import org.apache.spark.sql.execution.command.CommandUtils * propagates. */ private[v2] trait V2AlterViewPreparation extends V2ViewPreparation { - protected def existingView: ViewInfo + protected def existingView: View protected lazy val existingProps: Map[String, String] = existingView.properties.asScala.toMap @@ -53,13 +53,13 @@ private[v2] trait V2AlterViewPreparation extends V2ViewPreparation { override def collation: Option[String] = existingProp(TableCatalog.PROP_COLLATION) // Preserve the existing view's owner (v1-parity with AlterViewAsCommand's viewMeta.copy, // which leaves `owner` untouched). If the existing view has no PROP_OWNER, pass it through - // as None so the replacement ViewInfo also has no owner. + // as None so the replacement View also has no owner. override def owner: Option[String] = existingProp(TableCatalog.PROP_OWNER) override def userProperties: Map[String, String] = existingProps // Preserve the existing view's schema binding mode. Reuse `viewSchemaModeFromProperties` // for a v1-identical decode -- it honors `viewSchemaBindingEnabled` and defaults missing - // values to SchemaBinding. We feed the typed `ViewInfo.schemaMode` String in via a + // values to SchemaBinding. We feed the typed `View.schemaMode` String in via a // single-key map so the decode logic stays in one place. override def viewSchemaMode: ViewSchemaMode = CatalogTable.viewSchemaModeFromProperties( @@ -75,7 +75,7 @@ private[v2] trait V2AlterViewPreparation extends V2ViewPreparation { case class AlterV2ViewExec( catalog: ViewCatalog, identifier: Identifier, - existingView: ViewInfo, + existingView: View, originalText: String, query: LogicalPlan) extends V2AlterViewPreparation { @@ -96,7 +96,7 @@ case class AlterV2ViewExec( case class AlterV2ViewSetPropertiesExec( catalog: ViewCatalog, identifier: Identifier, - existingView: ViewInfo, + existingView: View, properties: Map[String, String]) extends LeafV2CommandExec { override def output: Seq[org.apache.spark.sql.catalyst.expressions.Attribute] = Seq.empty @@ -124,7 +124,7 @@ case class AlterV2ViewSetPropertiesExec( case class AlterV2ViewUnsetPropertiesExec( catalog: ViewCatalog, identifier: Identifier, - existingView: ViewInfo, + existingView: View, propertyKeys: Seq[String]) extends LeafV2CommandExec { override def output: Seq[org.apache.spark.sql.catalyst.expressions.Attribute] = Seq.empty @@ -152,7 +152,7 @@ case class AlterV2ViewUnsetPropertiesExec( case class AlterV2ViewSchemaBindingExec( catalog: ViewCatalog, identifier: Identifier, - existingView: ViewInfo, + existingView: View, viewSchemaMode: ViewSchemaMode) extends LeafV2CommandExec { override def output: Seq[org.apache.spark.sql.catalyst.expressions.Attribute] = Seq.empty diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableLikeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableLikeExec.scala index a472f6cf8e14c..fd6ac760a1ce8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableLikeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateTableLikeExec.scala @@ -81,15 +81,18 @@ case class CreateTableLikeExec( Seq.empty } - // Derive target columns from source; for V1Table sources apply CharVarcharUtils to preserve - // CHAR/VARCHAR types as declared rather than collapsed to StringType. - private def targetColumns: Array[Column] = + // Derive target columns from source without propagating source field IDs to the new table. + // For V1Table sources, apply CharVarcharUtils to preserve CHAR/VARCHAR types as declared + // rather than collapsed to StringType. + private def targetColumns: Array[Column] = { sourceTable match { case v1: V1Table => - CatalogV2Util.structTypeToV2Columns(CharVarcharUtils.getRawSchema(v1.catalogTable.schema)) + val rawSchema = CharVarcharUtils.getRawSchema(v1.catalogTable.schema) + CatalogV2Util.structTypeToV2Columns(rawSchema, keepIds = false) case _ => - sourceTable.columns + CatalogV2Util.clearIds(sourceTable.columns) } + } // Source table properties are intentionally excluded; connectors read sourceTable // to clone any additional format-specific or custom state they need. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateV2ViewExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateV2ViewExec.scala index 4e10c7d3ab284..2781ca4528c53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateV2ViewExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/CreateV2ViewExec.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.analysis.{ResolvedIdentifier, SchemaEvoluti import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.catalyst.util.CharVarcharUtils -import org.apache.spark.sql.connector.catalog.{DependencyList, Identifier, TableCatalog, ViewCatalog, ViewInfo} +import org.apache.spark.sql.connector.catalog.{DependencyList, Identifier, TableCatalog, View, ViewCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.IdentifierHelper import org.apache.spark.sql.errors.QueryCompilationErrors import org.apache.spark.sql.execution.command.{CommandUtils, ViewHelper} @@ -33,7 +33,7 @@ import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.util.ArrayImplicits._ /** - * Shared validation + ViewInfo construction for v2 CREATE VIEW / ALTER VIEW execs. + * Shared validation + View construction for v2 CREATE VIEW / ALTER VIEW execs. * * Mirrors the persistent-view portion of v1 [[ViewHelper.prepareTable]] + the execution-time * checks in [[org.apache.spark.sql.execution.command.CreateViewCommand.run]]. Post-analysis @@ -57,7 +57,7 @@ private[v2] trait V2ViewPreparation extends LeafV2CommandExec { protected lazy val fullNameParts: Seq[String] = (catalog.name() +: identifier.asMultipartIdentifier).toSeq - /** Optional structured dependency list to stamp on the built `ViewInfo`. */ + /** Optional structured dependency list to stamp on the built `View`. */ protected def viewDependencies: Option[DependencyList] = None /** Optional view sub-kind to stamp on `PROP_TABLE_TYPE`; defaults to `VIEW` when `None`. */ @@ -74,7 +74,7 @@ private[v2] trait V2ViewPreparation extends LeafV2CommandExec { override def output: Seq[Attribute] = Seq.empty - protected def buildViewInfo(): ViewInfo = { + protected def buildViewInfo(): View = { import ViewHelper._ if (userSpecifiedColumns.nonEmpty) { @@ -107,7 +107,7 @@ private[v2] trait V2ViewPreparation extends LeafV2CommandExec { query.output.map(_.name).toArray } - val builder = new ViewInfo.Builder() + val builder = new View.Builder() .withSchema(aliasedSchema) .withProperties(userProperties.asJava) .withQueryText(originalText) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala index 4fd7d993cc3d0..50130e5fb7770 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/DataSourceV2Strategy.scala @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.catalyst.trees.TreePattern.SCALAR_SUBQUERY import org.apache.spark.sql.catalyst.util.{quoteIfNeeded, toPrettySQL, GeneratedColumn, IdentityColumn, ResolveDefaultColumns, ResolveTableConstraints, V2ExpressionBuilder} import org.apache.spark.sql.classic.SparkSession -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Dependency, DependencyList, Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, TableCapability, TableCatalog, TableSummary, TruncatableTable, V1Table, V1ViewInfo, ViewCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Dependency, DependencyList, Identifier, StagingTableCatalog, SupportsDeleteV2, SupportsNamespaces, SupportsPartitionManagement, SupportsWrite, TableCapability, TableCatalog, TableSummary, TruncatableTable, V1Table, V1View, ViewCatalog} import org.apache.spark.sql.connector.catalog.TableChange import org.apache.spark.sql.connector.catalog.index.SupportsIndex import org.apache.spark.sql.connector.expressions.{FieldReference, LiteralValue} @@ -105,10 +105,10 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat } // Strategy cases that target v2 views read `ResolvedPersistentView.info` directly. For - // session-catalog (v1) views the payload is a `V1ViewInfo` wrapping the original - // `CatalogTable`; v2 catalogs supply a regular `ViewInfo` from the catalog. + // session-catalog (v1) views the payload is a `V1View` wrapping the original + // `CatalogTable`; v2 catalogs supply a regular `View` from the catalog. // `ResolveSessionCatalog` rewrites session-catalog views to v1 commands before this strategy - // fires, so v2 cases that don't expect a `V1ViewInfo` won't see one. + // fires, so v2 cases that don't expect a `V1View` won't see one. private def qualifyLocInTableSpec(tableSpec: TableSpec): TableSpec = { val newLoc = tableSpec.location.map { loc => @@ -269,13 +269,13 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat // Views are wrapped in V1Table so the exec can extract schema and provider uniformly -- // session-catalog (v1) views unwrap to their original `CatalogTable`; non-session v2 // views go through `V1Table.toCatalogTable` to synthesize an equivalent `CatalogTable` - // from the resolved `ViewInfo`. + // from the resolved `View`. case CreateTableLike( ResolvedIdentifier(catalog, ident), source, locationStr, provider, serdeInfo, properties, ifNotExists) => val table = source match { case ResolvedTable(_, _, t, _) => t - case ResolvedPersistentView(_, _, info: V1ViewInfo) => V1Table(info.v1Table) + case ResolvedPersistentView(_, _, info: V1View) => V1Table(info.v1Table) case rpv @ ResolvedPersistentView(viewCatalog, viewIdent, _) => V1Table(V1Table.toCatalogTable(viewCatalog, viewIdent, rpv.info)) case ResolvedTempView(_, meta) => V1Table(meta) @@ -366,7 +366,7 @@ class DataSourceV2Strategy(session: SparkSession) extends Strategy with Predicat // View DDL / inspection on a non-session v2 catalog that the v1 rewrite in // `ResolveSessionCatalog` can't handle (its `ResolvedViewIdentifier` matcher is gated on - // `isSessionCatalog`). Routed to dedicated v2 execs that read the typed `ViewInfo` + // `isSessionCatalog`). Routed to dedicated v2 execs that read the typed `View` // resolved at analysis time directly from `ResolvedPersistentView.info` -- no re-loading // at exec time. case SetViewProperties(rpv @ ResolvedPersistentView(catalog, ident, _), props) => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala index 8680785e0815f..cb21e70569217 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/ShowTablesExec.scala @@ -22,15 +22,15 @@ import scala.collection.mutable.ArrayBuffer import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.util.StringUtils -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog, TableViewCatalog} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, RelationCatalog, TableCatalog} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.NamespaceHelper import org.apache.spark.sql.execution.LeafExecNode /** * Physical plan node for showing tables. * - * For a [[TableViewCatalog]] (one that exposes both tables and views in a shared identifier - * namespace), this routes through [[TableViewCatalog#listTableAndViewSummaries]] so that views are + * For a [[RelationCatalog]] (one that exposes both tables and views in a shared identifier + * namespace), this routes through [[RelationCatalog#listRelationSummaries]] so that views are * included in the listing -- matching the v1 `SHOW TABLES` semantics where views appear * alongside tables. Pure [[TableCatalog]] catalogs continue to use `listTables` and return * tables only. @@ -44,8 +44,8 @@ case class ShowTablesExec( val rows = new ArrayBuffer[InternalRow]() val identifiers: Array[Identifier] = catalog match { - case mc: TableViewCatalog => - mc.listTableAndViewSummaries(namespace.toArray).map(_.identifier()) + case mc: RelationCatalog => + mc.listRelationSummaries(namespace.toArray).map(_.identifier()) case _ => catalog.listTables(namespace.toArray) } identifiers.foreach { ident => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala index 46359f1fa8a2d..f1ff11b1a4a65 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2TableRefreshUtil.scala @@ -97,7 +97,6 @@ private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging { } }) validateTableIdentity(currentTable, r) - validateColumnIds(currentTable, r) validateDataColumns(currentTable, r, schemaValidationMode) validateMetadataColumns(currentTable, r, schemaValidationMode) r.copy(table = currentTable) @@ -125,22 +124,13 @@ private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging { V2TableUtil.validateTableId(relation.name, relation.table.id, currentTable) } - private def validateColumnIds( - currentTable: Table, - relation: DataSourceV2Relation): Unit = { - val errors = V2TableUtil.validateColumnIds(currentTable, relation) - if (errors.nonEmpty) { - throw QueryCompilationErrors.columnIdMismatchAfterAnalysis(relation.name, errors) - } - } - private def validateDataColumns( currentTable: Table, relation: DataSourceV2Relation, mode: SchemaValidationMode): Unit = { - val errors = V2TableUtil.validateCapturedColumns(currentTable, relation, mode) + val errors = V2TableUtil.validateCapturedColumns(currentTable, relation, mode, checkIds = true) if (errors.nonEmpty) { - throw QueryCompilationErrors.columnsMissingOrAddedAfterAnalysis(relation.name, errors) + throw QueryCompilationErrors.columnsChangedAfterAnalysis(relation.name, errors) } } @@ -148,7 +138,11 @@ private[sql] object V2TableRefreshUtil extends SQLConfHelper with Logging { currentTable: Table, relation: DataSourceV2Relation, mode: SchemaValidationMode): Unit = { - val errors = V2TableUtil.validateCapturedMetadataColumns(currentTable, relation, mode) + val errors = V2TableUtil.validateCapturedMetadataColumns( + currentTable, + relation, + mode, + checkIds = true) if (errors.nonEmpty) { throw QueryCompilationErrors.metadataColumnsChangedAfterAnalysis(relation.name, errors) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ViewInspectionExecs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ViewInspectionExecs.scala index d1ceeba833ea7..2bf4664bc4740 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ViewInspectionExecs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/V2ViewInspectionExecs.scala @@ -24,11 +24,11 @@ import org.apache.spark.sql.catalyst.{InternalRow, SQLConfHelper} import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.util.{escapeSingleQuotedString, quoteIfNeeded, ResolveDefaultColumns} import org.apache.spark.sql.catalyst.util.ResolveDefaultColumnsUtils.CURRENT_DEFAULT_COLUMN_METADATA_KEY -import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog, TableSummary, ViewInfo} +import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog, TableSummary, View} import org.apache.spark.sql.errors.QueryCompilationErrors /** - * Read-side v2 view execs. Each receives the typed [[ViewInfo]] resolved at analysis time + * Read-side v2 view execs. Each receives the typed [[View]] resolved at analysis time * (carried on `ResolvedPersistentView.info`) and formats output rows directly from it -- * matching the way v2 table inspection execs (e.g. `ShowCreateTableExec`, `DescribeTableExec`) * consume the [[org.apache.spark.sql.connector.catalog.Table]] attached to `ResolvedTable`. @@ -40,15 +40,15 @@ import org.apache.spark.sql.errors.QueryCompilationErrors /** * Physical plan node for SHOW CREATE TABLE on a v2 view. Reconstructs the {@code CREATE VIEW} - * statement directly from the typed [[ViewInfo]] -- the column list comes from - * [[ViewInfo#schema]], the body from [[ViewInfo#queryText]], the binding mode from - * [[ViewInfo#schemaMode]], and the user TBLPROPERTIES from [[ViewInfo#properties]] (with the + * statement directly from the typed [[View]] -- the column list comes from + * [[View#schema]], the body from [[View#queryText]], the binding mode from + * [[View#schemaMode]], and the user TBLPROPERTIES from [[View#properties]] (with the * reserved-keys filter applied so internal entries don't leak into the rendered DDL). */ case class ShowCreateV2ViewExec( output: Seq[Attribute], quotedName: String, - viewInfo: ViewInfo) extends LeafV2CommandExec with SQLConfHelper { + viewInfo: View) extends LeafV2CommandExec with SQLConfHelper { override protected def run(): Seq[InternalRow] = { val builder = new StringBuilder @@ -99,7 +99,7 @@ case class ShowCreateV2ViewExec( /** * Physical plan node for SHOW TBLPROPERTIES on a v2 view. Returns the user-facing properties - * from [[ViewInfo#properties]] -- reserved first-class keys (PROP_COMMENT, PROP_COLLATION, + * from [[View#properties]] -- reserved first-class keys (PROP_COMMENT, PROP_COLLATION, * PROP_OWNER, PROP_TABLE_TYPE, ...) are filtered out so users see only what they (or the * catalog) explicitly set, matching v1 `SHOW TBLPROPERTIES` on a session-catalog view (which * hides these because v1 stores them in typed `CatalogTable` fields rather than `properties`). @@ -108,7 +108,7 @@ case class ShowCreateV2ViewExec( case class ShowV2ViewPropertiesExec( output: Seq[Attribute], quotedName: String, - viewInfo: ViewInfo, + viewInfo: View, propertyKey: Option[String]) extends LeafV2CommandExec with SQLConfHelper { override protected def run(): Seq[InternalRow] = { @@ -133,11 +133,11 @@ case class ShowV2ViewPropertiesExec( /** * Physical plan node for SHOW COLUMNS on a v2 view. Returns one row per top-level field in - * [[ViewInfo#schema]]. + * [[View#schema]]. */ case class ShowV2ViewColumnsExec( output: Seq[Attribute], - viewInfo: ViewInfo) extends LeafV2CommandExec { + viewInfo: View) extends LeafV2CommandExec { override protected def run(): Seq[InternalRow] = { viewInfo.schema.map(c => toCatalystRow(c.name)).toSeq @@ -157,7 +157,7 @@ case class DescribeV2ViewExec( output: Seq[Attribute], catalogName: String, identifier: Identifier, - viewInfo: ViewInfo, + viewInfo: View, isExtended: Boolean) extends DescribeIdentifierRows with SQLConfHelper { override protected def run(): Seq[InternalRow] = { @@ -170,7 +170,7 @@ case class DescribeV2ViewExec( result += toCatalystRow("# Detailed View Information", "", "") addIdentifierRows(result, catalogName, identifier, entityLabel = "View") // Surface the view sub-kind so users see whether they're looking at a plain VIEW - // or a sub-kind like METRIC_VIEW. `ViewInfo`'s constructor unconditionally stamps + // or a sub-kind like METRIC_VIEW. `View`'s constructor unconditionally stamps // `PROP_TABLE_TYPE` (defaulting to `VIEW`), so this row is always present and // matches v1 `CatalogTable.toJsonLinkedHashMap`'s `Type` row for parity. result += toCatalystRow( @@ -235,7 +235,7 @@ case class DescribeV2ViewExec( */ case class DescribeV2ViewColumnExec( output: Seq[Attribute], - viewInfo: ViewInfo, + viewInfo: View, colNameParts: Seq[String], isExtended: Boolean) extends LeafV2CommandExec with SQLConfHelper { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala index 33709fbd5f5a7..a1a4ca6196ebd 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/WriteToDataSourceV2Exec.scala @@ -967,7 +967,7 @@ private[v2] trait V2CreateTableAsSelectBaseExec extends LeafV2CommandExec { protected def getV2Columns(schema: StructType, forceNullable: Boolean): Array[Column] = { val rawSchema = CharVarcharUtils.getRawSchema(removeInternalMetadata(schema), conf) val tableSchema = if (forceNullable) rawSchema.asNullable else rawSchema - CatalogV2Util.structTypeToV2Columns(tableSchema) + CatalogV2Util.structTypeToV2Columns(tableSchema, keepIds = false) } protected def writeToTable( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala index c44a5d30cafe6..8543fa9ca1d5a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/orc/OrcPartitionReaderFactory.scala @@ -129,7 +129,7 @@ case class OrcPartitionReaderFactory( return buildColumnarReaderWithAggregates(file, conf) } val filePath = file.toPath - lazy val (reader, readerOptions) = createORCReader(filePath, conf) + val (reader, readerOptions) = createORCReader(filePath, conf) val orcSchema = Utils.tryWithResource(reader)(_.getSchema) val resultedColPruneInfo = OrcUtils.requestedColumnIds( isCaseSensitive, dataSchema, readDataSchema, orcSchema, conf) @@ -172,10 +172,14 @@ case class OrcPartitionReaderFactory( val fs = filePath.getFileSystem(conf) val readerOptions = OrcFile.readerOptions(conf).filesystem(fs) val reader = OrcFile.createReader(filePath, readerOptions) - - pushDownPredicates(reader.getSchema, conf) - - (reader, readerOptions) + try { + pushDownPredicates(reader.getSchema, conf) + (reader, readerOptions) + } catch { + case e: Throwable => + reader.close() + throw e + } } /** diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index adee0b2ea19a1..ab5ad5d1270b7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -29,9 +29,9 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.types.ops.TypeApiOps import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, GenericArrayData, MapData, STUtils} import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.ops.TypeApiOps import org.apache.spark.unsafe.types.{BinaryView, UTF8String, VariantVal} object EvaluatePython { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala index 966c5d14bc662..95bb1cbf99ba0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/H2Dialect.scala @@ -133,7 +133,7 @@ private[sql] case class H2Dialect() extends JdbcDialect with NoLegacyJDBCError { options: JDBCOptions): Boolean = { val sql = "SELECT * FROM INFORMATION_SCHEMA.INDEXES WHERE " + s"TABLE_SCHEMA = '${tableIdent.namespace().last}' AND " + - s"TABLE_NAME = '${tableIdent.name()}' AND INDEX_NAME = '$indexName'" + s"TABLE_NAME = '${tableIdent.name()}' AND INDEX_NAME = '${escapeSql(indexName)}'" JdbcUtils.checkIfIndexExists(conn, sql, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala index a34d23512e996..b0b10f1d09f27 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/JdbcDialects.scala @@ -716,11 +716,11 @@ abstract class JdbcDialect extends Serializable with Logging { } def getTableCommentQuery(table: String, comment: String): String = { - s"COMMENT ON TABLE $table IS '$comment'" + s"COMMENT ON TABLE $table IS '${escapeSql(comment)}'" } def getSchemaCommentQuery(schema: String, comment: String): String = { - s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS '$comment'" + s"COMMENT ON SCHEMA ${quoteIdentifier(schema)} IS '${escapeSql(comment)}'" } def removeSchemaCommentQuery(schema: String): String = { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala index a047085a35378..b301c0c0bd5bc 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/MySQLDialect.scala @@ -266,7 +266,7 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No // See https://dev.mysql.com/doc/refman/8.0/en/alter-table.html override def getTableCommentQuery(table: String, comment: String): String = { - s"ALTER TABLE $table COMMENT = '$comment'" + s"ALTER TABLE $table COMMENT = '${escapeSql(comment)}'" } override def getJDBCType(dt: DataType): Option[JdbcType] = dt match { @@ -318,7 +318,7 @@ private case class MySQLDialect() extends JdbcDialect with SQLConfHelper with No tableIdent: Identifier, options: JDBCOptions): Boolean = { val sql = s"SHOW INDEXES FROM ${quoteIdentifier(tableIdent.name())} " + - s"WHERE key_name = '$indexName'" + s"WHERE key_name = '${escapeSql(indexName)}'" JdbcUtils.checkIfIndexExists(conn, sql, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala index 6caf1f4b1ff6a..d3ef79fdf3f9a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/OracleDialect.scala @@ -75,7 +75,22 @@ private case class OracleDialect() extends JdbcDialect with SQLConfHelper with N override def visitSQLFunction(funcName: String, inputs: Array[String]): String = { funcName match { case "TRUNC" => - s"TRUNC(${inputs(0)}, 'IW')" + // Map Spark's trunc format strings to Oracle equivalents. + // inputs(1) arrives quoted, e.g. "'MONTH'" (see JDBCSQLBuilder.visitLiteral). + // Case-insensitive: Spark's parseTruncLevel uppercases before matching. + val fmt = inputs(1).toUpperCase(Locale.ROOT) + val oracleFormat = fmt match { + case "'WEEK'" => "'IW'" + case "'MONTH'" | "'MM'" | "'MON'" => "'MM'" + case "'QUARTER'" => "'Q'" + case "'YEAR'" | "'YYYY'" | "'YY'" => "'YYYY'" + case _ => + // Unmapped formats: don't push down. compileExpression catches the + // exception and returns None, so Spark evaluates trunc locally. + throw new IllegalArgumentException( + s"Unsupported Oracle TRUNC format: ${inputs(1)}") + } + s"TRUNC(${inputs(0)}, $oracleFormat)" case _ => super.visitSQLFunction(funcName, inputs) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala index dd57c129179ef..8941767ec3573 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/jdbc/PostgresDialect.scala @@ -271,7 +271,7 @@ private case class PostgresDialect() tableIdent: Identifier, options: JDBCOptions): Boolean = { val sql = s"SELECT * FROM pg_indexes WHERE tablename = '${tableIdent.name()}' AND" + - s" indexname = '$indexName'" + s" indexname = '${escapeSql(indexName)}'" JdbcUtils.checkIfIndexExists(conn, sql, options) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala b/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala new file mode 100644 index 0000000000000..bb8f04a8a5565 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/util/PartitionKeyedAccumulator.scala @@ -0,0 +1,90 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import java.util.concurrent.ConcurrentHashMap + +import org.apache.spark.util.AccumulatorV2 + +/** + * An `AccumulatorV2` that records one value of type `T` per partition, keyed by partition id with + * LAST-WRITE-WINS merge. When the same partition is recorded more than once -- e.g. duplicate + * cross-executor computes, or speculative tasks -- the later value replaces the earlier one rather + * than aggregating, so each partition contributes exactly once. The key set is the set of recorded + * partitions, and callers fold the values (see [[foldValues]]) to derive de-duplicated aggregates; + * a plain summing accumulator would instead over-count under duplicate computes. + * + * `add` is expected to be called once per task (e.g. from a task completion listener) with that + * partition's value, so a partition is recorded even when it produced nothing. Updates from + * failed/interrupted tasks are dropped by the accumulator framework (it is not + * `countFailedValues`), so only complete per-partition values are ever merged. + * + * Backed by a `ConcurrentHashMap`, whose per-entry atomicity is sufficient here: `add` and the + * `putAll` in `merge` are last-write-wins per key, and the reads (`value`, + * `accumulatedNumPartitions`, `foldValues`) only require thread-safety and eventual consistency + * -- they are weakly consistent during concurrent updates but exact once all updates have been + * merged. This avoids any explicit locking (and the nested-lock pattern a two-map `merge` would + * otherwise need). + * + * @tparam T the per-partition value type. Must be non-null (`ConcurrentHashMap` forbids nulls). + */ +class PartitionKeyedAccumulator[T] extends AccumulatorV2[(Int, T), java.util.Map[Int, T]] { + + // partition id -> value. + private val byPartition = new ConcurrentHashMap[Int, T]() + + override def isZero: Boolean = byPartition.isEmpty + + override def copyAndReset(): PartitionKeyedAccumulator[T] = new PartitionKeyedAccumulator[T] + + override def copy(): PartitionKeyedAccumulator[T] = { + val newAcc = new PartitionKeyedAccumulator[T] + newAcc.byPartition.putAll(byPartition) + newAcc + } + + override def reset(): Unit = byPartition.clear() + + override def add(v: (Int, T)): Unit = byPartition.put(v._1, v._2) + + override def merge(other: AccumulatorV2[(Int, T), java.util.Map[Int, T]]): Unit = other match { + case o: PartitionKeyedAccumulator[T] => + // Last-write-wins per partition id: a partition recorded by more than one task replaces + // rather than accumulates, keeping any caller-derived aggregate exact. + byPartition.putAll(o.byPartition) + case _ => throw new UnsupportedOperationException( + s"Cannot merge ${this.getClass.getName} with ${other.getClass.getName}") + } + + // A read-only VIEW over the live map -- no copy. Only the accumulator framework calls `value` + // (event log / `toInfo` / `toString`); our own code reads via `accumulatedNumPartitions` / + // `foldValues`. The view is thread-safe (ConcurrentHashMap) and weakly consistent, which matches + // this accumulator's eventual-consistency contract. + override def value: java.util.Map[Int, T] = java.util.Collections.unmodifiableMap(byPartition) + + /** Number of distinct partitions that have been recorded. */ + def accumulatedNumPartitions: Long = byPartition.size().toLong + + /** Folds the per-partition values (each partition counted once) into a single aggregate. */ + def foldValues[A](zero: A)(op: (A, T) => A): A = { + var result = zero + val it = byPartition.values().iterator() + while (it.hasNext) result = op(result, it.next()) + result + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/types/JavaGeographyTypeSuite.java b/sql/core/src/test/java/org/apache/spark/sql/types/JavaGeographyTypeSuite.java similarity index 100% rename from sql/core/src/test/scala/org/apache/spark/sql/types/JavaGeographyTypeSuite.java rename to sql/core/src/test/java/org/apache/spark/sql/types/JavaGeographyTypeSuite.java diff --git a/sql/core/src/test/scala/org/apache/spark/sql/types/JavaGeometryTypeSuite.java b/sql/core/src/test/java/org/apache/spark/sql/types/JavaGeometryTypeSuite.java similarity index 97% rename from sql/core/src/test/scala/org/apache/spark/sql/types/JavaGeometryTypeSuite.java rename to sql/core/src/test/java/org/apache/spark/sql/types/JavaGeometryTypeSuite.java index 867957884d230..9a19367a25c24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/types/JavaGeometryTypeSuite.java +++ b/sql/core/src/test/java/org/apache/spark/sql/types/JavaGeometryTypeSuite.java @@ -64,8 +64,8 @@ public void geometryTypeWithSpecifiedInvalidSridTest() { public void geometryTypeWithSpecifiedValidCrsTest() { // Valid CRS values for GEOMETRY (including Spark overrides) Stream.of( - "SRID:0", "EPSG:3857", "OGC:CRS84", "EPSG:4326", "OGC:CRS27", "EPSG:4267", "OGC:CRS83", "EPSG:4269", - "EPSG:2000", "ESRI:102100") + "SRID:0", "EPSG:3857", "OGC:CRS84", "EPSG:4326", "OGC:CRS27", "EPSG:4267", + "OGC:CRS83", "EPSG:4269", "EPSG:2000", "ESRI:102100") .forEach(crs -> { Integer srid = CartesianSpatialReferenceSystemMapper.getSrid(crs); DataType geometryType = DataTypes.createGeometryType(crs); diff --git a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala index 106ee36594b38..085dbcd804665 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/CachedTableSuite.scala @@ -474,12 +474,12 @@ class CachedTableSuite extends SharedSparkSession val toBeCleanedAccIds = new HashSet[Long] val accId1 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.materializationAccumulatorId }.head toBeCleanedAccIds += accId1 val accId2 = spark.table("t1").queryExecution.withCachedData.collect { - case i: InMemoryRelation => i.cacheBuilder.sizeInBytesStats.id + case i: InMemoryRelation => i.cacheBuilder.materializationAccumulatorId }.head toBeCleanedAccIds += accId2 @@ -509,6 +509,37 @@ class CachedTableSuite extends SharedSparkSession } } + test("SPARK-57547: clearCache resets materialization bookkeeping") { + val df = spark.range(0, 100, 1, numPartitions = 4).filter($"id" >= 0) + df.cache() + try { + val cacheRelations = df.queryExecution.withCachedData.collect { + case i: InMemoryRelation => i + } + assert(cacheRelations.length == 1) + val builder = cacheRelations.head.cacheBuilder + // Force the cache build directly (a plain df action can be served from the query-result + // cache and skip the rebuild after clearCache). + builder.cachedColumnBuffers.count() + assert(builder.isCachedColumnBuffersLoaded) + assert(builder.materializedRowCount == 100L) + + builder.clearCache() + // The loaded latch and the materialization stats must not survive clearCache, otherwise a + // rebuilt cache would inherit a stale "loaded" state with stale/zero statistics. + assert(!builder.isCachedColumnBuffersLoaded) + assert(builder.materializedRowCount == 0L) + assert(builder.materializedSizeInBytes == 0L) + + // Rebuilding works and reports correct stats again. + builder.cachedColumnBuffers.count() + assert(builder.isCachedColumnBuffersLoaded) + assert(builder.materializedRowCount == 100L) + } finally { + df.unpersist(blocking = true) + } + } + test("SPARK-10327 Cache Table is not working while subquery has alias in its project list") { withTempView("abc") { sparkContext.parallelize((1, 1) :: (2, 2) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala index 36244071206b8..a49d74e6276a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameAggregateSuite.scala @@ -1088,6 +1088,56 @@ class DataFrameAggregateSuite extends SharedSparkSession ) } + test("SPARK-57329: mode normalizes -0.0/0.0 in the frequency buffer") { + checkAnswer( + Seq(0.0d, 0.0d, -0.0d, -0.0d, 9.0d, 9.0d, 9.0d).toDF("d").select(expr("mode(d)")), + Row(0.0d)) + checkAnswer( + Seq(0.0f, 0.0f, -0.0f, -0.0f, 9.0f, 9.0f, 9.0f).toDF("f").select(expr("mode(f)")), + Row(0.0f)) + + checkAnswer( + Seq(Array(-0.0d), Array(-0.0d), Array(0.0d), Array(0.0d), + Array(9.0d), Array(9.0d), Array(9.0d)).toDF("a").select(expr("mode(a)")), + Row(Seq(0.0d))) + + // pandas_mode shares the same normalization path; cover it explicitly. It is an + // internal expression, so invoke it via Column.internalFn rather than SQL. + checkAnswer( + Seq(0.0d, 0.0d, -0.0d, -0.0d, 9.0d).toDF("d") + .select(Column.internalFn("pandas_mode", col("d"), lit(true))), + Row(Seq(0.0d))) + checkAnswer( + Seq(0.0f, 0.0f, -0.0f, -0.0f, 9.0f).toDF("f") + .select(Column.internalFn("pandas_mode", col("f"), lit(true))), + Row(Seq(0.0f))) + checkAnswer( + Seq(Array(-0.0d), Array(-0.0d), Array(0.0d), Array(0.0d), Array(9.0d)).toDF("a") + .select(Column.internalFn("pandas_mode", col("a"), lit(true))), + Row(Seq(Seq(0.0d)))) + + // Struct complex type: same recursive NormalizeFloatingNumbers path as the array case + // above, but a different shape. -0.0/0.0 collapse to 4 occurrences, outvoting 9.0's 3. + checkAnswer( + sql("SELECT mode(named_struct('a', v)) FROM " + + "VALUES (-0.0D), (-0.0D), (0.0D), (0.0D), (9.0D), (9.0D), (9.0D) AS t(v)"), + Row(Row(0.0d))) + + // NaN with differing bit patterns nested in a complex type. The normalization lambda + // canonicalizes the NaN bits so the two patterns collapse to 4 occurrences, outvoting + // 9.0's 3. Without normalization each pattern forms its own group of 2 and 9.0 (3) wins, + // so the expected NaN result genuinely distinguishes fixed from buggy behavior. + val nan1 = java.lang.Double.longBitsToDouble(0x7ff8000000000000L) + val nan2 = java.lang.Double.longBitsToDouble(0x7ff8000000000001L) + assert(nan1.isNaN && nan2.isNaN && + java.lang.Double.doubleToRawLongBits(nan1) != + java.lang.Double.doubleToRawLongBits(nan2)) + checkAnswer( + Seq(nan1, nan1, nan2, nan2, 9.0d, 9.0d, 9.0d).toDF("v") + .select(struct(col("v")).as("s")).select(expr("mode(s)")), + Row(Row(Double.NaN))) + } + test("SPARK-27581: DataFrame count_distinct(\"*\") shouldn't fail with AnalysisException") { val df = sql("select id % 100 from range(100000)") val distinctCount1 = df.select(expr("count(distinct(*))")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala index 5192d5d66f243..3d6fd5c1bf5fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala @@ -2571,6 +2571,7 @@ class DataFrameSuite extends SharedSparkSession test("SPARK-41048: Improve output partitioning and ordering with AQE cache") { withSQLConf( SQLConf.CAN_CHANGE_CACHED_PLAN_OUTPUT_PARTITIONING.key -> "true", + SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key -> "0", SQLConf.AUTO_BROADCASTJOIN_THRESHOLD.key -> "-1") { val df1 = spark.range(10).selectExpr("cast(id as string) c1") val df2 = spark.range(10).selectExpr("cast(id as string) c2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala index c46cde0d0db13..f79824de8ff3b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameWindowFunctionsSuite.scala @@ -20,10 +20,11 @@ package org.apache.spark.sql import org.scalatest.matchers.must.Matchers.the import org.apache.spark.TestUtils.{assertNotSpilled, assertSpilled} -import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, Lag, Literal, NonFoldableLiteral} +import org.apache.spark.sql.catalyst.expressions.{Add, AggregateWindowFunction, AttributeReference, Expression, If, IsNotNull, Lag, Literal, NonFoldableLiteral, RowNumber} import org.apache.spark.sql.catalyst.optimizer.TransposeWindow import org.apache.spark.sql.catalyst.plans.logical.{Window => LogicalWindow} import org.apache.spark.sql.catalyst.plans.physical.HashPartitioning +import org.apache.spark.sql.catalyst.trees.UnaryLike import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper import org.apache.spark.sql.execution.exchange.{ENSURE_REQUIREMENTS, Exchange, ShuffleExchangeExec} import org.apache.spark.sql.execution.window.WindowExec @@ -880,6 +881,28 @@ class DataFrameWindowFunctionsSuite extends SharedSparkSession ) } + test("SPARK-57505: a window function expression wrapped into a Column works with over()") { + val df = Seq((1, "a"), (2, "a"), (3, "b")).toDF("value", "key") + val window = Window.partitionBy($"key").orderBy($"value") + // Wrapping a catalyst window function expression directly with Column(expr) used to box the + // AggregateWindowFunction (RowNumber is one) in an AggregateExpression, which then failed + // analysis with WINDOW_FUNCTION_WITHOUT_OVER_CLAUSE. It must now behave like by-name + // row_number(). + checkAnswer( + df.select($"value", Column(RowNumber()).over(window).as("rn")), + Seq(Row(1, 1), Row(2, 2), Row(3, 1))) + } + + test("SPARK-57505: a custom AggregateWindowFunction wrapped into a Column works with over()") { + val df = Seq((1, "a"), (2, "a"), (3, "b")).toDF("value", "key") + val window = Window.partitionBy($"key").orderBy($"value") + // Mirrors plugging in a user-defined AggregateWindowFunction through the Column API: + // Column(MyWindowFunction(inputColumn.expr)).over(window) + checkAnswer( + df.select($"value", Column(NonNullRunningCount($"value".expr)).over(window).as("cnt")), + Seq(Row(1, 1), Row(2, 2), Row(3, 1))) + } + test("SPARK-12989 ExtractWindowExpressions treats alias as regular attribute") { val src = Seq((0, 3, 5)).toDF("a", "b", "c") .withColumn("Data", struct("a", "b")) @@ -1808,3 +1831,25 @@ class DataFrameWindowFunctionsSuite extends SharedSparkSession } } } + +/** + * A minimal user-defined window function, it counts the non-null values of `child` from the start + * of the window frame up to and including the current row. + */ +case class NonNullRunningCount(child: Expression) + extends AggregateWindowFunction with UnaryLike[Expression] { + + private lazy val count = AttributeReference("count", IntegerType, nullable = false)() + + override lazy val aggBufferAttributes: Seq[AttributeReference] = count :: Nil + override lazy val initialValues: Seq[Expression] = Literal(0) :: Nil + override lazy val updateExpressions: Seq[Expression] = + If(IsNotNull(child), Add(count, Literal(1)), count) :: Nil + override lazy val evaluateExpression: Expression = count + + override def nullable: Boolean = false + override def prettyName: String = "non_null_running_count" + + override protected def withNewChildInternal(newChild: Expression): NonNullRunningCount = + copy(child = newChild) +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala index af52204dbb7ff..618f0ab675e19 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/ExplainSuite.scala @@ -552,6 +552,9 @@ class ExplainSuite extends ExplainSuiteHelper with DisableAdaptiveExecutionSuite class ExplainSuiteAE extends ExplainSuiteHelper with EnableAdaptiveExecutionSuite { import testImplicits._ + override protected def sparkConf = + super.sparkConf.set(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key, "0") + test("SPARK-35884: Explain Formatted") { val df1 = Seq((1, 2), (2, 3)).toDF("k", "v1") val df2 = Seq((2, 3), (1, 1)).toDF("k", "v2") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala index 3ea77446f268d..7f695e90df884 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/JoinSuite.scala @@ -46,6 +46,9 @@ class JoinSuite extends SharedSparkSession with AdaptiveSparkPlanHelper setupTestData() + override protected def sparkConf = + super.sparkConf.set(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key, "0") + def statisticSizeInByte(df: classic.DataFrame): BigInt = { df.queryExecution.optimizedPlan.stats.sizeInBytes } @@ -1834,6 +1837,10 @@ class ThreadLeakInSortMergeJoinSuite with AdaptiveSparkPlanHelper { setupTestData() + + override protected def sparkConf = + super.sparkConf.set(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key, "0") + override protected def createSparkSession: TestSparkSession = { classic.SparkSession.cleanupAnyExistingSession() new TestSparkSession( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala index 389d0d5a29d59..395cb67f44155 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQueryTestSuite.scala @@ -179,6 +179,8 @@ class SQLQueryTestSuite extends SharedSparkSession with SQLHelper // regex magic. .set("spark.test.noSerdeInExplain", "true") .set(SQLConf.SCHEMA_LEVEL_COLLATIONS_ENABLED, true) + // SPARK-57667: pin so AQE keeps SMJ regardless of the default + .set(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD, 0L) // SPARK-32106 Since we add SQL test 'transform.sql' will use `cat` command, // here we need to ignore it. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala index a6f4de1c80e3e..37684c7fce3b6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/collation/CollationSuite.scala @@ -49,6 +49,9 @@ class CollationSuite extends DatasourceV2SQLBase with AdaptiveSparkPlanHelper { private val fullyQualifiedPrefix = s"${CollationFactory.CATALOG}.${CollationFactory.SCHEMA}." private val collations = Seq("UTF8_BINARY", "UTF8_LCASE", "UNICODE", "UNICODE_CI") + override protected def sparkConf = + super.sparkConf.set(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key, "0") + @inline private def isSortMergeForced: Boolean = { SQLConf.get.getConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD) == -1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2IncrementallyConstructedQueryTests.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2IncrementallyConstructedQueryTests.scala index 1dbaad18e3e71..9753f7db20976 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2IncrementallyConstructedQueryTests.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DSv2IncrementallyConstructedQueryTests.scala @@ -234,7 +234,7 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas // --------------------------------------------------------------------------- // Scenario 4: external drop and recreate table. // 4a: table ID detects it, TABLE_ID_MISMATCH in classic, succeeds in Connect - // 4b: column IDs detect it, COLUMN_ID_MISMATCH in classic, succeeds in Connect + // 4b: column IDs detect it, COLUMNS_MISMATCH in classic, succeeds in Connect // 4c: no IDs, goes undetected, join succeeds (both modes) // --------------------------------------------------------------------------- @@ -321,7 +321,7 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas exception = intercept[AnalysisException] { df1.join(df2, df1("id") === df2("id")).collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> "(?s).*")) } @@ -364,7 +364,7 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas df1.join(df2, df1("id") === df2("id")), Seq(Row(2, 200, 2, 200))) } else { - // Classic: neither TABLE_ID_MISMATCH nor COLUMN_ID_MISMATCH fires, so the + // Classic: neither TABLE_ID_MISMATCH nor COLUMNS_MISMATCH fires, so the // drop and recreate goes undetected. df1 keeps its pre-drop snapshot // (1, 100) while df2 reads the recreated table (2, 200), so the join finds // no matching ids and returns no rows. @@ -378,7 +378,7 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas // --------------------------------------------------------------------------- // Scenario 5: external drop+re-add column. - // 5a: column IDs detect it, COLUMN_ID_MISMATCH in classic, succeeds in Connect + // 5a: column IDs detect it, COLUMNS_MISMATCH in classic, succeeds in Connect // 5b: no IDs, goes undetected, join succeeds (both modes) // --------------------------------------------------------------------------- @@ -411,7 +411,7 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas exception = intercept[AnalysisException] { df1.join(df2, df1("id") === df2("id")).collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> "(?s).*")) } @@ -438,7 +438,7 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas val df2 = session.table(nullBothT) - // Neither TABLE_ID_MISMATCH nor COLUMN_ID_MISMATCH fires. + // Neither TABLE_ID_MISMATCH nor COLUMNS_MISMATCH fires. // The change goes undetected and the join succeeds. checkRows( df1.join(df2, df1("id") === df2("id")), @@ -449,10 +449,9 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas // --------------------------------------------------------------------------- // Scenario 6: external type change (drop INT column, add STRING column). - // The delete removes the old column ID and the add assigns a fresh one, - // so the column ID check fires (COLUMN_ID_MISMATCH) in classic before schema - // validation gets a chance to compare data types. - // Connect re-resolves both sides with the new column ID. + // The drop removes the old column ID and the add assigns a fresh one, + // so schema validation reports both an ID change and a type change in COLUMNS_MISMATCH. + // Connect re-resolves both sides with the new column. // --------------------------------------------------------------------------- test(s"${testPrefix}SPARK-54157: join after external drop+re-add different-type column" + @@ -485,7 +484,7 @@ trait DSv2IncrementallyConstructedQueryTests extends DSv2ExternalMutationTestBas exception = intercept[AnalysisException] { df1.join(df2, df1("id") === df2("id")).collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> "(?s).*")) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala index f272f28a5f92f..4cbce9d9cb84c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2DataFrameSuite.scala @@ -28,10 +28,9 @@ import org.apache.spark.sql.{AnalysisException, DataFrame, Row, SaveMode, SparkS import org.apache.spark.sql.QueryTest.withQueryExecutionsCaptured import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.plans.logical.{AppendData, CreateTableAsSelect, LogicalPlan, ReplaceTableAsSelect} -import org.apache.spark.sql.connector.catalog.{CachingInMemoryTableCatalog, Column, ColumnDefaultValue, ComposedColumnIdTableCatalog, DefaultValue, Identifier, InMemoryTableCatalog, MixedColumnIdTableCatalog, NullColumnIdInMemoryTableCatalog, NullTableIdAndNullColumnIdInMemoryTableCatalog, NullTableIdInMemoryTableCatalog, SupportsV1OverwriteWithSaveAsTable, TableCatalog, TableInfo, TypeChangeResetsColIdTableCatalog} -import org.apache.spark.sql.connector.catalog.BasicInMemoryTableCatalog -import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, UpdateColumnDefaultValue} +import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, CachingInMemoryTableCatalog, CatalogV2Util, Column, ColumnDefaultValue, DefaultValue, Identifier, InMemoryBaseTable, InMemoryTableCatalog, MixedColumnIdTableCatalog, NullColumnIdInMemoryTableCatalog, NullTableIdAndNullColumnIdInMemoryTableCatalog, NullTableIdInMemoryTableCatalog, SupportsV1OverwriteWithSaveAsTable, TableCatalog, TableInfo, TypeChangeResetsColIdTableCatalog} import org.apache.spark.sql.connector.catalog.TableChange +import org.apache.spark.sql.connector.catalog.TableChange.{AddColumn, UpdateColumnDefaultValue} import org.apache.spark.sql.connector.catalog.TableWritePrivilege import org.apache.spark.sql.connector.catalog.TruncatableTable import org.apache.spark.sql.connector.expressions.{ApplyTransform, GeneralScalarExpression, LiteralValue, Transform} @@ -43,6 +42,7 @@ import org.apache.spark.sql.functions.{lit, sum} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.{BooleanType, CalendarIntervalType, DoubleType, IntegerType, LongType, StringType, StructType, TimestampType} import org.apache.spark.sql.util.QueryExecutionListener +import org.apache.spark.sql.util.SchemaUtils import org.apache.spark.unsafe.types.UTF8String class DataSourceV2DataFrameSuite @@ -77,9 +77,7 @@ class DataSourceV2DataFrameSuite .set("spark.sql.catalog.mixedcolidcat", classOf[MixedColumnIdTableCatalog].getName) .set("spark.sql.catalog.mixedcolidcat.copyOnLoad", "true") - .set("spark.sql.catalog.composedidcat", - classOf[ComposedColumnIdTableCatalog].getName) - .set("spark.sql.catalog.composedidcat.copyOnLoad", "true") + .set(InMemoryBaseTable.ASSIGN_COLUMN_IDS, "true") after { catalog("cachingcat").asInstanceOf[CachingInMemoryTableCatalog].clearCache() @@ -1017,7 +1015,7 @@ class DataSourceV2DataFrameSuite Column.create("c1", IntegerType), Column.create("c2", StringType)) } - assert(cols === expectedCols) + assert(CatalogV2Util.clearIds(cols) === expectedCols) } } } @@ -1605,7 +1603,7 @@ class DataSourceV2DataFrameSuite // // Core behavior: when a DataFrame captures column IDs at analysis time, // and those IDs change before execution, the query is rejected with - // COLUMN_ID_MISMATCH. + // COLUMNS_MISMATCH. test("drop+re-add column with same name and type rejects stale DataFrame") { val t = "testcat.ns1.ns2.tbl" @@ -1620,7 +1618,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -1639,9 +1637,15 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, - parameters = Map("tableName" -> ".*", "errors" -> ".*")) + parameters = Map( + "tableName" -> ".*", + "errors" -> + """| + |- `salary` field ID has changed from \d+ to \d+ + |- `salary` type has changed from INT to STRING + |""".stripMargin.strip)) } } @@ -1658,7 +1662,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -1679,7 +1683,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> "(?s).*salary.*bonus.*")) @@ -1721,7 +1725,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -1756,7 +1760,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -1774,14 +1778,14 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } } test("drop+re-add nested struct field rejects stale DataFrame") { - val t = "composedidcat.ns1.ns2.tbl" + val t = "testcat.ns1.ns2.tbl" withTable(t) { sql(s"CREATE TABLE $t (id INT, person STRUCT) USING foo") sql(s"INSERT INTO $t VALUES (1, named_struct('name', 'Alice', 'age', 30))") @@ -1792,7 +1796,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -1831,7 +1835,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -1848,7 +1852,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -1899,54 +1903,7 @@ class DataSourceV2DataFrameSuite } } - // Column ID tests: Composed nested IDs - // - // ComposedColumnIdTableCatalog encodes nested field IDs into the - // top-level Column.id() string, modeling the recommended adoption - // pattern for connectors with nested IDs. Any nested - // change produces a different encoded string, so validateColumnIds - // detects it even though Spark only compares top-level strings. - - test("composed nested IDs detect drop+re-add of nested field") { - val t = "composedidcat.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id INT, person STRUCT) USING foo") - sql(s"INSERT INTO $t VALUES (1, named_struct('name', 'Alice', 'age', 30))") - val df = spark.table(t) - - sql(s"ALTER TABLE $t DROP COLUMN person.age") - sql(s"ALTER TABLE $t ADD COLUMN person.age INT") - - // The inner age field got a new nested ID on re-add. The composed - // top-level string changes, so COLUMN_ID_MISMATCH fires. - checkError( - exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", - matchPVals = true, - parameters = Map("tableName" -> ".*", "errors" -> ".*")) - } - } - - test("composed nested IDs tolerate same data inserted into nested column") { - val t = "composedidcat.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id INT, person STRUCT) USING foo") - sql(s"INSERT INTO $t VALUES (1, named_struct('name', 'Alice', 'age', 30))") - val df = spark.table(t) - - // pure data insert, no schema change: composed IDs stay the same - sql(s"INSERT INTO $t VALUES (2, named_struct('name', 'Bob', 'age', 25))") - - checkAnswer(df, Seq( - Row(1, Row("Alice", 30)), - Row(2, Row("Bob", 25)))) - } - } - // Column ID tests: Additional nested coverage - // - // These tests fill specific nested cells that are not covered by the - // coarse (testcat) or composed (composedidcat) groups above. // Nested type change with preserved top-level ID: the standard catalog // preserves the parent ID, so schema validation catches the incompatible @@ -1970,10 +1927,8 @@ class DataSourceV2DataFrameSuite } } - // Depth >= 3 nesting with composed IDs: drop+re-add at depth 3 produces - // a different composed ID at the top level. - test("depth 3 nesting with composed IDs detects deep field change") { - val t = "composedidcat.ns1.ns2.tbl" + test("depth 3 nesting detects deep nested field drop+re-add") { + val t = "testcat.ns1.ns2.tbl" withTable(t) { sql(s"CREATE TABLE $t (id INT, a STRUCT>) USING foo") sql(s"INSERT INTO $t VALUES (1, named_struct('b', named_struct('c', 42)))") @@ -1982,11 +1937,9 @@ class DataSourceV2DataFrameSuite sql(s"ALTER TABLE $t DROP COLUMN a.b.c") sql(s"ALTER TABLE $t ADD COLUMN a.b.c INT") - // The deep nested field c got a new ID on re-add, changing the - // composed top-level ID for column a. checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -2021,14 +1974,8 @@ class DataSourceV2DataFrameSuite // nested fields are added or dropped (assignMissingIds matches by name // only). These tests verify that behavior using the catalog API. - // Column ID tests: Composed IDs for container types (arrays, maps) - // - // ComposedColumnIdTableCatalog encodes nested field IDs into the - // top-level string. These tests verify detection of nested drop+re-add - // inside array element structs and map value structs. - - test("composed nested IDs detect drop+re-add in array element struct") { - val t = "composedidcat.ns1.ns2.tbl" + test("drop+re-add in array element struct detected by field ID") { + val t = "testcat.ns1.ns2.tbl" withTable(t) { sql(s"CREATE TABLE $t (id INT, items ARRAY>) USING foo") sql(s"INSERT INTO $t VALUES (1, array(named_struct('name', 'x', 'price', 10)))") @@ -2037,18 +1984,16 @@ class DataSourceV2DataFrameSuite sql(s"ALTER TABLE $t DROP COLUMN items.element.price") sql(s"ALTER TABLE $t ADD COLUMN items.element.price INT") - // The nested price field got a new ID on re-add. The composed - // top-level ID for items changes, so COLUMN_ID_MISMATCH fires. checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } } - test("composed nested IDs detect drop+re-add in map value struct") { - val t = "composedidcat.ns1.ns2.tbl" + test("drop+re-add in map value struct detected by field ID") { + val t = "testcat.ns1.ns2.tbl" withTable(t) { sql(s"CREATE TABLE $t (id INT, props MAP>) USING foo") sql(s"INSERT INTO $t VALUES (1, map('k1', named_struct('x', 10, 'y', 20)))") @@ -2057,28 +2002,6 @@ class DataSourceV2DataFrameSuite sql(s"ALTER TABLE $t DROP COLUMN props.value.y") sql(s"ALTER TABLE $t ADD COLUMN props.value.y INT") - // The nested y field got a new ID on re-add. The composed - // top-level ID for props changes, so COLUMN_ID_MISMATCH fires. - checkError( - exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", - matchPVals = true, - parameters = Map("tableName" -> ".*", "errors" -> ".*")) - } - } - - test("composed nested IDs detect rename within struct") { - val t = "composedidcat.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id INT, person STRUCT) USING foo") - sql(s"INSERT INTO $t VALUES (1, named_struct('name', 'Alice', 'age', 30))") - val df = spark.table(t) - - sql(s"ALTER TABLE $t RENAME COLUMN person.name TO first_name") - - // With position-based keys, the renamed field stays at position 0 - // and keeps its nested ID. The composed string is unchanged, so - // schema validation catches the struct type difference instead. checkError( exception = intercept[AnalysisException] { df.collect() }, condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", @@ -2087,13 +2010,13 @@ class DataSourceV2DataFrameSuite } } - test("composed nested IDs: reorder preserves composed column ID") { - val t = "composedidcat.ns1.ns2.tbl" + test("reorder preserves top-level column ID") { + val t = "testcat.ns1.ns2.tbl" val ident = Identifier.of(Array("ns1", "ns2"), "tbl") withTable(t) { sql(s"CREATE TABLE $t (id INT, person STRUCT) USING foo") - val cat = catalog("composedidcat") + val cat = catalog("testcat") val personBefore = cat.loadTable(ident).columns().find(_.name() == "person").get val idBefore = personBefore.id() val typeBefore = personBefore.dataType() @@ -2112,32 +2035,14 @@ class DataSourceV2DataFrameSuite assert(typeAfter.toString.startsWith("StructType(StructField(age"), s"age should be first field after reorder, got: $typeAfter") - // Position-based keys: each ordinal position keeps its old ID after - // reorder, so the composed string is unchanged despite the schema change. + // The top-level column ID is preserved after reorder. assert(idBefore == idAfter, - s"Composed ID should be unchanged after reorder: $idBefore vs $idAfter") - } - } - - test("composed nested IDs tolerate nested field reorder end-to-end") { - val t = "composedidcat.ns1.ns2.tbl" - withTable(t) { - sql(s"CREATE TABLE $t (id INT, person STRUCT) USING foo") - sql(s"INSERT INTO $t VALUES (1, named_struct('name', 'Alice', 'age', 30))") - val df = spark.table(t) - - sql(s"ALTER TABLE $t ALTER COLUMN person.age FIRST") - - // InMemoryTable does not actually reorder nested struct fields in stored - // data, so the read still returns the original field order. This is fine - // because the purpose of this test is to verify that the column ID check - // passes (no COLUMN_ID_MISMATCH) after a nested field reorder. - checkAnswer(df, Seq(Row(1, Row("Alice", 30)))) + s"Top-level column ID should be unchanged after reorder: $idBefore vs $idAfter") } } - test("composed nested IDs detect drop+re-add in map key struct") { - val t = "composedidcat.ns1.ns2.tbl" + test("drop+re-add in map key struct detected by field ID") { + val t = "testcat.ns1.ns2.tbl" withTable(t) { sql(s"CREATE TABLE $t " + s"(id INT, coords MAP, STRING>) USING foo") @@ -2148,11 +2053,9 @@ class DataSourceV2DataFrameSuite sql(s"ALTER TABLE $t DROP COLUMN coords.key.y") sql(s"ALTER TABLE $t ADD COLUMN coords.key.y INT") - // The nested y field in the map key struct got a new ID on re-add. - // The composed top-level ID for coords changes. checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -2253,7 +2156,7 @@ class DataSourceV2DataFrameSuite exception = intercept[AnalysisException] { df1.join(df2, df1("id") === df2("id")).collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -2278,7 +2181,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.filter("salary > 50").collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -2299,7 +2202,7 @@ class DataSourceV2DataFrameSuite exception = intercept[AnalysisException] { df.groupBy("id").agg(sum("salary")).collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -2318,7 +2221,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.orderBy("salary").collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -2337,7 +2240,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.select("salary").collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -2361,7 +2264,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -2413,7 +2316,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) } @@ -2434,7 +2337,7 @@ class DataSourceV2DataFrameSuite // stale DataFrame detects salary ID mismatch checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*salary.*")) @@ -2465,9 +2368,15 @@ class DataSourceV2DataFrameSuite // reset-id catalog assigns a new ID for the widened column checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, - parameters = Map("tableName" -> ".*", "errors" -> ".*")) + parameters = Map( + "tableName" -> ".*", + "errors" -> + """| + |- `salary` field ID has changed from \d+ to \d+ + |- `salary` type has changed from INT to BIGINT + |""".stripMargin.strip)) } } @@ -2568,7 +2477,7 @@ class DataSourceV2DataFrameSuite // [[commandExecuted]] phase, before the refresh phase runs. As a result, // column ID validation does not apply to the source DataFrame in a // [[writeTo]] path. The append succeeds without throwing a - // COLUMN_ID_MISMATCH error. + // COLUMNS_MISMATCH error. test("writeTo().append() does not throw column ID mismatch after drop+re-add column") { val t = "testcat.ns1.ns2.tbl" withTable(t) { @@ -2580,7 +2489,7 @@ class DataSourceV2DataFrameSuite sql(s"ALTER TABLE $t ADD COLUMN salary INT") // Command is eagerly executed before the refresh phase validates - // column IDs. No COLUMN_ID_MISMATCH exception is thrown. + // column IDs. No COLUMNS_MISMATCH exception is thrown. source.writeTo(t).append() } } @@ -2599,7 +2508,7 @@ class DataSourceV2DataFrameSuite exception = intercept[AnalysisException] { source.write.format(v2Format).insertInto(t) }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> "(?s).*")) } @@ -2627,7 +2536,7 @@ class DataSourceV2DataFrameSuite checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> "(?s).*")) } @@ -2670,7 +2579,7 @@ class DataSourceV2DataFrameSuite sql(s"ALTER TABLE $t ADD COLUMN bonus INT") // The stale DataFrame has only [id, salary] while the table now has - // [id, salary, bonus]. Since column IDs are null, no COLUMN_ID_MISMATCH + // [id, salary, bonus]. Since column IDs are null, no COLUMNS_MISMATCH // error is thrown; new columns are tolerated for read queries. checkAnswer(df, Seq(Row(1, 100))) } @@ -2739,6 +2648,22 @@ class DataSourceV2DataFrameSuite } } + test("SPARK-57544: V1 CTAS from DSv2 scan does not persist column IDs in catalog schema") { + val v2Src = "testcat.ns1.ns2.v2_src" + val v1Dst = "v1_dst" + withTable(v2Src, v1Dst) { + sql(s"CREATE TABLE $v2Src (id INT, salary INT) USING foo") + sql(s"INSERT INTO $v2Src VALUES (1, 100)") + + // DSv2 source carries column IDs on its output attributes. + assert(spark.table(v2Src).schema.fields.forall(_.id.isDefined)) + + // V1 CTAS from a DSv2 scan: column IDs must not be stored in the V1 catalog schema. + sql(s"CREATE TABLE $v1Dst USING parquet AS SELECT * FROM $v2Src") + assert(spark.table(v1Dst).schema.fields.forall(_.id.isEmpty)) + } + } + test("SPARK-53924: temp view on DSv2 table allows top-level column additions") { val t = "testcat.ns1.ns2.tbl" withTable(t) { @@ -3459,7 +3384,7 @@ class DataSourceV2DataFrameSuite df.write.mode("append").format(v2Format).withSchemaEvolution().saveAsTable(t) - assert(spark.table(t).schema === + assert(SchemaUtils.clearFieldIds(spark.table(t).schema) === new StructType().add("id", LongType).add("data", StringType)) checkAnswer(spark.table(t), Seq(Row(1L, "a"))) } @@ -3473,7 +3398,7 @@ class DataSourceV2DataFrameSuite df.write.format(v2Format).withSchemaEvolution().insertInto(t) - assert(spark.table(t).schema === + assert(SchemaUtils.clearFieldIds(spark.table(t).schema) === new StructType().add("id", LongType).add("data", StringType)) checkAnswer(spark.table(t), Seq(Row(1L, "a"))) } @@ -3487,7 +3412,7 @@ class DataSourceV2DataFrameSuite df.write.mode("overwrite").format(v2Format).withSchemaEvolution().insertInto(t) - assert(spark.table(t).schema === + assert(SchemaUtils.clearFieldIds(spark.table(t).schema) === new StructType().add("id", LongType).add("data", StringType)) checkAnswer(spark.table(t), Seq(Row(1L, "a"))) } @@ -3502,7 +3427,7 @@ class DataSourceV2DataFrameSuite Seq((1L, "a")).toDF("id", "data") .write.mode("overwrite").format(v2Format).withSchemaEvolution().insertInto(t) - assert(spark.table(t).schema === + assert(SchemaUtils.clearFieldIds(spark.table(t).schema) === new org.apache.spark.sql.types.StructType() .add("id", org.apache.spark.sql.types.LongType) .add("data", StringType)) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2ExtSessionColumnIdSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2ExtSessionColumnIdSuite.scala index ed46f33e7df01..521dc45f351aa 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2ExtSessionColumnIdSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2ExtSessionColumnIdSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.connector import org.apache.spark.SparkConf import org.apache.spark.sql.{AnalysisException, QueryTest, Row, SparkSession} -import org.apache.spark.sql.connector.catalog.SharedInMemoryTableCatalog +import org.apache.spark.sql.connector.catalog.{InMemoryBaseTable, SharedInMemoryTableCatalog} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -47,6 +47,7 @@ class DataSourceV2ExtSessionColumnIdSuite extends QueryTest with SharedSparkSess // copyOnLoad: each loadTable returns a fresh copy, simulating a real // catalog where metadata is reloaded from the metastore on each access .set("spark.sql.catalog.sharedcat.copyOnLoad", "true") + .set(InMemoryBaseTable.ASSIGN_COLUMN_IDS, "true") override def afterEach(): Unit = { try { @@ -133,7 +134,7 @@ class DataSourceV2ExtSessionColumnIdSuite extends QueryTest with SharedSparkSess exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map( "tableName" -> ".*", "errors" -> ".*")) @@ -156,7 +157,7 @@ class DataSourceV2ExtSessionColumnIdSuite extends QueryTest with SharedSparkSess // NullTableIdInMemoryTableCatalog), so column ID check catches it checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> "(?s).*")) } @@ -201,7 +202,7 @@ class DataSourceV2ExtSessionColumnIdSuite extends QueryTest with SharedSparkSess // both column ID mismatches are detected checkError( exception = intercept[AnalysisException] { df.collect() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> "(?s).*salary.*bonus.*")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetadataTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetadataTableSuite.scala index 37acbf1e0442f..ef9a69eaf25ce 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetadataTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetadataTableSuite.scala @@ -20,14 +20,14 @@ package org.apache.spark.sql.connector import org.apache.spark.SparkConf import org.apache.spark.sql.{QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.NoSuchTableException -import org.apache.spark.sql.connector.catalog.{Identifier, MetadataTable, Table, TableCatalog, TableChange, TableInfo, TableSummary} +import org.apache.spark.sql.connector.catalog.{DelegatingTable, Identifier, Table, TableCatalog, TableChange, TableInfo, TableSummary} import org.apache.spark.sql.connector.expressions.LogicalExpressions import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap /** - * Tests for the data-source-table side of [[MetadataTable]]: a v2 catalog returns + * Tests for the data-source-table side of [[DelegatingTable]]: a v2 catalog returns * metadata-only tables and Spark reads / writes them via the V1 data-source path. * View-related paths live in [[DataSourceV2MetadataViewSuite]]. */ @@ -111,7 +111,7 @@ class DataSourceV2MetadataTableSuite extends QueryTest with SharedSparkSession { } /** - * A read-only [[TableCatalog]] that returns [[MetadataTable]] for a small set of canned + * A read-only [[TableCatalog]] that returns [[DelegatingTable]] for a small set of canned * table fixtures. Used to drive the data-source-table read path (file source + v2 provider) * through Spark's V1 data-source machinery. */ @@ -124,7 +124,7 @@ class TestingDataSourceTableCatalog extends TableCatalog { .withLocation(ident.namespace().head) .withTableType(TableSummary.EXTERNAL_TABLE_TYPE) .build() - new MetadataTable(info, ident.toString) + new DelegatingTable(info, ident.toString) case "test_partitioned_json" => val partitioning = LogicalExpressions.identity(LogicalExpressions.reference(Seq("c2"))) val info = new TableInfo.Builder() @@ -134,13 +134,13 @@ class TestingDataSourceTableCatalog extends TableCatalog { .withTableType(TableSummary.EXTERNAL_TABLE_TYPE) .withPartitions(Array(partitioning)) .build() - new MetadataTable(info, ident.toString) + new DelegatingTable(info, ident.toString) case "test_v2" => val info = new TableInfo.Builder() .withSchema(FakeV2Provider.schema) .withProvider(classOf[FakeV2Provider].getName) .build() - new MetadataTable(info, ident.toString) + new DelegatingTable(info, ident.toString) case _ => throw new NoSuchTableException(ident) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetadataViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetadataViewSuite.scala index 163c0957e0d0e..b8aaf59c816da 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetadataViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2MetadataViewSuite.scala @@ -20,21 +20,21 @@ package org.apache.spark.sql.connector import org.apache.spark.SparkConf import org.apache.spark.sql.{AnalysisException, QueryTest, Row} import org.apache.spark.sql.catalyst.analysis.{NoSuchTableException, NoSuchViewException, TableAlreadyExistsException, ViewAlreadyExistsException} -import org.apache.spark.sql.connector.catalog.{Identifier, MetadataTable, Table, TableCatalog, TableChange, TableInfo, TableSummary, TableViewCatalog, V1Table, ViewCatalog, ViewInfo} +import org.apache.spark.sql.connector.catalog.{DelegatingTable, Identifier, Relation, RelationCatalog, Table, TableCatalog, TableChange, TableInfo, TableSummary, V1Table, View, ViewCatalog} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap /** - * Tests for the view side of [[MetadataTable]]: view-text expansion on read, and + * Tests for the view side of [[DelegatingTable]]: view-text expansion on read, and * CREATE VIEW / ALTER VIEW ... AS going through the v2 write path * (`CreateV2ViewExec` / `AlterV2ViewExec`). View writes route through * [[ViewCatalog#createView]] / [[ViewCatalog#replaceView]]. * Data-source-table read paths live in * [[org.apache.spark.sql.connector.DataSourceV2MetadataTableSuite]]. * - * TODO: register a `MetadataTable`-backed `DelegatingCatalogExtension` as + * TODO: register a `DelegatingTable`-backed `DelegatingCatalogExtension` as * `spark.sql.catalog.spark_catalog` and run the shared * [[org.apache.spark.sql.execution.PersistedViewTestSuite]] body against the v2 path for full * parity with the v1 persisted-view coverage. @@ -43,7 +43,7 @@ class DataSourceV2MetadataViewSuite extends QueryTest with SharedSparkSession { import testImplicits._ override def sparkConf: SparkConf = super.sparkConf - .set("spark.sql.catalog.view_catalog", classOf[TestingTableViewCatalog].getName) + .set("spark.sql.catalog.view_catalog", classOf[TestingRelationCatalog].getName) // --- View read path ----------------------------------------------------- @@ -71,54 +71,52 @@ class DataSourceV2MetadataViewSuite extends QueryTest with SharedSparkSession { // End-to-end coverage of the v2 encoder -> parser round-trip: test_unqualified_multi is a // view whose captured catalog+namespace is view_catalog.ns1.ns2 (two-part namespace) and // whose body references `t` unqualified. At read time the unqualified `t` must expand to - // view_catalog.ns1.ns2.t via the captured context -- which TestingTableViewCatalog resolves to + // view_catalog.ns1.ns2.t via the captured context -- which TestingRelationCatalog resolves to // its own `t` fixture at that namespace. checkAnswer( spark.table("view_catalog.outer_ns.test_unqualified_multi"), Row("multi")) } - // --- ViewInfo unit tests ----------------------------------------------- + // --- View unit tests ----------------------------------------------- test("multi-part captured namespace round-trips through V1Table.toCatalogTable") { - // (a) ViewInfo.Builder stores (cat, Array(db1, db2)) as typed fields. + // (a) View.Builder stores (cat, Array(db1, db2)) as typed fields. // (b) V1Table.toCatalogTable reads them directly and emits v1's numbered // view.catalogAndNamespace.* keys so (c) the resulting CatalogTable's // `viewCatalogAndNamespace` exposes the full (cat, db1, db2), which is what the v1 // view-resolution path consumes to expand unqualified references in the view body. - val info = new ViewInfo.Builder() + val info = new View.Builder() .withSchema(new StructType().add("col", "string")) .withQueryText("SELECT col FROM t") .withCurrentCatalog("my_cat") .withCurrentNamespace(Array("db1", "db2")) .build() - val motTable = new MetadataTable(info, "v") // Any CatalogPlugin works here; toCatalogTable only reads `catalog.name()`. val catalog = spark.sessionState.catalogManager.catalog("view_catalog") val ct = V1Table.toCatalogTable( - catalog, Identifier.of(Array("ns"), "v"), motTable) + catalog, Identifier.of(Array("ns"), "v"), info) assert(ct.viewCatalogAndNamespace == Seq("my_cat", "db1", "db2")) // Namespace parts containing dots flow through structurally (no string encoding). - val infoWeird = new ViewInfo.Builder() + val infoWeird = new View.Builder() .withSchema(new StructType().add("col", "string")) .withQueryText("SELECT col FROM t") .withCurrentCatalog("my_cat") .withCurrentNamespace(Array("weird.db", "normal")) .build() val ctWeird = V1Table.toCatalogTable( - catalog, Identifier.of(Array("ns"), "v"), new MetadataTable(infoWeird, "v")) + catalog, Identifier.of(Array("ns"), "v"), infoWeird) assert(ctWeird.viewCatalogAndNamespace == Seq("my_cat", "weird.db", "normal")) } test("view with no captured catalog omits viewCatalogAndNamespace") { - val info = new ViewInfo.Builder() + val info = new View.Builder() .withSchema(new StructType().add("col", "string")) .withQueryText("SELECT * FROM spark_catalog.default.t") .build() - val motTable = new MetadataTable(info, "v") val catalog = spark.sessionState.catalogManager.catalog("view_catalog") - val ct = V1Table.toCatalogTable(catalog, Identifier.of(Array("ns"), "v"), motTable) + val ct = V1Table.toCatalogTable(catalog, Identifier.of(Array("ns"), "v"), info) assert(ct.viewCatalogAndNamespace.isEmpty) } @@ -209,7 +207,7 @@ class DataSourceV2MetadataViewSuite extends QueryTest with SharedSparkSession { // `unsupportedCreateOrReplaceViewOnTableError`. Pre-seed a non-view entry at a // multi-level-namespace identifier to exercise the rendering. val catalog = spark.sessionState.catalogManager.catalog("view_catalog") - .asInstanceOf[TestingTableViewCatalog] + .asInstanceOf[TestingRelationCatalog] val tblIdent = Identifier.of(Array("ns1", "inner"), "t_err") catalog.createTable( tblIdent, @@ -347,7 +345,7 @@ class DataSourceV2MetadataViewSuite extends QueryTest with SharedSparkSession { test("DESCRIBE TABLE EXTENDED ... AS JSON on a v2 view succeeds") { // `DescribeRelationJsonCommand` is a v1 runnable command that reads v1-shaped fields off // a `CatalogTable`. For non-session v2 views the resolved `ResolvedPersistentView.info` - // is a plain `ViewInfo`; the command projects it to a `CatalogTable` via + // is a plain `View`; the command projects it to a `CatalogTable` via // `V1Table.toCatalogTable` so DESC ... AS JSON works uniformly across session and // non-session view catalogs. seedV2View("v_desc_json") @@ -366,7 +364,7 @@ class DataSourceV2MetadataViewSuite extends QueryTest with SharedSparkSession { private def seedV2Table(name: String): Unit = { val catalog = spark.sessionState.catalogManager.catalog("view_catalog") - .asInstanceOf[TestingTableViewCatalog] + .asInstanceOf[TestingRelationCatalog] catalog.createTable( Identifier.of(Array("default"), name), new TableInfo.Builder() @@ -375,9 +373,9 @@ class DataSourceV2MetadataViewSuite extends QueryTest with SharedSparkSession { .build()) } - test("SHOW TABLES on a TableViewCatalog returns both tables and views (v1-parity)") { - // For a `TableViewCatalog` (a catalog exposing both tables and views in a shared - // identifier namespace), SHOW TABLES routes through `listTableAndViewSummaries` so views + test("SHOW TABLES on a RelationCatalog returns both tables and views (v1-parity)") { + // For a `RelationCatalog` (a catalog exposing both tables and views in a shared + // identifier namespace), SHOW TABLES routes through `listRelationSummaries` so views // appear alongside tables -- matching the v1 SHOW TABLES output. Pure `TableCatalog` // catalogs (no view mixin) continue to use `listTables` and return tables only. seedV2View("v_in_show_tables") @@ -394,29 +392,29 @@ class DataSourceV2MetadataViewSuite extends QueryTest with SharedSparkSession { } /** - * A [[TableViewCatalog]]: round-trips [[MetadataTable]] for created views and tables and + * A [[RelationCatalog]]: round-trips [[DelegatingTable]] for created views and tables and * exposes a few canned read-only view fixtures (`test_view`, `test_unqualified_view`, * `test_unqualified_multi`, plus an unqualified-target view at `ns1.ns2.t`) used by the * view-read tests. Entries created via `createTable` / `createView` are distinguished by the - * stored value's runtime type (ViewInfo vs TableInfo). The single-RPC perf entry point - * [[loadTableOrView]] returns either kind; [[loadTable]] is tables-only per the + * stored [[Relation]]'s runtime type (View vs Table). The single-RPC perf entry point + * [[loadRelation]] returns either kind; [[loadTable]] is tables-only per the * [[TableCatalog#loadTable]] contract. */ -class TestingTableViewCatalog extends TableViewCatalog { +class TestingRelationCatalog extends RelationCatalog { // Holds entries (views and tables) created via createTable / createView within the session. - // Keyed by (namespace, name); the stored value's runtime type (ViewInfo vs TableInfo) - // distinguishes views from tables. Mixed-catalog: shared identifier namespace per the - // TableViewCatalog contract. + // Keyed by (namespace, name); the stored [[Relation]]'s runtime type (View vs Table) + // distinguishes views from tables. Tables are stored as a DelegatingTable wrapping the + // TableInfo. Mixed-catalog: shared identifier namespace per the RelationCatalog contract. private val createdViews = - new java.util.concurrent.ConcurrentHashMap[(Seq[String], String), TableInfo]() + new java.util.concurrent.ConcurrentHashMap[(Seq[String], String), Relation]() - // Canned read-only view fixtures, exposed only via the perf path (loadTableOrView). loadView - // does not need to expose them because the resolver routes TableViewCatalog reads through - // loadTableOrView. - private def fixtureView(ident: Identifier): Option[ViewInfo] = ident.name() match { + // Canned read-only view fixtures, exposed only via the perf path (loadRelation). loadView + // does not need to expose them because the resolver routes RelationCatalog reads through + // loadRelation. + private def fixtureView(ident: Identifier): Option[View] = ident.name() match { case "test_view" => - Some(new ViewInfo.Builder() + Some(new View.Builder() .withSchema(new StructType().add("col", "string").add("i", "int")) .withQueryText( "SELECT col, col::int AS i FROM spark_catalog.default.t WHERE col = 'b'") @@ -424,7 +422,7 @@ class TestingTableViewCatalog extends TableViewCatalog { SQLConf.ANSI_ENABLED.key, (ident.namespace().head == "ansi").toString)) .build()) case "test_unqualified_view" => - Some(new ViewInfo.Builder() + Some(new View.Builder() .withSchema(new StructType().add("col", "string")) .withQueryText("SELECT col FROM t WHERE col = 'b'") .withCurrentCatalog("spark_catalog") @@ -434,7 +432,7 @@ class TestingTableViewCatalog extends TableViewCatalog { // View whose captured catalog+namespace is view_catalog.ns1.ns2 (two-part). The // unqualified `t` in the body must resolve via that captured context to // view_catalog.ns1.ns2.t, which this catalog also serves (see `t` case below). - Some(new ViewInfo.Builder() + Some(new View.Builder() .withSchema(new StructType().add("col", "string")) .withQueryText("SELECT col FROM t") .withCurrentCatalog("view_catalog") @@ -443,22 +441,21 @@ class TestingTableViewCatalog extends TableViewCatalog { case "t" if ident.namespace().toSeq == Seq("ns1", "ns2") => // Target of test_unqualified_multi's unqualified reference. Self-contained view so // the test doesn't need external data. - Some(new ViewInfo.Builder() + Some(new View.Builder() .withSchema(new StructType().add("col", "string")) .withQueryText("SELECT 'multi' AS col") .build()) case _ => None } - override def loadTableOrView(ident: Identifier): Table = { - // Single-RPC perf path: returns tables AND views (as MetadataTable). Stored entries - // win over fixture views (the fixture namespace is read-only and disjoint from - // createdViews in practice). loadTable, loadView, tableExists, viewExists all derive - // from this via the TableViewCatalog default impls. + override def loadRelation(ident: Identifier): Relation = { + // Single-RPC perf path: returns tables AND views. Stored entries win over fixture views + // (the fixture namespace is read-only and disjoint from createdViews in practice). + // loadTable, loadView, tableExists, viewExists all derive from this via the + // RelationCatalog default impls. val key = (ident.namespace().toSeq, ident.name()) Option(createdViews.get(key)) .orElse(fixtureView(ident)) - .map(new MetadataTable(_, ident.toString)) .getOrElse(throw new NoSuchTableException(ident)) } @@ -467,23 +464,24 @@ class TestingTableViewCatalog extends TableViewCatalog { // TableAlreadyExistsException. The shared `createdViews` keyspace makes `putIfAbsent` // throw uniformly for both table-at-ident and view-at-ident collisions. val key = (ident.namespace().toSeq, ident.name()) - if (createdViews.putIfAbsent(key, info) != null) { + val table = new DelegatingTable(info, ident.toString) + if (createdViews.putIfAbsent(key, table) != null) { throw new TableAlreadyExistsException(ident) } - new MetadataTable(info, ident.toString) + table } - /** Test-only accessor: returns the stored TableInfo (table or view) for the identifier. */ - def getStoredInfo(namespace: Array[String], name: String): TableInfo = { + /** Test-only accessor: returns the stored Relation (table or view) for the identifier. */ + def getStoredInfo(namespace: Array[String], name: String): Relation = { Option(createdViews.get((namespace.toSeq, name))).getOrElse { throw new NoSuchTableException(Identifier.of(namespace, name)) } } - /** Test-only accessor: returns the stored ViewInfo; fails if the entry is not a view. */ - def getStoredView(namespace: Array[String], name: String): ViewInfo = getStoredInfo( + /** Test-only accessor: returns the stored View; fails if the entry is not a view. */ + def getStoredView(namespace: Array[String], name: String): View = getStoredInfo( namespace, name) match { - case v: ViewInfo => v + case v: View => v case _ => throw new IllegalStateException( s"stored entry at ${namespace.mkString(".")}.$name is not a view") } @@ -494,7 +492,7 @@ class TestingTableViewCatalog extends TableViewCatalog { override def dropTable(ident: Identifier): Boolean = { val key = (ident.namespace().toSeq, ident.name()) val existing = createdViews.get(key) - if (existing == null || existing.isInstanceOf[ViewInfo]) return false + if (existing == null || existing.isInstanceOf[View]) return false createdViews.remove(key) != null } override def renameTable(oldIdent: Identifier, newIdent: Identifier): Unit = { @@ -505,7 +503,7 @@ class TestingTableViewCatalog extends TableViewCatalog { val targetNs = namespace.toSeq val ids = new java.util.ArrayList[Identifier]() createdViews.forEach { (key, info) => - if (key._1 == targetNs && !info.isInstanceOf[ViewInfo]) { + if (key._1 == targetNs && !info.isInstanceOf[View]) { ids.add(Identifier.of(key._1.toArray, key._2)) } } @@ -518,14 +516,14 @@ class TestingTableViewCatalog extends TableViewCatalog { val targetNs = namespace.toSeq val ids = new java.util.ArrayList[Identifier]() createdViews.forEach { (key, info) => - if (key._1 == targetNs && info.isInstanceOf[ViewInfo]) { + if (key._1 == targetNs && info.isInstanceOf[View]) { ids.add(Identifier.of(key._1.toArray, key._2)) } } ids.toArray(new Array[Identifier](0)) } - override def createView(ident: Identifier, info: ViewInfo): ViewInfo = { + override def createView(ident: Identifier, info: View): View = { val key = (ident.namespace().toSeq, ident.name()) if (createdViews.putIfAbsent(key, info) != null) { throw new ViewAlreadyExistsException(ident) @@ -533,10 +531,10 @@ class TestingTableViewCatalog extends TableViewCatalog { info } - override def replaceView(ident: Identifier, info: ViewInfo): ViewInfo = { + override def replaceView(ident: Identifier, info: View): View = { val key = (ident.namespace().toSeq, ident.name()) val existing = createdViews.get(key) - if (existing == null || !existing.isInstanceOf[ViewInfo]) { + if (existing == null || !existing.isInstanceOf[View]) { throw new NoSuchViewException(ident) } createdViews.put(key, info) @@ -546,7 +544,7 @@ class TestingTableViewCatalog extends TableViewCatalog { override def dropView(ident: Identifier): Boolean = { val key = (ident.namespace().toSeq, ident.name()) val existing = createdViews.get(key) - if (existing == null || !existing.isInstanceOf[ViewInfo]) return false + if (existing == null || !existing.isInstanceOf[View]) return false createdViews.remove(key) != null } @@ -554,7 +552,7 @@ class TestingTableViewCatalog extends TableViewCatalog { val oldKey = (oldIdent.namespace().toSeq, oldIdent.name()) val newKey = (newIdent.namespace().toSeq, newIdent.name()) val existing = createdViews.get(oldKey) - if (existing == null || !existing.isInstanceOf[ViewInfo]) { + if (existing == null || !existing.isInstanceOf[View]) { throw new NoSuchViewException(oldIdent) } if (createdViews.putIfAbsent(newKey, existing) != null) { @@ -600,14 +598,14 @@ class TestingTableOnlyCatalog extends TableCatalog { */ class TestingViewOnlyCatalog extends ViewCatalog { private val store = - new java.util.concurrent.ConcurrentHashMap[(Seq[String], String), ViewInfo]() + new java.util.concurrent.ConcurrentHashMap[(Seq[String], String), View]() // Seeded on first `initialize`. Filters `spark_catalog.default.t` so the read test can // assert deterministic output. ALTER VIEW tests overwrite it via `replaceView`. private def seedDefault(): Unit = { val key = (Seq("default"), "pure_v") if (!store.containsKey(key)) { - val info = new ViewInfo.Builder() + val info = new View.Builder() .withSchema(new StructType().add("x", "int")) .withQueryText("SELECT x FROM spark_catalog.default.t WHERE x > 1") .build() @@ -624,12 +622,12 @@ class TestingViewOnlyCatalog extends ViewCatalog { ids.toArray(new Array[Identifier](0)) } - override def loadView(ident: Identifier): ViewInfo = { + override def loadView(ident: Identifier): View = { val key = (ident.namespace().toSeq, ident.name()) Option(store.get(key)).getOrElse(throw new NoSuchViewException(ident)) } - override def createView(ident: Identifier, info: ViewInfo): ViewInfo = { + override def createView(ident: Identifier, info: View): View = { val key = (ident.namespace().toSeq, ident.name()) if (store.putIfAbsent(key, info) != null) { throw new ViewAlreadyExistsException(ident) @@ -637,7 +635,7 @@ class TestingViewOnlyCatalog extends ViewCatalog { info } - override def replaceView(ident: Identifier, info: ViewInfo): ViewInfo = { + override def replaceView(ident: Identifier, info: View): View = { val key = (ident.namespace().toSeq, ident.name()) if (!store.containsKey(key)) throw new NoSuchViewException(ident) store.put(key, info) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala index cb7531a0dbafd..150757dd07413 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/DataSourceV2SQLSuite.scala @@ -50,7 +50,7 @@ import org.apache.spark.sql.execution.streaming.runtime.MemoryStream import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.internal.SQLConf.{PARTITION_OVERWRITE_MODE, PartitionOverwriteMode, V2_SESSION_CATALOG_IMPLEMENTATION} import org.apache.spark.sql.sources.SimpleScanSource -import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructType} +import org.apache.spark.sql.types.{IntegerType, LongType, StringType, StructField, StructType} import org.apache.spark.unsafe.types.UTF8String abstract class DataSourceV2SQLSuite @@ -457,6 +457,44 @@ class DataSourceV2SQLSuiteV1Filter } } + test("CreateTableAsSelect: field IDs in query schema are not propagated to table columns") { + val basicCatalog = catalog("testcat").asTableCatalog + val atomicCatalog = catalog("testcat_atomic").asTableCatalog + val basicIdentifier = "testcat.table_name" + val atomicIdentifier = "testcat_atomic.table_name" + + // Use a non-numeric marker so it can never clash with InMemoryTable's sequential numeric IDs. + val sourceId = "source-id" + val nestedType = new StructType(Array(StructField("value", LongType).withId(sourceId))) + val schema = new StructType(Array( + StructField("id", LongType).withId(sourceId), + StructField("nested", nestedType).withId(sourceId))) + val sourceWithIds = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(1L, Row(42L)), Row(2L, Row(43L)))), schema) + + withTempView("source_with_ids") { + sourceWithIds.createOrReplaceTempView("source_with_ids") + Seq((basicCatalog, basicIdentifier), (atomicCatalog, atomicIdentifier)).foreach { + case (cat, identifier) => + withTable(identifier) { + spark.sql(s"CREATE TABLE $identifier USING foo AS SELECT * FROM source_with_ids") + val table = cat.loadTable(Identifier.of(Array(), "table_name")) + table.columns.foreach { col => + assert(col.metadataInJSON == null) + assert(col.id != sourceId) + col.dataType match { + case s: StructType => + s.fields.foreach { f => + assert(f.id.forall(_ != sourceId)) + } + case _ => + } + } + } + } + } + } + test("CreateTableAsSelect: do not double execute on collect(), take() and other queries") { val basicCatalog = catalog("testcat").asTableCatalog val atomicCatalog = catalog("testcat_atomic").asTableCatalog @@ -615,6 +653,44 @@ class DataSourceV2SQLSuiteV1Filter } } + test("ReplaceTableAsSelect: field IDs in query schema are not propagated to table columns") { + val basicCatalog = catalog("testcat").asTableCatalog + val atomicCatalog = catalog("testcat_atomic").asTableCatalog + val basicIdentifier = "testcat.table_name" + val atomicIdentifier = "testcat_atomic.table_name" + + val sourceId = "source-id" + val nestedType = new StructType(Array(StructField("value", LongType).withId(sourceId))) + val schema = new StructType(Array( + StructField("id", LongType).withId(sourceId), + StructField("nested", nestedType).withId(sourceId))) + val sourceWithIds = spark.createDataFrame( + spark.sparkContext.parallelize(Seq(Row(1L, Row(42L)), Row(2L, Row(43L)))), schema) + + withTempView("source_with_ids") { + sourceWithIds.createOrReplaceTempView("source_with_ids") + Seq((basicCatalog, basicIdentifier), (atomicCatalog, atomicIdentifier)).foreach { + case (cat, identifier) => + withTable(identifier) { + spark.sql(s"CREATE TABLE $identifier USING foo AS SELECT * FROM source_with_ids") + spark.sql(s"REPLACE TABLE $identifier USING foo AS SELECT * FROM source_with_ids") + val table = cat.loadTable(Identifier.of(Array(), "table_name")) + table.columns.foreach { col => + assert(col.metadataInJSON == null) + assert(col.id != sourceId) + col.dataType match { + case s: StructType => + s.fields.foreach { f => + assert(f.id.forall(_ != sourceId)) + } + case _ => + } + } + } + } + } + } + Seq("REPLACE", "CREATE OR REPLACE").foreach { cmd => test(s"ReplaceTableAsSelect: do not double execute $cmd on collect()") { val basicCatalog = catalog("testcat").asTableCatalog diff --git a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala index cb59ce80328d8..9c471031eeb94 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/connector/MergeIntoDataFrameSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.connector +import org.apache.spark.SparkConf import org.apache.spark.sql.{sources, Column, Row} import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.classic.MergeIntoWriter @@ -29,6 +30,9 @@ import org.apache.spark.sql.types.StringType class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { + override protected def sparkConf: SparkConf = super.sparkConf + .set(InMemoryBaseTable.ASSIGN_COLUMN_IDS, "true") + import testImplicits._ private def targetTableCol(colName: String): Column = { @@ -180,7 +184,7 @@ class MergeIntoDataFrameSuite extends RowLevelOperationSuiteBase { .update(Map("salary" -> targetTableCol("salary").plus(1))) .merge() }, - condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMN_ID_MISMATCH", + condition = "INCOMPATIBLE_TABLE_CHANGE_AFTER_ANALYSIS.COLUMNS_MISMATCH", matchPVals = true, parameters = Map("tableName" -> ".*", "errors" -> ".*")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/MetricViewV2CatalogSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/MetricViewV2CatalogSuite.scala index fedc2475f90ed..816590475ef62 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/MetricViewV2CatalogSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/MetricViewV2CatalogSuite.scala @@ -23,7 +23,7 @@ import scala.jdk.CollectionConverters._ import org.apache.spark.sql.{AnalysisException, QueryTest} import org.apache.spark.sql.catalyst.analysis.{NoSuchViewException, ViewAlreadyExistsException} -import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, MetadataTable, Table, TableCatalog, TableDependency, TableSummary, TableViewCatalog, ViewInfo} +import org.apache.spark.sql.connector.catalog.{Identifier, InMemoryTableCatalog, Relation, RelationCatalog, TableCatalog, TableDependency, TableSummary, View} import org.apache.spark.sql.metricview.serde.{AssetSource, Column, Constants, DimensionExpression, MeasureExpression, MetricView, MetricViewFactory, SQLSource} import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types.Metadata @@ -35,8 +35,8 @@ import org.apache.spark.sql.types.Metadata * [[org.apache.spark.sql.execution.datasources.v2.CreateV2MetricViewExec]]). * Metric views are persisted through the same [[ViewCatalog]] interface * as plain views; the only marker that distinguishes them is `PROP_TABLE_TYPE = METRIC_VIEW` - * plus the typed `viewDependencies` field on [[ViewInfo]]. The recording catalog used here is a - * minimal [[TableViewCatalog]] so the same instance can also host the source table referenced by + * plus the typed `viewDependencies` field on [[View]]. The recording catalog used here is a + * minimal [[RelationCatalog]] so the same instance can also host the source table referenced by * the metric view's YAML. */ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { @@ -121,11 +121,11 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { yaml } - private def capturedViewInfo(): ViewInfo = { + private def capturedViewInfo(): View = { val ident = Identifier.of(Array(testNamespace), metricViewName) val info = MetricViewRecordingCatalog.capturedViews.get(ident) assert(info != null, - s"Expected ViewInfo for $ident to be captured by the V2 catalog") + s"Expected View for $ident to be captured by the V2 catalog") info } @@ -134,7 +134,7 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { // ============================================================ - test("V2 catalog receives METRIC_VIEW table type and view text via ViewInfo") { + test("V2 catalog receives METRIC_VIEW table type and view text via View") { withTestCatalogTables { val metricView = MetricView( "0.1", @@ -144,7 +144,7 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { val yaml = createMetricView(fullMetricViewName, metricView) val info = capturedViewInfo() - // PROP_TABLE_TYPE is overwritten to METRIC_VIEW after `ViewInfo`'s constructor stamps it + // PROP_TABLE_TYPE is overwritten to METRIC_VIEW after `View`'s constructor stamps it // to VIEW; this is the marker `V1Table.toCatalogTable` reads to map the round-tripped row // back to `CatalogTableType.METRIC_VIEW`. assert(info.properties().get(TableCatalog.PROP_TABLE_TYPE) @@ -163,7 +163,7 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { } } - test("V2 catalog path populates metric_view.* + view context + sql configs on ViewInfo") { + test("V2 catalog path populates metric_view.* + view context + sql configs on View") { withTestCatalogTables { val metricView = MetricView( "0.1", @@ -182,7 +182,7 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { assert(props.get(MetricView.PROP_FROM_SQL) === null) assert(props.get(MetricView.PROP_WHERE) === "count > 0") - // SQL configs and current catalog/namespace are first-class typed fields on ViewInfo, no + // SQL configs and current catalog/namespace are first-class typed fields on View, no // longer encoded into properties for V2 catalogs. assert(info.sqlConfigs().size > 0, s"Expected at least one captured SQL config; got ${info.sqlConfigs()}") @@ -336,7 +336,7 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { // Now CREATE VIEW IF NOT EXISTS with a different YAML body. The catalog should not see // the second create at all (V2ViewPreparation's `viewExists` short-circuit fires before - // `buildViewInfo`), so the captured ViewInfo retains the original body. + // `buildViewInfo`), so the captured View retains the original body. val replacement = MetricView( "0.1", AssetSource(fullSourceTableName), @@ -471,7 +471,7 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { val deepIdent = Identifier.of(deepNamespace, deepMetricViewName) val info = MetricViewRecordingCatalog.capturedViews.get(deepIdent) - assert(info != null, s"Expected ViewInfo for $deepIdent to be captured") + assert(info != null, s"Expected View for $deepIdent to be captured") assert(info.properties().get(TableCatalog.PROP_TABLE_TYPE) === TableSummary.METRIC_VIEW_TABLE_TYPE) } finally { @@ -607,7 +607,7 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { val multiTable = "events_deep" val multiFull = s"$testCatalogName.${multiNamespace.mkString(".")}.$multiTable" withTestCatalogTables { - // The InMemoryTableCatalog (TableViewCatalog mixin) supports multi-level namespaces. + // The InMemoryTableCatalog (RelationCatalog mixin) supports multi-level namespaces. sql(s"CREATE NAMESPACE IF NOT EXISTS $testCatalogName.${multiNamespace.head}") sql(s"CREATE NAMESPACE IF NOT EXISTS " + s"$testCatalogName.${multiNamespace.mkString(".")}") @@ -650,7 +650,7 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { createMetricView(fullMetricViewName, mv) // The fixture's `events` source has rows ("region_1", 1, 5.0), ("region_2", 2, 10.0). // The metric view aggregates by `region` summing `count`. Resolution flows through - // loadTableOrView -> MetadataTable(ViewInfo) -> V1Table.toCatalogTable(ViewInfo) -> + // loadRelation -> DelegatingTable(View) -> V1Table.toCatalogTable(View) -> // CatalogTableType.METRIC_VIEW -> ResolveMetricView, which rewrites the view body // into Aggregate(Seq(region), Seq(sum(count) AS count_sum)) over `events`. The // `measure(...)` wrapper is required for measure columns -- selecting `count_sum` @@ -740,7 +740,7 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { // ============================================================ - test("DESCRIBE TABLE EXTENDED on a v2 metric view round-trips through loadTableOrView") { + test("DESCRIBE TABLE EXTENDED on a v2 metric view round-trips through loadRelation") { withTestCatalogTables { val mv = MetricView( "0.1", @@ -750,10 +750,10 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { val yaml = createMetricView(fullMetricViewName, mv) // DESCRIBE TABLE EXTENDED resolves the ident through `Analyzer.lookupTableOrView`, - // which calls `TableViewCatalog.loadTableOrView` once and gets back a - // `MetadataTable(ViewInfo)`. The analyzer wraps it as a `ResolvedPersistentView` and + // which calls `RelationCatalog.loadRelation` once and gets back a + // `DelegatingTable(View)`. The analyzer wraps it as a `ResolvedPersistentView` and // `DataSourceV2Strategy` routes through SPARK-56655's `DescribeV2ViewExec`, which - // reads the typed `ViewInfo` directly and emits the standard "Type" / "View Text" / + // reads the typed `View` directly and emits the standard "Type" / "View Text" / // "View Current Catalog" / "View Schema Mode" / etc. rows. Pins that `DescribeV2ViewExec` // emits a "Type" row for parity with v1 `CatalogTable.toJsonLinkedHashMap`, so users // can distinguish a plain VIEW from a sub-kind like METRIC_VIEW. @@ -918,7 +918,7 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { s"Expected TABLE_OR_VIEW_NOT_FOUND for the old ident, got " + s"${oldEx.getCondition}: ${oldEx.getMessage}") - // New ident loads through `TableViewCatalog.loadTableOrView` and surfaces the same + // New ident loads through `RelationCatalog.loadRelation` and surfaces the same // metric-view kind on `DESCRIBE TABLE EXTENDED`. val rows = sql(s"DESCRIBE TABLE EXTENDED $renamedFull").collect() val rowMap = rows.map(r => r.getString(0) -> r.getString(1)).toMap @@ -930,7 +930,7 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { } } - test("SHOW TABLES on a v2 TableViewCatalog lists both tables and metric views") { + test("SHOW TABLES on a v2 RelationCatalog lists both tables and metric views") { withTestCatalogTables { val mv = MetricView( "0.1", @@ -940,13 +940,13 @@ class MetricViewV2CatalogSuite extends QueryTest with SharedSparkSession { createMetricView(fullMetricViewName, mv) val tables = sql(s"SHOW TABLES IN $testCatalogName.$testNamespace") .collect().map(_.getString(1)).toSet - // SPARK-56655 routes SHOW TABLES on a `TableViewCatalog` through `listRelationSummaries` + // SPARK-56655 routes SHOW TABLES on a `RelationCatalog` through `listRelationSummaries` // so views appear alongside tables in the output (matching v1 SHOW TABLES on a session // catalog). Pure `TableCatalog` catalogs continue to return tables only. assert(tables.contains(sourceTableName), s"SHOW TABLES should list the source table, got: $tables") assert(tables.contains(metricViewName), - s"SHOW TABLES on a TableViewCatalog should also list metric views, got: $tables") + s"SHOW TABLES on a RelationCatalog should also list metric views, got: $tables") } } @@ -971,17 +971,17 @@ object MetricViewV2CatalogSuite { } /** - * Minimal [[TableViewCatalog]] used by [[MetricViewV2CatalogSuite]]. Layers `ViewCatalog` + * Minimal [[RelationCatalog]] used by [[MetricViewV2CatalogSuite]]. Layers `ViewCatalog` * methods over [[InMemoryTableCatalog]] (which provides table storage + namespace ops) and - * captures every [[ViewInfo]] passed to `createView` so tests can inspect the typed payload. + * captures every [[View]] passed to `createView` so tests can inspect the typed payload. * * The metric-view CREATE path goes via `ViewCatalog.createView`, so the captured map keys are * the view identifiers; the source table created by the test fixture is stored separately in * the inherited table catalog. */ -class MetricViewRecordingCatalog extends InMemoryTableCatalog with TableViewCatalog { +class MetricViewRecordingCatalog extends InMemoryTableCatalog with RelationCatalog { private val views = - new ConcurrentHashMap[(Seq[String], String), ViewInfo]() + new ConcurrentHashMap[(Seq[String], String), View]() // -- ViewCatalog methods -- @@ -994,11 +994,11 @@ class MetricViewRecordingCatalog extends InMemoryTableCatalog with TableViewCata out.asScala.toArray } - // `loadView`, `tableExists`, and `viewExists` are inherited from `TableViewCatalog`'s - // defaults, which derive from `loadTableOrView` -- a stored `ViewInfo` is wrapped in - // `MetadataTable` by `loadTableOrView` and the defaults unwrap it correctly. + // `loadView`, `tableExists`, and `viewExists` are inherited from `RelationCatalog`'s + // defaults, which derive from `loadRelation` -- a stored `View` is returned directly by + // `loadRelation` and the defaults discriminate it by type correctly. - // Bypasses `TableViewCatalog.tableExists` (whose default delegates to `loadTableOrView`, + // Bypasses `RelationCatalog.tableExists` (whose default delegates to `loadRelation`, // which checks our `views` map first); we want a tables-only check here so the cross-type // collision branches in `createView` / `replaceView` see only "is there a *table* at this // ident?". @@ -1006,8 +1006,8 @@ class MetricViewRecordingCatalog extends InMemoryTableCatalog with TableViewCata try { super[InMemoryTableCatalog].loadTable(ident); true } catch { case _: org.apache.spark.sql.catalyst.analysis.NoSuchTableException => false } - override def createView(ident: Identifier, info: ViewInfo): ViewInfo = { - // TableViewCatalog active-rejection contract: createView must throw + override def createView(ident: Identifier, info: View): View = { + // RelationCatalog active-rejection contract: createView must throw // ViewAlreadyExistsException when *either* a view *or* a table sits at the ident. if (tableExistsTablesOnly(ident)) { throw new ViewAlreadyExistsException(ident) @@ -1020,8 +1020,8 @@ class MetricViewRecordingCatalog extends InMemoryTableCatalog with TableViewCata info } - override def replaceView(ident: Identifier, info: ViewInfo): ViewInfo = { - // Per the TableViewCatalog contract, replaceView must surface NoSuchViewException + override def replaceView(ident: Identifier, info: View): View = { + // Per the RelationCatalog contract, replaceView must surface NoSuchViewException // when a *table* sits at the ident (not silently succeed and shadow the table). if (tableExistsTablesOnly(ident)) throw new NoSuchViewException(ident) val key = (ident.namespace().toSeq, ident.name()) @@ -1055,13 +1055,13 @@ class MetricViewRecordingCatalog extends InMemoryTableCatalog with TableViewCata } } - // -- TableViewCatalog single-RPC perf path -- + // -- RelationCatalog single-RPC perf path -- - override def loadTableOrView(ident: Identifier): Table = { + override def loadRelation(ident: Identifier): Relation = { val key = (ident.namespace().toSeq, ident.name()) Option(views.get(key)) match { - case Some(info) => new MetadataTable(info, ident.toString) - // Bypass `TableViewCatalog.loadTable` (whose default delegates back to `loadTableOrView`) + case Some(info) => info + // Bypass `RelationCatalog.loadTable` (whose default delegates back to `loadRelation`) // and call `InMemoryTableCatalog.loadTable` directly to avoid infinite recursion. case None => super[InMemoryTableCatalog].loadTable(ident) } @@ -1069,10 +1069,10 @@ class MetricViewRecordingCatalog extends InMemoryTableCatalog with TableViewCata } object MetricViewRecordingCatalog { - // Captures every ViewInfo that flows through createView / replaceView so individual tests + // Captures every View that flows through createView / replaceView so individual tests // can assert on it. Cleared between tests via `reset()`. - val capturedViews: ConcurrentHashMap[Identifier, ViewInfo] = - new ConcurrentHashMap[Identifier, ViewInfo]() + val capturedViews: ConcurrentHashMap[Identifier, View] = + new ConcurrentHashMap[Identifier, View]() def reset(): Unit = capturedViews.clear() } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala index 877e6970368b3..f7afdb5e6e537 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/QueryExecutionSuite.scala @@ -52,6 +52,9 @@ case class QueryExecutionTestRecord( class QueryExecutionSuite extends SharedSparkSession { import testImplicits._ + override protected def sparkConf = + super.sparkConf.set(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key, "0") + def checkDumpedPlans(path: String, expected: Int): Unit = Utils.tryWithResource( Source.fromFile(path)) { source => assert(source.getLines().toList diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index a83d5c99bb5d1..d70bd71587971 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -1225,6 +1225,78 @@ class WholeStageCodegenSuite extends SharedSparkSession "CSE-disabled codegen (i.e. fall back to the lazy, short-circuiting non-CSE path)") } + test("SPARK-56032: FilterExec skips CSE codegen when the common subexpression is cheap") { + // A column repeated across conjuncts never becomes a common subexpression -- a bare column is a + // `LeafExpression`, which `EquivalentExpressions` skips, and `splitConjunctivePredicates` feeds + // each conjunct to a separate `addExprTree` call. The realistic cheap-but-recorded case is a + // shared *non-leaf* slot read such as a struct field access: `s.x > 5 AND s.x < 100` shares + // `GetStructField(s, x)`. Caching that gains nothing over the non-CSE path's lazy load, so the + // gate must fall back. (Pre-`isCheap`-gate this took the CSE path, emitting the eager + // prologue.) + val schema = StructType(Seq( + StructField("s", StructType(Seq(StructField("x", IntegerType, nullable = true))), + nullable = true))) + val data = spark.sparkContext.parallelize(Seq( + Row(Row(10)), Row(Row(3)), Row(Row(200)), Row(Row(50)), Row(Row(null)), Row(null))) + val expected = Seq(Row(Row(10)), Row(Row(50))) + + def filterCode(cseEnabled: Boolean): String = { + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> cseEnabled.toString, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val df = spark.createDataFrame(data, schema) + // Both conjuncts share `GetStructField(s, x)`, a cheap non-leaf common subexpression. + val filtered = df.where("s.x > 5 AND s.x < 100") + val plan = filtered.queryExecution.executedPlan + assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec]), + "Filter should be in whole-stage codegen") + checkAnswer(filtered, expected) + codegenString(plan) + } + } + + def normalize(code: String): String = code.replaceAll("#\\d+", "#") + assert(normalize(filterCode(cseEnabled = true)) == normalize(filterCode(cseEnabled = false)), + "With only a cheap common subexpression, CSE-enabled FilterExec codegen should be " + + "identical to CSE-disabled codegen (i.e. fall back to the lazy, short-circuiting " + + "non-CSE path)") + } + + test("SPARK-56032: FilterExec takes CSE codegen when the common subexpression is non-cheap") { + // The dual of the cheap-subexpression test: when `otherPreds` share a genuinely non-cheap + // computation (`a + b`, whose `isCheap` is false), the gate must take the CSE path so the + // shared result is computed once. Verify the CSE-enabled code differs from CSE-disabled here, + // pinning down that the gate still fires for real repeated computation. + val schema = StructType(Seq( + StructField("a", IntegerType, nullable = true), + StructField("b", IntegerType, nullable = true))) + val data = spark.sparkContext.parallelize(Seq( + Row(1, 5), Row(60, 50), Row(10, 20), Row(0, 0), Row(null, 5))) + val expected = Seq(Row(1, 5), Row(10, 20)) + + def filterCode(cseEnabled: Boolean): String = { + withSQLConf( + SQLConf.SUBEXPRESSION_ELIMINATION_ENABLED.key -> cseEnabled.toString, + SQLConf.WHOLESTAGE_CODEGEN_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "false") { + val df = spark.createDataFrame(data, schema) + // Both conjuncts share `a + b`, a non-cheap common subexpression worth eliminating. + val filtered = df.where("(a + b) > 0 AND (a + b) < 100") + val plan = filtered.queryExecution.executedPlan + assert(plan.exists(_.isInstanceOf[WholeStageCodegenExec]), + "Filter should be in whole-stage codegen") + checkAnswer(filtered, expected) + codegenString(plan) + } + } + + def normalize(code: String): String = code.replaceAll("#\\d+", "#") + assert(normalize(filterCode(cseEnabled = true)) != normalize(filterCode(cseEnabled = false)), + "With a non-cheap common subexpression, CSE-enabled FilterExec codegen should differ from " + + "CSE-disabled codegen (i.e. take the CSE path that computes the shared result once)") + } + test("SPARK-56032: subexpressionElimination.filterExec.enabled gates FilterExec CSE " + "independently of subexpression elimination") { // The conf disables CSE specifically for FilterExec while leaving subexpression elimination diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala index 50322905f29f3..381305abec6a7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/adaptive/AdaptiveQueryExecSuite.scala @@ -69,6 +69,9 @@ class AdaptiveQueryExecSuite setupTestData() + override protected def sparkConf = + super.sparkConf.set(SQLConf.ADAPTIVE_MAX_SHUFFLE_HASH_JOIN_LOCAL_MAP_THRESHOLD.key, "0") + private def runAdaptiveAndVerifyResult(query: String, skipCheckAnswer: Boolean = false): (SparkPlan, SparkPlan) = { var finalPlanCnt = 0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala new file mode 100644 index 0000000000000..161f8ec647a22 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/ConcurrentInMemoryRelationSuite.scala @@ -0,0 +1,325 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.columnar + +import java.io.File +import java.util.concurrent.CountDownLatch + +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually +import org.scalatest.time.{Millis, Seconds, Span} + +import org.apache.spark.{LocalSparkContext, SparkConf, SparkContext, SparkFunSuite} +import org.apache.spark.sql.{DataFrame, Dataset, SparkSession} +import org.apache.spark.sql.columnar.CachedBatch +import org.apache.spark.sql.functions.{lit, when} +import org.apache.spark.util.{ThreadUtils, Utils} + +/** + * Regression test for SPARK-57547: concurrent first-touch of one cold table cache must not let + * duplicate partition computes silently drop rows. + * + * AQE creates a separate `TableCacheQueryStageExec` for every reference to the same cache (table + * cache stages are never reused), and each one submits its own build job over the shared cache RDD. + * A query that references a cached relation several times therefore first-touches the cold cache + * from several jobs at once. Spark has no global cross-executor "compute this partition once" + * barrier, so the same partition can be computed by multiple executors. If the cache decided it was + * "loaded" from a raw task-completion count (the legacy behavior), those duplicate completions + * could push the count to the partition count while a row-producing partition was still being + * computed, falsely marking the cache loaded with rowCount 0 -- which lets AQE propagate an empty + * relation and silently lose rows. + * + * The fix counts the DISTINCT set of materialized partitions instead, so duplicate computes can + * no longer mark the cache loaded early. These tests reproduce the race deterministically: a + * two-stage gate holds the row-producing partition while the empty-output partition's duplicate + * cross-executor completions accumulate. With distinct tracking the cache stays correctly + * not-loaded while a partition is still building, so the consumer observes every row; were the + * loaded check to fall back to a raw task-completion count it would latch the cache as loaded + * with rowCount 0 and let AQE propagate an empty relation, losing rows (which the repro detects + * as a row-count mismatch). A multi-executor `local-cluster` session is required so the duplicate + * computes land on different executors. + */ +class ConcurrentInMemoryRelationSuite extends SparkFunSuite with LocalSparkContext with Eventually { + + private def cacheBuilderOf(ds: Dataset[_]): CachedRDDBuilder = { + val relations = ds.queryExecution.withCachedData.collect { case i: InMemoryRelation => i } + assert(relations.length == 1) + relations.head.cacheBuilder + } + + private def withSession(numExecutors: Int = 4)( + f: SparkSession => Unit): Unit = { + val conf = new SparkConf() + .setMaster(s"local-cluster[$numExecutors,1,1024]") + .setAppName("ConcurrentInMemoryRelationSuite") + sc = new SparkContext(conf) + try { + // Wait for all executors to register so tasks spread one-per-executor as the tests assume. + eventually(timeout(Span(60, Seconds)), interval(Span(200, Millis))) { + assert(sc.getExecutorIds().size == numExecutors) + } + f(SparkSession.builder().sparkContext(sc).getOrCreate()) + } finally { + resetSparkContext() + } + } + + /** + * Drives the actual SPARK-57547 data loss deterministically. + * + * Caches a skewed join with two shuffle partitions: every partition has non-empty INPUT (so + * neither is pruned as an empty task), but only the `skewKey` bucket produces OUTPUT rows -- so + * one partition is row-producing and the other produces zero rows. A two-stage gate blocks every + * partition's build inside `mapPartitions` until released. `numReferences` threads each submit + * their own build job over the shared cache RDD (exactly as per-reference + * `TableCacheQueryStageExec`s do); on `local-cluster[4,1,...]` (= numReferences x cachePartitions + * task slots, one task per executor) the empty-output partition is computed by both references on + * two distinct executors. + * + * Sequence: (1) the threads first-touch the cold cache, gating all `numReferences x + * cachePartitions` tasks; (2) release only the empty-output partition, so its two + * cross-executor completions land while the row-producing partition is still gated; (3) poll + * `isCachedColumnBuffersLoaded` -- distinct-partition tracking keeps it false (a raw + * task-completion count would instead reach cachePartitions here and latch a poisoned + * "loaded" state with rowCount 0); (4) a consumer query (a GROUPED aggregate, where empty + * propagation could collapse the result) sees the cache not-loaded and plans against the + * real rows -- had it been poisoned, AQE would have propagated an empty relation and dropped + * rows; (5) release the producing partition. Returns (rows the consumer observed, expected + * rows), equal unless poisoned. + */ + private def runDataLossRepro(spark: SparkSession): (Long, Long) = { + import spark.implicits._ + val numKeys = 64 + val skewKey = 42 + val rowsPerKey = 500 + val numReferences = 2 + val cachePartitions = 2 // one row-producing, one empty-output (see shuffle.partitions below) + val expected = rowsPerKey.toLong * rowsPerKey.toLong // only the skewKey bucket joins + + // Exactly two shuffle partitions (one row-producing, one empty-output), no broadcast so the + // join shuffles, and build the cache with AQE off so the skewed producing partition is not + // rebalanced away (which would defuse the window). Consumers below run with AQE on so they go + // through TableCacheQueryStageExec + empty-relation propagation. + spark.conf.set("spark.sql.shuffle.partitions", "2") + spark.conf.set("spark.sql.autoBroadcastJoinThreshold", "-1") + spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "false") + spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "false") + spark.conf.set("spark.sql.adaptive.enabled", "false") + + val gateDir = Utils.createTempDir() + def file(name: String) = new File(gateDir, name) + val releaseEmpty = file("releaseEmpty").getAbsolutePath + val releaseProducing = file("releaseProducing").getAbsolutePath + val entryDir = gateDir.getAbsolutePath + + def side(matchSalt: Int, valueCol: String): DataFrame = + spark.range(0, numKeys.toLong * rowsPerKey).select( + ($"id" % numKeys).cast("int").as("k"), + when(($"id" % numKeys) === skewKey, lit(0)).otherwise(lit(matchSalt)).as("salt"), + $"id".as(valueCol)) + + val joined = side(1, "lv") + .join(side(2, "rv"), Seq("k", "salt")) + .select($"k", $"lv").as[(Int, Long)] + + // Two-stage gate: every partition signals it has entered (past the block-existence check) and + // waits for releaseEmpty; the row-producing partition (the only one with rows) waits longer. + val gated = joined.mapPartitions { iter => + val buffered = iter.buffered + val isProducing = buffered.hasNext + file(s"entered-${java.util.UUID.randomUUID()}").createNewFile() + def waitFor(path: String): Unit = { + val deadline = System.currentTimeMillis() + 60000 + while (!new File(path).exists() && System.currentTimeMillis() < deadline) Thread.sleep(50) + } + waitFor(releaseEmpty) + if (isProducing) waitFor(releaseProducing) + buffered + }.toDF("k", "lv") + + val cached = gated.cache() + try { + val builder = cacheBuilderOf(cached) + // Cache plan captured (static, 2 partitions); consumers from here on use AQE. + spark.conf.set("spark.sql.adaptive.enabled", "true") + // Every reference launches its own build job over the shared cache RDD (no dedup at this + // layer), so the empty partition is computed by every reference: numReferences x + // cachePartitions gated tasks. + val expectedEntries = numReferences * cachePartitions + val rdd = builder.cachedColumnBuffers + val submitted = new CountDownLatch(numReferences) + val pool = ThreadUtils.newDaemonFixedThreadPool(numReferences, "spark57547-dataloss") + try { + val firstTouch = (1 to numReferences).map { _ => + pool.submit(new java.util.concurrent.Callable[Unit] { + override def call(): Unit = { + val f = spark.sparkContext.submitJob( + rdd, + (_: Iterator[CachedBatch]) => (), + 0 until rdd.getNumPartitions, + (_: Int, _: Unit) => (), + ()) + submitted.countDown() + ThreadUtils.awaitResult(f, 120.seconds) + } + }) + } + assert(submitted.await(60, java.util.concurrent.TimeUnit.SECONDS)) + // Wait until every build task is parked at the gate (all have passed the block-existence + // check), so releasing the empty partition forces its cross-executor completions to run. + eventually(timeout(Span(60, Seconds)), interval(Span(100, Millis))) { + val entered = new File(entryDir).listFiles().count(_.getName.startsWith("entered-")) + assert(entered == expectedEntries, s"entered=$entered expected=$expectedEntries") + } + + // Stage 1: release ONLY the empty-output partition; the producing partition stays gated. + assert(new File(releaseEmpty).createNewFile()) + + // Were the loaded check to fall back to a raw task-completion count, the empty partition's + // duplicate cross-executor completions would push that count to cachePartitions even though + // the producing partition has not run, latching the cache as "loaded" with rowCount 0. We + // read it through the relation handle -- exactly what AQE's stats reads do in production -- + // and the one-way latch would make the poison permanent (the producing partition is still + // gated when the consumer runs below). With distinct-partition accounting the cache stays + // not loaded here, so this poll times out and we fall through to a normal (complete) build. + val poisoned = + try { + eventually(timeout(Span(30, Seconds)), interval(Span(100, Millis))) { + assert(builder.isCachedColumnBuffersLoaded) + } + true + } catch { + case _: org.scalatest.exceptions.TestFailedException => false + } + + // A GROUPED aggregate (not a global count): AQE empty-relation propagation collapses the + // whole result when the cache stage is (falsely) reported as a zero-row materialized stage; + // a global aggregate over empty would still emit one row and mask the loss. + val observed = if (poisoned) { + // The cache lied (loaded with rowCount 0 while the producing partition is still gated and + // unbuilt). The consumer plans against it and AQE propagates an empty relation, so the + // rows silently vanish. The producing partition stays gated, so this is deterministic. + val consumer = cached.groupBy("k").count() + val rows = consumer.collect() + assert(consumer.queryExecution.executedPlan.toString.contains("EmptyRelation"), + "expected AQE to propagate an empty relation from the poisoned cache stage") + assert(new File(releaseProducing).createNewFile()) // unblock the build for clean shutdown + rows.map(_.getLong(1)).sum + } else { + // The cache is correctly not loaded, so let the producing partition finish and the + // consumer observes every row. + assert(new File(releaseProducing).createNewFile()) + cached.groupBy("k").count().collect().map(_.getLong(1)).sum + } + firstTouch.foreach(_.get(120, java.util.concurrent.TimeUnit.SECONDS)) + (observed, expected) + } finally { + pool.shutdown() + } + } finally { + cached.unpersist(blocking = true) + Utils.deleteRecursively(gateDir) + } + } + + /** + * Builds a cold cache whose partitions all carry rows and first-touches it concurrently from + * `numReferences` jobs with every partition gated, so each partition is computed once per + * reference on a distinct executor (`numReferences` duplicate cross-executor computes per + * partition). Returns (reported materialized row count, expected rows); with distinct-partition + * tracking on, the keyed accumulator de-duplicates the duplicate computes so the count is exact. + */ + private def runDuplicateComputeStats(spark: SparkSession): (Long, Long) = { + import spark.implicits._ + val numReferences = 2 + val cachePartitions = 2 + val numRows = 200L // split evenly across the partitions; every partition is non-empty + + val gateDir = Utils.createTempDir() + def file(name: String) = new File(gateDir, name) + val release = file("release").getAbsolutePath + val entryDir = gateDir.getAbsolutePath + + // Every partition has rows and blocks at the gate until released, so all references' build + // tasks are in flight (past the block-existence check) before any completes -- forcing the + // duplicate cross-executor computes that the per-batch accumulator would over-count. + val cached = spark.range(0, numRows, 1, cachePartitions).as[Long].mapPartitions { iter => + file(s"entered-${java.util.UUID.randomUUID()}").createNewFile() + val deadline = System.currentTimeMillis() + 60000 + while (!new File(release).exists() && System.currentTimeMillis() < deadline) Thread.sleep(50) + iter + }.cache() + try { + val builder = cacheBuilderOf(cached) + val rdd = builder.cachedColumnBuffers + val submitted = new CountDownLatch(numReferences) + val pool = ThreadUtils.newDaemonFixedThreadPool(numReferences, "spark57547-stats") + try { + val futures = (1 to numReferences).map { _ => + pool.submit(new java.util.concurrent.Callable[Unit] { + override def call(): Unit = { + val f = spark.sparkContext.submitJob( + rdd, + (_: Iterator[CachedBatch]) => (), + 0 until rdd.getNumPartitions, + (_: Int, _: Unit) => (), + ()) + submitted.countDown() + ThreadUtils.awaitResult(f, 120.seconds) + } + }) + } + assert(submitted.await(60, java.util.concurrent.TimeUnit.SECONDS)) + // Wait until every reference's task for every partition is parked at the gate, then release + // them so each partition is computed once per reference. + eventually(timeout(Span(60, Seconds)), interval(Span(100, Millis))) { + val entered = new File(entryDir).listFiles().count(_.getName.startsWith("entered-")) + assert(entered == numReferences * cachePartitions, s"entered=$entered") + } + assert(new File(release).createNewFile()) + futures.foreach(_.get(120, java.util.concurrent.TimeUnit.SECONDS)) + assert(builder.isCachedColumnBuffersLoaded) + (builder.materializedRowCount, numRows) + } finally { + pool.shutdown() + } + } finally { + cached.unpersist(blocking = true) + Utils.deleteRecursively(gateDir) + } + } + + test("SPARK-57547: concurrent first-touch of a cold cache does not lose rows") { + withSession() { spark => + val (observed, expected) = runDataLossRepro(spark) + assert(observed == expected, s"consumer observed $observed rows, expected $expected") + } + } + + test("SPARK-57547: cache statistics are exact under duplicate cross-executor computes") { + // Every partition is computed by both references, so the partition-keyed accumulator sees a + // duplicate `add` per partition. Last-write-wins de-duplication keeps the reported row count + // exact -- a naive summing accumulator would over-count under these duplicate computes. + withSession() { spark => + val (rowCount, expected) = runDuplicateComputeStats(spark) + assert(rowCount == expected, + s"partition-keyed accumulator should report exact row count $expected, got $rowCount") + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala index 5cd62302861ae..57da12e87979a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarQuerySuite.scala @@ -361,7 +361,7 @@ class InMemoryColumnarQuerySuite extends SharedSparkSession with AdaptiveSparkPl checkAnswer(cached, expectedAnswer) // Check that the right size was calculated. - assert(cached.cacheBuilder.sizeInBytesStats.value === expectedAnswer.length * INT.defaultSize) + assert(cached.cacheBuilder.materializedSizeInBytes === expectedAnswer.length * INT.defaultSize) } test("cached row count should be calculated") { @@ -375,7 +375,7 @@ class InMemoryColumnarQuerySuite extends SharedSparkSession with AdaptiveSparkPl checkAnswer(cached, expectedAnswer) // Check that the right row count was calculated. - assert(cached.cacheBuilder.rowCountStats.value === 6) + assert(cached.cacheBuilder.materializedRowCount === 6) } test("access primitive-type columns in CachedBatch without whole stage codegen") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterViewSetTblPropertiesSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterViewSetTblPropertiesSuiteBase.scala index b55ee4563c6eb..bf3931d7c1cba 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterViewSetTblPropertiesSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/AlterViewSetTblPropertiesSuiteBase.scala @@ -100,7 +100,7 @@ trait AlterViewSetTblPropertiesSuiteBase extends QueryTest with DDLCommandTestUt test("setting `comment` flows through to SHOW CREATE TABLE") { // v1 `AlterTableSetPropertiesCommand` updates the typed `CatalogTable.comment` field when // the user passes `'comment'` via SET TBLPROPERTIES, so SHOW CREATE TABLE renders the - // comment in the COMMENT clause. The v2 path uses `ViewInfo.properties` as the source of + // comment in the COMMENT clause. The v2 path uses `View.properties` as the source of // truth for `PROP_COMMENT` (see `AlterV2ViewSetPropertiesExec` and `ShowCreateV2ViewExec`), // so the same SET TBLPROPERTIES('comment' = ...) round-trips through SHOW CREATE TABLE. // Pin the cross-catalog parity here. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateViewSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateViewSuiteBase.scala index d046e74b68624..b0e42ed8cbe95 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateViewSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/CreateViewSuiteBase.scala @@ -40,7 +40,7 @@ trait CreateViewSuiteBase extends QueryTest with DDLCommandTestUtils { /** * Seed a non-view table at `qualified` (full `catalog.ns.name`) and run `body`. Same SQL - * for v1 and v2 -- `InMemoryTableViewCatalog.createTable` accepts the parquet TableInfo + * for v1 and v2 -- `InMemoryRelationCatalog.createTable` accepts the parquet TableInfo * the same way the session catalog does, so both legs share this implementation. */ protected final def withSeededTable(qualified: String)(body: => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropViewSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropViewSuiteBase.scala index 92a368e4155a6..9266f2dbca9d2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropViewSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/DropViewSuiteBase.scala @@ -37,7 +37,7 @@ trait DropViewSuiteBase extends QueryTest with DDLCommandTestUtils { /** * Seed a non-view table at `qualified` (full `catalog.ns.name`) and run `body`. Same SQL - * for v1 and v2 -- `InMemoryTableViewCatalog.createTable` accepts the parquet TableInfo + * for v1 and v2 -- `InMemoryRelationCatalog.createTable` accepts the parquet TableInfo * the same way the session catalog does, so both legs share this implementation. */ protected final def withSeededTable(qualified: String)(body: => Unit): Unit = { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala index cd917a817f7f0..6968a75f5eee3 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/PlanResolutionSuite.scala @@ -153,7 +153,7 @@ class PlanResolutionSuite extends SharedSparkSession with AnalysisTest { when(t.comment).thenReturn(None) when(t.collation).thenReturn(None) if (tableType == CatalogTableType.VIEW) { - // Stub the view-only fields that resolution reads through `V1ViewInfo.builderFrom`. + // Stub the view-only fields that resolution reads through `V1View.builderFrom`. // Mockito returns `null` for unstubbed Object methods, which would NPE the moment // builderFrom calls `.getOrElse` / `.asJava` / `.toArray` on a null Option/Seq/Map. when(t.viewText).thenReturn(None) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewSchemaBindingSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewSchemaBindingSuite.scala index 517880047d256..c9f769d9f5485 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewSchemaBindingSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewSchemaBindingSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.execution.command class AlterViewSchemaBindingSuite extends command.AlterViewSchemaBindingSuiteBase with ViewCommandSuiteBase { - test("V2: catalog stores the new schema mode on ViewInfo") { + test("V2: catalog stores the new schema mode on View") { val view = s"$catalog.$namespace.v2_schema_mode" createView(view) sql(s"ALTER VIEW $view WITH SCHEMA EVOLUTION") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewSetTblPropertiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewSetTblPropertiesSuite.scala index 642d8d46b7fac..46499b6b49693 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewSetTblPropertiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewSetTblPropertiesSuite.scala @@ -26,7 +26,7 @@ import org.apache.spark.sql.execution.command class AlterViewSetTblPropertiesSuite extends command.AlterViewSetTblPropertiesSuiteBase with ViewCommandSuiteBase { - test("V2: catalog stores the property on ViewInfo") { + test("V2: catalog stores the property on View") { val view = s"$catalog.$namespace.v2_set_view_info" createView(view) sql(s"ALTER VIEW $view SET TBLPROPERTIES ('k' = 'v')") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewUnsetTblPropertiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewUnsetTblPropertiesSuite.scala index 0d7f13007e9f5..52871b1a04128 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewUnsetTblPropertiesSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/AlterViewUnsetTblPropertiesSuite.scala @@ -22,7 +22,7 @@ import org.apache.spark.sql.execution.command class AlterViewUnsetTblPropertiesSuite extends command.AlterViewUnsetTblPropertiesSuiteBase with ViewCommandSuiteBase { - test("V2: unset removes the entry from the stored ViewInfo") { + test("V2: unset removes the entry from the stored View") { val view = s"$catalog.$namespace.v2_unset_view_info" createViewWithProps(view, "k" -> "v") sql(s"ALTER VIEW $view UNSET TBLPROPERTIES ('k')") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateViewSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateViewSuite.scala index 1750a538b1d4c..7ff3344546647 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateViewSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/CreateViewSuite.scala @@ -18,13 +18,13 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.sql.AnalysisException -import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, TableCatalog, ViewInfo} +import org.apache.spark.sql.connector.catalog.{BasicInMemoryTableCatalog, TableCatalog, View} import org.apache.spark.sql.execution.command class CreateViewSuite extends command.CreateViewSuiteBase with ViewCommandSuiteBase { import testImplicits._ - test("V2: CREATE VIEW propagates DEFAULT COLLATION onto the stored ViewInfo") { + test("V2: CREATE VIEW propagates DEFAULT COLLATION onto the stored View") { val view = s"$catalog.$namespace.v2_create_collation" withTable("spark_catalog.default.src_coll") { Seq("a", "b").toDF("col").write.saveAsTable("spark_catalog.default.src_coll") @@ -56,13 +56,13 @@ class CreateViewSuite extends command.CreateViewSuiteBase with ViewCommandSuiteB // The Base version of this scenario asserts the SQL behavior (errors / no-op); // here we additionally pin the v2-only post-condition that the persisted entry under // the colliding identifier remains a `TableInfo` and is NOT silently swapped for a - // `ViewInfo` by the IF NOT EXISTS path. + // `View` by the IF NOT EXISTS path. val name = "v2_ifne_keeps_table" val view = s"$catalog.$namespace.$name" withSeededTable(view) { sql(s"CREATE VIEW IF NOT EXISTS $view AS SELECT 1 AS col") val stored = viewCatalog.getStoredInfo(Array(namespace), name) - assert(!stored.isInstanceOf[ViewInfo]) + assert(!stored.isInstanceOf[View]) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowCreateTableSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowCreateTableSuite.scala index e2ce378f21cc5..bbb26f0510187 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowCreateTableSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ShowCreateTableSuite.scala @@ -17,13 +17,20 @@ package org.apache.spark.sql.execution.command.v2 +import org.apache.spark.SparkConf import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.catalyst.util.FieldMetadataUtils +import org.apache.spark.sql.connector.catalog.InMemoryBaseTable import org.apache.spark.sql.execution.command /** * The class contains tests for the `SHOW CREATE TABLE` command to check V2 table catalogs. */ class ShowCreateTableSuite extends command.ShowCreateTableSuiteBase with CommandSuiteBase { + + override def sparkConf: SparkConf = super.sparkConf + .set(InMemoryBaseTable.ASSIGN_COLUMN_IDS, "true") + override def fullName: String = s"$catalog.$ns.$table" test("SPARK-33898: show create table as serde") { @@ -200,6 +207,20 @@ class ShowCreateTableSuite extends command.ShowCreateTableSuiteBase with Command } } + test("SPARK-57544: show create table does not expose column IDs; schema does") { + withNamespaceAndTable(ns, table) { t => + sql(s"CREATE TABLE $t (id INT, salary INT) $defaultUsing") + + // Column IDs assigned by the catalog must NOT appear in SHOW CREATE TABLE output. + val showDDL = getShowCreateDDL(t) + assert(!showDDL.exists(_.contains(FieldMetadataUtils.FIELD_ID_METADATA_KEY))) + + // Column IDs must be accessible via df.schema. + val fields = spark.table(t).schema.fields + assert(fields.forall(_.id.isDefined)) + } + } + test("show table constraints") { withNamespaceAndTable("ns", "tbl", nonPartitionCatalog) { t => withNamespaceAndTable("ns", "other_table", nonPartitionCatalog) { otherTable => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ViewCommandSuiteBase.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ViewCommandSuiteBase.scala index 4dbbbc21afa9f..3f67a61bce4a4 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ViewCommandSuiteBase.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/command/v2/ViewCommandSuiteBase.scala @@ -18,21 +18,21 @@ package org.apache.spark.sql.execution.command.v2 import org.apache.spark.SparkConf -import org.apache.spark.sql.connector.catalog.InMemoryTableViewCatalog +import org.apache.spark.sql.connector.catalog.InMemoryRelationCatalog /** * Settings for v2 view command test suites. Extends v2 [[CommandSuiteBase]] (so view tests * inherit `checkLocation` and the standard v2 `test_catalog` configuration), and additionally - * wires `test_view_catalog` to [[InMemoryTableViewCatalog]] -- the catalog that the unified + * wires `test_view_catalog` to [[InMemoryRelationCatalog]] -- the catalog that the unified * `*SuiteBase` view tests under `command/` target via the `$catalog` placeholder. */ trait ViewCommandSuiteBase extends CommandSuiteBase { override def catalog: String = "test_view_catalog" override def sparkConf: SparkConf = super.sparkConf - .set(s"spark.sql.catalog.$catalog", classOf[InMemoryTableViewCatalog].getName) + .set(s"spark.sql.catalog.$catalog", classOf[InMemoryRelationCatalog].getName) - /** Helper: returns the configured `InMemoryTableViewCatalog`. */ - protected def viewCatalog: InMemoryTableViewCatalog = - spark.sessionState.catalogManager.catalog(catalog).asInstanceOf[InMemoryTableViewCatalog] + /** Helper: returns the configured `InMemoryRelationCatalog`. */ + protected def viewCatalog: InMemoryRelationCatalog = + spark.sessionState.catalogManager.catalog(catalog).asInstanceOf[InMemoryRelationCatalog] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala index 21860122244a1..195e7c7fce7ff 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/FileSourceStrategySuite.scala @@ -605,25 +605,27 @@ class FileSourceStrategySuite extends SharedSparkSession { } test(s"SPARK-44021: Test ${SQLConf.FILES_MAX_PARTITION_NUM.key} works as expected") { - val files = - Range(0, 300000).map(p => PartitionedFile(InternalRow.empty, sp(s"$p"), 0, 50000000)) - val maxPartitionBytes = conf.filesMaxPartitionBytes - val defaultPartitions = FilePartition.getFilePartitions(spark, files, maxPartitionBytes) - assert(defaultPartitions.size === 150000) - - withSQLConf(SQLConf.FILES_MAX_PARTITION_NUM.key -> "20000") { - val partitions = FilePartition.getFilePartitions(spark, files, maxPartitionBytes) - assert(partitions.size === 20000) - } + withSQLConf(SQLConf.FILES_MAX_PARTITION_BYTES.key -> "128MB") { + val files = + Range(0, 300000).map(p => PartitionedFile(InternalRow.empty, sp(s"$p"), 0, 50000000)) + val maxPartitionBytes = conf.filesMaxPartitionBytes + val defaultPartitions = FilePartition.getFilePartitions(spark, files, maxPartitionBytes) + assert(defaultPartitions.size === 150000) + + withSQLConf(SQLConf.FILES_MAX_PARTITION_NUM.key -> "20000") { + val partitions = FilePartition.getFilePartitions(spark, files, maxPartitionBytes) + assert(partitions.size === 20000) + } - withSQLConf(SQLConf.FILES_MAX_PARTITION_NUM.key -> "50000") { - val partitions = FilePartition.getFilePartitions(spark, files, maxPartitionBytes) - assert(partitions.size === 50000) - } + withSQLConf(SQLConf.FILES_MAX_PARTITION_NUM.key -> "50000") { + val partitions = FilePartition.getFilePartitions(spark, files, maxPartitionBytes) + assert(partitions.size === 50000) + } - withSQLConf(SQLConf.FILES_MAX_PARTITION_NUM.key -> "200000") { - val partitions = FilePartition.getFilePartitions(spark, files, maxPartitionBytes) - assert(partitions.size === defaultPartitions.size) + withSQLConf(SQLConf.FILES_MAX_PARTITION_NUM.key -> "200000") { + val partitions = FilePartition.getFilePartitions(spark, files, maxPartitionBytes) + assert(partitions.size === defaultPartitions.size) + } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala index d3723881bfa24..0f03f06b7425a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/jdbc/JdbcUtilsSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.sql.AnalysisException import org.apache.spark.sql.catalyst.parser.ParseException import org.apache.spark.sql.types._ +import org.apache.spark.util.Utils class JdbcUtilsSuite extends SparkFunSuite { @@ -69,4 +70,44 @@ class JdbcUtilsSuite extends SparkFunSuite { condition = "PARSE_SYNTAX_ERROR", parameters = Map("error" -> "'.'", "hint" -> "")) } + + test("redactUrl keeps only the jdbc:: prefix") { + val redaction = Utils.REDACTION_REPLACEMENT_TEXT + + // Only the "jdbc::" prefix is kept; everything after it (host, port, database, + // userinfo and connection properties) is redacted, regardless of the driver-specific syntax. + // This covers credentials embedded as a "//user:pwd@host" authority, ... + assert(JDBCOptions.redactUrl("jdbc:mysql://user:secret@host:3306/db", None) === + s"jdbc:mysql:$redaction") + assert(JDBCOptions.redactUrl("jdbc:mysql://user:p@ss@host:3306/db?password=other", None) === + s"jdbc:mysql:$redaction") + // ... as Oracle Thin's "user/pwd@host" form (no "//" authority), ... + assert(JDBCOptions.redactUrl("jdbc:oracle:thin:scott/tiger@host:1521/svc", None) === + s"jdbc:oracle:$redaction") + assert(JDBCOptions.redactUrl("jdbc:oracle:thin:scott/tiger@//host:1521/svc?x=1", None) === + s"jdbc:oracle:$redaction") + // ... and as "?"- or ";"-delimited connection properties. + assert(JDBCOptions.redactUrl( + "jdbc:postgresql://host/db?user=alice&password=secret", None) === + s"jdbc:postgresql:$redaction") + assert(JDBCOptions.redactUrl( + "jdbc:sqlserver://localhost:1433;databaseName=testdb;password=secret", None) === + s"jdbc:sqlserver:$redaction") + + // Even URLs that carry no credentials are reduced to the prefix -- nothing past the + // subprotocol is assumed safe. + assert(JDBCOptions.redactUrl("jdbc:mysql://localhost/db", None) === s"jdbc:mysql:$redaction") + assert(JDBCOptions.redactUrl("jdbc:h2:mem:testdb", None) === s"jdbc:h2:$redaction") + + // A URL with no subname delimiter (no second colon) is redacted wholesale. + assert(JDBCOptions.redactUrl("jdbc:weird-url", None) === redaction) + + // The user-configured regex is still applied on top of the kept prefix. + assert(JDBCOptions.redactUrl("jdbc:mysql://host/db", Some("mysql".r)) === + s"jdbc:$redaction:$redaction") + + // Null and empty inputs are passed through. + assert(JDBCOptions.redactUrl(null, None) === null) + assert(JDBCOptions.redactUrl("", None) === "") + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionReaderFactorySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionReaderFactorySuite.scala new file mode 100644 index 0000000000000..ae161bb73f10d --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/orc/OrcPartitionReaderFactorySuite.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.datasources.orc + +import org.apache.spark.DebugFilesystem +import org.apache.spark.memory.MemoryMode +import org.apache.spark.paths.SparkPath +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.datasources.PartitionedFile +import org.apache.spark.sql.execution.datasources.v2.orc.OrcPartitionReaderFactory +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sources.EqualTo +import org.apache.spark.sql.test.SharedSparkSession +import org.apache.spark.sql.types._ +import org.apache.spark.util.SerializableConfiguration + +class OrcPartitionReaderFactorySuite extends OrcTest with SharedSparkSession { + + import testImplicits._ + + test("SPARK-57529: Fix possible ORC reader leak in OrcPartitionReaderFactory") { + withTempPath { dir => + val dataSchema = StructType(Array(StructField("value", StringType))) + spark.range(10) + .select($"id".cast(StringType).as("value")) + .write.orc(dir.getCanonicalPath) + + val orcFile = dir.listFiles(_.getName.endsWith(".orc")).headOption + .getOrElse(fail("No ORC file written")) + + withSQLConf(SQLConf.ORC_FILTER_PUSHDOWN_ENABLED.key -> "true") { + val sqlConf = spark.sessionState.conf + val hadoopConf = spark.sessionState.newHadoopConf() + // Route file I/O through DebugFilesystem so we can assert no streams are leaked. + hadoopConf.set("fs.file.impl", classOf[DebugFilesystem].getName) + hadoopConf.set("fs.file.impl.disable.cache", "true") + val broadcastedConf = + spark.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) + + val factory = OrcPartitionReaderFactory( + sqlConf = sqlConf, + broadcastedConf = broadcastedConf, + dataSchema = dataSchema, + readDataSchema = dataSchema, + partitionSchema = StructType(Seq.empty), + // Integer literal on a STRING column triggers IllegalArgumentException in + // OrcFilters.createFilter -> buildLeafSearchArgument + filters = Array(EqualTo("value", 1)), + aggregation = None, + options = new OrcOptions(Map.empty[String, String], sqlConf), + memoryMode = MemoryMode.ON_HEAP) + + val partFile = PartitionedFile( + partitionValues = InternalRow.empty, + filePath = SparkPath.fromPathString(orcFile.getAbsolutePath), + start = 0, + length = orcFile.length()) + + DebugFilesystem.clearOpenStreams() + intercept[IllegalArgumentException] { + factory.buildReader(partFile) + } + DebugFilesystem.assertNoOpenStreams() + } + } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala index d7f28b79acff4..3b0279b5e6f30 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/v2/state/StateDataSourceTransformWithStateSuite.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.functions.{col, explode, timestamp_seconds} import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.streaming.{InputMapRow, ListState, MapInputEvent, MapOutputEvent, MapStateTTLProcessor, MaxEventTimeStatefulProcessor, OutputMode, RunningCountStatefulProcessor, RunningCountStatefulProcessorWithProcTimeTimerUpdates, StatefulProcessor, StateStoreMetricsTest, TestMapStateProcessor, TimeMode, TimerValues, TransformWithStateSuiteUtils, Trigger, TTLConfig, ValueState} import org.apache.spark.sql.streaming.util.StreamManualClock -import org.apache.spark.tags.SlowSQLTest +import org.apache.spark.tags.{ExtendedSQLTest, SlowSQLTest} import org.apache.spark.util.Utils /** Stateful processor of single value state var with non-primitive type */ @@ -1173,6 +1173,7 @@ class StateDataSourceTransformWithStateSuite extends StateStoreMetricsTest } } +@ExtendedSQLTest class StateDataSourceTransformWithStateSuiteCheckpointV2 extends StateDataSourceTransformWithStateSuite { @@ -1185,6 +1186,6 @@ class StateDataSourceTransformWithStateSuiteCheckpointV2 extends /** * Test suite that runs all StateDataSourceTransformWithStateSuite tests with row checksum enabled. */ -@SlowSQLTest +@ExtendedSQLTest class StateDataSourceTransformWithStateSuiteWithRowChecksum extends StateDataSourceTransformWithStateSuite with EnableStateStoreRowChecksum diff --git a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala index ea730abf67d66..29e3e0e6fb68a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/jdbc/JDBCSuite.scala @@ -18,7 +18,7 @@ package org.apache.spark.sql.jdbc import java.math.BigDecimal -import java.sql.{Date, DriverManager, Timestamp} +import java.sql.{Connection, Date, DriverManager, ResultSet, Statement, Timestamp} import java.time.{Instant, LocalDate, LocalDateTime} import java.time.format.DateTimeFormatter import java.util.{Calendar, GregorianCalendar, Properties, TimeZone} @@ -26,6 +26,7 @@ import java.util.{Calendar, GregorianCalendar, Properties, TimeZone} import scala.jdk.CollectionConverters._ import scala.util.Random +import org.mockito.ArgumentCaptor import org.mockito.ArgumentMatchers._ import org.mockito.Mockito._ @@ -35,7 +36,8 @@ import org.apache.spark.sql.catalyst.{analysis, TableIdentifier} import org.apache.spark.sql.catalyst.parser.CatalystSqlParser import org.apache.spark.sql.catalyst.plans.logical.ShowCreateTable import org.apache.spark.sql.catalyst.util.{CaseInsensitiveMap, CharVarcharUtils, DateTimeTestUtils} -import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference} +import org.apache.spark.sql.connector.catalog.Identifier +import org.apache.spark.sql.connector.expressions.{Expression => V2Expression, FieldReference, GeneralScalarExpression, LiteralValue} import org.apache.spark.sql.connector.expressions.filter.{AlwaysFalse, AlwaysTrue, Predicate} import org.apache.spark.sql.execution.{DataSourceScanExec, ExtendedMode, ProjectExec} import org.apache.spark.sql.execution.command.{ExplainCommand, ShowCreateTableCommand} @@ -837,6 +839,31 @@ class JDBCSuite extends SharedSparkSession { assert(JdbcDialects.get("test.invalid") === NoopDialect) } + test("SPARK-57447: (H2|MySQL|Postgres)Dialect escape a single quote in indexExists") { + // indexExists builds a lookup query with the index name as a SQL string literal, so a single + // quote in the name must be escaped to keep the WHERE clause well-formed. + Seq( + "jdbc:h2:mem:testdb0" -> "INDEX_NAME = 'i''1'", + "jdbc:mysql://127.0.0.1/db" -> "key_name = 'i''1'", + "jdbc:postgresql://127.0.0.1/db" -> "indexname = 'i''1'" + ).foreach { case (jdbcUrl, expectedClause) => + val dialect = JdbcDialects.get(jdbcUrl) + val conn = mock(classOf[Connection]) + val stmt = mock(classOf[Statement]) + val rs = mock(classOf[ResultSet]) + when(conn.createStatement()).thenReturn(stmt) + when(stmt.executeQuery(anyString())).thenReturn(rs) + + val options = new JDBCOptions(jdbcUrl, "test.people", Map.empty[String, String]) + dialect.indexExists(conn, "i'1", Identifier.of(Array("test"), "people"), options) + + val sqlCaptor = ArgumentCaptor.forClass(classOf[String]) + verify(stmt).executeQuery(sqlCaptor.capture()) + assert(sqlCaptor.getValue.contains(expectedClause), + s"Unexpected lookup SQL for $jdbcUrl: ${sqlCaptor.getValue}") + } + } + test("quote column names by jdbc dialect") { val mySQLDialect = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") val postgresDialect = JdbcDialects.get("jdbc:postgresql://127.0.0.1/db") @@ -943,6 +970,19 @@ class JDBCSuite extends SharedSparkSession { assert(mySQLSQL(StringStartsWith("c", "a%b_")) === """`c` LIKE 'a\\%b\\_%' ESCAPE '\\'""") } + test("SPARK-57446: escape single quotes in JDBC comment queries") { + val defaultDialect = JdbcDialects.get("jdbc:") + assert(defaultDialect.getTableCommentQuery("t", "a'b") === + "COMMENT ON TABLE t IS 'a''b'") + assert(defaultDialect.getSchemaCommentQuery("s", "a'b") === + """COMMENT ON SCHEMA "s" IS 'a''b'""") + + // MySQL overrides getTableCommentQuery with its own ALTER TABLE syntax. + val mySQLDialect = JdbcDialects.get("jdbc:mysql://127.0.0.1/db") + assert(mySQLDialect.getTableCommentQuery("t", "a'b") === + "ALTER TABLE t COMMENT = 'a''b'") + } + test("Dialect unregister") { JdbcDialects.unregisterDialect(H2Dialect()) try { @@ -1463,6 +1503,50 @@ class JDBCSuite extends SharedSparkSession { assert(getJdbcType(oracleDialect, TimestampNTZType) == "TIMESTAMP") } + test("Oracle TRUNC pushdown should map Spark format strings to Oracle format") { + val oracleDialect = JdbcDialects.get("jdbc:oracle://127.0.0.1/db") + val dateRef = FieldReference("d") + + // LiteralValue for StringType must use UTF8String (Spark's internal string type) + // to match what V2ExpressionBuilder produces in the real pushdown path. + import org.apache.spark.unsafe.types.UTF8String + def truncExpr(fmt: String): GeneralScalarExpression = new GeneralScalarExpression("TRUNC", + Array[V2Expression](dateRef, LiteralValue(UTF8String.fromString(fmt), StringType))) + + val monthSql = oracleDialect.compileExpression(truncExpr("MONTH")).get + assert(monthSql.contains("'MM'"), + s"trunc(d, 'MONTH') should produce Oracle 'MM', got: $monthSql") + assert(!monthSql.contains("'IW'"), + s"trunc(d, 'MONTH') should NOT produce 'IW', got: $monthSql") + + val weekSql = oracleDialect.compileExpression(truncExpr("WEEK")).get + assert(weekSql.contains("'IW'"), + s"trunc(d, 'WEEK') should produce Oracle 'IW', got: $weekSql") + + val yearSql = oracleDialect.compileExpression(truncExpr("YEAR")).get + assert(yearSql.contains("'YYYY'"), + s"trunc(d, 'YEAR') should produce Oracle 'YYYY', got: $yearSql") + + val quarterSql = oracleDialect.compileExpression(truncExpr("QUARTER")).get + assert(quarterSql.contains("'Q'"), + s"trunc(d, 'QUARTER') should produce Oracle 'Q', got: $quarterSql") + + // Case-insensitive: lowercase formats must also map correctly + val weekLowerSql = oracleDialect.compileExpression(truncExpr("week")).get + assert(weekLowerSql.contains("'IW'"), + s"trunc(d, 'week') (lowercase) should produce Oracle 'IW', got: $weekLowerSql") + + // Unmapped formats should NOT be pushed down (compileExpression returns None) + assert(oracleDialect.compileExpression(truncExpr("DAY")).isEmpty, + "Unmapped format 'DAY' should not be pushed down (compileExpression should return None)") + + // Alias formats (MM, MON, YYYY, YY) should also map correctly + val mmSql = oracleDialect.compileExpression(truncExpr("MM")).get + assert(mmSql.contains("'MM'"), s"trunc(d, 'MM') should produce Oracle 'MM', got: $mmSql") + val yySql = oracleDialect.compileExpression(truncExpr("YY")).get + assert(yySql.contains("'YYYY'"), s"trunc(d, 'YY') should produce Oracle 'YYYY', got: $yySql") + } + private def assertEmptyQuery(sqlString: String): Unit = { assert(sql(sqlString).collect().isEmpty) } @@ -2059,7 +2143,8 @@ class JDBCSuite extends SharedSparkSession { spark.read.format("jdbc").options(opts).load() }, condition = "FAILED_JDBC.CONNECTION", - parameters = Map("url" -> url) + // getRedactUrl() keeps only the "jdbc::" prefix and redacts the rest. + parameters = Map("url" -> s"jdbc:mysql:${Utils.REDACTION_REPLACEMENT_TEXT}") ) } @@ -2412,7 +2497,8 @@ class JDBCSuite extends SharedSparkSession { } }, condition = "FAILED_JDBC.CONNECTION", - parameters = Map("url" -> url) + // getRedactUrl() keeps only the "jdbc::" prefix and redacts the rest. + parameters = Map("url" -> s"$connectionUrl:${Utils.REDACTION_REPLACEMENT_TEXT}") ) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala index c0786c7384465..8d2418c5e0521 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/sources/InsertSuite.scala @@ -3089,6 +3089,37 @@ class InsertSuite extends DataSourceTest with SharedSparkSession { } } } + + test("SPARK-56919: INSERT OVERWRITE should not lose table path when AQE fails") { + withSQLConf( + SQLConf.ADAPTIVE_EXECUTION_ENABLED.key -> "true", + SQLConf.ADAPTIVE_EXECUTION_FORCE_APPLY.key -> "true", + SQLConf.PLANNED_WRITE_ENABLED.key -> "false") { + withTempPath { path => + val tablePath = path.getAbsolutePath + spark.range(10).toDF("id") + .write.mode("overwrite").parquet(tablePath) + assert(new java.io.File(tablePath).exists()) + + spark.udf.register("fail_udf", (i: Long) => { + throw new RuntimeException("SPARK-56919") + i + }) + + // The repartition forces a shuffle stage. With planned write disabled, + // materializeAdaptiveSparkPlan runs the stage, which fails via the UDF. + intercept[Exception] { + spark.sql(s"SELECT fail_udf(id) as id FROM parquet.`$tablePath`") + .repartition(2) + .write.mode("overwrite").parquet(tablePath) + } + + // The table path must survive a failed overwrite. + assert(new java.io.File(tablePath).exists(), + "Table path should not be permanently lost after a failed INSERT OVERWRITE") + } + } + } } class FileExistingTestFileSystem extends RawLocalFileSystem { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/util/PartitionKeyedAccumulatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/util/PartitionKeyedAccumulatorSuite.scala new file mode 100644 index 0000000000000..19e499942e310 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/util/PartitionKeyedAccumulatorSuite.scala @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.util + +import org.apache.spark.SparkFunSuite + +class PartitionKeyedAccumulatorSuite extends SparkFunSuite { + + // The cache use case records (rowCount, sizeInBytes) per partition. + private type Stats = (Long, Long) + + private def sumRows(acc: PartitionKeyedAccumulator[Stats]): Long = + acc.foldValues(0L)((sum, v) => sum + v._1) + + private def sumBytes(acc: PartitionKeyedAccumulator[Stats]): Long = + acc.foldValues(0L)((sum, v) => sum + v._2) + + test("isZero, add, value and accumulatedNumPartitions") { + val acc = new PartitionKeyedAccumulator[Stats] + assert(acc.isZero) + assert(acc.accumulatedNumPartitions == 0) + assert(acc.value.isEmpty) + + acc.add((0, (10L, 100L))) + assert(!acc.isZero) + assert(acc.accumulatedNumPartitions == 1) + assert(acc.value.get(0) == ((10L, 100L))) + + acc.add((1, (5L, 50L))) + assert(acc.accumulatedNumPartitions == 2) + assert(sumRows(acc) == 15L) + assert(sumBytes(acc) == 150L) + } + + test("add is last-write-wins for the same partition id") { + val acc = new PartitionKeyedAccumulator[Stats] + acc.add((0, (1L, 1L))) + acc.add((0, (2L, 2L))) // re-records partition 0 (e.g. a recompute) + assert(acc.accumulatedNumPartitions == 1) + assert(sumRows(acc) == 2L) // the later value wins, not 1 + 2 + assert(sumBytes(acc) == 2L) + } + + test("merge is last-write-wins per partition id (de-duplicates, does not sum)") { + // Two references compute the same partitions; partition 0 is computed by both. + val a = new PartitionKeyedAccumulator[Stats] + a.add((0, (10L, 100L))) + + val b = new PartitionKeyedAccumulator[Stats] + b.add((0, (10L, 100L))) // duplicate compute of partition 0 + b.add((1, (5L, 50L))) + + a.merge(b) + assert(a.accumulatedNumPartitions == 2) // partitions {0, 1}, not 3 + assert(sumRows(a) == 15L) // 10 (partition 0, counted once) + 5, NOT 25 + assert(sumBytes(a) == 150L) + } + + test("copy is an independent snapshot") { + val acc = new PartitionKeyedAccumulator[Stats] + acc.add((0, (10L, 100L))) + val snapshot = acc.copy() + acc.add((1, (5L, 50L))) // mutate the original after copying + + assert(snapshot.accumulatedNumPartitions == 1) + assert(sumRows(snapshot) == 10L) + assert(acc.accumulatedNumPartitions == 2) + assert(sumRows(acc) == 15L) + } + + test("reset and copyAndReset") { + val acc = new PartitionKeyedAccumulator[Stats] + acc.add((0, (10L, 100L))) + assert(!acc.isZero) + + assert(acc.copyAndReset().isZero) + assert(!acc.isZero) // copyAndReset does not mutate the source + + acc.reset() + assert(acc.isZero) + assert(acc.accumulatedNumPartitions == 0) + } + + test("works for an arbitrary value type") { + val acc = new PartitionKeyedAccumulator[String] + acc.add((0, "a")) + acc.add((1, "b")) + acc.add((0, "c")) // last-write-wins + assert(acc.accumulatedNumPartitions == 2) + assert(acc.foldValues("")((s, v) => s + v).length == 2) // "c" + "b" (each partition once) + } +} diff --git a/sql/hive-thriftserver/pom.xml b/sql/hive-thriftserver/pom.xml index 7fb1e6ca6bed1..b07ccc60574a9 100644 --- a/sql/hive-thriftserver/pom.xml +++ b/sql/hive-thriftserver/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java index 35c54388a0555..86501504b6368 100644 --- a/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java +++ b/sql/hive-thriftserver/src/main/java/org/apache/hive/service/cli/thrift/ThriftHttpServlet.java @@ -101,9 +101,10 @@ public ThriftHttpServlet(TProcessor processor, TProtocolFactory protocolFactory, // Initialize the cookie based authentication related variables. if (isCookieAuthEnabled) { // Generate the signer with secret. - String secret = Long.toString(RAN.nextLong()); - LOG.debug("Using the random number as the secret for cookie generation " + secret); - this.signer = new CookieSigner(secret.getBytes()); + byte[] secret = new byte[32]; + RAN.nextBytes(secret); + LOG.debug("Using the random bytes as the secret for cookie generation"); + this.signer = new CookieSigner(secret); this.cookieMaxAge = (int) hiveConf.getTimeVar( ConfVars.HIVE_SERVER2_THRIFT_HTTP_COOKIE_MAX_AGE, TimeUnit.SECONDS); this.cookieDomain = hiveConf.getVar(ConfVars.HIVE_SERVER2_THRIFT_HTTP_COOKIE_DOMAIN); diff --git a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala index 46302b316b757..591c2727c498e 100644 --- a/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala +++ b/sql/hive-thriftserver/src/main/scala/org/apache/spark/sql/hive/thriftserver/SparkExecuteStatementOperation.scala @@ -33,11 +33,11 @@ import org.apache.hive.service.rpc.thrift.{TCLIServiceConstants, TColumnDesc, TP import org.apache.spark.internal.{Logging, LogKeys} import org.apache.spark.internal.LogKeys._ import org.apache.spark.sql.{DataFrame, Row, SparkSession} +import org.apache.spark.sql.catalyst.types.ops.TypeApiOps import org.apache.spark.sql.catalyst.util.CharVarcharUtils import org.apache.spark.sql.catalyst.util.DateTimeConstants.MILLIS_PER_SECOND import org.apache.spark.sql.internal.{SQLConf, VariableSubstitution} import org.apache.spark.sql.types._ -import org.apache.spark.sql.types.ops.TypeApiOps import org.apache.spark.util.{Utils => SparkUtils} private[hive] class SparkExecuteStatementOperation( diff --git a/sql/hive/pom.xml b/sql/hive/pom.xml index a52842ab52a81..53b66493b1900 100644 --- a/sql/hive/pom.xml +++ b/sql/hive/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala index a459ef329755e..9f406cd3662e2 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/client/HiveClientSuite.scala @@ -37,8 +37,10 @@ import org.apache.spark.sql.connector.catalog.TableCatalog import org.apache.spark.sql.hive.HiveExternalCatalog import org.apache.spark.sql.hive.test.TestHiveVersion import org.apache.spark.sql.types.{IntegerType, StructType} +import org.apache.spark.tags.SlowHiveTest import org.apache.spark.util.{MutableURLClassLoader, Utils} +@SlowHiveTest class HiveClientSuite(version: String) extends HiveVersionSuite(version) { private var versionSpark: TestHiveVersion = null diff --git a/sql/pipelines/pom.xml b/sql/pipelines/pom.xml index e3dc230db89e5..3f7a62634d90e 100644 --- a/sql/pipelines/pom.xml +++ b/sql/pipelines/pom.xml @@ -22,7 +22,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../pom.xml spark-pipelines_2.13 diff --git a/streaming/pom.xml b/streaming/pom.xml index 6ca2ecb302f0d..825848589d3bb 100644 --- a/streaming/pom.xml +++ b/streaming/pom.xml @@ -21,7 +21,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../pom.xml diff --git a/tools/pom.xml b/tools/pom.xml index 7630f7875ed21..da75410f58aac 100644 --- a/tools/pom.xml +++ b/tools/pom.xml @@ -20,7 +20,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../pom.xml diff --git a/udf/worker/core/pom.xml b/udf/worker/core/pom.xml index f09e8b722ec46..1a5bbff112c38 100644 --- a/udf/worker/core/pom.xml +++ b/udf/worker/core/pom.xml @@ -24,7 +24,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../../pom.xml diff --git a/udf/worker/proto/pom.xml b/udf/worker/proto/pom.xml index 894dcce9fb55d..dbfbbc5d8c0d4 100644 --- a/udf/worker/proto/pom.xml +++ b/udf/worker/proto/pom.xml @@ -24,7 +24,7 @@ org.apache.spark spark-parent_2.13 - 4.2.0.1-4.3.0-1 + 4.2.0.1-4.3.0-2 ../../../pom.xml