Skip to content
Open
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect

import org.apache.spark.{sql => sqlApi}

/**
* Extends [[sqlApi.QueryTest]] to provide connect-specific overrides to helpers like
* [[checkAnswer]] that avoid classic-only APIs.
*
* Can be used together with [[SparkSessionBinder connect.SparkSessionBinder]] to create a
* 'connect variant' of a test.
*
* Note: broader use will require more overrides.
*/
trait QueryTest extends sqlApi.QueryTest with SparkSessionProvider {

override protected def checkAnswer(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This overrides only the checkAnswer(df, Seq[Row]) variant, which is enough for QueryTestSuite. But the stated goal is re-running arbitrary sql/core suites over Connect, and the other QueryTest helpers (other checkAnswer overloads, checkDataset, ...) still reach classic-only paths like queryExecution/logicalPlan. Worth a line in the trait doc noting that broader reuse will need more overrides.

df: => sqlApi.DataFrame, expectedAnswer: Seq[sqlApi.Row]): Unit = {
val sparkAnswer = df.collect().toSeq
sqlApi.QueryTest.sameRows(expectedAnswer, sparkAnswer) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect

import org.apache.spark.sql.QueryTestSuite

/**
* Runs [[QueryTestSuite]] tests through a Connect session.
*
* This validates the `FooSuite with connect.SparkSessionBinder` pattern: the existing
* [[QueryTestSuite]] tests are inherited unchanged, but execute against a
* [[SparkSession connect.SparkSession]] instead of a classic one.
*/
class QueryTestWithConnectSuite
extends QueryTestSuite
with SparkSessionBinder
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect

import java.util.UUID

import org.apache.spark.SparkEnv
import org.apache.spark.sql
import org.apache.spark.sql.classic
import org.apache.spark.sql.connect.client.SparkConnectClient
import org.apache.spark.sql.connect.config.Connect
import org.apache.spark.sql.connect.service.SparkConnectService

/**
* Provides a [[SparkSession connect.SparkSession]] backed by an in-process gRPC server.
* Extends [[sql.SparkSessionBinder sql.SparkSessionBinder]] (which creates a
* [[classic.SparkSession classic.SparkSession]] and SparkContext), then layers a Connect client
* session on top by starting the gRPC service in-process.
*
* Mix in this trait to exercise existing sql/core test suites through the Connect path:
* {{{
* class FooWithConnectSuite
* extends FooSuite
* with connect.SparkSessionBinder
* }}}
*/
trait SparkSessionBinder extends sql.SparkSessionBinder with QueryTest {

private var _connectSpark: SparkSession = _

protected override def spark: SparkSession = _connectSpark

/** The underlying classic session used by the in-process server. */
private def classicSpark: classic.SparkSession = super.spark.asInstanceOf[classic.SparkSession]

override def beforeAll(): Unit = {
super.beforeAll()
val prevPort = SparkEnv.get.conf.get(Connect.CONNECT_GRPC_BINDING_PORT)
try {
// set GRPC_BINDING_PORT to 0 so that the server picks a random, freely available port.
SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, 0)
SparkConnectService.start(classicSpark.sparkContext)
} finally {
SparkEnv.get.conf.set(Connect.CONNECT_GRPC_BINDING_PORT, prevPort)
}
val client = SparkConnectClient
.builder()
.port(SparkConnectService.localPort)
.sessionId(UUID.randomUUID().toString)
.userId("test")
.build()
_connectSpark = SparkSession
.builder()
.client(client)
.create()
}

override def afterAll(): Unit = {
try {
if (_connectSpark != null) {
_connectSpark.close()
_connectSpark = null
}
SparkConnectService.stop()
} finally {
super.afterAll()
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.connect

import org.apache.spark.sql

/**
* A common trait for test suites or utils that require a connect [[SparkSession]].
* Use together with e.g. [[SparkSessionBinder]].
*/
trait SparkSessionProvider extends sql.SparkSessionProvider {
protected override def spark: SparkSession
}
4 changes: 3 additions & 1 deletion sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,7 @@ trait QueryTestBase
/**
* Strip Spark-side filtering in order to check if a datasource filters rows correctly.
*/
@deprecated("Classic-only method, use classic.QueryTest", "4.2.0")
protected def stripSparkFilter(df: DataFrame): DataFrame = {
val schema = df.schema
val withoutFilters = df.queryExecution.executedPlan.transform {
Expand All @@ -524,6 +525,7 @@ trait QueryTestBase
* Turn a logical plan into a `DataFrame`. This should be removed once we have an easier
* way to construct `DataFrame` directly out of local data without relying on implicits.
*/
@deprecated("Classic-only method, use classic.QueryTest", "4.2.0")
protected implicit def logicalPlanToSparkQuery(plan: LogicalPlan): classic.DataFrame = {
classic.Dataset.ofRows(spark.asInstanceOf[classic.SparkSession], plan)
}
Expand Down Expand Up @@ -1211,7 +1213,7 @@ object QueryTest extends Assertions {

}

class QueryTestSuite extends test.SharedSparkSession {
class QueryTestSuite extends QueryTest with SparkSessionBinder {
test("SPARK-16940: checkAnswer should raise TestFailedException for wrong results") {
intercept[org.scalatest.exceptions.TestFailedException] {
checkAnswer(sql("SELECT 1"), Row(2) :: Nil)
Expand Down
173 changes: 173 additions & 0 deletions sql/core/src/test/scala/org/apache/spark/sql/SparkSessionBinder.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,173 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql

import scala.concurrent.duration._

import org.scalatest.{BeforeAndAfterEach, Suite}
import org.scalatest.concurrent.Eventually

import org.apache.spark.{DebugFilesystem, SparkConf}
import org.apache.spark.internal.config.UNSAFE_EXCEPTION_ON_MEMORY_LEAK
import org.apache.spark.sql.catalyst.expressions.CodegenObjectFactoryMode
import org.apache.spark.sql.catalyst.optimizer.ConvertToLocalRelation
import org.apache.spark.sql.internal.{SQLConf, StaticSQLConf}
import org.apache.spark.sql.test.TestSparkSession

trait SparkSessionBinder extends QueryTest with SparkSessionBinderBase {

/**
* Suites extending this trait are sharing resources (e.g. SparkSession) in their
* tests. This trait initializes the spark session in its [[beforeAll()]] implementation before
* the automatic thread snapshot is performed, so the audit code could fail to report threads
* leaked by that shared session.
*
* The behavior is overridden here to take the snapshot before the spark session is initialized.
*/
override protected val enableAutoThreadAudit = false

protected override def beforeAll(): Unit = {
doThreadPreAudit()
super.beforeAll()
}

protected override def afterAll(): Unit = {
try {
super.afterAll()
} finally {
doThreadPostAudit()
}
}
}

trait SparkSessionBinderBase
extends QueryTestBase
with SparkSessionProvider
with BeforeAndAfterEach
with Eventually { self: Suite =>

protected def sparkConf = {
val conf = new SparkConf()
.set("spark.hadoop.fs.file.impl", classOf[DebugFilesystem].getName)
.set(UNSAFE_EXCEPTION_ON_MEMORY_LEAK, true)
.set(SQLConf.CODEGEN_FALLBACK.key, "false")
.set(SQLConf.CODEGEN_FACTORY_MODE.key, CodegenObjectFactoryMode.CODEGEN_ONLY.toString)
// Disable ConvertToLocalRelation for better test coverage. Test cases built on
// LocalRelation will exercise the optimization rules better by disabling it as
// this rule may potentially block testing of other optimization rules such as
// ConstantPropagation etc.
.set(SQLConf.OPTIMIZER_EXCLUDED_RULES.key, ConvertToLocalRelation.ruleName)
conf.set(
StaticSQLConf.WAREHOUSE_PATH,
conf.get(StaticSQLConf.WAREHOUSE_PATH) + "/" + getClass.getCanonicalName)
conf.set(StaticSQLConf.LOAD_SESSION_EXTENSIONS_FROM_CLASSPATH, false)
conf.set(StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD,
sys.env.getOrElse("SPARK_TEST_SQL_SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD",
StaticSQLConf.SHUFFLE_EXCHANGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt)
conf.set(StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD,
sys.env.getOrElse("SPARK_TEST_SQL_RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD",
StaticSQLConf.RESULT_QUERY_STAGE_MAX_THREAD_THRESHOLD.defaultValueString).toInt)
}

/**
* The [[TestSparkSession]] to use for all tests in this suite.
*
* By default, the underlying [[org.apache.spark.SparkContext]] will be run in local
* mode with the default test configurations.
*/
private var _spark: classic.SparkSession = null

protected override def spark: SparkSession = _spark

/**
* The [[TestSQLContext]] to use for all tests in this suite.
*/
protected implicit def sqlContext: SQLContext = _spark.sqlContext

protected def createSparkSession: classic.SparkSession = {
classic.SparkSession.cleanupAnyExistingSession()
new TestSparkSession(sparkConf)
}

protected def sqlConf: SQLConf = _spark.sessionState.conf

/**
* Initialize the [[TestSparkSession]]. Generally, this is just called from
* beforeAll; however, in test using styles other than FunSuite, there is
* often code that relies on the session between test group constructs and
* the actual tests, which may need this session. It is purely a semantic
* difference, but semantically, it makes more sense to call
* 'initializeSession' between a 'describe' and an 'it' call than it does to
* call 'beforeAll'.
*/
protected def initializeSession(): Unit = {
if (_spark == null) {
_spark = createSparkSession
}
}

/**
* Make sure the [[TestSparkSession]] is initialized before any tests are run.
*/
protected override def beforeAll(): Unit = {
initializeSession()

// Ensure we have initialized the context before calling parent code
super.beforeAll()
}

/**
* Stop the underlying [[org.apache.spark.SparkContext]], if any.
*/
protected override def afterAll(): Unit = {
try {
super.afterAll()
} finally {
try {
if (_spark != null) {
try {
_spark.sessionState.catalog.reset()
} finally {
_spark.stop()
_spark = null
}
}
} finally {
SparkSession.clearActiveSession()
SparkSession.clearDefaultSession()
}
}
}

protected override def beforeEach(): Unit = {
super.beforeEach()
DebugFilesystem.clearOpenStreams()
}

protected override def afterEach(): Unit = {
super.afterEach()
// Clear all persistent datasets after each test
_spark.sharedState.cacheManager.clearCache()
// files can be closed from other threads, so wait a bit
// normally this doesn't take more than 1s
eventually(timeout(10.seconds), interval(2.seconds)) {
DebugFilesystem.assertNoOpenStreams()
}
}

}
Loading