spark KafkaWriter 源码

  • 2022-10-20
  • 浏览 (302)

spark KafkaWriter 代码

文件路径:/connector/kafka-0-10-sql/src/main/scala/org/apache/spark/sql/kafka010/KafkaWriter.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.sql.kafka010

import java.{util => ju}

import org.apache.spark.internal.Logging
import org.apache.spark.sql.{AnalysisException, SparkSession}
import org.apache.spark.sql.catalyst.CatalystTypeConverters
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.QueryExecution
import org.apache.spark.sql.types.{BinaryType, DataType, IntegerType, StringType}
import org.apache.spark.util.Utils

/**
 * The [[KafkaWriter]] class is used to write data from a batch query
 * or structured streaming query, given by a [[QueryExecution]], to Kafka.
 * The data is assumed to have a value column, and an optional topic and key
 * columns. If the topic column is missing, then the topic must come from
 * the 'topic' configuration option. If the key column is missing, then a
 * null valued key field will be added to the
 * [[org.apache.kafka.clients.producer.ProducerRecord]].
 */
private[kafka010] object KafkaWriter extends Logging {
  val TOPIC_ATTRIBUTE_NAME: String = "topic"
  val KEY_ATTRIBUTE_NAME: String = "key"
  val VALUE_ATTRIBUTE_NAME: String = "value"
  val HEADERS_ATTRIBUTE_NAME: String = "headers"
  val PARTITION_ATTRIBUTE_NAME: String = "partition"

  override def toString: String = "KafkaWriter"

  def validateQuery(
      schema: Seq[Attribute],
      kafkaParameters: ju.Map[String, Object],
      topic: Option[String] = None): Unit = {
    try {
      topicExpression(schema, topic)
      keyExpression(schema)
      valueExpression(schema)
      headersExpression(schema)
      partitionExpression(schema)
    } catch {
      case e: IllegalStateException => throw new AnalysisException(e.getMessage)
    }
  }

  def write(
      sparkSession: SparkSession,
      queryExecution: QueryExecution,
      kafkaParameters: ju.Map[String, Object],
      topic: Option[String] = None): Unit = {
    val schema = queryExecution.analyzed.output
    validateQuery(schema, kafkaParameters, topic)
    queryExecution.toRdd.foreachPartition { iter =>
      val writeTask = new KafkaWriteTask(kafkaParameters, schema, topic)
      Utils.tryWithSafeFinally(block = writeTask.execute(iter))(
        finallyBlock = writeTask.close())
    }
  }

  def topicExpression(schema: Seq[Attribute], topic: Option[String] = None): Expression = {
    topic.map(Literal(_)).getOrElse(
      expression(schema, TOPIC_ATTRIBUTE_NAME, Seq(StringType)) {
        throw new IllegalStateException(s"topic option required when no " +
          s"'${TOPIC_ATTRIBUTE_NAME}' attribute is present. Use the " +
          s"${KafkaSourceProvider.TOPIC_OPTION_KEY} option for setting a topic.")
      }
    )
  }

  def keyExpression(schema: Seq[Attribute]): Expression = {
    expression(schema, KEY_ATTRIBUTE_NAME, Seq(StringType, BinaryType)) {
      Literal(null, BinaryType)
    }
  }

  def valueExpression(schema: Seq[Attribute]): Expression = {
    expression(schema, VALUE_ATTRIBUTE_NAME, Seq(StringType, BinaryType)) {
      throw new IllegalStateException(s"Required attribute '${VALUE_ATTRIBUTE_NAME}' not found")
    }
  }

  def headersExpression(schema: Seq[Attribute]): Expression = {
    expression(schema, HEADERS_ATTRIBUTE_NAME, Seq(KafkaRecordToRowConverter.headersType)) {
      Literal(CatalystTypeConverters.convertToCatalyst(null),
        KafkaRecordToRowConverter.headersType)
    }
  }

  def partitionExpression(schema: Seq[Attribute]): Expression = {
    expression(schema, PARTITION_ATTRIBUTE_NAME, Seq(IntegerType)) {
      Literal(null, IntegerType)
    }
  }

  private def expression(
      schema: Seq[Attribute],
      attrName: String,
      desired: Seq[DataType])(
      default: => Expression): Expression = {
    val expr = schema.find(_.name == attrName).getOrElse(default)
    if (!desired.exists(_.sameType(expr.dataType))) {
      throw new IllegalStateException(s"$attrName attribute unsupported type " +
        s"${expr.dataType.catalogString}. $attrName must be a(n) " +
        s"${desired.map(_.catalogString).mkString(" or ")}")
    }
    expr
  }
}

相关信息

spark 源码目录

相关文章

spark ConsumerStrategy 源码

spark JsonUtils 源码

spark KafkaBatch 源码

spark KafkaBatchPartitionReader 源码

spark KafkaBatchWrite 源码

spark KafkaContinuousStream 源码

spark KafkaDataWriter 源码

spark KafkaMicroBatchStream 源码

spark KafkaOffsetRangeCalculator 源码

spark KafkaOffsetRangeLimit 源码

0  赞