Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import javax.annotation.concurrent.GuardedBy
import scala.collection.mutable.ListBuffer

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{Encoder, SQLContext}
import org.apache.spark.sql.{Encoder, SparkSession, SQLContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.encoders.{encoderFor, ExpressionEncoder}
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
Expand All @@ -43,32 +43,53 @@ import org.apache.spark.sql.internal.connector.SimpleTableProvider
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap

object MemoryStream {
object MemoryStream extends LowPriorityMemoryStreamImplicits {
protected val currentBlockId = new AtomicInteger(0)
protected val memoryStreamId = new AtomicInteger(0)

def apply[A : Encoder](implicit sqlContext: SQLContext): MemoryStream[A] =
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
def apply[A : Encoder](implicit sparkSession: SparkSession): MemoryStream[A] =
new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)

def apply[A : Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] =
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, Some(numPartitions))
def apply[A : Encoder](numPartitions: Int)(implicit sparkSession: SparkSession): MemoryStream[A] =
new MemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, Some(numPartitions))
}

/**
* Provides lower-priority implicits for MemoryStream to prevent ambiguity when both
* SparkSession and SQLContext are in scope. The implicits in the companion object,
* which use SparkSession, take higher precedence.
*/
trait LowPriorityMemoryStreamImplicits {
this: MemoryStream.type =>

// Deprecated: Used when an implicit SQLContext is in scope
@deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0")
def apply[A: Encoder]()(implicit sqlContext: SQLContext): MemoryStream[A] =
new MemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext.sparkSession)

@deprecated("Use MemoryStream.apply with an implicit SparkSession instead of SQLContext", "4.1.0")
def apply[A: Encoder](numPartitions: Int)(implicit sqlContext: SQLContext): MemoryStream[A] =
new MemoryStream[A](
memoryStreamId.getAndIncrement(),
sqlContext.sparkSession,
Some(numPartitions))
}

/**
* A base class for memory stream implementations. Supports adding data and resetting.
*/
abstract class MemoryStreamBase[A : Encoder](sqlContext: SQLContext) extends SparkDataStream {
abstract class MemoryStreamBase[A : Encoder](sparkSession: SparkSession) extends SparkDataStream {
val encoder = encoderFor[A]
protected val attributes = toAttributes(encoder.schema)

protected lazy val toRow: ExpressionEncoder.Serializer[A] = encoder.createSerializer()

def toDS(): Dataset[A] = {
Dataset[A](sqlContext.sparkSession, logicalPlan)
Dataset[A](sparkSession, logicalPlan)
}

def toDF(): DataFrame = {
Dataset.ofRows(sqlContext.sparkSession, logicalPlan)
Dataset.ofRows(sparkSession, logicalPlan)
}

def addData(data: A*): OffsetV2 = {
Expand Down Expand Up @@ -156,9 +177,9 @@ class MemoryStreamScanBuilder(stream: MemoryStreamBase[_]) extends ScanBuilder w
*/
case class MemoryStream[A : Encoder](
id: Int,
sqlContext: SQLContext,
sparkSession: SparkSession,
numPartitions: Option[Int] = None)
extends MemoryStreamBase[A](sqlContext)
extends MemoryStreamBase[A](sparkSession)
with MicroBatchStream
with SupportsTriggerAvailableNow
with Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import org.json4s.jackson.Serialization

import org.apache.spark.{SparkEnv, TaskContext}
import org.apache.spark.rpc.RpcEndpointRef
import org.apache.spark.sql.{Encoder, SQLContext}
import org.apache.spark.sql.{Encoder, SparkSession}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.connector.read.InputPartition
Expand All @@ -44,8 +44,11 @@ import org.apache.spark.util.RpcUtils
* ContinuousMemoryStreamInputPartitionReader instances to poll. It returns the record at
* the specified offset within the list, or null if that offset doesn't yet have a record.
*/
class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPartitions: Int = 2)
extends MemoryStreamBase[A](sqlContext) with ContinuousStream {
class ContinuousMemoryStream[A : Encoder](
id: Int,
sparkSession: SparkSession,
numPartitions: Int = 2)
extends MemoryStreamBase[A](sparkSession) with ContinuousStream {

private implicit val formats: Formats = Serialization.formats(NoTypeHints)

Expand Down Expand Up @@ -112,11 +115,11 @@ class ContinuousMemoryStream[A : Encoder](id: Int, sqlContext: SQLContext, numPa
object ContinuousMemoryStream {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shall we do the same low priority implicit trick here?

protected val memoryStreamId = new AtomicInteger(0)

def apply[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext)
def apply[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] =
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession)

def singlePartition[A : Encoder](implicit sqlContext: SQLContext): ContinuousMemoryStream[A] =
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sqlContext, 1)
def singlePartition[A : Encoder](implicit sparkSession: SparkSession): ContinuousMemoryStream[A] =
new ContinuousMemoryStream[A](memoryStreamId.getAndIncrement(), sparkSession, 1)
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1012,7 +1012,7 @@ class DatasetSuite extends QueryTest
assert(err.getMessage.contains("An Observation can be used with a Dataset only once"))

// streaming datasets are not supported
val streamDf = new MemoryStream[Int](0, sqlContext).toDF()
val streamDf = new MemoryStream[Int](0, spark).toDF()
val streamObservation = Observation("stream")
val streamErr = intercept[IllegalArgumentException] {
streamDf.observe(streamObservation, avg($"value").cast("int").as("avg_val"))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,9 @@ class AsyncProgressTrackingMicroBatchExecutionSuite

class MemoryStreamCapture[A: Encoder](
id: Int,
sqlContext: SQLContext,
sparkSession: SparkSession,
numPartitions: Option[Int] = None)
extends MemoryStream[A](id, sqlContext, numPartitions = numPartitions) {
extends MemoryStream[A](id, sparkSession, numPartitions = numPartitions) {

val commits = new ListBuffer[streaming.Offset]()
val commitThreads = new ListBuffer[Thread]()
Expand Down Expand Up @@ -136,7 +136,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
test("async WAL commits recovery") {
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath

val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStream[Int](id = 0, spark)
val ds = inputData.toDF()

var index = 0
Expand Down Expand Up @@ -204,7 +204,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
}

test("async WAL commits turn on and off") {
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStream[Int](id = 0, spark)
val ds = inputData.toDS()

val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
Expand Down Expand Up @@ -308,7 +308,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
}

test("Fail with once trigger") {
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStream[Int](id = 0, spark)
val ds = inputData.toDF()

val e = intercept[IllegalArgumentException] {
Expand All @@ -323,7 +323,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite

test("Fail with available now trigger") {

val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStream[Int](id = 0, spark)
val ds = inputData.toDF()

val e = intercept[IllegalArgumentException] {
Expand All @@ -339,7 +339,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
test("switching between async wal commit enabled and trigger once") {
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath

val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStream[Int](id = 0, spark)
val ds = inputData.toDF()

var index = 0
Expand Down Expand Up @@ -500,7 +500,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
test("switching between async wal commit enabled and available now") {
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath

val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStream[Int](id = 0, spark)
val ds = inputData.toDF()

var index = 0
Expand Down Expand Up @@ -669,7 +669,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
}

def testAsyncWriteErrorsAlreadyExists(path: String): Unit = {
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStream[Int](id = 0, spark)
val ds = inputData.toDS()
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath

Expand Down Expand Up @@ -720,7 +720,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
}

def testAsyncWriteErrorsPermissionsIssue(path: String): Unit = {
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStream[Int](id = 0, spark)
val ds = inputData.toDS()
val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
val commitDir = new File(checkpointLocation + path)
Expand Down Expand Up @@ -778,7 +778,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite

val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath

val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStreamCapture[Int](id = 0, spark)

val ds = inputData.toDF()

Expand Down Expand Up @@ -852,7 +852,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
}

test("interval commits and recovery") {
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
val ds = inputData.toDS()

val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
Expand Down Expand Up @@ -934,7 +934,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
}

test("recovery when first offset is not zero and not commit log entries") {
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
val ds = inputData.toDS()

val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
Expand All @@ -961,7 +961,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
/**
* start new stream
*/
val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
val inputData2 = new MemoryStreamCapture[Int](id = 0, spark)
val ds2 = inputData2.toDS()
testStream(ds2, extraOptions = Map(
ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
Expand Down Expand Up @@ -995,7 +995,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
}

test("recovery non-contiguous log") {
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
val ds = inputData.toDS()

val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
Expand Down Expand Up @@ -1088,7 +1088,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
}

test("Fail on pipelines using unsupported sinks") {
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStream[Int](id = 0, spark)
val ds = inputData.toDF()

val e = intercept[IllegalArgumentException] {
Expand All @@ -1109,7 +1109,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite

withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "false") {
withTempDir { checkpointLocation =>
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
val ds = inputData.toDS()

val clock = new StreamManualClock
Expand Down Expand Up @@ -1243,7 +1243,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
test("with async log purging") {
withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") {
withTempDir { checkpointLocation =>
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
val ds = inputData.toDS()

val clock = new StreamManualClock
Expand Down Expand Up @@ -1381,7 +1381,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
}

test("test multiple gaps in offset and commit logs") {
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
val ds = inputData.toDS()

val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
Expand Down Expand Up @@ -1427,7 +1427,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
/**
* start new stream
*/
val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
val inputData2 = new MemoryStreamCapture[Int](id = 0, spark)
val ds2 = inputData2.toDS()
testStream(ds2, extraOptions = Map(
ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
Expand Down Expand Up @@ -1460,7 +1460,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
}

test("recovery when gaps exist in offset and commit log") {
val inputData = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStreamCapture[Int](id = 0, spark)
val ds = inputData.toDS()

val checkpointLocation = Utils.createTempDir(namePrefix = "streaming.metadata").getCanonicalPath
Expand Down Expand Up @@ -1494,7 +1494,7 @@ class AsyncProgressTrackingMicroBatchExecutionSuite
/**
* start new stream
*/
val inputData2 = new MemoryStreamCapture[Int](id = 0, sqlContext = sqlContext)
val inputData2 = new MemoryStreamCapture[Int](id = 0, spark)
val ds2 = inputData2.toDS()
testStream(ds2, extraOptions = Map(
ASYNC_PROGRESS_TRACKING_ENABLED -> "true",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,28 @@ class MemorySinkSuite extends StreamTest with BeforeAndAfter {
intsToDF(expected)(schema))
}

test("LowPriorityMemoryStreamImplicits works with implicit sqlContext") {
// Test that MemoryStream can be created using implicit sqlContext
implicit val sqlContext: SQLContext = spark.sqlContext

// Test MemoryStream[A]() with implicit sqlContext
val stream1 = MemoryStream[Int]()
assert(stream1 != null)

// Test MemoryStream[A](numPartitions) with implicit sqlContext
val stream2 = MemoryStream[String](3)
assert(stream2 != null)

// Verify the streams work correctly
stream1.addData(1, 2, 3)
val df1 = stream1.toDF()
assert(df1.schema.fieldNames.contains("value"))

stream2.addData("a", "b", "c")
val df2 = stream2.toDF()
assert(df2.schema.fieldNames.contains("value"))
}

private implicit def intsToDF(seq: Seq[Int])(implicit schema: StructType): DataFrame = {
require(schema.fields.length === 1)
sqlContext.createDataset(seq).toDF(schema.fieldNames.head)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter with Match
test("async log purging") {
withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") {
withTempDir { checkpointLocation =>
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStream[Int](id = 0, spark)
val ds = inputData.toDS()
testStream(ds)(
StartStream(checkpointLocation = checkpointLocation.getCanonicalPath),
Expand Down Expand Up @@ -99,7 +99,7 @@ class MicroBatchExecutionSuite extends StreamTest with BeforeAndAfter with Match
test("error notifier test") {
withSQLConf(SQLConf.MIN_BATCHES_TO_RETAIN.key -> "2", SQLConf.ASYNC_LOG_PURGE.key -> "true") {
withTempDir { checkpointLocation =>
val inputData = new MemoryStream[Int](id = 0, sqlContext = sqlContext)
val inputData = new MemoryStream[Int](id = 0, spark)
val ds = inputData.toDS()
val e = intercept[StreamingQueryException] {

Expand Down
Loading