diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala index 4e26e2c2f5e..587bffe7ebc 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/HostAlloc.scala @@ -284,7 +284,7 @@ private class HostAlloc(nonPinnedLimit: Long) { private def canNeverSucceed(amount: Long, preferPinned: Boolean): Boolean = { val pinnedFailed = (isPinnedOnly || preferPinned) && (amount > pinnedLimit) val nonPinnedFailed = isPinnedOnly || (amount > nonPinnedLimit) - pinnedFailed && nonPinnedFailed + !isUnlimited && pinnedFailed && nonPinnedFailed } private def checkSize(amount: Long, preferPinned: Boolean): Unit = { diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala index 28adb9301f7..508b869f0bd 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBuffer.scala @@ -206,14 +206,11 @@ class RapidsBufferCopyIterator(buffer: RapidsBuffer) } override def close(): Unit = { - val hasNextBeforeClose = hasNext val toClose = new ArrayBuffer[AutoCloseable]() toClose.appendAll(chunkedPacker) toClose.appendAll(Option(singleShotBuffer)) toClose.safeClose() - require(!hasNextBeforeClose, - "RapidsBufferCopyIterator was closed before exhausting") } } diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala index 1b49374f9fb..61a636c1708 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsBufferCatalog.scala @@ -775,11 +775,14 @@ object RapidsBufferCatalog extends Logging { rapidsConf.chunkedPackBounceBufferSize, rapidsConf.spillToDiskBounceBufferSize) diskBlockManager = new RapidsDiskBlockManager(conf) - val hostSpillStorageSize = if (rapidsConf.hostSpillStorageSize == -1) { + val hostSpillStorageSize = if (rapidsConf.offHeapLimitEnabled) { + // Disable the limit because it is handled by the RapidsHostMemoryStore + None + } else if (rapidsConf.hostSpillStorageSize == -1) { // + 1 GiB by default to match backwards compatibility - rapidsConf.pinnedPoolSize + (1024 * 1024 * 1024) + Some(rapidsConf.pinnedPoolSize + (1024 * 1024 * 1024)) } else { - rapidsConf.hostSpillStorageSize + Some(rapidsConf.hostSpillStorageSize) } hostStorage = new RapidsHostMemoryStore(hostSpillStorageSize) diskStorage = new RapidsDiskStore(diskBlockManager) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala index d61f6061116..63f1b723ff1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsDiskStore.scala @@ -21,7 +21,7 @@ import java.nio.channels.FileChannel.MapMode import java.util.concurrent.ConcurrentHashMap import ai.rapids.cudf.{Cuda, HostMemoryBuffer, MemoryBuffer} -import com.nvidia.spark.rapids.Arm.{closeOnExcept, withResource} +import com.nvidia.spark.rapids.Arm.withResource import com.nvidia.spark.rapids.StorageTier.StorageTier import com.nvidia.spark.rapids.format.TableMeta @@ -192,7 +192,7 @@ class RapidsDiskStore(diskBlockManager: RapidsDiskBlockManager) val path = id.getDiskPath(diskBlockManager) withResource(new FileInputStream(path)) { fis => val (header, hostBuffer) = SerializedHostTableUtils.readTableHeaderAndBuffer(fis) - val hostCols = closeOnExcept(hostBuffer) { _ => + val hostCols = withResource(hostBuffer) { _ => SerializedHostTableUtils.buildHostColumns(header, hostBuffer, sparkTypes) } new ColumnarBatch(hostCols.toArray, header.getNumRows) diff --git a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala index dbdbb38f13c..05958a7e4b1 100644 --- a/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala +++ b/sql-plugin/src/main/scala/com/nvidia/spark/rapids/RapidsHostMemoryStore.scala @@ -22,7 +22,7 @@ import java.util.concurrent.ConcurrentHashMap import scala.collection.mutable -import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostColumnVector, HostMemoryBuffer, JCudfSerialization, MemoryBuffer, NvtxColor, NvtxRange, PinnedMemoryPool} +import ai.rapids.cudf.{Cuda, DeviceMemoryBuffer, HostColumnVector, HostMemoryBuffer, JCudfSerialization, MemoryBuffer, NvtxColor, NvtxRange} import com.nvidia.spark.rapids.Arm.{closeOnExcept, freeOnExcept, withResource} import com.nvidia.spark.rapids.SpillPriorities.{applyPriorityOffset, HOST_MEMORY_BUFFER_SPILL_OFFSET} import com.nvidia.spark.rapids.StorageTier.StorageTier @@ -36,29 +36,14 @@ import org.apache.spark.sql.vectorized.ColumnarBatch /** * A buffer store using host memory. * @param maxSize maximum size in bytes for all buffers in this store - * @param pageableMemoryPoolSize maximum size in bytes for the internal pageable memory pool */ class RapidsHostMemoryStore( - maxSize: Long) + maxSize: Option[Long]) extends RapidsBufferStore(StorageTier.HOST) { override protected def spillableOnAdd: Boolean = false - override def getMaxSize: Option[Long] = Some(maxSize) - - private def allocateHostBuffer( - size: Long, - preferPinned: Boolean = true): HostMemoryBuffer = { - var buffer: HostMemoryBuffer = null - if (preferPinned) { - buffer = PinnedMemoryPool.tryAllocate(size) - if (buffer != null) { - return buffer - } - } - - HostMemoryBuffer.allocate(size, false) - } + override def getMaxSize: Option[Long] = maxSize def addBuffer( id: RapidsBufferId, @@ -102,21 +87,23 @@ class RapidsHostMemoryStore( buffer: RapidsBuffer, catalog: RapidsBufferCatalog, stream: Cuda.Stream): Boolean = { - // this spillStore has a maximum size requirement (host only). We need to spill from it - // in order to make room for `buffer`. - val targetTotalSize = maxSize - buffer.memoryUsedBytes - if (targetTotalSize <= 0) { - // lets not spill to host when the buffer we are about - // to spill is larger than our limit - false - } else { - val amountSpilled = synchronousSpill(targetTotalSize, catalog, stream) - if (amountSpilled != 0) { - logDebug(s"Spilled $amountSpilled bytes from ${name} to make room for ${buffer.id}") - TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled) + maxSize.forall { ms => + // this spillStore has a maximum size requirement (host only). We need to spill from it + // in order to make room for `buffer`. + val targetTotalSize = ms - buffer.memoryUsedBytes + if (targetTotalSize < 0) { + // lets not spill to host when the buffer we are about + // to spill is larger than our limit + false + } else { + val amountSpilled = synchronousSpill(targetTotalSize, catalog, stream) + if (amountSpilled != 0) { + logDebug(s"Spilled $amountSpilled bytes from ${name} to make room for ${buffer.id}") + TrampolineUtil.incTaskMetricsDiskBytesSpilled(amountSpilled) + } + // if after spill we can fit the new buffer, return true + buffer.memoryUsedBytes <= (ms - currentSize) } - // if after spill we can fit the new buffer, return true - buffer.memoryUsedBytes <= (maxSize - currentSize) } } @@ -125,53 +112,58 @@ class RapidsHostMemoryStore( catalog: RapidsBufferCatalog, stream: Cuda.Stream): Option[RapidsBufferBase] = { val wouldFit = trySpillToMaximumSize(other, catalog, stream) - // TODO: this is disabled for now since subsequent work will tie this into - // our host allocator apis. - if (false && !wouldFit) { + if (!wouldFit) { // skip host - logWarning(s"Buffer ${other} with size ${other.memoryUsedBytes} does not fit " + + logWarning(s"Buffer $other with size ${other.memoryUsedBytes} does not fit " + s"in the host store, skipping tier.") None } else { withResource(other.getCopyIterator) { otherBufferIterator => val isChunked = otherBufferIterator.isChunked val totalCopySize = otherBufferIterator.getTotalCopySize - closeOnExcept(allocateHostBuffer(totalCopySize)) { hostBuffer => - withResource(new NvtxRange("spill to host", NvtxColor.BLUE)) { _ => - var hostOffset = 0L - val start = System.nanoTime() - while (otherBufferIterator.hasNext) { - val otherBuffer = otherBufferIterator.next() - withResource(otherBuffer) { _ => - otherBuffer match { - case devBuffer: DeviceMemoryBuffer => - hostBuffer.copyFromMemoryBufferAsync( - hostOffset, devBuffer, 0, otherBuffer.getLength, stream) - hostOffset += otherBuffer.getLength - case _ => - throw new IllegalStateException("copying from buffer without device memory") + closeOnExcept(HostAlloc.allocHighPriority(totalCopySize)) { hb => + hb.map { hostBuffer => + withResource(new NvtxRange("spill to host", NvtxColor.BLUE)) { _ => + var hostOffset = 0L + val start = System.nanoTime() + while (otherBufferIterator.hasNext) { + val otherBuffer = otherBufferIterator.next() + withResource(otherBuffer) { _ => + otherBuffer match { + case devBuffer: DeviceMemoryBuffer => + hostBuffer.copyFromMemoryBufferAsync( + hostOffset, devBuffer, 0, otherBuffer.getLength, stream) + hostOffset += otherBuffer.getLength + case _ => + throw new IllegalStateException("copying from buffer without device memory") + } } } + stream.sync() + val end = System.nanoTime() + val szMB = (totalCopySize.toDouble / 1024.0 / 1024.0).toLong + val bw = (szMB.toDouble / ((end - start).toDouble / 1000000000.0)).toLong + logDebug(s"Spill to host (chunked=$isChunked) " + + s"size=$szMB MiB bandwidth=$bw MiB/sec") } - stream.sync() - val end = System.nanoTime() - val szMB = (totalCopySize.toDouble / 1024.0 / 1024.0).toLong - val bw = (szMB.toDouble / ((end - start).toDouble / 1000000000.0)).toLong - logDebug(s"Spill to host (chunked=$isChunked) " + - s"size=$szMB MiB bandwidth=$bw MiB/sec") + new RapidsHostMemoryBuffer( + other.id, + totalCopySize, + other.meta, + applyPriorityOffset(other.getSpillPriority, HOST_MEMORY_BUFFER_SPILL_OFFSET), + hostBuffer) + }.orElse { + // skip host + logWarning(s"Buffer $other with size ${other.memoryUsedBytes} does not fit " + + s"in the host store, skipping tier.") + None } - Some(new RapidsHostMemoryBuffer( - other.id, - totalCopySize, - other.meta, - applyPriorityOffset(other.getSpillPriority, HOST_MEMORY_BUFFER_SPILL_OFFSET), - hostBuffer)) } } } } - def numBytesFree: Long = maxSize - currentSize + def numBytesFree: Option[Long] = maxSize.map(_ - currentSize) class RapidsHostMemoryBuffer( id: RapidsBufferId, diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala index 3666b85458e..61940ffd463 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsBufferCatalogSuite.scala @@ -215,7 +215,7 @@ class RapidsBufferCatalogSuite extends AnyFunSuite with MockitoSugar { withResource(spy(new RapidsDeviceMemoryStore)) { deviceStore => val mockStore = mock[RapidsBufferStore] withResource( - new RapidsHostMemoryStore(10000)) { hostStore => + new RapidsHostMemoryStore(Some(10000))) { hostStore => deviceStore.setSpillStore(hostStore) hostStore.setSpillStore(mockStore) val catalog = new RapidsBufferCatalog(deviceStore) diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala index 6adcbcc1909..b911bccb286 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsDiskStoreSuite.scala @@ -62,7 +62,7 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = spy(new RapidsBufferCatalog(devStore)) - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => devStore.setSpillStore(hostStore) withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager])) { diskStore => @@ -102,7 +102,7 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => devStore.setSpillStore(hostStore) withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager])) { @@ -144,7 +144,7 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => devStore.setSpillStore(hostStore) withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager])) { diskStore => @@ -288,7 +288,7 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => devStore.setSpillStore(hostStore) withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager])) { diskStore => hostStore.setSpillStore(diskStore) @@ -340,7 +340,7 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with MockitoSugar { testBufferFileDeletion(canShareDiskPaths = true) } - class AlwaysFailingRapidsHostMemoryStore extends RapidsHostMemoryStore(0L){ + class AlwaysFailingRapidsHostMemoryStore extends RapidsHostMemoryStore(Some(0L)){ override def createBuffer( other: RapidsBuffer, catalog: RapidsBufferCatalog, @@ -357,7 +357,7 @@ class RapidsDiskStoreSuite extends FunSuiteWithTempDir with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => devStore.setSpillStore(hostStore) withResource(new RapidsDiskStore(mock[RapidsDiskBlockManager])) { diskStore => diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala index 2d028f0cf7b..153b8da6556 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RapidsHostMemoryStoreSuite.scala @@ -22,8 +22,7 @@ import java.math.RoundingMode import ai.rapids.cudf.{ContiguousTable, Cuda, HostColumnVector, HostMemoryBuffer, Table} import com.nvidia.spark.rapids.Arm._ import org.mockito.{ArgumentCaptor, ArgumentMatchers} -import org.mockito.ArgumentMatchers.any -import org.mockito.Mockito.{never, spy, times, verify, when} +import org.mockito.Mockito.{spy, times, verify} import org.scalatest.funsuite.AnyFunSuite import org.scalatestplus.mockito.MockitoSugar @@ -92,10 +91,10 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val mockStore = mock[RapidsHostMemoryStore] withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = spy(new RapidsBufferCatalog(devStore)) - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => assertResult(0)(hostStore.currentSize) - assertResult(hostStoreMaxSize)(hostStore.numBytesFree) + assertResult(hostStoreMaxSize)(hostStore.numBytesFree.get) devStore.setSpillStore(hostStore) hostStore.setSpillStore(mockStore) @@ -110,7 +109,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { catalog.synchronousSpill(devStore, 0) assertResult(bufferSize)(hostStore.currentSize) - assertResult(hostStoreMaxSize - bufferSize)(hostStore.numBytesFree) + assertResult(hostStoreMaxSize - bufferSize)(hostStore.numBytesFree.get) verify(catalog, times(2)).registerNewBuffer(ArgumentMatchers.any[RapidsBuffer]) verify(catalog).removeBufferTier( ArgumentMatchers.eq(handle.id), ArgumentMatchers.eq(StorageTier.DEVICE)) @@ -130,7 +129,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val mockStore = mock[RapidsHostMemoryStore] withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => devStore.setSpillStore(hostStore) hostStore.setSpillStore(mockStore) @@ -165,7 +164,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val mockStore = mock[RapidsHostMemoryStore] withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore) - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => devStore.setSpillStore(hostStore) hostStore.setSpillStore(mockStore) @@ -204,7 +203,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val (catalog, devStore, hostStore, diskStore) = closeOnExcept(new RapidsDiskStore(bm)) { diskStore => closeOnExcept(new RapidsDeviceMemoryStore()) { devStore => - closeOnExcept(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + closeOnExcept(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => devStore.setSpillStore(hostStore) hostStore.setSpillStore(diskStore) val catalog = closeOnExcept( @@ -257,7 +256,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val spillPriority = -10 val hostStoreMaxSize = 1L * 1024 * 1024 val mockStore = mock[RapidsDiskStore] - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore, hostStore) devStore.setSpillStore(hostStore) @@ -283,7 +282,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 val bm = new RapidsDiskBlockManager(new SparkConf()) withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore, hostStore) devStore.setSpillStore(hostStore) @@ -314,7 +313,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 val bm = new RapidsDiskBlockManager(new SparkConf()) withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore, hostStore) devStore.setSpillStore(hostStore) @@ -347,7 +346,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 val bm = new RapidsDiskBlockManager(new SparkConf()) withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore, hostStore) devStore.setSpillStore(hostStore) @@ -395,7 +394,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 val bm = new RapidsDiskBlockManager(new SparkConf()) withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore, hostStore) devStore.setSpillStore(hostStore) @@ -437,7 +436,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 val bm = new RapidsDiskBlockManager(new SparkConf()) withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore, hostStore) devStore.setSpillStore(hostStore) @@ -480,7 +479,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 val bm = new RapidsDiskBlockManager(new SparkConf()) withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore, hostStore) val hostBatch = buildHostBatch() @@ -514,7 +513,7 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val hostStoreMaxSize = 1L * 1024 * 1024 val bm = new RapidsDiskBlockManager(new SparkConf()) withResource(new RapidsDiskStore(bm)) { diskStore => - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore, hostStore) val hostBatch = buildHostBatchWithDuplicate() @@ -549,62 +548,46 @@ class RapidsHostMemoryStoreSuite extends AnyFunSuite with MockitoSugar { val hostStoreMaxSize = 256 withResource(new RapidsDeviceMemoryStore) { devStore => val catalog = new RapidsBufferCatalog(devStore) - val mockStore = mock[RapidsBufferStore] - val mockBuff = mock[mockStore.RapidsBufferBase] - when(mockBuff.id).thenReturn(new RapidsBufferId { - override val tableId: Int = 0 - override def getDiskPath(diskBlockManager: RapidsDiskBlockManager): File = null - }) - when(mockStore.getMaxSize).thenAnswer(_ => None) - when(mockStore.copyBuffer(any(), any(), any())).thenReturn(Some(mockBuff)) - when(mockStore.tier) thenReturn (StorageTier.DISK) - withResource(new RapidsHostMemoryStore(hostStoreMaxSize)) { hostStore => + val spyStore = spy(new RapidsDiskStore(new RapidsDiskBlockManager(new SparkConf()))) + withResource(new RapidsHostMemoryStore(Some(hostStoreMaxSize))) { hostStore => devStore.setSpillStore(hostStore) - hostStore.setSpillStore(mockStore) + hostStore.setSpillStore(spyStore) var bigHandle: RapidsBufferHandle = null var bigTable = buildContiguousTable(1024 * 1024) - var smallTable = buildContiguousTable(1) closeOnExcept(bigTable) { _ => - closeOnExcept(smallTable) { _ => - // make a copy of the table so we can compare it later to the - // one reconstituted after the spill - val expectedBatch = - withResource(bigTable.getTable.contiguousSplit()) { expectedTable => - GpuColumnVector.from(expectedTable(0).getTable, sparkTypes) - } - withResource(expectedBatch) { _ => - bigHandle = withResource(bigTable) { _ => - catalog.addContiguousTable( - bigTable, - spillPriority) - } // close the bigTable so it can be spilled - bigTable = null - catalog.synchronousSpill(devStore, 0) - verify(mockStore, never()).copyBuffer( - ArgumentMatchers.any[RapidsBuffer], - ArgumentMatchers.any[RapidsBufferCatalog], - ArgumentMatchers.any[Cuda.Stream]) - withResource(catalog.acquireBuffer(bigHandle)) { buffer => - assertResult(StorageTier.HOST)(buffer.storageTier) - withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => - TestUtils.compareBatches(expectedBatch, actualBatch) - } - } + // make a copy of the table so we can compare it later to the + // one reconstituted after the spill + val expectedBatch = + withResource(bigTable.getTable.contiguousSplit()) { expectedTable => + GpuColumnVector.from(expectedTable(0).getTable, sparkTypes) } - withResource(smallTable) { _ => + withResource(expectedBatch) { _ => + bigHandle = withResource(bigTable) { _ => catalog.addContiguousTable( - smallTable, spillPriority, - false) - } // close the smallTable so it can be spilled - smallTable = null + bigTable, + spillPriority) + } // close the bigTable so it can be spilled + bigTable = null + withResource(catalog.acquireBuffer(bigHandle)) { buffer => + assertResult(StorageTier.DEVICE)(buffer.storageTier) + withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } catalog.synchronousSpill(devStore, 0) val rapidsBufferCaptor: ArgumentCaptor[RapidsBuffer] = ArgumentCaptor.forClass(classOf[RapidsBuffer]) - verify(mockStore).copyBuffer( + verify(spyStore).copyBuffer( rapidsBufferCaptor.capture(), ArgumentMatchers.any[RapidsBufferCatalog], ArgumentMatchers.any[Cuda.Stream]) assertResult(bigHandle.id)(rapidsBufferCaptor.getValue.id) + withResource(catalog.acquireBuffer(bigHandle)) { buffer => + assertResult(StorageTier.DISK)(buffer.storageTier) + withResource(buffer.getColumnarBatch(sparkTypes)) { actualBatch => + TestUtils.compareBatches(expectedBatch, actualBatch) + } + } } } } diff --git a/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala b/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala index d0db74d2e9d..a8e0ad550ea 100644 --- a/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala +++ b/tests/src/test/scala/com/nvidia/spark/rapids/RmmSparkRetrySuiteBase.scala @@ -37,7 +37,7 @@ class RmmSparkRetrySuiteBase extends AnyFunSuite with BeforeAndAfterEach { Rmm.initialize(RmmAllocationMode.CUDA_DEFAULT, null, 512 * 1024 * 1024) } deviceStorage = spy(new RapidsDeviceMemoryStore()) - val hostStore = new RapidsHostMemoryStore(1L * 1024 * 1024) + val hostStore = new RapidsHostMemoryStore(Some(1L * 1024 * 1024)) deviceStorage.setSpillStore(hostStore) val catalog = new RapidsBufferCatalog(deviceStorage, hostStore) // set these against the singleton so we close them later