From 5e059bae2638d8937566132302f45525c099b77f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Sun, 24 May 2026 10:52:30 +0000 Subject: [PATCH 01/12] Add classic.SparkSessionProvider --- .../sql/classic/SparkSessionProvider.scala | 27 +++++++++++++++++++ 1 file changed, 27 insertions(+) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala new file mode 100644 index 0000000000000..e459250f2d3f4 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala @@ -0,0 +1,27 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.sql + +/** + * A common trait for test suites that require a classic [[SparkSession]]. + */ +trait SparkSessionProvider extends sql.SparkSessionProvider { + override protected def spark: SparkSession +} From be85c952e263747b047305923a8c3fe14d291095 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Thu, 28 May 2026 18:35:32 +0000 Subject: [PATCH 02/12] Move test.SharedSparkSessionBase functionality to sql.SharedSparkSession --- .../apache/spark/sql/SharedSparkSession.scala | 147 ++++++++++++++++++ .../spark/sql/test/SharedSparkSession.scala | 126 +-------------- 2 files changed, 151 insertions(+), 122 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/SharedSparkSession.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/SharedSparkSession.scala new file mode 100644 index 0000000000000..2d34e6829ddd6 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SharedSparkSession.scala @@ -0,0 +1,147 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.concurrent.duration._ + +import org.scalatest.{BeforeAndAfterEach, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf, SparkFunSuite} +import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.TestSparkSession + +trait SharedSparkSession + extends SparkFunSuite + with SparkSessionProvider + with BeforeAndAfterEach + with Eventually { + + protected def sparkConf = { + val conf = new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set(UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) + // Disable ConvertToLocalRelation for better test coverage. Test cases built on + // LocalRelation will exercise the optimization rules better by disabling it as + // this rule may potentially block testing of other optimization rules such as + // ConstantPropagation etc. + .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) + conf.set( + StaticSQLConf.WAREHOUSE_PATH, + conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) + conf.set(StaticSQLConf.LOAD_SESSION_EXTENSIONS_FROM_CLASSPATH, false) + conf.set(StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD, + sys.env.getOrElse("SPARK_TEST_SQL_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD", + StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) + conf.set(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD, + sys.env.getOrElse("SPARK_TEST_SQL_RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD", + StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: classic.SparkSession = null + + protected override def spark: SparkSession = _spark + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: classic.SparkSession = { + classic.SparkSession.cleanupAnyExistingSession() + new TestSparkSession(sparkConf) + } + + protected def sqlConf: SQLConf = _spark.sessionState.conf + + /** + * Initialize the [[TestSparkSession]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeSession' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeSession(): Unit = { + if (_spark == null) { + _spark = createSparkSession + } + } + + /** + * Make sure the [[TestSparkSession]] is initialized before any tests are run. + */ + protected override def beforeAll(): Unit = { + initializeSession() + + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + _spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds), interval(2.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index fb26d3311ebef..eccadab760665 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -19,15 +19,8 @@ package org.apache.spark.sql.test import scala.concurrent.duration._ -import org.scalatest.{BeforeAndAfterEach, Suite} -import org.scalatest.concurrent.Eventually - -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK +import org.apache.spark.sql import org.apache.spark.sql.{classic, QueryTest, QueryTestBase, SparkSession, SparkSessionProvider, SQLContext} -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode -import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { @@ -87,119 +80,8 @@ trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { */ trait SharedSparkSessionBase extends QueryTestBase - with SparkSessionProvider - with BeforeAndAfterEach - with Eventually { self: Suite => - - protected def sparkConf = { - val conf = new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set(UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) - // Disable ConvertToLocalRelation for better test coverage. Test cases built on - // LocalRelation will exercise the optimization rules better by disabling it as - // this rule may potentially block testing of other optimization rules such as - // ConstantPropagation etc. - .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) - conf.set( - StaticSQLConf.WAREHOUSE_PATH, - conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) - conf.set(StaticSQLConf.LOAD_SESSION_EXTENSIONS_FROM_CLASSPATH, false) - conf.set(StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD, - sys.env.getOrElse("SPARK_TEST_SQL_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD", - StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) - conf.set(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD, - sys.env.getOrElse("SPARK_TEST_SQL_RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD", - StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected override def spark: classic.SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - classic.SparkSession.cleanupAnyExistingSession() - new TestSparkSession(sparkConf) - } - - protected def sqlConf: SQLConf = _spark.sessionState.conf - - /** - * Initialize the [[TestSparkSession]]. Generally, this is just called from - * beforeAll; however, in test using styles other than FunSuite, there is - * often code that relies on the session between test group constructs and - * the actual tests, which may need this session. It is purely a semantic - * difference, but semantically, it makes more sense to call - * 'initializeSession' between a 'describe' and an 'it' call than it does to - * call 'beforeAll'. - */ - protected def initializeSession(): Unit = { - if (_spark == null) { - _spark = createSparkSession - } - } + with sql.SharedSparkSession { - /** - * Make sure the [[TestSparkSession]] is initialized before any tests are run. - */ - protected override def beforeAll(): Unit = { - initializeSession() - - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - try { - super.afterAll() - } finally { - try { - if (_spark != null) { - try { - _spark.sessionState.catalog.reset() - } finally { - _spark.stop() - _spark = null - } - } - } finally { - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - } - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } - - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds), interval(2.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } + protected override def spark: classic.SparkSession = + super.spark.asInstanceOf[classic.SparkSession] } From 73a8535de0e902a20921c156cd85a46220b9e079 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Thu, 28 May 2026 18:49:56 +0000 Subject: [PATCH 03/12] [API CHANGE]: Move doThreadPreAudit, doThreadPostAudit to sql.SharedSparkSession This is technically an 'api change' as it moves the thread audit stuff from `test.SharedSparkSession` to `test.SharedSparkSessionBase`. This breaks code that implements `SharedSparkSessionBase` to circumvent the thread audit stuff. --- .../apache/spark/sql/SharedSparkSession.scala | 20 ++++++++++++++-- .../spark/sql/test/SharedSparkSession.scala | 23 ------------------- 2 files changed, 18 insertions(+), 25 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/SharedSparkSession.scala index 2d34e6829ddd6..8cab22ac97182 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SharedSparkSession.scala @@ -95,10 +95,22 @@ trait SharedSparkSession } } + /** + * Suites extending [[SharedSparkSession]] are sharing resources (e.g. SparkSession) in their + * tests. That trait initializes the spark session in its [[beforeAll()]] implementation before + * the automatic thread snapshot is performed, so the audit code could fail to report threads + * leaked by that shared session. + * + * The behavior is overridden here to take the snapshot before the spark session is initialized. + */ + override protected val enableAutoThreadAudit = false + /** * Make sure the [[TestSparkSession]] is initialized before any tests are run. */ protected override def beforeAll(): Unit = { + doThreadPreAudit() + initializeSession() // Ensure we have initialized the context before calling parent code @@ -122,8 +134,12 @@ trait SharedSparkSession } } } finally { - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() + try { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } finally { + doThreadPostAudit() + } } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index eccadab760665..283ef23bfba16 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -24,29 +24,6 @@ import org.apache.spark.sql.{classic, QueryTest, QueryTestBase, SparkSession, Sp trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { - /** - * Suites extending [[SharedSparkSession]] are sharing resources (e.g. SparkSession) in their - * tests. That trait initializes the spark session in its [[beforeAll()]] implementation before - * the automatic thread snapshot is performed, so the audit code could fail to report threads - * leaked by that shared session. - * - * The behavior is overridden here to take the snapshot before the spark session is initialized. - */ - override protected val enableAutoThreadAudit = false - - protected override def beforeAll(): Unit = { - doThreadPreAudit() - super.beforeAll() - } - - protected override def afterAll(): Unit = { - try { - super.afterAll() - } finally { - doThreadPostAudit() - } - } - // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that // execution as a map keyed by (planNodeId, planNodeName, metricName) -> metricValue. def runAndFetchMetrics(func: => Unit): Map[(Long, String, String), String] = { From c361709f75136d0752303a96d2b643d9284df909 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Thu, 28 May 2026 19:13:25 +0000 Subject: [PATCH 04/12] Rename sql.SharedSparkSession to sql.SparkSessionBinder to prevent shadowing --- .../sql/{SharedSparkSession.scala => SparkSessionBinder.scala} | 2 +- .../scala/org/apache/spark/sql/test/SharedSparkSession.scala | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) rename sql/core/src/test/scala/org/apache/spark/sql/{SharedSparkSession.scala => SparkSessionBinder.scala} (99%) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala similarity index 99% rename from sql/core/src/test/scala/org/apache/spark/sql/SharedSparkSession.scala rename to sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala index 8cab22ac97182..f1f70e9b2b109 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala @@ -29,7 +29,7 @@ import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.TestSparkSession -trait SharedSparkSession +trait SparkSessionBinder extends SparkFunSuite with SparkSessionProvider with BeforeAndAfterEach diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 283ef23bfba16..65bdd029aa65b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -57,7 +57,7 @@ trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { */ trait SharedSparkSessionBase extends QueryTestBase - with sql.SharedSparkSession { + with sql.SparkSessionBinder { protected override def spark: classic.SparkSession = super.spark.asInstanceOf[classic.SparkSession] From 032a6847f2719fb21986078bdd7fa1c1ccb7745a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Thu, 28 May 2026 19:31:39 +0000 Subject: [PATCH 05/12] Deprecate test.SharedSparkSession --- .../scala/org/apache/spark/sql/test/SharedSparkSession.scala | 2 ++ 1 file changed, 2 insertions(+) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 65bdd029aa65b..ad764d99c059f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -22,6 +22,7 @@ import scala.concurrent.duration._ import org.apache.spark.sql import org.apache.spark.sql.{classic, QueryTest, QueryTestBase, SparkSession, SparkSessionProvider, SQLContext} +@deprecated("Use SparkSessionBinder and QueryTest instead") trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that @@ -55,6 +56,7 @@ trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ +@deprecated("Use SparkSessionBinder and QueryTest instead") trait SharedSparkSessionBase extends QueryTestBase with sql.SparkSessionBinder { From c9ea7547df660fdf0823c0ee6cddbcfc1ac167d9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Mon, 25 May 2026 16:44:56 +0000 Subject: [PATCH 06/12] Add connect.SparkSession{Provider,Binder}, connect.QueryTest and demo --- .../apache/spark/sql/connect/QueryTest.scala | 38 +++++++ .../connect/QueryTestWithConnectSuite.scala | 32 ++++++ .../sql/connect/SparkSessionBinder.scala | 98 +++++++++++++++++++ .../sql/connect/SparkSessionProvider.scala | 28 ++++++ .../org/apache/spark/sql/QueryTest.scala | 2 +- 5 files changed, 197 insertions(+), 1 deletion(-) create mode 100644 sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala create mode 100644 sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala create mode 100644 sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala create mode 100644 sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala new file mode 100644 index 0000000000000..e107eb01f5700 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala @@ -0,0 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.{sql => sqlApi} + +/** + * Extends [[sqlApi.QueryTest]] for use with Connect sessions. + * + * Overrides [[checkAnswer]] to avoid classic-only code paths (e.g. `queryExecution`, + * `logicalPlan`, `materializedRdd`) that are not available on Connect DataFrames. + */ +trait QueryTest extends sqlApi.QueryTest with SparkSessionProvider { + + override protected def checkAnswer( + df: => sqlApi.DataFrame, expectedAnswer: Seq[sqlApi.Row]): Unit = { + val sparkAnswer = df.collect().toSeq + sqlApi.QueryTest.sameRows(expectedAnswer, sparkAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } +} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala new file mode 100644 index 0000000000000..488335daa8ded --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.sql.QueryTestSuite + +/** + * Runs [[QueryTestSuite]] tests through a Connect session. + * + * This validates the `FooSuite with connect.SharedSparkSession` pattern: the existing + * [[QueryTestSuite]] tests are inherited unchanged, but execute against a + * [[org.apache.spark.sql.connect.SparkSession connect.SparkSession]] instead of a classic one. + */ +class QueryTestWithConnectSuite + extends QueryTestSuite + with SparkSessionBinder + with QueryTest diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala new file mode 100644 index 0000000000000..0e077cdcdf33e --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala @@ -0,0 +1,98 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import java.util.UUID + +import scala.concurrent.duration._ + +import org.scalatest.concurrent.Eventually + +import org.apache.spark.DebugFilesystem +import org.apache.spark.sql +import org.apache.spark.sql.classic +import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.common.config.ConnectCommon +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.SparkConnectService + +/** + * A test trait that provides a Connect [[SparkSession]] backed by an in-process gRPC server. + * Extends [[sql.SparkSessionBinder sql.SparkSessionBinder]] (which creates a + * [[classic.SparkSession classic.SparkSession]] and SparkContext), then layers a Connect client + * session on top by starting the gRPC service in-process. + * + * Mix in this trait to exercise existing sql/core test suites through the Connect path: + * {{{ + * class FooWithConnectSuite + * extends FooSuite + * with connect.SparkSessionBinder + * with connect.QueryTest + * }}} + */ +trait SparkSessionBinder extends sql.SparkSessionBinder { + + private val serverPort: Int = + ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000) + + @volatile private var _connectSpark: SparkSession = _ + + protected override def spark: SparkSession = _connectSpark + + /** The underlying classic session used by the in-process server. */ + private def classicSpark: classic.SparkSession = super.spark.asInstanceOf[classic.SparkSession] + + override def beforeAll(): Unit = { + super.beforeAll() + withSparkEnvConfs((Connect.CONNECT_GRPC_BINDING_PORT.key, serverPort.toString)) { + SparkConnectService.start(classicSpark.sparkContext) + } + val client = SparkConnectClient + .builder() + .port(serverPort) + .sessionId(UUID.randomUUID().toString) + .userId("test") + .build() + _connectSpark = SparkSession + .builder() + .client(client) + .create() + } + + override def afterAll(): Unit = { + try { + if (_connectSpark != null) { + _connectSpark.close() + _connectSpark = null + } + SparkConnectService.stop() + } finally { + super.afterAll() + } + } + + // The base SharedSparkSessionBase.afterEach calls spark.sharedState which is not supported + // on Connect. Override to use the classic session for cleanup. + protected override def afterEach(): Unit = { + // super.afterEach() from BeforeAndAfterEach (skipping SharedSparkSessionBase) + classicSpark.sharedState.cacheManager.clearCache() + Eventually.eventually(Eventually.timeout(10.seconds), Eventually.interval(2.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } +} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala new file mode 100644 index 0000000000000..d9e456c0fd706 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.sql + +/** + * A common trait for test suites or utils that require a connect [[SparkSession]]. + * Use together with e.g. [[SparkSessionBinder]]. + */ +trait SparkSessionProvider extends sql.SparkSessionProvider { + protected override def spark: SparkSession +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 5a1ea3d9f53cf..5b02cbbd95d43 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -1211,7 +1211,7 @@ object QueryTest extends Assertions { } -class QueryTestSuite extends test.SharedSparkSession { +class QueryTestSuite extends QueryTest with SparkSessionBinder { test("SPARK-16940: checkAnswer should raise TestFailedException for wrong results") { intercept[org.scalatest.exceptions.TestFailedException] { checkAnswer(sql("SELECT 1"), Row(2) :: Nil) From c9c8538ec2b4425a485da4c21bd1c3282f24fba8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Thu, 28 May 2026 19:40:10 +0000 Subject: [PATCH 07/12] Add classic.SparkSessionBinder with usage demonstration --- .../sql/classic/SparkSessionBinder.scala | 26 +++++++++++++++++++ .../parquet/ParquetQuerySuite.scala | 6 +++-- 2 files changed, 30 insertions(+), 2 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala new file mode 100644 index 0000000000000..9eea2c7cf6965 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala @@ -0,0 +1,26 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.sql + +trait SparkSessionBinder + extends sql.SparkSessionBinder + with SparkSessionProvider { + override protected def spark: SparkSession = super.spark.asInstanceOf[SparkSession] +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index b20b6d397fd17..5c2bc6829ea59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.classic import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{SchemaColumnConvertNotSupportedException, SQLHadoopMapReduceCommitProtocol} import org.apache.spark.sql.execution.datasources.parquet.TestingUDT._ @@ -37,14 +38,15 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.functions.struct import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. */ -abstract class ParquetQuerySuite extends ParquetTest with SharedSparkSession { +abstract class ParquetQuerySuite extends ParquetTest + with QueryTest + with classic.SparkSessionBinder { import testImplicits._ test("simple select queries") { From 4c35b220e4e303e7aaf45c438955d71839ac3eb6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Thu, 28 May 2026 20:53:24 +0000 Subject: [PATCH 08/12] fixup: fix compile error --- .../org/apache/spark/sql/test/SharedSparkSession.scala | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index ad764d99c059f..1c506df20634d 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.test import scala.concurrent.duration._ import org.apache.spark.sql -import org.apache.spark.sql.{classic, QueryTest, QueryTestBase, SparkSession, SparkSessionProvider, SQLContext} +import org.apache.spark.sql.{classic, QueryTest, QueryTestBase} @deprecated("Use SparkSessionBinder and QueryTest instead") trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { @@ -57,9 +57,7 @@ trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ @deprecated("Use SparkSessionBinder and QueryTest instead") -trait SharedSparkSessionBase - extends QueryTestBase - with sql.SparkSessionBinder { +trait SharedSparkSessionBase extends sql.SparkSessionBinder with QueryTestBase { protected override def spark: classic.SparkSession = super.spark.asInstanceOf[classic.SparkSession] From c52b50fb5cb961f5757f0b9a0adbde378d7f8f5c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Fri, 29 May 2026 17:32:44 +0000 Subject: [PATCH 09/12] Restructure so that SparkSessionBinder implements QueryTest, address nits --- .../apache/spark/sql/connect/QueryTest.scala | 9 ++- .../connect/QueryTestWithConnectSuite.scala | 5 +- .../sql/connect/SparkSessionBinder.scala | 36 ++++-------- .../org/apache/spark/sql/QueryTest.scala | 4 +- .../apache/spark/sql/SparkSessionBinder.scala | 54 ++++++++++-------- .../apache/spark/sql/classic/QueryTest.scala | 55 +++++++++++++++++++ .../sql/classic/SparkSessionBinder.scala | 7 ++- .../sql/classic/SparkSessionProvider.scala | 3 - .../spark/sql/test/SharedSparkSession.scala | 15 +++-- 9 files changed, 123 insertions(+), 65 deletions(-) create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala index e107eb01f5700..ab3bd2c494311 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala @@ -20,10 +20,13 @@ package org.apache.spark.sql.connect import org.apache.spark.{sql => sqlApi} /** - * Extends [[sqlApi.QueryTest]] for use with Connect sessions. + * Extends [[sqlApi.QueryTest]] to provide connect-specific overrides to helpers like + * [[checkAnswer]] that avoid classic-only APIs. * - * Overrides [[checkAnswer]] to avoid classic-only code paths (e.g. `queryExecution`, - * `logicalPlan`, `materializedRdd`) that are not available on Connect DataFrames. + * Can be used together with [[SparkSessionBinder connect.SparkSessionBinder]] to create a + * 'connect variant' of a test. + * + * Note: broader use will require more overrides. */ trait QueryTest extends sqlApi.QueryTest with SparkSessionProvider { diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala index 488335daa8ded..f13765dc03aa1 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala @@ -22,11 +22,10 @@ import org.apache.spark.sql.QueryTestSuite /** * Runs [[QueryTestSuite]] tests through a Connect session. * - * This validates the `FooSuite with connect.SharedSparkSession` pattern: the existing + * This validates the `FooSuite with connect.QueryTest` pattern: the existing * [[QueryTestSuite]] tests are inherited unchanged, but execute against a - * [[org.apache.spark.sql.connect.SparkSession connect.SparkSession]] instead of a classic one. + * [[SparkSession connect.SparkSession]] instead of a classic one. */ class QueryTestWithConnectSuite extends QueryTestSuite - with SparkSessionBinder with QueryTest diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala index 0e077cdcdf33e..d60f26a22fee5 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala @@ -19,20 +19,15 @@ package org.apache.spark.sql.connect import java.util.UUID -import scala.concurrent.duration._ - -import org.scalatest.concurrent.Eventually - -import org.apache.spark.DebugFilesystem +import org.apache.spark.SparkEnv import org.apache.spark.sql import org.apache.spark.sql.classic import org.apache.spark.sql.connect.client.SparkConnectClient -import org.apache.spark.sql.connect.common.config.ConnectCommon import org.apache.spark.sql.connect.config.Connect import org.apache.spark.sql.connect.service.SparkConnectService /** - * A test trait that provides a Connect [[SparkSession]] backed by an in-process gRPC server. + * Provides a [[SparkSession connect.SparkSession]] backed by an in-process gRPC server. * Extends [[sql.SparkSessionBinder sql.SparkSessionBinder]] (which creates a * [[classic.SparkSession classic.SparkSession]] and SparkContext), then layers a Connect client * session on top by starting the gRPC service in-process. @@ -42,15 +37,11 @@ import org.apache.spark.sql.connect.service.SparkConnectService * class FooWithConnectSuite * extends FooSuite * with connect.SparkSessionBinder - * with connect.QueryTest * }}} */ -trait SparkSessionBinder extends sql.SparkSessionBinder { +trait SparkSessionBinder extends sql.SparkSessionBinder with QueryTest { - private val serverPort: Int = - ConnectCommon.CONNECT_GRPC_BINDING_PORT + util.Random.nextInt(1000) - - @volatile private var _connectSpark: SparkSession = _ + private var _connectSpark: SparkSession = _ protected override def spark: SparkSession = _connectSpark @@ -59,12 +50,17 @@ trait SparkSessionBinder extends sql.SparkSessionBinder { override def beforeAll(): Unit = { super.beforeAll() - withSparkEnvConfs((Connect.CONNECT_GRPC_BINDING_PORT.key, serverPort.toString)) { + val prevPort = SparkEnv.get.conf.get(Connect.CONNECT_GRPC_BINDING_PORT) + try { + // set GRPC_BINDING_PORT to 0 so that the server picks a random, freely available port. + SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, 0) SparkConnectService.start(classicSpark.sparkContext) + } finally { + SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, prevPort) } val client = SparkConnectClient .builder() - .port(serverPort) + .port(SparkConnectService.localPort) .sessionId(UUID.randomUUID().toString) .userId("test") .build() @@ -85,14 +81,4 @@ trait SparkSessionBinder extends sql.SparkSessionBinder { super.afterAll() } } - - // The base SharedSparkSessionBase.afterEach calls spark.sharedState which is not supported - // on Connect. Override to use the classic session for cleanup. - protected override def afterEach(): Unit = { - // super.afterEach() from BeforeAndAfterEach (skipping SharedSparkSessionBase) - classicSpark.sharedState.cacheManager.clearCache() - Eventually.eventually(Eventually.timeout(10.seconds), Eventually.interval(2.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 5b02cbbd95d43..8f7a0b8f5ddc2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -510,6 +510,7 @@ trait QueryTestBase /** * Strip Spark-side filtering in order to check if a datasource filters rows correctly. */ + @deprecated("Classic-only method, use classic.QueryTest", "4.2.0") protected def stripSparkFilter(df: DataFrame): DataFrame = { val schema = df.schema val withoutFilters = df.queryExecution.executedPlan.transform { @@ -524,6 +525,7 @@ trait QueryTestBase * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier * way to construct `DataFrame` directly out of local data without relying on implicits. */ + @deprecated("Classic-only method, use classic.QueryTest", "4.2.0") protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): classic.DataFrame = { classic.Dataset.ofRows(spark.asInstanceOf[classic.SparkSession], plan) } @@ -1211,7 +1213,7 @@ object QueryTest extends Assertions { } -class QueryTestSuite extends QueryTest with SparkSessionBinder { +class QueryTestSuite extends QueryTest { test("SPARK-16940: checkAnswer should raise TestFailedException for wrong results") { intercept[org.scalatest.exceptions.TestFailedException] { checkAnswer(sql("SELECT 1"), Row(2) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala index f1f70e9b2b109..a3ca244ca3718 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala @@ -22,18 +22,44 @@ import scala.concurrent.duration._ import org.scalatest.{BeforeAndAfterEach, Suite} import org.scalatest.concurrent.Eventually -import org.apache.spark.{DebugFilesystem, SparkConf, SparkFunSuite} +import org.apache.spark.{DebugFilesystem, SparkConf} import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} import org.apache.spark.sql.test.TestSparkSession -trait SparkSessionBinder - extends SparkFunSuite +trait SparkSessionBinder extends QueryTest with SparkSessionBinderBase { + + /** + * Suites extending this trait are sharing resources (e.g. SparkSession) in their + * tests. This trait initializes the spark session in its [[beforeAll()]] implementation before + * the automatic thread snapshot is performed, so the audit code could fail to report threads + * leaked by that shared session. + * + * The behavior is overridden here to take the snapshot before the spark session is initialized. + */ + override protected val enableAutoThreadAudit = false + + protected override def beforeAll(): Unit = { + doThreadPreAudit() + super.beforeAll() + } + + protected override def afterAll(): Unit = { + try { + super.afterAll() + } finally { + doThreadPostAudit() + } + } +} + +trait SparkSessionBinderBase + extends QueryTestBase with SparkSessionProvider with BeforeAndAfterEach - with Eventually { + with Eventually { self: Suite => protected def sparkConf = { val conf = new SparkConf() @@ -95,22 +121,10 @@ trait SparkSessionBinder } } - /** - * Suites extending [[SharedSparkSession]] are sharing resources (e.g. SparkSession) in their - * tests. That trait initializes the spark session in its [[beforeAll()]] implementation before - * the automatic thread snapshot is performed, so the audit code could fail to report threads - * leaked by that shared session. - * - * The behavior is overridden here to take the snapshot before the spark session is initialized. - */ - override protected val enableAutoThreadAudit = false - /** * Make sure the [[TestSparkSession]] is initialized before any tests are run. */ protected override def beforeAll(): Unit = { - doThreadPreAudit() - initializeSession() // Ensure we have initialized the context before calling parent code @@ -134,12 +148,8 @@ trait SparkSessionBinder } } } finally { - try { - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - } finally { - doThreadPostAudit() - } + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala new file mode 100644 index 0000000000000..20941dd0c549b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import scala.language.implicitConversions + +import org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.FilterExec + +/** + * Extends [[org.apache.spark.sql.QueryTest sql.QueryTest]] to explicitly provide + * a [[SparkSession classic.SparkSession]] and corresponding helpers. + * + * Use this trait to indicate that a test is classic-only, + * i.e it is not intended to run this test with a + * [[org.apache.spark.sql.connect.QueryTest connect.QueryTest]] override. + */ +trait QueryTest extends sql.QueryTest with SparkSessionProvider { + + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val withoutFilters = df.queryExecution.executedPlan.transform { + case FilterExec(_, child) => child + } + + spark.internalCreateDataFrame(withoutFilters.execute(), schema) + } + + /** + * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier + * way to construct `DataFrame` directly out of local data without relying on implicits. + */ + protected implicit override def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + Dataset.ofRows(spark, plan) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala index 9eea2c7cf6965..e0b4a794d2bb7 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala @@ -19,8 +19,9 @@ package org.apache.spark.sql.classic import org.apache.spark.sql -trait SparkSessionBinder - extends sql.SparkSessionBinder - with SparkSessionProvider { +/** + * Overrides [[spark]] to provide a [[SparkSession classic.SparkSession]] + */ +trait SparkSessionBinder extends sql.SparkSessionBinder with QueryTest { override protected def spark: SparkSession = super.spark.asInstanceOf[SparkSession] } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala index e459250f2d3f4..77de0db4bf68b 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala @@ -19,9 +19,6 @@ package org.apache.spark.sql.classic import org.apache.spark.sql -/** - * A common trait for test suites that require a classic [[SparkSession]]. - */ trait SparkSessionProvider extends sql.SparkSessionProvider { override protected def spark: SparkSession } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 1c506df20634d..6a176805a349a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -19,11 +19,16 @@ package org.apache.spark.sql.test import scala.concurrent.duration._ +import org.scalatest.Suite + import org.apache.spark.sql -import org.apache.spark.sql.{classic, QueryTest, QueryTestBase} +import org.apache.spark.sql.{classic, QueryTest} + +@deprecated("Use SparkSessionBinder (or classic.SparkSessionBinder if required) instead", "4.2.0") +trait SharedSparkSession extends sql.SparkSessionBinder { -@deprecated("Use SparkSessionBinder and QueryTest instead") -trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { + protected override def spark: classic.SparkSession = + super.spark.asInstanceOf[classic.SparkSession] // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that // execution as a map keyed by (planNodeId, planNodeName, metricName) -> metricValue. @@ -56,8 +61,8 @@ trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ -@deprecated("Use SparkSessionBinder and QueryTest instead") -trait SharedSparkSessionBase extends sql.SparkSessionBinder with QueryTestBase { +@deprecated("Use SparkSessionBinder (or classic.SparkSessionBinder if required) instead", "4.2.0") +trait SharedSparkSessionBase extends sql.SparkSessionBinderBase { self: Suite => protected override def spark: classic.SparkSession = super.spark.asInstanceOf[classic.SparkSession] From 642166b3019b6e918eb3bc54c1535013ed1f6cf6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Fri, 29 May 2026 17:34:06 +0000 Subject: [PATCH 10/12] fixup --- .../apache/spark/sql/connect/QueryTestWithConnectSuite.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala index f13765dc03aa1..013acba63b80f 100644 --- a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala @@ -22,10 +22,10 @@ import org.apache.spark.sql.QueryTestSuite /** * Runs [[QueryTestSuite]] tests through a Connect session. * - * This validates the `FooSuite with connect.QueryTest` pattern: the existing + * This validates the `FooSuite with connect.SparkSessionBinder` pattern: the existing * [[QueryTestSuite]] tests are inherited unchanged, but execute against a * [[SparkSession connect.SparkSession]] instead of a classic one. */ class QueryTestWithConnectSuite extends QueryTestSuite - with QueryTest + with SparkSessionBinder From d51a967ee894ea958c03509deda94985079fdab4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Fri, 29 May 2026 18:04:28 +0000 Subject: [PATCH 11/12] Have SharedSparkSession as empty alias of classic.SparkSessionBinder --- .../sql/classic/SparkSessionBinder.scala | 29 +++++++++++++++ .../spark/sql/test/SharedSparkSession.scala | 36 ++----------------- 2 files changed, 31 insertions(+), 34 deletions(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala index e0b4a794d2bb7..68920a445e5fc 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.classic +import scala.concurrent.duration._ + import org.apache.spark.sql /** @@ -24,4 +26,31 @@ import org.apache.spark.sql */ trait SparkSessionBinder extends sql.SparkSessionBinder with QueryTest { override protected def spark: SparkSession = super.spark.asInstanceOf[SparkSession] + + // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that + // execution as a map keyed by (planNodeId, planNodeName, metricName) -> metricValue. + def runAndFetchMetrics(func: => Unit): Map[(Long, String, String), String] = { + val statusStore = spark.sharedState.statusStore + val oldCount = statusStore.executionsList().size + + func + + // Wait until the new execution is started and being tracked. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsCount() >= oldCount) + } + + // Wait for listener to finish computing the metrics for the execution. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsList().nonEmpty && + statusStore.executionsList().last.metricValues != null) + } + + val exec = statusStore.executionsList().last + val execId = exec.executionId + val sqlMetrics = statusStore.planGraph(execId).allNodes + .flatMap(n => n.metrics.map(m => (m.accumulatorId, (n.id, n.name, m.name)))) + .toMap + statusStore.executionMetrics(execId).map { case (k, v) => sqlMetrics(k) -> v } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index 6a176805a349a..c52bcd4aa9c2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -17,46 +17,14 @@ package org.apache.spark.sql.test -import scala.concurrent.duration._ - import org.scalatest.Suite import org.apache.spark.sql -import org.apache.spark.sql.{classic, QueryTest} +import org.apache.spark.sql.classic @deprecated("Use SparkSessionBinder (or classic.SparkSessionBinder if required) instead", "4.2.0") -trait SharedSparkSession extends sql.SparkSessionBinder { - - protected override def spark: classic.SparkSession = - super.spark.asInstanceOf[classic.SparkSession] - - // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that - // execution as a map keyed by (planNodeId, planNodeName, metricName) -> metricValue. - def runAndFetchMetrics(func: => Unit): Map[(Long, String, String), String] = { - val statusStore = spark.sharedState.statusStore - val oldCount = statusStore.executionsList().size +trait SharedSparkSession extends classic.SparkSessionBinder - func - - // Wait until the new execution is started and being tracked. - eventually(timeout(10.seconds), interval(10.milliseconds)) { - assert(statusStore.executionsCount() >= oldCount) - } - - // Wait for listener to finish computing the metrics for the execution. - eventually(timeout(10.seconds), interval(10.milliseconds)) { - assert(statusStore.executionsList().nonEmpty && - statusStore.executionsList().last.metricValues != null) - } - - val exec = statusStore.executionsList().last - val execId = exec.executionId - val sqlMetrics = statusStore.planGraph(execId).allNodes - .flatMap(n => n.metrics.map(m => (m.accumulatorId, (n.id, n.name, m.name)))) - .toMap - statusStore.executionMetrics(execId).map { case (k, v) => sqlMetrics(k) -> v } - } -} /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. From 6b5dba77015378c074f5bfbd6c26304a1675abff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Matthis=20G=C3=B6rdel?= Date: Fri, 29 May 2026 22:19:48 +0000 Subject: [PATCH 12/12] fixup --- sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 8f7a0b8f5ddc2..17212fa30b954 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -1213,7 +1213,7 @@ object QueryTest extends Assertions { } -class QueryTestSuite extends QueryTest { +class QueryTestSuite extends QueryTest with SparkSessionBinder { test("SPARK-16940: checkAnswer should raise TestFailedException for wrong results") { intercept[org.scalatest.exceptions.TestFailedException] { checkAnswer(sql("SELECT 1"), Row(2) :: Nil)