Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix compression and tests therefore #108

Merged
merged 7 commits into from
Jun 8, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 6 additions & 3 deletions build.sc
Original file line number Diff line number Diff line change
@@ -1,17 +1,20 @@
import mill._
import mill.scalalib.publish.{Developer, License, PomSettings, VersionControl}
import scalalib._
import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version_mill0.9:0.1.1`
import $ivy.`de.tototec::de.tobiasroeser.mill.vcs.version::0.1.4`
import de.tobiasroeser.mill.vcs.version.VcsVersion
import $ivy.`com.github.lolgab::mill-mima_mill0.9:0.0.4`
import $ivy.`com.github.lolgab::mill-mima::0.0.10`
import com.github.lolgab.mill.mima._

val dottyVersion = Option(sys.props("dottyVersion"))

object requests extends Cross[RequestsModule]((List("2.12.13", "2.13.5", "2.11.12", "3.0.0") ++ dottyVersion): _*)
class RequestsModule(val crossScalaVersion: String) extends CrossScalaModule with PublishModule with Mima {
def publishVersion = VcsVersion.vcsState().format()
def mimaPreviousVersions = VcsVersion.vcsState().lastTag.toSeq
def mimaPreviousVersions = Seq("0.7.0") ++ VcsVersion.vcsState().lastTag.toSeq
override def mimaBinaryIssueFilters = Seq(
ProblemFilter.exclude[ReversedMissingMethodProblem]("requests.BaseSession.send")
)
def artifactName = "requests"
def pomSettings = PomSettings(
description = "Scala port of the popular Python Requests HTTP client",
Expand Down
2 changes: 1 addition & 1 deletion mill
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# This is a wrapper script, that automatically download mill from GitHub release pages
# You can give the required mill version with MILL_VERSION env variable
# If no version is given, it falls back to the value of DEFAULT_MILL_VERSION
DEFAULT_MILL_VERSION=0.9.7
DEFAULT_MILL_VERSION=0.9.12

set -e

Expand Down
20 changes: 11 additions & 9 deletions requests/src/requests/Requester.scala
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import scala.collection.mutable

trait BaseSession{
def headers: Map[String, String]

def cookies: mutable.Map[String, HttpCookie]
def readTimeout: Int
def connectTimeout: Int
Expand Down Expand Up @@ -54,7 +53,7 @@ object Requester{
}
case class Requester(verb: String,
sess: BaseSession){

private val upperCaseVerb = verb.toUpperCase

/**
* Makes a single HTTP request, and returns a [[Response]] object. Requires
Expand Down Expand Up @@ -204,7 +203,6 @@ case class Requester(verb: String,
}

connection.setInstanceFollowRedirects(false)
val upperCaseVerb = verb.toUpperCase
if (Requester.officialHttpMethods.contains(upperCaseVerb)) {
connection.setRequestMethod(upperCaseVerb)
} else {
Expand Down Expand Up @@ -250,17 +248,18 @@ case class Requester(verb: String,
.map{case (k, v) => s"""$k="$v""""}
.mkString("; ")
)
}
if (verb.toUpperCase == "POST" || verb.toUpperCase == "PUT" || verb.toUpperCase == "PATCH" || verb.toUpperCase == "DELETE") {
}

if (upperCaseVerb == "POST" || upperCaseVerb == "PUT" || upperCaseVerb == "PATCH" || upperCaseVerb == "DELETE") {
if (!chunkedUpload) {
val bytes = new ByteArrayOutputStream()
data.write(compress.wrap(bytes))
usingOutputStream(compress.wrap(bytes)) { os => data.write(os) }
val byteArray = bytes.toByteArray
connection.setFixedLengthStreamingMode(byteArray.length)
if (byteArray.nonEmpty) connection.getOutputStream.write(byteArray)
usingOutputStream(connection.getOutputStream) { os => os.write(byteArray) }
} else {
connection.setChunkedStreamingMode(0)
data.write(compress.wrap(connection.getOutputStream))
usingOutputStream(compress.wrap(connection.getOutputStream)) { os => data.write(os) }
}
}

Expand Down Expand Up @@ -333,7 +332,7 @@ case class Requester(verb: String,
// The HEAD method is identical to GET except that the server
// MUST NOT return a message-body in the response.
// https://www.w3.org/Protocols/rfc2616/rfc2616-sec9.html section 9.4
if (verb == "HEAD") f(new ByteArrayInputStream(Array()))
if (upperCaseVerb == "HEAD") f(new ByteArrayInputStream(Array()))
else if (stream != null) {
try f(
if (deGzip) new GZIPInputStream(stream)
Expand Down Expand Up @@ -366,6 +365,9 @@ case class Requester(verb: String,
}
}
}

private def usingOutputStream[T](os: OutputStream)(fn: OutputStream => T): Unit =
try fn(os) finally os.close()

/**
* Overload of [[Requester.apply]] that takes a [[Request]] object as configuration
Expand Down
2 changes: 2 additions & 0 deletions requests/test/src-2/requests/Scala2RequestTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ object Scala2RequestTests extends TestSuite{
assert(read(res1).obj("form") == Obj("foo" -> "baz", "hello" -> "world"))
}
}

test("put") {
for (chunkedUpload <- Seq(true, false)) {
val res1 = requests.put(
Expand All @@ -28,6 +29,7 @@ object Scala2RequestTests extends TestSuite{
assert(read(res1).obj("form") == Obj("foo" -> "baz", "hello" -> "world"))
}
}

test("send"){
requests.send("get")("https://httpbin.org/get?hello=world&foo=baz")

Expand Down
86 changes: 58 additions & 28 deletions requests/test/src/requests/RequestTests.scala
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ object RequestTests extends TestSuite{
}
}
}

test("params"){
test("get"){
// All in URL
Expand Down Expand Up @@ -58,6 +59,7 @@ object RequestTests extends TestSuite{
assert(read(res4).obj("args") == Obj("++-- lol" -> " !@#$%", "hello" -> "world"))
}
}

test("multipart"){
for(chunkedUpload <- Seq(true, false)) {
val response = requests.post(
Expand All @@ -73,8 +75,8 @@ object RequestTests extends TestSuite{
assert(read(response).obj("form") == Obj("file2" -> "Goodbye!"))
}
}
test("cookies"){

test("cookies"){
test("session"){
val s = requests.Session(cookieValues = Map("hello" -> "world"))
val res1 = s.get("https://httpbin.org/cookies").text().trim
Expand All @@ -99,35 +101,37 @@ object RequestTests extends TestSuite{
assert(read(res2) == Obj("cookies" -> Obj("freeform" -> "test test", "hello" -> "hello, world")))
}
}
// Tests fail with 'Request to https://httpbin.org/absolute-redirect/4 failed with status code 404'
// test("redirects"){
// test("max"){
// val res1 = requests.get("https://httpbin.org/absolute-redirect/4")
// assert(res1.statusCode == 200)
// val res2 = requests.get("https://httpbin.org/absolute-redirect/5")
// assert(res2.statusCode == 200)
// val res3 = requests.get("https://httpbin.org/absolute-redirect/6", check = false)
// assert(res3.statusCode == 302)
// val res4 = requests.get("https://httpbin.org/absolute-redirect/6", maxRedirects = 10)
// assert(res4.statusCode == 200)
// }
// test("maxRelative"){
// val res1 = requests.get("https://httpbin.org/relative-redirect/4")
// assert(res1.statusCode == 200)
// val res2 = requests.get("https://httpbin.org/relative-redirect/5")
// assert(res2.statusCode == 200)
// val res3 = requests.get("https://httpbin.org/relative-redirect/6", check = false)
// assert(res3.statusCode == 302)
// val res4 = requests.get("https://httpbin.org/relative-redirect/6", maxRedirects = 10)
// assert(res4.statusCode == 200)
// }
// }

test("redirects"){
test("max"){
val res1 = requests.get("https://httpbin.org/absolute-redirect/4")
assert(res1.statusCode == 200)
val res2 = requests.get("https://httpbin.org/absolute-redirect/5")
assert(res2.statusCode == 200)
val res3 = requests.get("https://httpbin.org/absolute-redirect/6", check = false)
assert(res3.statusCode == 302)
val res4 = requests.get("https://httpbin.org/absolute-redirect/6", maxRedirects = 10)
assert(res4.statusCode == 200)
}
test("maxRelative"){
val res1 = requests.get("https://httpbin.org/relative-redirect/4")
assert(res1.statusCode == 200)
val res2 = requests.get("https://httpbin.org/relative-redirect/5")
assert(res2.statusCode == 200)
val res3 = requests.get("https://httpbin.org/relative-redirect/6", check = false)
assert(res3.statusCode == 302)
val res4 = requests.get("https://httpbin.org/relative-redirect/6", maxRedirects = 10)
assert(res4.statusCode == 200)
}
}

test("streaming"){
val res1 = requests.get("http://httpbin.org/stream/5").text()
assert(res1.linesIterator.length == 5)
val res2 = requests.get("http://httpbin.org/stream/52").text()
assert(res2.linesIterator.length == 52)
}

test("timeouts"){
test("read"){
intercept[TimeoutException] {
Expand All @@ -144,6 +148,7 @@ object RequestTests extends TestSuite{
}
}
}

test("failures"){
intercept[UnknownHostException]{
requests.get("https://doesnt-exist-at-all.com/")
Expand All @@ -156,6 +161,7 @@ object RequestTests extends TestSuite{
requests.get("://doesnt-exist.com/")
}
}

test("decompress"){
val res1 = requests.get("https://httpbin.org/gzip")
assert(read(res1.text()).obj("headers").obj("Host").str == "httpbin.org")
Expand All @@ -171,6 +177,7 @@ object RequestTests extends TestSuite{

(res1.bytes.length, res2.bytes.length, res3.bytes.length, res4.bytes.length)
}

test("compression"){
val res1 = requests.post(
"https://httpbin.org/post",
Expand All @@ -184,16 +191,18 @@ object RequestTests extends TestSuite{
compress = requests.Compress.Gzip,
data = new RequestBlob.ByteSourceRequestBlob("I am cow")
)
assert(res2.text().contains("data:application/octet-stream;base64,H4sIAAAAAAAAAA=="))
assert(read(new String(res2.bytes))("data").toString ==
""""data:application/octet-stream;base64,H4sIAAAAAAAAAPNUSMxVSM4vBwCAGeD4CAAAAA=="""")

val res3 = requests.post(
"https://httpbin.org/post",
compress = requests.Compress.Deflate,
data = new RequestBlob.ByteSourceRequestBlob("Hear me moo")
)
assert(res3.text().contains("data:application/octet-stream;base64,eJw="))
res3.text()
}
assert(read(new String(res2.bytes))("data").toString ==
""""data:application/octet-stream;base64,H4sIAAAAAAAAAPNUSMxVSM4vBwCAGeD4CAAAAA=="""")
}

test("headers"){
test("default"){
val res = requests.get("https://httpbin.org/headers").text()
Expand All @@ -207,6 +216,7 @@ object RequestTests extends TestSuite{
}
}
}

test("clientCertificate"){
val base = "./requests/test/resources"
val url = "https://client.badssl.com"
Expand Down Expand Up @@ -253,13 +263,15 @@ object RequestTests extends TestSuite{
assert(res.statusCode == 400)
}
}

test("selfSignedCertificate"){
val res = requests.get(
"https://self-signed.badssl.com",
verifySslCerts = false
)
assert(res.statusCode == 200)
}

test("gzipError"){
val response = requests.head("https://api.github.com/users/lihaoyi")
assert(response.statusCode == 200)
Expand All @@ -268,5 +280,23 @@ object RequestTests extends TestSuite{
assert(response.headers.keySet.map(_.toLowerCase).contains("content-length"))
assert(response.headers.keySet.map(_.toLowerCase).contains("content-type"))
}

/**
* Compress with each compression mode and call server. Server expands
* and passes it back so we can compare
*/
test("compressionData") {
import requests.Compress._
val str = "I am deflater mouse"
Seq(None, Gzip, Deflate).foreach(c =>
ServerUtils.usingEchoServer { port =>
assert(str == requests.post(
s"http://localhost:$port/echo",
compress = c,
data = new RequestBlob.ByteSourceRequestBlob(str)
).data.toString)
}
)
}
}
}
77 changes: 77 additions & 0 deletions requests/test/src/requests/ServerUtils.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package requests

import com.sun.net.httpserver.{HttpExchange, HttpHandler, HttpServer}
import java.io._
import java.net.InetSocketAddress
import java.util.zip.{GZIPInputStream, InflaterInputStream}
import requests.Compress._
import scala.annotation.tailrec
import scala.collection.mutable.StringBuilder

object ServerUtils {
def usingEchoServer(f: Int => Unit): Unit = {
val server = new EchoServer
try f(server.getPort())
finally server.stop()
}

private class EchoServer extends HttpHandler {
private val server: HttpServer = HttpServer.create(new InetSocketAddress(0), 0)
server.createContext("/echo", this)
server.setExecutor(null); // default executor
server.start()

def getPort(): Int = server.getAddress.getPort

def stop(): Unit = server.stop(0)

override def handle(t: HttpExchange): Unit = {
val h: java.util.List[String] =
t.getRequestHeaders.get("Content-encoding")
val c: Compress =
if (h == null) None
else if (h.contains("gzip")) Gzip
else if (h.contains("deflate")) Deflate
else None
val msg = new Plumper(c).decompress(t.getRequestBody)
t.sendResponseHeaders(200, msg.length)
t.getResponseBody.write(msg.getBytes())
}
}

/** Stream uncompresser
* @param c
* Compression mode
*/
private class Plumper(c: Compress) {

private def wrap(is: InputStream): InputStream =
c match {
case None => is
case Gzip => new GZIPInputStream(is)
case Deflate => new InflaterInputStream(is)
}

def decompress(compressed: InputStream): String = {
val gis = wrap(compressed)
val br = new BufferedReader(new InputStreamReader(gis, "UTF-8"))
val sb = new StringBuilder()

@tailrec
def read(): Unit = {
val line = br.readLine
if (line != null) {
sb.append(line)
read()
}
}

read()
br.close()
gis.close()
compressed.close()
sb.toString()
}
}

}