Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARK-1100: prevent Spark from overwriting directory silently #11

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -34,14 +34,11 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat}
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob}
import org.apache.hadoop.mapreduce.{RecordWriter => NewRecordWriter}
import org.apache.hadoop.mapreduce.{OutputFormat => NewOutputFormat, Job => NewAPIHadoopJob, RecordWriter => NewRecordWriter, JobContext, SparkHadoopMapReduceUtil}
import org.apache.hadoop.mapreduce.lib.output.{FileOutputFormat => NewFileOutputFormat}

// SparkHadoopWriter and SparkHadoopMapReduceUtil are actually source files defined in Spark.
import org.apache.hadoop.mapred.SparkHadoopWriter
import org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil

import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
Expand Down Expand Up @@ -604,8 +601,12 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
val job = new NewAPIHadoopJob(conf)
job.setOutputKeyClass(keyClass)
job.setOutputValueClass(valueClass)

val wrappedConf = new SerializableWritable(job.getConfiguration)
NewFileOutputFormat.setOutputPath(job, new Path(path))
val outpath = new Path(path)
NewFileOutputFormat.setOutputPath(job, outpath)
val jobFormat = outputFormatClass.newInstance
jobFormat.checkOutputSpecs(new JobContext(wrappedConf.value, job.getJobID))
val formatter = new SimpleDateFormat("yyyyMMddHHmm")
val jobtrackerID = formatter.format(new Date())
val stageId = self.id
Expand Down Expand Up @@ -633,7 +634,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
committer.commitTask(hadoopContext)
return 1
}
val jobFormat = outputFormatClass.newInstance

/* apparently we need a TaskAttemptID to construct an OutputCommitter;
* however we're only going to use this local OutputCommitter for
* setupJob/commitJob, so we just use a dummy "map" task.
Expand All @@ -642,7 +643,7 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
val jobTaskContext = newTaskAttemptContext(wrappedConf.value, jobAttemptId)
val jobCommitter = jobFormat.getOutputCommitter(jobTaskContext)
jobCommitter.setupJob(jobTaskContext)
val count = self.context.runJob(self, writeShard _).sum
self.context.runJob(self, writeShard _).sum
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry just noticed this - is there any need for .sum here now?

jobCommitter.commitJob(jobTaskContext)
}

Expand Down Expand Up @@ -712,6 +713,10 @@ class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
valueClass.getSimpleName + ")")

val path = new Path(conf.get("mapred.output.dir"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not totally sure mapred.output.dir is going to be set always here. It might be safer for this one to just only do it the format is a FileOutputFormat (e.g. check if the output format is instanceof FileOutputFormat). Then get the path by casting it to a FileOutputFormat and calling getOutputPath. That seems a bit safer to me.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you are right!

val fs = path.getFileSystem(conf)
conf.getOutputFormat.checkOutputSpecs(fs, conf)

val writer = new SparkHadoopWriter(conf)
writer.preSetup()

Expand Down
22 changes: 22 additions & 0 deletions core/src/test/scala/org/apache/spark/FileSuite.scala
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ import scala.io.Source
import com.google.common.io.Files
import org.apache.hadoop.io._
import org.apache.hadoop.io.compress.DefaultCodec
import org.apache.hadoop.mapred.FileAlreadyExistsException
import org.scalatest.FunSuite

import org.apache.spark.SparkContext._
Expand Down Expand Up @@ -208,4 +209,25 @@ class FileSuite extends FunSuite with LocalSparkContext {
assert(rdd.count() === 3)
assert(rdd.count() === 3)
}

test ("prevent user from overwriting the empty directory") {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These tests are looking good, but they don't cover the newAPI code path you added. Would you mind adding tests for that too? You could either create two new tests or just have these tests each use both the old and the new API's.

sc = new SparkContext("local", "test")
val tempdir = Files.createTempDir()
var randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1)
intercept[FileAlreadyExistsException] {
randomRDD.saveAsTextFile(tempdir.getPath)
}
}

test ("prevent user from overwriting the non-empty directory") {
sc = new SparkContext("local", "test")
val tempdir = Files.createTempDir()
var randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1)
randomRDD.saveAsTextFile(tempdir.getPath + "/output")
assert(new File(tempdir.getPath + "/output/part-00000").exists() === true)
randomRDD = sc.parallelize(Array((1, "a"), (1, "a"), (2, "b"), (3, "c")), 1)
intercept[FileAlreadyExistsException] {
randomRDD.saveAsTextFile(tempdir.getPath + "/output")
}
}
}