diff --git a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala
index 4192c339a..6768f1742 100644
--- a/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala
+++ b/core/src/main/scala/org/locationtech/rasterframes/RasterFunctions.scala
@@ -22,6 +22,7 @@
package org.locationtech.rasterframes
import geotrellis.proj4.CRS
import geotrellis.raster.mapalgebra.local.LocalTileBinaryOp
+import geotrellis.raster.render.ColorRamp
import geotrellis.raster.{CellType, Tile}
import geotrellis.vector.Extent
import org.apache.spark.annotation.Experimental
@@ -34,6 +35,7 @@ import org.locationtech.rasterframes.expressions.aggregates._
import org.locationtech.rasterframes.expressions.generators._
import org.locationtech.rasterframes.expressions.localops._
import org.locationtech.rasterframes.expressions.tilestats._
+import org.locationtech.rasterframes.expressions.transformers.RenderPNG.{RenderCompositePNG, RenderColorRampPNG}
import org.locationtech.rasterframes.expressions.transformers._
import org.locationtech.rasterframes.model.TileDimensions
import org.locationtech.rasterframes.stats._
@@ -81,6 +83,10 @@ trait RasterFunctions {
def rf_assemble_tile(columnIndex: Column, rowIndex: Column, cellData: Column, tileCols: Int, tileRows: Int, ct: CellType): TypedColumn[Any, Tile] =
rf_convert_cell_type(TileAssembler(columnIndex, rowIndex, cellData, lit(tileCols), lit(tileRows)), ct).as(cellData.columnName).as[Tile](singlebandTileEncoder)
+ /** Create a Tile from a column of cell data with location indexes and perform cell conversion. */
+ def rf_assemble_tile(columnIndex: Column, rowIndex: Column, cellData: Column, tileCols: Int, tileRows: Int): TypedColumn[Any, Tile] =
+ TileAssembler(columnIndex, rowIndex, cellData, lit(tileCols), lit(tileRows))
+
/** Create a Tile from a column of cell data with location indexes. */
def rf_assemble_tile(columnIndex: Column, rowIndex: Column, cellData: Column, tileCols: Column, tileRows: Column): TypedColumn[Any, Tile] =
TileAssembler(columnIndex, rowIndex, cellData, tileCols, tileRows)
@@ -317,12 +323,24 @@ trait RasterFunctions {
ReprojectGeometry(sourceGeom, srcCRSCol, dstCRSCol)
/** Render Tile as ASCII string, for debugging purposes. */
- def rf_render_ascii(col: Column): TypedColumn[Any, String] =
- DebugRender.RenderAscii(col)
+ def rf_render_ascii(tile: Column): TypedColumn[Any, String] =
+ DebugRender.RenderAscii(tile)
/** Render Tile cell values as numeric values, for debugging purposes. */
- def rf_render_matrix(col: Column): TypedColumn[Any, String] =
- DebugRender.RenderMatrix(col)
+ def rf_render_matrix(tile: Column): TypedColumn[Any, String] =
+ DebugRender.RenderMatrix(tile)
+
+ /** Converts tiles in a column into PNG encoded byte array, using given ColorRamp to assign values to colors. */
+ def rf_render_png(tile: Column, colors: ColorRamp): TypedColumn[Any, Array[Byte]] =
+ RenderColorRampPNG(tile, colors)
+
+ /** Converts columns of tiles representing RGB channels into a PNG encoded byte array. */
+ def rf_render_png(red: Column, green: Column, blue: Column): TypedColumn[Any, Array[Byte]] =
+ RenderCompositePNG(red, green, blue)
+
+ /** Converts columns of tiles representing RGB channels into a single RGB packaged tile. */
+ def rf_rgb_composite(red: Column, green: Column, blue: Column): Column =
+ RGBComposite(red, green, blue)
/** Cellwise less than value comparison between two tiles. */
def rf_local_less(left: Column, right: Column): Column = Less(left, right)
diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala
index ea502b3a3..66517c57c 100644
--- a/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala
+++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/package.scala
@@ -128,6 +128,9 @@ package object expressions {
registry.registerExpression[DebugRender.RenderAscii]("rf_render_ascii")
registry.registerExpression[DebugRender.RenderMatrix]("rf_render_matrix")
+ registry.registerExpression[RenderPNG.RenderCompositePNG]("rf_render_png")
+ registry.registerExpression[RGBComposite]("rf_rgb_composite")
+
registry.registerExpression[transformers.ReprojectGeometry]("st_reproject")
}
}
diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala
index babb9c7b7..54201152e 100644
--- a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala
+++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/DebugRender.scala
@@ -21,28 +21,29 @@
package org.locationtech.rasterframes.expressions.transformers
-import org.locationtech.rasterframes.expressions.UnaryRasterOp
-import org.locationtech.rasterframes.util.TileAsMatrix
-import geotrellis.raster.Tile
import geotrellis.raster.render.ascii.AsciiArtEncoder
+import geotrellis.raster.{Tile, isNoData}
import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
import org.apache.spark.sql.types.{DataType, StringType}
import org.apache.spark.sql.{Column, TypedColumn}
import org.apache.spark.unsafe.types.UTF8String
+import org.locationtech.rasterframes.expressions.UnaryRasterOp
import org.locationtech.rasterframes.model.TileContext
+import spire.syntax.cfor.cfor
abstract class DebugRender(asciiArt: Boolean) extends UnaryRasterOp
- with CodegenFallback with Serializable {
- override def dataType: DataType = StringType
+ with CodegenFallback with Serializable {
+ import org.locationtech.rasterframes.expressions.transformers.DebugRender.TileAsMatrix
+ override def dataType: DataType = StringType
- override protected def eval(tile: Tile, ctx: Option[TileContext]): Any = {
- UTF8String.fromString(if (asciiArt)
- s"\n${tile.renderAscii(AsciiArtEncoder.Palette.NARROW)}\n"
- else
- s"\n${tile.renderMatrix(6)}\n"
- )
- }
+ override protected def eval(tile: Tile, ctx: Option[TileContext]): Any = {
+ UTF8String.fromString(if (asciiArt)
+ s"\n${tile.renderAscii(AsciiArtEncoder.Palette.NARROW)}\n"
+ else
+ s"\n${tile.renderMatrix(6)}\n"
+ )
+ }
}
object DebugRender {
@@ -75,4 +76,29 @@ object DebugRender {
def apply(tile: Column): TypedColumn[Any, String] =
new Column(RenderMatrix(tile.expr)).as[String]
}
+
+ implicit class TileAsMatrix(val tile: Tile) extends AnyVal {
+ def renderMatrix(significantDigits: Int): String = {
+ val ND = s"%${significantDigits+5}s".format(Double.NaN)
+ val fmt = s"% ${significantDigits+5}.${significantDigits}g"
+ val buf = new StringBuilder("[")
+ cfor(0)(_ < tile.rows, _ + 1) { row =>
+ if(row > 0) buf.append(' ')
+ buf.append('[')
+ cfor(0)(_ < tile.cols, _ + 1) { col =>
+ val v = tile.getDouble(col, row)
+ if (isNoData(v)) buf.append(ND)
+ else buf.append(fmt.format(v))
+
+ if (col < tile.cols - 1)
+ buf.append(',')
+ }
+ buf.append(']')
+ if (row < tile.rows - 1)
+ buf.append(",\n")
+ }
+ buf.append("]")
+ buf.toString()
+ }
+ }
}
diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala
new file mode 100644
index 000000000..9f0a9c808
--- /dev/null
+++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RGBComposite.scala
@@ -0,0 +1,102 @@
+/*
+ * This software is licensed under the Apache 2 license, quoted below.
+ *
+ * Copyright 2019 Astraea, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * [http://www.apache.org/licenses/LICENSE-2.0]
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.locationtech.rasterframes.expressions.transformers
+
+import geotrellis.raster.ArrayMultibandTile
+import org.apache.spark.sql.Column
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
+import org.apache.spark.sql.catalyst.analysis.TypeCheckResult.{TypeCheckFailure, TypeCheckSuccess}
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription, TernaryExpression}
+import org.apache.spark.sql.rf.TileUDT
+import org.apache.spark.sql.types.DataType
+import org.locationtech.rasterframes._
+import org.locationtech.rasterframes.encoders.CatalystSerializer._
+import org.locationtech.rasterframes.expressions.DynamicExtractors.tileExtractor
+import org.locationtech.rasterframes.expressions.row
+import org.locationtech.rasterframes.tiles.ProjectedRasterTile
+
+/**
+ * Expression to combine the given tile columns into an 32-bit RGB composite.
+ * Tiles in each row will first be and-ed with 0xFF, bit shifted, and or-ed into a single 32-bit word.
+ * @param red tile column to represent red channel
+ * @param green tile column to represent green channel
+ * @param blue tile column to represent blue channel
+ */
+@ExpressionDescription(
+ usage = "_FUNC_(red, green, blue) - Combines the given tile columns into an 32-bit RGB composite.",
+ arguments = """
+ Arguments:
+ * red - tile column representing the red channel
+ * green - tile column representing the green channel
+ * blue - tile column representing the blue channel"""
+)
+case class RGBComposite(red: Expression, green: Expression, blue: Expression) extends TernaryExpression
+ with CodegenFallback {
+
+ override def nodeName: String = "rf_rgb_composite"
+
+ override def dataType: DataType = if(
+ red.dataType.conformsTo[ProjectedRasterTile] ||
+ blue.dataType.conformsTo[ProjectedRasterTile] ||
+ green.dataType.conformsTo[ProjectedRasterTile]
+ ) red.dataType
+ else TileType
+
+ override def children: Seq[Expression] = Seq(red, green, blue)
+
+ override def checkInputDataTypes(): TypeCheckResult = {
+ if (!tileExtractor.isDefinedAt(red.dataType)) {
+ TypeCheckFailure(s"Red channel input type '${red.dataType}' does not conform to a raster type.")
+ }
+ else if (!tileExtractor.isDefinedAt(green.dataType)) {
+ TypeCheckFailure(s"Green channel input type '${green.dataType}' does not conform to a raster type.")
+ }
+ else if (!tileExtractor.isDefinedAt(blue.dataType)) {
+ TypeCheckFailure(s"Blue channel input type '${blue.dataType}' does not conform to a raster type.")
+ }
+ else TypeCheckSuccess
+ }
+
+ override protected def nullSafeEval(input1: Any, input2: Any, input3: Any): Any = {
+ val (r, rc) = tileExtractor(red.dataType)(row(input1))
+ val (g, gc) = tileExtractor(green.dataType)(row(input2))
+ val (b, bc) = tileExtractor(blue.dataType)(row(input3))
+
+ // Pick the first available TileContext, if any, and reassociate with the result
+ val ctx = Seq(rc, gc, bc).flatten.headOption
+ val composite = ArrayMultibandTile(
+ r.rescale(0, 255), g.rescale(0, 255), b.rescale(0, 255)
+ ).color()
+ ctx match {
+ case Some(c) => c.toProjectRasterTile(composite).toInternalRow
+ case None =>
+ implicit val tileSer = TileUDT.tileSerializer
+ composite.toInternalRow
+ }
+ }
+}
+
+object RGBComposite {
+ def apply(red: Column, green: Column, blue: Column): Column =
+ new Column(RGBComposite(red.expr, green.expr, blue.expr))
+}
diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RenderPNG.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RenderPNG.scala
new file mode 100644
index 000000000..144a4abb6
--- /dev/null
+++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/transformers/RenderPNG.scala
@@ -0,0 +1,79 @@
+/*
+ * This software is licensed under the Apache 2 license, quoted below.
+ *
+ * Copyright 2019 Astraea, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * [http://www.apache.org/licenses/LICENSE-2.0]
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.locationtech.rasterframes.expressions.transformers
+
+import geotrellis.raster.Tile
+import geotrellis.raster.render.ColorRamp
+import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback
+import org.apache.spark.sql.catalyst.expressions.{Expression, ExpressionDescription}
+import org.apache.spark.sql.types.{BinaryType, DataType}
+import org.apache.spark.sql.{Column, TypedColumn}
+import org.locationtech.rasterframes.expressions.UnaryRasterOp
+import org.locationtech.rasterframes.model.TileContext
+
+/**
+ * Converts a tile into a PNG encoded byte array.
+ * @param child tile column
+ * @param ramp color ramp to use for non-composite tiles.
+ */
+abstract class RenderPNG(child: Expression, ramp: Option[ColorRamp]) extends UnaryRasterOp with CodegenFallback with Serializable {
+ override def dataType: DataType = BinaryType
+ override protected def eval(tile: Tile, ctx: Option[TileContext]): Any = {
+ val png = ramp.map(tile.renderPng).getOrElse(tile.renderPng())
+ png.bytes
+ }
+}
+
+object RenderPNG {
+ import org.locationtech.rasterframes.encoders.SparkBasicEncoders._
+
+ @ExpressionDescription(
+ usage = "_FUNC_(tile) - Encode the given tile into a RGB composite PNG. Assumes the red, green, and " +
+ "blue channels are encoded as 8-bit channels within the 32-bit word.",
+ arguments = """
+ Arguments:
+ * tile - tile to render"""
+ )
+ case class RenderCompositePNG(child: Expression) extends RenderPNG(child, None) {
+ override def nodeName: String = "rf_render_png"
+ }
+
+ object RenderCompositePNG {
+ def apply(red: Column, green: Column, blue: Column): TypedColumn[Any, Array[Byte]] =
+ new Column(RenderCompositePNG(RGBComposite(red.expr, green.expr, blue.expr))).as[Array[Byte]]
+ }
+
+ @ExpressionDescription(
+ usage = "_FUNC_(tile) - Encode the given tile as a PNG using a color ramp with assignemnts from quantile computation",
+ arguments = """
+ Arguments:
+ * tile - tile to render"""
+ )
+ case class RenderColorRampPNG(child: Expression, colors: ColorRamp) extends RenderPNG(child, Some(colors)) {
+ override def nodeName: String = "rf_render_png"
+ }
+
+ object RenderColorRampPNG {
+ def apply(tile: Column, colors: ColorRamp): TypedColumn[Any, Array[Byte]] =
+ new Column(RenderColorRampPNG(tile.expr, colors)).as[Array[Byte]]
+ }
+}
diff --git a/core/src/main/scala/org/locationtech/rasterframes/extensions/SinglebandGeoTiffMethods.scala b/core/src/main/scala/org/locationtech/rasterframes/extensions/SinglebandGeoTiffMethods.scala
index 833ba80e3..168444efe 100644
--- a/core/src/main/scala/org/locationtech/rasterframes/extensions/SinglebandGeoTiffMethods.scala
+++ b/core/src/main/scala/org/locationtech/rasterframes/extensions/SinglebandGeoTiffMethods.scala
@@ -30,6 +30,7 @@ import org.apache.spark.sql.{DataFrame, Row, SparkSession}
import org.locationtech.rasterframes._
import org.locationtech.rasterframes.encoders.CatalystSerializer._
import org.locationtech.rasterframes.model.TileDimensions
+import org.locationtech.rasterframes.tiles.ProjectedRasterTile
trait SinglebandGeoTiffMethods extends MethodExtensions[SinglebandGeoTiff] {
def toDF(dims: TileDimensions = NOMINAL_TILE_DIMS)(implicit spark: SparkSession): DataFrame = {
@@ -56,4 +57,6 @@ trait SinglebandGeoTiffMethods extends MethodExtensions[SinglebandGeoTiff] {
spark.createDataFrame(spark.sparkContext.makeRDD(rows, 1), schema)
}
+
+ def toProjectedRasterTile: ProjectedRasterTile = ProjectedRasterTile(self.projectedRaster)
}
diff --git a/core/src/main/scala/org/locationtech/rasterframes/util/DataFrameRenderers.scala b/core/src/main/scala/org/locationtech/rasterframes/util/DataFrameRenderers.scala
new file mode 100644
index 000000000..ae57edcf3
--- /dev/null
+++ b/core/src/main/scala/org/locationtech/rasterframes/util/DataFrameRenderers.scala
@@ -0,0 +1,101 @@
+/*
+ * This software is licensed under the Apache 2 license, quoted below.
+ *
+ * Copyright 2019 Astraea, Inc.
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License"); you may not
+ * use this file except in compliance with the License. You may obtain a copy of
+ * the License at
+ *
+ * [http://www.apache.org/licenses/LICENSE-2.0]
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+ * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+ * License for the specific language governing permissions and limitations under
+ * the License.
+ *
+ * SPDX-License-Identifier: Apache-2.0
+ *
+ */
+
+package org.locationtech.rasterframes.util
+
+import geotrellis.raster.render.ColorRamps
+import org.apache.spark.sql.Dataset
+import org.apache.spark.sql.functions.{base64, concat, concat_ws, length, lit, substring, when}
+import org.apache.spark.sql.types.{StringType, StructField}
+import org.locationtech.rasterframes.expressions.DynamicExtractors
+import org.locationtech.rasterframes.{rfConfig, rf_render_png, rf_resample}
+
+/**
+ * DataFrame extensiosn for rendering sample content in a number of ways
+ */
+trait DataFrameRenderers {
+ private val truncateWidth = rfConfig.getInt("max-truncate-row-element-length")
+
+ implicit class DFWithPrettyPrint(val df: Dataset[_]) {
+
+ private def stringifyRowElements(cols: Seq[StructField], truncate: Boolean, renderTiles: Boolean) = {
+ cols
+ .map(c => {
+ val resolved = df.col(s"`${c.name}`")
+ if (renderTiles && DynamicExtractors.tileExtractor.isDefinedAt(c.dataType))
+ concat(
+ lit("
")
+ )
+ else {
+ val str = resolved.cast(StringType)
+ if (truncate)
+ when(length(str) > lit(truncateWidth),
+ concat(substring(str, 1, truncateWidth), lit("..."))
+ )
+ .otherwise(str)
+ else str
+ }
+ })
+ }
+
+ def toMarkdown(numRows: Int = 5, truncate: Boolean = false, renderTiles: Boolean = true): String = {
+ import df.sqlContext.implicits._
+ val cols = df.schema.fields
+ val header = cols.map(_.name).mkString("| ", " | ", " |") + "\n" + ("|---" * cols.length) + "|\n"
+ val stringifiers = stringifyRowElements(cols, truncate, renderTiles)
+ val cat = concat_ws(" | ", stringifiers: _*)
+ val rows = df
+ .select(cat)
+ .limit(numRows)
+ .as[String]
+ .collect()
+ .map(_.replaceAll("\\[", "\\\\["))
+ .map(_.replace('\n', '↩'))
+
+ val body = rows
+ .mkString("| ", " |\n| ", " |")
+
+ val caption = if (rows.length >= numRows) s"\n_Showing only top $numRows rows_.\n\n" else ""
+ caption + header + body
+ }
+
+ def toHTML(numRows: Int = 5, truncate: Boolean = false, renderTiles: Boolean = true): String = {
+ import df.sqlContext.implicits._
+ val cols = df.schema.fields
+ val header = "\n" + cols.map(_.name).mkString("", " | ", " |
\n") + "\n"
+ val stringifiers = stringifyRowElements(cols, truncate, renderTiles)
+ val cat = concat_ws("
", stringifiers: _*)
+ val rows = df
+ .select(cat).limit(numRows)
+ .as[String]
+ .collect()
+
+ val body = rows
+ .mkString(" | ", " |
\n", " |
\n")
+
+ val caption = if (rows.length >= numRows) s"Showing only top $numRows rows\n" else ""
+
+ "\n" + caption + header + "\n" + body + "\n" + "
"
+ }
+ }
+}
diff --git a/core/src/main/scala/org/locationtech/rasterframes/util/package.scala b/core/src/main/scala/org/locationtech/rasterframes/util/package.scala
index d1b967af0..f4c6854ab 100644
--- a/core/src/main/scala/org/locationtech/rasterframes/util/package.scala
+++ b/core/src/main/scala/org/locationtech/rasterframes/util/package.scala
@@ -22,26 +22,24 @@
package org.locationtech.rasterframes
import com.typesafe.scalalogging.Logger
+import geotrellis.raster.CellGrid
import geotrellis.raster.crop.TileCropMethods
import geotrellis.raster.io.geotiff.reader.GeoTiffReader
import geotrellis.raster.mapalgebra.local.LocalTileBinaryOp
import geotrellis.raster.mask.TileMaskMethods
import geotrellis.raster.merge.TileMergeMethods
import geotrellis.raster.prototype.TilePrototypeMethods
-import geotrellis.raster.{CellGrid, Tile, isNoData}
import geotrellis.spark.Bounds
import geotrellis.spark.tiling.TilerKeyMethods
import geotrellis.util.{ByteReader, GetComponent}
+import org.apache.spark.sql._
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
import org.apache.spark.sql.catalyst.expressions.{Alias, Expression, NamedExpression}
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.rules.Rule
-import org.apache.spark.sql.functions._
import org.apache.spark.sql.rf._
-import org.apache.spark.sql.types.{StringType, StructField}
-import org.apache.spark.sql._
+import org.apache.spark.sql.types.StringType
import org.slf4j.LoggerFactory
-import spire.syntax.cfor._
import scala.Boolean.box
@@ -50,7 +48,7 @@ import scala.Boolean.box
*
* @since 12/18/17
*/
-package object util {
+package object util extends DataFrameRenderers {
@transient
protected lazy val logger: Logger =
Logger(LoggerFactory.getLogger("org.locationtech.rasterframes"))
@@ -159,86 +157,6 @@ package object util {
analyzer(sqlContext).extendedResolutionRules
}
- implicit class TileAsMatrix(val tile: Tile) extends AnyVal {
- def renderMatrix(significantDigits: Int): String = {
- val ND = s"%${significantDigits+5}s".format(Double.NaN)
- val fmt = s"% ${significantDigits+5}.${significantDigits}g"
- val buf = new StringBuilder("[")
- cfor(0)(_ < tile.rows, _ + 1) { row =>
- if(row > 0) buf.append(' ')
- buf.append('[')
- cfor(0)(_ < tile.cols, _ + 1) { col =>
- val v = tile.getDouble(col, row)
- if (isNoData(v)) buf.append(ND)
- else buf.append(fmt.format(v))
-
- if (col < tile.cols - 1)
- buf.append(',')
- }
- buf.append(']')
- if (row < tile.rows - 1)
- buf.append(",\n")
- }
- buf.append("]")
- buf.toString()
- }
- }
-
- private val truncateWidth = rfConfig.getInt("max-truncate-row-element-length")
-
- implicit class DFWithPrettyPrint(val df: Dataset[_]) extends AnyVal {
-
- def stringifyRowElements(cols: Seq[StructField], truncate: Boolean) = {
- cols
- .map(c => s"`${c.name}`")
- .map(c => df.col(c).cast(StringType))
- .map(c => if (truncate) {
- when(length(c) > lit(truncateWidth), concat(substring(c, 1, truncateWidth), lit("...")))
- .otherwise(c)
- } else c)
- }
-
- def toMarkdown(numRows: Int = 5, truncate: Boolean = false): String = {
- import df.sqlContext.implicits._
- val cols = df.schema.fields
- val header = cols.map(_.name).mkString("| ", " | ", " |") + "\n" + ("|---" * cols.length) + "|\n"
- val stringifiers = stringifyRowElements(cols, truncate)
- val cat = concat_ws(" | ", stringifiers: _*)
- val rows = df
- .select(cat)
- .limit(numRows)
- .as[String]
- .collect()
- .map(_.replaceAll("\\[", "\\\\["))
- .map(_.replace('\n', '↩'))
-
- val body = rows
- .mkString("| ", " |\n| ", " |")
-
- val caption = if (rows.length >= numRows) s"\n_Showing only top $numRows rows_.\n\n" else ""
- caption + header + body
- }
-
- def toHTML(numRows: Int = 5, truncate: Boolean = false): String = {
- import df.sqlContext.implicits._
- val cols = df.schema.fields
- val header = "\n" + cols.map(_.name).mkString("", " | ", " |
\n") + "\n"
- val stringifiers = stringifyRowElements(cols, truncate)
- val cat = concat_ws("", stringifiers: _*)
- val rows = df
- .select(cat).limit(numRows)
- .as[String]
- .collect()
-
- val body = rows
- .mkString(" | ", " |
\n", " |
\n")
-
- val caption = if (rows.length >= numRows) s"Showing only top $numRows rows\n" else ""
-
- "\n" + caption + header + "\n" + body + "\n" + "
"
- }
- }
-
object Shims {
// GT 1.2.1 to 2.0.0
def toArrayTile[T <: CellGrid](tile: T): T =
@@ -281,5 +199,4 @@ package object util {
result.asInstanceOf[GeoTiffReader.GeoTiffInfo]
}
}
-
}
diff --git a/core/src/test/scala/org/locationtech/rasterframes/ExtensionMethodSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/ExtensionMethodSpec.scala
index 2359e88fb..4f5fe3591 100644
--- a/core/src/test/scala/org/locationtech/rasterframes/ExtensionMethodSpec.scala
+++ b/core/src/test/scala/org/locationtech/rasterframes/ExtensionMethodSpec.scala
@@ -113,17 +113,27 @@ class ExtensionMethodSpec extends TestEnvironment with TestData with SubdivideSu
}
it("should render Markdown") {
+ import org.apache.spark.sql.functions.lit
+
val md = rf.toMarkdown()
md.count(_ == '|') shouldBe >=(3 * 5)
- md.count(_ == '\n') should be >=(6)
+ md.count(_ == '\n') should be >= 6
- val md2 = rf.toMarkdown(truncate=true)
+ val md2 = rf.withColumn("long_string", lit("p" * 42)).toMarkdown(truncate=true, renderTiles = false)
md2 should include ("...")
+
+ val md3 = rf.toMarkdown(truncate=true, renderTiles = false)
+ md3 shouldNot include("
-
-
-
-
- |
- extent |
- tile |
-
-
-
-
- 0 |
- (-7783653.637667, 1052646.4919514267, -7724349.609951426, 1111950.519667) |
-  |
-
-
- 1 |
- (-7724349.609951427, 1052646.4919514267, -7665045.582235852, 1111950.519667) |
-  |
-
-
- 2 |
- (-7665045.582235853, 1052646.4919514267, -7605741.554520279, 1111950.519667) |
-  |
-
-
- 3 |
- (-7605741.55452028, 1052646.4919514267, -7546437.526804706, 1111950.519667) |
-  |
-
-
- 4 |
- (-7546437.526804707, 1052646.4919514267, -7487133.499089133, 1111950.519667) |
-  |
-
-
-
-
diff --git a/pyrasterframes/src/main/python/docs/static/rasterframe-sample.md b/pyrasterframes/src/main/python/docs/static/rasterframe-sample.md
index c7ec4ed5d..2d850a31f 100644
--- a/pyrasterframes/src/main/python/docs/static/rasterframe-sample.md
+++ b/pyrasterframes/src/main/python/docs/static/rasterframe-sample.md
@@ -14,7 +14,7 @@
2019-02-28 |
(+proj=sinu +lon_0=0.0 +x_0=0.0 +y_0=0.0 +a=6371007.181 +b=6371007.181 +units=m ,) |
(-7783653.637667, 993342.4642358534, -7665045.582235852, 1111950.519667) |
-  | \
+  |
1 |
@@ -35,14 +35,14 @@
2019-02-28 |
(+proj=sinu +lon_0=0.0 +x_0=0.0 +y_0=0.0 +a=6371007.181 +b=6371007.181 +units=m ,) |
(-7427829.47137356, 993342.4642358534, -7309221.415942413, 1111950.519667) |
-  | \
+  |
4 |
2019-02-28 |
(+proj=sinu +lon_0=0.0 +x_0=0.0 +y_0=0.0 +a=6371007.181 +b=6371007.181 +units=m ,) |
(-7309221.415942414, 993342.4642358534, -7190613.360511266, 1111950.519667) |
-  | \
+  |
diff --git a/pyrasterframes/src/main/python/docs/supervised-learning.pymd b/pyrasterframes/src/main/python/docs/supervised-learning.pymd
index f66fcbdf9..0a3f8c0ef 100644
--- a/pyrasterframes/src/main/python/docs/supervised-learning.pymd
+++ b/pyrasterframes/src/main/python/docs/supervised-learning.pymd
@@ -72,17 +72,18 @@ crses = df.select('crs.crsProj4').distinct().collect()
print('Found ', len(crses), 'distinct CRS.')
crs = crses[0][0]
-label_df = spark.read.geojson(os.path.join(resource_dir_uri(), 'luray-labels.geojson')) \
- .select('id', st_reproject('geometry', lit('EPSG:4326'), lit(crs)).alias('geometry')) \
- .hint('broadcast')
+label_df = spark.read.geojson(
+ os.path.join(resource_dir_uri(), 'luray-labels.geojson')) \
+ .select('id', st_reproject('geometry', lit('EPSG:4326'), lit(crs)).alias('geometry')) \
+ .hint('broadcast')
-df_joined = df.join(label_df, st_intersects(st_geometry('extent'), 'geometry'))
+df_joined = df.join(label_df, st_intersects(st_geometry('extent'), 'geometry')) \
+ .withColumn('dims', rf_dimensions('B01'))
+
+df_labeled = df_joined.withColumn('label',
+ rf_rasterize('geometry', st_geometry('extent'), 'id', 'dims.cols', 'dims.rows')
+)
-df_joined.createOrReplaceTempView('df_joined')
-df_labeled = spark.sql("""
-SELECT *, rf_rasterize(geometry, st_geometry(extent), id, rf_dimensions(B01).cols, rf_dimensions(B01).rows) AS label
-FROM df_joined
-""")
```
## Masking Poor Quality Cells
@@ -92,17 +93,20 @@ To filter only for good quality pixels, we follow roughly the same procedure as
```python, make_mask
from pyspark.sql.functions import lit
-mask_part = df_labeled.withColumn('nodata', rf_local_equal('scl', lit(0))) \
- .withColumn('defect', rf_local_equal('scl', lit(1))) \
- .withColumn('cloud8', rf_local_equal('scl', lit(8))) \
- .withColumn('cloud9', rf_local_equal('scl', lit(9))) \
- .withColumn('cirrus', rf_local_equal('scl', lit(10)))
-
-df_mask_inv = mask_part.withColumn('mask', rf_local_add('nodata', 'defect')) \
- .withColumn('mask', rf_local_add('mask', 'cloud8')) \
- .withColumn('mask', rf_local_add('mask', 'cloud9')) \
- .withColumn('mask', rf_local_add('mask', 'cirrus')) \
- .drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus')
+mask_part = df_labeled \
+ .withColumn('nodata', rf_local_equal('scl', lit(0))) \
+ .withColumn('defect', rf_local_equal('scl', lit(1))) \
+ .withColumn('cloud8', rf_local_equal('scl', lit(8))) \
+ .withColumn('cloud9', rf_local_equal('scl', lit(9))) \
+ .withColumn('cirrus', rf_local_equal('scl', lit(10)))
+
+df_mask_inv = mask_part \
+ .withColumn('mask', rf_local_add('nodata', 'defect')) \
+ .withColumn('mask', rf_local_add('mask', 'cloud8')) \
+ .withColumn('mask', rf_local_add('mask', 'cloud9')) \
+ .withColumn('mask', rf_local_add('mask', 'cirrus')) \
+ .drop('nodata', 'defect', 'cloud8', 'cloud9', 'cirrus')
+
# at this point the mask contains 0 for good cells and 1 for defect, etc
# convert cell type and set value 1 to NoData
df_mask = df_mask_inv.withColumn('mask',
@@ -159,7 +163,8 @@ pipeline.getStages()
The next step is to actually run each step of the Pipeline we created, including fitting the decision tree model. We filter the DataFrame for only _tiles_ intersecting the label raster because the label shapes are relatively sparse over the imagery. It would be logically equivalent to either include or exclude thi step, but it is more efficient to filter because it will mean less data going into the pipeline.
```python, train
-model = pipeline.fit(df_mask.filter(rf_tile_sum('label') > 0).cache())
+model_input = df_mask.filter(rf_tile_sum('label') > 0).cache()
+model = pipeline.fit(model_input)
```
## Model Evaluation
@@ -171,9 +176,11 @@ prediction_df = model.transform(df_mask) \
.drop(assembler.getOutputCol()).cache()
prediction_df.printSchema()
-eval = MulticlassClassificationEvaluator(predictionCol=classifier.getPredictionCol(),
- labelCol=classifier.getLabelCol(),
- metricName='accuracy')
+eval = MulticlassClassificationEvaluator(
+ predictionCol=classifier.getPredictionCol(),
+ labelCol=classifier.getLabelCol(),
+ metricName='accuracy'
+)
accuracy = eval.evaluate(prediction_df)
print("\nAccuracy:", accuracy)
@@ -185,7 +192,7 @@ As an example of using the flexibility provided by DataFrames, the code below co
cnf_mtrx = prediction_df.groupBy(classifier.getPredictionCol()) \
.pivot(classifier.getLabelCol()) \
.count() \
- .sort(classifier.getPredictionCol())
+ .sort(classifier.getPredictionCol())
cnf_mtrx
```
@@ -195,40 +202,33 @@ Because the pipeline included a `TileExploder`, we will recreate the tiled data
```python, assemble_prediction
scored = model.transform(df_mask.drop('label'))
-scored.createOrReplaceTempView('scored')
-
-retiled = spark.sql("""
-SELECT extent, crs,
- rf_assemble_tile(column_index, row_index, prediction, 128, 128) as prediction,
- rf_assemble_tile(column_index, row_index, B04, 128, 128) as red,
- rf_assemble_tile(column_index, row_index, B03, 128, 128) as grn,
- rf_assemble_tile(column_index, row_index, B02, 128, 128) as blu
-FROM scored
-GROUP BY extent, crs
-""")
+retiled = scored \
+ .groupBy('extent', 'crs') \
+ .agg(
+ rf_assemble_tile('column_index', 'row_index', 'prediction', 128, 128).alias('prediction'),
+ rf_assemble_tile('column_index', 'row_index', 'B04', 128, 128).alias('red'),
+ rf_assemble_tile('column_index', 'row_index', 'B03', 128, 128).alias('grn'),
+ rf_assemble_tile('column_index', 'row_index', 'B02', 128, 128).alias('blu')
+ )
retiled.printSchema()
```
Take a look at a sample of the resulting output and the corresponding area's red-green-blue composite image.
```python, display_rgb
-sample = retiled.select('prediction', 'red', 'grn', 'blu') \
+sample = retiled \
+ .select('prediction', rf_rgb_composite('red', 'grn', 'blu').alias('rgb')) \
.sort(-rf_tile_sum(rf_local_equal('prediction', lit(3.0)))) \
.first()
-sample_prediction = sample['prediction']
-
-red = sample['red'].cells
-grn = sample['grn'].cells
-blu = sample['blu'].cells
-sample_rgb = np.concatenate([red[ :, :, None], grn[:, :, None] , blu[ :, :, None]], axis=2)
-mins = np.nanmin(sample_rgb, axis=(0,1))
-plt.imshow((sample_rgb - mins)/ (np.nanmax(sample_rgb, axis=(0,1)) - mins))
+sample_rgb = sample['rgb']
+mins = np.nanmin(sample_rgb.cells, axis=(0,1))
+plt.imshow((sample_rgb.cells - mins) / (np.nanmax(sample_rgb.cells, axis=(0,1)) - mins))
```
Recall the label coding: 1 is forest (purple), 2 is cropland (green) and 3 is developed areas(yellow).
```python, display_prediction
-display(sample_prediction)
+display(sample['prediction'])
```
diff --git a/pyrasterframes/src/main/python/docs/vector-data.pymd b/pyrasterframes/src/main/python/docs/vector-data.pymd
index cc8e761d1..2c66b1562 100644
--- a/pyrasterframes/src/main/python/docs/vector-data.pymd
+++ b/pyrasterframes/src/main/python/docs/vector-data.pymd
@@ -86,7 +86,7 @@ As documented in the @ref:[function reference](reference.md), various user-defin
```python, native_centroid
from pyrasterframes.rasterfunctions import st_centroid
df = df.withColumn('centroid', st_centroid(df.geometry))
-centroids = df.select('name', 'geometry', 'naive_centroid', 'centroid')
+centroids = df.select('geometry', 'name', 'naive_centroid', 'centroid')
centroids.limit(3)
```
@@ -101,14 +101,9 @@ l8 = l8.withColumn('geom', st_geometry(l8.bounds_wgs84))
l8 = l8.withColumn('paducah', st_point(lit(-88.6275), lit(37.072222)))
l8_filtered = l8.filter(st_intersects(l8.geom, st_bufferPoint(l8.paducah, lit(500000.0))))
+l8_filtered.select('product_id', 'entity_id', 'acquisition_date', 'cloud_cover_pct')
```
-```python, evaluate=False, echo=False
-# suppressed due to run time.
-l8_filtered.count()
-```
-
-
[GeoPandas]: http://geopandas.org
[OGR]: https://gdal.org/drivers/vector/index.html
[Shapely]: https://shapely.readthedocs.io/en/latest/manual.html
diff --git a/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py b/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py
index 6545b5a72..86250b83a 100644
--- a/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py
+++ b/pyrasterframes/src/main/python/pyrasterframes/rasterfunctions.py
@@ -53,13 +53,27 @@ def rf_cell_types():
return [CellType(str(ct)) for ct in _context_call('rf_cell_types')]
-def rf_assemble_tile(col_index, row_index, cell_data_col, num_cols, num_rows, cell_type):
+def rf_assemble_tile(col_index, row_index, cell_data_col, num_cols, num_rows, cell_type=None):
"""Create a Tile from a column of cell data with location indices"""
jfcn = RFContext.active().lookup('rf_assemble_tile')
- return Column(
- jfcn(_to_java_column(col_index), _to_java_column(row_index), _to_java_column(cell_data_col), num_cols, num_rows,
- _parse_cell_type(cell_type)))
+ if isinstance(num_cols, Column):
+ num_cols = _to_java_column(num_cols)
+
+ if isinstance(num_rows, Column):
+ num_rows = _to_java_column(num_rows)
+
+ if cell_type is None:
+ return Column(jfcn(
+ _to_java_column(col_index), _to_java_column(row_index), _to_java_column(cell_data_col),
+ num_cols, num_rows
+ ))
+
+ else:
+ return Column(jfcn(
+ _to_java_column(col_index), _to_java_column(row_index), _to_java_column(cell_data_col),
+ num_cols, num_rows, _parse_cell_type(cell_type)
+ ))
def rf_array_to_tile(array_col, num_cols, num_rows):
"""Convert array in `array_col` into a Tile of dimensions `num_cols` and `num_rows'"""
@@ -348,6 +362,16 @@ def rf_render_matrix(tile_col):
return _apply_column_function('rf_render_matrix', tile_col)
+def rf_render_png(red_tile_col, green_tile_col, blue_tile_col):
+ """Converts columns of tiles representing RGB channels into a PNG encoded byte array."""
+ return _apply_column_function('rf_render_png', red_tile_col, green_tile_col, blue_tile_col)
+
+
+def rf_rgb_composite(red_tile_col, green_tile_col, blue_tile_col):
+ """Converts columns of tiles representing RGB channels into a single RGB packaged tile."""
+ return _apply_column_function('rf_rgb_composite', red_tile_col, green_tile_col, blue_tile_col)
+
+
def rf_no_data_cells(tile_col):
"""Count of NODATA cells"""
return _apply_column_function('rf_no_data_cells', tile_col)
diff --git a/pyrasterframes/src/main/python/tests/GeoTiffWriterTests.py b/pyrasterframes/src/main/python/tests/GeoTiffWriterTests.py
index 8fa30a7b2..ef28c6562 100644
--- a/pyrasterframes/src/main/python/tests/GeoTiffWriterTests.py
+++ b/pyrasterframes/src/main/python/tests/GeoTiffWriterTests.py
@@ -61,7 +61,7 @@ def test_unstructured_write(self):
os.remove(dest_file)
- def test_unstructured_write_schemeless(self):
+ def test_unstructured_write_schemaless(self):
# should be able to write a projected raster tile column to path like '/data/foo/file.tif'
from pyrasterframes.rasterfunctions import rf_agg_stats, rf_crs
rf = self.spark.read.raster(self.img_uri)
diff --git a/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py b/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py
index f682e2609..feee746eb 100644
--- a/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py
+++ b/pyrasterframes/src/main/python/tests/PyRasterFramesTests.py
@@ -410,156 +410,6 @@ def test_raster_join(self):
self.rf.raster_join(rf_prime, join_exprs=self.rf.extent)
-class RasterSource(TestEnvironment):
-
- def test_handle_lazy_eval(self):
- df = self.spark.read.raster(self.img_uri)
- ltdf = df.select('proj_raster')
- self.assertGreater(ltdf.count(), 0)
- self.assertIsNotNone(ltdf.first())
-
- tdf = df.select(rf_tile('proj_raster'))
- self.assertGreater(tdf.count(), 0)
- self.assertIsNotNone(tdf.first())
-
- def test_strict_eval(self):
- df_lazy = self.spark.read.raster(self.img_uri, lazy_tiles=True)
- # when doing Show on a lazy tile we will see something like RasterRefTile(RasterRef(JVMGeoTiffRasterSource(...
- # use this trick to get the `show` string
- show_str_lazy = df_lazy.select('proj_raster')._jdf.showString(1, -1, False)
- self.assertTrue('RasterRef' in show_str_lazy)
-
- # again for strict
- df_strict = self.spark.read.raster(self.img_uri, lazy_tiles=False)
- show_str_strict = df_strict.select('proj_raster')._jdf.showString(1, -1, False)
- self.assertTrue('RasterRef' not in show_str_strict)
-
-
- def test_prt_functions(self):
- df = self.spark.read.raster(self.img_uri) \
- .withColumn('crs', rf_crs('proj_raster')) \
- .withColumn('ext', rf_extent('proj_raster')) \
- .withColumn('geom', rf_geometry('proj_raster'))
- df.select('crs', 'ext', 'geom').first()
-
- def test_raster_source_reader(self):
- # much the same as RasterSourceDataSourceSpec here; but using https PDS. Takes about 30s to run
-
- def l8path(b):
- assert b in range(1, 12)
- base = "https://s3-us-west-2.amazonaws.com/landsat-pds/c1/L8/199/026/LC08_L1TP_199026_20180919_20180928_01_T1/LC08_L1TP_199026_20180919_20180928_01_T1_B{}.TIF"
- return base.format(b)
-
- path_param = '\n'.join([l8path(b) for b in [1, 2, 3]]) # "http://foo.com/file1.tif,http://foo.com/file2.tif"
- tile_size = 512
-
- df = self.spark.read.raster(
- tile_dimensions=(tile_size, tile_size),
- paths=path_param,
- lazy_tiles=True,
- ).cache()
-
- # schema is tile_path and tile
- # df.printSchema()
- self.assertTrue(len(df.columns) == 2 and 'proj_raster_path' in df.columns and 'proj_raster' in df.columns)
-
- # the most common tile dimensions should be as passed to `options`, showing that options are correctly applied
- tile_size_df = df.select(rf_dimensions(df.proj_raster).rows.alias('r'), rf_dimensions(df.proj_raster).cols.alias('c')) \
- .groupby(['r', 'c']).count().toPandas()
- most_common_size = tile_size_df.loc[tile_size_df['count'].idxmax()]
- self.assertTrue(most_common_size.r == tile_size and most_common_size.c == tile_size)
-
- # all rows are from a single source URI
- path_count = df.groupby(df.proj_raster_path).count()
- print(path_count.toPandas())
- self.assertTrue(path_count.count() == 3)
-
- def test_raster_source_reader_schemeless(self):
- import os.path
- path = os.path.join(self.resource_dir, "L8-B8-Robinson-IL.tiff")
- self.assertTrue(not path.startswith('file://'))
- df = self.spark.read.raster(path)
- self.assertTrue(df.count() > 0)
-
- def test_raster_source_catalog_reader(self):
- import pandas as pd
-
- scene_dict = {
- 1: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/041/LC08_L1TP_015041_20190305_20190309_01_T1/LC08_L1TP_015041_20190305_20190309_01_T1_B{}.TIF',
- 2: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/042/LC08_L1TP_015042_20190305_20190309_01_T1/LC08_L1TP_015042_20190305_20190309_01_T1_B{}.TIF',
- 3: 'http://landsat-pds.s3.amazonaws.com/c1/L8/016/041/LC08_L1TP_016041_20190224_20190309_01_T1/LC08_L1TP_016041_20190224_20190309_01_T1_B{}.TIF',
- }
-
- def path(scene, band):
- assert band in range(1, 12)
- p = scene_dict[scene]
- return p.format(band)
-
- # Create a pandas dataframe (makes it easy to create spark df)
- path_pandas = pd.DataFrame([
- {'b1': path(1, 1), 'b2': path(1, 2), 'b3': path(1, 3)},
- {'b1': path(2, 1), 'b2': path(2, 2), 'b3': path(2, 3)},
- {'b1': path(3, 1), 'b2': path(3, 2), 'b3': path(3, 3)},
- ])
- # comma separated list of column names containing URI's to read.
- catalog_columns = ','.join(path_pandas.columns.tolist()) # 'b1,b2,b3'
- path_table = self.spark.createDataFrame(path_pandas)
-
- path_df = self.spark.read.raster(
- tile_dimensions=(512, 512),
- catalog=path_table,
- catalog_col_names=catalog_columns,
- lazy_tiles=True # We'll get an OOM error if we try to read 9 scenes all at once!
- )
-
- self.assertTrue(len(path_df.columns) == 6) # three bands times {path, tile}
- self.assertTrue(path_df.select('b1_path').distinct().count() == 3) # as per scene_dict
- b1_paths_maybe = path_df.select('b1_path').distinct().collect()
- b1_paths = [s.format('1') for s in scene_dict.values()]
- self.assertTrue(all([row.b1_path in b1_paths for row in b1_paths_maybe]))
-
- def test_raster_source_catalog_reader_with_pandas(self):
- import pandas as pd
- import geopandas
- from shapely.geometry import Point
-
- scene_dict = {
- 1: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/041/LC08_L1TP_015041_20190305_20190309_01_T1/LC08_L1TP_015041_20190305_20190309_01_T1_B{}.TIF',
- 2: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/042/LC08_L1TP_015042_20190305_20190309_01_T1/LC08_L1TP_015042_20190305_20190309_01_T1_B{}.TIF',
- 3: 'http://landsat-pds.s3.amazonaws.com/c1/L8/016/041/LC08_L1TP_016041_20190224_20190309_01_T1/LC08_L1TP_016041_20190224_20190309_01_T1_B{}.TIF',
- }
-
- def path(scene, band):
- assert band in range(1, 12)
- p = scene_dict[scene]
- return p.format(band)
-
- # Create a pandas dataframe (makes it easy to create spark df)
- path_pandas = pd.DataFrame([
- {'b1': path(1, 1), 'b2': path(1, 2), 'b3': path(1, 3), 'geo': Point(1, 1)},
- {'b1': path(2, 1), 'b2': path(2, 2), 'b3': path(2, 3), 'geo': Point(2, 2)},
- {'b1': path(3, 1), 'b2': path(3, 2), 'b3': path(3, 3), 'geo': Point(3, 3)},
- ])
-
- # here a subtle difference with the test_raster_source_catalog_reader test, feed the DataFrame not a CSV and not an already created spark DF.
- df = self.spark.read.raster(
- catalog=path_pandas,
- catalog_col_names=['b1', 'b2', 'b3']
- )
- self.assertEqual(len(df.columns), 7) # three path cols, three tile cols, and geo
- self.assertTrue('geo' in df.columns)
- self.assertTrue(df.select('b1_path').distinct().count() == 3)
-
-
- # Same test with geopandas
- geo_df = geopandas.GeoDataFrame(path_pandas, crs={'init': 'EPSG:4326'}, geometry='geo')
- df2 = self.spark.read.raster(
- catalog=geo_df,
- catalog_col_names=['b1', 'b2', 'b3']
- )
- self.assertEqual(len(df2.columns), 7) # three path cols, three tile cols, and geo
- self.assertTrue('geo' in df2.columns)
- self.assertTrue(df2.select('b1_path').distinct().count() == 3)
def suite():
diff --git a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py
index 0c71363df..96fa4d1c9 100644
--- a/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py
+++ b/pyrasterframes/src/main/python/tests/RasterFunctionsTests.py
@@ -20,8 +20,11 @@
from pyrasterframes.rasterfunctions import *
from pyrasterframes.utils import gdal_version
+from pyrasterframes.rf_types import Tile
+from pyspark import Row
from pyspark.sql.functions import *
+
from . import TestEnvironment
@@ -265,7 +268,6 @@ def test_cell_type_in_functions(self):
self.assertEqual(result['ct'].cell_type, ct)
self.assertEqual(result['ct_str'].cell_type, ct)
self.assertEqual(result['make'].cell_type, CellType.int8())
- self.assertEqual(result['make2'].cell_type, CellType.int8().with_no_data_value(99))
counts = df.select(
rf_no_data_cells('make').alias("nodata1"),
@@ -278,4 +280,25 @@ def test_cell_type_in_functions(self):
self.assertEqual(counts["nodata1"], 0)
self.assertEqual(counts["data2"], 0)
self.assertEqual(counts["nodata2"], 3 * 4)
+ self.assertEqual(result['make2'].cell_type, CellType.int8().with_no_data_value(99))
+
+ def test_render_composite(self):
+ cat = self.spark.createDataFrame([
+ Row(red=self.l8band_uri(4), green=self.l8band_uri(3), blue=self.l8band_uri(2))
+ ])
+ rf = self.spark.read.raster(catalog = cat, catalog_col_names=['red', 'green', 'blue'])
+
+ # Test composite construction
+ rgb = rf.select(rf_tile(rf_rgb_composite('red', 'green', 'blue')).alias('rgb')).first()['rgb']
+
+ # TODO: how to better test this?
+ self.assertIsInstance(rgb, Tile)
+ self.assertEqual(rgb.dimensions(), [186, 169])
+
+ ## Test PNG generation
+ png_bytes = rf.select(rf_render_png('red', 'green', 'blue').alias('png')).first()['png']
+ # Look for the PNG magic cookie
+ self.assertEqual(png_bytes[0:8], bytearray([0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]))
+
+
diff --git a/pyrasterframes/src/main/python/tests/RasterSourceTests.py b/pyrasterframes/src/main/python/tests/RasterSourceTests.py
new file mode 100644
index 000000000..08ebe078c
--- /dev/null
+++ b/pyrasterframes/src/main/python/tests/RasterSourceTests.py
@@ -0,0 +1,174 @@
+#
+# This software is licensed under the Apache 2 license, quoted below.
+#
+# Copyright 2019 Astraea, Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License"); you may not
+# use this file except in compliance with the License. You may obtain a copy of
+# the License at
+#
+# [http://www.apache.org/licenses/LICENSE-2.0]
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
+# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
+# License for the specific language governing permissions and limitations under
+# the License.
+#
+# SPDX-License-Identifier: Apache-2.0
+#
+
+
+from pyrasterframes.rasterfunctions import *
+from . import TestEnvironment
+
+class RasterSource(TestEnvironment):
+
+ def test_handle_lazy_eval(self):
+ df = self.spark.read.raster(self.img_uri)
+ ltdf = df.select('proj_raster')
+ self.assertGreater(ltdf.count(), 0)
+ self.assertIsNotNone(ltdf.first())
+
+ tdf = df.select(rf_tile('proj_raster'))
+ self.assertGreater(tdf.count(), 0)
+ self.assertIsNotNone(tdf.first())
+
+ def test_strict_eval(self):
+ df_lazy = self.spark.read.raster(self.img_uri, lazy_tiles=True)
+ # when doing Show on a lazy tile we will see something like RasterRefTile(RasterRef(JVMGeoTiffRasterSource(...
+ # use this trick to get the `show` string
+ show_str_lazy = df_lazy.select('proj_raster')._jdf.showString(1, -1, False)
+ self.assertTrue('RasterRef' in show_str_lazy)
+
+ # again for strict
+ df_strict = self.spark.read.raster(self.img_uri, lazy_tiles=False)
+ show_str_strict = df_strict.select('proj_raster')._jdf.showString(1, -1, False)
+ self.assertTrue('RasterRef' not in show_str_strict)
+
+
+ def test_prt_functions(self):
+ df = self.spark.read.raster(self.img_uri) \
+ .withColumn('crs', rf_crs('proj_raster')) \
+ .withColumn('ext', rf_extent('proj_raster')) \
+ .withColumn('geom', rf_geometry('proj_raster'))
+ df.select('crs', 'ext', 'geom').first()
+
+ def test_raster_source_reader(self):
+ # much the same as RasterSourceDataSourceSpec here; but using https PDS. Takes about 30s to run
+
+ def l8path(b):
+ assert b in range(1, 12)
+ base = "https://s3-us-west-2.amazonaws.com/landsat-pds/c1/L8/199/026/LC08_L1TP_199026_20180919_20180928_01_T1/LC08_L1TP_199026_20180919_20180928_01_T1_B{}.TIF"
+ return base.format(b)
+
+ path_param = '\n'.join([l8path(b) for b in [1, 2, 3]]) # "http://foo.com/file1.tif,http://foo.com/file2.tif"
+ tile_size = 512
+
+ df = self.spark.read.raster(
+ tile_dimensions=(tile_size, tile_size),
+ paths=path_param,
+ lazy_tiles=True,
+ ).cache()
+
+ # schema is tile_path and tile
+ # df.printSchema()
+ self.assertTrue(len(df.columns) == 2 and 'proj_raster_path' in df.columns and 'proj_raster' in df.columns)
+
+ # the most common tile dimensions should be as passed to `options`, showing that options are correctly applied
+ tile_size_df = df.select(rf_dimensions(df.proj_raster).rows.alias('r'), rf_dimensions(df.proj_raster).cols.alias('c')) \
+ .groupby(['r', 'c']).count().toPandas()
+ most_common_size = tile_size_df.loc[tile_size_df['count'].idxmax()]
+ self.assertTrue(most_common_size.r == tile_size and most_common_size.c == tile_size)
+
+ # all rows are from a single source URI
+ path_count = df.groupby(df.proj_raster_path).count()
+ print(path_count.toPandas())
+ self.assertTrue(path_count.count() == 3)
+
+ def test_raster_source_reader_schemeless(self):
+ import os.path
+ path = os.path.join(self.resource_dir, "L8-B8-Robinson-IL.tiff")
+ self.assertTrue(not path.startswith('file://'))
+ df = self.spark.read.raster(path)
+ self.assertTrue(df.count() > 0)
+
+ def test_raster_source_catalog_reader(self):
+ import pandas as pd
+
+ scene_dict = {
+ 1: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/041/LC08_L1TP_015041_20190305_20190309_01_T1/LC08_L1TP_015041_20190305_20190309_01_T1_B{}.TIF',
+ 2: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/042/LC08_L1TP_015042_20190305_20190309_01_T1/LC08_L1TP_015042_20190305_20190309_01_T1_B{}.TIF',
+ 3: 'http://landsat-pds.s3.amazonaws.com/c1/L8/016/041/LC08_L1TP_016041_20190224_20190309_01_T1/LC08_L1TP_016041_20190224_20190309_01_T1_B{}.TIF',
+ }
+
+ def path(scene, band):
+ assert band in range(1, 12)
+ p = scene_dict[scene]
+ return p.format(band)
+
+ # Create a pandas dataframe (makes it easy to create spark df)
+ path_pandas = pd.DataFrame([
+ {'b1': path(1, 1), 'b2': path(1, 2), 'b3': path(1, 3)},
+ {'b1': path(2, 1), 'b2': path(2, 2), 'b3': path(2, 3)},
+ {'b1': path(3, 1), 'b2': path(3, 2), 'b3': path(3, 3)},
+ ])
+ # comma separated list of column names containing URI's to read.
+ catalog_columns = ','.join(path_pandas.columns.tolist()) # 'b1,b2,b3'
+ path_table = self.spark.createDataFrame(path_pandas)
+
+ path_df = self.spark.read.raster(
+ tile_dimensions=(512, 512),
+ catalog=path_table,
+ catalog_col_names=catalog_columns,
+ lazy_tiles=True # We'll get an OOM error if we try to read 9 scenes all at once!
+ )
+
+ self.assertTrue(len(path_df.columns) == 6) # three bands times {path, tile}
+ self.assertTrue(path_df.select('b1_path').distinct().count() == 3) # as per scene_dict
+ b1_paths_maybe = path_df.select('b1_path').distinct().collect()
+ b1_paths = [s.format('1') for s in scene_dict.values()]
+ self.assertTrue(all([row.b1_path in b1_paths for row in b1_paths_maybe]))
+
+ def test_raster_source_catalog_reader_with_pandas(self):
+ import pandas as pd
+ import geopandas
+ from shapely.geometry import Point
+
+ scene_dict = {
+ 1: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/041/LC08_L1TP_015041_20190305_20190309_01_T1/LC08_L1TP_015041_20190305_20190309_01_T1_B{}.TIF',
+ 2: 'http://landsat-pds.s3.amazonaws.com/c1/L8/015/042/LC08_L1TP_015042_20190305_20190309_01_T1/LC08_L1TP_015042_20190305_20190309_01_T1_B{}.TIF',
+ 3: 'http://landsat-pds.s3.amazonaws.com/c1/L8/016/041/LC08_L1TP_016041_20190224_20190309_01_T1/LC08_L1TP_016041_20190224_20190309_01_T1_B{}.TIF',
+ }
+
+ def path(scene, band):
+ assert band in range(1, 12)
+ p = scene_dict[scene]
+ return p.format(band)
+
+ # Create a pandas dataframe (makes it easy to create spark df)
+ path_pandas = pd.DataFrame([
+ {'b1': path(1, 1), 'b2': path(1, 2), 'b3': path(1, 3), 'geo': Point(1, 1)},
+ {'b1': path(2, 1), 'b2': path(2, 2), 'b3': path(2, 3), 'geo': Point(2, 2)},
+ {'b1': path(3, 1), 'b2': path(3, 2), 'b3': path(3, 3), 'geo': Point(3, 3)},
+ ])
+
+ # here a subtle difference with the test_raster_source_catalog_reader test, feed the DataFrame not a CSV and not an already created spark DF.
+ df = self.spark.read.raster(
+ catalog=path_pandas,
+ catalog_col_names=['b1', 'b2', 'b3']
+ )
+ self.assertEqual(len(df.columns), 7) # three path cols, three tile cols, and geo
+ self.assertTrue('geo' in df.columns)
+ self.assertTrue(df.select('b1_path').distinct().count() == 3)
+
+
+ # Same test with geopandas
+ geo_df = geopandas.GeoDataFrame(path_pandas, crs={'init': 'EPSG:4326'}, geometry='geo')
+ df2 = self.spark.read.raster(
+ catalog=geo_df,
+ catalog_col_names=['b1', 'b2', 'b3']
+ )
+ self.assertEqual(len(df2.columns), 7) # three path cols, three tile cols, and geo
+ self.assertTrue('geo' in df2.columns)
+ self.assertTrue(df2.select('b1_path').distinct().count() == 3)
diff --git a/pyrasterframes/src/main/python/tests/__init__.py b/pyrasterframes/src/main/python/tests/__init__.py
index 177c9a8c7..152859fb0 100644
--- a/pyrasterframes/src/main/python/tests/__init__.py
+++ b/pyrasterframes/src/main/python/tests/__init__.py
@@ -73,6 +73,10 @@ def setUpClass(cls):
cls.img_uri = 'file://' + os.path.join(cls.resource_dir, 'L8-B8-Robinson-IL.tiff')
+ @classmethod
+ def l8band_uri(cls, band_index):
+ return 'file://' + os.path.join(cls.resource_dir, 'L8-B{}-Elkton-VA.tiff'.format(band_index))
+
def create_layer(self):
from pyrasterframes.rasterfunctions import rf_convert_cell_type
# load something into a rasterframe
diff --git a/pyrasterframes/src/main/scala/org/locationtech/rasterframes/py/PyRFContext.scala b/pyrasterframes/src/main/scala/org/locationtech/rasterframes/py/PyRFContext.scala
index 4baa50d40..30d612fdb 100644
--- a/pyrasterframes/src/main/scala/org/locationtech/rasterframes/py/PyRFContext.scala
+++ b/pyrasterframes/src/main/scala/org/locationtech/rasterframes/py/PyRFContext.scala
@@ -234,12 +234,11 @@ class PyRFContext(implicit sparkSession: SparkSession) extends RasterFunctions
def _dfToMarkdown(df: DataFrame, numRows: Int, truncate: Boolean): String = {
import rasterframes.util.DFWithPrettyPrint
- df.toMarkdown(numRows, truncate)
+ df.toMarkdown(numRows, truncate, renderTiles = true)
}
def _dfToHTML(df: DataFrame, numRows: Int, truncate: Boolean): String = {
import rasterframes.util.DFWithPrettyPrint
- df.toHTML(numRows, truncate)
+ df.toHTML(numRows, truncate, renderTiles = true)
}
-
}