Skip to content

Commit

Permalink
fix compression and tests therefore (#108)
Browse files Browse the repository at this point in the history
* fix compression and tests therefore

* Comply with recommended changes

* Revert some formatting changes and rename variable

* Move echo server to a separate file

* Remove unused import

* Fix Mima

* Fix compilation error in Scala 3

Co-authored-by: Tom Ballard <tomballard@redangus.org>
Co-authored-by: Lorenzo Gabriele <lorenzolespaul@gmail.com>
  • Loading branch information
3 people authored Jun 8, 2022
1 parent 11a9787 commit 7e18f64
Show file tree
Hide file tree
Showing 6 changed files with 155 additions and 41 deletions.
9 changes: 6 additions & 3 deletions build.sc
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import mill._
import mill.scalalib.publish.{Developer, License, PomSettings, VersionControl}
import scalalib._
import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version_mill0.9:0.1.1`
import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version::0.1.4`
import de.tobiasroeser.mill.vcs.version.VcsVersion
import $ivy.`com.github.lolgab::mill-mima_mill0.9:0.0.4`
import $ivy.`com.github.lolgab::mill-mima::0.0.10`
import com.github.lolgab.mill.mima._

val dottyVersion = Option(sys.props("dottyVersion"))

object requests extends Cross[RequestsModule]((List("2.12.13", "2.13.5", "2.11.12", "3.0.0") ++ dottyVersion): _*)
class RequestsModule(val crossScalaVersion: String) extends CrossScalaModule with PublishModule with Mima {
def publishVersion = VcsVersion.vcsState().format()
def mimaPreviousVersions = VcsVersion.vcsState().lastTag.toSeq
def mimaPreviousVersions = Seq("0.7.0") ++ VcsVersion.vcsState().lastTag.toSeq
override def mimaBinaryIssueFilters = Seq(
ProblemFilter.exclude[ReversedMissingMethodProblem]("requests.BaseSession.send")
)
def artifactName = "requests"
def pomSettings = PomSettings(
description = "Scala port of the popular Python Requests HTTP client",
Expand Down
2 changes: 1 addition & 1 deletion mill
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This is a wrapper script, that automatically download mill from GitHub release pages
# You can give the required mill version with MILL_VERSION env variable
# If no version is given, it falls back to the value of DEFAULT_MILL_VERSION
DEFAULT_MILL_VERSION=0.9.7
DEFAULT_MILL_VERSION=0.9.12

set -e

Expand Down
20 changes: 11 additions & 9 deletions requests/src/requests/Requester.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import scala.collection.mutable

trait BaseSession{
def headers: Map[String, String]

def cookies: mutable.Map[String, HttpCookie]
def readTimeout: Int
def connectTimeout: Int
Expand Down Expand Up @@ -54,7 +53,7 @@ object Requester{
}
case class Requester(verb: String,
sess: BaseSession){

private val upperCaseVerb = verb.toUpperCase

/**
* Makes a single HTTP request, and returns a [[Response]] object. Requires
Expand Down Expand Up @@ -204,7 +203,6 @@ case class Requester(verb: String,
}

connection.setInstanceFollowRedirects(false)
val upperCaseVerb = verb.toUpperCase
if (Requester.officialHttpMethods.contains(upperCaseVerb)) {
connection.setRequestMethod(upperCaseVerb)
} else {
Expand Down Expand Up @@ -250,17 +248,18 @@ case class Requester(verb: String,
.map{case (k, v) => s"""$k="$v""""}
.mkString("; ")
)
}
if (verb.toUpperCase == "POST" || verb.toUpperCase == "PUT" || verb.toUpperCase == "PATCH" || verb.toUpperCase == "DELETE") {
}

if (upperCaseVerb == "POST" || upperCaseVerb == "PUT" || upperCaseVerb == "PATCH" || upperCaseVerb == "DELETE") {
if (!chunkedUpload) {
val bytes = new ByteArrayOutputStream()
data.write(compress.wrap(bytes))
usingOutputStream(compress.wrap(bytes)) { os => data.write(os) }
val byteArray = bytes.toByteArray
connection.setFixedLengthStreamingMode(byteArray.length)
if (byteArray.nonEmpty) connection.getOutputStream.write(byteArray)
usingOutputStream(connection.getOutputStream) { os => os.write(byteArray) }
} else {
connection.setChunkedStreamingMode(0)
data.write(compress.wrap(connection.getOutputStream))
usingOutputStream(compress.wrap(connection.getOutputStream)) { os => data.write(os) }
}
}

Expand Down Expand Up @@ -333,7 +332,7 @@ case class Requester(verb: String,
// The HEAD method is identical to GET except that the server
// MUST NOT return a message-body in the response.
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html section 9.4
if (verb == "HEAD") f(new ByteArrayInputStream(Array()))
if (upperCaseVerb == "HEAD") f(new ByteArrayInputStream(Array()))
else if (stream != null) {
try f(
if (deGzip) new GZIPInputStream(stream)
Expand Down Expand Up @@ -366,6 +365,9 @@ case class Requester(verb: String,
}
}
}

private def usingOutputStream[T](os: OutputStream)(fn: OutputStream => T): Unit =
try fn(os) finally os.close()

/**
* Overload of [[Requester.apply]] that takes a [[Request]] object as configuration
Expand Down
2 changes: 2 additions & 0 deletions requests/test/src-2/requests/Scala2RequestTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ object Scala2RequestTests extends TestSuite{
assert(read(res1).obj("form") == Obj("foo" -> "baz", "hello" -> "world"))
}
}

test("put") {
for (chunkedUpload <- Seq(true, false)) {
val res1 = requests.put(
Expand All @@ -28,6 +29,7 @@ object Scala2RequestTests extends TestSuite{
assert(read(res1).obj("form") == Obj("foo" -> "baz", "hello" -> "world"))
}
}

test("send"){
requests.send("get")("https://httpbin.org/get?hello=world&foo=baz")

Expand Down
86 changes: 58 additions & 28 deletions requests/test/src/requests/RequestTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ object RequestTests extends TestSuite{
}
}
}

test("params"){
test("get"){
// All in URL
Expand Down Expand Up @@ -58,6 +59,7 @@ object RequestTests extends TestSuite{
assert(read(res4).obj("args") == Obj("++-- lol" -> " !@#$%", "hello" -> "world"))
}
}

test("multipart"){
for(chunkedUpload <- Seq(true, false)) {
val response = requests.post(
Expand All @@ -73,8 +75,8 @@ object RequestTests extends TestSuite{
assert(read(response).obj("form") == Obj("file2" -> "Goodbye!"))
}
}
test("cookies"){

test("cookies"){
test("session"){
val s = requests.Session(cookieValues = Map("hello" -> "world"))
val res1 = s.get("https://httpbin.org/cookies").text().trim
Expand All @@ -99,35 +101,37 @@ object RequestTests extends TestSuite{
assert(read(res2) == Obj("cookies" -> Obj("freeform" -> "test test", "hello" -> "hello, world")))
}
}
// Tests fail with 'Request to https://httpbin.org/absolute-redirect/4 failed with status code 404'
// test("redirects"){
// test("max"){
// val res1 = requests.get("https://httpbin.org/absolute-redirect/4")
// assert(res1.statusCode == 200)
// val res2 = requests.get("https://httpbin.org/absolute-redirect/5")
// assert(res2.statusCode == 200)
// val res3 = requests.get("https://httpbin.org/absolute-redirect/6", check = false)
// assert(res3.statusCode == 302)
// val res4 = requests.get("https://httpbin.org/absolute-redirect/6", maxRedirects = 10)
// assert(res4.statusCode == 200)
// }
// test("maxRelative"){
// val res1 = requests.get("https://httpbin.org/relative-redirect/4")
// assert(res1.statusCode == 200)
// val res2 = requests.get("https://httpbin.org/relative-redirect/5")
// assert(res2.statusCode == 200)
// val res3 = requests.get("https://httpbin.org/relative-redirect/6", check = false)
// assert(res3.statusCode == 302)
// val res4 = requests.get("https://httpbin.org/relative-redirect/6", maxRedirects = 10)
// assert(res4.statusCode == 200)
// }
// }

test("redirects"){
test("max"){
val res1 = requests.get("https://httpbin.org/absolute-redirect/4")
assert(res1.statusCode == 200)
val res2 = requests.get("https://httpbin.org/absolute-redirect/5")
assert(res2.statusCode == 200)
val res3 = requests.get("https://httpbin.org/absolute-redirect/6", check = false)
assert(res3.statusCode == 302)
val res4 = requests.get("https://httpbin.org/absolute-redirect/6", maxRedirects = 10)
assert(res4.statusCode == 200)
}
test("maxRelative"){
val res1 = requests.get("https://httpbin.org/relative-redirect/4")
assert(res1.statusCode == 200)
val res2 = requests.get("https://httpbin.org/relative-redirect/5")
assert(res2.statusCode == 200)
val res3 = requests.get("https://httpbin.org/relative-redirect/6", check = false)
assert(res3.statusCode == 302)
val res4 = requests.get("https://httpbin.org/relative-redirect/6", maxRedirects = 10)
assert(res4.statusCode == 200)
}
}

test("streaming"){
val res1 = requests.get("http://httpbin.org/stream/5").text()
assert(res1.linesIterator.length == 5)
val res2 = requests.get("http://httpbin.org/stream/52").text()
assert(res2.linesIterator.length == 52)
}

test("timeouts"){
test("read"){
intercept[TimeoutException] {
Expand All @@ -144,6 +148,7 @@ object RequestTests extends TestSuite{
}
}
}

test("failures"){
intercept[UnknownHostException]{
requests.get("https://doesnt-exist-at-all.com/")
Expand All @@ -156,6 +161,7 @@ object RequestTests extends TestSuite{
requests.get("://doesnt-exist.com/")
}
}

test("decompress"){
val res1 = requests.get("https://httpbin.org/gzip")
assert(read(res1.text()).obj("headers").obj("Host").str == "httpbin.org")
Expand All @@ -171,6 +177,7 @@ object RequestTests extends TestSuite{

(res1.bytes.length, res2.bytes.length, res3.bytes.length, res4.bytes.length)
}

test("compression"){
val res1 = requests.post(
"https://httpbin.org/post",
Expand All @@ -184,16 +191,18 @@ object RequestTests extends TestSuite{
compress = requests.Compress.Gzip,
data = new RequestBlob.ByteSourceRequestBlob("I am cow")
)
assert(res2.text().contains("data:application/octet-stream;base64,H4sIAAAAAAAAAA=="))
assert(read(new String(res2.bytes))("data").toString ==
""""data:application/octet-stream;base64,H4sIAAAAAAAAAPNUSMxVSM4vBwCAGeD4CAAAAA=="""")

val res3 = requests.post(
"https://httpbin.org/post",
compress = requests.Compress.Deflate,
data = new RequestBlob.ByteSourceRequestBlob("Hear me moo")
)
assert(res3.text().contains("data:application/octet-stream;base64,eJw="))
res3.text()
}
assert(read(new String(res2.bytes))("data").toString ==
""""data:application/octet-stream;base64,H4sIAAAAAAAAAPNUSMxVSM4vBwCAGeD4CAAAAA=="""")
}

test("headers"){
test("default"){
val res = requests.get("https://httpbin.org/headers").text()
Expand All @@ -207,6 +216,7 @@ object RequestTests extends TestSuite{
}
}
}

test("clientCertificate"){
val base = "./requests/test/resources"
val url = "https://client.badssl.com"
Expand Down Expand Up @@ -253,13 +263,15 @@ object RequestTests extends TestSuite{
assert(res.statusCode == 400)
}
}

test("selfSignedCertificate"){
val res = requests.get(
"https://self-signed.badssl.com",
verifySslCerts = false
)
assert(res.statusCode == 200)
}

test("gzipError"){
val response = requests.head("https://api.github.com/users/lihaoyi")
assert(response.statusCode == 200)
Expand All @@ -268,5 +280,23 @@ object RequestTests extends TestSuite{
assert(response.headers.keySet.map(_.toLowerCase).contains("content-length"))
assert(response.headers.keySet.map(_.toLowerCase).contains("content-type"))
}

/**
* Compress with each compression mode and call server. Server expands
* and passes it back so we can compare
*/
test("compressionData") {
import requests.Compress._
val str = "I am deflater mouse"
Seq(None, Gzip, Deflate).foreach(c =>
ServerUtils.usingEchoServer { port =>
assert(str == requests.post(
s"http://localhost:$port/echo",
compress = c,
data = new RequestBlob.ByteSourceRequestBlob(str)
).data.toString)
}
)
}
}
}
77 changes: 77 additions & 0 deletions requests/test/src/requests/ServerUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package requests

import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer}
import java.io._
import java.net.InetSocketAddress
import java.util.zip.{GZIPInputStream, InflaterInputStream}
import requests.Compress._
import scala.annotation.tailrec
import scala.collection.mutable.StringBuilder

object ServerUtils {
def usingEchoServer(f: Int => Unit): Unit = {
val server = new EchoServer
try f(server.getPort())
finally server.stop()
}

private class EchoServer extends HttpHandler {
private val server: HttpServer = HttpServer.create(new InetSocketAddress(0), 0)
server.createContext("/echo", this)
server.setExecutor(null); // default executor
server.start()

def getPort(): Int = server.getAddress.getPort

def stop(): Unit = server.stop(0)

override def handle(t: HttpExchange): Unit = {
val h: java.util.List[String] =
t.getRequestHeaders.get("Content-encoding")
val c: Compress =
if (h == null) None
else if (h.contains("gzip")) Gzip
else if (h.contains("deflate")) Deflate
else None
val msg = new Plumper(c).decompress(t.getRequestBody)
t.sendResponseHeaders(200, msg.length)
t.getResponseBody.write(msg.getBytes())
}
}

/** Stream uncompresser
* @param c
* Compression mode
*/
private class Plumper(c: Compress) {

private def wrap(is: InputStream): InputStream =
c match {
case None => is
case Gzip => new GZIPInputStream(is)
case Deflate => new InflaterInputStream(is)
}

def decompress(compressed: InputStream): String = {
val gis = wrap(compressed)
val br = new BufferedReader(new InputStreamReader(gis, "UTF-8"))
val sb = new StringBuilder()

@tailrec
def read(): Unit = {
val line = br.readLine
if (line != null) {
sb.append(line)
read()
}
}

read()
br.close()
gis.close()
compressed.close()
sb.toString()
}
}

}

0 comments on commit 7e18f64

Please sign in to comment.