Skip to content

Commit f0fb8bf

Browse files
utkarsh39cloud-fan
authored andcommitted
[SPARK-50749][SQL] Fix ordering bug in CommutativeExpression.gatherCommutative method
### What changes were proposed in this pull request? [SPARK-49977](https://issues.apache.org/jira/browse/SPARK-49977) introduced a bug in the `CommutativeExpression.gatherCommutative()` method, changing the function's output order. Consider the following concrete example: ``` val addExpression = Add(   Literal(1),   Add(     Literal(2),     Literal(3)   ) ) val commutativeExpressions = addExpression.gatherCommutative(addExpression,   { case Add(l, r, _) => Seq(l, r)}) ``` Consider the output of the `gatherCommutative` method. [SPARK-49977](https://issues.apache.org/jira/browse/SPARK-49977) introduced a bug that reversed the output order. This PR fixes the bug in `gatherCommutative()` to restore the original correct ordered output. ``` // Prior to [SPARK-49977](https://issues.apache.org/jira/browse/SPARK-49977) and after this fix // commutativeExpressions -> Seq(Literal(1), Literal(2), Literal(3))) // Post [SPARK-49977](https://issues.apache.org/jira/browse/SPARK-49977) and before this fix // commutativeExpressions -> Seq(Literal(3), Literal(2), Literal(1))) ``` ### Why are the changes needed? Fixing a bug ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Added a test ### Was this patch authored or co-authored using generative AI tooling? Generated-by: ChatGPT Closes #49392 from utkarsh39/SPARK-50749. Authored-by: utkarsh39 <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent ee868b9 commit f0fb8bf

File tree

2 files changed

+24
-6
lines changed

2 files changed

+24
-6
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Expression.scala

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1354,19 +1354,24 @@ trait UserDefinedExpression {
13541354
}
13551355

13561356
trait CommutativeExpression extends Expression {
1357-
/** Collects adjacent commutative operations. */
1358-
private def gatherCommutative(
1357+
/**
1358+
* Collects adjacent commutative operations.
1359+
*
1360+
* Exposed for testing
1361+
*/
1362+
private[spark] def gatherCommutative(
13591363
e: Expression,
13601364
f: PartialFunction[CommutativeExpression, Seq[Expression]]): Seq[Expression] = {
13611365
val resultBuffer = scala.collection.mutable.Buffer[Expression]()
1362-
val stack = scala.collection.mutable.Stack[Expression](e)
1366+
val queue = scala.collection.mutable.Queue[Expression](e)
13631367

13641368
// [SPARK-49977]: Use iterative approach to avoid creating many temporary List objects
13651369
// for deep expression trees through recursion.
1366-
while (stack.nonEmpty) {
1367-
stack.pop() match {
1370+
while (queue.nonEmpty) {
1371+
val current = queue.dequeue()
1372+
current match {
13681373
case c: CommutativeExpression if f.isDefinedAt(c) =>
1369-
stack.pushAll(f(c))
1374+
queue ++= f(c)
13701375
case other =>
13711376
resultBuffer += other.canonicalized
13721377
}

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CanonicalizeSuite.scala

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -479,4 +479,17 @@ class CanonicalizeSuite extends SparkFunSuite {
479479
}
480480
}
481481
}
482+
483+
test("unit test for gatherCommutative()") {
484+
val addExpression = Add(
485+
Literal(1),
486+
Add(
487+
Literal(2),
488+
Literal(3)
489+
)
490+
)
491+
val commutativeExpressions = addExpression.gatherCommutative(addExpression,
492+
{ case Add(l, r, _) => Seq(l, r)})
493+
assert(commutativeExpressions == Seq(Literal(1), Literal(2), Literal(3)))
494+
}
482495
}

0 commit comments

Comments
 (0)