Skip to content

Commit

Permalink
Merge pull request #429 from s22s/feature/tile-quantile
Browse files Browse the repository at this point in the history
Add rf_agg_approx_quantiles function
  • Loading branch information
vpipkt authored Jan 13, 2020
2 parents 73a52e6 + d73e255 commit 1730af9
Show file tree
Hide file tree
Showing 10 changed files with 259 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -21,17 +21,21 @@

package org.locationtech.rasterframes.encoders

import java.nio.ByteBuffer

import com.github.blemale.scaffeine.Scaffeine
import geotrellis.proj4.CRS
import geotrellis.raster._
import geotrellis.spark._
import geotrellis.spark.tiling.LayoutDefinition
import geotrellis.vector._
import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.types._
import org.locationtech.jts.geom.Envelope
import org.locationtech.rasterframes.TileType
import org.locationtech.rasterframes.encoders.CatalystSerializer.{CatalystIO, _}
import org.locationtech.rasterframes.model.LazyCRS
import org.locationtech.rasterframes.util.KryoSupport

/** Collection of CatalystSerializers for third-party types. */
trait StandardSerializers {
Expand Down Expand Up @@ -294,9 +298,23 @@ trait StandardSerializers {
implicit val spatialKeyTLMSerializer = tileLayerMetadataSerializer[SpatialKey]
implicit val spaceTimeKeyTLMSerializer = tileLayerMetadataSerializer[SpaceTimeKey]

implicit val quantileSerializer: CatalystSerializer[QuantileSummaries] = new CatalystSerializer[QuantileSummaries] {
override val schema: StructType = StructType(Seq(
StructField("quantile_serializer_kryo", BinaryType, false)
))

override protected def to[R](t: QuantileSummaries, io: CatalystSerializer.CatalystIO[R]): R = {
val buf = KryoSupport.serialize(t)
io.create(buf.array())
}

override protected def from[R](t: R, io: CatalystSerializer.CatalystIO[R]): QuantileSummaries = {
KryoSupport.deserialize[QuantileSummaries](ByteBuffer.wrap(io.getByteArray(t, 0)))
}
}
}

object StandardSerializers {
object StandardSerializers extends StandardSerializers {
private val s2ctCache = Scaffeine().build[String, CellType](
(s: String) => CellType.fromName(s)
)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,88 @@
/*
* This software is licensed under the Apache 2 license, quoted below.
*
* Copyright 2019 Astraea, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* [http://www.apache.org/licenses/LICENSE-2.0]
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*
* SPDX-License-Identifier: Apache-2.0
*
*/

package org.locationtech.rasterframes.expressions.aggregates

import geotrellis.raster.{Tile, isNoData}
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.catalyst.util.QuantileSummaries
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.{Column, Encoder, Row, TypedColumn, types}
import org.apache.spark.sql.types.{DataTypes, StructField, StructType}
import org.locationtech.rasterframes.TileType
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.expressions.accessors.ExtractTile


case class ApproxCellQuantilesAggregate(probabilities: Seq[Double], relativeError: Double) extends UserDefinedAggregateFunction {
import org.locationtech.rasterframes.encoders.StandardSerializers.quantileSerializer

override def inputSchema: StructType = StructType(Seq(
StructField("value", TileType, true)
))

override def bufferSchema: StructType = StructType(Seq(
StructField("buffer", schemaOf[QuantileSummaries], false)
))

override def dataType: types.DataType = DataTypes.createArrayType(DataTypes.DoubleType)

override def deterministic: Boolean = true

override def initialize(buffer: MutableAggregationBuffer): Unit =
buffer.update(0, new QuantileSummaries(QuantileSummaries.defaultCompressThreshold, relativeError).toRow)

override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
val qs = buffer.getStruct(0).to[QuantileSummaries]
if (!input.isNullAt(0)) {
val tile = input.getAs[Tile](0)
var result = qs
tile.foreachDouble(d => if (!isNoData(d)) result = result.insert(d))
buffer.update(0, result.toRow)
}
else buffer
}

override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
val left = buffer1.getStruct(0).to[QuantileSummaries]
val right = buffer2.getStruct(0).to[QuantileSummaries]
val merged = left.compress().merge(right.compress())
buffer1.update(0, merged.toRow)
}

override def evaluate(buffer: Row): Seq[Double] = {
val summaries = buffer.getStruct(0).to[QuantileSummaries]
probabilities.flatMap(summaries.query)
}
}

object ApproxCellQuantilesAggregate {
private implicit def doubleSeqEncoder: Encoder[Seq[Double]] = ExpressionEncoder()

def apply(
tile: Column,
probabilities: Seq[Double],
relativeError: Double = 0.00001): TypedColumn[Any, Seq[Double]] = {
new ApproxCellQuantilesAggregate(probabilities, relativeError)(ExtractTile(tile))
.as(s"rf_agg_approx_quantiles")
.as[Seq[Double]]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,10 @@ object HistogramAggregate {
import org.locationtech.rasterframes.encoders.StandardEncoders.cellHistEncoder

def apply(col: Column): TypedColumn[Any, CellHistogram] =
new HistogramAggregate()(ExtractTile(col))
apply(col, StreamingHistogram.DEFAULT_NUM_BUCKETS)

def apply(col: Column, numBuckets: Int): TypedColumn[Any, CellHistogram] =
new HistogramAggregate(numBuckets)(ExtractTile(col))
.as(s"rf_agg_approx_histogram($col)")
.as[CellHistogram]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,32 @@ trait AggregateFunctions {
/** Compute the cellwise/local count of NoData cells for all Tiles in a column. */
def rf_agg_local_no_data_cells(tile: Column): TypedColumn[Any, Tile] = LocalCountAggregate.LocalNoDataCellsUDAF(tile)

/** Compute the full column aggregate floating point histogram. */
/** Compute the approximate aggregate floating point histogram using a streaming algorithm, with the default of 80 buckets. */
def rf_agg_approx_histogram(tile: Column): TypedColumn[Any, CellHistogram] = HistogramAggregate(tile)

/** Compute the approximate aggregate floating point histogram using a streaming algorithm, with the given number of buckets. */
def rf_agg_approx_histogram(col: Column, numBuckets: Int): TypedColumn[Any, CellHistogram] = {
require(numBuckets > 0, "Must provide a positive number of buckets")
HistogramAggregate(col, numBuckets)
}

/**
* Calculates the approximate quantiles of a tile column of a DataFrame.
* @param tile tile column to extract cells from.
* @param probabilities a list of quantile probabilities
* Each number must belong to [0, 1].
* For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
* @param relativeError The relative target precision to achieve (greater than or equal to 0).
* @return the approximate quantiles at the given probabilities of each column
*/
def rf_agg_approx_quantiles(
tile: Column,
probabilities: Seq[Double],
relativeError: Double = 0.00001): TypedColumn[Any, Seq[Double]] = {
require(probabilities.nonEmpty, "at least one quantile probability is required")
ApproxCellQuantilesAggregate(tile, probabilities, relativeError)
}

/** Compute the full column aggregate floating point statistics. */
def rf_agg_stats(tile: Column): TypedColumn[Any, CellStatistics] = CellStatsAggregate(tile)

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
/*
* This software is licensed under the Apache 2 license, quoted below.
*
* Copyright 2018 Astraea, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not
* use this file except in compliance with the License. You may obtain a copy of
* the License at
*
* [http://www.apache.org/licenses/LICENSE-2.0]
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
* License for the specific language governing permissions and limitations under
* the License.
*
* SPDX-License-Identifier: Apache-2.0
*
*/

package org.locationtech.rasterframes

import org.locationtech.rasterframes.RasterFunctions
import org.apache.spark.sql.functions.{col, explode}

class RasterFramesStatsSpec extends TestEnvironment with TestData {

import spark.implicits._

val df = TestData.sampleGeoTiff
.toDF()
.withColumn("tilePlus2", rf_local_add(col("tile"), 2))


describe("Tile quantiles through built-in functions") {

it("should compute approx percentiles for a single tile col") {
// Use "explode"
val result = df
.select(rf_explode_tiles($"tile"))
.stat
.approxQuantile("tile", Array(0.10, 0.50, 0.90), 0.00001)

result.length should be(3)

// computing externally with numpy we arrive at 7963, 10068, 12160 for these quantiles
result should contain inOrderOnly(7963.0, 10068.0, 12160.0)

// Use "to_array" and built-in explode
val result2 = df
.select(explode(rf_tile_to_array_double($"tile")) as "tile")
.stat
.approxQuantile("tile", Array(0.10, 0.50, 0.90), 0.00001)

result2.length should be(3)

// computing externally with numpy we arrive at 7963, 10068, 12160 for these quantiles
result2 should contain inOrderOnly(7963.0, 10068.0, 12160.0)

}
}

describe("Tile quantiles through custom aggregate") {
it("should compute approx percentiles for a single tile col") {
val result = df
.select(rf_agg_approx_quantiles($"tile", Seq(0.1, 0.5, 0.9)))
.first()

result.length should be(3)

// computing externally with numpy we arrive at 7963, 10068, 12160 for these quantiles
result should contain inOrderOnly(7963.0, 10068.0, 12160.0)
}

}
}

8 changes: 8 additions & 0 deletions docs/src/main/paradox/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -634,6 +634,14 @@ Aggregates over the `tile` and returns statistical summaries of cell values: num

Aggregates over all of the rows in DataFrame of `tile` and returns a count of each cell value to create a histogram with values are plotted on the x-axis and counts on the y-axis. Related is the @ref:[`rf_tile_histogram`](reference.md#rf-tile-histogram) function which operates on a single row at a time.

### rf_agg_approx_quantiles

Array[Double] rf_agg_approx_quantiles(Tile tile, List[float] probabilities, float relative_error)

__Not supported in SQL.__

Calculates the approximate quantiles of a tile column of a DataFrame. `probabilities` is a list of float values at which to compute the quantiles. These must belong to [0, 1]. For example 0 is the minimum, 0.5 is the median, 1 is the maximum. Returns an array of values approximately at the specified `probabilities`.

### rf_agg_extent

Extent rf_agg_extent(Extent extent)
Expand Down
1 change: 1 addition & 0 deletions docs/src/main/paradox/release-notes.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* Added `rf_render_color_ramp_png` to compute PNG byte array for a single tile column, with specified color ramp.
* In `rf_ipython`, improved rendering of dataframe binary contents with PNG preamble.
* Throw an `IllegalArgumentException` when attempting to apply a mask to a `Tile` whose `CellType` has no NoData defined. ([#409](https://github.com/locationtech/rasterframes/issues/384))
* Add `rf_agg_approx-quantiles` function to compute cell quantiles across an entire column.

### 0.8.4

Expand Down
16 changes: 16 additions & 0 deletions pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,22 @@ def rf_agg_approx_histogram(tile_col):
return _apply_column_function('rf_agg_approx_histogram', tile_col)


def rf_agg_approx_quantiles(tile_col, probabilities, relative_error=0.00001):
"""
Calculates the approximate quantiles of a tile column of a DataFrame.
:param tile_col: column to extract cells from.
:param probabilities: a list of quantile probabilities. Each number must belong to [0, 1].
For example 0 is the minimum, 0.5 is the median, 1 is the maximum.
:param relative_error: The relative target precision to achieve (greater than or equal to 0). Default is 0.00001
:return: An array of values approximately at the specified `probabilities`
"""

_jfn = RFContext.active().lookup('rf_agg_approx_quantiles')
_tile_col = _to_java_column(tile_col)
return Column(_jfn(_tile_col, probabilities, relative_error))


def rf_agg_stats(tile_col):
"""Compute the full column aggregate floating point statistics"""
return _apply_column_function('rf_agg_stats', tile_col)
Expand Down
19 changes: 14 additions & 5 deletions pyrasterframes/src/main/python/tests/RasterFunctionsTests.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,25 @@

from unittest import skip

import numpy as np
import sys
from numpy.testing import assert_equal
from pyspark import Row
from pyspark.sql.functions import *

import pyrasterframes
from pyrasterframes.rasterfunctions import *
from pyrasterframes.rf_types import *
from pyrasterframes.utils import gdal_version
from pyspark import Row
from pyspark.sql.functions import *

import numpy as np
from numpy.testing import assert_equal, assert_allclose

from unittest import skip
from . import TestEnvironment


class RasterFunctions(TestEnvironment):

def setUp(self):
import sys
if not sys.warnoptions:
import warnings
warnings.simplefilter("ignore")
Expand Down Expand Up @@ -138,6 +141,12 @@ def test_aggregations(self):
self.assertEqual(row['rf_agg_no_data_cells(tile)'], 1000)
self.assertEqual(row['rf_agg_stats(tile)'].data_cells, row['rf_agg_data_cells(tile)'])

def test_agg_approx_quantiles(self):
agg = self.rf.agg(rf_agg_approx_quantiles('tile', [0.1, 0.5, 0.9, 0.98]))
result = agg.first()[0]
# expected result from computing in external python process; c.f. scala tests
assert_allclose(result, np.array([7963., 10068., 12160., 14366.]))

def test_sql(self):

self.rf.createOrReplaceTempView("rf_test_sql")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,13 @@ class PyRFContext(implicit sparkSession: SparkSession) extends RasterFunctions

def rf_local_unequal_int(col: Column, scalar: Int): Column = rf_local_unequal[Int](col, scalar)

// other function support
/** py4j friendly version of this function */
def rf_agg_approx_quantiles(tile: Column, probabilities: java.util.List[Double], relativeError: Double): TypedColumn[Any, Seq[Double]] = {
import scala.collection.JavaConverters._
rf_agg_approx_quantiles(tile, probabilities.asScala, relativeError)
}

def _make_crs_literal(crsText: String): Column = {
rasterframes.encoders.serialized_literal[CRS](LazyCRS(crsText))
}
Expand Down

0 comments on commit 1730af9

Please sign in to comment.