Skip to content

Commit

Permalink
[SPARKNLP-1096] Adding support to Microsoft Fabric for WordEmbeddings…
Browse files Browse the repository at this point in the history
… storage index (#14467)
  • Loading branch information
danilojsl authored Dec 9, 2024
1 parent aefe88f commit 3596ded
Show file tree
Hide file tree
Showing 6 changed files with 147 additions and 31 deletions.
16 changes: 13 additions & 3 deletions src/main/scala/com/johnsnowlabs/client/util/CloudHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/
package com.johnsnowlabs.client.util

import com.johnsnowlabs.nlp.util.io.CloudStorageType
import com.johnsnowlabs.nlp.util.io.CloudStorageType.CloudStorageType
import com.johnsnowlabs.nlp.util.io.{CloudStorageType, ResourceHelper}

import java.net.{URI, URL}

Expand Down Expand Up @@ -71,7 +71,8 @@ object CloudHelper {
}

def isCloudPath(uri: String): Boolean = {
isS3Path(uri) || isGCPStoragePath(uri) || isAzureBlobPath(uri)
val intraCloudPath = isIntraCloudPath(uri)
(isS3Path(uri) || isGCPStoragePath(uri) || isAzureBlobPath(uri)) && !intraCloudPath
}

def isS3Path(uri: String): Boolean = {
Expand All @@ -81,7 +82,16 @@ object CloudHelper {
private def isGCPStoragePath(uri: String): Boolean = uri.startsWith("gs://")

private def isAzureBlobPath(uri: String): Boolean = {
uri.startsWith("https://") && uri.contains(".blob.core.windows.net/")
(uri.startsWith("https://") && uri.contains(".blob.core.windows.net/")) || uri.startsWith(
"abfss://")
}

private def isIntraCloudPath(uri: String): Boolean = {
uri.startsWith("abfss://") && isMicrosoftFabric
}

def isMicrosoftFabric: Boolean = {
ResourceHelper.spark.conf.getAll.keys.exists(_.startsWith("spark.fabric"))
}

def cloudType(uri: String): CloudStorageType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ final class RocksDBConnection private (path: String) extends AutoCloseable {

def findLocalIndex: String = {
val tmpIndexStorageLocalPath = RocksDBConnection.getTmpIndexStorageLocalPath(path)
if (new File(tmpIndexStorageLocalPath).exists()) {
val tmpIndexStorageLocalPathExists = new File(tmpIndexStorageLocalPath).exists()
val pathExist = new File(path.stripPrefix("file:")).exists()
if (tmpIndexStorageLocalPathExists) {
tmpIndexStorageLocalPath
} else if (new File(path).exists()) {
} else if (pathExist) {
path
} else {
val localFromClusterPath = SparkFiles.get(path)
Expand Down
27 changes: 22 additions & 5 deletions src/main/scala/com/johnsnowlabs/storage/StorageHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.johnsnowlabs.storage

import com.johnsnowlabs.client.CloudResources
import com.johnsnowlabs.client.util.CloudHelper
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkContext, SparkFiles}
Expand All @@ -34,7 +35,6 @@ object StorageHelper {
database: String,
storageRef: String,
withinStorage: Boolean): RocksDBConnection = {

val dbFolder = StorageHelper.resolveStorageName(database, storageRef)
val source = StorageLocator.getStorageSerializedPath(
storageSourcePath.replaceAllLiterally("\\", "/"),
Expand All @@ -49,7 +49,11 @@ object StorageHelper {
locator.destinationScheme,
spark.sparkContext)

RocksDBConnection.getOrCreate(locator.clusterFileName)
val storagePath = if (locator.clusterFilePath.toString.startsWith("file:")) {
locator.clusterFilePath.toString
} else locator.clusterFileName

RocksDBConnection.getOrCreate(storagePath)
}

def save(
Expand Down Expand Up @@ -96,9 +100,19 @@ object StorageHelper {
}
case "s3a" =>
copyIndexToLocal(source, new Path(tmpIndexStorageLocalPath), sparkContext)
case _ => copyIndexToCluster(source, clusterFilePath, sparkContext)
case _ => {
copyIndexToCluster(source, clusterFilePath, sparkContext)
}
}
}
case "abfss" =>
if (clusterFilePath.toString.startsWith("file:")) {
val tmpIndexStorageLocalPath =
RocksDBConnection.getTmpIndexStorageLocalPath(clusterFileName)
copyIndexToCluster(source, new Path("file://" + tmpIndexStorageLocalPath), sparkContext)
} else {
copyIndexToLocal(source, clusterFilePath, sparkContext)
}
case _ => {
copyIndexToCluster(source, clusterFilePath, sparkContext)
}
Expand All @@ -120,7 +134,8 @@ object StorageHelper {
sourcePath: Path,
dst: Path,
sparkContext: SparkContext): String = {
if (!new File(SparkFiles.get(dst.getName)).exists()) {
val destinationInSpark = new File(SparkFiles.get(dst.getName)).exists()
if (!destinationInSpark) {
val srcFS = sourcePath.getFileSystem(sparkContext.hadoopConfiguration)
val dstFS = dst.getFileSystem(sparkContext.hadoopConfiguration)

Expand All @@ -138,7 +153,9 @@ object StorageHelper {
sparkContext.hadoopConfiguration)
}

sparkContext.addFile(dst.toString, recursive = true)
if (!CloudHelper.isMicrosoftFabric) {
sparkContext.addFile(dst.toString, recursive = true)
}
}
dst.toString
}
Expand Down
46 changes: 32 additions & 14 deletions src/main/scala/com/johnsnowlabs/storage/StorageLocator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,47 @@ case class StorageLocator(database: String, storageRef: String, sparkSession: Sp
if (tmpLocation.matches("s3[a]?:/.*")) {
tmpLocation
} else {
val tmpLocationPath = new Path(tmpLocation)
fileSystem.mkdirs(tmpLocationPath)
fileSystem.deleteOnExit(tmpLocationPath)
tmpLocation
fileSystem.getScheme match {
case "abfss" =>
if (tmpLocation.startsWith("abfss:")) {
tmpLocation
} else {
"file:///" + tmpLocation
}
case _ =>
val tmpLocationPath = new Path(tmpLocation)
fileSystem.mkdirs(tmpLocationPath)
fileSystem.deleteOnExit(tmpLocationPath)
tmpLocation
}
}
}

val clusterFileName: String = {
StorageHelper.resolveStorageName(database, storageRef)
}
val clusterFileName: String = { StorageHelper.resolveStorageName(database, storageRef) }

val clusterFilePath: Path = {
if (!getTmpLocation.matches("s3[a]?:/.*")) {
val scheme = Option(new Path(clusterTmpLocation).toUri.getScheme).getOrElse("")
scheme match {
case "dbfs" | "hdfs" =>
Path.mergePaths(new Path(clusterTmpLocation), new Path("/" + clusterFileName))
case _ =>
Path.mergePaths(
new Path(fileSystem.getUri.toString + clusterTmpLocation),
new Path("/" + clusterFileName))
case "dbfs" | "hdfs" => mergePaths()
case "file" =>
val uri = fileSystem.getUri.toString
if (uri.startsWith("abfss:")) { mergePaths() }
else { mergePaths(withFileSystem = true) }
case "abfss" => mergePaths()
case _ => mergePaths(withFileSystem = true)
}
} else new Path(clusterTmpLocation + "/" + clusterFileName)
} else {
new Path(clusterTmpLocation + "/" + clusterFileName)
}
}

private def mergePaths(withFileSystem: Boolean = false): Path = {
if (withFileSystem) {
Path.mergePaths(
new Path(fileSystem.getUri.toString + clusterTmpLocation),
new Path("/" + clusterFileName))
} else Path.mergePaths(new Path(clusterTmpLocation), new Path("/" + clusterFileName))
}

val destinationScheme: String = fileSystem.getScheme
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/com/johnsnowlabs/util/Version.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ object Version {
def parse(str: String): Version = {
val parts = str
.replaceAll("-rc\\d", "")
.split('.')
.split("[.-]")
.takeWhile(p => isInteger(p))
.map(p => p.toInt)
.take(3)
.toList

Version(parts)
Expand Down
80 changes: 74 additions & 6 deletions src/test/scala/com/johnsnowlabs/nlp/util/VersionTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

package com.johnsnowlabs.nlp.util

import com.johnsnowlabs.tags.FastTest
import com.johnsnowlabs.util.Version
import org.junit.Assert.{assertFalse, assertTrue}
import org.scalatest.flatspec.AnyFlatSpec

class VersionTest extends AnyFlatSpec {

"Version" should "cast to float version of 1 digit" in {
"Version" should "cast to float version of 1 digit" taggedAs FastTest in {

val actualVersion1 = Version(1).toFloat
val actualVersion15 = Version(15).toFloat
Expand All @@ -32,15 +33,15 @@ class VersionTest extends AnyFlatSpec {

}

it should "cast to float version of 2 digits" in {
it should "cast to float version of 2 digits" taggedAs FastTest in {
val actualVersion1_2 = Version(List(1, 2)).toFloat
val actualVersion2_7 = Version(List(2, 7)).toFloat

assert(actualVersion1_2 == 1.2f)
assert(actualVersion2_7 == 2.7f)
}

it should "cast to float version of 3 digits" in {
it should "cast to float version of 3 digits" taggedAs FastTest in {
val actualVersion1_2_5 = Version(List(1, 2, 5)).toFloat
val actualVersion3_2_0 = Version(List(3, 2, 0)).toFloat
val actualVersion2_0_6 = Version(List(2, 0, 6)).toFloat
Expand All @@ -50,13 +51,13 @@ class VersionTest extends AnyFlatSpec {
assert(actualVersion2_0_6 == 2.06f)
}

it should "raise error when casting to float version > 3 digits" in {
it should "raise error when casting to float version > 3 digits" taggedAs FastTest in {
assertThrows[UnsupportedOperationException] {
Version(List(3, 0, 2, 5)).toFloat
}
}

it should "be compatible for latest versions" in {
it should "be compatible for latest versions" taggedAs FastTest in {
var currentVersion = Version(List(1, 2, 3))
var modelVersion = Version(List(1, 2))

Expand All @@ -80,7 +81,7 @@ class VersionTest extends AnyFlatSpec {

}

it should "be not compatible for latest versions" in {
it should "be not compatible for latest versions" taggedAs FastTest in {
var currentVersion = Version(List(1, 2))
var modelVersion = Version(List(1, 2, 3))

Expand All @@ -103,4 +104,71 @@ class VersionTest extends AnyFlatSpec {
assertFalse(isNotCompatible)
}

it should "parse a version with fewer than 3 numbers" taggedAs FastTest in {
val someVersion = "3.2"
val expectedVersion = "3.2"
val expectedFloatVersion = 3.2f
val actualVersion = Version.parse(someVersion)

assert(expectedVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

it should "parse a version with 3 numbers" taggedAs FastTest in {
val someVersion = "3.4.2"
val expectedFloatVersion = 3.42f
val actualVersion = Version.parse(someVersion)

assert(someVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

it should "truncate a version to 3 digits when it has more than 3 digits" taggedAs FastTest in {
val someVersion = "3.5.1.5.4.20241007.4"
val expectedVersion = "3.5.1"
val expectedFloatVersion = 3.51f
val actualVersion = Version.parse(someVersion)

assert(expectedVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

it should "handle a version with missing parts" taggedAs FastTest in {
val someVersion = "3"
val expectedVersion = "3"
val expectedFloatVersion = 3.0f
val actualVersion = Version.parse(someVersion)

assert(expectedVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

it should "handle a version with 3 digits and additional suffix" taggedAs FastTest in {
val someVersion = "3.4.2-beta"
val expectedVersion = "3.4.2"
val expectedFloatVersion = 3.42f
val actualVersion = Version.parse(someVersion)

assert(expectedVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

it should "raise exception with non-numeric and no valid parts" taggedAs FastTest in {
val someVersion = "alpha.beta.gamma"

assertThrows[UnsupportedOperationException] {
Version.parse(someVersion).toFloat
}
}

it should "handle a version with mixed numeric and non-numeric parts" taggedAs FastTest in {
val someVersion = "3.4-alpha.2"
val expectedVersion = "3.4"
val expectedFloatVersion = 3.4f
val actualVersion = Version.parse(someVersion)

assert(expectedVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

}

0 comments on commit 3596ded

Please sign in to comment.