spark arithmetic 源码

  • 2022-10-20
  • 浏览 (181)

spark arithmetic 代码

文件路径:/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/arithmetic.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.expressions

import scala.math.{max, min}

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{FunctionRegistry, TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.Cast.{toSQLId, toSQLType}
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.SQLQueryContext
import org.apache.spark.sql.catalyst.trees.TreePattern.{BINARY_ARITHMETIC, TreePattern, UNARY_POSITIVE}
import org.apache.spark.sql.catalyst.util.{IntervalMathUtils, IntervalUtils, MathUtils, TypeUtils}
import org.apache.spark.sql.errors.QueryExecutionErrors
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.CalendarInterval

@ExpressionDescription(
  usage = "_FUNC_(expr) - Returns the negated value of `expr`.",
  examples = """
    Examples:
      > SELECT _FUNC_(1);
       -1
  """,
  since = "1.0.0",
  group = "math_funcs")
case class UnaryMinus(
    child: Expression,
    failOnError: Boolean = SQLConf.get.ansiEnabled)
  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {

  def this(child: Expression) = this(child, SQLConf.get.ansiEnabled)

  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

  override def dataType: DataType = child.dataType

  override def toString: String = s"-$child"

  private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError)

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
    case _: DecimalType => defineCodeGen(ctx, ev, c => s"$c.unary_$$minus()")
    case ByteType | ShortType if failOnError =>
      nullSafeCodeGen(ctx, ev, eval => {
        val javaBoxedType = CodeGenerator.boxedType(dataType)
        val javaType = CodeGenerator.javaType(dataType)
        val originValue = ctx.freshName("origin")
        s"""
           |$javaType $originValue = ($javaType)($eval);
           |if ($originValue == $javaBoxedType.MIN_VALUE) {
           |  throw QueryExecutionErrors.unaryMinusCauseOverflowError($originValue);
           |}
           |${ev.value} = ($javaType)(-($originValue));
           """.stripMargin
      })
    case IntegerType | LongType if failOnError =>
      val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
      nullSafeCodeGen(ctx, ev, eval => {
        s"${ev.value} = $mathUtils.negateExact($eval);"
      })
    case dt: NumericType => nullSafeCodeGen(ctx, ev, eval => {
      val originValue = ctx.freshName("origin")
      // codegen would fail to compile if we just write (-($c))
      // for example, we could not write --9223372036854775808L in code
      s"""
        ${CodeGenerator.javaType(dt)} $originValue = (${CodeGenerator.javaType(dt)})($eval);
        ${ev.value} = (${CodeGenerator.javaType(dt)})(-($originValue));
      """})
    case _: CalendarIntervalType =>
      val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
      val method = if (failOnError) "negateExact" else "negate"
      defineCodeGen(ctx, ev, c => s"$iu.$method($c)")
    case _: AnsiIntervalType =>
      nullSafeCodeGen(ctx, ev, eval => {
        val mathUtils = IntervalMathUtils.getClass.getCanonicalName.stripSuffix("$")
        s"${ev.value} = $mathUtils.negateExact($eval);"
      })
  }

  protected override def nullSafeEval(input: Any): Any = dataType match {
    case CalendarIntervalType if failOnError =>
      IntervalUtils.negateExact(input.asInstanceOf[CalendarInterval])
    case CalendarIntervalType => IntervalUtils.negate(input.asInstanceOf[CalendarInterval])
    case _: DayTimeIntervalType => IntervalMathUtils.negateExact(input.asInstanceOf[Long])
    case _: YearMonthIntervalType => IntervalMathUtils.negateExact(input.asInstanceOf[Int])
    case _ => numeric.negate(input)
  }

  override def sql: String = {
    getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse("-") match {
      case "-" => s"(- ${child.sql})"
      case funcName => s"$funcName(${child.sql})"
    }
  }

  override protected def withNewChildInternal(newChild: Expression): UnaryMinus =
    copy(child = newChild)
}

@ExpressionDescription(
  usage = "_FUNC_(expr) - Returns the value of `expr`.",
  examples = """
    Examples:
      > SELECT _FUNC_(1);
       1
  """,
  since = "1.5.0",
  group = "math_funcs")
case class UnaryPositive(child: Expression)
  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {

  override def prettyName: String = "positive"

  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndInterval)

  override def dataType: DataType = child.dataType

  final override val nodePatterns: Seq[TreePattern] = Seq(UNARY_POSITIVE)

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
    defineCodeGen(ctx, ev, c => c)

  protected override def nullSafeEval(input: Any): Any = input

  override def sql: String = s"(+ ${child.sql})"

  override protected def withNewChildInternal(newChild: Expression): UnaryPositive =
    copy(child = newChild)
}

/**
 * A function that get the absolute value of the numeric or interval value.
 */
@ExpressionDescription(
  usage = "_FUNC_(expr) - Returns the absolute value of the numeric or interval value.",
  examples = """
    Examples:
      > SELECT _FUNC_(-1);
       1
      > SELECT _FUNC_(INTERVAL -'1-1' YEAR TO MONTH);
       1-1
  """,
  since = "1.2.0",
  group = "math_funcs")
case class Abs(child: Expression, failOnError: Boolean = SQLConf.get.ansiEnabled)
  extends UnaryExpression with ImplicitCastInputTypes with NullIntolerant {

  def this(child: Expression) = this(child, SQLConf.get.ansiEnabled)

  override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection.NumericAndAnsiInterval)

  override def dataType: DataType = child.dataType

  private lazy val numeric = (dataType match {
    case _: DayTimeIntervalType => LongExactNumeric
    case _: YearMonthIntervalType => IntegerExactNumeric
    case _ => TypeUtils.getNumeric(dataType, failOnError)
  }).asInstanceOf[Numeric[Any]]

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
    case _: DecimalType =>
      defineCodeGen(ctx, ev, c => s"$c.abs()")

    case ByteType | ShortType if failOnError =>
      val javaBoxedType = CodeGenerator.boxedType(dataType)
      val javaType = CodeGenerator.javaType(dataType)
      nullSafeCodeGen(ctx, ev, eval =>
        s"""
          |if ($eval == $javaBoxedType.MIN_VALUE) {
          |  throw QueryExecutionErrors.unaryMinusCauseOverflowError($eval);
          |} else if ($eval < 0) {
          |  ${ev.value} = ($javaType)-$eval;
          |} else {
          |  ${ev.value} = $eval;
          |}
          |""".stripMargin)

    case IntegerType | LongType if failOnError =>
      val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
      defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c")

    case _: AnsiIntervalType =>
      val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
      defineCodeGen(ctx, ev, c => s"$c < 0 ? $mathUtils.negateExact($c) : $c")

    case dt: NumericType =>
      defineCodeGen(ctx, ev, c => s"(${CodeGenerator.javaType(dt)})(java.lang.Math.abs($c))")
  }

  protected override def nullSafeEval(input: Any): Any = numeric.abs(input)

  override def flatArguments: Iterator[Any] = Iterator(child)

  override protected def withNewChildInternal(newChild: Expression): Abs = copy(child = newChild)
}

abstract class BinaryArithmetic extends BinaryOperator
  with NullIntolerant with SupportQueryContext {

  protected val evalMode: EvalMode.Value

  protected def failOnError: Boolean = evalMode match {
    // The TRY mode executes as if it would fail on errors, except that it would capture the errors
    // and return null results.
    case EvalMode.ANSI | EvalMode.TRY => true
    case _ => false
  }

  override def checkInputDataTypes(): TypeCheckResult = (left.dataType, right.dataType) match {
    case (l: DecimalType, r: DecimalType) if inputType.acceptsType(l) && inputType.acceptsType(r) =>
      // We allow decimal type inputs with different precision and scale, and use special formulas
      // to calculate the result precision and scale.
      TypeCheckResult.TypeCheckSuccess
    case _ => super.checkInputDataTypes()
  }

  override def dataType: DataType = (left.dataType, right.dataType) match {
    case (DecimalType.Fixed(p1, s1), DecimalType.Fixed(p2, s2)) =>
      resultDecimalType(p1, s1, p2, s2)
    case _ => left.dataType
  }

  // When `spark.sql.decimalOperations.allowPrecisionLoss` is set to true, if the precision / scale
  // needed are out of the range of available values, the scale is reduced up to 6, in order to
  // prevent the truncation of the integer part of the decimals.
  protected def allowPrecisionLoss: Boolean = SQLConf.get.decimalOperationsAllowPrecisionLoss

  protected def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
    throw new IllegalStateException(
      s"${getClass.getSimpleName} must override `resultDecimalType`.")
  }

  override def nullable: Boolean = super.nullable || evalMode == EvalMode.TRY || {
    if (left.dataType.isInstanceOf[DecimalType]) {
      // For decimal arithmetic, we may return null even if both inputs are not null, if overflow
      // happens and this `failOnError` flag is false.
      evalMode != EvalMode.ANSI
    } else {
      // For non-decimal arithmetic, the calculation always return non-null result when inputs are
      // not null. If overflow happens, we return either the overflowed value or fail.
      false
    }
  }

  final override val nodePatterns: Seq[TreePattern] = Seq(BINARY_ARITHMETIC)

  override lazy val resolved: Boolean = childrenResolved && checkInputDataTypes().isSuccess

  override def initQueryContext(): Option[SQLQueryContext] = {
    if (failOnError) {
      Some(origin.context)
    } else {
      None
    }
  }

  protected def checkDecimalOverflow(value: Decimal, precision: Int, scale: Int): Decimal = {
    value.toPrecision(precision, scale, Decimal.ROUND_HALF_UP, !failOnError, getContextOrNull())
  }

  /** Name of the function for this expression on a [[Decimal]] type. */
  def decimalMethod: String =
    throw QueryExecutionErrors.notOverrideExpectedMethodsError("BinaryArithmetics",
      "decimalMethod", "genCode")

  /** Name of the function for this expression on a [[CalendarInterval]] type. */
  def calendarIntervalMethod: String =
    throw QueryExecutionErrors.notOverrideExpectedMethodsError("BinaryArithmetics",
      "calendarIntervalMethod", "genCode")

  protected def isAnsiInterval: Boolean = dataType.isInstanceOf[AnsiIntervalType]

  // Name of the function for the exact version of this expression in [[Math]].
  // If the option "spark.sql.ansi.enabled" is enabled and there is corresponding
  // function in [[Math]], the exact function will be called instead of evaluation with [[symbol]].
  def exactMathMethod: Option[String] = None

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = dataType match {
    case DecimalType.Fixed(precision, scale) =>
      val errorContextCode = getContextOrNullCode(ctx, failOnError)
      val updateIsNull = if (failOnError) {
        ""
      } else {
        s"${ev.isNull} = ${ev.value} == null;"
      }
      nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
        s"""
           |${ev.value} = $eval1.$decimalMethod($eval2).toPrecision(
           |  $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContextCode);
           |$updateIsNull
       """.stripMargin
      })
    case CalendarIntervalType =>
      val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
      defineCodeGen(ctx, ev, (eval1, eval2) => s"$iu.$calendarIntervalMethod($eval1, $eval2)")
    case _: AnsiIntervalType =>
      assert(exactMathMethod.isDefined,
        s"The expression '$nodeName' must override the exactMathMethod() method " +
        "if it is supposed to operate over interval types.")
      val mathUtils = IntervalMathUtils.getClass.getCanonicalName.stripSuffix("$")
      defineCodeGen(ctx, ev, (eval1, eval2) => s"$mathUtils.${exactMathMethod.get}($eval1, $eval2)")
    // byte and short are casted into int when add, minus, times or divide
    case ByteType | ShortType =>
      nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
        val tmpResult = ctx.freshName("tmpResult")
        val overflowCheck = if (failOnError) {
          val javaType = CodeGenerator.boxedType(dataType)
          s"""
             |if ($tmpResult < $javaType.MIN_VALUE || $tmpResult > $javaType.MAX_VALUE) {
             |  throw QueryExecutionErrors.binaryArithmeticCauseOverflowError(
             |  $eval1, "$symbol", $eval2);
             |}
           """.stripMargin
        } else {
          ""
        }
        s"""
           |${CodeGenerator.JAVA_INT} $tmpResult = $eval1 $symbol $eval2;
           |$overflowCheck
           |${ev.value} = (${CodeGenerator.javaType(dataType)})($tmpResult);
         """.stripMargin
      })
    case IntegerType | LongType if failOnError && exactMathMethod.isDefined =>
      nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
        val errorContext = getContextOrNullCode(ctx)
        val mathUtils = MathUtils.getClass.getCanonicalName.stripSuffix("$")
        s"""
           |${ev.value} = $mathUtils.${exactMathMethod.get}($eval1, $eval2, $errorContext);
         """.stripMargin
      })

    case IntegerType | LongType | DoubleType | FloatType =>
      // When Double/Float overflows, there can be 2 cases:
      // - precision loss: according to SQL standard, the number is truncated;
      // - returns (+/-)Infinite: same behavior also other DBs have (e.g. Postgres)
      nullSafeCodeGen(ctx, ev, (eval1, eval2) => {
        s"""
           |${ev.value} = $eval1 $symbol $eval2;
         """.stripMargin
      })
  }

  override def nullSafeCodeGen(
      ctx: CodegenContext,
      ev: ExprCode,
      f: (String, String) => String): ExprCode = {
    if (evalMode == EvalMode.TRY) {
      val tryBlock: (String, String) => String = (eval1, eval2) => {
        s"""
           |try {
           | ${f(eval1, eval2)}
           |} catch (Exception e) {
           | ${ev.isNull} = true;
           |}
           |""".stripMargin
      }
      super.nullSafeCodeGen(ctx, ev, tryBlock)
    } else {
      super.nullSafeCodeGen(ctx, ev, f)
    }
  }

  override def eval(input: InternalRow): Any = {
    val value1 = left.eval(input)
    if (value1 == null) {
      null
    } else {
      val value2 = right.eval(input)
      if (value2 == null) {
        null
      } else {
        if (evalMode == EvalMode.TRY) {
          try {
            nullSafeEval(value1, value2)
          } catch {
            case _: Exception =>
              null
          }
        } else {
          nullSafeEval(value1, value2)
        }
      }
    }
  }
}

object BinaryArithmetic {
  def unapply(e: BinaryArithmetic): Option[(Expression, Expression)] = Some((e.left, e.right))
}

@ExpressionDescription(
  usage = "expr1 _FUNC_ expr2 - Returns `expr1`+`expr2`.",
  examples = """
    Examples:
      > SELECT 1 _FUNC_ 2;
       3
  """,
  since = "1.0.0",
  group = "math_funcs")
case class Add(
    left: Expression,
    right: Expression,
    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic
  with CommutativeExpression {

  def this(left: Expression, right: Expression) =
    this(left, right, EvalMode.fromSQLConf(SQLConf.get))

  override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

  override def symbol: String = "+"

  override def decimalMethod: String = "$plus"

  // scalastyle:off
  // The formula follows Hive which is based on the SQL standard and MS SQL:
  // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
  // https://msdn.microsoft.com/en-us/library/ms190476.aspx
  // Result Precision: max(s1, s2) + max(p1-s1, p2-s2) + 1
  // Result Scale:     max(s1, s2)
  // scalastyle:on
  override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
    val resultScale = max(s1, s2)
    val resultPrecision = max(p1 - s1, p2 - s2) + resultScale + 1
    if (allowPrecisionLoss) {
      DecimalType.adjustPrecisionScale(resultPrecision, resultScale)
    } else {
      DecimalType.bounded(resultPrecision, resultScale)
    }
  }

  override def calendarIntervalMethod: String = if (failOnError) "addExact" else "add"

  private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError)

  protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
    case DecimalType.Fixed(precision, scale) =>
      checkDecimalOverflow(numeric.plus(input1, input2).asInstanceOf[Decimal], precision, scale)
    case CalendarIntervalType if failOnError =>
      IntervalUtils.addExact(
        input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
    case CalendarIntervalType =>
      IntervalUtils.add(
        input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
    case _: DayTimeIntervalType =>
      IntervalMathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
    case _: YearMonthIntervalType =>
      IntervalMathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
    case _: IntegerType if failOnError =>
      MathUtils.addExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int], getContextOrNull())
    case _: LongType if failOnError =>
      MathUtils.addExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long], getContextOrNull())
    case _ => numeric.plus(input1, input2)
  }

  override def exactMathMethod: Option[String] = Some("addExact")

  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Add =
    copy(left = newLeft, right = newRight)

  override lazy val canonicalized: Expression = {
    // TODO: do not reorder consecutive `Add`s with different `evalMode`
    orderCommutative({ case Add(l, r, _) => Seq(l, r) }).reduce(Add(_, _, evalMode))
  }
}

@ExpressionDescription(
  usage = "expr1 _FUNC_ expr2 - Returns `expr1`-`expr2`.",
  examples = """
    Examples:
      > SELECT 2 _FUNC_ 1;
       1
  """,
  since = "1.0.0",
  group = "math_funcs")
case class Subtract(
    left: Expression,
    right: Expression,
    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic {

  def this(left: Expression, right: Expression) =
    this(left, right, EvalMode.fromSQLConf(SQLConf.get))

  override def inputType: AbstractDataType = TypeCollection.NumericAndInterval

  override def symbol: String = "-"

  override def decimalMethod: String = "$minus"

  // scalastyle:off
  // The formula follows Hive which is based on the SQL standard and MS SQL:
  // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
  // https://msdn.microsoft.com/en-us/library/ms190476.aspx
  // Result Precision: max(s1, s2) + max(p1-s1, p2-s2) + 1
  // Result Scale:     max(s1, s2)
  // scalastyle:on
  override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
    val resultScale = max(s1, s2)
    val resultPrecision = max(p1 - s1, p2 - s2) + resultScale + 1
    if (allowPrecisionLoss) {
      DecimalType.adjustPrecisionScale(resultPrecision, resultScale)
    } else {
      DecimalType.bounded(resultPrecision, resultScale)
    }
  }

  override def calendarIntervalMethod: String = if (failOnError) "subtractExact" else "subtract"

  private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError)

  protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
    case DecimalType.Fixed(precision, scale) =>
      checkDecimalOverflow(numeric.minus(input1, input2).asInstanceOf[Decimal], precision, scale)
    case CalendarIntervalType if failOnError =>
      IntervalUtils.subtractExact(
        input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
    case CalendarIntervalType =>
      IntervalUtils.subtract(
        input1.asInstanceOf[CalendarInterval], input2.asInstanceOf[CalendarInterval])
    case _: DayTimeIntervalType =>
      IntervalMathUtils.subtractExact(input1.asInstanceOf[Long], input2.asInstanceOf[Long])
    case _: YearMonthIntervalType =>
      IntervalMathUtils.subtractExact(input1.asInstanceOf[Int], input2.asInstanceOf[Int])
    case _: IntegerType if failOnError =>
      MathUtils.subtractExact(
        input1.asInstanceOf[Int],
        input2.asInstanceOf[Int],
        getContextOrNull())
    case _: LongType if failOnError =>
      MathUtils.subtractExact(
        input1.asInstanceOf[Long],
        input2.asInstanceOf[Long],
        getContextOrNull())
    case _ => numeric.minus(input1, input2)
  }

  override def exactMathMethod: Option[String] = Some("subtractExact")

  override protected def withNewChildrenInternal(
    newLeft: Expression, newRight: Expression): Subtract = copy(left = newLeft, right = newRight)
}

@ExpressionDescription(
  usage = "expr1 _FUNC_ expr2 - Returns `expr1`*`expr2`.",
  examples = """
    Examples:
      > SELECT 2 _FUNC_ 3;
       6
  """,
  since = "1.0.0",
  group = "math_funcs")
case class Multiply(
    left: Expression,
    right: Expression,
    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic
  with CommutativeExpression {

  def this(left: Expression, right: Expression) =
    this(left, right, EvalMode.fromSQLConf(SQLConf.get))

  override def inputType: AbstractDataType = NumericType

  override def symbol: String = "*"
  override def decimalMethod: String = "$times"

  // scalastyle:off
  // The formula follows Hive which is based on the SQL standard and MS SQL:
  // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
  // https://msdn.microsoft.com/en-us/library/ms190476.aspx
  // Result Precision: p1 + p2 + 1
  // Result Scale:     s1 + s2
  // scalastyle:on
  override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
    val resultScale = s1 + s2
    val resultPrecision = p1 + p2 + 1
    if (allowPrecisionLoss) {
      DecimalType.adjustPrecisionScale(resultPrecision, resultScale)
    } else {
      DecimalType.bounded(resultPrecision, resultScale)
    }
  }

  private lazy val numeric = TypeUtils.getNumeric(dataType, failOnError)

  protected override def nullSafeEval(input1: Any, input2: Any): Any = dataType match {
    case DecimalType.Fixed(precision, scale) =>
      checkDecimalOverflow(numeric.times(input1, input2).asInstanceOf[Decimal], precision, scale)
    case _: IntegerType if failOnError =>
      MathUtils.multiplyExact(
        input1.asInstanceOf[Int],
        input2.asInstanceOf[Int],
        getContextOrNull())
    case _: LongType if failOnError =>
      MathUtils.multiplyExact(
        input1.asInstanceOf[Long],
        input2.asInstanceOf[Long],
        getContextOrNull())
    case _ => numeric.times(input1, input2)
  }

  override def exactMathMethod: Option[String] = Some("multiplyExact")

  override protected def withNewChildrenInternal(
    newLeft: Expression, newRight: Expression): Multiply = copy(left = newLeft, right = newRight)

  override lazy val canonicalized: Expression = {
    // TODO: do not reorder consecutive `Multiply`s with different `evalMode`
    orderCommutative({ case Multiply(l, r, _) => Seq(l, r) }).reduce(Multiply(_, _, evalMode))
  }
}

// Common base trait for Divide and Remainder, since these two classes are almost identical
trait DivModLike extends BinaryArithmetic {

  protected def decimalToDataTypeCodeGen(decimalResult: String): String = decimalResult

  // Whether we should check overflow or not in ANSI mode.
  protected def checkDivideOverflow: Boolean = false

  override def nullable: Boolean = true

  private lazy val isZero: Any => Boolean = right.dataType match {
    case _: DecimalType => x => x.asInstanceOf[Decimal].isZero
    case _ => x => x == 0
  }

  final override def eval(input: InternalRow): Any = {
    // evaluate right first as we have a chance to skip left if right is 0
    val input2 = right.eval(input)
    if (input2 == null || (!failOnError && isZero(input2))) {
      null
    } else {
      val input1 = left.eval(input)
      if (input1 == null) {
        null
      } else {
        if (isZero(input2)) {
          // when we reach here, failOnError must be true.
          throw QueryExecutionErrors.divideByZeroError(getContextOrNull())
        }
        if (checkDivideOverflow && input1 == Long.MinValue && input2 == -1) {
          throw QueryExecutionErrors.overflowInIntegralDivideError(getContextOrNull())
        }
        evalOperation(input1, input2)
      }
    }
  }

  def evalOperation(left: Any, right: Any): Any

  /**
   * Special case handling due to division/remainder by 0 => null or ArithmeticException.
   */
  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val eval1 = left.genCode(ctx)
    val eval2 = right.genCode(ctx)
    val operandsDataType = left.dataType
    val isZero = if (operandsDataType.isInstanceOf[DecimalType]) {
      s"${eval2.value}.isZero()"
    } else {
      s"${eval2.value} == 0"
    }
    val javaType = CodeGenerator.javaType(dataType)
    val errorContextCode = getContextOrNullCode(ctx, failOnError)
    val operation = super.dataType match {
      case DecimalType.Fixed(precision, scale) =>
        val decimalValue = ctx.freshName("decimalValue")
        s"""
           |Decimal $decimalValue = ${eval1.value}.$decimalMethod(${eval2.value}).toPrecision(
           |  $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContextCode);
           |if ($decimalValue != null) {
           |  ${ev.value} = ${decimalToDataTypeCodeGen(s"$decimalValue")};
           |} else {
           |  ${ev.isNull} = true;
           |}
           |""".stripMargin
      case _ => s"${ev.value} = ($javaType)(${eval1.value} $symbol ${eval2.value});"
    }
    val checkIntegralDivideOverflow = if (checkDivideOverflow) {
      s"""
        |if (${eval1.value} == ${Long.MinValue}L && ${eval2.value} == -1)
        |  throw QueryExecutionErrors.overflowInIntegralDivideError($errorContextCode);
        |""".stripMargin
    } else {
      ""
    }

    // evaluate right first as we have a chance to skip left if right is 0
    if (!left.nullable && !right.nullable) {
      val divByZero = if (failOnError) {
        s"throw QueryExecutionErrors.divideByZeroError($errorContextCode);"
      } else {
        s"${ev.isNull} = true;"
      }
      ev.copy(code = code"""
        ${eval2.code}
        boolean ${ev.isNull} = false;
        $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
        if ($isZero) {
          $divByZero
        } else {
          ${eval1.code}
          $checkIntegralDivideOverflow
          $operation
        }""")
    } else {
      val nullOnErrorCondition = if (failOnError) "" else s" || $isZero"
      val failOnErrorBranch = if (failOnError) {
        s"if ($isZero) throw QueryExecutionErrors.divideByZeroError($errorContextCode);"
      } else {
        ""
      }
      ev.copy(code = code"""
        ${eval2.code}
        boolean ${ev.isNull} = false;
        $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
        if (${eval2.isNull}$nullOnErrorCondition) {
          ${ev.isNull} = true;
        } else {
          ${eval1.code}
          if (${eval1.isNull}) {
            ${ev.isNull} = true;
          } else {
            $failOnErrorBranch
            $checkIntegralDivideOverflow
            $operation
          }
        }""")
    }
  }
}

// scalastyle:off line.size.limit
@ExpressionDescription(
  usage = "expr1 _FUNC_ expr2 - Returns `expr1`/`expr2`. It always performs floating point division.",
  examples = """
    Examples:
      > SELECT 3 _FUNC_ 2;
       1.5
      > SELECT 2L _FUNC_ 2L;
       1.0
  """,
  since = "1.0.0",
  group = "math_funcs")
// scalastyle:on line.size.limit
case class Divide(
    left: Expression,
    right: Expression,
    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike {

  def this(left: Expression, right: Expression) =
    this(left, right, EvalMode.fromSQLConf(SQLConf.get))

  // `try_divide` has exactly the same behavior as the legacy divide, so here it only executes
  // the error code path when `evalMode` is `ANSI`.
  protected override def failOnError: Boolean = evalMode == EvalMode.ANSI

  override def inputType: AbstractDataType = TypeCollection(DoubleType, DecimalType)

  override def symbol: String = "/"
  override def decimalMethod: String = "$div"

  // scalastyle:off
  // The formula follows Hive which is based on the SQL standard and MS SQL:
  // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
  // https://msdn.microsoft.com/en-us/library/ms190476.aspx
  // Result Precision: p1 - s1 + s2 + max(6, s1 + p2 + 1)
  // Result Scale:     max(6, s1 + p2 + 1)
  // scalastyle:on
  override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
    if (allowPrecisionLoss) {
      val intDig = p1 - s1 + s2
      val scale = max(DecimalType.MINIMUM_ADJUSTED_SCALE, s1 + p2 + 1)
      val prec = intDig + scale
      DecimalType.adjustPrecisionScale(prec, scale)
    } else {
      var intDig = min(DecimalType.MAX_SCALE, p1 - s1 + s2)
      var decDig = min(DecimalType.MAX_SCALE, max(6, s1 + p2 + 1))
      val diff = (intDig + decDig) - DecimalType.MAX_SCALE
      if (diff > 0) {
        decDig -= diff / 2 + 1
        intDig = DecimalType.MAX_SCALE - decDig
      }
      DecimalType.bounded(intDig + decDig, decDig)
    }
  }

  private lazy val div: (Any, Any) => Any = dataType match {
    case d @ DecimalType.Fixed(precision, scale) => (l, r) => {
      val value = d.fractional.asInstanceOf[Fractional[Any]].div(l, r)
      checkDecimalOverflow(value.asInstanceOf[Decimal], precision, scale)
    }
    case ft: FractionalType => ft.fractional.asInstanceOf[Fractional[Any]].div
  }

  override def evalOperation(left: Any, right: Any): Any = div(left, right)

  override protected def withNewChildrenInternal(
    newLeft: Expression, newRight: Expression): Divide = copy(left = newLeft, right = newRight)
}

// scalastyle:off line.size.limit
@ExpressionDescription(
  usage = "expr1 _FUNC_ expr2 - Divide `expr1` by `expr2`. It returns NULL if an operand is NULL or `expr2` is 0. The result is casted to long.",
  examples = """
    Examples:
      > SELECT 3 _FUNC_ 2;
       1
      > SELECT INTERVAL '1-1' YEAR TO MONTH _FUNC_ INTERVAL '-1' MONTH;
       -13
  """,
  since = "3.0.0",
  group = "math_funcs")
// scalastyle:on line.size.limit
case class IntegralDivide(
    left: Expression,
    right: Expression,
    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike {

  def this(left: Expression, right: Expression) = this(left, right,
    EvalMode.fromSQLConf(SQLConf.get))

  override def checkDivideOverflow: Boolean = left.dataType match {
    case LongType if failOnError => true
    case _ => false
  }

  override def inputType: AbstractDataType = TypeCollection(
    LongType, DecimalType, YearMonthIntervalType, DayTimeIntervalType)

  override def dataType: DataType = LongType

  override def symbol: String = "/"
  override def decimalMethod: String = "quot"
  override def decimalToDataTypeCodeGen(decimalResult: String): String = s"$decimalResult.toLong()"

  override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
    // This follows division rule
    val intDig = p1 - s1 + s2
    // No precision loss can happen as the result scale is 0.
    DecimalType.bounded(intDig, 0)
  }

  override def sqlOperator: String = "div"

  private lazy val div: (Any, Any) => Any = {
    val integral = left.dataType match {
      case i: IntegralType =>
        i.integral.asInstanceOf[Integral[Any]]
      case d: DecimalType =>
        d.asIntegral.asInstanceOf[Integral[Any]]
      case _: YearMonthIntervalType =>
        IntegerType.integral.asInstanceOf[Integral[Any]]
      case _: DayTimeIntervalType =>
        LongType.integral.asInstanceOf[Integral[Any]]
    }
    (x, y) => {
      val res = super.dataType match {
        case DecimalType.Fixed(precision, scale) =>
          checkDecimalOverflow(integral.quot(x, y).asInstanceOf[Decimal], precision, scale)
        case _ => integral.quot(x, y)
      }
      if (res == null) {
        null
      } else {
        integral.toLong(res)
      }
    }
  }

  override def evalOperation(left: Any, right: Any): Any = div(left, right)

  override protected def withNewChildrenInternal(
      newLeft: Expression, newRight: Expression): IntegralDivide =
    copy(left = newLeft, right = newRight)
}

@ExpressionDescription(
  usage = "expr1 _FUNC_ expr2 - Returns the remainder after `expr1`/`expr2`.",
  examples = """
    Examples:
      > SELECT 2 % 1.8;
       0.2
      > SELECT MOD(2, 1.8);
       0.2
  """,
  since = "1.0.0",
  group = "math_funcs")
case class Remainder(
    left: Expression,
    right: Expression,
    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends DivModLike {

  def this(left: Expression, right: Expression) =
    this(left, right, EvalMode.fromSQLConf(SQLConf.get))

  override def inputType: AbstractDataType = NumericType

  override def symbol: String = "%"
  override def decimalMethod: String = "remainder"

  // scalastyle:off
  // The formula follows Hive which is based on the SQL standard and MS SQL:
  // https://cwiki.apache.org/confluence/download/attachments/27362075/Hive_Decimal_Precision_Scale_Support.pdf
  // https://msdn.microsoft.com/en-us/library/ms190476.aspx
  // Result Precision: min(p1-s1, p2-s2) + max(s1, s2)
  // Result Scale:     max(s1, s2)
  // scalastyle:on
  override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
    val resultScale = max(s1, s2)
    val resultPrecision = min(p1 - s1, p2 - s2) + resultScale
    if (allowPrecisionLoss) {
      DecimalType.adjustPrecisionScale(resultPrecision, resultScale)
    } else {
      DecimalType.bounded(resultPrecision, resultScale)
    }
  }

  override def toString: String = {
    getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(sqlOperator) match {
      case operator if operator == sqlOperator => s"($left $sqlOperator $right)"
      case funcName => s"$funcName($left, $right)"
    }
  }
  override def sql: String = {
    getTagValue(FunctionRegistry.FUNC_ALIAS).getOrElse(sqlOperator) match {
      case operator if operator == sqlOperator => s"(${left.sql} $sqlOperator ${right.sql})"
      case funcName => s"$funcName(${left.sql}, ${right.sql})"
    }
  }

  private lazy val mod: (Any, Any) => Any = dataType match {
    // special cases to make float/double primitive types faster
    case DoubleType =>
      (left, right) => left.asInstanceOf[Double] % right.asInstanceOf[Double]
    case FloatType =>
      (left, right) => left.asInstanceOf[Float] % right.asInstanceOf[Float]

    // catch-all cases
    case i: IntegralType =>
      val integral = i.integral.asInstanceOf[Integral[Any]]
      (left, right) => integral.rem(left, right)

    case d @ DecimalType.Fixed(precision, scale) =>
      val integral = d.asIntegral.asInstanceOf[Integral[Any]]
      (left, right) =>
        checkDecimalOverflow(integral.rem(left, right).asInstanceOf[Decimal], precision, scale)
  }

  override def evalOperation(left: Any, right: Any): Any = mod(left, right)

  override protected def withNewChildrenInternal(
    newLeft: Expression, newRight: Expression): Remainder = copy(left = newLeft, right = newRight)
}

@ExpressionDescription(
  usage = "_FUNC_(expr1, expr2) - Returns the positive value of `expr1` mod `expr2`.",
  examples = """
    Examples:
      > SELECT _FUNC_(10, 3);
       1
      > SELECT _FUNC_(-10, 3);
       2
  """,
  since = "1.5.0",
  group = "math_funcs")
case class Pmod(
    left: Expression,
    right: Expression,
    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get)) extends BinaryArithmetic {

  def this(left: Expression, right: Expression) =
    this(left, right, EvalMode.fromSQLConf(SQLConf.get))

  override def toString: String = s"pmod($left, $right)"

  override def symbol: String = "pmod"

  override def inputType: AbstractDataType = NumericType

  override def nullable: Boolean = true

  override def decimalMethod: String = "remainder"

  // This follows Remainder rule
  override def resultDecimalType(p1: Int, s1: Int, p2: Int, s2: Int): DecimalType = {
    val resultScale = max(s1, s2)
    val resultPrecision = min(p1 - s1, p2 - s2) + resultScale
    if (allowPrecisionLoss) {
      DecimalType.adjustPrecisionScale(resultPrecision, resultScale)
    } else {
      DecimalType.bounded(resultPrecision, resultScale)
    }
  }

  private lazy val isZero: Any => Boolean = right.dataType match {
    case _: DecimalType => x => x.asInstanceOf[Decimal].isZero
    case _ => x => x == 0
  }

  private lazy val pmodFunc: (Any, Any) => Any = dataType match {
    case _: IntegerType => (l, r) => pmod(l.asInstanceOf[Int], r.asInstanceOf[Int])
    case _: LongType => (l, r) => pmod(l.asInstanceOf[Long], r.asInstanceOf[Long])
    case _: ShortType => (l, r) => pmod(l.asInstanceOf[Short], r.asInstanceOf[Short])
    case _: ByteType => (l, r) => pmod(l.asInstanceOf[Byte], r.asInstanceOf[Byte])
    case _: FloatType => (l, r) => pmod(l.asInstanceOf[Float], r.asInstanceOf[Float])
    case _: DoubleType => (l, r) => pmod(l.asInstanceOf[Double], r.asInstanceOf[Double])
    case DecimalType.Fixed(precision, scale) => (l, r) => checkDecimalOverflow(
      pmod(l.asInstanceOf[Decimal], r.asInstanceOf[Decimal]), precision, scale)
  }

  final override def eval(input: InternalRow): Any = {
    // evaluate right first as we have a chance to skip left if right is 0
    val input2 = right.eval(input)
    if (input2 == null || (!failOnError && isZero(input2))) {
      null
    } else {
      val input1 = left.eval(input)
      if (input1 == null) {
        null
      } else {
        if (isZero(input2)) {
          // when we reach here, failOnError must bet true.
          throw QueryExecutionErrors.divideByZeroError(getContextOrNull())
        }
        pmodFunc(input1, input2)
      }
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val eval1 = left.genCode(ctx)
    val eval2 = right.genCode(ctx)
    val isZero = if (dataType.isInstanceOf[DecimalType]) {
      s"${eval2.value}.isZero()"
    } else {
      s"${eval2.value} == 0"
    }
    val remainder = ctx.freshName("remainder")
    val javaType = CodeGenerator.javaType(dataType)
    val errorContext = getContextOrNullCode(ctx)
    val result = dataType match {
      case DecimalType.Fixed(precision, scale) =>
        val decimalAdd = "$plus"
        s"""
           |$javaType $remainder = ${eval1.value}.$decimalMethod(${eval2.value});
           |if ($remainder.compare(new org.apache.spark.sql.types.Decimal().set(0)) < 0) {
           |  ${ev.value}=($remainder.$decimalAdd(${eval2.value})).$decimalMethod(${eval2.value});
           |} else {
           |  ${ev.value}=$remainder;
           |}
           |${ev.value} = ${ev.value}.toPrecision(
           |  $precision, $scale, Decimal.ROUND_HALF_UP(), ${!failOnError}, $errorContext);
           |${ev.isNull} = ${ev.value} == null;
           |""".stripMargin

      // byte and short are casted into int when add, minus, times or divide
      case ByteType | ShortType =>
        s"""
          $javaType $remainder = ($javaType)(${eval1.value} % ${eval2.value});
          if ($remainder < 0) {
            ${ev.value}=($javaType)(($remainder + ${eval2.value}) % ${eval2.value});
          } else {
            ${ev.value}=$remainder;
          }
        """
      case _ =>
        s"""
          $javaType $remainder = ${eval1.value} % ${eval2.value};
          if ($remainder < 0) {
            ${ev.value}=($remainder + ${eval2.value}) % ${eval2.value};
          } else {
            ${ev.value}=$remainder;
          }
        """
    }

    // evaluate right first as we have a chance to skip left if right is 0
    if (!left.nullable && !right.nullable) {
      val divByZero = if (failOnError) {
        s"throw QueryExecutionErrors.divideByZeroError($errorContext);"
      } else {
        s"${ev.isNull} = true;"
      }
      ev.copy(code = code"""
        ${eval2.code}
        boolean ${ev.isNull} = false;
        $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
        if ($isZero) {
          $divByZero
        } else {
          ${eval1.code}
          $result
        }""")
    } else {
      val nullOnErrorCondition = if (failOnError) "" else s" || $isZero"
      val failOnErrorBranch = if (failOnError) {
        s"if ($isZero) throw QueryExecutionErrors.divideByZeroError($errorContext);"
      } else {
        ""
      }
      ev.copy(code = code"""
        ${eval2.code}
        boolean ${ev.isNull} = false;
        $javaType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
        if (${eval2.isNull}$nullOnErrorCondition) {
          ${ev.isNull} = true;
        } else {
          ${eval1.code}
          if (${eval1.isNull}) {
            ${ev.isNull} = true;
          } else {
            $failOnErrorBranch
            $result
          }
        }""")
    }
  }

  private def pmod(a: Int, n: Int): Int = {
    val r = a % n
    if (r < 0) {(r + n) % n} else r
  }

  private def pmod(a: Long, n: Long): Long = {
    val r = a % n
    if (r < 0) {(r + n) % n} else r
  }

  private def pmod(a: Byte, n: Byte): Byte = {
    val r = a % n
    if (r < 0) {((r + n) % n).toByte} else r.toByte
  }

  private def pmod(a: Double, n: Double): Double = {
    val r = a % n
    if (r < 0) {(r + n) % n} else r
  }

  private def pmod(a: Short, n: Short): Short = {
    val r = a % n
    if (r < 0) {((r + n) % n).toShort} else r.toShort
  }

  private def pmod(a: Float, n: Float): Float = {
    val r = a % n
    if (r < 0) {(r + n) % n} else r
  }

  private def pmod(a: Decimal, n: Decimal): Decimal = {
    val r = a % n
    if (r != null && r.compare(Decimal.ZERO) < 0) {(r + n) % n} else r
  }

  override def sql: String = s"$prettyName(${left.sql}, ${right.sql})"

  override protected def withNewChildrenInternal(newLeft: Expression, newRight: Expression): Pmod =
    copy(left = newLeft, right = newRight)
}

/**
 * A function that returns the least value of all parameters, skipping null values.
 * It takes at least 2 parameters, and returns null iff all parameters are null.
 */
@ExpressionDescription(
  usage = "_FUNC_(expr, ...) - Returns the least value of all parameters, skipping null values.",
  examples = """
    Examples:
      > SELECT _FUNC_(10, 9, 2, 4, 3);
       2
  """,
  since = "1.5.0",
  group = "math_funcs")
case class Least(children: Seq[Expression]) extends ComplexTypeMergingExpression
  with CommutativeExpression {

  override def nullable: Boolean = children.forall(_.nullable)
  override def foldable: Boolean = children.forall(_.foldable)

  private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

  override def checkInputDataTypes(): TypeCheckResult = {
    if (children.length <= 1) {
      DataTypeMismatch(
        errorSubClass = "WRONG_NUM_PARAMS",
        messageParameters = Map("actualNum" -> children.length.toString))
    } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) {
      DataTypeMismatch(
        errorSubClass = "DATA_DIFF_TYPES",
        messageParameters = Map(
          "functionName" -> toSQLId(prettyName),
          "dataType" -> children.map(_.dataType).map(toSQLType).mkString("[", ", ", "]")
        )
      )
    } else {
      TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
    }
  }

  override def eval(input: InternalRow): Any = {
    children.foldLeft[Any](null)((r, c) => {
      val evalc = c.eval(input)
      if (evalc != null) {
        if (r == null || ordering.lt(evalc, r)) evalc else r
      } else {
        r
      }
    })
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val evalChildren = children.map(_.genCode(ctx))
    ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull))
    val evals = evalChildren.map(eval =>
      s"""
         |${eval.code}
         |${ctx.reassignIfSmaller(dataType, ev, eval)}
      """.stripMargin
    )

    val resultType = CodeGenerator.javaType(dataType)
    val codes = ctx.splitExpressionsWithCurrentInputs(
      expressions = evals,
      funcName = "least",
      extraArguments = Seq(resultType -> ev.value),
      returnType = resultType,
      makeSplitFunction = body =>
        s"""
          |$body
          |return ${ev.value};
        """.stripMargin,
      foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
    ev.copy(code =
      code"""
         |${ev.isNull} = true;
         |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
         |$codes
      """.stripMargin)
  }

  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Least =
    copy(children = newChildren)

  override lazy val canonicalized: Expression = {
    Least(orderCommutative({ case Least(children) => children }))
  }
}

/**
 * A function that returns the greatest value of all parameters, skipping null values.
 * It takes at least 2 parameters, and returns null iff all parameters are null.
 */
@ExpressionDescription(
  usage = "_FUNC_(expr, ...) - Returns the greatest value of all parameters, skipping null values.",
  examples = """
    Examples:
      > SELECT _FUNC_(10, 9, 2, 4, 3);
       10
  """,
  since = "1.5.0",
  group = "math_funcs")
case class Greatest(children: Seq[Expression]) extends ComplexTypeMergingExpression
  with CommutativeExpression {

  override def nullable: Boolean = children.forall(_.nullable)
  override def foldable: Boolean = children.forall(_.foldable)

  private lazy val ordering = TypeUtils.getInterpretedOrdering(dataType)

  override def checkInputDataTypes(): TypeCheckResult = {
    if (children.length <= 1) {
      DataTypeMismatch(
        errorSubClass = "WRONG_NUM_PARAMS",
        messageParameters = Map("actualNum" -> children.length.toString))
    } else if (!TypeCoercion.haveSameType(inputTypesForMerging)) {
      DataTypeMismatch(
        errorSubClass = "DATA_DIFF_TYPES",
        messageParameters = Map(
          "functionName" -> toSQLId(prettyName),
          "dataType" -> children.map(_.dataType).map(toSQLType).mkString("[", ", ", "]")
        )
      )
    } else {
      TypeUtils.checkForOrderingExpr(dataType, s"function $prettyName")
    }
  }

  override def eval(input: InternalRow): Any = {
    children.foldLeft[Any](null)((r, c) => {
      val evalc = c.eval(input)
      if (evalc != null) {
        if (r == null || ordering.gt(evalc, r)) evalc else r
      } else {
        r
      }
    })
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val evalChildren = children.map(_.genCode(ctx))
    ev.isNull = JavaCode.isNullGlobal(ctx.addMutableState(CodeGenerator.JAVA_BOOLEAN, ev.isNull))
    val evals = evalChildren.map(eval =>
      s"""
         |${eval.code}
         |${ctx.reassignIfGreater(dataType, ev, eval)}
      """.stripMargin
    )

    val resultType = CodeGenerator.javaType(dataType)
    val codes = ctx.splitExpressionsWithCurrentInputs(
      expressions = evals,
      funcName = "greatest",
      extraArguments = Seq(resultType -> ev.value),
      returnType = resultType,
      makeSplitFunction = body =>
        s"""
           |$body
           |return ${ev.value};
        """.stripMargin,
      foldFunctions = _.map(funcCall => s"${ev.value} = $funcCall;").mkString("\n"))
    ev.copy(code =
      code"""
         |${ev.isNull} = true;
         |$resultType ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
         |$codes
      """.stripMargin)
  }

  override protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]): Greatest =
    copy(children = newChildren)

  override lazy val canonicalized: Expression = {
    Greatest(orderCommutative({ case Greatest(children) => children }))
  }
}

相关信息

spark 源码目录

相关文章

spark AliasHelper 源码

spark ApplyFunctionExpression 源码

spark AttributeSet 源码

spark BloomFilterMightContain 源码

spark BoundAttribute 源码

spark CallMethodViaReflection 源码

spark Cast 源码

spark CodeGeneratorWithInterpretedFallback 源码

spark DynamicPruning 源码

spark EquivalentExpressions 源码

0  赞