spark Cast 源码

  • 2022-10-20
  • 浏览 (242)

spark Cast 代码

文件路径:/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/Cast.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.expressions

import java.time.{ZoneId, ZoneOffset}
import java.util.Locale
import java.util.concurrent.TimeUnit._

import org.apache.spark.SparkArithmeticException
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.{TypeCheckResult, TypeCoercion}
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.DataTypeMismatch
import org.apache.spark.sql.catalyst.expressions.codegen._
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.{SQLQueryContext, TreeNodeTag}
import org.apache.spark.sql.catalyst.trees.TreePattern._
import org.apache.spark.sql.catalyst.util._
import org.apache.spark.sql.catalyst.util.DateTimeConstants._
import org.apache.spark.sql.catalyst.util.DateTimeUtils._
import org.apache.spark.sql.catalyst.util.IntervalStringStyles.ANSI_STYLE
import org.apache.spark.sql.catalyst.util.IntervalUtils.{dayTimeIntervalToByte, dayTimeIntervalToDecimal, dayTimeIntervalToInt, dayTimeIntervalToLong, dayTimeIntervalToShort, yearMonthIntervalToByte, yearMonthIntervalToInt, yearMonthIntervalToShort}
import org.apache.spark.sql.errors.{QueryErrorsBase, QueryExecutionErrors}
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.UTF8StringBuilder
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}
import org.apache.spark.unsafe.types.UTF8String.{IntWrapper, LongWrapper}

object Cast extends QueryErrorsBase {
  /**
   * As per section 6.13 "cast specification" in "Information technology — Database languages " +
   * "- SQL — Part 2: Foundation (SQL/Foundation)":
   * If the <cast operand> is a <value expression>, then the valid combinations of TD and SD
   * in a <cast specification> are given by the following table. “Y” indicates that the
   * combination is syntactically valid without restriction; “M” indicates that the combination
   * is valid subject to other Syntax Rules in this Sub- clause being satisfied; and “N” indicates
   * that the combination is not valid:
   * SD                   TD
   *     EN AN C D T TS YM DT BO UDT B RT CT RW
   * EN  Y  Y  Y N N  N  M  M  N   M N  M  N N
   * AN  Y  Y  Y N N  N  N  N  N   M N  M  N N
   * C   Y  Y  Y Y Y  Y  Y  Y  Y   M N  M  N N
   * D   N  N  Y Y N  Y  N  N  N   M N  M  N N
   * T   N  N  Y N Y  Y  N  N  N   M N  M  N N
   * TS  N  N  Y Y Y  Y  N  N  N   M N  M  N N
   * YM  M  N  Y N N  N  Y  N  N   M N  M  N N
   * DT  M  N  Y N N  N  N  Y  N   M N  M  N N
   * BO  N  N  Y N N  N  N  N  Y   M N  M  N N
   * UDT M  M  M M M  M  M  M  M   M M  M  M N
   * B   N  N  N N N  N  N  N  N   M Y  M  N N
   * RT  M  M  M M M  M  M  M  M   M M  M  N N
   * CT  N  N  N N N  N  N  N  N   M N  N  M N
   * RW  N  N  N N N  N  N  N  N   N N  N  N M
   *
   * Where:
   *   EN  = Exact Numeric
   *   AN  = Approximate Numeric
   *   C   = Character (Fixed- or Variable-Length, or Character Large Object)
   *   D   = Date
   *   T   = Time
   *   TS  = Timestamp
   *   YM  = Year-Month Interval
   *   DT  = Day-Time Interval
   *   BO  = Boolean
   *   UDT  = User-Defined Type
   *   B   = Binary (Fixed- or Variable-Length or Binary Large Object)
   *   RT  = Reference type
   *   CT  = Collection type
   *   RW  = Row type
   *
   * Spark's ANSI mode follows the syntax rules, except it specially allow the following
   * straightforward type conversions which are disallowed as per the SQL standard:
   *   - Numeric <=> Boolean
   *   - String <=> Binary
   */
  def canAnsiCast(from: DataType, to: DataType): Boolean = (from, to) match {
    case (fromType, toType) if fromType == toType => true

    case (NullType, _) => true

    case (_, StringType) => true

    case (StringType, _: BinaryType) => true

    case (StringType, BooleanType) => true
    case (_: NumericType, BooleanType) => true

    case (StringType, TimestampType) => true
    case (DateType, TimestampType) => true
    case (TimestampNTZType, TimestampType) => true
    case (_: NumericType, TimestampType) => true

    case (StringType, TimestampNTZType) => true
    case (DateType, TimestampNTZType) => true
    case (TimestampType, TimestampNTZType) => true

    case (StringType, _: CalendarIntervalType) => true
    case (StringType, _: AnsiIntervalType) => true

    case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true
    case (_: IntegralType | _: DecimalType, _: AnsiIntervalType) => true

    case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
    case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true

    case (StringType, DateType) => true
    case (TimestampType, DateType) => true
    case (TimestampNTZType, DateType) => true

    case (_: NumericType, _: NumericType) => true
    case (StringType, _: NumericType) => true
    case (BooleanType, _: NumericType) => true
    case (TimestampType, _: NumericType) => true

    case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
      canAnsiCast(fromType, toType) && resolvableNullability(fn, tn)

    case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
      canAnsiCast(fromKey, toKey) && canAnsiCast(fromValue, toValue) &&
        resolvableNullability(fn, tn)

    case (StructType(fromFields), StructType(toFields)) =>
      fromFields.length == toFields.length &&
        fromFields.zip(toFields).forall {
          case (fromField, toField) =>
            canAnsiCast(fromField.dataType, toField.dataType) &&
              resolvableNullability(fromField.nullable, toField.nullable)
        }

    case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true

    case _ => false
  }

  // If the target data type is a complex type which can't have Null values, we should guarantee
  // that the casting between the element types won't produce Null results.
  def canTryCast(from: DataType, to: DataType): Boolean = (from, to) match {
    case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
      canCast(fromType, toType) &&
        resolvableNullability(fn || forceNullable(fromType, toType), tn)

    case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
      canCast(fromKey, toKey) &&
        (!forceNullable(fromKey, toKey)) &&
        canCast(fromValue, toValue) &&
        resolvableNullability(fn || forceNullable(fromValue, toValue), tn)

    case (StructType(fromFields), StructType(toFields)) =>
      fromFields.length == toFields.length &&
        fromFields.zip(toFields).forall {
          case (fromField, toField) =>
            canCast(fromField.dataType, toField.dataType) &&
              resolvableNullability(
                fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
                toField.nullable)
        }

    case _ =>
      Cast.canAnsiCast(from, to)
  }

  /**
   * A tag to identify if a CAST added by the table insertion resolver.
   */
  val BY_TABLE_INSERTION = TreeNodeTag[Unit]("by_table_insertion")

  /**
   * A tag to decide if a CAST is specified by user.
   */
  val USER_SPECIFIED_CAST = new TreeNodeTag[Boolean]("user_specified_cast")

  /**
   * Returns true iff we can cast `from` type to `to` type.
   */
  def canCast(from: DataType, to: DataType): Boolean = (from, to) match {
    case (fromType, toType) if fromType == toType => true

    case (NullType, _) => true

    case (_, StringType) => true

    case (StringType, BinaryType) => true
    case (_: IntegralType, BinaryType) => true

    case (StringType, BooleanType) => true
    case (DateType, BooleanType) => true
    case (TimestampType, BooleanType) => true
    case (_: NumericType, BooleanType) => true

    case (StringType, TimestampType) => true
    case (BooleanType, TimestampType) => true
    case (DateType, TimestampType) => true
    case (_: NumericType, TimestampType) => true
    case (TimestampNTZType, TimestampType) => true

    case (StringType, TimestampNTZType) => true
    case (DateType, TimestampNTZType) => true
    case (TimestampType, TimestampNTZType) => true

    case (StringType, DateType) => true
    case (TimestampType, DateType) => true
    case (TimestampNTZType, DateType) => true

    case (StringType, CalendarIntervalType) => true
    case (StringType, _: DayTimeIntervalType) => true
    case (StringType, _: YearMonthIntervalType) => true
    case (_: IntegralType, DayTimeIntervalType(s, e)) if s == e => true
    case (_: IntegralType, YearMonthIntervalType(s, e)) if s == e => true

    case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
    case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true
    case (_: AnsiIntervalType, _: IntegralType | _: DecimalType) => true
    case (_: IntegralType | _: DecimalType, _: AnsiIntervalType) => true

    case (StringType, _: NumericType) => true
    case (BooleanType, _: NumericType) => true
    case (DateType, _: NumericType) => true
    case (TimestampType, _: NumericType) => true
    case (_: NumericType, _: NumericType) => true

    case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
      canCast(fromType, toType) &&
        resolvableNullability(fn || forceNullable(fromType, toType), tn)

    case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
      canCast(fromKey, toKey) &&
        (!forceNullable(fromKey, toKey)) &&
        canCast(fromValue, toValue) &&
        resolvableNullability(fn || forceNullable(fromValue, toValue), tn)

    case (StructType(fromFields), StructType(toFields)) =>
      fromFields.length == toFields.length &&
        fromFields.zip(toFields).forall {
          case (fromField, toField) =>
            canCast(fromField.dataType, toField.dataType) &&
              resolvableNullability(
                fromField.nullable || forceNullable(fromField.dataType, toField.dataType),
                toField.nullable)
        }

    case (udt1: UserDefinedType[_], udt2: UserDefinedType[_]) if udt2.acceptsType(udt1) => true

    case _ => false
  }

  /**
   * Return true if we need to use the `timeZone` information casting `from` type to `to` type.
   * The patterns matched reflect the current implementation in the Cast node.
   * c.f. usage of `timeZone` in:
   * * Cast.castToString
   * * Cast.castToDate
   * * Cast.castToTimestamp
   */
  def needsTimeZone(from: DataType, to: DataType): Boolean = (from, to) match {
    case (StringType, TimestampType | DateType) => true
    case (TimestampType | DateType, StringType) => true
    case (DateType, TimestampType) => true
    case (TimestampType, DateType) => true
    case (ArrayType(fromType, _), ArrayType(toType, _)) => needsTimeZone(fromType, toType)
    case (MapType(fromKey, fromValue, _), MapType(toKey, toValue, _)) =>
      needsTimeZone(fromKey, toKey) || needsTimeZone(fromValue, toValue)
    case (StructType(fromFields), StructType(toFields)) =>
      fromFields.length == toFields.length &&
        fromFields.zip(toFields).exists {
          case (fromField, toField) =>
            needsTimeZone(fromField.dataType, toField.dataType)
        }
    case _ => false
  }

  /**
   * Returns true iff we can safely up-cast the `from` type to `to` type without any truncating or
   * precision lose or possible runtime failures. For example, long -> int, string -> int are not
   * up-cast.
   */
  def canUpCast(from: DataType, to: DataType): Boolean = (from, to) match {
    case _ if from == to => true
    case (from: NumericType, to: DecimalType) if to.isWiderThan(from) => true
    case (from: DecimalType, to: NumericType) if from.isTighterThan(to) => true
    case (f, t) if legalNumericPrecedence(f, t) => true
    case (DateType, TimestampType) => true
    case (_: AtomicType, StringType) => true
    case (_: CalendarIntervalType, StringType) => true
    case (NullType, _) => true

    // Spark supports casting between long and timestamp, please see `longToTimestamp` and
    // `timestampToLong` for details.
    case (TimestampType, LongType) => true
    case (LongType, TimestampType) => true

    case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
      resolvableNullability(fn, tn) && canUpCast(fromType, toType)

    case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
      resolvableNullability(fn, tn) && canUpCast(fromKey, toKey) && canUpCast(fromValue, toValue)

    case (StructType(fromFields), StructType(toFields)) =>
      fromFields.length == toFields.length &&
        fromFields.zip(toFields).forall {
          case (f1, f2) =>
            resolvableNullability(f1.nullable, f2.nullable) && canUpCast(f1.dataType, f2.dataType)
        }

    case (_: DayTimeIntervalType, _: DayTimeIntervalType) => true
    case (_: YearMonthIntervalType, _: YearMonthIntervalType) => true

    case (from: UserDefinedType[_], to: UserDefinedType[_]) if to.acceptsType(from) => true

    case _ => false
  }

  /**
   * Returns true iff we can cast the `from` type to `to` type as per the ANSI SQL.
   * In practice, the behavior is mostly the same as PostgreSQL. It disallows certain unreasonable
   * type conversions such as converting `string` to `int` or `double` to `boolean`.
   */
  def canANSIStoreAssign(from: DataType, to: DataType): Boolean = (from, to) match {
    case _ if from == to => true
    case (NullType, _) => true
    case (_: NumericType, _: NumericType) => true
    case (_: AtomicType, StringType) => true
    case (_: CalendarIntervalType, StringType) => true
    case (_: DatetimeType, _: DatetimeType) => true

    case (ArrayType(fromType, fn), ArrayType(toType, tn)) =>
      resolvableNullability(fn, tn) && canANSIStoreAssign(fromType, toType)

    case (MapType(fromKey, fromValue, fn), MapType(toKey, toValue, tn)) =>
      resolvableNullability(fn, tn) && canANSIStoreAssign(fromKey, toKey) &&
        canANSIStoreAssign(fromValue, toValue)

    case (StructType(fromFields), StructType(toFields)) =>
      fromFields.length == toFields.length &&
        fromFields.zip(toFields).forall {
          case (f1, f2) =>
            resolvableNullability(f1.nullable, f2.nullable) &&
              canANSIStoreAssign(f1.dataType, f2.dataType)
        }

    case _ => false
  }

  private def legalNumericPrecedence(from: DataType, to: DataType): Boolean = {
    val fromPrecedence = TypeCoercion.numericPrecedence.indexOf(from)
    val toPrecedence = TypeCoercion.numericPrecedence.indexOf(to)
    fromPrecedence >= 0 && fromPrecedence < toPrecedence
  }

  def canNullSafeCastToDecimal(from: DataType, to: DecimalType): Boolean = from match {
    case from: BooleanType if to.isWiderThan(DecimalType.BooleanDecimal) => true
    case from: NumericType if to.isWiderThan(from) => true
    case from: DecimalType =>
      // truncating or precision lose
      (to.precision - to.scale) > (from.precision - from.scale)
    case _ => false  // overflow
  }

  /**
   * Returns `true` if casting non-nullable values from `from` type to `to` type
   * may return null. Note that the caller side should take care of input nullability
   * first and only call this method if the input is not nullable.
   */
  def forceNullable(from: DataType, to: DataType): Boolean = (from, to) match {
    case (NullType, _) => false // empty array or map case
    case (_, _) if from == to => false

    case (StringType, BinaryType) => false
    case (StringType, _) => true
    case (_, StringType) => false

    case (FloatType | DoubleType, TimestampType) => true
    case (TimestampType, DateType) => false
    case (_, DateType) => true
    case (DateType, TimestampType) => false
    case (DateType, _) => true
    case (_, CalendarIntervalType) => true

    case (_, to: DecimalType) if !canNullSafeCastToDecimal(from, to) => true
    case (_: FractionalType, _: IntegralType) => true  // NaN, infinity
    case _ => false
  }

  def resolvableNullability(from: Boolean, to: Boolean): Boolean = !from || to

  /**
   * We process literals such as 'Infinity', 'Inf', '-Infinity' and 'NaN' etc in case
   * insensitive manner to be compatible with other database systems such as PostgreSQL and DB2.
   */
  def processFloatingPointSpecialLiterals(v: String, isFloat: Boolean): Any = {
    v.trim.toLowerCase(Locale.ROOT) match {
      case "inf" | "+inf" | "infinity" | "+infinity" =>
        if (isFloat) Float.PositiveInfinity else Double.PositiveInfinity
      case "-inf" | "-infinity" =>
        if (isFloat) Float.NegativeInfinity else Double.NegativeInfinity
      case "nan" =>
        if (isFloat) Float.NaN else Double.NaN
      case _ => null
    }
  }

  def typeCheckFailureMessage(
      from: DataType,
      to: DataType,
      fallbackConf: Option[(String, String)]): DataTypeMismatch = {
    def withFunSuggest(names: String*): DataTypeMismatch = {
      DataTypeMismatch(
        errorSubClass = "CAST_WITH_FUN_SUGGESTION",
        messageParameters = Map(
          "srcType" -> toSQLType(from),
          "targetType" -> toSQLType(to),
          "functionNames" -> names.map(toSQLId).mkString("/")))
    }
    (from, to) match {
      case (_: NumericType, TimestampType) =>
        withFunSuggest("TIMESTAMP_SECONDS", "TIMESTAMP_MILLIS", "TIMESTAMP_MICROS")

      case (TimestampType, _: NumericType) =>
        withFunSuggest("UNIX_SECONDS", "UNIX_MILLIS", "UNIX_MICROS")

      case (_: NumericType, DateType) =>
        withFunSuggest("DATE_FROM_UNIX_DATE")

      case (DateType, _: NumericType) =>
        withFunSuggest("UNIX_DATE")

      case _ if fallbackConf.isDefined && Cast.canCast(from, to) =>
        DataTypeMismatch(
          errorSubClass = "CAST_WITH_CONF_SUGGESTION",
          messageParameters = Map(
            "srcType" -> toSQLType(from),
            "targetType" -> toSQLType(to),
            "config" -> toSQLConf(fallbackConf.get._1),
            "configVal" -> toSQLValue(fallbackConf.get._2, StringType)))

      case _ =>
        DataTypeMismatch(
          errorSubClass = "CAST_WITHOUT_SUGGESTION",
          messageParameters = Map(
            "srcType" -> toSQLType(from),
            "targetType" -> toSQLType(to)))
    }
  }

  def apply(
      child: Expression,
      dataType: DataType,
      ansiEnabled: Boolean): Cast =
    Cast(child, dataType, None, EvalMode.fromBoolean(ansiEnabled))

  def apply(
      child: Expression,
      dataType: DataType,
      timeZoneId: Option[String],
      ansiEnabled: Boolean): Cast =
    Cast(child, dataType, timeZoneId, EvalMode.fromBoolean(ansiEnabled))
}

/**
 * Cast the child expression to the target data type.
 *
 * When cast from/to timezone related types, we need timeZoneId, which will be resolved with
 * session local timezone by an analyzer [[ResolveTimeZone]].
 */
@ExpressionDescription(
  usage = "_FUNC_(expr AS type) - Casts the value `expr` to the target data type `type`.",
  examples = """
    Examples:
      > SELECT _FUNC_('10' as int);
       10
  """,
  since = "1.0.0",
  group = "conversion_funcs")
case class Cast(
    child: Expression,
    dataType: DataType,
    timeZoneId: Option[String] = None,
    evalMode: EvalMode.Value = EvalMode.fromSQLConf(SQLConf.get))
  extends UnaryExpression
  with TimeZoneAwareExpression
  with NullIntolerant
  with SupportQueryContext
  with QueryErrorsBase {

  def this(child: Expression, dataType: DataType, timeZoneId: Option[String]) =
    this(child, dataType, timeZoneId, evalMode = EvalMode.fromSQLConf(SQLConf.get))

  override def withTimeZone(timeZoneId: String): TimeZoneAwareExpression =
    copy(timeZoneId = Option(timeZoneId))

  override protected def withNewChildInternal(newChild: Expression): Cast = copy(child = newChild)

  final override def nodePatternsInternal(): Seq[TreePattern] = Seq(CAST)

  def ansiEnabled: Boolean = {
    evalMode == EvalMode.ANSI || evalMode == EvalMode.TRY
  }

  // Whether this expression is used for `try_cast()`.
  def isTryCast: Boolean = {
    evalMode == EvalMode.TRY
  }

  private def typeCheckFailureInCast: DataTypeMismatch = evalMode match {
    case EvalMode.ANSI =>
      if (getTagValue(Cast.BY_TABLE_INSERTION).isDefined) {
        Cast.typeCheckFailureMessage(child.dataType, dataType,
          Some(SQLConf.STORE_ASSIGNMENT_POLICY.key ->
            SQLConf.StoreAssignmentPolicy.LEGACY.toString))
      } else {
        Cast.typeCheckFailureMessage(child.dataType, dataType,
          Some(SQLConf.ANSI_ENABLED.key -> "false"))
      }
    case EvalMode.TRY =>
      Cast.typeCheckFailureMessage(child.dataType, dataType, None)
    case _ =>
      DataTypeMismatch(
        errorSubClass = "CAST_WITHOUT_SUGGESTION",
        messageParameters = Map(
          "srcType" -> toSQLType(child.dataType),
          "targetType" -> toSQLType(dataType)))
  }

  override def checkInputDataTypes(): TypeCheckResult = {
    val canCast = evalMode match {
      case EvalMode.LEGACY => Cast.canCast(child.dataType, dataType)
      case EvalMode.ANSI => Cast.canAnsiCast(child.dataType, dataType)
      case EvalMode.TRY => Cast.canTryCast(child.dataType, dataType)
      case other => throw new IllegalArgumentException(s"Unknown EvalMode value: $other")
    }
    if (canCast) {
      TypeCheckResult.TypeCheckSuccess
    } else {
      typeCheckFailureInCast
    }
  }

  override def nullable: Boolean = if (!isTryCast) {
    child.nullable || Cast.forceNullable(child.dataType, dataType)
  } else {
    (child.dataType, dataType) match {
      case (StringType, BinaryType) => child.nullable
      // TODO: Implement a more accurate method for checking whether a decimal value can be cast
      //       as integral types without overflow. Currently, the cast can overflow even if
      //       "Cast.canUpCast" method returns true.
      case (_: DecimalType, _: IntegralType) => true
      case _ => child.nullable || !Cast.canUpCast(child.dataType, dataType)
    }
  }

  override def initQueryContext(): Option[SQLQueryContext] = if (ansiEnabled) {
    Some(origin.context)
  } else {
    None
  }

  // When this cast involves TimeZone, it's only resolved if the timeZoneId is set;
  // Otherwise behave like Expression.resolved.
  override lazy val resolved: Boolean =
    childrenResolved && checkInputDataTypes().isSuccess && (!needsTimeZone || timeZoneId.isDefined)

  override lazy val canonicalized: Expression = {
    val basic = withNewChildren(Seq(child.canonicalized)).asInstanceOf[Cast]
    if (timeZoneId.isDefined && !needsTimeZone) {
      basic.withTimeZone(null)
    } else {
      basic
    }
  }

  def needsTimeZone: Boolean = Cast.needsTimeZone(child.dataType, dataType)

  // [[func]] assumes the input is no longer null because eval already does the null check.
  @inline protected[this] def buildCast[T](a: Any, func: T => Any): Any = func(a.asInstanceOf[T])

  private lazy val dateFormatter = DateFormatter()
  private lazy val timestampFormatter = TimestampFormatter.getFractionFormatter(zoneId)
  private lazy val timestampNTZFormatter =
    TimestampFormatter.getFractionFormatter(ZoneOffset.UTC)

  private val legacyCastToStr = SQLConf.get.getConf(SQLConf.LEGACY_COMPLEX_TYPES_TO_STRING)
  // The brackets that are used in casting structs and maps to strings
  private val (leftBracket, rightBracket) = if (legacyCastToStr) ("[", "]") else ("{", "}")

  // The class name of `DateTimeUtils`
  protected def dateTimeUtilsCls: String = DateTimeUtils.getClass.getName.stripSuffix("$")

  // UDFToString
  private[this] def castToString(from: DataType): Any => Any = from match {
    case CalendarIntervalType =>
      buildCast[CalendarInterval](_, i => UTF8String.fromString(i.toString))
    case BinaryType => buildCast[Array[Byte]](_, UTF8String.fromBytes)
    case DateType => buildCast[Int](_, d => UTF8String.fromString(dateFormatter.format(d)))
    case TimestampType => buildCast[Long](_,
      t => UTF8String.fromString(timestampFormatter.format(t)))
    case TimestampNTZType => buildCast[Long](_,
      t => UTF8String.fromString(timestampNTZFormatter.format(t)))
    case ArrayType(et, _) =>
      buildCast[ArrayData](_, array => {
        val builder = new UTF8StringBuilder
        builder.append("[")
        if (array.numElements > 0) {
          val toUTF8String = castToString(et)
          if (array.isNullAt(0)) {
            if (!legacyCastToStr) builder.append("null")
          } else {
            builder.append(toUTF8String(array.get(0, et)).asInstanceOf[UTF8String])
          }
          var i = 1
          while (i < array.numElements) {
            builder.append(",")
            if (array.isNullAt(i)) {
              if (!legacyCastToStr) builder.append(" null")
            } else {
              builder.append(" ")
              builder.append(toUTF8String(array.get(i, et)).asInstanceOf[UTF8String])
            }
            i += 1
          }
        }
        builder.append("]")
        builder.build()
      })
    case MapType(kt, vt, _) =>
      buildCast[MapData](_, map => {
        val builder = new UTF8StringBuilder
        builder.append(leftBracket)
        if (map.numElements > 0) {
          val keyArray = map.keyArray()
          val valueArray = map.valueArray()
          val keyToUTF8String = castToString(kt)
          val valueToUTF8String = castToString(vt)
          builder.append(keyToUTF8String(keyArray.get(0, kt)).asInstanceOf[UTF8String])
          builder.append(" ->")
          if (valueArray.isNullAt(0)) {
            if (!legacyCastToStr) builder.append(" null")
          } else {
            builder.append(" ")
            builder.append(valueToUTF8String(valueArray.get(0, vt)).asInstanceOf[UTF8String])
          }
          var i = 1
          while (i < map.numElements) {
            builder.append(", ")
            builder.append(keyToUTF8String(keyArray.get(i, kt)).asInstanceOf[UTF8String])
            builder.append(" ->")
            if (valueArray.isNullAt(i)) {
              if (!legacyCastToStr) builder.append(" null")
            } else {
              builder.append(" ")
              builder.append(valueToUTF8String(valueArray.get(i, vt))
                .asInstanceOf[UTF8String])
            }
            i += 1
          }
        }
        builder.append(rightBracket)
        builder.build()
      })
    case StructType(fields) =>
      buildCast[InternalRow](_, row => {
        val builder = new UTF8StringBuilder
        builder.append(leftBracket)
        if (row.numFields > 0) {
          val st = fields.map(_.dataType)
          val toUTF8StringFuncs = st.map(castToString)
          if (row.isNullAt(0)) {
            if (!legacyCastToStr) builder.append("null")
          } else {
            builder.append(toUTF8StringFuncs(0)(row.get(0, st(0))).asInstanceOf[UTF8String])
          }
          var i = 1
          while (i < row.numFields) {
            builder.append(",")
            if (row.isNullAt(i)) {
              if (!legacyCastToStr) builder.append(" null")
            } else {
              builder.append(" ")
              builder.append(toUTF8StringFuncs(i)(row.get(i, st(i))).asInstanceOf[UTF8String])
            }
            i += 1
          }
        }
        builder.append(rightBracket)
        builder.build()
      })
    case pudt: PythonUserDefinedType => castToString(pudt.sqlType)
    case udt: UserDefinedType[_] =>
      buildCast[Any](_, o => UTF8String.fromString(udt.deserialize(o).toString))
    case YearMonthIntervalType(startField, endField) =>
      buildCast[Int](_, i => UTF8String.fromString(
        IntervalUtils.toYearMonthIntervalString(i, ANSI_STYLE, startField, endField)))
    case DayTimeIntervalType(startField, endField) =>
      buildCast[Long](_, i => UTF8String.fromString(
        IntervalUtils.toDayTimeIntervalString(i, ANSI_STYLE, startField, endField)))
    // In ANSI mode, Spark always use plain string representation on casting Decimal values
    // as strings. Otherwise, the casting is using `BigDecimal.toString` which may use scientific
    // notation if an exponent is needed.
    case _: DecimalType if ansiEnabled =>
      buildCast[Decimal](_, d => UTF8String.fromString(d.toPlainString))
    case _ => buildCast[Any](_, o => UTF8String.fromString(o.toString))
  }

  // BinaryConverter
  private[this] def castToBinary(from: DataType): Any => Any = from match {
    case StringType => buildCast[UTF8String](_, _.getBytes)
    case ByteType => buildCast[Byte](_, NumberConverter.toBinary)
    case ShortType => buildCast[Short](_, NumberConverter.toBinary)
    case IntegerType => buildCast[Int](_, NumberConverter.toBinary)
    case LongType => buildCast[Long](_, NumberConverter.toBinary)
  }

  // UDFToBoolean
  private[this] def castToBoolean(from: DataType): Any => Any = from match {
    case StringType =>
      buildCast[UTF8String](_, s => {
        if (StringUtils.isTrueString(s)) {
          true
        } else if (StringUtils.isFalseString(s)) {
          false
        } else {
          if (ansiEnabled) {
            throw QueryExecutionErrors.invalidInputSyntaxForBooleanError(s, getContextOrNull())
          } else {
            null
          }
        }
      })
    case TimestampType =>
      buildCast[Long](_, t => t != 0)
    case DateType =>
      // Hive would return null when cast from date to boolean
      buildCast[Int](_, d => null)
    case LongType =>
      buildCast[Long](_, _ != 0)
    case IntegerType =>
      buildCast[Int](_, _ != 0)
    case ShortType =>
      buildCast[Short](_, _ != 0)
    case ByteType =>
      buildCast[Byte](_, _ != 0)
    case DecimalType() =>
      buildCast[Decimal](_, !_.isZero)
    case DoubleType =>
      buildCast[Double](_, _ != 0)
    case FloatType =>
      buildCast[Float](_, _ != 0)
  }

  // TimestampConverter
  private[this] def castToTimestamp(from: DataType): Any => Any = from match {
    case StringType =>
      buildCast[UTF8String](_, utfs => {
        if (ansiEnabled) {
          DateTimeUtils.stringToTimestampAnsi(utfs, zoneId, getContextOrNull())
        } else {
          DateTimeUtils.stringToTimestamp(utfs, zoneId).orNull
        }
      })
    case BooleanType =>
      buildCast[Boolean](_, b => if (b) 1L else 0)
    case LongType =>
      buildCast[Long](_, l => longToTimestamp(l))
    case IntegerType =>
      buildCast[Int](_, i => longToTimestamp(i.toLong))
    case ShortType =>
      buildCast[Short](_, s => longToTimestamp(s.toLong))
    case ByteType =>
      buildCast[Byte](_, b => longToTimestamp(b.toLong))
    case DateType =>
      buildCast[Int](_, d => daysToMicros(d, zoneId))
    case TimestampNTZType =>
      buildCast[Long](_, ts => convertTz(ts, zoneId, ZoneOffset.UTC))
    // TimestampWritable.decimalToTimestamp
    case DecimalType() =>
      buildCast[Decimal](_, d => decimalToTimestamp(d))
    // TimestampWritable.doubleToTimestamp
    case DoubleType =>
      if (ansiEnabled) {
        buildCast[Double](_, d => doubleToTimestampAnsi(d, getContextOrNull()))
      } else {
        buildCast[Double](_, d => doubleToTimestamp(d))
      }
    // TimestampWritable.floatToTimestamp
    case FloatType =>
      if (ansiEnabled) {
        buildCast[Float](_, f => doubleToTimestampAnsi(f.toDouble, getContextOrNull()))
      } else {
        buildCast[Float](_, f => doubleToTimestamp(f.toDouble))
      }
  }

  private[this] def castToTimestampNTZ(from: DataType): Any => Any = from match {
    case StringType =>
      buildCast[UTF8String](_, utfs => {
        if (ansiEnabled) {
          DateTimeUtils.stringToTimestampWithoutTimeZoneAnsi(utfs, getContextOrNull())
        } else {
          DateTimeUtils.stringToTimestampWithoutTimeZone(utfs).orNull
        }
      })
    case DateType =>
      buildCast[Int](_, d => daysToMicros(d, ZoneOffset.UTC))
    case TimestampType =>
      buildCast[Long](_, ts => convertTz(ts, ZoneOffset.UTC, zoneId))
  }

  private[this] def decimalToTimestamp(d: Decimal): Long = {
    (d.toBigDecimal * MICROS_PER_SECOND).longValue
  }
  private[this] def doubleToTimestamp(d: Double): Any = {
    if (d.isNaN || d.isInfinite) null else (d * MICROS_PER_SECOND).toLong
  }

  // converting seconds to us
  private[this] def longToTimestamp(t: Long): Long = SECONDS.toMicros(t)
  // converting us to seconds
  private[this] def timestampToLong(ts: Long): Long = {
    Math.floorDiv(ts, MICROS_PER_SECOND)
  }
  // converting us to seconds in double
  private[this] def timestampToDouble(ts: Long): Double = {
    ts / MICROS_PER_SECOND.toDouble
  }

  // DateConverter
  private[this] def castToDate(from: DataType): Any => Any = from match {
    case StringType =>
      if (ansiEnabled) {
        buildCast[UTF8String](_, s => DateTimeUtils.stringToDateAnsi(s, getContextOrNull()))
      } else {
        buildCast[UTF8String](_, s => DateTimeUtils.stringToDate(s).orNull)
      }
    case TimestampType =>
      // throw valid precision more than seconds, according to Hive.
      // Timestamp.nanos is in 0 to 999,999,999, no more than a second.
      buildCast[Long](_, t => microsToDays(t, zoneId))
    case TimestampNTZType =>
      buildCast[Long](_, t => microsToDays(t, ZoneOffset.UTC))
  }

  // IntervalConverter
  private[this] def castToInterval(from: DataType): Any => Any = from match {
    case StringType =>
      buildCast[UTF8String](_, s => IntervalUtils.safeStringToInterval(s))
  }

  private[this] def castToDayTimeInterval(
      from: DataType,
      it: DayTimeIntervalType): Any => Any = from match {
    case StringType => buildCast[UTF8String](_, s =>
      IntervalUtils.castStringToDTInterval(s, it.startField, it.endField))
    case _: DayTimeIntervalType => buildCast[Long](_, s =>
      IntervalUtils.durationToMicros(IntervalUtils.microsToDuration(s), it.endField))
    case x: IntegralType =>
      if (x == LongType) {
        b => IntervalUtils.longToDayTimeInterval(
          x.integral.asInstanceOf[Integral[Any]].toLong(b), it.startField, it.endField)
      } else {
        b => IntervalUtils.intToDayTimeInterval(
          x.integral.asInstanceOf[Integral[Any]].toInt(b), it.startField, it.endField)
      }
    case DecimalType.Fixed(p, s) =>
      buildCast[Decimal](_, d =>
        IntervalUtils.decimalToDayTimeInterval(d, p, s, it.startField, it.endField))
  }

  private[this] def castToYearMonthInterval(
      from: DataType,
      it: YearMonthIntervalType): Any => Any = from match {
    case StringType => buildCast[UTF8String](_, s =>
      IntervalUtils.castStringToYMInterval(s, it.startField, it.endField))
    case _: YearMonthIntervalType => buildCast[Int](_, s =>
      IntervalUtils.periodToMonths(IntervalUtils.monthsToPeriod(s), it.endField))
    case x: IntegralType =>
      if (x == LongType) {
        b => IntervalUtils.longToYearMonthInterval(
          x.integral.asInstanceOf[Integral[Any]].toLong(b), it.startField, it.endField)
      } else {
        b => IntervalUtils.intToYearMonthInterval(
          x.integral.asInstanceOf[Integral[Any]].toInt(b), it.startField, it.endField)
      }
    case DecimalType.Fixed(p, s) =>
      buildCast[Decimal](_, d =>
        IntervalUtils.decimalToYearMonthInterval(d, p, s, it.startField, it.endField))
  }

  // LongConverter
  private[this] def castToLong(from: DataType): Any => Any = from match {
    case StringType if ansiEnabled =>
      buildCast[UTF8String](_, v => UTF8StringUtils.toLongExact(v, getContextOrNull()))
    case StringType =>
      val result = new LongWrapper()
      buildCast[UTF8String](_, s => if (s.toLong(result)) result.value else null)
    case BooleanType =>
      buildCast[Boolean](_, b => if (b) 1L else 0L)
    case DateType =>
      buildCast[Int](_, d => null)
    case TimestampType =>
      buildCast[Long](_, t => timestampToLong(t))
    case x: NumericType if ansiEnabled =>
      b => x.exactNumeric.asInstanceOf[Numeric[Any]].toLong(b)
    case x: NumericType =>
      b => x.numeric.asInstanceOf[Numeric[Any]].toLong(b)
    case x: DayTimeIntervalType =>
      buildCast[Long](_, i => dayTimeIntervalToLong(i, x.startField, x.endField))
    case x: YearMonthIntervalType =>
      buildCast[Int](_, i => yearMonthIntervalToInt(i, x.startField, x.endField).toLong)
  }

  // IntConverter
  private[this] def castToInt(from: DataType): Any => Any = from match {
    case StringType if ansiEnabled =>
      buildCast[UTF8String](_, v => UTF8StringUtils.toIntExact(v, getContextOrNull()))
    case StringType =>
      val result = new IntWrapper()
      buildCast[UTF8String](_, s => if (s.toInt(result)) result.value else null)
    case BooleanType =>
      buildCast[Boolean](_, b => if (b) 1 else 0)
    case DateType =>
      buildCast[Int](_, d => null)
    case TimestampType if ansiEnabled =>
      buildCast[Long](_, t => {
        val longValue = timestampToLong(t)
        if (longValue == longValue.toInt) {
          longValue.toInt
        } else {
          throw QueryExecutionErrors.castingCauseOverflowError(t, from, IntegerType)
        }
      })
    case TimestampType =>
      buildCast[Long](_, t => timestampToLong(t).toInt)
    case x: NumericType if ansiEnabled =>
      b => x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b)
    case x: NumericType =>
      b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b)
    case x: DayTimeIntervalType =>
      buildCast[Long](_, i => dayTimeIntervalToInt(i, x.startField, x.endField))
    case x: YearMonthIntervalType =>
      buildCast[Int](_, i => yearMonthIntervalToInt(i, x.startField, x.endField))
  }

  // ShortConverter
  private[this] def castToShort(from: DataType): Any => Any = from match {
    case StringType if ansiEnabled =>
      buildCast[UTF8String](_, v => UTF8StringUtils.toShortExact(v, getContextOrNull()))
    case StringType =>
      val result = new IntWrapper()
      buildCast[UTF8String](_, s => if (s.toShort(result)) {
        result.value.toShort
      } else {
        null
      })
    case BooleanType =>
      buildCast[Boolean](_, b => if (b) 1.toShort else 0.toShort)
    case DateType =>
      buildCast[Int](_, d => null)
    case TimestampType if ansiEnabled =>
      buildCast[Long](_, t => {
        val longValue = timestampToLong(t)
        if (longValue == longValue.toShort) {
          longValue.toShort
        } else {
          throw QueryExecutionErrors.castingCauseOverflowError(t, from, ShortType)
        }
      })
    case TimestampType =>
      buildCast[Long](_, t => timestampToLong(t).toShort)
    case x: NumericType if ansiEnabled =>
      b =>
        val intValue = try {
          x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b)
        } catch {
          case _: ArithmeticException =>
            throw QueryExecutionErrors.castingCauseOverflowError(b, from, ShortType)
        }
        if (intValue == intValue.toShort) {
          intValue.toShort
        } else {
          throw QueryExecutionErrors.castingCauseOverflowError(b, from, ShortType)
        }
    case x: NumericType =>
      b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toShort
    case x: DayTimeIntervalType =>
      buildCast[Long](_, i => dayTimeIntervalToShort(i, x.startField, x.endField))
    case x: YearMonthIntervalType =>
      buildCast[Int](_, i => yearMonthIntervalToShort(i, x.startField, x.endField))
  }

  // ByteConverter
  private[this] def castToByte(from: DataType): Any => Any = from match {
    case StringType if ansiEnabled =>
      buildCast[UTF8String](_, v => UTF8StringUtils.toByteExact(v, getContextOrNull()))
    case StringType =>
      val result = new IntWrapper()
      buildCast[UTF8String](_, s => if (s.toByte(result)) {
        result.value.toByte
      } else {
        null
      })
    case BooleanType =>
      buildCast[Boolean](_, b => if (b) 1.toByte else 0.toByte)
    case DateType =>
      buildCast[Int](_, d => null)
    case TimestampType if ansiEnabled =>
      buildCast[Long](_, t => {
        val longValue = timestampToLong(t)
        if (longValue == longValue.toByte) {
          longValue.toByte
        } else {
          throw QueryExecutionErrors.castingCauseOverflowError(t, from, ByteType)
        }
      })
    case TimestampType =>
      buildCast[Long](_, t => timestampToLong(t).toByte)
    case x: NumericType if ansiEnabled =>
      b =>
        val intValue = try {
          x.exactNumeric.asInstanceOf[Numeric[Any]].toInt(b)
        } catch {
          case _: ArithmeticException =>
            throw QueryExecutionErrors.castingCauseOverflowError(b, from, ByteType)
        }
        if (intValue == intValue.toByte) {
          intValue.toByte
        } else {
          throw QueryExecutionErrors.castingCauseOverflowError(b, from, ByteType)
        }
    case x: NumericType =>
      b => x.numeric.asInstanceOf[Numeric[Any]].toInt(b).toByte
    case x: DayTimeIntervalType =>
      buildCast[Long](_, i => dayTimeIntervalToByte(i, x.startField, x.endField))
    case x: YearMonthIntervalType =>
      buildCast[Int](_, i => yearMonthIntervalToByte(i, x.startField, x.endField))
  }

  /**
   * Change the precision / scale in a given decimal to those set in `decimalType` (if any),
   * modifying `value` in-place and returning it if successful. If an overflow occurs, it
   * either returns null or throws an exception according to the value set for
   * `spark.sql.ansi.enabled`.
   *
   * NOTE: this modifies `value` in-place, so don't call it on external data.
   */
  private[this] def changePrecision(value: Decimal, decimalType: DecimalType): Decimal = {
    changePrecision(value, decimalType, !ansiEnabled)
  }

  private[this] def changePrecision(
      value: Decimal,
      decimalType: DecimalType,
      nullOnOverflow: Boolean): Decimal = {
    if (value.changePrecision(decimalType.precision, decimalType.scale)) {
      value
    } else {
      if (nullOnOverflow) {
        null
      } else {
        throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
          value, decimalType.precision, decimalType.scale, getContextOrNull())
      }
    }
  }

  /**
   * Create new `Decimal` with precision and scale given in `decimalType` (if any).
   * If overflow occurs, if `spark.sql.ansi.enabled` is false, null is returned;
   * otherwise, an `ArithmeticException` is thrown.
   */
  private[this] def toPrecision(
      value: Decimal,
      decimalType: DecimalType,
      context: SQLQueryContext): Decimal =
    value.toPrecision(
      decimalType.precision, decimalType.scale, Decimal.ROUND_HALF_UP, !ansiEnabled, context)


  private[this] def castToDecimal(from: DataType, target: DecimalType): Any => Any = from match {
    case StringType if !ansiEnabled =>
      buildCast[UTF8String](_, s => {
        val d = Decimal.fromString(s)
        if (d == null) null else changePrecision(d, target)
      })
    case StringType if ansiEnabled =>
      buildCast[UTF8String](_,
        s => changePrecision(Decimal.fromStringANSI(s, target, getContextOrNull()), target))
    case BooleanType =>
      buildCast[Boolean](_,
        b => toPrecision(if (b) Decimal.ONE else Decimal.ZERO, target, getContextOrNull()))
    case DateType =>
      buildCast[Int](_, d => null) // date can't cast to decimal in Hive
    case TimestampType =>
      // Note that we lose precision here.
      buildCast[Long](_, t => changePrecision(Decimal(timestampToDouble(t)), target))
    case dt: DecimalType =>
      b => toPrecision(b.asInstanceOf[Decimal], target, getContextOrNull())
    case t: IntegralType =>
      b => changePrecision(Decimal(t.integral.asInstanceOf[Integral[Any]].toLong(b)), target)
    case x: FractionalType =>
      b => try {
        changePrecision(Decimal(x.fractional.asInstanceOf[Fractional[Any]].toDouble(b)), target)
      } catch {
        case _: NumberFormatException => null
      }
    case x: DayTimeIntervalType =>
      buildCast[Long](_, dt =>
        changePrecision(
          value = dayTimeIntervalToDecimal(dt, x.endField),
          decimalType = target,
          nullOnOverflow = false))
    case x: YearMonthIntervalType =>
      buildCast[Int](_, ym =>
        changePrecision(
          value = Decimal(yearMonthIntervalToInt(ym, x.startField, x.endField)),
          decimalType = target,
          nullOnOverflow = false))
  }

  // DoubleConverter
  private[this] def castToDouble(from: DataType): Any => Any = from match {
    case StringType =>
      buildCast[UTF8String](_, s => {
        val doubleStr = s.toString
        try doubleStr.toDouble catch {
          case _: NumberFormatException =>
            val d = Cast.processFloatingPointSpecialLiterals(doubleStr, false)
            if(ansiEnabled && d == null) {
              throw QueryExecutionErrors.invalidInputInCastToNumberError(
                DoubleType, s, getContextOrNull())
            } else {
              d
            }
        }
      })
    case BooleanType =>
      buildCast[Boolean](_, b => if (b) 1d else 0d)
    case DateType =>
      buildCast[Int](_, d => null)
    case TimestampType =>
      buildCast[Long](_, t => timestampToDouble(t))
    case x: NumericType =>
      b => x.numeric.asInstanceOf[Numeric[Any]].toDouble(b)
  }

  // FloatConverter
  private[this] def castToFloat(from: DataType): Any => Any = from match {
    case StringType =>
      buildCast[UTF8String](_, s => {
        val floatStr = s.toString
        try floatStr.toFloat catch {
          case _: NumberFormatException =>
            val f = Cast.processFloatingPointSpecialLiterals(floatStr, true)
            if (ansiEnabled && f == null) {
              throw QueryExecutionErrors.invalidInputInCastToNumberError(
                FloatType, s, getContextOrNull())
            } else {
              f
            }
        }
      })
    case BooleanType =>
      buildCast[Boolean](_, b => if (b) 1f else 0f)
    case DateType =>
      buildCast[Int](_, d => null)
    case TimestampType =>
      buildCast[Long](_, t => timestampToDouble(t).toFloat)
    case x: NumericType =>
      b => x.numeric.asInstanceOf[Numeric[Any]].toFloat(b)
  }

  private[this] def castArray(fromType: DataType, toType: DataType): Any => Any = {
    val elementCast = cast(fromType, toType)
    // TODO: Could be faster?
    buildCast[ArrayData](_, array => {
      val values = new Array[Any](array.numElements())
      array.foreach(fromType, (i, e) => {
        if (e == null) {
          values(i) = null
        } else {
          values(i) = elementCast(e)
        }
      })
      new GenericArrayData(values)
    })
  }

  private[this] def castMap(from: MapType, to: MapType): Any => Any = {
    val keyCast = castArray(from.keyType, to.keyType)
    val valueCast = castArray(from.valueType, to.valueType)
    buildCast[MapData](_, map => {
      val keys = keyCast(map.keyArray()).asInstanceOf[ArrayData]
      val values = valueCast(map.valueArray()).asInstanceOf[ArrayData]
      new ArrayBasedMapData(keys, values)
    })
  }

  private[this] def castStruct(from: StructType, to: StructType): Any => Any = {
    val castFuncs: Array[(Any) => Any] = from.fields.zip(to.fields).map {
      case (fromField, toField) => cast(fromField.dataType, toField.dataType)
    }
    // TODO: Could be faster?
    buildCast[InternalRow](_, row => {
      val newRow = new GenericInternalRow(from.fields.length)
      var i = 0
      while (i < row.numFields) {
        newRow.update(i,
          if (row.isNullAt(i)) null else castFuncs(i)(row.get(i, from.apply(i).dataType)))
        i += 1
      }
      newRow
    })
  }

  private def castInternal(from: DataType, to: DataType): Any => Any = {
    // If the cast does not change the structure, then we don't really need to cast anything.
    // We can return what the children return. Same thing should happen in the codegen path.
    if (DataType.equalsStructurally(from, to)) {
      identity
    } else if (from == NullType) {
      // According to `canCast`, NullType can be casted to any type.
      // For primitive types, we don't reach here because the guard of `nullSafeEval`.
      // But for nested types like struct, we might reach here for nested null type field.
      // We won't call the returned function actually, but returns a placeholder.
      _ => throw QueryExecutionErrors.cannotCastFromNullTypeError(to)
    } else {
      to match {
        case dt if dt == from => identity[Any]
        case StringType => castToString(from)
        case BinaryType => castToBinary(from)
        case DateType => castToDate(from)
        case decimal: DecimalType => castToDecimal(from, decimal)
        case TimestampType => castToTimestamp(from)
        case TimestampNTZType => castToTimestampNTZ(from)
        case CalendarIntervalType => castToInterval(from)
        case it: DayTimeIntervalType => castToDayTimeInterval(from, it)
        case it: YearMonthIntervalType => castToYearMonthInterval(from, it)
        case BooleanType => castToBoolean(from)
        case ByteType => castToByte(from)
        case ShortType => castToShort(from)
        case IntegerType => castToInt(from)
        case FloatType => castToFloat(from)
        case LongType => castToLong(from)
        case DoubleType => castToDouble(from)
        case array: ArrayType =>
          castArray(from.asInstanceOf[ArrayType].elementType, array.elementType)
        case map: MapType => castMap(from.asInstanceOf[MapType], map)
        case struct: StructType => castStruct(from.asInstanceOf[StructType], struct)
        case udt: UserDefinedType[_] if udt.acceptsType(from) =>
          identity[Any]
        case _: UserDefinedType[_] =>
          throw QueryExecutionErrors.cannotCastError(from, to)
      }
    }
  }

  private def cast(from: DataType, to: DataType): Any => Any = {
    if (!isTryCast) {
      castInternal(from, to)
    } else {
      (input: Any) =>
        try {
          castInternal(from, to)(input)
        } catch {
          case _: Exception =>
            null
        }
    }
  }

  protected[this] lazy val cast: Any => Any = cast(child.dataType, dataType)

  protected override def nullSafeEval(input: Any): Any = cast(input)

  override def genCode(ctx: CodegenContext): ExprCode = {
    // If the cast does not change the structure, then we don't really need to cast anything.
    // We can return what the children return. Same thing should happen in the interpreted path.
    if (DataType.equalsStructurally(child.dataType, dataType)) {
      child.genCode(ctx)
    } else {
      super.genCode(ctx)
    }
  }

  override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val eval = child.genCode(ctx)
    val nullSafeCast = nullSafeCastFunction(child.dataType, dataType, ctx)

    ev.copy(code = eval.code +
      castCode(ctx, eval.value, eval.isNull, ev.value, ev.isNull, dataType, nullSafeCast))
  }

  // The function arguments are: `input`, `result` and `resultIsNull`. We don't need `inputIsNull`
  // in parameter list, because the returned code will be put in null safe evaluation region.
  protected[this] type CastFunction = (ExprValue, ExprValue, ExprValue) => Block

  private[this] def nullSafeCastFunction(
      from: DataType,
      to: DataType,
      ctx: CodegenContext): CastFunction = to match {

    case _ if from == NullType => (c, evPrim, evNull) => code"$evNull = true;"
    case _ if to == from => (c, evPrim, evNull) => code"$evPrim = $c;"
    case StringType => castToStringCode(from, ctx)
    case BinaryType => castToBinaryCode(from)
    case DateType => castToDateCode(from, ctx)
    case decimal: DecimalType => castToDecimalCode(from, decimal, ctx)
    case TimestampType => castToTimestampCode(from, ctx)
    case TimestampNTZType => castToTimestampNTZCode(from, ctx)
    case CalendarIntervalType => castToIntervalCode(from)
    case it: DayTimeIntervalType => castToDayTimeIntervalCode(from, it)
    case it: YearMonthIntervalType => castToYearMonthIntervalCode(from, it)
    case BooleanType => castToBooleanCode(from, ctx)
    case ByteType => castToByteCode(from, ctx)
    case ShortType => castToShortCode(from, ctx)
    case IntegerType => castToIntCode(from, ctx)
    case FloatType => castToFloatCode(from, ctx)
    case LongType => castToLongCode(from, ctx)
    case DoubleType => castToDoubleCode(from, ctx)

    case array: ArrayType =>
      castArrayCode(from.asInstanceOf[ArrayType].elementType, array.elementType, ctx)
    case map: MapType => castMapCode(from.asInstanceOf[MapType], map, ctx)
    case struct: StructType => castStructCode(from.asInstanceOf[StructType], struct, ctx)
    case udt: UserDefinedType[_] if udt.acceptsType(from) =>
      (c, evPrim, evNull) => code"$evPrim = $c;"
    case _: UserDefinedType[_] =>
      throw QueryExecutionErrors.cannotCastError(from, to)
  }

  // Since we need to cast input expressions recursively inside ComplexTypes, such as Map's
  // Key and Value, Struct's field, we need to name out all the variable names involved in a cast.
  protected[this] def castCode(ctx: CodegenContext, input: ExprValue, inputIsNull: ExprValue,
    result: ExprValue, resultIsNull: ExprValue, resultType: DataType, cast: CastFunction): Block = {
    val javaType = JavaCode.javaType(resultType)
    val castCodeWithTryCatchIfNeeded = if (!isTryCast) {
      s"${cast(input, result, resultIsNull)}"
    } else {
      s"""
         |try {
         |  ${cast(input, result, resultIsNull)}
         |} catch (Exception e) {
         |  $resultIsNull = true;
         |}
         |""".stripMargin
    }
    code"""
      boolean $resultIsNull = $inputIsNull;
      $javaType $result = ${CodeGenerator.defaultValue(resultType)};
      if (!$inputIsNull) {
        $castCodeWithTryCatchIfNeeded
      }
    """
  }

  private def appendIfNotLegacyCastToStr(buffer: ExprValue, s: String): Block = {
    if (!legacyCastToStr) code"""$buffer.append("$s");""" else EmptyBlock
  }

  private def writeArrayToStringBuilder(
      et: DataType,
      array: ExprValue,
      buffer: ExprValue,
      ctx: CodegenContext): Block = {
    val elementToStringCode = castToStringCode(et, ctx)
    val funcName = ctx.freshName("elementToString")
    val element = JavaCode.variable("element", et)
    val elementStr = JavaCode.variable("elementStr", StringType)
    val elementToStringFunc = inline"${ctx.addNewFunction(funcName,
      s"""
         |private UTF8String $funcName(${CodeGenerator.javaType(et)} $element) {
         |  UTF8String $elementStr = null;
         |  ${elementToStringCode(element, elementStr, null /* resultIsNull won't be used */)}
         |  return elementStr;
         |}
       """.stripMargin)}"

    val loopIndex = ctx.freshVariable("loopIndex", IntegerType)
    code"""
       |$buffer.append("[");
       |if ($array.numElements() > 0) {
       |  if ($array.isNullAt(0)) {
       |    ${appendIfNotLegacyCastToStr(buffer, "null")}
       |  } else {
       |    $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, "0")}));
       |  }
       |  for (int $loopIndex = 1; $loopIndex < $array.numElements(); $loopIndex++) {
       |    $buffer.append(",");
       |    if ($array.isNullAt($loopIndex)) {
       |      ${appendIfNotLegacyCastToStr(buffer, " null")}
       |    } else {
       |      $buffer.append(" ");
       |      $buffer.append($elementToStringFunc(${CodeGenerator.getValue(array, et, loopIndex)}));
       |    }
       |  }
       |}
       |$buffer.append("]");
     """.stripMargin
  }

  private def writeMapToStringBuilder(
      kt: DataType,
      vt: DataType,
      map: ExprValue,
      buffer: ExprValue,
      ctx: CodegenContext): Block = {

    def dataToStringFunc(func: String, dataType: DataType) = {
      val funcName = ctx.freshName(func)
      val dataToStringCode = castToStringCode(dataType, ctx)
      val data = JavaCode.variable("data", dataType)
      val dataStr = JavaCode.variable("dataStr", StringType)
      val functionCall = ctx.addNewFunction(funcName,
        s"""
           |private UTF8String $funcName(${CodeGenerator.javaType(dataType)} $data) {
           |  UTF8String $dataStr = null;
           |  ${dataToStringCode(data, dataStr, null /* resultIsNull won't be used */)}
           |  return dataStr;
           |}
         """.stripMargin)
      inline"$functionCall"
    }

    val keyToStringFunc = dataToStringFunc("keyToString", kt)
    val valueToStringFunc = dataToStringFunc("valueToString", vt)
    val loopIndex = ctx.freshVariable("loopIndex", IntegerType)
    val mapKeyArray = JavaCode.expression(s"$map.keyArray()", classOf[ArrayData])
    val mapValueArray = JavaCode.expression(s"$map.valueArray()", classOf[ArrayData])
    val getMapFirstKey = CodeGenerator.getValue(mapKeyArray, kt, JavaCode.literal("0", IntegerType))
    val getMapFirstValue = CodeGenerator.getValue(mapValueArray, vt,
      JavaCode.literal("0", IntegerType))
    val getMapKeyArray = CodeGenerator.getValue(mapKeyArray, kt, loopIndex)
    val getMapValueArray = CodeGenerator.getValue(mapValueArray, vt, loopIndex)
    code"""
       |$buffer.append("$leftBracket");
       |if ($map.numElements() > 0) {
       |  $buffer.append($keyToStringFunc($getMapFirstKey));
       |  $buffer.append(" ->");
       |  if ($map.valueArray().isNullAt(0)) {
       |    ${appendIfNotLegacyCastToStr(buffer, " null")}
       |  } else {
       |    $buffer.append(" ");
       |    $buffer.append($valueToStringFunc($getMapFirstValue));
       |  }
       |  for (int $loopIndex = 1; $loopIndex < $map.numElements(); $loopIndex++) {
       |    $buffer.append(", ");
       |    $buffer.append($keyToStringFunc($getMapKeyArray));
       |    $buffer.append(" ->");
       |    if ($map.valueArray().isNullAt($loopIndex)) {
       |      ${appendIfNotLegacyCastToStr(buffer, " null")}
       |    } else {
       |      $buffer.append(" ");
       |      $buffer.append($valueToStringFunc($getMapValueArray));
       |    }
       |  }
       |}
       |$buffer.append("$rightBracket");
     """.stripMargin
  }

  private def writeStructToStringBuilder(
      st: Seq[DataType],
      row: ExprValue,
      buffer: ExprValue,
      ctx: CodegenContext): Block = {
    val structToStringCode = st.zipWithIndex.map { case (ft, i) =>
      val fieldToStringCode = castToStringCode(ft, ctx)
      val field = ctx.freshVariable("field", ft)
      val fieldStr = ctx.freshVariable("fieldStr", StringType)
      val javaType = JavaCode.javaType(ft)
      code"""
         |${if (i != 0) code"""$buffer.append(",");""" else EmptyBlock}
         |if ($row.isNullAt($i)) {
         |  ${appendIfNotLegacyCastToStr(buffer, if (i == 0) "null" else " null")}
         |} else {
         |  ${if (i != 0) code"""$buffer.append(" ");""" else EmptyBlock}
         |
         |  // Append $i field into the string buffer
         |  $javaType $field = ${CodeGenerator.getValue(row, ft, s"$i")};
         |  UTF8String $fieldStr = null;
         |  ${fieldToStringCode(field, fieldStr, null /* resultIsNull won't be used */)}
         |  $buffer.append($fieldStr);
         |}
       """.stripMargin
    }

    val writeStructCode = ctx.splitExpressions(
      expressions = structToStringCode.map(_.code),
      funcName = "fieldToString",
      arguments = ("InternalRow", row.code) ::
        (classOf[UTF8StringBuilder].getName, buffer.code) :: Nil)

    code"""
       |$buffer.append("$leftBracket");
       |$writeStructCode
       |$buffer.append("$rightBracket");
     """.stripMargin
  }

  @scala.annotation.tailrec
  private[this] def castToStringCode(from: DataType, ctx: CodegenContext): CastFunction = {
    from match {
      case BinaryType =>
        (c, evPrim, evNull) => code"$evPrim = UTF8String.fromBytes($c);"
      case DateType =>
        val df = JavaCode.global(
          ctx.addReferenceObj("dateFormatter", dateFormatter),
          dateFormatter.getClass)
        (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString(${df}.format($c));"""
      case TimestampType =>
        val tf = JavaCode.global(
          ctx.addReferenceObj("timestampFormatter", timestampFormatter),
          timestampFormatter.getClass)
        (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString($tf.format($c));"""
      case TimestampNTZType =>
        val tf = JavaCode.global(
          ctx.addReferenceObj("timestampNTZFormatter", timestampNTZFormatter),
          timestampNTZFormatter.getClass)
        (c, evPrim, evNull) => code"""$evPrim = UTF8String.fromString($tf.format($c));"""
      case CalendarIntervalType =>
        (c, evPrim, _) => code"""$evPrim = UTF8String.fromString($c.toString());"""
      case ArrayType(et, _) =>
        (c, evPrim, evNull) => {
          val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder])
          val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder])
          val writeArrayElemCode = writeArrayToStringBuilder(et, c, buffer, ctx)
          code"""
             |$bufferClass $buffer = new $bufferClass();
             |$writeArrayElemCode;
             |$evPrim = $buffer.build();
           """.stripMargin
        }
      case MapType(kt, vt, _) =>
        (c, evPrim, evNull) => {
          val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder])
          val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder])
          val writeMapElemCode = writeMapToStringBuilder(kt, vt, c, buffer, ctx)
          code"""
             |$bufferClass $buffer = new $bufferClass();
             |$writeMapElemCode;
             |$evPrim = $buffer.build();
           """.stripMargin
        }
      case StructType(fields) =>
        (c, evPrim, evNull) => {
          val row = ctx.freshVariable("row", classOf[InternalRow])
          val buffer = ctx.freshVariable("buffer", classOf[UTF8StringBuilder])
          val bufferClass = JavaCode.javaType(classOf[UTF8StringBuilder])
          val writeStructCode = writeStructToStringBuilder(fields.map(_.dataType), row, buffer, ctx)
          code"""
             |InternalRow $row = $c;
             |$bufferClass $buffer = new $bufferClass();
             |$writeStructCode
             |$evPrim = $buffer.build();
           """.stripMargin
        }
      case pudt: PythonUserDefinedType => castToStringCode(pudt.sqlType, ctx)
      case udt: UserDefinedType[_] =>
        val udtRef = JavaCode.global(ctx.addReferenceObj("udt", udt), udt.sqlType)
        (c, evPrim, evNull) => {
          code"$evPrim = UTF8String.fromString($udtRef.deserialize($c).toString());"
        }
      case i: YearMonthIntervalType =>
        val iu = IntervalUtils.getClass.getName.stripSuffix("$")
        val iss = IntervalStringStyles.getClass.getName.stripSuffix("$")
        val style = s"$iss$$.MODULE$$.ANSI_STYLE()"
        (c, evPrim, _) =>
          code"""
            $evPrim = UTF8String.fromString($iu.toYearMonthIntervalString($c, $style,
              (byte)${i.startField}, (byte)${i.endField}));
          """
      case i: DayTimeIntervalType =>
        val iu = IntervalUtils.getClass.getName.stripSuffix("$")
        val iss = IntervalStringStyles.getClass.getName.stripSuffix("$")
        val style = s"$iss$$.MODULE$$.ANSI_STYLE()"
        (c, evPrim, _) =>
          code"""
            $evPrim = UTF8String.fromString($iu.toDayTimeIntervalString($c, $style,
              (byte)${i.startField}, (byte)${i.endField}));
          """
      // In ANSI mode, Spark always use plain string representation on casting Decimal values
      // as strings. Otherwise, the casting is using `BigDecimal.toString` which may use scientific
      // notation if an exponent is needed.
      case _: DecimalType if ansiEnabled =>
        (c, evPrim, _) => code"$evPrim = UTF8String.fromString($c.toPlainString());"
      case _ =>
        (c, evPrim, evNull) => code"$evPrim = UTF8String.fromString(String.valueOf($c));"
    }
  }

  private[this] def castToBinaryCode(from: DataType): CastFunction = from match {
    case StringType =>
      (c, evPrim, evNull) =>
        code"$evPrim = $c.getBytes();"
    case _: IntegralType =>
      (c, evPrim, evNull) =>
        code"$evPrim = ${NumberConverter.getClass.getName.stripSuffix("$")}.toBinary($c);"
  }

  private[this] def castToDateCode(
      from: DataType,
      ctx: CodegenContext): CastFunction = {
    from match {
      case StringType =>
        val intOpt = ctx.freshVariable("intOpt", classOf[Option[Integer]])
        (c, evPrim, evNull) =>
          if (ansiEnabled) {
            val errorContext = getContextOrNullCode(ctx)
            code"""
              $evPrim = $dateTimeUtilsCls.stringToDateAnsi($c, $errorContext);
            """
          } else {
            code"""
              scala.Option<Integer> $intOpt =
                org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToDate($c);
              if ($intOpt.isDefined()) {
                $evPrim = ((Integer) $intOpt.get()).intValue();
              } else {
                $evNull = true;
              }
            """
          }

      case TimestampType =>
        val zidClass = classOf[ZoneId]
        val zid = JavaCode.global(ctx.addReferenceObj("zoneId", zoneId, zidClass.getName), zidClass)
        (c, evPrim, evNull) =>
          code"""$evPrim =
            org.apache.spark.sql.catalyst.util.DateTimeUtils.microsToDays($c, $zid);"""
      case TimestampNTZType =>
        (c, evPrim, evNull) =>
          code"$evPrim = $dateTimeUtilsCls.microsToDays($c, java.time.ZoneOffset.UTC);"
      case _ =>
        (c, evPrim, evNull) => code"$evNull = true;"
    }
  }

  private[this] def changePrecision(
      d: ExprValue,
      decimalType: DecimalType,
      evPrim: ExprValue,
      evNull: ExprValue,
      canNullSafeCast: Boolean,
      ctx: CodegenContext,
      nullOnOverflow: Boolean): Block = {
    if (canNullSafeCast) {
      code"""
         |$d.changePrecision(${decimalType.precision}, ${decimalType.scale});
         |$evPrim = $d;
       """.stripMargin
    } else {
      val errorContextCode = getContextOrNullCode(ctx, !nullOnOverflow)
      val overflowCode = if (nullOnOverflow) {
        s"$evNull = true;"
      } else {
        s"""
           |throw QueryExecutionErrors.cannotChangeDecimalPrecisionError(
           |  $d, ${decimalType.precision}, ${decimalType.scale}, $errorContextCode);
         """.stripMargin
      }
      code"""
         |if ($d.changePrecision(${decimalType.precision}, ${decimalType.scale})) {
         |  $evPrim = $d;
         |} else {
         |  $overflowCode
         |}
       """.stripMargin
    }
  }

  private[this] def changePrecision(
      d: ExprValue,
      decimalType: DecimalType,
      evPrim: ExprValue,
      evNull: ExprValue,
      canNullSafeCast: Boolean,
      ctx: CodegenContext): Block = {
    changePrecision(d, decimalType, evPrim, evNull, canNullSafeCast, ctx, !ansiEnabled)
  }

  private[this] def castToDecimalCode(
      from: DataType,
      target: DecimalType,
      ctx: CodegenContext): CastFunction = {
    val tmp = ctx.freshVariable("tmpDecimal", classOf[Decimal])
    val canNullSafeCast = Cast.canNullSafeCastToDecimal(from, target)
    from match {
      case StringType if !ansiEnabled =>
        (c, evPrim, evNull) =>
          code"""
              Decimal $tmp = Decimal.fromString($c);
              if ($tmp == null) {
                $evNull = true;
              } else {
                ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx)}
              }
          """
      case StringType if ansiEnabled =>
        val errorContext = getContextOrNullCode(ctx)
        val toType = ctx.addReferenceObj("toType", target)
        (c, evPrim, evNull) =>
          code"""
              Decimal $tmp = Decimal.fromStringANSI($c, $toType, $errorContext);
              ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx)}
          """
      case BooleanType =>
        (c, evPrim, evNull) =>
          code"""
            Decimal $tmp = $c ? Decimal.apply(1) : Decimal.apply(0);
            ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx)}
          """
      case DateType =>
        // date can't cast to decimal in Hive
        (c, evPrim, evNull) => code"$evNull = true;"
      case TimestampType =>
        // Note that we lose precision here.
        (c, evPrim, evNull) =>
          code"""
            Decimal $tmp = Decimal.apply(
              scala.math.BigDecimal.valueOf(${timestampToDoubleCode(c)}));
            ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx)}
          """
      case DecimalType() =>
        (c, evPrim, evNull) =>
          code"""
            Decimal $tmp = $c.clone();
            ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx)}
          """
      case x: IntegralType =>
        (c, evPrim, evNull) =>
          code"""
            Decimal $tmp = Decimal.apply((long) $c);
            ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx)}
          """
      case x: FractionalType =>
        // All other numeric types can be represented precisely as Doubles
        (c, evPrim, evNull) =>
          code"""
            try {
              Decimal $tmp = Decimal.apply(scala.math.BigDecimal.valueOf((double) $c));
              ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx)}
            } catch (java.lang.NumberFormatException e) {
              $evNull = true;
            }
          """
      case x: DayTimeIntervalType =>
        (c, evPrim, evNull) =>
          val u = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
          code"""
            Decimal $tmp = $u.dayTimeIntervalToDecimal($c, (byte)${x.endField});
            ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx, false)}
          """
      case x: YearMonthIntervalType =>
        (c, evPrim, evNull) =>
          val u = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
          val tmpYm = ctx.freshVariable("tmpYm", classOf[Int])
          code"""
            int $tmpYm = $u.yearMonthIntervalToInt($c, (byte)${x.startField}, (byte)${x.endField});
            Decimal $tmp = Decimal.apply($tmpYm);
            ${changePrecision(tmp, target, evPrim, evNull, canNullSafeCast, ctx, false)}
          """
    }
  }

  private[this] def castToTimestampCode(
      from: DataType,
      ctx: CodegenContext): CastFunction = from match {
    case StringType =>
      val zoneIdClass = classOf[ZoneId]
      val zid = JavaCode.global(
        ctx.addReferenceObj("zoneId", zoneId, zoneIdClass.getName),
        zoneIdClass)
      val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
      (c, evPrim, evNull) =>
        if (ansiEnabled) {
          val errorContext = getContextOrNullCode(ctx)
          code"""
            $evPrim = $dateTimeUtilsCls.stringToTimestampAnsi($c, $zid, $errorContext);
           """
        } else {
          code"""
            scala.Option<Long> $longOpt =
              org.apache.spark.sql.catalyst.util.DateTimeUtils.stringToTimestamp($c, $zid);
            if ($longOpt.isDefined()) {
              $evPrim = ((Long) $longOpt.get()).longValue();
            } else {
              $evNull = true;
            }
           """
        }
    case BooleanType =>
      (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;"
    case _: IntegralType =>
      (c, evPrim, evNull) => code"$evPrim = ${longToTimeStampCode(c)};"
    case DateType =>
      val zoneIdClass = classOf[ZoneId]
      val zid = JavaCode.global(
        ctx.addReferenceObj("zoneId", zoneId, zoneIdClass.getName),
        zoneIdClass)
      (c, evPrim, evNull) =>
        code"""$evPrim =
          org.apache.spark.sql.catalyst.util.DateTimeUtils.daysToMicros($c, $zid);"""
    case TimestampNTZType =>
      val zoneIdClass = classOf[ZoneId]
      val zid = JavaCode.global(
        ctx.addReferenceObj("zoneId", zoneId, zoneIdClass.getName),
        zoneIdClass)
      (c, evPrim, evNull) =>
        code"$evPrim = $dateTimeUtilsCls.convertTz($c, $zid, java.time.ZoneOffset.UTC);"
    case DecimalType() =>
      (c, evPrim, evNull) => code"$evPrim = ${decimalToTimestampCode(c)};"
    case DoubleType =>
      (c, evPrim, evNull) =>
        if (ansiEnabled) {
          val errorContext = getContextOrNullCode(ctx)
          code"$evPrim = $dateTimeUtilsCls.doubleToTimestampAnsi($c, $errorContext);"
        } else {
          code"""
            if (Double.isNaN($c) || Double.isInfinite($c)) {
              $evNull = true;
            } else {
              $evPrim = (long)($c * $MICROS_PER_SECOND);
            }
          """
        }
    case FloatType =>
      (c, evPrim, evNull) =>
        if (ansiEnabled) {
          val errorContext = getContextOrNullCode(ctx)
          code"$evPrim = $dateTimeUtilsCls.doubleToTimestampAnsi((double)$c, $errorContext);"
        } else {
          code"""
            if (Float.isNaN($c) || Float.isInfinite($c)) {
              $evNull = true;
            } else {
              $evPrim = (long)((double)$c * $MICROS_PER_SECOND);
            }
          """
        }
  }

  private[this] def castToTimestampNTZCode(
      from: DataType,
      ctx: CodegenContext): CastFunction = from match {
    case StringType =>
      val longOpt = ctx.freshVariable("longOpt", classOf[Option[Long]])
      (c, evPrim, evNull) =>
        if (ansiEnabled) {
          val errorContext = getContextOrNullCode(ctx)
          code"""
            $evPrim = $dateTimeUtilsCls.stringToTimestampWithoutTimeZoneAnsi($c, $errorContext);
           """
        } else {
          code"""
            scala.Option<Long> $longOpt = $dateTimeUtilsCls.stringToTimestampWithoutTimeZone($c);
            if ($longOpt.isDefined()) {
              $evPrim = ((Long) $longOpt.get()).longValue();
            } else {
              $evNull = true;
            }
           """
        }
    case DateType =>
      (c, evPrim, evNull) =>
        code"$evPrim = $dateTimeUtilsCls.daysToMicros($c, java.time.ZoneOffset.UTC);"
    case TimestampType =>
      val zoneIdClass = classOf[ZoneId]
      val zid = JavaCode.global(
        ctx.addReferenceObj("zoneId", zoneId, zoneIdClass.getName),
        zoneIdClass)
      (c, evPrim, evNull) =>
        code"$evPrim = $dateTimeUtilsCls.convertTz($c, java.time.ZoneOffset.UTC, $zid);"
  }

  private[this] def castToIntervalCode(from: DataType): CastFunction = from match {
    case StringType =>
      val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
      (c, evPrim, evNull) =>
        code"""$evPrim = $util.safeStringToInterval($c);
           if(${evPrim} == null) {
             ${evNull} = true;
           }
         """.stripMargin

  }

  private[this] def castToDayTimeIntervalCode(
      from: DataType,
      it: DayTimeIntervalType): CastFunction = from match {
    case StringType =>
      val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
      (c, evPrim, _) =>
        code"""
          $evPrim = $util.castStringToDTInterval($c, (byte)${it.startField}, (byte)${it.endField});
        """
    case _: DayTimeIntervalType =>
      val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
      (c, evPrim, _) =>
        code"""
          $evPrim = $util.durationToMicros($util.microsToDuration($c), (byte)${it.endField});
        """
    case x: IntegralType =>
      val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
      if (x == LongType) {
        (c, evPrim, _) =>
          code"""
            $evPrim = $iu.longToDayTimeInterval($c, (byte)${it.startField}, (byte)${it.endField});
          """
      } else {
        (c, evPrim, _) =>
          code"""
            $evPrim = $iu.intToDayTimeInterval($c, (byte)${it.startField}, (byte)${it.endField});
          """
      }
    case DecimalType.Fixed(p, s) =>
      val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
      (c, evPrim, _) =>
        code"""
          $evPrim = $iu.decimalToDayTimeInterval(
            $c, $p, $s, (byte)${it.startField}, (byte)${it.endField});
        """
  }

  private[this] def castToYearMonthIntervalCode(
      from: DataType,
      it: YearMonthIntervalType): CastFunction = from match {
    case StringType =>
      val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
      (c, evPrim, _) =>
        code"""
          $evPrim = $util.castStringToYMInterval($c, (byte)${it.startField}, (byte)${it.endField});
        """
    case _: YearMonthIntervalType =>
      val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
      (c, evPrim, _) =>
        code"""
          $evPrim = $util.periodToMonths($util.monthsToPeriod($c), (byte)${it.endField});
        """
    case x: IntegralType =>
      val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
      if (x == LongType) {
        (c, evPrim, _) =>
          code"""
            $evPrim = $iu.longToYearMonthInterval($c, (byte)${it.startField}, (byte)${it.endField});
          """
      } else {
        (c, evPrim, _) =>
          code"""
            $evPrim = $iu.intToYearMonthInterval($c, (byte)${it.startField}, (byte)${it.endField});
          """
      }
    case DecimalType.Fixed(p, s) =>
      val iu = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
      (c, evPrim, _) =>
        code"""
          $evPrim = $iu.decimalToYearMonthInterval(
            $c, $p, $s, (byte)${it.startField}, (byte)${it.endField});
        """
  }

  private[this] def decimalToTimestampCode(d: ExprValue): Block = {
    val block = inline"new java.math.BigDecimal($MICROS_PER_SECOND)"
    code"($d.toBigDecimal().bigDecimal().multiply($block)).longValue()"
  }
  private[this] def longToTimeStampCode(l: ExprValue): Block = code"$l * (long)$MICROS_PER_SECOND"
  private[this] def timestampToLongCode(ts: ExprValue): Block =
    code"java.lang.Math.floorDiv($ts, $MICROS_PER_SECOND)"
  private[this] def timestampToDoubleCode(ts: ExprValue): Block =
    code"$ts / (double)$MICROS_PER_SECOND"

  private[this] def castToBooleanCode(
      from: DataType,
      ctx: CodegenContext): CastFunction = from match {
    case StringType =>
      val stringUtils = inline"${StringUtils.getClass.getName.stripSuffix("$")}"
      (c, evPrim, evNull) =>
        val castFailureCode = if (ansiEnabled) {
          val errorContext = getContextOrNullCode(ctx)
          s"throw QueryExecutionErrors.invalidInputSyntaxForBooleanError($c, $errorContext);"
        } else {
          s"$evNull = true;"
        }
        code"""
          if ($stringUtils.isTrueString($c)) {
            $evPrim = true;
          } else if ($stringUtils.isFalseString($c)) {
            $evPrim = false;
          } else {
            $castFailureCode
          }
        """
    case TimestampType =>
      (c, evPrim, evNull) => code"$evPrim = $c != 0;"
    case DateType =>
      // Hive would return null when cast from date to boolean
      (c, evPrim, evNull) => code"$evNull = true;"
    case DecimalType() =>
      (c, evPrim, evNull) => code"$evPrim = !$c.isZero();"
    case n: NumericType =>
      (c, evPrim, evNull) => code"$evPrim = $c != 0;"
  }

  private[this] def castTimestampToIntegralTypeCode(
      ctx: CodegenContext,
      integralType: String,
      from: DataType,
      to: DataType): CastFunction = {
    if (ansiEnabled) {
      val longValue = ctx.freshName("longValue")
      val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
      val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
      (c, evPrim, _) =>
        code"""
          long $longValue = ${timestampToLongCode(c)};
          if ($longValue == ($integralType) $longValue) {
            $evPrim = ($integralType) $longValue;
          } else {
            throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt);
          }
        """
    } else {
      (c, evPrim, _) => code"$evPrim = ($integralType) ${timestampToLongCode(c)};"
    }
  }

  private[this] def castDayTimeIntervalToIntegralTypeCode(
      startField: Byte,
      endField: Byte,
      integralType: String): CastFunction = {
    val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
    (c, evPrim, _) =>
      code"""
        $evPrim = $util.dayTimeIntervalTo$integralType($c, (byte)$startField, (byte)$endField);
      """
  }

  private[this] def castYearMonthIntervalToIntegralTypeCode(
      startField: Byte,
      endField: Byte,
      integralType: String): CastFunction = {
    val util = IntervalUtils.getClass.getCanonicalName.stripSuffix("$")
    (c, evPrim, _) =>
      code"""
        $evPrim = $util.yearMonthIntervalTo$integralType($c, (byte)$startField, (byte)$endField);
      """
  }

  private[this] def castDecimalToIntegralTypeCode(integralType: String): CastFunction = {
    if (ansiEnabled) {
      (c, evPrim, _) => code"$evPrim = $c.roundTo${integralType.capitalize}();"
    } else {
      (c, evPrim, _) => code"$evPrim = $c.to${integralType.capitalize}();"
    }
  }

  private[this] def castIntegralTypeToIntegralTypeExactCode(
      ctx: CodegenContext,
      integralType: String,
      from: DataType,
      to: DataType): CastFunction = {
    assert(ansiEnabled)
    val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
    val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
    (c, evPrim, _) =>
      code"""
        if ($c == ($integralType) $c) {
          $evPrim = ($integralType) $c;
        } else {
          throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt);
        }
      """
  }


  private[this] def lowerAndUpperBound(integralType: String): (String, String) = {
    val (min, max, typeIndicator) = integralType.toLowerCase(Locale.ROOT) match {
      case "long" => (Long.MinValue, Long.MaxValue, "L")
      case "int" => (Int.MinValue, Int.MaxValue, "")
      case "short" => (Short.MinValue, Short.MaxValue, "")
      case "byte" => (Byte.MinValue, Byte.MaxValue, "")
    }
    (min.toString + typeIndicator, max.toString + typeIndicator)
  }

  private[this] def castFractionToIntegralTypeCode(
      ctx: CodegenContext,
      integralType: String,
      from: DataType,
      to: DataType): CastFunction = {
    assert(ansiEnabled)
    val (min, max) = lowerAndUpperBound(integralType)
    val mathClass = classOf[Math].getName
    val fromDt = ctx.addReferenceObj("from", from, from.getClass.getName)
    val toDt = ctx.addReferenceObj("to", to, to.getClass.getName)
    // When casting floating values to integral types, Spark uses the method `Numeric.toInt`
    // Or `Numeric.toLong` directly. For positive floating values, it is equivalent to `Math.floor`;
    // for negative floating values, it is equivalent to `Math.ceil`.
    // So, we can use the condition `Math.floor(x) <= upperBound && Math.ceil(x) >= lowerBound`
    // to check if the floating value x is in the range of an integral type after rounding.
    (c, evPrim, _) =>
      code"""
        if ($mathClass.floor($c) <= $max && $mathClass.ceil($c) >= $min) {
          $evPrim = ($integralType) $c;
        } else {
          throw QueryExecutionErrors.castingCauseOverflowError($c, $fromDt, $toDt);
        }
      """
  }

  private[this] def castToByteCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
    case StringType if ansiEnabled =>
      val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
      val errorContext = getContextOrNullCode(ctx)
      (c, evPrim, evNull) => code"$evPrim = $stringUtils.toByteExact($c, $errorContext);"
    case StringType =>
      val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
      (c, evPrim, evNull) =>
        code"""
          UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
          if ($c.toByte($wrapper)) {
            $evPrim = (byte) $wrapper.value;
          } else {
            $evNull = true;
          }
          $wrapper = null;
        """
    case BooleanType =>
      (c, evPrim, evNull) => code"$evPrim = $c ? (byte) 1 : (byte) 0;"
    case DateType =>
      (c, evPrim, evNull) => code"$evNull = true;"
    case TimestampType => castTimestampToIntegralTypeCode(ctx, "byte", from, ByteType)
    case DecimalType() => castDecimalToIntegralTypeCode("byte")
    case ShortType | IntegerType | LongType if ansiEnabled =>
      castIntegralTypeToIntegralTypeExactCode(ctx, "byte", from, ByteType)
    case FloatType | DoubleType if ansiEnabled =>
      castFractionToIntegralTypeCode(ctx, "byte", from, ByteType)
    case x: NumericType =>
      (c, evPrim, evNull) => code"$evPrim = (byte) $c;"
    case x: DayTimeIntervalType =>
      castDayTimeIntervalToIntegralTypeCode(x.startField, x.endField, "Byte")
    case x: YearMonthIntervalType =>
      castYearMonthIntervalToIntegralTypeCode(x.startField, x.endField, "Byte")
  }

  private[this] def castToShortCode(
      from: DataType,
      ctx: CodegenContext): CastFunction = from match {
    case StringType if ansiEnabled =>
      val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
      val errorContext = getContextOrNullCode(ctx)
      (c, evPrim, evNull) => code"$evPrim = $stringUtils.toShortExact($c, $errorContext);"
    case StringType =>
      val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
      (c, evPrim, evNull) =>
        code"""
          UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
          if ($c.toShort($wrapper)) {
            $evPrim = (short) $wrapper.value;
          } else {
            $evNull = true;
          }
          $wrapper = null;
        """
    case BooleanType =>
      (c, evPrim, evNull) => code"$evPrim = $c ? (short) 1 : (short) 0;"
    case DateType =>
      (c, evPrim, evNull) => code"$evNull = true;"
    case TimestampType => castTimestampToIntegralTypeCode(ctx, "short", from, ShortType)
    case DecimalType() => castDecimalToIntegralTypeCode("short")
    case IntegerType | LongType if ansiEnabled =>
      castIntegralTypeToIntegralTypeExactCode(ctx, "short", from, ShortType)
    case FloatType | DoubleType if ansiEnabled =>
      castFractionToIntegralTypeCode(ctx, "short", from, ShortType)
    case x: NumericType =>
      (c, evPrim, evNull) => code"$evPrim = (short) $c;"
    case x: DayTimeIntervalType =>
      castDayTimeIntervalToIntegralTypeCode(x.startField, x.endField, "Short")
    case x: YearMonthIntervalType =>
      castYearMonthIntervalToIntegralTypeCode(x.startField, x.endField, "Short")
  }

  private[this] def castToIntCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
    case StringType if ansiEnabled =>
      val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
      val errorContext = getContextOrNullCode(ctx)
      (c, evPrim, evNull) => code"$evPrim = $stringUtils.toIntExact($c, $errorContext);"
    case StringType =>
      val wrapper = ctx.freshVariable("intWrapper", classOf[UTF8String.IntWrapper])
      (c, evPrim, evNull) =>
        code"""
          UTF8String.IntWrapper $wrapper = new UTF8String.IntWrapper();
          if ($c.toInt($wrapper)) {
            $evPrim = $wrapper.value;
          } else {
            $evNull = true;
          }
          $wrapper = null;
        """
    case BooleanType =>
      (c, evPrim, evNull) => code"$evPrim = $c ? 1 : 0;"
    case DateType =>
      (c, evPrim, evNull) => code"$evNull = true;"
    case TimestampType => castTimestampToIntegralTypeCode(ctx, "int", from, IntegerType)
    case DecimalType() => castDecimalToIntegralTypeCode("int")
    case LongType if ansiEnabled =>
      castIntegralTypeToIntegralTypeExactCode(ctx, "int", from, IntegerType)
    case FloatType | DoubleType if ansiEnabled =>
      castFractionToIntegralTypeCode(ctx, "int", from, IntegerType)
    case x: NumericType =>
      (c, evPrim, evNull) => code"$evPrim = (int) $c;"
    case x: DayTimeIntervalType =>
      castDayTimeIntervalToIntegralTypeCode(x.startField, x.endField, "Int")
    case x: YearMonthIntervalType =>
      castYearMonthIntervalToIntegralTypeCode(x.startField, x.endField, "Int")
  }

  private[this] def castToLongCode(from: DataType, ctx: CodegenContext): CastFunction = from match {
    case StringType if ansiEnabled =>
      val stringUtils = UTF8StringUtils.getClass.getCanonicalName.stripSuffix("$")
      val errorContext = getContextOrNullCode(ctx)
      (c, evPrim, evNull) => code"$evPrim = $stringUtils.toLongExact($c, $errorContext);"
    case StringType =>
      val wrapper = ctx.freshVariable("longWrapper", classOf[UTF8String.LongWrapper])
      (c, evPrim, evNull) =>
        code"""
          UTF8String.LongWrapper $wrapper = new UTF8String.LongWrapper();
          if ($c.toLong($wrapper)) {
            $evPrim = $wrapper.value;
          } else {
            $evNull = true;
          }
          $wrapper = null;
        """
    case BooleanType =>
      (c, evPrim, evNull) => code"$evPrim = $c ? 1L : 0L;"
    case DateType =>
      (c, evPrim, evNull) => code"$evNull = true;"
    case TimestampType =>
      (c, evPrim, evNull) => code"$evPrim = (long) ${timestampToLongCode(c)};"
    case DecimalType() => castDecimalToIntegralTypeCode("long")
    case FloatType | DoubleType if ansiEnabled =>
      castFractionToIntegralTypeCode(ctx, "long", from, LongType)
    case x: NumericType =>
      (c, evPrim, evNull) => code"$evPrim = (long) $c;"
    case x: DayTimeIntervalType =>
      castDayTimeIntervalToIntegralTypeCode(x.startField, x.endField, "Long")
    case x: YearMonthIntervalType =>
      castYearMonthIntervalToIntegralTypeCode(x.startField, x.endField, "Int")
  }

  private[this] def castToFloatCode(from: DataType, ctx: CodegenContext): CastFunction = {
    from match {
      case StringType =>
        val floatStr = ctx.freshVariable("floatStr", StringType)
        (c, evPrim, evNull) =>
          val handleNull = if (ansiEnabled) {
            val errorContext = getContextOrNullCode(ctx)
            "throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
              s"org.apache.spark.sql.types.FloatType$$.MODULE$$,$c, $errorContext);"
          } else {
            s"$evNull = true;"
          }
          code"""
          final String $floatStr = $c.toString();
          try {
            $evPrim = Float.valueOf($floatStr);
          } catch (java.lang.NumberFormatException e) {
            final Float f = (Float) Cast.processFloatingPointSpecialLiterals($floatStr, true);
            if (f == null) {
              $handleNull
            } else {
              $evPrim = f.floatValue();
            }
          }
        """
      case BooleanType =>
        (c, evPrim, evNull) => code"$evPrim = $c ? 1.0f : 0.0f;"
      case DateType =>
        (c, evPrim, evNull) => code"$evNull = true;"
      case TimestampType =>
        (c, evPrim, evNull) => code"$evPrim = (float) (${timestampToDoubleCode(c)});"
      case DecimalType() =>
        (c, evPrim, evNull) => code"$evPrim = $c.toFloat();"
      case x: NumericType =>
        (c, evPrim, evNull) => code"$evPrim = (float) $c;"
    }
  }

  private[this] def castToDoubleCode(from: DataType, ctx: CodegenContext): CastFunction = {
    from match {
      case StringType =>
        val doubleStr = ctx.freshVariable("doubleStr", StringType)
        (c, evPrim, evNull) =>
          val handleNull = if (ansiEnabled) {
            val errorContext = getContextOrNullCode(ctx)
            "throw QueryExecutionErrors.invalidInputInCastToNumberError(" +
              s"org.apache.spark.sql.types.DoubleType$$.MODULE$$, $c, $errorContext);"
          } else {
            s"$evNull = true;"
          }
          code"""
          final String $doubleStr = $c.toString();
          try {
            $evPrim = Double.valueOf($doubleStr);
          } catch (java.lang.NumberFormatException e) {
            final Double d = (Double) Cast.processFloatingPointSpecialLiterals($doubleStr, false);
            if (d == null) {
              $handleNull
            } else {
              $evPrim = d.doubleValue();
            }
          }
        """
      case BooleanType =>
        (c, evPrim, evNull) => code"$evPrim = $c ? 1.0d : 0.0d;"
      case DateType =>
        (c, evPrim, evNull) => code"$evNull = true;"
      case TimestampType =>
        (c, evPrim, evNull) => code"$evPrim = ${timestampToDoubleCode(c)};"
      case DecimalType() =>
        (c, evPrim, evNull) => code"$evPrim = $c.toDouble();"
      case x: NumericType =>
        (c, evPrim, evNull) => code"$evPrim = (double) $c;"
    }
  }

  private[this] def castArrayCode(
      fromType: DataType, toType: DataType, ctx: CodegenContext): CastFunction = {
    val elementCast = nullSafeCastFunction(fromType, toType, ctx)
    val arrayClass = JavaCode.javaType(classOf[GenericArrayData])
    val fromElementNull = ctx.freshVariable("feNull", BooleanType)
    val fromElementPrim = ctx.freshVariable("fePrim", fromType)
    val toElementNull = ctx.freshVariable("teNull", BooleanType)
    val toElementPrim = ctx.freshVariable("tePrim", toType)
    val size = ctx.freshVariable("n", IntegerType)
    val j = ctx.freshVariable("j", IntegerType)
    val values = ctx.freshVariable("values", classOf[Array[Object]])
    val javaType = JavaCode.javaType(fromType)

    (c, evPrim, evNull) =>
      code"""
        final int $size = $c.numElements();
        final Object[] $values = new Object[$size];
        for (int $j = 0; $j < $size; $j ++) {
          if ($c.isNullAt($j)) {
            $values[$j] = null;
          } else {
            boolean $fromElementNull = false;
            $javaType $fromElementPrim =
              ${CodeGenerator.getValue(c, fromType, j)};
            ${castCode(ctx, fromElementPrim,
              fromElementNull, toElementPrim, toElementNull, toType, elementCast)}
            if ($toElementNull) {
              $values[$j] = null;
            } else {
              $values[$j] = $toElementPrim;
            }
          }
        }
        $evPrim = new $arrayClass($values);
      """
  }

  private[this] def castMapCode(from: MapType, to: MapType, ctx: CodegenContext): CastFunction = {
    val keysCast = castArrayCode(from.keyType, to.keyType, ctx)
    val valuesCast = castArrayCode(from.valueType, to.valueType, ctx)

    val mapClass = JavaCode.javaType(classOf[ArrayBasedMapData])

    val keys = ctx.freshVariable("keys", ArrayType(from.keyType))
    val convertedKeys = ctx.freshVariable("convertedKeys", ArrayType(to.keyType))
    val convertedKeysNull = ctx.freshVariable("convertedKeysNull", BooleanType)

    val values = ctx.freshVariable("values", ArrayType(from.valueType))
    val convertedValues = ctx.freshVariable("convertedValues", ArrayType(to.valueType))
    val convertedValuesNull = ctx.freshVariable("convertedValuesNull", BooleanType)

    (c, evPrim, evNull) =>
      code"""
        final ArrayData $keys = $c.keyArray();
        final ArrayData $values = $c.valueArray();
        ${castCode(ctx, keys, FalseLiteral,
          convertedKeys, convertedKeysNull, ArrayType(to.keyType), keysCast)}
        ${castCode(ctx, values, FalseLiteral,
          convertedValues, convertedValuesNull, ArrayType(to.valueType), valuesCast)}

        $evPrim = new $mapClass($convertedKeys, $convertedValues);
      """
  }

  private[this] def castStructCode(
      from: StructType, to: StructType, ctx: CodegenContext): CastFunction = {

    val fieldsCasts = from.fields.zip(to.fields).map {
      case (fromField, toField) => nullSafeCastFunction(fromField.dataType, toField.dataType, ctx)
    }
    val tmpResult = ctx.freshVariable("tmpResult", classOf[GenericInternalRow])
    val rowClass = JavaCode.javaType(classOf[GenericInternalRow])
    val tmpInput = ctx.freshVariable("tmpInput", classOf[InternalRow])

    val fieldsEvalCode = fieldsCasts.zipWithIndex.map { case (cast, i) =>
      val fromFieldPrim = ctx.freshVariable("ffp", from.fields(i).dataType)
      val fromFieldNull = ctx.freshVariable("ffn", BooleanType)
      val toFieldPrim = ctx.freshVariable("tfp", to.fields(i).dataType)
      val toFieldNull = ctx.freshVariable("tfn", BooleanType)
      val fromType = JavaCode.javaType(from.fields(i).dataType)
      val setColumn = CodeGenerator.setColumn(tmpResult, to.fields(i).dataType, i, toFieldPrim)
      code"""
        boolean $fromFieldNull = $tmpInput.isNullAt($i);
        if ($fromFieldNull) {
          $tmpResult.setNullAt($i);
        } else {
          $fromType $fromFieldPrim =
            ${CodeGenerator.getValue(tmpInput, from.fields(i).dataType, i.toString)};
          ${castCode(ctx, fromFieldPrim,
            fromFieldNull, toFieldPrim, toFieldNull, to.fields(i).dataType, cast)}
          if ($toFieldNull) {
            $tmpResult.setNullAt($i);
          } else {
            $setColumn;
          }
        }
       """
    }
    val fieldsEvalCodes = ctx.splitExpressions(
      expressions = fieldsEvalCode.map(_.code),
      funcName = "castStruct",
      arguments = ("InternalRow", tmpInput.code) :: (rowClass.code, tmpResult.code) :: Nil)

    (input, result, resultIsNull) =>
      code"""
        final $rowClass $tmpResult = new $rowClass(${fieldsCasts.length});
        final InternalRow $tmpInput = $input;
        $fieldsEvalCodes
        $result = $tmpResult;
      """
  }

  override def prettyName: String = if (!isTryCast) {
    "cast"
  } else {
    "try_cast"
  }

  override def toString: String = {
    s"$prettyName($child as ${dataType.simpleString})"
  }

  override def sql: String = dataType match {
    // HiveQL doesn't allow casting to complex types. For logical plans translated from HiveQL, this
    // type of casting can only be introduced by the analyzer, and can be omitted when converting
    // back to SQL query string.
    case _: ArrayType | _: MapType | _: StructType => child.sql
    case _ => s"${prettyName.toUpperCase(Locale.ROOT)}(${child.sql} AS ${dataType.sql})"
  }
}

/**
 * Cast the child expression to the target data type, but will throw error if the cast might
 * truncate, e.g. long -> int, timestamp -> data.
 *
 * Note: `target` is `AbstractDataType`, so that we can put `object DecimalType`, which means
 * we accept `DecimalType` with any valid precision/scale.
 */
case class UpCast(child: Expression, target: AbstractDataType, walkedTypePath: Seq[String] = Nil)
  extends UnaryExpression with Unevaluable {
  override lazy val resolved = false

  final override val nodePatterns: Seq[TreePattern] = Seq(UP_CAST)

  def dataType: DataType = target match {
    case DecimalType => DecimalType.SYSTEM_DEFAULT
    case _ => target.asInstanceOf[DataType]
  }

  override protected def withNewChildInternal(newChild: Expression): UpCast = copy(child = newChild)
}

/**
 * Casting a numeric value as another numeric type in store assignment. It can capture the
 * arithmetic errors and show proper error messages to users.
 */
case class CheckOverflowInTableInsert(child: Cast, columnName: String) extends UnaryExpression {
  override protected def withNewChildInternal(newChild: Expression): Expression =
    copy(child = newChild.asInstanceOf[Cast])

  override def eval(input: InternalRow): Any = try {
    child.eval(input)
  } catch {
    case e: SparkArithmeticException =>
      throw QueryExecutionErrors.castingCauseOverflowErrorInTableInsert(
        child.child.dataType,
        child.dataType,
        columnName)
  }

  override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
    val childGen = child.genCode(ctx)
    val exceptionClass = classOf[SparkArithmeticException].getCanonicalName
    val fromDt =
      ctx.addReferenceObj("from", child.child.dataType, child.child.dataType.getClass.getName)
    val toDt = ctx.addReferenceObj("to", child.dataType, child.dataType.getClass.getName)
    val col = ctx.addReferenceObj("colName", columnName, "java.lang.String")
    // scalastyle:off line.size.limit
    ev.copy(code = code"""
      boolean ${ev.isNull} = true;
      ${CodeGenerator.javaType(dataType)} ${ev.value} = ${CodeGenerator.defaultValue(dataType)};
      try {
        ${childGen.code}
        ${ev.isNull} = ${childGen.isNull};
        ${ev.value} = ${childGen.value};
      } catch ($exceptionClass e) {
        throw QueryExecutionErrors.castingCauseOverflowErrorInTableInsert($fromDt, $toDt, $col);
      }"""
    )
    // scalastyle:on line.size.limit
  }

  override def dataType: DataType = child.dataType

  override def sql: String = child.sql

  override def toString: String = child.toString
}

相关信息

spark 源码目录

相关文章

spark AliasHelper 源码

spark ApplyFunctionExpression 源码

spark AttributeSet 源码

spark BloomFilterMightContain 源码

spark BoundAttribute 源码

spark CallMethodViaReflection 源码

spark CodeGeneratorWithInterpretedFallback 源码

spark DynamicPruning 源码

spark EquivalentExpressions 源码

spark EvalMode 源码

0  赞