diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index fc895d60fad9f..5f6c1a3949eab 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -678,6 +678,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor * Construct [[Aggregate]] operator from Cube/Rollup/GroupingSets. */ private def constructAggregate( + operator: LogicalPlan, selectedGroupByExprs: Seq[Seq[Expression]], groupByExprs: Seq[Expression], aggregationExprs: Seq[NamedExpression], @@ -687,6 +688,10 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor throw QueryCompilationErrors.groupingSizeTooLargeError(GroupingID.dataType.defaultSize * 8) } + if (groupByExprs.exists(_.exists(_.isInstanceOf[Generator]))) { + throw QueryCompilationErrors.generatorOutsideSelectError(operator) + } + // Expand works by setting grouping expressions to null as determined by the // `selectedGroupByExprs`. To prevent these null values from being used in an aggregate // instead of the original value we need to create new aliases for all group by expressions @@ -748,7 +753,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor ResolveAggregateFunctions.resolveExprsWithAggregate(Seq(cond), aggForResolving) // Push the aggregate expressions into the aggregate (if any). - val newChild = constructAggregate(selectedGroupByExprs, groupByExprs, + val newChild = constructAggregate(h, selectedGroupByExprs, groupByExprs, aggregate.aggregateExpressions ++ extraAggExprs, aggregate.child) // Since the output exprId will be changed in the constructed aggregate, here we build an @@ -783,9 +788,9 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor case a if !a.childrenResolved => a // Ensure group by expressions and aggregate expressions have been resolved. - case Aggregate(GroupingAnalytics(selectedGroupByExprs, groupByExprs), aggExprs, child, _) + case a @ Aggregate(GroupingAnalytics(selectedGroupByExprs, groupByExprs), aggExprs, child, _) if aggExprs.forall(_.resolved) => - constructAggregate(selectedGroupByExprs, groupByExprs, aggExprs, child) + constructAggregate(a, selectedGroupByExprs, groupByExprs, aggExprs, child) // We should make sure all expressions in condition have been resolved. case f @ Filter(cond, child) if hasGroupingFunction(cond) && cond.resolved => diff --git a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala index 4121ba67bd3bf..138a29c6ae804 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/SQLQuerySuite.scala @@ -4934,6 +4934,34 @@ class SQLQuerySuite extends QueryTest with SharedSparkSession with AdaptiveSpark checkAnswer(df, Row("a")) } + + test("SPARK-51901: Disallow generator functions in grouping sets") { + checkError( + exception = intercept[AnalysisException] { + sql("select * group by grouping sets (inline(array(struct('col'))))") + }, + condition = "UNSUPPORTED_GENERATOR.OUTSIDE_SELECT", + parameters = Map( + "plan" -> "'Aggregate [groupingsets(Vector(0), inline(array(struct(col1, col))))]" + ) + ) + + checkError( + exception = intercept[AnalysisException] { + sql("select * group by grouping sets (explode(array('col')))") + }, + condition = "UNSUPPORTED_GENERATOR.OUTSIDE_SELECT", + parameters = Map("plan" -> "'Aggregate [groupingsets(Vector(0), explode(array(col)))]") + ) + + checkError( + exception = intercept[AnalysisException] { + sql("select * group by grouping sets (posexplode(array('col')))") + }, + condition = "UNSUPPORTED_GENERATOR.OUTSIDE_SELECT", + parameters = Map("plan" -> "'Aggregate [groupingsets(Vector(0), posexplode(array(col)))]") + ) + } } case class Foo(bar: Option[String])