diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala index 81e0e5297683b..b28e15a6840d9 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala @@ -78,7 +78,7 @@ abstract class Broadcast[T](val id: Long) extends Serializable { */ protected def assertValid() { if (!_isValid) { - throw new SparkException("Attempted to use %s when is no longer valid!".format(toString)) + throw new SparkException("Attempted to use %s after it has been destroyed!".format(toString)) } } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 1721818c212f9..5c239329588d8 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -21,7 +21,7 @@ import java.util.Set import java.util.Map.Entry import java.util.concurrent.ConcurrentHashMap -import scala.collection.{immutable, JavaConversions, mutable} +import scala.collection.{JavaConversions, mutable} import org.apache.spark.Logging @@ -50,11 +50,11 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } def iterator: Iterator[(A, B)] = { - val jIterator = getEntrySet.iterator() + val jIterator = getEntrySet.iterator JavaConversions.asScalaIterator(jIterator).map(kv => (kv.getKey, kv.getValue.value)) } - def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet() + def getEntrySet: Set[Entry[A, TimeStampedValue[B]]] = internalMap.entrySet override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { val newMap = new TimeStampedHashMap[A, B1] @@ -86,8 +86,7 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa } override def apply(key: A): B = { - val value = internalMap.get(key) - Option(value).map(_.value).getOrElse { throw new NoSuchElementException() } + get(key).getOrElse { throw new NoSuchElementException() } } override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = { @@ -101,9 +100,9 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa override def size: Int = internalMap.size override def foreach[U](f: ((A, B)) => U) { - val iterator = getEntrySet.iterator() - while(iterator.hasNext) { - val entry = iterator.next() + val it = getEntrySet.iterator + while(it.hasNext) { + val entry = it.next() val kv = (entry.getKey, entry.getValue.value) f(kv) } @@ -115,27 +114,39 @@ private[spark] class TimeStampedHashMap[A, B](updateTimeStampOnGet: Boolean = fa Option(prev).map(_.value) } - def toMap: immutable.Map[A, B] = iterator.toMap + def putAll(map: Map[A, B]) { + map.foreach { case (k, v) => update(k, v) } + } + + def toMap: Map[A, B] = iterator.toMap def clearOldValues(threshTime: Long, f: (A, B) => Unit) { - val iterator = getEntrySet.iterator() - while (iterator.hasNext) { - val entry = iterator.next() + val it = getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() if (entry.getValue.timestamp < threshTime) { f(entry.getKey, entry.getValue.value) logDebug("Removing key " + entry.getKey) - iterator.remove() + it.remove() } } } - /** - * Removes old key-value pairs that have timestamp earlier than `threshTime`. - */ + /** Removes old key-value pairs that have timestamp earlier than `threshTime`. */ def clearOldValues(threshTime: Long) { clearOldValues(threshTime, (_, _) => ()) } private def currentTime: Long = System.currentTimeMillis + // For testing + + def getTimeStampedValue(key: A): Option[TimeStampedValue[B]] = { + Option(internalMap.get(key)) + } + + def getTimestamp(key: A): Option[Long] = { + getTimeStampedValue(key).map(_.timestamp) + } + } diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala index f814f58261bf3..b65017d6806c6 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedWeakValueHashMap.scala @@ -18,47 +18,61 @@ package org.apache.spark.util import java.lang.ref.WeakReference +import java.util.concurrent.atomic.AtomicInteger -import scala.collection.{immutable, mutable} +import scala.collection.mutable + +import org.apache.spark.Logging /** * A wrapper of TimeStampedHashMap that ensures the values are weakly referenced and timestamped. * - * If the value is garbage collected and the weak reference is null, get() operation returns - * a non-existent value. However, the corresponding key is actually not removed in the current - * implementation. Key-value pairs whose timestamps are older than a particular threshold time - * can then be removed using the clearOldValues method. It exposes a scala.collection.mutable.Map - * interface to allow it to be a drop-in replacement for Scala HashMaps. + * If the value is garbage collected and the weak reference is null, get() will return a + * non-existent value. These entries are removed from the map periodically (every N inserts), as + * their values are no longer strongly reachable. Further, key-value pairs whose timestamps are + * older than a particular threshold can be removed using the clearOldValues method. * - * Internally, it uses a Java ConcurrentHashMap, so all operations on this HashMap are thread-safe. + * TimeStampedWeakValueHashMap exposes a scala.collection.mutable.Map interface, which allows it + * to be a drop-in replacement for Scala HashMaps. Internally, it uses a Java ConcurrentHashMap, + * so all operations on this HashMap are thread-safe. * * @param updateTimeStampOnGet Whether timestamp of a pair will be updated when it is accessed. */ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boolean = false) - extends mutable.Map[A, B]() { + extends mutable.Map[A, B]() with Logging { import TimeStampedWeakValueHashMap._ private val internalMap = new TimeStampedHashMap[A, WeakReference[B]](updateTimeStampOnGet) + private val insertCount = new AtomicInteger(0) + + /** Return a map consisting only of entries whose values are still strongly reachable. */ + private def nonNullReferenceMap = internalMap.filter { case (_, ref) => ref.get != null } def get(key: A): Option[B] = internalMap.get(key) - def iterator: Iterator[(A, B)] = internalMap.iterator + def iterator: Iterator[(A, B)] = nonNullReferenceMap.iterator override def + [B1 >: B](kv: (A, B1)): mutable.Map[A, B1] = { val newMap = new TimeStampedWeakValueHashMap[A, B1] + val oldMap = nonNullReferenceMap.asInstanceOf[mutable.Map[A, WeakReference[B1]]] + newMap.internalMap.putAll(oldMap.toMap) newMap.internalMap += kv newMap } override def - (key: A): mutable.Map[A, B] = { val newMap = new TimeStampedWeakValueHashMap[A, B] + newMap.internalMap.putAll(nonNullReferenceMap.toMap) newMap.internalMap -= key newMap } override def += (kv: (A, B)): this.type = { internalMap += kv + if (insertCount.incrementAndGet() % CLEAR_NULL_VALUES_INTERVAL == 0) { + clearNullValues() + } this } @@ -71,31 +85,53 @@ private[spark] class TimeStampedWeakValueHashMap[A, B](updateTimeStampOnGet: Boo override def apply(key: A): B = internalMap.apply(key) - override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = internalMap.filter(p) + override def filter(p: ((A, B)) => Boolean): mutable.Map[A, B] = nonNullReferenceMap.filter(p) override def empty: mutable.Map[A, B] = new TimeStampedWeakValueHashMap[A, B]() override def size: Int = internalMap.size - override def foreach[U](f: ((A, B)) => U) = internalMap.foreach(f) + override def foreach[U](f: ((A, B)) => U) = nonNullReferenceMap.foreach(f) def putIfAbsent(key: A, value: B): Option[B] = internalMap.putIfAbsent(key, value) - def toMap: immutable.Map[A, B] = iterator.toMap + def toMap: Map[A, B] = iterator.toMap - /** - * Remove old key-value pairs that have timestamp earlier than `threshTime`. - */ + /** Remove old key-value pairs with timestamps earlier than `threshTime`. */ def clearOldValues(threshTime: Long) = internalMap.clearOldValues(threshTime) + /** Remove entries with values that are no longer strongly reachable. */ + def clearNullValues() { + val it = internalMap.getEntrySet.iterator + while (it.hasNext) { + val entry = it.next() + if (entry.getValue.value.get == null) { + logDebug("Removing key " + entry.getKey + " because it is no longer strongly reachable.") + it.remove() + } + } + } + + // For testing + + def getTimestamp(key: A): Option[Long] = { + internalMap.getTimeStampedValue(key).map(_.timestamp) + } + + def getReference(key: A): Option[WeakReference[B]] = { + internalMap.getTimeStampedValue(key).map(_.value) + } } /** * Helper methods for converting to and from WeakReferences. */ -private[spark] object TimeStampedWeakValueHashMap { +private object TimeStampedWeakValueHashMap { - /* Implicit conversion methods to WeakReferences */ + // Number of inserts after which entries with null references are removed + val CLEAR_NULL_VALUES_INTERVAL = 100 + + /* Implicit conversion methods to WeakReferences. */ implicit def toWeakReference[V](v: V): WeakReference[V] = new WeakReference[V](v) @@ -107,12 +143,15 @@ private[spark] object TimeStampedWeakValueHashMap { (kv: (K, WeakReference[V])) => p(kv) } - /* Implicit conversion methods from WeakReferences */ + /* Implicit conversion methods from WeakReferences. */ implicit def fromWeakReference[V](ref: WeakReference[V]): V = ref.get implicit def fromWeakReferenceOption[V](v: Option[WeakReference[V]]): Option[V] = { - v.map(fromWeakReference) + v match { + case Some(ref) => Option(fromWeakReference(ref)) + case None => None + } } implicit def fromWeakReferenceTuple[K, V](kv: (K, WeakReference[V])): (K, V) = { @@ -128,5 +167,4 @@ private[spark] object TimeStampedWeakValueHashMap { map: mutable.Map[K, WeakReference[V]]) : mutable.Map[K, V] = { mutable.Map(map.mapValues(fromWeakReference).toSeq: _*) } - } diff --git a/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala new file mode 100644 index 0000000000000..6a5653ed2fb54 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/TimeStampedHashMapSuite.scala @@ -0,0 +1,264 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.lang.ref.WeakReference + +import scala.collection.mutable +import scala.collection.mutable.ArrayBuffer +import scala.util.Random + +import org.scalatest.FunSuite + +class TimeStampedHashMapSuite extends FunSuite { + + // Test the testMap function - a Scala HashMap should obviously pass + testMap(new mutable.HashMap[String, String]()) + + // Test TimeStampedHashMap basic functionality + testMap(new TimeStampedHashMap[String, String]()) + testMapThreadSafety(new TimeStampedHashMap[String, String]()) + + // Test TimeStampedWeakValueHashMap basic functionality + testMap(new TimeStampedWeakValueHashMap[String, String]()) + testMapThreadSafety(new TimeStampedWeakValueHashMap[String, String]()) + + test("TimeStampedHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing by timestamp") { + // clearing by insertion time + val map = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = false) + map("k1") = "v1" + assert(map("k1") === "v1") + Thread.sleep(10) + val threshTime = System.currentTimeMillis + assert(map.getTimestamp("k1").isDefined) + assert(map.getTimestamp("k1").get < threshTime) + map.clearOldValues(threshTime) + assert(map.get("k1") === None) + + // clearing by modification time + val map1 = new TimeStampedWeakValueHashMap[String, String](updateTimeStampOnGet = true) + map1("k1") = "v1" + map1("k2") = "v2" + assert(map1("k1") === "v1") + Thread.sleep(10) + val threshTime1 = System.currentTimeMillis + Thread.sleep(10) + assert(map1("k2") === "v2") // access k2 to update its access time to > threshTime + assert(map1.getTimestamp("k1").isDefined) + assert(map1.getTimestamp("k1").get < threshTime1) + assert(map1.getTimestamp("k2").isDefined) + assert(map1.getTimestamp("k2").get >= threshTime1) + map1.clearOldValues(threshTime1) //should only clear k1 + assert(map1.get("k1") === None) + assert(map1.get("k2").isDefined) + } + + test("TimeStampedWeakValueHashMap - clearing weak references") { + var strongRef = new Object + val weakRef = new WeakReference(strongRef) + val map = new TimeStampedWeakValueHashMap[String, Object] + map("k1") = strongRef + map("k2") = "v2" + map("k3") = "v3" + assert(map("k1") === strongRef) + + // clear strong reference to "k1" + strongRef = null + val startTime = System.currentTimeMillis + System.gc() // Make a best effort to run the garbage collection. It *usually* runs GC. + System.runFinalization() // Make a best effort to call finalizer on all cleaned objects. + while(System.currentTimeMillis - startTime < 10000 && weakRef.get != null) { + System.gc() + System.runFinalization() + Thread.sleep(100) + } + assert(map.getReference("k1").isDefined) + val ref = map.getReference("k1").get + assert(ref.get === null) + assert(map.get("k1") === None) + + // operations should only display non-null entries + assert(map.iterator.forall { case (k, v) => k != "k1" }) + assert(map.filter { case (k, v) => k != "k2" }.size === 1) + assert(map.filter { case (k, v) => k != "k2" }.head._1 === "k3") + assert(map.toMap.size === 2) + assert(map.toMap.forall { case (k, v) => k != "k1" }) + val buffer = new ArrayBuffer[String] + map.foreach { case (k, v) => buffer += v.toString } + assert(buffer.size === 2) + assert(buffer.forall(_ != "k1")) + val plusMap = map + (("k4", "v4")) + assert(plusMap.size === 3) + assert(plusMap.forall { case (k, v) => k != "k1" }) + val minusMap = map - "k2" + assert(minusMap.size === 1) + assert(minusMap.head._1 == "k3") + + // clear null values - should only clear k1 + map.clearNullValues() + assert(map.getReference("k1") === None) + assert(map.get("k1") === None) + assert(map.get("k2").isDefined) + assert(map.get("k2").get === "v2") + } + + /** Test basic operations of a Scala mutable Map. */ + def testMap(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val testMap1 = newMap() + val testMap2 = newMap() + val name = testMap1.getClass.getSimpleName + + test(name + " - basic test") { + // put, get, and apply + testMap1 += (("k1", "v1")) + assert(testMap1.get("k1").isDefined) + assert(testMap1.get("k1").get === "v1") + testMap1("k2") = "v2" + assert(testMap1.get("k2").isDefined) + assert(testMap1.get("k2").get === "v2") + assert(testMap1("k2") === "v2") + testMap1.update("k3", "v3") + assert(testMap1.get("k3").isDefined) + assert(testMap1.get("k3").get === "v3") + + // remove + testMap1.remove("k1") + assert(testMap1.get("k1").isEmpty) + testMap1.remove("k2") + intercept[NoSuchElementException] { + testMap1("k2") // Map.apply() causes exception + } + testMap1 -= "k3" + assert(testMap1.get("k3").isEmpty) + + // multi put + val keys = (1 to 100).map(_.toString) + val pairs = keys.map(x => (x, x * 2)) + assert((testMap2 ++ pairs).iterator.toSet === pairs.toSet) + testMap2 ++= pairs + + // iterator + assert(testMap2.iterator.toSet === pairs.toSet) + + // filter + val filtered = testMap2.filter { case (_, v) => v.toInt % 2 == 0 } + val evenPairs = pairs.filter { case (_, v) => v.toInt % 2 == 0 } + assert(filtered.iterator.toSet === evenPairs.toSet) + + // foreach + val buffer = new ArrayBuffer[(String, String)] + testMap2.foreach(x => buffer += x) + assert(testMap2.toSet === buffer.toSet) + + // multi remove + testMap2("k1") = "v1" + testMap2 --= keys + assert(testMap2.size === 1) + assert(testMap2.iterator.toSeq.head === ("k1", "v1")) + + // + + val testMap3 = testMap2 + (("k0", "v0")) + assert(testMap3.size === 2) + assert(testMap3.get("k1").isDefined) + assert(testMap3.get("k1").get === "v1") + assert(testMap3.get("k0").isDefined) + assert(testMap3.get("k0").get === "v0") + + // - + val testMap4 = testMap3 - "k0" + assert(testMap4.size === 1) + assert(testMap4.get("k1").isDefined) + assert(testMap4.get("k1").get === "v1") + } + } + + /** Test thread safety of a Scala mutable map. */ + def testMapThreadSafety(hashMapConstructor: => mutable.Map[String, String]) { + def newMap() = hashMapConstructor + val name = newMap().getClass.getSimpleName + val testMap = newMap() + @volatile var error = false + + def getRandomKey(m: mutable.Map[String, String]): Option[String] = { + val keys = testMap.keysIterator.toSeq + if (keys.nonEmpty) { + Some(keys(Random.nextInt(keys.size))) + } else { + None + } + } + + val threads = (1 to 25).map(i => new Thread() { + override def run() { + try { + for (j <- 1 to 1000) { + Random.nextInt(3) match { + case 0 => + testMap(Random.nextString(10)) = Random.nextDouble().toString // put + case 1 => + getRandomKey(testMap).map(testMap.get) // get + case 2 => + getRandomKey(testMap).map(testMap.remove) // remove + } + } + } catch { + case t: Throwable => + error = true + throw t + } + } + }) + + test(name + " - threading safety test") { + threads.map(_.start) + threads.map(_.join) + assert(!error) + } + } +}