diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/commands/OptimizeTableCommand.scala b/spark/src/main/scala/org/apache/spark/sql/delta/commands/OptimizeTableCommand.scala index 398a0e4572c..e053ddce4d4 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/commands/OptimizeTableCommand.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/commands/OptimizeTableCommand.scala @@ -391,7 +391,8 @@ class OptimizeExecutor( MultiDimClustering.cluster( input, approxNumFiles, - zOrderByColumns) + zOrderByColumns, + "zorder") } else { val useRepartition = sparkSession.sessionState.conf.getConf( DeltaSQLConf.DELTA_OPTIMIZE_REPARTITION_ENABLED) diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/expressions/HilbertIndex.scala b/spark/src/main/scala/org/apache/spark/sql/delta/expressions/HilbertIndex.scala new file mode 100644 index 00000000000..fd5f605b540 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/delta/expressions/HilbertIndex.scala @@ -0,0 +1,403 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.expressions + +import java.util + +import scala.collection.mutable + +import org.apache.spark.sql.delta.expressions.HilbertUtils._ + +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{ExpectsInputTypes, Expression} +import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback +import org.apache.spark.sql.types.{AbstractDataType, DataType, DataTypes} + +/** + * Represents a hilbert index built from the provided columns. + * The columns are expected to all be Ints and to have at most numBits individually. + * The points along the hilbert curve are represented by Longs. + */ +private[sql] case class HilbertLongIndex(numBits: Int, children: Seq[Expression]) + extends Expression with ExpectsInputTypes with CodegenFallback { + + private val n: Int = children.size + private val nullValue: Int = 0 + + override def nullable: Boolean = false + + // pre-initialize working set array + private val ints = new Array[Int](n) + + override def eval(input: InternalRow): Any = { + var i = 0 + while (i < n) { + ints(i) = children(i).eval(input) match { + case null => nullValue + case int: Integer => int + case any => throw new IllegalArgumentException( + s"${this.getClass.getSimpleName} expects only inputs of type Int, but got: " + + s"$any of type${any.getClass.getSimpleName}") + } + i += 1 + } + + HilbertStates.getStateList(n).translateNPointToDKey(ints, numBits) + } + + override def dataType: DataType = DataTypes.LongType + + override def inputTypes: Seq[AbstractDataType] = Seq.fill(n)(DataTypes.IntegerType) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): HilbertLongIndex = copy(children = newChildren) +} + +/** + * Represents a hilbert index built from the provided columns. + * The columns are expected to all be Ints and to have at most numBits. + * The points along the hilbert curve are represented by Byte arrays. + */ +private[sql] case class HilbertByteArrayIndex(numBits: Int, children: Seq[Expression]) + extends Expression with ExpectsInputTypes with CodegenFallback { + + private val n: Int = children.size + private val nullValue: Int = 0 + + override def nullable: Boolean = false + + // pre-initialize working set array + private val ints = new Array[Int](n) + + override def eval(input: InternalRow): Any = { + var i = 0 + while (i < n) { + ints(i) = children(i).eval(input) match { + case null => nullValue + case int: Integer => int + case any => throw new IllegalArgumentException( + s"${this.getClass.getSimpleName} expects only inputs of type Int, but got: " + + s"$any of type${any.getClass.getSimpleName}") + } + i += 1 + } + + HilbertStates.getStateList(n).translateNPointToDKeyArray(ints, numBits) + } + + override def dataType: DataType = DataTypes.BinaryType + + override def inputTypes: Seq[AbstractDataType] = Seq.fill(n)(DataTypes.IntegerType) + + override protected def withNewChildrenInternal( + newChildren: IndexedSeq[Expression]): HilbertByteArrayIndex = copy(children = newChildren) +} + +// scalastyle:off line.size.limit +/** + * The following code is based on this paper: + * https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=bfd6d94c98627756989b0147a68b7ab1f881a0d6 + * with optimizations around matrix manipulation taken from this one: + * https://pdfs.semanticscholar.org/4043/1c5c43a2121e1bc071fc035e90b8f4bb7164.pdf + * + * At a high level you construct a GeneratorTable with the getStateGenerator method. + * That represents the information necessary to construct a state list for a given number + * of dimension, N. + * Once you have the generator table for your dimension you can construct a state list. + * You can then turn those state lists into compact state lists that store all the information + * in one large array of longs. + */ +// scalastyle:on line.size.limit +object HilbertIndex { + + private type CompactStateList = HilbertCompactStateList + + val SIZE_OF_INT = 32 + + /** + * Construct the generator table for a space of dimension n. + * This table consists of 2^n rows, each row containing Y, X1, and TY. + * Y The index in the array representing the table. (0 to (2^n - 1)) + * X1 A coordinate representing points on the curve expressed as an n-point. + * These are arranged such that if two rows differ by 1 in Y then the binary + * representation of their X1 values differ by exactly one bit. + * These are the "Gray-codes" of their Y value. + * TY A transformation matrix that transforms X2(1) to the X1 value where Y is zero and + * transforms X2(2) to the X1 value where Y is (2^n - 1) + */ + def getStateGenerator(n: Int): GeneratorTable = { + val x2s = getX2GrayCodes(n) + + val len = 1 << n + val rows = (0 until len).map { i => + // A pair of n-points corresponding to the first and last points on the first order curve to + // which X1 transforms in the construction of a second order curve. + val x21 = x2s(i << 1) + val x22 = x2s((i << 1) + 1) + // Represents the magnitude of difference between X2 values in this row. + val dy = x21 ^ x22 + + Row( + y = i, + x1 = i ^ (i >>> 1), + m = HilbertMatrix(n, x21, getSetColumn(n, dy)) + ) + } + + new GeneratorTable(n, rows) + } + + // scalastyle:off line.size.limit + /** + * This will construct an x2-gray-codes sequence of order n as described in + * https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=bfd6d94c98627756989b0147a68b7ab1f881a0d6 + * + * Each pair of values corresponds to the first and last coordinates of points on a first + * order curve to which a point taken from column X1 transforms to at the second order. + */ + // scalastyle:on line.size.limit + private[this] def getX2GrayCodes(n: Int) : Array[Int] = { + if (n == 1) { + // hard code the base case + return Array(0, 1, 0, 1) + } + val mask = 1 << (n - 1) + val base = getX2GrayCodes(n - 1) + base(base.length - 1) = base(base.length - 2) + mask + val result = Array.fill(base.length * 2)(0) + base.indices.foreach { i => + result(i) = base(i) + result(result.length - 1 - i) = base(i) ^ mask + } + result + } + + private[this] case class Row(y: Int, x1: Int, m: HilbertMatrix) + + private[this] case class PointState(y: Int, var x1: Int = 0, var state: Int = 0) + + private[this] case class State(id: Int, matrix: HilbertMatrix, var pointStates: Seq[PointState]) + + private[sql] class StateList(n: Int, states: Map[Int, State]) { + def getNPointToDKeyStateMap: CompactStateList = { + val numNPoints = 1 << n + val array = new Array[Long](numNPoints * states.size) + + states.foreach { case (stateIdx, state) => + val stateStartIdx = stateIdx * numNPoints + + state.pointStates.foreach { ps => + val psLong = (ps.y.toLong << SIZE_OF_INT) | ps.state.toLong + array(stateStartIdx + ps.x1) = psLong + } + } + new CompactStateList(n, array) + } + def getDKeyToNPointStateMap: CompactStateList = { + val numNPoints = 1 << n + val array = new Array[Long](numNPoints * states.size) + + states.foreach { case (stateIdx, state) => + val stateStartIdx = stateIdx * numNPoints + + state.pointStates.foreach { ps => + val psLong = (ps.x1.toLong << SIZE_OF_INT) | ps.state.toLong + array(stateStartIdx + ps.y) = psLong + } + } + new CompactStateList(n, array) + } + } + + private[sql] class GeneratorTable(n: Int, rows: Seq[Row]) { + def generateStateList(): StateList = { + val result = mutable.Map[Int, State]() + val list = new util.LinkedList[State]() + + var nextStateNum = 1 + + val initialState = State(0, HilbertMatrix.identity(n), rows.map(r => PointState(r.y, r.x1))) + result.put(0, initialState) + + rows.foreach { row => + val matrix = row.m + result.find { case (_, s) => s.matrix == matrix } match { + case Some((_, s)) => + initialState.pointStates(row.y).state = s.id + case _ => + initialState.pointStates(row.y).state = nextStateNum + val newState = State(nextStateNum, matrix, Seq()) + result.put(nextStateNum, newState) + list.addLast(newState) + nextStateNum += 1 + } + } + + while (!list.isEmpty) { + val currentState = list.removeFirst() + currentState.pointStates = rows.indices.map(r => PointState(r)) + + rows.indices.foreach { i => + val j = currentState.matrix.transform(i) + val p = initialState.pointStates.find(_.x1 == j).get + val currentPointState = currentState.pointStates(p.y) + currentPointState.x1 = i + val tm = result(p.state).matrix.multiply(currentState.matrix) + + result.find { case (_, s) => s.matrix == tm } match { + case Some((_, s)) => + currentPointState.state = s.id + case _ => + currentPointState.state = nextStateNum + val newState = State(nextStateNum, tm, Seq()) + result.put(nextStateNum, newState) + list.addLast(newState) + nextStateNum += 1 + } + } + } + + new StateList(n, result.toMap) + } + } +} + +/** + * Represents a compact state map. This is used in the mapping between n-points and d-keys. + * [[array]] is treated as a Map(Int -> Map(Int -> (Int, Int))) + * + * Each values in the array will be a combination of two things, a point and the index of the + * next state, in the most- and least- significant bits, respectively. + * state -> coord -> [point + nextState] + */ +private[sql] class HilbertCompactStateList(n: Int, array: Array[Long]) { + private val maxNumN = 1 << n + private val mask = maxNumN - 1 + private val intMask = (1L << HilbertIndex.SIZE_OF_INT) - 1 + + // point and nextState + @inline def transform(nPoint: Int, state: Int): (Int, Int) = { + val value = array(state * maxNumN + nPoint) + ( + (value >>> HilbertIndex.SIZE_OF_INT).toInt, + (value & intMask).toInt + ) + } + + // These while loops are to minimize overhead. + // This method exists only for testing + private[expressions] def translateDKeyToNPoint(key: Long, k: Int): Array[Int] = { + val result = new Array[Int](n) + var currentState = 0 + var i = 0 + while (i < k) { + val h = (key >> ((k - 1 - i) * n)) & mask + + val (z, nextState) = transform(h.toInt, currentState) + + var j = 0 + while (j < n) { + val v = (z >> (n - 1 - j)) & 1 + result(j) = (result(j) << 1) | v + j += 1 + } + + currentState = nextState + i += 1 + } + result + } + + // These while loops are to minimize overhead. + // This method exists only for testing + private[expressions] def translateDKeyArrayToNPoint(key: Array[Byte], k: Int): Array[Int] = { + val result = new Array[Int](n) + val initialOffset = (key.length * 8) - (k * n) + var currentState = 0 + var i = 0 + while (i < k) { + val offset = initialOffset + (i * n) + val h = getBits(key, offset, n) + + val (z, nextState) = transform(h, currentState) + + var j = 0 + while (j < n) { + val v = (z >> (n - 1 - j)) & 1 + result(j) = (result(j) << 1) | v + j += 1 + } + + currentState = nextState + i += 1 + } + result + } + + /** + * Translate an n-dimensional point into it's corresponding position on the n-dimensional + * hilbert curve. + * @param point An n-dimensional point. (assumed to have n elements) + * @param k The number of meaningful bits in each value of the point. + */ + def translateNPointToDKey(point: Array[Int], k: Int): Long = { + var result = 0L + var currentState = 0 + var i = 0 + while (i < k) { + var z = 0 + var j = 0 + while (j < n) { + z = (z << 1) | ((point(j) >> (k - 1 - i)) & 1) + j += 1 + } + val (h, nextState) = transform(z, currentState) + result = (result << n) | h + currentState = nextState + i += 1 + } + result + } + + /** + * Translate an n-dimensional point into it's corresponding position on the n-dimensional + * hilbert curve. Returns the resulting integer as an array of bytes. + * @param point An n-dimensional point. (assumed to have n elements) + * @param k The number of meaningful bits in each value of the point. + */ + def translateNPointToDKeyArray(point: Array[Int], k: Int): Array[Byte] = { + val numBits = k * n + val numBytes = (numBits + 7) / 8 + val result = new Array[Byte](numBytes) + val initialOffset = (numBytes * 8) - numBits + var currentState = 0 + var i = 0 + while (i < k) { + var z = 0 + var j = 0 + while (j < n) { + z = (z << 1) | ((point(j) >> (k - 1 - i)) & 1) + j += 1 + } + val (h, nextState) = transform(z, currentState) + setBits(result, initialOffset + (i * n), h, n) + currentState = nextState + i += 1 + } + result + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/expressions/HilbertStates.java b/spark/src/main/scala/org/apache/spark/sql/delta/expressions/HilbertStates.java new file mode 100644 index 00000000000..d7c69a7f7e4 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/delta/expressions/HilbertStates.java @@ -0,0 +1,92 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.expressions; + +import org.apache.spark.SparkException; + +public class HilbertStates { + + /** + * Constructs a hilbert state for the given arity, [[n]]. + * This state list can be used to map n-points to their corresponding d-key value. + * + * @param n The number of bits in this space (we assert 2 <= n <= 9 for simplicity) + * @return The CompactStateList for mapping from n-point to hilbert distance key. + */ + private static HilbertCompactStateList constructHilbertState(int n) { + HilbertIndex.GeneratorTable generator = HilbertIndex.getStateGenerator(n); + return generator.generateStateList().getNPointToDKeyStateMap(); + } + + private HilbertStates() { } + + private static class HilbertIndex2 { + static final HilbertCompactStateList STATE_LIST = constructHilbertState(2); + } + + private static class HilbertIndex3 { + static final HilbertCompactStateList STATE_LIST = constructHilbertState(3); + } + + private static class HilbertIndex4 { + static final HilbertCompactStateList STATE_LIST = constructHilbertState(4); + } + + private static class HilbertIndex5 { + static final HilbertCompactStateList STATE_LIST = constructHilbertState(5); + } + + private static class HilbertIndex6 { + static final HilbertCompactStateList STATE_LIST = constructHilbertState(6); + } + + private static class HilbertIndex7 { + static final HilbertCompactStateList STATE_LIST = constructHilbertState(7); + } + + private static class HilbertIndex8 { + static final HilbertCompactStateList STATE_LIST = constructHilbertState(8); + } + + private static class HilbertIndex9 { + static final HilbertCompactStateList STATE_LIST = constructHilbertState(9); + } + + public static HilbertCompactStateList getStateList(int n) throws SparkException { + switch (n) { + case 2: + return HilbertIndex2.STATE_LIST; + case 3: + return HilbertIndex3.STATE_LIST; + case 4: + return HilbertIndex4.STATE_LIST; + case 5: + return HilbertIndex5.STATE_LIST; + case 6: + return HilbertIndex6.STATE_LIST; + case 7: + return HilbertIndex7.STATE_LIST; + case 8: + return HilbertIndex8.STATE_LIST; + case 9: + return HilbertIndex9.STATE_LIST; + default: + throw new SparkException(String.format("Cannot perform hilbert clustering on " + + "fewer than 2 or more than 9 dimensions; got %d dimensions", n)); + } + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/expressions/HilbertUtils.scala b/spark/src/main/scala/org/apache/spark/sql/delta/expressions/HilbertUtils.scala new file mode 100644 index 00000000000..ebffce4b901 --- /dev/null +++ b/spark/src/main/scala/org/apache/spark/sql/delta/expressions/HilbertUtils.scala @@ -0,0 +1,165 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.expressions + +object HilbertUtils { + + /** + * Returns the column number that is set. We assume that a bit is set. + */ + @inline def getSetColumn(n: Int, i: Int): Int = { + n - 1 - Integer.numberOfTrailingZeros(i) + } + + @inline def circularLeftShift(n: Int, i: Int, shift: Int): Int = { + ((i << shift) | (i >>> (n - shift))) & ((1 << n) - 1) + } + + @inline def circularRightShift(n: Int, i: Int, shift: Int): Int = { + ((i >>> shift) | (i << (n - shift))) & ((1 << n) - 1) + } + + @inline + private[expressions] def getBits(key: Array[Byte], offset: Int, n: Int): Int = { + // [ ][ ][ ][ ][ ] + // <---offset---> [ n-bits ] <- this is the result + var result = 0 + + var remainingBits = n + var keyIndex = offset / 8 + // initial key offset + var keyOffset = offset - (keyIndex * 8) + while (remainingBits > 0) { + val bitsFromIdx = math.min(remainingBits, 8 - keyOffset) + val newInt = if (remainingBits >= 8) { + java.lang.Byte.toUnsignedInt(key(keyIndex)) + } else { + java.lang.Byte.toUnsignedInt(key(keyIndex)) >>> (8 - keyOffset - bitsFromIdx) + } + result = (result << bitsFromIdx) | (newInt & ((1 << bitsFromIdx) - 1)) + + remainingBits -= (8 - keyOffset) + keyOffset = 0 + keyIndex += 1 + } + + result + } + + @inline + private[expressions] def setBits( + key: Array[Byte], + offset: Int, + newBits: Int, + n: Int): Array[Byte] = { + // bits: [ meaningless bits ][ n meaningful bits ] + // + // [ ][ ][ ][ ][ ] + // <---offset---> [ n-bits ] + + // move meaningful bits to the far left + var bits = newBits << (32 - n) + var remainingBits = n + + // initial key index + var keyIndex = offset / 8 + // initial key offset + var keyOffset = offset - (keyIndex * 8) + while (remainingBits > 0) { + key(keyIndex) = (key(keyIndex) | (bits >>> (24 + keyOffset))).toByte + remainingBits -= (8 - keyOffset) + bits = bits << (8 - keyOffset) + keyOffset = 0 + keyIndex += 1 + } + key + } + + /** + * treats `key` as an Integer and adds 1 + */ + @inline def addOne(key: Array[Byte]): Array[Byte] = { + var idx = key.length - 1 + var overflow = true + while (overflow && idx >= 0) { + key(idx) = (key(idx) + 1.toByte).toByte + overflow = key(idx) == 0 + idx -= 1 + } + key + } + + def manhattanDist(p1: Array[Int], p2: Array[Int]): Int = { + assert(p1.length == p2.length) + p1.zip(p2).map { case (a, b) => math.abs(a - b) }.sum + } + + + /** + * This is not really a matrix, but a representation of one. Due to the constraints of this + * system the necessary matrices can be defined by two values: dY and X2. DY is the amount + * of right shifting of the identity matrix, and X2 is a bitmask for which column values are + * negative. The [[toString]] method is overridden to construct and print the matrix to aid + * in debugging. + * Instead of constructing the matrix directly we store and manipulate these values. + */ + case class HilbertMatrix(n: Int, x2: Int, dy: Int) { + override def toString(): String = { + val sb = new StringBuilder() + + val base = 1 << (n - 1 - dy) + (0 until n).foreach { i => + sb.append('\n') + val row = circularRightShift(n, base, i) + (0 until n).foreach { j => + if (isColumnSet(row, j)) { + if (isColumnSet(x2, j)) { + sb.append('-') + } else { + sb.append(' ') + } + sb.append('1') + } else { + sb.append(" 0") + } + } + } + sb.append('\n') + sb.toString + } + + // columns count from the left: 0, 1, 2 ... , n + @inline def isColumnSet(i: Int, column: Int): Boolean = { + val mask = 1 << (n - 1 - column) + (i & mask) > 0 + } + + def transform(e: Int): Int = { + circularLeftShift(n, e ^ x2, dy) + } + + def multiply(other: HilbertMatrix): HilbertMatrix = { + HilbertMatrix(n, circularRightShift(n, x2, other.dy) ^ other.x2, (dy + other.dy) % n) + } + } + + object HilbertMatrix { + def identity(n: Int): HilbertMatrix = { + HilbertMatrix(n, 0, 0) + } + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/skipping/MultiDimClustering.scala b/spark/src/main/scala/org/apache/spark/sql/delta/skipping/MultiDimClustering.scala index 04483db5b0b..46c3e0132bc 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/skipping/MultiDimClustering.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/skipping/MultiDimClustering.scala @@ -21,6 +21,7 @@ import java.util.UUID import org.apache.spark.sql.delta.skipping.MultiDimClusteringFunctions._ import org.apache.spark.sql.delta.sources.DeltaSQLConf +import org.apache.spark.SparkException import org.apache.spark.internal.Logging import org.apache.spark.sql._ import org.apache.spark.sql.functions._ @@ -38,13 +39,24 @@ trait MultiDimClustering extends Logging { } object MultiDimClustering { - /** Repartition the given dataframe `df` into `approxNumPartitions` on the given `colNames`. */ + /** + * Repartition the given dataframe `df` based on the given `curve` type into + * `approxNumPartitions` on the given `colNames`. + */ def cluster( df: DataFrame, approxNumPartitions: Int, - colNames: Seq[String]): DataFrame = { + colNames: Seq[String], + curve: String): DataFrame = { assert(colNames.nonEmpty, "Cannot cluster by zero columns!") - ZOrderClustering.cluster(df, colNames, approxNumPartitions, randomizationExpressionOpt = None) + val clusteringImpl = curve match { + case "hilbert" => HilbertClustering + case "zorder" => ZOrderClustering + case unknownCurve => + throw new SparkException(s"Unknown curve ($unknownCurve), unable to perform multi " + + "dimensional clustering.") + } + clusteringImpl.cluster(df, colNames, approxNumPartitions, randomizationExpressionOpt = None) } } @@ -90,3 +102,12 @@ object ZOrderClustering extends SpaceFillingCurveClustering { interleave_bits(rangeIdCols: _*).cast(StringType) } } + +object HilbertClustering extends SpaceFillingCurveClustering with Logging { + override protected def getClusteringExpression(cols: Seq[Column], numRanges: Int): Column = { + assert(cols.size > 1, "Cannot do Hilbert clustering by zero or one column!") + val rangeIdCols = cols.map(range_partition_id(_, numRanges)) + val numBits = Integer.numberOfTrailingZeros(Integer.highestOneBit(numRanges)) + 1 + hilbert_index(numBits, rangeIdCols: _*) + } +} diff --git a/spark/src/main/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringFunctions.scala b/spark/src/main/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringFunctions.scala index a81b63c433f..68496dbdcae 100644 --- a/spark/src/main/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringFunctions.scala +++ b/spark/src/main/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringFunctions.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.delta.skipping // scalastyle:off import.ordering.noEmptyLine -import org.apache.spark.sql.delta.expressions.{InterleaveBits, RangePartitionId} +import org.apache.spark.sql.delta.expressions.{HilbertByteArrayIndex, HilbertLongIndex, InterleaveBits, RangePartitionId} +import org.apache.spark.SparkException import org.apache.spark.sql.Column -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Cast, Expression} +import org.apache.spark.sql.types.StringType /** Functions for multi-dimensional clustering of the data */ object MultiDimClusteringFunctions { @@ -54,4 +56,26 @@ object MultiDimClusteringFunctions { def interleave_bits(cols: Column*): Column = withExpr { InterleaveBits(cols.map(_.expr)) } + + // scalastyle:off line.size.limit + /** + * Transforms the provided integer columns into their corresponding position in the hilbert + * curve for the given dimension. + * @see https://citeseerx.ist.psu.edu/document?repid=rep1&type=pdf&doi=bfd6d94c98627756989b0147a68b7ab1f881a0d6 + * @see https://en.wikipedia.org/wiki/Hilbert_curve + * @param numBits The number of bits to consider in each column. + * @param cols The integer columns to map to the curve. + */ + // scalastyle:on line.size.limit + def hilbert_index(numBits: Int, cols: Column*): Column = withExpr { + if (cols.size > 9) { + throw new SparkException("Hilbert indexing can only be used on 9 or fewer columns.") + } + val hilbertBits = cols.length * numBits + if (hilbertBits < 64) { + HilbertLongIndex(numBits, cols.map(_.expr)) + } else { + Cast(HilbertByteArrayIndex(numBits, cols.map(_.expr)), StringType) + } + } } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/expressions/HilbertIndexSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/expressions/HilbertIndexSuite.scala new file mode 100644 index 00000000000..a377301e5b3 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/expressions/HilbertIndexSuite.scala @@ -0,0 +1,200 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.expressions + +import java.util + +import org.scalatest.Tag +import org.apache.spark.SparkFunSuite + +class HilbertIndexSuite extends SparkFunSuite { + + /** + * Represents a test case. Each n-k pair will verify the continuity of the mapping, + * and the reversibility of it. + * @param n The number of dimensions + * @param k The number of bits in each dimension + */ + case class TestCase(n: Int, k: Int) + val testCases = Seq( + TestCase(2, 10), + TestCase(3, 6), + TestCase(4, 5), + TestCase(5, 4), + TestCase(6, 3) + ) + + def gridTest[A](testNamePrefix: String, testTags: Tag*)(params: Seq[A])( + testFun: A => Unit): Unit = { + for (param <- params) { + test(testNamePrefix + s" ($param)", testTags: _*)(testFun(param)) + } + } + + gridTest("HilbertStates caches states")(2 to 9) { n => + val start = System.nanoTime() + HilbertStates.getStateList(n) + val end = System.nanoTime() + + HilbertStates.getStateList(n) + val end2 = System.nanoTime() + assert(end2 - end < end - start) + } + + gridTest("Hilbert Mapping is continuous (long keys)")(testCases) { case TestCase(n, k) => + val generator = HilbertIndex.getStateGenerator(n) + + val stateList = generator.generateStateList() + + val states = stateList.getDKeyToNPointStateMap + + val maxDKeys = 1L << (k * n) + var d = 0 + var lastPoint = new Array[Int](n) + while (d < maxDKeys) { + val point = states.translateDKeyToNPoint(d, k) + if (d != 0) { + assert(HilbertUtils.manhattanDist(lastPoint, point) == 1) + } + + lastPoint = point + d += 1 + } + + } + + gridTest("Hilbert Mapping is 1 to 1 (long keys)")(testCases) { case TestCase(n, k) => + val generator = HilbertIndex.getStateGenerator(n) + val stateList = generator.generateStateList() + + val d2p = stateList.getDKeyToNPointStateMap + val p2d = stateList.getNPointToDKeyStateMap + + val maxDKeys = 1L << (k * n) + var d = 0 + while (d < maxDKeys) { + val point = d2p.translateDKeyToNPoint(d, k) + val d2 = p2d.translateNPointToDKey(point, k) + assert(d == d2) + d += 1 + } + } + + gridTest("Hilbert Mapping is continuous (array keys)")(testCases) { case TestCase(n, k) => + val generator = HilbertIndex.getStateGenerator(n) + + val stateList = generator.generateStateList() + + val states = stateList.getDKeyToNPointStateMap + + val maxDKeys = 1L << (k * n) + val d = new Array[Byte](((k * n) / 8) + 1) + var lastPoint = new Array[Int](n) + var i = 0 + while (i < maxDKeys) { + val point = states.translateDKeyArrayToNPoint(d, k) + if (i != 0) { + assert(HilbertUtils.manhattanDist(lastPoint, point) == 1, + s"$i ${d.toSeq.map(_.toBinaryString.takeRight(8))} ${lastPoint.toSeq} to ${point.toSeq}") + } + + lastPoint = point + i += 1 + HilbertUtils.addOne(d) + } + + } + + gridTest("Hilbert Mapping is 1 to 1 (array keys)")(testCases) { case TestCase(n, k) => + val generator = HilbertIndex.getStateGenerator(n) + val stateList = generator.generateStateList() + + val d2p = stateList.getDKeyToNPointStateMap + val p2d = stateList.getNPointToDKeyStateMap + + val maxDKeys = 1L << (k * n) + val d = new Array[Byte](((k * n) / 8) + 1) + var i = 0 + while (i < maxDKeys) { + val point = d2p.translateDKeyArrayToNPoint(d, k) + val d2 = p2d.translateNPointToDKeyArray(point, k) + assert(util.Arrays.equals(d, d2), s"$i ${d.toSeq}, ${d2.toSeq}") + i += 1 + HilbertUtils.addOne(d) + } + } + + gridTest("continuous and 1 to 1 for all spaces")((2 to 9).map(n => TestCase(n, 15 - n))) { + case TestCase(n, k) => + val generator = HilbertIndex.getStateGenerator(n) + val stateList = generator.generateStateList() + + val d2p = stateList.getDKeyToNPointStateMap + val p2d = stateList.getNPointToDKeyStateMap + + val numBits = k * n + val numBytes = (numBits + 7) / 8 + + // test 1000 contiguous 1000 point blocks to make sure the mapping is continuous and one to one + + val maxDKeys = 1L << (k * n) + val step = maxDKeys / 1000 + var x = 0L + for (_ <- 0 until 1000) { + var dLong = x + val bigIntArray = BigInt(dLong).toByteArray + val dArray = new Array[Byte](numBytes) + + System.arraycopy( + bigIntArray, + math.max(0, bigIntArray.length - dArray.length), + dArray, + math.max(0, dArray.length - bigIntArray.length), + math.min(bigIntArray.length, dArray.length) + ) + + var lastPoint: Array[Int] = null + + for (_ <- 0 until 1000) { + val pArray = d2p.translateDKeyArrayToNPoint(dArray, k) + val pLong = d2p.translateDKeyToNPoint(dLong, k) + assert(util.Arrays.equals(pArray, pLong), s"points should be the same at $dLong") + + if (lastPoint != null) { + assert(HilbertUtils.manhattanDist(lastPoint, pLong) == 1, + s"distance between point and last point should be the same at $dLong") + } + + val dArray2 = p2d.translateNPointToDKeyArray(pArray, k) + val dLong2 = p2d.translateNPointToDKey(pLong, k) + + assert(dLong == dLong2, s"reversing the points should map correctly at $dLong != $dLong2") + + assert(util.Arrays.equals(dArray, dArray2), + s"reversing the points should map correctly at $dLong") + + lastPoint = pLong + + dLong += 1 + HilbertUtils.addOne(dArray) + } + + x += step + } + + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/expressions/HilbertUtilsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/expressions/HilbertUtilsSuite.scala new file mode 100644 index 00000000000..342af67d338 --- /dev/null +++ b/spark/src/test/scala/org/apache/spark/sql/delta/expressions/HilbertUtilsSuite.scala @@ -0,0 +1,129 @@ +/* + * Copyright (2021) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.delta.expressions + +import java.util + +import org.apache.spark.sql.delta.expressions.HilbertUtils.HilbertMatrix + +import org.apache.spark.SparkFunSuite + +class HilbertUtilsSuite extends SparkFunSuite { + + test("circularLeftShift") { + assert( + (0 until (1 << 10) by 7).forall(i => HilbertUtils.circularLeftShift(10, i, 0) == i), + "Shift by 0 should be a no op" + ) + assert( + (0 until (1 << 10) by 7).forall(i => HilbertUtils.circularLeftShift(10, i, 10) == i), + "Shift by n should be a no op" + ) + // 0111 (<< 2) => 1101 + assert( + HilbertUtils.circularLeftShift(4, 7, 2) == 13, + "handle wrapping" + ) + assert( + (0 until (1 << 5)).forall(HilbertUtils.circularLeftShift(5, _, 5) <= (1 << 5)), + "always mask values based on n" + ) + } + + test("circularRightShift") { + assert( + (0 until (1 << 10) by 7).forall(i => HilbertUtils.circularRightShift(10, i, 0) == i), + "Shift by 0 should be a no op" + ) + assert( + (0 until (1 << 10) by 7).forall(i => HilbertUtils.circularRightShift(10, i, 10) == i), + "Shift by n should be a no op" + ) + // 0111 (>> 2) => 1101 + assert( + HilbertUtils.circularRightShift(4, 7, 2) == 13, + "handle wrapping" + ) + assert( + (0 until (1 << 5)).forall(HilbertUtils.circularRightShift(5, _, 5) <= (1 << 5)), + "always mask values based on n" + ) + } + + test("getSetColumn should return the column that is set") { + (0 until 16) foreach { i => + assert(HilbertUtils.getSetColumn(16, 1 << i) == 16 - 1 - i) + } + } + + test("HilbertMatrix makes sense") { + val identityMatrix = HilbertMatrix.identity(10) + (0 until (1 << 10) by 7) foreach { i => + assert(identityMatrix.transform(i) == i, s"$i transformed by the identity should be $i") + } + + identityMatrix.multiply(HilbertMatrix.identity(10)) == identityMatrix + + val shift5 = HilbertMatrix(10, 0, 5) + assert(shift5.multiply(shift5) == identityMatrix, "shift by 5 twice should equal identity") + } + + test("HilbertUtils.getBits") { + assert(HilbertUtils.getBits(Array(0, 0, 1), 22, 2) == 1) + val array = Array[Byte](0, 0, -1, 0) + assert(HilbertUtils.getBits(array, 16, 4) == 15) + assert(HilbertUtils.getBits(array, 18, 3) == 7) + assert(HilbertUtils.getBits(array, 23, 1) == 1) + assert(HilbertUtils.getBits(array, 23, 2) == 2) + assert(HilbertUtils.getBits(array, 23, 8) == 128) + assert(HilbertUtils.getBits(array, 16, 3) == 7) + assert(HilbertUtils.getBits(array, 16, 2) == 3) + assert(HilbertUtils.getBits(array, 16, 1) == 1) + assert(HilbertUtils.getBits(array, 15, 2) == 1) + assert(HilbertUtils.getBits(array, 15, 1) == 0) + assert(HilbertUtils.getBits(array, 12, 8) == 15) + assert(HilbertUtils.getBits(array, 12, 12) == 255) + assert(HilbertUtils.getBits(array, 12, 13) == (255 << 1)) + + assert(HilbertUtils.getBits(Array(0, 1, 0), 6, 6) == 0) + assert(HilbertUtils.getBits(Array(0, 1, 0), 12, 6) == 4) + assert(HilbertUtils.getBits(Array(0, 1, 0), 18, 6) == 0) + } + + def check(received: Array[Byte], expected: Array[Byte]): Unit = { + assert(util.Arrays.equals(expected, received), + s"${expected.toSeq.map(_.toBinaryString.takeRight(8))} " + + s"${received.toSeq.map(_.toBinaryString.takeRight(8))}") + } + + test("HilbertUtils.setBits") { + check(HilbertUtils.setBits(Array(0, 0, 0), 7, 8, 4), Array(1, 0, 0)) + check(HilbertUtils.setBits(Array(0, 0, 0), 7, 12, 4), Array(1, (1.toByte << 7).toByte, 0)) + check(HilbertUtils.setBits(Array(8, 0, 5), 7, 12, 4), Array(9, (1.toByte << 7).toByte, 5)) + check(HilbertUtils.setBits(Array(8, 0, 2), 7, -1, 12), + Array(9, -1, ((7.toByte << 5).toByte | 2).toByte)) + check(HilbertUtils.setBits(Array(8, 14, 2), 15, 1, 1), Array(8, 15, 2)) + } + + test("addOne") { + check(HilbertUtils.addOne(Array(0, 0, 0)), Array(0, 0, 1)) + check(HilbertUtils.addOne(Array(0, 0, -1)), Array(0, 1, 0)) + check(HilbertUtils.addOne(Array(0, 0, -2)), Array(0, 0, -1)) + check(HilbertUtils.addOne(Array(0, -1, -1)), Array(1, 0, 0)) + check(HilbertUtils.addOne(Array(-1, -1, -1)), Array(0, 0, 0)) + } +} diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringFunctionsSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringFunctionsSuite.scala index 43347b7f2f8..7c4e9039799 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringFunctionsSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringFunctionsSuite.scala @@ -20,10 +20,13 @@ import java.nio.ByteBuffer import scala.util.Random +import org.apache.spark.sql.delta.expressions.{HilbertByteArrayIndex, HilbertLongIndex} import org.apache.spark.sql.delta.skipping.MultiDimClusteringFunctions._ import org.apache.spark.sql.delta.test.DeltaSQLCommandTest +import org.apache.spark.SparkException import org.apache.spark.sql.{AnalysisException, DataFrame, QueryTest, Row} +import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.functions.lit import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.test.SharedSparkSession @@ -229,6 +232,22 @@ class MultiDimClusteringFunctionsSuite extends QueryTest ) } + test("hilbert_index selects underlying expression correctly") { + assert(hilbert_index(10, Seq($"c1", $"c2", $"c3", $"c4", $"c5", $"c6"): _*).expr + .isInstanceOf[HilbertLongIndex]) + assert( + hilbert_index( + 10, + Seq($"c1", $"c2", $"c3", $"c4", $"c5", $"c6", $"c7", $"c8", $"c9"): _*) + .expr.asInstanceOf[Cast].child.isInstanceOf[HilbertByteArrayIndex]) + val e = intercept[SparkException]( + hilbert_index( + 11, + Seq($"c1", $"c2", $"c3", $"c4", $"c5", $"c6", $"c7", $"c8", $"c9", $"c10"): _*) + .expr.isInstanceOf[HilbertByteArrayIndex]) + assert(e.getMessage.contains("Hilbert indexing can only be used on 9 or fewer columns.")) + } + private def intToBinary(x: Int): Array[Byte] = { ByteBuffer.allocate(4).putInt(x).array() } diff --git a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringSuite.scala b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringSuite.scala index e586ad13207..28baf076870 100644 --- a/spark/src/test/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringSuite.scala +++ b/spark/src/test/scala/org/apache/spark/sql/delta/skipping/MultiDimClusteringSuite.scala @@ -79,7 +79,8 @@ class MultiDimClusteringSuite extends QueryTest val outputDf = MultiDimClustering.cluster( inputDf, approxNumPartitions = 4, - colNames = Seq("c1", "c2")) + colNames = Seq("c1", "c2"), + curve = "zorder") outputDf.write.parquet(new File(tempDir, "source").getCanonicalPath) // Load the partition 0 and verify that it contains (a, 20), (a, 20), (b, 20) @@ -105,35 +106,80 @@ class MultiDimClusteringSuite extends QueryTest } } + test("ensure records with close Hilbert curve values are close in the output") { + withTempDir { tempDir => + withSQLConf(MDC_NUM_RANGE_IDS.key -> "4", MDC_ADD_NOISE.key -> "false") { + val data = Seq( + // "c1" -> "c2", // (rangeId_c1, rangeId_c2) -> Decimal Hilbert index + "a" -> 20, "a" -> 20, // (0, 0) -> 0 + "b" -> 20, // (0, 0) -> 0 + "c" -> 30, // (1, 1) -> 2 + "d" -> 70, // (1, 2) -> 13 + "e" -> 90, "e" -> 90, "e" -> 90, // (1, 2) -> 13 + "f" -> 200, // (2, 3) -> 11 + "g" -> 10, // (3, 0) -> 5 + "h" -> 20) // (3, 0) -> 5 + + // Randomize the data. Use seed for deterministic input. + val inputDf = new Random(seed = 101) + .shuffle(data) + .toDF("c1", "c2") + + // Cluster the data and range partition into four partitions + val outputDf = MultiDimClustering.cluster( + inputDf, + approxNumPartitions = 2, + colNames = Seq("c1", "c2"), + curve = "hilbert") + outputDf.write.parquet(new File(tempDir, "source").getCanonicalPath) + + // Load the partition 0 and verify its records. + checkAnswer( + Seq("a" -> 20, "a" -> 20, "b" -> 20, "c" -> 30, "g" -> 10, "h" -> 20).toDF("c1", "c2"), + sparkSession.read.parquet(new File(tempDir, "source/part-00000*").getCanonicalPath) + ) + + // partition 1 + checkAnswer( + Seq("d" -> 70, "e" -> 90, "e" -> 90, "e" -> 90, "f" -> 200).toDF("c1", "c2"), + sparkSession.read.parquet(new File(tempDir, "source/part-00001*").getCanonicalPath) + ) + } + } + } + test("noise is helpful in skew handling") { - Seq("true", "false").foreach { addNoise => - withTempDir { tempDir => - withSQLConf( - MDC_NUM_RANGE_IDS.key -> "4", - MDC_ADD_NOISE.key -> addNoise) { - val data = Array.fill(100)(20, 20) // all records have the same values - val inputDf = data.toSeq.toDF("c1", "c2") - - // Cluster the data and range partition into four partitions - val outputDf = MultiDimClustering.cluster( - inputDf, - approxNumPartitions = 4, - colNames = Seq("c1", "c2")) - - outputDf.write.parquet(new File(tempDir, "source").getCanonicalPath) - - // If there is no noise added, expect only one partition, otherwise four partition - // as mentioned in the cluster command above. - val partCount = new File(tempDir, "source").listFiles(new FilenameFilter { - override def accept(dir: File, name: String): Boolean = { - name.startsWith("part-0000") + Seq("zorder", "hilbert").foreach { curve => + Seq("true", "false").foreach { addNoise => + withTempDir { tempDir => + withSQLConf( + MDC_NUM_RANGE_IDS.key -> "4", + MDC_ADD_NOISE.key -> addNoise) { + val data = Array.fill(100)(20, 20) // all records have the same values + val inputDf = data.toSeq.toDF("c1", "c2") + + // Cluster the data and range partition into four partitions + val outputDf = MultiDimClustering.cluster( + inputDf, + approxNumPartitions = 4, + colNames = Seq("c1", "c2"), + curve) + + outputDf.write.parquet(new File(tempDir, "source").getCanonicalPath) + + // If there is no noise added, expect only one partition, otherwise four partition + // as mentioned in the cluster command above. + val partCount = new File(tempDir, "source").listFiles(new FilenameFilter { + override def accept(dir: File, name: String): Boolean = { + name.startsWith("part-0000") + } + }).length + + if ("true".equals(addNoise)) { + assert(4 === partCount, s"Incorrect number of partitions when addNoise=$addNoise") + } else { + assert(1 === partCount, s"Incorrect number of partitions when addNoise=$addNoise") } - }).length - - if ("true".equals(addNoise)) { - assert(4 === partCount, s"Incorrect number of partitions when addNoise=$addNoise") - } else { - assert(1 === partCount, s"Incorrect number of partitions when addNoise=$addNoise") } } } @@ -141,19 +187,20 @@ class MultiDimClusteringSuite extends QueryTest } test(s"try clustering with different ranges and noise flag on/off") { - Seq("true", "false").foreach { addNoise => - Seq("20", "100", "200", "1000").foreach { numRanges => - withSQLConf( - MDC_NUM_RANGE_IDS.key -> numRanges, - MDC_ADD_NOISE.key -> addNoise) { - val data = Seq.range(0, 100) - val inputDf = Random.shuffle(data).map(x => (x, x * 113 % 101)).toDF("col1", "col2") - val outputDf = MultiDimClustering.cluster( - inputDf, - approxNumPartitions = 10, - colNames = Seq("col1", "col2")) - // Underlying data shouldn't change - checkAnswer(outputDf, inputDf) + Seq("zorder", "hilbert").foreach { curve => + Seq("true", "false").foreach { addNoise => + Seq("20", "100", "200", "1000").foreach { numRanges => + withSQLConf(MDC_NUM_RANGE_IDS.key -> numRanges, MDC_ADD_NOISE.key -> addNoise) { + val data = Seq.range(0, 100) + val inputDf = Random.shuffle(data).map(x => (x, x * 113 % 101)).toDF("col1", "col2") + val outputDf = MultiDimClustering.cluster( + inputDf, + approxNumPartitions = 10, + colNames = Seq("col1", "col2"), + curve) + // Underlying data shouldn't change + checkAnswer(outputDf, inputDf) + } } } }