Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unzip param to downloadModelDirectly in ResourceDownloader #13796

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions python/sparknlp/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,9 @@ def __init__(self, covered, total, percentage):


class _DownloadModelDirectly(ExtendedJavaWrapper):
def __init__(self, name, remote_loc="public/models"):
def __init__(self, name, remote_loc="public/models", unzip=True):
super(_DownloadModelDirectly, self).__init__(
"com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.downloadModelDirectly", name, remote_loc)
"com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.downloadModelDirectly", name, remote_loc, unzip)


class _DownloadModel(ExtendedJavaWrapper):
Expand Down
12 changes: 8 additions & 4 deletions python/sparknlp/pretrained/resource_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,21 @@ def downloadModel(reader, name, language, remote_loc=None, j_dwn='PythonResource
return reader(classname=None, java_model=j_obj)

@staticmethod
def downloadModelDirectly(name, remote_loc="public/models"):
def downloadModelDirectly(name, remote_loc="public/models", unzip=True):
"""Downloads a model directly to the cache folder.

You can use to copy-paste the s3 URI from the model hub and download the model.
For available s3 URI and models, please see the `Models Hub <https://sparknlp.org/models>`__.
Parameters
----------
name : str
Name of the model
Name of the model or s3 URI
remote_loc : str, optional
Directory of the remote Spark NLP Folder, by default "public/models"
unzip : Bool, optional
Used to unzip model, by default 'True'
"""
_internal._DownloadModelDirectly(name, remote_loc).apply()
_internal._DownloadModelDirectly(name, remote_loc, unzip).apply()


@staticmethod
def downloadPipeline(name, language, remote_loc=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ trait ResourceDownloader {

def downloadMetadataIfNeed(folder: String): List[ResourceMetadata]

def downloadAndUnzipFile(s3FilePath: String): Option[String]
def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean = true): Option[String]

val fileSystem: FileSystem = ResourceDownloader.fileSystem

Expand Down Expand Up @@ -126,6 +126,14 @@ object ResourceDownloader {
var communityDownloader: ResourceDownloader =
new S3ResourceDownloader(s3BucketCommunity, s3Path, cacheFolder, "community")

def getResourceDownloader(folder: String): ResourceDownloader = {
folder match {
case this.publicLoc => publicDownloader
case loc if loc.startsWith("@") => communityDownloader
case _ => privateDownloader
}
}

/** Reset the cache and recreate ResourceDownloader S3 credentials */
def resetResourceDownloader(): Unit = {
cache.empty
Expand Down Expand Up @@ -412,11 +420,7 @@ object ResourceDownloader {
}

private def getResourceMetadata(location: String): List[ResourceMetadata] = {
location match {
case this.publicLoc => publicDownloader.downloadMetadataIfNeed(location)
case loc if loc.startsWith("@") => communityDownloader.downloadMetadataIfNeed(location)
case _ => privateDownloader.downloadMetadataIfNeed(location)
}
getResourceDownloader(location).downloadMetadataIfNeed(location)
}

def showAvailableAnnotators(folder: String = publicLoc): Unit = {
Expand Down Expand Up @@ -450,20 +454,10 @@ object ResourceDownloader {
*/
def downloadResource(request: ResourceRequest): String = {
val future = Future {
if (request.folder.equals(publicLoc)) {
publicDownloader.download(request)
} else if (request.folder.startsWith("@")) {
val actualLoc = request.folder.replace("@", "")
val updatedRequest = ResourceRequest(
request.name,
request.language,
folder = actualLoc,
request.libVersion,
request.sparkVersion)
communityDownloader.download(updatedRequest)
} else {
privateDownloader.download(request)
}
val updatedRequest: ResourceRequest = if (request.folder.startsWith("@")) {
request.copy(folder = request.folder.replace("@", ""))
} else request
getResourceDownloader(request.folder).download(updatedRequest)
}

var downloadFinished = false
Expand Down Expand Up @@ -497,22 +491,19 @@ object ResourceDownloader {
path.get
}

/** Downloads a resource from the default S3 bucket in the cache pretrained folder.
/** Downloads a model from the default S3 bucket to the cache pretrained folder.
* @param model
* the name of the key in the S3 bucket
* the name of the key in the S3 bucket or s3 URI
* @param folder
* the language of the model
* @return
* the path to the downloaded file
* the folder of the model
* @param unzip
* used to unzip the model, by default true
*/
def downloadModelDirectly(model: String, folder: String = publicLoc): Unit = {
if (folder.equals(publicLoc)) {
publicDownloader.downloadAndUnzipFile(model)
} else if (folder.startsWith("@")) {
communityDownloader.downloadAndUnzipFile(model)
} else {
privateDownloader.downloadAndUnzipFile(model)
}
def downloadModelDirectly(
model: String,
folder: String = publicLoc,
unzip: Boolean = true): Unit = {
getResourceDownloader(folder).downloadAndUnzipFile(model, unzip)
}

def downloadModel[TModel <: PipelineStage](
Expand Down Expand Up @@ -569,17 +560,14 @@ object ResourceDownloader {
}

def getDownloadSize(resourceRequest: ResourceRequest): String = {
var size: Option[Long] = None
val folder = resourceRequest.folder
if (folder.equals(publicLoc)) {
size = publicDownloader.getDownloadSize(resourceRequest)
} else if (folder.startsWith("@")) {
val actualLoc = folder.replace("@", "")
size = communityDownloader.getDownloadSize(
ResourceRequest(resourceRequest.name, resourceRequest.language, actualLoc))
} else {
size = privateDownloader.getDownloadSize(resourceRequest)
}

val updatedResourceRequest: ResourceRequest = if (resourceRequest.folder.startsWith("@")) {
resourceRequest.copy(folder = resourceRequest.folder.replace("@", ""))
} else resourceRequest

val size = getResourceDownloader(resourceRequest.folder)
.getDownloadSize(updatedResourceRequest)

size match {
case Some(downloadBytes) => FileHelper.getHumanReadableFileSize(downloadBytes)
case None => "-1"
Expand Down Expand Up @@ -771,9 +759,12 @@ object PythonResourceDownloader {
ResourceDownloader.clearCache(name, Option(language), correctedFolder)
}

def downloadModelDirectly(model: String, remoteLoc: String = null): Unit = {
def downloadModelDirectly(
model: String,
remoteLoc: String = null,
unzip: Boolean = true): Unit = {
val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
ResourceDownloader.downloadModelDirectly(model, correctedFolder)
ResourceDownloader.downloadModelDirectly(model, correctedFolder, unzip)
}

def showUnCategorizedResources(): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,17 @@ class S3ResourceDownloader(
}
}

def downloadAndUnzipFile(s3FilePath: String): Option[String] = {
def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean): Option[String] = {
// handle s3FilePath options:
// 1--> s3://auxdata.johnsnowlabs.com/public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip
// 2--> public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip

val newS3FilePath = if (s3FilePath.startsWith("s3")) {
ResourceHelper.parseS3URI(s3FilePath)._2
} else s3FilePath

val s3File = newS3FilePath.split("/").last

val s3File = s3FilePath.split("/").last
val destinationFile = new Path(cachePath.toString + "/" + s3File)
val splitPath = destinationFile.toString.substring(0, destinationFile.toString.length - 4)

Expand All @@ -318,40 +326,41 @@ class S3ResourceDownloader(
val tmpFileName = Files.createTempFile(s3File, "").toString
val tmpFile = new File(tmpFileName)

val newStrfilePath: String = s3FilePath
val newStrfilePath: String = newS3FilePath
val mybucket: String = bucket
// 2. Download content to tmp file
awsGateway.getS3Object(mybucket, newStrfilePath, tmpFile)

// 4. Move tmp file to destination
fileSystem.moveFromLocalFile(new Path(tmpFile.toString), destinationFile)
}
if (unzip) {
if (!fileSystem.exists(new Path(splitPath))) {
val zis = new ZipInputStream(fileSystem.open(destinationFile))
val buf = Array.ofDim[Byte](1024)
var entry = zis.getNextEntry
require(
destinationFile.toString.substring(destinationFile.toString.length - 4) == ".zip",
"Not a zip file.")

if (!fileSystem.exists(new Path(splitPath))) {
val zis = new ZipInputStream(fileSystem.open(destinationFile))
val buf = Array.ofDim[Byte](1024)
var entry = zis.getNextEntry
require(
destinationFile.toString.substring(destinationFile.toString.length - 4) == ".zip",
"Not a zip file.")

while (entry != null) {
if (!entry.isDirectory) {
val entryName = new Path(splitPath, entry.getName)
val outputStream = fileSystem.create(entryName)
var bytesRead = zis.read(buf, 0, 1024)
while (bytesRead > -1) {
outputStream.write(buf, 0, bytesRead)
bytesRead = zis.read(buf, 0, 1024)
while (entry != null) {
if (!entry.isDirectory) {
val entryName = new Path(splitPath, entry.getName)
val outputStream = fileSystem.create(entryName)
var bytesRead = zis.read(buf, 0, 1024)
while (bytesRead > -1) {
outputStream.write(buf, 0, bytesRead)
bytesRead = zis.read(buf, 0, 1024)
}
outputStream.close()
}
outputStream.close()
zis.closeEntry()
entry = zis.getNextEntry
}
zis.closeEntry()
entry = zis.getNextEntry
zis.close()
// delete the zip file
fileSystem.delete(destinationFile, true)
}
zis.close()
// delete the zip file
fileSystem.delete(destinationFile, true)
}
Some(splitPath)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ class MockResourceDownloader(resourcePath: String) extends ResourceDownloader {
val resources: List[ResourceMetadata] = ResourceMetadata.readResources(resourcePath)

def downloadMetadataIfNeed(folder: String): List[ResourceMetadata] = resources

override def downloadAndUnzipFile(s3FilePath: String): Option[String] = Some("model")
override def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean = true): Option[String] =
Some("model")
}
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,16 @@ class ResourceDownloaderMetaSpec extends AnyFlatSpec with BeforeAndAfter {
"public/models/bert_base_cased_es_3.2.2_3.0_1630999631885.zip")
}

it should "download a model and keep it as zip" taggedAs SlowTest in {
ResourceDownloader.privateDownloader = realPrivateDownloader
ResourceDownloader.publicDownloader = realPublicDownloader
ResourceDownloader.communityDownloader = realCommunityDownloader
ResourceDownloader.downloadModelDirectly(
"s3://auxdata.johnsnowlabs.com/public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip",
folder = "public/models",
unzip = false)
}

it should "be able to list from online metadata" in {
ResourceDownloader.privateDownloader = realPrivateDownloader
ResourceDownloader.publicDownloader = realPublicDownloader
Expand Down