From 783d1f87832da2e70edcc69fbc0523a12d9c6ad6 Mon Sep 17 00:00:00 2001 From: hellozepp Date: Wed, 10 Aug 2022 09:43:47 +0800 Subject: [PATCH] Support callback header (#1808) * Support runscript callbackHeader --- .../java/tech/mlsql/crawler/RestUtils.scala | 14 ++++--- .../tech/mlsql/it/ByzerScriptTestSuite.scala | 42 ++++++++++++++++--- .../java/streaming/rest/RestController.scala | 14 ++++++- 3 files changed, 58 insertions(+), 12 deletions(-) diff --git a/streamingpro-core/src/main/java/tech/mlsql/crawler/RestUtils.scala b/streamingpro-core/src/main/java/tech/mlsql/crawler/RestUtils.scala index 40b47cbe4..57169b2ac 100644 --- a/streamingpro-core/src/main/java/tech/mlsql/crawler/RestUtils.scala +++ b/streamingpro-core/src/main/java/tech/mlsql/crawler/RestUtils.scala @@ -4,13 +4,13 @@ import net.csdn.common.path.Url import net.csdn.modules.transport.HttpTransportService.SResponse import net.csdn.modules.transport.{DefaultHttpTransportService, HttpTransportService} import org.apache.commons.lang3.exception.ExceptionUtils -import org.apache.http.{HttpEntity, HttpResponse} import org.apache.http.client.entity.UrlEncodedFormEntity import org.apache.http.client.fluent.{Form, Request} import org.apache.http.entity.ContentType import org.apache.http.entity.mime.{HttpMultipartMode, MultipartEntityBuilder} import org.apache.http.message.BasicNameValuePair import org.apache.http.util.EntityUtils +import org.apache.http.{HttpEntity, HttpResponse} import streaming.dsl.ScriptSQLExec import streaming.log.WowLog import tech.mlsql.common.JsonUtils @@ -22,17 +22,21 @@ import tech.mlsql.tool.{HDFSOperatorV2, Templates2} import java.nio.charset.Charset import scala.annotation.tailrec import scala.collection.JavaConversions._ -import scala.util.control.Breaks.{break, breakable} object RestUtils extends Logging with WowLog { - def httpClientPost(urlString: String, data: Map[String, String]): HttpResponse = { + def httpClientPost(urlString: String, data: Map[String, String], headers: Map[String, String]): HttpResponse = { val nameValuePairs = data.map { case (name, value) => new BasicNameValuePair(name, value) }.toList - Request.Post(urlString) + val req = Request.Post(urlString) .addHeader("Content-Type", "application/x-www-form-urlencoded") - .body(new UrlEncodedFormEntity(nameValuePairs, DefaultHttpTransportService.charset)) + + headers foreach { case (name, value) => + req.setHeader(name, value) + } + + req.body(new UrlEncodedFormEntity(nameValuePairs, DefaultHttpTransportService.charset)) .execute() .returnResponse() } diff --git a/streamingpro-it/src/test/scala/tech/mlsql/it/ByzerScriptTestSuite.scala b/streamingpro-it/src/test/scala/tech/mlsql/it/ByzerScriptTestSuite.scala index 5df8cd2b5..ec4efa09c 100644 --- a/streamingpro-it/src/test/scala/tech/mlsql/it/ByzerScriptTestSuite.scala +++ b/streamingpro-it/src/test/scala/tech/mlsql/it/ByzerScriptTestSuite.scala @@ -1,5 +1,8 @@ package tech.mlsql.it +import net.csdn.modules.transport.DefaultHttpTransportService +import org.apache.http.HttpEntity +import org.apache.http.util.EntityUtils import tech.mlsql.common.utils.log.Logging import tech.mlsql.crawler.RestUtils import tech.mlsql.it.contiainer.ByzerCluster @@ -8,6 +11,7 @@ import tech.mlsql.it.utils.DockerUtils.getCurProjectRootPath import java.io.File import java.util.UUID +import scala.collection.mutable /** * 23/02/2022 hellozepp(lisheng.zhanglin@163.com) @@ -54,11 +58,27 @@ class ByzerScriptTestSuite extends LocalBaseTestSuite with Logging { }) } - def runScript(url: String, user: String, code: String): (Int, String) = { + def runScript(url: String, user: String, code: String, callbackHeader: String = ""): (Int, String) = { val jobName = UUID.randomUUID().toString + val params = mutable.Map("sql" -> code, "owner" -> user, + "jobName" -> jobName, "sessionPerUser" -> "true", "sessionPerRequest" -> "true") + if (callbackHeader != "") params.put("callbackHeader", callbackHeader) logInfo(s"The test submits a script to the container through Rest, url:$url, sql:$code") - val (status, result) = RestUtils.rest_request_string(url, "post", Map("sql" -> code, "owner" -> user, - "jobName" -> jobName, "sessionPerUser" -> "true", "sessionPerRequest" -> "true"), + val (status, result) = RestUtils.rest_request_string(url, "post", params.toMap, + Map("Content-Type" -> "application/x-www-form-urlencoded"), Map("socket-timeout" -> "1800s", + "connect-timeout" -> "1800s", "retry" -> "1") + ) + logInfo(s"status:$status,result:$result") + (status, result) + } + + def runScriptWithHeader(url: String, user: String, code: String, callbackHeader: String = ""): (Int, HttpEntity) = { + val jobName = UUID.randomUUID().toString + val params = mutable.Map("sql" -> code, "owner" -> user, + "jobName" -> jobName, "sessionPerUser" -> "true", "sessionPerRequest" -> "true") + if (callbackHeader != "") params.put("callbackHeader", callbackHeader) + logInfo(s"The test submits a script to the container through Rest, url:$url, sql:$code") + val (status, result) = RestUtils.rest_request(url, "post", params.toMap, Map("Content-Type" -> "application/x-www-form-urlencoded"), Map("socket-timeout" -> "1800s", "connect-timeout" -> "1800s", "retry" -> "1") ) @@ -82,7 +102,7 @@ class ByzerScriptTestSuite extends LocalBaseTestSuite with Logging { val cluster: ByzerCluster = setupCluster() val hadoopContainer = cluster.hadoopContainer val byzerLangContainer = cluster.byzerLangContainer - val javaContainer = cluster.byzerLangContainer.container + val javaContainer = byzerLangContainer.container url = s"http://${javaContainer.getHost}:${javaContainer.getMappedPort(9003)}/run/script" test("javaContainer") { @@ -101,6 +121,19 @@ class ByzerScriptTestSuite extends LocalBaseTestSuite with Logging { } test("Execute yarn sql file") { + try { + val (_, result) = runScriptWithHeader(url, user, "select 1 as a,'jack' as b as bbc;", + """{"Authorization":"Bearer acc"}""") + val _result = EntityUtils.toString(result, DefaultHttpTransportService.charset) + println("With callbackHeader result:" + _result) + assert(_result === "[{\"a\":1,\"b\":\"jack\"}]") + } catch { + case _: Exception => + val res = "callbackHeader should be returned normally in the byzer callback!" + logError(res) + throw new RuntimeException(res) + } + TestManager.testCases.foreach(testCase => { try { val (status, result) = runScript(url, user, testCase.sql) @@ -110,7 +143,6 @@ class ByzerScriptTestSuite extends LocalBaseTestSuite with Logging { TestManager.acceptRest(testCase, 500, null, e) } }) - TestManager.report() } diff --git a/streamingpro-mlsql/src/main/java/streaming/rest/RestController.scala b/streamingpro-mlsql/src/main/java/streaming/rest/RestController.scala index 4b902dcc3..d89fde2d2 100644 --- a/streamingpro-mlsql/src/main/java/streaming/rest/RestController.scala +++ b/streamingpro-mlsql/src/main/java/streaming/rest/RestController.scala @@ -38,6 +38,7 @@ import org.apache.spark.sql.mlsql.session.{MLSQLSparkSession, SparkSessionCacheM import org.apache.spark.{MLSQLConf, SparkInstanceService} import tech.mlsql.MLSQLEnvKey import tech.mlsql.app.{CustomController, ResultResp} +import tech.mlsql.common.JsonUtils import tech.mlsql.common.utils.log.Logging import tech.mlsql.common.utils.serder.json.JSONTool import tech.mlsql.crawler.RestUtils @@ -105,6 +106,7 @@ class RestController extends ApplicationController with WowLog with Logging { new Parameter(name = "sessionPerRequest", required = false, description = "by default false", `type` = "boolean", allowEmptyValue = false), new Parameter(name = "async", required = false, description = "If set true ,please also provide a callback url use `callback` parameter and the job will run in background and the API will return. default: false", `type` = "boolean", allowEmptyValue = false), new Parameter(name = "callback", required = false, description = "Used when async is set true. callback is a url. default: false", `type` = "string", allowEmptyValue = false), + new Parameter(name = "callbackHeader", required = false, description = "Provide a jsonString parameter to set the header parameter of the callback request. default: false", `type` = "string", allowEmptyValue = false), new Parameter(name = "maxRetries", required = false, description = "Max retries of request callback.", `type` = "int", allowEmptyValue = false), new Parameter(name = "skipInclude", required = false, description = "disable include statement. default: false", `type` = "boolean", allowEmptyValue = false), new Parameter(name = "skipAuth", required = false, description = "disable table authorize . default: true", `type` = "boolean", allowEmptyValue = false), @@ -147,6 +149,12 @@ class RestController extends ApplicationController with WowLog with Logging { if (paramAsBoolean("async", false)) { JobManager.asyncRun(sparkSession, jobInfo, () => { val urlString = param("callback") + val callbackHeaderString = param("callbackHeader") + var callbackHeader = Map[String,String]() + if (callbackHeaderString != null && callbackHeaderString.nonEmpty){ + callbackHeader = JsonUtils.fromJson[Map[String,String]](callbackHeaderString) + } + val maxTries = Math.max(0, paramAsInt("maxRetries", -1)) + 1 try { ScriptSQLExec.parse(param("sql"), context, @@ -161,7 +169,8 @@ class RestController extends ApplicationController with WowLog with Logging { RestUtils.httpClientPost(urlString, Map("stat" -> s"""succeeded""", "res" -> outputResult, - "jobInfo" -> JSONTool.toJsonStr(jobInfo))), + "jobInfo" -> JSONTool.toJsonStr(jobInfo)), + callbackHeader), HttpStatus.SC_OK == _.getStatusLine.getStatusCode, response => logger.error(s"Succeeded SQL callback request failed after ${maxTries} attempts, " + s"the last response status is: ${response.getStatusLine.getStatusCode}.") @@ -178,7 +187,8 @@ class RestController extends ApplicationController with WowLog with Logging { RestUtils.httpClientPost(urlString, Map("stat" -> s"""failed""", "msg" -> (e.getMessage + "\n" + msgBuffer.mkString("\n")), - "jobInfo" -> JSONTool.toJsonStr(jobInfo))), + "jobInfo" -> JSONTool.toJsonStr(jobInfo)), + callbackHeader), HttpStatus.SC_OK == _.getStatusLine.getStatusCode, response => logger.error(s"Fail SQL callback request failed after ${maxTries} attempts, " + s"the last response status is: ${response.getStatusLine.getStatusCode}.")