diff --git a/Cargo.lock b/Cargo.lock index 2c1bc94c0..2727f1116 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -11,7 +11,7 @@ dependencies = [ "data-encoding", "hyper 1.4.1", "openssl", - "reqwest", + "reqwest 0.11.27", "serde", "serde_json", "thiserror", @@ -127,9 +127,9 @@ dependencies = [ [[package]] name = "anyhow" -version = "1.0.82" +version = "1.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" +checksum = "86fdf8605db99b54d3cd748a44c6d04df638eb5dafb219b135d0149bd0db01f6" [[package]] name = "asn1-rs" @@ -212,12 +212,45 @@ dependencies = [ "num-traits", ] +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + [[package]] name = "autocfg" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" +[[package]] +name = "aws-lc-rs" +version = "1.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2f95446d919226d587817a7d21379e6eb099b97b45110a7f272a444ca5c54070" +dependencies = [ + "aws-lc-sys", + "mirai-annotations", + "paste", + "zeroize", +] + +[[package]] +name = "aws-lc-sys" +version = "0.21.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3ddc4a5b231dd6958b140ff3151b6412b3f4321fab354f399eec8f14b06df62" +dependencies = [ + "bindgen", + "cc", + "cmake", + "dunce", + "fs_extra", + "libc", + "paste", +] + [[package]] name = "axum" version = "0.6.20" @@ -247,20 +280,21 @@ dependencies = [ "sha1", "sync_wrapper 0.1.2", "tokio", - "tokio-tungstenite", - "tower", + "tokio-tungstenite 0.20.1", + "tower 0.4.13", "tower-layer", "tower-service", ] [[package]] name = "axum" -version = "0.7.5" +version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf" +checksum = "8f43644eed690f5374f1af436ecd6aea01cd201f6fbdf0178adaf6907afb2cec" dependencies = [ "async-trait", - "axum-core 0.4.3", + "axum-core 0.4.4", + "base64 0.21.7", "bytes", "futures-util", "http 1.1.0", @@ -279,9 +313,11 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper 1.0.1", "tokio", - "tower", + "tokio-tungstenite 0.23.1", + "tower 0.5.1", "tower-layer", "tower-service", "tracing", @@ -306,9 +342,9 @@ dependencies = [ [[package]] name = "axum-core" -version = "0.4.3" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a15c63fd72d41492dc4f497196f5da1fb04fb7529e631d73630d1b491e47a2e3" +checksum = "5e6b8ba012a258d63c9adfa28b9ddcf66149da6f986c5b5452e629d5ee64bf00" dependencies = [ "async-trait", "bytes", @@ -319,7 +355,7 @@ dependencies = [ "mime", "pin-project-lite", "rustversion", - "sync_wrapper 0.1.2", + "sync_wrapper 1.0.1", "tower-layer", "tower-service", "tracing", @@ -358,6 +394,29 @@ version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" +[[package]] +name = "bindgen" +version = "0.69.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" +dependencies = [ + "bitflags 2.5.0", + "cexpr", + "clang-sys", + "itertools", + "lazy_static", + "lazycell", + "log", + "prettyplease", + "proc-macro2", + "quote", + "regex", + "rustc-hash", + "shlex", + "syn 2.0.60", + "which", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -440,15 +499,29 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "bytes" -version = "1.6.0" +version = "1.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "514de17de45fdb8dc022b1a7975556c53c86f9f0aa5f534b98977b171857c2c9" +checksum = "428d9aa8fbc0670b7b8d6030a7fadd0f86151cae55e4dbbece15f3780a3dfaf3" [[package]] name = "cc" -version = "1.0.95" +version = "1.1.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "07b1695e2c7e8fc85310cde85aeaab7e3097f593c91d209d3f9df76c928100f0" +dependencies = [ + "jobserver", + "libc", + "shlex", +] + +[[package]] +name = "cexpr" +version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d32a725bc159af97c3e629873bb9f88fb8cf8a4867175f76dc987815ea07c83b" +checksum = "6fac387a98bb7c37292057cffc56d62ecb629900026402633ae9160df93a8766" +dependencies = [ + "nom", +] [[package]] name = "cfg-if" @@ -468,7 +541,18 @@ dependencies = [ "num-traits", "serde", "wasm-bindgen", - "windows-targets 0.52.5", + "windows-targets 0.52.6", +] + +[[package]] +name = "clang-sys" +version = "1.8.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +dependencies = [ + "glob", + "libc", + "libloading", ] [[package]] @@ -511,6 +595,15 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +[[package]] +name = "cmake" +version = "0.1.51" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fb1e43aa7fd152b1f968787f7dbcdeb306d1867ff373c69955211876c053f91a" +dependencies = [ + "cc", +] + [[package]] name = "colorchoice" version = "1.0.0" @@ -719,9 +812,9 @@ dependencies = [ [[package]] name = "displaydoc" -version = "0.2.4" +version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "487585f4d0c6655fe74905e2504d8ad6908e4db67f744eb140876906c2f3175d" +checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", @@ -743,6 +836,38 @@ dependencies = [ "tokio", ] +[[package]] +name = "dunce" +version = "1.0.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92773504d58c093f6de2459af4af33faa518c13451eb8f2b5698ed3d36e7c813" + +[[package]] +name = "dynamic-proxy" +version = "0.1.0" +dependencies = [ + "anyhow", + "axum 0.7.6", + "bytes", + "futures-util", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.4.1", + "hyper-util", + "pin-project-lite", + "rcgen", + "reqwest 0.12.7", + "rustls 0.23.13", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-rustls 0.26.0", + "tokio-tungstenite 0.24.0", + "tracing", +] + [[package]] name = "either" version = "1.13.0" @@ -870,6 +995,12 @@ dependencies = [ "percent-encoding", ] +[[package]] +name = "fs_extra" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "42703706b716c37f96a77aea830392ad231f44c9e9a67872fa5548707e11b11c" + [[package]] name = "futures-channel" version = "0.3.30" @@ -981,6 +1112,12 @@ version = "0.28.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4271d37baee1b8c7e4b708028c57d816cf9d2434acb33a549475f78c181f6253" +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "h2" version = "0.3.26" @@ -1000,6 +1137,25 @@ dependencies = [ "tracing", ] +[[package]] +name = "h2" +version = "0.4.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "524e8ac6999421f49a846c2d4411f337e53497d8ec55d67753beffa43c5d9205" +dependencies = [ + "atomic-waker", + "bytes", + "fnv", + "futures-core", + "futures-sink", + "http 1.1.0", + "indexmap 2.2.6", + "slab", + "tokio", + "tokio-util", + "tracing", +] + [[package]] name = "hashbrown" version = "0.12.3" @@ -1160,7 +1316,7 @@ dependencies = [ "futures-channel", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", "httparse", @@ -1183,6 +1339,7 @@ dependencies = [ "bytes", "futures-channel", "futures-util", + "h2 0.4.6", "http 1.1.0", "http-body 1.0.1", "httparse", @@ -1218,16 +1375,49 @@ dependencies = [ "futures-util", "http 0.2.12", "hyper 0.14.28", - "rustls", + "rustls 0.21.11", + "tokio", + "tokio-rustls 0.24.1", +] + +[[package]] +name = "hyper-rustls" +version = "0.27.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08afdbb5c31130e3034af566421053ab03787c640246a446327f550d11bcb333" +dependencies = [ + "futures-util", + "http 1.1.0", + "hyper 1.4.1", + "hyper-util", + "rustls 0.23.13", + "rustls-pki-types", + "tokio", + "tokio-rustls 0.26.0", + "tower-service", +] + +[[package]] +name = "hyper-tls" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70206fc6890eaca9fde8a0bf71caa2ddfc9fe045ac9e5c70df101a7dbde866e0" +dependencies = [ + "bytes", + "http-body-util", + "hyper 1.4.1", + "hyper-util", + "native-tls", "tokio", - "tokio-rustls", + "tokio-native-tls", + "tower-service", ] [[package]] name = "hyper-util" -version = "0.1.6" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ab92f4f49ee4fb4f997c784b7a2e0fa70050211e0b6a287f898c3c9785ca956" +checksum = "da62f120a8a37763efb0cf8fdf264b884c7b8b9ac8660b900c8661030c00e6ba" dependencies = [ "bytes", "futures-channel", @@ -1238,7 +1428,7 @@ dependencies = [ "pin-project-lite", "socket2", "tokio", - "tower", + "tower 0.4.13", "tower-service", "tracing", ] @@ -1344,12 +1534,30 @@ dependencies = [ "serde", ] +[[package]] +name = "itertools" +version = "0.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ba291022dbbd398a455acf126c1e341954079855bc60dfdda641363bd6922569" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jobserver" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "48d1dbcbbeb6a7fec7e059840aa538bd62aaccf972c7346c4d9d2059312853d0" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.69" @@ -1368,12 +1576,28 @@ dependencies = [ "spin 0.5.2", ] +[[package]] +name = "lazycell" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" + [[package]] name = "libc" version = "0.2.153" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" +[[package]] +name = "libloading" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" +dependencies = [ + "cfg-if", + "windows-targets 0.52.6", +] + [[package]] name = "libm" version = "0.2.8" @@ -1486,6 +1710,29 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "mirai-annotations" +version = "1.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c9be0862c1b3f26a88803c4a49de6889c10e608b3ee9344e6ef5b45fb37ad3d1" + +[[package]] +name = "native-tls" +version = "0.2.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8614eb2c83d59d1c8cc974dd3f920198647674a0a035e1af1fa58707e317466" +dependencies = [ + "libc", + "log", + "openssl", + "openssl-probe", + "openssl-sys", + "schannel", + "security-framework", + "security-framework-sys", + "tempfile", +] + [[package]] name = "nom" version = "7.1.3" @@ -1508,11 +1755,10 @@ dependencies = [ [[package]] name = "num-bigint" -version = "0.4.4" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ - "autocfg", "num-integer", "num-traits", ] @@ -1620,6 +1866,12 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "openssl-probe" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" + [[package]] name = "openssl-sys" version = "0.9.103" @@ -1767,11 +2019,13 @@ dependencies = [ "async-trait", "axum 0.6.20", "bollard", + "bytes", "chrono", "clap", "colored", "dashmap", "data-encoding", + "dynamic-proxy", "futures-util", "http-body 0.4.6", "hyper 0.14.28", @@ -1779,8 +2033,7 @@ dependencies = [ "openssl", "pem", "rand", - "reqwest", - "ring", + "reqwest 0.11.27", "rusqlite", "rustls-pemfile 2.1.2", "rustls-pki-types", @@ -1791,15 +2044,14 @@ dependencies = [ "thiserror", "time", "tokio", - "tokio-rustls", "tokio-stream", - "tokio-tungstenite", - "tower", + "tokio-tungstenite 0.20.1", + "tower 0.4.13", "tower-http", "tracing", "tracing-subscriber", "trust-dns-server", - "tungstenite", + "tungstenite 0.20.1", "url", "valuable", "x509-parser", @@ -1827,17 +2079,22 @@ version = "0.4.12" dependencies = [ "anyhow", "async-trait", - "axum 0.7.5", + "axum 0.7.6", "bollard", "chrono", + "dynamic-proxy", "futures-util", + "http 1.1.0", + "http-body-util", "hyper 0.14.28", "plane-dynamic", "plane-test-macro", - "reqwest", + "reqwest 0.11.27", + "serde", "serde_json", "thiserror", "tokio", + "tokio-tungstenite 0.24.0", "tracing", "tracing-appender", "tracing-subscriber", @@ -1856,6 +2113,16 @@ version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" +[[package]] +name = "prettyplease" +version = "0.2.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +dependencies = [ + "proc-macro2", + "syn 2.0.60", +] + [[package]] name = "proc-macro2" version = "1.0.81" @@ -1904,6 +2171,19 @@ dependencies = [ "getrandom", ] +[[package]] +name = "rcgen" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "54077e1872c46788540de1ea3d7f4ccb1983d12f9aa909b234468676c1a36779" +dependencies = [ + "pem", + "ring", + "rustls-pki-types", + "time", + "yasna", +] + [[package]] name = "redox_syscall" version = "0.4.1" @@ -1968,11 +2248,11 @@ dependencies = [ "encoding_rs", "futures-core", "futures-util", - "h2", + "h2 0.3.26", "http 0.2.12", "http-body 0.4.6", "hyper 0.14.28", - "hyper-rustls", + "hyper-rustls 0.24.2", "ipnet", "js-sys", "log", @@ -1980,15 +2260,15 @@ dependencies = [ "once_cell", "percent-encoding", "pin-project-lite", - "rustls", + "rustls 0.21.11", "rustls-pemfile 1.0.4", "serde", "serde_json", "serde_urlencoded", "sync_wrapper 0.1.2", - "system-configuration", + "system-configuration 0.5.1", "tokio", - "tokio-rustls", + "tokio-rustls 0.24.1", "tower-service", "url", "wasm-bindgen", @@ -1998,6 +2278,51 @@ dependencies = [ "winreg", ] +[[package]] +name = "reqwest" +version = "0.12.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8f4955649ef5c38cc7f9e8aa41761d48fb9677197daea9984dc54f56aad5e63" +dependencies = [ + "base64 0.22.0", + "bytes", + "encoding_rs", + "futures-core", + "futures-util", + "h2 0.4.6", + "http 1.1.0", + "http-body 1.0.1", + "http-body-util", + "hyper 1.4.1", + "hyper-rustls 0.27.3", + "hyper-tls", + "hyper-util", + "ipnet", + "js-sys", + "log", + "mime", + "native-tls", + "once_cell", + "percent-encoding", + "pin-project-lite", + "rustls-pemfile 2.1.2", + "serde", + "serde_json", + "serde_urlencoded", + "sync_wrapper 1.0.1", + "system-configuration 0.6.1", + "tokio", + "tokio-native-tls", + "tokio-util", + "tower-service", + "url", + "wasm-bindgen", + "wasm-bindgen-futures", + "wasm-streams", + "web-sys", + "windows-registry", +] + [[package]] name = "ring" version = "0.17.8" @@ -2054,6 +2379,12 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc-hash" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" + [[package]] name = "rusticata-macros" version = "4.1.0" @@ -2084,10 +2415,26 @@ checksum = "7fecbfb7b1444f477b345853b1fce097a2c6fb637b2bfb87e6bc5db0f043fae4" dependencies = [ "log", "ring", - "rustls-webpki", + "rustls-webpki 0.101.7", "sct", ] +[[package]] +name = "rustls" +version = "0.23.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2dabaac7466917e566adb06783a81ca48944c6898a1b08b9374106dd671f4c8" +dependencies = [ + "aws-lc-rs", + "log", + "once_cell", + "ring", + "rustls-pki-types", + "rustls-webpki 0.102.8", + "subtle", + "zeroize", +] + [[package]] name = "rustls-pemfile" version = "1.0.4" @@ -2109,9 +2456,9 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.5.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54" +checksum = "fc0a2ce646f8655401bb81e7927b812614bd5d91dbc968696be50603510fcaf0" [[package]] name = "rustls-webpki" @@ -2123,6 +2470,18 @@ dependencies = [ "untrusted", ] +[[package]] +name = "rustls-webpki" +version = "0.102.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "64ca1bc8749bd4cf37b5ce386cc146580777b4e8572c7b97baf22c83f444bee9" +dependencies = [ + "aws-lc-rs", + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "rustversion" version = "1.0.15" @@ -2135,6 +2494,15 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +[[package]] +name = "schannel" +version = "0.1.24" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9aaafd5a2b6e3d657ff009d82fbd630b6bd54dd4eb06f21693925cdf80f9b8b" +dependencies = [ + "windows-sys 0.59.0", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2151,20 +2519,43 @@ dependencies = [ "untrusted", ] +[[package]] +name = "security-framework" +version = "2.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c627723fd09706bacdb5cf41499e95098555af3c3c29d014dc3c458ef6be11c0" +dependencies = [ + "bitflags 2.5.0", + "core-foundation", + "core-foundation-sys", + "libc", + "security-framework-sys", +] + +[[package]] +name = "security-framework-sys" +version = "2.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ea4a292869320c0272d7bc55a5a6aafaff59b4f63404a003887b679a2e05b4b6" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "serde" -version = "1.0.198" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" +checksum = "c8e3592472072e6e22e0a54d5904d9febf8508f65fb8552499a1abc7d1078c3a" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.198" +version = "1.0.210" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" +checksum = "243902eda00fad750862fc144cea25caca5e20d615af0a81bee94ca738f1df1f" dependencies = [ "proc-macro2", "quote", @@ -2173,11 +2564,12 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.116" +version = "1.0.128" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" +checksum = "6ff5456707a1de34e7e37f2a6fd3d3f808c318259cbd01ab6377795054b483d8" dependencies = [ "itoa", + "memchr", "ryu", "serde", ] @@ -2285,6 +2677,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shlex" +version = "1.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" + [[package]] name = "signal-hook-registry" version = "1.4.2" @@ -2409,7 +2807,7 @@ dependencies = [ "once_cell", "paste", "percent-encoding", - "rustls", + "rustls 0.21.11", "rustls-pemfile 1.0.4", "serde", "serde_json", @@ -2632,6 +3030,9 @@ name = "sync_wrapper" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a7065abeca94b6a8a577f9bd45aa0867a2238b74e8eb67cf10d492bc39351394" +dependencies = [ + "futures-core", +] [[package]] name = "synstructure" @@ -2653,7 +3054,18 @@ checksum = "ba3a3adc5c275d719af8cb4272ea1c4a6d668a777f37e115f6d11ddbc1c8e0e7" dependencies = [ "bitflags 1.3.2", "core-foundation", - "system-configuration-sys", + "system-configuration-sys 0.5.0", +] + +[[package]] +name = "system-configuration" +version = "0.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c879d448e9d986b661742763247d3693ed13609438cf3d006f51f5368a5ba6b" +dependencies = [ + "bitflags 2.5.0", + "core-foundation", + "system-configuration-sys 0.6.0", ] [[package]] @@ -2666,6 +3078,16 @@ dependencies = [ "libc", ] +[[package]] +name = "system-configuration-sys" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e1d1b10ced5ca923a1fcb8d03e96b8d3268065d724548c0211415ff6ac6bac4" +dependencies = [ + "core-foundation-sys", + "libc", +] + [[package]] name = "tempfile" version = "3.10.1" @@ -2680,18 +3102,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.59" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" +checksum = "d50af8abc119fb8bb6dbabcfa89656f46f84aa0ac7688088608076ad2b459a84" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.59" +version = "1.0.64" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" +checksum = "08904e7672f5eb876eaaf87e0ce17857500934f4981c4a0ab2b4aa98baac7fc3" dependencies = [ "proc-macro2", "quote", @@ -2756,9 +3178,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "tokio" -version = "1.39.2" +version = "1.40.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "daa4fb1bc778bd6f04cbfc4bb2d06a7396a8f299dc33ea1900cedaa316f467b1" +checksum = "e2b070231665d27ad9ec9b8df639893f46727666c6767db40317fbe920a5d998" dependencies = [ "backtrace", "bytes", @@ -2782,13 +3204,34 @@ dependencies = [ "syn 2.0.60", ] +[[package]] +name = "tokio-native-tls" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbae76ab933c85776efabc971569dd6119c580d8f5d448769dec1764bf796ef2" +dependencies = [ + "native-tls", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.24.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c28327cf380ac148141087fbfb9de9d7bd4e84ab5d2c28fbc911d753de8a7081" dependencies = [ - "rustls", + "rustls 0.21.11", + "tokio", +] + +[[package]] +name = "tokio-rustls" +version = "0.26.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c7bc40d0e5a97695bb96e27995cd3a08538541b0a846f65bba7a359f36700d4" +dependencies = [ + "rustls 0.23.13", + "rustls-pki-types", "tokio", ] @@ -2812,13 +3255,37 @@ checksum = "212d5dcb2a1ce06d81107c3d0ffa3121fe974b73f068c8282cb1c32328113b6c" dependencies = [ "futures-util", "log", - "rustls", + "rustls 0.21.11", "tokio", - "tokio-rustls", - "tungstenite", + "tokio-rustls 0.24.1", + "tungstenite 0.20.1", "webpki-roots", ] +[[package]] +name = "tokio-tungstenite" +version = "0.23.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c6989540ced10490aaf14e6bad2e3d33728a2813310a0c71d1574304c49631cd" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.23.0", +] + +[[package]] +name = "tokio-tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "edc5f74e248dc973e0dbb7b74c7e0d6fcc301c694ff50049504004ef4d0cdcd9" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite 0.24.0", +] + [[package]] name = "tokio-util" version = "0.7.10" @@ -2883,6 +3350,22 @@ dependencies = [ "tracing", ] +[[package]] +name = "tower" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2873938d487c3cfb9aed7546dc9f2711d867c9f90c46b889989a2cb84eba6b4f" +dependencies = [ + "futures-core", + "futures-util", + "pin-project-lite", + "sync_wrapper 0.1.2", + "tokio", + "tower-layer", + "tower-service", + "tracing", +] + [[package]] name = "tower-http" version = "0.4.4" @@ -2904,15 +3387,15 @@ dependencies = [ [[package]] name = "tower-layer" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c20c8dbed6283a09604c3e69b4b7eeb54e298b8a600d4d5ecb5ad39de609f1d0" +checksum = "121c2a6cda46980bb0fcd1647ffaf6cd3fc79a013de288782836f6df9c48780e" [[package]] name = "tower-service" -version = "0.3.2" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b6bc1c9ce2b5135ac7f93c72918fc37feb872bdc6a5533a8b85eb4b86bfdae52" +checksum = "8df9b6e13f2d32c91b9bd719c00d1958837bc7dec474d94952798cc8e69eeec3" [[package]] name = "tracing" @@ -3081,13 +3564,49 @@ dependencies = [ "httparse", "log", "rand", - "rustls", + "rustls 0.21.11", "sha1", "thiserror", "url", "utf-8", ] +[[package]] +name = "tungstenite" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e2e2ce1e47ed2994fd43b04c8f618008d4cabdd5ee34027cf14f9d918edd9c8" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "utf-8", +] + +[[package]] +name = "tungstenite" +version = "0.24.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18e5b8366ee7a95b16d32197d0b2604b43a0be89dc5fac9f8e96ccafbaedda8a" +dependencies = [ + "byteorder", + "bytes", + "data-encoding", + "http 1.1.0", + "httparse", + "log", + "rand", + "sha1", + "thiserror", + "utf-8", +] + [[package]] name = "typenum" version = "1.17.0" @@ -3123,9 +3642,9 @@ checksum = "e4259d9d4425d9f0661581b804cb85fe66a4c631cadd8f490d1c13a35d5d9291" [[package]] name = "unicode-xid" -version = "0.2.4" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f962df74c8c05a667b5ee8bcf162993134c104e96440b663c8daa176dc772d8c" +checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" [[package]] name = "unicode_categories" @@ -3292,6 +3811,19 @@ version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" +[[package]] +name = "wasm-streams" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b65dc4c90b63b118468cf747d8bf3566c1913ef60be765b5730ead9e0a3ba129" +dependencies = [ + "futures-util", + "js-sys", + "wasm-bindgen", + "wasm-bindgen-futures", + "web-sys", +] + [[package]] name = "web-sys" version = "0.3.69" @@ -3308,6 +3840,18 @@ version = "0.25.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f20c57d8d7db6d3b86154206ae5d8fba62dd39573114de97c2cb0578251f8e1" +[[package]] +name = "which" +version = "4.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87ba24419a2078cd2b0f2ede2691b6c66d8e47836da3b6db8265ebad47afbfc7" +dependencies = [ + "either", + "home", + "once_cell", + "rustix", +] + [[package]] name = "whoami" version = "1.5.1" @@ -3346,7 +3890,37 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-registry" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0" +dependencies = [ + "windows-result", + "windows-strings", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-result" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e" +dependencies = [ + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-strings" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10" +dependencies = [ + "windows-result", + "windows-targets 0.52.6", ] [[package]] @@ -3364,7 +3938,16 @@ version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ - "windows-targets 0.52.5", + "windows-targets 0.52.6", +] + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets 0.52.6", ] [[package]] @@ -3384,18 +3967,18 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ - "windows_aarch64_gnullvm 0.52.5", - "windows_aarch64_msvc 0.52.5", - "windows_i686_gnu 0.52.5", + "windows_aarch64_gnullvm 0.52.6", + "windows_aarch64_msvc 0.52.6", + "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", - "windows_i686_msvc 0.52.5", - "windows_x86_64_gnu 0.52.5", - "windows_x86_64_gnullvm 0.52.5", - "windows_x86_64_msvc 0.52.5", + "windows_i686_msvc 0.52.6", + "windows_x86_64_gnu 0.52.6", + "windows_x86_64_gnullvm 0.52.6", + "windows_x86_64_msvc 0.52.6", ] [[package]] @@ -3406,9 +3989,9 @@ checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" @@ -3418,9 +4001,9 @@ checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" @@ -3430,15 +4013,15 @@ checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" @@ -3448,9 +4031,9 @@ checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" @@ -3460,9 +4043,9 @@ checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" @@ -3472,9 +4055,9 @@ checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" @@ -3484,9 +4067,9 @@ checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" @@ -3524,6 +4107,15 @@ dependencies = [ "time", ] +[[package]] +name = "yasna" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e17bb3549cc1321ae1296b9cdc2698e2b6cb1992adfa19a8c72e5b7a738f44cd" +dependencies = [ + "time", +] + [[package]] name = "zerocopy" version = "0.7.32" diff --git a/Cargo.toml b/Cargo.toml index 9114f4fc5..0e88a01ed 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,7 @@ [workspace] resolver = "2" members = [ + "dynamic-proxy", "plane/plane-tests", "plane/plane-dynamic", "plane", @@ -12,5 +13,6 @@ members = [ # https://github.com/rust-lang/cargo/pull/9252/files default-members = [ "plane", - "plane/plane-tests" + "plane/plane-tests", + "dynamic-proxy", ] diff --git a/LICENSE b/LICENSE index ca5f511a5..7e1646724 100644 --- a/LICENSE +++ b/LICENSE @@ -19,3 +19,27 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +--- + +Contains code from hyperium/hyper-util, licensed under the MIT license: + +Copyright (c) 2023 Sean McArthur + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. diff --git a/dynamic-proxy/.gitignore b/dynamic-proxy/.gitignore new file mode 100644 index 000000000..ea8c4bf7f --- /dev/null +++ b/dynamic-proxy/.gitignore @@ -0,0 +1 @@ +/target diff --git a/dynamic-proxy/Cargo.toml b/dynamic-proxy/Cargo.toml new file mode 100644 index 000000000..9429139c7 --- /dev/null +++ b/dynamic-proxy/Cargo.toml @@ -0,0 +1,29 @@ +[package] +name = "dynamic-proxy" +version = "0.1.0" +edition = "2021" + +[dependencies] +anyhow = "1.0.89" +bytes = "1.7.2" +http = "1.1.0" +http-body = "1.0.1" +http-body-util = "0.1.2" +hyper = "1.4.1" +hyper-util = { version = "0.1.8", features = ["http1", "http2", "server", "server-graceful", "server-auto", "client", "client-legacy"] } +pin-project-lite = "0.2.14" +rustls = { version = "0.23.13", features = ["ring"] } +thiserror = "1.0.63" +serde = { version = "1.0.210", features = ["derive"] } +tokio = { version = "1.40.0", features = ["macros", "rt-multi-thread"] } +tokio-rustls = "0.26.0" +tracing = "0.1.40" + +[dev-dependencies] +axum = { version = "0.7.6", features = ["http2", "ws"] } +futures-util = "0.3.30" +http = "1.1.0" +rcgen = "0.13.1" +reqwest = { version = "0.12.7", features = ["http2", "stream"] } +serde_json = "1.0.128" +tokio-tungstenite = "0.24.0" diff --git a/dynamic-proxy/src/body.rs b/dynamic-proxy/src/body.rs new file mode 100644 index 000000000..c899ddfed --- /dev/null +++ b/dynamic-proxy/src/body.rs @@ -0,0 +1,24 @@ +//! Provides a concrete, boxed body and error type. + +use bytes::Bytes; +use http_body::Body; +use http_body_util::combinators::BoxBody; +use http_body_util::{BodyExt, Empty}; + +pub type BoxedError = Box; + +pub type SimpleBody = BoxBody; + +pub fn to_simple_body(body: B) -> SimpleBody +where + B: Body + Send + Sync + 'static, + B::Error: Into, +{ + body.map_err(|e| e.into() as BoxedError).boxed() +} + +pub fn simple_empty_body() -> SimpleBody { + Empty::::new() + .map_err(|_| unreachable!("Infallable")) + .boxed() +} diff --git a/dynamic-proxy/src/graceful_shutdown.rs b/dynamic-proxy/src/graceful_shutdown.rs new file mode 100644 index 000000000..0bd6fa5dc --- /dev/null +++ b/dynamic-proxy/src/graceful_shutdown.rs @@ -0,0 +1,101 @@ +//! Near-identical copy of hyper_util::server::graceful::GracefulShutdown +//! that derives `Clone` and adds a `subscribe` method. +//! https://github.com/hyperium/hyper-util/blob/master/src/server/graceful.rs + +use hyper_util::server::graceful::GracefulConnection; +use pin_project_lite::pin_project; +use std::{ + fmt::{self, Debug}, + future::Future, + pin::Pin, + task::{self, Poll}, +}; +use tokio::sync::watch; + +#[derive(Clone)] // Added in Plane +pub struct GracefulShutdown { + tx: watch::Sender<()>, +} + +impl GracefulShutdown { + /// Create a new graceful shutdown helper. + pub fn new() -> Self { + let (tx, _) = watch::channel(()); + Self { tx } + } + + /// Wrap a future for graceful shutdown watching. + pub fn watch(&self, conn: C) -> impl Future { + let mut rx = self.tx.subscribe(); + GracefulConnectionFuture::new(conn, async move { + let _ = rx.changed().await; + // hold onto the rx until the watched future is completed + rx + }) + } + + // Added in Plane + pub fn subscribe(&self) -> watch::Receiver<()> { + self.tx.subscribe() + } + + /// Signal shutdown for all watched connections. + /// + /// This returns a `Future` which will complete once all watched + /// connections have shutdown. + pub async fn shutdown(self) { + let Self { tx } = self; + + // signal all the watched futures about the change + let _ = tx.send(()); + // and then wait for all of them to complete + tx.closed().await; + } +} + +pin_project! { + struct GracefulConnectionFuture { + #[pin] + conn: C, + #[pin] + cancel: F, + #[pin] + // If cancelled, this is held until the inner conn is done. + cancelled_guard: Option, + } +} + +impl GracefulConnectionFuture { + fn new(conn: C, cancel: F) -> Self { + Self { + conn, + cancel, + cancelled_guard: None, + } + } +} + +impl Debug for GracefulConnectionFuture { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("GracefulConnectionFuture").finish() + } +} + +impl Future for GracefulConnectionFuture +where + C: GracefulConnection, + F: Future, +{ + type Output = C::Output; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let mut this = self.project(); + if this.cancelled_guard.is_none() { + if let Poll::Ready(guard) = this.cancel.poll(cx) { + this.cancelled_guard.set(Some(guard)); + this.conn.as_mut().graceful_shutdown(); + } + } + this.conn.poll(cx) + } +} diff --git a/dynamic-proxy/src/https_redirect.rs b/dynamic-proxy/src/https_redirect.rs new file mode 100644 index 000000000..c15af625f --- /dev/null +++ b/dynamic-proxy/src/https_redirect.rs @@ -0,0 +1,70 @@ +use crate::body::{simple_empty_body, BoxedError, SimpleBody}; +use http::{ + header, + uri::{Authority, Scheme}, + Request, Response, StatusCode, Uri, +}; +use hyper::{body::Incoming, service::Service}; +use std::{future::ready, pin::Pin, str::FromStr}; + +/// A hyper service that redirects HTTP requests to HTTPS. +#[derive(Debug, Clone)] +pub struct HttpsRedirectService; + +impl HttpsRedirectService { + fn call_inner(request: Request) -> Result, StatusCode> { + // Get the host header. + let hostname = request + .headers() + .get(header::HOST) + .ok_or(StatusCode::BAD_REQUEST)?; + // Parse the host header into an authority. + let authority = + Authority::from_str(hostname.to_str().map_err(|_| StatusCode::BAD_REQUEST)?) + .map_err(|_| StatusCode::BAD_REQUEST)?; + // Strip the port. + let authority = + Authority::from_str(authority.host()).expect("Valid host is always valid authority."); + + let request_uri = request.uri().clone(); + + // Set the scheme to HTTPS + let mut parts = request_uri.into_parts(); + parts.scheme = Some(Scheme::HTTPS); + + parts.authority = Some(authority); + + // Build the new URI + let new_uri = Uri::from_parts(parts).expect("URI is always valid"); + + let response = Response::builder() + .status(StatusCode::MOVED_PERMANENTLY) + .header(header::LOCATION, new_uri.to_string()) + .body(simple_empty_body()); + + response.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) + } +} + +impl Service> for HttpsRedirectService { + type Response = Response; + type Error = BoxedError; + type Future = Pin, BoxedError>>>>; + + fn call(&self, request: Request) -> Self::Future { + let result = Self::call_inner(request); + + let result = match result { + Ok(response) => response, + Err(status) => { + tracing::error!("Error redirecting to HTTPS: {}", status); + Response::builder() + .status(status) + .body(simple_empty_body()) + .expect("Response is always valid") + } + }; + + Box::pin(ready(Ok(result))) + } +} diff --git a/dynamic-proxy/src/lib.rs b/dynamic-proxy/src/lib.rs new file mode 100644 index 000000000..d42c6e56b --- /dev/null +++ b/dynamic-proxy/src/lib.rs @@ -0,0 +1,11 @@ +pub mod body; +mod graceful_shutdown; +pub mod https_redirect; +pub mod proxy; +pub mod request; +pub mod server; +mod upgrade; + +pub use hyper; +pub use rustls; +pub use tokio_rustls; diff --git a/dynamic-proxy/src/proxy.rs b/dynamic-proxy/src/proxy.rs new file mode 100644 index 000000000..5d6561445 --- /dev/null +++ b/dynamic-proxy/src/proxy.rs @@ -0,0 +1,141 @@ +use crate::{ + body::{simple_empty_body, to_simple_body, SimpleBody}, + request::should_upgrade, + upgrade::{split_request, split_response, UpgradeHandler}, +}; +use http::StatusCode; +use hyper::{Request, Response}; +use hyper_util::{ + client::legacy::{connect::HttpConnector, Client}, + rt::TokioExecutor, +}; +use std::{convert::Infallible, time::Duration}; + +const DEFAULT_TIMEOUT: Duration = Duration::from_secs(10); + +/// A client for proxying HTTP requests to an upstream server. +#[derive(Clone)] +pub struct ProxyClient { + client: Client, + timeout: Duration, +} + +impl Default for ProxyClient { + fn default() -> Self { + Self::new() + } +} + +impl ProxyClient { + pub fn new() -> Self { + let client = Client::builder(TokioExecutor::new()).build(HttpConnector::new()); + Self { + client, + timeout: DEFAULT_TIMEOUT, + } + } + + /// Sends an HTTP request to the upstream server and returns the response. + /// If the request establishes a websocket connection, an upgrade handler is returned. + /// In this case, you must call and await `.run()` on the upgrade handler (i.e. in a tokio task) + /// to ensure that messages are properly sent and received. + pub async fn request( + &self, + request: Request, + ) -> Result<(Response, Option), Infallible> { + let url = request.uri().to_string(); + + let res = self.handle_request(request).await; + + let res = match res { + Ok(res) => res, + Err(ProxyError::Timeout) => { + tracing::warn!(url, "Upstream request failed"); + return Ok(( + Response::builder() + .status(StatusCode::GATEWAY_TIMEOUT) + .body(simple_empty_body()) + .expect("Failed to build response"), + None, + )); + } + Err(e) => { + tracing::warn!(url, ?e, "Upstream request failed"); + return Ok(( + Response::builder() + .status(StatusCode::BAD_GATEWAY) + .body(simple_empty_body()) + .expect("Failed to build response"), + None, + )); + } + }; + + let (res, upgrade_handler) = res; + let (parts, body) = res.into_parts(); + let res = Response::from_parts(parts, to_simple_body(body)); + + Ok((res, upgrade_handler)) + } + + async fn handle_request( + &self, + request: Request, + ) -> Result<(Response, Option), ProxyError> { + if should_upgrade(&request) { + let (response, upgrade_handler) = self.handle_upgrade(request).await?; + Ok((response, Some(upgrade_handler))) + } else { + let result = self.upstream_request(request).await?; + Ok((result, None)) + } + } + + async fn handle_upgrade( + &self, + request: Request, + ) -> Result<(Response, UpgradeHandler), ProxyError> { + let (upstream_request, request_with_body) = split_request(request); + let res = self.upstream_request(upstream_request).await?; + let (upstream_response, response_with_body) = split_response(res); + + let upgrade_handler = UpgradeHandler::new(request_with_body, response_with_body); + + Ok((upstream_response, upgrade_handler)) + } + + async fn upstream_request( + &self, + request: Request, + ) -> Result, ProxyError> { + let res = match tokio::time::timeout(self.timeout, self.client.request(request)).await { + Ok(Ok(res)) => res, + Err(_) => { + return Err(ProxyError::Timeout); + } + Ok(Err(e)) => { + return Err(ProxyError::RequestFailed(e.into())); + } + }; + + let (parts, body) = res.into_parts(); + let res = Response::from_parts(parts, to_simple_body(body)); + + Ok(res) + } +} + +#[derive(thiserror::Error, Debug)] +pub enum ProxyError { + #[error("Upstream request timed out.")] + Timeout, + + #[error("Upstream request failed: {0}")] + RequestFailed(#[from] Box), + + #[error("Failed to upgrade response: {0}")] + UpgradeError(#[from] hyper::Error), + + #[error("IO error: {0}")] + IoError(#[from] tokio::io::Error), +} diff --git a/dynamic-proxy/src/request.rs b/dynamic-proxy/src/request.rs new file mode 100644 index 000000000..e897efe13 --- /dev/null +++ b/dynamic-proxy/src/request.rs @@ -0,0 +1,85 @@ +use crate::body::{to_simple_body, SimpleBody}; +use bytes::Bytes; +use http::{ + request::Parts, + uri::{Authority, Scheme}, + HeaderMap, HeaderName, HeaderValue, Request, Uri, +}; +use http_body::Body; +use std::{net::SocketAddr, str::FromStr}; + +/// Represents an HTTP request (from hyper) with helpers for mutating it. +pub struct MutableRequest +where + T: Body + Send + Sync + 'static, + T::Error: Into>, +{ + pub parts: Parts, + pub body: T, +} + +impl MutableRequest +where + T: Body + Send + Sync + 'static, + T::Error: Into>, +{ + pub fn from_request(request: Request) -> Self { + let (parts, body) = request.into_parts(); + Self { parts, body } + } + + pub fn into_request(self) -> Request { + Request::from_parts(self.parts, self.body) + } + + pub fn into_request_with_simple_body(self) -> Request { + Request::from_parts(self.parts, to_simple_body(self.body)) + } + + /// Rewrite the request so that it points to the given upstream address. + pub fn set_upstream_address(&mut self, address: SocketAddr) { + let uri = std::mem::take(&mut self.parts.uri); + let mut uri_parts = uri.into_parts(); + uri_parts.scheme = Some(Scheme::HTTP); + uri_parts.authority = Some( + Authority::try_from(address.to_string()) + .expect("SocketAddr should always be a valid authority."), + ); + self.parts.uri = Uri::from_parts(uri_parts).expect("URI should always be valid."); + } + + /// Add a header to the request. + /// + /// If the header is invalid, it will be ignored and logged. + pub fn add_header(&mut self, key: &str, value: &str) { + let Ok(key) = HeaderName::from_str(key) else { + tracing::error!("Attempted to set invalid header name: {}", key); + return; + }; + let Ok(value) = HeaderValue::from_str(value) else { + // Not logging the value, which could be sensitive. + tracing::error!("Attempted to set invalid header value with key: {}", key); + return; + }; + self.parts.headers.append(key, value); + } + + pub fn headers_mut(&mut self) -> &mut HeaderMap { + &mut self.parts.headers + } +} + +pub fn should_upgrade(request: &Request) -> bool { + let Some(conn_header) = request.headers().get("connection") else { + return false; + }; + + let Ok(conn_header) = conn_header.to_str() else { + return false; + }; + + conn_header + .to_lowercase() + .split(',') + .any(|s| s.trim() == "upgrade") +} diff --git a/dynamic-proxy/src/server.rs b/dynamic-proxy/src/server.rs new file mode 100644 index 000000000..1aa323373 --- /dev/null +++ b/dynamic-proxy/src/server.rs @@ -0,0 +1,349 @@ +use crate::{ + body::SimpleBody, graceful_shutdown::GracefulShutdown, https_redirect::HttpsRedirectService, +}; +use anyhow::Result; +use http::HeaderValue; +use hyper::{body::Incoming, service::Service, Request, Response}; +use hyper_util::{ + rt::{TokioExecutor, TokioIo}, + server::conn::auto::Builder as ServerBuilder, +}; +use rustls::{server::ResolvesServerCert, ServerConfig}; +use std::{ + net::{IpAddr, SocketAddr}, + sync::Arc, + time::Duration, +}; +use tokio::{net::TcpListener, select, task::JoinSet}; +use tokio_rustls::TlsAcceptor; + +/// Header which passes the client's IP address to the backend. +const X_FORWARDED_FOR: &str = "x-forwarded-for"; + +/// Header which passes the client's protocol (http or https) to the backend. +const X_FORWARDED_PROTO: &str = "x-forwarded-proto"; + +/// A simple server that wraps a hyper service and handles requests. +/// The server can be configured to listen for either HTTP and HTTPS, +/// and supports graceful shutdown and x-forwarded-* headers. +pub struct SimpleHttpServer { + handle: tokio::task::JoinHandle>, + graceful_shutdown: Option, +} + +#[must_use] // Otherwise, the tasks we started would be stopped as soon as the graceful shutdown is initiated. +async fn listen_loop( + listener: TcpListener, + service: S, + graceful_shutdown: GracefulShutdown, +) -> JoinSet<()> +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into>, +{ + let mut recv = graceful_shutdown.subscribe(); + let mut join_set = JoinSet::new(); + + loop { + let stream = select! { + stream = listener.accept() => stream, + _ = recv.changed() => break, + }; + + let (stream, remote_addr) = match stream { + Ok((stream, remote_addr)) => (stream, remote_addr), + Err(e) => { + tracing::warn!(?e, "Failed to accept connection."); + continue; + } + }; + let remote_ip = remote_addr.ip(); + let service = WrappedService::new(service.clone(), remote_ip, "http"); + + let server = ServerBuilder::new(TokioExecutor::new()); + let io = TokioIo::new(stream); + let conn = server.serve_connection_with_upgrades(io, service); + + let conn = graceful_shutdown.watch(conn.into_owned()); + join_set.spawn(async { + if let Err(e) = conn.await { + tracing::warn!(?e, "Failed to serve connection."); + } + }); + } + + // Even though join_set is never used, we return it to keep it from being dropped + // until the graceful shutdown (or timeout) is complete. Otherwise, the tasks we started + // would be stopped as soon as the graceful shutdown is initiated. + join_set +} + +#[must_use] // Otherwise, the tasks we started would be stopped as soon as the graceful shutdown is initiated. +async fn listen_loop_tls( + listener: TcpListener, + service: S, + resolver: Arc, + graceful_shutdown: GracefulShutdown, +) -> JoinSet<()> +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into>, +{ + let server_config = ServerConfig::builder() + .with_no_client_auth() + .with_cert_resolver(resolver); + let tls_acceptor = TlsAcceptor::from(Arc::new(server_config)); + let mut recv = graceful_shutdown.subscribe(); + let mut join_set = JoinSet::new(); + + loop { + let stream = select! { + stream = listener.accept() => stream, + _ = recv.changed() => break, + }; + + let (stream, remote_addr) = match stream { + Ok((stream, remote_addr)) => (stream, remote_addr), + Err(e) => { + tracing::warn!(?e, "Failed to accept connection."); + continue; + } + }; + let remote_ip = remote_addr.ip(); + let service = WrappedService::new(service.clone(), remote_ip, "https"); + let tls_acceptor = tls_acceptor.clone(); + + let graceful_shutdown = graceful_shutdown.clone(); + join_set.spawn(async move { + let server = ServerBuilder::new(TokioExecutor::new()); + + let stream = match tls_acceptor.accept(stream).await { + Ok(stream) => stream, + Err(e) => { + tracing::warn!(?e, "Failed to accept TLS connection."); + return; + } + }; + let io = TokioIo::new(stream); + + let conn = server.serve_connection_with_upgrades(io, service); + let conn = graceful_shutdown.watch(conn.into_owned()); + + if let Err(e) = conn.await { + tracing::warn!(?e, "Failed to serve connection."); + } + }); + } + + // Even though join_set is never used, we return it to keep it from being dropped + // until the graceful shutdown (or timeout) is complete. Otherwise, the tasks we started + // would be stopped as soon as the graceful shutdown is initiated. + join_set +} + +pub enum HttpsConfig { + Http, + Https { + resolver: Arc, + }, +} + +impl HttpsConfig { + pub fn from_resolver(resolver: R) -> Self { + Self::Https { + resolver: Arc::new(resolver), + } + } + + pub fn http() -> Self { + Self::Http + } +} + +impl SimpleHttpServer { + pub fn new(service: S, listener: TcpListener, https_config: HttpsConfig) -> Result + where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + S::Error: Into>, + { + let graceful_shutdown = GracefulShutdown::new(); + + let handle = match https_config { + HttpsConfig::Http => { + tokio::spawn(listen_loop(listener, service, graceful_shutdown.clone())) + } + HttpsConfig::Https { resolver } => { + if rustls::crypto::ring::default_provider() + .install_default() + .is_err() + { + tracing::info!("Using already-installed crypto provider.") + } + + tokio::spawn(listen_loop_tls( + listener, + service, + resolver, + graceful_shutdown.clone(), + )) + } + }; + + Ok(Self { + handle, + graceful_shutdown: Some(graceful_shutdown), + }) + } + + pub async fn graceful_shutdown(mut self) { + println!("Shutting down"); + let graceful_shutdown = self + .graceful_shutdown + .take() + .expect("self.graceful_shutdown is always set"); + graceful_shutdown.shutdown().await; + } + + pub async fn graceful_shutdown_with_timeout(mut self, timeout: Duration) { + let graceful_shutdown = self + .graceful_shutdown + .take() + .expect("self.graceful_shutdown is always set"); + let result = tokio::time::timeout(timeout, graceful_shutdown.shutdown()).await; + + if let Err(e) = result { + tracing::warn!(?e, "Timed out waiting for graceful shutdown, aborting."); + } + } +} + +impl Drop for SimpleHttpServer { + fn drop(&mut self) { + if self.graceful_shutdown.is_some() { + tracing::warn!("Shutting down SimpleHttpServer without a call to graceful_shutdown. Connections will be dropped abruptly!"); + } + + self.handle.abort(); + } +} + +pub struct ServerWithHttpRedirect { + http_server: SimpleHttpServer, + https_server: Option, +} + +pub struct ServerWithHttpRedirectHttpsConfig { + pub https_port: u16, + pub resolver: Arc, +} + +pub struct ServerWithHttpRedirectConfig { + pub http_port: u16, + pub https_config: Option, +} + +impl ServerWithHttpRedirect { + pub async fn new(service: S, server_config: ServerWithHttpRedirectConfig) -> Result + where + S: Service, Response = Response> + + Clone + + Send + + Sync + + 'static, + S::Future: Send + 'static, + S::Error: Into>, + { + if let Some(https_config) = server_config.https_config { + // Serve HTTPS + let https_listener = + TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], https_config.https_port))) + .await?; + let https_server = SimpleHttpServer::new( + service, + https_listener, + HttpsConfig::Https { + resolver: https_config.resolver, + }, + )?; + + // Redirect HTTP to HTTPS + let http_listener = + TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], server_config.http_port))) + .await?; + let http_server = + SimpleHttpServer::new(HttpsRedirectService, http_listener, HttpsConfig::Http)?; + + Ok(Self { + http_server, + https_server: Some(https_server), + }) + } else { + let listener = + TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], server_config.http_port))) + .await?; + let http_server = SimpleHttpServer::new(service, listener, HttpsConfig::Http)?; + + Ok(Self { + http_server, + https_server: None, + }) + } + } + + pub async fn graceful_shutdown_with_timeout(self, timeout: Duration) { + if let Some(https_server) = self.https_server { + tokio::join!( + self.http_server.graceful_shutdown_with_timeout(timeout), + https_server.graceful_shutdown_with_timeout(timeout) + ); + } else { + self.http_server + .graceful_shutdown_with_timeout(timeout) + .await; + } + } +} + +/// A service that wraps another service and sets +/// X-Forwarded-For and X-Forwarded-Proto headers. +struct WrappedService { + inner: S, + forwarded_for: IpAddr, + forwarded_proto: &'static str, +} + +impl WrappedService { + pub fn new(inner: S, forwarded_for: IpAddr, forwarded_proto: &'static str) -> Self { + Self { + inner, + forwarded_for, + forwarded_proto, + } + } +} + +impl Service> for WrappedService +where + S: Service, Response = Response>, +{ + type Response = S::Response; + type Error = S::Error; + type Future = S::Future; + + fn call(&self, request: Request) -> Self::Future { + let mut request = request; + request.headers_mut().insert( + X_FORWARDED_FOR, + HeaderValue::from_str(&format!("{}", self.forwarded_for)) + .expect("X-Forwarded-For is always valid"), + ); + request.headers_mut().insert( + X_FORWARDED_PROTO, + HeaderValue::from_str(self.forwarded_proto).expect("X-Forwarded-Proto is always valid"), + ); + self.inner.call(request) + } +} diff --git a/dynamic-proxy/src/upgrade.rs b/dynamic-proxy/src/upgrade.rs new file mode 100644 index 000000000..c90195c36 --- /dev/null +++ b/dynamic-proxy/src/upgrade.rs @@ -0,0 +1,64 @@ +use crate::{ + body::{simple_empty_body, SimpleBody}, + proxy::ProxyError, +}; +use http::{Request, Response}; +use hyper_util::rt::TokioIo; +use tokio::io::copy_bidirectional; + +/// Split a request into two requests. The first request has an empty body, +/// and the second request has the original body. +/// +/// The first request is forwarded on upstream. The second is used for bidirectional +/// communication after the connection has been upgraded. +pub fn split_request(request: Request) -> (Request, Request) { + let (parts, body) = request.into_parts(); + + let request1 = Request::from_parts(parts.clone(), simple_empty_body()); + let request2 = Request::from_parts(parts, body); + + (request1, request2) +} + +/// Clone a response, using an empty body. +pub fn split_response(response: Response) -> (Response, Response) { + let (parts, body) = response.into_parts(); + + let response1 = Response::from_parts(parts.clone(), simple_empty_body()); + let response2 = Response::from_parts(parts, body); + + (response1, response2) +} + +/// Wraps connection state that is needed to upgrade a connection so it can be passed back +/// from the connection handler. +/// +/// The receiver should call `.run` to turn this into a future, and then await it. +pub struct UpgradeHandler { + pub request: Request, + pub response: Response, +} + +impl UpgradeHandler { + pub fn new(request: Request, response: Response) -> Self { + Self { request, response } + } + + pub async fn run(self) -> Result<(), ProxyError> { + let response = hyper::upgrade::on(self.response) + .await + .map_err(ProxyError::UpgradeError)?; + let mut response = TokioIo::new(response); + + let request = hyper::upgrade::on(self.request) + .await + .map_err(ProxyError::UpgradeError)?; + let mut request = TokioIo::new(request); + + copy_bidirectional(&mut request, &mut response) + .await + .map_err(ProxyError::IoError)?; + + Ok(()) + } +} diff --git a/dynamic-proxy/tests/common/cert.rs b/dynamic-proxy/tests/common/cert.rs new file mode 100644 index 000000000..208834628 --- /dev/null +++ b/dynamic-proxy/tests/common/cert.rs @@ -0,0 +1,50 @@ +use rcgen::generate_simple_self_signed; +use rustls::crypto::ring::sign::any_supported_type; +use rustls::server::{ClientHello, ResolvesServerCert}; +use rustls::{pki_types::PrivateKeyDer, sign::CertifiedKey}; +use std::sync::Arc; + +const CERTIFICATE_SUBJECT_ALT_NAME: &str = "plane.test"; + +/// A certificate resolver that generates its own certificate on creation, +/// and uses that for all requests. +#[derive(Debug)] +pub struct StaticCertificateResolver { + certified_key: Arc, +} + +#[allow(unused)] +impl StaticCertificateResolver { + pub fn new() -> Self { + let subject_alt_names = vec![CERTIFICATE_SUBJECT_ALT_NAME.to_string()]; + + let rcgen::CertifiedKey { cert, key_pair } = + generate_simple_self_signed(subject_alt_names).unwrap(); + + let key = PrivateKeyDer::try_from(key_pair.serialized_der()) + .expect("Could not convert key pair to der"); + + let key = any_supported_type(&key).expect("Could not convert key to supported type"); + + let cert = cert.der().clone(); + + let certified_key = Arc::new(CertifiedKey::new(vec![cert], key)); + + Self { certified_key } + } + + pub fn certificate(&self) -> reqwest::Certificate { + let der = &self.certified_key.cert[0]; + reqwest::Certificate::from_der(der).unwrap() + } + + pub fn hostname(&self) -> String { + CERTIFICATE_SUBJECT_ALT_NAME.to_string() + } +} + +impl ResolvesServerCert for StaticCertificateResolver { + fn resolve(&self, _client_hello: ClientHello) -> Option> { + Some(self.certified_key.clone()) + } +} diff --git a/dynamic-proxy/tests/common/hello_world_service.rs b/dynamic-proxy/tests/common/hello_world_service.rs new file mode 100644 index 000000000..666136e7f --- /dev/null +++ b/dynamic-proxy/tests/common/hello_world_service.rs @@ -0,0 +1,46 @@ +use bytes::Bytes; +use dynamic_proxy::body::{to_simple_body, SimpleBody}; +use http_body_util::{BodyExt, Full}; +use hyper::{body::Incoming, service::Service, Request, Response}; +use std::{convert::Infallible, future::Future, pin::Pin}; + +/// A service that returns a greeting with the X-Forwarded-For and X-Forwarded-Proto headers. +#[derive(Clone)] +pub struct HelloWorldService; + +impl Service> for HelloWorldService { + type Response = Response; + type Error = Infallible; + type Future = Pin, Infallible>> + Send>>; + + fn call(&self, request: Request) -> Self::Future { + Box::pin(async { + let x_forwarded_for = request + .headers() + .get("x-forwarded-for") + .map(|h| h.to_str().unwrap().to_string()) + .unwrap_or_default(); + let x_forwarded_proto = request + .headers() + .get("x-forwarded-proto") + .map(|h| h.to_str().unwrap().to_string()) + .unwrap_or_default(); + + let _ = request.collect().await.unwrap().to_bytes(); + + let body = format!( + "Hello, world! X-Forwarded-For: {}, X-Forwarded-Proto: {}", + x_forwarded_for, x_forwarded_proto + ); + + let response = Response::builder() + .status(200) + .body(to_simple_body(Full::new(Bytes::from( + body.as_bytes().to_vec(), + )))) + .unwrap(); + + Ok(response) + }) + } +} diff --git a/dynamic-proxy/tests/common/mod.rs b/dynamic-proxy/tests/common/mod.rs new file mode 100644 index 000000000..aff08bdab --- /dev/null +++ b/dynamic-proxy/tests/common/mod.rs @@ -0,0 +1,5 @@ +pub mod cert; +pub mod hello_world_service; +pub mod simple_axum_server; +pub mod simple_upgrade_service; +pub mod websocket_echo_server; diff --git a/dynamic-proxy/tests/common/simple_axum_server.rs b/dynamic-proxy/tests/common/simple_axum_server.rs new file mode 100644 index 000000000..b6a89cb1e --- /dev/null +++ b/dynamic-proxy/tests/common/simple_axum_server.rs @@ -0,0 +1,77 @@ +use axum::{body::Body, extract::Request, routing::any, Json, Router}; +use http::Method; +use http_body_util::BodyExt; +use serde::{Deserialize, Serialize}; +use std::{collections::HashMap, net::SocketAddr}; +use tokio::net::TcpListener; + +/// A simple server that returns the request info as json. +pub struct SimpleAxumServer { + handle: tokio::task::JoinHandle<()>, + addr: SocketAddr, +} + +#[allow(unused)] +impl SimpleAxumServer { + pub async fn new() -> Self { + let app = Router::new() + .route("/*path", any(return_request_info)) + .route("/", any(return_request_info)); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let tcp_listener = TcpListener::bind(addr).await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + + let handle = tokio::spawn(async { + axum::serve(tcp_listener, app.into_make_service()) + .await + .unwrap(); + }); + + Self { handle, addr } + } + + pub fn addr(&self) -> SocketAddr { + self.addr + } +} + +impl Drop for SimpleAxumServer { + fn drop(&mut self) { + self.handle.abort(); + } +} + +// Handler function for the root route +async fn return_request_info(method: Method, request: Request) -> Json { + let method = method.to_string(); + + let path = request.uri().path().to_string(); + let query = request.uri().query().unwrap_or("").to_string(); + + let headers: HashMap = request + .headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap().to_string())) + .collect(); + + let body = request.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8(body.to_vec()).unwrap(); + + Json(RequestInfo { + path, + query, + method, + headers, + body, + }) +} + +#[derive(Serialize, Deserialize)] +pub struct RequestInfo { + pub path: String, + pub query: String, + pub method: String, + pub headers: HashMap, + pub body: String, +} diff --git a/dynamic-proxy/tests/common/simple_upgrade_service.rs b/dynamic-proxy/tests/common/simple_upgrade_service.rs new file mode 100644 index 000000000..dee216099 --- /dev/null +++ b/dynamic-proxy/tests/common/simple_upgrade_service.rs @@ -0,0 +1,73 @@ +use bytes::Bytes; +use dynamic_proxy::body::{to_simple_body, SimpleBody}; +use http::header::CONNECTION; +use http_body_util::{Empty, Full}; +use hyper::{ + body::Incoming, + header::{HeaderValue, UPGRADE}, + service::Service, + upgrade::Upgraded, + Request, Response, StatusCode, +}; +use hyper_util::rt::TokioIo; +use std::{convert::Infallible, future::Future, pin::Pin}; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +/// A service that upgrades the connection and echos messages back to the client. +/// (Note: this does not use the actual websocket protocol on the wire, but is sufficient to +/// test the upgrade path.) +#[derive(Clone)] +pub struct SimpleUpgradeService; + +impl Service> for SimpleUpgradeService { + type Response = Response; + type Error = Infallible; + type Future = Pin, Infallible>> + Send>>; + + fn call(&self, mut req: Request) -> Self::Future { + Box::pin(async move { + if req.headers().contains_key(UPGRADE) { + // Handle upgrade + let mut res = Response::new(to_simple_body(Empty::new())); + *res.status_mut() = StatusCode::SWITCHING_PROTOCOLS; + res.headers_mut() + .insert(CONNECTION, HeaderValue::from_static("upgrade")); + res.headers_mut() + .insert(UPGRADE, HeaderValue::from_static("websocket")); + tokio::task::spawn(async move { + if let Ok(upgraded) = hyper::upgrade::on(&mut req).await { + if let Err(e) = handle_upgraded_connection(upgraded).await { + tracing::error!("Error handling upgraded connection: {}", e); + } + } + }); + + Ok(res) + } else { + // Regular response + let response = Response::builder() + .status(200) + .body(to_simple_body(Full::new(Bytes::from("Hello, world!")))) + .unwrap(); + + Ok(response) + } + }) + } +} + +async fn handle_upgraded_connection(upgraded: Upgraded) -> std::io::Result<()> { + let mut upgraded = TokioIo::new(upgraded); + + // echo message back to client + loop { + let mut buf = vec![0; 1024]; + let n = upgraded.read(&mut buf).await?; + if n == 0 { + break; + } + upgraded.write_all(&buf[..n]).await?; + } + + Ok(()) +} diff --git a/dynamic-proxy/tests/common/websocket_echo_server.rs b/dynamic-proxy/tests/common/websocket_echo_server.rs new file mode 100644 index 000000000..a406eb766 --- /dev/null +++ b/dynamic-proxy/tests/common/websocket_echo_server.rs @@ -0,0 +1,64 @@ +use axum::{ + extract::ws::{Message, WebSocket, WebSocketUpgrade}, + response::IntoResponse, + routing::get, + Router, +}; +use std::net::SocketAddr; +use tokio::net::TcpListener; + +/// A websocket echo server that echos messages back to the client. +pub struct WebSocketEchoServer { + handle: tokio::task::JoinHandle<()>, + addr: SocketAddr, +} + +#[allow(unused)] +impl WebSocketEchoServer { + pub async fn new() -> Self { + let app = Router::new().route("/ws", get(ws_handler)); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let tcp_listener = TcpListener::bind(addr).await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + + let handle = tokio::spawn(async move { + axum::serve(tcp_listener, app.into_make_service()) + .await + .unwrap(); + }); + + Self { handle, addr } + } + + pub fn addr(&self) -> SocketAddr { + self.addr + } +} + +impl Drop for WebSocketEchoServer { + fn drop(&mut self) { + self.handle.abort(); + } +} + +async fn ws_handler(ws: WebSocketUpgrade) -> impl IntoResponse { + ws.on_upgrade(handle_socket) +} + +async fn handle_socket(mut socket: WebSocket) { + while let Some(msg) = socket.recv().await { + match msg { + Ok(msg) => { + if let Message::Text(text) = msg { + if socket.send(Message::Text(text)).await.is_err() { + panic!("WebSocket connection closed."); + } + } + } + Err(e) => { + panic!("Error receiving message: {}", e); + } + } + } +} diff --git a/dynamic-proxy/tests/graceful.rs b/dynamic-proxy/tests/graceful.rs new file mode 100644 index 000000000..5ca779c85 --- /dev/null +++ b/dynamic-proxy/tests/graceful.rs @@ -0,0 +1,64 @@ +use bytes::Bytes; +use dynamic_proxy::body::to_simple_body; +use dynamic_proxy::server::{HttpsConfig, SimpleHttpServer}; +use hyper::StatusCode; +use std::convert::Infallible; +use std::net::SocketAddr; +use tokio::net::TcpListener; +use tokio::time::Duration; + +// Ref: https://github.com/hyperium/hyper-util/blob/master/examples/server_graceful.rs + +async fn slow_hello_world( + _: hyper::Request, +) -> Result, Infallible> { + tokio::time::sleep(Duration::from_secs(1)).await; // emulate slow request + let body = http_body_util::Full::::from("Hello, world!".to_owned()); + let body = to_simple_body(body); + Ok(hyper::Response::new(body)) +} + +#[tokio::test] +async fn test_graceful_shutdown() { + // Start the server + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let addr = listener.local_addr().unwrap(); + let server = SimpleHttpServer::new( + hyper::service::service_fn(slow_hello_world), + listener, + HttpsConfig::Http, + ) + .unwrap(); + + let url = format!("http://{}", addr); + + // Create a client and start a POST request without finishing the body + let client = reqwest::Client::new(); + + let response_handle = { + let client = client.clone(); + let url = url.clone(); + tokio::spawn(async move { client.get(&url).send().await.unwrap() }) + }; + + tokio::time::sleep(Duration::from_millis(100)).await; + + // Call server.graceful_shutdown() + let shutdown_task = tokio::spawn(async move { server.graceful_shutdown().await }); + + tokio::time::sleep(Duration::from_millis(100)).await; + + let response = response_handle.await.unwrap(); + + // Wait for the shutdown task to complete. + shutdown_task.await.unwrap(); + + // Ensure that the result is as expected + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.text().await.unwrap(), "Hello, world!"); + + // Attempt to make another request, which should fail due to the server shutting down + let result = client.get(&url).send().await; + assert!(result.is_err()); +} diff --git a/dynamic-proxy/tests/graceful_https.rs b/dynamic-proxy/tests/graceful_https.rs new file mode 100644 index 000000000..c36cb680f --- /dev/null +++ b/dynamic-proxy/tests/graceful_https.rs @@ -0,0 +1,74 @@ +use bytes::Bytes; +use common::cert::StaticCertificateResolver; +use dynamic_proxy::body::to_simple_body; +use dynamic_proxy::server::HttpsConfig; +use dynamic_proxy::server::SimpleHttpServer; +use hyper::StatusCode; +use std::convert::Infallible; +use std::net::SocketAddr; +use tokio::net::TcpListener; +use tokio::time::Duration; + +mod common; + +// Ref: https://github.com/hyperium/hyper-util/blob/master/examples/server_graceful.rs + +#[tokio::test] +async fn test_graceful_shutdown_https() { + // Set up HTTPS configuration + let resolver = StaticCertificateResolver::new(); + let cert = resolver.certificate(); + let hostname = resolver.hostname(); + + // Start the server + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let addr = listener.local_addr().unwrap(); + let server = SimpleHttpServer::new( + hyper::service::service_fn(|_| async move { + tokio::time::sleep(Duration::from_secs(1)).await; // emulate slow request + let body = http_body_util::Full::::from("Hello, world!".to_owned()); + let body = to_simple_body(body); + Ok::<_, Infallible>(hyper::Response::new(body)) + }), + listener, + HttpsConfig::from_resolver(resolver), + ) + .unwrap(); + + let url = format!("https://{}:{}", hostname, addr.port()); + + // Create a client with HTTPS configuration + let client = reqwest::Client::builder() + .https_only(true) + .add_root_certificate(cert) + .resolve(&hostname, addr /* port is ignored */) + .build() + .unwrap(); + + let response_handle = { + let client = client.clone(); + let url = url.clone(); + tokio::spawn(async move { client.get(&url).send().await.unwrap() }) + }; + + tokio::time::sleep(Duration::from_millis(600)).await; + + // Call server.graceful_shutdown() + let shutdown_task = tokio::spawn(async move { server.graceful_shutdown().await }); + + tokio::time::sleep(Duration::from_millis(200)).await; + + let response = response_handle.await.unwrap(); + + // Wait for the shutdown task to complete. + shutdown_task.await.unwrap(); + + // Ensure that the result is as expected + assert_eq!(response.status(), StatusCode::OK); + assert_eq!(response.text().await.unwrap(), "Hello, world!"); + + // Attempt to make another request, which should fail due to the server shutting down + let result = client.get(&url).send().await; + assert!(result.is_err()); +} diff --git a/dynamic-proxy/tests/hello_world_http.rs b/dynamic-proxy/tests/hello_world_http.rs new file mode 100644 index 000000000..d9cb4ab2e --- /dev/null +++ b/dynamic-proxy/tests/hello_world_http.rs @@ -0,0 +1,26 @@ +use common::hello_world_service::HelloWorldService; +use dynamic_proxy::server::{HttpsConfig, SimpleHttpServer}; +use hyper::StatusCode; +use std::net::SocketAddr; +use tokio::net::TcpListener; + +mod common; + +#[tokio::test] +async fn test_hello_world_http() { + let service = HelloWorldService; + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let addr = listener.local_addr().unwrap(); + let _server = SimpleHttpServer::new(service, listener, HttpsConfig::Http).unwrap(); + + let url = format!("http://{}", addr); + + let client = reqwest::Client::new(); + let res = client.get(url).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!( + res.text().await.unwrap(), + "Hello, world! X-Forwarded-For: 127.0.0.1, X-Forwarded-Proto: http" + ); +} diff --git a/dynamic-proxy/tests/https_test.rs b/dynamic-proxy/tests/https_test.rs new file mode 100644 index 000000000..95f06902a --- /dev/null +++ b/dynamic-proxy/tests/https_test.rs @@ -0,0 +1,39 @@ +use common::{cert::StaticCertificateResolver, hello_world_service::HelloWorldService}; +use dynamic_proxy::server::{HttpsConfig, SimpleHttpServer}; +use std::net::SocketAddr; +use tokio::net::TcpListener; + +mod common; + +#[tokio::test] +async fn test_https() { + let resolver = StaticCertificateResolver::new(); + let cert = resolver.certificate(); + let hostname = resolver.hostname(); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let addr = listener.local_addr().unwrap(); + let _server = SimpleHttpServer::new( + HelloWorldService, + listener, + HttpsConfig::from_resolver(resolver), + ) + .unwrap(); + + let client = reqwest::Client::builder() + .https_only(true) + .add_root_certificate(cert) + .resolve(&hostname, addr /* port is ignored */) + .build() + .unwrap(); + + let url = format!("https://{}:{}", hostname, addr.port()); + + let res = client.get(&url).send().await.unwrap(); + assert!(res.status().is_success()); + assert_eq!( + res.text().await.unwrap(), + "Hello, world! X-Forwarded-For: 127.0.0.1, X-Forwarded-Proto: https" + ); +} diff --git a/dynamic-proxy/tests/test_http_redirect.rs b/dynamic-proxy/tests/test_http_redirect.rs new file mode 100644 index 000000000..7639bd2e7 --- /dev/null +++ b/dynamic-proxy/tests/test_http_redirect.rs @@ -0,0 +1,77 @@ +use dynamic_proxy::{ + https_redirect::HttpsRedirectService, + server::{HttpsConfig, SimpleHttpServer}, +}; +use http::{header, StatusCode}; +use reqwest::{Response, Url}; +use std::net::{IpAddr, SocketAddr}; +use tokio::net::TcpListener; + +const DOMAIN: &str = "foo.bar.baz"; + +fn get_client() -> reqwest::Client { + reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .resolve(DOMAIN, SocketAddr::new(IpAddr::from([127, 0, 0, 1]), 0)) + .build() + .unwrap() +} + +async fn do_request(url: &str) -> Response { + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let port = listener.local_addr().unwrap().port(); + let https_config = HttpsConfig::http(); + let _server = SimpleHttpServer::new(HttpsRedirectService, listener, https_config); + + // url needs to have port + let mut url = Url::parse(url).unwrap(); + url.set_port(Some(port)).unwrap(); + + // Request to http://foo.bar.baz should redirect to https://foo.bar.baz + + get_client().get(url).send().await.unwrap() +} + +#[tokio::test] +async fn test_https_redirect() { + let response = do_request("http://foo.bar.baz").await; + + assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY); + assert_eq!( + response.headers().get(header::LOCATION).unwrap(), + "https://foo.bar.baz/" + ); +} + +#[tokio::test] +async fn test_https_redirect_with_slash_path() { + let response = do_request("http://foo.bar.baz/").await; + + assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY); + assert_eq!( + response.headers().get(header::LOCATION).unwrap(), + "https://foo.bar.baz/" + ); +} + +#[tokio::test] +async fn test_https_redirect_with_path() { + let response = do_request("http://foo.bar.baz/abc/123").await; + + assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY); + assert_eq!( + response.headers().get(header::LOCATION).unwrap(), + "https://foo.bar.baz/abc/123" + ); +} + +#[tokio::test] +async fn test_https_redirect_with_query_params() { + let response = do_request("http://foo.bar.baz/?a=1&b=2").await; + + assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY); + assert_eq!( + response.headers().get(header::LOCATION).unwrap(), + "https://foo.bar.baz/?a=1&b=2" + ); +} diff --git a/dynamic-proxy/tests/test_http_versions.rs b/dynamic-proxy/tests/test_http_versions.rs new file mode 100644 index 000000000..3734d6f4e --- /dev/null +++ b/dynamic-proxy/tests/test_http_versions.rs @@ -0,0 +1,50 @@ +use common::hello_world_service::HelloWorldService; +use dynamic_proxy::server::{HttpsConfig, SimpleHttpServer}; +use hyper::StatusCode; +use std::net::SocketAddr; +use tokio::net::TcpListener; + +mod common; + +#[tokio::test] +async fn test_http1() { + let service = HelloWorldService; + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let addr = listener.local_addr().unwrap(); + let _server = SimpleHttpServer::new(service, listener, HttpsConfig::Http).unwrap(); + + let url = format!("http://{}", addr); + + let client = reqwest::Client::builder().http1_only().build().unwrap(); + let res = client.get(url).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.version(), reqwest::Version::HTTP_11); + assert_eq!( + res.text().await.unwrap(), + "Hello, world! X-Forwarded-For: 127.0.0.1, X-Forwarded-Proto: http" + ); +} + +#[tokio::test] +async fn test_http2() { + let service = HelloWorldService; + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let addr = listener.local_addr().unwrap(); + let _server = SimpleHttpServer::new(service, listener, HttpsConfig::Http).unwrap(); + + let url = format!("http://{}", addr); + + let client = reqwest::Client::builder() + .http2_prior_knowledge() + .build() + .unwrap(); + let res = client.get(url).send().await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert_eq!(res.version(), reqwest::Version::HTTP_2); + assert_eq!( + res.text().await.unwrap(), + "Hello, world! X-Forwarded-For: 127.0.0.1, X-Forwarded-Proto: http" + ); +} diff --git a/dynamic-proxy/tests/test_proxy_request.rs b/dynamic-proxy/tests/test_proxy_request.rs new file mode 100644 index 000000000..410361552 --- /dev/null +++ b/dynamic-proxy/tests/test_proxy_request.rs @@ -0,0 +1,137 @@ +use crate::common::simple_axum_server::SimpleAxumServer; +use anyhow::Result; +use bytes::Bytes; +use common::simple_axum_server::RequestInfo; +use dynamic_proxy::{ + body::{simple_empty_body, to_simple_body, BoxedError}, + proxy::ProxyClient, + request::MutableRequest, +}; +use http::{Method, Request, StatusCode}; +use http_body_util::{combinators::BoxBody, BodyExt, Full}; +use std::net::SocketAddr; +use tokio::net::TcpListener; + +mod common; + +async fn make_request(req: Request>) -> Result { + let server = SimpleAxumServer::new().await; + let proxy_client = ProxyClient::new(); + + let mut req = MutableRequest::from_request(req); + req.set_upstream_address(server.addr()); + + let (res, upgrade_handler) = proxy_client.request(req.into_request()).await.unwrap(); + assert_eq!(res.status(), StatusCode::OK); + assert!(upgrade_handler.is_none()); + + let body = res.into_body().collect().await.unwrap().to_bytes(); + let result: RequestInfo = serde_json::from_slice(&body).unwrap(); + + Ok(result) +} + +#[tokio::test] +async fn test_proxy_simple_request() { + let req = Request::builder() + .method(Method::GET) + .uri("http://foo.bar".to_string()) + .body(simple_empty_body()) + .unwrap(); + + let result = make_request(req).await.unwrap(); + + assert_eq!(result.path, "/"); + assert_eq!(result.method, "GET"); + assert_eq!(result.headers.len(), 1); + assert!(result.headers.contains_key("host")); +} + +#[tokio::test] +async fn test_proxy_simple_post_request() { + let req = Request::builder() + .method(Method::POST) + .uri("http://foo.bar".to_string()) + .body(simple_empty_body()) + .unwrap(); + + let result = make_request(req).await.unwrap(); + + assert_eq!(result.path, "/"); + assert_eq!(result.method, "POST"); + assert_eq!(result.headers.len(), 1); + assert!(result.headers.contains_key("host")); +} + +#[tokio::test] +async fn test_proxy_request_with_path_and_query_params() { + let req = Request::builder() + .method(Method::POST) + .uri("http://foo.bar/foo/bar?baz=1&qux=2".to_string()) + .body(simple_empty_body()) + .unwrap(); + + let result = make_request(req).await.unwrap(); + + assert_eq!(result.path, "/foo/bar"); + assert_eq!(result.query, "baz=1&qux=2"); + assert_eq!(result.method, "POST"); + assert_eq!(result.headers.len(), 1); + assert!(result.headers.contains_key("host")); +} + +#[tokio::test] +async fn test_proxy_request_with_headers() { + let req = Request::builder() + .method(Method::GET) + .uri("http://foo.bar/foo".to_string()) + .header("X-Test", "test") + .body(simple_empty_body()) + .unwrap(); + + let result = make_request(req).await.unwrap(); + + assert_eq!(result.path, "/foo"); + assert_eq!(result.method, "GET"); + assert_eq!(result.headers.len(), 2); + assert!(result.headers.contains_key("host")); + assert_eq!(result.headers.get("x-test").unwrap(), "test"); +} + +#[tokio::test] +async fn test_proxy_body() { + let req = Request::builder() + .method(Method::POST) + .uri("http://foo.bar/foo".to_string()) + .body(to_simple_body(Full::new("test".into()))) + .unwrap(); + + let result = make_request(req).await.unwrap(); + + assert_eq!(result.path, "/foo"); + assert_eq!(result.method, "POST"); + assert_eq!(result.headers.len(), 2); + assert!(result.headers.contains_key("host")); + assert!(result.headers.contains_key("content-length")); + assert_eq!(result.body, "test"); +} + +#[tokio::test] +async fn test_proxy_no_upstream() { + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let tcp_listener = TcpListener::bind(addr).await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + + let req = Request::builder() + .method(Method::GET) + .uri(format!("http://{}", addr)) + .body(simple_empty_body()) + .unwrap(); + + let client = ProxyClient::new(); + let (result, upgrade_handler) = client.request(req).await.unwrap(); + + // expect error HTTP 502 after timeout + assert_eq!(result.status(), StatusCode::GATEWAY_TIMEOUT); + assert!(upgrade_handler.is_none()); +} diff --git a/dynamic-proxy/tests/test_proxy_websocket.rs b/dynamic-proxy/tests/test_proxy_websocket.rs new file mode 100644 index 000000000..0af9a3657 --- /dev/null +++ b/dynamic-proxy/tests/test_proxy_websocket.rs @@ -0,0 +1,105 @@ +use common::websocket_echo_server::WebSocketEchoServer; +use dynamic_proxy::{ + body::SimpleBody, + proxy::ProxyClient, + request::MutableRequest, + server::{HttpsConfig, SimpleHttpServer}, +}; +use futures_util::{SinkExt, StreamExt}; +use http::{Request, Response}; +use hyper::{body::Incoming, service::Service}; +use std::{future::Future, net::SocketAddr, pin::Pin}; +use tokio::net::TcpListener; +use tokio_tungstenite::{connect_async, tungstenite::protocol::Message}; + +mod common; + +#[derive(Clone)] +pub struct SimpleProxyService { + upstream: SocketAddr, + client: ProxyClient, +} + +impl SimpleProxyService { + pub fn new(upstream: SocketAddr) -> Self { + let client = ProxyClient::new(); + Self { upstream, client } + } +} + +impl Service> for SimpleProxyService { + type Response = Response; + type Error = Box; + type Future = Pin< + Box< + dyn Future< + Output = Result, Box>, + > + Send, + >, + >; + + fn call(&self, request: Request) -> Self::Future { + let mut request = MutableRequest::from_request(request); + request.set_upstream_address(self.upstream); + let request = request.into_request_with_simple_body(); + let client = self.client.clone(); + + Box::pin(async move { + let (res, upgrade_handler) = client.request(request).await.unwrap(); + + let upgrade_handler = upgrade_handler.unwrap(); + tokio::spawn(async move { + upgrade_handler.run().await.unwrap(); + }); + + Ok(res) + }) + } +} + +#[tokio::test] +async fn test_websocket_echo() { + // Start the WebSocket echo server + let server = WebSocketEchoServer::new().await; + let server_addr = server.addr(); + + // Start the proxy + let proxy_service = SimpleProxyService::new(server_addr); + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("Failed to bind listener"); + let proxy_addr = listener.local_addr().expect("Failed to get proxy address"); + let _server = SimpleHttpServer::new(proxy_service, listener, HttpsConfig::Http).unwrap(); + + // Connect to the WebSocket server + let url = format!("ws://{}/ws", proxy_addr); + let (mut ws_stream, _) = connect_async(&url).await.expect("Failed to connect"); + + // Send a message + let message = "Hello, WebSocket!"; + ws_stream + .send(Message::Text(message.to_string())) + .await + .expect("Failed to send message"); + + // Receive the echoed message + if let Some(Ok(msg)) = ws_stream.next().await { + match msg { + Message::Text(received_text) => { + assert_eq!( + received_text, message, + "Received message doesn't match sent message" + ); + } + _ => panic!("Unexpected message type received"), + } + } else { + panic!("Failed to receive message"); + } + + // Close the connection + ws_stream + .close(None) + .await + .expect("Failed to close WebSocket"); +} diff --git a/dynamic-proxy/tests/test_upgrade.rs b/dynamic-proxy/tests/test_upgrade.rs new file mode 100644 index 000000000..fd874af30 --- /dev/null +++ b/dynamic-proxy/tests/test_upgrade.rs @@ -0,0 +1,71 @@ +use bytes::Bytes; +use common::simple_upgrade_service::SimpleUpgradeService; +use dynamic_proxy::server::{HttpsConfig, SimpleHttpServer}; +use http_body_util::Empty; +use hyper::{ + header::{HeaderValue, CONNECTION, UPGRADE}, + Request, StatusCode, +}; +use hyper_util::rt::TokioIo; +use std::net::SocketAddr; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::{TcpListener, TcpStream}; + +mod common; + +#[tokio::test] +async fn test_upgrade() { + let service = SimpleUpgradeService; + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let listener = TcpListener::bind(addr).await.unwrap(); + let addr = listener.local_addr().unwrap(); + let _server = SimpleHttpServer::new(service, listener, HttpsConfig::Http).unwrap(); + + let url = format!("http://{}", addr); + + let req = Request::builder() + .uri(url) + .header(UPGRADE, "websocket") + .body(Empty::::new()) + .unwrap(); + + let stream = TcpStream::connect(addr).await.unwrap(); + let io = TokioIo::new(stream); + + let (mut sender, conn) = hyper::client::conn::http1::handshake(io).await.unwrap(); + + let handle = tokio::task::spawn(async move { + // conn.with_upgrades() will block until sender.send_request() is called. + // It's not clear to me why, but the example to run it in its own task + // comes from this example: + // https://github.com/hyperium/hyper/blob/master/examples/upgrades.rs + if let Err(err) = conn.with_upgrades().await { + Err(anyhow::anyhow!("Connection failed: {:?}", err)) + } else { + Ok(()) + } + }); + + let res = sender.send_request(req).await.unwrap(); + handle.await.unwrap().unwrap(); + + assert_eq!(res.status(), StatusCode::SWITCHING_PROTOCOLS); + assert_eq!( + res.headers().get(UPGRADE).unwrap(), + &HeaderValue::from_static("websocket") + ); + assert_eq!( + res.headers().get(CONNECTION).unwrap(), + &HeaderValue::from_static("upgrade") + ); + + let upgraded = hyper::upgrade::on(res).await.unwrap(); + let mut upgraded = TokioIo::new(upgraded); + upgraded.write_all(b"Hello from the client!").await.unwrap(); + + let mut buf = vec![0; 1024]; + let n = upgraded.read(&mut buf).await.unwrap(); + assert_eq!(&buf[..n], b"Hello from the client!"); + + upgraded.flush().await.unwrap(); +} diff --git a/plane/Cargo.toml b/plane/Cargo.toml index 4451f07ca..9bbdad6df 100644 --- a/plane/Cargo.toml +++ b/plane/Cargo.toml @@ -16,11 +16,13 @@ async-stream = "0.3.5" async-trait = "0.1.74" axum = { version = "0.6.20", features = ["ws"] } bollard = "0.17.0" +bytes = "1.7.2" chrono = { version = "0.4.31", features = ["serde"] } clap = { version = "4.4.10", features = ["derive"] } colored = "2.0.4" dashmap = "5.5.3" data-encoding = "2.4.0" +dynamic-proxy = { path="../dynamic-proxy" } futures-util = "0.3.29" http-body = "0.4.6" hyper = { version = "0.14.27", features = ["server"] } @@ -29,7 +31,6 @@ openssl = "0.10.66" pem = "3.0.2" rand = "0.8.5" reqwest = { version = "0.11.22", features = ["json", "rustls-tls"], default-features = false } -ring = "0.17.5" rusqlite = { version = "0.31.0", features = ["bundled", "serde_json"] } rustls-pemfile = "2.0.0" rustls-pki-types = "1.0.0" @@ -40,7 +41,6 @@ sqlx = { version = "0.8.0", features = ["runtime-tokio", "tls-rustls", "postgres thiserror = "1.0.50" time = "0.3.30" tokio = { version = "1.33.0", features = ["macros", "rt-multi-thread", "signal"] } -tokio-rustls = "0.24.1" tokio-stream = { version="0.1.14", features=["sync"] } tokio-tungstenite = { version = "0.20.1", features = ["rustls-tls-webpki-roots"] } tower = "0.4.13" diff --git a/plane/plane-tests/Cargo.toml b/plane/plane-tests/Cargo.toml index d9dfa1dac..a6dd0b68e 100644 --- a/plane/plane-tests/Cargo.toml +++ b/plane/plane-tests/Cargo.toml @@ -6,17 +6,22 @@ edition = "2021" [dependencies] anyhow = "1.0.75" async-trait = "0.1.74" -axum = "0.7.5" +axum = { version = "0.7.5", features = ["ws"] } bollard = "0.17.0" chrono = { version = "0.4.31", features = ["serde"] } +dynamic-proxy = { path = "../../dynamic-proxy" } futures-util = "0.3.29" +http = "1.1.0" +http-body-util = "0.1.2" hyper = { version = "0.14.27", features = ["server"] } plane = { path = "../plane-dynamic", package = "plane-dynamic" } plane-test-macro = { path = "plane-test-macro" } reqwest = { version = "0.11.22", features = ["json", "rustls-tls"], default-features = false } +serde = "1.0.210" serde_json = "1.0.107" thiserror = "1.0.50" tokio = { version = "1.33.0", features = ["macros", "rt-multi-thread", "signal"] } +tokio-tungstenite = "0.24.0" tracing = "0.1.40" tracing-appender = "0.2.2" tracing-subscriber = { version = "0.3.17", features = ["env-filter"] } diff --git a/plane/plane-tests/tests/cert_manager.rs b/plane/plane-tests/tests/cert_manager.rs index 36d91e58f..fdc1e60ae 100644 --- a/plane/plane-tests/tests/cert_manager.rs +++ b/plane/plane-tests/tests/cert_manager.rs @@ -1,10 +1,12 @@ +use std::sync::Arc; + use crate::common::timeout::WithTimeout; use common::test_env::TestEnvironment; use plane::{ names::{Name, ProxyName}, proxy::{ - cert_manager::watcher_manager_pair, proxy_connection::ProxyConnection, AcmeConfig, - AcmeEabConfiguration, + cert_manager::watcher_manager_pair, proxy_connection::ProxyConnection, + proxy_server::ProxyState, AcmeConfig, AcmeEabConfiguration, }, }; use plane_test_macro::plane_test; @@ -39,11 +41,14 @@ async fn cert_manager_does_refresh(env: TestEnvironment) { .await .unwrap(); + let state = Arc::new(ProxyState::new(None)); + let _proxy_connection = ProxyConnection::new( ProxyName::new_random(), controller.client(), env.cluster.clone(), cert_manager, + state.clone(), ); cert_watcher .wait_for_initial_cert() @@ -86,11 +91,14 @@ async fn cert_manager_does_refresh_eab(env: TestEnvironment) { .await .unwrap(); + let state = Arc::new(ProxyState::new(None)); + let _proxy_connection = ProxyConnection::new( ProxyName::new_random(), controller.client(), env.cluster.clone(), cert_manager, + state.clone(), ); cert_watcher .wait_for_initial_cert() diff --git a/plane/plane-tests/tests/common/localhost_resolver.rs b/plane/plane-tests/tests/common/localhost_resolver.rs new file mode 100644 index 000000000..4348c1b3b --- /dev/null +++ b/plane/plane-tests/tests/common/localhost_resolver.rs @@ -0,0 +1,22 @@ +use hyper::client::connect::dns::Name; +use reqwest::dns::{Resolve, Resolving}; +use std::{future::ready, net::SocketAddr, sync::Arc}; + +/// A reqwest-compatible DNS resolver that resolves all requests to localhost. +struct LocalhostResolver; + +impl Resolve for LocalhostResolver { + fn resolve(&self, _name: Name) -> Resolving { + let addrs = vec![SocketAddr::from(([127, 0, 0, 1], 0))]; + let addrs: Box + Send> = Box::new(addrs.into_iter()); + Box::pin(ready(Ok(addrs))) + } +} + +#[allow(unused)] +pub fn localhost_client() -> reqwest::Client { + reqwest::Client::builder() + .dns_resolver(Arc::new(LocalhostResolver)) + .build() + .unwrap() +} diff --git a/plane/plane-tests/tests/common/mod.rs b/plane/plane-tests/tests/common/mod.rs index f452b0b47..8ea45bdce 100644 --- a/plane/plane-tests/tests/common/mod.rs +++ b/plane/plane-tests/tests/common/mod.rs @@ -7,9 +7,13 @@ use tokio::time::timeout; pub mod async_drop; pub mod auth_mock; pub mod docker; +pub mod localhost_resolver; +pub mod proxy_mock; pub mod resources; +pub mod simple_axum_server; // TODO: copied from dynamic-proxy (until we merge them back) pub mod test_env; pub mod timeout; +pub mod websocket_echo_server; // TODO: copied from dynamic-proxy (until we merge them back) pub fn run_test(name: &str, time_limit: Duration, test_function: F) where diff --git a/plane/plane-tests/tests/common/proxy_mock.rs b/plane/plane-tests/tests/common/proxy_mock.rs new file mode 100644 index 000000000..72a7c8ddc --- /dev/null +++ b/plane/plane-tests/tests/common/proxy_mock.rs @@ -0,0 +1,88 @@ +use dynamic_proxy::server::{HttpsConfig, SimpleHttpServer}; +use plane::{ + names::BackendName, + protocol::{RouteInfoRequest, RouteInfoResponse}, + proxy::{connection_monitor::BackendEntry, proxy_server::ProxyState}, +}; +use std::net::SocketAddr; +use tokio::{net::TcpListener, sync::mpsc}; + +pub struct MockProxy { + proxy_state: ProxyState, + route_info_request_receiver: mpsc::Receiver, + addr: SocketAddr, + _server: SimpleHttpServer, +} + +#[allow(unused)] +impl MockProxy { + pub async fn new() -> Self { + Self::new_inner(None).await + } + + pub async fn new_with_root_redirect(root_redirect_url: String) -> Self { + Self::new_inner(Some(root_redirect_url)).await + } + + pub fn backend_entry(&self, backend_id: &BackendName) -> Option { + self.proxy_state.inner.monitor.get_backend_entry(backend_id) + } + + async fn new_inner(root_redirect_url: Option) -> Self { + let proxy_state = ProxyState::new(root_redirect_url); + let (route_info_request_sender, route_info_request_receiver) = mpsc::channel(1); + + proxy_state.inner.route_map.set_sender(move |m| { + route_info_request_sender + .try_send(m) + .expect("Failed to send route info request"); + }); + + let listener = TcpListener::bind("127.0.0.1:0") + .await + .expect("Failed to bind listener"); + let addr = listener.local_addr().expect("Failed to get local address"); + + let server = SimpleHttpServer::new(proxy_state.clone(), listener, HttpsConfig::http()) + .expect("Failed to create server"); + + Self { + proxy_state, + route_info_request_receiver, + addr, + _server: server, + } + } + + pub fn addr(&self) -> SocketAddr { + self.addr + } + + pub fn port(&self) -> u16 { + self.addr.port() + } + + pub fn set_ready(&self, ready: bool) { + self.proxy_state.set_ready(ready); + } + + pub async fn recv_route_info_request(&mut self) -> RouteInfoRequest { + self.route_info_request_receiver + .recv() + .await + .expect("Failed to receive route info request") + } + + pub async fn expect_no_route_info_request(&mut self) { + tokio::time::sleep(tokio::time::Duration::from_millis(500)).await; + assert!( + self.route_info_request_receiver.is_empty(), + "Expected no route info request, but got: {}", + self.route_info_request_receiver.len() + ); + } + + pub async fn send_route_info_response(&mut self, response: RouteInfoResponse) { + self.proxy_state.inner.route_map.receive(response); + } +} diff --git a/plane/plane-tests/tests/common/resources/pebble.rs b/plane/plane-tests/tests/common/resources/pebble.rs index 0bed017c2..908d07673 100644 --- a/plane/plane-tests/tests/common/resources/pebble.rs +++ b/plane/plane-tests/tests/common/resources/pebble.rs @@ -8,7 +8,6 @@ use plane::proxy::AcmeEabConfiguration; use reqwest::Client; use serde_json::json; use std::os::unix::fs::PermissionsExt; -use std::path::Path; use std::time::{Duration, SystemTime}; use url::Url; @@ -96,12 +95,12 @@ impl Pebble { ) -> Result { let scratch_dir = env.scratch_dir.clone(); - #[cfg(target_os = "macos")] - avoid_weird_mac_bug(&env.run_name, &scratch_dir).await?; - let pebble_dir = scratch_dir.canonicalize()?.join("pebble"); std::fs::create_dir_all(&pebble_dir)?; + #[cfg(target_os = "macos")] + avoid_weird_mac_bug(&env.run_name, &scratch_dir).await?; + let mut pebble_config = json!({ "pebble": { "listenAddress": "0.0.0.0:14000", @@ -191,7 +190,7 @@ impl Pebble { /// the scratch directory itself (i.e. the parent of the pebble directory) seems to prevent this /// from happening. #[cfg(target_os = "macos")] -pub async fn avoid_weird_mac_bug(name: &str, scratch_dir: &Path) -> Result<()> { +pub async fn avoid_weird_mac_bug(name: &str, scratch_dir: &std::path::Path) -> Result<()> { println!( "Creating dummy container for macos {}", scratch_dir.to_str().unwrap() diff --git a/plane/plane-tests/tests/common/simple_axum_server.rs b/plane/plane-tests/tests/common/simple_axum_server.rs new file mode 100644 index 000000000..999e0ca74 --- /dev/null +++ b/plane/plane-tests/tests/common/simple_axum_server.rs @@ -0,0 +1,76 @@ +use axum::{body::Body, extract::Request, routing::any, Json, Router}; +use http::Method; +use http_body_util::BodyExt; +use serde::{Deserialize, Serialize}; +use std::{collections::HashMap, net::SocketAddr}; +use tokio::net::TcpListener; + +pub struct SimpleAxumServer { + handle: tokio::task::JoinHandle<()>, + addr: SocketAddr, +} + +#[allow(unused)] +impl SimpleAxumServer { + pub async fn new() -> Self { + let app = Router::new() + .route("/*path", any(return_request_info)) + .route("/", any(return_request_info)); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let tcp_listener = TcpListener::bind(addr).await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + + let handle = tokio::spawn(async { + axum::serve(tcp_listener, app.into_make_service()) + .await + .unwrap(); + }); + + Self { handle, addr } + } + + pub fn addr(&self) -> SocketAddr { + self.addr + } +} + +impl Drop for SimpleAxumServer { + fn drop(&mut self) { + self.handle.abort(); + } +} + +// Handler function for the root route +async fn return_request_info(method: Method, request: Request) -> Json { + let method = method.to_string(); + + let path = request.uri().path().to_string(); + let query = request.uri().query().unwrap_or("").to_string(); + + let headers: HashMap = request + .headers() + .iter() + .map(|(k, v)| (k.to_string(), v.to_str().unwrap().to_string())) + .collect(); + + let body = request.into_body().collect().await.unwrap().to_bytes(); + let body = String::from_utf8(body.to_vec()).unwrap(); + + Json(RequestInfo { + path, + query, + method, + headers, + body, + }) +} + +#[derive(Serialize, Deserialize)] +pub struct RequestInfo { + pub path: String, + pub query: String, + pub method: String, + pub headers: HashMap, + pub body: String, +} diff --git a/plane/plane-tests/tests/common/test_env.rs b/plane/plane-tests/tests/common/test_env.rs index 74c800bf7..6b44d95f6 100644 --- a/plane/plane-tests/tests/common/test_env.rs +++ b/plane/plane-tests/tests/common/test_env.rs @@ -3,7 +3,9 @@ use super::{ resources::{database::DevDatabase, pebble::Pebble}, }; use chrono::Duration; +use dynamic_proxy::server::{HttpsConfig, SimpleHttpServer}; use plane::{ + client::PlaneClient, controller::ControllerServer, database::PlaneDatabase, dns::run_dns_with_listener, @@ -14,17 +16,21 @@ use plane::{ }, Drone, DroneConfig, ExecutorConfig, }, - names::{AcmeDnsServerName, ControllerName, DroneName, Name}, - proxy::AcmeEabConfiguration, + names::{AcmeDnsServerName, ControllerName, DroneName, Name, ProxyName}, + proxy::{ + cert_manager::watcher_manager_pair, proxy_connection::ProxyConnection, + proxy_server::ProxyState, AcmeEabConfiguration, + }, typed_unix_socket::{server::TypedUnixSocketServer, WrappedMessage}, types::{ClusterName, DronePoolName}, util::random_string, }; use std::{ - net::{IpAddr, Ipv4Addr}, + net::{IpAddr, Ipv4Addr, SocketAddr}, path::{Path, PathBuf}, sync::{Arc, Mutex}, }; +use tokio::net::TcpListener; use tokio::sync::broadcast::Receiver; use tracing::subscriber::DefaultGuard; use tracing_appender::non_blocking::WorkerGuard; @@ -118,6 +124,43 @@ impl TestEnvironment { .expect("Unable to construct controller.") } + pub async fn proxy( + &mut self, + controller: &ControllerServer, + ) -> Result> { + let cluster: ClusterName = "localhost:9090".parse().unwrap(); + + let client = PlaneClient::new(controller.url().clone()); + + let state = Arc::new(ProxyState::new(Some("https://plane.test".to_string()))); + + let (_, cert_manager) = watcher_manager_pair(cluster.clone(), None, None) + .await + .unwrap(); + + let proxy_connection = ProxyConnection::new( + ProxyName::new_random(), + client, + cluster, + cert_manager, + state.clone(), + ); + + let addr: SocketAddr = ([0, 0, 0, 0], 0).into(); + tracing::info!(%addr, "Listening for HTTP connections."); + let tcp_listener = TcpListener::bind(addr).await.unwrap(); + let port = tcp_listener.local_addr().unwrap().port(); + + // Spawn the server on a separate task + let server = SimpleHttpServer::new(state, tcp_listener, HttpsConfig::Http)?; + + Ok(Proxy { + port, + _server: server, + _connection: proxy_connection, + }) + } + pub async fn controller_with_forward_auth(&mut self, forward_auth: &Url) -> ControllerServer { let db = self.db().await; let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); @@ -258,6 +301,13 @@ impl TestEnvironment { } } +pub struct Proxy { + #[allow(dead_code)] // Used in tests. + pub port: u16, + _server: SimpleHttpServer, + _connection: ProxyConnection, +} + #[allow(dead_code)] // Used in tests. pub struct DroneWithSocket { pub socket_server: TypedUnixSocketServer, diff --git a/plane/plane-tests/tests/common/websocket_echo_server.rs b/plane/plane-tests/tests/common/websocket_echo_server.rs new file mode 100644 index 000000000..a406eb766 --- /dev/null +++ b/plane/plane-tests/tests/common/websocket_echo_server.rs @@ -0,0 +1,64 @@ +use axum::{ + extract::ws::{Message, WebSocket, WebSocketUpgrade}, + response::IntoResponse, + routing::get, + Router, +}; +use std::net::SocketAddr; +use tokio::net::TcpListener; + +/// A websocket echo server that echos messages back to the client. +pub struct WebSocketEchoServer { + handle: tokio::task::JoinHandle<()>, + addr: SocketAddr, +} + +#[allow(unused)] +impl WebSocketEchoServer { + pub async fn new() -> Self { + let app = Router::new().route("/ws", get(ws_handler)); + + let addr = SocketAddr::from(([127, 0, 0, 1], 0)); + let tcp_listener = TcpListener::bind(addr).await.unwrap(); + let addr = tcp_listener.local_addr().unwrap(); + + let handle = tokio::spawn(async move { + axum::serve(tcp_listener, app.into_make_service()) + .await + .unwrap(); + }); + + Self { handle, addr } + } + + pub fn addr(&self) -> SocketAddr { + self.addr + } +} + +impl Drop for WebSocketEchoServer { + fn drop(&mut self) { + self.handle.abort(); + } +} + +async fn ws_handler(ws: WebSocketUpgrade) -> impl IntoResponse { + ws.on_upgrade(handle_socket) +} + +async fn handle_socket(mut socket: WebSocket) { + while let Some(msg) = socket.recv().await { + match msg { + Ok(msg) => { + if let Message::Text(text) = msg { + if socket.send(Message::Text(text)).await.is_err() { + panic!("WebSocket connection closed."); + } + } + } + Err(e) => { + panic!("Error receiving message: {}", e); + } + } + } +} diff --git a/plane/plane-tests/tests/proxy.rs b/plane/plane-tests/tests/proxy.rs new file mode 100644 index 000000000..58b3d71c0 --- /dev/null +++ b/plane/plane-tests/tests/proxy.rs @@ -0,0 +1,383 @@ +use common::{ + localhost_resolver::localhost_client, + proxy_mock::MockProxy, + simple_axum_server::{RequestInfo, SimpleAxumServer}, + test_env::TestEnvironment, +}; +use plane::{ + log_types::BackendAddr, + names::{BackendName, Name}, + protocol::{RouteInfo, RouteInfoResponse}, + types::{BearerToken, ClusterName, SecretToken, Subdomain}, +}; +use plane_test_macro::plane_test; +use reqwest::StatusCode; +use std::{net::SocketAddr, str::FromStr}; +use tokio::net::TcpListener; + +mod common; + +#[plane_test] +async fn proxy_root_no_redirect(env: TestEnvironment) { + let mut proxy = MockProxy::new().await; + let url = format!("http://{}", proxy.addr()); + let handle = tokio::spawn(async { reqwest::get(url).await.expect("Failed to send request") }); + + proxy.expect_no_route_info_request().await; + + let response = handle.await.unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_REQUEST); + assert!(response.headers().get("location").is_none()); +} + +#[plane_test] +async fn proxy_root_redirect(env: TestEnvironment) { + let proxy = MockProxy::new_with_root_redirect("https://plane.test/".to_string()).await; + let url = format!("http://{}", proxy.addr()); + + let client = reqwest::Client::builder() + .redirect(reqwest::redirect::Policy::none()) + .build() + .unwrap(); + + let response = client.get(url).send().await.unwrap(); + + assert_eq!(response.status(), StatusCode::MOVED_PERMANENTLY); + assert_eq!( + response.headers().get("location").unwrap(), + "https://plane.test/" + ); +} + +#[plane_test] +async fn proxy_bad_bearer_token(env: TestEnvironment) { + let mut proxy = MockProxy::new().await; + let url = format!("http://{}/abc123/", proxy.addr()); + let handle = tokio::spawn(async { reqwest::get(url).await.expect("Failed to send request") }); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: None, + }) + .await; + + let response = handle.await.unwrap(); + + assert_eq!(response.status(), StatusCode::GONE); +} + +#[plane_test] +async fn proxy_backend_unreachable(env: TestEnvironment) { + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://plane.test:{port}/abc123/"); + let client = localhost_client(); + let handle = tokio::spawn(client.get(url).send()); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: BackendName::new_random(), + address: BackendAddr(SocketAddr::from(([123, 234, 123, 234], 12345))), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: None, + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + + assert_eq!(response.status(), StatusCode::BAD_GATEWAY); +} + +#[plane_test] +async fn proxy_backend_timeout(env: TestEnvironment) { + // We will start a listener, but never respond on it, to simulate a timeout. + let listener = TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], 0))) + .await + .unwrap(); + let addr = listener.local_addr().unwrap(); + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://plane.test:{port}/abc123/"); + let client = localhost_client(); + let handle = tokio::spawn(client.get(url).send()); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: BackendName::new_random(), + address: BackendAddr(addr), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: None, + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + + assert_eq!(response.status(), StatusCode::GATEWAY_TIMEOUT); +} + +#[plane_test] +async fn proxy_backend_accepts(env: TestEnvironment) { + let server = SimpleAxumServer::new().await; + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://plane.test:{port}/abc123/"); + let client = localhost_client(); + let handle = tokio::spawn(client.get(url).send()); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: BackendName::new_random(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: None, + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::OK); + let request_info: RequestInfo = response.json().await.unwrap(); + assert_eq!(request_info.path, "/"); + assert_eq!(request_info.method, "GET"); +} + +#[plane_test] +async fn proxy_static_token(env: TestEnvironment) { + let server = SimpleAxumServer::new().await; + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://plane.test:{port}/s.abc123/foobar"); + let client = localhost_client(); + let handle = tokio::spawn(client.get(url).send()); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("s.abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("s.abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: BackendName::new_random(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: None, + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::OK); + let request_info: RequestInfo = response.json().await.unwrap(); + assert_eq!(request_info.path, "/s.abc123/foobar"); // With static tokens, we pass along the original path. + assert_eq!(request_info.method, "GET"); +} + +#[plane_test] +async fn proxy_expected_subdomain_not_present(env: TestEnvironment) { + let server = SimpleAxumServer::new().await; + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://plane.test:{port}/abc123/"); + let client = localhost_client(); + let handle = tokio::spawn(client.get(url).send()); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: BackendName::new_random(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: Some(Subdomain::from_str("missing-subdomain").unwrap()), + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::FORBIDDEN); +} + +#[plane_test] +async fn proxy_expected_subdomain_is_present(env: TestEnvironment) { + let server = SimpleAxumServer::new().await; + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://mysubdomain.plane.test:{port}/abc123/"); + let client = localhost_client(); + let handle = tokio::spawn(client.get(url).send()); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: BackendName::new_random(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: Some(Subdomain::from_str("mysubdomain").unwrap()), + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::OK); +} + +#[plane_test] +async fn proxy_backend_passes_forwarded_headers(env: TestEnvironment) { + let server = SimpleAxumServer::new().await; + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://plane.test:{port}/abc123/"); + let client = localhost_client(); + let handle = tokio::spawn(client.get(url).send()); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: BackendName::new_random(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: None, + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::OK); + let request_info: RequestInfo = response.json().await.unwrap(); + let headers = request_info.headers; + assert_eq!(headers.get("x-forwarded-for").unwrap(), "127.0.0.1"); + assert_eq!(headers.get("x-forwarded-proto").unwrap(), "http"); +} + +#[plane_test] +async fn proxy_returns_backend_id_in_header(env: TestEnvironment) { + let server = SimpleAxumServer::new().await; + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://plane.test:{port}/abc123/"); + let client = localhost_client(); + let handle = tokio::spawn(client.get(url).send()); + + let backend_id = BackendName::new_random(); + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: backend_id.clone(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: None, + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::OK); + let headers = response.headers(); + assert_eq!( + headers.get("x-plane-backend-id").unwrap().to_str().unwrap(), + &backend_id.to_string() + ); +} diff --git a/plane/plane-tests/tests/proxy_connection_monitor.rs b/plane/plane-tests/tests/proxy_connection_monitor.rs new file mode 100644 index 000000000..fb659da69 --- /dev/null +++ b/plane/plane-tests/tests/proxy_connection_monitor.rs @@ -0,0 +1,114 @@ +use common::{ + localhost_resolver::localhost_client, proxy_mock::MockProxy, + simple_axum_server::SimpleAxumServer, test_env::TestEnvironment, + websocket_echo_server::WebSocketEchoServer, +}; +use plane::{ + log_types::BackendAddr, + names::{BackendName, Name}, + protocol::{RouteInfo, RouteInfoResponse}, + types::{BearerToken, ClusterName, SecretToken}, +}; +use plane_test_macro::plane_test; +use reqwest::StatusCode; +use std::str::FromStr; +use tokio_tungstenite::connect_async; + +mod common; + +#[plane_test] +async fn proxy_marks_backend_as_recently_active(env: TestEnvironment) { + let server = SimpleAxumServer::new().await; + let backend_name = BackendName::new_random(); + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://plane.test:{port}/abc123/"); + let client = localhost_client(); + + let backend_entry = proxy.backend_entry(&backend_name); + assert!(backend_entry.is_none()); + + let handle = tokio::spawn(client.get(url).send()); + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + println!("received route info request"); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: backend_name.clone(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: None, + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::OK); + + let Some(backend_entry) = proxy.backend_entry(&backend_name) else { + panic!("Backend entry not found"); + }; + assert_eq!(backend_entry.active_connections, 0); + assert!(backend_entry.had_recent_connection); +} + +#[plane_test] +async fn proxy_marks_websocket_backend_as_active(env: TestEnvironment) { + let server = WebSocketEchoServer::new().await; + let backend_name = BackendName::new_random(); + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("localhost:{}", port)).unwrap(); + let url = format!("ws://localhost:{port}/abc123/ws"); + + let handle = tokio::spawn(connect_async(url)); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: backend_name.clone(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: None, + }), + }) + .await; + + let (mut ws_stream, _) = handle.await.unwrap().unwrap(); + + let Some(backend_entry) = proxy.backend_entry(&backend_name) else { + panic!("Backend entry not found"); + }; + assert_eq!(backend_entry.active_connections, 1); + assert!(backend_entry.had_recent_connection); + + ws_stream.close(None).await.unwrap(); + drop(ws_stream); + + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + let backend_entry = proxy.backend_entry(&backend_name).unwrap(); + assert_eq!(backend_entry.active_connections, 0); + assert!(backend_entry.had_recent_connection); +} diff --git a/plane/plane-tests/tests/proxy_cors.rs b/plane/plane-tests/tests/proxy_cors.rs new file mode 100644 index 000000000..f4a5d5ff9 --- /dev/null +++ b/plane/plane-tests/tests/proxy_cors.rs @@ -0,0 +1,131 @@ +use common::{ + localhost_resolver::localhost_client, proxy_mock::MockProxy, + simple_axum_server::SimpleAxumServer, test_env::TestEnvironment, +}; +use plane::{ + log_types::BackendAddr, + names::{BackendName, Name}, + protocol::{RouteInfo, RouteInfoResponse}, + types::{BearerToken, ClusterName, SecretToken}, +}; +use plane_test_macro::plane_test; +use reqwest::StatusCode; +use std::str::FromStr; + +mod common; + +#[plane_test] +async fn proxy_gone_request_has_cors_headers(env: TestEnvironment) { + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let url = format!("http://plane.test:{port}/abc123/"); + let client = localhost_client(); + + let handle = tokio::spawn(client.get(url).send()); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: None, + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::GONE); + assert_eq!( + response + .headers() + .get("access-control-allow-origin") + .unwrap() + .to_str() + .unwrap(), + "*" + ); + assert_eq!( + response + .headers() + .get("access-control-allow-methods") + .unwrap() + .to_str() + .unwrap(), + "*" + ); + assert_eq!( + response + .headers() + .get("access-control-allow-headers") + .unwrap() + .to_str() + .unwrap(), + "*" + ); +} + +#[plane_test] +async fn proxy_valid_request_has_cors_headers(env: TestEnvironment) { + let server = SimpleAxumServer::new().await; + let mut proxy = MockProxy::new().await; + let cluster = ClusterName::from_str(&format!("plane.test:{}", proxy.port())).unwrap(); + let port = proxy.port(); + let url = format!("http://plane.test:{port}/abc123/"); + let client = localhost_client(); + + let handle = tokio::spawn(client.get(url).send()); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: BackendName::new_random(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: None, + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::OK); + assert_eq!( + response + .headers() + .get("access-control-allow-origin") + .unwrap() + .to_str() + .unwrap(), + "*" + ); + assert_eq!( + response + .headers() + .get("access-control-allow-methods") + .unwrap() + .to_str() + .unwrap(), + "*" + ); + assert_eq!( + response + .headers() + .get("access-control-allow-headers") + .unwrap() + .to_str() + .unwrap(), + "*" + ); +} diff --git a/plane/plane-tests/tests/proxy_headers.rs b/plane/plane-tests/tests/proxy_headers.rs new file mode 100644 index 000000000..e2e32c226 --- /dev/null +++ b/plane/plane-tests/tests/proxy_headers.rs @@ -0,0 +1,134 @@ +use common::{ + localhost_resolver::localhost_client, + proxy_mock::MockProxy, + simple_axum_server::{RequestInfo, SimpleAxumServer}, + test_env::TestEnvironment, +}; +use plane::{ + log_types::BackendAddr, + names::{BackendName, Name}, + protocol::{RouteInfo, RouteInfoResponse}, + types::{BearerToken, ClusterName, SecretToken}, +}; +use plane_test_macro::plane_test; +use reqwest::StatusCode; +use serde_json::json; +use std::str::FromStr; + +mod common; + +#[plane_test] +async fn proxy_fake_verified_headers_are_stripped(env: TestEnvironment) { + let server = SimpleAxumServer::new().await; + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://plane.test:{port}/abc123/"); + let client = localhost_client(); + + let handle = tokio::spawn( + client + .get(url) + .header("x-verified-blah", "foobar") // this header should be removed + .header("this-header-is-ok", "blah") // this header should be preserved + .send(), + ); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: BackendName::new_random(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: None, + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::OK); + let request_info: RequestInfo = response.json().await.unwrap(); + assert_eq!(request_info.path, "/"); + assert_eq!(request_info.method, "GET"); + + assert_eq!(request_info.headers.get("x-verified-blah"), None); + assert_eq!( + request_info.headers.get("this-header-is-ok"), + Some(&"blah".to_string()) + ); +} + +#[plane_test] +async fn proxy_plane_headers_are_set(env: TestEnvironment) { + let server = SimpleAxumServer::new().await; + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://plane.test:{port}/abc123/a/b/c/"); + let client = localhost_client(); + + let handle = tokio::spawn(client.get(url).send()); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: BackendName::try_from("backend123".to_string()).unwrap(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret987".to_string()), + cluster, + user: Some("auser123".to_string()), + user_data: Some(json!({ + "access": "readonly", + "email": "a@example.com", + })), + subdomain: None, + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::OK); + let request_info: RequestInfo = response.json().await.unwrap(); + assert_eq!(request_info.path, "/a/b/c/"); + assert_eq!(request_info.method, "GET"); + + assert_eq!( + request_info.headers.get("x-verified-username"), + Some(&"auser123".to_string()) + ); + assert_eq!( + request_info.headers.get("x-verified-user-data"), + Some(&r#"{"access":"readonly","email":"a@example.com"}"#.to_string()) + ); + assert_eq!( + request_info.headers.get("x-verified-secret"), + Some(&"secret987".to_string()) + ); + assert_eq!( + request_info.headers.get("x-verified-path"), + Some(&"/abc123/a/b/c/".to_string()) + ); + assert_eq!( + request_info.headers.get("x-verified-backend"), + Some(&"backend123".to_string()) + ); +} diff --git a/plane/plane-tests/tests/proxy_ready.rs b/plane/plane-tests/tests/proxy_ready.rs new file mode 100644 index 000000000..145eb5151 --- /dev/null +++ b/plane/plane-tests/tests/proxy_ready.rs @@ -0,0 +1,46 @@ +use common::{proxy_mock::MockProxy, test_env::TestEnvironment}; +use plane_test_macro::plane_test; +use reqwest::StatusCode; + +mod common; + +#[plane_test] +async fn proxy_not_ready(env: TestEnvironment) { + let proxy = MockProxy::new().await; + let url = format!("http://{}/ready", proxy.addr()); + let response = reqwest::get(url).await.expect("Failed to send request"); + + assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); +} + +#[plane_test] +async fn proxy_ready(env: TestEnvironment) { + let proxy = MockProxy::new().await; + proxy.set_ready(true); + let url = format!("http://{}/ready", proxy.addr()); + let response = reqwest::get(url).await.expect("Failed to send request"); + + assert_eq!(response.status(), StatusCode::OK); +} + +/// Tests that the proxy becomes ready when it connects to the controller. +/// It's surprisingly hard to test that the proxy becomes non-ready when the +/// controller shuts down, because we don't actually drop the connection +/// task (we rely on the process exiting to do that). This is kind of a bug, +/// at least for the purpose of testing. +#[plane_test] +async fn proxy_becomes_ready(env: TestEnvironment) { + let controller = env.controller().await; + let proxy = env.proxy(&controller).await.unwrap(); + + let url = format!("http://127.0.0.1:{}/ready", proxy.port); + // NB. This is a race, we can remove it if this test starts to flake. + let response = reqwest::get(&url).await.expect("Failed to send request"); + assert_eq!(response.status(), StatusCode::SERVICE_UNAVAILABLE); + + // Wait for the proxy to become ready. + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + + let response = reqwest::get(&url).await.expect("Failed to send request"); + assert_eq!(response.status(), StatusCode::OK); +} diff --git a/plane/plane-tests/tests/proxy_server_header.rs b/plane/plane-tests/tests/proxy_server_header.rs new file mode 100644 index 000000000..00c22567e --- /dev/null +++ b/plane/plane-tests/tests/proxy_server_header.rs @@ -0,0 +1,83 @@ +use common::{ + localhost_resolver::localhost_client, proxy_mock::MockProxy, + simple_axum_server::SimpleAxumServer, test_env::TestEnvironment, +}; +use plane::{ + log_types::BackendAddr, + names::{BackendName, Name}, + protocol::{RouteInfo, RouteInfoResponse}, + types::{BearerToken, ClusterName, SecretToken}, +}; +use plane_test_macro::plane_test; +use reqwest::StatusCode; +use std::str::FromStr; + +mod common; + +#[plane_test] +async fn proxy_error_response_includes_server_header(env: TestEnvironment) { + let mut proxy = MockProxy::new().await; + let url = format!("http://{}/abc/", proxy.addr()); + let handle = tokio::spawn(async { reqwest::get(url).await.expect("Failed to send request") }); + + let _ = proxy.recv_route_info_request().await; + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc".to_string()), + route_info: None, + }) + .await; + + let response = handle.await.unwrap(); + assert_eq!(response.status(), StatusCode::GONE); + assert!(response + .headers() + .get("server") + .unwrap() + .to_str() + .unwrap() + .starts_with("Plane/"),); +} + +#[plane_test] +async fn proxy_valid_response_includes_server_header(env: TestEnvironment) { + let server = SimpleAxumServer::new().await; + + let mut proxy = MockProxy::new().await; + let port = proxy.port(); + let cluster = ClusterName::from_str(&format!("plane.test:{}", port)).unwrap(); + let url = format!("http://plane.test:{port}/abc123/"); + let client = localhost_client(); + let handle = tokio::spawn(client.get(url).send()); + + let route_info_request = proxy.recv_route_info_request().await; + assert_eq!( + route_info_request.token, + BearerToken::from("abc123".to_string()) + ); + + proxy + .send_route_info_response(RouteInfoResponse { + token: BearerToken::from("abc123".to_string()), + route_info: Some(RouteInfo { + backend_id: BackendName::new_random(), + address: BackendAddr(server.addr()), + secret_token: SecretToken::from("secret".to_string()), + cluster, + user: None, + user_data: None, + subdomain: None, + }), + }) + .await; + + let response = handle.await.unwrap().unwrap(); + assert_eq!(response.status(), StatusCode::OK); + assert!(response + .headers() + .get("server") + .unwrap() + .to_str() + .unwrap() + .starts_with("Plane/"),); +} diff --git a/plane/src/client/sse.rs b/plane/src/client/sse.rs index 81b97d88c..c62636175 100644 --- a/plane/src/client/sse.rs +++ b/plane/src/client/sse.rs @@ -1,7 +1,9 @@ use super::PlaneClientError; use crate::util::ExponentialBackoff; -use hyper::header::{ACCEPT, CONNECTION}; -use reqwest::{Client, Response}; +use reqwest::{ + header::{ACCEPT, CONNECTION}, + Client, Response, +}; use serde::de::DeserializeOwned; use std::marker::PhantomData; use tungstenite::http::HeaderValue; @@ -180,7 +182,7 @@ mod tests { Router, }; use futures_util::stream::Stream; - use hyper::HeaderMap; + use reqwest::header::HeaderMap; use serde::{Deserialize, Serialize}; use std::{convert::Infallible, time::Duration}; use tokio::{sync::broadcast, task::JoinHandle, time::timeout}; @@ -192,7 +194,7 @@ mod tests { struct DemoSseServer { port: u16, - handle: Option>>, + handle: Option>>, disconnect_sender: broadcast::Sender<()>, } @@ -247,7 +249,7 @@ mod tests { let server = axum::Server::from_tcp(listener) .unwrap() .serve(app.into_make_service()); - let handle = tokio::spawn(server); + let handle = tokio::spawn(async move { server.await.map_err(anyhow::Error::new) }); Self { port, diff --git a/plane/src/controller/backend_state.rs b/plane/src/controller/backend_state.rs index ae169836d..3bff18389 100644 --- a/plane/src/controller/backend_state.rs +++ b/plane/src/controller/backend_state.rs @@ -5,6 +5,7 @@ use crate::{ }; use axum::{ extract::{Path, State}, + http::HeaderMap, response::{ sse::{Event, KeepAlive}, Response, Sse, @@ -12,7 +13,6 @@ use axum::{ Json, }; use futures_util::{Stream, StreamExt}; -use hyper::HeaderMap; use std::convert::Infallible; async fn backend_status( diff --git a/plane/src/drone/runtime/unix_socket/mod.rs b/plane/src/drone/runtime/unix_socket/mod.rs index 5b75773f1..f1a1b9b02 100644 --- a/plane/src/drone/runtime/unix_socket/mod.rs +++ b/plane/src/drone/runtime/unix_socket/mod.rs @@ -142,7 +142,8 @@ impl Runtime for UnixSocketRuntime { impl UnixSocketRuntime { pub async fn new(config: UnixSocketRuntimeConfig) -> Result { - let client = TypedUnixSocketClient::new(&config.socket_path).await?; + let client: TypedUnixSocketClient = + TypedUnixSocketClient::new(&config.socket_path).await?; Ok(Self { client }) } } diff --git a/plane/src/proxy/cert_manager.rs b/plane/src/proxy/cert_manager.rs index 3fc733511..b7e74cea5 100644 --- a/plane/src/proxy/cert_manager.rs +++ b/plane/src/proxy/cert_manager.rs @@ -10,6 +10,10 @@ use acme2_eab::{ }; use anyhow::{anyhow, Context, Result}; use chrono::Utc; +use dynamic_proxy::tokio_rustls::rustls::{ + server::{ClientHello, ResolvesServerCert}, + sign::CertifiedKey, +}; use std::{ ops::Sub, path::{Path, PathBuf}, @@ -20,10 +24,6 @@ use tokio::sync::{ broadcast, watch::{Receiver, Sender}, }; -use tokio_rustls::rustls::{ - server::{ClientHello, ResolvesServerCert}, - sign::CertifiedKey, -}; use valuable::Valuable; const DNS_01: &str = "dns-01"; @@ -87,6 +87,12 @@ impl CertWatcher { } } +impl std::fmt::Debug for CertWatcher { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(f, "CertWatcher") + } +} + impl ResolvesServerCert for CertWatcher { fn resolve(&self, _client_hello: ClientHello<'_>) -> Option> { if self diff --git a/plane/src/proxy/cert_pair.rs b/plane/src/proxy/cert_pair.rs index 0066a8d40..bee43ea4b 100644 --- a/plane/src/proxy/cert_pair.rs +++ b/plane/src/proxy/cert_pair.rs @@ -1,13 +1,13 @@ use crate::log_types::LoggableTime; use anyhow::{anyhow, Result}; +use dynamic_proxy::rustls::{ + crypto::aws_lc_rs::sign::any_supported_type, pki_types::PrivateKeyDer, +}; +use dynamic_proxy::tokio_rustls::rustls::sign::CertifiedKey; use pem::Pem; -use rustls_pki_types::{CertificateDer, PrivateKeyDer, PrivatePkcs1KeyDer}; +use rustls_pki_types::CertificateDer; use serde::{Deserialize, Serialize}; use std::{fs::Permissions, io, os::unix::fs::PermissionsExt, path::Path}; -use tokio_rustls::rustls::{ - sign::{any_supported_type, CertifiedKey}, - Certificate, PrivateKey, -}; use x509_parser::{certificate::X509Certificate, oid_registry::asn1_rs::FromDer}; #[derive(Debug, Clone, Serialize, Deserialize)] @@ -68,13 +68,14 @@ impl CertificatePair { let validity_start = validity.not_before.to_datetime(); let validity_end = validity.not_after.to_datetime(); + // Convert the Vec> to a Vec> + // by copying the certificate data. let certs = certs .into_iter() - .map(|cert| Certificate(cert.to_vec())) + .map(|cert| CertificateDer::from(cert.to_vec())) .collect(); - let private_key = PrivateKey(key.secret_der().to_vec()); // NB. rustls 0.22 gets rid of this; the PrivateKeyDer is passed to any_supported_type directly. - let key = any_supported_type(&private_key)?; + let key = any_supported_type(key)?; let certified_key = CertifiedKey::new(certs, key); @@ -93,8 +94,8 @@ impl CertificatePair { .map(|cert_der| CertificateDer::from(cert_der.to_vec())) .collect(); - let key = PrivatePkcs1KeyDer::from(pkey_der.to_vec()); - let key: PrivateKeyDer = key.into(); + let key = PrivateKeyDer::try_from(pkey_der.to_vec()) + .map_err(|e| anyhow!("Error converting private key to der: {}", e))?; Self::new(&key, certs) } diff --git a/plane/src/proxy/connection_monitor.rs b/plane/src/proxy/connection_monitor.rs index 64357270f..83ca54769 100644 --- a/plane/src/proxy/connection_monitor.rs +++ b/plane/src/proxy/connection_monitor.rs @@ -8,13 +8,13 @@ use tokio::task::JoinHandle; type BackendNameListener = Box; -#[derive(Debug)] +#[derive(Debug, Clone)] pub struct BackendEntry { /// The current number of active connections to the backend. - active_connections: u32, + pub active_connections: u32, /// Whether the backend has had a recent connection (since this value was last checked). - had_recent_connection: bool, + pub had_recent_connection: bool, } #[derive(Default)] @@ -149,6 +149,15 @@ impl ConnectionMonitorHandle { Self { monitor, handle } } + pub fn get_backend_entry(&self, backend_id: &BackendName) -> Option { + self.monitor + .lock() + .expect("Monitor lock was poisoned.") + .backends + .get(backend_id) + .cloned() + } + pub fn monitor(&self) -> Arc> { self.monitor.clone() } diff --git a/plane/src/proxy/mod.rs b/plane/src/proxy/mod.rs index 4e2a33cc6..bcd2f5b99 100644 --- a/plane/src/proxy/mod.rs +++ b/plane/src/proxy/mod.rs @@ -1,26 +1,26 @@ use self::proxy_connection::ProxyConnection; use crate::names::ProxyName; use crate::proxy::cert_manager::watcher_manager_pair; -use crate::proxy::proxy_service::ProxyMakeService; -use crate::proxy::shutdown_signal::ShutdownSignal; use crate::{client::PlaneClient, signals::wait_for_shutdown_signal, types::ClusterName}; use anyhow::Result; +use dynamic_proxy::server::{ + ServerWithHttpRedirect, ServerWithHttpRedirectConfig, ServerWithHttpRedirectHttpsConfig, +}; +use proxy_server::ProxyState; use serde::{Deserialize, Serialize}; -use std::net::IpAddr; use std::path::PathBuf; +use std::sync::Arc; +use std::time::Duration; use url::Url; pub mod cert_manager; mod cert_pair; pub mod command; -mod connection_monitor; +pub mod connection_monitor; pub mod proxy_connection; -mod proxy_service; -mod rewriter; +pub mod proxy_server; +mod request; mod route_map; -mod shutdown_signal; -mod subdomain; -mod tls; #[derive(Debug, Clone, Copy)] pub enum Protocol { @@ -37,22 +37,13 @@ impl Protocol { } } -/// Information about the incoming request that is forwarded to the request in -/// X-Forwarded-* headers. -#[derive(Debug, Clone, Copy)] -pub struct ForwardableRequestInfo { - /// The IP address of the client that made the request. - /// Forwarded as X-Forwarded-For. - ip: IpAddr, - - /// The protocol of the incoming request. - /// Forwarded as X-Forwarded-Proto. - protocol: Protocol, -} - #[derive(Debug, Copy, Clone, Serialize, Deserialize)] pub struct ServerPortConfig { + /// The port to listen on for HTTP requests. + /// If https_port is provided, this port will only serve a redirect to HTTPS. pub http_port: u16, + + /// The port to listen on for HTTPS requests. pub https_port: Option, } @@ -119,45 +110,49 @@ pub async fn run_proxy(config: ProxyConfig) -> Result<()> { ) .await?; - let proxy_connection = ProxyConnection::new(config.name, client, config.cluster, cert_manager); - let shutdown_signal = ShutdownSignal::new(); + let state = Arc::new(ProxyState::new( + config.root_redirect_url.map(|u| u.to_string()), + )); - let https_redirect = config.port_config.https_port.is_some(); + // This returns a guard, we need to keep it in scope so that the connection is not terminated. + let _proxy_connection = ProxyConnection::new( + config.name, + client, + config.cluster, + cert_manager, + state.clone(), + ); - if config.port_config.https_port.is_some() { + let server = if let Some(https_port) = config.port_config.https_port { + tracing::info!("Waiting for initial certificate."); cert_watcher.wait_for_initial_cert().await?; - } - let http_handle = ProxyMakeService { - state: proxy_connection.state(), - https_redirect, - root_redirect_url: config.root_redirect_url.clone(), - } - .serve_http(config.port_config.http_port, shutdown_signal.subscribe())?; - - let https_handle = if let Some(https_port) = config.port_config.https_port { - tracing::info!("Waiting for initial certificate."); + let https_config = ServerWithHttpRedirectHttpsConfig { + https_port, + resolver: Arc::new(cert_watcher), + }; - let https_handle = ProxyMakeService { - state: proxy_connection.state(), - https_redirect: false, - root_redirect_url: config.root_redirect_url, - } - .serve_https(https_port, cert_watcher, shutdown_signal.subscribe())?; + let server_config = ServerWithHttpRedirectConfig { + http_port: config.port_config.http_port, + https_config: Some(https_config), + }; - Some(https_handle) + ServerWithHttpRedirect::new(state, server_config).await? } else { - None + let server_config = ServerWithHttpRedirectConfig { + http_port: config.port_config.http_port, + https_config: None, + }; + + ServerWithHttpRedirect::new(state, server_config).await? }; wait_for_shutdown_signal().await; - shutdown_signal.shutdown(); tracing::info!("Shutting down proxy server."); - http_handle.await?; - if let Some(https_handle) = https_handle { - https_handle.await?; - } + server + .graceful_shutdown_with_timeout(Duration::from_secs(10)) + .await; Ok(()) } diff --git a/plane/src/proxy/proxy_connection.rs b/plane/src/proxy/proxy_connection.rs index 18aac7646..4ad653469 100644 --- a/plane/src/proxy/proxy_connection.rs +++ b/plane/src/proxy/proxy_connection.rs @@ -1,4 +1,4 @@ -use super::{cert_manager::CertManager, proxy_service::ProxyState}; +use super::{cert_manager::CertManager, proxy_server::ProxyState}; use crate::{ client::PlaneClient, names::ProxyName, @@ -20,8 +20,9 @@ impl ProxyConnection { client: PlaneClient, cluster: ClusterName, mut cert_manager: CertManager, + state: Arc, ) -> Self { - let state = Arc::new(ProxyState::new()); + tracing::info!("Creating proxy connection"); let handle = { let state = state.clone(); @@ -30,8 +31,9 @@ impl ProxyConnection { let mut proxy_connection = client.proxy_connection(&cluster); loop { + state.set_ready(false); let mut conn = proxy_connection.connect_with_retry(&name).await; - state.set_connected(true); + state.set_ready(true); let sender = conn.sender(MessageFromProxy::CertManagerRequest); cert_manager.set_request_sender(move |m| { @@ -41,13 +43,16 @@ impl ProxyConnection { }); let sender = conn.sender(MessageFromProxy::RouteInfoRequest); - state.route_map.set_sender(move |m: RouteInfoRequest| { - if let Err(e) = sender.send(m) { - tracing::error!(?e, "Error sending route info request."); - } - }); + state + .inner + .route_map + .set_sender(move |m: RouteInfoRequest| { + if let Err(e) = sender.send(m) { + tracing::error!(?e, "Error sending route info request."); + } + }); let sender = conn.sender(MessageFromProxy::KeepAlive); - state.monitor.set_listener(move |backend| { + state.inner.monitor.set_listener(move |backend| { if let Err(err) = sender.send(backend.clone()) { tracing::error!(?err, "Error sending keepalive."); } @@ -56,7 +61,7 @@ impl ProxyConnection { while let Some(message) = conn.recv().await { match message { MessageToProxy::RouteInfoResponse(response) => { - state.route_map.receive(response); + state.inner.route_map.receive(response); } MessageToProxy::CertManagerResponse(response) => { tracing::info!( @@ -66,12 +71,10 @@ impl ProxyConnection { cert_manager.receive(response); } MessageToProxy::BackendRemoved { backend } => { - state.route_map.remove_backend(&backend); + state.inner.route_map.remove_backend(&backend); } } } - - state.set_connected(false); } }) }; diff --git a/plane/src/proxy/proxy_server.rs b/plane/src/proxy/proxy_server.rs new file mode 100644 index 000000000..789e263aa --- /dev/null +++ b/plane/src/proxy/proxy_server.rs @@ -0,0 +1,280 @@ +use super::{ + connection_monitor::ConnectionMonitorHandle, + request::{get_and_maybe_remove_bearer_token, subdomain_from_host}, + route_map::RouteMap, +}; +use crate::{names::Name, protocol::RouteInfo, SERVER_NAME}; +use bytes::Bytes; +use dynamic_proxy::{ + body::{simple_empty_body, SimpleBody}, + hyper::{ + body::{Body, Incoming}, + header::{self, HeaderValue}, + service::Service, + Request, Response, StatusCode, Uri, + }, + proxy::ProxyClient, + request::MutableRequest, +}; +use std::{ + future::{ready, Future}, + sync::atomic::{AtomicBool, Ordering}, +}; +use std::{pin::Pin, sync::Arc}; + +pub struct ProxyStateInner { + pub route_map: RouteMap, + pub proxy_client: ProxyClient, + pub monitor: ConnectionMonitorHandle, + pub connected: AtomicBool, + + /// If set, the "root" path (/) will redirect to this URL. + pub root_redirect_url: Option, +} + +#[derive(Clone)] +pub struct ProxyState { + pub inner: Arc, +} + +impl Default for ProxyState { + fn default() -> Self { + Self::new(None) + } +} + +impl ProxyState { + pub fn new(root_redirect_url: Option) -> Self { + let inner = ProxyStateInner { + route_map: RouteMap::new(), + proxy_client: ProxyClient::new(), + monitor: ConnectionMonitorHandle::new(), + connected: AtomicBool::new(false), + root_redirect_url, + }; + + Self { + inner: Arc::new(inner), + } + } + + pub fn set_ready(&self, ready: bool) { + self.inner.connected.store(ready, Ordering::Relaxed); + } + + pub fn is_ready(&self) -> bool { + self.inner.connected.load(Ordering::Relaxed) + } +} + +impl Service> for ProxyState { + type Response = Response; + type Error = Box; + type Future = Pin< + Box< + dyn Future< + Output = Result, Box>, + > + Send, + >, + >; + + fn call(&self, request: Request) -> Self::Future { + // Handle "/ready" + if request.uri().path() == "/ready" { + if self.is_ready() { + return Box::pin(ready(status_code_to_response(StatusCode::OK))); + } else { + return Box::pin(ready(status_code_to_response( + StatusCode::SERVICE_UNAVAILABLE, + ))); + } + } + + if request.uri().path() == "/" { + if let Some(root_redirect_url) = &self.inner.root_redirect_url { + let mut response = Response::builder() + .status(StatusCode::MOVED_PERMANENTLY) + .header(header::LOCATION, root_redirect_url) + .body(simple_empty_body()) + .expect("Failed to build response"); + + apply_general_headers(&mut response); + + return Box::pin(ready(Ok(response))); + } else { + return Box::pin(ready(status_code_to_response(StatusCode::BAD_REQUEST))); + } + } + + let mut request = MutableRequest::from_request(request); + + // extract the bearer token from the request + let mut uri_parts = request.parts.uri.clone().into_parts(); + let original_path = request.parts.uri.path().to_string(); + let bearer_token = get_and_maybe_remove_bearer_token(&mut uri_parts); + + let Some(bearer_token) = bearer_token else { + // This should have already been handled by the root redirect above. + return Box::pin(ready(status_code_to_response(StatusCode::BAD_REQUEST))); + }; + + let Ok(uri) = Uri::from_parts(uri_parts) else { + return Box::pin(ready(status_code_to_response(StatusCode::BAD_REQUEST))); + }; + request.parts.uri = uri; + + let inner = self.inner.clone(); + + Box::pin(async move { + // look up the route info for the bearer token + let route_info = inner.route_map.lookup(&bearer_token).await; + + let Some(route_info) = route_info else { + return status_code_to_response(StatusCode::GONE); + }; + + if let Err(status_code) = prepare_request(&mut request, &route_info, &original_path) { + return status_code_to_response(status_code); + } + + let request = request.into_request_with_simple_body(); + + let result = inner.proxy_client.request(request).await; + + let (mut res, upgrade_handler) = match result { + Ok((res, upgrade_handler)) => (res, upgrade_handler), + Err(e) => { + tracing::error!(?e, "Error proxying request"); + return status_code_to_response(StatusCode::INTERNAL_SERVER_ERROR); + } + }; + + if let Some(upgrade_handler) = upgrade_handler { + let monitor = inner.monitor.monitor(); + monitor + .lock() + .expect("Monitor lock poisoned") + .inc_connection(&route_info.backend_id); + let backend_id = route_info.backend_id.clone(); + tokio::spawn(async move { + if let Err(err) = upgrade_handler.run().await { + tracing::error!("Error running upgrade handler: {}", err); + }; + + monitor + .lock() + .expect("Monitor lock poisoned") + .dec_connection(&backend_id); + }); + } else { + inner.monitor.touch_backend(&route_info.backend_id); + } + + apply_general_headers(&mut res); + res.headers_mut().insert( + "x-plane-backend-id", + HeaderValue::from_str(&route_info.backend_id.to_string()) + .expect("Backend ID is always a valid header value"), + ); + + Ok(res) + }) + } +} + +fn prepare_request( + request: &mut MutableRequest, + route_info: &RouteInfo, + original_path: &str, +) -> Result<(), StatusCode> +where + T: Body + Send + Sync, + T::Error: Into>, +{ + // Check cluster and subdomain. + let Some(host) = request + .parts + .headers + .get(header::HOST) + .and_then(|h| h.to_str().ok()) + else { + return Err(StatusCode::BAD_REQUEST); + }; + + let Ok(request_subdomain) = subdomain_from_host(host, &route_info.cluster) else { + // The host header does not match the expected cluster. + return Err(StatusCode::FORBIDDEN); + }; + + if let Some(subdomain) = &route_info.subdomain { + if request_subdomain != Some(subdomain) { + return Err(StatusCode::FORBIDDEN); + } + } + + request.set_upstream_address(route_info.address.0); + + // Remove x-verified-* headers from inbound request. + { + let headers = request.headers_mut(); + let mut headers_to_remove = Vec::new(); + headers.iter_mut().for_each(|(name, _)| { + if name.as_str().starts_with("x-verified-") { + headers_to_remove.push(name.clone()); + } + }); + + for header in headers_to_remove { + headers.remove(&header); + } + } + + // Set special Plane headers. + if let Some(username) = &route_info.user { + request.add_header("x-verified-username", username); + } + + if let Some(user_data) = &route_info.user_data { + let user_data_str = + serde_json::to_string(user_data).expect("User data is always serializable"); + request.add_header("x-verified-user-data", &user_data_str); + } + + request.add_header("x-verified-path", original_path); + request.add_header("x-verified-backend", route_info.backend_id.as_str()); + request.add_header("x-verified-secret", route_info.secret_token.as_str()); + + Ok(()) +} + +fn status_code_to_response( + status_code: StatusCode, +) -> Result, Box> { + let mut response = Response::builder() + .status(status_code) + .body(simple_empty_body()) + .expect("Failed to build response"); + + apply_general_headers(&mut response); + + Ok(response) +} + +/// Mutates a request to add static headers present on all responses +/// (error or valid). +fn apply_general_headers(response: &mut Response) { + let headers = response.headers_mut(); + headers.insert( + header::ACCESS_CONTROL_ALLOW_ORIGIN, + HeaderValue::from_static("*"), + ); + headers.insert( + header::ACCESS_CONTROL_ALLOW_METHODS, + HeaderValue::from_static("*"), + ); + headers.insert( + header::ACCESS_CONTROL_ALLOW_HEADERS, + HeaderValue::from_static("*"), + ); + headers.insert(header::SERVER, HeaderValue::from_static(SERVER_NAME)); +} diff --git a/plane/src/proxy/proxy_service.rs b/plane/src/proxy/proxy_service.rs deleted file mode 100644 index 1dc6daca2..000000000 --- a/plane/src/proxy/proxy_service.rs +++ /dev/null @@ -1,500 +0,0 @@ -use super::connection_monitor::ConnectionMonitorHandle; -use super::rewriter::RequestRewriterError; -use super::route_map::RouteMap; -use super::tls::TlsStream; -use super::{ForwardableRequestInfo, Protocol}; -use crate::names::BackendName; -use crate::proxy::cert_manager::CertWatcher; -use crate::proxy::rewriter::RequestRewriter; -use crate::proxy::tls::TlsAcceptor; -use crate::SERVER_NAME; -use axum::http::uri::PathAndQuery; -use futures_util::{Future, FutureExt}; -use hyper::server::conn::AddrIncoming; -use hyper::{ - client::HttpConnector, server::conn::AddrStream, service::Service, Body, Request, Response, -}; -use std::convert::Infallible; -use std::net::SocketAddr; -use std::pin::Pin; -use std::sync::{atomic::AtomicBool, Arc}; -use std::{ - future::ready, - io::ErrorKind, - task::{self, Poll}, -}; -use tokio::io::copy_bidirectional; -use tokio::task::JoinHandle; -use tokio_rustls::rustls::ServerConfig; -use url::Url; - -const PLANE_BACKEND_ID_HEADER: &str = "x-plane-backend-id"; - -const DEFAULT_CORS_HEADERS: &[(&str, &str)] = &[ - ("Access-Control-Allow-Origin", "*"), - ( - "Access-Control-Allow-Methods", - "GET, POST, PUT, DELETE, OPTIONS", - ), - ( - "Access-Control-Allow-Headers", - "Content-Type, Authorization", - ), - ("Access-Control-Allow-Credentials", "true"), -]; - -fn response_builder() -> hyper::http::response::Builder { - let mut request = hyper::Response::builder(); - request = request.header("Access-Control-Allow-Origin", "*"); - request = request.header( - "Access-Control-Allow-Methods", - "GET, POST, PUT, DELETE, OPTIONS", - ); - request = request.header( - "Access-Control-Allow-Headers", - "Content-Type, Authorization", - ); - request = request.header("Access-Control-Allow-Credentials", "true"); - request -} - -#[derive(Debug, thiserror::Error)] -pub enum ProxyError { - #[error("Invalid or expired connection token")] - InvalidConnectionToken, - - #[error("Missing `host` header")] - MissingHostHeader, - - #[error("Bad request")] - BadRequest, - - #[error("Invalid subdomain")] - InvalidSubdomain, - - #[error("HTTP error: {0}")] - HttpError(#[from] hyper::http::Error), - - #[error("Error binding server: {0}")] - BindError(hyper::Error), - - #[error("Error upgrading request: {0}")] - UpgradeError(hyper::Error), - - #[error("Error making request: {0} (backend: {1})")] - RequestError(hyper::Error, BackendName), - - #[error("Error making upgradable request: {0}")] - UpgradableRequestError(hyper::Error), -} - -impl From for ProxyError { - fn from(err: RequestRewriterError) -> Self { - match err { - RequestRewriterError::InvalidHostHeader => ProxyError::BadRequest, - } - } -} - -pub struct ProxyState { - pub route_map: RouteMap, - http_client: hyper::Client, - pub monitor: ConnectionMonitorHandle, - connected: AtomicBool, -} - -impl Default for ProxyState { - fn default() -> Self { - Self::new() - } -} - -impl ProxyState { - pub fn new() -> Self { - Self { - route_map: RouteMap::new(), - http_client: hyper::Client::builder().build_http::(), - monitor: ConnectionMonitorHandle::new(), - connected: AtomicBool::new(false), - } - } - - pub fn set_connected(&self, connected: bool) { - self.connected - .store(connected, std::sync::atomic::Ordering::Relaxed); - } - - pub fn connected(&self) -> bool { - self.connected.load(std::sync::atomic::Ordering::Relaxed) - } -} - -struct RequestHandler { - state: Arc, - https_redirect: bool, - remote_meta: ForwardableRequestInfo, - root_redirect_url: Option, -} - -impl RequestHandler { - async fn handle_request( - self: Arc, - req: hyper::Request, - ) -> Result, Infallible> { - let result = self.handle_request_inner(req).await; - match result { - Ok(response) => Ok(response), - Err(err) => { - let (status_code, body) = match err { - ProxyError::InvalidConnectionToken => ( - hyper::StatusCode::GONE, - "The backend is no longer available or the connection token is invalid.", - ), - ProxyError::MissingHostHeader => { - (hyper::StatusCode::BAD_REQUEST, "Bad request") - } - ProxyError::InvalidSubdomain => { - (hyper::StatusCode::UNAUTHORIZED, "Invalid subdomain") - } - ProxyError::BadRequest => (hyper::StatusCode::BAD_REQUEST, "Bad request"), - ProxyError::RequestError(err, backend) => { - tracing::warn!(?err, %backend, "Error proxying request to backend."); - (hyper::StatusCode::BAD_GATEWAY, "Connect error") - } - err => { - tracing::error!(?err, "Unhandled error handling request."); - (hyper::StatusCode::INTERNAL_SERVER_ERROR, "Internal error") - } - }; - Ok(response_builder() - .status(status_code) - .header(hyper::header::SERVER, SERVER_NAME) - .body(hyper::Body::from(body.to_string())) - .expect("Static response is always valid")) - } - } - } - - async fn handle_request_inner( - self: Arc, - req: hyper::Request, - ) -> Result, ProxyError> { - // Handle "/ready" - if req.uri().path() == "/ready" { - if self.state.connected() { - return Ok(response_builder() - .status(hyper::StatusCode::OK) - .header(hyper::header::SERVER, SERVER_NAME) - .body("Plane Proxy server (ready)".into())?); - } else { - return Ok(response_builder() - .status(hyper::StatusCode::SERVICE_UNAVAILABLE) - .header(hyper::header::SERVER, SERVER_NAME) - .body("Plane Proxy server (not ready)".into())?); - } - } - - if self.https_redirect { - let Some(host) = req - .headers() - .get(hyper::header::HOST) - .and_then(|value| value.to_str().ok()) - else { - return Err(ProxyError::MissingHostHeader); - }; - - let host = match host.parse() { - Ok(host) => host, - Err(err) => { - tracing::warn!(?err, ?host, "Invalid host header."); - return Err(ProxyError::BadRequest); - } - }; - - let mut uri_parts = req.uri().clone().into_parts(); - uri_parts.scheme = Some("https".parse().expect("https is a valid scheme.")); - uri_parts.authority = Some(host); - uri_parts.path_and_query = uri_parts - .path_and_query - .or_else(|| Some(PathAndQuery::from_static(""))); - let uri = hyper::Uri::from_parts(uri_parts).expect("URI parts are valid."); - return Ok(response_builder() - .status(hyper::StatusCode::MOVED_PERMANENTLY) - .header(hyper::header::LOCATION, uri.to_string()) - .header(hyper::header::SERVER, SERVER_NAME) - .body(hyper::Body::empty())?); - } - - if req.uri().path() == "/" { - if let Some(root_redirect_url) = &self.root_redirect_url { - return Ok(response_builder() - .status(hyper::StatusCode::MOVED_PERMANENTLY) - .header(hyper::header::LOCATION, root_redirect_url.to_string()) - .header(hyper::header::SERVER, SERVER_NAME) - .body(hyper::Body::empty())?); - } - } - - self.handle_proxy_request(req).await - } - - async fn handle_proxy_request( - self: Arc, - req: hyper::Request, - ) -> Result, ProxyError> { - let Some(mut request_rewriter) = RequestRewriter::new(req, self.remote_meta) else { - tracing::warn!("Request rewriter failed to create."); - return Err(ProxyError::BadRequest); - }; - - let route_info = self - .state - .route_map - .lookup(request_rewriter.bearer_token()) - .await; - - let Some(route_info) = route_info else { - return Err(ProxyError::InvalidConnectionToken); - }; - - let subdomain = match request_rewriter.get_subdomain(&route_info.cluster) { - Ok(subdomain) => subdomain, - Err(err) => { - tracing::warn!(?err, "Subdomain not found in request rewriter."); - return Err(ProxyError::InvalidSubdomain); - } - }; - if subdomain != route_info.subdomain.as_deref() { - tracing::warn!( - "Subdomain mismatch! subdomain in header: {:?}, subdomain in backend: {:?}", - subdomain, - route_info.subdomain - ); - return Err(ProxyError::InvalidSubdomain); - } - - let backend_id = route_info.backend_id.clone(); - request_rewriter.set_authority(route_info.address.0); - - let mut response = if request_rewriter.should_upgrade() { - let (req, req_clone) = request_rewriter.into_request_pair(&route_info); - let response = self - .state - .http_client - .request(req_clone) - .await - .map_err(ProxyError::UpgradableRequestError)?; - let response_clone = clone_response_empty_body(&response); - - let mut response_upgrade = hyper::upgrade::on(response) - .await - .map_err(ProxyError::UpgradeError)?; - let monitor = self.state.monitor.monitor(); - let backend_id = backend_id.clone(); - - tokio::spawn(async move { - let mut req_upgrade = match hyper::upgrade::on(req).await { - Ok(req) => req, - Err(error) => { - tracing::error!(?error, "Error upgrading connection."); - return; - } - }; - - monitor - .lock() - .expect("Monitor lock was poisoned.") - .inc_connection(&backend_id); - - match copy_bidirectional(&mut req_upgrade, &mut response_upgrade).await { - Ok(_) => (), - Err(error) if error.kind() == ErrorKind::UnexpectedEof => { - tracing::info!("Upgraded connection closed with UnexpectedEof."); - } - Err(error) if error.kind() == ErrorKind::TimedOut => { - tracing::info!("Upgraded connection timed out."); - } - Err(error) if error.kind() == ErrorKind::ConnectionReset => { - tracing::info!("Connection reset by peer."); - } - Err(error) if error.kind() == ErrorKind::BrokenPipe => { - tracing::info!("Broken pipe."); - } - Err(error) => { - tracing::error!(?error, "Error with upgraded connection."); - } - } - - monitor - .lock() - .expect("Monitor lock was poisoned.") - .dec_connection(&backend_id); - }); - - response_clone - } else { - let req = request_rewriter.into_request(&route_info); - self.state.monitor.touch_backend(&backend_id); - self.state - .http_client - .request(req) - .await - .map_err(|e| ProxyError::RequestError(e, backend_id.clone()))? - }; - - let headers = response.headers_mut(); - headers.insert( - PLANE_BACKEND_ID_HEADER, - backend_id - .to_string() - .parse() - .expect("Backend ID is a valid header value."), - ); - - for (key, value) in DEFAULT_CORS_HEADERS { - if !headers.contains_key(*key) { - headers.insert( - *key, - value.parse().expect("CORS header is a valid header value."), - ); - } - } - - Ok(response) - } -} - -fn clone_response_empty_body(response: &Response) -> Response { - let mut builder = Response::builder(); - - builder - .headers_mut() - .expect("Builder::headers_mut should always work on a new builder.") - .extend(response.headers().clone()); - - builder = builder.status(response.status()); - - builder - .body(Body::empty()) - .expect("Response is always valid.") -} - -pub struct ProxyService { - handler: Arc, -} - -impl Service> for ProxyService { - type Response = Response; - type Error = Infallible; - type Future = Pin> + Send>>; - - fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: Request) -> Self::Future { - Box::pin(self.handler.clone().handle_request(req)) - } -} - -pub struct ProxyMakeService { - pub state: Arc, - pub https_redirect: bool, - pub root_redirect_url: Option, -} - -impl ProxyMakeService { - pub fn serve_http(self, port: u16, shutdown_future: F) -> Result, ProxyError> - where - F: Future + Send + 'static, - { - let addr: SocketAddr = ([0, 0, 0, 0], port).into(); - tracing::info!(%addr, "Listening for HTTP connections."); - let server = hyper::Server::bind(&addr) - .serve(self) - .with_graceful_shutdown(shutdown_future); - let handle = tokio::spawn(async { - let _ = server.await; - }); - - Ok(handle) - } - - pub fn serve_https( - self, - port: u16, - cert_watcher: CertWatcher, - shutdown_future: F, - ) -> Result, ProxyError> - where - F: Future + Send + 'static, - { - let server_config = ServerConfig::builder() - .with_safe_defaults() - .with_no_client_auth() - .with_cert_resolver(Arc::new(cert_watcher)); - - let addr: SocketAddr = ([0, 0, 0, 0], port).into(); - let incoming = AddrIncoming::bind(&addr).map_err(ProxyError::BindError)?; - tracing::info!(%addr, "Listening for HTTPS connections."); - - let tls_acceptor = TlsAcceptor::new(Arc::new(server_config), incoming); - - let server = hyper::Server::builder(tls_acceptor) - .serve(self) - .with_graceful_shutdown(shutdown_future); - let handle = tokio::spawn(async { - let _ = server.await; - }); - - Ok(handle) - } -} - -impl<'a> Service<&'a AddrStream> for ProxyMakeService { - type Response = ProxyService; - type Error = ProxyError; - type Future = Pin> + Send>>; - - fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: &'a AddrStream) -> Self::Future { - let remote_ip = req.remote_addr().ip(); - let handler = Arc::new(RequestHandler { - state: self.state.clone(), - https_redirect: self.https_redirect, - remote_meta: ForwardableRequestInfo { - ip: remote_ip, - protocol: Protocol::Http, - }, - root_redirect_url: self.root_redirect_url.clone(), - }); - ready(Ok(ProxyService { handler })).boxed() - } -} - -impl<'a> Service<&'a TlsStream> for ProxyMakeService { - type Response = ProxyService; - type Error = ProxyError; - type Future = Pin> + Send>>; - - fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll> { - Poll::Ready(Ok(())) - } - - fn call(&mut self, req: &'a TlsStream) -> Self::Future { - let remote_ip = req.remote_ip; - let handler = Arc::new(RequestHandler { - state: self.state.clone(), - https_redirect: false, - remote_meta: ForwardableRequestInfo { - ip: remote_ip, - protocol: Protocol::Https, - }, - root_redirect_url: self.root_redirect_url.clone(), - }); - ready(Ok(ProxyService { handler })).boxed() - } -} diff --git a/plane/src/proxy/request.rs b/plane/src/proxy/request.rs new file mode 100644 index 000000000..9928d2179 --- /dev/null +++ b/plane/src/proxy/request.rs @@ -0,0 +1,206 @@ +use crate::types::{BearerToken, ClusterName}; +use dynamic_proxy::hyper::http::uri::{self, PathAndQuery}; +use std::str::FromStr; + +// If a cluster name does not specify a port, :443 is implied. +// Most browsers will not specify it, but some (e.g. the `ws` websocket client in Node.js) +// will, so we strip it. +const HTTPS_PORT_SUFFIX: &str = ":443"; + +/// Returns Ok(Some(subdomain)) if a subdomain is found. +/// Returns Ok(None) if no subdomain is found, but the host header matches the cluster name. +/// Returns Err(()) if the host header does not +/// match the cluster name. +pub fn subdomain_from_host<'a>( + host: &'a str, + cluster: &ClusterName, +) -> Result, ()> { + let host = if let Some(host) = host.strip_suffix(HTTPS_PORT_SUFFIX) { + host + } else { + host + }; + + if let Some(subdomain) = host.strip_suffix(cluster.as_str()) { + if subdomain.is_empty() { + // Subdomain exactly matches cluster name. + Ok(None) + } else if let Some(subdomain) = subdomain.strip_suffix('.') { + Ok(Some(subdomain)) + } else { + Err(()) + } + } else { + tracing::warn!(host, "Host header does not end in cluster name."); + Err(()) + } +} + +/// Removes a connection string from the URI and returns it. +/// If no connection string is found, returns None. +pub fn get_and_maybe_remove_bearer_token(parts: &mut uri::Parts) -> Option { + let path_and_query = parts.path_and_query.clone()?; + + let full_path = path_and_query.path().strip_prefix('/')?; + + // Split the incoming path into the token and the path to proxy to. If there is no slash, the token is + // the full incoming path, and the path to proxy to is just `/`. + let (token, path) = match full_path.split_once('/') { + Some((token, path)) => (token, path), + None => (full_path, ""), + }; + + if token.is_empty() { + return None; + } + + let token = BearerToken::from(token.to_string()); + + if token.is_static() { + // We don't rewrite the URL if using a static token. + return Some(token); + } + + let query = path_and_query + .query() + .map(|query| format!("?{}", query)) + .unwrap_or_default(); + + parts.path_and_query = Some( + PathAndQuery::from_str(format!("/{}{}", path, query).as_str()) + .expect("Path and query is valid."), + ); + + Some(token) +} + +#[cfg(test)] +mod tests { + use uri::Uri; + + use super::*; + use std::str::FromStr; + + #[test] + fn no_subdomains() { + let host = "foo.bar.baz"; + let cluster = ClusterName::from_str("foo.bar.baz").unwrap(); + assert_eq!(subdomain_from_host(host, &cluster), Ok(None)); + } + + #[test] + fn valid_subdomain() { + let host = "foobar.example.com"; + let cluster = ClusterName::from_str("example.com").unwrap(); + assert_eq!(subdomain_from_host(host, &cluster), Ok(Some("foobar"))); + } + + #[test] + fn valid_suffix_no_dot() { + let host = "foobarexample.com"; + let cluster = ClusterName::from_str("example.com").unwrap(); + assert_eq!(subdomain_from_host(host, &cluster), Err(())); + } + + #[test] + fn invalid_suffix() { + let host = "abc.abc.com"; + let cluster = ClusterName::from_str("example.com").unwrap(); + assert_eq!(subdomain_from_host(host, &cluster), Err(())); + } + + #[test] + fn allowed_port() { + let host = "foobar.myhost:8080"; + let cluster = ClusterName::from_str("myhost:8080").unwrap(); + assert_eq!(subdomain_from_host(host, &cluster), Ok(Some("foobar"))); + } + + #[test] + fn port_required() { + let host = "foobar.myhost"; + let cluster = ClusterName::from_str("myhost:8080").unwrap(); + assert_eq!(subdomain_from_host(host, &cluster), Err(())); + } + + #[test] + fn port_must_match() { + let host = "foobar.myhost:8080"; + let cluster = ClusterName::from_str("myhost").unwrap(); + assert_eq!(subdomain_from_host(host, &cluster), Err(())); + } + + #[test] + fn port_443_optional() { + let host = "foobar.myhost:443"; + let cluster = ClusterName::from_str("myhost").unwrap(); + assert_eq!(subdomain_from_host(host, &cluster), Ok(Some("foobar"))); + } + + #[test] + fn test_get_and_maybe_remove_bearer_token() { + let url = Uri::from_str("https://example.com/foo/bar").unwrap(); + let mut parts = url.into_parts(); + assert_eq!( + get_and_maybe_remove_bearer_token(&mut parts), + Some(BearerToken::from("foo".to_string())) + ); + assert_eq!( + parts.path_and_query, + Some(PathAndQuery::from_str("/bar").unwrap()) + ); + } + + #[test] + fn test_get_and_maybe_remove_bearer_token_ends_no_slash() { + let url = Uri::from_str("https://example.com/foo").unwrap(); + let mut parts = url.into_parts(); + assert_eq!( + get_and_maybe_remove_bearer_token(&mut parts), + Some(BearerToken::from("foo".to_string())) + ); + assert_eq!( + parts.path_and_query, + Some(PathAndQuery::from_str("/").unwrap()) + ); + } + + #[test] + fn test_get_and_maybe_remove_bearer_token_ends_in_slash() { + let url = Uri::from_str("https://example.com/foo/").unwrap(); + let mut parts = url.into_parts(); + assert_eq!( + get_and_maybe_remove_bearer_token(&mut parts), + Some(BearerToken::from("foo".to_string())) + ); + assert_eq!( + parts.path_and_query, + Some(PathAndQuery::from_str("/").unwrap()) + ); + } + + #[test] + fn test_get_and_maybe_remove_bearer_token_no_token() { + let url = Uri::from_str("https://example.com/").unwrap(); + let mut parts = url.into_parts(); + assert_eq!(get_and_maybe_remove_bearer_token(&mut parts), None); + assert_eq!( + parts.path_and_query, + Some(PathAndQuery::from_str("/").unwrap()) + ); + } + + #[test] + fn test_get_and_maybe_remove_bearer_token_static_token() { + let url = Uri::from_str("https://example.com/s.foo/bar").unwrap(); + let mut parts = url.into_parts(); + assert_eq!( + get_and_maybe_remove_bearer_token(&mut parts), + Some(BearerToken::from("s.foo".to_string())) + ); + assert_eq!( + parts.path_and_query, + Some(PathAndQuery::from_str("/s.foo/bar").unwrap()) + ); + } +} diff --git a/plane/src/proxy/rewriter.rs b/plane/src/proxy/rewriter.rs deleted file mode 100644 index 74810e920..000000000 --- a/plane/src/proxy/rewriter.rs +++ /dev/null @@ -1,286 +0,0 @@ -use super::{subdomain::subdomain_from_host, ForwardableRequestInfo}; -use crate::{ - protocol::RouteInfo, - types::{BearerToken, ClusterName}, -}; -use hyper::{ - header::HOST, - http::{request, uri}, - Body, HeaderMap, Request, Uri, -}; -use reqwest::header::HeaderValue; -use std::{borrow::BorrowMut, net::SocketAddr, str::FromStr}; -use tungstenite::http::uri::PathAndQuery; - -const VERIFIED_HEADER_PREFIX: &str = "x-verified-"; -const USERNAME_HEADER: &str = "x-verified-username"; -const AUTH_SECRET_HEADER: &str = "x-verified-secret"; -const AUTH_USER_DATA_HEADER: &str = "x-verified-user-data"; -const PATH_PREFIX_HEADER: &str = "x-verified-path"; -const BACKEND_ID_HEADER: &str = "x-verified-backend"; -const X_FORWARDED_FOR_HEADER: &str = "x-forwarded-for"; -const X_FORWARDED_PROTO_HEADER: &str = "x-forwarded-proto"; - -#[derive(Debug, thiserror::Error, PartialEq, Eq)] -pub enum RequestRewriterError { - #[error("Invalid `host` header")] - InvalidHostHeader, -} - -impl From for RequestRewriterError { - fn from(_: hyper::header::ToStrError) -> Self { - RequestRewriterError::InvalidHostHeader - } -} - -pub struct RequestRewriter { - parts: request::Parts, - uri_parts: uri::Parts, - body: Body, - bearer_token: BearerToken, - prefix_uri: Uri, - remote_meta: ForwardableRequestInfo, -} - -impl RequestRewriter { - pub fn new(request: Request, remote_meta: ForwardableRequestInfo) -> Option { - let (parts, body) = request.into_parts(); - - let mut uri_parts = parts.uri.clone().into_parts(); - uri_parts.scheme = Some("http".parse().expect("Scheme is valid.")); - - let bearer_token = match extract_bearer_token(&mut uri_parts) { - Some(bearer_token) => bearer_token, - None => { - tracing::warn!(uri=?parts.uri, "Bearer token not found in URI."); - return None; - } - }; - - let mut prefix_uri_parts = parts.uri.clone().into_parts(); - prefix_uri_parts.path_and_query = - Some(PathAndQuery::from_str(&format!("/{}/", bearer_token)).expect("Path is valid.")); - let prefix_uri = Uri::from_parts(prefix_uri_parts).expect("URI parts are valid."); - - Some(Self { - parts, - uri_parts, - body, - bearer_token, - prefix_uri, - remote_meta, - }) - } - - pub fn set_authority(&mut self, addr: SocketAddr) { - self.uri_parts.authority = Some( - addr.to_string() - .parse() - .expect("SocketAddr is a valid authority."), - ); - } - - pub fn bearer_token(&self) -> &BearerToken { - &self.bearer_token - } - - /// Returns the subdomain of the request's host header, after stripping the cluster name. - /// Returns Ok(Some(subdomain)) if a subdomain is found. - /// Returns Ok(None) if no subdomain is found, but the host header matches the cluster name. - /// Returns Err(RequestRewriterError::InvalidHostHeader) if the host header does not - /// match the cluster name, or no host header is found. - pub fn get_subdomain( - &self, - cluster: &ClusterName, - ) -> Result, RequestRewriterError> { - let Some(hostname) = self.parts.headers.get(HOST) else { - return Err(RequestRewriterError::InvalidHostHeader); - }; - - let hostname = match hostname.to_str() { - Ok(hostname) => hostname, - Err(err) => { - tracing::warn!(?hostname, ?err, "Host header is not valid UTF-8."); - return Err(RequestRewriterError::InvalidHostHeader); - } - }; - - subdomain_from_host(hostname, cluster) - } - - fn into_parts(self) -> (request::Parts, Body, Uri, ForwardableRequestInfo) { - let Self { - mut parts, - uri_parts, - body, - prefix_uri, - remote_meta, - .. - } = self; - - let uri = Uri::from_parts(uri_parts).expect("URI parts are valid."); - parts.uri = uri; - - (parts, body, prefix_uri, remote_meta) - } - - pub fn into_request(self, route_info: &RouteInfo) -> Request { - let (mut parts, body, prefix_uri, remote_meta) = self.into_parts(); - - let headers = parts.headers.borrow_mut(); - set_headers_from_route_info(headers, route_info, &prefix_uri, remote_meta); - - Request::from_parts(parts, body) - } - - pub fn into_request_pair(self, route_info: &RouteInfo) -> (Request, Request) { - let (parts, body, prefix_uri, remote_meta) = self.into_parts(); - let req2 = clone_request_with_empty_body(&parts, route_info, &prefix_uri, remote_meta); - let req1 = Request::from_parts(parts, body); - - (req1, req2) - } - - pub fn should_upgrade(&self) -> bool { - let Some(conn_header) = self.parts.headers.get("connection") else { - return false; - }; - - let Ok(conn_header) = conn_header.to_str() else { - return false; - }; - - conn_header - .to_lowercase() - .split(',') - .any(|s| s.trim() == "upgrade") - } -} - -fn clone_request_with_empty_body( - parts: &request::Parts, - route_info: &RouteInfo, - prefix_uri: &Uri, - remote_meta: ForwardableRequestInfo, -) -> request::Request { - let mut builder = request::Builder::new() - .method(parts.method.clone()) - .uri(parts.uri.clone()); - - let headers = builder - .headers_mut() - .expect("Can always call headers_mut() on a new builder."); - - headers.extend(parts.headers.clone()); - set_headers_from_route_info(headers, route_info, prefix_uri, remote_meta); - - builder - .body(Body::empty()) - .expect("Request is always valid.") -} - -fn extract_bearer_token(parts: &mut uri::Parts) -> Option { - let Some(path_and_query) = parts.path_and_query.clone() else { - panic!("No path and query"); - }; - - let full_path = path_and_query.path().strip_prefix('/')?; - - // Split the incoming path into the token and the path to proxy to. If there is no slash, the token is - // the full incoming path, and the path to proxy to is just `/`. - let (token, path) = match full_path.split_once('/') { - Some((token, path)) => (token, path), - None => (full_path, "/"), - }; - - let token = BearerToken::from(token.to_string()); - - if token.is_static() { - // We don't rewrite the URL if using a static token. - return Some(token); - } - - let query = path_and_query - .query() - .map(|query| format!("?{}", query)) - .unwrap_or_default(); - - parts.path_and_query = Some( - PathAndQuery::from_str(format!("/{}{}", path, query).as_str()) - .expect("Path and query is valid."), - ); - - Some(token) -} - -fn set_headers_from_route_info( - headers: &mut HeaderMap, - route_info: &RouteInfo, - prefix_uri: &Uri, - remote_meta: ForwardableRequestInfo, -) { - let mut headers_to_remove = Vec::new(); - for header_name in headers.keys() { - if header_name.as_str().starts_with(VERIFIED_HEADER_PREFIX) { - headers_to_remove.push(header_name.clone()); - } - } - - for header_name in headers_to_remove { - headers.remove(header_name); - } - - if let Some(user) = &route_info.user { - headers.insert( - USERNAME_HEADER, - HeaderValue::from_str(user.as_str()).expect("User is valid."), - ); - } - - let forwards = if let Some(forwards) = headers.get(X_FORWARDED_FOR_HEADER) { - let forwards = forwards.to_str().unwrap_or("").to_string(); - format!("{}, {}", forwards, remote_meta.ip) - } else { - remote_meta.ip.to_string() - }; - - headers.insert( - X_FORWARDED_FOR_HEADER, - HeaderValue::from_str(forwards.as_str()).expect("Forwards are valid."), - ); - - if headers.get(X_FORWARDED_PROTO_HEADER).is_none() { - headers.insert( - "x-forwarded-proto", - HeaderValue::from_static(remote_meta.protocol.as_str()), - ); - } - - headers.insert( - AUTH_SECRET_HEADER, - HeaderValue::from_str(&route_info.secret_token.to_string()).expect("Secret is valid."), - ); - - headers.insert( - AUTH_USER_DATA_HEADER, - HeaderValue::from_str( - &serde_json::to_string(&route_info.user_data) - .expect("JSON value should always serialize."), - ) - .expect("User data is valid"), - ); - - headers.insert( - PATH_PREFIX_HEADER, - HeaderValue::from_str(&prefix_uri.to_string()).expect("Path is valid."), - ); - - headers.insert( - BACKEND_ID_HEADER, - route_info - .backend_id - .to_string() - .parse() - .expect("Backend ID is a valid header value."), - ); -} diff --git a/plane/src/proxy/shutdown_signal.rs b/plane/src/proxy/shutdown_signal.rs deleted file mode 100644 index 3a8142920..000000000 --- a/plane/src/proxy/shutdown_signal.rs +++ /dev/null @@ -1,24 +0,0 @@ -use futures_util::Future; -use tokio::sync::broadcast; - -pub struct ShutdownSignal { - send_shutdown: broadcast::Sender<()>, -} - -impl ShutdownSignal { - pub fn new() -> Self { - let (send_shutdown, _) = broadcast::channel::<()>(1); - Self { send_shutdown } - } - - pub fn shutdown(&self) { - let _ = self.send_shutdown.send(()); - } - - pub fn subscribe(&self) -> impl Future + Send + 'static { - let mut receiver = self.send_shutdown.subscribe(); - async move { - let _ = receiver.recv().await; - } - } -} diff --git a/plane/src/proxy/subdomain.rs b/plane/src/proxy/subdomain.rs deleted file mode 100644 index cf0db47b6..000000000 --- a/plane/src/proxy/subdomain.rs +++ /dev/null @@ -1,110 +0,0 @@ -use super::rewriter::RequestRewriterError; -use crate::types::ClusterName; - -// If a cluster name does not specify a port, :443 is implied. -// Most browsers will not specify it, but some (e.g. the `ws` websocket client in Node.js) -// will, so we strip it. -const HTTPS_PORT_SUFFIX: &str = ":443"; - -/// Returns Ok(Some(subdomain)) if a subdomain is found. -/// Returns Ok(None) if no subdomain is found, but the host header matches the cluster name. -/// Returns Err(RequestRewriterError::InvalidHostHeader) if the host header does not -/// match the cluster name. -pub fn subdomain_from_host<'a>( - host: &'a str, - cluster: &ClusterName, -) -> Result, RequestRewriterError> { - let host = if let Some(host) = host.strip_suffix(HTTPS_PORT_SUFFIX) { - host - } else { - host - }; - - if let Some(subdomain) = host.strip_suffix(cluster.as_str()) { - if subdomain.is_empty() { - // Subdomain exactly matches cluster name. - Ok(None) - } else if let Some(subdomain) = subdomain.strip_suffix('.') { - Ok(Some(subdomain)) - } else { - Err(RequestRewriterError::InvalidHostHeader) - } - } else { - tracing::warn!(host, "Host header does not end in cluster name."); - Err(RequestRewriterError::InvalidHostHeader) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::str::FromStr; - - #[test] - fn no_subdomains() { - let host = "foo.bar.baz"; - let cluster = ClusterName::from_str("foo.bar.baz").unwrap(); - assert_eq!(subdomain_from_host(host, &cluster), Ok(None)); - } - - #[test] - fn valid_subdomain() { - let host = "foobar.example.com"; - let cluster = ClusterName::from_str("example.com").unwrap(); - assert_eq!(subdomain_from_host(host, &cluster), Ok(Some("foobar"))); - } - - #[test] - fn valid_suffix_no_dot() { - let host = "foobarexample.com"; - let cluster = ClusterName::from_str("example.com").unwrap(); - assert_eq!( - subdomain_from_host(host, &cluster), - Err(RequestRewriterError::InvalidHostHeader) - ); - } - - #[test] - fn invalid_suffix() { - let host = "abc.abc.com"; - let cluster = ClusterName::from_str("example.com").unwrap(); - assert_eq!( - subdomain_from_host(host, &cluster), - Err(RequestRewriterError::InvalidHostHeader) - ); - } - - #[test] - fn allowed_port() { - let host = "foobar.myhost:8080"; - let cluster = ClusterName::from_str("myhost:8080").unwrap(); - assert_eq!(subdomain_from_host(host, &cluster), Ok(Some("foobar"))); - } - - #[test] - fn port_required() { - let host = "foobar.myhost"; - let cluster = ClusterName::from_str("myhost:8080").unwrap(); - assert_eq!( - subdomain_from_host(host, &cluster), - Err(RequestRewriterError::InvalidHostHeader) - ); - } - - #[test] - fn port_must_match() { - let host = "foobar.myhost:8080"; - let cluster = ClusterName::from_str("myhost").unwrap(); - assert_eq!( - subdomain_from_host(host, &cluster), - Err(RequestRewriterError::InvalidHostHeader) - ); - } - - #[test] - fn port_443_optional() { - let host = "foobar.myhost:443"; - let cluster = ClusterName::from_str("myhost").unwrap(); - assert_eq!(subdomain_from_host(host, &cluster), Ok(Some("foobar"))); - } -} diff --git a/plane/src/proxy/tls.rs b/plane/src/proxy/tls.rs deleted file mode 100644 index bc954a7a4..000000000 --- a/plane/src/proxy/tls.rs +++ /dev/null @@ -1,120 +0,0 @@ -use core::task::Context; -use futures_util::{ready, Future}; -use hyper::server::accept::Accept; -use hyper::server::conn::{AddrIncoming, AddrStream}; -use std::io; -use std::net::IpAddr; -use std::pin::Pin; -use std::sync::Arc; -use std::task::Poll; -use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; -use tokio_rustls::rustls::ServerConfig; - -// From: https://github.com/rustls/hyper-rustls/blob/main/examples/server.rs -pub struct TlsAcceptor { - config: Arc, - incoming: AddrIncoming, -} - -impl TlsAcceptor { - pub fn new(config: Arc, incoming: AddrIncoming) -> TlsAcceptor { - TlsAcceptor { config, incoming } - } -} - -impl Accept for TlsAcceptor { - type Conn = TlsStream; - type Error = io::Error; - - fn poll_accept( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - ) -> Poll>> { - let pin = self.get_mut(); - match ready!(Pin::new(&mut pin.incoming).poll_accept(cx)) { - Some(Ok(sock)) => Poll::Ready(Some(Ok(TlsStream::new(sock, pin.config.clone())))), - Some(Err(e)) => Poll::Ready(Some(Err(e))), - None => Poll::Ready(None), - } - } -} - -impl AsyncRead for TlsStream { - fn poll_read( - self: Pin<&mut Self>, - cx: &mut Context, - buf: &mut ReadBuf, - ) -> Poll> { - let pin = self.get_mut(); - match pin.state { - State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { - Ok(mut stream) => { - let result = Pin::new(&mut stream).poll_read(cx, buf); - pin.state = State::Streaming(stream); - result - } - Err(err) => Poll::Ready(Err(err)), - }, - State::Streaming(ref mut stream) => Pin::new(stream).poll_read(cx, buf), - } - } -} - -impl AsyncWrite for TlsStream { - fn poll_write( - self: Pin<&mut Self>, - cx: &mut Context<'_>, - buf: &[u8], - ) -> Poll> { - let pin = self.get_mut(); - match pin.state { - State::Handshaking(ref mut accept) => match ready!(Pin::new(accept).poll(cx)) { - Ok(mut stream) => { - let result = Pin::new(&mut stream).poll_write(cx, buf); - pin.state = State::Streaming(stream); - result - } - Err(err) => Poll::Ready(Err(err)), - }, - State::Streaming(ref mut stream) => Pin::new(stream).poll_write(cx, buf), - } - } - - fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.state { - State::Handshaking(_) => Poll::Ready(Ok(())), - State::Streaming(ref mut stream) => Pin::new(stream).poll_flush(cx), - } - } - - fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { - match self.state { - State::Handshaking(_) => Poll::Ready(Ok(())), - State::Streaming(ref mut stream) => Pin::new(stream).poll_shutdown(cx), - } - } -} - -enum State { - Handshaking(tokio_rustls::Accept), - Streaming(tokio_rustls::server::TlsStream), -} - -// tokio_rustls::server::TlsStream doesn't expose constructor methods, -// so we have to TlsAcceptor::accept and handshake to have access to it -// TlsStream implements AsyncRead/AsyncWrite handshaking tokio_rustls::Accept first -pub struct TlsStream { - state: State, - pub remote_ip: IpAddr, -} - -impl TlsStream { - fn new(stream: AddrStream, config: Arc) -> TlsStream { - let remote_ip = stream.remote_addr().ip(); - let accept = tokio_rustls::TlsAcceptor::from(config).accept(stream); - TlsStream { - state: State::Handshaking(accept), - remote_ip, - } - } -} diff --git a/plane/src/typed_socket/client.rs b/plane/src/typed_socket/client.rs index 22bdc8727..10393c026 100644 --- a/plane/src/typed_socket/client.rs +++ b/plane/src/typed_socket/client.rs @@ -8,6 +8,10 @@ use std::marker::PhantomData; use tokio::net::TcpStream; use tokio_tungstenite::{MaybeTlsStream, WebSocketStream}; use tungstenite::handshake::client::generate_key; +use tungstenite::http::{ + header::{HeaderValue, AUTHORIZATION}, + Method, Request, +}; use tungstenite::{error::ProtocolError, Message}; type Socket = WebSocketStream>; @@ -88,9 +92,9 @@ impl TypedSocketConnector { } /// Creates a WebSocket request from an AuthorizedAddress. -fn auth_url_to_request(addr: &AuthorizedAddress) -> Result, PlaneClientError> { - let mut request = hyper::Request::builder() - .method(hyper::Method::GET) +fn auth_url_to_request(addr: &AuthorizedAddress) -> Result, PlaneClientError> { + let mut request = Request::builder() + .method(Method::GET) .uri(addr.url.as_str()) .header( "Host", @@ -108,8 +112,8 @@ fn auth_url_to_request(addr: &AuthorizedAddress) -> Result, P if let Some(bearer_header) = addr.bearer_header() { request = request.header( - hyper::header::AUTHORIZATION, - hyper::header::HeaderValue::from_str(&bearer_header).expect("Bearer header is valid"), + AUTHORIZATION, + HeaderValue::from_str(&bearer_header).expect("Bearer header is valid"), ); } diff --git a/plane/src/types/mod.rs b/plane/src/types/mod.rs index a48875bb1..66f072f4b 100644 --- a/plane/src/types/mod.rs +++ b/plane/src/types/mod.rs @@ -372,6 +372,12 @@ impl Display for SecretToken { } } +impl SecretToken { + pub fn as_str(&self) -> &str { + &self.0 + } +} + #[derive(Clone, Serialize, Deserialize, Debug)] pub struct ConnectResponse { pub backend_id: BackendName,