Skip to content

Commit

Permalink
Add unzip param to downloadModelDirectly in ResourceDownloader (#13796)
Browse files Browse the repository at this point in the history
* added downloadModelDirectlyAsZip

* added unzip argument to downloadModelDirectly
  • Loading branch information
mehmetbutgul authored May 25, 2023
1 parent 468568e commit f24c85f
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 79 deletions.
4 changes: 2 additions & 2 deletions python/sparknlp/internal/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,9 +328,9 @@ def __init__(self, covered, total, percentage):


class _DownloadModelDirectly(ExtendedJavaWrapper):
def __init__(self, name, remote_loc="public/models"):
def __init__(self, name, remote_loc="public/models", unzip=True):
super(_DownloadModelDirectly, self).__init__(
"com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.downloadModelDirectly", name, remote_loc)
"com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.downloadModelDirectly", name, remote_loc, unzip)


class _DownloadModel(ExtendedJavaWrapper):
Expand Down
12 changes: 8 additions & 4 deletions python/sparknlp/pretrained/resource_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,21 @@ def downloadModel(reader, name, language, remote_loc=None, j_dwn='PythonResource
return reader(classname=None, java_model=j_obj)

@staticmethod
def downloadModelDirectly(name, remote_loc="public/models"):
def downloadModelDirectly(name, remote_loc="public/models", unzip=True):
"""Downloads a model directly to the cache folder.
You can use to copy-paste the s3 URI from the model hub and download the model.
For available s3 URI and models, please see the `Models Hub <https://sparknlp.org/models>`__.
Parameters
----------
name : str
Name of the model
Name of the model or s3 URI
remote_loc : str, optional
Directory of the remote Spark NLP Folder, by default "public/models"
unzip : Bool, optional
Used to unzip model, by default 'True'
"""
_internal._DownloadModelDirectly(name, remote_loc).apply()
_internal._DownloadModelDirectly(name, remote_loc, unzip).apply()


@staticmethod
def downloadPipeline(name, language, remote_loc=None):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ trait ResourceDownloader {

def downloadMetadataIfNeed(folder: String): List[ResourceMetadata]

def downloadAndUnzipFile(s3FilePath: String): Option[String]
def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean = true): Option[String]

val fileSystem: FileSystem = ResourceDownloader.fileSystem

Expand Down Expand Up @@ -126,6 +126,14 @@ object ResourceDownloader {
var communityDownloader: ResourceDownloader =
new S3ResourceDownloader(s3BucketCommunity, s3Path, cacheFolder, "community")

def getResourceDownloader(folder: String): ResourceDownloader = {
folder match {
case this.publicLoc => publicDownloader
case loc if loc.startsWith("@") => communityDownloader
case _ => privateDownloader
}
}

/** Reset the cache and recreate ResourceDownloader S3 credentials */
def resetResourceDownloader(): Unit = {
cache.empty
Expand Down Expand Up @@ -412,11 +420,7 @@ object ResourceDownloader {
}

private def getResourceMetadata(location: String): List[ResourceMetadata] = {
location match {
case this.publicLoc => publicDownloader.downloadMetadataIfNeed(location)
case loc if loc.startsWith("@") => communityDownloader.downloadMetadataIfNeed(location)
case _ => privateDownloader.downloadMetadataIfNeed(location)
}
getResourceDownloader(location).downloadMetadataIfNeed(location)
}

def showAvailableAnnotators(folder: String = publicLoc): Unit = {
Expand Down Expand Up @@ -450,20 +454,10 @@ object ResourceDownloader {
*/
def downloadResource(request: ResourceRequest): String = {
val future = Future {
if (request.folder.equals(publicLoc)) {
publicDownloader.download(request)
} else if (request.folder.startsWith("@")) {
val actualLoc = request.folder.replace("@", "")
val updatedRequest = ResourceRequest(
request.name,
request.language,
folder = actualLoc,
request.libVersion,
request.sparkVersion)
communityDownloader.download(updatedRequest)
} else {
privateDownloader.download(request)
}
val updatedRequest: ResourceRequest = if (request.folder.startsWith("@")) {
request.copy(folder = request.folder.replace("@", ""))
} else request
getResourceDownloader(request.folder).download(updatedRequest)
}

var downloadFinished = false
Expand Down Expand Up @@ -497,22 +491,19 @@ object ResourceDownloader {
path.get
}

/** Downloads a resource from the default S3 bucket in the cache pretrained folder.
/** Downloads a model from the default S3 bucket to the cache pretrained folder.
* @param model
* the name of the key in the S3 bucket
* the name of the key in the S3 bucket or s3 URI
* @param folder
* the language of the model
* @return
* the path to the downloaded file
* the folder of the model
* @param unzip
* used to unzip the model, by default true
*/
def downloadModelDirectly(model: String, folder: String = publicLoc): Unit = {
if (folder.equals(publicLoc)) {
publicDownloader.downloadAndUnzipFile(model)
} else if (folder.startsWith("@")) {
communityDownloader.downloadAndUnzipFile(model)
} else {
privateDownloader.downloadAndUnzipFile(model)
}
def downloadModelDirectly(
model: String,
folder: String = publicLoc,
unzip: Boolean = true): Unit = {
getResourceDownloader(folder).downloadAndUnzipFile(model, unzip)
}

def downloadModel[TModel <: PipelineStage](
Expand Down Expand Up @@ -569,17 +560,14 @@ object ResourceDownloader {
}

def getDownloadSize(resourceRequest: ResourceRequest): String = {
var size: Option[Long] = None
val folder = resourceRequest.folder
if (folder.equals(publicLoc)) {
size = publicDownloader.getDownloadSize(resourceRequest)
} else if (folder.startsWith("@")) {
val actualLoc = folder.replace("@", "")
size = communityDownloader.getDownloadSize(
ResourceRequest(resourceRequest.name, resourceRequest.language, actualLoc))
} else {
size = privateDownloader.getDownloadSize(resourceRequest)
}

val updatedResourceRequest: ResourceRequest = if (resourceRequest.folder.startsWith("@")) {
resourceRequest.copy(folder = resourceRequest.folder.replace("@", ""))
} else resourceRequest

val size = getResourceDownloader(resourceRequest.folder)
.getDownloadSize(updatedResourceRequest)

size match {
case Some(downloadBytes) => FileHelper.getHumanReadableFileSize(downloadBytes)
case None => "-1"
Expand Down Expand Up @@ -772,9 +760,12 @@ object PythonResourceDownloader {
ResourceDownloader.clearCache(name, Option(language), correctedFolder)
}

def downloadModelDirectly(model: String, remoteLoc: String = null): Unit = {
def downloadModelDirectly(
model: String,
remoteLoc: String = null,
unzip: Boolean = true): Unit = {
val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc)
ResourceDownloader.downloadModelDirectly(model, correctedFolder)
ResourceDownloader.downloadModelDirectly(model, correctedFolder, unzip)
}

def showUnCategorizedResources(): String = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,9 +307,17 @@ class S3ResourceDownloader(
}
}

def downloadAndUnzipFile(s3FilePath: String): Option[String] = {
def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean): Option[String] = {
// handle s3FilePath options:
// 1--> s3://auxdata.johnsnowlabs.com/public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip
// 2--> public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip

val newS3FilePath = if (s3FilePath.startsWith("s3")) {
ResourceHelper.parseS3URI(s3FilePath)._2
} else s3FilePath

val s3File = newS3FilePath.split("/").last

val s3File = s3FilePath.split("/").last
val destinationFile = new Path(cachePath.toString + "/" + s3File)
val splitPath = destinationFile.toString.substring(0, destinationFile.toString.length - 4)

Expand All @@ -318,40 +326,41 @@ class S3ResourceDownloader(
val tmpFileName = Files.createTempFile(s3File, "").toString
val tmpFile = new File(tmpFileName)

val newStrfilePath: String = s3FilePath
val newStrfilePath: String = newS3FilePath
val mybucket: String = bucket
// 2. Download content to tmp file
awsGateway.getS3Object(mybucket, newStrfilePath, tmpFile)

// 4. Move tmp file to destination
fileSystem.moveFromLocalFile(new Path(tmpFile.toString), destinationFile)
}
if (unzip) {
if (!fileSystem.exists(new Path(splitPath))) {
val zis = new ZipInputStream(fileSystem.open(destinationFile))
val buf = Array.ofDim[Byte](1024)
var entry = zis.getNextEntry
require(
destinationFile.toString.substring(destinationFile.toString.length - 4) == ".zip",
"Not a zip file.")

if (!fileSystem.exists(new Path(splitPath))) {
val zis = new ZipInputStream(fileSystem.open(destinationFile))
val buf = Array.ofDim[Byte](1024)
var entry = zis.getNextEntry
require(
destinationFile.toString.substring(destinationFile.toString.length - 4) == ".zip",
"Not a zip file.")

while (entry != null) {
if (!entry.isDirectory) {
val entryName = new Path(splitPath, entry.getName)
val outputStream = fileSystem.create(entryName)
var bytesRead = zis.read(buf, 0, 1024)
while (bytesRead > -1) {
outputStream.write(buf, 0, bytesRead)
bytesRead = zis.read(buf, 0, 1024)
while (entry != null) {
if (!entry.isDirectory) {
val entryName = new Path(splitPath, entry.getName)
val outputStream = fileSystem.create(entryName)
var bytesRead = zis.read(buf, 0, 1024)
while (bytesRead > -1) {
outputStream.write(buf, 0, bytesRead)
bytesRead = zis.read(buf, 0, 1024)
}
outputStream.close()
}
outputStream.close()
zis.closeEntry()
entry = zis.getNextEntry
}
zis.closeEntry()
entry = zis.getNextEntry
zis.close()
// delete the zip file
fileSystem.delete(destinationFile, true)
}
zis.close()
// delete the zip file
fileSystem.delete(destinationFile, true)
}
Some(splitPath)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,6 @@ class MockResourceDownloader(resourcePath: String) extends ResourceDownloader {
val resources: List[ResourceMetadata] = ResourceMetadata.readResources(resourcePath)

def downloadMetadataIfNeed(folder: String): List[ResourceMetadata] = resources

override def downloadAndUnzipFile(s3FilePath: String): Option[String] = Some("model")
override def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean = true): Option[String] =
Some("model")
}
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,16 @@ class ResourceDownloaderMetaSpec extends AnyFlatSpec with BeforeAndAfter {
"public/models/bert_base_cased_es_3.2.2_3.0_1630999631885.zip")
}

it should "download a model and keep it as zip" taggedAs SlowTest in {
ResourceDownloader.privateDownloader = realPrivateDownloader
ResourceDownloader.publicDownloader = realPublicDownloader
ResourceDownloader.communityDownloader = realCommunityDownloader
ResourceDownloader.downloadModelDirectly(
"s3://auxdata.johnsnowlabs.com/public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip",
folder = "public/models",
unzip = false)
}

it should "be able to list from online metadata" in {
ResourceDownloader.privateDownloader = realPrivateDownloader
ResourceDownloader.publicDownloader = realPublicDownloader
Expand Down

0 comments on commit f24c85f

Please sign in to comment.