Skip to content

Commit

Permalink
Performance optimization for aggregates: do not save context of varia…
Browse files Browse the repository at this point in the history
…bles in state
  • Loading branch information
arkadius committed Jul 6, 2021
1 parent 5a0fa95 commit fa186b9
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 185 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -6,5 +6,6 @@ object VariableConstants {
final val InputMetaVariableName = "inputMeta"
final val MetaVariableName = "meta"
final val OutputVariableName = "output"
final val KeyVariableName = "key"

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,12 @@ case class ValidationContext(localVariables: Map[String, TypingResult] = Map.emp
def withVariable(outputVar: OutputVar, value: TypingResult)(implicit nodeId: NodeId): ValidatedNel[PartSubGraphCompilationError, ValidationContext] =
withVariable(outputVar.outputName, value, Some(outputVar.fieldName))

def withVariableOverriden(name: String, value: TypingResult, paramName: Option[String])
(implicit nodeId: NodeId): ValidatedNel[PartSubGraphCompilationError, ValidationContext] = {
validateVariableFormat(name, paramName)
.map(_ => copy(localVariables = localVariables + (name -> value)))
}

private def validateVariableExists(name: String, paramName: Option[String])(implicit nodeId: NodeId): ValidatedNel[PartSubGraphCompilationError, String] =
if (variables.contains(name)) Invalid(OverwrittenVariable(name, paramName)).toValidatedNel else Valid(name)

Expand Down Expand Up @@ -90,4 +96,4 @@ object OutputVar {

def customNode(outputName: String): OutputVar =
OutputVar(CustomNodeFieldName, outputName)
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import pl.touk.nussknacker.engine.api.LazyParameter
import pl.touk.nussknacker.engine.api.context.ProcessCompilationError.{CannotCreateObjectError, NodeId}
import pl.touk.nussknacker.engine.api.context.{ProcessCompilationError, ValidationContext}
import pl.touk.nussknacker.engine.api.typed.typing.TypingResult
import pl.touk.nussknacker.engine.flink.util.keyed.KeyEnricher

/*
This class serves two purposes:
Expand Down Expand Up @@ -49,9 +50,14 @@ abstract class Aggregator extends AggregateFunction[AnyRef, AnyRef, AnyRef] {

override final def merge(a: AnyRef, b: AnyRef): AnyRef = mergeAggregates(a.asInstanceOf[Aggregate], b.asInstanceOf[Aggregate])

final def toContextTransformation(variableName: String, aggregateBy: LazyParameter[_])(implicit nodeId: NodeId):
ValidationContext => ValidatedNel[ProcessCompilationError, ValidationContext] = validationCtx => computeOutputType(aggregateBy.returnType)
final def toContextTransformation(variableName: String, emitContext: Boolean, aggregateBy: LazyParameter[_])(implicit nodeId: NodeId):
ValidationContext => ValidatedNel[ProcessCompilationError, ValidationContext] = validationCtx =>
computeOutputType(aggregateBy.returnType)
//TODO: better error?
.leftMap(message => NonEmptyList.of(CannotCreateObjectError(message, nodeId.id)))
.andThen(validationCtx.withVariable(variableName, _, paramName = None))
.andThen { outputType =>
val ctx = if (emitContext) validationCtx else ValidationContext.empty
ctx.withVariable(variableName, outputType, paramName = None)
}.andThen(KeyEnricher.contextTransformation)

}
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package pl.touk.nussknacker.engine.flink.util.transformer.aggregate

import java.util.concurrent.TimeUnit

import cats.data.NonEmptyList
import com.codahale.metrics.{Histogram, SlidingTimeWindowReservoir}
import org.apache.flink.api.common.functions.RuntimeContext
Expand All @@ -17,7 +16,7 @@ import pl.touk.nussknacker.engine.api.context.ProcessCompilationError.NodeId
import pl.touk.nussknacker.engine.api.typed.typing.TypingResult
import pl.touk.nussknacker.engine.api.{ValueWithContext, Context => NkContext}
import pl.touk.nussknacker.engine.flink.api.state.{LatelyEvictableStateFunction, StateHolder}
import pl.touk.nussknacker.engine.flink.util.keyed.StringKeyedValue
import pl.touk.nussknacker.engine.flink.util.keyed.{KeyEnricher, StringKeyedValue}
import pl.touk.nussknacker.engine.flink.util.metrics.MetricUtils
import pl.touk.nussknacker.engine.flink.util.orderedmap.FlinkRangeMap
import pl.touk.nussknacker.engine.flink.util.orderedmap.FlinkRangeMap._
Expand Down Expand Up @@ -45,7 +44,7 @@ class AggregatorFunction[MapT[K,V]](protected val aggregator: Aggregator, protec

}

trait AggregatorFunctionMixin[MapT[K,V]] { self: StateHolder[MapT[Long, AnyRef]] =>
trait AggregatorFunctionMixin[MapT[K,V]] extends KeyEnricher { self: StateHolder[MapT[Long, AnyRef]] =>

def getRuntimeContext: RuntimeContext

Expand Down Expand Up @@ -90,7 +89,7 @@ trait AggregatorFunctionMixin[MapT[K,V]] { self: StateHolder[MapT[Long, AnyRef]]
val newState = addElementToState(value, timestamp, timeService, out)
val finalVal = computeFinalValue(newState, timestamp)
timeHistogram.update(System.nanoTime() - start)
out.collect(ValueWithContext(finalVal, value.context))
out.collect(ValueWithContext(finalVal, enrichWithKey(value.context, value.value)))
}

protected def addElementToState(value: ValueWithContext[StringKeyedValue[AnyRef]],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class EmitExtraWindowWhenNoDataTumblingAggregatorFunction[MapT[K,V]](protected v
(implicit override val rangeMap: FlinkRangeMap[MapT])
extends KeyedProcessFunction[String, ValueWithContext[StringKeyedValue[AnyRef]], ValueWithContext[AnyRef]]
with StateHolder[MapT[Long, AnyRef]]
with AggregatorFunctionMixin[MapT] with AddedElementContextStateHolder[MapT] {
with AggregatorFunctionMixin[MapT] {

type FlinkCtx = KeyedProcessFunction[String, ValueWithContext[StringKeyedValue[AnyRef]], ValueWithContext[AnyRef]]#Context
type FlinkOnTimerCtx = KeyedProcessFunction[String, ValueWithContext[StringKeyedValue[AnyRef]], ValueWithContext[AnyRef]]#OnTimerContext
Expand All @@ -38,7 +38,6 @@ class EmitExtraWindowWhenNoDataTumblingAggregatorFunction[MapT[K,V]](protected v

override def open(parameters: Configuration): Unit = {
state = getRuntimeContext.getState(stateDescriptor)
addedElementContext = getRuntimeContext.getState(addedElementContextDescriptor)
}

override protected val minimalResolutionMs: Long = timeWindowLengthMillis
Expand All @@ -49,20 +48,15 @@ class EmitExtraWindowWhenNoDataTumblingAggregatorFunction[MapT[K,V]](protected v

override protected def handleElementAddedToState(newElementInStateTimestamp: Long, newElement: aggregator.Element, nkCtx: NkContext,
timerService: TimerService, out: Collector[ValueWithContext[AnyRef]]): Unit = {
addedElementContext.update(readAddedElementContextOrInitial().updated(newElementInStateTimestamp, nkCtx))
timerService.registerEventTimeTimer(newElementInStateTimestamp + timeWindowLengthMillis)
}

override def onTimer(timestamp: Long, ctx: FlinkOnTimerCtx, out: Collector[ValueWithContext[AnyRef]]): Unit = {
val currentStateValue = readStateOrInitial()
val previousTimestamp = timestamp - timeWindowLengthMillis
val currentStateValue = readStateOrInitial()
val finalVal = computeFinalValue(currentStateValue, previousTimestamp)
out.collect(ValueWithContext(finalVal, enrichWithKey(NkContext(""), ctx.getCurrentKey)))

readAddedElementContextOrInitial().toRO(previousTimestamp).toScalaMapRO.lastOption.foreach {
case (_, nkCtx) =>
val finalVal = computeFinalValue(currentStateValue, previousTimestamp)
out.collect(ValueWithContext(finalVal, nkCtx))
}

val previousTimestampStateAndRest = stateForTimestampToReadUntilEnd(currentStateValue, previousTimestamp)
if (previousTimestampStateAndRest.toScalaMapRO.isEmpty) {
evictStates()
Expand All @@ -73,7 +67,6 @@ class EmitExtraWindowWhenNoDataTumblingAggregatorFunction[MapT[K,V]](protected v

override protected def updateState(stateValue: MapT[Long, AnyRef], stateValidity: Long, timeService: TimerService): Unit = {
state.update(stateValue)
invalidateAddedElementContextState(stateValue)
}

override protected def doMoveEvictionTime(time: Long, timeService: TimerService): Unit = {
Expand All @@ -82,7 +75,6 @@ class EmitExtraWindowWhenNoDataTumblingAggregatorFunction[MapT[K,V]](protected v

protected def evictStates(): Unit = {
state.clear()
addedElementContext.clear()
}

override protected def readState(): MapT[Long, AnyRef] = state.value()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,13 @@ class EmitWhenEventLeftAggregatorFunction[MapT[K,V]](protected val aggregator: A
override protected val aggregateTypeInformation: TypeInformation[AnyRef])
(implicit override val rangeMap: FlinkRangeMap[MapT])
extends LatelyEvictableStateFunction[ValueWithContext[StringKeyedValue[AnyRef]], ValueWithContext[AnyRef], MapT[Long, AnyRef]]
with AggregatorFunctionMixin[MapT] with AddedElementContextStateHolder[MapT] {
with AggregatorFunctionMixin[MapT] {

type FlinkCtx = KeyedProcessFunction[String, ValueWithContext[StringKeyedValue[AnyRef]], ValueWithContext[AnyRef]]#Context
type FlinkOnTimerCtx = KeyedProcessFunction[String, ValueWithContext[StringKeyedValue[AnyRef]], ValueWithContext[AnyRef]]#OnTimerContext

override def open(parameters: Configuration): Unit = {
super.open(parameters)
addedElementContext = getRuntimeContext.getState(addedElementContextDescriptor)
}

override def processElement(value: ValueWithContext[StringKeyedValue[AnyRef]], ctx: FlinkCtx, out: Collector[ValueWithContext[AnyRef]]): Unit = {
Expand All @@ -39,39 +38,24 @@ class EmitWhenEventLeftAggregatorFunction[MapT[K,V]](protected val aggregator: A

override protected def handleElementAddedToState(newElementInStateTimestamp: Long, newElement: aggregator.Element, nkCtx: NkContext,
timerService: TimerService, out: Collector[ValueWithContext[AnyRef]]): Unit = {
addedElementContext.update(readAddedElementContextOrInitial().updated(newElementInStateTimestamp, nkCtx))
timerService.registerEventTimeTimer(newElementInStateTimestamp + timeWindowLengthMillis)
}

override def onTimer(timestamp: Long, ctx: FlinkOnTimerCtx, out: Collector[ValueWithContext[AnyRef]]): Unit = {
val currentStateValue = readStateOrInitial()
handleElementLeftSlide(currentStateValue, timestamp, ctx.timerService(), out)
handleElementLeftSlide(currentStateValue, timestamp, ctx, out)
super.onTimer(timestamp, ctx, out)
}

protected def handleElementLeftSlide(currentStateValue: MapT[Long, aggregator.Aggregate], timestamp: Long,
timerService: TimerService, out: Collector[ValueWithContext[AnyRef]]): Unit = {
ctx: FlinkOnTimerCtx, out: Collector[ValueWithContext[AnyRef]]): Unit = {
val stateForRecentlySentEvent = currentStateValue.toScalaMapRO.lastOption.map {
case (lastTimestamp, _) => stateForTimestampToReadUntilEnd(currentStateValue, lastTimestamp) // shouldn't we save somewhere recently sent timestamp?
case (lastTimestamp, _) => stateForTimestampToReadUntilEnd(currentStateValue, lastTimestamp) // shouldn't we save somewhere recently sent timestamp?
}.getOrElse(currentStateValue)
for {
lastEntryToRemove <- stateForRecentlySentEvent.toRO(timestamp - timeWindowLengthMillis).toScalaMapRO.lastOption
(lastTimestampToRemove, _) = lastEntryToRemove
matchingContext <- readAddedElementContextOrInitial().toScalaMapRO.get(lastTimestampToRemove)
} {
if (stateForRecentlySentEvent.toRO(timestamp - timeWindowLengthMillis).toScalaMapRO.nonEmpty) {
val finalVal = computeFinalValue(currentStateValue, timestamp)
out.collect(ValueWithContext(finalVal, matchingContext))
out.collect(ValueWithContext(finalVal, enrichWithKey(NkContext(""), ctx.getCurrentKey)))
}
}

override protected def updateState(stateValue: MapT[Long, AnyRef], stateValidity: Long, timeService: TimerService): Unit = {
super.updateState(stateValue, stateValidity, timeService)
invalidateAddedElementContextState(stateValue)
}

override protected def evictStates(): Unit = {
super.evictStates()
addedElementContext.clear()
}

}

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,54 +1,40 @@
package pl.touk.nussknacker.engine.flink.util.transformer.aggregate

import org.apache.flink.api.common.functions.AggregateFunction
import org.apache.flink.api.java.tuple
import org.apache.flink.api.java.tuple.Tuple2
import pl.touk.nussknacker.engine.api.typed.typing.TypingResult
import pl.touk.nussknacker.engine.api.{Context, ValueWithContext}
import pl.touk.nussknacker.engine.flink.util.transformer.aggregate.UnwrappingAggregateFunction.AccumulatorWithContext
import pl.touk.nussknacker.engine.flink.util.keyed.{KeyEnricher, StringKeyedValue}

/**
* This class unwraps value from input's KeyedValue. It also accumulate first Nussknacker's context that will be passed in output at the end.
* This class unwraps value from input's KeyedValue. It also accumulate key that will be passed in output at the end.
*
* NOTE: it would be much cleaner if we evaluated aggregateBy here. However, FLINK-10250 prevents us from doing this and we *have* to compute it beforehand
*
* When using this class it's important that agggregator, passedType and unwrap must match: unwrap result is of passedType and can be processed by aggregator
*/
object UnwrappingAggregateFunction {
//We use Tuple2 here, to create TypeInformation more easily
type AccumulatorWithContext = Tuple2[AnyRef, Context]
}

class UnwrappingAggregateFunction[T](aggregator: Aggregator,
passedType: TypingResult,
unwrap: T => AnyRef,
outputContextStrategy: OutputContextStrategy)
extends AggregateFunction[ValueWithContext[T], Tuple2[AnyRef, Context], ValueWithContext[AnyRef]] {
class UnwrappingAggregateFunction[Input](aggregator: Aggregator,
passedType: TypingResult,
unwrapAggregatedValue: Input => AnyRef)
extends AggregateFunction[ValueWithContext[StringKeyedValue[Input]], StringKeyedValue[AnyRef], ValueWithContext[AnyRef]] with KeyEnricher {

private val expectedType = aggregator.computeOutputType(passedType)
.valueOr(msg => throw new IllegalArgumentException(msg))

override def createAccumulator(): AccumulatorWithContext = new Tuple2(aggregator.createAccumulator(), null)
override def createAccumulator(): StringKeyedValue[AnyRef] = StringKeyedValue(null, aggregator.createAccumulator())

override def add(value: ValueWithContext[T], accumulator: AccumulatorWithContext): AccumulatorWithContext = {
val underlyingAcc = aggregator.add(unwrap(value.value), accumulator.f0)
val contextToUse = outputContextStrategy.transform(Option(accumulator.f1), value.context)
new Tuple2(underlyingAcc, contextToUse.orNull)
override def add(wrappedInput: ValueWithContext[StringKeyedValue[Input]], accumulator: StringKeyedValue[AnyRef]): StringKeyedValue[AnyRef] = {
wrappedInput.value.mapValue(input => aggregator.add(unwrapAggregatedValue(input), accumulator.value))
}

override def getResult(accumulator: AccumulatorWithContext): ValueWithContext[AnyRef] = {
val accCtx = Option(accumulator.f1).getOrElse(outputContextStrategy.empty)
val finalResult = aggregator.alignToExpectedType(aggregator.getResult(accumulator.f0), expectedType)
ValueWithContext(finalResult, accCtx)
override def getResult(accumulator: StringKeyedValue[AnyRef]): ValueWithContext[AnyRef] = {
val finalResult = aggregator.alignToExpectedType(aggregator.getResult(accumulator.value), expectedType)
ValueWithContext(finalResult, enrichWithKey(Context(""), accumulator))
}

override def merge(a: AccumulatorWithContext, b: AccumulatorWithContext): AccumulatorWithContext = {
val underlyingAcc = aggregator.merge(a.f0, b.f0)
val firstContext = Option(a.f1).getOrElse(b.f1)
new tuple.Tuple2[AnyRef, Context](underlyingAcc, firstContext)
override def merge(a: StringKeyedValue[AnyRef], b: StringKeyedValue[AnyRef]): StringKeyedValue[AnyRef] = {
val mergedKey = Option(a.key).getOrElse(b.key)
val mergedValue = aggregator.merge(a.value, b.value)
StringKeyedValue(mergedKey, mergedValue)
}

}



Loading

0 comments on commit fa186b9

Please sign in to comment.