diff --git a/python/sparknlp/internal/__init__.py b/python/sparknlp/internal/__init__.py index a525a53ca2047a..21b48e7693b793 100644 --- a/python/sparknlp/internal/__init__.py +++ b/python/sparknlp/internal/__init__.py @@ -328,9 +328,9 @@ def __init__(self, covered, total, percentage): class _DownloadModelDirectly(ExtendedJavaWrapper): - def __init__(self, name, remote_loc="public/models"): + def __init__(self, name, remote_loc="public/models", unzip=True): super(_DownloadModelDirectly, self).__init__( - "com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.downloadModelDirectly", name, remote_loc) + "com.johnsnowlabs.nlp.pretrained.PythonResourceDownloader.downloadModelDirectly", name, remote_loc, unzip) class _DownloadModel(ExtendedJavaWrapper): diff --git a/python/sparknlp/pretrained/resource_downloader.py b/python/sparknlp/pretrained/resource_downloader.py index c42b35ce46698d..7755b9d0e5878c 100644 --- a/python/sparknlp/pretrained/resource_downloader.py +++ b/python/sparknlp/pretrained/resource_downloader.py @@ -104,17 +104,21 @@ def downloadModel(reader, name, language, remote_loc=None, j_dwn='PythonResource return reader(classname=None, java_model=j_obj) @staticmethod - def downloadModelDirectly(name, remote_loc="public/models"): + def downloadModelDirectly(name, remote_loc="public/models", unzip=True): """Downloads a model directly to the cache folder. - + You can use to copy-paste the s3 URI from the model hub and download the model. + For available s3 URI and models, please see the `Models Hub `__. Parameters ---------- name : str - Name of the model + Name of the model or s3 URI remote_loc : str, optional Directory of the remote Spark NLP Folder, by default "public/models" + unzip : Bool, optional + Used to unzip model, by default 'True' """ - _internal._DownloadModelDirectly(name, remote_loc).apply() + _internal._DownloadModelDirectly(name, remote_loc, unzip).apply() + @staticmethod def downloadPipeline(name, language, remote_loc=None): diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala index 88175340eee72a..4abafc4b99febb 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloader.scala @@ -84,7 +84,7 @@ trait ResourceDownloader { def downloadMetadataIfNeed(folder: String): List[ResourceMetadata] - def downloadAndUnzipFile(s3FilePath: String): Option[String] + def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean = true): Option[String] val fileSystem: FileSystem = ResourceDownloader.fileSystem @@ -126,6 +126,14 @@ object ResourceDownloader { var communityDownloader: ResourceDownloader = new S3ResourceDownloader(s3BucketCommunity, s3Path, cacheFolder, "community") + def getResourceDownloader(folder: String): ResourceDownloader = { + folder match { + case this.publicLoc => publicDownloader + case loc if loc.startsWith("@") => communityDownloader + case _ => privateDownloader + } + } + /** Reset the cache and recreate ResourceDownloader S3 credentials */ def resetResourceDownloader(): Unit = { cache.empty @@ -412,11 +420,7 @@ object ResourceDownloader { } private def getResourceMetadata(location: String): List[ResourceMetadata] = { - location match { - case this.publicLoc => publicDownloader.downloadMetadataIfNeed(location) - case loc if loc.startsWith("@") => communityDownloader.downloadMetadataIfNeed(location) - case _ => privateDownloader.downloadMetadataIfNeed(location) - } + getResourceDownloader(location).downloadMetadataIfNeed(location) } def showAvailableAnnotators(folder: String = publicLoc): Unit = { @@ -450,20 +454,10 @@ object ResourceDownloader { */ def downloadResource(request: ResourceRequest): String = { val future = Future { - if (request.folder.equals(publicLoc)) { - publicDownloader.download(request) - } else if (request.folder.startsWith("@")) { - val actualLoc = request.folder.replace("@", "") - val updatedRequest = ResourceRequest( - request.name, - request.language, - folder = actualLoc, - request.libVersion, - request.sparkVersion) - communityDownloader.download(updatedRequest) - } else { - privateDownloader.download(request) - } + val updatedRequest: ResourceRequest = if (request.folder.startsWith("@")) { + request.copy(folder = request.folder.replace("@", "")) + } else request + getResourceDownloader(request.folder).download(updatedRequest) } var downloadFinished = false @@ -497,22 +491,19 @@ object ResourceDownloader { path.get } - /** Downloads a resource from the default S3 bucket in the cache pretrained folder. + /** Downloads a model from the default S3 bucket to the cache pretrained folder. * @param model - * the name of the key in the S3 bucket + * the name of the key in the S3 bucket or s3 URI * @param folder - * the language of the model - * @return - * the path to the downloaded file + * the folder of the model + * @param unzip + * used to unzip the model, by default true */ - def downloadModelDirectly(model: String, folder: String = publicLoc): Unit = { - if (folder.equals(publicLoc)) { - publicDownloader.downloadAndUnzipFile(model) - } else if (folder.startsWith("@")) { - communityDownloader.downloadAndUnzipFile(model) - } else { - privateDownloader.downloadAndUnzipFile(model) - } + def downloadModelDirectly( + model: String, + folder: String = publicLoc, + unzip: Boolean = true): Unit = { + getResourceDownloader(folder).downloadAndUnzipFile(model, unzip) } def downloadModel[TModel <: PipelineStage]( @@ -569,17 +560,14 @@ object ResourceDownloader { } def getDownloadSize(resourceRequest: ResourceRequest): String = { - var size: Option[Long] = None - val folder = resourceRequest.folder - if (folder.equals(publicLoc)) { - size = publicDownloader.getDownloadSize(resourceRequest) - } else if (folder.startsWith("@")) { - val actualLoc = folder.replace("@", "") - size = communityDownloader.getDownloadSize( - ResourceRequest(resourceRequest.name, resourceRequest.language, actualLoc)) - } else { - size = privateDownloader.getDownloadSize(resourceRequest) - } + + val updatedResourceRequest: ResourceRequest = if (resourceRequest.folder.startsWith("@")) { + resourceRequest.copy(folder = resourceRequest.folder.replace("@", "")) + } else resourceRequest + + val size = getResourceDownloader(resourceRequest.folder) + .getDownloadSize(updatedResourceRequest) + size match { case Some(downloadBytes) => FileHelper.getHumanReadableFileSize(downloadBytes) case None => "-1" @@ -772,9 +760,12 @@ object PythonResourceDownloader { ResourceDownloader.clearCache(name, Option(language), correctedFolder) } - def downloadModelDirectly(model: String, remoteLoc: String = null): Unit = { + def downloadModelDirectly( + model: String, + remoteLoc: String = null, + unzip: Boolean = true): Unit = { val correctedFolder = Option(remoteLoc).getOrElse(ResourceDownloader.publicLoc) - ResourceDownloader.downloadModelDirectly(model, correctedFolder) + ResourceDownloader.downloadModelDirectly(model, correctedFolder, unzip) } def showUnCategorizedResources(): String = { diff --git a/src/main/scala/com/johnsnowlabs/nlp/pretrained/S3ResourceDownloader.scala b/src/main/scala/com/johnsnowlabs/nlp/pretrained/S3ResourceDownloader.scala index 9888f2e7055ab8..f491e5e618832b 100644 --- a/src/main/scala/com/johnsnowlabs/nlp/pretrained/S3ResourceDownloader.scala +++ b/src/main/scala/com/johnsnowlabs/nlp/pretrained/S3ResourceDownloader.scala @@ -307,9 +307,17 @@ class S3ResourceDownloader( } } - def downloadAndUnzipFile(s3FilePath: String): Option[String] = { + def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean): Option[String] = { + // handle s3FilePath options: + // 1--> s3://auxdata.johnsnowlabs.com/public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip + // 2--> public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip + + val newS3FilePath = if (s3FilePath.startsWith("s3")) { + ResourceHelper.parseS3URI(s3FilePath)._2 + } else s3FilePath + + val s3File = newS3FilePath.split("/").last - val s3File = s3FilePath.split("/").last val destinationFile = new Path(cachePath.toString + "/" + s3File) val splitPath = destinationFile.toString.substring(0, destinationFile.toString.length - 4) @@ -318,7 +326,7 @@ class S3ResourceDownloader( val tmpFileName = Files.createTempFile(s3File, "").toString val tmpFile = new File(tmpFileName) - val newStrfilePath: String = s3FilePath + val newStrfilePath: String = newS3FilePath val mybucket: String = bucket // 2. Download content to tmp file awsGateway.getS3Object(mybucket, newStrfilePath, tmpFile) @@ -326,32 +334,33 @@ class S3ResourceDownloader( // 4. Move tmp file to destination fileSystem.moveFromLocalFile(new Path(tmpFile.toString), destinationFile) } + if (unzip) { + if (!fileSystem.exists(new Path(splitPath))) { + val zis = new ZipInputStream(fileSystem.open(destinationFile)) + val buf = Array.ofDim[Byte](1024) + var entry = zis.getNextEntry + require( + destinationFile.toString.substring(destinationFile.toString.length - 4) == ".zip", + "Not a zip file.") - if (!fileSystem.exists(new Path(splitPath))) { - val zis = new ZipInputStream(fileSystem.open(destinationFile)) - val buf = Array.ofDim[Byte](1024) - var entry = zis.getNextEntry - require( - destinationFile.toString.substring(destinationFile.toString.length - 4) == ".zip", - "Not a zip file.") - - while (entry != null) { - if (!entry.isDirectory) { - val entryName = new Path(splitPath, entry.getName) - val outputStream = fileSystem.create(entryName) - var bytesRead = zis.read(buf, 0, 1024) - while (bytesRead > -1) { - outputStream.write(buf, 0, bytesRead) - bytesRead = zis.read(buf, 0, 1024) + while (entry != null) { + if (!entry.isDirectory) { + val entryName = new Path(splitPath, entry.getName) + val outputStream = fileSystem.create(entryName) + var bytesRead = zis.read(buf, 0, 1024) + while (bytesRead > -1) { + outputStream.write(buf, 0, bytesRead) + bytesRead = zis.read(buf, 0, 1024) + } + outputStream.close() } - outputStream.close() + zis.closeEntry() + entry = zis.getNextEntry } - zis.closeEntry() - entry = zis.getNextEntry + zis.close() + // delete the zip file + fileSystem.delete(destinationFile, true) } - zis.close() - // delete the zip file - fileSystem.delete(destinationFile, true) } Some(splitPath) diff --git a/src/test/scala/com/johnsnowlabs/nlp/pretrained/MockResourceDownloader.scala b/src/test/scala/com/johnsnowlabs/nlp/pretrained/MockResourceDownloader.scala index 7e86eb39843c5e..a8c03dd69fad2b 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/pretrained/MockResourceDownloader.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/pretrained/MockResourceDownloader.scala @@ -34,6 +34,6 @@ class MockResourceDownloader(resourcePath: String) extends ResourceDownloader { val resources: List[ResourceMetadata] = ResourceMetadata.readResources(resourcePath) def downloadMetadataIfNeed(folder: String): List[ResourceMetadata] = resources - - override def downloadAndUnzipFile(s3FilePath: String): Option[String] = Some("model") + override def downloadAndUnzipFile(s3FilePath: String, unzip: Boolean = true): Option[String] = + Some("model") } diff --git a/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloaderMetaSpec.scala b/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloaderMetaSpec.scala index 0af8283a589726..77a928e4993504 100644 --- a/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloaderMetaSpec.scala +++ b/src/test/scala/com/johnsnowlabs/nlp/pretrained/ResourceDownloaderMetaSpec.scala @@ -184,6 +184,16 @@ class ResourceDownloaderMetaSpec extends AnyFlatSpec with BeforeAndAfter { "public/models/bert_base_cased_es_3.2.2_3.0_1630999631885.zip") } + it should "download a model and keep it as zip" taggedAs SlowTest in { + ResourceDownloader.privateDownloader = realPrivateDownloader + ResourceDownloader.publicDownloader = realPublicDownloader + ResourceDownloader.communityDownloader = realCommunityDownloader + ResourceDownloader.downloadModelDirectly( + "s3://auxdata.johnsnowlabs.com/public/models/albert_base_sequence_classifier_ag_news_en_3.4.0_3.0_1639648298937.zip", + folder = "public/models", + unzip = false) + } + it should "be able to list from online metadata" in { ResourceDownloader.privateDownloader = realPrivateDownloader ResourceDownloader.publicDownloader = realPublicDownloader