diff --git a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/ExplodeTiles.scala b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/ExplodeTiles.scala index bd2a4689a..2a70be585 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/ExplodeTiles.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/expressions/generators/ExplodeTiles.scala @@ -21,16 +21,15 @@ package org.locationtech.rasterframes.expressions.generators -import org.locationtech.rasterframes._ -import org.locationtech.rasterframes.encoders.CatalystSerializer._ -import org.locationtech.rasterframes.util._ import geotrellis.raster._ import org.apache.spark.sql._ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenFallback import org.apache.spark.sql.catalyst.expressions.{Expression, Generator, GenericInternalRow} -import org.apache.spark.sql.rf.TileUDT import org.apache.spark.sql.types._ +import org.locationtech.rasterframes._ +import org.locationtech.rasterframes.expressions.DynamicExtractors +import org.locationtech.rasterframes.util._ import spire.syntax.cfor.cfor /** @@ -67,8 +66,11 @@ case class ExplodeTiles( override def eval(input: InternalRow): TraversableOnce[InternalRow] = { val tiles = Array.ofDim[Tile](children.length) cfor(0)(_ < tiles.length, _ + 1) { index => - val row = children(index).eval(input).asInstanceOf[InternalRow] - tiles(index) = if(row != null) row.to[Tile](TileUDT.tileSerializer) else null + val c = children(index) + val row = c.eval(input).asInstanceOf[InternalRow] + tiles(index) = if(row != null) + DynamicExtractors.tileExtractor(c.dataType)(row)._1 + else null } val dims = tiles.filter(_ != null).map(_.dimensions) if(dims.isEmpty) Seq.empty[InternalRow] diff --git a/core/src/main/scala/org/locationtech/rasterframes/ml/TileColumnSupport.scala b/core/src/main/scala/org/locationtech/rasterframes/ml/TileColumnSupport.scala index d261f7e91..5e8a2537f 100644 --- a/core/src/main/scala/org/locationtech/rasterframes/ml/TileColumnSupport.scala +++ b/core/src/main/scala/org/locationtech/rasterframes/ml/TileColumnSupport.scala @@ -21,8 +21,8 @@ package org.locationtech.rasterframes.ml -import org.apache.spark.sql.rf.TileUDT import org.apache.spark.sql.types.{StructField, StructType} +import org.locationtech.rasterframes.expressions.DynamicExtractors /** * Utility mix-in for separating out tile columns from non-tile columns. @@ -31,13 +31,11 @@ import org.apache.spark.sql.types.{StructField, StructType} */ trait TileColumnSupport { protected def isTile(field: StructField) = - field.dataType.typeName.equalsIgnoreCase(TileUDT.typeName) + DynamicExtractors.tileExtractor.isDefinedAt(field.dataType) type TileFields = Array[StructField] type NonTileFields = Array[StructField] protected def selectTileAndNonTileFields(schema: StructType): (TileFields, NonTileFields) = { - val tiles = schema.fields.filter(isTile) - val nonTiles = schema.fields.filterNot(isTile) - (tiles, nonTiles) + schema.fields.partition(f => DynamicExtractors.tileExtractor.isDefinedAt(f.dataType)) } } diff --git a/core/src/test/scala/org/locationtech/rasterframes/ml/TileExploderSpec.scala b/core/src/test/scala/org/locationtech/rasterframes/ml/TileExploderSpec.scala index 2d9e2d04c..b79f1bdf8 100644 --- a/core/src/test/scala/org/locationtech/rasterframes/ml/TileExploderSpec.scala +++ b/core/src/test/scala/org/locationtech/rasterframes/ml/TileExploderSpec.scala @@ -21,28 +21,48 @@ package org.locationtech.rasterframes.ml -import org.locationtech.rasterframes.TestData -import geotrellis.raster.Tile -import org.apache.spark.sql.functions.lit -import org.locationtech.rasterframes.TestEnvironment +import geotrellis.proj4.LatLng +import geotrellis.raster.{IntCellType, Tile} +import org.apache.spark.sql.functions.{avg, lit} +import org.locationtech.rasterframes.{TestData, TestEnvironment} /** * * @since 2/16/18 */ class TileExploderSpec extends TestEnvironment with TestData { describe("Tile explode transformer") { - it("should explode tiles") { - import spark.implicits._ + import spark.implicits._ + it("should explode tile") { val df = Seq[(Tile, Tile)]((byteArrayTile, byteArrayTile)).toDF("tile1", "tile2").withColumn("other", lit("stuff")) val exploder = new TileExploder() val newSchema = exploder.transformSchema(df.schema) val exploded = exploder.transform(df) + assert(newSchema === exploded.schema) assert(exploded.columns.length === 5) assert(exploded.count() === 9) write(exploded) + exploded.agg(avg($"tile1")).as[Double].first() should be (byteArrayTile.statisticsDouble.get.mean) + } + + it("should explode proj_raster") { + val randPRT = TestData.projectedRasterTile(10, 10, scala.util.Random.nextInt(), extent, LatLng, IntCellType) + + val df = Seq(randPRT).toDF("proj_raster").withColumn("other", lit("stuff")) + + val exploder = new TileExploder() + val newSchema = exploder.transformSchema(df.schema) + + val exploded = exploder.transform(df) + + assert(newSchema === exploded.schema) + assert(exploded.columns.length === 4) + assert(exploded.count() === randPRT.size) + write(exploded) + + exploded.agg(avg($"proj_raster")).as[Double].first() should be (randPRT.statisticsDouble.get.mean) } } } diff --git a/docs/src/main/paradox/release-notes.md b/docs/src/main/paradox/release-notes.md index 2a108de99..e750c6d76 100644 --- a/docs/src/main/paradox/release-notes.md +++ b/docs/src/main/paradox/release-notes.md @@ -2,6 +2,10 @@ ## 0.8.x +### 0.8.2 + +* Fixed `TileExploder` to support `proj_raster` struct [(#287)](https://github.com/locationtech/rasterframes/issues/287). + ### 0.8.1 * Added `rf_local_no_data`, `rf_local_data` and `rf_interpret_cell_type_as` raster functions. diff --git a/pyrasterframes/src/main/python/docs/vector-data.pymd b/pyrasterframes/src/main/python/docs/vector-data.pymd index 8d50f20db..1762565ad 100644 --- a/pyrasterframes/src/main/python/docs/vector-data.pymd +++ b/pyrasterframes/src/main/python/docs/vector-data.pymd @@ -93,7 +93,7 @@ l8_filtered = l8 \ .filter(st_intersects(l8.geom, st_bufferPoint(l8.paducah, lit(50000.0)))) \ .filter(l8.acquisition_date > '2018-02-01') \ .filter(l8.acquisition_date < '2018-04-01') -l8_filtered.select('product_id', 'entity_id', 'acquisition_date', 'cloud_cover_pct').toPandas() +l8_filtered.select('product_id', 'entity_id', 'acquisition_date', 'cloud_cover_pct') ``` [GeoPandas]: http://geopandas.org diff --git a/pyrasterframes/src/main/python/tests/ExploderTests.py b/pyrasterframes/src/main/python/tests/ExploderTests.py index 05ff958ae..f7635ad7a 100644 --- a/pyrasterframes/src/main/python/tests/ExploderTests.py +++ b/pyrasterframes/src/main/python/tests/ExploderTests.py @@ -33,7 +33,6 @@ class ExploderTests(TestEnvironment): - @unittest.skip("See issue https://github.com/locationtech/rasterframes/issues/163") def test_tile_exploder_pipeline_for_prt(self): # NB the tile is a Projected Raster Tile df = self.spark.read.raster(self.img_uri)