Skip to content

Commit

Permalink
Replaced TileDimensions with Dimension[Int].
Browse files Browse the repository at this point in the history
  • Loading branch information
metasim committed Nov 14, 2019
1 parent 38d503e commit 9a4f156
Show file tree
Hide file tree
Showing 32 changed files with 146 additions and 175 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@ import org.apache.spark.sql._
import org.locationtech.rasterframes._
import org.locationtech.rasterframes.expressions.generators.RasterSourceToRasterRefs
import org.locationtech.rasterframes.expressions.transformers.RasterRefToTile
import org.locationtech.rasterframes.model.TileDimensions
import org.locationtech.rasterframes.ref.RFRasterSource
import org.openjdk.jmh.annotations._

Expand All @@ -47,7 +46,7 @@ class RasterRefBench extends SparkEnv with LazyLogging {
val r2 = RFRasterSource(remoteCOGSingleband2)

singleDF = Seq((r1, r2)).toDF("B1", "B2")
.select(RasterRefToTile(RasterSourceToRasterRefs(Some(TileDimensions(r1.dimensions)), Seq(0), $"B1", $"B2")))
.select(RasterRefToTile(RasterSourceToRasterRefs(Some(r1.dimensions), Seq(0), $"B1", $"B2")))

expandedDF = Seq((r1, r2)).toDF("B1", "B2")
.select(RasterRefToTile(RasterSourceToRasterRefs($"B1", $"B2")))
Expand Down
7 changes: 5 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,11 @@ lazy val pyrasterframes = project
spark("core").value % Provided,
spark("mllib").value % Provided,
spark("sql").value % Provided
)
),
Test / test := (Test / test).dependsOn(experimental / Test / test).value
)


lazy val datasource = project
.configs(IntegrationTest)
.settings(Defaults.itSettings)
Expand All @@ -105,6 +107,7 @@ lazy val datasource = project
spark("mllib").value % Provided,
spark("sql").value % Provided
),
Test / test := (Test / test).dependsOn(core / Test / test).value,
initialCommands in console := (initialCommands in console).value +
"""
|import org.locationtech.rasterframes.datasource.geotrellis._
Expand All @@ -127,7 +130,7 @@ lazy val experimental = project
),
fork in IntegrationTest := true,
javaOptions in IntegrationTest := Seq("-Xmx2G"),
parallelExecution in IntegrationTest := false
Test / test := (Test / test).dependsOn(datasource / Test / test).value
)

lazy val docs = project
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package org.apache.spark.sql.rf

import java.lang.reflect.Constructor

import org.apache.spark.sql.AnalysisException
import org.apache.spark.sql.catalyst.FunctionIdentifier
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry
import org.apache.spark.sql.catalyst.analysis.FunctionRegistry.FunctionBuilder
Expand All @@ -12,7 +13,6 @@ import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.execution.datasources.LogicalRelation
import org.apache.spark.sql.sources.BaseRelation
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{AnalysisException, DataFrame, Dataset, SQLContext}

import scala.reflect._
import scala.util.{Failure, Success, Try}
Expand All @@ -23,11 +23,6 @@ import scala.util.{Failure, Success, Try}
* @since 2/13/18
*/
object VersionShims {
def readJson(sqlContext: SQLContext, rows: Dataset[String]): DataFrame = {
// NB: Will get a deprecation warning for Spark 2.2.x
sqlContext.read.json(rows.rdd) // <-- deprecation warning expected
}

def updateRelation(lr: LogicalRelation, base: BaseRelation): LogicalPlan = {
val lrClazz = classOf[LogicalRelation]
val ctor = lrClazz.getConstructors.head.asInstanceOf[Constructor[LogicalRelation]]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ package org.locationtech.rasterframes
import geotrellis.proj4.CRS
import geotrellis.raster.mapalgebra.local.LocalTileBinaryOp
import geotrellis.raster.render.ColorRamp
import geotrellis.raster.{CellType, Tile}
import geotrellis.raster.{CellType, Dimensions, Tile}
import geotrellis.vector.Extent
import org.apache.spark.annotation.Experimental
import org.apache.spark.sql.functions.{lit, udf}
Expand All @@ -35,9 +35,8 @@ import org.locationtech.rasterframes.expressions.aggregates._
import org.locationtech.rasterframes.expressions.generators._
import org.locationtech.rasterframes.expressions.localops._
import org.locationtech.rasterframes.expressions.tilestats._
import org.locationtech.rasterframes.expressions.transformers.RenderPNG.{RenderCompositePNG, RenderColorRampPNG}
import org.locationtech.rasterframes.expressions.transformers.RenderPNG.{RenderColorRampPNG, RenderCompositePNG}
import org.locationtech.rasterframes.expressions.transformers._
import org.locationtech.rasterframes.model.TileDimensions
import org.locationtech.rasterframes.stats._
import org.locationtech.rasterframes.{functions => F}

Expand All @@ -51,7 +50,7 @@ trait RasterFunctions {

// format: off
/** Query the number of (cols, rows) in a Tile. */
def rf_dimensions(col: Column): TypedColumn[Any, TileDimensions] = GetDimensions(col)
def rf_dimensions(col: Column): TypedColumn[Any, Dimensions[Int]] = GetDimensions(col)

/** Extracts the bounding box of a geometry as an Extent */
def st_extent(col: Column): TypedColumn[Any, Extent] = GeometryToExtent(col)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import java.sql.Timestamp
import org.locationtech.rasterframes.stats.{CellHistogram, CellStatistics, LocalCellStatistics}
import org.locationtech.jts.geom.Envelope
import geotrellis.proj4.CRS
import geotrellis.raster.{CellSize, CellType, Raster, Tile, TileLayout}
import geotrellis.raster.{CellSize, CellType, Dimensions, Raster, Tile, TileLayout}
import geotrellis.layer._
import geotrellis.vector.{Extent, ProjectedExtent}
import org.apache.spark.sql.{Encoder, Encoders}
Expand Down Expand Up @@ -70,8 +70,7 @@ trait StandardEncoders extends SpatialEncoders {
implicit def tileContextEncoder: ExpressionEncoder[TileContext] = TileContext.encoder
implicit def tileDataContextEncoder: ExpressionEncoder[TileDataContext] = TileDataContext.encoder
implicit def extentTilePairEncoder: Encoder[(ProjectedExtent, Tile)] = Encoders.tuple(projectedExtentEncoder, singlebandTileEncoder)


implicit def tileDimensionsEncoder: Encoder[Dimensions[Int]] = CatalystSerializerEncoder[Dimensions[Int]](true)
}

object StandardEncoders extends StandardEncoders
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ import com.github.blemale.scaffeine.Scaffeine
import geotrellis.proj4.CRS
import geotrellis.raster._
import geotrellis.layer._

import geotrellis.vector._
import org.apache.spark.sql.types._
import org.locationtech.jts.geom.Envelope
Expand Down Expand Up @@ -60,9 +59,11 @@ trait StandardSerializers {
StructField("xmax", DoubleType, false),
StructField("ymax", DoubleType, false)
))

override def to[R](t: Extent, io: CatalystIO[R]): R = io.create(
t.xmin, t.ymin, t.xmax, t.ymax
)

override def from[R](row: R, io: CatalystIO[R]): Extent = Extent(
io.getDouble(row, 0),
io.getDouble(row, 1),
Expand Down Expand Up @@ -95,25 +96,31 @@ trait StandardSerializers {
override val schema: StructType = StructType(Seq(
StructField("crsProj4", StringType, false)
))

override def to[R](t: CRS, io: CatalystIO[R]): R = io.create(
io.encode(
// Don't do this... it's 1000x slower to decode.
//t.epsgCode.map(c => "EPSG:" + c).getOrElse(t.toProj4String)
t.toProj4String
)
)

override def from[R](row: R, io: CatalystIO[R]): CRS =
LazyCRS(io.getString(row, 0))
}

implicit val cellTypeSerializer: CatalystSerializer[CellType] = new CatalystSerializer[CellType] {

import StandardSerializers._

override val schema: StructType = StructType(Seq(
StructField("cellTypeName", StringType, false)
))

override def to[R](t: CellType, io: CatalystIO[R]): R = io.create(
io.encode(ct2sCache.get(t))
)

override def from[R](row: R, io: CatalystIO[R]): CellType =
s2ctCache.get(io.getString(row, 0))
}
Expand Down Expand Up @@ -229,7 +236,7 @@ trait StandardSerializers {
)
}

implicit def boundsSerializer[T >: Null: CatalystSerializer]: CatalystSerializer[KeyBounds[T]] = new CatalystSerializer[KeyBounds[T]] {
implicit def boundsSerializer[T >: Null : CatalystSerializer]: CatalystSerializer[KeyBounds[T]] = new CatalystSerializer[KeyBounds[T]] {
override val schema: StructType = StructType(Seq(
StructField("minKey", schemaOf[T], true),
StructField("maxKey", schemaOf[T], true)
Expand All @@ -246,7 +253,7 @@ trait StandardSerializers {
)
}

def tileLayerMetadataSerializer[T >: Null: CatalystSerializer]: CatalystSerializer[TileLayerMetadata[T]] = new CatalystSerializer[TileLayerMetadata[T]] {
def tileLayerMetadataSerializer[T >: Null : CatalystSerializer]: CatalystSerializer[TileLayerMetadata[T]] = new CatalystSerializer[TileLayerMetadata[T]] {
override val schema: StructType = StructType(Seq(
StructField("cellType", schemaOf[CellType], false),
StructField("layout", schemaOf[LayoutDefinition], false),
Expand All @@ -273,6 +280,7 @@ trait StandardSerializers {
}

implicit def rasterSerializer: CatalystSerializer[Raster[Tile]] = new CatalystSerializer[Raster[Tile]] {

import org.apache.spark.sql.rf.TileUDT.tileSerializer

override val schema: StructType = StructType(Seq(
Expand All @@ -294,6 +302,22 @@ trait StandardSerializers {
implicit val spatialKeyTLMSerializer = tileLayerMetadataSerializer[SpatialKey]
implicit val spaceTimeKeyTLMSerializer = tileLayerMetadataSerializer[SpaceTimeKey]

implicit val tileDimensionsSerializer: CatalystSerializer[Dimensions[Int]] = new CatalystSerializer[Dimensions[Int]] {
override val schema: StructType = StructType(Seq(
StructField("cols", IntegerType, false),
StructField("rows", IntegerType, false)
))

override protected def to[R](t: Dimensions[Int], io: CatalystIO[R]): R = io.create(
t.cols,
t.rows
)

override protected def from[R](t: R, io: CatalystIO[R]): Dimensions[Int] = Dimensions[Int](
io.getInt(t, 0),
io.getInt(t, 1)
)
}
}

object StandardSerializers {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ package org.locationtech.rasterframes.expressions.accessors

import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.expressions.OnCellGridExpression
import geotrellis.raster.CellGrid
import geotrellis.raster.{CellGrid, Dimensions}
import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.locationtech.rasterframes.model.TileDimensions

/**
* Extract a raster's dimensions
Expand All @@ -43,12 +42,13 @@ import org.locationtech.rasterframes.model.TileDimensions
case class GetDimensions(child: Expression) extends OnCellGridExpression with CodegenFallback {
override def nodeName: String = "rf_dimensions"

def dataType = schemaOf[TileDimensions]
def dataType = schemaOf[Dimensions[Int]]

override def eval(grid: CellGrid[Int]): Any = TileDimensions(grid.cols, grid.rows).toInternalRow
override def eval(grid: CellGrid[Int]): Any = Dimensions[Int](grid.cols, grid.rows).toInternalRow
}

object GetDimensions {
def apply(col: Column): TypedColumn[Any, TileDimensions] =
new Column(new GetDimensions(col.expr)).as[TileDimensions]
import org.locationtech.rasterframes.encoders.StandardEncoders.tileDimensionsEncoder
def apply(col: Column): TypedColumn[Any, Dimensions[Int]] =
new Column(new GetDimensions(col.expr)).as[Dimensions[Int]]
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ package org.locationtech.rasterframes.expressions.aggregates
import org.locationtech.rasterframes._
import org.locationtech.rasterframes.encoders.CatalystSerializer
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.model.TileDimensions
import geotrellis.proj4.{CRS, Transform}
import geotrellis.raster._
import geotrellis.raster.reproject.{Reproject, ReprojectRasterExtent}
Expand All @@ -34,7 +33,7 @@ import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAg
import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.{Column, Row, TypedColumn}

class ProjectedLayerMetadataAggregate(destCRS: CRS, destDims: TileDimensions) extends UserDefinedAggregateFunction {
class ProjectedLayerMetadataAggregate(destCRS: CRS, destDims: Dimensions[Int]) extends UserDefinedAggregateFunction {
import ProjectedLayerMetadataAggregate._

override def inputSchema: StructType = CatalystSerializer[InputRecord].schema
Expand Down Expand Up @@ -94,14 +93,14 @@ object ProjectedLayerMetadataAggregate {
/** Primary user facing constructor */
def apply(destCRS: CRS, extent: Column, crs: Column, cellType: Column, tileSize: Column): TypedColumn[Any, TileLayerMetadata[SpatialKey]] =
// Ordering must match InputRecord schema
new ProjectedLayerMetadataAggregate(destCRS, TileDimensions(NOMINAL_TILE_SIZE, NOMINAL_TILE_SIZE))(extent, crs, cellType, tileSize).as[TileLayerMetadata[SpatialKey]]
new ProjectedLayerMetadataAggregate(destCRS, Dimensions(NOMINAL_TILE_SIZE, NOMINAL_TILE_SIZE))(extent, crs, cellType, tileSize).as[TileLayerMetadata[SpatialKey]]

def apply(destCRS: CRS, destDims: TileDimensions, extent: Column, crs: Column, cellType: Column, tileSize: Column): TypedColumn[Any, TileLayerMetadata[SpatialKey]] =
def apply(destCRS: CRS, destDims: Dimensions[Int], extent: Column, crs: Column, cellType: Column, tileSize: Column): TypedColumn[Any, TileLayerMetadata[SpatialKey]] =
// Ordering must match InputRecord schema
new ProjectedLayerMetadataAggregate(destCRS, destDims)(extent, crs, cellType, tileSize).as[TileLayerMetadata[SpatialKey]]

private[expressions]
case class InputRecord(extent: Extent, crs: CRS, cellType: CellType, tileSize: TileDimensions) {
case class InputRecord(extent: Extent, crs: CRS, cellType: CellType, tileSize: Dimensions[Int]) {
def toBufferRecord(destCRS: CRS): BufferRecord = {
val transform = Transform(crs, destCRS)

Expand All @@ -125,7 +124,7 @@ object ProjectedLayerMetadataAggregate {
StructField("extent", CatalystSerializer[Extent].schema, false),
StructField("crs", CatalystSerializer[CRS].schema, false),
StructField("cellType", CatalystSerializer[CellType].schema, false),
StructField("tileSize", CatalystSerializer[TileDimensions].schema, false)
StructField("tileSize", CatalystSerializer[Dimensions[Int]].schema, false)
))

override protected def to[R](t: InputRecord, io: CatalystIO[R]): R =
Expand All @@ -135,7 +134,7 @@ object ProjectedLayerMetadataAggregate {
io.get[Extent](t, 0),
io.get[CRS](t, 1),
io.get[CellType](t, 2),
io.get[TileDimensions](t, 3)
io.get[Dimensions[Int]](t, 3)
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ package org.locationtech.rasterframes.expressions.aggregates
import geotrellis.proj4.CRS
import geotrellis.raster.reproject.Reproject
import geotrellis.raster.resample.ResampleMethod
import geotrellis.raster.{ArrayTile, CellType, MultibandTile, ProjectedRaster, Raster, Tile}
import geotrellis.raster.{ArrayTile, CellType, Dimensions, MultibandTile, ProjectedRaster, Raster, Tile}
import geotrellis.layer._
import geotrellis.vector.Extent
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
Expand All @@ -34,7 +34,6 @@ import org.locationtech.rasterframes._
import org.locationtech.rasterframes.util._
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.expressions.aggregates.TileRasterizerAggregate.ProjectedRasterDefinition
import org.locationtech.rasterframes.model.TileDimensions
import org.slf4j.LoggerFactory

/**
Expand Down Expand Up @@ -119,7 +118,7 @@ object TileRasterizerAggregate {
new TileRasterizerAggregate(prd)(crsCol, extentCol, tileCol).as(nodeName).as[Raster[Tile]]
}

def collect(df: DataFrame, destCRS: CRS, destExtent: Option[Extent], rasterDims: Option[TileDimensions]): ProjectedRaster[MultibandTile] = {
def collect(df: DataFrame, destCRS: CRS, destExtent: Option[Extent], rasterDims: Option[Dimensions[Int]]): ProjectedRaster[MultibandTile] = {
val tileCols = WithDataFrameMethods(df).tileColumns
require(tileCols.nonEmpty, "need at least one tile column")
// Select the anchoring Tile, Extent and CRS columns
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@

package org.locationtech.rasterframes.expressions.generators

import geotrellis.raster.GridBounds
import geotrellis.raster.{Dimensions, GridBounds}
import geotrellis.vector.Extent
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
Expand All @@ -30,8 +30,7 @@ import org.apache.spark.sql.types.{DataType, StructField, StructType}
import org.apache.spark.sql.{Column, TypedColumn}
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.expressions.generators.RasterSourceToRasterRefs.bandNames
import org.locationtech.rasterframes.model.TileDimensions
import org.locationtech.rasterframes.ref.{RasterRef, RFRasterSource}
import org.locationtech.rasterframes.ref.{RFRasterSource, RasterRef}
import org.locationtech.rasterframes.util._
import org.locationtech.rasterframes.RasterSourceType

Expand All @@ -43,7 +42,7 @@ import scala.util.control.NonFatal
*
* @since 9/6/18
*/
case class RasterSourceToRasterRefs(children: Seq[Expression], bandIndexes: Seq[Int], subtileDims: Option[TileDimensions] = None) extends Expression
case class RasterSourceToRasterRefs(children: Seq[Expression], bandIndexes: Seq[Int], subtileDims: Option[Dimensions[Int]] = None) extends Expression
with Generator with CodegenFallback with ExpectsInputTypes {

override def inputTypes: Seq[DataType] = Seq.fill(children.size)(RasterSourceType)
Expand Down Expand Up @@ -86,7 +85,7 @@ case class RasterSourceToRasterRefs(children: Seq[Expression], bandIndexes: Seq[

object RasterSourceToRasterRefs {
def apply(rrs: Column*): TypedColumn[Any, RasterRef] = apply(None, Seq(0), rrs: _*)
def apply(subtileDims: Option[TileDimensions], bandIndexes: Seq[Int], rrs: Column*): TypedColumn[Any, RasterRef] =
def apply(subtileDims: Option[Dimensions[Int]], bandIndexes: Seq[Int], rrs: Column*): TypedColumn[Any, RasterRef] =
new Column(new RasterSourceToRasterRefs(rrs.map(_.expr), bandIndexes, subtileDims)).as[RasterRef]

private[rasterframes] def bandNames(basename: String, bandIndexes: Seq[Int]): Seq[String] = bandIndexes match {
Expand Down
Loading

0 comments on commit 9a4f156

Please sign in to comment.