Skip to content

Commit

Permalink
Merge pull request #1 from ankurdave/add_project_to_graph
Browse files Browse the repository at this point in the history
Merge current master and reimplement Graph.mask using innerJoin
  • Loading branch information
amatsukawa committed Dec 18, 2013
2 parents cb20175 + 0f137e8 commit d7ebff0
Show file tree
Hide file tree
Showing 38 changed files with 2,448 additions and 1,985 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
if (getKeyClass().isArray && partitioner.isInstanceOf[HashPartitioner]) {
throw new SparkException("Default partitioner cannot partition array keys.")
}
new ShuffledRDD[K, V, (K, V)](self, partitioner)
if (self.partitioner == partitioner) self else new ShuffledRDD[K, V, (K, V)](self, partitioner)
}

/**
Expand Down
27 changes: 11 additions & 16 deletions core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import java.io.{ObjectOutputStream, IOException}

private[spark] class ZippedPartitionsPartition(
idx: Int,
@transient rdds: Seq[RDD[_]])
@transient rdds: Seq[RDD[_]],
@transient val preferredLocations: Seq[String])
extends Partition {

override val index: Int = idx
Expand All @@ -47,27 +48,21 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
if (preservesPartitioning) firstParent[Any].partitioner else None

override def getPartitions: Array[Partition] = {
val sizes = rdds.map(x => x.partitions.size)
if (!sizes.forall(x => x == sizes(0))) {
val numParts = rdds.head.partitions.size
if (!rdds.forall(rdd => rdd.partitions.size == numParts)) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
val array = new Array[Partition](sizes(0))
for (i <- 0 until sizes(0)) {
array(i) = new ZippedPartitionsPartition(i, rdds)
Array.tabulate[Partition](numParts) { i =>
val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i)))
// Check whether there are any hosts that match all RDDs; otherwise return the union
val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y))
val locs = if (!exactMatchLocations.isEmpty) exactMatchLocations else prefs.flatten.distinct
new ZippedPartitionsPartition(i, rdds, locs)
}
array
}

override def getPreferredLocations(s: Partition): Seq[String] = {
val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions
val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) }
// Check whether there are any hosts that match all RDDs; otherwise return the union
val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y))
if (!exactMatchLocations.isEmpty) {
exactMatchLocations
} else {
prefs.flatten.distinct
}
s.asInstanceOf[ZippedPartitionsPartition].preferredLocations
}

override def clearDependencies() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,10 @@ class BitSet(numBits: Int) {
words(index >> 6) |= bitmask // div by 64 and mask
}

def unset(index: Int) {
val bitmask = 1L << (index & 0x3f) // mod 64 and shift
words(index >> 6) &= ~bitmask // div by 64 and mask
}

/**
* Return the value of the bit with the specified index. The value is true if the bit with
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
/** Return the value at the specified position. */
def getValue(pos: Int): T = _data(pos)

def iterator() = new Iterator[T] {
def iterator = new Iterator[T] {
var pos = nextPos(0)
override def hasNext: Boolean = pos != INVALID_POS
override def next(): T = {
Expand Down Expand Up @@ -249,8 +249,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
* in the lower bits, similar to java.util.HashMap
*/
private def hashcode(h: Int): Int = {
val r = h ^ (h >>> 20) ^ (h >>> 12)
r ^ (r >>> 7) ^ (r >>> 4)
it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
}

private def nextPowerOf2(n: Int): Int = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,


/** Set the value for a key */
def setMerge(k: K, v: V, mergeF: (V,V) => V) {
def setMerge(k: K, v: V, mergeF: (V, V) => V) {
val pos = keySet.addWithoutResize(k)
val ind = pos & OpenHashSet.POSITION_MASK
if ((pos & OpenHashSet.NONEXISTENCE_MASK) != 0) { // if first add
Expand Down
Loading

0 comments on commit d7ebff0

Please sign in to comment.