spark functions 源码
spark functions 代码
文件路径:/mllib/src/main/scala/org/apache/spark/ml/functions.scala
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.ml
import org.apache.spark.annotation.Since
import org.apache.spark.ml.linalg.{SparseVector, Vector, Vectors}
import org.apache.spark.mllib.linalg.{Vector => OldVector}
import org.apache.spark.sql.Column
import org.apache.spark.sql.functions.udf
// scalastyle:off
@Since("3.0.0")
object functions {
// scalastyle:on
private val vectorToArrayUdf = udf { vec: Any =>
vec match {
case v: Vector => v.toArray
case v: OldVector => v.toArray
case v => throw new IllegalArgumentException(
"function vector_to_array requires a non-null input argument and input type must be " +
"`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " +
s"but got ${ if (v == null) "null" else v.getClass.getName }.")
}
}.asNonNullable()
private val vectorToArrayFloatUdf = udf { vec: Any =>
vec match {
case v: SparseVector =>
val data = new Array[Float](v.size)
v.foreachActive { (index, value) => data(index) = value.toFloat }
data
case v: Vector => v.toArray.map(_.toFloat)
case v: OldVector => v.toArray.map(_.toFloat)
case v => throw new IllegalArgumentException(
"function vector_to_array requires a non-null input argument and input type must be " +
"`org.apache.spark.ml.linalg.Vector` or `org.apache.spark.mllib.linalg.Vector`, " +
s"but got ${ if (v == null) "null" else v.getClass.getName }.")
}
}.asNonNullable()
/**
* Converts a column of MLlib sparse/dense vectors into a column of dense arrays.
* @param v: the column of MLlib sparse/dense vectors
* @param dtype: the desired underlying data type in the returned array
* @return an array<float> if dtype is float32, or array<double> if dtype is float64
* @since 3.0.0
*/
def vector_to_array(v: Column, dtype: String = "float64"): Column = {
if (dtype == "float64") {
vectorToArrayUdf(v)
} else if (dtype == "float32") {
vectorToArrayFloatUdf(v)
} else {
throw new IllegalArgumentException(
s"Unsupported dtype: $dtype. Valid values: float64, float32."
)
}
}
private val arrayToVectorUdf = udf { array: Seq[Double] =>
Vectors.dense(array.toArray)
}
/**
* Converts a column of array of numeric type into a column of dense vectors in MLlib.
* @param v: the column of array<NumericType> type
* @return a column of type `org.apache.spark.ml.linalg.Vector`
* @since 3.1.0
*/
def array_to_vector(v: Column): Column = {
arrayToVectorUdf(v)
}
}
相关信息
相关文章
0
赞
- 所属分类: 前端技术
- 本文标签:
热门推荐
-
2、 - 优质文章
-
3、 gate.io
-
8、 golang
-
9、 openharmony
-
10、 Vue中input框自动聚焦