forked from locationtech/rasterframes
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial implementation with quantiles Signed-off-by: Jason T. Brown <jason@astraea.earth>
- Loading branch information
Showing
4 changed files
with
204 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
76 changes: 76 additions & 0 deletions
76
core/src/main/scala/org/locationtech/rasterframes/extensions/RasterFrameStatFunctions.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
package org.locationtech.rasterframes.extensions | ||
|
||
import org.locationtech.rasterframes.stats._ | ||
import org.apache.spark.sql.DataFrame | ||
import org.apache.spark.sql.functions.col | ||
|
||
final class RasterFrameStatFunctions private[rasterframes](df: DataFrame) { | ||
|
||
/** | ||
* Calculates the approximate quantiles of a numerical column of a DataFrame. | ||
* | ||
* The result of this algorithm has the following deterministic bound: | ||
* If the DataFrame has N elements and if we request the quantile at probability `p` up to error | ||
* `err`, then the algorithm will return a sample `x` from the DataFrame so that the *exact* rank | ||
* of `x` is close to (p * N). | ||
* More precisely, | ||
* | ||
* {{{ | ||
* floor((p - err) * N) <= rank(x) <= ceil((p + err) * N) | ||
* }}} | ||
* | ||
* This method implements a variation of the Greenwald-Khanna algorithm (with some speed | ||
* optimizations). | ||
* The algorithm was first present in <a href="http://dx.doi.org/10.1145/375663.375670"> | ||
* Space-efficient Online Computation of Quantile Summaries</a> by Greenwald and Khanna. | ||
* | ||
* @param col the name of the numerical column | ||
* @param probabilities a list of quantile probabilities | ||
* Each number must belong to [0, 1]. | ||
* For example 0 is the minimum, 0.5 is the median, 1 is the maximum. | ||
* @param relativeError The relative target precision to achieve (greater than or equal to 0). | ||
* If set to zero, the exact quantiles are computed, which could be very expensive. | ||
* Note that values greater than 1 are accepted but give the same result as 1. | ||
* @return the approximate quantiles at the given probabilities | ||
* | ||
* @note null and NaN values will be removed from the numerical column before calculation. If | ||
* the dataframe is empty or the column only contains null or NaN, an empty array is returned. | ||
* | ||
* @since 2.0.0 | ||
*/ | ||
def approxTileQuantile( | ||
col: String, | ||
probabilities: Array[Double], | ||
relativeError: Double): Array[Double] = { | ||
approxTileQuantile(Array(col), probabilities, relativeError).head | ||
} | ||
|
||
/** | ||
* Calculates the approximate quantiles of numerical columns of a DataFrame. | ||
* @see `approxQuantile(col:Str* approxQuantile)` for detailed description. | ||
* | ||
* @param cols the names of the numerical columns | ||
* @param probabilities a list of quantile probabilities | ||
* Each number must belong to [0, 1]. | ||
* For example 0 is the minimum, 0.5 is the median, 1 is the maximum. | ||
* @param relativeError The relative target precision to achieve (greater than or equal to 0). | ||
* If set to zero, the exact quantiles are computed, which could be very expensive. | ||
* Note that values greater than 1 are accepted but give the same result as 1. | ||
* @return the approximate quantiles at the given probabilities of each column | ||
* | ||
* @note null and NaN values will be ignored in numerical columns before calculation. For | ||
* columns only containing null or NaN values, an empty array is returned. | ||
* | ||
*/ | ||
def approxTileQuantile( | ||
cols: Array[String], | ||
probabilities: Array[Double], | ||
relativeError: Double): Array[Array[Double]] = { | ||
multipleApproxQuantiles( | ||
df.select(cols.map(col): _*), | ||
cols, | ||
probabilities, | ||
relativeError).map(_.toArray).toArray | ||
} | ||
|
||
} |
59 changes: 59 additions & 0 deletions
59
core/src/main/scala/org/locationtech/rasterframes/stats/package.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
package org.locationtech.rasterframes | ||
|
||
import geotrellis.raster.Tile | ||
import org.locationtech.rasterframes.TileType | ||
import org.locationtech.rasterframes.expressions.DynamicExtractors._ | ||
import org.apache.spark.sql.{Column, DataFrame, Row} | ||
import org.apache.spark.sql.catalyst.expressions.Cast | ||
import org.apache.spark.sql.catalyst.util.QuantileSummaries | ||
import org.apache.spark.sql.types.{DoubleType, NumericType} | ||
import org.locationtech.rasterframes.expressions.accessors.ExtractTile | ||
|
||
|
||
package object stats { | ||
|
||
def multipleApproxQuantiles(df: DataFrame, | ||
cols: Seq[String], | ||
probabilities: Seq[Double], | ||
relativeError: Double): Seq[Seq[Double]] = { | ||
require(relativeError >= 0, | ||
s"Relative Error must be non-negative but got $relativeError") | ||
|
||
val columns: Seq[Column] = cols.map { colName => | ||
val field = df.schema(colName) | ||
|
||
require(tileExtractor.isDefinedAt(field.dataType), | ||
s"Quantile calculation for column $colName with data type ${field.dataType}" + | ||
" is not supported; it must be Tile-like.") | ||
ExtractTile(new Column(colName)) | ||
} | ||
|
||
val emptySummaries = Array.fill(cols.size)( | ||
new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError)) | ||
|
||
def apply(summaries: Array[QuantileSummaries], row: Row): Array[QuantileSummaries] = { | ||
var i = 0 | ||
while (i < summaries.length) { | ||
if (!row.isNullAt(i)) { | ||
val t: Tile = row.getAs[Tile](i) | ||
// now insert all the tile values into the summary for this column | ||
t.foreachDouble(v ⇒ | ||
if (!v.isNaN) summaries(i) = summaries(i).insert(v) | ||
) | ||
} | ||
i += 1 // next column | ||
} | ||
summaries | ||
} | ||
|
||
def merge( | ||
sum1: Array[QuantileSummaries], | ||
sum2: Array[QuantileSummaries]): Array[QuantileSummaries] = { | ||
sum1.zip(sum2).map { case (s1, s2) => s1.compress().merge(s2.compress()) } | ||
} | ||
val summaries = df.select(columns: _*).rdd.treeAggregate(emptySummaries)(apply, merge) | ||
|
||
summaries.map { summary => probabilities.flatMap(summary.query) } | ||
} | ||
|
||
} |
66 changes: 66 additions & 0 deletions
66
core/src/test/scala/org/locationtech/rasterframes/RasterFramesStatsSpec.scala
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
/* | ||
* This software is licensed under the Apache 2 license, quoted below. | ||
* | ||
* Copyright 2018 Astraea, Inc. | ||
* | ||
* Licensed under the Apache License, Version 2.0 (the "License"); you may not | ||
* use this file except in compliance with the License. You may obtain a copy of | ||
* the License at | ||
* | ||
* [http://www.apache.org/licenses/LICENSE-2.0] | ||
* | ||
* Unless required by applicable law or agreed to in writing, software | ||
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT | ||
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the | ||
* License for the specific language governing permissions and limitations under | ||
* the License. | ||
* | ||
* SPDX-License-Identifier: Apache-2.0 | ||
* | ||
*/ | ||
|
||
package org.locationtech.rasterframes | ||
|
||
import org.apache.spark.sql.functions.col | ||
|
||
import org.locationtech.rasterframes._ | ||
import org.locationtech.rasterframes.RasterFunctions | ||
|
||
class RasterFramesStatsSpec extends TestEnvironment with TestData { | ||
|
||
describe("DataFrame.tileStats extension methods") { | ||
|
||
val df = TestData.sampleGeoTiff.toDF() | ||
.withColumn("tilePlus2", rf_local_add(col("tile"), 2)) | ||
|
||
it("should compute approx percentiles for a single tile col"){ | ||
|
||
val result = df.tileStat().approxTileQuantile( | ||
"tile", | ||
Array(0.10, 0.50, 0.90), | ||
0.00001 | ||
) | ||
|
||
result.length should be (3) | ||
|
||
// computing externally with numpy we arrive at 7963, 10068, 12160 for these quantiles | ||
result should contain inOrderOnly (7963.0, 10068.0, 12160.0) | ||
} | ||
|
||
it("should compute approx percentiles for many tile cols"){ | ||
val result = df.tileStat().approxTileQuantile( | ||
Array("tile", "tilePlus2"), | ||
Array(0.25, 0.75), | ||
0.00001 | ||
) | ||
result.length should be (2) | ||
// nested inside is another array of length 2 for each p | ||
result.foreach{c ⇒ c.length should be (2)} | ||
|
||
result.head should contain inOrderOnly (8701, 11261) | ||
result.tail.head should contain inOrderOnly (8703, 11263) | ||
} | ||
|
||
} | ||
|
||
} |