spark DecorrelateInnerQuery 源码

  • 2022-10-20
  • 浏览 (265)

spark DecorrelateInnerQuery 代码

文件路径:/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/DecorrelateInnerQuery.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.optimizer

import scala.collection.mutable.ArrayBuffer

import org.apache.spark.SparkException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.SubExprUtils._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.trees.TreePattern.OUTER_REFERENCE
import org.apache.spark.sql.errors.{QueryCompilationErrors, QueryExecutionErrors}
import org.apache.spark.util.collection.Utils

/**
 * Decorrelate the inner query by eliminating outer references and create domain joins.
 * The implementation is based on the paper: Unnesting Arbitrary Queries by Thomas Neumann
 * and Alfons Kemper. https://dl.gi.de/handle/20.500.12116/2418.
 *
 * A correlated subquery can be viewed as a "dependent" nested loop join between the outer and
 * the inner query. For each row produced by the outer query, we bind the [[OuterReference]]s in
 * in the inner query with the corresponding values in the row, and then evaluate the inner query.
 *
 * Dependent Join
 * :- Outer Query
 * +- Inner Query
 *
 * If the [[OuterReference]]s are bound to the same value, the inner query will return the same
 * result. Based on this, we can reduce the times to evaluate the inner query by first getting
 * all distinct values of the [[OuterReference]]s.
 *
 * Normal Join
 * :- Outer Query
 * +- Dependent Join
 *    :- Inner Query
 *    +- Distinct Aggregate (outer_ref1, outer_ref2, ...)
 *       +- Outer Query
 *
 * The distinct aggregate of the outer references is called a "domain", and the dependent join
 * between the inner query and the domain is called a "domain join". We need to push down the
 * domain join through the inner query until there is no outer reference in the sub-tree and
 * the domain join will turn into a normal join.
 *
 * The decorrelation function returns a new query plan with optional placeholder [[DomainJoins]]s
 * added and a list of join conditions with the outer query. [[DomainJoin]]s need to be rewritten
 * into actual inner join between the inner query sub-tree and the outer query.
 *
 * E.g. decorrelate an inner query with equality predicates:
 *
 * SELECT (SELECT MIN(b) FROM t1 WHERE t2.c = t1.a) FROM t2
 *
 * Aggregate [] [min(b)]            Aggregate [a] [min(b), a]
 * +- Filter (outer(c) = a)   =>   +- Relation [t1]
 *    +- Relation [t1]
 *
 * Join conditions: [c = a]
 *
 * E.g. decorrelate an inner query with non-equality predicates:
 *
 * SELECT (SELECT MIN(b) FROM t1 WHERE t2.c > t1.a) FROM t2
 *
 * Aggregate [] [min(b)]            Aggregate [c'] [min(b), c']
 * +- Filter (outer(c) > a)   =>   +- Filter (c' > a)
 *    +- Relation [t1]                  +- DomainJoin [c']
 *                                         +- Relation [t1]
 *
 * Join conditions: [c <=> c']
 */
object DecorrelateInnerQuery extends PredicateHelper {

  /**
   * Check if an expression contains any attribute. Note OuterReference is a
   * leaf node and will not be found here.
   */
  private def containsAttribute(expression: Expression): Boolean = {
    expression.exists(_.isInstanceOf[Attribute])
  }

  /**
   * Check if an expression can be pulled up over an [[Aggregate]] without changing the
   * semantics of the plan. The expression must be an equality predicate that guarantees
   * one-to-one mapping between inner and outer attributes.
   * For example:
   *   (a = outer(c)) -> true
   *   (a > outer(c)) -> false
   *   (a + b = outer(c)) -> false
   *   (a = outer(c) - b) -> false
   */
  def canPullUpOverAgg(expression: Expression): Boolean = {
    def isSupported(e: Expression): Boolean = e match {
      case _: Attribute => true
      // Allow Cast expressions that guarantee 1:1 mapping.
      case Cast(a: Attribute, dataType, _, _) => Cast.canUpCast(a.dataType, dataType)
      case _ => false
    }

    // Only allow equality condition with one side being an attribute or an expression that
    // guarantees 1:1 mapping and another side being an expression without attributes from
    // the inner query.
    expression match {
      case Equality(a, b) if isSupported(a) => !containsAttribute(b)
      case Equality(a, b) if isSupported(b) => !containsAttribute(a)
      case o => !containsAttribute(o)
    }
  }

  /**
   * Collect outer references in an expressions that are in the output attributes of the outer plan.
   */
  private def collectOuterReferences(expression: Expression): AttributeSet = {
    AttributeSet(expression.collect { case o: OuterReference => o.toAttribute })
  }

  /**
   * Collect outer references in a sequence of expressions that are in the output attributes
   * of the outer plan.
   */
  private def collectOuterReferences(expressions: Seq[Expression]): AttributeSet = {
    AttributeSet.fromAttributeSets(expressions.map(collectOuterReferences))
  }

  /**
   * Build a mapping between outer references with equivalent inner query attributes.
   * E.g. [outer(a) = x, y = outer(b), outer(c) = z + 1] => {a -> x, b -> y}
   */
  private def collectEquivalentOuterReferences(
      expressions: Seq[Expression]): AttributeMap[Attribute] = {
    AttributeMap(expressions.collect {
      case Equality(o: OuterReference, a: Attribute) => (o.toAttribute, a.toAttribute)
      case Equality(a: Attribute, o: OuterReference) => (o.toAttribute, a.toAttribute)
    })
  }

  /**
   * Replace all outer references using the expressions in the given outer reference map.
   */
  private def replaceOuterReference[E <: Expression](
      expression: E,
      outerReferenceMap: AttributeMap[Attribute]): E = {
    expression.transformWithPruning(_.containsPattern(OUTER_REFERENCE)) {
      case o: OuterReference => outerReferenceMap.getOrElse(o.toAttribute, o)
    }.asInstanceOf[E]
  }

  /**
   * Replace all outer references in the given expressions using the expressions in the
   * outer reference map.
   */
  private def replaceOuterReferences[E <: Expression](
      expressions: Seq[E],
      outerReferenceMap: AttributeMap[Attribute]): Seq[E] = {
    expressions.map(replaceOuterReference(_, outerReferenceMap))
  }

  /**
   * Replace all outer references in the given named expressions and keep the output
   * attributes unchanged.
   */
  private def replaceOuterInNamedExpressions(
      expressions: Seq[NamedExpression],
      outerReferenceMap: AttributeMap[Attribute]): Seq[NamedExpression] = {
    expressions.map { expr =>
      val newExpr = replaceOuterReference(expr, outerReferenceMap)
      if (!newExpr.toAttribute.semanticEquals(expr.toAttribute)) {
        Alias(newExpr, expr.name)(expr.exprId)
      } else {
        newExpr
      }
    }
  }

  /**
   * Return all references that are presented in the join conditions but not in the output
   * of the given named expressions.
   */
  private def missingReferences(
      namedExpressions: Seq[NamedExpression],
      joinCond: Seq[Expression]): AttributeSet = {
    val output = namedExpressions.map(_.toAttribute)
    AttributeSet(joinCond.flatMap(_.references)) -- AttributeSet(output)
  }

  /**
   * Deduplicate the inner and the outer query attributes and return an aliased
   * subquery plan and join conditions if duplicates are found. Duplicated attributes
   * can break the structural integrity when joining the inner and outer plan together.
   */
  def deduplicate(
      innerPlan: LogicalPlan,
      conditions: Seq[Expression],
      outerOutputSet: AttributeSet): (LogicalPlan, Seq[Expression]) = {
    val duplicates = innerPlan.outputSet.intersect(outerOutputSet)
    if (duplicates.nonEmpty) {
      val aliasMap = AttributeMap(duplicates.map { dup =>
        dup -> Alias(dup, dup.toString)()
      }.toSeq)
      val aliasedExpressions = innerPlan.output.map { ref =>
        aliasMap.getOrElse(ref, ref)
      }
      val aliasedProjection = Project(aliasedExpressions, innerPlan)
      val aliasedConditions = conditions.map(_.transform {
        case ref: Attribute => aliasMap.getOrElse(ref, ref).toAttribute
      })
      (aliasedProjection, aliasedConditions)
    } else {
      (innerPlan, conditions)
    }
  }

  /**
   * Build a mapping between domain attributes and corresponding outer query expressions
   * using the join conditions.
   */
  private def buildDomainAttrMap(
      conditions: Seq[Expression],
      domainAttrs: Seq[Attribute]): Map[Attribute, Expression] = {
    val domainAttrSet = AttributeSet(domainAttrs)
    conditions.collect {
      // When we build the join conditions between the domain attributes and outer references,
      // the left hand side is always the domain attribute used in the inner query and the right
      // hand side is the attribute from the outer query. Note here the right hand side of a
      // condition is not necessarily an attribute, for example it can be a literal (if foldable)
      // or a cast expression after the optimization.
      case EqualNullSafe(left: Attribute, right: Expression) if domainAttrSet.contains(left) =>
        left -> right
    }.toMap
  }

  /**
   * Rewrite all [[DomainJoin]]s in the inner query to actual joins with the outer query.
   */
  def rewriteDomainJoins(
      outerPlan: LogicalPlan,
      innerPlan: LogicalPlan,
      conditions: Seq[Expression]): LogicalPlan = innerPlan match {
    case d @ DomainJoin(domainAttrs, child, joinType, condition) =>
      val domainAttrMap = buildDomainAttrMap(conditions, domainAttrs)

      val newChild = joinType match {
        // Left outer domain joins are used to handle the COUNT bug.
        case LeftOuter =>
          // Replace the attributes in the domain join condition with the actual outer expressions
          // and use the new join conditions to rewrite domain joins in its child. For example:
          // DomainJoin [c'] LeftOuter (a = c') with domainAttrMap: { c' -> _1 }.
          // Then the new conditions to use will be [(a = _1)].
          assert(condition.isDefined,
            s"LeftOuter domain join should always have the join condition defined:\n$d")
          val newCond = condition.get.transform {
            case a: Attribute => domainAttrMap.getOrElse(a, a)
          }
          // Recursively rewrite domain joins using the new conditions.
          rewriteDomainJoins(outerPlan, child, splitConjunctivePredicates(newCond))
        case Inner =>
          // The decorrelation framework adds domain inner joins by traversing down the plan tree
          // recursively until it reaches a node that is not correlated with the outer query.
          // So the child node of a domain inner join shouldn't contain another domain join.
          assert(!child.exists(_.isInstanceOf[DomainJoin]),
            s"Child of a domain inner join shouldn't contain another domain join.\n$child")
          child
        case o =>
          throw new IllegalStateException(s"Unexpected domain join type $o")
      }

      // We should only rewrite a domain join when all corresponding outer plan attributes
      // can be found from the join condition.
      if (domainAttrMap.size == domainAttrs.size) {
        val groupingExprs = domainAttrs.map(domainAttrMap)
        val aggregateExprs = groupingExprs.zip(domainAttrs).map {
          // Rebuild the aliases.
          case (inputAttr, outputAttr) => Alias(inputAttr, outputAttr.name)(outputAttr.exprId)
        }
        // Construct a domain with the outer query plan.
        // DomainJoin [a', b']  =>  Aggregate [a, b] [a AS a', b AS b']
        //                          +- Relation [a, b]
        val domain = Aggregate(groupingExprs, aggregateExprs, outerPlan)
        newChild match {
          // A special optimization for OneRowRelation.
          // TODO: add a more general rule to optimize join with OneRowRelation.
          case _: OneRowRelation => domain
          // Construct a domain join.
          // Join joinType condition
          // :- Domain
          // +- Inner Query
          case _ => Join(domain, newChild, joinType, condition, JoinHint.NONE)
        }
      } else {
        throw new IllegalStateException(
          s"Unable to rewrite domain join with conditions: $conditions\n$d.")
      }
    case p: LogicalPlan =>
      p.mapChildren(rewriteDomainJoins(outerPlan, _, conditions))
  }

  def apply(
      innerPlan: LogicalPlan,
      outerPlan: LogicalPlan,
      handleCountBug: Boolean = false): (LogicalPlan, Seq[Expression]) = {
    val outputPlanInputAttrs = outerPlan.inputSet

    // The return type of the recursion.
    // The first parameter is a new logical plan with correlation eliminated.
    // The second parameter is a list of join conditions with the outer query.
    // The third parameter is a mapping between the outer references and equivalent
    // expressions from the inner query that is used to replace outer references.
    type ReturnType = (LogicalPlan, Seq[Expression], AttributeMap[Attribute])

    // Decorrelate the input plan with a set of parent outer references and a boolean flag
    // indicating whether the result of the plan will be aggregated. Steps:
    // 1. Recursively collects outer references from the inner query until it reaches a node
    //    that does not contain correlated value.
    // 2. Inserts an optional [[DomainJoin]] node to indicate whether a domain (inner) join is
    //    needed between the outer query and the specific sub-tree of the inner query.
    // 3. Returns a list of join conditions with the outer query and a mapping between outer
    //    references with references inside the inner query. The parent nodes need to preserve
    //    the references inside the join conditions and substitute all outer references using
    //    the mapping.
    def decorrelate(
        plan: LogicalPlan,
        parentOuterReferences: AttributeSet,
        aggregated: Boolean = false): ReturnType = {
      val isCorrelated = hasOuterReferences(plan)
      if (!isCorrelated) {
        // We have reached a plan without correlation to the outer plan.
        if (parentOuterReferences.isEmpty) {
          // If there is no outer references from the parent nodes, it means all outer
          // attributes can be substituted by attributes from the inner plan. So no
          // domain join is needed.
          (plan, Nil, AttributeMap.empty[Attribute])
        } else {
          // Build the domain join with the parent outer references.
          val attributes = parentOuterReferences.toSeq
          val domains = attributes.map(_.newInstance())
          // A placeholder to be rewritten into domain join.
          val domainJoin = DomainJoin(domains, plan)
          val outerReferenceMap = Utils.toMap(attributes, domains)
          // Build join conditions between domain attributes and outer references.
          // EqualNullSafe is used to make sure null key can be joined together. Note
          // outer referenced attributes can be changed during the outer query optimization.
          // The equality conditions will also serve as an attribute mapping between new
          // outer references and domain attributes when rewriting the domain joins.
          // E.g. if the attribute a is changed to a1, the join condition a' <=> outer(a)
          // will become a' <=> a1, and we can construct the aliases based on the condition:
          // DomainJoin [a']        Join Inner
          // +- InnerQuery     =>   :- InnerQuery
          //                        +- Aggregate [a1] [a1 AS a']
          //                           +- OuterQuery
          val conditions = outerReferenceMap.map {
            case (o, a) =>
              val cond = EqualNullSafe(a, OuterReference(o))
              // SPARK-40615: Certain data types (e.g. MapType) do not support ordering, so
              // the EqualNullSafe join condition can become unresolved.
              if (!cond.resolved) {
                if (!RowOrdering.isOrderable(a.dataType)) {
                  throw QueryCompilationErrors.unsupportedCorrelatedReferenceDataTypeError(
                    o, a.dataType, plan.origin)
                } else {
                  throw SparkException.internalError(s"Unable to decorrelate subquery: " +
                    s"join condition '${cond.sql}' cannot be resolved.")
                }
              }
              cond
          }
          (domainJoin, conditions.toSeq, AttributeMap(outerReferenceMap))
        }
      } else {
        plan match {
          case Filter(condition, child) =>
            val conditions = splitConjunctivePredicates(condition)
            val (correlated, uncorrelated) = conditions.partition(containsOuter)
            // Find outer references that can be substituted by attributes from the inner
            // query using the equality predicates.
            val equivalences = collectEquivalentOuterReferences(correlated)
            // Correlated predicates can be removed from the Filter's condition and used as
            // join conditions with the outer query. However, if the results of the sub-tree
            // is aggregated, only certain correlated equality predicates can be used, because
            // the references in the join conditions need to be preserved in both the grouping
            // and aggregate expressions of an Aggregate, which may change the semantics of the
            // plan and lead to incorrect results. Here is an example:
            // Relations:
            //   t1(a, b): [(1, 1)]
            //   t2(c, d): [(1, 1), (2, 0)]
            //
            // Query:
            //   SELECT * FROM t1 WHERE a = (SELECT MAX(c) FROM t2 WHERE b >= d)
            //
            // Subquery plan transformation if correlated predicates are used as join conditions:
            //   Aggregate [max(c)]               Aggregate [d] [max(c), d]
            //   +- Filter (outer(b) >= d)   =>   +- Relation [c, d]
            //      +- Relation [c, d]
            //
            // Plan after rewrite:
            //   Project [a, b]                                   -- [(1, 1)]
            //   +- Join LeftOuter (b >= d AND a = max(c))
            //      :- Relation [a, b]
            //      +- Aggregate [d] [max(c), d]                  -- [(1, 1), (2, 0)]
            //         +- Relation [c, d]
            //
            // The result of the original query should be an empty set but the transformed
            // query will output an incorrect result of (1, 1). The correct transformation
            // with domain join is illustrated below:
            //   Aggregate [max(c)]               Aggregate [b'] [max(c), b']
            //   +- Filter (outer(b) >= d)   =>   +- Filter (b' >= d)
            //      +- Relation [c, d]               +- DomainJoin [b']
            //                                          +- Relation [c, d]
            // Plan after rewrite:
            //   Project [a, b]
            //   +- Join LeftOuter (b <=> b' AND a = max(c))  -- []
            //      :- Relation [a, b]
            //      +- Aggregate [b'] [max(c), b']            -- [(2, 1)]
            //         +- Join Inner (b' >= d)                -- [(1, 1, 1), (2, 0, 1)] (DomainJoin)
            //            :- Relation [c, d]
            //            +- Aggregate [b] [b AS b']          -- [(1)] (Domain)
            //               +- Relation [a, b]
            if (aggregated) {
              // Split the correlated predicates into predicates that can and cannot be directly
              // used as join conditions with the outer query depending on whether they can
              // be pulled up over an Aggregate without changing the semantics of the plan.
              val (equalityCond, predicates) = correlated.partition(canPullUpOverAgg)
              val outerReferences = collectOuterReferences(predicates)
              val newOuterReferences =
                parentOuterReferences ++ outerReferences -- equivalences.keySet
              val (newChild, joinCond, outerReferenceMap) =
                decorrelate(child, newOuterReferences, aggregated)
              // Add the outer references mapping collected from the equality conditions.
              val newOuterReferenceMap = outerReferenceMap ++ equivalences
              // Replace all outer references in the non-equality predicates.
              val newCorrelated = replaceOuterReferences(predicates, newOuterReferenceMap)
              // The new filter condition is the original filter condition with correlated
              // equality predicates removed.
              val newFilterCond = newCorrelated ++ uncorrelated
              val newFilter = newFilterCond match {
                case Nil => newChild
                case conditions => Filter(conditions.reduce(And), newChild)
              }
              // Equality predicates are used as join conditions with the outer query.
              val newJoinCond = joinCond ++ equalityCond
              (newFilter, newJoinCond, newOuterReferenceMap)
            } else {
              // Results of this sub-tree is not aggregated, so all correlated predicates
              // can be directly used as outer query join conditions.
              val newOuterReferences = parentOuterReferences -- equivalences.keySet
              val (newChild, joinCond, outerReferenceMap) =
                decorrelate(child, newOuterReferences, aggregated)
              // Add the outer references mapping collected from the equality conditions.
              val newOuterReferenceMap = outerReferenceMap ++ equivalences
              val newFilter = uncorrelated match {
                case Nil => newChild
                case conditions => Filter(conditions.reduce(And), newChild)
              }
              val newJoinCond = joinCond ++ correlated
              (newFilter, newJoinCond, newOuterReferenceMap)
            }

          case Project(projectList, child) =>
            val outerReferences = collectOuterReferences(projectList)
            val newOuterReferences = parentOuterReferences ++ outerReferences
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(child, newOuterReferences, aggregated)
            // Replace all outer references in the original project list and keep the output
            // attributes unchanged.
            val newProjectList = replaceOuterInNamedExpressions(projectList, outerReferenceMap)
            // Preserve required domain attributes in the join condition by adding the missing
            // references to the new project list.
            val referencesToAdd = missingReferences(newProjectList, joinCond)
            val newProject = Project(newProjectList ++ referencesToAdd, newChild)
            (newProject, joinCond, outerReferenceMap)

          case a @ Aggregate(groupingExpressions, aggregateExpressions, child) =>
            val outerReferences = collectOuterReferences(a.expressions)
            val newOuterReferences = parentOuterReferences ++ outerReferences
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(child, newOuterReferences, aggregated = true)
            // Replace all outer references in grouping and aggregate expressions, and keep
            // the output attributes unchanged.
            val newGroupingExpr = replaceOuterReferences(groupingExpressions, outerReferenceMap)
            val newAggExpr = replaceOuterInNamedExpressions(aggregateExpressions, outerReferenceMap)
            // Add all required domain attributes to both grouping and aggregate expressions.
            val referencesToAdd = missingReferences(newAggExpr, joinCond)
            val newAggregate = a.copy(
              groupingExpressions = newGroupingExpr ++ referencesToAdd,
              aggregateExpressions = newAggExpr ++ referencesToAdd,
              child = newChild)

            // Preserving domain attributes over an Aggregate with an empty grouping expression
            // is subject to the "COUNT bug" that can lead to wrong answer:
            //
            // Suppose the original query is:
            //   SELECT a, (SELECT COUNT(*) cnt FROM t2 WHERE t1.a = t2.c) FROM t1
            //
            // Decorrelated plan:
            //   Project [a, scalar-subquery [a = c]]
            //   :  +- Aggregate [c] [count(*) AS cnt, c]
            //   :     +- Relation [c, d]
            //   +- Relation [a, b]
            //
            // After rewrite:
            //   Project [a, cnt]
            //   +- Join LeftOuter (a = c)
            //      :- Relation [a, b]
            //      +- Aggregate [c] [count(*) AS cnt, c]
            //         +- Relation [c, d]
            //
            //     T1            T2          T2' (GROUP BY c)
            // +---+---+     +---+---+     +---+-----+
            // | a | b |     | c | d |     | c | cnt |
            // +---+---+     +---+---+     +---+-----+
            // | 0 | 1 |     | 0 | 2 |     | 0 | 2   |
            // | 1 | 2 |     | 0 | 3 |     +---+-----+
            // +---+---+     +---+---+
            //
            // T1 nested loop join T2     T1 left outer join T2'
            // on (a = c):                on (a = c):
            // +---+-----+                +---+-----++
            // | a | cnt |                | a | cnt  |
            // +---+-----+                +---+------+
            // | 0 | 2   |                | 0 | 2    |
            // | 1 | 0   | <--- correct   | 1 | null | <--- wrong result
            // +---+-----+                +---+------+
            //
            // If an aggregate is subject to the COUNT bug:
            // 1) add a column `true AS alwaysTrue` to the result of the aggregate
            // 2) insert a left outer domain join between the outer query and this aggregate
            // 3) rewrite the original aggregate's output column using the default value of the
            //    aggregate function and the alwaysTrue column.
            //
            // For example, T1 left outer join T2' with `alwaysTrue` marker:
            // +---+------+------------+--------------------------------+
            // | c | cnt  | alwaysTrue | if(isnull(alwaysTrue), 0, cnt) |
            // +---+------+------------+--------------------------------+
            // | 0 | 2    | true       | 2                              |
            // | 0 | null | null       | 0                              |  <--- correct result
            // +---+------+------------+--------------------------------+
            if (groupingExpressions.isEmpty && handleCountBug) {
              // Evaluate the aggregate expressions with zero tuples.
              val resultMap = RewriteCorrelatedScalarSubquery.evalAggregateOnZeroTups(newAggregate)
              val alwaysTrue = Alias(Literal.TrueLiteral, "alwaysTrue")()
              val alwaysTrueRef = alwaysTrue.toAttribute.withNullability(true)
              val expressions = ArrayBuffer.empty[NamedExpression]
              // Create new aliases for aggregate expressions that have non-null default
              // values and reconstruct the output with the `alwaysTrue` marker.
              val projectList = newAggregate.aggregateExpressions.map { a =>
                resultMap.get(a.exprId) match {
                  // Aggregate expression is not subject to the count bug.
                  case Some(Literal(null, _)) | None =>
                    expressions += a
                    // The attribute is nullable since it is from the right-hand side of a
                    // left outer join.
                    a.toAttribute.withNullability(true)
                  case Some(default) =>
                    assert(a.isInstanceOf[Alias], s"Cannot have non-aliased expression $a in " +
                      s"aggregate that evaluates to non-null value with zero tuples.")
                    val newAttr = a.newInstance()
                    val ref = newAttr.toAttribute.withNullability(true)
                    expressions += newAttr
                    Alias(If(IsNull(alwaysTrueRef), default, ref), a.name)(a.exprId)
                }
              }
              // Insert a placeholder left outer domain join between the outer query and
              // and aggregate node and use the current collected join conditions as the
              // left outer join condition.
              //
              // Original subquery:
              //   Aggregate [count(1) AS cnt]
              //   +- Filter (a = outer(c))
              //      +- Relation [a, b]
              //
              // After decorrelation and before COUNT bug handling:
              //   Aggregate [a] [count(1) AS cnt, a]
              //   +- Relation [a, b]
              //
              // joinCond with the outer query: (a = outer(c))
              //
              // Handle the COUNT bug:
              //   Project [if(isnull(alwaysTrue), 0, cnt') AS cnt, c']
              //   +- DomainJoin [c'] LeftOuter (a = c')
              //      +- Aggregate [a] [count(1) AS cnt', a, true AS alwaysTrue]
              //         +- Relation [a, b]
              //
              // New joinCond with the outer query: (c' <=> outer(c)), and the DomainJoin
              // will be written as:
              //   Project [if(isnull(alwaysTrue), 0, cnt') AS cnt, c']
              //   +- Join LeftOuter (a = c')
              //      :- Aggregate [c] [c AS c']
              //      :  +- OuterQuery [c, d]
              //      +- Aggregate [a] [count(1) AS cnt', a, true AS alwaysTrue]
              //         +- Relation [a, b]
              //
              val agg = newAggregate.copy(aggregateExpressions = expressions.toSeq :+ alwaysTrue)
              // Find all outer references that are used in the join conditions.
              val outerAttrs = collectOuterReferences(joinCond).toSeq
              // Create new instance of the outer attributes as if they are generated inside
              // the subquery by a left outer join with the outer query. Use new instance here
              // to avoid conflicting join attributes with the inner query.
              val domainAttrs = outerAttrs.map(_.newInstance())
              val mapping = AttributeMap(outerAttrs.zip(domainAttrs))
              // Use the current join conditions returned from the recursive call as the join
              // conditions for the left outer join. All outer references in the join
              // conditions are replaced by the newly created domain attributes.
              val condition = replaceOuterReferences(joinCond, mapping).reduceOption(And)
              val domainJoin = DomainJoin(domainAttrs, agg, LeftOuter, condition)
              // Original domain attributes preserved through Aggregate are no longer needed.
              val newProjectList = projectList.filter(!referencesToAdd.contains(_))
              val project = Project(newProjectList ++ domainAttrs, domainJoin)
              val newJoinCond = outerAttrs.zip(domainAttrs).map { case (outer, inner) =>
                EqualNullSafe(inner, OuterReference(outer))
              }
              (project, newJoinCond, mapping)
            } else {
              (newAggregate, joinCond, outerReferenceMap)
            }

          case d: Distinct =>
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(d.child, parentOuterReferences, aggregated = true)
            (d.copy(child = newChild), joinCond, outerReferenceMap)

          case j @ Join(left, right, joinType, condition, _) =>
            val outerReferences = collectOuterReferences(j.expressions)
            // Join condition containing outer references is not supported.
            assert(outerReferences.isEmpty, s"Correlated column is not allowed in join: $j")
            val newOuterReferences = parentOuterReferences ++ outerReferences
            val shouldPushToLeft = joinType match {
              case LeftOuter | LeftSemiOrAnti(_) | FullOuter => true
              case _ => hasOuterReferences(left)
            }
            val shouldPushToRight = joinType match {
              case RightOuter | FullOuter => true
              case _ => hasOuterReferences(right)
            }
            val (newLeft, leftJoinCond, leftOuterReferenceMap) = if (shouldPushToLeft) {
              decorrelate(left, newOuterReferences, aggregated)
            } else {
              (left, Nil, AttributeMap.empty[Attribute])
            }
            val (newRight, rightJoinCond, rightOuterReferenceMap) = if (shouldPushToRight) {
              decorrelate(right, newOuterReferences, aggregated)
            } else {
              (right, Nil, AttributeMap.empty[Attribute])
            }
            val newOuterReferenceMap = leftOuterReferenceMap ++ rightOuterReferenceMap
            val newJoinCond = leftJoinCond ++ rightJoinCond
            // If we push the dependent join to both sides, we can augment the join condition
            // such that both sides are matched on the domain attributes. For example,
            // - Left Map: {outer(c1) = c1}
            // - Right Map: {outer(c1) = 10 - c1}
            // Then the join condition can be augmented with (c1 <=> 10 - c1).
            val augmentedConditions = leftOuterReferenceMap.flatMap {
              case (outer, inner) => rightOuterReferenceMap.get(outer).map(EqualNullSafe(inner, _))
            }
            val newCondition = (condition ++ augmentedConditions).reduceOption(And)
            val newJoin = j.copy(left = newLeft, right = newRight, condition = newCondition)
            (newJoin, newJoinCond, newOuterReferenceMap)

          case u: UnaryNode =>
            val outerReferences = collectOuterReferences(u.expressions)
            assert(outerReferences.isEmpty, s"Correlated column is not allowed in $u")
            val (newChild, joinCond, outerReferenceMap) =
              decorrelate(u.child, parentOuterReferences, aggregated)
            (u.withNewChildren(newChild :: Nil), joinCond, outerReferenceMap)

          case o =>
            throw QueryExecutionErrors.decorrelateInnerQueryThroughPlanUnsupportedError(o)
        }
      }
    }
    val (newChild, joinCond, _) = decorrelate(BooleanSimplification(innerPlan), AttributeSet.empty)
    val (plan, conditions) = deduplicate(newChild, joinCond, outputPlanInputAttrs)
    (plan, stripOuterReferences(conditions))
  }
}

相关信息

spark 源码目录

相关文章

spark ComplexTypes 源码

spark CostBasedJoinReorder 源码

spark EliminateResolvedHint 源码

spark InjectRuntimeFilter 源码

spark InlineCTE 源码

spark LimitPushDownThroughWindow 源码

spark MergeScalarSubqueries 源码

spark NestedColumnAliasing 源码

spark NormalizeFloatingNumbers 源码

spark OptimizeCsvJsonExprs 源码

0  赞