spark MLSerDe 源码

  • 2022-10-20
  • 浏览 (242)

spark MLSerDe 代码

文件路径:/mllib/src/main/scala/org/apache/spark/ml/python/MLSerDe.scala

/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *    http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.spark.ml.python

import java.io.OutputStream
import java.nio.{ByteBuffer, ByteOrder}

import net.razorvine.pickle._

import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.ml.linalg._
import org.apache.spark.mllib.api.python.SerDeBase

/**
 * SerDe utility functions for pyspark.ml.
 */
private[spark] object MLSerDe extends SerDeBase with Serializable {

  override val PYSPARK_PACKAGE = "pyspark.ml"

  // Pickler for DenseVector
  private[python] class DenseVectorPickler extends BasePickler[DenseVector] {

    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
      val vector: DenseVector = obj.asInstanceOf[DenseVector]
      val bytes = new Array[Byte](8 * vector.size)
      val bb = ByteBuffer.wrap(bytes)
      bb.order(ByteOrder.nativeOrder())
      val db = bb.asDoubleBuffer()
      db.put(vector.values)

      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(bytes.length))
      out.write(bytes)
      out.write(Opcodes.TUPLE1)
    }

    def construct(args: Array[Object]): Object = {
      if (args.length != 1) {
        throw new PickleException("length of args should be 1")
      }
      val bytes = getBytes(args(0))
      val bb = ByteBuffer.wrap(bytes, 0, bytes.length)
      bb.order(ByteOrder.nativeOrder())
      val db = bb.asDoubleBuffer()
      val ans = new Array[Double](bytes.length / 8)
      db.get(ans)
      Vectors.dense(ans)
    }
  }

  // Pickler for DenseMatrix
  private[python] class DenseMatrixPickler extends BasePickler[DenseMatrix] {

    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
      val m: DenseMatrix = obj.asInstanceOf[DenseMatrix]
      val bytes = new Array[Byte](8 * m.values.length)
      val order = ByteOrder.nativeOrder()
      val isTransposed = if (m.isTransposed) 1 else 0
      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().put(m.values)

      out.write(Opcodes.MARK)
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(m.numRows))
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(m.numCols))
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(bytes.length))
      out.write(bytes)
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(isTransposed))
      out.write(Opcodes.TUPLE)
    }

    def construct(args: Array[Object]): Object = {
      if (args.length != 4) {
        throw new PickleException("length of args should be 4")
      }
      val bytes = getBytes(args(2))
      val n = bytes.length / 8
      val values = new Array[Double](n)
      val order = ByteOrder.nativeOrder()
      ByteBuffer.wrap(bytes).order(order).asDoubleBuffer().get(values)
      val isTransposed = args(3).asInstanceOf[Int] == 1
      new DenseMatrix(args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], values, isTransposed)
    }
  }

  // Pickler for SparseMatrix
  private[python] class SparseMatrixPickler extends BasePickler[SparseMatrix] {

    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
      val s = obj.asInstanceOf[SparseMatrix]
      val order = ByteOrder.nativeOrder()

      val colPtrsBytes = new Array[Byte](4 * s.colPtrs.length)
      val indicesBytes = new Array[Byte](4 * s.rowIndices.length)
      val valuesBytes = new Array[Byte](8 * s.values.length)
      val isTransposed = if (s.isTransposed) 1 else 0
      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().put(s.colPtrs)
      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().put(s.rowIndices)
      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().put(s.values)

      out.write(Opcodes.MARK)
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(s.numRows))
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(s.numCols))
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(colPtrsBytes.length))
      out.write(colPtrsBytes)
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(indicesBytes.length))
      out.write(indicesBytes)
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(valuesBytes.length))
      out.write(valuesBytes)
      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(isTransposed))
      out.write(Opcodes.TUPLE)
    }

    def construct(args: Array[Object]): Object = {
      if (args.length != 6) {
        throw new PickleException("length of args should be 6")
      }
      val order = ByteOrder.nativeOrder()
      val colPtrsBytes = getBytes(args(2))
      val indicesBytes = getBytes(args(3))
      val valuesBytes = getBytes(args(4))
      val colPtrs = new Array[Int](colPtrsBytes.length / 4)
      val rowIndices = new Array[Int](indicesBytes.length / 4)
      val values = new Array[Double](valuesBytes.length / 8)
      ByteBuffer.wrap(colPtrsBytes).order(order).asIntBuffer().get(colPtrs)
      ByteBuffer.wrap(indicesBytes).order(order).asIntBuffer().get(rowIndices)
      ByteBuffer.wrap(valuesBytes).order(order).asDoubleBuffer().get(values)
      val isTransposed = args(5).asInstanceOf[Int] == 1
      new SparseMatrix(
        args(0).asInstanceOf[Int], args(1).asInstanceOf[Int], colPtrs, rowIndices, values,
        isTransposed)
    }
  }

  // Pickler for SparseVector
  private[python] class SparseVectorPickler extends BasePickler[SparseVector] {

    def saveState(obj: Object, out: OutputStream, pickler: Pickler): Unit = {
      val v: SparseVector = obj.asInstanceOf[SparseVector]
      val n = v.indices.length
      val indiceBytes = new Array[Byte](4 * n)
      val order = ByteOrder.nativeOrder()
      ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().put(v.indices)
      val valueBytes = new Array[Byte](8 * n)
      ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().put(v.values)

      out.write(Opcodes.BININT)
      out.write(PickleUtils.integer_to_bytes(v.size))
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(indiceBytes.length))
      out.write(indiceBytes)
      out.write(Opcodes.BINSTRING)
      out.write(PickleUtils.integer_to_bytes(valueBytes.length))
      out.write(valueBytes)
      out.write(Opcodes.TUPLE3)
    }

    def construct(args: Array[Object]): Object = {
      if (args.length != 3) {
        throw new PickleException("length of args should be 3")
      }
      val size = args(0).asInstanceOf[Int]
      val indiceBytes = getBytes(args(1))
      val valueBytes = getBytes(args(2))
      val n = indiceBytes.length / 4
      val indices = new Array[Int](n)
      val values = new Array[Double](n)
      if (n > 0) {
        val order = ByteOrder.nativeOrder()
        ByteBuffer.wrap(indiceBytes).order(order).asIntBuffer().get(indices)
        ByteBuffer.wrap(valueBytes).order(order).asDoubleBuffer().get(values)
      }
      new SparseVector(size, indices, values)
    }
  }

  var initialized = false
  // This should be called before trying to serialize any above classes
  // In cluster mode, this should be put in the closure
  override def initialize(): Unit = {
    SerDeUtil.initialize()
    synchronized {
      if (!initialized) {
        new DenseVectorPickler().register()
        new DenseMatrixPickler().register()
        new SparseMatrixPickler().register()
        new SparseVectorPickler().register()
        initialized = true
      }
    }
  }
  // will not called in Executor automatically
  initialize()
}

相关信息

spark 源码目录

相关文章

spark ArrayWrappers 源码

spark InMemoryStore 源码

spark KVIndex 源码

spark KVStore 源码

spark KVStoreIterator 源码

spark KVStoreSerializer 源码

spark KVStoreView 源码

spark KVTypeInfo 源码

spark LevelDB 源码

spark LevelDBIterator 源码

0  赞