Skip to content

Commit

Permalink
[SPARK-5484][GRAPHX] Periodically do checkpoint in Pregel
Browse files Browse the repository at this point in the history
## What changes were proposed in this pull request?

Pregel-based iterative algorithms with more than ~50 iterations begin to slow down and eventually fail with a StackOverflowError due to Spark's lack of support for long lineage chains.

This PR causes Pregel to checkpoint the graph periodically if the checkpoint directory is set.
This PR moves PeriodicGraphCheckpointer.scala from mllib to graphx, moves PeriodicRDDCheckpointer.scala, PeriodicCheckpointer.scala from mllib to core
## How was this patch tested?

unit tests, manual tests
(Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests)

(If this patch involves UI changes, please attach a screenshot; otherwise, remove this)

Author: ding <ding@localhost.localdomain>
Author: dding3 <ding.ding@intel.com>
Author: Michael Allman <michael@videoamp.com>

Closes #15125 from dding3/cp2_pregel.
  • Loading branch information
ding authored and Felix Cheung committed Apr 25, 2017
1 parent 67eef47 commit 0a7f5f2
Show file tree
Hide file tree
Showing 13 changed files with 128 additions and 76 deletions.
4 changes: 2 additions & 2 deletions core/src/main/scala/org/apache/spark/rdd/RDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import org.apache.spark.partial.GroupedCountEvaluator
import org.apache.spark.partial.PartialResult
import org.apache.spark.storage.{RDDBlockId, StorageLevel}
import org.apache.spark.util.{BoundedPriorityQueue, Utils}
import org.apache.spark.util.collection.OpenHashMap
import org.apache.spark.util.collection.{OpenHashMap, Utils => collectionUtils}
import org.apache.spark.util.random.{BernoulliCellSampler, BernoulliSampler, PoissonSampler,
SamplingUtils}

Expand Down Expand Up @@ -1420,7 +1420,7 @@ abstract class RDD[T: ClassTag](
val mapRDDs = mapPartitions { items =>
// Priority keeps the largest elements, so let's reverse the ordering.
val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
queue ++= collectionUtils.takeOrdered(items, num)(ord)
Iterator.single(queue)
}
if (mapRDDs.partitions.length == 0) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.rdd.util

import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.PeriodicCheckpointer


/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.util

import scala.collection.mutable

Expand Down Expand Up @@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel
* @param sc SparkContext for the Datasets given to this checkpointer
* @tparam T Dataset type, such as RDD[Double]
*/
private[mllib] abstract class PeriodicCheckpointer[T](
private[spark] abstract class PeriodicCheckpointer[T](
val checkpointInterval: Int,
val sc: SparkContext) extends Logging {

Expand Down Expand Up @@ -127,6 +127,16 @@ private[mllib] abstract class PeriodicCheckpointer[T](
/** Get list of checkpoint files for this given Dataset */
protected def getCheckpointFiles(data: T): Iterable[String]

/**
* Call this to unpersist the Dataset.
*/
def unpersistDataSet(): Unit = {
while (persistedQueue.nonEmpty) {
val dataToUnpersist = persistedQueue.dequeue()
unpersist(dataToUnpersist)
}
}

/**
* Call this at the end to delete any remaining checkpoint files.
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,7 @@ class SortingSuite extends SparkFunSuite with SharedSparkContext with Matchers w
}

test("get a range of elements in an array not partitioned by a range partitioner") {
val pairArr = util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
val pairArr = scala.util.Random.shuffle((1 to 1000).toList).map(x => (x, x))
val pairs = sc.parallelize(pairArr, 10)
val range = pairs.filterByRange(200, 800).collect()
assert((800 to 200 by -1).toArray.sorted === range.map(_._1).sorted)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,18 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.utils

import org.apache.hadoop.fs.Path

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.{SharedSparkContext, SparkContext, SparkFunSuite}
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils


class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
class PeriodicRDDCheckpointerSuite extends SparkFunSuite with SharedSparkContext {

import PeriodicRDDCheckpointerSuite._

Expand Down
14 changes: 14 additions & 0 deletions docs/configuration.md
Original file line number Diff line number Diff line change
Expand Up @@ -2149,6 +2149,20 @@ showDF(properties, numRows = 200, truncate = FALSE)

</table>

### GraphX

<table class="table">
<tr><th>Property Name</th><th>Default</th><th>Meaning</th></tr>
<tr>
<td><code>spark.graphx.pregel.checkpointInterval</code></td>
<td>-1</td>
<td>
Checkpoint interval for graph and message in Pregel. It used to avoid stackOverflowError due to long lineage chains
after lots of iterations. The checkpoint is disabled by default.
</td>
</tr>
</table>

### Deploy

<table class="table">
Expand Down
9 changes: 6 additions & 3 deletions docs/graphx-programming-guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,9 @@ messages remaining.
> messaging function. These constraints allow additional optimization within GraphX.
The following is the type signature of the [Pregel operator][GraphOps.pregel] as well as a *sketch*
of its implementation (note calls to graph.cache have been removed):
of its implementation (note: to avoid stackOverflowError due to long lineage chains, pregel support periodcally
checkpoint graph and messages by setting "spark.graphx.pregel.checkpointInterval" to a positive number,
say 10. And set checkpoint directory as well using SparkContext.setCheckpointDir(directory: String)):

{% highlight scala %}
class GraphOps[VD, ED] {
Expand All @@ -722,6 +724,7 @@ class GraphOps[VD, ED] {
: Graph[VD, ED] = {
// Receive the initial message at each vertex
var g = mapVertices( (vid, vdata) => vprog(vid, vdata, initialMsg) ).cache()

// compute the messages
var messages = g.mapReduceTriplets(sendMsg, mergeMsg)
var activeMessages = messages.count()
Expand All @@ -734,8 +737,8 @@ class GraphOps[VD, ED] {
// Send new messages, skipping edges where neither side received a message. We must cache
// messages so it can be materialized on the next line, allowing us to uncache the previous
// iteration.
messages = g.mapReduceTriplets(
sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
messages = GraphXUtils.mapReduceTriplets(
g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
activeMessages = messages.count()
i += 1
}
Expand Down
25 changes: 21 additions & 4 deletions graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ package org.apache.spark.graphx

import scala.reflect.ClassTag

import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.rdd.util.PeriodicRDDCheckpointer

/**
* Implements a Pregel-like bulk-synchronous message-passing API.
Expand Down Expand Up @@ -122,27 +125,39 @@ object Pregel extends Logging {
require(maxIterations > 0, s"Maximum number of iterations must be greater than 0," +
s" but got ${maxIterations}")

var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
val checkpointInterval = graph.vertices.sparkContext.getConf
.getInt("spark.graphx.pregel.checkpointInterval", -1)
var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg))
val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED](
checkpointInterval, graph.vertices.sparkContext)
graphCheckpointer.update(g)

// compute the messages
var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)
val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)](
checkpointInterval, graph.vertices.sparkContext)
messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
var activeMessages = messages.count()

// Loop
var prevG: Graph[VD, ED] = null
var i = 0
while (activeMessages > 0 && i < maxIterations) {
// Receive the messages and update the vertices.
prevG = g
g = g.joinVertices(messages)(vprog).cache()
g = g.joinVertices(messages)(vprog)
graphCheckpointer.update(g)

val oldMessages = messages
// Send new messages, skipping edges where neither side received a message. We must cache
// messages so it can be materialized on the next line, allowing us to uncache the previous
// iteration.
messages = GraphXUtils.mapReduceTriplets(
g, sendMsg, mergeMsg, Some((oldMessages, activeDirection))).cache()
g, sendMsg, mergeMsg, Some((oldMessages, activeDirection)))
// The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages
// (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages
// and the vertices of g).
messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
activeMessages = messages.count()

logInfo("Pregel finished iteration " + i)
Expand All @@ -154,7 +169,9 @@ object Pregel extends Logging {
// count the iteration
i += 1
}
messages.unpersist(blocking = false)
messageCheckpointer.unpersistDataSet()
graphCheckpointer.deleteAllCheckpoints()
messageCheckpointer.deleteAllCheckpoints()
g
} // end of apply

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.graphx.util

import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.PeriodicCheckpointer


/**
Expand Down Expand Up @@ -74,9 +75,8 @@ import org.apache.spark.storage.StorageLevel
* @tparam VD Vertex descriptor type
* @tparam ED Edge descriptor type
*
* TODO: Move this out of MLlib?
*/
private[mllib] class PeriodicGraphCheckpointer[VD, ED](
private[spark] class PeriodicGraphCheckpointer[VD, ED](
checkpointInterval: Int,
sc: SparkContext)
extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
Expand All @@ -87,10 +87,13 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](

override protected def persist(data: Graph[VD, ED]): Unit = {
if (data.vertices.getStorageLevel == StorageLevel.NONE) {
data.vertices.persist()
/* We need to use cache because persist does not honor the default storage level requested
* when constructing the graph. Only cache does that.
*/
data.vertices.cache()
}
if (data.edges.getStorageLevel == StorageLevel.NONE) {
data.edges.persist()
data.edges.cache()
}
}

Expand Down
Loading

0 comments on commit 0a7f5f2

Please sign in to comment.