Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARK-5484][GraphX] Periodically do checkpoint in Pregel #15125

Closed
wants to merge 20 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.util

import scala.collection.mutable

Expand Down Expand Up @@ -58,7 +58,7 @@ import org.apache.spark.storage.StorageLevel
* @param sc SparkContext for the Datasets given to this checkpointer
* @tparam T Dataset type, such as RDD[Double]
*/
private[mllib] abstract class PeriodicCheckpointer[T](
private[spark] abstract class PeriodicCheckpointer[T](
val checkpointInterval: Int,
val sc: SparkContext) extends Logging {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.util
Copy link
Member

@felixcheung felixcheung Apr 19, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's PeriodicRDDCheckpointer, shouldn't this be in the org.apache.spark.rdd.util namespace?


import org.apache.spark.SparkContext
import org.apache.spark.rdd.RDD
Expand Down
4 changes: 3 additions & 1 deletion graphx/src/main/scala/org/apache/spark/graphx/GraphOps.scala
Original file line number Diff line number Diff line change
Expand Up @@ -362,12 +362,14 @@ class GraphOps[VD: ClassTag, ED: ClassTag](graph: Graph[VD, ED]) extends Seriali
def pregel[A: ClassTag](
initialMsg: A,
maxIterations: Int = Int.MaxValue,
checkpointInterval: Int = 25,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it a breaking change to add a param into the middle of the list?
also, this should be documented above, right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

also, why is this 25?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point in both cases. I'm wondering if the periodic pregel checkpointing operation should be controlled by a config value instead. Suppose, for example, we create a config key spark.graphx.pregel.checkpointInterval. If it's set to the default value (0, to retain existing functionality), checkpointing is disabled. If it's set to a positive integer, the checkpointing is performed with that many iterations. Other values are invalid.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

25 is the value in my test. I checked this value in LDA/ALS, looks like the default value is 10, change it to 10 ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Agree with @mallman , we don't need change the api interface if we use a config value to controll checkpoint interval. I will udpate the PR soon.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

About the default value, I think we should set it as a positive value to turn on checkpoint operation by default to avoid stackoverflow exception. To align with other implementations in spark, I would like to set 10 as default value. Please let me know if there is any thoughts about this.

activeDirection: EdgeDirection = EdgeDirection.Either)(
vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A)
: Graph[VD, ED] = {
Pregel(graph, initialMsg, maxIterations, activeDirection)(vprog, sendMsg, mergeMsg)
Pregel(graph, initialMsg, maxIterations, activeDirection,
checkpointInterval)(vprog, sendMsg, mergeMsg)
}

/**
Expand Down
20 changes: 18 additions & 2 deletions graphx/src/main/scala/org/apache/spark/graphx/Pregel.scala
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,10 @@ package org.apache.spark.graphx

import scala.reflect.ClassTag

import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.internal.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.util.PeriodicRDDCheckpointer

/**
* Implements a Pregel-like bulk-synchronous message-passing API.
Expand Down Expand Up @@ -113,7 +116,8 @@ object Pregel extends Logging {
(graph: Graph[VD, ED],
initialMsg: A,
maxIterations: Int = Int.MaxValue,
activeDirection: EdgeDirection = EdgeDirection.Either)
activeDirection: EdgeDirection = EdgeDirection.Either,
checkpointInterval: Int = 25)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ditto, why 25?

(vprog: (VertexId, VD, A) => VD,
sendMsg: EdgeTriplet[VD, ED] => Iterator[(VertexId, A)],
mergeMsg: (A, A) => A)
Expand All @@ -123,16 +127,25 @@ object Pregel extends Logging {
s" but got ${maxIterations}")

var g = graph.mapVertices((vid, vdata) => vprog(vid, vdata, initialMsg)).cache()
val graphCheckpointer = new PeriodicGraphCheckpointer[VD, ED](
checkpointInterval, graph.vertices.sparkContext)
graphCheckpointer.update(g)

// compute the messages
var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg)
var messages = GraphXUtils.mapReduceTriplets(g, sendMsg, mergeMsg).cache()
val messageCheckpointer = new PeriodicRDDCheckpointer[(VertexId, A)](
checkpointInterval, graph.vertices.sparkContext)
messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we need to update it before the loop? I think it should be enough to do the update in the loop.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree. What do you think, @dding3?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, for the sake of simplicity and consistency, I'm going to suggest we keep the checkpointer update calls but remove all .cache() calls. The update calls persist the underlying data, making the calls to .cache() unnecessary.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we need cache graph/messages here so they don't need to be computed again in the loop. I agree with you and I will keep the checkpointer update calls and remove all .cache calls.

var activeMessages = messages.count()

// Loop
var prevG: Graph[VD, ED] = null
var i = 0
while (activeMessages > 0 && i < maxIterations) {
// Receive the messages and update the vertices.
prevG = g
g = g.joinVertices(messages)(vprog).cache()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

checkpointer will do persist. So I think we don't need to call cache here.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, this may be more subtle than I thought. I'm going to think through this again.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok. I agree with your observation here, @viirya.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK

graphCheckpointer.update(g)

val oldMessages = messages
// Send new messages, skipping edges where neither side received a message. We must cache
Expand All @@ -143,6 +156,7 @@ object Pregel extends Logging {
// The call to count() materializes `messages` and the vertices of `g`. This hides oldMessages
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should the comment here be updated?

// (depended on by the vertices of g) and the vertices of prevG (depended on by oldMessages
// and the vertices of g).
messageCheckpointer.update(messages.asInstanceOf[RDD[(VertexId, A)]])
activeMessages = messages.count()

logInfo("Pregel finished iteration " + i)
Expand All @@ -155,6 +169,8 @@ object Pregel extends Logging {
i += 1
}
messages.unpersist(blocking = false)
Copy link
Member

@viirya viirya Feb 15, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Let's do messageCheckpointer.unpersist(messages).

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like unpersist is a protected method and we cannot access it from Pregel. Add a new public method to unpersist dataset to work around this.

graphCheckpointer.deleteAllCheckpoints()
messageCheckpointer.deleteAllCheckpoints()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this be inside a finally clause to make sure checkpoint data is cleaned up even in error cases?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think when there is an exception during training, if we keep the checkpoints, there is a chance for user to recover from it. I checked in RandomForest/GBT in spark, looks like they only delete the checkpoints when the training successful finished.

g
} // end of apply

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,12 @@
* limitations under the License.
*/

package org.apache.spark.mllib.impl
package org.apache.spark.graphx.util

import org.apache.spark.SparkContext
import org.apache.spark.graphx.Graph
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.PeriodicCheckpointer


/**
Expand Down Expand Up @@ -76,7 +77,7 @@ import org.apache.spark.storage.StorageLevel
*
* TODO: Move this out of MLlib?
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This comment should be removed.

*/
private[mllib] class PeriodicGraphCheckpointer[VD, ED](
private[spark] class PeriodicGraphCheckpointer[VD, ED](
checkpointInterval: Int,
sc: SparkContext)
extends PeriodicCheckpointer[Graph[VD, ED]](checkpointInterval, sc) {
Expand All @@ -87,10 +88,7 @@ private[mllib] class PeriodicGraphCheckpointer[VD, ED](

override protected def persist(data: Graph[VD, ED]): Unit = {
if (data.vertices.getStorageLevel == StorageLevel.NONE) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't persist better? this could potentially support different storage level later

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to use cache because persist does not honor the default storage level requested when constructing the graph. Only cache does that. It's confusing, but true. To verify this for yourself, change these values to persist and run the PeriodicGraphCheckpointerSuite tests.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mallman @dding3 it will be good to add a comment on that here

data.vertices.persist()
}
if (data.edges.getStorageLevel == StorageLevel.NONE) {
data.edges.persist()
data.persist()
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hmm, GraphImpl.persist actually persists its vertices and replicatedVertexView.edges (i.e., edges). But the problem is we only check the storage level of vertices. Maybe we should keep the original.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I enhanced the persistence tests in PeriodicGraphCheckpointSuite to check that the storage level requested in the graph construction is the storage level seen after persistence. Both this version and the original version of this method failed that unit test.

The graph's vertex and edge rdds are somewhat peculiar in that .cache() and .persist() do not do the same thing, unlike other RDDs. And while .cache() honors the default storage level specified at graph construction time, .persist() always caches with the MEMORY_ONLY storage level.

At any rate, getting the PeriodicGraphCheckpointer to honor the default storage level specified at graph construction time requires changing these method calls from persist() to cache().

}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@ import org.apache.spark.mllib.clustering.{DistributedLDAModel => OldDistributedL
EMLDAOptimizer => OldEMLDAOptimizer, LDA => OldLDA, LDAModel => OldLDAModel,
LDAOptimizer => OldLDAOptimizer, LocalLDAModel => OldLocalLDAModel,
OnlineLDAOptimizer => OldOnlineLDAOptimizer}
import org.apache.spark.mllib.impl.PeriodicCheckpointer
import org.apache.spark.mllib.linalg.{Vector => OldVector, Vectors => OldVectors}
import org.apache.spark.mllib.linalg.MatrixImplicits._
import org.apache.spark.mllib.linalg.VectorImplicits._
Expand All @@ -43,9 +42,9 @@ import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row, SparkSession}
import org.apache.spark.sql.functions.{col, monotonically_increasing_id, udf}
import org.apache.spark.sql.types.StructType
import org.apache.spark.util.PeriodicCheckpointer
import org.apache.spark.util.VersionUtils


private[clustering] trait LDAParams extends Params with HasFeaturesCol with HasMaxIter
with HasSeed with HasCheckpointInterval {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg.Vector
import org.apache.spark.ml.regression.{DecisionTreeRegressionModel, DecisionTreeRegressor}
import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo}
import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy}
import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.PeriodicRDDCheckpointer


private[spark] object GradientBoostedTrees extends Logging {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import breeze.stats.distributions.{Gamma, RandBasis}

import org.apache.spark.annotation.{DeveloperApi, Since}
import org.apache.spark.graphx._
import org.apache.spark.mllib.impl.PeriodicGraphCheckpointer
import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.mllib.linalg.{DenseVector, Matrices, SparseVector, Vector, Vectors}
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this test suite isn't moved into the GraphX codebase?

import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.graphx.{Edge, Graph}
import org.apache.spark.graphx.util.PeriodicGraphCheckpointer
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.{SparkContext, SparkFunSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason this file isn't moved into the core codebase?

import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
import org.apache.spark.util.{PeriodicRDDCheckpointer, Utils}


class PeriodicRDDCheckpointerSuite extends SparkFunSuite with MLlibTestSparkContext {
Expand Down
9 changes: 7 additions & 2 deletions project/MimaExcludes.scala
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@ object MimaExcludes {
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.streaming.scheduler.StreamingListener.onStreamingStarted"),

// [SPARK-19148][SQL] do not expose the external table concept in Catalog
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.createTable")
ProblemFilters.exclude[ReversedMissingMethodProblem]("org.apache.spark.sql.catalog.Catalog.createTable"),

// SPARK-5484 Periodically do checkpoint in Pregel
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.graphx.Pregel.apply"),
ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.graphx.GraphOps.pregel"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.graphx.GraphOps.pregel$default$3")
)

// Exclude rules for 2.1.x
Expand Down Expand Up @@ -932,7 +937,7 @@ object MimaExcludes {
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.numTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.classification.RandomForestClassificationModel.setFeatureSubsetStrategy"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.numTrees"),
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy")
ProblemFilters.exclude[IncompatibleResultTypeProblem]("org.apache.spark.ml.regression.RandomForestRegressionModel.setFeatureSubsetStrategy")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove trailing whitespace

Copy link
Member

@viirya viirya Feb 17, 2017

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

huh, why scala style checker doesn't complain about this.

)
}

Expand Down