spark randomExpressions 源码
spark randomExpressions 代码
文件路径:/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/randomExpressions.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.catalyst.expressions
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed
import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, CodeGenerator, ExprCode, FalseLiteral}
import org.apache.spark.sql.catalyst.expressions.codegen.Block._
import org.apache.spark.sql.catalyst.trees.TreePattern.{EXPRESSION_WITH_RANDOM_SEED, TreePattern}
import org.apache.spark.sql.types._
import org.apache.spark.util.random.XORShiftRandom
/**
* A Random distribution generating expression.
* TODO: This can be made generic to generate any type of random distribution, or any type of
* StructType.
*
* Since this expression is stateful, it cannot be a case object.
*/
abstract class RDG extends UnaryExpression with ExpectsInputTypes with Stateful
with ExpressionWithRandomSeed {
/**
* Record ID within each partition. By being transient, the Random Number Generator is
* reset every time we serialize and deserialize and initialize it.
*/
@transient protected var rng: XORShiftRandom = _
override protected def initializeInternal(partitionIndex: Int): Unit = {
rng = new XORShiftRandom(seed + partitionIndex)
}
override def seedExpression: Expression = child
@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
}
override def nullable: Boolean = false
override def dataType: DataType = DoubleType
override def inputTypes: Seq[AbstractDataType] = Seq(TypeCollection(IntegerType, LongType))
}
/**
* Represents the behavior of expressions which have a random seed and can renew the seed.
* Usually the random seed needs to be renewed at each execution under streaming queries.
*/
trait ExpressionWithRandomSeed extends Expression {
override val nodePatterns: Seq[TreePattern] = Seq(EXPRESSION_WITH_RANDOM_SEED)
def seedExpression: Expression
def withNewSeed(seed: Long): Expression
}
/** Generate a random column with i.i.d. uniformly distributed values in [0, 1). */
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = "_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) uniformly distributed values in [0, 1).",
examples = """
Examples:
> SELECT _FUNC_();
0.9629742951434543
> SELECT _FUNC_(0);
0.7604953758285915
> SELECT _FUNC_(null);
0.7604953758285915
""",
note = """
The function is non-deterministic in general case.
""",
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class Rand(child: Expression, hideSeed: Boolean = false) extends RDG {
def this() = this(UnresolvedSeed, true)
def this(child: Expression) = this(child, false)
override def withNewSeed(seed: Long): Rand = Rand(Literal(seed, LongType), hideSeed)
override protected def evalInternal(input: InternalRow): Double = rng.nextDouble()
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = classOf[XORShiftRandom].getName
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextDouble();""",
isNull = FalseLiteral)
}
override def freshCopy(): Rand = Rand(child, hideSeed)
override def flatArguments: Iterator[Any] = Iterator(child)
override def sql: String = {
s"rand(${if (hideSeed) "" else child.sql})"
}
override protected def withNewChildInternal(newChild: Expression): Rand = copy(child = newChild)
}
object Rand {
def apply(seed: Long): Rand = Rand(Literal(seed, LongType))
}
/** Generate a random column with i.i.d. values drawn from the standard normal distribution. */
// scalastyle:off line.size.limit
@ExpressionDescription(
usage = """_FUNC_([seed]) - Returns a random value with independent and identically distributed (i.i.d.) values drawn from the standard normal distribution.""",
examples = """
Examples:
> SELECT _FUNC_();
-0.3254147983080288
> SELECT _FUNC_(0);
1.6034991609278433
> SELECT _FUNC_(null);
1.6034991609278433
""",
note = """
The function is non-deterministic in general case.
""",
since = "1.5.0",
group = "math_funcs")
// scalastyle:on line.size.limit
case class Randn(child: Expression, hideSeed: Boolean = false) extends RDG {
def this() = this(UnresolvedSeed, true)
def this(child: Expression) = this(child, false)
override def withNewSeed(seed: Long): Randn = Randn(Literal(seed, LongType), hideSeed)
override protected def evalInternal(input: InternalRow): Double = rng.nextGaussian()
override def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val className = classOf[XORShiftRandom].getName
val rngTerm = ctx.addMutableState(className, "rng")
ctx.addPartitionInitializationStatement(
s"$rngTerm = new $className(${seed}L + partitionIndex);")
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $rngTerm.nextGaussian();""",
isNull = FalseLiteral)
}
override def freshCopy(): Randn = Randn(child, hideSeed)
override def flatArguments: Iterator[Any] = Iterator(child)
override def sql: String = {
s"randn(${if (hideSeed) "" else child.sql})"
}
override protected def withNewChildInternal(newChild: Expression): Randn = copy(child = newChild)
}
object Randn {
def apply(seed: Long): Randn = Randn(Literal(seed, LongType))
}
相关信息
相关文章
spark ApplyFunctionExpression 源码
spark BloomFilterMightContain 源码
spark CallMethodViaReflection 源码
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
7、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦