From ae36d8712ba11510467513598c77dc3c77caf35d Mon Sep 17 00:00:00 2001 From: Ganesha S Date: Sun, 21 Sep 2025 13:05:05 +0530 Subject: [PATCH] [SPARK-53656][SS] Refactor MemoryStream to use SparkSession instead of SQLContext --- .../execution/streaming/runtime/memory.scala | 43 +++++++++++++----- .../sources/ContinuousMemoryStream.scala | 17 ++++--- .../org/apache/spark/sql/DatasetSuite.scala | 2 +- ...ressTrackingMicroBatchExecutionSuite.scala | 44 +++++++++---------- .../execution/streaming/MemorySinkSuite.scala | 22 ++++++++++ .../streaming/MicroBatchExecutionSuite.scala | 4 +- .../state/StateStoreCoordinatorSuite.scala | 13 +++--- .../streaming/state/StateStoreSuite.scala | 3 +- .../AcceptsLatestSeenOffsetSuite.scala | 16 +++---- .../spark/sql/streaming/StreamSuite.scala | 2 +- .../StreamingQueryListenerSuite.scala | 12 ++--- .../StreamingQueryManagerSuite.scala | 12 ++--- .../sql/streaming/StreamingQuerySuite.scala | 2 +- .../TransformWithStateClusterSuite.scala | 4 +- .../streaming/TransformWithStateSuite.scala | 3 +- .../streaming/TriggerAvailableNowSuite.scala | 2 +- .../test/DataStreamReaderWriterSuite.scala | 4 +- .../sql/hive/execution/HiveDDLSuite.scala | 3 +- .../spark/sql/pipelines/graph/elements.scala | 2 +- .../graph/ConnectInvalidPipelineSuite.scala | 5 ++- .../graph/ConnectValidPipelineSuite.scala | 5 +++ .../graph/MaterializeTablesSuite.scala | 7 +-- .../graph/TriggeredGraphExecutionSuite.scala | 6 ++- 23 files changed, 146 insertions(+), 87 deletions(-) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala index 68eb3cc7688d2..dc7224e209da9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/runtime/memory.scala @@ -24,7 +24,7 @@ import javax.annotation.concurrent.GuardedBy import scala.collection.mutable.ListBuffer import org.apache.spark.internal.Logging -import org.apache.spark.sql.{Encoder, SQLContext} +import org.apache.spark.sql.{Encoder, SparkSession, SQLContext} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder} import org.apache.spark.sql.catalyst.expressions.UnsafeRow @@ -43,32 +43,53 @@ import org.apache.spark.sql.internal.connector.SimpleTableProvider import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap -object MemoryStream { +object MemoryStream extends LowPriorityMemoryStreamImplicits { protected val currentBlockId = new AtomicInteger(0) protected val memoryStreamId = new AtomicInteger(0) - def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] = - new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) + def apply[A : Encoder](implicit sparkSession: SparkSession): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) - def apply[A : Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] = - new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, Some(numPartitions)) + def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: SparkSession): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, Some(numPartitions)) +} + +/** + * Provides lower-priority implicits for MemoryStream to prevent ambiguity when both + * SparkSession and SQLContext are in scope. The implicits in the companion object, + * which use SparkSession, take higher precedence. + */ +trait LowPriorityMemoryStreamImplicits { + this: MemoryStream.type => + + // Deprecated: Used when an implicit SQLContext is in scope + @deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0") + def apply[A: Encoder]()(implicit sqlContext: SQLContext): MemoryStream[A] = + new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession) + + @deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0") + def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] = + new MemoryStream[A]( + memoryStreamId.getAndIncrement(), + sqlContext.sparkSession, + Some(numPartitions)) } /** * A base class for memory stream implementations. Supports adding data and resetting. */ -abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends SparkDataStream { +abstract class MemoryStreamBase[A : Encoder](sparkSession: SparkSession) extends SparkDataStream { val encoder = encoderFor[A] protected val attributes = toAttributes(encoder.schema) protected lazy val toRow: ExpressionEncoder.Serializer[A] = encoder.createSerializer() def toDS(): Dataset[A] = { - Dataset[A](sqlContext.sparkSession, logicalPlan) + Dataset[A](sparkSession, logicalPlan) } def toDF(): DataFrame = { - Dataset.ofRows(sqlContext.sparkSession, logicalPlan) + Dataset.ofRows(sparkSession, logicalPlan) } def addData(data: A*): OffsetV2 = { @@ -156,9 +177,9 @@ class MemoryStreamScanBuilder(stream: MemoryStreamBase[_]) extends ScanBuilder w */ case class MemoryStream[A : Encoder]( id: Int, - sqlContext: SQLContext, + sparkSession: SparkSession, numPartitions: Option[Int] = None) - extends MemoryStreamBase[A](sqlContext) + extends MemoryStreamBase[A](sparkSession) with MicroBatchStream with SupportsTriggerAvailableNow with Logging { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala index 03884d02faeb7..fd588235c2536 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/streaming/sources/ContinuousMemoryStream.scala @@ -27,7 +27,7 @@ import org.json4s.jackson.Serialization import org.apache.spark.{SparkEnv, TaskContext} import org.apache.spark.rpc.RpcEndpointRef -import org.apache.spark.sql.{Encoder, SQLContext} +import org.apache.spark.sql.{Encoder, SparkSession} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.UnsafeRow import org.apache.spark.sql.connector.read.InputPartition @@ -44,8 +44,11 @@ import org.apache.spark.util.RpcUtils * ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at * the specified offset within the list, or null if that offset doesn't yet have a record. */ -class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2) - extends MemoryStreamBase[A](sqlContext) with ContinuousStream { +class ContinuousMemoryStream[A : Encoder]( + id: Int, + sparkSession: SparkSession, + numPartitions: Int = 2) + extends MemoryStreamBase[A](sparkSession) with ContinuousStream { private implicit val formats: Formats = Serialization.formats(NoTypeHints) @@ -112,11 +115,11 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa object ContinuousMemoryStream { protected val memoryStreamId = new AtomicInteger(0) - def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = - new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext) + def apply[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession) - def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] = - new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1) + def singlePartition[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] = + new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1) } /** diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala index 653ad7bc34332..941fd22054242 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala @@ -1012,7 +1012,7 @@ class DatasetSuite extends QueryTest assert(err.getMessage.contains("An Observation can be used with a Dataset only once")) // streaming datasets are not supported - val streamDf = new MemoryStream[Int](0, sqlContext).toDF() + val streamDf = new MemoryStream[Int](0, spark).toDF() val streamObservation = Observation("stream") val streamErr = intercept[IllegalArgumentException] { streamDf.observe(streamObservation, avg($"value").cast("int").as("avg_val")) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala index 218b66b779463..e31e0e70cf39c 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/AsyncProgressTrackingMicroBatchExecutionSuite.scala @@ -67,9 +67,9 @@ class AsyncProgressTrackingMicroBatchExecutionSuite class MemoryStreamCapture[A: Encoder]( id: Int, - sqlContext: SQLContext, + sparkSession: SparkSession, numPartitions: Option[Int] = None) - extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) { + extends MemoryStream[A](id, sparkSession, numPartitions = numPartitions) { val commits = new ListBuffer[streaming.Offset]() val commitThreads = new ListBuffer[Thread]() @@ -136,7 +136,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite test("async WAL commits recovery") { val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStream[Int](id = 0, spark) val ds = inputData.toDF() var index = 0 @@ -204,7 +204,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite } test("async WAL commits turn on and off") { - val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStream[Int](id = 0, spark) val ds = inputData.toDS() val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -308,7 +308,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite } test("Fail with once trigger") { - val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStream[Int](id = 0, spark) val ds = inputData.toDF() val e = intercept[IllegalArgumentException] { @@ -323,7 +323,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite test("Fail with available now trigger") { - val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStream[Int](id = 0, spark) val ds = inputData.toDF() val e = intercept[IllegalArgumentException] { @@ -339,7 +339,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite test("switching between async wal commit enabled and trigger once") { val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStream[Int](id = 0, spark) val ds = inputData.toDF() var index = 0 @@ -500,7 +500,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite test("switching between async wal commit enabled and available now") { val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStream[Int](id = 0, spark) val ds = inputData.toDF() var index = 0 @@ -669,7 +669,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite } def testAsyncWriteErrorsAlreadyExists(path: String): Unit = { - val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStream[Int](id = 0, spark) val ds = inputData.toDS() val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -720,7 +720,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite } def testAsyncWriteErrorsPermissionsIssue(path: String): Unit = { - val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStream[Int](id = 0, spark) val ds = inputData.toDS() val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath val commitDir = new File(checkpointLocation + path) @@ -778,7 +778,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath - val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStreamCapture[Int](id = 0, spark) val ds = inputData.toDF() @@ -852,7 +852,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite } test("interval commits and recovery") { - val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStreamCapture[Int](id = 0, spark) val ds = inputData.toDS() val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -934,7 +934,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite } test("recovery when first offset is not zero and not commit log entries") { - val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStreamCapture[Int](id = 0, spark) val ds = inputData.toDS() val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -961,7 +961,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite /** * start new stream */ - val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext) + val inputData2 = new MemoryStreamCapture[Int](id = 0, spark) val ds2 = inputData2.toDS() testStream(ds2, extraOptions = Map( ASYNC_PROGRESS_TRACKING_ENABLED -> "true", @@ -995,7 +995,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite } test("recovery non-contiguous log") { - val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStreamCapture[Int](id = 0, spark) val ds = inputData.toDS() val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -1088,7 +1088,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite } test("Fail on pipelines using unsupported sinks") { - val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStream[Int](id = 0, spark) val ds = inputData.toDF() val e = intercept[IllegalArgumentException] { @@ -1109,7 +1109,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "false") { withTempDir { checkpointLocation => - val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStreamCapture[Int](id = 0, spark) val ds = inputData.toDS() val clock = new StreamManualClock @@ -1243,7 +1243,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite test("with async log purging") { withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") { withTempDir { checkpointLocation => - val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStreamCapture[Int](id = 0, spark) val ds = inputData.toDS() val clock = new StreamManualClock @@ -1381,7 +1381,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite } test("test multiple gaps in offset and commit logs") { - val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStreamCapture[Int](id = 0, spark) val ds = inputData.toDS() val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -1427,7 +1427,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite /** * start new stream */ - val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext) + val inputData2 = new MemoryStreamCapture[Int](id = 0, spark) val ds2 = inputData2.toDS() testStream(ds2, extraOptions = Map( ASYNC_PROGRESS_TRACKING_ENABLED -> "true", @@ -1460,7 +1460,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite } test("recovery when gaps exist in offset and commit log") { - val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStreamCapture[Int](id = 0, spark) val ds = inputData.toDS() val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath @@ -1494,7 +1494,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite /** * start new stream */ - val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext) + val inputData2 = new MemoryStreamCapture[Int](id = 0, spark) val ds2 = inputData2.toDS() testStream(ds2, extraOptions = Map( ASYNC_PROGRESS_TRACKING_ENABLED -> "true", diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala index 4ec44eac22e36..7246053a296d6 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MemorySinkSuite.scala @@ -343,6 +343,28 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter { intsToDF(expected)(schema)) } + test("LowPriorityMemoryStreamImplicits works with implicit sqlContext") { + // Test that MemoryStream can be created using implicit sqlContext + implicit val sqlContext: SQLContext = spark.sqlContext + + // Test MemoryStream[A]() with implicit sqlContext + val stream1 = MemoryStream[Int]() + assert(stream1 != null) + + // Test MemoryStream[A](numPartitions) with implicit sqlContext + val stream2 = MemoryStream[String](3) + assert(stream2 != null) + + // Verify the streams work correctly + stream1.addData(1, 2, 3) + val df1 = stream1.toDF() + assert(df1.schema.fieldNames.contains("value")) + + stream2.addData("a", "b", "c") + val df2 = stream2.toDF() + assert(df2.schema.fieldNames.contains("value")) + } + private implicit def intsToDF(seq: Seq[Int])(implicit schema: StructType): DataFrame = { require(schema.fields.length === 1) sqlContext.createDataset(seq).toDF(schema.fieldNames.head) diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala index 3fec6e816b839..bd5dc846fd58e 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/MicroBatchExecutionSuite.scala @@ -54,7 +54,7 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter with Match test("async log purging") { withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") { withTempDir { checkpointLocation => - val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStream[Int](id = 0, spark) val ds = inputData.toDS() testStream(ds)( StartStream(checkpointLocation = checkpointLocation.getCanonicalPath), @@ -99,7 +99,7 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter with Match test("error notifier test") { withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") { withTempDir { checkpointLocation => - val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext) + val inputData = new MemoryStream[Int](id = 0, spark) val ds = inputData.toDS() val e = intercept[StreamingQueryException] { diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala index 723bb0a876234..79bcdbca9ec69 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreCoordinatorSuite.scala @@ -123,11 +123,10 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { test("query stop deactivates related store providers") { var coordRef: StateStoreCoordinatorRef = null try { - val spark = SparkSession.builder().sparkContext(sc).getOrCreate() + implicit val spark: SparkSession = SparkSession.builder().sparkContext(sc).getOrCreate() SparkSession.setActiveSession(spark) import spark.implicits._ coordRef = spark.streams.stateStoreCoordinator - implicit val sqlContext = spark.sqlContext spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") // Start a query and run a batch to load state stores @@ -254,7 +253,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sqlContext = spark.sqlContext + implicit val sparkSession: SparkSession = spark val inputData = MemoryStream[Int] val query = setUpStatefulQuery(inputData, "query") // Add, commit, and wait multiple times to force snapshot versions and time difference @@ -290,7 +289,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sqlContext = spark.sqlContext + implicit val sparkSession: SparkSession = spark // Start a join query and run some data to force snapshot uploads val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] @@ -333,7 +332,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sqlContext = spark.sqlContext + implicit val sparkSession: SparkSession = spark // Start and run two queries together with some data to force snapshot uploads val input1 = MemoryStream[Int] val input2 = MemoryStream[Int] @@ -400,7 +399,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sqlContext = spark.sqlContext + implicit val sparkSession: SparkSession = spark // Start a query and run some data to force snapshot uploads val inputData = MemoryStream[Int] val query = setUpStatefulQuery(inputData, "query") @@ -444,7 +443,7 @@ class StateStoreCoordinatorSuite extends SparkFunSuite with SharedSparkContext { ) { case (coordRef, spark) => import spark.implicits._ - implicit val sqlContext = spark.sqlContext + implicit val sparkSession: SparkSession = spark // Start a query and run some data to force snapshot uploads val inputData = MemoryStream[Int] val query = setUpStatefulQuery(inputData, "query") diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala index 0b1483241b922..1acf239df85b5 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/streaming/state/StateStoreSuite.scala @@ -1206,9 +1206,8 @@ class StateStoreSuite extends StateStoreSuiteBase[HDFSBackedStateStoreProvider] test("SPARK-21145: Restarted queries create new provider instances") { try { val checkpointLocation = Utils.createTempDir().getAbsoluteFile - val spark = SparkSession.builder().master("local[2]").getOrCreate() + implicit val spark: SparkSession = SparkSession.builder().master("local[2]").getOrCreate() SparkSession.setActiveSession(spark) - implicit val sqlContext = spark.sqlContext spark.conf.set(SQLConf.SHUFFLE_PARTITIONS.key, "1") import spark.implicits._ val inputData = MemoryStream[Int] diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/AcceptsLatestSeenOffsetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/AcceptsLatestSeenOffsetSuite.scala index 2a4abd99f6c19..6a89d39d1e279 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/AcceptsLatestSeenOffsetSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/AcceptsLatestSeenOffsetSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.streaming import org.scalatest.BeforeAndAfter -import org.apache.spark.sql.{Encoder, SQLContext} +import org.apache.spark.sql.{Encoder} import org.apache.spark.sql.catalyst.plans.logical.Range import org.apache.spark.sql.classic.{DataFrame, Dataset, SparkSession} import org.apache.spark.sql.connector.read.streaming @@ -62,7 +62,7 @@ class AcceptsLatestSeenOffsetSuite extends StreamTest with BeforeAndAfter { } test("DataSource V2 source with micro-batch") { - val inputData = new TestMemoryStream[Long](0, spark.sqlContext) + val inputData = new TestMemoryStream[Long](0, spark) val df = inputData.toDF().select("value") /** Add data to this test source by incrementing its available offset */ @@ -110,7 +110,7 @@ class AcceptsLatestSeenOffsetSuite extends StreamTest with BeforeAndAfter { // Test case: when the query is restarted, we expect the execution to call `latestSeenOffset` // first. Later as part of the execution, execution may call `initialOffset` if the previous // run of the query had no committed batches. - val inputData = new TestMemoryStream[Long](0, spark.sqlContext) + val inputData = new TestMemoryStream[Long](0, spark) val df = inputData.toDF().select("value") /** Add data to this test source by incrementing its available offset */ @@ -152,7 +152,7 @@ class AcceptsLatestSeenOffsetSuite extends StreamTest with BeforeAndAfter { } test("DataSource V2 source with continuous mode") { - val inputData = new TestContinuousMemoryStream[Long](0, spark.sqlContext, 1) + val inputData = new TestContinuousMemoryStream[Long](0, spark, 1) val df = inputData.toDF().select("value") /** Add data to this test source by incrementing its available offset */ @@ -233,9 +233,9 @@ class AcceptsLatestSeenOffsetSuite extends StreamTest with BeforeAndAfter { class TestMemoryStream[A : Encoder]( _id: Int, - _sqlContext: SQLContext, + _sparkSession: SparkSession, _numPartitions: Option[Int] = None) - extends MemoryStream[A](_id, _sqlContext, _numPartitions) + extends MemoryStream[A](_id, _sparkSession, _numPartitions) with AcceptsLatestSeenOffset { @volatile var latestSeenOffset: streaming.Offset = null @@ -260,9 +260,9 @@ class AcceptsLatestSeenOffsetSuite extends StreamTest with BeforeAndAfter { class TestContinuousMemoryStream[A : Encoder]( _id: Int, - _sqlContext: SQLContext, + _sparkSession: SparkSession, _numPartitions: Int = 2) - extends ContinuousMemoryStream[A](_id, _sqlContext, _numPartitions) + extends ContinuousMemoryStream[A](_id, _sparkSession, _numPartitions) with AcceptsLatestSeenOffset { @volatile var latestSeenOffset: streaming.Offset = _ diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala index cbb2eba7ecc89..2ae0de640aaf0 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamSuite.scala @@ -116,7 +116,7 @@ class StreamSuite extends StreamTest { val memoryStream = MemoryStream[Int] val executionRelation = StreamingExecutionRelation( memoryStream, toAttributes(memoryStream.encoder.schema), None)( - memoryStream.sqlContext.sparkSession) + memoryStream.sparkSession) assert(executionRelation.computeStats().sizeInBytes == spark.sessionState.conf.defaultSizeInBytes) } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala index e1d44efc172ea..4eabc82281e14 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryListenerSuite.scala @@ -63,7 +63,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { private def testSingleListenerBasic(listener: EventCollector): Unit = { val clock = new StreamManualClock - val inputData = new MemoryStream[Int](0, sqlContext) + val inputData = new MemoryStream[Int](0, spark) val df = inputData.toDS().as[Long].map { 10 / _ } case class AssertStreamExecThreadToWaitForClock() @@ -333,7 +333,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { spark.streams.addListener(listener) try { var numTriggers = 0 - val input = new MemoryStream[Int](0, sqlContext) { + val input = new MemoryStream[Int](0, spark) { override def latestOffset(startOffset: OffsetV2, limit: ReadLimit): OffsetV2 = { numTriggers += 1 super.latestOffset(startOffset, limit) @@ -375,7 +375,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { collector.reset() session.sparkContext.addJobTag(jobTag1) session.sparkContext.addJobTag(jobTag2) - val mem = MemoryStream[Int](implicitly[Encoder[Int]], session.sqlContext) + val mem = MemoryStream[Int](implicitly[Encoder[Int]], session) testStream(mem.toDS())( AddData(mem, 1, 2, 3), CheckAnswer(1, 2, 3) @@ -400,7 +400,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { def runQuery(session: SparkSession): Unit = { collector1.reset() collector2.reset() - val mem = MemoryStream[Int](implicitly[Encoder[Int]], session.sqlContext) + val mem = MemoryStream[Int](implicitly[Encoder[Int]], session) testStream(mem.toDS())( AddData(mem, 1, 2, 3), CheckAnswer(1, 2, 3) @@ -468,7 +468,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { test("listener propagates observable metrics") { import org.apache.spark.sql.functions._ val clock = new StreamManualClock - val inputData = new MemoryStream[Int](0, sqlContext) + val inputData = new MemoryStream[Int](0, spark) val df = inputData.toDF() .observe( name = "my_event", @@ -564,7 +564,7 @@ class StreamingQueryListenerSuite extends StreamTest with BeforeAndAfter { } try { - val input = new MemoryStream[Int](0, sqlContext) + val input = new MemoryStream[Int](0, spark) val clock = new StreamManualClock() val result = input.toDF().select("value") testStream(result)( diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala index c0a123a2895cc..e42050e088a28 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQueryManagerSuite.scala @@ -273,8 +273,8 @@ class StreamingQueryManagerSuite extends StreamTest { testQuietly("can start a streaming query with the same name in a different session") { val session2 = spark.cloneSession() - val ds1 = MemoryStream(Encoders.INT, spark.sqlContext).toDS() - val ds2 = MemoryStream(Encoders.INT, session2.sqlContext).toDS() + val ds1 = MemoryStream(Encoders.INT, spark).toDS() + val ds2 = MemoryStream(Encoders.INT, session2).toDS() val queryName = "abc" val query1 = ds1.writeStream.format("noop").queryName(queryName).start() @@ -347,8 +347,8 @@ class StreamingQueryManagerSuite extends StreamTest { withTempDir { dir => val session2 = spark.cloneSession() - val ms1 = MemoryStream(Encoders.INT, spark.sqlContext) - val ds2 = MemoryStream(Encoders.INT, session2.sqlContext).toDS() + val ms1 = MemoryStream(Encoders.INT, spark) + val ds2 = MemoryStream(Encoders.INT, session2).toDS() val chkLocation = new File(dir, "_checkpoint").getCanonicalPath val dataLocation = new File(dir, "data").getCanonicalPath @@ -376,8 +376,8 @@ class StreamingQueryManagerSuite extends StreamTest { withTempDir { dir => val session2 = spark.cloneSession() - val ms1 = MemoryStream(Encoders.INT, spark.sqlContext) - val ds2 = MemoryStream(Encoders.INT, session2.sqlContext).toDS() + val ms1 = MemoryStream(Encoders.INT, spark) + val ds2 = MemoryStream(Encoders.INT, session2).toDS() val chkLocation = new File(dir, "_checkpoint").getCanonicalPath val dataLocation = new File(dir, "data").getCanonicalPath diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala index 7ea53d41a150b..82c6f18955afd 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/StreamingQuerySuite.scala @@ -230,7 +230,7 @@ class StreamingQuerySuite extends StreamTest with BeforeAndAfter with Logging wi clock = new StreamManualClock /** Custom MemoryStream that waits for manual clock to reach a time */ - val inputData = new MemoryStream[Int](0, sqlContext) { + val inputData = new MemoryStream[Int](0, spark) { private def dataAdded: Boolean = currentOffset.offset != -1 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateClusterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateClusterSuite.scala index f6f3b2bd8b795..414e8e418f952 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateClusterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateClusterSuite.scala @@ -138,7 +138,7 @@ class TransformWithStateClusterSuite extends StreamTest with TransformWithStateC testWithAndWithoutImplicitEncoders("streaming with transformWithState - " + "without initial state") { (spark, useImplicits) => import spark.implicits._ - val input = MemoryStream(Encoders.STRING, spark.sqlContext) + val input = MemoryStream(Encoders.STRING, spark) val agg = input.toDS() .groupByKey(x => x) .transformWithState(new FruitCountStatefulProcessor(useImplicits), @@ -180,7 +180,7 @@ class TransformWithStateClusterSuite extends StreamTest with TransformWithStateC val fruitCountInitial = fruitCountInitialDS .groupByKey(x => x) - val input = MemoryStream(Encoders.STRING, spark.sqlContext) + val input = MemoryStream(Encoders.STRING, spark) val agg = input.toDS() .groupByKey(x => x) .transformWithState(new FruitCountStatefulProcessorWithInitialState(useImplicits), diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala index 2a1ec4c7ab611..aac8587191962 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TransformWithStateSuite.scala @@ -30,6 +30,7 @@ import org.apache.spark.{SparkException, SparkRuntimeException, SparkUnsupported import org.apache.spark.internal.Logging import org.apache.spark.sql.{Dataset, Encoders, Row} import org.apache.spark.sql.catalyst.util.stringToFile +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.execution.datasources.v2.state.StateSourceOptions import org.apache.spark.sql.execution.exchange.ShuffleExchangeExec import org.apache.spark.sql.execution.streaming.checkpointing.CheckpointFileManager @@ -1008,7 +1009,7 @@ abstract class TransformWithStateSuite extends StateStoreMetricsTest } test("transformWithState - lazy iterators can properly get/set keyed state") { - val spark = this.spark + implicit val spark: SparkSession = this.spark import spark.implicits._ class ProcessorWithLazyIterators diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala index 3741ee8ab1feb..7b4338dff6b24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/TriggerAvailableNowSuite.scala @@ -98,7 +98,7 @@ class TriggerAvailableNowSuite extends FileStreamSourceTest { } class TestMicroBatchStream extends TestDataFrameProvider { - private lazy val memoryStream = MemoryStream[Long](0, spark.sqlContext) + private lazy val memoryStream = MemoryStream[Long](0, spark) override def toDF: DataFrame = memoryStream.toDF() diff --git a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala index 74db2a3843d76..2a186a9296f4a 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/streaming/test/DataStreamReaderWriterSuite.scala @@ -526,7 +526,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { } private def testMemorySinkCheckpointRecovery(chkLoc: String, provideInWriter: Boolean): Unit = { - val ms = new MemoryStream[Int](0, sqlContext) + val ms = new MemoryStream[Int](0, spark) val df = ms.toDF().toDF("a") val tableName = "test" def startQuery: StreamingQuery = { @@ -585,7 +585,7 @@ class DataStreamReaderWriterSuite extends StreamTest with BeforeAndAfter { test("append mode memory sink's do not support checkpoint recovery") { import testImplicits._ - val ms = new MemoryStream[Int](0, sqlContext) + val ms = new MemoryStream[Int](0, spark) val df = ms.toDF().toDF("a") val checkpointLoc = newMetadataDir val checkpointDir = new File(checkpointLoc, "offsets") diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala index 769b633a9c525..caa4ca4581b4d 100644 --- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala +++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveDDLSuite.scala @@ -37,6 +37,7 @@ import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.TableAlreadyExistsException import org.apache.spark.sql.catalyst.catalog._ import org.apache.spark.sql.catalyst.parser.{CatalystSqlParser, ParseException} +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.{CatalogManager, CatalogV2Util, Identifier, TableChange, TableInfo} import org.apache.spark.sql.connector.catalog.CatalogManager.SESSION_CATALOG_NAME import org.apache.spark.sql.connector.catalog.SupportsNamespaces.PROP_OWNER @@ -2654,7 +2655,7 @@ class HiveDDLSuite import org.apache.spark.sql.execution.streaming.runtime.MemoryStream import testImplicits._ - implicit val _sqlContext = spark.sqlContext + implicit val sparkSession: SparkSession = spark withTempView("t1") { Seq((1, "one"), (2, "two"), (4, "four")).toDF("number", "word").createOrReplaceTempView("t1") diff --git a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala index 95a57dcc4495c..bb534404f5656 100644 --- a/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala +++ b/sql/pipelines/src/main/scala/org/apache/spark/sql/pipelines/graph/elements.scala @@ -199,7 +199,7 @@ case class VirtualTableInput( // create empty streaming/batch df based on input type. def createEmptyDF(schema: StructType): DataFrame = readOptions match { case _: StreamingReadOptions => - MemoryStream[Row](ExpressionEncoder(schema, lenient = false), spark.sqlContext) + MemoryStream[Row](ExpressionEncoder(schema, lenient = false), spark) .toDF() case _ => spark.createDataFrame(new util.ArrayList[Row](), schema) } diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala index f37716b4a24d3..7c8181b5b72a5 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectInvalidPipelineSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.pipelines.graph -import org.apache.spark.sql.AnalysisException +import org.apache.spark.sql.{AnalysisException, SparkSession} import org.apache.spark.sql.execution.streaming.runtime.MemoryStream import org.apache.spark.sql.pipelines.utils.{PipelineTest, TestGraphRegistrationContext} import org.apache.spark.sql.test.SharedSparkSession @@ -423,6 +423,7 @@ class ConnectInvalidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ val p = new TestGraphRegistrationContext(spark) { + implicit val sparkSession: SparkSession = spark val mem = MemoryStream[Int] mem.addData(1) registerPersistedView("a", query = dfFlowFunc(mem.toDF())) @@ -466,6 +467,7 @@ class ConnectInvalidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ val graph = new TestGraphRegistrationContext(spark) { + implicit val sparkSession: SparkSession = spark registerMaterializedView("a", query = dfFlowFunc(MemoryStream[Int].toDF())) }.resolveToDataflowGraph() @@ -489,6 +491,7 @@ class ConnectInvalidPipelineSuite extends PipelineTest with SharedSparkSession { val graph = new TestGraphRegistrationContext(spark) { registerTable("a") + implicit val sparkSession: SparkSession = spark registerFlow( destinationName = "a", name = "once_flow", diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala index 2c0e2a728c69f..da1ef2c907295 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/ConnectValidPipelineSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark.sql.pipelines.graph +import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.catalyst.analysis.UnresolvedRelation import org.apache.spark.sql.catalyst.plans.logical.Union @@ -158,6 +159,7 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ class P extends TestGraphRegistrationContext(spark) { + implicit val sparkSession: SparkSession = spark val ints = MemoryStream[Int] ints.addData(1, 2, 3, 4) registerPersistedView("a", query = dfFlowFunc(ints.toDF())) @@ -199,6 +201,7 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ class P extends TestGraphRegistrationContext(spark) { + implicit val sparkSession: SparkSession = spark val ints1 = MemoryStream[Int] ints1.addData(1, 2, 3, 4) val ints2 = MemoryStream[Int] @@ -359,6 +362,7 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ class P extends TestGraphRegistrationContext(spark) { + implicit val sparkSession: SparkSession = spark val mem = MemoryStream[Int] registerPersistedView("a", query = dfFlowFunc(mem.toDF())) registerTable("b") @@ -402,6 +406,7 @@ class ConnectValidPipelineSuite extends PipelineTest with SharedSparkSession { import session.implicits._ val graph = new TestGraphRegistrationContext(spark) { + implicit val sparkSession: SparkSession = spark val mem = MemoryStream[Int] mem.addData(1, 2) registerPersistedView("complete-view", query = dfFlowFunc(Seq(1, 2).toDF("x"))) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala index 37c32a3498661..c030553c04078 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/MaterializeTablesSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.pipelines.graph import scala.jdk.CollectionConverters._ import org.apache.spark.SparkThrowable +import org.apache.spark.sql.classic.SparkSession import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} import org.apache.spark.sql.connector.expressions.Expressions import org.apache.spark.sql.execution.streaming.runtime.MemoryStream @@ -261,8 +262,8 @@ abstract class MaterializeTablesSuite extends BaseCoreExecutionTest { test("invalid schema merge") { val session = spark + implicit val sparkSession: SparkSession = spark import session.implicits._ - implicit def sqlContext: org.apache.spark.sql.classic.SQLContext = spark.sqlContext val streamInts = MemoryStream[Int] streamInts.addData(1, 2) @@ -330,7 +331,6 @@ abstract class MaterializeTablesSuite extends BaseCoreExecutionTest { test("specified schema incompatible with existing table") { val session = spark import session.implicits._ - implicit def sqlContext: org.apache.spark.sql.classic.SQLContext = spark.sqlContext sql(s"CREATE TABLE ${TestGraphRegistrationContext.DEFAULT_DATABASE}.t6(x BOOLEAN)") val catalog = spark.sessionState.catalogManager.currentCatalog.asInstanceOf[TableCatalog] @@ -344,6 +344,7 @@ abstract class MaterializeTablesSuite extends BaseCoreExecutionTest { val ex = intercept[TableMaterializationException] { materializeGraph(new TestGraphRegistrationContext(spark) { + implicit val sparkSession: SparkSession = spark val source: MemoryStream[Int] = MemoryStream[Int] source.addData(1, 2) registerTable( @@ -628,8 +629,8 @@ abstract class MaterializeTablesSuite extends BaseCoreExecutionTest { s"Streaming tables should evolve schema only if not full refresh = $isFullRefresh" ) { val session = spark + implicit val sparkSession: SparkSession = spark import session.implicits._ - implicit def sqlContext: org.apache.spark.sql.classic.SQLContext = spark.sqlContext val streamInts = MemoryStream[Int] streamInts.addData(1 until 5: _*) diff --git a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala index 4fcd9dad93fe7..0f20c0506ba1c 100644 --- a/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala +++ b/sql/pipelines/src/test/scala/org/apache/spark/sql/pipelines/graph/TriggeredGraphExecutionSuite.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.pipelines.graph import org.scalatest.time.{Seconds, Span} -import org.apache.spark.sql.{functions, Row} +import org.apache.spark.sql.{functions, Row, SparkSession} import org.apache.spark.sql.catalyst.TableIdentifier import org.apache.spark.sql.classic.{DataFrame, Dataset} import org.apache.spark.sql.connector.catalog.{CatalogV2Util, Identifier, TableCatalog} @@ -183,6 +183,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession // Construct pipeline val pipelineDef = new TestGraphRegistrationContext(spark) { + implicit val sparkSession: SparkSession = spark private val ints = MemoryStream[Int] ints.addData(1 until 10: _*) registerView("input", query = dfFlowFunc(ints.toDF())) @@ -259,6 +260,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession // Construct pipeline val pipelineDef = new TestGraphRegistrationContext(spark) { + implicit val sparkSession: SparkSession = spark private val ints = MemoryStream[Int] registerView("input", query = dfFlowFunc(ints.toDF())) registerTable( @@ -309,6 +311,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession }) val pipelineDef = new TestGraphRegistrationContext(spark) { + implicit val sparkSession: SparkSession = spark private val memoryStream = MemoryStream[Int] memoryStream.addData(1, 2) registerView("input_view", query = dfFlowFunc(memoryStream.toDF())) @@ -547,6 +550,7 @@ class TriggeredGraphExecutionSuite extends ExecutionTest with SharedSparkSession // Construct pipeline val pipelineDef = new TestGraphRegistrationContext(spark) { + implicit val sparkSession: SparkSession = spark private val memoryStream = MemoryStream[Int] memoryStream.addData(1, 2) registerView("input_view", query = dfFlowFunc(memoryStream.toDF()))