Skip to content

Commit

Permalink
Add ListState implementation.
Browse files Browse the repository at this point in the history
  • Loading branch information
sahnib committed Jan 3, 2024
1 parent b29edf5 commit 3bcdac6
Show file tree
Hide file tree
Showing 18 changed files with 709 additions and 63 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.streaming

import org.apache.spark.annotation.{Evolving, Experimental}

@Experimental
@Evolving
/**
* Interface used for arbitrary stateful operations with the v2 API to capture
* list value state.
*/
trait ListState[S] extends Serializable {

/** Whether state exists or not. */
def exists(): Boolean

/** Get the state value if it exists */
def get(): Iterator[S]

/** Get the list value as an option if it exists and None otherwise */
def getOption(): Option[Iterator[S]]

/** Update the value of the list. */
def put(newState: Seq[S]): Unit

/** Append an entry to the list */
def appendValue(newState: S): Unit

/** Append an entire list to the existing value */
def appendList(newState: Seq[S]): Unit

/** Remove this state. */
def remove(): Unit
}
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,10 @@ trait StatefulProcessorHandle extends Serializable {

/** Function to return queryInfo for currently running task */
def getQueryInfo(): QueryInfo

/**
* Creates new or returns existing list state associated with stateName.
* The ListState persists values of type T.
*/
def getListState[T](stateName: String): ListState[T]
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.execution.streaming

import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.streaming.ListState

/**
* Provides concrete implementation for list of values associated with a state variable
* used in the streaming transformWithState operator.
*
* @param store - reference to the StateStore instance to be used for storing state
* @param stateName - name of logical state partition
* @tparam S - data type of object that will be stored in the list
*/
class ListStateImpl[S](store: StateStore,
stateName: String) extends ListState[S] with Logging {

/** Whether state exists or not. */
override def exists(): Boolean = {
val stateValue = store.get(StateEncoder.encodeGroupingKey(stateName), stateName)
stateValue != null
}

/** Get the state value if it exists. If the state does not exist in state store, an
* empty iterator is returned. */
override def get(): Iterator[S] = {
val encodedKey = StateEncoder.encodeGroupingKey(stateName)
val unsafeRowValuesIterator = store.valuesIterator(encodedKey, stateName)
new Iterator[S] {
override def hasNext: Boolean = {
unsafeRowValuesIterator.hasNext
}

override def next(): S = {
val valueUnsafeRow = unsafeRowValuesIterator.next()
StateEncoder.decodeValue(valueUnsafeRow)
}
}
}

/** Get the list value as an option if it exists and None otherwise. */
override def getOption(): Option[Iterator[S]] = {
Option(get())
}

/** Update the value of the list. */
override def put(newState: Seq[S]): Unit = {
validateNewState(newState)

if (newState.isEmpty) {
this.remove()
} else {
val encodedKey = StateEncoder.encodeGroupingKey(stateName)

var isFirst = true
newState.foreach { v =>
val encodedValue = StateEncoder.encodeValue(v)
if (isFirst) {
store.put(encodedKey, encodedValue, stateName)
isFirst = false
} else {
store.merge(encodedKey, encodedValue, stateName)
}
}
}
}

/** Append an entry to the list. */
override def appendValue(newState: S): Unit = {
if (newState == null) {
throw new IllegalArgumentException("value added to ListState should be non-null")
}
store.merge(StateEncoder.encodeGroupingKey(stateName),
StateEncoder.encodeValue(newState), stateName)
}

/** Append an entire list to the existing value. */
override def appendList(newState: Seq[S]): Unit = {
validateNewState(newState)

val encodedKey = StateEncoder.encodeGroupingKey(stateName)
newState.foreach { v =>
val encodedValue = StateEncoder.encodeValue(v)
store.merge(encodedKey, encodedValue, stateName)
}
}

/** Remove this state. */
override def remove(): Unit = {
store.remove(StateEncoder.encodeGroupingKey(stateName), stateName)
}

private def validateNewState(newState: Seq[S]): Unit = {
if (newState == null) {
throw new IllegalArgumentException("newState list should be non-null")
}

val containsNullElements = newState.contains(null)
if (containsNullElements) {
throw new IllegalArgumentException("value added to ListState should be non-null")
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.sql.execution.streaming

import org.apache.commons.lang3.SerializationUtils

import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.{UnsafeProjection, UnsafeRow}
import org.apache.spark.sql.types.{BinaryType, StructType}

/**
* Helper object providing APIs to encodes the grouping key, and user provided values.
*/
object StateEncoder {

// TODO: validate places that are trying to encode the key and check if we can eliminate/
// add caching for some of these calls.
def encodeGroupingKey(stateName: String): UnsafeRow = {
val keyOption = ImplicitKeyTracker.getImplicitKeyOption
if (keyOption.isEmpty) {
throw new UnsupportedOperationException("Implicit key not found for operation on" +
s"stateName=$stateName")
}

val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)
val keyByteArr = SerializationUtils.serialize(keyOption.get.asInstanceOf[Serializable])
val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
val keyRow = keyEncoder(InternalRow(keyByteArr))
keyRow
}

def encodeValue[S] (value: S): UnsafeRow = {
val schemaForValueRow: StructType = new StructType().add("value", BinaryType)
val valueByteArr = SerializationUtils.serialize(value.asInstanceOf[Serializable])
val valueEncoder = UnsafeProjection.create(schemaForValueRow)
val valueRow = valueEncoder(InternalRow(valueByteArr))
valueRow
}

def decodeValue[S](row: UnsafeRow): S = {
SerializationUtils
.deserialize(row.getBinary(0))
.asInstanceOf[S]
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import java.util.UUID
import org.apache.spark.TaskContext
import org.apache.spark.internal.Logging
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.streaming.{QueryInfo, StatefulProcessorHandle, ValueState}
import org.apache.spark.sql.streaming.{ListState, QueryInfo, StatefulProcessorHandle, ValueState}
import org.apache.spark.util.Utils

/**
Expand Down Expand Up @@ -123,5 +123,14 @@ class StatefulProcessorHandleImpl(store: StateStore, runId: UUID)
resultState
}


override def getQueryInfo(): QueryInfo = currQueryInfo

override def getListState[T](stateName: String): ListState[T] = {
verify(currState == CREATED, s"Cannot create state variable with name=$stateName after " +
"initialization is complete")
store.createColFamilyIfAbsent(stateName)
val resultState = new ListStateImpl[T](store, stateName)
resultState
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,8 @@ case class TransformWithStateExec(
numColsPrefixKey = 0,
session.sqlContext.sessionState,
Some(session.sqlContext.streams.stateStoreCoordinator),
useColumnFamilies = true
useColumnFamilies = true,
useStatefulProcessorEncoder = true
) {
case (store: StateStore, singleIterator: Iterator[InternalRow]) =>
val processorHandle = new StatefulProcessorHandleImpl(store, getStateInfo.queryRunId)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,16 +16,10 @@
*/
package org.apache.spark.sql.execution.streaming

import java.io.Serializable

import org.apache.commons.lang3.SerializationUtils

import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.UnsafeRow
import org.apache.spark.sql.execution.streaming.state.StateStore
import org.apache.spark.sql.streaming.ValueState
import org.apache.spark.sql.types._

/**
* Class that provides a concrete implementation for a single value state associated with state
Expand All @@ -38,30 +32,6 @@ class ValueStateImpl[S](
store: StateStore,
stateName: String) extends ValueState[S] with Logging {

// TODO: validate places that are trying to encode the key and check if we can eliminate/
// add caching for some of these calls.
private def encodeKey(): UnsafeRow = {
val keyOption = ImplicitKeyTracker.getImplicitKeyOption
if (!keyOption.isDefined) {
throw new UnsupportedOperationException("Implicit key not found for operation on" +
s"stateName=$stateName")
}

val schemaForKeyRow: StructType = new StructType().add("key", BinaryType)
val keyByteArr = SerializationUtils.serialize(keyOption.get.asInstanceOf[Serializable])
val keyEncoder = UnsafeProjection.create(schemaForKeyRow)
val keyRow = keyEncoder(InternalRow(keyByteArr))
keyRow
}

private def encodeValue(value: S): UnsafeRow = {
val schemaForValueRow: StructType = new StructType().add("value", BinaryType)
val valueByteArr = SerializationUtils.serialize(value.asInstanceOf[Serializable])
val valueEncoder = UnsafeProjection.create(schemaForValueRow)
val valueRow = valueEncoder(InternalRow(valueByteArr))
valueRow
}

/** Function to check if state exists. Returns true if present and false otherwise */
override def exists(): Boolean = {
getImpl() != null
Expand All @@ -71,9 +41,7 @@ class ValueStateImpl[S](
override def getOption(): Option[S] = {
val retRow = getImpl()
if (retRow != null) {
val resState = SerializationUtils
.deserialize(retRow.getBinary(0))
.asInstanceOf[S]
val resState = StateEncoder.decodeValue[S](retRow)
Some(resState)
} else {
None
Expand All @@ -84,26 +52,25 @@ class ValueStateImpl[S](
override def get(): S = {
val retRow = getImpl()
if (retRow != null) {
val resState = SerializationUtils
.deserialize(retRow.getBinary(0))
.asInstanceOf[S]
val resState = StateEncoder.decodeValue[S](retRow)
resState
} else {
null.asInstanceOf[S]
}
}

private def getImpl(): UnsafeRow = {
store.get(encodeKey(), stateName)
store.get(StateEncoder.encodeGroupingKey(stateName), stateName)
}

/** Function to update and overwrite state associated with given key */
override def update(newState: S): Unit = {
store.put(encodeKey(), encodeValue(newState), stateName)
store.put(StateEncoder.encodeGroupingKey(stateName),
StateEncoder.encodeValue(newState), stateName)
}

/** Function to remove state for given key */
override def remove(): Unit = {
store.remove(encodeKey(), stateName)
store.remove(StateEncoder.encodeGroupingKey(stateName), stateName)
}
}
Loading

0 comments on commit 3bcdac6

Please sign in to comment.