Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[SPARKNLP-1096] Adding support to Microsoft Fabric for WordEmbeddings #14467

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions src/main/scala/com/johnsnowlabs/client/util/CloudHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
*/
package com.johnsnowlabs.client.util

import com.johnsnowlabs.nlp.util.io.CloudStorageType
import com.johnsnowlabs.nlp.util.io.CloudStorageType.CloudStorageType
import com.johnsnowlabs.nlp.util.io.{CloudStorageType, ResourceHelper}

import java.net.{URI, URL}

Expand Down Expand Up @@ -71,7 +71,8 @@ object CloudHelper {
}

def isCloudPath(uri: String): Boolean = {
isS3Path(uri) || isGCPStoragePath(uri) || isAzureBlobPath(uri)
val intraCloudPath = isIntraCloudPath(uri)
(isS3Path(uri) || isGCPStoragePath(uri) || isAzureBlobPath(uri)) && !intraCloudPath
}

def isS3Path(uri: String): Boolean = {
Expand All @@ -81,7 +82,16 @@ object CloudHelper {
private def isGCPStoragePath(uri: String): Boolean = uri.startsWith("gs://")

private def isAzureBlobPath(uri: String): Boolean = {
uri.startsWith("https://") && uri.contains(".blob.core.windows.net/")
(uri.startsWith("https://") && uri.contains(".blob.core.windows.net/")) || uri.startsWith(
"abfss://")
}

private def isIntraCloudPath(uri: String): Boolean = {
uri.startsWith("abfss://") && isMicrosoftFabric
}

def isMicrosoftFabric: Boolean = {
ResourceHelper.spark.conf.getAll.keys.exists(_.startsWith("spark.fabric"))
}

def cloudType(uri: String): CloudStorageType = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,11 @@ final class RocksDBConnection private (path: String) extends AutoCloseable {

def findLocalIndex: String = {
val tmpIndexStorageLocalPath = RocksDBConnection.getTmpIndexStorageLocalPath(path)
if (new File(tmpIndexStorageLocalPath).exists()) {
val tmpIndexStorageLocalPathExists = new File(tmpIndexStorageLocalPath).exists()
val pathExist = new File(path.stripPrefix("file:")).exists()
if (tmpIndexStorageLocalPathExists) {
tmpIndexStorageLocalPath
} else if (new File(path).exists()) {
} else if (pathExist) {
path
} else {
val localFromClusterPath = SparkFiles.get(path)
Expand Down
27 changes: 22 additions & 5 deletions src/main/scala/com/johnsnowlabs/storage/StorageHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
package com.johnsnowlabs.storage

import com.johnsnowlabs.client.CloudResources
import com.johnsnowlabs.client.util.CloudHelper
import org.apache.hadoop.fs.{FileSystem, FileUtil, Path}
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkContext, SparkFiles}
Expand All @@ -34,7 +35,6 @@ object StorageHelper {
database: String,
storageRef: String,
withinStorage: Boolean): RocksDBConnection = {

val dbFolder = StorageHelper.resolveStorageName(database, storageRef)
val source = StorageLocator.getStorageSerializedPath(
storageSourcePath.replaceAllLiterally("\\", "/"),
Expand All @@ -49,7 +49,11 @@ object StorageHelper {
locator.destinationScheme,
spark.sparkContext)

RocksDBConnection.getOrCreate(locator.clusterFileName)
val storagePath = if (locator.clusterFilePath.toString.startsWith("file:")) {
locator.clusterFilePath.toString
} else locator.clusterFileName

RocksDBConnection.getOrCreate(storagePath)
}

def save(
Expand Down Expand Up @@ -96,9 +100,19 @@ object StorageHelper {
}
case "s3a" =>
copyIndexToLocal(source, new Path(tmpIndexStorageLocalPath), sparkContext)
case _ => copyIndexToCluster(source, clusterFilePath, sparkContext)
case _ => {
copyIndexToCluster(source, clusterFilePath, sparkContext)
}
}
}
case "abfss" =>
if (clusterFilePath.toString.startsWith("file:")) {
val tmpIndexStorageLocalPath =
RocksDBConnection.getTmpIndexStorageLocalPath(clusterFileName)
copyIndexToCluster(source, new Path("file://" + tmpIndexStorageLocalPath), sparkContext)
} else {
copyIndexToLocal(source, clusterFilePath, sparkContext)
}
case _ => {
copyIndexToCluster(source, clusterFilePath, sparkContext)
}
Expand All @@ -120,7 +134,8 @@ object StorageHelper {
sourcePath: Path,
dst: Path,
sparkContext: SparkContext): String = {
if (!new File(SparkFiles.get(dst.getName)).exists()) {
val destinationInSpark = new File(SparkFiles.get(dst.getName)).exists()
if (!destinationInSpark) {
val srcFS = sourcePath.getFileSystem(sparkContext.hadoopConfiguration)
val dstFS = dst.getFileSystem(sparkContext.hadoopConfiguration)

Expand All @@ -138,7 +153,9 @@ object StorageHelper {
sparkContext.hadoopConfiguration)
}

sparkContext.addFile(dst.toString, recursive = true)
if (!CloudHelper.isMicrosoftFabric) {
sparkContext.addFile(dst.toString, recursive = true)
}
}
dst.toString
}
Expand Down
46 changes: 32 additions & 14 deletions src/main/scala/com/johnsnowlabs/storage/StorageLocator.scala
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,47 @@ case class StorageLocator(database: String, storageRef: String, sparkSession: Sp
if (tmpLocation.matches("s3[a]?:/.*")) {
tmpLocation
} else {
val tmpLocationPath = new Path(tmpLocation)
fileSystem.mkdirs(tmpLocationPath)
fileSystem.deleteOnExit(tmpLocationPath)
tmpLocation
fileSystem.getScheme match {
case "abfss" =>
if (tmpLocation.startsWith("abfss:")) {
tmpLocation
} else {
"file:///" + tmpLocation
}
case _ =>
val tmpLocationPath = new Path(tmpLocation)
fileSystem.mkdirs(tmpLocationPath)
fileSystem.deleteOnExit(tmpLocationPath)
tmpLocation
}
}
}

val clusterFileName: String = {
StorageHelper.resolveStorageName(database, storageRef)
}
val clusterFileName: String = { StorageHelper.resolveStorageName(database, storageRef) }

val clusterFilePath: Path = {
if (!getTmpLocation.matches("s3[a]?:/.*")) {
val scheme = Option(new Path(clusterTmpLocation).toUri.getScheme).getOrElse("")
scheme match {
case "dbfs" | "hdfs" =>
Path.mergePaths(new Path(clusterTmpLocation), new Path("/" + clusterFileName))
case _ =>
Path.mergePaths(
new Path(fileSystem.getUri.toString + clusterTmpLocation),
new Path("/" + clusterFileName))
case "dbfs" | "hdfs" => mergePaths()
case "file" =>
val uri = fileSystem.getUri.toString
if (uri.startsWith("abfss:")) { mergePaths() }
else { mergePaths(withFileSystem = true) }
case "abfss" => mergePaths()
case _ => mergePaths(withFileSystem = true)
}
} else new Path(clusterTmpLocation + "/" + clusterFileName)
} else {
new Path(clusterTmpLocation + "/" + clusterFileName)
}
}

private def mergePaths(withFileSystem: Boolean = false): Path = {
if (withFileSystem) {
Path.mergePaths(
new Path(fileSystem.getUri.toString + clusterTmpLocation),
new Path("/" + clusterFileName))
} else Path.mergePaths(new Path(clusterTmpLocation), new Path("/" + clusterFileName))
}

val destinationScheme: String = fileSystem.getScheme
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/com/johnsnowlabs/util/Version.scala
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ object Version {
def parse(str: String): Version = {
val parts = str
.replaceAll("-rc\\d", "")
.split('.')
.split("[.-]")
.takeWhile(p => isInteger(p))
.map(p => p.toInt)
.take(3)
.toList

Version(parts)
Expand Down
80 changes: 74 additions & 6 deletions src/test/scala/com/johnsnowlabs/nlp/util/VersionTest.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,14 @@

package com.johnsnowlabs.nlp.util

import com.johnsnowlabs.tags.FastTest
import com.johnsnowlabs.util.Version
import org.junit.Assert.{assertFalse, assertTrue}
import org.scalatest.flatspec.AnyFlatSpec

class VersionTest extends AnyFlatSpec {

"Version" should "cast to float version of 1 digit" in {
"Version" should "cast to float version of 1 digit" taggedAs FastTest in {

val actualVersion1 = Version(1).toFloat
val actualVersion15 = Version(15).toFloat
Expand All @@ -32,15 +33,15 @@ class VersionTest extends AnyFlatSpec {

}

it should "cast to float version of 2 digits" in {
it should "cast to float version of 2 digits" taggedAs FastTest in {
val actualVersion1_2 = Version(List(1, 2)).toFloat
val actualVersion2_7 = Version(List(2, 7)).toFloat

assert(actualVersion1_2 == 1.2f)
assert(actualVersion2_7 == 2.7f)
}

it should "cast to float version of 3 digits" in {
it should "cast to float version of 3 digits" taggedAs FastTest in {
val actualVersion1_2_5 = Version(List(1, 2, 5)).toFloat
val actualVersion3_2_0 = Version(List(3, 2, 0)).toFloat
val actualVersion2_0_6 = Version(List(2, 0, 6)).toFloat
Expand All @@ -50,13 +51,13 @@ class VersionTest extends AnyFlatSpec {
assert(actualVersion2_0_6 == 2.06f)
}

it should "raise error when casting to float version > 3 digits" in {
it should "raise error when casting to float version > 3 digits" taggedAs FastTest in {
assertThrows[UnsupportedOperationException] {
Version(List(3, 0, 2, 5)).toFloat
}
}

it should "be compatible for latest versions" in {
it should "be compatible for latest versions" taggedAs FastTest in {
var currentVersion = Version(List(1, 2, 3))
var modelVersion = Version(List(1, 2))

Expand All @@ -80,7 +81,7 @@ class VersionTest extends AnyFlatSpec {

}

it should "be not compatible for latest versions" in {
it should "be not compatible for latest versions" taggedAs FastTest in {
var currentVersion = Version(List(1, 2))
var modelVersion = Version(List(1, 2, 3))

Expand All @@ -103,4 +104,71 @@ class VersionTest extends AnyFlatSpec {
assertFalse(isNotCompatible)
}

it should "parse a version with fewer than 3 numbers" taggedAs FastTest in {
val someVersion = "3.2"
val expectedVersion = "3.2"
val expectedFloatVersion = 3.2f
val actualVersion = Version.parse(someVersion)

assert(expectedVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

it should "parse a version with 3 numbers" taggedAs FastTest in {
val someVersion = "3.4.2"
val expectedFloatVersion = 3.42f
val actualVersion = Version.parse(someVersion)

assert(someVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

it should "truncate a version to 3 digits when it has more than 3 digits" taggedAs FastTest in {
val someVersion = "3.5.1.5.4.20241007.4"
val expectedVersion = "3.5.1"
val expectedFloatVersion = 3.51f
val actualVersion = Version.parse(someVersion)

assert(expectedVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

it should "handle a version with missing parts" taggedAs FastTest in {
val someVersion = "3"
val expectedVersion = "3"
val expectedFloatVersion = 3.0f
val actualVersion = Version.parse(someVersion)

assert(expectedVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

it should "handle a version with 3 digits and additional suffix" taggedAs FastTest in {
val someVersion = "3.4.2-beta"
val expectedVersion = "3.4.2"
val expectedFloatVersion = 3.42f
val actualVersion = Version.parse(someVersion)

assert(expectedVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

it should "raise exception with non-numeric and no valid parts" taggedAs FastTest in {
val someVersion = "alpha.beta.gamma"

assertThrows[UnsupportedOperationException] {
Version.parse(someVersion).toFloat
}
}

it should "handle a version with mixed numeric and non-numeric parts" taggedAs FastTest in {
val someVersion = "3.4-alpha.2"
val expectedVersion = "3.4"
val expectedFloatVersion = 3.4f
val actualVersion = Version.parse(someVersion)

assert(expectedVersion == actualVersion.toString)
assert(expectedFloatVersion == actualVersion.toFloat)
}

}