diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala new file mode 100644 index 0000000000000..ab3bd2c494311 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTest.scala @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.{sql => sqlApi} + +/** + * Extends [[sqlApi.QueryTest]] to provide connect-specific overrides to helpers like + * [[checkAnswer]] that avoid classic-only APIs. + * + * Can be used together with [[SparkSessionBinder connect.SparkSessionBinder]] to create a + * 'connect variant' of a test. + * + * Note: broader use will require more overrides. + */ +trait QueryTest extends sqlApi.QueryTest with SparkSessionProvider { + + override protected def checkAnswer( + df: => sqlApi.DataFrame, expectedAnswer: Seq[sqlApi.Row]): Unit = { + val sparkAnswer = df.collect().toSeq + sqlApi.QueryTest.sameRows(expectedAnswer, sparkAnswer) match { + case Some(errorMessage) => fail(errorMessage) + case None => + } + } +} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala new file mode 100644 index 0000000000000..013acba63b80f --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/QueryTestWithConnectSuite.scala @@ -0,0 +1,31 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.sql.QueryTestSuite + +/** + * Runs [[QueryTestSuite]] tests through a Connect session. + * + * This validates the `FooSuite with connect.SparkSessionBinder` pattern: the existing + * [[QueryTestSuite]] tests are inherited unchanged, but execute against a + * [[SparkSession connect.SparkSession]] instead of a classic one. + */ +class QueryTestWithConnectSuite + extends QueryTestSuite + with SparkSessionBinder diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala new file mode 100644 index 0000000000000..d60f26a22fee5 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionBinder.scala @@ -0,0 +1,84 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import java.util.UUID + +import org.apache.spark.SparkEnv +import org.apache.spark.sql +import org.apache.spark.sql.classic +import org.apache.spark.sql.connect.client.SparkConnectClient +import org.apache.spark.sql.connect.config.Connect +import org.apache.spark.sql.connect.service.SparkConnectService + +/** + * Provides a [[SparkSession connect.SparkSession]] backed by an in-process gRPC server. + * Extends [[sql.SparkSessionBinder sql.SparkSessionBinder]] (which creates a + * [[classic.SparkSession classic.SparkSession]] and SparkContext), then layers a Connect client + * session on top by starting the gRPC service in-process. + * + * Mix in this trait to exercise existing sql/core test suites through the Connect path: + * {{{ + * class FooWithConnectSuite + * extends FooSuite + * with connect.SparkSessionBinder + * }}} + */ +trait SparkSessionBinder extends sql.SparkSessionBinder with QueryTest { + + private var _connectSpark: SparkSession = _ + + protected override def spark: SparkSession = _connectSpark + + /** The underlying classic session used by the in-process server. */ + private def classicSpark: classic.SparkSession = super.spark.asInstanceOf[classic.SparkSession] + + override def beforeAll(): Unit = { + super.beforeAll() + val prevPort = SparkEnv.get.conf.get(Connect.CONNECT_GRPC_BINDING_PORT) + try { + // set GRPC_BINDING_PORT to 0 so that the server picks a random, freely available port. + SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, 0) + SparkConnectService.start(classicSpark.sparkContext) + } finally { + SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, prevPort) + } + val client = SparkConnectClient + .builder() + .port(SparkConnectService.localPort) + .sessionId(UUID.randomUUID().toString) + .userId("test") + .build() + _connectSpark = SparkSession + .builder() + .client(client) + .create() + } + + override def afterAll(): Unit = { + try { + if (_connectSpark != null) { + _connectSpark.close() + _connectSpark = null + } + SparkConnectService.stop() + } finally { + super.afterAll() + } + } +} diff --git a/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala new file mode 100644 index 0000000000000..d9e456c0fd706 --- /dev/null +++ b/sql/connect/server/src/test/scala/org/apache/spark/sql/connect/SparkSessionProvider.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.connect + +import org.apache.spark.sql + +/** + * A common trait for test suites or utils that require a connect [[SparkSession]]. + * Use together with e.g. [[SparkSessionBinder]]. + */ +trait SparkSessionProvider extends sql.SparkSessionProvider { + protected override def spark: SparkSession +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala index 5a1ea3d9f53cf..17212fa30b954 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala @@ -510,6 +510,7 @@ trait QueryTestBase /** * Strip Spark-side filtering in order to check if a datasource filters rows correctly. */ + @deprecated("Classic-only method, use classic.QueryTest", "4.2.0") protected def stripSparkFilter(df: DataFrame): DataFrame = { val schema = df.schema val withoutFilters = df.queryExecution.executedPlan.transform { @@ -524,6 +525,7 @@ trait QueryTestBase * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier * way to construct `DataFrame` directly out of local data without relying on implicits. */ + @deprecated("Classic-only method, use classic.QueryTest", "4.2.0") protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): classic.DataFrame = { classic.Dataset.ofRows(spark.asInstanceOf[classic.SparkSession], plan) } @@ -1211,7 +1213,7 @@ object QueryTest extends Assertions { } -class QueryTestSuite extends test.SharedSparkSession { +class QueryTestSuite extends QueryTest with SparkSessionBinder { test("SPARK-16940: checkAnswer should raise TestFailedException for wrong results") { intercept[org.scalatest.exceptions.TestFailedException] { checkAnswer(sql("SELECT 1"), Row(2) :: Nil) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala new file mode 100644 index 0000000000000..a3ca244ca3718 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala @@ -0,0 +1,173 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql + +import scala.concurrent.duration._ + +import org.scalatest.{BeforeAndAfterEach, Suite} +import org.scalatest.concurrent.Eventually + +import org.apache.spark.{DebugFilesystem, SparkConf} +import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK +import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode +import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation +import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +import org.apache.spark.sql.test.TestSparkSession + +trait SparkSessionBinder extends QueryTest with SparkSessionBinderBase { + + /** + * Suites extending this trait are sharing resources (e.g. SparkSession) in their + * tests. This trait initializes the spark session in its [[beforeAll()]] implementation before + * the automatic thread snapshot is performed, so the audit code could fail to report threads + * leaked by that shared session. + * + * The behavior is overridden here to take the snapshot before the spark session is initialized. + */ + override protected val enableAutoThreadAudit = false + + protected override def beforeAll(): Unit = { + doThreadPreAudit() + super.beforeAll() + } + + protected override def afterAll(): Unit = { + try { + super.afterAll() + } finally { + doThreadPostAudit() + } + } +} + +trait SparkSessionBinderBase + extends QueryTestBase + with SparkSessionProvider + with BeforeAndAfterEach + with Eventually { self: Suite => + + protected def sparkConf = { + val conf = new SparkConf() + .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) + .set(UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) + .set(SQLConf.CODEGEN_FALLBACK.key, "false") + .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) + // Disable ConvertToLocalRelation for better test coverage. Test cases built on + // LocalRelation will exercise the optimization rules better by disabling it as + // this rule may potentially block testing of other optimization rules such as + // ConstantPropagation etc. + .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) + conf.set( + StaticSQLConf.WAREHOUSE_PATH, + conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) + conf.set(StaticSQLConf.LOAD_SESSION_EXTENSIONS_FROM_CLASSPATH, false) + conf.set(StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD, + sys.env.getOrElse("SPARK_TEST_SQL_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD", + StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) + conf.set(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD, + sys.env.getOrElse("SPARK_TEST_SQL_RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD", + StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) + } + + /** + * The [[TestSparkSession]] to use for all tests in this suite. + * + * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local + * mode with the default test configurations. + */ + private var _spark: classic.SparkSession = null + + protected override def spark: SparkSession = _spark + + /** + * The [[TestSQLContext]] to use for all tests in this suite. + */ + protected implicit def sqlContext: SQLContext = _spark.sqlContext + + protected def createSparkSession: classic.SparkSession = { + classic.SparkSession.cleanupAnyExistingSession() + new TestSparkSession(sparkConf) + } + + protected def sqlConf: SQLConf = _spark.sessionState.conf + + /** + * Initialize the [[TestSparkSession]]. Generally, this is just called from + * beforeAll; however, in test using styles other than FunSuite, there is + * often code that relies on the session between test group constructs and + * the actual tests, which may need this session. It is purely a semantic + * difference, but semantically, it makes more sense to call + * 'initializeSession' between a 'describe' and an 'it' call than it does to + * call 'beforeAll'. + */ + protected def initializeSession(): Unit = { + if (_spark == null) { + _spark = createSparkSession + } + } + + /** + * Make sure the [[TestSparkSession]] is initialized before any tests are run. + */ + protected override def beforeAll(): Unit = { + initializeSession() + + // Ensure we have initialized the context before calling parent code + super.beforeAll() + } + + /** + * Stop the underlying [[org.apache.spark.SparkContext]], if any. + */ + protected override def afterAll(): Unit = { + try { + super.afterAll() + } finally { + try { + if (_spark != null) { + try { + _spark.sessionState.catalog.reset() + } finally { + _spark.stop() + _spark = null + } + } + } finally { + SparkSession.clearActiveSession() + SparkSession.clearDefaultSession() + } + } + } + + protected override def beforeEach(): Unit = { + super.beforeEach() + DebugFilesystem.clearOpenStreams() + } + + protected override def afterEach(): Unit = { + super.afterEach() + // Clear all persistent datasets after each test + _spark.sharedState.cacheManager.clearCache() + // files can be closed from other threads, so wait a bit + // normally this doesn't take more than 1s + eventually(timeout(10.seconds), interval(2.seconds)) { + DebugFilesystem.assertNoOpenStreams() + } + } + +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala new file mode 100644 index 0000000000000..20941dd0c549b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/QueryTest.scala @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import scala.language.implicitConversions + +import org.apache.spark.sql +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.FilterExec + +/** + * Extends [[org.apache.spark.sql.QueryTest sql.QueryTest]] to explicitly provide + * a [[SparkSession classic.SparkSession]] and corresponding helpers. + * + * Use this trait to indicate that a test is classic-only, + * i.e it is not intended to run this test with a + * [[org.apache.spark.sql.connect.QueryTest connect.QueryTest]] override. + */ +trait QueryTest extends sql.QueryTest with SparkSessionProvider { + + /** + * Strip Spark-side filtering in order to check if a datasource filters rows correctly. + */ + protected def stripSparkFilter(df: DataFrame): DataFrame = { + val schema = df.schema + val withoutFilters = df.queryExecution.executedPlan.transform { + case FilterExec(_, child) => child + } + + spark.internalCreateDataFrame(withoutFilters.execute(), schema) + } + + /** + * Turn a logical plan into a `DataFrame`. This should be removed once we have an easier + * way to construct `DataFrame` directly out of local data without relying on implicits. + */ + protected implicit override def logicalPlanToSparkQuery(plan: LogicalPlan): DataFrame = { + Dataset.ofRows(spark, plan) + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala new file mode 100644 index 0000000000000..68920a445e5fc --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionBinder.scala @@ -0,0 +1,56 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import scala.concurrent.duration._ + +import org.apache.spark.sql + +/** + * Overrides [[spark]] to provide a [[SparkSession classic.SparkSession]] + */ +trait SparkSessionBinder extends sql.SparkSessionBinder with QueryTest { + override protected def spark: SparkSession = super.spark.asInstanceOf[SparkSession] + + // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that + // execution as a map keyed by (planNodeId, planNodeName, metricName) -> metricValue. + def runAndFetchMetrics(func: => Unit): Map[(Long, String, String), String] = { + val statusStore = spark.sharedState.statusStore + val oldCount = statusStore.executionsList().size + + func + + // Wait until the new execution is started and being tracked. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsCount() >= oldCount) + } + + // Wait for listener to finish computing the metrics for the execution. + eventually(timeout(10.seconds), interval(10.milliseconds)) { + assert(statusStore.executionsList().nonEmpty && + statusStore.executionsList().last.metricValues != null) + } + + val exec = statusStore.executionsList().last + val execId = exec.executionId + val sqlMetrics = statusStore.planGraph(execId).allNodes + .flatMap(n => n.metrics.map(m => (m.accumulatorId, (n.id, n.name, m.name)))) + .toMap + statusStore.executionMetrics(execId).map { case (k, v) => sqlMetrics(k) -> v } + } +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala new file mode 100644 index 0000000000000..77de0db4bf68b --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/classic/SparkSessionProvider.scala @@ -0,0 +1,24 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.classic + +import org.apache.spark.sql + +trait SparkSessionProvider extends sql.SparkSessionProvider { + override protected def spark: SparkSession +} diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala index b20b6d397fd17..5c2bc6829ea59 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/parquet/ParquetQuerySuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.{InternalRow, TableIdentifier} import org.apache.spark.sql.catalyst.expressions.SpecificInternalRow import org.apache.spark.sql.catalyst.util.ArrayData +import org.apache.spark.sql.classic import org.apache.spark.sql.execution.FileSourceScanExec import org.apache.spark.sql.execution.datasources.{SchemaColumnConvertNotSupportedException, SQLHadoopMapReduceCommitProtocol} import org.apache.spark.sql.execution.datasources.parquet.TestingUDT._ @@ -37,14 +38,15 @@ import org.apache.spark.sql.execution.datasources.v2.BatchScanExec import org.apache.spark.sql.execution.datasources.v2.parquet.ParquetScan import org.apache.spark.sql.functions.struct import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.test.SharedSparkSession import org.apache.spark.sql.types._ import org.apache.spark.util.Utils /** * A test suite that tests various Parquet queries. */ -abstract class ParquetQuerySuite extends ParquetTest with SharedSparkSession { +abstract class ParquetQuerySuite extends ParquetTest + with QueryTest + with classic.SparkSessionBinder { import testImplicits._ test("simple select queries") { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala index fb26d3311ebef..c52bcd4aa9c2f 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/test/SharedSparkSession.scala @@ -17,189 +17,21 @@ package org.apache.spark.sql.test -import scala.concurrent.duration._ +import org.scalatest.Suite -import org.scalatest.{BeforeAndAfterEach, Suite} -import org.scalatest.concurrent.Eventually +import org.apache.spark.sql +import org.apache.spark.sql.classic -import org.apache.spark.{DebugFilesystem, SparkConf} -import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK -import org.apache.spark.sql.{classic, QueryTest, QueryTestBase, SparkSession, SparkSessionProvider, SQLContext} -import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode -import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation -import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf} +@deprecated("Use SparkSessionBinder (or classic.SparkSessionBinder if required) instead", "4.2.0") +trait SharedSparkSession extends classic.SparkSessionBinder -trait SharedSparkSession extends QueryTest with SharedSparkSessionBase { - - /** - * Suites extending [[SharedSparkSession]] are sharing resources (e.g. SparkSession) in their - * tests. That trait initializes the spark session in its [[beforeAll()]] implementation before - * the automatic thread snapshot is performed, so the audit code could fail to report threads - * leaked by that shared session. - * - * The behavior is overridden here to take the snapshot before the spark session is initialized. - */ - override protected val enableAutoThreadAudit = false - - protected override def beforeAll(): Unit = { - doThreadPreAudit() - super.beforeAll() - } - - protected override def afterAll(): Unit = { - try { - super.afterAll() - } finally { - doThreadPostAudit() - } - } - - // Runs func (which must trigger exactly one SQL execution) and returns the SQL metrics of that - // execution as a map keyed by (planNodeId, planNodeName, metricName) -> metricValue. - def runAndFetchMetrics(func: => Unit): Map[(Long, String, String), String] = { - val statusStore = spark.sharedState.statusStore - val oldCount = statusStore.executionsList().size - - func - - // Wait until the new execution is started and being tracked. - eventually(timeout(10.seconds), interval(10.milliseconds)) { - assert(statusStore.executionsCount() >= oldCount) - } - - // Wait for listener to finish computing the metrics for the execution. - eventually(timeout(10.seconds), interval(10.milliseconds)) { - assert(statusStore.executionsList().nonEmpty && - statusStore.executionsList().last.metricValues != null) - } - - val exec = statusStore.executionsList().last - val execId = exec.executionId - val sqlMetrics = statusStore.planGraph(execId).allNodes - .flatMap(n => n.metrics.map(m => (m.accumulatorId, (n.id, n.name, m.name)))) - .toMap - statusStore.executionMetrics(execId).map { case (k, v) => sqlMetrics(k) -> v } - } -} /** * Helper trait for SQL test suites where all tests share a single [[TestSparkSession]]. */ -trait SharedSparkSessionBase - extends QueryTestBase - with SparkSessionProvider - with BeforeAndAfterEach - with Eventually { self: Suite => - - protected def sparkConf = { - val conf = new SparkConf() - .set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName) - .set(UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true) - .set(SQLConf.CODEGEN_FALLBACK.key, "false") - .set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString) - // Disable ConvertToLocalRelation for better test coverage. Test cases built on - // LocalRelation will exercise the optimization rules better by disabling it as - // this rule may potentially block testing of other optimization rules such as - // ConstantPropagation etc. - .set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName) - conf.set( - StaticSQLConf.WAREHOUSE_PATH, - conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName) - conf.set(StaticSQLConf.LOAD_SESSION_EXTENSIONS_FROM_CLASSPATH, false) - conf.set(StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD, - sys.env.getOrElse("SPARK_TEST_SQL_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD", - StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) - conf.set(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD, - sys.env.getOrElse("SPARK_TEST_SQL_RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD", - StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt) - } - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - * - * By default, the underlying [[org.apache.spark.SparkContext]] will be run in local - * mode with the default test configurations. - */ - private var _spark: TestSparkSession = null - - /** - * The [[TestSparkSession]] to use for all tests in this suite. - */ - protected override def spark: classic.SparkSession = _spark - - /** - * The [[TestSQLContext]] to use for all tests in this suite. - */ - protected implicit def sqlContext: SQLContext = _spark.sqlContext - - protected def createSparkSession: TestSparkSession = { - classic.SparkSession.cleanupAnyExistingSession() - new TestSparkSession(sparkConf) - } - - protected def sqlConf: SQLConf = _spark.sessionState.conf - - /** - * Initialize the [[TestSparkSession]]. Generally, this is just called from - * beforeAll; however, in test using styles other than FunSuite, there is - * often code that relies on the session between test group constructs and - * the actual tests, which may need this session. It is purely a semantic - * difference, but semantically, it makes more sense to call - * 'initializeSession' between a 'describe' and an 'it' call than it does to - * call 'beforeAll'. - */ - protected def initializeSession(): Unit = { - if (_spark == null) { - _spark = createSparkSession - } - } - - /** - * Make sure the [[TestSparkSession]] is initialized before any tests are run. - */ - protected override def beforeAll(): Unit = { - initializeSession() - - // Ensure we have initialized the context before calling parent code - super.beforeAll() - } - - /** - * Stop the underlying [[org.apache.spark.SparkContext]], if any. - */ - protected override def afterAll(): Unit = { - try { - super.afterAll() - } finally { - try { - if (_spark != null) { - try { - _spark.sessionState.catalog.reset() - } finally { - _spark.stop() - _spark = null - } - } - } finally { - SparkSession.clearActiveSession() - SparkSession.clearDefaultSession() - } - } - } - - protected override def beforeEach(): Unit = { - super.beforeEach() - DebugFilesystem.clearOpenStreams() - } +@deprecated("Use SparkSessionBinder (or classic.SparkSessionBinder if required) instead", "4.2.0") +trait SharedSparkSessionBase extends sql.SparkSessionBinderBase { self: Suite => - protected override def afterEach(): Unit = { - super.afterEach() - // Clear all persistent datasets after each test - spark.sharedState.cacheManager.clearCache() - // files can be closed from other threads, so wait a bit - // normally this doesn't take more than 1s - eventually(timeout(10.seconds), interval(2.seconds)) { - DebugFilesystem.assertNoOpenStreams() - } - } + protected override def spark: classic.SparkSession = + super.spark.asInstanceOf[classic.SparkSession] }