Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SPARKNLP-88 Adding support for S3 in CoNLL, POS, CoNLLU #13596

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 32 additions & 2 deletions src/main/scala/com/johnsnowlabs/client/aws/AWSGateway.scala
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,8 @@ import com.amazonaws.services.s3.model.{
GetObjectRequest,
ObjectMetadata,
PutObjectResult,
S3Object
S3Object,
S3ObjectSummary
}
import com.amazonaws.services.s3.transfer.{Transfer, TransferManagerBuilder}
import com.amazonaws.services.s3.{AmazonS3, AmazonS3ClientBuilder}
Expand All @@ -34,7 +35,9 @@ import com.johnsnowlabs.util.{ConfigHelper, ConfigLoader}
import org.apache.hadoop.fs.{FileSystem, Path}
import org.slf4j.{Logger, LoggerFactory}

import scala.jdk.CollectionConverters._
import java.io.File
import scala.util.control.NonFatal

class AWSGateway(
accessKeyId: String = ConfigLoader.getConfigStringValue(ConfigHelper.awsExternalAccessKeyId),
Expand Down Expand Up @@ -111,6 +114,10 @@ class AWSGateway(
} catch {
case exception: AmazonServiceException =>
if (exception.getStatusCode == 404) false else throw exception
case NonFatal(unexpectedException) =>
val methodName = Thread.currentThread.getStackTrace()(1).getMethodName
throw new Exception(
s"Unexpected error in ${this.getClass.getName}.$methodName: $unexpectedException")
}
}

Expand All @@ -121,6 +128,10 @@ class AWSGateway(
} catch {
case exception: AmazonServiceException =>
if (exception.getStatusCode == 404) false else throw exception
case NonFatal(unexpectedException) =>
val methodName = Thread.currentThread.getStackTrace()(1).getMethodName
throw new Exception(
s"Unexpected error in ${this.getClass.getName}.$methodName: $unexpectedException")
}

}
Expand All @@ -145,7 +156,12 @@ class AWSGateway(
val meta = client.getObjectMetadata(bucket, s3FilePath)
Some(meta.getContentLength)
} catch {
case e: AmazonServiceException => if (e.getStatusCode == 404) None else throw e
case exception: AmazonServiceException =>
if (exception.getStatusCode == 404) None else throw exception
case NonFatal(unexpectedException) =>
val methodName = Thread.currentThread.getStackTrace()(1).getMethodName
throw new Exception(
s"Unexpected error in ${this.getClass.getName}.$methodName: $unexpectedException")
}
}

Expand Down Expand Up @@ -199,6 +215,20 @@ class AWSGateway(
}
}

def listS3Files(bucket: String, s3Path: String): Array[S3ObjectSummary] = {
try {
val listObjects = client.listObjectsV2(bucket, s3Path)
listObjects.getObjectSummaries.asScala.toArray
} catch {
case e: AmazonServiceException =>
throw new AmazonServiceException("Amazon service error: " + e.getMessage)
case NonFatal(unexpectedException) =>
val methodName = Thread.currentThread.getStackTrace()(1).getMethodName
throw new Exception(
s"Unexpected error in ${this.getClass.getName}.$methodName: $unexpectedException")
}
}

override def close(): Unit = {
client.shutdown()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -680,7 +680,6 @@ trait WithGraphResolver {

if (localGraphPath.isDefined && OutputHelper
.getFileSystem(localGraphPath.get)
._1
.getScheme == "dbfs") {
files =
ResourceHelper.listLocalFiles(localGraphPath.get).map(file => file.getAbsolutePath)
Expand Down
4 changes: 2 additions & 2 deletions src/main/scala/com/johnsnowlabs/nlp/training/CoNLL.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ package com.johnsnowlabs.nlp.training

import com.johnsnowlabs.nlp.annotators.common.Annotated.{NerTaggedSentence, PosTaggedSentence}
import com.johnsnowlabs.nlp.annotators.common._
import com.johnsnowlabs.nlp.util.io.{ExternalResource, ReadAs, ResourceHelper}
import com.johnsnowlabs.nlp.util.io.{ExternalResource, OutputHelper, ReadAs, ResourceHelper}
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType, DocumentAssembler}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{Dataset, SparkSession}
Expand Down Expand Up @@ -352,7 +352,7 @@ case class CoNLL(
storageLevel: StorageLevel = StorageLevel.DISK_ONLY): Dataset[_] = {
if (path.endsWith("*")) {
val rdd = spark.sparkContext
.wholeTextFiles(path, minPartitions = parallelism)
.wholeTextFiles(OutputHelper.parsePath(path), minPartitions = parallelism)
.flatMap { case (_, content) =>
val lines = content.split(System.lineSeparator)
readLines(lines)
Expand Down
3 changes: 2 additions & 1 deletion src/main/scala/com/johnsnowlabs/nlp/training/POS.scala
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

package com.johnsnowlabs.nlp.training

import com.johnsnowlabs.nlp.util.io.OutputHelper
import com.johnsnowlabs.nlp.{Annotation, AnnotatorType}
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.functions.{col, concat_ws, udf}
Expand Down Expand Up @@ -166,7 +167,7 @@ case class POS() {
require(delimiter.length == 1, s"Delimiter must be one character long. Received $delimiter")

val dataset = sparkSession.read
.textFile(path)
.textFile(OutputHelper.parsePath(path))
.filter(_.nonEmpty)
.map(line => lineToTaggedDocument(line, delimiter))
.map { case TaggedDocument(sentence, taggedTokens) =>
Expand Down
85 changes: 79 additions & 6 deletions src/main/scala/com/johnsnowlabs/nlp/util/io/OutputHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ import org.apache.spark.SparkFiles

import java.io.{File, FileWriter, PrintWriter}
import java.nio.charset.StandardCharsets
import scala.language.existentials
import scala.util.{Failure, Success, Try}

object OutputHelper {

Expand All @@ -34,13 +34,86 @@ object OutputHelper {
def getFileSystem: FileSystem = {
FileSystem.get(sparkSession.sparkContext.hadoopConfiguration)
}
def getFileSystem(resource: String): FileSystem = {
val resourcePath = new Path(parsePath(resource))
FileSystem.get(resourcePath.toUri, sparkSession.sparkContext.hadoopConfiguration)
}

def getFileSystem(resource: String): (FileSystem, Path) = {
val resourcePath = new Path(resource)
val fileSystem =
FileSystem.get(resourcePath.toUri, sparkSession.sparkContext.hadoopConfiguration)
def parsePath(path: String): String = {
val pathPrefix = path.split("://").head
pathPrefix match {
case "s3" => path.replace("s3", "s3a")
case "file" => {
val pattern = """^file:(/+)""".r
pattern.replaceAllIn(path, "file:///")
}
case _ => path
}
}

def doesPathExists(resource: String): (Boolean, Option[Path]) = {
val fileSystem = OutputHelper.getFileSystem(resource)
var modifiedPath = resource

fileSystem.getScheme match {
case "file" =>
val path = new Path(resource)
var exists = Try {
fileSystem.exists(path)
} match {
case Success(value) => value
case Failure(_) => false
}

if (!exists) {
modifiedPath = resource.replaceFirst("//+", "///")
exists = Try {
fileSystem.exists(new Path(modifiedPath))
} match {
case Success(value) => value
case Failure(_) => false
}
}

if (!exists) {
modifiedPath = resource.replaceFirst("/+", "//")
exists = Try {
fileSystem.exists(new Path(modifiedPath))
} match {
case Success(value) => value
case Failure(_) => false
}
}

if (!exists) {
val pattern = """^file:/*""".r
modifiedPath = pattern.replaceAllIn(resource, "")
exists = Try {
fileSystem.exists(new Path(modifiedPath))
} match {
case Success(value) => value
case Failure(_) => false
}
}

if (exists) {
(exists, Some(new Path(modifiedPath)))
} else (exists, None)
case _ => {
val exists = Try {
val modifiedPath = parsePath(resource)
fileSystem.exists(new Path(modifiedPath))
} match {
case Success(value) => value
case Failure(_) => false
}

if (exists) {
(exists, Some(new Path(modifiedPath)))
} else (exists, None)
}
}

(fileSystem, resourcePath)
}

private def getLogsFolder: String =
Expand Down
56 changes: 34 additions & 22 deletions src/main/scala/com/johnsnowlabs/nlp/util/io/ResourceHelper.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,14 +17,15 @@
package com.johnsnowlabs.nlp.util.io

import com.amazonaws.AmazonServiceException
import com.johnsnowlabs.client.aws.AWSGateway
import com.johnsnowlabs.nlp.annotators.Tokenizer
import com.johnsnowlabs.nlp.annotators.common.{TaggedSentence, TaggedWord}
import com.johnsnowlabs.nlp.pretrained.ResourceDownloader
import com.johnsnowlabs.nlp.util.io.ReadAs._
import com.johnsnowlabs.nlp.{DocumentAssembler, Finisher}
import com.johnsnowlabs.util.ConfigHelper
import org.apache.commons.io.{FileUtils, IOUtils}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.spark.ml.{Pipeline, PipelineModel}
import org.apache.spark.sql.{DataFrame, Dataset, SparkSession}

Expand Down Expand Up @@ -104,26 +105,37 @@ object ResourceHelper {
/** Structure for a SourceStream coming from compiled content */
case class SourceStream(resource: String) {

val (fileSystem, path) = OutputHelper.getFileSystem(resource)
if (!fileSystem.exists(path)) {
var fileSystem: Option[FileSystem] = None
private val (pathExists, path) = OutputHelper.doesPathExists(resource)
if (!pathExists) {
throw new FileNotFoundException(s"file or folder: $resource not found")
} else {
fileSystem = Some(OutputHelper.getFileSystem(resource))
}

val pipe: Seq[InputStream] = {

/** Check whether it exists in file system */
val files = fileSystem.listFiles(path, true)
val buffer = ArrayBuffer.empty[InputStream]
while (files.hasNext) buffer.append(fileSystem.open(files.next().getPath))
buffer
}

val openBuffers: Seq[BufferedSource] = pipe.map(pp => {
val pipe: Seq[InputStream] = getPipe(fileSystem.get)
private val openBuffers: Seq[BufferedSource] = pipe.map(pp => {
new BufferedSource(pp)("UTF-8")
})

val content: Seq[Iterator[String]] = openBuffers.map(c => c.getLines())

private def getPipe(fileSystem: FileSystem): Seq[InputStream] = {
if (fileSystem.getScheme == "s3a") {
val awsGateway = new AWSGateway()
val (bucket, s3Path) = parseS3URI(path.get.toString)
val inputStreams = awsGateway.listS3Files(bucket, s3Path).map { summary =>
val s3Object = awsGateway.getS3Object(bucket, summary.getKey)
s3Object.getObjectContent
}
inputStreams
} else {
val files = fileSystem.listFiles(path.get, true)
val buffer = ArrayBuffer.empty[InputStream]
while (files.hasNext) buffer.append(fileSystem.open(files.next().getPath))
buffer
}
}

/** Copies the resource into a local temporary folder and returns the folders URI.
*
* @param prefix
Expand All @@ -132,16 +144,16 @@ object ResourceHelper {
* URI of the created temporary folder with the resource
*/
def copyToLocal(prefix: String = "sparknlp_tmp_"): URI = {
if (fileSystem.getScheme == "file")
if (fileSystem.get.getScheme == "file")
return URI.create(resource)

val destination: file.Path = Files.createTempDirectory(prefix)

val destinationUri = fileSystem.getScheme match {
val destinationUri = fileSystem.get.getScheme match {
case "hdfs" =>
fileSystem.copyToLocalFile(false, path, new Path(destination.toUri), true)
if (fileSystem.getFileStatus(path).isDirectory)
Paths.get(destination.toString, path.getName).toUri
fileSystem.get.copyToLocalFile(false, path.get, new Path(destination.toUri), true)
if (fileSystem.get.getFileStatus(path.get).isDirectory)
Paths.get(destination.toString, path.get.getName).toUri
else destination.toUri
case "dbfs" =>
val dbfsPath = path.toString.replace("dbfs:/", "/dbfs/")
Expand All @@ -151,9 +163,9 @@ object ResourceHelper {
else FileUtils.copyDirectory(sourceFile, targetFile)
targetFile.toURI
case _ =>
val files = fileSystem.listFiles(path, false)
val files = fileSystem.get.listFiles(path.get, false)
while (files.hasNext) {
fileSystem.copyFromLocalFile(files.next.getPath, new Path(destination.toUri))
fileSystem.get.copyFromLocalFile(files.next.getPath, new Path(destination.toUri))
}
destination.toUri
}
Expand Down Expand Up @@ -719,7 +731,7 @@ object ResourceHelper {

def moveFile(sourceFile: String, destinationFile: String): Unit = {

val (sourceFileSystem, _) = OutputHelper.getFileSystem(sourceFile)
val sourceFileSystem = OutputHelper.getFileSystem(sourceFile)

if (destinationFile.startsWith("s3:")) {
val s3Bucket = destinationFile.replace("s3://", "").split("/").head
Expand Down
Loading