-
Notifications
You must be signed in to change notification settings - Fork 8
/
Copy pathRowTransformation.scala
90 lines (70 loc) · 2.47 KB
/
RowTransformation.scala
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
package diamond.transform.row
import diamond.transform.{Transformation, TransformationContext}
import org.apache.spark.sql.Row
import org.apache.spark.sql.types._
import scala.collection.mutable
/**
* A general row transformation that takes a Row and returns a new Row.
*
* The new Row may conform to a different schema. The new Row may be
* computed with reference to any values in the original Row or to any
* values in the TransformationContext.
*
* Created by markmo on 12/12/2015.
*/
trait RowTransformation extends Transformation {
val dependencies = mutable.Set[RowTransformation]()
def apply(row: Row, ctx: TransformationContext): Row
def addDependencies(dependencies: RowTransformation*) = {
this.dependencies ++= dependencies
this
}
def edges: Traversable[(RowTransformation, RowTransformation)] = dependencies.map((_, this))
}
/*
* experimental feature only implemented from 2.11
* http://stackoverflow.com/questions/25234682/in-scala-can-you-make-an-anonymous-function-have-a-default-argument
*
trait TransformFunc {
def apply(row: Row, ctx: TransformationContext, fieldLocator: (String => Any) = { _ => None }): Row
}*/
object RowTransformation {
val SCHEMA_KEY = "schema"
//def apply(name: String)(op: TransformFunc) =
def apply(name: String)(op: (Row, TransformationContext) => Row) = {
val myName = name
new RowTransformation {
val name = myName
def apply(row: Row, ctx: TransformationContext): Row = op(row, ctx)
}
}
/**
* Given a row, returns a function that will lookup fields by name
* and return the right type.
*
* Depends on the row schema being set in the context using
* Transformation.SCHEMA_KEY
*
* @param row a spark.sql.Row
* @param ctx a TransformationContext
* @return a function
*/
def fieldLocator(row: Row, ctx: TransformationContext) =
if (ctx.contains(SCHEMA_KEY)) {
val schema = ctx(SCHEMA_KEY).asInstanceOf[StructType]
(name: String) => {
val field = schema(name)
val index = schema.fieldIndex(name)
field.dataType match {
case IntegerType => row.getInt(index)
case DoubleType => row.getDouble(index)
case BooleanType => row.getBoolean(index)
case DateType => row.getDate(index)
case TimestampType => row.getTimestamp(index)
case _ => row.getString(index)
}
}
} else {
throw new RuntimeException("No schema in context")
}
}