spark JoinedRow 源码

  • 2022-10-20
  • 浏览 (231)

spark JoinedRow 代码

文件路径:/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/JoinedRow.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.catalyst.expressions

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.util.{ArrayData, MapData}
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.{CalendarInterval, UTF8String}

/**
 * A mutable wrapper that makes two rows appear as a single concatenated row.  Designed to
 * be instantiated once per thread and reused.
 */
class JoinedRow extends InternalRow {
  private[this] var row1: InternalRow = _
  private[this] var row2: InternalRow = _

  def this(left: InternalRow, right: InternalRow) = {
    this()
    row1 = left
    row2 = right
  }

  /** Updates this JoinedRow to used point at two new base rows.  Returns itself. */
  def apply(r1: InternalRow, r2: InternalRow): JoinedRow = {
    row1 = r1
    row2 = r2
    this
  }

  /** Updates this JoinedRow by updating its left base row.  Returns itself. */
  def withLeft(newLeft: InternalRow): JoinedRow = {
    row1 = newLeft
    this
  }

  /** Updates this JoinedRow by updating its right base row.  Returns itself. */
  def withRight(newRight: InternalRow): JoinedRow = {
    row2 = newRight
    this
  }

  /** Gets this JoinedRow's left base row. */
  def getLeft: InternalRow = {
    row1
  }

  /** Gets this JoinedRow's right base row. */
  def getRight: InternalRow = {
    row2
  }

  override def toSeq(fieldTypes: Seq[DataType]): Seq[Any] = {
    assert(fieldTypes.length == row1.numFields + row2.numFields)
    val (left, right) = fieldTypes.splitAt(row1.numFields)
    row1.toSeq(left) ++ row2.toSeq(right)
  }

  override def numFields: Int = row1.numFields + row2.numFields

  override def get(i: Int, dt: DataType): AnyRef =
    if (i < row1.numFields) row1.get(i, dt) else row2.get(i - row1.numFields, dt)

  override def isNullAt(i: Int): Boolean =
    if (i < row1.numFields) row1.isNullAt(i) else row2.isNullAt(i - row1.numFields)

  override def getBoolean(i: Int): Boolean =
    if (i < row1.numFields) row1.getBoolean(i) else row2.getBoolean(i - row1.numFields)

  override def getByte(i: Int): Byte =
    if (i < row1.numFields) row1.getByte(i) else row2.getByte(i - row1.numFields)

  override def getShort(i: Int): Short =
    if (i < row1.numFields) row1.getShort(i) else row2.getShort(i - row1.numFields)

  override def getInt(i: Int): Int =
    if (i < row1.numFields) row1.getInt(i) else row2.getInt(i - row1.numFields)

  override def getLong(i: Int): Long =
    if (i < row1.numFields) row1.getLong(i) else row2.getLong(i - row1.numFields)

  override def getFloat(i: Int): Float =
    if (i < row1.numFields) row1.getFloat(i) else row2.getFloat(i - row1.numFields)

  override def getDouble(i: Int): Double =
    if (i < row1.numFields) row1.getDouble(i) else row2.getDouble(i - row1.numFields)

  override def getDecimal(i: Int, precision: Int, scale: Int): Decimal = {
    if (i < row1.numFields) {
      row1.getDecimal(i, precision, scale)
    } else {
      row2.getDecimal(i - row1.numFields, precision, scale)
    }
  }

  override def getUTF8String(i: Int): UTF8String =
    if (i < row1.numFields) row1.getUTF8String(i) else row2.getUTF8String(i - row1.numFields)

  override def getBinary(i: Int): Array[Byte] =
    if (i < row1.numFields) row1.getBinary(i) else row2.getBinary(i - row1.numFields)

  override def getArray(i: Int): ArrayData =
    if (i < row1.numFields) row1.getArray(i) else row2.getArray(i - row1.numFields)

  override def getInterval(i: Int): CalendarInterval =
    if (i < row1.numFields) row1.getInterval(i) else row2.getInterval(i - row1.numFields)

  override def getMap(i: Int): MapData =
    if (i < row1.numFields) row1.getMap(i) else row2.getMap(i - row1.numFields)

  override def getStruct(i: Int, numFields: Int): InternalRow = {
    if (i < row1.numFields) {
      row1.getStruct(i, numFields)
    } else {
      row2.getStruct(i - row1.numFields, numFields)
    }
  }

  override def anyNull: Boolean = row1.anyNull || row2.anyNull

  override def setNullAt(i: Int): Unit = {
    if (i < row1.numFields) {
      row1.setNullAt(i)
    } else {
      row2.setNullAt(i - row1.numFields)
    }
  }

  override def update(i: Int, value: Any): Unit = {
    if (i < row1.numFields) {
      row1.update(i, value)
    } else {
      row2.update(i - row1.numFields, value)
    }
  }

  override def copy(): InternalRow = {
    val copy1 = row1.copy()
    val copy2 = row2.copy()
    new JoinedRow(copy1, copy2)
  }

  override def toString: String = {
    // Make sure toString never throws NullPointerException.
    if ((row1 eq null) && (row2 eq null)) {
      "[ empty row ]"
    } else if (row1 eq null) {
      row2.toString
    } else if (row2 eq null) {
      row1.toString
    } else {
      s"{${row1.toString} + ${row2.toString}}"
    }
  }
}

相关信息

spark 源码目录

相关文章

spark AliasHelper 源码

spark ApplyFunctionExpression 源码

spark AttributeSet 源码

spark BloomFilterMightContain 源码

spark BoundAttribute 源码

spark CallMethodViaReflection 源码

spark Cast 源码

spark CodeGeneratorWithInterpretedFallback 源码

spark DynamicPruning 源码

spark EquivalentExpressions 源码

0  赞