From 7e18f643a4979394bd4a68379812991284dcc832 Mon Sep 17 00:00:00 2001 From: tballard Date: Wed, 8 Jun 2022 14:51:42 -0600 Subject: [PATCH] fix compression and tests therefore (#108) * fix compression and tests therefore * Comply with recommended changes * Revert some formatting changes and rename variable * Move echo server to a separate file * Remove unused import * Fix Mima * Fix compilation error in Scala 3 Co-authored-by: Tom Ballard Co-authored-by: Lorenzo Gabriele --- build.sc | 9 +- mill | 2 +- requests/src/requests/Requester.scala | 20 +++-- .../src-2/requests/Scala2RequestTests.scala | 2 + requests/test/src/requests/RequestTests.scala | 86 +++++++++++++------ requests/test/src/requests/ServerUtils.scala | 77 +++++++++++++++++ 6 files changed, 155 insertions(+), 41 deletions(-) create mode 100644 requests/test/src/requests/ServerUtils.scala diff --git a/build.sc b/build.sc index b084e44..96a8241 100644 --- a/build.sc +++ b/build.sc @@ -1,9 +1,9 @@ import mill._ import mill.scalalib.publish.{Developer, License, PomSettings, VersionControl} import scalalib._ -import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version_mill0.9:0.1.1` +import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version::0.1.4` import de.tobiasroeser.mill.vcs.version.VcsVersion -import $ivy.`com.github.lolgab::mill-mima_mill0.9:0.0.4` +import $ivy.`com.github.lolgab::mill-mima::0.0.10` import com.github.lolgab.mill.mima._ val dottyVersion = Option(sys.props("dottyVersion")) @@ -11,7 +11,10 @@ val dottyVersion = Option(sys.props("dottyVersion")) object requests extends Cross[RequestsModule]((List("2.12.13", "2.13.5", "2.11.12", "3.0.0") ++ dottyVersion): _*) class RequestsModule(val crossScalaVersion: String) extends CrossScalaModule with PublishModule with Mima { def publishVersion = VcsVersion.vcsState().format() - def mimaPreviousVersions = VcsVersion.vcsState().lastTag.toSeq + def mimaPreviousVersions = Seq("0.7.0") ++ VcsVersion.vcsState().lastTag.toSeq + override def mimaBinaryIssueFilters = Seq( + ProblemFilter.exclude[ReversedMissingMethodProblem]("requests.BaseSession.send") + ) def artifactName = "requests" def pomSettings = PomSettings( description = "Scala port of the popular Python Requests HTTP client", diff --git a/mill b/mill index d66be9e..c379d48 100755 --- a/mill +++ b/mill @@ -3,7 +3,7 @@ # This is a wrapper script, that automatically download mill from GitHub release pages # You can give the required mill version with MILL_VERSION env variable # If no version is given, it falls back to the value of DEFAULT_MILL_VERSION -DEFAULT_MILL_VERSION=0.9.7 +DEFAULT_MILL_VERSION=0.9.12 set -e diff --git a/requests/src/requests/Requester.scala b/requests/src/requests/Requester.scala index 54f345a..0534768 100644 --- a/requests/src/requests/Requester.scala +++ b/requests/src/requests/Requester.scala @@ -9,7 +9,6 @@ import scala.collection.mutable trait BaseSession{ def headers: Map[String, String] - def cookies: mutable.Map[String, HttpCookie] def readTimeout: Int def connectTimeout: Int @@ -54,7 +53,7 @@ object Requester{ } case class Requester(verb: String, sess: BaseSession){ - + private val upperCaseVerb = verb.toUpperCase /** * Makes a single HTTP request, and returns a [[Response]] object. Requires @@ -204,7 +203,6 @@ case class Requester(verb: String, } connection.setInstanceFollowRedirects(false) - val upperCaseVerb = verb.toUpperCase if (Requester.officialHttpMethods.contains(upperCaseVerb)) { connection.setRequestMethod(upperCaseVerb) } else { @@ -250,17 +248,18 @@ case class Requester(verb: String, .map{case (k, v) => s"""$k="$v""""} .mkString("; ") ) - } - if (verb.toUpperCase == "POST" || verb.toUpperCase == "PUT" || verb.toUpperCase == "PATCH" || verb.toUpperCase == "DELETE") { + } + + if (upperCaseVerb == "POST" || upperCaseVerb == "PUT" || upperCaseVerb == "PATCH" || upperCaseVerb == "DELETE") { if (!chunkedUpload) { val bytes = new ByteArrayOutputStream() - data.write(compress.wrap(bytes)) + usingOutputStream(compress.wrap(bytes)) { os => data.write(os) } val byteArray = bytes.toByteArray connection.setFixedLengthStreamingMode(byteArray.length) - if (byteArray.nonEmpty) connection.getOutputStream.write(byteArray) + usingOutputStream(connection.getOutputStream) { os => os.write(byteArray) } } else { connection.setChunkedStreamingMode(0) - data.write(compress.wrap(connection.getOutputStream)) + usingOutputStream(compress.wrap(connection.getOutputStream)) { os => data.write(os) } } } @@ -333,7 +332,7 @@ case class Requester(verb: String, // The HEAD method is identical to GET except that the server // MUST NOT return a message-body in the response. // https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html section 9.4 - if (verb == "HEAD") f(new ByteArrayInputStream(Array())) + if (upperCaseVerb == "HEAD") f(new ByteArrayInputStream(Array())) else if (stream != null) { try f( if (deGzip) new GZIPInputStream(stream) @@ -366,6 +365,9 @@ case class Requester(verb: String, } } } + + private def usingOutputStream[T](os: OutputStream)(fn: OutputStream => T): Unit = + try fn(os) finally os.close() /** * Overload of [[Requester.apply]] that takes a [[Request]] object as configuration diff --git a/requests/test/src-2/requests/Scala2RequestTests.scala b/requests/test/src-2/requests/Scala2RequestTests.scala index 63d73de..9a9d787 100644 --- a/requests/test/src-2/requests/Scala2RequestTests.scala +++ b/requests/test/src-2/requests/Scala2RequestTests.scala @@ -18,6 +18,7 @@ object Scala2RequestTests extends TestSuite{ assert(read(res1).obj("form") == Obj("foo" -> "baz", "hello" -> "world")) } } + test("put") { for (chunkedUpload <- Seq(true, false)) { val res1 = requests.put( @@ -28,6 +29,7 @@ object Scala2RequestTests extends TestSuite{ assert(read(res1).obj("form") == Obj("foo" -> "baz", "hello" -> "world")) } } + test("send"){ requests.send("get")("https://httpbin.org/get?hello=world&foo=baz") diff --git a/requests/test/src/requests/RequestTests.scala b/requests/test/src/requests/RequestTests.scala index 98b94b4..6f63f75 100644 --- a/requests/test/src/requests/RequestTests.scala +++ b/requests/test/src/requests/RequestTests.scala @@ -30,6 +30,7 @@ object RequestTests extends TestSuite{ } } } + test("params"){ test("get"){ // All in URL @@ -58,6 +59,7 @@ object RequestTests extends TestSuite{ assert(read(res4).obj("args") == Obj("++-- lol" -> " !@#$%", "hello" -> "world")) } } + test("multipart"){ for(chunkedUpload <- Seq(true, false)) { val response = requests.post( @@ -73,8 +75,8 @@ object RequestTests extends TestSuite{ assert(read(response).obj("form") == Obj("file2" -> "Goodbye!")) } } - test("cookies"){ + test("cookies"){ test("session"){ val s = requests.Session(cookieValues = Map("hello" -> "world")) val res1 = s.get("https://httpbin.org/cookies").text().trim @@ -99,35 +101,37 @@ object RequestTests extends TestSuite{ assert(read(res2) == Obj("cookies" -> Obj("freeform" -> "test test", "hello" -> "hello, world"))) } } - // Tests fail with 'Request to https://httpbin.org/absolute-redirect/4 failed with status code 404' - // test("redirects"){ - // test("max"){ - // val res1 = requests.get("https://httpbin.org/absolute-redirect/4") - // assert(res1.statusCode == 200) - // val res2 = requests.get("https://httpbin.org/absolute-redirect/5") - // assert(res2.statusCode == 200) - // val res3 = requests.get("https://httpbin.org/absolute-redirect/6", check = false) - // assert(res3.statusCode == 302) - // val res4 = requests.get("https://httpbin.org/absolute-redirect/6", maxRedirects = 10) - // assert(res4.statusCode == 200) - // } - // test("maxRelative"){ - // val res1 = requests.get("https://httpbin.org/relative-redirect/4") - // assert(res1.statusCode == 200) - // val res2 = requests.get("https://httpbin.org/relative-redirect/5") - // assert(res2.statusCode == 200) - // val res3 = requests.get("https://httpbin.org/relative-redirect/6", check = false) - // assert(res3.statusCode == 302) - // val res4 = requests.get("https://httpbin.org/relative-redirect/6", maxRedirects = 10) - // assert(res4.statusCode == 200) - // } - // } + + test("redirects"){ + test("max"){ + val res1 = requests.get("https://httpbin.org/absolute-redirect/4") + assert(res1.statusCode == 200) + val res2 = requests.get("https://httpbin.org/absolute-redirect/5") + assert(res2.statusCode == 200) + val res3 = requests.get("https://httpbin.org/absolute-redirect/6", check = false) + assert(res3.statusCode == 302) + val res4 = requests.get("https://httpbin.org/absolute-redirect/6", maxRedirects = 10) + assert(res4.statusCode == 200) + } + test("maxRelative"){ + val res1 = requests.get("https://httpbin.org/relative-redirect/4") + assert(res1.statusCode == 200) + val res2 = requests.get("https://httpbin.org/relative-redirect/5") + assert(res2.statusCode == 200) + val res3 = requests.get("https://httpbin.org/relative-redirect/6", check = false) + assert(res3.statusCode == 302) + val res4 = requests.get("https://httpbin.org/relative-redirect/6", maxRedirects = 10) + assert(res4.statusCode == 200) + } + } + test("streaming"){ val res1 = requests.get("http://httpbin.org/stream/5").text() assert(res1.linesIterator.length == 5) val res2 = requests.get("http://httpbin.org/stream/52").text() assert(res2.linesIterator.length == 52) } + test("timeouts"){ test("read"){ intercept[TimeoutException] { @@ -144,6 +148,7 @@ object RequestTests extends TestSuite{ } } } + test("failures"){ intercept[UnknownHostException]{ requests.get("https://doesnt-exist-at-all.com/") @@ -156,6 +161,7 @@ object RequestTests extends TestSuite{ requests.get("://doesnt-exist.com/") } } + test("decompress"){ val res1 = requests.get("https://httpbin.org/gzip") assert(read(res1.text()).obj("headers").obj("Host").str == "httpbin.org") @@ -171,6 +177,7 @@ object RequestTests extends TestSuite{ (res1.bytes.length, res2.bytes.length, res3.bytes.length, res4.bytes.length) } + test("compression"){ val res1 = requests.post( "https://httpbin.org/post", @@ -184,16 +191,18 @@ object RequestTests extends TestSuite{ compress = requests.Compress.Gzip, data = new RequestBlob.ByteSourceRequestBlob("I am cow") ) - assert(res2.text().contains("data:application/octet-stream;base64,H4sIAAAAAAAAAA==")) + assert(read(new String(res2.bytes))("data").toString == + """"data:application/octet-stream;base64,H4sIAAAAAAAAAPNUSMxVSM4vBwCAGeD4CAAAAA=="""") val res3 = requests.post( "https://httpbin.org/post", compress = requests.Compress.Deflate, data = new RequestBlob.ByteSourceRequestBlob("Hear me moo") ) - assert(res3.text().contains("data:application/octet-stream;base64,eJw=")) - res3.text() - } + assert(read(new String(res2.bytes))("data").toString == + """"data:application/octet-stream;base64,H4sIAAAAAAAAAPNUSMxVSM4vBwCAGeD4CAAAAA=="""") + } + test("headers"){ test("default"){ val res = requests.get("https://httpbin.org/headers").text() @@ -207,6 +216,7 @@ object RequestTests extends TestSuite{ } } } + test("clientCertificate"){ val base = "./requests/test/resources" val url = "https://client.badssl.com" @@ -253,6 +263,7 @@ object RequestTests extends TestSuite{ assert(res.statusCode == 400) } } + test("selfSignedCertificate"){ val res = requests.get( "https://self-signed.badssl.com", @@ -260,6 +271,7 @@ object RequestTests extends TestSuite{ ) assert(res.statusCode == 200) } + test("gzipError"){ val response = requests.head("https://api.github.com/users/lihaoyi") assert(response.statusCode == 200) @@ -268,5 +280,23 @@ object RequestTests extends TestSuite{ assert(response.headers.keySet.map(_.toLowerCase).contains("content-length")) assert(response.headers.keySet.map(_.toLowerCase).contains("content-type")) } + + /** + * Compress with each compression mode and call server. Server expands + * and passes it back so we can compare + */ + test("compressionData") { + import requests.Compress._ + val str = "I am deflater mouse" + Seq(None, Gzip, Deflate).foreach(c => + ServerUtils.usingEchoServer { port => + assert(str == requests.post( + s"http://localhost:$port/echo", + compress = c, + data = new RequestBlob.ByteSourceRequestBlob(str) + ).data.toString) + } + ) + } } } diff --git a/requests/test/src/requests/ServerUtils.scala b/requests/test/src/requests/ServerUtils.scala new file mode 100644 index 0000000..d9ceebd --- /dev/null +++ b/requests/test/src/requests/ServerUtils.scala @@ -0,0 +1,77 @@ +package requests + +import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer} +import java.io._ +import java.net.InetSocketAddress +import java.util.zip.{GZIPInputStream, InflaterInputStream} +import requests.Compress._ +import scala.annotation.tailrec +import scala.collection.mutable.StringBuilder + +object ServerUtils { + def usingEchoServer(f: Int => Unit): Unit = { + val server = new EchoServer + try f(server.getPort()) + finally server.stop() + } + + private class EchoServer extends HttpHandler { + private val server: HttpServer = HttpServer.create(new InetSocketAddress(0), 0) + server.createContext("/echo", this) + server.setExecutor(null); // default executor + server.start() + + def getPort(): Int = server.getAddress.getPort + + def stop(): Unit = server.stop(0) + + override def handle(t: HttpExchange): Unit = { + val h: java.util.List[String] = + t.getRequestHeaders.get("Content-encoding") + val c: Compress = + if (h == null) None + else if (h.contains("gzip")) Gzip + else if (h.contains("deflate")) Deflate + else None + val msg = new Plumper(c).decompress(t.getRequestBody) + t.sendResponseHeaders(200, msg.length) + t.getResponseBody.write(msg.getBytes()) + } + } + + /** Stream uncompresser + * @param c + * Compression mode + */ + private class Plumper(c: Compress) { + + private def wrap(is: InputStream): InputStream = + c match { + case None => is + case Gzip => new GZIPInputStream(is) + case Deflate => new InflaterInputStream(is) + } + + def decompress(compressed: InputStream): String = { + val gis = wrap(compressed) + val br = new BufferedReader(new InputStreamReader(gis, "UTF-8")) + val sb = new StringBuilder() + + @tailrec + def read(): Unit = { + val line = br.readLine + if (line != null) { + sb.append(line) + read() + } + } + + read() + br.close() + gis.close() + compressed.close() + sb.toString() + } + } + +}