From 7a2698fc407c09ad098f7c04ed7748e209310f67 Mon Sep 17 00:00:00 2001 From: Danilo Burbano Date: Wed, 27 Nov 2024 15:38:05 -0500 Subject: [PATCH] [SPARKNLP-1096] Adding support to Microsoft Fabric for WordEmbeddings storage index --- .../client/util/CloudHelper.scala | 16 +++- .../storage/RocksDBConnection.scala | 6 +- .../johnsnowlabs/storage/StorageHelper.scala | 27 +++++-- .../johnsnowlabs/storage/StorageLocator.scala | 46 +++++++---- .../scala/com/johnsnowlabs/util/Version.scala | 3 +- .../johnsnowlabs/nlp/util/VersionTest.scala | 80 +++++++++++++++++-- 6 files changed, 147 insertions(+), 31 deletions(-) diff --git a/src/main/scala/com/johnsnowlabs/client/util/CloudHelper.scala b/src/main/scala/com/johnsnowlabs/client/util/CloudHelper.scala index ece74507b86003..8ab06e13e08c2e 100644 --- a/src/main/scala/com/johnsnowlabs/client/util/CloudHelper.scala +++ b/src/main/scala/com/johnsnowlabs/client/util/CloudHelper.scala @@ -15,8 +15,8 @@ */ package com.johnsnowlabs.client.util -import com.johnsnowlabs.nlp.util.io.CloudStorageType import com.johnsnowlabs.nlp.util.io.CloudStorageType.CloudStorageType +import com.johnsnowlabs.nlp.util.io.{CloudStorageType, ResourceHelper} import java.net.{URI, URL} @@ -71,7 +71,8 @@ object CloudHelper { } def isCloudPath(uri: String): Boolean = { - isS3Path(uri) || isGCPStoragePath(uri) || isAzureBlobPath(uri) + val intraCloudPath = isIntraCloudPath(uri) + (isS3Path(uri) || isGCPStoragePath(uri) || isAzureBlobPath(uri)) && !intraCloudPath } def isS3Path(uri: String): Boolean = { @@ -81,7 +82,16 @@ object CloudHelper { private def isGCPStoragePath(uri: String): Boolean = uri.startsWith("gs://") private def isAzureBlobPath(uri: String): Boolean = { - uri.startsWith("https://") && uri.contains(".blob.core.windows.net/") + (uri.startsWith("https://") && uri.contains(".blob.core.windows.net/")) || uri.startsWith( + "abfss://") + } + + private def isIntraCloudPath(uri: String): Boolean = { + uri.startsWith("abfss://") && isMicrosoftFabric + } + + def isMicrosoftFabric: Boolean = { + ResourceHelper.spark.conf.getAll.keys.exists(_.startsWith("spark.fabric")) } def cloudType(uri: String): CloudStorageType = { diff --git a/src/main/scala/com/johnsnowlabs/storage/RocksDBConnection.scala b/src/main/scala/com/johnsnowlabs/storage/RocksDBConnection.scala index 412a2b377134fc..010be6d2da11b9 100644 --- a/src/main/scala/com/johnsnowlabs/storage/RocksDBConnection.scala +++ b/src/main/scala/com/johnsnowlabs/storage/RocksDBConnection.scala @@ -43,9 +43,11 @@ final class RocksDBConnection private (path: String) extends AutoCloseable { def findLocalIndex: String = { val tmpIndexStorageLocalPath = RocksDBConnection.getTmpIndexStorageLocalPath(path) - if (new File(tmpIndexStorageLocalPath).exists()) { + val tmpIndexStorageLocalPathExists = new File(tmpIndexStorageLocalPath).exists() + val pathExist = new File(path.stripPrefix("file:")).exists() + if (tmpIndexStorageLocalPathExists) { tmpIndexStorageLocalPath - } else if (new File(path).exists()) { + } else if (pathExist) { path } else { val localFromClusterPath = SparkFiles.get(path) diff --git a/src/main/scala/com/johnsnowlabs/storage/StorageHelper.scala b/src/main/scala/com/johnsnowlabs/storage/StorageHelper.scala index 3d40733637c18d..453ea4ed6bbda2 100644 --- a/src/main/scala/com/johnsnowlabs/storage/StorageHelper.scala +++ b/src/main/scala/com/johnsnowlabs/storage/StorageHelper.scala @@ -17,6 +17,7 @@ package com.johnsnowlabs.storage import com.johnsnowlabs.client.CloudResources +import com.johnsnowlabs.client.util.CloudHelper import org.apache.hadoop.fs.{FileSystem, FileUtil, Path} import org.apache.spark.sql.SparkSession import org.apache.spark.{SparkContext, SparkFiles} @@ -34,7 +35,6 @@ object StorageHelper { database: String, storageRef: String, withinStorage: Boolean): RocksDBConnection = { - val dbFolder = StorageHelper.resolveStorageName(database, storageRef) val source = StorageLocator.getStorageSerializedPath( storageSourcePath.replaceAllLiterally("\\", "/"), @@ -49,7 +49,11 @@ object StorageHelper { locator.destinationScheme, spark.sparkContext) - RocksDBConnection.getOrCreate(locator.clusterFileName) + val storagePath = if (locator.clusterFilePath.toString.startsWith("file:")) { + locator.clusterFilePath.toString + } else locator.clusterFileName + + RocksDBConnection.getOrCreate(storagePath) } def save( @@ -96,9 +100,19 @@ object StorageHelper { } case "s3a" => copyIndexToLocal(source, new Path(tmpIndexStorageLocalPath), sparkContext) - case _ => copyIndexToCluster(source, clusterFilePath, sparkContext) + case _ => { + copyIndexToCluster(source, clusterFilePath, sparkContext) + } } } + case "abfss" => + if (clusterFilePath.toString.startsWith("file:")) { + val tmpIndexStorageLocalPath = + RocksDBConnection.getTmpIndexStorageLocalPath(clusterFileName) + copyIndexToCluster(source, new Path("file://" + tmpIndexStorageLocalPath), sparkContext) + } else { + copyIndexToLocal(source, clusterFilePath, sparkContext) + } case _ => { copyIndexToCluster(source, clusterFilePath, sparkContext) } @@ -120,7 +134,8 @@ object StorageHelper { sourcePath: Path, dst: Path, sparkContext: SparkContext): String = { - if (!new File(SparkFiles.get(dst.getName)).exists()) { + val destinationInSpark = new File(SparkFiles.get(dst.getName)).exists() + if (!destinationInSpark) { val srcFS = sourcePath.getFileSystem(sparkContext.hadoopConfiguration) val dstFS = dst.getFileSystem(sparkContext.hadoopConfiguration) @@ -138,7 +153,9 @@ object StorageHelper { sparkContext.hadoopConfiguration) } - sparkContext.addFile(dst.toString, recursive = true) + if (!CloudHelper.isMicrosoftFabric) { + sparkContext.addFile(dst.toString, recursive = true) + } } dst.toString } diff --git a/src/main/scala/com/johnsnowlabs/storage/StorageLocator.scala b/src/main/scala/com/johnsnowlabs/storage/StorageLocator.scala index e651bac31b524f..fcf2de7150abe7 100644 --- a/src/main/scala/com/johnsnowlabs/storage/StorageLocator.scala +++ b/src/main/scala/com/johnsnowlabs/storage/StorageLocator.scala @@ -29,29 +29,47 @@ case class StorageLocator(database: String, storageRef: String, sparkSession: Sp if (tmpLocation.matches("s3[a]?:/.*")) { tmpLocation } else { - val tmpLocationPath = new Path(tmpLocation) - fileSystem.mkdirs(tmpLocationPath) - fileSystem.deleteOnExit(tmpLocationPath) - tmpLocation + fileSystem.getScheme match { + case "abfss" => + if (tmpLocation.startsWith("abfss:")) { + tmpLocation + } else { + "file:///" + tmpLocation + } + case _ => + val tmpLocationPath = new Path(tmpLocation) + fileSystem.mkdirs(tmpLocationPath) + fileSystem.deleteOnExit(tmpLocationPath) + tmpLocation + } } } - val clusterFileName: String = { - StorageHelper.resolveStorageName(database, storageRef) - } + val clusterFileName: String = { StorageHelper.resolveStorageName(database, storageRef) } val clusterFilePath: Path = { if (!getTmpLocation.matches("s3[a]?:/.*")) { val scheme = Option(new Path(clusterTmpLocation).toUri.getScheme).getOrElse("") scheme match { - case "dbfs" | "hdfs" => - Path.mergePaths(new Path(clusterTmpLocation), new Path("/" + clusterFileName)) - case _ => - Path.mergePaths( - new Path(fileSystem.getUri.toString + clusterTmpLocation), - new Path("/" + clusterFileName)) + case "dbfs" | "hdfs" => mergePaths() + case "file" => + val uri = fileSystem.getUri.toString + if (uri.startsWith("abfss:")) { mergePaths() } + else { mergePaths(withFileSystem = true) } + case "abfss" => mergePaths() + case _ => mergePaths(withFileSystem = true) } - } else new Path(clusterTmpLocation + "/" + clusterFileName) + } else { + new Path(clusterTmpLocation + "/" + clusterFileName) + } + } + + private def mergePaths(withFileSystem: Boolean = false): Path = { + if (withFileSystem) { + Path.mergePaths( + new Path(fileSystem.getUri.toString + clusterTmpLocation), + new Path("/" + clusterFileName)) + } else Path.mergePaths(new Path(clusterTmpLocation), new Path("/" + clusterFileName)) } val destinationScheme: String = fileSystem.getScheme diff --git a/src/main/scala/com/johnsnowlabs/util/Version.scala b/src/main/scala/com/johnsnowlabs/util/Version.scala index 7e10c7da8807ed..83a637283a5416 100644 --- a/src/main/scala/com/johnsnowlabs/util/Version.scala +++ b/src/main/scala/com/johnsnowlabs/util/Version.scala @@ -48,9 +48,10 @@ object Version { def parse(str: String): Version = { val parts = str .replaceAll("-rc\\d", "") - .split('.') + .split("[.-]") .takeWhile(p => isInteger(p)) .map(p => p.toInt) + .take(3) .toList Version(parts) diff --git a/src/test/scala/com/johnsnowlabs/nlp/util/VersionTest.scala b/src/test/scala/com/johnsnowlabs/nlp/util/VersionTest.scala index 0e947b4e567a44..e4c16f15c2214a 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/util/VersionTest.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/util/VersionTest.scala @@ -16,13 +16,14 @@ package com.johnsnowlabs.nlp.util +import com.johnsnowlabs.tags.FastTest import com.johnsnowlabs.util.Version import org.junit.Assert.{assertFalse, assertTrue} import org.scalatest.flatspec.AnyFlatSpec class VersionTest extends AnyFlatSpec { - "Version" should "cast to float version of 1 digit" in { + "Version" should "cast to float version of 1 digit" taggedAs FastTest in { val actualVersion1 = Version(1).toFloat val actualVersion15 = Version(15).toFloat @@ -32,7 +33,7 @@ class VersionTest extends AnyFlatSpec { } - it should "cast to float version of 2 digits" in { + it should "cast to float version of 2 digits" taggedAs FastTest in { val actualVersion1_2 = Version(List(1, 2)).toFloat val actualVersion2_7 = Version(List(2, 7)).toFloat @@ -40,7 +41,7 @@ class VersionTest extends AnyFlatSpec { assert(actualVersion2_7 == 2.7f) } - it should "cast to float version of 3 digits" in { + it should "cast to float version of 3 digits" taggedAs FastTest in { val actualVersion1_2_5 = Version(List(1, 2, 5)).toFloat val actualVersion3_2_0 = Version(List(3, 2, 0)).toFloat val actualVersion2_0_6 = Version(List(2, 0, 6)).toFloat @@ -50,13 +51,13 @@ class VersionTest extends AnyFlatSpec { assert(actualVersion2_0_6 == 2.06f) } - it should "raise error when casting to float version > 3 digits" in { + it should "raise error when casting to float version > 3 digits" taggedAs FastTest in { assertThrows[UnsupportedOperationException] { Version(List(3, 0, 2, 5)).toFloat } } - it should "be compatible for latest versions" in { + it should "be compatible for latest versions" taggedAs FastTest in { var currentVersion = Version(List(1, 2, 3)) var modelVersion = Version(List(1, 2)) @@ -80,7 +81,7 @@ class VersionTest extends AnyFlatSpec { } - it should "be not compatible for latest versions" in { + it should "be not compatible for latest versions" taggedAs FastTest in { var currentVersion = Version(List(1, 2)) var modelVersion = Version(List(1, 2, 3)) @@ -103,4 +104,71 @@ class VersionTest extends AnyFlatSpec { assertFalse(isNotCompatible) } + it should "parse a version with fewer than 3 numbers" taggedAs FastTest in { + val someVersion = "3.2" + val expectedVersion = "3.2" + val expectedFloatVersion = 3.2f + val actualVersion = Version.parse(someVersion) + + assert(expectedVersion == actualVersion.toString) + assert(expectedFloatVersion == actualVersion.toFloat) + } + + it should "parse a version with 3 numbers" taggedAs FastTest in { + val someVersion = "3.4.2" + val expectedFloatVersion = 3.42f + val actualVersion = Version.parse(someVersion) + + assert(someVersion == actualVersion.toString) + assert(expectedFloatVersion == actualVersion.toFloat) + } + + it should "truncate a version to 3 digits when it has more than 3 digits" taggedAs FastTest in { + val someVersion = "3.5.1.5.4.20241007.4" + val expectedVersion = "3.5.1" + val expectedFloatVersion = 3.51f + val actualVersion = Version.parse(someVersion) + + assert(expectedVersion == actualVersion.toString) + assert(expectedFloatVersion == actualVersion.toFloat) + } + + it should "handle a version with missing parts" taggedAs FastTest in { + val someVersion = "3" + val expectedVersion = "3" + val expectedFloatVersion = 3.0f + val actualVersion = Version.parse(someVersion) + + assert(expectedVersion == actualVersion.toString) + assert(expectedFloatVersion == actualVersion.toFloat) + } + + it should "handle a version with 3 digits and additional suffix" taggedAs FastTest in { + val someVersion = "3.4.2-beta" + val expectedVersion = "3.4.2" + val expectedFloatVersion = 3.42f + val actualVersion = Version.parse(someVersion) + + assert(expectedVersion == actualVersion.toString) + assert(expectedFloatVersion == actualVersion.toFloat) + } + + it should "raise exception with non-numeric and no valid parts" taggedAs FastTest in { + val someVersion = "alpha.beta.gamma" + + assertThrows[UnsupportedOperationException] { + Version.parse(someVersion).toFloat + } + } + + it should "handle a version with mixed numeric and non-numeric parts" taggedAs FastTest in { + val someVersion = "3.4-alpha.2" + val expectedVersion = "3.4" + val expectedFloatVersion = 3.4f + val actualVersion = Version.parse(someVersion) + + assert(expectedVersion == actualVersion.toString) + assert(expectedFloatVersion == actualVersion.toFloat) + } + }