From 1481b482f246039a024736d46bab86a9f1051cdf Mon Sep 17 00:00:00 2001 From: Eran Ifrah Date: Mon, 14 Oct 2024 17:29:53 +0000 Subject: [PATCH 1/3] Make redis-rs part of this repo Signed-off-by: Eran Ifrah --- .github/workflows/python.yml | 6 +- .github/workflows/semgrep.yml | 2 +- .gitignore | 2 + .gitmodules | 3 - Makefile | 76 + benchmarks/rust/Cargo.toml | 2 +- csharp/lib/Cargo.toml | 2 +- glide-core/Cargo.toml | 4 +- glide-core/redis-rs/Cargo.toml | 3 + glide-core/redis-rs/LICENSE | 33 + glide-core/redis-rs/Makefile | 96 + glide-core/redis-rs/README.md | 233 + glide-core/redis-rs/afl/.gitignore | 2 + glide-core/redis-rs/afl/parser/Cargo.toml | 17 + glide-core/redis-rs/afl/parser/in/array | 5 + glide-core/redis-rs/afl/parser/in/array-null | 1 + glide-core/redis-rs/afl/parser/in/bulkstring | 2 + .../redis-rs/afl/parser/in/bulkstring-null | 1 + glide-core/redis-rs/afl/parser/in/error | 1 + glide-core/redis-rs/afl/parser/in/integer | 1 + .../redis-rs/afl/parser/in/invalid-string | 2 + glide-core/redis-rs/afl/parser/in/string | 1 + glide-core/redis-rs/afl/parser/src/main.rs | 9 + .../redis-rs/afl/parser/src/reproduce.rs | 13 + glide-core/redis-rs/appveyor.yml | 23 + glide-core/redis-rs/redis-test/CHANGELOG.md | 44 + glide-core/redis-rs/redis-test/Cargo.toml | 26 + glide-core/redis-rs/redis-test/LICENSE | 33 + glide-core/redis-rs/redis-test/README.md | 4 + glide-core/redis-rs/redis-test/release.toml | 1 + glide-core/redis-rs/redis-test/src/lib.rs | 426 ++ glide-core/redis-rs/redis/CHANGELOG.md | 828 ++++ glide-core/redis-rs/redis/Cargo.toml | 227 + glide-core/redis-rs/redis/LICENSE | 33 + .../redis-rs/redis/benches/bench_basic.rs | 277 ++ .../redis-rs/redis/benches/bench_cluster.rs | 108 + .../redis/benches/bench_cluster_async.rs | 88 + .../redis-rs/redis/examples/async-await.rs | 24 + .../redis/examples/async-connection-loss.rs | 97 + .../redis/examples/async-multiplexed.rs | 46 + .../redis-rs/redis/examples/async-pub-sub.rs | 22 + .../redis-rs/redis/examples/async-scan.rs | 25 + glide-core/redis-rs/redis/examples/basic.rs | 169 + .../redis-rs/redis/examples/geospatial.rs | 68 + glide-core/redis-rs/redis/examples/streams.rs | 270 ++ glide-core/redis-rs/redis/release.toml | 2 + glide-core/redis-rs/redis/src/acl.rs | 312 ++ .../redis-rs/redis/src/aio/async_std.rs | 269 ++ .../redis-rs/redis/src/aio/connection.rs | 543 +++ .../redis/src/aio/connection_manager.rs | 310 ++ glide-core/redis-rs/redis/src/aio/mod.rs | 286 ++ .../redis/src/aio/multiplexed_connection.rs | 656 +++ glide-core/redis-rs/redis/src/aio/runtime.rs | 82 + glide-core/redis-rs/redis/src/aio/tokio.rs | 204 + glide-core/redis-rs/redis/src/client.rs | 855 ++++ glide-core/redis-rs/redis/src/cluster.rs | 1076 +++++ .../redis-rs/redis/src/cluster_async/LICENSE | 7 + .../cluster_async/connections_container.rs | 881 ++++ .../src/cluster_async/connections_logic.rs | 481 ++ .../redis-rs/redis/src/cluster_async/mod.rs | 2656 +++++++++++ .../redis-rs/redis/src/cluster_client.rs | 752 +++ .../redis-rs/redis/src/cluster_pipeline.rs | 151 + .../redis-rs/redis/src/cluster_routing.rs | 1374 ++++++ .../redis-rs/redis/src/cluster_slotmap.rs | 435 ++ .../redis-rs/redis/src/cluster_topology.rs | 645 +++ glide-core/redis-rs/redis/src/cmd.rs | 663 +++ .../redis/src/commands/cluster_scan.rs | 720 +++ .../redis-rs/redis/src/commands/json.rs | 390 ++ .../redis-rs/redis/src/commands/macros.rs | 275 ++ glide-core/redis-rs/redis/src/commands/mod.rs | 2190 +++++++++ glide-core/redis-rs/redis/src/connection.rs | 1997 ++++++++ glide-core/redis-rs/redis/src/geo.rs | 361 ++ glide-core/redis-rs/redis/src/lib.rs | 506 ++ glide-core/redis-rs/redis/src/macros.rs | 7 + glide-core/redis-rs/redis/src/parser.rs | 658 +++ glide-core/redis-rs/redis/src/pipeline.rs | 324 ++ glide-core/redis-rs/redis/src/push_manager.rs | 234 + glide-core/redis-rs/redis/src/r2d2.rs | 36 + glide-core/redis-rs/redis/src/script.rs | 255 + glide-core/redis-rs/redis/src/sentinel.rs | 778 +++ glide-core/redis-rs/redis/src/streams.rs | 670 +++ glide-core/redis-rs/redis/src/tls.rs | 142 + glide-core/redis-rs/redis/src/types.rs | 2460 ++++++++++ glide-core/redis-rs/redis/tests/parser.rs | 195 + .../redis-rs/redis/tests/support/cluster.rs | 792 +++ .../redis/tests/support/mock_cluster.rs | 487 ++ .../redis-rs/redis/tests/support/mod.rs | 887 ++++ .../redis-rs/redis/tests/support/sentinel.rs | 404 ++ .../redis-rs/redis/tests/support/util.rs | 23 + glide-core/redis-rs/redis/tests/test_acl.rs | 156 + glide-core/redis-rs/redis/tests/test_async.rs | 1132 +++++ .../redis/tests/test_async_async_std.rs | 328 ++ .../test_async_cluster_connections_logic.rs | 563 +++ glide-core/redis-rs/redis/tests/test_basic.rs | 1581 ++++++ .../redis-rs/redis/tests/test_bignum.rs | 61 + .../redis-rs/redis/tests/test_cluster.rs | 1093 +++++ .../redis/tests/test_cluster_async.rs | 4245 +++++++++++++++++ .../redis-rs/redis/tests/test_cluster_scan.rs | 849 ++++ .../redis-rs/redis/tests/test_geospatial.rs | 197 + .../redis-rs/redis/tests/test_module_json.rs | 540 +++ .../redis-rs/redis/tests/test_sentinel.rs | 496 ++ .../redis-rs/redis/tests/test_streams.rs | 627 +++ glide-core/redis-rs/redis/tests/test_types.rs | 606 +++ glide-core/redis-rs/release.sh | 15 + glide-core/redis-rs/rustfmt.toml | 2 + .../redis-rs/scripts/get_command_info.py | 227 + .../redis-rs/scripts/update-versions.sh | 20 + glide-core/redis-rs/upload-docs.sh | 26 + go/Cargo.toml | 2 +- go/DEVELOPER.md | 4 +- java/Cargo.toml | 2 +- node/DEVELOPER.md | 2 + node/rust-client/Cargo.toml | 2 +- python/Cargo.toml | 2 +- python/DEVELOPER.md | 175 +- python/python/tests/test_async_client.py | 1 + submodules/redis-rs | 1 - 117 files changed, 43641 insertions(+), 101 deletions(-) create mode 100644 Makefile create mode 100644 glide-core/redis-rs/Cargo.toml create mode 100644 glide-core/redis-rs/LICENSE create mode 100644 glide-core/redis-rs/Makefile create mode 100644 glide-core/redis-rs/README.md create mode 100644 glide-core/redis-rs/afl/.gitignore create mode 100644 glide-core/redis-rs/afl/parser/Cargo.toml create mode 100644 glide-core/redis-rs/afl/parser/in/array create mode 100644 glide-core/redis-rs/afl/parser/in/array-null create mode 100644 glide-core/redis-rs/afl/parser/in/bulkstring create mode 100644 glide-core/redis-rs/afl/parser/in/bulkstring-null create mode 100644 glide-core/redis-rs/afl/parser/in/error create mode 100644 glide-core/redis-rs/afl/parser/in/integer create mode 100644 glide-core/redis-rs/afl/parser/in/invalid-string create mode 100644 glide-core/redis-rs/afl/parser/in/string create mode 100644 glide-core/redis-rs/afl/parser/src/main.rs create mode 100644 glide-core/redis-rs/afl/parser/src/reproduce.rs create mode 100644 glide-core/redis-rs/appveyor.yml create mode 100644 glide-core/redis-rs/redis-test/CHANGELOG.md create mode 100644 glide-core/redis-rs/redis-test/Cargo.toml create mode 100644 glide-core/redis-rs/redis-test/LICENSE create mode 100644 glide-core/redis-rs/redis-test/README.md create mode 100644 glide-core/redis-rs/redis-test/release.toml create mode 100644 glide-core/redis-rs/redis-test/src/lib.rs create mode 100644 glide-core/redis-rs/redis/CHANGELOG.md create mode 100644 glide-core/redis-rs/redis/Cargo.toml create mode 100644 glide-core/redis-rs/redis/LICENSE create mode 100644 glide-core/redis-rs/redis/benches/bench_basic.rs create mode 100644 glide-core/redis-rs/redis/benches/bench_cluster.rs create mode 100644 glide-core/redis-rs/redis/benches/bench_cluster_async.rs create mode 100644 glide-core/redis-rs/redis/examples/async-await.rs create mode 100644 glide-core/redis-rs/redis/examples/async-connection-loss.rs create mode 100644 glide-core/redis-rs/redis/examples/async-multiplexed.rs create mode 100644 glide-core/redis-rs/redis/examples/async-pub-sub.rs create mode 100644 glide-core/redis-rs/redis/examples/async-scan.rs create mode 100644 glide-core/redis-rs/redis/examples/basic.rs create mode 100644 glide-core/redis-rs/redis/examples/geospatial.rs create mode 100644 glide-core/redis-rs/redis/examples/streams.rs create mode 100644 glide-core/redis-rs/redis/release.toml create mode 100644 glide-core/redis-rs/redis/src/acl.rs create mode 100644 glide-core/redis-rs/redis/src/aio/async_std.rs create mode 100644 glide-core/redis-rs/redis/src/aio/connection.rs create mode 100644 glide-core/redis-rs/redis/src/aio/connection_manager.rs create mode 100644 glide-core/redis-rs/redis/src/aio/mod.rs create mode 100644 glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs create mode 100644 glide-core/redis-rs/redis/src/aio/runtime.rs create mode 100644 glide-core/redis-rs/redis/src/aio/tokio.rs create mode 100644 glide-core/redis-rs/redis/src/client.rs create mode 100644 glide-core/redis-rs/redis/src/cluster.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_async/LICENSE create mode 100644 glide-core/redis-rs/redis/src/cluster_async/connections_container.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_async/mod.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_client.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_pipeline.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_routing.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_slotmap.rs create mode 100644 glide-core/redis-rs/redis/src/cluster_topology.rs create mode 100644 glide-core/redis-rs/redis/src/cmd.rs create mode 100644 glide-core/redis-rs/redis/src/commands/cluster_scan.rs create mode 100644 glide-core/redis-rs/redis/src/commands/json.rs create mode 100644 glide-core/redis-rs/redis/src/commands/macros.rs create mode 100644 glide-core/redis-rs/redis/src/commands/mod.rs create mode 100644 glide-core/redis-rs/redis/src/connection.rs create mode 100644 glide-core/redis-rs/redis/src/geo.rs create mode 100644 glide-core/redis-rs/redis/src/lib.rs create mode 100644 glide-core/redis-rs/redis/src/macros.rs create mode 100644 glide-core/redis-rs/redis/src/parser.rs create mode 100644 glide-core/redis-rs/redis/src/pipeline.rs create mode 100644 glide-core/redis-rs/redis/src/push_manager.rs create mode 100644 glide-core/redis-rs/redis/src/r2d2.rs create mode 100644 glide-core/redis-rs/redis/src/script.rs create mode 100644 glide-core/redis-rs/redis/src/sentinel.rs create mode 100644 glide-core/redis-rs/redis/src/streams.rs create mode 100644 glide-core/redis-rs/redis/src/tls.rs create mode 100644 glide-core/redis-rs/redis/src/types.rs create mode 100644 glide-core/redis-rs/redis/tests/parser.rs create mode 100644 glide-core/redis-rs/redis/tests/support/cluster.rs create mode 100644 glide-core/redis-rs/redis/tests/support/mock_cluster.rs create mode 100644 glide-core/redis-rs/redis/tests/support/mod.rs create mode 100644 glide-core/redis-rs/redis/tests/support/sentinel.rs create mode 100644 glide-core/redis-rs/redis/tests/support/util.rs create mode 100644 glide-core/redis-rs/redis/tests/test_acl.rs create mode 100644 glide-core/redis-rs/redis/tests/test_async.rs create mode 100644 glide-core/redis-rs/redis/tests/test_async_async_std.rs create mode 100644 glide-core/redis-rs/redis/tests/test_async_cluster_connections_logic.rs create mode 100644 glide-core/redis-rs/redis/tests/test_basic.rs create mode 100644 glide-core/redis-rs/redis/tests/test_bignum.rs create mode 100644 glide-core/redis-rs/redis/tests/test_cluster.rs create mode 100644 glide-core/redis-rs/redis/tests/test_cluster_async.rs create mode 100644 glide-core/redis-rs/redis/tests/test_cluster_scan.rs create mode 100644 glide-core/redis-rs/redis/tests/test_geospatial.rs create mode 100644 glide-core/redis-rs/redis/tests/test_module_json.rs create mode 100644 glide-core/redis-rs/redis/tests/test_sentinel.rs create mode 100644 glide-core/redis-rs/redis/tests/test_streams.rs create mode 100644 glide-core/redis-rs/redis/tests/test_types.rs create mode 100755 glide-core/redis-rs/release.sh create mode 100644 glide-core/redis-rs/rustfmt.toml create mode 100644 glide-core/redis-rs/scripts/get_command_info.py create mode 100755 glide-core/redis-rs/scripts/update-versions.sh create mode 100755 glide-core/redis-rs/upload-docs.sh delete mode 160000 submodules/redis-rs diff --git a/.github/workflows/python.yml b/.github/workflows/python.yml index c85045df07..45d2c0cf0d 100644 --- a/.github/workflows/python.yml +++ b/.github/workflows/python.yml @@ -216,9 +216,9 @@ jobs: - name: Install dependencies if: always() - uses: threeal/pipx-install-action@latest - with: - packages: flake8 isort black + working-directory: ./python + run: | + sudo apt install -y python3-pip python3 flake8 isort black - name: Lint python with isort if: always() diff --git a/.github/workflows/semgrep.yml b/.github/workflows/semgrep.yml index 6e4235abdb..4bfd9e12ac 100644 --- a/.github/workflows/semgrep.yml +++ b/.github/workflows/semgrep.yml @@ -33,4 +33,4 @@ jobs: # Fetch project source with GitHub Actions Checkout. - uses: actions/checkout@v3 # Run the "semgrep ci" command on the command line of the docker image. - - run: semgrep ci --config auto --no-suppress-errors + - run: semgrep ci --config auto --no-suppress-errors --exclude-rule generic.secrets.security.detected-private-key.detected-private-key diff --git a/.gitignore b/.gitignore index 573bfc218d..6799f31ea6 100644 --- a/.gitignore +++ b/.gitignore @@ -43,6 +43,8 @@ logger-rs.linux-x64-gnu.node utils/clusters/ utils/tls_crts/ utils/TestUtils.js +.build/ +.project # OSS Review Toolkit (ORT) files **/ort*/** diff --git a/.gitmodules b/.gitmodules index 87a3d9b855..e69de29bb2 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,3 +0,0 @@ -[submodule "submodules/redis-rs"] - path = submodules/redis-rs - url = https://github.com/amazon-contributing/redis-rs diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..438cd4d08d --- /dev/null +++ b/Makefile @@ -0,0 +1,76 @@ +.PHONY: all java java-test python python-test node node-test check-redis-server go go-test + +BLUE=\033[34m +YELLOW=\033[33m +GREEN=\033[32m +RESET=\033[0m +ROOT_DIR=$(shell pwd) +PYENV_DIR=$(shell pwd)/python/.env +PY_PATH=$(shell find python/.env -name "site-packages"|xargs readlink -f) +PY_GLIDE_PATH=$(shell pwd)/python/python/ + +all: java java-test python python-test node node-test go go-test + +java: + @echo "$(GREEN)Building for Java (release)$(RESET)" + @cd java && ./gradlew :client:buildAllRelease + +java-test: check-redis-server + @echo "$(GREEN)Running spotlessCheck$(RESET)" + @cd java && ./gradlew :spotlessCheck + @echo "$(GREEN)Running spotlessApply$(RESET)" + @cd java && ./gradlew :spotlessApply + @echo "$(GREEN)Running integration tests$(RESET)" + @cd java && ./gradlew :integTest:test + +python: .build/python_deps + @echo "$(GREEN)Building for Python (release)$(RESET)" + @cd python && VIRTUAL_ENV=$(PYENV_DIR) .env/bin/maturin develop --release --strip + +# Python dependencies +.build/python_deps: + @echo "$(GREEN)Generating protobuf files...$(RESET)" + @protoc -Iprotobuf=$(ROOT_DIR)/glide-core/src/protobuf/ \ + --python_out=$(ROOT_DIR)/python/python/glide $(ROOT_DIR)/glide-core/src/protobuf/*.proto + @echo "$(GREEN)Building environment...$(RESET)" + @cd python && python3 -m venv .env + @echo "$(GREEN)Installing requirements...$(RESET)" + @cd python && .env/bin/pip install -r requirements.txt + @mkdir -p .build/ && touch .build/python_deps + +python-test: check-redis-server + cd python && PYTHONPATH=$(PY_PATH):$(PY_GLIDE_PATH) .env/bin/pytest --asyncio-mode=auto + +node: .build/node_deps + @echo "$(GREEN)Building for NodeJS (release)...$(RESET)" + @cd node && npm run build:release + +# NodeJS dependencies +.build/node_deps: + @echo "$(GREEN)Installing NodeJS dependencies...$(RESET)" + @cd node && npm i + @cd node/rust-client && npm i + @mkdir -p .build/ && touch .build/node_deps + +node-test: check-redis-server + @echo "$(GREEN)Running tests for NodeJS$(RESET)" + @cd node && npm run build + cd node && npm test + +# Check for the existence of redis-server by simply calling which shell command +check-redis-server: + which redis-server + +go: .build/go_deps + $(MAKE) -C go build + +go-test: + $(MAKE) -C go test + +.build/go_deps: + @echo "$(GREEN)Installing GO dependencies...$(RESET)" + $(MAKE) -C go install-build-tools + @mkdir -p .build/ && touch .build/go_deps + +clean: + rm -fr .build/ diff --git a/benchmarks/rust/Cargo.toml b/benchmarks/rust/Cargo.toml index 6f0849d505..d63bc98e57 100644 --- a/benchmarks/rust/Cargo.toml +++ b/benchmarks/rust/Cargo.toml @@ -11,7 +11,7 @@ authors = ["Valkey GLIDE Maintainers"] tokio = { version = "1", features = ["macros", "time", "rt-multi-thread"] } glide-core = { path = "../../glide-core" } logger_core = {path = "../../logger_core"} -redis = { path = "../../submodules/redis-rs/redis", features = ["aio"] } +redis = { path = "../../glide-core/redis-rs/redis", features = ["aio"] } futures = "0.3.28" rand = "0.8.5" itoa = "1.0.6" diff --git a/csharp/lib/Cargo.toml b/csharp/lib/Cargo.toml index 95981480b2..b49e098bf7 100644 --- a/csharp/lib/Cargo.toml +++ b/csharp/lib/Cargo.toml @@ -12,7 +12,7 @@ name = "glide_rs" crate-type = ["cdylib"] [dependencies] -redis = { path = "../../submodules/redis-rs/redis", features = ["aio", "tokio-comp","tokio-native-tls-comp"] } +redis = { path = "../../glide-core/redis-rs/redis", features = ["aio", "tokio-comp","tokio-native-tls-comp"] } glide-core = { path = "../../glide-core" } tokio = { version = "^1", features = ["rt", "macros", "rt-multi-thread", "time"] } logger_core = {path = "../../logger_core"} diff --git a/glide-core/Cargo.toml b/glide-core/Cargo.toml index e0a1b05368..51de808bd2 100644 --- a/glide-core/Cargo.toml +++ b/glide-core/Cargo.toml @@ -10,7 +10,7 @@ authors = ["Valkey GLIDE Maintainers"] [dependencies] bytes = "1" futures = "^0.3" -redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp", "connection-manager","cluster", "cluster-async"] } +redis = { path = "./redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp", "connection-manager","cluster", "cluster-async"] } tokio = { version = "1", features = ["macros", "time"] } logger_core = {path = "../logger_core"} dispose = "0.5.0" @@ -42,7 +42,7 @@ serial_test = "3" criterion = { version = "^0.5", features = ["html_reports", "async_tokio"] } which = "5" ctor = "0.2.2" -redis = { path = "../submodules/redis-rs/redis", features = ["tls-rustls-insecure"] } +redis = { path = "./redis-rs/redis", features = ["tls-rustls-insecure"] } iai-callgrind = "0.9" tokio = { version = "1", features = ["rt-multi-thread"] } glide-core = { path = ".", features = ["socket-layer"] } # always enable this feature in tests. diff --git a/glide-core/redis-rs/Cargo.toml b/glide-core/redis-rs/Cargo.toml new file mode 100644 index 0000000000..2f4ebbcbbe --- /dev/null +++ b/glide-core/redis-rs/Cargo.toml @@ -0,0 +1,3 @@ +[workspace] +members = ["redis", "redis-test"] +resolver = "2" diff --git a/glide-core/redis-rs/LICENSE b/glide-core/redis-rs/LICENSE new file mode 100644 index 0000000000..533ac4e5a2 --- /dev/null +++ b/glide-core/redis-rs/LICENSE @@ -0,0 +1,33 @@ +Copyright (c) 2022 by redis-rs contributors + +Redis cluster code in parts copyright (c) 2018 by Atsushi Koge. + +Some rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * The names of the contributors may not be used to endorse or + promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/glide-core/redis-rs/Makefile b/glide-core/redis-rs/Makefile new file mode 100644 index 0000000000..9f177a6c22 --- /dev/null +++ b/glide-core/redis-rs/Makefile @@ -0,0 +1,96 @@ +build: + @cargo build + +test: + @echo "====================================================================" + @echo "Build all features with lock file" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" cargo build --locked --all-features + + @echo "====================================================================" + @echo "Testing Connection Type TCP without features" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 cargo test --locked -p redis --no-default-features -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type TCP with all features and RESP2" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 cargo test --locked -p redis --all-features -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type TCP with all features and RESP3" + @echo "====================================================================" + @REDISRS_SERVER_TYPE=tcp PROTOCOL=RESP3 cargo test -p redis --all-features -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type TCP with all features and Rustls support" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp+tls RUST_BACKTRACE=1 cargo test --locked -p redis --all-features -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type TCP with all features and native-TLS support" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp+tls RUST_BACKTRACE=1 cargo test --locked -p redis --features=json,tokio-native-tls-comp,connection-manager,cluster-async -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type UNIX" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=unix RUST_BACKTRACE=1 cargo test --locked -p redis --test parser --test test_basic --test test_types --all-features -- --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing Connection Type UNIX SOCKETS" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=unix RUST_BACKTRACE=1 cargo test --locked -p redis --all-features -- --test-threads=1 --skip test_cluster --skip test_async_cluster --skip test_module --skip test_cluster_scan + + @echo "====================================================================" + @echo "Testing async-std with Rustls" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 cargo test --locked -p redis --features=async-std-rustls-comp,cluster-async -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing async-std with native-TLS" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 cargo test --locked -p redis --features=async-std-native-tls-comp,cluster-async -- --nocapture --test-threads=1 --skip test_module + + @echo "====================================================================" + @echo "Testing redis-test" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" RUST_BACKTRACE=1 cargo test --locked -p redis-test + + +test-module: + @echo "====================================================================" + @echo "Testing RESP2 with module support enabled (currently only RedisJSON)" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 cargo test --locked --all-features test_module -- --test-threads=1 + + @echo "====================================================================" + @echo "Testing RESP3 with module support enabled (currently only RedisJSON)" + @echo "====================================================================" + @RUSTFLAGS="-D warnings" REDISRS_SERVER_TYPE=tcp RUST_BACKTRACE=1 RESP3=true cargo test --all-features test_module -- --test-threads=1 + +test-single: test + +bench: + cargo bench --all-features + +docs: + @RUSTFLAGS="-D warnings" RUSTDOCFLAGS="--cfg docsrs" cargo +nightly doc --all-features --no-deps + +upload-docs: docs + @./upload-docs.sh + +style-check: + @rustup component add rustfmt 2> /dev/null + cargo fmt --all -- --check + +lint: + @rustup component add clippy 2> /dev/null + cargo clippy --all-features --all --tests --examples -- -D clippy::all -D warnings + +fuzz: + cd afl/parser/ && \ + cargo afl build --bin fuzz-target && \ + cargo afl fuzz -i in -o out target/debug/fuzz-target + +.PHONY: build test bench docs upload-docs style-check lint fuzz diff --git a/glide-core/redis-rs/README.md b/glide-core/redis-rs/README.md new file mode 100644 index 0000000000..34cdfe4778 --- /dev/null +++ b/glide-core/redis-rs/README.md @@ -0,0 +1,233 @@ +# redis-rs + +[![Rust](https://github.com/redis-rs/redis-rs/actions/workflows/rust.yml/badge.svg)](https://github.com/redis-rs/redis-rs/actions/workflows/rust.yml) +[![crates.io](https://img.shields.io/crates/v/redis.svg)](https://crates.io/crates/redis) +[![Chat](https://img.shields.io/discord/976380008299917365?logo=discord)](https://discord.gg/WHKcJK9AKP) + +Redis-rs is a high level redis library for Rust. It provides convenient access +to all Redis functionality through a very flexible but low-level API. It +uses a customizable type conversion trait so that any operation can return +results in just the type you are expecting. This makes for a very pleasant +development experience. + +The crate is called `redis` and you can depend on it via cargo: + +```ini +[dependencies] +redis = "0.25.2" +``` + +Documentation on the library can be found at +[docs.rs/redis](https://docs.rs/redis). + +**Note: redis-rs requires at least Rust 1.60.** + +## Basic Operation + +To open a connection you need to create a client and then to fetch a +connection from it. In the future there will be a connection pool for +those, currently each connection is separate and not pooled. + +Many commands are implemented through the `Commands` trait but manual +command creation is also possible. + +```rust +use redis::Commands; + +fn fetch_an_integer() -> redis::RedisResult { + // connect to redis + let client = redis::Client::open("redis://127.0.0.1/")?; + let mut con = client.get_connection(None)?; + // throw away the result, just make sure it does not fail + let _ : () = con.set("my_key", 42)?; + // read back the key and return it. Because the return value + // from the function is a result for integer this will automatically + // convert into one. + con.get("my_key") +} +``` + +Variables are converted to and from the Redis format for a wide variety of types +(`String`, num types, tuples, `Vec`). If you want to use it with your own types, +you can implement the `FromRedisValue` and `ToRedisArgs` traits, or derive it with the +[redis-macros](https://github.com/daniel7grant/redis-macros/#json-wrapper-with-redisjson) crate. + +## Async support + +To enable asynchronous clients, enable the relevant feature in your Cargo.toml, +`tokio-comp` for tokio users or `async-std-comp` for async-std users. + +``` +# if you use tokio +redis = { version = "0.25.2", features = ["tokio-comp"] } + +# if you use async-std +redis = { version = "0.25.2", features = ["async-std-comp"] } +``` + +## TLS Support + +To enable TLS support, you need to use the relevant feature entry in your Cargo.toml. +Currently, `native-tls` and `rustls` are supported. + +To use `native-tls`: + +``` +redis = { version = "0.25.2", features = ["tls-native-tls"] } + +# if you use tokio +redis = { version = "0.25.2", features = ["tokio-native-tls-comp"] } + +# if you use async-std +redis = { version = "0.25.2", features = ["async-std-native-tls-comp"] } +``` + +To use `rustls`: + +``` +redis = { version = "0.25.2", features = ["tls-rustls"] } + +# if you use tokio +redis = { version = "0.25.2", features = ["tokio-rustls-comp"] } + +# if you use async-std +redis = { version = "0.25.2", features = ["async-std-rustls-comp"] } +``` + +With `rustls`, you can add the following feature flags on top of other feature flags to enable additional features: + +- `tls-rustls-insecure`: Allow insecure TLS connections +- `tls-rustls-webpki-roots`: Use `webpki-roots` (Mozilla's root certificates) instead of native root certificates + +then you should be able to connect to a redis instance using the `rediss://` URL scheme: + +```rust +let client = redis::Client::open("rediss://127.0.0.1/")?; +``` + +To enable insecure mode, append `#insecure` at the end of the URL: + +```rust +let client = redis::Client::open("rediss://127.0.0.1/#insecure")?; +``` + +**Deprecation Notice:** If you were using the `tls` or `async-std-tls-comp` features, please use the `tls-native-tls` or `async-std-native-tls-comp` features respectively. + +## Cluster Support + +Support for Redis Cluster can be enabled by enabling the `cluster` feature in your Cargo.toml: + +`redis = { version = "0.25.2", features = [ "cluster"] }` + +Then you can simply use the `ClusterClient`, which accepts a list of available nodes. Note +that only one node in the cluster needs to be specified when instantiating the client, though +you can specify multiple. + +```rust +use redis::cluster::ClusterClient; +use redis::Commands; + +fn fetch_an_integer() -> String { + let nodes = vec!["redis://127.0.0.1/"]; + let client = ClusterClient::new(nodes).unwrap(); + let mut connection = client.get_connection(None).unwrap(); + let _: () = connection.set("test", "test_data").unwrap(); + let rv: String = connection.get("test").unwrap(); + return rv; +} +``` + +Async Redis Cluster support can be enabled by enabling the `cluster-async` feature, along +with your preferred async runtime, e.g.: + +`redis = { version = "0.25.2", features = [ "cluster-async", "tokio-std-comp" ] }` + +```rust +use redis::cluster::ClusterClient; +use redis::AsyncCommands; + +async fn fetch_an_integer() -> String { + let nodes = vec!["redis://127.0.0.1/"]; + let client = ClusterClient::new(nodes).unwrap(); + let mut connection = client.get_async_connection().await.unwrap(); + let _: () = connection.set("test", "test_data").await.unwrap(); + let rv: String = connection.get("test").await.unwrap(); + return rv; +} +``` + +## JSON Support + +Support for the RedisJSON Module can be enabled by specifying "json" as a feature in your Cargo.toml. + +`redis = { version = "0.25.2", features = ["json"] }` + +Then you can simply import the `JsonCommands` trait which will add the `json` commands to all Redis Connections (not to be confused with just `Commands` which only adds the default commands) + +```rust +use redis::Client; +use redis::JsonCommands; +use redis::RedisResult; +use redis::ToRedisArgs; + +// Result returns Ok(true) if the value was set +// Result returns Err(e) if there was an error with the server itself OR serde_json was unable to serialize the boolean +fn set_json_bool(key: P, path: P, b: bool) -> RedisResult { + let client = Client::open("redis://127.0.0.1").unwrap(); + let connection = client.get_connection(None).unwrap(); + + // runs `JSON.SET {key} {path} {b}` + connection.json_set(key, path, b)? +} + +``` + +To parse the results, you'll need to use `serde_json` (or some other json lib) to deserialize +the results from the bytes. It will always be a `Vec`, if no results were found at the path it'll +be an empty `Vec`. If you want to handle deserialization and `Vec` unwrapping automatically, +you can use the `Json` wrapper from the +[redis-macros](https://github.com/daniel7grant/redis-macros/#json-wrapper-with-redisjson) crate. + +## Development + +To test `redis` you're going to need to be able to test with the Redis Modules, to do this +you must set the following environment variable before running the test script + +- `REDIS_RS_REDIS_JSON_PATH` = The absolute path to the RedisJSON module (Either `librejson.so` for Linux or `librejson.dylib` for MacOS). + +- Please refer to this [link](https://github.com/RedisJSON/RedisJSON) to access the RedisJSON module: + + + +If you want to develop on the library there are a few commands provided +by the makefile: + +To build: + + $ make + +To test: + + $ make test + +To run benchmarks: + + $ make bench + +To build the docs (require nightly compiler, see [rust-lang/rust#43781](https://github.com/rust-lang/rust/issues/43781)): + + $ make docs + +We encourage you to run `clippy` prior to seeking a merge for your work. The lints can be quite strict. Running this on your own workstation can save you time, since Travis CI will fail any build that doesn't satisfy `clippy`: + + $ cargo clippy --all-features --all --tests --examples -- -D clippy::all -D warnings + +To run fuzz tests with afl, first install cargo-afl (`cargo install -f afl`), +then run: + + $ make fuzz + +If the fuzzer finds a crash, in order to reproduce it, run: + + $ cd afl// + $ cargo run --bin reproduce -- out/crashes/ diff --git a/glide-core/redis-rs/afl/.gitignore b/glide-core/redis-rs/afl/.gitignore new file mode 100644 index 0000000000..1776e13233 --- /dev/null +++ b/glide-core/redis-rs/afl/.gitignore @@ -0,0 +1,2 @@ +out/ +core.* diff --git a/glide-core/redis-rs/afl/parser/Cargo.toml b/glide-core/redis-rs/afl/parser/Cargo.toml new file mode 100644 index 0000000000..9f5202d86a --- /dev/null +++ b/glide-core/redis-rs/afl/parser/Cargo.toml @@ -0,0 +1,17 @@ +[package] +name = "fuzz-target-parser" +version = "0.1.0" +authors = ["redis-rs developers"] +edition = "2018" + +[[bin]] +name = "fuzz-target" +path = "src/main.rs" + +[[bin]] +name = "reproduce" +path = "src/reproduce.rs" + +[dependencies] +afl = "0.4" +redis = { path = "../../redis" } diff --git a/glide-core/redis-rs/afl/parser/in/array b/glide-core/redis-rs/afl/parser/in/array new file mode 100644 index 0000000000..c92e405790 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/array @@ -0,0 +1,5 @@ +*3 +:1 +$-1 +$2 +hi diff --git a/glide-core/redis-rs/afl/parser/in/array-null b/glide-core/redis-rs/afl/parser/in/array-null new file mode 100644 index 0000000000..e0f619c1b3 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/array-null @@ -0,0 +1 @@ +*-1 diff --git a/glide-core/redis-rs/afl/parser/in/bulkstring b/glide-core/redis-rs/afl/parser/in/bulkstring new file mode 100644 index 0000000000..930878abea --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/bulkstring @@ -0,0 +1,2 @@ +$6 +foobar diff --git a/glide-core/redis-rs/afl/parser/in/bulkstring-null b/glide-core/redis-rs/afl/parser/in/bulkstring-null new file mode 100644 index 0000000000..f4280bede5 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/bulkstring-null @@ -0,0 +1 @@ +$-1 diff --git a/glide-core/redis-rs/afl/parser/in/error b/glide-core/redis-rs/afl/parser/in/error new file mode 100644 index 0000000000..7cfd9a521a --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/error @@ -0,0 +1 @@ +-ERR unknown command diff --git a/glide-core/redis-rs/afl/parser/in/integer b/glide-core/redis-rs/afl/parser/in/integer new file mode 100644 index 0000000000..49525f0d45 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/integer @@ -0,0 +1 @@ +:1337 diff --git a/glide-core/redis-rs/afl/parser/in/invalid-string b/glide-core/redis-rs/afl/parser/in/invalid-string new file mode 100644 index 0000000000..604dd3e85a --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/invalid-string @@ -0,0 +1,2 @@ +$6 +foo diff --git a/glide-core/redis-rs/afl/parser/in/string b/glide-core/redis-rs/afl/parser/in/string new file mode 100644 index 0000000000..054430c700 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/in/string @@ -0,0 +1 @@ ++OK diff --git a/glide-core/redis-rs/afl/parser/src/main.rs b/glide-core/redis-rs/afl/parser/src/main.rs new file mode 100644 index 0000000000..6dc674edff --- /dev/null +++ b/glide-core/redis-rs/afl/parser/src/main.rs @@ -0,0 +1,9 @@ +use afl::fuzz; + +use redis::parse_redis_value; + +fn main() { + fuzz!(|data: &[u8]| { + let _ = parse_redis_value(data); + }); +} diff --git a/glide-core/redis-rs/afl/parser/src/reproduce.rs b/glide-core/redis-rs/afl/parser/src/reproduce.rs new file mode 100644 index 0000000000..086dfffb50 --- /dev/null +++ b/glide-core/redis-rs/afl/parser/src/reproduce.rs @@ -0,0 +1,13 @@ +use redis::parse_redis_value; + +fn main() { + let args: Vec = std::env::args().collect(); + if args.len() != 2 { + println!("Usage: {} ", args[0]); + std::process::exit(1); + } + + let data = std::fs::read(&args[1]).expect(&format!("Could not open file {}", args[1])); + let v = parse_redis_value(&data); + println!("Result: {:?}", v); +} diff --git a/glide-core/redis-rs/appveyor.yml b/glide-core/redis-rs/appveyor.yml new file mode 100644 index 0000000000..8310b8def5 --- /dev/null +++ b/glide-core/redis-rs/appveyor.yml @@ -0,0 +1,23 @@ +os: Visual Studio 2015 + +environment: + REDISRS_SERVER_TYPE: tcp + RUST_BACKTRACE: 1 + matrix: + - channel: stable + target: x86_64-pc-windows-msvc + - channel: stable + target: x86_64-pc-windows-gnu +install: + - appveyor DownloadFile https://win.rustup.rs/ -FileName rustup-init.exe + - rustup-init -yv --default-toolchain %channel% --default-host %target% + - set PATH=%PATH%;%USERPROFILE%\.cargo\bin + - rustc -vV + - cargo -vV + - cmd: nuget install redis-64 -excludeversion + - set PATH=%PATH%;%APPVEYOR_BUILD_FOLDER%\redis-64\tools\ + +build: false + +test_script: + - cargo test --verbose --no-default-features --features tokio-comp %cargoflags% diff --git a/glide-core/redis-rs/redis-test/CHANGELOG.md b/glide-core/redis-rs/redis-test/CHANGELOG.md new file mode 100644 index 0000000000..83d3ab3dc4 --- /dev/null +++ b/glide-core/redis-rs/redis-test/CHANGELOG.md @@ -0,0 +1,44 @@ +### 0.4.0 (2023-03-08) +* Track redis 0.25.0 release + +### 0.3.0 (2023-12-05) +* Track redis 0.24.0 release + +### 0.2.3 (2023-09-01) + +* Track redis 0.23.3 release + +### 0.2.2 (2023-08-10) + +* Track redis 0.23.2 release + +### 0.2.1 (2023-07-28) + +* Track redis 0.23.1 release + + +### 0.2.0 (2023-04-05) + +* Track redis 0.23.0 release + + +### 0.2.0-beta.1 (2023-03-28) + +* Track redis 0.23.0-beta.1 release + + +### 0.1.1 (2022-10-18) + +#### Changes +* Add README +* Update LICENSE file / symlink from parent directory + + + +### 0.1.0 (2022-10-05) + +This is the initial release of the redis-test crate, which aims to provide mocking +for connections and commands. Thanks @tdyas! + +#### Features +* Testing module with support for mocking redis connections and commands ([#465](https://github.com/redis-rs/redis-rs/pull/465) @tdyas) \ No newline at end of file diff --git a/glide-core/redis-rs/redis-test/Cargo.toml b/glide-core/redis-rs/redis-test/Cargo.toml new file mode 100644 index 0000000000..6e0bcc3a9f --- /dev/null +++ b/glide-core/redis-rs/redis-test/Cargo.toml @@ -0,0 +1,26 @@ +[package] +name = "redis-test" +version = "0.4.0" +edition = "2021" +description = "Testing helpers for the `redis` crate" +homepage = "https://github.com/redis-rs/redis-rs" +repository = "https://github.com/redis-rs/redis-rs" +documentation = "https://docs.rs/redis-test" +license = "BSD-3-Clause" +rust-version = "1.65" + +[lib] +bench = false + +[dependencies] +redis = { version = "0.25.0", path = "../redis" } + +bytes = { version = "1", optional = true } +futures = { version = "0.3", optional = true } + +[features] +aio = ["futures", "redis/aio"] + +[dev-dependencies] +redis = { version = "0.25.0", path = "../redis", features = ["aio", "tokio-comp"] } +tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread", "time"] } diff --git a/glide-core/redis-rs/redis-test/LICENSE b/glide-core/redis-rs/redis-test/LICENSE new file mode 100644 index 0000000000..533ac4e5a2 --- /dev/null +++ b/glide-core/redis-rs/redis-test/LICENSE @@ -0,0 +1,33 @@ +Copyright (c) 2022 by redis-rs contributors + +Redis cluster code in parts copyright (c) 2018 by Atsushi Koge. + +Some rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * The names of the contributors may not be used to endorse or + promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/glide-core/redis-rs/redis-test/README.md b/glide-core/redis-rs/redis-test/README.md new file mode 100644 index 0000000000..b89bfc4edb --- /dev/null +++ b/glide-core/redis-rs/redis-test/README.md @@ -0,0 +1,4 @@ +# redis-test + +Testing utilities for the redis-rs crate. + diff --git a/glide-core/redis-rs/redis-test/release.toml b/glide-core/redis-rs/redis-test/release.toml new file mode 100644 index 0000000000..7dc5b7a0a6 --- /dev/null +++ b/glide-core/redis-rs/redis-test/release.toml @@ -0,0 +1 @@ +tag-name = "redis-test-{{version}}" diff --git a/glide-core/redis-rs/redis-test/src/lib.rs b/glide-core/redis-rs/redis-test/src/lib.rs new file mode 100644 index 0000000000..cafe8a347b --- /dev/null +++ b/glide-core/redis-rs/redis-test/src/lib.rs @@ -0,0 +1,426 @@ +//! Testing support +//! +//! This module provides `MockRedisConnection` which implements ConnectionLike and can be +//! used in the same place as any other type that behaves like a Redis connection. This is useful +//! for writing unit tests without needing a Redis server. +//! +//! # Example +//! +//! ```rust +//! use redis::{ConnectionLike, RedisError}; +//! use redis_test::{MockCmd, MockRedisConnection}; +//! +//! fn my_exists(conn: &mut C, key: &str) -> Result { +//! let exists: bool = redis::cmd("EXISTS").arg(key).query(conn)?; +//! Ok(exists) +//! } +//! +//! let mut mock_connection = MockRedisConnection::new(vec![ +//! MockCmd::new(redis::cmd("EXISTS").arg("foo"), Ok("1")), +//! ]); +//! +//! let result = my_exists(&mut mock_connection, "foo").unwrap(); +//! assert_eq!(result, true); +//! ``` + +use std::collections::VecDeque; +use std::sync::{Arc, Mutex}; + +use redis::{Cmd, ConnectionLike, ErrorKind, Pipeline, RedisError, RedisResult, Value}; + +#[cfg(feature = "aio")] +use futures::{future, FutureExt}; + +#[cfg(feature = "aio")] +use redis::{aio::ConnectionLike as AioConnectionLike, RedisFuture}; + +/// Helper trait for converting test values into a `redis::Value` returned from a +/// `MockRedisConnection`. This is necessary because neither `redis::types::ToRedisArgs` +/// nor `redis::types::FromRedisValue` performs the precise conversion needed. +pub trait IntoRedisValue { + /// Convert a value into `redis::Value`. + fn into_redis_value(self) -> Value; +} + +impl IntoRedisValue for String { + fn into_redis_value(self) -> Value { + Value::BulkString(self.as_bytes().to_vec()) + } +} + +impl IntoRedisValue for &str { + fn into_redis_value(self) -> Value { + Value::BulkString(self.as_bytes().to_vec()) + } +} + +#[cfg(feature = "bytes")] +impl IntoRedisValue for bytes::Bytes { + fn into_redis_value(self) -> Value { + Value::BulkString(self.to_vec()) + } +} + +impl IntoRedisValue for Vec { + fn into_redis_value(self) -> Value { + Value::BulkString(self) + } +} + +impl IntoRedisValue for Value { + fn into_redis_value(self) -> Value { + self + } +} + +impl IntoRedisValue for i64 { + fn into_redis_value(self) -> Value { + Value::Int(self) + } +} + +/// Helper trait for converting `redis::Cmd` and `redis::Pipeline` instances into +/// encoded byte vectors. +pub trait IntoRedisCmdBytes { + /// Convert a command into an encoded byte vector. + fn into_redis_cmd_bytes(self) -> Vec; +} + +impl IntoRedisCmdBytes for Cmd { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_command() + } +} + +impl IntoRedisCmdBytes for &Cmd { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_command() + } +} + +impl IntoRedisCmdBytes for &mut Cmd { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_command() + } +} + +impl IntoRedisCmdBytes for Pipeline { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_pipeline() + } +} + +impl IntoRedisCmdBytes for &Pipeline { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_pipeline() + } +} + +impl IntoRedisCmdBytes for &mut Pipeline { + fn into_redis_cmd_bytes(self) -> Vec { + self.get_packed_pipeline() + } +} + +/// Represents a command to be executed against a `MockConnection`. +pub struct MockCmd { + cmd_bytes: Vec, + responses: Result, RedisError>, +} + +impl MockCmd { + /// Create a new `MockCmd` given a Redis command and either a value convertible to + /// a `redis::Value` or a `RedisError`. + pub fn new(cmd: C, response: Result) -> Self + where + C: IntoRedisCmdBytes, + V: IntoRedisValue, + { + MockCmd { + cmd_bytes: cmd.into_redis_cmd_bytes(), + responses: response.map(|r| vec![r.into_redis_value()]), + } + } + + /// Create a new `MockCommand` given a Redis command/pipeline and a vector of value convertible + /// to a `redis::Value` or a `RedisError`. + pub fn with_values(cmd: C, responses: Result, RedisError>) -> Self + where + C: IntoRedisCmdBytes, + V: IntoRedisValue, + { + MockCmd { + cmd_bytes: cmd.into_redis_cmd_bytes(), + responses: responses.map(|xs| xs.into_iter().map(|x| x.into_redis_value()).collect()), + } + } +} + +/// A mock Redis client for testing without a server. `MockRedisConnection` checks whether the +/// client submits a specific sequence of commands and generates an error if it does not. +#[derive(Clone)] +pub struct MockRedisConnection { + commands: Arc>>, +} + +impl MockRedisConnection { + /// Construct a new from the given sequence of commands. + pub fn new(commands: I) -> Self + where + I: IntoIterator, + { + MockRedisConnection { + commands: Arc::new(Mutex::new(VecDeque::from_iter(commands))), + } + } +} + +impl ConnectionLike for MockRedisConnection { + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + let mut commands = self.commands.lock().unwrap(); + let next_cmd = commands.pop_front().ok_or_else(|| { + RedisError::from(( + ErrorKind::ClientError, + "TEST", + "unexpected command".to_owned(), + )) + })?; + + if cmd != next_cmd.cmd_bytes { + return Err(RedisError::from(( + ErrorKind::ClientError, + "TEST", + format!( + "unexpected command: expected={}, actual={}", + String::from_utf8(next_cmd.cmd_bytes) + .unwrap_or_else(|_| "decode error".to_owned()), + String::from_utf8(Vec::from(cmd)).unwrap_or_else(|_| "decode error".to_owned()), + ), + ))); + } + + next_cmd + .responses + .and_then(|values| match values.as_slice() { + [value] => Ok(value.clone()), + [] => Err(RedisError::from(( + ErrorKind::ClientError, + "no value configured as response", + ))), + _ => Err(RedisError::from(( + ErrorKind::ClientError, + "multiple values configured as response for command expecting a single value", + ))), + }) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + _offset: usize, + _count: usize, + ) -> RedisResult> { + let mut commands = self.commands.lock().unwrap(); + let next_cmd = commands.pop_front().ok_or_else(|| { + RedisError::from(( + ErrorKind::ClientError, + "TEST", + "unexpected command".to_owned(), + )) + })?; + + if cmd != next_cmd.cmd_bytes { + return Err(RedisError::from(( + ErrorKind::ClientError, + "TEST", + format!( + "unexpected command: expected={}, actual={}", + String::from_utf8(next_cmd.cmd_bytes) + .unwrap_or_else(|_| "decode error".to_owned()), + String::from_utf8(Vec::from(cmd)).unwrap_or_else(|_| "decode error".to_owned()), + ), + ))); + } + + next_cmd.responses + } + + fn get_db(&self) -> i64 { + 0 + } + + fn check_connection(&mut self) -> bool { + true + } + + fn is_open(&self) -> bool { + true + } +} + +#[cfg(feature = "aio")] +impl AioConnectionLike for MockRedisConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + let packed_cmd = cmd.get_packed_command(); + let response = ::req_packed_command( + self, + packed_cmd.as_slice(), + ); + future::ready(response).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + let packed_cmd = cmd.get_packed_pipeline(); + let response = ::req_packed_commands( + self, + packed_cmd.as_slice(), + offset, + count, + ); + future::ready(response).boxed() + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_closed(&self) -> bool { + false + } +} + +#[cfg(test)] +mod tests { + use super::{MockCmd, MockRedisConnection}; + use redis::{cmd, pipe, ErrorKind, Value}; + + #[test] + fn sync_basic_test() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + MockCmd::new(cmd("SET").arg("bar").arg("foo"), Ok("")), + MockCmd::new(cmd("GET").arg("bar"), Ok("foo")), + ]); + + cmd("SET").arg("foo").arg(42).execute(&mut conn); + assert_eq!(cmd("GET").arg("foo").query(&mut conn), Ok(42)); + + cmd("SET").arg("bar").arg("foo").execute(&mut conn); + assert_eq!( + cmd("GET").arg("bar").query(&mut conn), + Ok(Value::BulkString(b"foo".as_ref().into())) + ); + } + + #[cfg(feature = "aio")] + #[tokio::test] + async fn async_basic_test() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + MockCmd::new(cmd("SET").arg("bar").arg("foo"), Ok("")), + MockCmd::new(cmd("GET").arg("bar"), Ok("foo")), + ]); + + cmd("SET") + .arg("foo") + .arg("42") + .query_async::<_, ()>(&mut conn) + .await + .unwrap(); + let result: Result = cmd("GET").arg("foo").query_async(&mut conn).await; + assert_eq!(result, Ok(42)); + + cmd("SET") + .arg("bar") + .arg("foo") + .query_async::<_, ()>(&mut conn) + .await + .unwrap(); + let result: Result, _> = cmd("GET").arg("bar").query_async(&mut conn).await; + assert_eq!(result.as_deref(), Ok(&b"foo"[..])); + } + + #[test] + fn errors_for_unexpected_commands() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + ]); + + cmd("SET").arg("foo").arg(42).execute(&mut conn); + assert_eq!(cmd("GET").arg("foo").query(&mut conn), Ok(42)); + + let err = cmd("SET") + .arg("bar") + .arg("foo") + .query::<()>(&mut conn) + .unwrap_err(); + assert_eq!(err.kind(), ErrorKind::ClientError); + assert_eq!(err.detail(), Some("unexpected command")); + } + + #[test] + fn errors_for_mismatched_commands() { + let mut conn = MockRedisConnection::new(vec![ + MockCmd::new(cmd("SET").arg("foo").arg(42), Ok("")), + MockCmd::new(cmd("GET").arg("foo"), Ok(42)), + MockCmd::new(cmd("SET").arg("bar").arg("foo"), Ok("")), + ]); + + cmd("SET").arg("foo").arg(42).execute(&mut conn); + let err = cmd("SET") + .arg("bar") + .arg("foo") + .query::<()>(&mut conn) + .unwrap_err(); + assert_eq!(err.kind(), ErrorKind::ClientError); + assert!(err.detail().unwrap().contains("unexpected command")); + } + + #[test] + fn pipeline_basic_test() { + let mut conn = MockRedisConnection::new(vec![MockCmd::with_values( + pipe().cmd("GET").arg("foo").cmd("GET").arg("bar"), + Ok(vec!["hello", "world"]), + )]); + + let results: Vec = pipe() + .cmd("GET") + .arg("foo") + .cmd("GET") + .arg("bar") + .query(&mut conn) + .expect("success"); + assert_eq!(results, vec!["hello", "world"]); + } + + #[test] + fn pipeline_atomic_test() { + let mut conn = MockRedisConnection::new(vec![MockCmd::with_values( + pipe().atomic().cmd("GET").arg("foo").cmd("GET").arg("bar"), + Ok(vec![Value::Array( + vec!["hello", "world"] + .into_iter() + .map(|x| Value::BulkString(x.as_bytes().into())) + .collect(), + )]), + )]); + + let results: Vec = pipe() + .atomic() + .cmd("GET") + .arg("foo") + .cmd("GET") + .arg("bar") + .query(&mut conn) + .expect("success"); + assert_eq!(results, vec!["hello", "world"]); + } +} diff --git a/glide-core/redis-rs/redis/CHANGELOG.md b/glide-core/redis-rs/redis/CHANGELOG.md new file mode 100644 index 0000000000..9c3dd18524 --- /dev/null +++ b/glide-core/redis-rs/redis/CHANGELOG.md @@ -0,0 +1,828 @@ +### 0.25.2 (2024-03-15) + +* MultiplexedConnection: Separate response handling for pipeline. ([#1078](https://github.com/redis-rs/redis-rs/pull/1078)) + +### 0.25.1 (2024-03-12) + +* Fix small disambiguity in examples ([#1072](https://github.com/redis-rs/redis-rs/pull/1072) @sunhuachuang) +* Upgrade to socket2 0.5 ([#1073](https://github.com/redis-rs/redis-rs/pull/1073) @djc) +* Avoid library dependency on futures-time ([#1074](https://github.com/redis-rs/redis-rs/pull/1074) @djc) + + +### 0.25.0 (2024-03-08) + +#### Features + +* **Breaking change**: Add connection timeout to the cluster client ([#834](https://github.com/redis-rs/redis-rs/pull/834)) +* **Breaking change**: Deprecate aio::Connection ([#889](https://github.com/redis-rs/redis-rs/pull/889)) +* Cluster: fix read from replica & missing slots ([#965](https://github.com/redis-rs/redis-rs/pull/965)) +* Async cluster connection: Improve handling of missing connections ([#968](https://github.com/redis-rs/redis-rs/pull/968)) +* Add support for parsing to/from any sized arrays ([#981](https://github.com/redis-rs/redis-rs/pull/981)) +* Upgrade to rustls 0.22 ([#1000](https://github.com/redis-rs/redis-rs/pull/1000) @djc) +* add SMISMEMBER command ([#1002](https://github.com/redis-rs/redis-rs/pull/1002) @Zacaria) +* Add support for some big number types ([#1014](https://github.com/redis-rs/redis-rs/pull/1014) @AkiraMiyakoda) +* Add Support for UUIDs ([#1029](https://github.com/redis-rs/redis-rs/pull/1029) @Rabbitminers) +* Add FromRedisValue::from_owned_redis_value to reduce copies while parsing response ([#1030](https://github.com/redis-rs/redis-rs/pull/1030) @Nathan-Fenner) +* Save reconnected connections during retries ([#1033](https://github.com/redis-rs/redis-rs/pull/1033)) +* Avoid panic on connection failure ([#1035](https://github.com/redis-rs/redis-rs/pull/1035)) +* add disable client setinfo feature and its default mode is off ([#1036](https://github.com/redis-rs/redis-rs/pull/1036) @Ggiggle) +* Reconnect on parsing errors ([#1051](https://github.com/redis-rs/redis-rs/pull/1051)) +* preallocate buffer for evalsha in Script ([#1044](https://github.com/redis-rs/redis-rs/pull/1044) @framlog) + +#### Changes + +* Align more commands routings ([#938](https://github.com/redis-rs/redis-rs/pull/938)) +* Fix HashMap conversion ([#977](https://github.com/redis-rs/redis-rs/pull/977) @mxbrt) +* MultiplexedConnection: Remove unnecessary allocation in send ([#990](https://github.com/redis-rs/redis-rs/pull/990)) +* Tests: Reduce cluster setup flakiness ([#999](https://github.com/redis-rs/redis-rs/pull/999)) +* Remove the unwrap_or! macro ([#1010](https://github.com/redis-rs/redis-rs/pull/1010)) +* Remove allocation from command function ([#1008](https://github.com/redis-rs/redis-rs/pull/1008)) +* Catch panics from task::spawn in tests ([#1015](https://github.com/redis-rs/redis-rs/pull/1015)) +* Fix lint errors from new Rust version ([#1016](https://github.com/redis-rs/redis-rs/pull/1016)) +* Fix warnings that appear only with native-TLS ([#1018](https://github.com/redis-rs/redis-rs/pull/1018)) +* Hide the req_packed_commands from docs ([#1020](https://github.com/redis-rs/redis-rs/pull/1020)) +* Fix documentaion error ([#1022](https://github.com/redis-rs/redis-rs/pull/1022) @rcl-viveksharma) +* Fixes minor grammar mistake in json.rs file ([#1026](https://github.com/redis-rs/redis-rs/pull/1026) @RScrusoe) +* Enable ignored pipe test ([#1027](https://github.com/redis-rs/redis-rs/pull/1027)) +* Fix names of existing async cluster tests ([#1028](https://github.com/redis-rs/redis-rs/pull/1028)) +* Add lock file to keep MSRV constant ([#1039](https://github.com/redis-rs/redis-rs/pull/1039)) +* Fail CI if lock file isn't updated ([#1042](https://github.com/redis-rs/redis-rs/pull/1042)) +* impl Clone/Copy for SetOptions ([#1046](https://github.com/redis-rs/redis-rs/pull/1046) @ahmadbky) +* docs: add "connection-manager" cfg attr ([#1048](https://github.com/redis-rs/redis-rs/pull/1048) @DCNick3) +* Remove the usage of aio::Connection in tests ([#1049](https://github.com/redis-rs/redis-rs/pull/1049)) +* Fix new clippy lints ([#1052](https://github.com/redis-rs/redis-rs/pull/1052)) +* Handle server errors in array response ([#1056](https://github.com/redis-rs/redis-rs/pull/1056)) +* Appease Clippy ([#1061](https://github.com/redis-rs/redis-rs/pull/1061)) +* make Pipeline handle returned bulks correctly ([#1063](https://github.com/redis-rs/redis-rs/pull/1063) @framlog) +* Update mio dependency due to vulnerability ([#1064](https://github.com/redis-rs/redis-rs/pull/1064)) +* Simplify Sink polling logic ([#1065](https://github.com/redis-rs/redis-rs/pull/1065)) +* Separate parsing errors from general response errors ([#1069](https://github.com/redis-rs/redis-rs/pull/1069)) + +### 0.24.0 (2023-12-05) + +#### Features +* **Breaking change**: Support Mutual TLS ([#858](https://github.com/redis-rs/redis-rs/pull/858) @sp-angel) +* Implement `FromRedisValue` for `Box<[T]>` and `Arc<[T]>` ([#799](https://github.com/redis-rs/redis-rs/pull/799) @JOT85) +* Sync Cluster: support multi-slot operations. ([#967](https://github.com/redis-rs/redis-rs/pull/967)) +* Execute multi-node requests using try_request. ([#919](https://github.com/redis-rs/redis-rs/pull/919)) +* Sorted set blocking commands ([#962](https://github.com/redis-rs/redis-rs/pull/962) @gheorghitamutu) +* Allow passing routing information to cluster. ([#899](https://github.com/redis-rs/redis-rs/pull/899)) +* Add `tcp_nodelay` feature ([#941](https://github.com/redis-rs/redis-rs/pull/941) @PureWhiteWu) +* Add support for multi-shard commands. ([#900](https://github.com/redis-rs/redis-rs/pull/900)) + +#### Changes +* Order in usage of ClusterParams. ([#997](https://github.com/redis-rs/redis-rs/pull/997)) +* **Breaking change**: Fix StreamId::contains_key signature ([#783](https://github.com/redis-rs/redis-rs/pull/783) @Ayush1325) +* **Breaking change**: Update Command expiration values to be an appropriate type ([#589](https://github.com/redis-rs/redis-rs/pull/589) @joshleeb) +* **Breaking change**: Bump aHash to v0.8.6 ([#966](https://github.com/redis-rs/redis-rs/pull/966) @aumetra) +* Fix features for `load_native_certs`. ([#996](https://github.com/redis-rs/redis-rs/pull/996)) +* Revert redis-test versioning changes ([#993](https://github.com/redis-rs/redis-rs/pull/993)) +* Tests: Add retries to test cluster creation ([#994](https://github.com/redis-rs/redis-rs/pull/994)) +* Fix sync cluster behavior with transactions. ([#983](https://github.com/redis-rs/redis-rs/pull/983)) +* Sync Pub/Sub - cache received pub/sub messages. ([#910](https://github.com/redis-rs/redis-rs/pull/910)) +* Prefer routing to primary in a transaction. ([#986](https://github.com/redis-rs/redis-rs/pull/986)) +* Accept iterator at `ClusterClient` initialization ([#987](https://github.com/redis-rs/redis-rs/pull/987) @ruanpetterson) +* **Breaking change**: Change timeouts from usize and isize to f64 ([#988](https://github.com/redis-rs/redis-rs/pull/988) @eythorhel19) +* Update minimal rust version to 1.6.5 ([#982](https://github.com/redis-rs/redis-rs/pull/982)) +* Disable JSON module tests for redis 6.2.4. ([#980](https://github.com/redis-rs/redis-rs/pull/980)) +* Add connection string examples ([#976](https://github.com/redis-rs/redis-rs/pull/976) @NuclearOreo) +* Move response policy into multi-node routing. ([#952](https://github.com/redis-rs/redis-rs/pull/952)) +* Added functions that allow tests to check version. ([#963](https://github.com/redis-rs/redis-rs/pull/963)) +* Fix XREADGROUP command ordering as per Redis Docs, and compatibility with Upstash Redis ([#960](https://github.com/redis-rs/redis-rs/pull/960) @prabhpreet) +* Optimize make_pipeline_results by pre-allocate memory ([#957](https://github.com/redis-rs/redis-rs/pull/957) @PureWhiteWu) +* Run module tests sequentially. ([#956](https://github.com/redis-rs/redis-rs/pull/956)) +* Log cluster creation output in tests. ([#955](https://github.com/redis-rs/redis-rs/pull/955)) +* CI: Update and use better maintained github actions. ([#954](https://github.com/redis-rs/redis-rs/pull/954)) +* Call CLIENT SETINFO on new connections. ([#945](https://github.com/redis-rs/redis-rs/pull/945)) +* Deprecate functions that erroneously use `tokio` in their name. ([#913](https://github.com/redis-rs/redis-rs/pull/913)) +* CI: Increase timeouts and use newer redis. ([#949](https://github.com/redis-rs/redis-rs/pull/949)) +* Remove redis version from redis-test. ([#943](https://github.com/redis-rs/redis-rs/pull/943)) + +### 0.23.4 (2023-11-26) +**Yanked** -- Inadvertently introduced breaking changes (sorry!). The changes in this tag +have been pushed to 0.24.0. + +### 0.23.3 (2023-09-01) + +Note that this release fixes a small regression in async Redis Cluster handling of the `PING` command. +Based on updated response aggregation logic in [#888](https://github.com/redis-rs/redis-rs/pull/888), it +will again return a single response instead of an array. + +#### Features +* Add `key_type` command ([#933](https://github.com/redis-rs/redis-rs/pull/933) @bruaba) +* Async cluster: Group responses by response_policy. ([#888](https://github.com/redis-rs/redis-rs/pull/888)) + + +#### Fixes +* Remove unnecessary heap allocation ([#939](https://github.com/redis-rs/redis-rs/pull/939) @thechampagne) +* Sentinel tests: Ensure no ports are used twice ([#915](https://github.com/redis-rs/redis-rs/pull/915)) +* Fix lint issues ([#937](https://github.com/redis-rs/redis-rs/pull/937)) +* Fix JSON serialization error test ([#928](https://github.com/redis-rs/redis-rs/pull/928)) +* Remove unused dependencies ([#916](https://github.com/redis-rs/redis-rs/pull/916)) + + +### 0.23.2 (2023-08-10) + +#### Fixes +* Fix sentinel tests flakiness ([#912](https://github.com/redis-rs/redis-rs/pull/912)) +* Rustls: Remove usage of deprecated method ([#921](https://github.com/redis-rs/redis-rs/pull/921)) +* Fix compiling with sentinel feature, without aio feature ([#922](https://github.com/redis-rs/redis-rs/pull/923) @brocaar) +* Add timeouts to tests github action ([#911](https://github.com/redis-rs/redis-rs/pull/911)) + +### 0.23.1 (2023-07-28) + +#### Features +* Add basic Sentinel functionality ([#836](https://github.com/redis-rs/redis-rs/pull/836) @felipou) +* Enable keep alive on tcp connections via feature ([#886](https://github.com/redis-rs/redis-rs/pull/886) @DoumanAsh) +* Support fan-out commands in cluster-async ([#843](https://github.com/redis-rs/redis-rs/pull/843) @nihohit) +* connection_manager: retry and backoff on reconnect ([#804](https://github.com/redis-rs/redis-rs/pull/804) @nihohit) + +#### Changes +* Tests: Wait for all servers ([#901](https://github.com/redis-rs/redis-rs/pull/901) @barshaul) +* Pin `tempfile` dependency ([#902](https://github.com/redis-rs/redis-rs/pull/902)) +* Update routing data for commands. ([#887](https://github.com/redis-rs/redis-rs/pull/887) @nihohit) +* Add basic benchmark reporting to CI ([#880](https://github.com/redis-rs/redis-rs/pull/880)) +* Add `set_options` cmd ([#879](https://github.com/redis-rs/redis-rs/pull/879) @RokasVaitkevicius) +* Move random connection creation to when needed. ([#882](https://github.com/redis-rs/redis-rs/pull/882) @nihohit) +* Clean up existing benchmarks ([#881](https://github.com/redis-rs/redis-rs/pull/881)) +* Improve async cluster client performance. ([#877](https://github.com/redis-rs/redis-rs/pull/877) @nihohit) +* Allow configuration of cluster retry wait duration ([#859](https://github.com/redis-rs/redis-rs/pull/859) @nihohit) +* Fix async connect when ns resolved to multi ip ([#872](https://github.com/redis-rs/redis-rs/pull/872) @hugefiver) +* Reduce the number of unnecessary clones. ([#874](https://github.com/redis-rs/redis-rs/pull/874) @nihohit) +* Remove connection checking on every request. ([#873](https://github.com/redis-rs/redis-rs/pull/873) @nihohit) +* cluster_async: Wrap internal state with Arc. ([#864](https://github.com/redis-rs/redis-rs/pull/864) @nihohit) +* Fix redirect routing on request with no route. ([#870](https://github.com/redis-rs/redis-rs/pull/870) @nihohit) +* Amend README for macOS users ([#869](https://github.com/redis-rs/redis-rs/pull/869) @sarisssa) +* Improved redirection error handling ([#857](https://github.com/redis-rs/redis-rs/pull/857)) +* Fix minor async client bug. ([#862](https://github.com/redis-rs/redis-rs/pull/862) @nihohit) +* Split aio.rs to separate files. ([#821](https://github.com/redis-rs/redis-rs/pull/821) @nihohit) +* Add time feature to tokio dependency ([#855](https://github.com/redis-rs/redis-rs/pull/855) @robjtede) +* Refactor cluster error handling ([#844](https://github.com/redis-rs/redis-rs/pull/844)) +* Fix unnecessarily mutable variable ([#849](https://github.com/redis-rs/redis-rs/pull/849) @kamulos) +* Newtype SlotMap ([#845](https://github.com/redis-rs/redis-rs/pull/845)) +* Bump MSRV to 1.60 ([#846](https://github.com/redis-rs/redis-rs/pull/846)) +* Improve error logging. ([#838](https://github.com/redis-rs/redis-rs/pull/838) @nihohit) +* Improve documentation, add references to `redis-macros` ([#769](https://github.com/redis-rs/redis-rs/pull/769) @daniel7grant) +* Allow creating Cmd with capacity. ([#817](https://github.com/redis-rs/redis-rs/pull/817) @nihohit) + + +### 0.23.0 (2023-04-05) +In addition to *everything mentioned in 0.23.0-beta.1 notes*, this release adds support for Rustls, a long- +sought feature. Thanks to @rharish101 and @LeoRowan for getting this in! + +#### Changes +* Update Rustls to v0.21.0 ([#820](https://github.com/redis-rs/redis-rs/pull/820) @rharish101) +* Implement support for Rustls ([#725](https://github.com/redis-rs/redis-rs/pull/725) @rharish101, @LeoRowan) + + +### 0.23.0-beta.1 (2023-03-28) + +This release adds the `cluster_async` module, which introduces async Redis Cluster support. The code therein +is largely taken from @Marwes's [redis-cluster-async crate](https://github.com/redis-rs/redis-cluster-async), which itself +appears to have started from a sync Redis Cluster implementation started by @atuk721. In any case, thanks to @Marwes and @atuk721 +for the great work, and we hope to keep development moving forward in `redis-rs`. + +Though async Redis Cluster functionality for the time being has been kept as close to the originating crate as possible, previous users of +`redis-cluster-async` should note the following changes: +* Retries, while still configurable, can no longer be set to `None`/infinite retries +* Routing and slot parsing logic has been removed and merged with existing `redis-rs` functionality +* The client has been removed and superceded by common `ClusterClient` +* Renamed `Connection` to `ClusterConnection` +* Added support for reading from replicas +* Added support for insecure TLS +* Added support for setting both username and password + +#### Breaking Changes +* Fix long-standing bug related to `AsyncIter`'s stream implementation in which polling the server + for additional data yielded broken data in most cases. Type bounds for `AsyncIter` have changed slightly, + making this a potentially breaking change. ([#597](https://github.com/redis-rs/redis-rs/pull/597) @roger) + +#### Changes +* Commands: Add additional generic args for key arguments ([#795](https://github.com/redis-rs/redis-rs/pull/795) @MaxOhn) +* Add `mset` / deprecate `set_multiple` ([#766](https://github.com/redis-rs/redis-rs/pull/766) @randomairborne) +* More efficient interfaces for `MultiplexedConnection` and `ConnectionManager` ([#811](https://github.com/redis-rs/redis-rs/pull/811) @nihohit) +* Refactor / remove flaky test ([#810](https://github.com/redis-rs/redis-rs/pull/810)) +* `cluster_async`: rename `Connection` to `ClusterConnection`, `Pipeline` to `ClusterConnInner` ([#808](https://github.com/redis-rs/redis-rs/pull/808)) +* Support parsing IPV6 cluster nodes ([#796](https://github.com/redis-rs/redis-rs/pull/796) @socs) +* Common client for sync/async cluster connections ([#798](https://github.com/redis-rs/redis-rs/pull/798)) + * `cluster::ClusterConnection` underlying connection type is now generic (with existing type as default) + * Support `read_from_replicas` in cluster_async + * Set retries in `ClusterClientBuilder` + * Add mock tests for `cluster` +* cluster-async common slot parsing([#793](https://github.com/redis-rs/redis-rs/pull/793)) +* Support async-std in cluster_async module ([#790](https://github.com/redis-rs/redis-rs/pull/790)) +* Async-Cluster use same routing as Sync-Cluster ([#789](https://github.com/redis-rs/redis-rs/pull/789)) +* Add Async Cluster Support ([#696](https://github.com/redis-rs/redis-rs/pull/696)) +* Fix broken json-module tests ([#786](https://github.com/redis-rs/redis-rs/pull/786)) +* `cluster`: Tls Builder support / simplify cluster connection map ([#718](https://github.com/redis-rs/redis-rs/pull/718) @0xWOF, @utkarshgupta137) + + +### 0.22.3 (2023-01-23) + +#### Changes +* Restore inherent `ClusterConnection::check_connection()` method ([#758](https://github.com/redis-rs/redis-rs/pull/758) @robjtede) + + + +### 0.22.2 (2023-01-07) + +This release adds various incremental improvements and fixes a few long-standing bugs. Thanks to all our +contributors for making this release happen. + +#### Features +* Implement ToRedisArgs for HashMap ([#722](https://github.com/redis-rs/redis-rs/pull/722) @gibranamparan) +* Add explicit `MGET` command ([#729](https://github.com/redis-rs/redis-rs/pull/729) @vamshiaruru-virgodesigns) + +#### Bug fixes +* Enable single-item-vector `get` responses ([#507](https://github.com/redis-rs/redis-rs/pull/507) @hank121314) +* Fix empty result from xread_options with deleted entries ([#712](https://github.com/redis-rs/redis-rs/pull/712) @Quiwin) +* Limit Parser Recursion ([#724](https://github.com/redis-rs/redis-rs/pull/724)) +* Improve MultiplexedConnection Error Handling ([#699](https://github.com/redis-rs/redis-rs/pull/699)) + +#### Changes +* Add test case for atomic pipeline ([#702](https://github.com/redis-rs/redis-rs/pull/702) @CNLHC) +* Capture subscribe result error in PubSub doc example ([#739](https://github.com/redis-rs/redis-rs/pull/739) @baoyachi) +* Use async-std name resolution when necessary ([#701](https://github.com/redis-rs/redis-rs/pull/701) @UgnilJoZ) +* Add Script::invoke_async method ([#711](https://github.com/redis-rs/redis-rs/pull/711) @r-bk) +* Cluster Refactorings ([#717](https://github.com/redis-rs/redis-rs/pull/717), [#716](https://github.com/redis-rs/redis-rs/pull/716), [#709](https://github.com/redis-rs/redis-rs/pull/709), [#707](https://github.com/redis-rs/redis-rs/pull/707), [#706](https://github.com/redis-rs/redis-rs/pull/706) @0xWOF, @utkarshgupta137) +* Fix intermitent test failure ([#714](https://github.com/redis-rs/redis-rs/pull/714) @0xWOF, @utkarshgupta137) +* Doc changes ([#705](https://github.com/redis-rs/redis-rs/pull/705) @0xWOF, @utkarshgupta137) +* Lint fixes ([#704](https://github.com/redis-rs/redis-rs/pull/704) @0xWOF) + + + +### 0.22.1 (2022-10-18) + +#### Changes +* Add README attribute to Cargo.toml +* Update LICENSE file / symlink from parent directory + + +### 0.22.0 (2022-10-05) + +This release adds various incremental improvements, including +additional convenience commands, improved Cluster APIs, and various other bug +fixes/library improvements. + +Although the changes here are incremental, this is a major release due to the +breaking changes listed below. + +This release would not be possible without our many wonderful +contributors -- thank you! + +#### Breaking changes +* Box all large enum variants; changes enum signature ([#667](https://github.com/redis-rs/redis-rs/pull/667) @nihohit) +* Support ACL commands by adding Rule::Other to cover newly defined flags; adds new enum variant ([#685](https://github.com/redis-rs/redis-rs/pull/685) @garyhai) +* Switch from sha1 to sha1_smol; renames `sha1` feature ([#576](https://github.com/redis-rs/redis-rs/pull/576)) + +#### Features +* Add support for RedisJSON ([#657](https://github.com/redis-rs/redis-rs/pull/657) @d3rpp) +* Add support for weights in zunionstore and zinterstore ([#641](https://github.com/redis-rs/redis-rs/pull/641) @ndd7xv) +* Cluster: Create read_from_replicas option ([#635](https://github.com/redis-rs/redis-rs/pull/635) @utkarshgupta137) +* Make Direction a public enum to use with Commands like BLMOVE ([#646](https://github.com/redis-rs/redis-rs/pull/646) @thorbadour) +* Add `ahash` feature for using ahash internally & for redis values ([#636](https://github.com/redis-rs/redis-rs/pull/636) @utkarshgupta137) +* Add Script::load function ([#603](https://github.com/redis-rs/redis-rs/pull/603) @zhiburt) +* Add support for OBJECT ([[#610]](https://github.com/redis-rs/redis-rs/pull/610) @roger) +* Add GETEX and GETDEL support ([#582](https://github.com/redis-rs/redis-rs/pull/582) @arpandaze) +* Add support for ZMPOP ([#605](https://github.com/redis-rs/redis-rs/pull/605) @gkorland) + +#### Changes +* Rust 2021 Edition / MSRV 1.59.0 +* Fix: Support IPV6 link-local address parsing ([#679](https://github.com/redis-rs/redis-rs/pull/679) @buaazp) +* Derive Clone and add Deref trait to InfoDict ([#661](https://github.com/redis-rs/redis-rs/pull/661) @danni-m) +* ClusterClient: add handling for empty initial_nodes, use ClusterParams to store cluster parameters, improve builder pattern ([#669](https://github.com/redis-rs/redis-rs/pull/669) @utkarshgupta137) +* Implement Debug for MultiplexedConnection & Pipeline ([#664](https://github.com/redis-rs/redis-rs/pull/664) @elpiel) +* Add support for casting RedisResult to CString ([#660](https://github.com/redis-rs/redis-rs/pull/660) @nihohit) +* Move redis crate to subdirectory to support multiple crates in project ([#465](https://github.com/redis-rs/redis-rs/pull/465) @tdyas) +* Stop versioning Cargo.lock ([#620](https://github.com/redis-rs/redis-rs/pull/620)) +* Auto-implement ConnectionLike for DerefMut ([#567](https://github.com/redis-rs/redis-rs/pull/567) @holmesmr) +* Return errors from parsing inner items ([#608](https://github.com/redis-rs/redis-rs/pull/608)) +* Make dns resolution async, in async runtime ([#606](https://github.com/redis-rs/redis-rs/pull/606) @roger) +* Make async_trait dependency optional ([#572](https://github.com/redis-rs/redis-rs/pull/572) @kamulos) +* Add username to ClusterClient and ClusterConnection ([#596](https://github.com/redis-rs/redis-rs/pull/596) @gildaf) + + + +### 0.21.6 (2022-08-24) + +* Update dependencies ([#588](https://github.com/mitsuhiko/redis-rs/pull/588)) + + +### 0.21.5 (2022-01-10) + +#### Features + +* Add new list commands ([#570](https://github.com/mitsuhiko/redis-rs/pull/570)) + + + +### 0.21.4 (2021-11-04) + +#### Features + +* Add convenience command: zrandbember ([#556](https://github.com/mitsuhiko/redis-rs/pull/556)) + + + + +### 0.21.3 (2021-10-15) + +#### Features + +* Add support for TLS with cluster mode ([#548](https://github.com/mitsuhiko/redis-rs/pull/548)) + +#### Changes + +* Remove stunnel as a dep, use redis native tls ([#542](https://github.com/mitsuhiko/redis-rs/pull/542)) + + + + + +### 0.21.2 (2021-09-02) + + +#### Bug Fixes + +* Compile with tokio-comp and up-to-date dependencies ([282f997e](https://github.com/mitsuhiko/redis-rs/commit/282f997e41cc0de2a604c0f6a96d82818dacc373), closes [#531](https://github.com/mitsuhiko/redis-rs/issues/531), breaks [#](https://github.com/mitsuhiko/redis-rs/issues/)) + +#### Breaking Changes + +* Compile with tokio-comp and up-to-date dependencies ([282f997e](https://github.com/mitsuhiko/redis-rs/commit/282f997e41cc0de2a604c0f6a96d82818dacc373), closes [#531](https://github.com/mitsuhiko/redis-rs/issues/531), breaks [#](https://github.com/mitsuhiko/redis-rs/issues/)) + + + + +### 0.21.1 (2021-08-25) + + +#### Bug Fixes + +* pin futures dependency to required version ([9be392bc](https://github.com/mitsuhiko/redis-rs/commit/9be392bc5b22326a8a0eafc7aa18cc04c5d79e0e)) + + + + +### 0.21.0 (2021-07-16) + + +#### Performance + +* Don't enqueue multiplexed commands if the receiver is dropped ([ca5019db](https://github.com/mitsuhiko/redis-rs/commit/ca5019dbe76cc56c93eaecb5721de8fcf74d1641)) + +#### Features + +* Refactor ConnectionAddr to remove boxing and clarify fields + + +### 0.20.2 (2021-06-17) + +#### Features + +* Provide a new_async_std function ([c3716d15](https://github.com/mitsuhiko/redis-rs/commit/c3716d154f067b71acdd5bd927e118305cd0830b)) + +#### Bug Fixes + +* Return Ready(Ok(())) when we have flushed all messages ([ca319c06](https://github.com/mitsuhiko/redis-rs/commit/ca319c06ad80fc37f1f701aecebbd5dabb0dceb0)) +* Don't loop forever on shutdown of the multiplexed connection ([ddecce9e](https://github.com/mitsuhiko/redis-rs/commit/ddecce9e10b8ab626f41409aae289d62b4fb74be)) + + + + +### 0.20.1 (2021-05-18) + + +#### Bug Fixes + +* Error properly if eof is reached in the decoder ([306797c3](https://github.com/mitsuhiko/redis-rs/commit/306797c3c55ab24e0a29b6517356af794731d326)) + + + + +## 0.20.0 (2021-02-17) + + +#### Features + +* Make ErrorKind non_exhaustive for forwards compatibility ([ac5e1a60](https://github.com/mitsuhiko/redis-rs/commit/ac5e1a60d398814b18ed1b579fe3f6b337e545e9)) +* **aio:** Allow the underlying IO stream to be customized ([6d2fc8fa](https://github.com/mitsuhiko/redis-rs/commit/6d2fc8faa707fbbbaae9fe092bbc90ce01224523)) + + + + +## 0.19.0 (2020-12-26) + + +#### Features + +* Update to tokio 1.0 ([41960194](https://github.com/mitsuhiko/redis-rs/commit/4196019494aafc2bab718bafd1fdfd5e8c195ffa)) +* use the node specified in the MOVED error ([8a53e269](https://github.com/mitsuhiko/redis-rs/commit/8a53e2699d7d7bd63f222de778ed6820b0a65665)) + + + + +## 0.18.0 (2020-12-03) + + +#### Bug Fixes + +* Don't require tokio for the connection manager ([46be86f3](https://github.com/mitsuhiko/redis-rs/commit/46be86f3f07df4900559bf9a4dfd0b5138c3ac52)) + +* Make ToRedisArgs and FromRedisValue consistent for booleans + +BREAKING CHANGE + +bool are now written as 0 and 1 instead of true and false. Parsing a bool still accept true and false so this should not break anything for most users however if you are reading something out that was stored as a bool you may see different results. + +#### Features + +* Update tokio dependency to 0.3 ([bf5e0af3](https://github.com/mitsuhiko/redis-rs/commit/bf5e0af31c08be1785656031ffda96c355ee83c4), closes [#396](https://github.com/mitsuhiko/redis-rs/issues/396)) +* add doc_cfg for Makefile and docs.rs config ([1bf79517](https://github.com/mitsuhiko/redis-rs/commit/1bf795174521160934f3695326897458246e4978)) +* Impl FromRedisValue for i128 and u128 + + +# Changelog + +## [0.18.0](https://github.com/mitsuhiko/redis-rs/compare/0.17.0...0.18.0) - 2020-12-03 + +## [0.17.0](https://github.com/mitsuhiko/redis-rs/compare/0.16.0...0.17.0) - 2020-07-29 + +**Fixes and improvements** + +* Added Redis Streams commands ([#162](https://github.com/mitsuhiko/redis-rs/pull/319)) +* Added support for zpopmin and zpopmax ([#351](https://github.com/mitsuhiko/redis-rs/pull/351)) +* Added TLS support, gated by a feature flag ([#305](https://github.com/mitsuhiko/redis-rs/pull/305)) +* Added Debug and Clone implementations to redis::Script ([#365](https://github.com/mitsuhiko/redis-rs/pull/365)) +* Added FromStr for ConnectionInfo ([#368](https://github.com/mitsuhiko/redis-rs/pull/368)) +* Support SCAN methods on async connections ([#326](https://github.com/mitsuhiko/redis-rs/pull/326)) +* Removed unnecessary overhead around `Value` conversions ([#327](https://github.com/mitsuhiko/redis-rs/pull/327)) +* Support for Redis 6 auth ([#341](https://github.com/mitsuhiko/redis-rs/pull/341)) +* BUGFIX: Make aio::Connection Sync again ([#321](https://github.com/mitsuhiko/redis-rs/pull/321)) +* BUGFIX: Return UnexpectedEof if we try to decode at eof ([#322](https://github.com/mitsuhiko/redis-rs/pull/322)) +* Added support to create a connection from a (host, port) tuple ([#370](https://github.com/mitsuhiko/redis-rs/pull/370)) + +## [0.16.0](https://github.com/mitsuhiko/redis-rs/compare/0.15.1...0.16.0) - 2020-05-10 + +**Fixes and improvements** + +* Reduce dependencies without async IO ([#266](https://github.com/mitsuhiko/redis-rs/pull/266)) +* Add an afl fuzz target ([#274](https://github.com/mitsuhiko/redis-rs/pull/274)) +* Updated to combine 4 and avoid async dependencies for sync-only ([#272](https://github.com/mitsuhiko/redis-rs/pull/272)) + * **BREAKING CHANGE**: The parser type now only persists the buffer and takes the Read instance in `parse_value` +* Implement a connection manager for automatic reconnection ([#278](https://github.com/mitsuhiko/redis-rs/pull/278)) +* Add async-std support ([#281](https://github.com/mitsuhiko/redis-rs/pull/281)) +* Fix key extraction for some stream commands ([#283](https://github.com/mitsuhiko/redis-rs/pull/283)) +* Add asynchronous PubSub support ([#287](https://github.com/mitsuhiko/redis-rs/pull/287)) + +### Breaking changes + +#### Changes to the `Parser` type ([#272](https://github.com/mitsuhiko/redis-rs/pull/272)) + +The parser type now only persists the buffer and takes the Read instance in `parse_value`. +`redis::parse_redis_value` is unchanged and continues to work. + + +Old: + +```rust +let mut parser = Parser::new(bytes); +let result = parser.parse_value(); +``` + +New: + +```rust +let mut parser = Parser::new(); +let result = parser.pase_value(bytes); +``` + +## [0.15.1](https://github.com/mitsuhiko/redis-rs/compare/0.15.0...0.15.1) - 2020-01-15 + +**Fixes and improvements** + +* Fixed the `r2d2` feature (re-added it) ([#265](https://github.com/mitsuhiko/redis-rs/pull/265)) + +## [0.15.0](https://github.com/mitsuhiko/redis-rs/compare/0.14.0...0.15.0) - 2020-01-15 + +**Fixes and improvements** + +* Added support for redis cluster ([#239](https://github.com/mitsuhiko/redis-rs/pull/239)) + +## [0.14.0](https://github.com/mitsuhiko/redis-rs/compare/0.13.0...0.14.0) - 2020-01-08 + +**Fixes and improvements** + +* Fix the command verb being sent to redis for `zremrangebyrank` ([#240](https://github.com/mitsuhiko/redis-rs/pull/240)) +* Add `get_connection_with_timeout` to Client ([#243](https://github.com/mitsuhiko/redis-rs/pull/243)) +* **Breaking change:** Add Cmd::get, Cmd::set and remove PipelineCommands ([#253](https://github.com/mitsuhiko/redis-rs/pull/253)) +* Async-ify the API ([#232](https://github.com/mitsuhiko/redis-rs/pull/232)) +* Bump minimal required Rust version to 1.39 (required for the async/await API) +* Add async/await examples ([#261](https://github.com/mitsuhiko/redis-rs/pull/261), [#263](https://github.com/mitsuhiko/redis-rs/pull/263)) +* Added support for PSETEX and PTTL commands. ([#259](https://github.com/mitsuhiko/redis-rs/pull/259)) + +### Breaking changes + +#### Add Cmd::get, Cmd::set and remove PipelineCommands ([#253](https://github.com/mitsuhiko/redis-rs/pull/253)) + +If you are using pipelines and were importing the `PipelineCommands` trait you can now remove that import +and only use the `Commands` trait. + +Old: + +```rust +use redis::{Commands, PipelineCommands}; +``` + +New: + +```rust +use redis::Commands; +``` + +## [0.13.0](https://github.com/mitsuhiko/redis-rs/compare/0.12.0...0.13.0) - 2019-10-14 + +**Fixes and improvements** + +* **Breaking change:** rename `parse_async` to `parse_redis_value_async` for consistency ([ce59cecb](https://github.com/mitsuhiko/redis-rs/commit/ce59cecb830d4217115a4e74e38891e76cf01474)). +* Run clippy over the entire codebase ([#238](https://github.com/mitsuhiko/redis-rs/pull/238)) +* **Breaking change:** Make `Script#invoke_async` generic over `aio::ConnectionLike` ([#242](https://github.com/mitsuhiko/redis-rs/pull/242)) + +### BREAKING CHANGES + +#### Rename `parse_async` to `parse_redis_value_async` for consistency ([ce59cecb](https://github.com/mitsuhiko/redis-rs/commit/ce59cecb830d4217115a4e74e38891e76cf01474)). + +If you used `redis::parse_async` before, you now need to change this to `redis::parse_redis_value_async` +or import the method under the new name: `use redis::parse_redis_value_async`. + +#### Make `Script#invoke_async` generic over `aio::ConnectionLike` ([#242](https://github.com/mitsuhiko/redis-rs/pull/242)) + +`Script#invoke_async` was changed to be generic over `aio::ConnectionLike` in order to support wrapping a `SharedConnection` in user code. +This required adding a new generic parameter to the method, causing an error when the return type is defined using the turbofish syntax. + +Old: + +```rust +redis::Script::new("return ...") + .key("key1") + .arg("an argument") + .invoke_async::() +``` + +New: + +```rust +redis::Script::new("return ...") + .key("key1") + .arg("an argument") + .invoke_async::<_, String>() +``` + +## [0.12.0](https://github.com/mitsuhiko/redis-rs/compare/0.11.0...0.12.0) - 2019-08-26 + +**Fixes and improvements** + +* **Breaking change:** Use `dyn` keyword to avoid deprecation warning ([#223](https://github.com/mitsuhiko/redis-rs/pull/223)) +* **Breaking change:** Update url dependency to v2 ([#234](https://github.com/mitsuhiko/redis-rs/pull/234)) +* **Breaking change:** (async) Fix `Script::invoke_async` return type error ([#233](https://github.com/mitsuhiko/redis-rs/pull/233)) +* Add `GETRANGE` and `SETRANGE` commands ([#202](https://github.com/mitsuhiko/redis-rs/pull/202)) +* Fix `SINTERSTORE` wrapper name, it's now correctly `sinterstore` ([#225](https://github.com/mitsuhiko/redis-rs/pull/225)) +* Allow running `SharedConnection` with any other runtime ([#229](https://github.com/mitsuhiko/redis-rs/pull/229)) +* Reformatted as Edition 2018 code ([#235](https://github.com/mitsuhiko/redis-rs/pull/235)) + +### BREAKING CHANGES + +#### Use `dyn` keyword to avoid deprecation warning ([#223](https://github.com/mitsuhiko/redis-rs/pull/223)) + +Rust nightly deprecated bare trait objects. +This PR adds the `dyn` keyword to all trait objects in order to get rid of the warning. +This bumps the minimal supported Rust version to [Rust 1.27](https://blog.rust-lang.org/2018/06/21/Rust-1.27.html). + +#### Update url dependency to v2 ([#234](https://github.com/mitsuhiko/redis-rs/pull/234)) + +We updated the `url` dependency to v2. We do expose this on our public API on the `redis::parse_redis_url` function. If you depend on that, make sure to also upgrade your direct dependency. + +#### (async) Fix Script::invoke_async return type error ([#233](https://github.com/mitsuhiko/redis-rs/pull/233)) + +Previously, invoking a script with a complex return type would cause the following error: + +``` +Response was of incompatible type: "Not a bulk response" (response was string data('"4b98bef92b171357ddc437b395c7c1a5145ca2bd"')) +``` + +This was because the Future returned when loading the script into the database returns the hash of the script, and thus the return type of `String` would not match the intended return type. + +This commit adds an enum to account for the different Future return types. + + +## [0.11.0](https://github.com/mitsuhiko/redis-rs/compare/0.11.0-beta.2...0.11.0) - 2019-07-19 + +This release includes all fixes & improvements from the two beta releases listed below. +This release contains breaking changes. + +**Fixes and improvements** + +* (async) Fix performance problem for SharedConnection ([#222](https://github.com/mitsuhiko/redis-rs/pull/222)) + +## [0.11.0-beta.2](https://github.com/mitsuhiko/redis-rs/compare/0.11.0-beta.1...0.11.0-beta.2) - 2019-07-14 + +**Fixes and improvements** + +* (async) Don't block the executor from shutting down ([#217](https://github.com/mitsuhiko/redis-rs/pull/217)) + +## [0.11.0-beta.1](https://github.com/mitsuhiko/redis-rs/compare/0.10.0...0.11.0-beta.1) - 2019-05-30 + +**Fixes and improvements** + +* (async) Simplify implicit pipeline handling ([#182](https://github.com/mitsuhiko/redis-rs/pull/182)) +* (async) Use `tokio_sync`'s channels instead of futures ([#195](https://github.com/mitsuhiko/redis-rs/pull/195)) +* (async) Only allocate one oneshot per request ([#194](https://github.com/mitsuhiko/redis-rs/pull/194)) +* Remove redundant BufReader when parsing ([#197](https://github.com/mitsuhiko/redis-rs/pull/197)) +* Hide actual type returned from async parser ([#193](https://github.com/mitsuhiko/redis-rs/pull/193)) +* Use more performant operations for line parsing ([#198](https://github.com/mitsuhiko/redis-rs/pull/198)) +* Optimize the command encoding, see below for **breaking changes** ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) +* Add support for geospatial commands ([#130](https://github.com/mitsuhiko/redis-rs/pull/130)) +* (async) Add support for async Script invocation ([#206](https://github.com/mitsuhiko/redis-rs/pull/206)) + +### BREAKING CHANGES + +#### Renamed the async module to aio ([#189](https://github.com/mitsuhiko/redis-rs/pull/189)) + +`async` is a reserved keyword in Rust 2018, so this avoids the need to write `r#async` in it. + +Old code: + +```rust +use redis::async::SharedConnection; +``` + +New code: + +```rust +use redis::aio::SharedConnection; +``` + +#### The trait `ToRedisArgs` was changed ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) + +`ToRedisArgs` has been changed to take take an instance of `RedisWrite` instead of `Vec>`. Use the `write_arg` method instead of `Vec::push`. + +#### Minimum Rust version is now 1.26 ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) + +Upgrade your compiler. +`impl Iterator` is used, requiring a more recent version of the Rust compiler. + +#### `iter` now takes `self` by value ([#165](https://github.com/mitsuhiko/redis-rs/pull/165)) + +`iter` now takes `self` by value instead of cloning `self` inside the method. + +Old code: + +```rust +let mut iter : redis::Iter = cmd.arg("my_set").cursor_arg(0).iter(&con).unwrap(); +``` + +New code: + +```rust +let mut iter : redis::Iter = cmd.arg("my_set").cursor_arg(0).clone().iter(&con).unwrap(); +``` + +(The above line calls `clone()`.) + +#### A mutable connection object is now required ([#148](https://github.com/mitsuhiko/redis-rs/pull/148)) + +We removed the internal usage of `RefCell` and `Cell` and instead require a mutable reference, `&mut ConnectionLike`, +on all command calls. + +Old code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/")?; +let con = client.get_connection(None)?; +redis::cmd("SET").arg("my_key").arg(42).execute(&con); +``` + +New code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/")?; +let mut con = client.get_connection(None)?; +redis::cmd("SET").arg("my_key").arg(42).execute(&mut con); +``` + +Due to this, `transaction` has changed. The callback now also receives a mutable reference to the used connection. + +Old code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +let con = client.get_connection(None).unwrap(); +let key = "the_key"; +let (new_val,) : (isize,) = redis::transaction(&con, &[key], |pipe| { + let old_val : isize = con.get(key)?; + pipe + .set(key, old_val + 1).ignore() + .get(key).query(&con) +})?; +``` + +New code: + +```rust +let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +let mut con = client.get_connection(None).unwrap(); +let key = "the_key"; +let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { + let old_val : isize = con.get(key)?; + pipe + .set(key, old_val + 1).ignore() + .get(key).query(&con) +})?; +``` + +#### Remove `rustc-serialize` feature ([#200](https://github.com/mitsuhiko/redis-rs/pull/200)) + +We removed serialization to/from JSON. The underlying library is deprecated for a long time. + +Old code in `Cargo.toml`: + +``` +[dependencies.redis] +version = "0.9.1" +features = ["with-rustc-json"] +``` + +There's no replacement for the feature. +Use [serde](https://serde.rs/) and handle the serialization/deserialization in your own code. + +#### Remove `with-unix-sockets` feature ([#201](https://github.com/mitsuhiko/redis-rs/pull/201)) + +We removed the Unix socket feature. It is now always enabled. +We also removed auto-detection. + +Old code in `Cargo.toml`: + +``` +[dependencies.redis] +version = "0.9.1" +features = ["with-unix-sockets"] +``` + +There's no replacement for the feature. Unix sockets will continue to work by default. + +## [0.10.0](https://github.com/mitsuhiko/redis-rs/compare/0.9.1...0.10.0) - 2019-02-19 + +* Fix handling of passwords with special characters (#163) +* Better performance for async code due to less boxing (#167) + * CAUTION: redis-rs will now require Rust 1.26 +* Add `clear` method to the pipeline (#176) +* Better benchmarking (#179) +* Fully formatted source code (#181) + +## [0.9.1](https://github.com/mitsuhiko/redis-rs/compare/0.9.0...0.9.1) (2018-09-10) + +* Add ttl command + +## [0.9.0](https://github.com/mitsuhiko/redis-rs/compare/0.8.0...0.9.0) (2018-08-08) + +Some time has passed since the last release. +This new release will bring less bugs, more commands, experimental async support and better performance. + +Highlights: + +* Implement flexible PubSub API (#136) +* Avoid allocating some redundant Vec's during encoding (#140) +* Add an async interface using futures-rs (#141) +* Allow the async connection to have multiple in flight requests (#143) + +The async support is currently experimental. + +## [0.8.0](https://github.com/mitsuhiko/redis-rs/compare/0.7.1...0.8.0) (2016-12-26) + +* Add publish command + +## [0.7.1](https://github.com/mitsuhiko/redis-rs/compare/0.7.0...0.7.1) (2016-12-17) + +* Fix unix socket builds +* Relax lifetimes for scripts + +## [0.7.0](https://github.com/mitsuhiko/redis-rs/compare/0.6.0...0.7.0) (2016-07-23) + +* Add support for built-in unix sockets + +## [0.6.0](https://github.com/mitsuhiko/redis-rs/compare/0.5.4...0.6.0) (2016-07-14) + +* feat: Make rustc-serialize an optional feature (#96) + +## [0.5.4](https://github.com/mitsuhiko/redis-rs/compare/0.5.3...0.5.4) (2016-06-25) + +* fix: Improved single arg handling (#95) +* feat: Implement ToRedisArgs for &String (#89) +* feat: Faster command encoding (#94) + +## [0.5.3](https://github.com/mitsuhiko/redis-rs/compare/0.5.2...0.5.3) (2016-05-03) + +* fix: Use explicit versions for dependencies +* fix: Send `AUTH` command before other commands +* fix: Shutdown connection upon protocol error +* feat: Add `keys` method +* feat: Possibility to set read and write timeouts for the connection diff --git a/glide-core/redis-rs/redis/Cargo.toml b/glide-core/redis-rs/redis/Cargo.toml new file mode 100644 index 0000000000..fd79ff079e --- /dev/null +++ b/glide-core/redis-rs/redis/Cargo.toml @@ -0,0 +1,227 @@ +[package] +name = "redis" +version = "0.25.2" +keywords = ["redis", "database"] +description = "Redis driver for Rust." +homepage = "https://github.com/redis-rs/redis-rs" +repository = "https://github.com/redis-rs/redis-rs" +documentation = "https://docs.rs/redis" +license = "BSD-3-Clause" +edition = "2021" +rust-version = "1.65" +readme = "../README.md" + +[package.metadata.docs.rs] +all-features = true +rustdoc-args = ["--cfg", "docsrs"] + +[lib] +bench = false + +[dependencies] +# These two are generally really common simple dependencies so it does not seem +# much of a point to optimize these, but these could in theory be removed for +# an indirection through std::Formatter. +ryu = "1.0" +itoa = "1.0" + +# Strum is a set of macros and traits for working with enums and strings easier in Rust. +strum = "0.26" +strum_macros = "0.26" + +# This is a dependency that already exists in url +percent-encoding = "2.1" + +# We need this for redis url parsing +url = "= 2.5.0" + +# We need this for script support +sha1_smol = { version = "1.0", optional = true } + +combine = { version = "4.6", default-features = false, features = ["std"] } + +# Only needed for AIO +bytes = { version = "1", optional = true } +futures-util = { version = "0.3.15", default-features = false, optional = true } +pin-project-lite = { version = "0.2", optional = true } +tokio-util = { version = "0.7", optional = true } +tokio = { version = "1", features = ["rt", "net", "time", "sync"] } +socket2 = { version = "0.5", features = ["all"], optional = true } +fast-math = { version = "0.1.1", optional = true } +dispose = { version = "0.5.0", optional = true } + +# Only needed for the connection manager +arc-swap = { version = "1.7.1" } +futures = { version = "0.3.3", optional = true } +tokio-retry = { version = "0.3.0", optional = true } + +# Only needed for the r2d2 feature +r2d2 = { version = "0.8.8", optional = true } + +# Only needed for cluster +crc16 = { version = "0.4", optional = true } +rand = { version = "0.8", optional = true } +derivative = { version = "2.2.0", optional = true } + +# Only needed for async cluster +dashmap = { version = "6.0", optional = true } + +# Only needed for async_std support +async-std = { version = "1.8.0", optional = true } +async-trait = { version = "0.1.24", optional = true } +# To avoid conflicts, backoff-std-async.version != backoff-tokio.version so we could run tests with --all-features +backoff-std-async = { package = "backoff", version = "0.3.0", optional = true, features = ["async-std"] } + +# Only needed for tokio support +backoff-tokio = { package = "backoff", version = "0.4.0", optional = true, features = ["tokio"] } + +# Only needed for native tls +native-tls = { version = "0.2", optional = true } +tokio-native-tls = { version = "0.3", optional = true } +async-native-tls = { version = "0.4", optional = true } + +# Only needed for rustls +rustls = { version = "0.22", optional = true } +webpki-roots = { version = "0.26", optional = true } +rustls-native-certs = { version = "0.7", optional = true } +tokio-rustls = { version = "0.25", optional = true } +futures-rustls = { version = "0.25", optional = true } +rustls-pemfile = { version = "2", optional = true } +rustls-pki-types = { version = "1", optional = true } + +# Only needed for RedisJSON Support +serde = { version = "1.0.82", optional = true } +serde_json = { version = "1.0.82", optional = true } + +# Only needed for bignum Support +rust_decimal = { version = "1.33.1", optional = true } +bigdecimal = { version = "0.4.2", optional = true } +num-bigint = "0.4.4" + +# Optional aHash support +ahash = { version = "0.8.11", optional = true } + +tracing = "0.1" +arcstr = "1.1.5" + +# Optional uuid support +uuid = { version = "1.6.1", optional = true } + +[features] +default = ["acl", "streams", "geospatial", "script", "keep-alive"] +acl = [] +aio = ["bytes", "pin-project-lite", "futures-util", "futures-util/alloc", "futures-util/sink", "tokio/io-util", "tokio-util", "tokio-util/codec", "combine/tokio", "async-trait", "fast-math", "dispose"] +geospatial = [] +json = ["serde", "serde/derive", "serde_json"] +cluster = ["crc16", "rand", "derivative"] +script = ["sha1_smol"] +tls-native-tls = ["native-tls"] +tls-rustls = ["rustls", "rustls-native-certs", "rustls-pemfile", "rustls-pki-types"] +tls-rustls-insecure = ["tls-rustls"] +tls-rustls-webpki-roots = ["tls-rustls", "webpki-roots"] +async-std-comp = ["aio", "async-std", "backoff-std-async"] +async-std-native-tls-comp = ["async-std-comp", "async-native-tls", "tls-native-tls"] +async-std-rustls-comp = ["async-std-comp", "futures-rustls", "tls-rustls"] +tokio-comp = ["aio", "tokio/net", "backoff-tokio"] +tokio-native-tls-comp = ["tokio-comp", "tls-native-tls", "tokio-native-tls"] +tokio-rustls-comp = ["tokio-comp", "tls-rustls", "tokio-rustls"] +connection-manager = ["futures", "aio", "tokio-retry"] +streams = [] +cluster-async = ["cluster", "futures", "futures-util", "dashmap"] +keep-alive = ["socket2"] +sentinel = ["rand"] +tcp_nodelay = [] +rust_decimal = ["dep:rust_decimal"] +bigdecimal = ["dep:bigdecimal"] +num-bigint = [] +uuid = ["dep:uuid"] +disable-client-setinfo = [] + +# Deprecated features +tls = ["tls-native-tls"] # use "tls-native-tls" instead +async-std-tls-comp = ["async-std-native-tls-comp"] # use "async-std-native-tls-comp" instead + +[dev-dependencies] +rand = "0.8" +socket2 = "0.5" +assert_approx_eq = "1.0" +fnv = "1.0.5" +futures = "0.3" +futures-time = "3" +criterion = "0.4" +partial-io = { version = "0.5", features = ["tokio", "quickcheck1"] } +quickcheck = "1.0.3" +tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread", "time"] } +tempfile = "=3.6.0" +once_cell = "1" +anyhow = "1" +sscanf = "0.4.1" + +[[test]] +name = "test_async" +required-features = ["tokio-comp"] + +[[test]] +name = "test_async_async_std" +required-features = ["async-std-comp"] + +[[test]] +name = "parser" +required-features = ["aio"] + +[[test]] +name = "test_acl" + +[[test]] +name = "test_module_json" +required-features = ["json", "serde/derive"] + +[[test]] +name = "test_cluster_async" +required-features = ["cluster-async"] + +[[test]] +name = "test_async_cluster_connections_logic" +required-features = ["cluster-async"] + +[[test]] +name = "test_bignum" + +[[bench]] +name = "bench_basic" +harness = false +required-features = ["tokio-comp"] + +[[bench]] +name = "bench_cluster" +harness = false +required-features = ["cluster"] + +[[bench]] +name = "bench_cluster_async" +harness = false +required-features = ["cluster-async", "tokio-comp"] + +[[example]] +name = "async-multiplexed" +required-features = ["tokio-comp"] + +[[example]] +name = "async-await" +required-features = ["aio"] + +[[example]] +name = "async-pub-sub" +required-features = ["aio"] + +[[example]] +name = "async-scan" +required-features = ["aio"] + +[[example]] +name = "async-connection-loss" +required-features = ["connection-manager"] + +[[example]] +name = "streams" +required-features = ["streams"] diff --git a/glide-core/redis-rs/redis/LICENSE b/glide-core/redis-rs/redis/LICENSE new file mode 100644 index 0000000000..533ac4e5a2 --- /dev/null +++ b/glide-core/redis-rs/redis/LICENSE @@ -0,0 +1,33 @@ +Copyright (c) 2022 by redis-rs contributors + +Redis cluster code in parts copyright (c) 2018 by Atsushi Koge. + +Some rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are +met: + + * Redistributions of source code must retain the above copyright + notice, this list of conditions and the following disclaimer. + + * Redistributions in binary form must reproduce the above + copyright notice, this list of conditions and the following + disclaimer in the documentation and/or other materials provided + with the distribution. + + * The names of the contributors may not be used to endorse or + promote products derived from this software without specific + prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/glide-core/redis-rs/redis/benches/bench_basic.rs b/glide-core/redis-rs/redis/benches/bench_basic.rs new file mode 100644 index 0000000000..356f74217e --- /dev/null +++ b/glide-core/redis-rs/redis/benches/bench_basic.rs @@ -0,0 +1,277 @@ +use criterion::{criterion_group, criterion_main, Bencher, Criterion, Throughput}; +use futures::prelude::*; +use redis::{RedisError, Value}; + +use support::*; + +#[path = "../tests/support/mod.rs"] +mod support; + +fn bench_simple_getsetdel(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + b.iter(|| { + let key = "test_key"; + redis::cmd("SET").arg(key).arg(42).execute(&mut con); + let _: isize = redis::cmd("GET").arg(key).query(&mut con).unwrap(); + redis::cmd("DEL").arg(key).execute(&mut con); + }); +} + +fn bench_simple_getsetdel_async(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let mut con = runtime.block_on(ctx.async_connection()).unwrap(); + + b.iter(|| { + runtime + .block_on(async { + let key = "test_key"; + () = redis::cmd("SET") + .arg(key) + .arg(42) + .query_async(&mut con) + .await?; + let _: isize = redis::cmd("GET").arg(key).query_async(&mut con).await?; + () = redis::cmd("DEL").arg(key).query_async(&mut con).await?; + Ok::<_, RedisError>(()) + }) + .unwrap() + }); +} + +fn bench_simple_getsetdel_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + b.iter(|| { + let key = "test_key"; + let _: (usize,) = redis::pipe() + .cmd("SET") + .arg(key) + .arg(42) + .ignore() + .cmd("GET") + .arg(key) + .cmd("DEL") + .arg(key) + .ignore() + .query(&mut con) + .unwrap(); + }); +} + +fn bench_simple_getsetdel_pipeline_precreated(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let key = "test_key"; + let mut pipe = redis::pipe(); + pipe.cmd("SET") + .arg(key) + .arg(42) + .ignore() + .cmd("GET") + .arg(key) + .cmd("DEL") + .arg(key) + .ignore(); + + b.iter(|| { + let _: (usize,) = pipe.query(&mut con).unwrap(); + }); +} + +const PIPELINE_QUERIES: usize = 1_000; + +fn long_pipeline() -> redis::Pipeline { + let mut pipe = redis::pipe(); + + for i in 0..PIPELINE_QUERIES { + pipe.set(format!("foo{i}"), "bar").ignore(); + } + pipe +} + +fn bench_long_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let pipe = long_pipeline(); + + b.iter(|| { + pipe.query::<()>(&mut con).unwrap(); + }); +} + +fn bench_async_long_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let mut con = runtime.block_on(ctx.async_connection()).unwrap(); + + let pipe = long_pipeline(); + + b.iter(|| { + runtime + .block_on(async { pipe.query_async::<_, ()>(&mut con).await }) + .unwrap(); + }); +} + +fn bench_multiplexed_async_long_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let mut con = runtime + .block_on(ctx.multiplexed_async_connection_tokio()) + .unwrap(); + + let pipe = long_pipeline(); + + b.iter(|| { + runtime + .block_on(async { pipe.query_async::<_, ()>(&mut con).await }) + .unwrap(); + }); +} + +fn bench_multiplexed_async_implicit_pipeline(b: &mut Bencher) { + let ctx = TestContext::new(); + let runtime = current_thread_runtime(); + let con = runtime + .block_on(ctx.multiplexed_async_connection_tokio()) + .unwrap(); + + let cmds: Vec<_> = (0..PIPELINE_QUERIES) + .map(|i| redis::cmd("SET").arg(format!("foo{i}")).arg(i).clone()) + .collect(); + + let mut connections = (0..PIPELINE_QUERIES) + .map(|_| con.clone()) + .collect::>(); + + b.iter(|| { + runtime + .block_on(async { + cmds.iter() + .zip(&mut connections) + .map(|(cmd, con)| cmd.query_async::<_, ()>(con)) + .collect::>() + .try_for_each(|()| async { Ok(()) }) + .await + }) + .unwrap(); + }); +} + +fn bench_query(c: &mut Criterion) { + let mut group = c.benchmark_group("query"); + group + .bench_function("simple_getsetdel", bench_simple_getsetdel) + .bench_function("simple_getsetdel_async", bench_simple_getsetdel_async) + .bench_function("simple_getsetdel_pipeline", bench_simple_getsetdel_pipeline) + .bench_function( + "simple_getsetdel_pipeline_precreated", + bench_simple_getsetdel_pipeline_precreated, + ); + group.finish(); + + let mut group = c.benchmark_group("query_pipeline"); + group + .bench_function( + "multiplexed_async_implicit_pipeline", + bench_multiplexed_async_implicit_pipeline, + ) + .bench_function( + "multiplexed_async_long_pipeline", + bench_multiplexed_async_long_pipeline, + ) + .bench_function("async_long_pipeline", bench_async_long_pipeline) + .bench_function("long_pipeline", bench_long_pipeline) + .throughput(Throughput::Elements(PIPELINE_QUERIES as u64)); + group.finish(); +} + +fn bench_encode_small(b: &mut Bencher) { + b.iter(|| { + let mut cmd = redis::cmd("HSETX"); + + cmd.arg("ABC:1237897325302:878241asdyuxpioaswehqwu") + .arg("some hash key") + .arg(124757920); + + cmd.get_packed_command() + }); +} + +fn bench_encode_integer(b: &mut Bencher) { + b.iter(|| { + let mut pipe = redis::pipe(); + + for _ in 0..1_000 { + pipe.set(123, 45679123).ignore(); + } + pipe.get_packed_pipeline() + }); +} + +fn bench_encode_pipeline(b: &mut Bencher) { + b.iter(|| { + let mut pipe = redis::pipe(); + + for _ in 0..1_000 { + pipe.set("foo", "bar").ignore(); + } + pipe.get_packed_pipeline() + }); +} + +fn bench_encode_pipeline_nested(b: &mut Bencher) { + b.iter(|| { + let mut pipe = redis::pipe(); + + for _ in 0..200 { + pipe.set( + "foo", + ("bar", 123, b"1231279712", &["test", "test", "test"][..]), + ) + .ignore(); + } + pipe.get_packed_pipeline() + }); +} + +fn bench_encode(c: &mut Criterion) { + let mut group = c.benchmark_group("encode"); + group + .bench_function("pipeline", bench_encode_pipeline) + .bench_function("pipeline_nested", bench_encode_pipeline_nested) + .bench_function("integer", bench_encode_integer) + .bench_function("small", bench_encode_small); + group.finish(); +} + +fn bench_decode_simple(b: &mut Bencher, input: &[u8]) { + b.iter(|| redis::parse_redis_value(input).unwrap()); +} +fn bench_decode(c: &mut Criterion) { + let value = Value::Array(vec![ + Value::Okay, + Value::SimpleString("testing".to_string()), + Value::Array(vec![]), + Value::Nil, + Value::BulkString(vec![b'a'; 10]), + Value::Int(7512182390), + ]); + + let mut group = c.benchmark_group("decode"); + { + let mut input = Vec::new(); + support::encode_value(&value, &mut input).unwrap(); + assert_eq!(redis::parse_redis_value(&input).unwrap(), value); + group.bench_function("decode", move |b| bench_decode_simple(b, &input)); + } + group.finish(); +} + +criterion_group!(bench, bench_query, bench_encode, bench_decode); +criterion_main!(bench); diff --git a/glide-core/redis-rs/redis/benches/bench_cluster.rs b/glide-core/redis-rs/redis/benches/bench_cluster.rs new file mode 100644 index 0000000000..da854474ae --- /dev/null +++ b/glide-core/redis-rs/redis/benches/bench_cluster.rs @@ -0,0 +1,108 @@ +#![allow(clippy::unit_arg)] // want to allow this for `black_box()` +#![cfg(feature = "cluster")] +use criterion::{black_box, criterion_group, criterion_main, Criterion, Throughput}; +use redis::cluster::cluster_pipe; + +use support::*; + +#[path = "../tests/support/mod.rs"] +mod support; + +const PIPELINE_QUERIES: usize = 100; + +fn bench_set_get_and_del(c: &mut Criterion, con: &mut redis::cluster::ClusterConnection) { + let key = "test_key"; + + let mut group = c.benchmark_group("cluster_basic"); + + group.bench_function("set", |b| { + b.iter(|| { + redis::cmd("SET").arg(key).arg(42).execute(con); + black_box(()) + }) + }); + + group.bench_function("get", |b| { + b.iter(|| black_box(redis::cmd("GET").arg(key).query::(con).unwrap())) + }); + + let mut set_and_del = || { + redis::cmd("SET").arg(key).arg(42).execute(con); + redis::cmd("DEL").arg(key).execute(con); + }; + group.bench_function("set_and_del", |b| { + b.iter(|| { + set_and_del(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_pipeline(c: &mut Criterion, con: &mut redis::cluster::ClusterConnection) { + let mut group = c.benchmark_group("cluster_pipeline"); + group.throughput(Throughput::Elements(PIPELINE_QUERIES as u64)); + + let mut queries = Vec::new(); + for i in 0..PIPELINE_QUERIES { + queries.push(format!("foo{i}")); + } + + let build_pipeline = || { + let mut pipe = cluster_pipe(); + for q in &queries { + pipe.set(q, "bar").ignore(); + } + }; + group.bench_function("build_pipeline", |b| { + b.iter(|| { + build_pipeline(); + black_box(()) + }) + }); + + let mut pipe = cluster_pipe(); + for q in &queries { + pipe.set(q, "bar").ignore(); + } + group.bench_function("query_pipeline", |b| { + b.iter(|| { + pipe.query::<()>(con).unwrap(); + black_box(()) + }) + }); + + group.finish(); +} + +fn bench_cluster_setup(c: &mut Criterion) { + let cluster = TestClusterContext::new(6, 1); + cluster.wait_for_cluster_up(); + + let mut con = cluster.connection(); + bench_set_get_and_del(c, &mut con); + bench_pipeline(c, &mut con); +} + +#[allow(dead_code)] +fn bench_cluster_read_from_replicas_setup(c: &mut Criterion) { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| builder.read_from_replicas(), + false, + ); + cluster.wait_for_cluster_up(); + + let mut con = cluster.connection(); + bench_set_get_and_del(c, &mut con); + bench_pipeline(c, &mut con); +} + +criterion_group!( + cluster_bench, + bench_cluster_setup, + // bench_cluster_read_from_replicas_setup +); +criterion_main!(cluster_bench); diff --git a/glide-core/redis-rs/redis/benches/bench_cluster_async.rs b/glide-core/redis-rs/redis/benches/bench_cluster_async.rs new file mode 100644 index 0000000000..28c3b83c87 --- /dev/null +++ b/glide-core/redis-rs/redis/benches/bench_cluster_async.rs @@ -0,0 +1,88 @@ +#![allow(clippy::unit_arg)] // want to allow this for `black_box()` +#![cfg(feature = "cluster")] +use criterion::{black_box, criterion_group, criterion_main, Criterion}; +use futures_util::{stream, TryStreamExt}; +use redis::RedisError; + +use support::*; +use tokio::runtime::Runtime; + +#[path = "../tests/support/mod.rs"] +mod support; + +fn bench_cluster_async( + c: &mut Criterion, + con: &mut redis::cluster_async::ClusterConnection, + runtime: &Runtime, +) { + let mut group = c.benchmark_group("cluster_async"); + group.bench_function("set_get_and_del", |b| { + b.iter(|| { + runtime + .block_on(async { + let key = "test_key"; + () = redis::cmd("SET").arg(key).arg(42).query_async(con).await?; + let _: isize = redis::cmd("GET").arg(key).query_async(con).await?; + () = redis::cmd("DEL").arg(key).query_async(con).await?; + + Ok::<_, RedisError>(()) + }) + .unwrap(); + black_box(()) + }) + }); + + group.bench_function("parallel_requests", |b| { + let num_parallel = 100; + let cmds: Vec<_> = (0..num_parallel) + .map(|i| redis::cmd("SET").arg(format!("foo{i}")).arg(i).clone()) + .collect(); + + let mut connections = (0..num_parallel).map(|_| con.clone()).collect::>(); + + b.iter(|| { + runtime + .block_on(async { + cmds.iter() + .zip(&mut connections) + .map(|(cmd, con)| cmd.query_async::<_, ()>(con)) + .collect::>() + .try_for_each(|()| async { Ok(()) }) + .await + }) + .unwrap(); + black_box(()) + }); + }); + + group.bench_function("pipeline", |b| { + let num_queries = 100; + + let mut pipe = redis::pipe(); + + for _ in 0..num_queries { + pipe.set("foo".to_string(), "bar").ignore(); + } + + b.iter(|| { + runtime + .block_on(async { pipe.query_async::<_, ()>(con).await }) + .unwrap(); + black_box(()) + }); + }); + + group.finish(); +} + +fn bench_cluster_setup(c: &mut Criterion) { + let cluster = TestClusterContext::new(6, 1); + cluster.wait_for_cluster_up(); + let runtime = current_thread_runtime(); + let mut con = runtime.block_on(cluster.async_connection(None)); + + bench_cluster_async(c, &mut con, &runtime); +} + +criterion_group!(cluster_async_bench, bench_cluster_setup,); +criterion_main!(cluster_async_bench); diff --git a/glide-core/redis-rs/redis/examples/async-await.rs b/glide-core/redis-rs/redis/examples/async-await.rs new file mode 100644 index 0000000000..2d829c7d60 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-await.rs @@ -0,0 +1,24 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use redis::{AsyncCommands, GlideConnectionOptions}; + +#[tokio::main] +async fn main() -> redis::RedisResult<()> { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let mut con = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + con.set("key1", b"foo").await?; + + redis::cmd("SET") + .arg(&["key2", "bar"]) + .query_async(&mut con) + .await?; + + let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; + assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/async-connection-loss.rs b/glide-core/redis-rs/redis/examples/async-connection-loss.rs new file mode 100644 index 0000000000..a7dba3ab89 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-connection-loss.rs @@ -0,0 +1,97 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +//! This example will connect to Redis in one of three modes: +//! +//! - Regular async connection +//! - Async multiplexed connection +//! - Async connection manager +//! +//! It will then send a PING every 100 ms and print the result. + +use std::env; +use std::process; +use std::time::Duration; + +use futures::future; +use redis::aio::ConnectionLike; +use redis::GlideConnectionOptions; +use redis::RedisResult; +use tokio::time::interval; + +enum Mode { + Deprecated, + Default, + Reconnect, +} + +async fn run_single(mut con: C) -> RedisResult<()> { + let mut interval = interval(Duration::from_millis(100)); + loop { + interval.tick().await; + println!(); + println!("> PING"); + let result: RedisResult = redis::cmd("PING").query_async(&mut con).await; + println!("< {result:?}"); + } +} + +async fn run_multi(mut con: C) -> RedisResult<()> { + let mut interval = interval(Duration::from_millis(100)); + loop { + interval.tick().await; + println!(); + println!("> PING"); + println!("> PING"); + println!("> PING"); + let results: ( + RedisResult, + RedisResult, + RedisResult, + ) = future::join3( + redis::cmd("PING").query_async(&mut con.clone()), + redis::cmd("PING").query_async(&mut con.clone()), + redis::cmd("PING").query_async(&mut con), + ) + .await; + println!("< {:?}", results.0); + println!("< {:?}", results.1); + println!("< {:?}", results.2); + } +} + +#[tokio::main] +async fn main() -> RedisResult<()> { + let mode = match env::args().nth(1).as_deref() { + Some("default") => { + println!("Using default connection mode\n"); + Mode::Default + } + Some("reconnect") => { + println!("Using reconnect manager mode\n"); + Mode::Reconnect + } + Some("deprecated") => { + println!("Using deprecated connection mode\n"); + Mode::Deprecated + } + Some(_) | None => { + println!("Usage: reconnect-manager (default|multiplexed|reconnect)"); + process::exit(1); + } + }; + + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + match mode { + Mode::Default => { + run_multi( + client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await?, + ) + .await? + } + Mode::Reconnect => run_multi(client.get_connection_manager().await?).await?, + #[allow(deprecated)] + Mode::Deprecated => run_single(client.get_async_connection(None).await?).await?, + }; + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/async-multiplexed.rs b/glide-core/redis-rs/redis/examples/async-multiplexed.rs new file mode 100644 index 0000000000..2e5332359b --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-multiplexed.rs @@ -0,0 +1,46 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use futures::prelude::*; +use redis::{aio::MultiplexedConnection, GlideConnectionOptions, RedisResult}; + +async fn test_cmd(con: &MultiplexedConnection, i: i32) -> RedisResult<()> { + let mut con = con.clone(); + + let key = format!("key{i}"); + let key2 = format!("key{i}_2"); + let value = format!("foo{i}"); + + redis::cmd("SET") + .arg(&key) + .arg(&value) + .query_async(&mut con) + .await?; + + redis::cmd("SET") + .arg(&[&key2, "bar"]) + .query_async(&mut con) + .await?; + + redis::cmd("MGET") + .arg(&[&key, &key2]) + .query_async(&mut con) + .map(|result| { + assert_eq!(Ok((value, b"bar".to_vec())), result); + Ok(()) + }) + .await +} + +#[tokio::main] +async fn main() { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + + let con = client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await + .unwrap(); + + let cmds = (0..100).map(|i| test_cmd(&con, i)); + let result = future::try_join_all(cmds).await.unwrap(); + + assert_eq!(100, result.len()); +} diff --git a/glide-core/redis-rs/redis/examples/async-pub-sub.rs b/glide-core/redis-rs/redis/examples/async-pub-sub.rs new file mode 100644 index 0000000000..fe84b44fb9 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-pub-sub.rs @@ -0,0 +1,22 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use futures_util::StreamExt as _; +use redis::{AsyncCommands, GlideConnectionOptions}; + +#[tokio::main] +async fn main() -> redis::RedisResult<()> { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let mut publish_conn = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + let mut pubsub_conn = client.get_async_pubsub().await?; + + pubsub_conn.subscribe("wavephone").await?; + let mut pubsub_stream = pubsub_conn.on_message(); + + publish_conn.publish("wavephone", "banana").await?; + + let pubsub_msg: String = pubsub_stream.next().await.unwrap().get_payload()?; + assert_eq!(&pubsub_msg, "banana"); + + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/async-scan.rs b/glide-core/redis-rs/redis/examples/async-scan.rs new file mode 100644 index 0000000000..06a66fe83e --- /dev/null +++ b/glide-core/redis-rs/redis/examples/async-scan.rs @@ -0,0 +1,25 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use futures::stream::StreamExt; +use redis::{AsyncCommands, AsyncIter, GlideConnectionOptions}; + +#[tokio::main] +async fn main() -> redis::RedisResult<()> { + let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + let mut con = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + con.set("async-key1", b"foo").await?; + con.set("async-key2", b"foo").await?; + + let iter: AsyncIter = con.scan().await?; + let mut keys: Vec<_> = iter.collect().await; + + keys.sort(); + + assert_eq!( + keys, + vec!["async-key1".to_string(), "async-key2".to_string()] + ); + Ok(()) +} diff --git a/glide-core/redis-rs/redis/examples/basic.rs b/glide-core/redis-rs/redis/examples/basic.rs new file mode 100644 index 0000000000..622dc36e59 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/basic.rs @@ -0,0 +1,169 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use redis::{transaction, Commands}; + +use std::collections::HashMap; +use std::env; + +/// This function demonstrates how a return value can be coerced into a +/// hashmap of tuples. This is particularly useful for responses like +/// CONFIG GET or all most H functions which will return responses in +/// such list of implied tuples. +fn do_print_max_entry_limits(con: &mut redis::Connection) -> redis::RedisResult<()> { + // since rust cannot know what format we actually want we need to be + // explicit here and define the type of our response. In this case + // String -> int fits all the items we query for. + let config: HashMap = redis::cmd("CONFIG") + .arg("GET") + .arg("*-max-*-entries") + .query(con)?; + + println!("Max entry limits:"); + + println!( + " max-intset: {}", + config.get("set-max-intset-entries").unwrap_or(&0) + ); + println!( + " hash-max-ziplist: {}", + config.get("hash-max-ziplist-entries").unwrap_or(&0) + ); + println!( + " list-max-ziplist: {}", + config.get("list-max-ziplist-entries").unwrap_or(&0) + ); + println!( + " zset-max-ziplist: {}", + config.get("zset-max-ziplist-entries").unwrap_or(&0) + ); + + Ok(()) +} + +/// This is a pretty stupid example that demonstrates how to create a large +/// set through a pipeline and how to iterate over it through implied +/// cursors. +fn do_show_scanning(con: &mut redis::Connection) -> redis::RedisResult<()> { + // This makes a large pipeline of commands. Because the pipeline is + // modified in place we can just ignore the return value upon the end + // of each iteration. + let mut pipe = redis::pipe(); + for num in 0..1000 { + pipe.cmd("SADD").arg("my_set").arg(num).ignore(); + } + + // since we don't care about the return value of the pipeline we can + // just cast it into the unit type. + pipe.query(con)?; + + // since rust currently does not track temporaries for us, we need to + // store it in a local variable. + let mut cmd = redis::cmd("SSCAN"); + cmd.arg("my_set").cursor_arg(0); + + // as a simple exercise we just sum up the iterator. Since the fold + // method carries an initial value we do not need to define the + // type of the iterator, rust will figure "int" out for us. + let sum: i32 = cmd.iter::(con)?.sum(); + + println!("The sum of all numbers in the set 0-1000: {sum}"); + + Ok(()) +} + +/// Demonstrates how to do an atomic increment in a very low level way. +fn do_atomic_increment_lowlevel(con: &mut redis::Connection) -> redis::RedisResult<()> { + let key = "the_key"; + println!("Run low-level atomic increment:"); + + // set the initial value so we have something to test with. + redis::cmd("SET").arg(key).arg(42).query(con)?; + + loop { + // we need to start watching the key we care about, so that our + // exec fails if the key changes. + redis::cmd("WATCH").arg(key).query(con)?; + + // load the old value, so we know what to increment. + let val: isize = redis::cmd("GET").arg(key).query(con)?; + + // at this point we can go into an atomic pipe (a multi block) + // and set up the keys. + let response: Option<(isize,)> = redis::pipe() + .atomic() + .cmd("SET") + .arg(key) + .arg(val + 1) + .ignore() + .cmd("GET") + .arg(key) + .query(con)?; + + match response { + None => { + continue; + } + Some(response) => { + let (new_val,) = response; + println!(" New value: {new_val}"); + break; + } + } + } + + Ok(()) +} + +/// Demonstrates how to do an atomic increment with transaction support. +fn do_atomic_increment(con: &mut redis::Connection) -> redis::RedisResult<()> { + let key = "the_key"; + println!("Run high-level atomic increment:"); + + // set the initial value so we have something to test with. + con.set(key, 42)?; + + // run the transaction block. + let (new_val,): (isize,) = transaction(con, &[key], |con, pipe| { + // load the old value, so we know what to increment. + let val: isize = con.get(key)?; + // increment + pipe.set(key, val + 1).ignore().get(key).query(con) + })?; + + // and print the result + println!("New value: {new_val}"); + + Ok(()) +} + +/// Runs all the examples and propagates errors up. +fn do_redis_code(url: &str) -> redis::RedisResult<()> { + // general connection handling + let client = redis::Client::open(url)?; + let mut con = client.get_connection(None)?; + + // read some config and print it. + do_print_max_entry_limits(&mut con)?; + + // demonstrate how scanning works. + do_show_scanning(&mut con)?; + + // shows an atomic increment. + do_atomic_increment_lowlevel(&mut con)?; + do_atomic_increment(&mut con)?; + + Ok(()) +} + +fn main() { + // at this point the errors are fatal, let's just fail hard. + let url = if env::args().nth(1) == Some("--tls".into()) { + "rediss://127.0.0.1:6380/#insecure" + } else { + "redis://127.0.0.1:6379/" + }; + + if let Err(err) = do_redis_code(url) { + println!("Could not execute example:"); + println!(" {}: {}", err.category(), err); + } +} diff --git a/glide-core/redis-rs/redis/examples/geospatial.rs b/glide-core/redis-rs/redis/examples/geospatial.rs new file mode 100644 index 0000000000..5033b6c775 --- /dev/null +++ b/glide-core/redis-rs/redis/examples/geospatial.rs @@ -0,0 +1,68 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use std::process::exit; + +use redis::RedisResult; + +#[cfg(feature = "geospatial")] +fn run() -> RedisResult<()> { + use redis::{geo, Commands}; + use std::env; + use std::f64; + + let redis_url = match env::var("REDIS_URL") { + Ok(url) => url, + Err(..) => "redis://127.0.0.1/".to_string(), + }; + + let client = redis::Client::open(redis_url.as_str())?; + let mut con = client.get_connection(None)?; + + // Add some members to the geospatial index. + + let added: isize = con.geo_add( + "gis", + &[ + (geo::Coord::lon_lat("13.361389", "38.115556"), "Palermo"), + (geo::Coord::lon_lat("15.087269", "37.502669"), "Catania"), + (geo::Coord::lon_lat("13.5833332", "37.316667"), "Agrigento"), + ], + )?; + + println!("[geo_add] Added {added} members."); + + // Get the position of one of them. + + let position: Vec> = con.geo_pos("gis", "Palermo")?; + println!("[geo_pos] Position for Palermo: {position:?}"); + + // Search members near (13.5, 37.75) + + let options = geo::RadiusOptions::default() + .order(geo::RadiusOrder::Asc) + .with_dist() + .limit(2); + let items: Vec = + con.geo_radius("gis", 13.5, 37.75, 150.0, geo::Unit::Kilometers, options)?; + + for item in items { + println!( + "[geo_radius] {}, dist = {} Km", + item.name, + item.dist.unwrap_or(f64::NAN) + ); + } + + Ok(()) +} + +#[cfg(not(feature = "geospatial"))] +fn run() -> RedisResult<()> { + Ok(()) +} + +fn main() { + if let Err(e) = run() { + println!("{e:?}"); + exit(1); + } +} diff --git a/glide-core/redis-rs/redis/examples/streams.rs b/glide-core/redis-rs/redis/examples/streams.rs new file mode 100644 index 0000000000..8c40ea487d --- /dev/null +++ b/glide-core/redis-rs/redis/examples/streams.rs @@ -0,0 +1,270 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "streams")] + +use redis::streams::{StreamId, StreamKey, StreamMaxlen, StreamReadOptions, StreamReadReply}; + +use redis::{Commands, RedisResult, Value}; + +use std::thread; +use std::time::Duration; +use std::time::{SystemTime, UNIX_EPOCH}; + +const DOG_STREAM: &str = "example-dog"; +const CAT_STREAM: &str = "example-cat"; +const DUCK_STREAM: &str = "example-duck"; + +const STREAMS: &[&str] = &[DOG_STREAM, CAT_STREAM, DUCK_STREAM]; + +const SLOWNESSES: &[u8] = &[2, 3, 4]; + +/// This program generates an arbitrary set of records across three +/// different streams. It then reads the data back in such a way +/// that demonstrates basic usage of both the XREAD and XREADGROUP +/// commands. +fn main() { + let client = redis::Client::open("redis://127.0.0.1/").expect("client"); + + println!("Demonstrating XADD followed by XREAD, single threaded\n"); + + add_records(&client).expect("contrived record generation"); + + read_records(&client).expect("simple read"); + + demo_group_reads(&client); + + clean_up(&client) +} + +fn demo_group_reads(client: &redis::Client) { + println!("\n\nDemonstrating a longer stream of data flowing\nin over time, consumed by multiple threads using XREADGROUP\n"); + + let mut handles = vec![]; + + let cc = client.clone(); + // Launch a producer thread which repeatedly adds records, + // with only a small delay between writes. + handles.push(thread::spawn(move || { + let repeat = 30; + let slowness = 1; + for _ in 0..repeat { + add_records(&cc).expect("add"); + thread::sleep(Duration::from_millis(random_wait_millis(slowness))) + } + })); + + // Launch consumer threads which repeatedly read from the + // streams at various speeds. They'll effectively compete + // to consume the stream. + // + // Consumer groups are only appropriate for cases where you + // do NOT want each consumer to read ALL of the data. This + // example is a contrived scenario so that each consumer + // receives its own, specific chunk of data. + // + // Once the data is read, the redis-rs lib will automatically + // acknowledge its receipt via XACK. + // + // Read more about reading with consumer groups here: + // https://redis.io/commands/xreadgroup + for slowness in SLOWNESSES { + let repeat = 5; + let ca = client.clone(); + handles.push(thread::spawn(move || { + let mut con = ca.get_connection(None).expect("con"); + + // We must create each group and each consumer + // See https://redis.io/commands/xreadgroup#differences-between-xread-and-xreadgroup + + for key in STREAMS { + let created: Result<(), _> = con.xgroup_create_mkstream(*key, GROUP_NAME, "$"); + if let Err(e) = created { + println!("Group already exists: {e:?}") + } + } + + for _ in 0..repeat { + let read_reply = read_group_records(&ca, *slowness).expect("group read"); + + // fake some expensive work + for StreamKey { key, ids } in read_reply.keys { + for StreamId { id, map: _ } in &ids { + thread::sleep(Duration::from_millis(random_wait_millis(*slowness))); + println!( + "Stream {} ID {} Consumer slowness {} SysTime {}", + key, + id, + slowness, + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("time") + .as_millis() + ); + } + + // acknowledge each stream and message ID once all messages are + // correctly processed + let id_strs: Vec<&String> = + ids.iter().map(|StreamId { id, map: _ }| id).collect(); + con.xack(key, GROUP_NAME, &id_strs).expect("ack") + } + } + })) + } + + for h in handles { + h.join().expect("Join") + } +} + +/// Generate some contrived records and add them to various +/// streams. +fn add_records(client: &redis::Client) -> RedisResult<()> { + let mut con = client.get_connection(None).expect("conn"); + + let maxlen = StreamMaxlen::Approx(1000); + + // a stream whose records have two fields + for _ in 0..thrifty_rand() { + con.xadd_maxlen( + DOG_STREAM, + maxlen, + "*", + &[("bark", arbitrary_value()), ("groom", arbitrary_value())], + )?; + } + + // a streams whose records have three fields + for _ in 0..thrifty_rand() { + con.xadd_maxlen( + CAT_STREAM, + maxlen, + "*", + &[ + ("meow", arbitrary_value()), + ("groom", arbitrary_value()), + ("hunt", arbitrary_value()), + ], + )?; + } + + // a streams whose records have four fields + for _ in 0..thrifty_rand() { + con.xadd_maxlen( + DUCK_STREAM, + maxlen, + "*", + &[ + ("quack", arbitrary_value()), + ("waddle", arbitrary_value()), + ("splash", arbitrary_value()), + ("flap", arbitrary_value()), + ], + )?; + } + + Ok(()) +} + +/// An approximation of randomness, without leaving the stdlib. +fn thrifty_rand() -> u8 { + let penultimate_num = 2; + (SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time travel") + .as_nanos() + % penultimate_num) as u8 + + 1 +} + +const MAGIC: u64 = 11; +fn random_wait_millis(slowness: u8) -> u64 { + thrifty_rand() as u64 * thrifty_rand() as u64 * MAGIC * slowness as u64 +} + +/// Generate a potentially unique value. +fn arbitrary_value() -> String { + format!( + "{}", + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("Time travel") + .as_nanos() + ) +} + +/// Block the thread for this many milliseconds while +/// waiting for data to arrive on the stream. +const BLOCK_MILLIS: usize = 5000; + +/// Read back records from all three streams, if they're available. +/// Doesn't bother with consumer groups. Generally the user +/// would be responsible for keeping track of the most recent +/// ID from which they need to read, but in this example, we +/// just go back to the beginning of time and ask for all the +/// records in the stream. +fn read_records(client: &redis::Client) -> RedisResult<()> { + let mut con = client.get_connection(None).expect("conn"); + + let opts = StreamReadOptions::default().block(BLOCK_MILLIS); + + // Oldest known time index + let starting_id = "0-0"; + // Same as above + let another_form = "0"; + + let srr: StreamReadReply = con + .xread_options(STREAMS, &[starting_id, another_form, starting_id], &opts) + .expect("read"); + + for StreamKey { key, ids } in srr.keys { + println!("Stream {key}"); + for StreamId { id, map } in ids { + println!("\tID {id}"); + for (n, s) in map { + if let Value::BulkString(bytes) = s { + println!("\t\t{}: {}", n, String::from_utf8(bytes).expect("utf8")) + } else { + panic!("Weird data") + } + } + } + } + + Ok(()) +} + +fn consumer_name(slowness: u8) -> String { + format!("example-consumer-{slowness}") +} + +const GROUP_NAME: &str = "example-group-aaa"; + +fn read_group_records(client: &redis::Client, slowness: u8) -> RedisResult { + let mut con = client.get_connection(None).expect("conn"); + + let opts = StreamReadOptions::default() + .block(BLOCK_MILLIS) + .count(3) + .group(GROUP_NAME, consumer_name(slowness)); + + let srr: StreamReadReply = con + .xread_options( + &[DOG_STREAM, CAT_STREAM, DUCK_STREAM], + &[">", ">", ">"], + &opts, + ) + .expect("records"); + + Ok(srr) +} + +fn clean_up(client: &redis::Client) { + let mut con = client.get_connection(None).expect("con"); + for k in STREAMS { + let trimmed: RedisResult<()> = con.xtrim(*k, StreamMaxlen::Equals(0)); + trimmed.expect("trim"); + + let destroyed: RedisResult<()> = con.xgroup_destroy(*k, GROUP_NAME); + destroyed.expect("xgroup destroy"); + } +} diff --git a/glide-core/redis-rs/redis/release.toml b/glide-core/redis-rs/redis/release.toml new file mode 100644 index 0000000000..942730e0b6 --- /dev/null +++ b/glide-core/redis-rs/redis/release.toml @@ -0,0 +1,2 @@ +pre-release-hook = "../scripts/update-versions.sh" +tag-name = "{{version}}" diff --git a/glide-core/redis-rs/redis/src/acl.rs b/glide-core/redis-rs/redis/src/acl.rs new file mode 100644 index 0000000000..ef85877ba6 --- /dev/null +++ b/glide-core/redis-rs/redis/src/acl.rs @@ -0,0 +1,312 @@ +//! Defines types to use with the ACL commands. + +use crate::types::{ + ErrorKind, FromRedisValue, RedisError, RedisResult, RedisWrite, ToRedisArgs, Value, +}; + +macro_rules! not_convertible_error { + ($v:expr, $det:expr) => { + RedisError::from(( + ErrorKind::TypeError, + "Response type not convertible", + format!("{:?} (response was {:?})", $det, $v), + )) + }; +} + +/// ACL rules are used in order to activate or remove a flag, or to perform a +/// given change to the user ACL, which under the hood are just single words. +#[derive(Debug, Eq, PartialEq)] +pub enum Rule { + /// Enable the user: it is possible to authenticate as this user. + On, + /// Disable the user: it's no longer possible to authenticate with this + /// user, however the already authenticated connections will still work. + Off, + + /// Add the command to the list of commands the user can call. + AddCommand(String), + /// Remove the command to the list of commands the user can call. + RemoveCommand(String), + /// Add all the commands in such category to be called by the user. + AddCategory(String), + /// Remove the commands from such category the client can call. + RemoveCategory(String), + /// Alias for `+@all`. Note that it implies the ability to execute all the + /// future commands loaded via the modules system. + AllCommands, + /// Alias for `-@all`. + NoCommands, + + /// Add this password to the list of valid password for the user. + AddPass(String), + /// Remove this password from the list of valid passwords. + RemovePass(String), + /// Add this SHA-256 hash value to the list of valid passwords for the user. + AddHashedPass(String), + /// Remove this hash value from from the list of valid passwords + RemoveHashedPass(String), + /// All the set passwords of the user are removed, and the user is flagged + /// as requiring no password: it means that every password will work + /// against this user. + NoPass, + /// Flush the list of allowed passwords. Moreover removes the _nopass_ status. + ResetPass, + + /// Add a pattern of keys that can be mentioned as part of commands. + Pattern(String), + /// Alias for `~*`. + AllKeys, + /// Flush the list of allowed keys patterns. + ResetKeys, + + /// Performs the following actions: `resetpass`, `resetkeys`, `off`, `-@all`. + /// The user returns to the same state it has immediately after its creation. + Reset, + + /// Raw text of [`ACL rule`][1] that not enumerated above. + /// + /// [1]: https://redis.io/docs/manual/security/acl + Other(String), +} + +impl ToRedisArgs for Rule { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + use self::Rule::*; + + match self { + On => out.write_arg(b"on"), + Off => out.write_arg(b"off"), + + AddCommand(cmd) => out.write_arg_fmt(format_args!("+{cmd}")), + RemoveCommand(cmd) => out.write_arg_fmt(format_args!("-{cmd}")), + AddCategory(cat) => out.write_arg_fmt(format_args!("+@{cat}")), + RemoveCategory(cat) => out.write_arg_fmt(format_args!("-@{cat}")), + AllCommands => out.write_arg(b"allcommands"), + NoCommands => out.write_arg(b"nocommands"), + + AddPass(pass) => out.write_arg_fmt(format_args!(">{pass}")), + RemovePass(pass) => out.write_arg_fmt(format_args!("<{pass}")), + AddHashedPass(pass) => out.write_arg_fmt(format_args!("#{pass}")), + RemoveHashedPass(pass) => out.write_arg_fmt(format_args!("!{pass}")), + NoPass => out.write_arg(b"nopass"), + ResetPass => out.write_arg(b"resetpass"), + + Pattern(pat) => out.write_arg_fmt(format_args!("~{pat}")), + AllKeys => out.write_arg(b"allkeys"), + ResetKeys => out.write_arg(b"resetkeys"), + + Reset => out.write_arg(b"reset"), + + Other(rule) => out.write_arg(rule.as_bytes()), + }; + } +} + +/// An info dictionary type storing Redis ACL information as multiple `Rule`. +/// This type collects key/value data returned by the [`ACL GETUSER`][1] command. +/// +/// [1]: https://redis.io/commands/acl-getuser +#[derive(Debug, Eq, PartialEq)] +pub struct AclInfo { + /// Describes flag rules for the user. Represented by [`Rule::On`][1], + /// [`Rule::Off`][2], [`Rule::AllKeys`][3], [`Rule::AllCommands`][4] and + /// [`Rule::NoPass`][5]. + /// + /// [1]: ./enum.Rule.html#variant.On + /// [2]: ./enum.Rule.html#variant.Off + /// [3]: ./enum.Rule.html#variant.AllKeys + /// [4]: ./enum.Rule.html#variant.AllCommands + /// [5]: ./enum.Rule.html#variant.NoPass + pub flags: Vec, + /// Describes the user's passwords. Represented by [`Rule::AddHashedPass`][1]. + /// + /// [1]: ./enum.Rule.html#variant.AddHashedPass + pub passwords: Vec, + /// Describes capabilities of which commands the user can call. + /// Represented by [`Rule::AddCommand`][1], [`Rule::AddCategory`][2], + /// [`Rule::RemoveCommand`][3] and [`Rule::RemoveCategory`][4]. + /// + /// [1]: ./enum.Rule.html#variant.AddCommand + /// [2]: ./enum.Rule.html#variant.AddCategory + /// [3]: ./enum.Rule.html#variant.RemoveCommand + /// [4]: ./enum.Rule.html#variant.RemoveCategory + pub commands: Vec, + /// Describes patterns of keys which the user can access. Represented by + /// [`Rule::Pattern`][1]. + /// + /// [1]: ./enum.Rule.html#variant.Pattern + pub keys: Vec, +} + +impl FromRedisValue for AclInfo { + fn from_redis_value(v: &Value) -> RedisResult { + let mut it = v + .as_sequence() + .ok_or_else(|| not_convertible_error!(v, ""))? + .iter() + .skip(1) + .step_by(2); + + let (flags, passwords, commands, keys) = match (it.next(), it.next(), it.next(), it.next()) + { + (Some(flags), Some(passwords), Some(commands), Some(keys)) => { + // Parse flags + // Ref: https://github.com/redis/redis/blob/0cabe0cfa7290d9b14596ec38e0d0a22df65d1df/src/acl.c#L83-L90 + let flags = flags + .as_sequence() + .ok_or_else(|| { + not_convertible_error!(flags, "Expect an array response of ACL flags") + })? + .iter() + .map(|flag| match flag { + Value::BulkString(flag) => match flag.as_slice() { + b"on" => Ok(Rule::On), + b"off" => Ok(Rule::Off), + b"allkeys" => Ok(Rule::AllKeys), + b"allcommands" => Ok(Rule::AllCommands), + b"nopass" => Ok(Rule::NoPass), + other => Ok(Rule::Other(String::from_utf8_lossy(other).into_owned())), + }, + _ => Err(not_convertible_error!( + flag, + "Expect an arbitrary binary data" + )), + }) + .collect::>()?; + + let passwords = passwords + .as_sequence() + .ok_or_else(|| { + not_convertible_error!(flags, "Expect an array response of ACL flags") + })? + .iter() + .map(|pass| Ok(Rule::AddHashedPass(String::from_redis_value(pass)?))) + .collect::>()?; + + let commands = match commands { + Value::BulkString(cmd) => std::str::from_utf8(cmd)?, + _ => { + return Err(not_convertible_error!( + commands, + "Expect a valid UTF8 string" + )) + } + } + .split_terminator(' ') + .map(|cmd| match cmd { + x if x.starts_with("+@") => Ok(Rule::AddCategory(x[2..].to_owned())), + x if x.starts_with("-@") => Ok(Rule::RemoveCategory(x[2..].to_owned())), + x if x.starts_with('+') => Ok(Rule::AddCommand(x[1..].to_owned())), + x if x.starts_with('-') => Ok(Rule::RemoveCommand(x[1..].to_owned())), + _ => Err(not_convertible_error!( + cmd, + "Expect a command addition/removal" + )), + }) + .collect::>()?; + + let keys = keys + .as_sequence() + .ok_or_else(|| not_convertible_error!(keys, ""))? + .iter() + .map(|pat| Ok(Rule::Pattern(String::from_redis_value(pat)?))) + .collect::>()?; + + (flags, passwords, commands, keys) + } + _ => { + return Err(not_convertible_error!( + v, + "Expect a resposne from `ACL GETUSER`" + )) + } + }; + + Ok(Self { + flags, + passwords, + commands, + keys, + }) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! assert_args { + ($rule:expr, $arg:expr) => { + assert_eq!($rule.to_redis_args(), vec![$arg.to_vec()]); + }; + } + + #[test] + fn test_rule_to_arg() { + use self::Rule::*; + + assert_args!(On, b"on"); + assert_args!(Off, b"off"); + assert_args!(AddCommand("set".to_owned()), b"+set"); + assert_args!(RemoveCommand("set".to_owned()), b"-set"); + assert_args!(AddCategory("hyperloglog".to_owned()), b"+@hyperloglog"); + assert_args!(RemoveCategory("hyperloglog".to_owned()), b"-@hyperloglog"); + assert_args!(AllCommands, b"allcommands"); + assert_args!(NoCommands, b"nocommands"); + assert_args!(AddPass("mypass".to_owned()), b">mypass"); + assert_args!(RemovePass("mypass".to_owned()), b" io::Result { + let socket = TcpStream::connect(addr).await?; + #[cfg(feature = "tcp_nodelay")] + socket.set_nodelay(true)?; + #[cfg(feature = "keep-alive")] + { + //For now rely on system defaults + const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); + //these are useless error that not going to happen + let mut std_socket = std::net::TcpStream::try_from(socket)?; + let socket2: socket2::Socket = std_socket.into(); + socket2.set_tcp_keepalive(&KEEP_ALIVE)?; + std_socket = socket2.into(); + Ok(std_socket.into()) + } + #[cfg(not(feature = "keep-alive"))] + { + Ok(socket) + } +} +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use crate::connection::TlsConnParams; + +pin_project_lite::pin_project! { + /// Wraps the async_std `AsyncRead/AsyncWrite` in order to implement the required the tokio traits + /// for it + pub struct AsyncStdWrapped { #[pin] inner: T } +} + +impl AsyncStdWrapped { + pub(super) fn new(inner: T) -> Self { + Self { inner } + } +} + +impl AsyncWrite for AsyncStdWrapped +where + T: async_std::io::Write, +{ + fn poll_write( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + buf: &[u8], + ) -> std::task::Poll> { + async_std::io::Write::poll_write(self.project().inner, cx, buf) + } + + fn poll_flush( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + ) -> std::task::Poll> { + async_std::io::Write::poll_flush(self.project().inner, cx) + } + + fn poll_shutdown( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + ) -> std::task::Poll> { + async_std::io::Write::poll_close(self.project().inner, cx) + } +} + +impl AsyncRead for AsyncStdWrapped +where + T: async_std::io::Read, +{ + fn poll_read( + self: Pin<&mut Self>, + cx: &mut core::task::Context, + buf: &mut ReadBuf<'_>, + ) -> std::task::Poll> { + let n = ready!(async_std::io::Read::poll_read( + self.project().inner, + cx, + buf.initialize_unfilled() + ))?; + buf.advance(n); + std::task::Poll::Ready(Ok(())) + } +} + +/// Represents an AsyncStd connectable +pub enum AsyncStd { + /// Represents an Async_std TCP connection. + Tcp(AsyncStdWrapped), + /// Represents an Async_std TLS encrypted TCP connection. + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + TcpTls(AsyncStdWrapped>>), + /// Represents an Async_std Unix connection. + #[cfg(unix)] + Unix(AsyncStdWrapped), +} + +impl AsyncWrite for AsyncStd { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &[u8], + ) -> Poll> { + match &mut *self { + AsyncStd::Tcp(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + AsyncStd::TcpTls(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(unix)] + AsyncStd::Unix(r) => Pin::new(r).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + AsyncStd::Tcp(r) => Pin::new(r).poll_flush(cx), + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + AsyncStd::TcpTls(r) => Pin::new(r).poll_flush(cx), + #[cfg(unix)] + AsyncStd::Unix(r) => Pin::new(r).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + AsyncStd::Tcp(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + AsyncStd::TcpTls(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(unix)] + AsyncStd::Unix(r) => Pin::new(r).poll_shutdown(cx), + } + } +} + +impl AsyncRead for AsyncStd { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + AsyncStd::Tcp(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + AsyncStd::TcpTls(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(unix)] + AsyncStd::Unix(r) => Pin::new(r).poll_read(cx, buf), + } + } +} + +#[async_trait] +impl RedisRuntime for AsyncStd { + async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult { + Ok(connect_tcp(&socket_addr) + .await + .map(|con| Self::Tcp(AsyncStdWrapped::new(con)))?) + } + + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + _tls_params: &Option, + ) -> RedisResult { + let tcp_stream = connect_tcp(&socket_addr).await?; + let tls_connector = if insecure { + TlsConnector::new() + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true) + .use_sni(false) + } else { + TlsConnector::new() + }; + Ok(tls_connector + .connect(hostname, tcp_stream) + .await + .map(|con| Self::TcpTls(AsyncStdWrapped::new(Box::new(con))))?) + } + + #[cfg(feature = "tls-rustls")] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + tls_params: &Option, + ) -> RedisResult { + let tcp_stream = connect_tcp(&socket_addr).await?; + + let config = create_rustls_config(insecure, tls_params.clone())?; + let tls_connector = TlsConnector::from(Arc::new(config)); + + Ok(tls_connector + .connect( + rustls_pki_types::ServerName::try_from(hostname)?.to_owned(), + tcp_stream, + ) + .await + .map(|con| Self::TcpTls(AsyncStdWrapped::new(Box::new(con))))?) + } + + #[cfg(unix)] + async fn connect_unix(path: &Path) -> RedisResult { + Ok(UnixStream::connect(path) + .await + .map(|con| Self::Unix(AsyncStdWrapped::new(con)))?) + } + + fn spawn(f: impl Future + Send + 'static) { + async_std::task::spawn(f); + } + + fn boxed(self) -> Pin> { + match self { + AsyncStd::Tcp(x) => Box::pin(x), + #[cfg(any( + feature = "async-std-native-tls-comp", + feature = "async-std-rustls-comp" + ))] + AsyncStd::TcpTls(x) => Box::pin(x), + #[cfg(unix)] + AsyncStd::Unix(x) => Box::pin(x), + } + } +} diff --git a/glide-core/redis-rs/redis/src/aio/connection.rs b/glide-core/redis-rs/redis/src/aio/connection.rs new file mode 100644 index 0000000000..6b1f6e657a --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/connection.rs @@ -0,0 +1,543 @@ +#![allow(deprecated)] + +#[cfg(feature = "async-std-comp")] +use super::async_std; +use super::ConnectionLike; +use super::{setup_connection, AsyncStream, RedisRuntime}; +use crate::cmd::{cmd, Cmd}; +use crate::connection::{ + resp2_is_pub_sub_state_cleared, resp3_is_pub_sub_state_cleared, ConnectionAddr, ConnectionInfo, + Msg, RedisConnectionInfo, +}; +#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +use crate::parser::ValueCodec; +use crate::types::{ErrorKind, FromRedisValue, RedisError, RedisFuture, RedisResult, Value}; +use crate::{from_owned_redis_value, ProtocolVersion, ToRedisArgs}; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use ::async_std::net::ToSocketAddrs; +use ::tokio::io::{AsyncRead, AsyncWrite, AsyncWriteExt}; +#[cfg(feature = "tokio-comp")] +use ::tokio::net::lookup_host; +use combine::{parser::combinator::AnySendSyncPartialState, stream::PointerOffset}; +use futures_util::future::select_ok; +use futures_util::{ + future::FutureExt, + stream::{Stream, StreamExt}, +}; +use std::net::{IpAddr, SocketAddr}; +use std::pin::Pin; +#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +use tokio_util::codec::Decoder; +use tracing::info; + +/// Represents a stateful redis TCP connection. +#[deprecated(note = "aio::Connection is deprecated. Use aio::MultiplexedConnection instead.")] +pub struct Connection>> { + con: C, + buf: Vec, + decoder: combine::stream::Decoder>, + db: i64, + + // Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`. + // + // This flag is checked when attempting to send a command, and if it's raised, we attempt to + // exit the pubsub state before executing the new request. + pubsub: bool, + + // Field indicating which protocol to use for server communications. + protocol: ProtocolVersion, +} + +fn assert_sync() {} + +#[allow(unused)] +fn test() { + assert_sync::(); +} + +impl Connection { + pub(crate) fn map(self, f: impl FnOnce(C) -> D) -> Connection { + let Self { + con, + buf, + decoder, + db, + pubsub, + protocol, + } = self; + Connection { + con: f(con), + buf, + decoder, + db, + pubsub, + protocol, + } + } +} + +impl Connection +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + /// Constructs a new `Connection` out of a `AsyncRead + AsyncWrite` object + /// and a `RedisConnectionInfo` + pub async fn new(connection_info: &RedisConnectionInfo, con: C) -> RedisResult { + let mut rv = Connection { + con, + buf: Vec::new(), + decoder: combine::stream::Decoder::new(), + db: connection_info.db, + pubsub: false, + protocol: connection_info.protocol, + }; + setup_connection(connection_info, &mut rv).await?; + Ok(rv) + } + + /// Converts this [`Connection`] into [`PubSub`]. + pub fn into_pubsub(self) -> PubSub { + PubSub::new(self) + } + + /// Converts this [`Connection`] into [`Monitor`] + pub fn into_monitor(self) -> Monitor { + Monitor::new(self) + } + + /// Fetches a single response from the connection. + async fn read_response(&mut self) -> RedisResult { + crate::parser::parse_redis_value_async(&mut self.decoder, &mut self.con).await + } + + /// Brings [`Connection`] out of `PubSub` mode. + /// + /// This will unsubscribe this [`Connection`] from all subscriptions. + /// + /// If this function returns error then on all command send tries will be performed attempt + /// to exit from `PubSub` mode until it will be successful. + async fn exit_pubsub(&mut self) -> RedisResult<()> { + let res = self.clear_active_subscriptions().await; + if res.is_ok() { + self.pubsub = false; + } else { + // Raise the pubsub flag to indicate the connection is "stuck" in that state. + self.pubsub = true; + } + + res + } + + /// Get the inner connection out of a PubSub + /// + /// Any active subscriptions are unsubscribed. In the event of an error, the connection is + /// dropped. + async fn clear_active_subscriptions(&mut self) -> RedisResult<()> { + // Responses to unsubscribe commands return in a 3-tuple with values + // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs). + // The "count of remaining subs" includes both pattern subscriptions and non pattern + // subscriptions. Thus, to accurately drain all unsubscribe messages received from the + // server, both commands need to be executed at once. + { + // Prepare both unsubscribe commands + let unsubscribe = crate::Pipeline::new() + .add_command(cmd("UNSUBSCRIBE")) + .add_command(cmd("PUNSUBSCRIBE")) + .get_packed_pipeline(); + + // Execute commands + self.con.write_all(&unsubscribe).await?; + } + + // Receive responses + // + // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe + // commands. There may be more responses if there are active subscriptions. In this case, + // messages are received until the _subscription count_ in the responses reach zero. + let mut received_unsub = false; + let mut received_punsub = false; + if self.protocol != ProtocolVersion::RESP2 { + while let Value::Push { kind, data } = + from_owned_redis_value(self.read_response().await?)? + { + if data.len() >= 2 { + if let Value::Int(num) = data[1] { + if resp3_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &kind, + num as isize, + ) { + break; + } + } + } + } + } else { + loop { + let res: (Vec, (), isize) = + from_owned_redis_value(self.read_response().await?)?; + if resp2_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &res.0, + res.2, + ) { + break; + } + } + } + + // Finally, the connection is back in its normal state since all subscriptions were + // cancelled *and* all unsubscribe messages were received. + Ok(()) + } +} + +#[cfg(feature = "async-std-comp")] +#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] +impl Connection> +where + C: Unpin + ::async_std::io::Read + ::async_std::io::Write + Send, +{ + /// Constructs a new `Connection` out of a `async_std::io::AsyncRead + async_std::io::AsyncWrite` object + /// and a `RedisConnectionInfo` + pub async fn new_async_std(connection_info: &RedisConnectionInfo, con: C) -> RedisResult { + Connection::new(connection_info, async_std::AsyncStdWrapped::new(con)).await + } +} + +pub(crate) async fn connect( + connection_info: &ConnectionInfo, + socket_addr: Option, +) -> RedisResult> +where + C: Unpin + RedisRuntime + AsyncRead + AsyncWrite + Send, +{ + let (con, _ip) = connect_simple::(connection_info, socket_addr).await?; + Connection::new(&connection_info.redis, con).await +} + +impl ConnectionLike for Connection +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + (async move { + if self.pubsub { + self.exit_pubsub().await?; + } + self.buf.clear(); + cmd.write_packed_command(&mut self.buf); + self.con.write_all(&self.buf).await?; + if cmd.is_no_response() { + return Ok(Value::Nil); + } + loop { + match self.read_response().await? { + Value::Push { .. } => continue, + val => return Ok(val), + } + } + }) + .boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + (async move { + if self.pubsub { + self.exit_pubsub().await?; + } + + self.buf.clear(); + cmd.write_packed_pipeline(&mut self.buf); + self.con.write_all(&self.buf).await?; + + let mut first_err = None; + + for _ in 0..offset { + let response = self.read_response().await; + if let Err(err) = response { + if first_err.is_none() { + first_err = Some(err); + } + } + } + + let mut rv = Vec::with_capacity(count); + let mut count = count; + let mut idx = 0; + while idx < count { + let response = self.read_response().await; + match response { + Ok(item) => { + // RESP3 can insert push data between command replies + if let Value::Push { .. } = item { + // if that is the case we have to extend the loop and handle push data + count += 1; + } else { + rv.push(item); + } + } + Err(err) => { + if first_err.is_none() { + first_err = Some(err); + } + } + } + idx += 1; + } + + if let Some(err) = first_err { + Err(err) + } else { + Ok(rv) + } + }) + .boxed() + } + + fn get_db(&self) -> i64 { + self.db + } + + fn is_closed(&self) -> bool { + // always false for AsyncRead + AsyncWrite (cant do better) + false + } +} + +/// Represents a `PubSub` connection. +pub struct PubSub>>(Connection); + +/// Represents a `Monitor` connection. +pub struct Monitor>>(Connection); + +impl PubSub +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + fn new(con: Connection) -> Self { + Self(con) + } + + /// Subscribes to a new channel. + pub async fn subscribe(&mut self, channel: T) -> RedisResult<()> { + let mut cmd = cmd("SUBSCRIBE"); + cmd.arg(channel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Subscribes to a new channel with a pattern. + pub async fn psubscribe(&mut self, pchannel: T) -> RedisResult<()> { + let mut cmd = cmd("PSUBSCRIBE"); + cmd.arg(pchannel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Unsubscribes from a channel. + pub async fn unsubscribe(&mut self, channel: T) -> RedisResult<()> { + let mut cmd = cmd("UNSUBSCRIBE"); + cmd.arg(channel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Unsubscribes from a channel with a pattern. + pub async fn punsubscribe(&mut self, pchannel: T) -> RedisResult<()> { + let mut cmd = cmd("PUNSUBSCRIBE"); + cmd.arg(pchannel); + if self.0.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + cmd.query_async(&mut self.0).await + } + + /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions. + /// + /// The message itself is still generic and can be converted into an appropriate type through + /// the helper methods on it. + pub fn on_message(&mut self) -> impl Stream + '_ { + ValueCodec::default() + .framed(&mut self.0.con) + .filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?.ok()?) })) + } + + /// Returns [`Stream`] of [`Msg`]s from this [`PubSub`]s subscriptions consuming it. + /// + /// The message itself is still generic and can be converted into an appropriate type through + /// the helper methods on it. + /// This can be useful in cases where the stream needs to be returned or held by something other + /// than the [`PubSub`]. + pub fn into_on_message(self) -> impl Stream { + ValueCodec::default() + .framed(self.0.con) + .filter_map(|msg| Box::pin(async move { Msg::from_value(&msg.ok()?.ok()?) })) + } + + /// Exits from `PubSub` mode and converts [`PubSub`] into [`Connection`]. + #[deprecated(note = "aio::Connection is deprecated")] + pub async fn into_connection(mut self) -> Connection { + self.0.exit_pubsub().await.ok(); + + self.0 + } +} + +impl Monitor +where + C: Unpin + AsyncRead + AsyncWrite + Send, +{ + /// Create a [`Monitor`] from a [`Connection`] + pub fn new(con: Connection) -> Self { + Self(con) + } + + /// Deliver the MONITOR command to this [`Monitor`]ing wrapper. + pub async fn monitor(&mut self) -> RedisResult<()> { + cmd("MONITOR").query_async(&mut self.0).await + } + + /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection + pub fn on_message(&mut self) -> impl Stream + '_ { + ValueCodec::default() + .framed(&mut self.0.con) + .filter_map(|value| { + Box::pin(async move { T::from_owned_redis_value(value.ok()?.ok()?).ok() }) + }) + } + + /// Returns [`Stream`] of [`FromRedisValue`] values from this [`Monitor`]ing connection + pub fn into_on_message(self) -> impl Stream { + ValueCodec::default() + .framed(self.0.con) + .filter_map(|value| { + Box::pin(async move { T::from_owned_redis_value(value.ok()?.ok()?).ok() }) + }) + } +} + +pub(crate) async fn get_socket_addrs( + host: &str, + port: u16, +) -> RedisResult + Send + '_> { + #[cfg(feature = "tokio-comp")] + let socket_addrs = lookup_host((host, port)).await?; + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + let socket_addrs = (host, port).to_socket_addrs().await?; + + let mut socket_addrs = socket_addrs.peekable(); + match socket_addrs.peek() { + Some(_) => Ok(socket_addrs), + None => Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "No address found for host", + ))), + } +} + +/// Logs the creation of a connection, including its type, the node, and optionally its IP address. +fn log_conn_creation(conn_type: &str, node: T, ip: Option) +where + T: std::fmt::Debug, +{ + info!( + "Creating {conn_type} connection for node: {node:?}{}", + ip.map(|ip| format!(", IP: {:?}", ip)).unwrap_or_default() + ); +} + +pub(crate) async fn connect_simple( + connection_info: &ConnectionInfo, + _socket_addr: Option, +) -> RedisResult<(T, Option)> { + Ok(match connection_info.addr { + ConnectionAddr::Tcp(ref host, port) => { + if let Some(socket_addr) = _socket_addr { + return Ok::<_, RedisError>(( + ::connect_tcp(socket_addr).await?, + Some(socket_addr.ip()), + )); + } + let socket_addrs = get_socket_addrs(host, port).await?; + select_ok(socket_addrs.map(|socket_addr| { + log_conn_creation("TCP", format!("{host}:{port}"), Some(socket_addr.ip())); + Box::pin(async move { + Ok::<_, RedisError>(( + ::connect_tcp(socket_addr).await?, + Some(socket_addr.ip()), + )) + }) + })) + .await? + .0 + } + + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + ConnectionAddr::TcpTls { + ref host, + port, + insecure, + ref tls_params, + } => { + if let Some(socket_addr) = _socket_addr { + return Ok::<_, RedisError>(( + ::connect_tcp_tls(host, socket_addr, insecure, tls_params).await?, + Some(socket_addr.ip()), + )); + } + let socket_addrs = get_socket_addrs(host, port).await?; + select_ok(socket_addrs.map(|socket_addr| { + log_conn_creation( + "TCP with TLS", + format!("{host}:{port}"), + Some(socket_addr.ip()), + ); + Box::pin(async move { + Ok::<_, RedisError>(( + ::connect_tcp_tls(host, socket_addr, insecure, tls_params).await?, + Some(socket_addr.ip()), + )) + }) + })) + .await? + .0 + } + + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] + ConnectionAddr::TcpTls { .. } => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot connect to TCP with TLS without the tls feature" + )); + } + + #[cfg(unix)] + ConnectionAddr::Unix(ref path) => { + log_conn_creation("UDS", path, None); + (::connect_unix(path).await?, None) + } + + #[cfg(not(unix))] + ConnectionAddr::Unix(_) => { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot connect to unix sockets \ + on this platform", + ))) + } + }) +} diff --git a/glide-core/redis-rs/redis/src/aio/connection_manager.rs b/glide-core/redis-rs/redis/src/aio/connection_manager.rs new file mode 100644 index 0000000000..61df9bc31a --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/connection_manager.rs @@ -0,0 +1,310 @@ +use super::RedisFuture; +use crate::client::GlideConnectionOptions; +use crate::cmd::Cmd; +use crate::push_manager::PushManager; +use crate::types::{RedisError, RedisResult, Value}; +use crate::{ + aio::{ConnectionLike, MultiplexedConnection, Runtime}, + Client, +}; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use ::async_std::net::ToSocketAddrs; +use arc_swap::ArcSwap; +use futures::{ + future::{self, Shared}, + FutureExt, +}; +use futures_util::future::BoxFuture; +use std::sync::Arc; +use tokio_retry::strategy::{jitter, ExponentialBackoff}; +use tokio_retry::Retry; + +/// A `ConnectionManager` is a proxy that wraps a [multiplexed +/// connection][multiplexed-connection] and automatically reconnects to the +/// server when necessary. +/// +/// Like the [`MultiplexedConnection`][multiplexed-connection], this +/// manager can be cloned, allowing requests to be be sent concurrently on +/// the same underlying connection (tcp/unix socket). +/// +/// ## Behavior +/// +/// - When creating an instance of the `ConnectionManager`, an initial +/// connection will be established and awaited. Connection errors will be +/// returned directly. +/// - When a command sent to the server fails with an error that represents +/// a "connection dropped" condition, that error will be passed on to the +/// user, but it will trigger a reconnection in the background. +/// - The reconnect code will atomically swap the current (dead) connection +/// with a future that will eventually resolve to a `MultiplexedConnection` +/// or to a `RedisError` +/// - All commands that are issued after the reconnect process has been +/// initiated, will have to await the connection future. +/// - If reconnecting fails, all pending commands will be failed as well. A +/// new reconnection attempt will be triggered if the error is an I/O error. +/// +/// [multiplexed-connection]: struct.MultiplexedConnection.html +#[derive(Clone)] +pub struct ConnectionManager { + /// Information used for the connection. This is needed to be able to reconnect. + client: Client, + /// The connection future. + /// + /// The `ArcSwap` is required to be able to replace the connection + /// without making the `ConnectionManager` mutable. + connection: Arc>>, + + runtime: Runtime, + retry_strategy: ExponentialBackoff, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + push_manager: PushManager, +} + +/// A `RedisResult` that can be cloned because `RedisError` is behind an `Arc`. +type CloneableRedisResult = Result>; + +/// Type alias for a shared boxed future that will resolve to a `CloneableRedisResult`. +type SharedRedisFuture = Shared>>; + +/// Handle a command result. If the connection was dropped, reconnect. +macro_rules! reconnect_if_dropped { + ($self:expr, $result:expr, $current:expr) => { + if let Err(ref e) = $result { + if e.is_unrecoverable_error() { + $self.reconnect($current); + } + } + }; +} + +/// Handle a connection result. If there's an I/O error, reconnect. +/// Propagate any error. +macro_rules! reconnect_if_io_error { + ($self:expr, $result:expr, $current:expr) => { + if let Err(e) = $result { + if e.is_io_error() { + $self.reconnect($current); + } + return Err(e); + } + }; +} + +impl ConnectionManager { + const DEFAULT_CONNECTION_RETRY_EXPONENT_BASE: u64 = 2; + const DEFAULT_CONNECTION_RETRY_FACTOR: u64 = 100; + const DEFAULT_NUMBER_OF_CONNECTION_RETRIESE: usize = 6; + + /// Connect to the server and store the connection inside the returned `ConnectionManager`. + /// + /// This requires the `connection-manager` feature, which will also pull in + /// the Tokio executor. + pub async fn new(client: Client) -> RedisResult { + Self::new_with_backoff( + client, + Self::DEFAULT_CONNECTION_RETRY_EXPONENT_BASE, + Self::DEFAULT_CONNECTION_RETRY_FACTOR, + Self::DEFAULT_NUMBER_OF_CONNECTION_RETRIESE, + ) + .await + } + + /// Connect to the server and store the connection inside the returned `ConnectionManager`. + /// + /// This requires the `connection-manager` feature, which will also pull in + /// the Tokio executor. + /// + /// In case of reconnection issues, the manager will retry reconnection + /// number_of_retries times, with an exponentially increasing delay, calculated as + /// rand(0 .. factor * (exponent_base ^ current-try)). + pub async fn new_with_backoff( + client: Client, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + ) -> RedisResult { + Self::new_with_backoff_and_timeouts( + client, + exponent_base, + factor, + number_of_retries, + std::time::Duration::MAX, + std::time::Duration::MAX, + ) + .await + } + + /// Connect to the server and store the connection inside the returned `ConnectionManager`. + /// + /// This requires the `connection-manager` feature, which will also pull in + /// the Tokio executor. + /// + /// In case of reconnection issues, the manager will retry reconnection + /// number_of_retries times, with an exponentially increasing delay, calculated as + /// rand(0 .. factor * (exponent_base ^ current-try)). + /// + /// The new connection will timeout operations after `response_timeout` has passed. + /// Each connection attempt to the server will timeout after `connection_timeout`. + pub async fn new_with_backoff_and_timeouts( + client: Client, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + ) -> RedisResult { + // Create a MultiplexedConnection and wait for it to be established + let push_manager = PushManager::default(); + let runtime = Runtime::locate(); + let retry_strategy = ExponentialBackoff::from_millis(exponent_base).factor(factor); + let mut connection = Self::new_connection( + client.clone(), + retry_strategy.clone(), + number_of_retries, + response_timeout, + connection_timeout, + ) + .await?; + + // Wrap the connection in an `ArcSwap` instance for fast atomic access + connection.set_push_manager(push_manager.clone()).await; + Ok(Self { + client, + connection: Arc::new(ArcSwap::from_pointee( + future::ok(connection).boxed().shared(), + )), + runtime, + number_of_retries, + retry_strategy, + response_timeout, + connection_timeout, + push_manager, + }) + } + + async fn new_connection( + client: Client, + exponential_backoff: ExponentialBackoff, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + ) -> RedisResult { + let retry_strategy = exponential_backoff.map(jitter).take(number_of_retries); + Retry::spawn(retry_strategy, || { + client.get_multiplexed_async_connection_with_timeouts( + response_timeout, + connection_timeout, + GlideConnectionOptions::default(), + ) + }) + .await + } + + /// Reconnect and overwrite the old connection. + /// + /// The `current` guard points to the shared future that was active + /// when the connection loss was detected. + fn reconnect(&self, current: arc_swap::Guard>>) { + let client = self.client.clone(); + let retry_strategy = self.retry_strategy.clone(); + let number_of_retries = self.number_of_retries; + let response_timeout = self.response_timeout; + let connection_timeout = self.connection_timeout; + let pmc = self.push_manager.clone(); + let new_connection: SharedRedisFuture = async move { + let mut con = Self::new_connection( + client, + retry_strategy, + number_of_retries, + response_timeout, + connection_timeout, + ) + .await?; + con.set_push_manager(pmc).await; + Ok(con) + } + .boxed() + .shared(); + + // Update the connection in the connection manager + let new_connection_arc = Arc::new(new_connection.clone()); + let prev = self + .connection + .compare_and_swap(¤t, new_connection_arc); + + // If the swap happened... + if Arc::ptr_eq(&prev, ¤t) { + // ...start the connection attempt immediately but do not wait on it. + self.runtime.spawn(new_connection.map(|_| ())); + } + } + + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult { + // Clone connection to avoid having to lock the ArcSwap in write mode + let guard = self.connection.load(); + let connection_result = (**guard) + .clone() + .await + .map_err(|e| e.clone_mostly("Reconnecting failed")); + reconnect_if_io_error!(self, connection_result, guard); + let result = connection_result?.send_packed_command(cmd).await; + reconnect_if_dropped!(self, &result, guard); + result + } + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + pub async fn send_packed_commands( + &mut self, + cmd: &crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisResult> { + // Clone shared connection future to avoid having to lock the ArcSwap in write mode + let guard = self.connection.load(); + let connection_result = (**guard) + .clone() + .await + .map_err(|e| e.clone_mostly("Reconnecting failed")); + reconnect_if_io_error!(self, connection_result, guard); + let result = connection_result? + .send_packed_commands(cmd, offset, count) + .await; + reconnect_if_dropped!(self, &result, guard); + result + } + + /// Returns `PushManager` of Connection, this method is used to subscribe/unsubscribe from Push types + pub fn get_push_manager(&self) -> PushManager { + self.push_manager.clone() + } +} + +impl ConnectionLike for ConnectionManager { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + (async move { self.send_packed_command(cmd).await }).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() + } + + fn get_db(&self) -> i64 { + self.client.connection_info().redis.db + } + + fn is_closed(&self) -> bool { + // always return false due to automatic reconnect + false + } +} diff --git a/glide-core/redis-rs/redis/src/aio/mod.rs b/glide-core/redis-rs/redis/src/aio/mod.rs new file mode 100644 index 0000000000..ffe2c9e3a2 --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/mod.rs @@ -0,0 +1,286 @@ +//! Adds async IO support to redis. +use crate::cmd::{cmd, Cmd}; +use crate::connection::{ + get_resp3_hello_command_error, PubSubSubscriptionKind, RedisConnectionInfo, +}; +use crate::types::{ErrorKind, ProtocolVersion, RedisFuture, RedisResult, Value}; +use crate::PushKind; +use ::tokio::io::{AsyncRead, AsyncWrite}; +use async_trait::async_trait; +use futures_util::Future; +use std::net::SocketAddr; +#[cfg(unix)] +use std::path::Path; +use std::pin::Pin; +use std::time::Duration; + +/// Enables the async_std compatibility +#[cfg(feature = "async-std-comp")] +#[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] +pub mod async_std; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use crate::connection::TlsConnParams; + +/// Enables the tokio compatibility +#[cfg(feature = "tokio-comp")] +#[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] +pub mod tokio; + +/// Represents the ability of connecting via TCP or via Unix socket +#[async_trait] +pub(crate) trait RedisRuntime: AsyncStream + Send + Sync + Sized + 'static { + /// Performs a TCP connection + async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult; + + // Performs a TCP TLS connection + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + tls_params: &Option, + ) -> RedisResult; + + /// Performs a UNIX connection + #[cfg(unix)] + async fn connect_unix(path: &Path) -> RedisResult; + + fn spawn(f: impl Future + Send + 'static); + + fn boxed(self) -> Pin> { + Box::pin(self) + } +} + +/// Trait for objects that implements `AsyncRead` and `AsyncWrite` +pub trait AsyncStream: AsyncRead + AsyncWrite {} +impl AsyncStream for S where S: AsyncRead + AsyncWrite {} + +/// An async abstraction over connections. +pub trait ConnectionLike { + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value>; + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + /// Important - this function is meant for internal usage, since it's + /// easy to pass incorrect `offset` & `count` parameters, which might + /// cause the connection to enter an erroneous state. Users shouldn't + /// call it, instead using the Pipeline::query_async function. + #[doc(hidden)] + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec>; + + /// Returns the database this connection is bound to. Note that this + /// information might be unreliable because it's initially cached and + /// also might be incorrect if the connection like object is not + /// actually connected. + fn get_db(&self) -> i64; + + /// Returns the state of the connection + fn is_closed(&self) -> bool; +} + +/// Implements ability to notify about disconnection events +#[async_trait] +pub trait DisconnectNotifier: Send + Sync { + /// Notify about disconnect event + fn notify_disconnect(&mut self); + + /// Wait for disconnect event with timeout + async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration); + + /// Intended to be used with Box + fn clone_box(&self) -> Box; +} + +impl Clone for Box { + fn clone(&self) -> Box { + self.clone_box() + } +} + +// Initial setup for every connection. +async fn setup_connection(connection_info: &RedisConnectionInfo, con: &mut C) -> RedisResult<()> +where + C: ConnectionLike, +{ + if connection_info.protocol != ProtocolVersion::RESP2 { + let hello_cmd = resp3_hello(connection_info); + let val: RedisResult = hello_cmd.query_async(con).await; + if let Err(err) = val { + return Err(get_resp3_hello_command_error(err)); + } + } else if let Some(password) = &connection_info.password { + let mut command = cmd("AUTH"); + if let Some(username) = &connection_info.username { + command.arg(username); + } + match command.arg(password).query_async(con).await { + Ok(Value::Okay) => (), + Err(e) => { + let err_msg = e.detail().ok_or(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + ))?; + + if !err_msg.contains("wrong number of arguments for 'auth' command") { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + )); + } + + let mut command = cmd("AUTH"); + match command.arg(password).query_async(con).await { + Ok(Value::Okay) => (), + _ => { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed" + )); + } + } + } + _ => { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed" + )); + } + } + } + + if connection_info.db != 0 { + match cmd("SELECT").arg(connection_info.db).query_async(con).await { + Ok(Value::Okay) => (), + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to switch database" + )), + } + } + + if let Some(client_name) = &connection_info.client_name { + match cmd("CLIENT") + .arg("SETNAME") + .arg(client_name) + .query_async(con) + .await + { + Ok(Value::Okay) => {} + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to set client name" + )), + } + } + + // result is ignored, as per the command's instructions. + // https://redis.io/commands/client-setinfo/ + #[cfg(not(feature = "disable-client-setinfo"))] + let _: RedisResult<()> = crate::connection::client_set_info_pipeline() + .query_async(con) + .await; + + // resubscribe + if connection_info.protocol != ProtocolVersion::RESP3 { + return Ok(()); + } + static KIND_TO_COMMAND: [(PubSubSubscriptionKind, &str); 3] = [ + (PubSubSubscriptionKind::Exact, "SUBSCRIBE"), + (PubSubSubscriptionKind::Pattern, "PSUBSCRIBE"), + (PubSubSubscriptionKind::Sharded, "SSUBSCRIBE"), + ]; + + if connection_info.pubsub_subscriptions.is_none() { + return Ok(()); + } + + for (subscription_kind, channels_patterns) in + connection_info.pubsub_subscriptions.as_ref().unwrap() + { + for channel_pattern in channels_patterns.iter() { + let mut subscribe_command = + cmd(KIND_TO_COMMAND[Into::::into(*subscription_kind)].1); + subscribe_command.arg(channel_pattern); + + // This is a quite intricate code - Per RESP3, subscriptions commands do not return anything. + // Instead, push messages will be pushed for each channel. Thus, this is not a typycal request-response pattern. + // The act of pushing is asyncronous with the regard to the subscription command, and might be delayed for some time after the server state was already updated. + // (i.e. the behaviour is implementation defined). + // We will assume the configured time out is enough for the server to push the notifications. + match subscribe_command.query_async(con).await { + Ok(Value::Push { kind, data }) => { + match *subscription_kind { + PubSubSubscriptionKind::Exact => { + if kind != PushKind::Subscribe + || Value::BulkString(channel_pattern.clone()) != data[0] + { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to restore Exact subscription channels" + )); + } + } + PubSubSubscriptionKind::Pattern => { + if kind != PushKind::PSubscribe + || Value::BulkString(channel_pattern.clone()) != data[0] + { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to restore Pattern subscription channels" + )); + } + } + PubSubSubscriptionKind::Sharded => { + if kind != PushKind::SSubscribe + || Value::BulkString(channel_pattern.clone()) != data[0] + { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to restore Sharded subscription channels" + )); + } + } + } + } + _ => { + fail!(( + ErrorKind::ResponseError, + // TODO: Consider printing the exact command + "Failed to receive subscription notification while restoring subscription channels" + )); + } + } + } + } + + Ok(()) +} + +mod connection; +pub use connection::*; +mod multiplexed_connection; +pub use multiplexed_connection::*; +#[cfg(feature = "connection-manager")] +mod connection_manager; +#[cfg(feature = "connection-manager")] +#[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] +pub use connection_manager::*; +mod runtime; +use crate::commands::resp3_hello; +pub(super) use runtime::*; diff --git a/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs new file mode 100644 index 0000000000..1067bc2df5 --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/multiplexed_connection.rs @@ -0,0 +1,656 @@ +use super::{ConnectionLike, Runtime}; +use crate::aio::setup_connection; +use crate::aio::DisconnectNotifier; +use crate::client::GlideConnectionOptions; +use crate::cmd::Cmd; +#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +use crate::parser::ValueCodec; +use crate::push_manager::PushManager; +use crate::types::{RedisError, RedisFuture, RedisResult, Value}; +use crate::{cmd, ConnectionInfo, ProtocolVersion, PushKind}; +use ::tokio::{ + io::{AsyncRead, AsyncWrite}, + sync::{mpsc, oneshot}, +}; +use arc_swap::ArcSwap; +use futures_util::{ + future::{Future, FutureExt}, + ready, + sink::Sink, + stream::{self, Stream, StreamExt, TryStreamExt as _}, +}; +use pin_project_lite::pin_project; +use std::collections::VecDeque; +use std::fmt; +use std::fmt::Debug; +use std::io; +use std::pin::Pin; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::sync::Arc; +use std::task::{self, Poll}; +use std::time::Duration; +#[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] +use tokio_util::codec::Decoder; + +// Senders which the result of a single request are sent through +type PipelineOutput = oneshot::Sender>; + +enum ResponseAggregate { + SingleCommand, + Pipeline { + expected_response_count: usize, + current_response_count: usize, + buffer: Vec, + first_err: Option, + }, +} + +impl ResponseAggregate { + fn new(pipeline_response_count: Option) -> Self { + match pipeline_response_count { + Some(response_count) => ResponseAggregate::Pipeline { + expected_response_count: response_count, + current_response_count: 0, + buffer: Vec::new(), + first_err: None, + }, + None => ResponseAggregate::SingleCommand, + } + } +} + +struct InFlight { + output: PipelineOutput, + response_aggregate: ResponseAggregate, +} + +// A single message sent through the pipeline +struct PipelineMessage { + input: S, + output: PipelineOutput, + // If `None`, this is a single request, not a pipeline of multiple requests. + pipeline_response_count: Option, +} + +/// Wrapper around a `Stream + Sink` where each item sent through the `Sink` results in one or more +/// items being output by the `Stream` (the number is specified at time of sending). With the +/// interface provided by `Pipeline` an easy interface of request to response, hiding the `Stream` +/// and `Sink`. +#[derive(Clone)] +struct Pipeline { + sender: mpsc::Sender>, + push_manager: Arc>, + is_stream_closed: Arc, +} + +impl Debug for Pipeline +where + SinkItem: Debug, +{ + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_tuple("Pipeline").field(&self.sender).finish() + } +} + +pin_project! { + struct PipelineSink { + #[pin] + sink_stream: T, + in_flight: VecDeque, + error: Option, + push_manager: Arc>, + disconnect_notifier: Option>, + is_stream_closed: Arc, + } +} + +impl PipelineSink +where + T: Stream> + 'static, +{ + fn new( + sink_stream: T, + push_manager: Arc>, + disconnect_notifier: Option>, + is_stream_closed: Arc, + ) -> Self + where + T: Sink + Stream> + 'static, + { + PipelineSink { + sink_stream, + in_flight: VecDeque::new(), + error: None, + push_manager, + disconnect_notifier, + is_stream_closed, + } + } + + // Read messages from the stream and send them back to the caller + fn poll_read(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + loop { + let item = match ready!(self.as_mut().project().sink_stream.poll_next(cx)) { + Some(result) => result, + // The redis response stream is not going to produce any more items so we `Err` + // to break out of the `forward` combinator and stop handling requests + None => { + // this is the right place to notify about the passive TCP disconnect + // In other places we cannot distinguish between the active destruction of MultiplexedConnection and passive disconnect + if let Some(disconnect_notifier) = self.as_mut().project().disconnect_notifier { + disconnect_notifier.notify_disconnect(); + } + self.is_stream_closed.store(true, Ordering::Relaxed); + return Poll::Ready(Err(())); + } + }; + self.as_mut().send_result(item); + } + } + + fn send_result(self: Pin<&mut Self>, result: RedisResult) { + let self_ = self.project(); + let mut skip_value = false; + if let Ok(res) = &result { + if let Value::Push { kind, data: _data } = res { + self_.push_manager.load().try_send_raw(res); + if !kind.has_reply() { + // If it's not true then push kind is converted to reply of a command + skip_value = true; + } + } + } + + let mut entry = match self_.in_flight.pop_front() { + Some(entry) => entry, + None => return, + }; + + if skip_value { + self_.in_flight.push_front(entry); + return; + } + + match &mut entry.response_aggregate { + ResponseAggregate::SingleCommand => { + entry.output.send(result).ok(); + } + ResponseAggregate::Pipeline { + expected_response_count, + current_response_count, + buffer, + first_err, + } => { + match result { + Ok(item) => { + buffer.push(item); + } + Err(err) => { + if first_err.is_none() { + *first_err = Some(err); + } + } + } + + *current_response_count += 1; + if current_response_count < expected_response_count { + // Need to gather more response values + self_.in_flight.push_front(entry); + return; + } + + let response = match first_err.take() { + Some(err) => Err(err), + None => Ok(Value::Array(std::mem::take(buffer))), + }; + + // `Err` means that the receiver was dropped in which case it does not + // care about the output and we can continue by just dropping the value + // and sender + entry.output.send(response).ok(); + } + } + } +} + +impl Sink> for PipelineSink +where + T: Sink + Stream> + 'static, +{ + type Error = (); + + // Retrieve incoming messages and write them to the sink + fn poll_ready( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + match ready!(self.as_mut().project().sink_stream.poll_ready(cx)) { + Ok(()) => Ok(()).into(), + Err(err) => { + *self.project().error = Some(err); + Ok(()).into() + } + } + } + + fn start_send( + mut self: Pin<&mut Self>, + PipelineMessage { + input, + output, + pipeline_response_count, + }: PipelineMessage, + ) -> Result<(), Self::Error> { + // If there is nothing to receive our output we do not need to send the message as it is + // ambiguous whether the message will be sent anyway. Helps shed some load on the + // connection. + if output.is_closed() { + return Ok(()); + } + + let self_ = self.as_mut().project(); + + if let Some(err) = self_.error.take() { + let _ = output.send(Err(err)); + return Err(()); + } + + match self_.sink_stream.start_send(input) { + Ok(()) => { + let response_aggregate = ResponseAggregate::new(pipeline_response_count); + let entry = InFlight { + output, + response_aggregate, + }; + + self_.in_flight.push_back(entry); + Ok(()) + } + Err(err) => { + let _ = output.send(Err(err)); + Err(()) + } + } + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + ready!(self + .as_mut() + .project() + .sink_stream + .poll_flush(cx) + .map_err(|err| { + self.as_mut().send_result(Err(err)); + }))?; + self.poll_read(cx) + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + // No new requests will come in after the first call to `close` but we need to complete any + // in progress requests before closing + if !self.in_flight.is_empty() { + ready!(self.as_mut().poll_flush(cx))?; + } + let this = self.as_mut().project(); + this.sink_stream.poll_close(cx).map_err(|err| { + self.send_result(Err(err)); + }) + } +} + +impl Pipeline +where + SinkItem: Send + 'static, +{ + fn new( + sink_stream: T, + disconnect_notifier: Option>, + ) -> (Self, impl Future) + where + T: Sink + Stream> + 'static, + T: Send + 'static, + T::Item: Send, + T::Error: Send, + T::Error: ::std::fmt::Debug, + { + const BUFFER_SIZE: usize = 50; + let (sender, mut receiver) = mpsc::channel(BUFFER_SIZE); + let push_manager: Arc> = + Arc::new(ArcSwap::new(Arc::new(PushManager::default()))); + let is_stream_closed = Arc::new(AtomicBool::new(false)); + let sink = PipelineSink::new::( + sink_stream, + push_manager.clone(), + disconnect_notifier, + is_stream_closed.clone(), + ); + let f = stream::poll_fn(move |cx| receiver.poll_recv(cx)) + .map(Ok) + .forward(sink) + .map(|_| ()); + ( + Pipeline { + sender, + push_manager, + is_stream_closed, + }, + f, + ) + } + + // `None` means that the stream was out of items causing that poll loop to shut down. + async fn send_single( + &mut self, + item: SinkItem, + timeout: Duration, + ) -> Result> { + self.send_recv(item, None, timeout).await + } + + async fn send_recv( + &mut self, + input: SinkItem, + // If `None`, this is a single request, not a pipeline of multiple requests. + pipeline_response_count: Option, + timeout: Duration, + ) -> Result> { + let (sender, receiver) = oneshot::channel(); + + self.sender + .send(PipelineMessage { + input, + pipeline_response_count, + output: sender, + }) + .await + .map_err(|_| None)?; + match Runtime::locate().timeout(timeout, receiver).await { + Ok(Ok(result)) => result.map_err(Some), + Ok(Err(_)) => { + // The `sender` was dropped which likely means that the stream part + // failed for one reason or another + Err(None) + } + Err(elapsed) => Err(Some(elapsed.into())), + } + } + + /// Sets `PushManager` of Pipeline + async fn set_push_manager(&mut self, push_manager: PushManager) { + self.push_manager.store(Arc::new(push_manager)); + } + + pub fn is_closed(&self) -> bool { + self.is_stream_closed.load(Ordering::Relaxed) + } +} + +/// A connection object which can be cloned, allowing requests to be be sent concurrently +/// on the same underlying connection (tcp/unix socket). +#[derive(Clone)] +pub struct MultiplexedConnection { + pipeline: Pipeline>, + db: i64, + response_timeout: Duration, + protocol: ProtocolVersion, + push_manager: PushManager, +} + +impl Debug for MultiplexedConnection { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("MultiplexedConnection") + .field("pipeline", &self.pipeline) + .field("db", &self.db) + .finish() + } +} + +impl MultiplexedConnection { + /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object + /// and a `ConnectionInfo` + pub async fn new( + connection_info: &ConnectionInfo, + stream: C, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(Self, impl Future)> + where + C: Unpin + AsyncRead + AsyncWrite + Send + 'static, + { + Self::new_with_response_timeout( + connection_info, + stream, + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Constructs a new `MultiplexedConnection` out of a `AsyncRead + AsyncWrite` object + /// and a `ConnectionInfo`. The new object will wait on operations for the given `response_timeout`. + pub async fn new_with_response_timeout( + connection_info: &ConnectionInfo, + stream: C, + response_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(Self, impl Future)> + where + C: Unpin + AsyncRead + AsyncWrite + Send + 'static, + { + fn boxed( + f: impl Future + Send + 'static, + ) -> Pin + Send>> { + Box::pin(f) + } + + #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))] + compile_error!("tokio-comp or async-std-comp features required for aio feature"); + + let redis_connection_info = &connection_info.redis; + let codec = ValueCodec::default() + .framed(stream) + .and_then(|msg| async move { msg }); + let (mut pipeline, driver) = + Pipeline::new(codec, glide_connection_options.disconnect_notifier); + let driver = boxed(driver); + let pm = PushManager::default(); + if let Some(sender) = glide_connection_options.push_sender { + pm.replace_sender(sender); + } + + pipeline.set_push_manager(pm.clone()).await; + let mut con = MultiplexedConnection { + pipeline, + db: connection_info.redis.db, + response_timeout, + push_manager: pm, + protocol: redis_connection_info.protocol, + }; + let driver = { + let auth = setup_connection(&connection_info.redis, &mut con); + + futures_util::pin_mut!(auth); + + match futures_util::future::select(auth, driver).await { + futures_util::future::Either::Left((result, driver)) => { + result?; + driver + } + futures_util::future::Either::Right(((), _)) => { + return Err(RedisError::from(( + crate::ErrorKind::IoError, + "Multiplexed connection driver unexpectedly terminated", + ))); + } + } + }; + Ok((con, driver)) + } + + /// Sets the time that the multiplexer will wait for responses on operations before failing. + pub fn set_response_timeout(&mut self, timeout: std::time::Duration) { + self.response_timeout = timeout; + } + + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + pub async fn send_packed_command(&mut self, cmd: &Cmd) -> RedisResult { + let result = self + .pipeline + .send_single(cmd.get_packed_command(), self.response_timeout) + .await + .map_err(|err| { + err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) + }); + if self.protocol != ProtocolVersion::RESP2 { + if let Err(e) = &result { + if e.is_connection_dropped() { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + } + } + } + result + } + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + pub async fn send_packed_commands( + &mut self, + cmd: &crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisResult> { + let result = self + .pipeline + .send_recv( + cmd.get_packed_pipeline(), + Some(offset + count), + self.response_timeout, + ) + .await + .map_err(|err| { + err.unwrap_or_else(|| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) + }); + + if self.protocol != ProtocolVersion::RESP2 { + if let Err(e) = &result { + if e.is_connection_dropped() { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + } + } + } + let value = result?; + match value { + Value::Array(mut values) => { + values.drain(..offset); + Ok(values) + } + _ => Ok(vec![value]), + } + } + + /// Sets `PushManager` of connection + pub async fn set_push_manager(&mut self, push_manager: PushManager) { + self.push_manager = push_manager.clone(); + self.pipeline.set_push_manager(push_manager).await; + } +} + +impl ConnectionLike for MultiplexedConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + (async move { self.send_packed_command(cmd).await }).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + cmd: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + (async move { self.send_packed_commands(cmd, offset, count).await }).boxed() + } + + fn get_db(&self) -> i64 { + self.db + } + + fn is_closed(&self) -> bool { + self.pipeline.is_closed() + } +} +impl MultiplexedConnection { + /// Subscribes to a new channel. + pub async fn subscribe(&mut self, channel_name: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("SUBSCRIBE"); + cmd.arg(channel_name.clone()); + cmd.query_async(self).await?; + Ok(()) + } + + /// Unsubscribes from channel. + pub async fn unsubscribe(&mut self, channel_name: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("UNSUBSCRIBE"); + cmd.arg(channel_name); + cmd.query_async(self).await?; + Ok(()) + } + + /// Subscribes to a new channel with pattern. + pub async fn psubscribe(&mut self, channel_pattern: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("PSUBSCRIBE"); + cmd.arg(channel_pattern.clone()); + cmd.query_async(self).await?; + Ok(()) + } + + /// Unsubscribes from channel pattern. + pub async fn punsubscribe(&mut self, channel_pattern: String) -> RedisResult<()> { + if self.protocol == ProtocolVersion::RESP2 { + return Err(RedisError::from(( + crate::ErrorKind::InvalidClientConfig, + "RESP3 is required for this command", + ))); + } + let mut cmd = cmd("PUNSUBSCRIBE"); + cmd.arg(channel_pattern); + cmd.query_async(self).await?; + Ok(()) + } + + /// Returns `PushManager` of Connection, this method is used to subscribe/unsubscribe from Push types + pub fn get_push_manager(&self) -> PushManager { + self.push_manager.clone() + } +} diff --git a/glide-core/redis-rs/redis/src/aio/runtime.rs b/glide-core/redis-rs/redis/src/aio/runtime.rs new file mode 100644 index 0000000000..5755f62c9f --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/runtime.rs @@ -0,0 +1,82 @@ +use std::{io, time::Duration}; + +use futures_util::Future; + +#[cfg(feature = "async-std-comp")] +use super::async_std; +#[cfg(feature = "tokio-comp")] +use super::tokio; +use super::RedisRuntime; +use crate::types::RedisError; + +#[derive(Clone, Debug)] +pub(crate) enum Runtime { + #[cfg(feature = "tokio-comp")] + Tokio, + #[cfg(feature = "async-std-comp")] + AsyncStd, +} + +impl Runtime { + pub(crate) fn locate() -> Self { + #[cfg(all(feature = "tokio-comp", not(feature = "async-std-comp")))] + { + Runtime::Tokio + } + + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + { + Runtime::AsyncStd + } + + #[cfg(all(feature = "tokio-comp", feature = "async-std-comp"))] + { + if ::tokio::runtime::Handle::try_current().is_ok() { + Runtime::Tokio + } else { + Runtime::AsyncStd + } + } + + #[cfg(all(not(feature = "tokio-comp"), not(feature = "async-std-comp")))] + { + compile_error!("tokio-comp or async-std-comp features required for aio feature") + } + } + + #[allow(dead_code)] + pub(super) fn spawn(&self, f: impl Future + Send + 'static) { + match self { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => tokio::Tokio::spawn(f), + #[cfg(feature = "async-std-comp")] + Runtime::AsyncStd => async_std::AsyncStd::spawn(f), + } + } + + pub(crate) async fn timeout( + &self, + duration: Duration, + future: F, + ) -> Result { + match self { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => ::tokio::time::timeout(duration, future) + .await + .map_err(|_| Elapsed(())), + #[cfg(feature = "async-std-comp")] + Runtime::AsyncStd => ::async_std::future::timeout(duration, future) + .await + .map_err(|_| Elapsed(())), + } + } +} + +#[derive(Debug)] +pub(crate) struct Elapsed(()); + +impl From for RedisError { + fn from(_: Elapsed) -> Self { + io::Error::from(io::ErrorKind::TimedOut).into() + } +} diff --git a/glide-core/redis-rs/redis/src/aio/tokio.rs b/glide-core/redis-rs/redis/src/aio/tokio.rs new file mode 100644 index 0000000000..3a68c0ebfc --- /dev/null +++ b/glide-core/redis-rs/redis/src/aio/tokio.rs @@ -0,0 +1,204 @@ +use super::{AsyncStream, RedisResult, RedisRuntime, SocketAddr}; +use async_trait::async_trait; +#[allow(unused_imports)] // fixes "Duration" unused when built with non-default feature set +use std::{ + future::Future, + io, + pin::Pin, + task::{self, Poll}, + time::Duration, +}; +#[cfg(unix)] +use tokio::net::UnixStream as UnixStreamTokio; +use tokio::{ + io::{AsyncRead, AsyncWrite, ReadBuf}, + net::TcpStream as TcpStreamTokio, +}; + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use native_tls::TlsConnector; + +#[cfg(feature = "tls-rustls")] +use crate::connection::create_rustls_config; +#[cfg(feature = "tls-rustls")] +use std::sync::Arc; +#[cfg(feature = "tls-rustls")] +use tokio_rustls::{client::TlsStream, TlsConnector}; + +#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tokio-rustls-comp")))] +use tokio_native_tls::TlsStream; + +#[cfg(feature = "tokio-rustls-comp")] +use crate::tls::TlsConnParams; + +#[cfg(all(feature = "tokio-native-tls-comp", not(feature = "tls-rustls")))] +use crate::connection::TlsConnParams; + +#[cfg(unix)] +use super::Path; + +#[inline(always)] +async fn connect_tcp(addr: &SocketAddr) -> io::Result { + let socket = TcpStreamTokio::connect(addr).await?; + #[cfg(feature = "tcp_nodelay")] + socket.set_nodelay(true)?; + #[cfg(feature = "keep-alive")] + { + //For now rely on system defaults + const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); + //these are useless error that not going to happen + let std_socket = socket.into_std()?; + let socket2: socket2::Socket = std_socket.into(); + socket2.set_tcp_keepalive(&KEEP_ALIVE)?; + // TCP_USER_TIMEOUT configuration isn't supported across all operation systems + #[cfg(any(target_os = "android", target_os = "fuchsia", target_os = "linux"))] + { + // TODO: Replace this hardcoded timeout with a configurable timeout when https://github.com/redis-rs/redis-rs/issues/1147 is resolved + const DFEAULT_USER_TCP_TIMEOUT: Duration = Duration::from_secs(5); + socket2.set_tcp_user_timeout(Some(DFEAULT_USER_TCP_TIMEOUT))?; + } + TcpStreamTokio::from_std(socket2.into()) + } + + #[cfg(not(feature = "keep-alive"))] + { + Ok(socket) + } +} + +pub(crate) enum Tokio { + /// Represents a Tokio TCP connection. + Tcp(TcpStreamTokio), + /// Represents a Tokio TLS encrypted TCP connection + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + TcpTls(Box>), + /// Represents a Tokio Unix connection. + #[cfg(unix)] + Unix(UnixStreamTokio), +} + +impl AsyncWrite for Tokio { + fn poll_write( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &[u8], + ) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_write(cx, buf), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_write(cx, buf), + } + } + + fn poll_flush(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_flush(cx), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_flush(cx), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_flush(cx), + } + } + + fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_shutdown(cx), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_shutdown(cx), + } + } +} + +impl AsyncRead for Tokio { + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match &mut *self { + Tokio::Tcp(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(r) => Pin::new(r).poll_read(cx, buf), + #[cfg(unix)] + Tokio::Unix(r) => Pin::new(r).poll_read(cx, buf), + } + } +} + +#[async_trait] +impl RedisRuntime for Tokio { + async fn connect_tcp(socket_addr: SocketAddr) -> RedisResult { + Ok(connect_tcp(&socket_addr).await.map(Tokio::Tcp)?) + } + + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + _: &Option, + ) -> RedisResult { + let tls_connector: tokio_native_tls::TlsConnector = if insecure { + TlsConnector::builder() + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true) + .use_sni(false) + .build()? + } else { + TlsConnector::new()? + } + .into(); + Ok(tls_connector + .connect(hostname, connect_tcp(&socket_addr).await?) + .await + .map(|con| Tokio::TcpTls(Box::new(con)))?) + } + + #[cfg(feature = "tls-rustls")] + async fn connect_tcp_tls( + hostname: &str, + socket_addr: SocketAddr, + insecure: bool, + tls_params: &Option, + ) -> RedisResult { + let config = create_rustls_config(insecure, tls_params.clone())?; + let tls_connector = TlsConnector::from(Arc::new(config)); + + Ok(tls_connector + .connect( + rustls_pki_types::ServerName::try_from(hostname)?.to_owned(), + connect_tcp(&socket_addr).await?, + ) + .await + .map(|con| Tokio::TcpTls(Box::new(con)))?) + } + + #[cfg(unix)] + async fn connect_unix(path: &Path) -> RedisResult { + Ok(UnixStreamTokio::connect(path).await.map(Tokio::Unix)?) + } + + #[cfg(feature = "tokio-comp")] + fn spawn(f: impl Future + Send + 'static) { + tokio::spawn(f); + } + + #[cfg(not(feature = "tokio-comp"))] + fn spawn(_: impl Future + Send + 'static) { + unreachable!() + } + + fn boxed(self) -> Pin> { + match self { + Tokio::Tcp(x) => Box::pin(x), + #[cfg(any(feature = "tokio-native-tls-comp", feature = "tokio-rustls-comp"))] + Tokio::TcpTls(x) => Box::pin(x), + #[cfg(unix)] + Tokio::Unix(x) => Box::pin(x), + } + } +} diff --git a/glide-core/redis-rs/redis/src/client.rs b/glide-core/redis-rs/redis/src/client.rs new file mode 100644 index 0000000000..5e3f144e71 --- /dev/null +++ b/glide-core/redis-rs/redis/src/client.rs @@ -0,0 +1,855 @@ +use std::time::Duration; + +#[cfg(feature = "aio")] +use crate::aio::DisconnectNotifier; + +use crate::{ + connection::{connect, Connection, ConnectionInfo, ConnectionLike, IntoConnectionInfo}, + push_manager::PushInfo, + types::{RedisResult, Value}, +}; +#[cfg(feature = "aio")] +use std::net::IpAddr; +#[cfg(feature = "aio")] +use std::net::SocketAddr; +#[cfg(feature = "aio")] +use std::pin::Pin; +use tokio::sync::mpsc; + +#[cfg(feature = "tls-rustls")] +use crate::tls::{inner_build_with_tls, TlsCertificates}; + +/// The client type. +#[derive(Debug, Clone)] +pub struct Client { + pub(crate) connection_info: ConnectionInfo, +} + +/// The client acts as connector to the redis server. By itself it does not +/// do much other than providing a convenient way to fetch a connection from +/// it. In the future the plan is to provide a connection pool in the client. +/// +/// When opening a client a URL in the following format should be used: +/// +/// ```plain +/// redis://host:port/db +/// ``` +/// +/// Example usage:: +/// +/// ```rust,no_run +/// let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// let con = client.get_connection(None).unwrap(); +/// ``` +impl Client { + /// Connects to a redis server and returns a client. This does not + /// actually open a connection yet but it does perform some basic + /// checks on the URL that might make the operation fail. + pub fn open(params: T) -> RedisResult { + Ok(Client { + connection_info: params.into_connection_info()?, + }) + } + + /// Instructs the client to actually connect to redis and returns a + /// connection object. The connection object can be used to send + /// commands to the server. This can fail with a variety of errors + /// (like unreachable host) so it's important that you handle those + /// errors. + pub fn get_connection( + &self, + _push_sender: Option>, + ) -> RedisResult { + connect(&self.connection_info, None) + } + + /// Instructs the client to actually connect to redis with specified + /// timeout and returns a connection object. The connection object + /// can be used to send commands to the server. This can fail with + /// a variety of errors (like unreachable host) so it's important + /// that you handle those errors. + pub fn get_connection_with_timeout(&self, timeout: Duration) -> RedisResult { + connect(&self.connection_info, Some(timeout)) + } + + /// Returns a reference of client connection info object. + pub fn get_connection_info(&self) -> &ConnectionInfo { + &self.connection_info + } +} + +/// Glide-specific connection options +#[derive(Clone, Default)] +pub struct GlideConnectionOptions { + /// Queue for RESP3 notifications + pub push_sender: Option>, + #[cfg(feature = "aio")] + /// Passive disconnect notifier + pub disconnect_notifier: Option>, +} + +/// To enable async support you need to chose one of the supported runtimes and active its +/// corresponding feature: `tokio-comp` or `async-std-comp` +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +impl Client { + /// Returns an async connection from the client. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[deprecated( + note = "aio::Connection is deprecated. Use client::get_multiplexed_async_connection instead." + )] + #[allow(deprecated)] + pub async fn get_async_connection( + &self, + _push_sender: Option>, + ) -> RedisResult { + let (con, _ip) = match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => { + self.get_simple_async_connection::(None) + .await? + } + #[cfg(feature = "async-std-comp")] + Runtime::AsyncStd => { + self.get_simple_async_connection::(None) + .await? + } + }; + + crate::aio::Connection::new(&self.connection_info.redis, con).await + } + + /// Returns an async connection from the client. + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + #[deprecated( + note = "aio::Connection is deprecated. Use client::get_multiplexed_tokio_connection instead." + )] + #[allow(deprecated)] + pub async fn get_tokio_connection(&self) -> RedisResult { + use crate::aio::RedisRuntime; + Ok( + crate::aio::connect::(&self.connection_info, None) + .await? + .map(RedisRuntime::boxed), + ) + } + + /// Returns an async connection from the client. + #[cfg(feature = "async-std-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] + #[deprecated( + note = "aio::Connection is deprecated. Use client::get_multiplexed_async_std_connection instead." + )] + #[allow(deprecated)] + pub async fn get_async_std_connection(&self) -> RedisResult { + use crate::aio::RedisRuntime; + Ok( + crate::aio::connect::(&self.connection_info, None) + .await? + .map(RedisRuntime::boxed), + ) + } + + /// Returns an async connection from the client. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) + )] + pub async fn get_multiplexed_async_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + self.get_multiplexed_async_connection_with_timeouts( + std::time::Duration::MAX, + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Returns an async connection from the client. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) + )] + pub async fn get_multiplexed_async_connection_with_timeouts( + &self, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + let result = match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + rt @ Runtime::Tokio => { + rt.timeout( + connection_timeout, + self.get_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ), + ) + .await + } + #[cfg(feature = "async-std-comp")] + rt @ Runtime::AsyncStd => { + rt.timeout( + connection_timeout, + self.get_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ), + ) + .await + } + }; + + match result { + Ok(Ok(connection)) => Ok(connection), + Ok(Err(e)) => Err(e), + Err(elapsed) => Err(elapsed.into()), + } + .map(|(conn, _ip)| conn) + } + + /// For TCP connections: returns (async connection, Some(the direct IP address)) + /// For Unix connections, returns (async connection, None) + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + #[cfg_attr( + docsrs, + doc(cfg(any(feature = "tokio-comp", feature = "async-std-comp"))) + )] + pub async fn get_multiplexed_async_connection_and_ip( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> { + match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + Runtime::Tokio => { + self.get_multiplexed_async_connection_inner::( + Duration::MAX, + None, + glide_connection_options, + ) + .await + } + #[cfg(feature = "async-std-comp")] + Runtime::AsyncStd => { + self.get_multiplexed_async_connection_inner::( + Duration::MAX, + None, + glide_connection_options, + ) + .await + } + } + } + + /// Returns an async multiplexed connection from the client. + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn get_multiplexed_tokio_connection_with_response_timeouts( + &self, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + let result = Runtime::locate() + .timeout( + connection_timeout, + self.get_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ), + ) + .await; + + match result { + Ok(Ok((connection, _ip))) => Ok(connection), + Ok(Err(e)) => Err(e), + Err(elapsed) => Err(elapsed.into()), + } + } + + /// Returns an async multiplexed connection from the client. + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn get_multiplexed_tokio_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + self.get_multiplexed_tokio_connection_with_response_timeouts( + std::time::Duration::MAX, + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Returns an async multiplexed connection from the client. + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "async-std-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] + pub async fn get_multiplexed_async_std_connection_with_timeouts( + &self, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + let result = Runtime::locate() + .timeout( + connection_timeout, + self.get_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ), + ) + .await; + + match result { + Ok(Ok((connection, _ip))) => Ok(connection), + Ok(Err(e)) => Err(e), + Err(elapsed) => Err(elapsed.into()), + } + } + + /// Returns an async multiplexed connection from the client. + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "async-std-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] + pub async fn get_multiplexed_async_std_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult { + self.get_multiplexed_async_std_connection_with_timeouts( + std::time::Duration::MAX, + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Returns an async multiplexed connection from the client and a future which must be polled + /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// The multiplexer will return a timeout error on any request that takes longer then `response_timeout`. + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn create_multiplexed_tokio_connection_with_response_timeout( + &self, + response_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<( + crate::aio::MultiplexedConnection, + impl std::future::Future, + )> { + self.create_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ) + .await + .map(|(conn, driver, _ip)| (conn, driver)) + } + + /// Returns an async multiplexed connection from the client and a future which must be polled + /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "tokio-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "tokio-comp")))] + pub async fn create_multiplexed_tokio_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<( + crate::aio::MultiplexedConnection, + impl std::future::Future, + )> { + self.create_multiplexed_tokio_connection_with_response_timeout( + std::time::Duration::MAX, + glide_connection_options, + ) + .await + .map(|conn_res| (conn_res.0, conn_res.1)) + } + + /// Returns an async multiplexed connection from the client and a future which must be polled + /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// The multiplexer will return a timeout error on any request that takes longer then `response_timeout`. + #[cfg(feature = "async-std-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] + pub async fn create_multiplexed_async_std_connection_with_response_timeout( + &self, + response_timeout: std::time::Duration, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<( + crate::aio::MultiplexedConnection, + impl std::future::Future, + )> { + self.create_multiplexed_async_connection_inner::( + response_timeout, + None, + glide_connection_options, + ) + .await + .map(|(conn, driver, _ip)| (conn, driver)) + } + + /// Returns an async multiplexed connection from the client and a future which must be polled + /// to drive any requests submitted to it (see `get_multiplexed_tokio_connection`). + /// + /// A multiplexed connection can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + #[cfg(feature = "async-std-comp")] + #[cfg_attr(docsrs, doc(cfg(feature = "async-std-comp")))] + pub async fn create_multiplexed_async_std_connection( + &self, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<( + crate::aio::MultiplexedConnection, + impl std::future::Future, + )> { + self.create_multiplexed_async_std_connection_with_response_timeout( + std::time::Duration::MAX, + glide_connection_options, + ) + .await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + #[deprecated(note = "use get_connection_manager instead")] + pub async fn get_tokio_connection_manager(&self) -> RedisResult { + crate::aio::ConnectionManager::new(self.clone()).await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + pub async fn get_connection_manager(&self) -> RedisResult { + crate::aio::ConnectionManager::new(self.clone()).await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + #[deprecated(note = "use get_connection_manager_with_backoff instead")] + pub async fn get_tokio_connection_manager_with_backoff( + &self, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + ) -> RedisResult { + self.get_connection_manager_with_backoff_and_timeouts( + exponent_base, + factor, + number_of_retries, + std::time::Duration::MAX, + std::time::Duration::MAX, + ) + .await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + #[deprecated(note = "use get_connection_manager_with_backoff_and_timeouts instead")] + pub async fn get_tokio_connection_manager_with_backoff_and_timeouts( + &self, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + ) -> RedisResult { + crate::aio::ConnectionManager::new_with_backoff_and_timeouts( + self.clone(), + exponent_base, + factor, + number_of_retries, + response_timeout, + connection_timeout, + ) + .await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + pub async fn get_connection_manager_with_backoff_and_timeouts( + &self, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + ) -> RedisResult { + crate::aio::ConnectionManager::new_with_backoff_and_timeouts( + self.clone(), + exponent_base, + factor, + number_of_retries, + response_timeout, + connection_timeout, + ) + .await + } + + /// Returns an async [`ConnectionManager`][connection-manager] from the client. + /// + /// The connection manager wraps a + /// [`MultiplexedConnection`][multiplexed-connection]. If a command to that + /// connection fails with a connection error, then a new connection is + /// established in the background and the error is returned to the caller. + /// + /// This means that on connection loss at least one command will fail, but + /// the connection will be re-established automatically if possible. Please + /// refer to the [`ConnectionManager`][connection-manager] docs for + /// detailed reconnecting behavior. + /// + /// A connection manager can be cloned, allowing requests to be be sent concurrently + /// on the same underlying connection (tcp/unix socket). + /// + /// [connection-manager]: aio/struct.ConnectionManager.html + /// [multiplexed-connection]: aio/struct.MultiplexedConnection.html + #[cfg(feature = "connection-manager")] + #[cfg_attr(docsrs, doc(cfg(feature = "connection-manager")))] + pub async fn get_connection_manager_with_backoff( + &self, + exponent_base: u64, + factor: u64, + number_of_retries: usize, + ) -> RedisResult { + crate::aio::ConnectionManager::new_with_backoff( + self.clone(), + exponent_base, + factor, + number_of_retries, + ) + .await + } + + pub(crate) async fn get_multiplexed_async_connection_inner( + &self, + response_timeout: std::time::Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<(crate::aio::MultiplexedConnection, Option)> + where + T: crate::aio::RedisRuntime, + { + let (connection, driver, ip) = self + .create_multiplexed_async_connection_inner::( + response_timeout, + socket_addr, + glide_connection_options, + ) + .await?; + T::spawn(driver); + Ok((connection, ip)) + } + + async fn create_multiplexed_async_connection_inner( + &self, + response_timeout: std::time::Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult<( + crate::aio::MultiplexedConnection, + impl std::future::Future, + Option, + )> + where + T: crate::aio::RedisRuntime, + { + let (con, ip) = self.get_simple_async_connection::(socket_addr).await?; + crate::aio::MultiplexedConnection::new_with_response_timeout( + &self.connection_info, + con, + response_timeout, + glide_connection_options, + ) + .await + .map(|res| (res.0, res.1, ip)) + } + + async fn get_simple_async_connection( + &self, + socket_addr: Option, + ) -> RedisResult<( + Pin>, + Option, + )> + where + T: crate::aio::RedisRuntime, + { + let (conn, ip) = + crate::aio::connect_simple::(&self.connection_info, socket_addr).await?; + Ok((conn.boxed(), ip)) + } + + #[cfg(feature = "connection-manager")] + pub(crate) fn connection_info(&self) -> &ConnectionInfo { + &self.connection_info + } + + /// Constructs a new `Client` with parameters necessary to create a TLS connection. + /// + /// - `conn_info` - URL using the `rediss://` scheme. + /// - `tls_certs` - `TlsCertificates` structure containing: + /// -- `client_tls` - Optional `ClientTlsConfig` containing byte streams for + /// --- `client_cert` - client's byte stream containing client certificate in PEM format + /// --- `client_key` - client's byte stream containing private key in PEM format + /// -- `root_cert` - Optional byte stream yielding PEM formatted file for root certificates. + /// + /// If `ClientTlsConfig` ( cert+key pair ) is not provided, then client-side authentication is not enabled. + /// If `root_cert` is not provided, then system root certificates are used instead. + /// + /// # Examples + /// + /// ```no_run + /// use std::{fs::File, io::{BufReader, Read}}; + /// + /// use redis::{Client, AsyncCommands as _, TlsCertificates, ClientTlsConfig}; + /// + /// async fn do_redis_code( + /// url: &str, + /// root_cert_file: &str, + /// cert_file: &str, + /// key_file: &str + /// ) -> redis::RedisResult<()> { + /// let root_cert_file = File::open(root_cert_file).expect("cannot open private cert file"); + /// let mut root_cert_vec = Vec::new(); + /// BufReader::new(root_cert_file) + /// .read_to_end(&mut root_cert_vec) + /// .expect("Unable to read ROOT cert file"); + /// + /// let cert_file = File::open(cert_file).expect("cannot open private cert file"); + /// let mut client_cert_vec = Vec::new(); + /// BufReader::new(cert_file) + /// .read_to_end(&mut client_cert_vec) + /// .expect("Unable to read client cert file"); + /// + /// let key_file = File::open(key_file).expect("cannot open private key file"); + /// let mut client_key_vec = Vec::new(); + /// BufReader::new(key_file) + /// .read_to_end(&mut client_key_vec) + /// .expect("Unable to read client key file"); + /// + /// let client = Client::build_with_tls( + /// url, + /// TlsCertificates { + /// client_tls: Some(ClientTlsConfig{ + /// client_cert: client_cert_vec, + /// client_key: client_key_vec, + /// }), + /// root_cert: Some(root_cert_vec), + /// } + /// ) + /// .expect("Unable to build client"); + /// + /// let connection_info = client.get_connection_info(); + /// + /// println!(">>> connection info: {connection_info:?}"); + /// + /// let mut con = client.get_async_connection(None).await?; + /// + /// con.set("key1", b"foo").await?; + /// + /// redis::cmd("SET") + /// .arg(&["key2", "bar"]) + /// .query_async(&mut con) + /// .await?; + /// + /// let result = redis::cmd("MGET") + /// .arg(&["key1", "key2"]) + /// .query_async(&mut con) + /// .await; + /// assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + /// println!("Result from MGET: {result:?}"); + /// + /// Ok(()) + /// } + /// ``` + #[cfg(feature = "tls-rustls")] + pub fn build_with_tls( + conn_info: C, + tls_certs: TlsCertificates, + ) -> RedisResult { + let connection_info = conn_info.into_connection_info()?; + + inner_build_with_tls(connection_info, tls_certs) + } + + /// Returns an async receiver for pub-sub messages. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + // TODO - do we want to type-erase pubsub using a trait, to allow us to replace it with a different implementation later? + pub async fn get_async_pubsub(&self) -> RedisResult { + #[allow(deprecated)] + self.get_async_connection(None) + .await + .map(|connection| connection.into_pubsub()) + } + + /// Returns an async receiver for monitor messages. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + // TODO - do we want to type-erase monitor using a trait, to allow us to replace it with a different implementation later? + pub async fn get_async_monitor(&self) -> RedisResult { + #[allow(deprecated)] + self.get_async_connection(None) + .await + .map(|connection| connection.into_monitor()) + } +} + +#[cfg(feature = "aio")] +use crate::aio::Runtime; + +impl ConnectionLike for Client { + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + self.get_connection(None)?.req_packed_command(cmd) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + self.get_connection(None)? + .req_packed_commands(cmd, offset, count) + } + + fn get_db(&self) -> i64 { + self.connection_info.redis.db + } + + fn check_connection(&mut self) -> bool { + if let Ok(mut conn) = self.get_connection(None) { + conn.check_connection() + } else { + false + } + } + + fn is_open(&self) -> bool { + if let Ok(conn) = self.get_connection(None) { + conn.is_open() + } else { + false + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn regression_293_parse_ipv6_with_interface() { + assert!(Client::open(("fe80::cafe:beef%eno1", 6379)).is_ok()); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster.rs b/glide-core/redis-rs/redis/src/cluster.rs new file mode 100644 index 0000000000..f9c76f5161 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster.rs @@ -0,0 +1,1076 @@ +//! This module extends the library to support Redis Cluster. +//! +//! Note that this module does not currently provide pubsub +//! functionality. +//! +//! # Example +//! ```rust,no_run +//! use redis::Commands; +//! use redis::cluster::ClusterClient; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let client = ClusterClient::new(nodes).unwrap(); +//! let mut connection = client.get_connection(None).unwrap(); +//! +//! let _: () = connection.set("test", "test_data").unwrap(); +//! let rv: String = connection.get("test").unwrap(); +//! +//! assert_eq!(rv, "test_data"); +//! ``` +//! +//! # Pipelining +//! ```rust,no_run +//! use redis::Commands; +//! use redis::cluster::{cluster_pipe, ClusterClient}; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let client = ClusterClient::new(nodes).unwrap(); +//! let mut connection = client.get_connection(None).unwrap(); +//! +//! let key = "test"; +//! +//! let _: () = cluster_pipe() +//! .rpush(key, "123").ignore() +//! .ltrim(key, -10, -1).ignore() +//! .expire(key, 60).ignore() +//! .query(&mut connection).unwrap(); +//! ``` +use std::cell::RefCell; +use std::collections::HashSet; +use std::str::FromStr; +use std::thread; +use std::time::Duration; + +use rand::{seq::IteratorRandom, thread_rng}; + +use crate::cluster_pipeline::UNROUTABLE_ERROR; +use crate::cluster_routing::{ + MultipleNodeRoutingInfo, ResponsePolicy, Routable, SingleNodeRoutingInfo, +}; +use crate::cluster_slotmap::SlotMap; +use crate::cluster_topology::parse_and_count_slots; +use crate::cmd::{cmd, Cmd}; +use crate::connection::{ + connect, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, RedisConnectionInfo, +}; +use crate::parser::parse_redis_value; +use crate::types::{ErrorKind, HashMap, RedisError, RedisResult, Value}; +pub use crate::TlsMode; // Pub for backwards compatibility +use crate::{ + cluster_client::ClusterParams, + cluster_routing::{Redirect, Route, RoutingInfo}, + IntoConnectionInfo, PushInfo, +}; + +pub use crate::cluster_client::{ClusterClient, ClusterClientBuilder}; +pub use crate::cluster_pipeline::{cluster_pipe, ClusterPipeline}; + +use tokio::sync::mpsc; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +#[cfg(not(feature = "tls-rustls"))] +use crate::connection::TlsConnParams; + +#[derive(Clone)] +enum Input<'a> { + Slice { + cmd: &'a [u8], + routable: Value, + }, + Cmd(&'a Cmd), + Commands { + cmd: &'a [u8], + route: SingleNodeRoutingInfo, + offset: usize, + count: usize, + }, +} + +impl<'a> Input<'a> { + fn send(&'a self, connection: &mut impl ConnectionLike) -> RedisResult { + match self { + Input::Slice { cmd, routable: _ } => { + connection.req_packed_command(cmd).map(Output::Single) + } + Input::Cmd(cmd) => connection.req_command(cmd).map(Output::Single), + Input::Commands { + cmd, + route: _, + offset, + count, + } => connection + .req_packed_commands(cmd, *offset, *count) + .map(Output::Multi), + } + } +} + +impl<'a> Routable for Input<'a> { + fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + match self { + Input::Slice { cmd: _, routable } => routable.arg_idx(idx), + Input::Cmd(cmd) => cmd.arg_idx(idx), + Input::Commands { .. } => None, + } + } + + fn position(&self, candidate: &[u8]) -> Option { + match self { + Input::Slice { cmd: _, routable } => routable.position(candidate), + Input::Cmd(cmd) => cmd.position(candidate), + Input::Commands { .. } => None, + } + } +} + +enum Output { + Single(Value), + Multi(Vec), +} + +impl From for Value { + fn from(value: Output) -> Self { + match value { + Output::Single(value) => value, + Output::Multi(values) => Value::Array(values), + } + } +} + +impl From for Vec { + fn from(value: Output) -> Self { + match value { + Output::Single(value) => vec![value], + Output::Multi(values) => values, + } + } +} + +/// Implements the process of connecting to a Redis server +/// and obtaining and configuring a connection handle. +pub trait Connect: Sized { + /// Connect to a node, returning handle for command execution. + fn connect(info: T, timeout: Option) -> RedisResult + where + T: IntoConnectionInfo; + + /// Sends an already encoded (packed) command into the TCP socket and + /// does not read a response. This is useful for commands like + /// `MONITOR` which yield multiple items. This needs to be used with + /// care because it changes the state of the connection. + fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()>; + + /// Sets the write timeout for the connection. + /// + /// If the provided value is `None`, then `send_packed_command` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + fn set_write_timeout(&self, dur: Option) -> RedisResult<()>; + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `recv_response` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + fn set_read_timeout(&self, dur: Option) -> RedisResult<()>; + + /// Fetches a single response from the connection. This is useful + /// if used in combination with `send_packed_command`. + fn recv_response(&mut self) -> RedisResult; +} + +impl Connect for Connection { + fn connect(info: T, timeout: Option) -> RedisResult + where + T: IntoConnectionInfo, + { + connect(&info.into_connection_info()?, timeout) + } + + fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> { + Self::send_packed_command(self, cmd) + } + + fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + Self::set_write_timeout(self, dur) + } + + fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + Self::set_read_timeout(self, dur) + } + + fn recv_response(&mut self) -> RedisResult { + Self::recv_response(self) + } +} + +/// This represents a Redis Cluster connection. It stores the +/// underlying connections maintained for each node in the cluster, as well +/// as common parameters for connecting to nodes and executing commands. +pub struct ClusterConnection { + initial_nodes: Vec, + connections: RefCell>, + slots: RefCell, + auto_reconnect: RefCell, + read_timeout: RefCell>, + write_timeout: RefCell>, + cluster_params: ClusterParams, +} + +impl ClusterConnection +where + C: ConnectionLike + Connect, +{ + pub(crate) fn new( + cluster_params: ClusterParams, + initial_nodes: Vec, + _push_sender: Option>, + ) -> RedisResult { + let connection = Self { + connections: RefCell::new(HashMap::new()), + slots: RefCell::new(SlotMap::new(vec![], cluster_params.read_from_replicas)), + auto_reconnect: RefCell::new(true), + cluster_params, + read_timeout: RefCell::new(None), + write_timeout: RefCell::new(None), + initial_nodes: initial_nodes.to_vec(), + }; + connection.create_initial_connections()?; + + Ok(connection) + } + + /// Set an auto reconnect attribute. + /// Default value is true; + pub fn set_auto_reconnect(&self, value: bool) { + let mut auto_reconnect = self.auto_reconnect.borrow_mut(); + *auto_reconnect = value; + } + + /// Sets the write timeout for the connection. + /// + /// If the provided value is `None`, then `send_packed_command` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + // Check if duration is valid before updating local value. + if dur.is_some() && dur.unwrap().is_zero() { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Duration should be None or non-zero.", + ))); + } + + let mut t = self.write_timeout.borrow_mut(); + *t = dur; + let connections = self.connections.borrow(); + for conn in connections.values() { + conn.set_write_timeout(dur)?; + } + Ok(()) + } + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `recv_response` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + // Check if duration is valid before updating local value. + if dur.is_some() && dur.unwrap().is_zero() { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Duration should be None or non-zero.", + ))); + } + + let mut t = self.read_timeout.borrow_mut(); + *t = dur; + let connections = self.connections.borrow(); + for conn in connections.values() { + conn.set_read_timeout(dur)?; + } + Ok(()) + } + + /// Check that all connections it has are available (`PING` internally). + #[doc(hidden)] + pub fn check_connection(&mut self) -> bool { + ::check_connection(self) + } + + pub(crate) fn execute_pipeline(&mut self, pipe: &ClusterPipeline) -> RedisResult> { + self.send_recv_and_retry_cmds(pipe.commands()) + } + + /// Returns the connection status. + /// + /// The connection is open until any `read_response` call recieved an + /// invalid response from the server (most likely a closed or dropped + /// connection, otherwise a Redis protocol error). When using unix + /// sockets the connection is open until writing a command failed with a + /// `BrokenPipe` error. + fn create_initial_connections(&self) -> RedisResult<()> { + let mut connections = HashMap::with_capacity(self.initial_nodes.len()); + + for info in self.initial_nodes.iter() { + let addr = info.addr.to_string(); + + if let Ok(mut conn) = self.connect(&addr) { + if conn.check_connection() { + connections.insert(addr, conn); + break; + } + } + } + + if connections.is_empty() { + return Err(RedisError::from(( + ErrorKind::IoError, + "It failed to check startup nodes.", + ))); + } + + *self.connections.borrow_mut() = connections; + self.refresh_slots()?; + Ok(()) + } + + // Query a node to discover slot-> master mappings. + fn refresh_slots(&self) -> RedisResult<()> { + let mut slots = self.slots.borrow_mut(); + *slots = self.create_new_slots()?; + + let mut nodes = slots.values().flatten().collect::>(); + nodes.sort_unstable(); + nodes.dedup(); + + let mut connections = self.connections.borrow_mut(); + *connections = nodes + .into_iter() + .filter_map(|addr| { + if connections.contains_key(addr) { + let mut conn = connections.remove(addr).unwrap(); + if conn.check_connection() { + return Some((addr.to_string(), conn)); + } + } + + if let Ok(mut conn) = self.connect(addr) { + if conn.check_connection() { + return Some((addr.to_string(), conn)); + } + } + + None + }) + .collect(); + + Ok(()) + } + + fn create_new_slots(&self) -> RedisResult { + let mut connections = self.connections.borrow_mut(); + let mut rng = thread_rng(); + let len = connections.len(); + let samples = connections.iter_mut().choose_multiple(&mut rng, len); + let mut result = Err(RedisError::from(( + ErrorKind::ResponseError, + "Slot refresh error.", + "didn't get any slots from server".to_string(), + ))); + for (addr, conn) in samples { + let value = conn.req_command(&slot_cmd())?; + let addr = addr.split(':').next().ok_or(RedisError::from(( + ErrorKind::ClientError, + "can't parse node address", + )))?; + match parse_and_count_slots(&value, self.cluster_params.tls, addr).map(|slots_data| { + SlotMap::new(slots_data.1, self.cluster_params.read_from_replicas) + }) { + Ok(new_slots) => { + result = Ok(new_slots); + break; + } + Err(err) => result = Err(err), + } + } + result + } + + fn connect(&self, node: &str) -> RedisResult { + let info = get_connection_info(node, self.cluster_params.clone())?; + + let mut conn = C::connect(info, Some(self.cluster_params.connection_timeout))?; + if self.cluster_params.read_from_replicas + != crate::cluster_slotmap::ReadFromReplicaStrategy::AlwaysFromPrimary + { + // If READONLY is sent to primary nodes, it will have no effect + cmd("READONLY").query(&mut conn)?; + } + conn.set_read_timeout(*self.read_timeout.borrow())?; + conn.set_write_timeout(*self.write_timeout.borrow())?; + Ok(conn) + } + + fn get_connection<'a>( + &self, + connections: &'a mut HashMap, + route: &Route, + ) -> RedisResult<(String, &'a mut C)> { + let slots = self.slots.borrow(); + if let Some(addr) = slots.slot_addr_for_route(route) { + Ok(( + addr.to_string(), + self.get_connection_by_addr(connections, addr)?, + )) + } else { + // try a random node next. This is safe if slots are involved + // as a wrong node would reject the request. + Ok(get_random_connection(connections)) + } + } + + fn get_connection_by_addr<'a>( + &self, + connections: &'a mut HashMap, + addr: &str, + ) -> RedisResult<&'a mut C> { + if connections.contains_key(addr) { + Ok(connections.get_mut(addr).unwrap()) + } else { + // Create new connection. + // TODO: error handling + let conn = self.connect(addr)?; + Ok(connections.entry(addr.to_string()).or_insert(conn)) + } + } + + fn get_addr_for_cmd(&self, cmd: &Cmd) -> RedisResult { + let slots = self.slots.borrow(); + + let addr_for_slot = |route: Route| -> RedisResult { + let slot_addr = slots + .slot_addr_for_route(&route) + .ok_or((ErrorKind::ClusterDown, "Missing slot coverage"))?; + Ok(slot_addr.to_string()) + }; + + match RoutingInfo::for_routable(cmd) { + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + | Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::RandomPrimary)) => { + Ok(addr_for_slot(Route::new_random_primary())?) + } + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(route))) => { + Ok(addr_for_slot(route)?) + } + _ => fail!(UNROUTABLE_ERROR), + } + } + + fn map_cmds_to_nodes(&self, cmds: &[Cmd]) -> RedisResult> { + let mut cmd_map: HashMap = HashMap::new(); + + for (idx, cmd) in cmds.iter().enumerate() { + let addr = self.get_addr_for_cmd(cmd)?; + let nc = cmd_map + .entry(addr.clone()) + .or_insert_with(|| NodeCmd::new(addr)); + nc.indexes.push(idx); + cmd.write_packed_command(&mut nc.pipe); + } + + let mut result = Vec::new(); + for (_, v) in cmd_map.drain() { + result.push(v); + } + Ok(result) + } + + fn execute_on_all<'a>( + &'a self, + input: Input, + addresses: HashSet<&'a str>, + connections: &'a mut HashMap, + ) -> Vec> { + addresses + .into_iter() + .map(|addr| { + let connection = self.get_connection_by_addr(connections, addr)?; + match input { + Input::Slice { cmd, routable: _ } => connection.req_packed_command(cmd), + Input::Cmd(cmd) => connection.req_command(cmd), + Input::Commands { + cmd: _, + route: _, + offset: _, + count: _, + } => Err(( + ErrorKind::ClientError, + "req_packed_commands isn't supported with multiple nodes", + ) + .into()), + } + .map(|res| (addr, res)) + }) + .collect() + } + + fn execute_on_all_nodes<'a>( + &'a self, + input: Input, + slots: &'a mut SlotMap, + connections: &'a mut HashMap, + ) -> Vec> { + self.execute_on_all(input, slots.addresses_for_all_nodes(), connections) + } + + fn execute_on_all_primaries<'a>( + &'a self, + input: Input, + slots: &'a mut SlotMap, + connections: &'a mut HashMap, + ) -> Vec> { + self.execute_on_all(input, slots.addresses_for_all_primaries(), connections) + } + + fn execute_multi_slot<'a, 'b>( + &'a self, + input: Input, + slots: &'a mut SlotMap, + connections: &'a mut HashMap, + routes: &'b [(Route, Vec)], + ) -> Vec> + where + 'b: 'a, + { + slots + .addresses_for_multi_slot(routes) + .enumerate() + .map(|(index, addr)| { + let addr = addr.ok_or(RedisError::from(( + ErrorKind::IoError, + "Couldn't find connection", + )))?; + let connection = self.get_connection_by_addr(connections, addr)?; + let (_, indices) = routes.get(index).unwrap(); + let cmd = + crate::cluster_routing::command_for_multi_slot_indices(&input, indices.iter()); + connection.req_command(&cmd).map(|res| (addr, res)) + }) + .collect() + } + + fn execute_on_multiple_nodes( + &self, + input: Input, + routing: MultipleNodeRoutingInfo, + response_policy: Option, + ) -> RedisResult { + let mut connections = self.connections.borrow_mut(); + let mut slots = self.slots.borrow_mut(); + + let results = match &routing { + MultipleNodeRoutingInfo::MultiSlot(routes) => { + self.execute_multi_slot(input, &mut slots, &mut connections, routes) + } + MultipleNodeRoutingInfo::AllMasters => { + self.execute_on_all_primaries(input, &mut slots, &mut connections) + } + MultipleNodeRoutingInfo::AllNodes => { + self.execute_on_all_nodes(input, &mut slots, &mut connections) + } + }; + + match response_policy { + Some(ResponsePolicy::AllSucceeded) => { + for result in results { + result?; + } + + Ok(Value::Okay) + } + Some(ResponsePolicy::OneSucceeded) => { + let mut last_failure = None; + + for result in results { + match result { + Ok((_, val)) => return Ok(val), + Err(err) => last_failure = Some(err), + } + } + + Err(last_failure + .unwrap_or_else(|| (ErrorKind::IoError, "Couldn't find a connection").into())) + } + Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty) => { + // Attempt to return the first result that isn't `Nil` or an error. + // If no such response is found and all servers returned `Nil`, it indicates that all shards are empty, so return `Nil`. + // If we received only errors, return the last received error. + // If we received a mix of errors and `Nil`s, we can't determine if all shards are empty, + // thus we return the last received error instead of `Nil`. + let mut last_failure = None; + let num_of_results = results.len(); + let mut nil_counter = 0; + for result in results { + match result.map(|(_, res)| res) { + Ok(Value::Nil) => nil_counter += 1, + Ok(val) => return Ok(val), + Err(err) => last_failure = Some(err), + } + } + if nil_counter == num_of_results { + Ok(Value::Nil) + } else { + Err(last_failure.unwrap_or_else(|| { + (ErrorKind::IoError, "Couldn't find a connection").into() + })) + } + } + Some(ResponsePolicy::Aggregate(op)) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + crate::cluster_routing::aggregate(results, op) + } + Some(ResponsePolicy::AggregateLogical(op)) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + crate::cluster_routing::logical_aggregate(results, op) + } + Some(ResponsePolicy::CombineArrays) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + match routing { + MultipleNodeRoutingInfo::MultiSlot(vec) => { + crate::cluster_routing::combine_and_sort_array_results( + results, + vec.iter().map(|(_, indices)| indices), + ) + } + _ => crate::cluster_routing::combine_array_results(results), + } + } + Some(ResponsePolicy::CombineMaps) => { + let results = results + .into_iter() + .map(|res| res.map(|(_, val)| val)) + .collect::>>()?; + crate::cluster_routing::combine_map_results(results) + } + Some(ResponsePolicy::Special) | None => { + // This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user. + // TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes. + let results = results + .into_iter() + .map(|result| { + result.map(|(addr, val)| (Value::BulkString(addr.as_bytes().to_vec()), val)) + }) + .collect::>>()?; + Ok(Value::Map(results)) + } + } + } + + #[allow(clippy::unnecessary_unwrap)] + fn request(&self, input: Input) -> RedisResult { + let route_option = match &input { + Input::Slice { cmd: _, routable } => RoutingInfo::for_routable(routable), + Input::Cmd(cmd) => RoutingInfo::for_routable(*cmd), + Input::Commands { + cmd: _, + route, + offset: _, + count: _, + } => Some(RoutingInfo::SingleNode(route.clone())), + }; + let single_node_routing = match route_option { + Some(RoutingInfo::SingleNode(single_node_routing)) => single_node_routing, + Some(RoutingInfo::MultiNode((multi_node_routing, response_policy))) => { + return self + .execute_on_multiple_nodes(input, multi_node_routing, response_policy) + .map(Output::Single); + } + None => fail!(UNROUTABLE_ERROR), + }; + + let mut retries = 0; + let mut redirected = None::; + + loop { + // Get target address and response. + let (addr, rv) = { + let mut connections = self.connections.borrow_mut(); + let (addr, conn) = if let Some(redirected) = redirected.take() { + let (addr, is_asking) = match redirected { + Redirect::Moved(addr) => (addr, false), + Redirect::Ask(addr) => (addr, true), + }; + let conn = self.get_connection_by_addr(&mut connections, &addr)?; + if is_asking { + // if we are in asking mode we want to feed a single + // ASKING command into the connection before what we + // actually want to execute. + conn.req_packed_command(&b"*1\r\n$6\r\nASKING\r\n"[..])?; + } + (addr.to_string(), conn) + } else { + match &single_node_routing { + SingleNodeRoutingInfo::Random => get_random_connection(&mut connections), + SingleNodeRoutingInfo::SpecificNode(route) => { + self.get_connection(&mut connections, route)? + } + SingleNodeRoutingInfo::RandomPrimary => { + self.get_connection(&mut connections, &Route::new_random_primary())? + } + SingleNodeRoutingInfo::ByAddress { host, port } => { + let address = format!("{host}:{port}"); + let conn = self.get_connection_by_addr(&mut connections, &address)?; + (address, conn) + } + } + }; + (addr, input.send(conn)) + }; + + match rv { + Ok(rv) => return Ok(rv), + Err(err) => { + if retries == self.cluster_params.retry_params.number_of_retries { + return Err(err); + } + retries += 1; + + match err.retry_method() { + crate::types::RetryMethod::AskRedirect => { + redirected = err + .redirect_node() + .map(|(node, _slot)| Redirect::Ask(node.to_string())); + } + crate::types::RetryMethod::MovedRedirect => { + // Refresh slots. + self.refresh_slots()?; + // Request again. + redirected = err + .redirect_node() + .map(|(node, _slot)| Redirect::Moved(node.to_string())); + } + crate::types::RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica + | crate::types::RetryMethod::WaitAndRetry => { + // Sleep and retry. + let sleep_time = self + .cluster_params + .retry_params + .wait_time_for_retry(retries); + thread::sleep(sleep_time); + } + crate::types::RetryMethod::Reconnect => { + if *self.auto_reconnect.borrow() { + if let Ok(mut conn) = self.connect(&addr) { + if conn.check_connection() { + self.connections.borrow_mut().insert(addr, conn); + } + } + } + } + crate::types::RetryMethod::NoRetry => { + return Err(err); + } + crate::types::RetryMethod::RetryImmediately => {} + } + } + } + } + } + + fn send_recv_and_retry_cmds(&self, cmds: &[Cmd]) -> RedisResult> { + // Vector to hold the results, pre-populated with `Nil` values. This allows the original + // cmd ordering to be re-established by inserting the response directly into the result + // vector (e.g., results[10] = response). + let mut results = vec![Value::Nil; cmds.len()]; + + let to_retry = self + .send_all_commands(cmds) + .and_then(|node_cmds| self.recv_all_commands(&mut results, &node_cmds))?; + + if to_retry.is_empty() { + return Ok(results); + } + + // Refresh the slots to ensure that we have a clean slate for the retry attempts. + self.refresh_slots()?; + + // Given that there are commands that need to be retried, it means something in the cluster + // topology changed. Execute each command seperately to take advantage of the existing + // retry logic that handles these cases. + for retry_idx in to_retry { + let cmd = &cmds[retry_idx]; + results[retry_idx] = self.request(Input::Cmd(cmd))?.into(); + } + Ok(results) + } + + // Build up a pipeline per node, then send it + fn send_all_commands(&self, cmds: &[Cmd]) -> RedisResult> { + let mut connections = self.connections.borrow_mut(); + + let node_cmds = self.map_cmds_to_nodes(cmds)?; + for nc in &node_cmds { + self.get_connection_by_addr(&mut connections, &nc.addr)? + .send_packed_command(&nc.pipe)?; + } + Ok(node_cmds) + } + + // Receive from each node, keeping track of which commands need to be retried. + fn recv_all_commands( + &self, + results: &mut [Value], + node_cmds: &[NodeCmd], + ) -> RedisResult> { + let mut to_retry = Vec::new(); + let mut connections = self.connections.borrow_mut(); + let mut first_err = None; + + for nc in node_cmds { + for cmd_idx in &nc.indexes { + match self + .get_connection_by_addr(&mut connections, &nc.addr)? + .recv_response() + { + Ok(item) => results[*cmd_idx] = item, + Err(err) if err.is_cluster_error() => to_retry.push(*cmd_idx), + Err(err) => first_err = first_err.or(Some(err)), + } + } + } + match first_err { + Some(err) => Err(err), + None => Ok(to_retry), + } + } +} + +const MULTI: &[u8] = "*1\r\n$5\r\nMULTI\r\n".as_bytes(); +impl ConnectionLike for ClusterConnection { + fn supports_pipelining(&self) -> bool { + false + } + + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + self.request(Input::Cmd(cmd)).map(|res| res.into()) + } + + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + let actual_cmd = if cmd.starts_with(MULTI) { + &cmd[MULTI.len()..] + } else { + cmd + }; + let value = parse_redis_value(actual_cmd)?; + self.request(Input::Slice { + cmd, + routable: value, + }) + .map(|res| res.into()) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + let actual_cmd = if cmd.starts_with(MULTI) { + &cmd[MULTI.len()..] + } else { + cmd + }; + let value = parse_redis_value(actual_cmd)?; + let route = match RoutingInfo::for_routable(&value) { + Some(RoutingInfo::MultiNode(_)) => None, + Some(RoutingInfo::SingleNode(route)) => Some(route), + None => None, + } + .unwrap_or(SingleNodeRoutingInfo::Random); + self.request(Input::Commands { + cmd, + offset, + count, + route, + }) + .map(|res| res.into()) + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_open(&self) -> bool { + let connections = self.connections.borrow(); + for conn in connections.values() { + if !conn.is_open() { + return false; + } + } + true + } + + fn check_connection(&mut self) -> bool { + let mut connections = self.connections.borrow_mut(); + for conn in connections.values_mut() { + if !conn.check_connection() { + return false; + } + } + true + } +} + +#[derive(Debug)] +struct NodeCmd { + // The original command indexes + indexes: Vec, + pipe: Vec, + addr: String, +} + +impl NodeCmd { + fn new(a: String) -> NodeCmd { + NodeCmd { + indexes: vec![], + pipe: vec![], + addr: a, + } + } +} + +// TODO: This function can panic and should probably +// return an Option instead: +fn get_random_connection( + connections: &mut HashMap, +) -> (String, &mut C) { + let addr = connections + .keys() + .choose(&mut thread_rng()) + .expect("Connections is empty") + .to_string(); + let con = connections.get_mut(&addr).expect("Connections is empty"); + (addr, con) +} + +// The node string passed to this function will always be in the format host:port as it is either: +// - Created by calling ConnectionAddr::to_string (unix connections are not supported in cluster mode) +// - Returned from redis via the ASK/MOVED response +pub(crate) fn get_connection_info( + node: &str, + cluster_params: ClusterParams, +) -> RedisResult { + let invalid_error = || (ErrorKind::InvalidClientConfig, "Invalid node string"); + + let (host, port) = node + .rsplit_once(':') + .and_then(|(host, port)| { + Some(host.trim_start_matches('[').trim_end_matches(']')) + .filter(|h| !h.is_empty()) + .zip(u16::from_str(port).ok()) + }) + .ok_or_else(invalid_error)?; + + Ok(ConnectionInfo { + addr: get_connection_addr( + host.to_string(), + port, + cluster_params.tls, + cluster_params.tls_params, + ), + redis: RedisConnectionInfo { + password: cluster_params.password, + username: cluster_params.username, + client_name: cluster_params.client_name, + protocol: cluster_params.protocol, + db: 0, + pubsub_subscriptions: cluster_params.pubsub_subscriptions, + }, + }) +} + +pub(crate) fn get_connection_addr( + host: String, + port: u16, + tls: Option, + tls_params: Option, +) -> ConnectionAddr { + match tls { + Some(TlsMode::Secure) => ConnectionAddr::TcpTls { + host, + port, + insecure: false, + tls_params, + }, + Some(TlsMode::Insecure) => ConnectionAddr::TcpTls { + host, + port, + insecure: true, + tls_params, + }, + _ => ConnectionAddr::Tcp(host, port), + } +} + +pub(crate) fn slot_cmd() -> Cmd { + let mut cmd = Cmd::new(); + cmd.arg("CLUSTER").arg("SLOTS"); + cmd +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn parse_cluster_node_host_port() { + let cases = vec![ + ( + "127.0.0.1:6379", + ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379u16), + ), + ( + "localhost.localdomain:6379", + ConnectionAddr::Tcp("localhost.localdomain".to_string(), 6379u16), + ), + ( + "dead::cafe:beef:30001", + ConnectionAddr::Tcp("dead::cafe:beef".to_string(), 30001u16), + ), + ( + "[fe80::cafe:beef%en1]:30001", + ConnectionAddr::Tcp("fe80::cafe:beef%en1".to_string(), 30001u16), + ), + ]; + + for (input, expected) in cases { + let res = get_connection_info(input, ClusterParams::default()); + assert_eq!(res.unwrap().addr, expected); + } + + let cases = vec![":0", "[]:6379"]; + for input in cases { + let res = get_connection_info(input, ClusterParams::default()); + assert_eq!( + res.err(), + Some(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Invalid node string", + ))), + ); + } + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_async/LICENSE b/glide-core/redis-rs/redis/src/cluster_async/LICENSE new file mode 100644 index 0000000000..aaa71a1638 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/LICENSE @@ -0,0 +1,7 @@ +Copyright 2019 Atsushi Koge, Markus Westerlind + +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs new file mode 100644 index 0000000000..2bfbb8b934 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/connections_container.rs @@ -0,0 +1,881 @@ +use crate::cluster_async::ConnectionFuture; +use crate::cluster_routing::{Route, SlotAddr}; +use crate::cluster_slotmap::{ReadFromReplicaStrategy, SlotMap, SlotMapValue}; +use crate::cluster_topology::TopologyHash; +use dashmap::DashMap; +use futures::FutureExt; +use rand::seq::IteratorRandom; +use std::net::IpAddr; + +/// A struct that encapsulates a network connection along with its associated IP address. +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct ConnectionWithIp { + /// The actual connection + pub conn: Connection, + /// The IP associated with the connection + pub ip: Option, +} + +impl ConnectionWithIp +where + Connection: Clone + Send + 'static, +{ + /// Consumes the current instance and returns a new `ConnectionWithIp` + /// where the connection is wrapped in a future. + #[doc(hidden)] + pub fn into_future(self) -> ConnectionWithIp> { + ConnectionWithIp { + conn: async { self.conn }.boxed().shared(), + ip: self.ip, + } + } +} + +impl From<(Connection, Option)> for ConnectionWithIp { + fn from(val: (Connection, Option)) -> Self { + ConnectionWithIp { + conn: val.0, + ip: val.1, + } + } +} + +impl From> for (Connection, Option) { + fn from(val: ConnectionWithIp) -> Self { + (val.conn, val.ip) + } +} + +#[derive(Clone, Eq, PartialEq, Debug)] +pub struct ClusterNode { + pub user_connection: ConnectionWithIp, + pub management_connection: Option>, +} + +impl ClusterNode +where + Connection: Clone, +{ + pub fn new( + user_connection: ConnectionWithIp, + management_connection: Option>, + ) -> Self { + Self { + user_connection, + management_connection, + } + } + + pub(crate) fn get_connection(&self, conn_type: &ConnectionType) -> Connection { + match conn_type { + ConnectionType::User => self.user_connection.conn.clone(), + ConnectionType::PreferManagement => self.management_connection.as_ref().map_or_else( + || self.user_connection.conn.clone(), + |management_conn| management_conn.conn.clone(), + ), + } + } +} + +#[derive(Clone, Eq, PartialEq, Debug)] + +pub(crate) enum ConnectionType { + User, + PreferManagement, +} + +pub(crate) struct ConnectionsMap(pub(crate) DashMap>); + +impl std::fmt::Display for ConnectionsMap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + for item in self.0.iter() { + let (address, node) = (item.key(), item.value()); + match node.user_connection.ip { + Some(ip) => writeln!(f, "{address} - {ip}")?, + None => writeln!(f, "{address}")?, + }; + } + Ok(()) + } +} + +pub(crate) struct ConnectionsContainer { + connection_map: DashMap>, + pub(crate) slot_map: SlotMap, + read_from_replica_strategy: ReadFromReplicaStrategy, + topology_hash: TopologyHash, +} + +impl Default for ConnectionsContainer { + fn default() -> Self { + Self { + connection_map: Default::default(), + slot_map: Default::default(), + read_from_replica_strategy: ReadFromReplicaStrategy::AlwaysFromPrimary, + topology_hash: 0, + } + } +} + +pub(crate) type ConnectionAndAddress = (String, Connection); + +impl ConnectionsContainer +where + Connection: Clone, +{ + pub(crate) fn new( + slot_map: SlotMap, + connection_map: ConnectionsMap, + read_from_replica_strategy: ReadFromReplicaStrategy, + topology_hash: TopologyHash, + ) -> Self { + Self { + connection_map: connection_map.0, + slot_map, + read_from_replica_strategy, + topology_hash, + } + } + + // Extends the current connection map with the provided one + pub(crate) fn extend_connection_map( + &mut self, + other_connection_map: ConnectionsMap, + ) { + self.connection_map.extend(other_connection_map.0); + } + + /// Returns true if the address represents a known primary node. + pub(crate) fn is_primary(&self, address: &String) -> bool { + self.connection_for_address(address).is_some() + && self + .slot_map + .values() + .any(|slot_addrs| slot_addrs.primary.as_str() == address) + } + + fn round_robin_read_from_replica( + &self, + slot_map_value: &SlotMapValue, + ) -> Option> { + let addrs = &slot_map_value.addrs; + let initial_index = slot_map_value + .latest_used_replica + .load(std::sync::atomic::Ordering::Relaxed); + let mut check_count = 0; + loop { + check_count += 1; + + // Looped through all replicas, no connected replica was found. + if check_count > addrs.replicas.len() { + return self.connection_for_address(addrs.primary.as_str()); + } + let index = (initial_index + check_count) % addrs.replicas.len(); + if let Some(connection) = self.connection_for_address(addrs.replicas[index].as_str()) { + let _ = slot_map_value.latest_used_replica.compare_exchange_weak( + initial_index, + index, + std::sync::atomic::Ordering::Relaxed, + std::sync::atomic::Ordering::Relaxed, + ); + return Some(connection); + } + } + } + + fn lookup_route(&self, route: &Route) -> Option> { + let slot_map_value = self.slot_map.slot_value_for_route(route)?; + let addrs = &slot_map_value.addrs; + if addrs.replicas.is_empty() { + return self.connection_for_address(addrs.primary.as_str()); + } + + match route.slot_addr() { + SlotAddr::Master => self.connection_for_address(addrs.primary.as_str()), + SlotAddr::ReplicaOptional => match self.read_from_replica_strategy { + ReadFromReplicaStrategy::AlwaysFromPrimary => { + self.connection_for_address(addrs.primary.as_str()) + } + ReadFromReplicaStrategy::RoundRobin => { + self.round_robin_read_from_replica(slot_map_value) + } + }, + SlotAddr::ReplicaRequired => self.round_robin_read_from_replica(slot_map_value), + } + } + + pub(crate) fn connection_for_route( + &self, + route: &Route, + ) -> Option> { + self.lookup_route(route).or_else(|| { + if route.slot_addr() != SlotAddr::Master { + self.lookup_route(&Route::new(route.slot(), SlotAddr::Master)) + } else { + None + } + }) + } + + pub(crate) fn all_node_connections( + &self, + ) -> impl Iterator> + '_ { + self.connection_map.iter().map(move |item| { + let (node, address) = (item.key(), item.value()); + (node.clone(), address.user_connection.conn.clone()) + }) + } + + pub(crate) fn all_primary_connections( + &self, + ) -> impl Iterator> + '_ { + self.slot_map + .addresses_for_all_primaries() + .into_iter() + .flat_map(|addr| self.connection_for_address(addr)) + } + + pub(crate) fn node_for_address(&self, address: &str) -> Option> { + self.connection_map + .get(address) + .map(|item| item.value().clone()) + } + + pub(crate) fn connection_for_address( + &self, + address: &str, + ) -> Option> { + self.connection_map.get(address).map(|item| { + let (address, conn) = (item.key(), item.value()); + (address.clone(), conn.user_connection.conn.clone()) + }) + } + + pub(crate) fn random_connections( + &self, + amount: usize, + conn_type: ConnectionType, + ) -> impl Iterator> + '_ { + self.connection_map + .iter() + .choose_multiple(&mut rand::thread_rng(), amount) + .into_iter() + .map(move |item| { + let (address, node) = (item.key(), item.value()); + let conn = node.get_connection(&conn_type); + (address.clone(), conn) + }) + } + + pub(crate) fn replace_or_add_connection_for_address( + &self, + address: impl Into, + node: ClusterNode, + ) -> String { + let address = address.into(); + self.connection_map.insert(address.clone(), node); + address + } + + pub(crate) fn remove_node(&self, address: &String) -> Option> { + self.connection_map + .remove(address) + .map(|(_key, value)| value) + } + + pub(crate) fn len(&self) -> usize { + self.connection_map.len() + } + + pub(crate) fn get_current_topology_hash(&self) -> TopologyHash { + self.topology_hash + } + + /// Returns true if the connections container contains no connections. + pub(crate) fn is_empty(&self) -> bool { + self.connection_map.is_empty() + } +} + +#[cfg(test)] +mod tests { + use std::collections::HashSet; + + use crate::cluster_routing::Slot; + + use super::*; + impl ClusterNode + where + Connection: Clone, + { + pub(crate) fn new_only_with_user_conn(user_connection: Connection) -> Self { + let ip = None; + Self { + user_connection: (user_connection, ip).into(), + management_connection: None, + } + } + } + fn remove_nodes(container: &ConnectionsContainer, addresses: &[&str]) { + for address in addresses { + container.remove_node(&(*address).into()); + } + } + + fn remove_all_connections(container: &ConnectionsContainer) { + remove_nodes( + container, + &[ + "primary1", + "primary2", + "primary3", + "replica2-1", + "replica3-1", + "replica3-2", + ], + ); + } + + fn one_of( + connection: Option>, + expected_connections: &[usize], + ) -> bool { + let found = connection.unwrap().1; + expected_connections.contains(&found) + } + fn create_cluster_node( + connection: usize, + use_management_connections: bool, + ) -> ClusterNode { + let ip = None; + ClusterNode::new( + (connection, ip).into(), + if use_management_connections { + Some((connection * 10, ip).into()) + } else { + None + }, + ) + } + + fn create_container_with_strategy( + stragey: ReadFromReplicaStrategy, + use_management_connections: bool, + ) -> ConnectionsContainer { + let slot_map = SlotMap::new( + vec![ + Slot::new(1, 1000, "primary1".to_owned(), Vec::new()), + Slot::new( + 1002, + 2000, + "primary2".to_owned(), + vec!["replica2-1".to_owned()], + ), + Slot::new( + 2001, + 3000, + "primary3".to_owned(), + vec!["replica3-1".to_owned(), "replica3-2".to_owned()], + ), + ], + ReadFromReplicaStrategy::AlwaysFromPrimary, // this argument shouldn't matter, since we overload the RFR strategy. + ); + let connection_map = DashMap::new(); + connection_map.insert( + "primary1".into(), + create_cluster_node(1, use_management_connections), + ); + connection_map.insert( + "primary2".into(), + create_cluster_node(2, use_management_connections), + ); + connection_map.insert( + "primary3".into(), + create_cluster_node(3, use_management_connections), + ); + connection_map.insert( + "replica2-1".into(), + create_cluster_node(21, use_management_connections), + ); + connection_map.insert( + "replica3-1".into(), + create_cluster_node(31, use_management_connections), + ); + connection_map.insert( + "replica3-2".into(), + create_cluster_node(32, use_management_connections), + ); + + ConnectionsContainer { + slot_map, + connection_map, + read_from_replica_strategy: stragey, + topology_hash: 0, + } + } + + fn create_container() -> ConnectionsContainer { + create_container_with_strategy(ReadFromReplicaStrategy::RoundRobin, false) + } + + #[test] + fn get_connection_for_primary_route() { + let container = create_container(); + + assert!(container + .connection_for_route(&Route::new(0, SlotAddr::Master)) + .is_none()); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(500, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(1000, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert!(container + .connection_for_route(&Route::new(1001, SlotAddr::Master)) + .is_none()); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1002, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1500, SlotAddr::Master)) + .unwrap() + .1 + ); + + assert_eq!( + 3, + container + .connection_for_route(&Route::new(2001, SlotAddr::Master)) + .unwrap() + .1 + ); + } + + #[test] + fn get_connection_for_replica_route() { + let container = create_container(); + + assert!(container + .connection_for_route(&Route::new(1001, SlotAddr::ReplicaOptional)) + .is_none()); + + assert_eq!( + 21, + container + .connection_for_route(&Route::new(1002, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 21, + container + .connection_for_route(&Route::new(1500, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert!(one_of( + container.connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)), + &[31, 32], + )); + } + + #[test] + fn get_primary_connection_for_replica_route_if_no_replicas_were_added() { + let container = create_container(); + + assert!(container + .connection_for_route(&Route::new(0, SlotAddr::ReplicaOptional)) + .is_none()); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(500, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 1, + container + .connection_for_route(&Route::new(1000, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + } + + #[test] + fn get_replica_connection_for_replica_route_if_some_but_not_all_replicas_were_removed() { + let container = create_container(); + container.remove_node(&"replica3-2".into()); + + assert_eq!( + 31, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaRequired)) + .unwrap() + .1 + ); + } + + #[test] + fn get_replica_connection_for_replica_route_if_replica_is_required_even_if_strategy_is_always_from_primary( + ) { + let container = + create_container_with_strategy(ReadFromReplicaStrategy::AlwaysFromPrimary, false); + + assert!(one_of( + container.connection_for_route(&Route::new(2001, SlotAddr::ReplicaRequired)), + &[31, 32], + )); + } + + #[test] + fn get_primary_connection_for_replica_route_if_all_replicas_were_removed() { + let container = create_container(); + remove_nodes(&container, &["replica2-1", "replica3-1", "replica3-2"]); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1002, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 2, + container + .connection_for_route(&Route::new(1500, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + + assert_eq!( + 3, + container + .connection_for_route(&Route::new(2001, SlotAddr::ReplicaOptional)) + .unwrap() + .1 + ); + } + + #[test] + fn get_connection_by_address() { + let container = create_container(); + + assert!(container.connection_for_address("foobar").is_none()); + + assert_eq!(1, container.connection_for_address("primary1").unwrap().1); + assert_eq!(2, container.connection_for_address("primary2").unwrap().1); + assert_eq!(3, container.connection_for_address("primary3").unwrap().1); + assert_eq!( + 21, + container.connection_for_address("replica2-1").unwrap().1 + ); + assert_eq!( + 31, + container.connection_for_address("replica3-1").unwrap().1 + ); + assert_eq!( + 32, + container.connection_for_address("replica3-2").unwrap().1 + ); + } + + #[test] + fn get_connection_by_address_returns_none_if_connection_was_removed() { + let container = create_container(); + container.remove_node(&"primary1".into()); + + assert!(container.connection_for_address("primary1").is_none()); + } + + #[test] + fn get_connection_by_address_returns_added_connection() { + let container = create_container(); + let address = container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + + assert_eq!( + (address, 4), + container.connection_for_address("foobar").unwrap() + ); + } + + #[test] + fn get_random_connections_without_repetitions() { + let container = create_container(); + + let random_connections: HashSet<_> = container + .random_connections(3, ConnectionType::User) + .map(|pair| pair.1) + .collect(); + + assert_eq!(random_connections.len(), 3); + assert!(random_connections + .iter() + .all(|connection| [1, 2, 3, 21, 31, 32].contains(connection))); + } + + #[test] + fn get_random_connections_returns_none_if_all_connections_were_removed() { + let container = create_container(); + remove_all_connections(&container); + + assert_eq!( + 0, + container + .random_connections(1, ConnectionType::User) + .count() + ); + } + + #[test] + fn get_random_connections_returns_added_connection() { + let container = create_container(); + remove_all_connections(&container); + let address = container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + let random_connections: Vec<_> = container + .random_connections(1, ConnectionType::User) + .collect(); + + assert_eq!(vec![(address, 4)], random_connections); + } + + #[test] + fn get_random_connections_is_bound_by_the_number_of_connections_in_the_map() { + let container = create_container(); + let mut random_connections: Vec<_> = container + .random_connections(1000, ConnectionType::User) + .map(|pair| pair.1) + .collect(); + random_connections.sort(); + + assert_eq!(random_connections, vec![1, 2, 3, 21, 31, 32]); + } + + #[test] + fn get_random_management_connections() { + let container = create_container_with_strategy(ReadFromReplicaStrategy::RoundRobin, true); + let mut random_connections: Vec<_> = container + .random_connections(1000, ConnectionType::PreferManagement) + .map(|pair| pair.1) + .collect(); + random_connections.sort(); + + assert_eq!(random_connections, vec![10, 20, 30, 210, 310, 320]); + } + + #[test] + fn get_all_user_connections() { + let container = create_container(); + let mut connections: Vec<_> = container + .all_node_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![1, 2, 3, 21, 31, 32], connections); + } + + #[test] + fn get_all_user_connections_returns_added_connection() { + let container = create_container(); + container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + + let mut connections: Vec<_> = container + .all_node_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![1, 2, 3, 4, 21, 31, 32], connections); + } + + #[test] + fn get_all_user_connections_does_not_return_removed_connection() { + let container = create_container(); + container.remove_node(&"primary1".into()); + + let mut connections: Vec<_> = container + .all_node_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![2, 3, 21, 31, 32], connections); + } + + #[test] + fn get_all_primaries() { + let container = create_container(); + + let mut connections: Vec<_> = container + .all_primary_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![1, 2, 3], connections); + } + + #[test] + fn get_all_primaries_does_not_return_removed_connection() { + let container = create_container(); + container.remove_node(&"primary1".into()); + + let mut connections: Vec<_> = container + .all_primary_connections() + .map(|conn| conn.1) + .collect(); + connections.sort(); + + assert_eq!(vec![2, 3], connections); + } + + #[test] + fn len_is_adjusted_on_removals_and_additions() { + let container = create_container(); + + assert_eq!(container.len(), 6); + + container.remove_node(&"primary1".into()); + assert_eq!(container.len(), 5); + + container.replace_or_add_connection_for_address( + "foobar", + ClusterNode::new_only_with_user_conn(4), + ); + assert_eq!(container.len(), 6); + } + + #[test] + fn len_is_not_adjusted_on_removals_of_nonexisting_connections_or_additions_of_existing_connections( + ) { + let container = create_container(); + + assert_eq!(container.len(), 6); + + container.remove_node(&"foobar".into()); + assert_eq!(container.len(), 6); + + container.replace_or_add_connection_for_address( + "primary1", + ClusterNode::new_only_with_user_conn(4), + ); + assert_eq!(container.len(), 6); + } + + #[test] + fn remove_node_returns_connection_if_it_exists() { + let container = create_container(); + + let connection = container.remove_node(&"primary1".into()); + assert_eq!(connection, Some(ClusterNode::new_only_with_user_conn(1))); + + let non_connection = container.remove_node(&"foobar".into()); + assert_eq!(non_connection, None); + } + + #[test] + fn test_is_empty() { + let container = create_container(); + + assert!(!container.is_empty()); + container.remove_node(&"primary1".into()); + assert!(!container.is_empty()); + container.remove_node(&"primary2".into()); + container.remove_node(&"primary3".into()); + assert!(!container.is_empty()); + + container.remove_node(&"replica2-1".into()); + container.remove_node(&"replica3-1".into()); + assert!(!container.is_empty()); + + container.remove_node(&"replica3-2".into()); + assert!(container.is_empty()); + } + + #[test] + fn is_primary_returns_true_for_known_primary() { + let container = create_container(); + + assert!(container.is_primary(&"primary1".into())); + } + + #[test] + fn is_primary_returns_false_for_known_replica() { + let container = create_container(); + + assert!(!container.is_primary(&"replica2-1".into())); + } + + #[test] + fn is_primary_returns_false_for_removed_node() { + let container = create_container(); + let address = "primary1".into(); + container.remove_node(&address); + + assert!(!container.is_primary(&address)); + } + + #[test] + fn test_extend_connection_map() { + let mut container = create_container(); + let mut current_addresses: Vec<_> = container + .all_node_connections() + .map(|conn| conn.0) + .collect(); + + let new_node = "new_primary1".to_string(); + // Check that `new_node` not exists in the current + assert!(container.connection_for_address(&new_node).is_none()); + // Create new connection map + let new_connection_map = DashMap::new(); + new_connection_map.insert(new_node.clone(), create_cluster_node(1, false)); + + // Extend the current connection map + container.extend_connection_map(ConnectionsMap(new_connection_map)); + + // Check that the new addresses vector contains both the new node and all previous nodes + let mut new_addresses: Vec<_> = container + .all_node_connections() + .map(|conn| conn.0) + .collect(); + current_addresses.push(new_node); + current_addresses.sort(); + new_addresses.sort(); + assert_eq!(current_addresses, new_addresses); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs b/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs new file mode 100644 index 0000000000..7de2493000 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/connections_logic.rs @@ -0,0 +1,481 @@ +use std::net::SocketAddr; + +use super::{ + connections_container::{ClusterNode, ConnectionWithIp}, + Connect, +}; +use crate::{ + aio::{ConnectionLike, DisconnectNotifier, Runtime}, + client::GlideConnectionOptions, + cluster::get_connection_info, + cluster_client::ClusterParams, + ErrorKind, RedisError, RedisResult, +}; + +use futures::prelude::*; +use futures_util::{future::BoxFuture, join}; +use tracing::warn; + +pub(crate) type ConnectionFuture = futures::future::Shared>; +/// Cluster node for async connections +#[doc(hidden)] +pub type AsyncClusterNode = ClusterNode>; + +#[doc(hidden)] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum RefreshConnectionType { + // Refresh only user connections + OnlyUserConnection, + // Refresh only management connections + OnlyManagementConnection, + // Refresh all connections: both management and user connections. + AllConnections, +} + +fn failed_management_connection( + addr: &str, + user_conn: ConnectionWithIp>, + err: RedisError, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Send + Clone + Sync + Connect + 'static, +{ + warn!( + "Failed to create management connection for node `{:?}`. Error: `{:?}`", + addr, err + ); + ConnectAndCheckResult::ManagementConnectionFailed { + node: AsyncClusterNode::new(user_conn, None), + err, + } +} + +pub(crate) async fn get_or_create_conn( + addr: &str, + node: Option>, + params: &ClusterParams, + conn_type: RefreshConnectionType, + glide_connection_options: GlideConnectionOptions, +) -> RedisResult> +where + C: ConnectionLike + Send + Clone + Sync + Connect + 'static, +{ + if let Some(node) = node { + // We won't check whether the DNS address of this node has changed and now points to a new IP. + // Instead, we depend on managed Redis services to close the connection for refresh if the node has changed. + match check_node_connections(&node, params, conn_type, addr).await { + None => Ok(node), + Some(conn_type) => connect_and_check( + addr, + params.clone(), + None, + conn_type, + Some(node), + glide_connection_options, + ) + .await + .get_node(), + } + } else { + connect_and_check( + addr, + params.clone(), + None, + conn_type, + None, + glide_connection_options, + ) + .await + .get_node() + } +} + +fn create_async_node( + user_conn: ConnectionWithIp, + management_conn: Option>, +) -> AsyncClusterNode +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + AsyncClusterNode::new( + user_conn.into_future(), + management_conn.map(|conn| conn.into_future()), + ) +} + +pub(crate) async fn connect_and_check_all_connections( + addr: &str, + params: ClusterParams, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + match future::join( + create_connection( + addr, + params.clone(), + socket_addr, + false, + glide_connection_options.clone(), + ), + create_connection( + addr, + params.clone(), + socket_addr, + true, + glide_connection_options, + ), + ) + .await + { + (Ok(conn_1), Ok(conn_2)) => { + // Both connections were successfully established + let mut user_conn: ConnectionWithIp = conn_1; + let mut management_conn: ConnectionWithIp = conn_2; + if let Err(err) = setup_user_connection(&mut user_conn.conn, params).await { + return err.into(); + } + match setup_management_connection(&mut management_conn.conn).await { + Ok(_) => ConnectAndCheckResult::Success(create_async_node( + user_conn, + Some(management_conn), + )), + Err(err) => failed_management_connection(addr, user_conn.into_future(), err), + } + } + (Ok(mut connection), Err(err)) | (Err(err), Ok(mut connection)) => { + // Only a single connection was successfully established. Use it for the user connection + match setup_user_connection(&mut connection.conn, params).await { + Ok(_) => failed_management_connection(addr, connection.into_future(), err), + Err(err) => err.into(), + } + } + (Err(err_1), Err(err_2)) => { + // Neither of the connections succeeded. + RedisError::from(( + ErrorKind::IoError, + "Failed to refresh both connections", + format!( + "Node: {:?} received errors: `{:?}`, `{:?}`", + addr, err_1, err_2 + ), + )) + .into() + } + } +} + +async fn connect_and_check_only_management_conn( + addr: &str, + params: ClusterParams, + socket_addr: Option, + prev_node: AsyncClusterNode, + disconnect_notifier: Option>, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + match create_connection::( + addr, + params.clone(), + socket_addr, + true, + GlideConnectionOptions { + push_sender: None, + disconnect_notifier, + }, + ) + .await + { + Err(conn_err) => failed_management_connection(addr, prev_node.user_connection, conn_err), + + Ok(mut connection) => { + if let Err(err) = setup_management_connection(&mut connection.conn).await { + return failed_management_connection(addr, prev_node.user_connection, err); + } + + ConnectAndCheckResult::Success(ClusterNode { + user_connection: prev_node.user_connection, + management_connection: Some(connection.into_future()), + }) + } + } +} + +#[doc(hidden)] +#[must_use] +pub enum ConnectAndCheckResult { + // Returns a node that was fully connected according to the request. + Success(AsyncClusterNode), + // Returns a node that failed to create a management connection, but has a working user connection. + ManagementConnectionFailed { + node: AsyncClusterNode, + err: RedisError, + }, + // Request failed completely, could not return a node with any working connection. + Failed(RedisError), +} + +impl ConnectAndCheckResult { + pub fn get_node(self) -> RedisResult> { + match self { + ConnectAndCheckResult::Success(node) => Ok(node), + ConnectAndCheckResult::ManagementConnectionFailed { node, .. } => Ok(node), + ConnectAndCheckResult::Failed(err) => Err(err), + } + } + + pub fn get_error(self) -> Option { + match self { + ConnectAndCheckResult::Success(_) => None, + ConnectAndCheckResult::ManagementConnectionFailed { err, .. } => Some(err), + ConnectAndCheckResult::Failed(err) => Some(err), + } + } +} + +impl From for ConnectAndCheckResult { + fn from(value: RedisError) -> Self { + ConnectAndCheckResult::Failed(value) + } +} + +impl From> for ConnectAndCheckResult { + fn from(value: AsyncClusterNode) -> Self { + ConnectAndCheckResult::Success(value) + } +} + +impl From>> for ConnectAndCheckResult { + fn from(value: RedisResult>) -> Self { + match value { + Ok(value) => value.into(), + Err(err) => err.into(), + } + } +} + +#[doc(hidden)] +pub async fn connect_and_check( + addr: &str, + params: ClusterParams, + socket_addr: Option, + conn_type: RefreshConnectionType, + node: Option>, + glide_connection_options: GlideConnectionOptions, +) -> ConnectAndCheckResult +where + C: ConnectionLike + Connect + Send + Sync + 'static + Clone, +{ + match conn_type { + RefreshConnectionType::OnlyUserConnection => { + let user_conn = match create_and_setup_user_connection( + addr, + params.clone(), + socket_addr, + glide_connection_options, + ) + .await + { + Ok(tuple) => tuple, + Err(err) => return err.into(), + }; + let management_conn = node.and_then(|node| node.management_connection); + AsyncClusterNode::new(user_conn.into_future(), management_conn).into() + } + RefreshConnectionType::OnlyManagementConnection => { + // Refreshing only the management connection requires the node to exist alongside a user connection. Otherwise, refresh all connections. + match node { + Some(node) => { + connect_and_check_only_management_conn( + addr, + params, + socket_addr, + node, + glide_connection_options.disconnect_notifier, + ) + .await + } + None => { + connect_and_check_all_connections( + addr, + params, + socket_addr, + glide_connection_options, + ) + .await + } + } + } + RefreshConnectionType::AllConnections => { + connect_and_check_all_connections(addr, params, socket_addr, glide_connection_options) + .await + } + } +} + +async fn create_and_setup_user_connection( + node: &str, + params: ClusterParams, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, +) -> RedisResult> +where + C: ConnectionLike + Connect + Send + 'static, +{ + let mut connection: ConnectionWithIp = create_connection( + node, + params.clone(), + socket_addr, + false, + glide_connection_options, + ) + .await?; + setup_user_connection(&mut connection.conn, params).await?; + Ok(connection) +} + +async fn setup_user_connection(conn: &mut C, params: ClusterParams) -> RedisResult<()> +where + C: ConnectionLike + Connect + Send + 'static, +{ + let read_from_replicas = params.read_from_replicas + != crate::cluster_slotmap::ReadFromReplicaStrategy::AlwaysFromPrimary; + let connection_timeout = params.connection_timeout; + check_connection(conn, connection_timeout).await?; + if read_from_replicas { + // If READONLY is sent to primary nodes, it will have no effect + crate::cmd("READONLY").query_async(conn).await?; + } + Ok(()) +} + +#[doc(hidden)] +pub const MANAGEMENT_CONN_NAME: &str = "glide_management_connection"; + +async fn setup_management_connection(conn: &mut C) -> RedisResult<()> +where + C: ConnectionLike + Connect + Send + 'static, +{ + crate::cmd("CLIENT") + .arg(&["SETNAME", MANAGEMENT_CONN_NAME]) + .query_async(conn) + .await?; + Ok(()) +} + +async fn create_connection( + node: &str, + mut params: ClusterParams, + socket_addr: Option, + is_management: bool, + mut glide_connection_options: GlideConnectionOptions, +) -> RedisResult> +where + C: ConnectionLike + Connect + Send + 'static, +{ + let connection_timeout = params.connection_timeout; + let response_timeout = params.response_timeout; + // ignore pubsub subscriptions and push notifications for management connections + if is_management { + params.pubsub_subscriptions = None; + } + let info = get_connection_info(node, params)?; + // management connection does not require notifications or disconnect notifications + if is_management { + glide_connection_options.disconnect_notifier = None; + } + C::connect( + info, + response_timeout, + connection_timeout, + socket_addr, + glide_connection_options, + ) + .await + .map(|conn| conn.into()) +} + +/// The function returns None if the checked connection/s are healthy. Otherwise, it returns the type of the unhealthy connection/s. +#[allow(dead_code)] +#[doc(hidden)] +pub async fn check_node_connections( + node: &AsyncClusterNode, + params: &ClusterParams, + conn_type: RefreshConnectionType, + address: &str, +) -> Option +where + C: ConnectionLike + Send + 'static + Clone, +{ + let timeout = params.connection_timeout; + let (check_mgmt_connection, check_user_connection) = match conn_type { + RefreshConnectionType::OnlyUserConnection => (false, true), + RefreshConnectionType::OnlyManagementConnection => (true, false), + RefreshConnectionType::AllConnections => (true, true), + }; + let check = |conn, timeout, conn_type| async move { + match check_connection(&mut conn.await, timeout).await { + Ok(_) => false, + Err(err) => { + warn!( + "The {} connection for node {} is unhealthy. Error: {:?}", + conn_type, address, err + ); + true + } + } + }; + let (mgmt_failed, user_failed) = join!( + async { + if !check_mgmt_connection { + return false; + } + match node.management_connection.clone() { + Some(connection) => check(connection.conn, timeout, "management").await, + None => { + warn!("The management connection for node {} isn't set", address); + true + } + } + }, + async { + if !check_user_connection { + return false; + } + let conn = node.user_connection.conn.clone(); + check(conn, timeout, "user").await + }, + ); + + match (mgmt_failed, user_failed) { + (true, true) => Some(RefreshConnectionType::AllConnections), + (true, false) => Some(RefreshConnectionType::OnlyManagementConnection), + (false, true) => Some(RefreshConnectionType::OnlyUserConnection), + (false, false) => None, + } +} + +async fn check_connection(conn: &mut C, timeout: std::time::Duration) -> RedisResult<()> +where + C: ConnectionLike + Send + 'static, +{ + Runtime::locate() + .timeout(timeout, crate::cmd("PING").query_async::<_, String>(conn)) + .await??; + Ok(()) +} + +/// Splits a string address into host and port. If the passed address cannot be parsed, None is returned. +/// [addr] should be in the following format: ":". +pub(crate) fn get_host_and_port_from_addr(addr: &str) -> Option<(&str, u16)> { + let parts: Vec<&str> = addr.split(':').collect(); + if parts.len() != 2 { + return None; + } + let host = parts.first().unwrap(); + let port = parts.get(1).unwrap(); + port.parse::().ok().map(|port| (*host, port)) +} diff --git a/glide-core/redis-rs/redis/src/cluster_async/mod.rs b/glide-core/redis-rs/redis/src/cluster_async/mod.rs new file mode 100644 index 0000000000..be7beb79b7 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_async/mod.rs @@ -0,0 +1,2656 @@ +//! This module provides async functionality for Redis Cluster. +//! +//! By default, [`ClusterConnection`] makes use of [`MultiplexedConnection`] and maintains a pool +//! of connections to each node in the cluster. While it generally behaves similarly to +//! the sync cluster module, certain commands do not route identically, due most notably to +//! a current lack of support for routing commands to multiple nodes. +//! +//! Also note that pubsub functionality is not currently provided by this module. +//! +//! # Example +//! ```rust,no_run +//! use redis::cluster::ClusterClient; +//! use redis::AsyncCommands; +//! +//! async fn fetch_an_integer() -> String { +//! let nodes = vec!["redis://127.0.0.1/"]; +//! let client = ClusterClient::new(nodes).unwrap(); +//! let mut connection = client.get_async_connection(None).await.unwrap(); +//! let _: () = connection.set("test", "test_data").await.unwrap(); +//! let rv: String = connection.get("test").await.unwrap(); +//! return rv; +//! } +//! ``` + +mod connections_container; +mod connections_logic; +/// Exposed only for testing. +pub mod testing { + pub use super::connections_container::ConnectionWithIp; + pub use super::connections_logic::*; +} +use crate::{ + client::GlideConnectionOptions, + cluster_routing::{Routable, RoutingInfo}, + cluster_slotmap::SlotMap, + cluster_topology::SLOT_SIZE, + cmd, + commands::cluster_scan::{cluster_scan, ClusterScanArgs, ObjectType, ScanStateRC}, + FromRedisValue, InfoDict, ToRedisArgs, +}; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use async_std::task::{spawn, JoinHandle}; +use dashmap::DashMap; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use futures::executor::block_on; +use std::{ + collections::{HashMap, HashSet}, + fmt, io, mem, + net::{IpAddr, SocketAddr}, + pin::Pin, + sync::{ + atomic::{self, AtomicUsize, Ordering}, + Arc, Mutex, + }, + task::{self, Poll}, + time::SystemTime, +}; +#[cfg(feature = "tokio-comp")] +use tokio::task::JoinHandle; + +#[cfg(feature = "tokio-comp")] +use crate::aio::DisconnectNotifier; + +use crate::{ + aio::{get_socket_addrs, ConnectionLike, MultiplexedConnection, Runtime}, + cluster::slot_cmd, + cluster_async::connections_logic::{ + get_host_and_port_from_addr, get_or_create_conn, ConnectionFuture, RefreshConnectionType, + }, + cluster_client::{ClusterParams, RetryParams}, + cluster_routing::{ + self, MultipleNodeRoutingInfo, Redirect, ResponsePolicy, Route, SingleNodeRoutingInfo, + SlotAddr, + }, + cluster_topology::{ + calculate_topology, get_slot, SlotRefreshState, DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES, + DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL, DEFAULT_REFRESH_SLOTS_RETRY_MAX_INTERVAL, + }, + connection::{PubSubSubscriptionInfo, PubSubSubscriptionKind}, + push_manager::PushInfo, + Cmd, ConnectionInfo, ErrorKind, IntoConnectionInfo, RedisError, RedisFuture, RedisResult, + Value, +}; +use futures::stream::{FuturesUnordered, StreamExt}; +use std::time::Duration; + +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use crate::aio::{async_std::AsyncStd, RedisRuntime}; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use backoff_std_async::future::retry; +#[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] +use backoff_std_async::{Error as BackoffError, ExponentialBackoff}; + +#[cfg(feature = "tokio-comp")] +use async_trait::async_trait; +#[cfg(feature = "tokio-comp")] +use backoff_tokio::future::retry; +#[cfg(feature = "tokio-comp")] +use backoff_tokio::{Error as BackoffError, ExponentialBackoff}; +#[cfg(feature = "tokio-comp")] +use tokio::{sync::Notify, time::timeout}; + +use dispose::{Disposable, Dispose}; +use futures::{future::BoxFuture, prelude::*, ready}; +use pin_project_lite::pin_project; +use tokio::sync::{ + mpsc, + oneshot::{self, Receiver}, + RwLock, +}; +use tracing::{debug, info, trace, warn}; + +use self::{ + connections_container::{ConnectionAndAddress, ConnectionType, ConnectionsMap}, + connections_logic::connect_and_check, +}; + +/// This represents an async Redis Cluster connection. It stores the +/// underlying connections maintained for each node in the cluster, as well +/// as common parameters for connecting to nodes and executing commands. +#[derive(Clone)] +pub struct ClusterConnection(mpsc::Sender>); + +impl ClusterConnection +where + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, +{ + pub(crate) async fn new( + initial_nodes: &[ConnectionInfo], + cluster_params: ClusterParams, + push_sender: Option>, + ) -> RedisResult> { + ClusterConnInner::new(initial_nodes, cluster_params, push_sender) + .await + .map(|inner| { + let (tx, mut rx) = mpsc::channel::>(100); + let stream = async move { + let _ = stream::poll_fn(move |cx| rx.poll_recv(cx)) + .map(Ok) + .forward(inner) + .await; + }; + #[cfg(feature = "tokio-comp")] + tokio::spawn(stream); + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + AsyncStd::spawn(stream); + + ClusterConnection(tx) + }) + } + + /// Special handling for `SCAN` command, using `cluster_scan`. + /// If you wish to use a match pattern, use [`cluster_scan_with_pattern`]. + /// Perform a `SCAN` command on a Redis cluster, using scan state object in order to handle changes in topology + /// and make sure that all keys that were in the cluster from start to end of the scan are scanned. + /// In order to make sure all keys in the cluster scanned, topology refresh occurs more frequently and may affect performance. + /// + /// # Arguments + /// + /// * `scan_state_rc` - A reference to the scan state, For initiating new scan send [`ScanStateRC::new()`], + /// for each subsequent iteration use the returned [`ScanStateRC`]. + /// * `count` - An optional count of keys requested, + /// the amount returned can vary and not obligated to return exactly count. + /// * `object_type` - An optional [`ObjectType`] enum of requested key redis type. + /// + /// # Returns + /// + /// A [`ScanStateRC`] for the updated state of the scan and the vector of keys that were found in the scan. + /// structure of returned value: + /// `Ok((ScanStateRC, Vec))` + /// + /// When the scan is finished [`ScanStateRC`] will be None, and can be checked by calling `scan_state_wrapper.is_finished()`. + /// + /// # Example + /// ```rust,no_run + /// use redis::cluster::ClusterClient; + /// use redis::{ScanStateRC, FromRedisValue, from_redis_value, Value, ObjectType}; + /// + /// async fn scan_all_cluster() -> Vec { + /// let nodes = vec!["redis://127.0.0.1/"]; + /// let client = ClusterClient::new(nodes).unwrap(); + /// let mut connection = client.get_async_connection(None).await.unwrap(); + /// let mut scan_state_rc = ScanStateRC::new(); + /// let mut keys: Vec = vec![]; + /// loop { + /// let (next_cursor, scan_keys): (ScanStateRC, Vec) = + /// connection.cluster_scan(scan_state_rc, None, None).await.unwrap(); + /// scan_state_rc = next_cursor; + /// let mut scan_keys = scan_keys + /// .into_iter() + /// .map(|v| from_redis_value(&v).unwrap()) + /// .collect::>(); // Change the type of `keys` to `Vec` + /// keys.append(&mut scan_keys); + /// if scan_state_rc.is_finished() { + /// break; + /// } + /// } + /// keys + /// } + /// ``` + pub async fn cluster_scan( + &mut self, + scan_state_rc: ScanStateRC, + count: Option, + object_type: Option, + ) -> RedisResult<(ScanStateRC, Vec)> { + let cluster_scan_args = ClusterScanArgs::new(scan_state_rc, None, count, object_type); + self.route_cluster_scan(cluster_scan_args).await + } + + /// Special handling for `SCAN` command, using `cluster_scan_with_pattern`. + /// It is a special case of [`cluster_scan`], with an additional match pattern. + /// Perform a `SCAN` command on a Redis cluster, using scan state object in order to handle changes in topology + /// and make sure that all keys that were in the cluster from start to end of the scan are scanned. + /// In order to make sure all keys in the cluster scanned, topology refresh occurs more frequently and may affect performance. + /// + /// # Arguments + /// + /// * `scan_state_rc` - A reference to the scan state, For initiating new scan send [`ScanStateRC::new()`], + /// for each subsequent iteration use the returned [`ScanStateRC`]. + /// * `match_pattern` - A match pattern of requested keys. + /// * `count` - An optional count of keys requested, + /// the amount returned can vary and not obligated to return exactly count. + /// * `object_type` - An optional [`ObjectType`] enum of requested key redis type. + /// + /// # Returns + /// + /// A [`ScanStateRC`] for the updated state of the scan and the vector of keys that were found in the scan. + /// structure of returned value: + /// `Ok((ScanStateRC, Vec))` + /// + /// When the scan is finished [`ScanStateRC`] will be None, and can be checked by calling `scan_state_wrapper.is_finished()`. + /// + /// # Example + /// ```rust,no_run + /// use redis::cluster::ClusterClient; + /// use redis::{ScanStateRC, FromRedisValue, from_redis_value, Value, ObjectType}; + /// + /// async fn scan_all_cluster() -> Vec { + /// let nodes = vec!["redis://127.0.0.1/"]; + /// let client = ClusterClient::new(nodes).unwrap(); + /// let mut connection = client.get_async_connection(None).await.unwrap(); + /// let mut scan_state_rc = ScanStateRC::new(); + /// let mut keys: Vec = vec![]; + /// loop { + /// let (next_cursor, scan_keys): (ScanStateRC, Vec) = + /// connection.cluster_scan_with_pattern(scan_state_rc, b"my_key", None, None).await.unwrap(); + /// scan_state_rc = next_cursor; + /// let mut scan_keys = scan_keys + /// .into_iter() + /// .map(|v| from_redis_value(&v).unwrap()) + /// .collect::>(); // Change the type of `keys` to `Vec` + /// keys.append(&mut scan_keys); + /// if scan_state_rc.is_finished() { + /// break; + /// } + /// } + /// keys + /// } + /// ``` + pub async fn cluster_scan_with_pattern( + &mut self, + scan_state_rc: ScanStateRC, + match_pattern: K, + count: Option, + object_type: Option, + ) -> RedisResult<(ScanStateRC, Vec)> { + let cluster_scan_args = ClusterScanArgs::new( + scan_state_rc, + Some(match_pattern.to_redis_args().concat()), + count, + object_type, + ); + self.route_cluster_scan(cluster_scan_args).await + } + + /// Route cluster scan to be handled by internal cluster_scan command + async fn route_cluster_scan( + &mut self, + cluster_scan_args: ClusterScanArgs, + ) -> RedisResult<(ScanStateRC, Vec)> { + let (sender, receiver) = oneshot::channel(); + self.0 + .send(Message { + cmd: CmdArg::ClusterScan { cluster_scan_args }, + sender, + }) + .await + .map_err(|_| { + RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to send command", + )) + })?; + receiver + .await + .unwrap_or_else(|_| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to receive command", + ))) + }) + .map(|response| match response { + Response::ClusterScanResult(new_scan_state_ref, key) => (new_scan_state_ref, key), + Response::Single(_) => unreachable!(), + Response::Multiple(_) => unreachable!(), + }) + } + + /// Send a command to the given `routing`. If `routing` is [None], it will be computed from `cmd`. + pub async fn route_command( + &mut self, + cmd: &Cmd, + routing: cluster_routing::RoutingInfo, + ) -> RedisResult { + trace!("route_command"); + let (sender, receiver) = oneshot::channel(); + self.0 + .send(Message { + cmd: CmdArg::Cmd { + cmd: Arc::new(cmd.clone()), + routing: routing.into(), + }, + sender, + }) + .await + .map_err(|_| { + RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to send command", + )) + })?; + receiver + .await + .unwrap_or_else(|_| { + Err(RedisError::from(io::Error::new( + io::ErrorKind::BrokenPipe, + "redis_cluster: Unable to receive command", + ))) + }) + .map(|response| match response { + Response::Single(value) => value, + Response::Multiple(_) => unreachable!(), + Response::ClusterScanResult(_, _) => unreachable!(), + }) + } + + /// Send commands in `pipeline` to the given `route`. If `route` is [None], it will be computed from `pipeline`. + pub async fn route_pipeline<'a>( + &'a mut self, + pipeline: &'a crate::Pipeline, + offset: usize, + count: usize, + route: SingleNodeRoutingInfo, + ) -> RedisResult> { + let (sender, receiver) = oneshot::channel(); + self.0 + .send(Message { + cmd: CmdArg::Pipeline { + pipeline: Arc::new(pipeline.clone()), + offset, + count, + route: route.into(), + }, + sender, + }) + .await + .map_err(|_| RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))?; + + receiver + .await + .unwrap_or_else(|_| Err(RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe)))) + .map(|response| match response { + Response::Multiple(values) => values, + Response::Single(_) => unreachable!(), + Response::ClusterScanResult(_, _) => unreachable!(), + }) + } +} + +#[cfg(feature = "tokio-comp")] +#[derive(Clone)] +struct TokioDisconnectNotifier { + disconnect_notifier: Arc, +} + +#[cfg(feature = "tokio-comp")] +#[async_trait] +impl DisconnectNotifier for TokioDisconnectNotifier { + fn notify_disconnect(&mut self) { + self.disconnect_notifier.notify_one(); + } + + async fn wait_for_disconnect_with_timeout(&self, max_wait: &Duration) { + let _ = timeout(*max_wait, async { + self.disconnect_notifier.notified().await; + }) + .await; + } + + fn clone_box(&self) -> Box { + Box::new(self.clone()) + } +} + +#[cfg(feature = "tokio-comp")] +impl TokioDisconnectNotifier { + fn new() -> TokioDisconnectNotifier { + TokioDisconnectNotifier { + disconnect_notifier: Arc::new(Notify::new()), + } + } +} + +type ConnectionMap = connections_container::ConnectionsMap>; +type ConnectionsContainer = + self::connections_container::ConnectionsContainer>; + +pub(crate) struct InnerCore { + pub(crate) conn_lock: RwLock>, + cluster_params: ClusterParams, + pending_requests: Mutex>>, + slot_refresh_state: SlotRefreshState, + initial_nodes: Vec, + subscriptions_by_address: RwLock>, + unassigned_subscriptions: RwLock, + glide_connection_options: GlideConnectionOptions, +} + +pub(crate) type Core = Arc>; + +impl InnerCore +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + // return address of node for slot + pub(crate) async fn get_address_from_slot( + &self, + slot: u16, + slot_addr: SlotAddr, + ) -> Option { + self.conn_lock + .read() + .await + .slot_map + .get_node_address_for_slot(slot, slot_addr) + } + + // return epoch of node + pub(crate) async fn get_address_epoch(&self, node_address: &str) -> Result { + let command = cmd("CLUSTER").arg("INFO").to_owned(); + let node_conn = self + .conn_lock + .read() + .await + .connection_for_address(node_address) + .ok_or(RedisError::from(( + ErrorKind::ResponseError, + "Failed to parse cluster info", + )))?; + + let cluster_info = node_conn.1.await.req_packed_command(&command).await; + match cluster_info { + Ok(value) => { + let info_dict: Result = + FromRedisValue::from_redis_value(&value); + if let Ok(info_dict) = info_dict { + let epoch = info_dict.get("cluster_my_epoch"); + if let Some(epoch) = epoch { + Ok(epoch) + } else { + Err(RedisError::from(( + ErrorKind::ResponseError, + "Failed to get epoch from cluster info", + ))) + } + } else { + Err(RedisError::from(( + ErrorKind::ResponseError, + "Failed to parse cluster info", + ))) + } + } + Err(redis_error) => Err(redis_error), + } + } + + // return slots of node + pub(crate) async fn get_slots_of_address(&self, node_address: &str) -> Vec { + self.conn_lock + .read() + .await + .slot_map + .get_slots_of_node(node_address) + } +} + +pub(crate) struct ClusterConnInner { + pub(crate) inner: Core, + state: ConnectionState, + #[allow(clippy::complexity)] + in_flight_requests: stream::FuturesUnordered>>>, + refresh_error: Option, + // Handler of the periodic check task. + periodic_checks_handler: Option>, + // Handler of fast connection validation task + connections_validation_handler: Option>, +} + +impl Dispose for ClusterConnInner { + fn dispose(self) { + if let Some(handle) = self.periodic_checks_handler { + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + block_on(handle.cancel()); + #[cfg(feature = "tokio-comp")] + handle.abort() + } + if let Some(handle) = self.connections_validation_handler { + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + block_on(handle.cancel()); + #[cfg(feature = "tokio-comp")] + handle.abort() + } + } +} + +#[derive(Clone)] +pub(crate) enum InternalRoutingInfo { + SingleNode(InternalSingleNodeRouting), + MultiNode((MultipleNodeRoutingInfo, Option)), +} + +#[derive(PartialEq, Clone, Debug)] +/// Represents different policies for refreshing the cluster slots. +pub(crate) enum RefreshPolicy { + /// `Throttable` indicates that the refresh operation can be throttled, + /// meaning it can be delayed or rate-limited if necessary. + Throttable, + /// `NotThrottable` indicates that the refresh operation should not be throttled, + /// meaning it should be executed immediately without any delay or rate-limiting. + NotThrottable, +} + +impl From for InternalRoutingInfo { + fn from(value: cluster_routing::RoutingInfo) -> Self { + match value { + cluster_routing::RoutingInfo::SingleNode(route) => { + InternalRoutingInfo::SingleNode(route.into()) + } + cluster_routing::RoutingInfo::MultiNode(routes) => { + InternalRoutingInfo::MultiNode(routes) + } + } + } +} + +impl From> for InternalRoutingInfo { + fn from(value: InternalSingleNodeRouting) -> Self { + InternalRoutingInfo::SingleNode(value) + } +} + +#[derive(Clone)] +pub(crate) enum InternalSingleNodeRouting { + Random, + SpecificNode(Route), + ByAddress(String), + Connection { + address: String, + conn: ConnectionFuture, + }, + Redirect { + redirect: Redirect, + previous_routing: Box>, + }, +} + +impl Default for InternalSingleNodeRouting { + fn default() -> Self { + Self::Random + } +} + +impl From for InternalSingleNodeRouting { + fn from(value: SingleNodeRoutingInfo) -> Self { + match value { + SingleNodeRoutingInfo::Random => InternalSingleNodeRouting::Random, + SingleNodeRoutingInfo::SpecificNode(route) => { + InternalSingleNodeRouting::SpecificNode(route) + } + SingleNodeRoutingInfo::RandomPrimary => { + InternalSingleNodeRouting::SpecificNode(Route::new_random_primary()) + } + SingleNodeRoutingInfo::ByAddress { host, port } => { + InternalSingleNodeRouting::ByAddress(format!("{host}:{port}")) + } + } + } +} + +#[derive(Clone)] +enum CmdArg { + Cmd { + cmd: Arc, + routing: InternalRoutingInfo, + }, + Pipeline { + pipeline: Arc, + offset: usize, + count: usize, + route: InternalSingleNodeRouting, + }, + ClusterScan { + // struct containing the arguments for the cluster scan command - scan state cursor, match pattern, count and object type. + cluster_scan_args: ClusterScanArgs, + }, +} + +fn route_for_pipeline(pipeline: &crate::Pipeline) -> RedisResult> { + fn route_for_command(cmd: &Cmd) -> Option { + match cluster_routing::RoutingInfo::for_routable(cmd) { + Some(cluster_routing::RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) => None, + Some(cluster_routing::RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(route), + )) => Some(route), + Some(cluster_routing::RoutingInfo::SingleNode( + SingleNodeRoutingInfo::RandomPrimary, + )) => Some(Route::new_random_primary()), + Some(cluster_routing::RoutingInfo::MultiNode(_)) => None, + Some(cluster_routing::RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + .. + })) => None, + None => None, + } + } + + // Find first specific slot and send to it. There's no need to check If later commands + // should be routed to a different slot, since the server will return an error indicating this. + pipeline.cmd_iter().map(route_for_command).try_fold( + None, + |chosen_route, next_cmd_route| match (chosen_route, next_cmd_route) { + (None, _) => Ok(next_cmd_route), + (_, None) => Ok(chosen_route), + (Some(chosen_route), Some(next_cmd_route)) => { + if chosen_route.slot() != next_cmd_route.slot() { + Err((ErrorKind::CrossSlot, "Received crossed slots in pipeline").into()) + } else if chosen_route.slot_addr() == SlotAddr::ReplicaOptional { + Ok(Some(next_cmd_route)) + } else { + Ok(Some(chosen_route)) + } + } + }, + ) +} + +fn boxed_sleep(duration: Duration) -> BoxFuture<'static, ()> { + #[cfg(feature = "tokio-comp")] + return Box::pin(tokio::time::sleep(duration)); + + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + return Box::pin(async_std::task::sleep(duration)); +} + +pub(crate) enum Response { + Single(Value), + ClusterScanResult(ScanStateRC, Vec), + Multiple(Vec), +} + +pub(crate) enum OperationTarget { + Node { address: String }, + FanOut, + NotFound, +} +type OperationResult = Result; + +impl From for OperationTarget { + fn from(address: String) -> Self { + OperationTarget::Node { address } + } +} + +struct Message { + cmd: CmdArg, + sender: oneshot::Sender>, +} + +enum RecoverFuture { + RecoverSlots(BoxFuture<'static, RedisResult<()>>), + Reconnect(BoxFuture<'static, ()>), +} + +enum ConnectionState { + PollComplete, + Recover(RecoverFuture), +} + +impl fmt::Debug for ConnectionState { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!( + f, + "{}", + match self { + ConnectionState::PollComplete => "PollComplete", + ConnectionState::Recover(_) => "Recover", + } + ) + } +} + +#[derive(Clone)] +struct RequestInfo { + cmd: CmdArg, +} + +impl RequestInfo { + fn set_redirect(&mut self, redirect: Option) { + if let Some(redirect) = redirect { + match &mut self.cmd { + CmdArg::Cmd { routing, .. } => match routing { + InternalRoutingInfo::SingleNode(route) => { + let redirect = InternalSingleNodeRouting::Redirect { + redirect, + previous_routing: Box::new(std::mem::take(route)), + } + .into(); + *routing = redirect; + } + InternalRoutingInfo::MultiNode(_) => { + panic!("Cannot redirect multinode requests") + } + }, + CmdArg::Pipeline { route, .. } => { + let redirect = InternalSingleNodeRouting::Redirect { + redirect, + previous_routing: Box::new(std::mem::take(route)), + }; + *route = redirect; + } + // cluster_scan is sent as a normal command internally so we will not reach that point. + CmdArg::ClusterScan { .. } => { + unreachable!() + } + } + } + } + + fn reset_routing(&mut self) { + let fix_route = |route: &mut InternalSingleNodeRouting| { + match route { + InternalSingleNodeRouting::Redirect { + previous_routing, .. + } => { + let previous_routing = std::mem::take(previous_routing.as_mut()); + *route = previous_routing; + } + // If a specific connection is specified, then reconnecting without resetting the routing + // will mean that the request is still routed to the old connection. + InternalSingleNodeRouting::Connection { address, .. } => { + *route = InternalSingleNodeRouting::ByAddress(address.to_string()); + } + _ => {} + } + }; + match &mut self.cmd { + CmdArg::Cmd { routing, .. } => { + if let InternalRoutingInfo::SingleNode(route) = routing { + fix_route(route); + } + } + CmdArg::Pipeline { route, .. } => { + fix_route(route); + } + // cluster_scan is sent as a normal command internally so we will not reach that point. + CmdArg::ClusterScan { .. } => { + unreachable!() + } + } + } +} + +pin_project! { + #[project = RequestStateProj] + enum RequestState { + None, + Future { + #[pin] + future: F, + }, + Sleep { + #[pin] + sleep: BoxFuture<'static, ()>, + }, + } +} + +struct PendingRequest { + retry: u32, + sender: oneshot::Sender>, + info: RequestInfo, +} + +pin_project! { + struct Request { + retry_params: RetryParams, + request: Option>, + #[pin] + future: RequestState>, + } +} + +#[must_use] +enum Next { + Retry { + request: PendingRequest, + }, + RetryBusyLoadingError { + request: PendingRequest, + address: String, + }, + Reconnect { + // if not set, then a reconnect should happen without sending a request afterwards + request: Option>, + target: String, + }, + RefreshSlots { + // if not set, then a slot refresh should happen without sending a request afterwards + request: Option>, + sleep_duration: Option, + }, + ReconnectToInitialNodes { + // if not set, then a reconnect should happen without sending a request afterwards + request: Option>, + }, + Done, +} + +impl Future for Request { + type Output = Next; + + fn poll(mut self: Pin<&mut Self>, cx: &mut task::Context) -> Poll { + let mut this = self.as_mut().project(); + // If the sender is closed, the caller is no longer waiting for the reply, and it is ambiguous + // whether they expect the side-effect of the request to happen or not. + if this.request.is_none() || this.request.as_ref().unwrap().sender.is_closed() { + return Poll::Ready(Next::Done); + } + let future = match this.future.as_mut().project() { + RequestStateProj::Future { future } => future, + RequestStateProj::Sleep { sleep } => { + ready!(sleep.poll(cx)); + return Next::Retry { + request: self.project().request.take().unwrap(), + } + .into(); + } + _ => panic!("Request future must be Some"), + }; + match ready!(future.poll(cx)) { + Ok(item) => { + self.respond(Ok(item)); + Next::Done.into() + } + Err((target, err)) => { + let request = this.request.as_mut().unwrap(); + // TODO - would be nice if we didn't need to repeat this code twice, with & without retries. + if request.retry >= this.retry_params.number_of_retries { + let next = if err.kind() == ErrorKind::AllConnectionsUnavailable { + Next::ReconnectToInitialNodes { request: None }.into() + } else if matches!(err.retry_method(), crate::types::RetryMethod::MovedRedirect) + || matches!(target, OperationTarget::NotFound) + { + Next::RefreshSlots { + request: None, + sleep_duration: None, + } + .into() + } else if matches!(err.retry_method(), crate::types::RetryMethod::Reconnect) { + if let OperationTarget::Node { address } = target { + Next::Reconnect { + request: None, + target: address, + } + .into() + } else { + Next::Done.into() + } + } else { + Next::Done.into() + }; + self.respond(Err(err)); + return next; + } + request.retry = request.retry.saturating_add(1); + + if err.kind() == ErrorKind::AllConnectionsUnavailable { + return Next::ReconnectToInitialNodes { + request: Some(this.request.take().unwrap()), + } + .into(); + } + + let sleep_duration = this.retry_params.wait_time_for_retry(request.retry); + + let address = match target { + OperationTarget::Node { address } => address, + OperationTarget::FanOut => { + trace!("Request error `{}` multi-node request", err); + + // Fanout operation are retried per internal request, and don't need additional retries. + self.respond(Err(err)); + return Next::Done.into(); + } + OperationTarget::NotFound => { + // TODO - this is essentially a repeat of the retirable error. probably can remove duplication. + let mut request = this.request.take().unwrap(); + request.info.reset_routing(); + return Next::RefreshSlots { + request: Some(request), + sleep_duration: Some(sleep_duration), + } + .into(); + } + }; + + warn!("Received request error {} on node {:?}.", err, address); + + match err.retry_method() { + crate::types::RetryMethod::AskRedirect => { + let mut request = this.request.take().unwrap(); + request.info.set_redirect( + err.redirect_node() + .map(|(node, _slot)| Redirect::Ask(node.to_string())), + ); + Next::Retry { request }.into() + } + crate::types::RetryMethod::MovedRedirect => { + let mut request = this.request.take().unwrap(); + request.info.set_redirect( + err.redirect_node() + .map(|(node, _slot)| Redirect::Moved(node.to_string())), + ); + Next::RefreshSlots { + request: Some(request), + sleep_duration: None, + } + .into() + } + crate::types::RetryMethod::WaitAndRetry => { + let sleep_duration = this.retry_params.wait_time_for_retry(request.retry); + // Sleep and retry. + this.future.set(RequestState::Sleep { + sleep: boxed_sleep(sleep_duration), + }); + self.poll(cx) + } + crate::types::RetryMethod::Reconnect => { + let mut request = this.request.take().unwrap(); + // TODO should we reset the redirect here? + request.info.reset_routing(); + warn!("disconnected from {:?}", address); + Next::Reconnect { + request: Some(request), + target: address, + } + .into() + } + crate::types::RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica => { + Next::RetryBusyLoadingError { + request: this.request.take().unwrap(), + address, + } + .into() + } + crate::types::RetryMethod::RetryImmediately => Next::Retry { + request: this.request.take().unwrap(), + } + .into(), + crate::types::RetryMethod::NoRetry => { + self.respond(Err(err)); + Next::Done.into() + } + } + } + } + } +} + +impl Request { + fn respond(self: Pin<&mut Self>, msg: RedisResult) { + // If `send` errors the receiver has dropped and thus does not care about the message + let _ = self + .project() + .request + .take() + .expect("Result should only be sent once") + .sender + .send(msg); + } +} + +enum ConnectionCheck { + Found((String, ConnectionFuture)), + OnlyAddress(String), + RandomConnection, +} + +impl ClusterConnInner +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + async fn new( + initial_nodes: &[ConnectionInfo], + cluster_params: ClusterParams, + push_sender: Option>, + ) -> RedisResult> { + let disconnect_notifier = { + #[cfg(feature = "tokio-comp")] + { + Some::>(Box::new(TokioDisconnectNotifier::new())) + } + #[cfg(not(feature = "tokio-comp"))] + None + }; + + let glide_connection_options = GlideConnectionOptions { + push_sender, + disconnect_notifier, + }; + + let connections = Self::create_initial_connections( + initial_nodes, + &cluster_params, + glide_connection_options.clone(), + ) + .await?; + + let topology_checks_interval = cluster_params.topology_checks_interval; + let slots_refresh_rate_limiter = cluster_params.slots_refresh_rate_limit; + let inner = Arc::new(InnerCore { + conn_lock: RwLock::new(ConnectionsContainer::new( + Default::default(), + connections, + cluster_params.read_from_replicas, + 0, + )), + cluster_params: cluster_params.clone(), + pending_requests: Mutex::new(Vec::new()), + slot_refresh_state: SlotRefreshState::new(slots_refresh_rate_limiter), + initial_nodes: initial_nodes.to_vec(), + unassigned_subscriptions: RwLock::new( + if let Some(subs) = cluster_params.pubsub_subscriptions { + subs.clone() + } else { + PubSubSubscriptionInfo::new() + }, + ), + subscriptions_by_address: RwLock::new(Default::default()), + glide_connection_options, + }); + let mut connection = ClusterConnInner { + inner, + in_flight_requests: Default::default(), + refresh_error: None, + state: ConnectionState::PollComplete, + periodic_checks_handler: None, + connections_validation_handler: None, + }; + Self::refresh_slots_and_subscriptions_with_retries( + connection.inner.clone(), + &RefreshPolicy::NotThrottable, + ) + .await?; + + if let Some(duration) = topology_checks_interval { + let periodic_task = + ClusterConnInner::periodic_topology_check(connection.inner.clone(), duration); + #[cfg(feature = "tokio-comp")] + { + connection.periodic_checks_handler = Some(tokio::spawn(periodic_task)); + } + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + { + connection.periodic_checks_handler = Some(spawn(periodic_task)); + } + } + + let connections_validation_interval = cluster_params.connections_validation_interval; + if let Some(duration) = connections_validation_interval { + let connections_validation_handler = + ClusterConnInner::connections_validation_task(connection.inner.clone(), duration); + #[cfg(feature = "tokio-comp")] + { + connection.connections_validation_handler = + Some(tokio::spawn(connections_validation_handler)); + } + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + { + connection.connections_validation_handler = + Some(spawn(connections_validation_handler)); + } + } + + Ok(Disposable::new(connection)) + } + + /// Go through each of the initial nodes and attempt to retrieve all IP entries from them. + /// If there's a DNS endpoint that directs to several IP addresses, add all addresses to the initial nodes list. + /// Returns a vector of tuples, each containing a node's address (including the hostname) and its corresponding SocketAddr if retrieved. + pub(crate) async fn try_to_expand_initial_nodes( + initial_nodes: &[ConnectionInfo], + ) -> Vec<(String, Option)> { + stream::iter(initial_nodes) + .fold( + Vec::with_capacity(initial_nodes.len()), + |mut acc, info| async { + let (host, port) = match &info.addr { + crate::ConnectionAddr::Tcp(host, port) => (host, port), + crate::ConnectionAddr::TcpTls { + host, + port, + insecure: _, + tls_params: _, + } => (host, port), + crate::ConnectionAddr::Unix(_) => { + // We don't support multiple addresses for a Unix address. Store the initial node address and continue + acc.push((info.addr.to_string(), None)); + return acc; + } + }; + match get_socket_addrs(host, *port).await { + Ok(socket_addrs) => { + for addr in socket_addrs { + acc.push((info.addr.to_string(), Some(addr))); + } + } + Err(_) => { + // Couldn't find socket addresses, store the initial node address and continue + acc.push((info.addr.to_string(), None)); + } + }; + acc + }, + ) + .await + } + + async fn create_initial_connections( + initial_nodes: &[ConnectionInfo], + params: &ClusterParams, + glide_connection_options: GlideConnectionOptions, + ) -> RedisResult> { + let initial_nodes: Vec<(String, Option)> = + Self::try_to_expand_initial_nodes(initial_nodes).await; + let connections = stream::iter(initial_nodes.iter().cloned()) + .map(|(node_addr, socket_addr)| { + let mut params: ClusterParams = params.clone(); + let glide_connection_options = glide_connection_options.clone(); + // set subscriptions to none, they will be applied upon the topology discovery + params.pubsub_subscriptions = None; + + async move { + let result = connect_and_check( + &node_addr, + params, + socket_addr, + RefreshConnectionType::AllConnections, + None, + glide_connection_options, + ) + .await + .get_node(); + let node_address = if let Some(socket_addr) = socket_addr { + socket_addr.to_string() + } else { + node_addr + }; + result.map(|node| (node_address, node)) + } + }) + .buffer_unordered(initial_nodes.len()) + .fold( + ( + ConnectionsMap(DashMap::with_capacity(initial_nodes.len())), + None, + ), + |connections: (ConnectionMap, Option), addr_conn_res| async move { + match addr_conn_res { + Ok((addr, node)) => { + connections.0 .0.insert(addr, node); + (connections.0, None) + } + Err(e) => (connections.0, Some(e.to_string())), + } + }, + ) + .await; + if connections.0 .0.is_empty() { + return Err(RedisError::from(( + ErrorKind::IoError, + "Failed to create initial connections", + connections.1.unwrap_or("".to_string()), + ))); + } + info!("Connected to initial nodes:\n{}", connections.0); + Ok(connections.0) + } + + fn reconnect_to_initial_nodes(&mut self) -> impl Future { + let inner = self.inner.clone(); + async move { + let connection_map = match Self::create_initial_connections( + &inner.initial_nodes, + &inner.cluster_params, + inner.glide_connection_options.clone(), + ) + .await + { + Ok(map) => map, + Err(err) => { + warn!("Can't reconnect to initial nodes: `{err}`"); + return; + } + }; + let mut write_lock = inner.conn_lock.write().await; + write_lock.extend_connection_map(connection_map); + drop(write_lock); + if let Err(err) = Self::refresh_slots_and_subscriptions_with_retries( + inner.clone(), + &RefreshPolicy::Throttable, + ) + .await + { + warn!("Can't refresh slots with initial nodes: `{err}`"); + }; + } + } + + // Validate all existing user connections and try to reconnect if nessesary. + // In addition, as a safety measure, drop nodes that do not have any assigned slots. + // This function serves as a cheap alternative to slot_refresh() and thus can be used much more frequently. + // The function does not discover the topology from the cluster and assumes the cached topology is valid. + // In addition, the validation is done by peeking at the state of the underlying transport w/o overhead of additional commands to server. + async fn validate_all_user_connections(inner: Arc>) { + let mut all_valid_conns = HashMap::new(); + // prep connections and clean out these w/o assigned slots, as we might have established connections to unwanted hosts + let mut nodes_to_delete = Vec::new(); + let connections_container = inner.conn_lock.read().await; + + let all_nodes_with_slots: HashSet = connections_container + .slot_map + .addresses_for_all_nodes() + .iter() + .map(|addr| String::from(*addr)) + .collect(); + + connections_container + .all_node_connections() + .for_each(|(addr, con)| { + if all_nodes_with_slots.contains(&addr) { + all_valid_conns.insert(addr.clone(), con.clone()); + } else { + nodes_to_delete.push(addr.clone()); + } + }); + + for addr in &nodes_to_delete { + connections_container.remove_node(addr); + } + + drop(connections_container); + + // identify nodes with closed connection + let mut addrs_to_refresh = Vec::new(); + for (addr, con_fut) in &all_valid_conns { + let con = con_fut.clone().await; + // connection object might be present despite the transport being closed + if con.is_closed() { + // transport is closed, need to refresh + addrs_to_refresh.push(addr.clone()); + } + } + + // identify missing nodes + addrs_to_refresh.extend( + all_nodes_with_slots + .iter() + .filter(|addr| !all_valid_conns.contains_key(*addr)) + .cloned(), + ); + + if !addrs_to_refresh.is_empty() { + // dont try existing nodes since we know a. it does not exist. b. exist but its connection is closed + Self::refresh_connections( + inner.clone(), + addrs_to_refresh, + RefreshConnectionType::AllConnections, + false, + ) + .await; + } + } + + async fn refresh_connections( + inner: Arc>, + addresses: Vec, + conn_type: RefreshConnectionType, + check_existing_conn: bool, + ) { + info!("Started refreshing connections to {:?}", addresses); + let connections_container = inner.conn_lock.read().await; + let cluster_params = &inner.cluster_params; + let subscriptions_by_address = &inner.subscriptions_by_address; + let glide_connection_optons = &inner.glide_connection_options; + + stream::iter(addresses.into_iter()) + .fold( + &*connections_container, + |connections_container, address| async move { + let node_option = if check_existing_conn { + connections_container.remove_node(&address) + } else { + None + }; + + // override subscriptions for this connection + let mut cluster_params = cluster_params.clone(); + let subs_guard = subscriptions_by_address.read().await; + cluster_params.pubsub_subscriptions = subs_guard.get(&address).cloned(); + drop(subs_guard); + let node = get_or_create_conn( + &address, + node_option, + &cluster_params, + conn_type, + glide_connection_optons.clone(), + ) + .await; + match node { + Ok(node) => { + connections_container + .replace_or_add_connection_for_address(address, node); + } + Err(err) => { + warn!( + "Failed to refresh connection for node {}. Error: `{:?}`", + address, err + ); + } + } + connections_container + }, + ) + .await; + info!("refresh connections completed"); + } + + async fn aggregate_results( + receivers: Vec<(Option, oneshot::Receiver>)>, + routing: &MultipleNodeRoutingInfo, + response_policy: Option, + ) -> RedisResult { + let extract_result = |response| match response { + Response::Single(value) => value, + Response::Multiple(_) => unreachable!(), + Response::ClusterScanResult(_, _) => unreachable!(), + }; + + let convert_result = |res: Result, _>| { + res.map_err(|_| RedisError::from((ErrorKind::ResponseError, "request wasn't handled due to internal failure"))) // this happens only if the result sender is dropped before usage. + .and_then(|res| res.map(extract_result)) + }; + + let get_receiver = |(_, receiver): (_, oneshot::Receiver>)| async { + convert_result(receiver.await) + }; + + // TODO - once Value::Error will be merged, these will need to be updated to handle this new value. + match response_policy { + Some(ResponsePolicy::AllSucceeded) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .map(|mut results| results.pop().unwrap()) // unwrap is safe, since at least one function succeeded + } + Some(ResponsePolicy::OneSucceeded) => future::select_ok( + receivers + .into_iter() + .map(|tuple| Box::pin(get_receiver(tuple))), + ) + .await + .map(|(result, _)| result), + Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty) => { + // Attempt to return the first result that isn't `Nil` or an error. + // If no such response is found and all servers returned `Nil`, it indicates that all shards are empty, so return `Nil`. + // If we received only errors, return the last received error. + // If we received a mix of errors and `Nil`s, we can't determine if all shards are empty, + // thus we return the last received error instead of `Nil`. + let num_of_results: usize = receivers.len(); + let mut futures = receivers + .into_iter() + .map(get_receiver) + .collect::>(); + let mut nil_counter = 0; + let mut last_err = None; + while let Some(result) = futures.next().await { + match result { + Ok(Value::Nil) => nil_counter += 1, + Ok(val) => return Ok(val), + Err(e) => last_err = Some(e), + } + } + + if nil_counter == num_of_results { + // All received results are `Nil` + Ok(Value::Nil) + } else { + Err(last_err.unwrap_or_else(|| { + ( + ErrorKind::AllConnectionsUnavailable, + "Couldn't find any connection", + ) + .into() + })) + } + } + Some(ResponsePolicy::Aggregate(op)) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(|results| crate::cluster_routing::aggregate(results, op)) + } + Some(ResponsePolicy::AggregateLogical(op)) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(|results| crate::cluster_routing::logical_aggregate(results, op)) + } + Some(ResponsePolicy::CombineArrays) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(|results| match routing { + MultipleNodeRoutingInfo::MultiSlot(vec) => { + crate::cluster_routing::combine_and_sort_array_results( + results, + vec.iter().map(|(_, indices)| indices), + ) + } + _ => crate::cluster_routing::combine_array_results(results), + }) + } + Some(ResponsePolicy::CombineMaps) => { + future::try_join_all(receivers.into_iter().map(get_receiver)) + .await + .and_then(crate::cluster_routing::combine_map_results) + } + Some(ResponsePolicy::Special) | None => { + // This is our assumption - if there's no coherent way to aggregate the responses, we just map each response to the sender, and pass it to the user. + // TODO - once Value::Error is merged, we can use join_all and report separate errors and also pass successes. + future::try_join_all(receivers.into_iter().map(|(addr, receiver)| async move { + let result = convert_result(receiver.await)?; + // The unwrap here is possible, because if `addr` is None, an error should have been sent on the receiver. + Ok((Value::BulkString(addr.unwrap().as_bytes().to_vec()), result)) + })) + .await + .map(Value::Map) + } + } + } + + // Query a node to discover slot-> master mappings with retries + async fn refresh_slots_and_subscriptions_with_retries( + inner: Arc>, + policy: &RefreshPolicy, + ) -> RedisResult<()> { + let SlotRefreshState { + in_progress, + last_run, + rate_limiter, + } = &inner.slot_refresh_state; + // Ensure only a single slot refresh operation occurs at a time + if in_progress + .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) + .is_err() + { + return Ok(()); + } + let mut skip_slots_refresh = false; + if *policy == RefreshPolicy::Throttable { + // Check if the current slot refresh is triggered before the wait duration has passed + let last_run_rlock = last_run.read().await; + if let Some(last_run_time) = *last_run_rlock { + let passed_time = SystemTime::now() + .duration_since(last_run_time) + .unwrap_or_else(|err| { + warn!( + "Failed to get the duration since the last slot refresh, received error: {:?}", + err + ); + // Setting the passed time to 0 will force the current refresh to continue and reset the stored last_run timestamp with the current one + Duration::from_secs(0) + }); + let wait_duration = rate_limiter.wait_duration(); + if passed_time <= wait_duration { + debug!("Skipping slot refresh as the wait duration hasn't yet passed. Passed time = {:?}, + Wait duration = {:?}", passed_time, wait_duration); + skip_slots_refresh = true; + } + } + } + + let mut res = Ok(()); + if !skip_slots_refresh { + let retry_strategy = ExponentialBackoff { + initial_interval: DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL, + max_interval: DEFAULT_REFRESH_SLOTS_RETRY_MAX_INTERVAL, + max_elapsed_time: None, + ..Default::default() + }; + let retries_counter = AtomicUsize::new(0); + res = retry(retry_strategy, || { + let curr_retry = retries_counter.fetch_add(1, atomic::Ordering::Relaxed); + Self::refresh_slots(inner.clone(), curr_retry) + }) + .await; + } + in_progress.store(false, Ordering::Relaxed); + + Self::refresh_pubsub_subscriptions(inner).await; + + res + } + + pub(crate) async fn check_topology_and_refresh_if_diff( + inner: Arc>, + policy: &RefreshPolicy, + ) -> bool { + let topology_changed = Self::check_for_topology_diff(inner.clone()).await; + if topology_changed { + let _ = Self::refresh_slots_and_subscriptions_with_retries(inner.clone(), policy).await; + } + topology_changed + } + + async fn periodic_topology_check(inner: Arc>, interval_duration: Duration) { + loop { + let _ = boxed_sleep(interval_duration).await; + let topology_changed = + Self::check_topology_and_refresh_if_diff(inner.clone(), &RefreshPolicy::Throttable) + .await; + if !topology_changed { + // This serves as a safety measure for validating pubsub subsctiptions state in case it has drifted + // while topology stayed the same. + // For example, a failed attempt to refresh a connection which is triggered from refresh_pubsub_subscriptions(), + // might leave a node unconnected indefinitely in case topology is stable and no request are attempted to this node. + Self::refresh_pubsub_subscriptions(inner.clone()).await; + } + } + } + + async fn connections_validation_task(inner: Arc>, interval_duration: Duration) { + loop { + if let Some(disconnect_notifier) = + inner.glide_connection_options.disconnect_notifier.clone() + { + disconnect_notifier + .wait_for_disconnect_with_timeout(&interval_duration) + .await; + } else { + let _ = boxed_sleep(interval_duration).await; + } + + Self::validate_all_user_connections(inner.clone()).await; + } + } + + async fn refresh_pubsub_subscriptions(inner: Arc>) { + if inner.cluster_params.protocol != crate::types::ProtocolVersion::RESP3 { + return; + } + + let mut addrs_to_refresh: HashSet = HashSet::new(); + let mut subs_by_address_guard = inner.subscriptions_by_address.write().await; + let mut unassigned_subs_guard = inner.unassigned_subscriptions.write().await; + let conns_read_guard = inner.conn_lock.read().await; + + // validate active subscriptions location + subs_by_address_guard.retain(|current_address, address_subs| { + address_subs.retain(|kind, channels_patterns| { + channels_patterns.retain(|channel_pattern| { + let new_slot = get_slot(channel_pattern); + let mut valid = false; + if let Some((new_address, _)) = conns_read_guard + .connection_for_route(&Route::new(new_slot, SlotAddr::Master)) + { + if *new_address == *current_address { + valid = true; + } + } + // no new address or new address differ - move to unassigned and store this address for connection reset + if !valid { + // need to drop the original connection for clearing the subscription in the server, avoiding possible double-receivers + if conns_read_guard + .connection_for_address(current_address) + .is_some() + { + addrs_to_refresh.insert(current_address.clone()); + } + + unassigned_subs_guard + .entry(*kind) + .and_modify(|channels_patterns| { + channels_patterns.insert(channel_pattern.clone()); + }) + .or_insert(HashSet::from([channel_pattern.clone()])); + } + valid + }); + !channels_patterns.is_empty() + }); + !address_subs.is_empty() + }); + + // try to assign new addresses + unassigned_subs_guard.retain(|kind: &PubSubSubscriptionKind, channels_patterns| { + channels_patterns.retain(|channel_pattern| { + let new_slot = get_slot(channel_pattern); + if let Some((new_address, _)) = + conns_read_guard.connection_for_route(&Route::new(new_slot, SlotAddr::Master)) + { + // need to drop the new connection so the subscription will be picked up in setup_connection() + addrs_to_refresh.insert(new_address.clone()); + + let e = subs_by_address_guard + .entry(new_address.clone()) + .or_insert(PubSubSubscriptionInfo::new()); + + e.entry(*kind) + .or_insert(HashSet::new()) + .insert(channel_pattern.clone()); + + return false; + } + true + }); + !channels_patterns.is_empty() + }); + + drop(conns_read_guard); + drop(unassigned_subs_guard); + drop(subs_by_address_guard); + + if !addrs_to_refresh.is_empty() { + // immediately trigger connection reestablishment + Self::refresh_connections( + inner.clone(), + addrs_to_refresh.into_iter().collect(), + RefreshConnectionType::AllConnections, + false, + ) + .await; + } + } + + /// Queries log2n nodes (where n represents the number of cluster nodes) to determine whether their + /// topology view differs from the one currently stored in the connection manager. + /// Returns true if change was detected, otherwise false. + async fn check_for_topology_diff(inner: Arc>) -> bool { + let read_guard = inner.conn_lock.read().await; + let num_of_nodes: usize = read_guard.len(); + // TODO: Starting from Rust V1.67, integers has logarithms support. + // When we no longer need to support Rust versions < 1.67, remove fast_math and transition to the ilog2 function. + let num_of_nodes_to_query = + std::cmp::max(fast_math::log2_raw(num_of_nodes as f32) as usize, 1); + let (res, failed_connections) = calculate_topology_from_random_nodes( + &inner, + num_of_nodes_to_query, + &read_guard, + DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES, + ) + .await; + + if let Ok((_, found_topology_hash)) = res { + if read_guard.get_current_topology_hash() != found_topology_hash { + return true; + } + } + drop(read_guard); + + if !failed_connections.is_empty() { + Self::refresh_connections( + inner, + failed_connections, + RefreshConnectionType::OnlyManagementConnection, + true, + ) + .await; + } + + false + } + + async fn refresh_slots( + inner: Arc>, + curr_retry: usize, + ) -> Result<(), BackoffError> { + // Update the slot refresh last run timestamp + let now = SystemTime::now(); + let mut last_run_wlock = inner.slot_refresh_state.last_run.write().await; + *last_run_wlock = Some(now); + drop(last_run_wlock); + Self::refresh_slots_inner(inner, curr_retry) + .await + .map_err(|err| { + if curr_retry > DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES { + BackoffError::Permanent(err) + } else { + BackoffError::from(err) + } + }) + } + + pub(crate) fn check_if_all_slots_covered(slot_map: &SlotMap) -> bool { + let mut slots_covered = 0; + for (end, slots) in slot_map.slots.iter() { + slots_covered += end.saturating_sub(slots.start).saturating_add(1); + } + slots_covered == SLOT_SIZE + } + + // Query a node to discover slot-> master mappings + async fn refresh_slots_inner(inner: Arc>, curr_retry: usize) -> RedisResult<()> { + let read_guard = inner.conn_lock.read().await; + let num_of_nodes = read_guard.len(); + const MAX_REQUESTED_NODES: usize = 10; + let num_of_nodes_to_query = std::cmp::min(num_of_nodes, MAX_REQUESTED_NODES); + let (new_slots, topology_hash) = calculate_topology_from_random_nodes( + &inner, + num_of_nodes_to_query, + &read_guard, + curr_retry, + ) + .await + .0?; + let connections = &*read_guard; + // Create a new connection vector of the found nodes + let mut nodes = new_slots.values().flatten().collect::>(); + nodes.sort_unstable(); + nodes.dedup(); + let nodes_len = nodes.len(); + let addresses_and_connections_iter = stream::iter(nodes) + .fold( + Vec::with_capacity(nodes_len), + |mut addrs_and_conns, addr| async move { + if let Some(node) = connections.node_for_address(addr.as_str()) { + addrs_and_conns.push((addr, Some(node))); + return addrs_and_conns; + } + // If it's a DNS endpoint, it could have been stored in the existing connections vector using the resolved IP address instead of the DNS endpoint's name. + // We shall check if a connection is already exists under the resolved IP name. + let (host, port) = match get_host_and_port_from_addr(addr) { + Some((host, port)) => (host, port), + None => { + addrs_and_conns.push((addr, None)); + return addrs_and_conns; + } + }; + let conn = get_socket_addrs(host, port) + .await + .ok() + .map(|mut socket_addresses| { + socket_addresses + .find_map(|addr| connections.node_for_address(&addr.to_string())) + }) + .unwrap_or(None); + addrs_and_conns.push((addr, conn)); + addrs_and_conns + }, + ) + .await; + let new_connections: ConnectionMap = stream::iter(addresses_and_connections_iter) + .fold( + ConnectionsMap(DashMap::with_capacity(nodes_len)), + |connections, (addr, node)| async { + let mut cluster_params = inner.cluster_params.clone(); + let subs_guard = inner.subscriptions_by_address.read().await; + cluster_params.pubsub_subscriptions = subs_guard.get(addr).cloned(); + drop(subs_guard); + let node = get_or_create_conn( + addr, + node, + &cluster_params, + RefreshConnectionType::AllConnections, + inner.glide_connection_options.clone(), + ) + .await; + if let Ok(node) = node { + connections.0.insert(addr.into(), node); + } + connections + }, + ) + .await; + + drop(read_guard); + info!("refresh_slots found nodes:\n{new_connections}"); + // Replace the current slot map and connection vector with the new ones + let mut write_guard = inner.conn_lock.write().await; + *write_guard = ConnectionsContainer::new( + new_slots, + new_connections, + inner.cluster_params.read_from_replicas, + topology_hash, + ); + Ok(()) + } + + async fn execute_on_multiple_nodes<'a>( + cmd: &'a Arc, + routing: &'a MultipleNodeRoutingInfo, + core: Core, + response_policy: Option, + ) -> OperationResult { + trace!("execute_on_multiple_nodes"); + let connections_container = core.conn_lock.read().await; + if connections_container.is_empty() { + return OperationResult::Err(( + OperationTarget::FanOut, + ( + ErrorKind::AllConnectionsUnavailable, + "No connections found for multi-node operation", + ) + .into(), + )); + } + + // This function maps the connections to senders & receivers of one-shot channels, and the receivers are mapped to `PendingRequest`s. + // This allows us to pass the new `PendingRequest`s to `try_request`, while letting `execute_on_multiple_nodes` wait on the receivers + // for all of the individual requests to complete. + #[allow(clippy::type_complexity)] // The return value is complex, but indentation and linebreaks make it human readable. + fn into_channels( + iterator: impl Iterator< + Item = Option<(Arc, ConnectionAndAddress>)>, + >, + ) -> ( + Vec<(Option, Receiver>)>, + Vec>>, + ) { + iterator + .map(|tuple_opt| { + let (sender, receiver) = oneshot::channel(); + if let Some((cmd, conn, address)) = + tuple_opt.map(|(cmd, (address, conn))| (cmd, conn, address)) + { + ( + (Some(address.clone()), receiver), + Some(PendingRequest { + retry: 0, + sender, + info: RequestInfo { + cmd: CmdArg::Cmd { + cmd, + routing: InternalSingleNodeRouting::Connection { + address, + conn, + } + .into(), + }, + }, + }), + ) + } else { + let _ = sender.send(Err(( + ErrorKind::ConnectionNotFoundForRoute, + "Connection not found", + ) + .into())); + ((None, receiver), None) + } + }) + .unzip() + } + + let (receivers, requests): (Vec<_>, Vec<_>) = match routing { + MultipleNodeRoutingInfo::AllNodes => into_channels( + connections_container + .all_node_connections() + .map(|tuple| Some((cmd.clone(), tuple))), + ), + MultipleNodeRoutingInfo::AllMasters => into_channels( + connections_container + .all_primary_connections() + .map(|tuple| Some((cmd.clone(), tuple))), + ), + MultipleNodeRoutingInfo::MultiSlot(slots) => { + into_channels(slots.iter().map(|(route, indices)| { + connections_container + .connection_for_route(route) + .map(|tuple| { + let new_cmd = crate::cluster_routing::command_for_multi_slot_indices( + cmd.as_ref(), + indices.iter(), + ); + (Arc::new(new_cmd), tuple) + }) + })) + } + }; + + drop(connections_container); + core.pending_requests + .lock() + .unwrap() + .extend(requests.into_iter().flatten()); + + Self::aggregate_results(receivers, routing, response_policy) + .await + .map(Response::Single) + .map_err(|err| (OperationTarget::FanOut, err)) + } + + pub(crate) async fn try_cmd_request( + cmd: Arc, + routing: InternalRoutingInfo, + core: Core, + ) -> OperationResult { + let routing = match routing { + // commands that are sent to multiple nodes are handled here. + InternalRoutingInfo::MultiNode((multi_node_routing, response_policy)) => { + return Self::execute_on_multiple_nodes( + &cmd, + &multi_node_routing, + core, + response_policy, + ) + .await; + } + + InternalRoutingInfo::SingleNode(routing) => routing, + }; + trace!("route request to single node"); + + // if we reached this point, we're sending the command only to single node, and we need to find the + // right connection to the node. + let (address, mut conn) = Self::get_connection(routing, core, Some(cmd.clone())) + .await + .map_err(|err| (OperationTarget::NotFound, err))?; + conn.req_packed_command(&cmd) + .await + .map(Response::Single) + .map_err(|err| (address.into(), err)) + } + + async fn try_pipeline_request( + pipeline: Arc, + offset: usize, + count: usize, + conn: impl Future>, + ) -> OperationResult { + trace!("try_pipeline_request"); + let (address, mut conn) = conn.await.map_err(|err| (OperationTarget::NotFound, err))?; + conn.req_packed_commands(&pipeline, offset, count) + .await + .map(Response::Multiple) + .map_err(|err| (OperationTarget::Node { address }, err)) + } + + async fn try_request(info: RequestInfo, core: Core) -> OperationResult { + match info.cmd { + CmdArg::Cmd { cmd, routing } => Self::try_cmd_request(cmd, routing, core).await, + CmdArg::Pipeline { + pipeline, + offset, + count, + route, + } => { + Self::try_pipeline_request( + pipeline, + offset, + count, + Self::get_connection(route, core, None), + ) + .await + } + CmdArg::ClusterScan { + cluster_scan_args, .. + } => { + let core = core; + let scan_result = cluster_scan(core, cluster_scan_args).await; + match scan_result { + Ok((scan_state_ref, values)) => { + Ok(Response::ClusterScanResult(scan_state_ref, values)) + } + // TODO: After routing issues with sending to random node on not-key based commands are resolved, + // this error should be handled in the same way as other errors and not fan-out. + Err(err) => Err((OperationTarget::FanOut, err)), + } + } + } + } + + async fn get_connection( + routing: InternalSingleNodeRouting, + core: Core, + cmd: Option>, + ) -> RedisResult<(String, C)> { + let read_guard = core.conn_lock.read().await; + let mut asking = false; + + let conn_check = match routing { + InternalSingleNodeRouting::Redirect { + redirect: Redirect::Moved(moved_addr), + .. + } => read_guard + .connection_for_address(moved_addr.as_str()) + .map_or( + ConnectionCheck::OnlyAddress(moved_addr), + ConnectionCheck::Found, + ), + InternalSingleNodeRouting::Redirect { + redirect: Redirect::Ask(ask_addr), + .. + } => { + asking = true; + read_guard.connection_for_address(ask_addr.as_str()).map_or( + ConnectionCheck::OnlyAddress(ask_addr), + ConnectionCheck::Found, + ) + } + InternalSingleNodeRouting::SpecificNode(route) => { + match read_guard.connection_for_route(&route) { + Some((conn, address)) => ConnectionCheck::Found((conn, address)), + None => { + // No connection is found for the given route: + // - For key-based commands, attempt redirection to a random node, + // hopefully to be redirected afterwards by a MOVED error. + // - For non-key-based commands, avoid attempting redirection to a random node + // as it wouldn't result in MOVED hints and can lead to unwanted results + // (e.g., sending management command to a different node than the user asked for); instead, raise the error. + let routable_cmd = cmd.and_then(|cmd| Routable::command(&*cmd)); + if routable_cmd.is_some() + && !RoutingInfo::is_key_routing_command(&routable_cmd.unwrap()) + { + return Err(( + ErrorKind::ConnectionNotFoundForRoute, + "Requested connection not found for route", + format!("{route:?}"), + ) + .into()); + } else { + warn!("No connection found for route `{route:?}`. Attempting redirection to a random node."); + ConnectionCheck::RandomConnection + } + } + } + } + InternalSingleNodeRouting::Random => ConnectionCheck::RandomConnection, + InternalSingleNodeRouting::Connection { address, conn } => { + return Ok((address, conn.await)); + } + InternalSingleNodeRouting::ByAddress(address) => { + if let Some((address, conn)) = read_guard.connection_for_address(&address) { + return Ok((address, conn.await)); + } else { + return Err(( + ErrorKind::ConnectionNotFoundForRoute, + "Requested connection not found", + address, + ) + .into()); + } + } + }; + drop(read_guard); + + let (address, mut conn) = match conn_check { + ConnectionCheck::Found((address, connection)) => (address, connection.await), + ConnectionCheck::OnlyAddress(addr) => { + let mut this_conn_params = core.cluster_params.clone(); + let subs_guard = core.subscriptions_by_address.read().await; + this_conn_params.pubsub_subscriptions = subs_guard.get(addr.as_str()).cloned(); + drop(subs_guard); + match connect_and_check::( + &addr, + this_conn_params, + None, + RefreshConnectionType::AllConnections, + None, + core.glide_connection_options.clone(), + ) + .await + .get_node() + { + Ok(node) => { + let connection_clone = node.user_connection.conn.clone().await; + let connections = core.conn_lock.read().await; + let address = connections.replace_or_add_connection_for_address(addr, node); + drop(connections); + (address, connection_clone) + } + Err(err) => { + return Err(err); + } + } + } + ConnectionCheck::RandomConnection => { + let read_guard = core.conn_lock.read().await; + let (random_address, random_conn_future) = read_guard + .random_connections(1, ConnectionType::User) + .next() + .ok_or(RedisError::from(( + ErrorKind::AllConnectionsUnavailable, + "No random connection found", + )))?; + return Ok((random_address, random_conn_future.await)); + } + }; + + if asking { + let _ = conn.req_packed_command(&crate::cmd::cmd("ASKING")).await; + } + Ok((address, conn)) + } + + fn poll_recover(&mut self, cx: &mut task::Context<'_>) -> Poll> { + let recover_future = match &mut self.state { + ConnectionState::PollComplete => return Poll::Ready(Ok(())), + ConnectionState::Recover(future) => future, + }; + match recover_future { + RecoverFuture::RecoverSlots(ref mut future) => match ready!(future.as_mut().poll(cx)) { + Ok(_) => { + trace!("Recovered!"); + self.state = ConnectionState::PollComplete; + Poll::Ready(Ok(())) + } + Err(err) => { + trace!("Recover slots failed!"); + *future = Box::pin(Self::refresh_slots_and_subscriptions_with_retries( + self.inner.clone(), + &RefreshPolicy::Throttable, + )); + Poll::Ready(Err(err)) + } + }, + RecoverFuture::Reconnect(ref mut future) => { + ready!(future.as_mut().poll(cx)); + trace!("Reconnected connections"); + self.state = ConnectionState::PollComplete; + Poll::Ready(Ok(())) + } + } + } + + async fn handle_loading_error( + core: Core, + info: RequestInfo, + address: String, + retry: u32, + ) -> OperationResult { + let is_primary = core.conn_lock.read().await.is_primary(&address); + + if !is_primary { + // If the connection is a replica, remove the connection and retry. + // The connection will be established again on the next call to refresh slots once the replica is no longer in loading state. + core.conn_lock.read().await.remove_node(&address); + } else { + // If the connection is primary, just sleep and retry + let sleep_duration = core.cluster_params.retry_params.wait_time_for_retry(retry); + boxed_sleep(sleep_duration).await; + } + + Self::try_request(info, core).await + } + + fn poll_complete(&mut self, cx: &mut task::Context<'_>) -> Poll { + let mut poll_flush_action = PollFlushAction::None; + + let mut pending_requests_guard = self.inner.pending_requests.lock().unwrap(); + if !pending_requests_guard.is_empty() { + let mut pending_requests = mem::take(&mut *pending_requests_guard); + for request in pending_requests.drain(..) { + // Drop the request if none is waiting for a response to free up resources for + // requests callers care about (load shedding). It will be ambiguous whether the + // request actually goes through regardless. + if request.sender.is_closed() { + continue; + } + + let future = Self::try_request(request.info.clone(), self.inner.clone()).boxed(); + self.in_flight_requests.push(Box::pin(Request { + retry_params: self.inner.cluster_params.retry_params.clone(), + request: Some(request), + future: RequestState::Future { future }, + })); + } + *pending_requests_guard = pending_requests; + } + drop(pending_requests_guard); + + loop { + let result = match Pin::new(&mut self.in_flight_requests).poll_next(cx) { + Poll::Ready(Some(result)) => result, + Poll::Ready(None) | Poll::Pending => break, + }; + match result { + Next::Done => {} + Next::Retry { request } => { + let future = Self::try_request(request.info.clone(), self.inner.clone()); + self.in_flight_requests.push(Box::pin(Request { + retry_params: self.inner.cluster_params.retry_params.clone(), + request: Some(request), + future: RequestState::Future { + future: Box::pin(future), + }, + })); + } + Next::RetryBusyLoadingError { request, address } => { + // TODO - do we also want to try and reconnect to replica if it is loading? + let future = Self::handle_loading_error( + self.inner.clone(), + request.info.clone(), + address, + request.retry, + ); + self.in_flight_requests.push(Box::pin(Request { + retry_params: self.inner.cluster_params.retry_params.clone(), + request: Some(request), + future: RequestState::Future { + future: Box::pin(future), + }, + })); + } + Next::RefreshSlots { + request, + sleep_duration, + } => { + poll_flush_action = + poll_flush_action.change_state(PollFlushAction::RebuildSlots); + if let Some(request) = request { + let future: RequestState< + Pin + Send>>, + > = match sleep_duration { + Some(sleep_duration) => RequestState::Sleep { + sleep: boxed_sleep(sleep_duration), + }, + None => RequestState::Future { + future: Box::pin(Self::try_request( + request.info.clone(), + self.inner.clone(), + )), + }, + }; + self.in_flight_requests.push(Box::pin(Request { + retry_params: self.inner.cluster_params.retry_params.clone(), + request: Some(request), + future, + })); + } + } + Next::Reconnect { + request, target, .. + } => { + poll_flush_action = + poll_flush_action.change_state(PollFlushAction::Reconnect(vec![target])); + if let Some(request) = request { + self.inner.pending_requests.lock().unwrap().push(request); + } + } + Next::ReconnectToInitialNodes { request } => { + poll_flush_action = poll_flush_action + .change_state(PollFlushAction::ReconnectFromInitialConnections); + if let Some(request) = request { + self.inner.pending_requests.lock().unwrap().push(request); + } + } + } + } + + if matches!(poll_flush_action, PollFlushAction::None) { + if self.in_flight_requests.is_empty() { + Poll::Ready(poll_flush_action) + } else { + Poll::Pending + } + } else { + Poll::Ready(poll_flush_action) + } + } + + fn send_refresh_error(&mut self) { + if self.refresh_error.is_some() { + if let Some(mut request) = Pin::new(&mut self.in_flight_requests) + .iter_pin_mut() + .find(|request| request.request.is_some()) + { + (*request) + .as_mut() + .respond(Err(self.refresh_error.take().unwrap())); + } else if let Some(request) = self.inner.pending_requests.lock().unwrap().pop() { + let _ = request.sender.send(Err(self.refresh_error.take().unwrap())); + } + } + } +} + +enum PollFlushAction { + None, + RebuildSlots, + Reconnect(Vec), + ReconnectFromInitialConnections, +} + +impl PollFlushAction { + fn change_state(self, next_state: PollFlushAction) -> PollFlushAction { + match (self, next_state) { + (PollFlushAction::None, next_state) => next_state, + (next_state, PollFlushAction::None) => next_state, + (PollFlushAction::ReconnectFromInitialConnections, _) + | (_, PollFlushAction::ReconnectFromInitialConnections) => { + PollFlushAction::ReconnectFromInitialConnections + } + + (PollFlushAction::RebuildSlots, _) | (_, PollFlushAction::RebuildSlots) => { + PollFlushAction::RebuildSlots + } + + (PollFlushAction::Reconnect(mut addrs), PollFlushAction::Reconnect(new_addrs)) => { + addrs.extend(new_addrs); + Self::Reconnect(addrs) + } + } + } +} + +impl Sink> for Disposable> +where + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, +{ + type Error = (); + + fn poll_ready(self: Pin<&mut Self>, _cx: &mut task::Context) -> Poll> { + Poll::Ready(Ok(())) + } + + fn start_send(self: Pin<&mut Self>, msg: Message) -> Result<(), Self::Error> { + let Message { cmd, sender } = msg; + + let info = RequestInfo { cmd }; + + self.inner + .pending_requests + .lock() + .unwrap() + .push(PendingRequest { + retry: 0, + sender, + info, + }); + Ok(()) + } + + fn poll_flush( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + trace!("poll_flush: {:?}", self.state); + loop { + self.send_refresh_error(); + + if let Err(err) = ready!(self.as_mut().poll_recover(cx)) { + // We failed to reconnect, while we will try again we will report the + // error if we can to avoid getting trapped in an infinite loop of + // trying to reconnect + self.refresh_error = Some(err); + + // Give other tasks a chance to progress before we try to recover + // again. Since the future may not have registered a wake up we do so + // now so the task is not forgotten + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + match ready!(self.poll_complete(cx)) { + PollFlushAction::None => return Poll::Ready(Ok(())), + PollFlushAction::RebuildSlots => { + self.state = ConnectionState::Recover(RecoverFuture::RecoverSlots(Box::pin( + ClusterConnInner::refresh_slots_and_subscriptions_with_retries( + self.inner.clone(), + &RefreshPolicy::Throttable, + ), + ))); + } + PollFlushAction::Reconnect(addresses) => { + self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin( + ClusterConnInner::refresh_connections( + self.inner.clone(), + addresses, + RefreshConnectionType::OnlyUserConnection, + true, + ), + ))); + } + PollFlushAction::ReconnectFromInitialConnections => { + self.state = ConnectionState::Recover(RecoverFuture::Reconnect(Box::pin( + self.reconnect_to_initial_nodes(), + ))); + } + } + } + } + + fn poll_close( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + ) -> Poll> { + // Try to drive any in flight requests to completion + match self.poll_complete(cx) { + Poll::Ready(PollFlushAction::None) => (), + Poll::Ready(_) => Err(())?, + Poll::Pending => (), + }; + // If we no longer have any requests in flight we are done (skips any reconnection + // attempts) + if self.in_flight_requests.is_empty() { + return Poll::Ready(Ok(())); + } + + self.poll_flush(cx) + } +} + +async fn calculate_topology_from_random_nodes<'a, C>( + inner: &Core, + num_of_nodes_to_query: usize, + read_guard: &tokio::sync::RwLockReadGuard<'a, ConnectionsContainer>, + curr_retry: usize, +) -> ( + RedisResult<( + crate::cluster_slotmap::SlotMap, + crate::cluster_topology::TopologyHash, + )>, + Vec, +) +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + let requested_nodes = + read_guard.random_connections(num_of_nodes_to_query, ConnectionType::PreferManagement); + let topology_join_results = + futures::future::join_all(requested_nodes.map(|(addr, conn)| async move { + let mut conn: C = conn.await; + let res = conn.req_packed_command(&slot_cmd()).await; + (addr, res) + })) + .await; + let failed_addresses = topology_join_results + .iter() + .filter_map(|(address, res)| match res { + Err(err) if err.is_unrecoverable_error() => Some(address.clone()), + _ => None, + }) + .collect(); + let topology_values = topology_join_results.iter().filter_map(|(addr, res)| { + res.as_ref() + .ok() + .and_then(|value| get_host_and_port_from_addr(addr).map(|(host, _)| (host, value))) + }); + ( + calculate_topology( + topology_values, + curr_retry, + inner.cluster_params.tls, + num_of_nodes_to_query, + inner.cluster_params.read_from_replicas, + ), + failed_addresses, + ) +} + +impl ConnectionLike for ClusterConnection +where + C: ConnectionLike + Send + Clone + Unpin + Sync + Connect + 'static, +{ + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + let routing = cluster_routing::RoutingInfo::for_routable(cmd).unwrap_or( + cluster_routing::RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random), + ); + self.route_command(cmd, routing).boxed() + } + + fn req_packed_commands<'a>( + &'a mut self, + pipeline: &'a crate::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + async move { + let route = route_for_pipeline(pipeline)?; + self.route_pipeline(pipeline, offset, count, route.into()) + .await + } + .boxed() + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_closed(&self) -> bool { + false + } +} + +/// Implements the process of connecting to a Redis server +/// and obtaining a connection handle. +pub trait Connect: Sized { + /// Connect to a node. + /// For TCP connections, returning a tuple of handle for command execution and the node's IP address. + /// For UNIX connections, returning a tuple of handle for command execution and None. + fn connect<'a, T>( + info: T, + response_timeout: Duration, + connection_timeout: Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (Self, Option)> + where + T: IntoConnectionInfo + Send + 'a; +} + +impl Connect for MultiplexedConnection { + fn connect<'a, T>( + info: T, + response_timeout: Duration, + connection_timeout: Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (MultiplexedConnection, Option)> + where + T: IntoConnectionInfo + Send + 'a, + { + async move { + let connection_info = info.into_connection_info()?; + let client = crate::Client::open(connection_info)?; + + match Runtime::locate() { + #[cfg(feature = "tokio-comp")] + rt @ Runtime::Tokio => { + rt.timeout( + connection_timeout, + client.get_multiplexed_async_connection_inner::( + response_timeout, + socket_addr, + glide_connection_options, + ), + ) + .await? + } + #[cfg(feature = "async-std-comp")] + rt @ Runtime::AsyncStd => { + rt.timeout(connection_timeout,client + .get_multiplexed_async_connection_inner::( + response_timeout, + socket_addr, + glide_connection_options, + )) + .await? + } + } + } + .boxed() + } +} + +#[cfg(test)] +mod pipeline_routing_tests { + use super::route_for_pipeline; + use crate::{ + cluster_routing::{Route, SlotAddr}, + cmd, + }; + + #[test] + fn test_first_route_is_found() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .add_command(cmd("FLUSHALL")) // route to all masters + .get("foo") // route to slot 12182 + .add_command(cmd("EVAL")); // route randomly + + assert_eq!( + route_for_pipeline(&pipeline), + Ok(Some(Route::new(12182, SlotAddr::ReplicaOptional))) + ); + } + + #[test] + fn test_return_none_if_no_route_is_found() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .add_command(cmd("FLUSHALL")) // route to all masters + .add_command(cmd("EVAL")); // route randomly + + assert_eq!(route_for_pipeline(&pipeline), Ok(None)); + } + + #[test] + fn test_prefer_primary_route_over_replica() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .get("foo") // route to replica of slot 12182 + .add_command(cmd("FLUSHALL")) // route to all masters + .add_command(cmd("EVAL"))// route randomly + .cmd("CONFIG").arg("GET").arg("timeout") // unkeyed command + .set("foo", "bar"); // route to primary of slot 12182 + + assert_eq!( + route_for_pipeline(&pipeline), + Ok(Some(Route::new(12182, SlotAddr::Master))) + ); + } + + #[test] + fn test_raise_cross_slot_error_on_conflicting_slots() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .add_command(cmd("FLUSHALL")) // route to all masters + .set("baz", "bar") // route to slot 4813 + .get("foo"); // route to slot 12182 + + assert_eq!( + route_for_pipeline(&pipeline).unwrap_err().kind(), + crate::ErrorKind::CrossSlot + ); + } + + #[test] + fn unkeyed_commands_dont_affect_route() { + let mut pipeline = crate::Pipeline::new(); + + pipeline + .set("{foo}bar", "baz") // route to primary of slot 12182 + .cmd("CONFIG").arg("GET").arg("timeout") // unkeyed command + .set("foo", "bar") // route to primary of slot 12182 + .cmd("DEBUG").arg("PAUSE").arg("100") // unkeyed command + .cmd("ECHO").arg("hello world"); // unkeyed command + + assert_eq!( + route_for_pipeline(&pipeline), + Ok(Some(Route::new(12182, SlotAddr::Master))) + ); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_client.rs b/glide-core/redis-rs/redis/src/cluster_client.rs new file mode 100644 index 0000000000..5815bede1e --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_client.rs @@ -0,0 +1,752 @@ +use crate::cluster_slotmap::ReadFromReplicaStrategy; +#[cfg(feature = "cluster-async")] +use crate::cluster_topology::{ + DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, DEFAULT_SLOTS_REFRESH_WAIT_DURATION, +}; +use crate::connection::{ConnectionAddr, ConnectionInfo, IntoConnectionInfo}; +use crate::types::{ErrorKind, ProtocolVersion, RedisError, RedisResult}; +use crate::{cluster, cluster::TlsMode}; +use crate::{PubSubSubscriptionInfo, PushInfo}; +use rand::Rng; +#[cfg(feature = "cluster-async")] +use std::ops::Add; +use std::time::Duration; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +#[cfg(not(feature = "tls-rustls"))] +use crate::connection::TlsConnParams; + +#[cfg(feature = "cluster-async")] +use crate::cluster_async; + +#[cfg(feature = "tls-rustls")] +use crate::tls::{retrieve_tls_certificates, TlsCertificates}; + +use tokio::sync::mpsc; + +/// Parameters specific to builder, so that +/// builder parameters may have different types +/// than final ClusterParams +#[derive(Default)] +struct BuilderParams { + password: Option, + username: Option, + read_from_replicas: ReadFromReplicaStrategy, + tls: Option, + #[cfg(feature = "tls-rustls")] + certs: Option, + retries_configuration: RetryParams, + connection_timeout: Option, + #[cfg(feature = "cluster-async")] + topology_checks_interval: Option, + #[cfg(feature = "cluster-async")] + connections_validation_interval: Option, + #[cfg(feature = "cluster-async")] + slots_refresh_rate_limit: SlotsRefreshRateLimit, + client_name: Option, + response_timeout: Option, + protocol: ProtocolVersion, + pubsub_subscriptions: Option, +} + +#[derive(Clone)] +pub(crate) struct RetryParams { + pub(crate) number_of_retries: u32, + max_wait_time: u64, + min_wait_time: u64, + exponent_base: u64, + factor: u64, +} + +impl Default for RetryParams { + fn default() -> Self { + const DEFAULT_RETRIES: u32 = 16; + const DEFAULT_MAX_RETRY_WAIT_TIME: u64 = 655360; + const DEFAULT_MIN_RETRY_WAIT_TIME: u64 = 1280; + const DEFAULT_EXPONENT_BASE: u64 = 2; + const DEFAULT_FACTOR: u64 = 10; + Self { + number_of_retries: DEFAULT_RETRIES, + max_wait_time: DEFAULT_MAX_RETRY_WAIT_TIME, + min_wait_time: DEFAULT_MIN_RETRY_WAIT_TIME, + exponent_base: DEFAULT_EXPONENT_BASE, + factor: DEFAULT_FACTOR, + } + } +} + +impl RetryParams { + pub(crate) fn wait_time_for_retry(&self, retry: u32) -> Duration { + let base_wait = self.exponent_base.pow(retry) * self.factor; + let clamped_wait = base_wait + .min(self.max_wait_time) + .max(self.min_wait_time + 1); + let jittered_wait = rand::thread_rng().gen_range(self.min_wait_time..clamped_wait); + Duration::from_millis(jittered_wait) + } +} + +/// Configuration for rate limiting slot refresh operations in a Redis cluster. +/// +/// This struct defines the interval duration between consecutive slot refresh +/// operations and an additional jitter to introduce randomness in the refresh intervals. +/// +/// # Fields +/// +/// * `interval_duration`: The minimum duration to wait between consecutive slot refresh operations. +/// * `max_jitter_milli`: The maximum jitter in milliseconds to add to the interval duration. +#[cfg(feature = "cluster-async")] +#[derive(Clone, Copy)] +pub(crate) struct SlotsRefreshRateLimit { + pub(crate) interval_duration: Duration, + pub(crate) max_jitter_milli: u64, +} + +#[cfg(feature = "cluster-async")] +impl Default for SlotsRefreshRateLimit { + fn default() -> Self { + Self { + interval_duration: DEFAULT_SLOTS_REFRESH_WAIT_DURATION, + max_jitter_milli: DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, + } + } +} + +#[cfg(feature = "cluster-async")] +impl SlotsRefreshRateLimit { + pub(crate) fn wait_duration(&self) -> Duration { + let duration_jitter = match self.max_jitter_milli { + 0 => Duration::from_millis(0), + _ => Duration::from_millis(rand::thread_rng().gen_range(0..self.max_jitter_milli)), + }; + self.interval_duration.add(duration_jitter) + } +} +/// Redis cluster specific parameters. +#[derive(Default, Clone)] +#[doc(hidden)] +pub struct ClusterParams { + pub(crate) password: Option, + pub(crate) username: Option, + pub(crate) read_from_replicas: ReadFromReplicaStrategy, + /// tls indicates tls behavior of connections. + /// When Some(TlsMode), connections use tls and verify certification depends on TlsMode. + /// When None, connections do not use tls. + pub(crate) tls: Option, + pub(crate) retry_params: RetryParams, + #[cfg(feature = "cluster-async")] + pub(crate) topology_checks_interval: Option, + #[cfg(feature = "cluster-async")] + pub(crate) slots_refresh_rate_limit: SlotsRefreshRateLimit, + #[cfg(feature = "cluster-async")] + pub(crate) connections_validation_interval: Option, + pub(crate) tls_params: Option, + pub(crate) client_name: Option, + pub(crate) connection_timeout: Duration, + pub(crate) response_timeout: Duration, + pub(crate) protocol: ProtocolVersion, + pub(crate) pubsub_subscriptions: Option, +} + +impl ClusterParams { + fn from(value: BuilderParams) -> RedisResult { + #[cfg(not(feature = "tls-rustls"))] + let tls_params = None; + + #[cfg(feature = "tls-rustls")] + let tls_params = { + let retrieved_tls_params = value.certs.clone().map(retrieve_tls_certificates); + + retrieved_tls_params.transpose()? + }; + + Ok(Self { + password: value.password, + username: value.username, + read_from_replicas: value.read_from_replicas, + tls: value.tls, + retry_params: value.retries_configuration, + connection_timeout: value.connection_timeout.unwrap_or(Duration::MAX), + #[cfg(feature = "cluster-async")] + topology_checks_interval: value.topology_checks_interval, + #[cfg(feature = "cluster-async")] + slots_refresh_rate_limit: value.slots_refresh_rate_limit, + #[cfg(feature = "cluster-async")] + connections_validation_interval: value.connections_validation_interval, + tls_params, + client_name: value.client_name, + response_timeout: value.response_timeout.unwrap_or(Duration::MAX), + protocol: value.protocol, + pubsub_subscriptions: value.pubsub_subscriptions, + }) + } +} + +/// Used to configure and build a [`ClusterClient`]. +pub struct ClusterClientBuilder { + initial_nodes: RedisResult>, + builder_params: BuilderParams, +} + +impl ClusterClientBuilder { + /// Creates a new `ClusterClientBuilder` with the provided initial_nodes. + /// + /// This is the same as `ClusterClient::builder(initial_nodes)`. + pub fn new( + initial_nodes: impl IntoIterator, + ) -> ClusterClientBuilder { + ClusterClientBuilder { + initial_nodes: initial_nodes + .into_iter() + .map(|x| x.into_connection_info()) + .collect(), + builder_params: Default::default(), + } + } + + /// Creates a new [`ClusterClient`] from the parameters. + /// + /// This does not create connections to the Redis Cluster, but only performs some basic checks + /// on the initial nodes' URLs and passwords/usernames. + /// + /// When the `tls-rustls` feature is enabled and TLS credentials are provided, they are set for + /// each cluster connection. + /// + /// # Errors + /// + /// Upon failure to parse initial nodes or if the initial nodes have different passwords or + /// usernames, an error is returned. + pub fn build(self) -> RedisResult { + let initial_nodes = self.initial_nodes?; + + let first_node = match initial_nodes.first() { + Some(node) => node, + None => { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Initial nodes can't be empty.", + ))) + } + }; + + let mut cluster_params = ClusterParams::from(self.builder_params)?; + let password = if cluster_params.password.is_none() { + cluster_params + .password + .clone_from(&first_node.redis.password); + &cluster_params.password + } else { + &None + }; + let username = if cluster_params.username.is_none() { + cluster_params + .username + .clone_from(&first_node.redis.username); + &cluster_params.username + } else { + &None + }; + if cluster_params.tls.is_none() { + cluster_params.tls = match first_node.addr { + ConnectionAddr::TcpTls { + host: _, + port: _, + insecure, + tls_params: _, + } => Some(match insecure { + false => TlsMode::Secure, + true => TlsMode::Insecure, + }), + _ => None, + }; + } + + let mut nodes = Vec::with_capacity(initial_nodes.len()); + for mut node in initial_nodes { + if let ConnectionAddr::Unix(_) = node.addr { + return Err(RedisError::from((ErrorKind::InvalidClientConfig, + "This library cannot use unix socket because Redis's cluster command returns only cluster's IP and port."))); + } + + if password.is_some() && node.redis.password != *password { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot use different password among initial nodes.", + ))); + } + + if username.is_some() && node.redis.username != *username { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot use different username among initial nodes.", + ))); + } + + if node.redis.client_name.is_some() + && node.redis.client_name != cluster_params.client_name + { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Cannot use different client_name among initial nodes.", + ))); + } + + node.redis.protocol = cluster_params.protocol; + nodes.push(node); + } + + Ok(ClusterClient { + initial_nodes: nodes, + cluster_params, + }) + } + + /// Sets client name for the new ClusterClient. + pub fn client_name(mut self, client_name: String) -> ClusterClientBuilder { + self.builder_params.client_name = Some(client_name); + self + } + + /// Sets password for the new ClusterClient. + pub fn password(mut self, password: String) -> ClusterClientBuilder { + self.builder_params.password = Some(password); + self + } + + /// Sets username for the new ClusterClient. + pub fn username(mut self, username: String) -> ClusterClientBuilder { + self.builder_params.username = Some(username); + self + } + + /// Sets number of retries for the new ClusterClient. + pub fn retries(mut self, retries: u32) -> ClusterClientBuilder { + self.builder_params.retries_configuration.number_of_retries = retries; + self + } + + /// Sets maximal wait time in millisceonds between retries for the new ClusterClient. + pub fn max_retry_wait(mut self, max_wait: u64) -> ClusterClientBuilder { + self.builder_params.retries_configuration.max_wait_time = max_wait; + self + } + + /// Sets minimal wait time in millisceonds between retries for the new ClusterClient. + pub fn min_retry_wait(mut self, min_wait: u64) -> ClusterClientBuilder { + self.builder_params.retries_configuration.min_wait_time = min_wait; + self + } + + /// Sets the factor and exponent base for the retry wait time. + /// The formula for the wait is rand(min_wait_retry .. min(max_retry_wait , factor * exponent_base ^ retry))ms. + pub fn retry_wait_formula(mut self, factor: u64, exponent_base: u64) -> ClusterClientBuilder { + self.builder_params.retries_configuration.factor = factor; + self.builder_params.retries_configuration.exponent_base = exponent_base; + self + } + + /// Sets TLS mode for the new ClusterClient. + /// + /// It is extracted from the first node of initial_nodes if not set. + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + pub fn tls(mut self, tls: TlsMode) -> ClusterClientBuilder { + self.builder_params.tls = Some(tls); + self + } + + /// Sets raw TLS certificates for the new ClusterClient. + /// + /// When set, enforces the connection must be TLS secured. + /// + /// All certificates must be provided as byte streams loaded from PEM files their consistency is + /// checked during `build()` call. + /// + /// - `certificates` - `TlsCertificates` structure containing: + /// -- `client_tls` - Optional `ClientTlsConfig` containing byte streams for + /// --- `client_cert` - client's byte stream containing client certificate in PEM format + /// --- `client_key` - client's byte stream containing private key in PEM format + /// -- `root_cert` - Optional byte stream yielding PEM formatted file for root certificates. + /// + /// If `ClientTlsConfig` ( cert+key pair ) is not provided, then client-side authentication is not enabled. + /// If `root_cert` is not provided, then system root certificates are used instead. + #[cfg(feature = "tls-rustls")] + pub fn certs(mut self, certificates: TlsCertificates) -> ClusterClientBuilder { + self.builder_params.tls = Some(TlsMode::Secure); + self.builder_params.certs = Some(certificates); + self + } + + /// Enables reading from replicas for all new connections (default is disabled). + /// + /// If enabled, then read queries will go to the replica nodes & write queries will go to the + /// primary nodes. If there are no replica nodes, then all queries will go to the primary nodes. + pub fn read_from_replicas(mut self) -> ClusterClientBuilder { + self.builder_params.read_from_replicas = ReadFromReplicaStrategy::RoundRobin; + self + } + + /// Enables periodic topology checks for this client. + /// + /// If enabled, periodic topology checks will be executed at the configured intervals to examine whether there + /// have been any changes in the cluster's topology. If a change is detected, it will trigger a slot refresh. + /// Unlike slot refreshments, the periodic topology checks only examine a limited number of nodes to query their + /// topology, ensuring that the check remains quick and efficient. + #[cfg(feature = "cluster-async")] + pub fn periodic_topology_checks(mut self, interval: Duration) -> ClusterClientBuilder { + self.builder_params.topology_checks_interval = Some(interval); + self + } + + /// Enables periodic connections checks for this client. + /// If enabled, the conenctions to the cluster nodes will be validated periodicatly, per configured interval. + /// In addition, for tokio runtime, passive disconnections could be detected instantly, + /// triggering reestablishemnt, w/o waiting for the next periodic check. + #[cfg(feature = "cluster-async")] + pub fn periodic_connections_checks(mut self, interval: Duration) -> ClusterClientBuilder { + self.builder_params.connections_validation_interval = Some(interval); + self + } + + /// Sets the rate limit for slot refresh operations in the cluster. + /// + /// This method configures the interval duration between consecutive slot + /// refresh operations and an additional jitter to introduce randomness + /// in the refresh intervals. + /// + /// # Parameters + /// + /// * `interval_duration`: The minimum duration to wait between consecutive slot refresh operations. + /// * `max_jitter_milli`: The maximum jitter in milliseconds to add to the interval duration. + /// + /// # Defaults + /// + /// If not set, the slots refresh rate limit configurations will be set with the default values: + /// ``` + /// #[cfg(feature = "cluster-async")] + /// use redis::cluster_topology::{DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, DEFAULT_SLOTS_REFRESH_WAIT_DURATION}; + /// ``` + /// + /// - `interval_duration`: `DEFAULT_SLOTS_REFRESH_WAIT_DURATION` + /// - `max_jitter_milli`: `DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI` + /// + #[cfg(feature = "cluster-async")] + pub fn slots_refresh_rate_limit( + mut self, + interval_duration: Duration, + max_jitter_milli: u64, + ) -> ClusterClientBuilder { + self.builder_params.slots_refresh_rate_limit = SlotsRefreshRateLimit { + interval_duration, + max_jitter_milli, + }; + self + } + + /// Enables timing out on slow connection time. + /// + /// If enabled, the cluster will only wait the given time on each connection attempt to each node. + pub fn connection_timeout(mut self, connection_timeout: Duration) -> ClusterClientBuilder { + self.builder_params.connection_timeout = Some(connection_timeout); + self + } + + /// Enables timing out on slow responses. + /// + /// If enabled, the cluster will only wait the given time to each response from each node. + pub fn response_timeout(mut self, response_timeout: Duration) -> ClusterClientBuilder { + self.builder_params.response_timeout = Some(response_timeout); + self + } + + /// Sets the protocol with which the client should communicate with the server. + pub fn use_protocol(mut self, protocol: ProtocolVersion) -> ClusterClientBuilder { + self.builder_params.protocol = protocol; + self + } + + /// Use `build()`. + #[deprecated(since = "0.22.0", note = "Use build()")] + pub fn open(self) -> RedisResult { + self.build() + } + + /// Use `read_from_replicas()`. + #[deprecated(since = "0.22.0", note = "Use read_from_replicas()")] + pub fn readonly(mut self, read_from_replicas: bool) -> ClusterClientBuilder { + self.builder_params.read_from_replicas = if read_from_replicas { + ReadFromReplicaStrategy::RoundRobin + } else { + ReadFromReplicaStrategy::AlwaysFromPrimary + }; + self + } + + /// Sets the pubsub configuration for the new ClusterClient. + pub fn pubsub_subscriptions( + mut self, + pubsub_subscriptions: PubSubSubscriptionInfo, + ) -> ClusterClientBuilder { + self.builder_params.pubsub_subscriptions = Some(pubsub_subscriptions); + self + } +} + +/// This is a Redis Cluster client. +#[derive(Clone)] +pub struct ClusterClient { + initial_nodes: Vec, + cluster_params: ClusterParams, +} + +impl ClusterClient { + /// Creates a `ClusterClient` with the default parameters. + /// + /// This does not create connections to the Redis Cluster, but only performs some basic checks + /// on the initial nodes' URLs and passwords/usernames. + /// + /// # Errors + /// + /// Upon failure to parse initial nodes or if the initial nodes have different passwords or + /// usernames, an error is returned. + pub fn new( + initial_nodes: impl IntoIterator, + ) -> RedisResult { + Self::builder(initial_nodes).build() + } + + /// Creates a [`ClusterClientBuilder`] with the provided initial_nodes. + pub fn builder( + initial_nodes: impl IntoIterator, + ) -> ClusterClientBuilder { + ClusterClientBuilder::new(initial_nodes) + } + + /// Creates new connections to Redis Cluster nodes and returns a + /// [`cluster::ClusterConnection`]. + /// + /// # Errors + /// + /// An error is returned if there is a failure while creating connections or slots. + pub fn get_connection( + &self, + push_sender: Option>, + ) -> RedisResult { + cluster::ClusterConnection::new( + self.cluster_params.clone(), + self.initial_nodes.clone(), + push_sender, + ) + } + + /// Creates new connections to Redis Cluster nodes and returns a + /// [`cluster_async::ClusterConnection`]. + /// + /// # Errors + /// + /// An error is returned if there is a failure while creating connections or slots. + #[cfg(feature = "cluster-async")] + pub async fn get_async_connection( + &self, + push_sender: Option>, + ) -> RedisResult { + cluster_async::ClusterConnection::new( + &self.initial_nodes, + self.cluster_params.clone(), + push_sender, + ) + .await + } + + #[doc(hidden)] + pub fn get_generic_connection( + &self, + push_sender: Option>, + ) -> RedisResult> + where + C: crate::ConnectionLike + crate::cluster::Connect + Send, + { + cluster::ClusterConnection::new( + self.cluster_params.clone(), + self.initial_nodes.clone(), + push_sender, + ) + } + + #[doc(hidden)] + #[cfg(feature = "cluster-async")] + pub async fn get_async_generic_connection( + &self, + ) -> RedisResult> + where + C: crate::aio::ConnectionLike + + cluster_async::Connect + + Clone + + Send + + Sync + + Unpin + + 'static, + { + cluster_async::ClusterConnection::new( + &self.initial_nodes, + self.cluster_params.clone(), + None, + ) + .await + } + + /// Use `new()`. + #[deprecated(since = "0.22.0", note = "Use new()")] + pub fn open(initial_nodes: Vec) -> RedisResult { + Self::new(initial_nodes) + } +} + +#[cfg(test)] +mod tests { + #[cfg(feature = "cluster-async")] + use crate::cluster_topology::{ + DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI, DEFAULT_SLOTS_REFRESH_WAIT_DURATION, + }; + + use super::{ClusterClient, ClusterClientBuilder, ConnectionInfo, IntoConnectionInfo}; + + fn get_connection_data() -> Vec { + vec![ + "redis://127.0.0.1:6379".into_connection_info().unwrap(), + "redis://127.0.0.1:6378".into_connection_info().unwrap(), + "redis://127.0.0.1:6377".into_connection_info().unwrap(), + ] + } + + fn get_connection_data_with_password() -> Vec { + vec![ + "redis://:password@127.0.0.1:6379" + .into_connection_info() + .unwrap(), + "redis://:password@127.0.0.1:6378" + .into_connection_info() + .unwrap(), + "redis://:password@127.0.0.1:6377" + .into_connection_info() + .unwrap(), + ] + } + + fn get_connection_data_with_username_and_password() -> Vec { + vec![ + "redis://user1:password@127.0.0.1:6379" + .into_connection_info() + .unwrap(), + "redis://user1:password@127.0.0.1:6378" + .into_connection_info() + .unwrap(), + "redis://user1:password@127.0.0.1:6377" + .into_connection_info() + .unwrap(), + ] + } + + #[test] + fn give_no_password() { + let client = ClusterClient::new(get_connection_data()).unwrap(); + assert_eq!(client.cluster_params.password, None); + } + + #[test] + fn give_password_by_initial_nodes() { + let client = ClusterClient::new(get_connection_data_with_password()).unwrap(); + assert_eq!(client.cluster_params.password, Some("password".to_string())); + } + + #[test] + fn give_username_and_password_by_initial_nodes() { + let client = ClusterClient::new(get_connection_data_with_username_and_password()).unwrap(); + assert_eq!(client.cluster_params.password, Some("password".to_string())); + assert_eq!(client.cluster_params.username, Some("user1".to_string())); + } + + #[test] + fn give_different_password_by_initial_nodes() { + let result = ClusterClient::new(vec![ + "redis://:password1@127.0.0.1:6379", + "redis://:password2@127.0.0.1:6378", + "redis://:password3@127.0.0.1:6377", + ]); + assert!(result.is_err()); + } + + #[test] + fn give_different_username_by_initial_nodes() { + let result = ClusterClient::new(vec![ + "redis://user1:password@127.0.0.1:6379", + "redis://user2:password@127.0.0.1:6378", + "redis://user1:password@127.0.0.1:6377", + ]); + assert!(result.is_err()); + } + + #[test] + fn give_username_password_by_method() { + let client = ClusterClientBuilder::new(get_connection_data_with_password()) + .password("pass".to_string()) + .username("user1".to_string()) + .build() + .unwrap(); + assert_eq!(client.cluster_params.password, Some("pass".to_string())); + assert_eq!(client.cluster_params.username, Some("user1".to_string())); + } + + #[test] + fn give_empty_initial_nodes() { + let client = ClusterClient::new(Vec::::new()); + assert!(client.is_err()) + } + + #[cfg(feature = "cluster-async")] + #[test] + fn give_slots_refresh_rate_limit_configurations() { + let interval_dur = std::time::Duration::from_secs(20); + let client = ClusterClientBuilder::new(get_connection_data()) + .slots_refresh_rate_limit(interval_dur, 500) + .build() + .unwrap(); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .interval_duration, + interval_dur + ); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .max_jitter_milli, + 500 + ); + } + + #[cfg(feature = "cluster-async")] + #[test] + fn dont_give_slots_refresh_rate_limit_configurations_uses_defaults() { + let client = ClusterClientBuilder::new(get_connection_data()) + .build() + .unwrap(); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .interval_duration, + DEFAULT_SLOTS_REFRESH_WAIT_DURATION + ); + assert_eq!( + client + .cluster_params + .slots_refresh_rate_limit + .max_jitter_milli, + DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI + ); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_pipeline.rs b/glide-core/redis-rs/redis/src/cluster_pipeline.rs new file mode 100644 index 0000000000..9da1fee781 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_pipeline.rs @@ -0,0 +1,151 @@ +use crate::cluster::ClusterConnection; +use crate::cmd::{cmd, Cmd}; +use crate::types::{ + from_owned_redis_value, ErrorKind, FromRedisValue, HashSet, RedisResult, ToRedisArgs, Value, +}; + +pub(crate) const UNROUTABLE_ERROR: (ErrorKind, &str) = ( + ErrorKind::ClientError, + "This command cannot be safely routed in cluster mode", +); + +fn is_illegal_cmd(cmd: &str) -> bool { + matches!( + cmd, + "BGREWRITEAOF" | "BGSAVE" | "BITOP" | "BRPOPLPUSH" | + // All commands that start with "CLIENT" + "CLIENT" | "CLIENT GETNAME" | "CLIENT KILL" | "CLIENT LIST" | "CLIENT SETNAME" | + // All commands that start with "CONFIG" + "CONFIG" | "CONFIG GET" | "CONFIG RESETSTAT" | "CONFIG REWRITE" | "CONFIG SET" | + "DBSIZE" | + "ECHO" | "EVALSHA" | + "FLUSHALL" | "FLUSHDB" | + "INFO" | + "KEYS" | + "LASTSAVE" | + "MGET" | "MOVE" | "MSET" | "MSETNX" | + "PFMERGE" | "PFCOUNT" | "PING" | "PUBLISH" | + "RANDOMKEY" | "RENAME" | "RENAMENX" | "RPOPLPUSH" | + "SAVE" | "SCAN" | + // All commands that start with "SCRIPT" + "SCRIPT" | "SCRIPT EXISTS" | "SCRIPT FLUSH" | "SCRIPT KILL" | "SCRIPT LOAD" | + "SDIFF" | "SDIFFSTORE" | + // All commands that start with "SENTINEL" + "SENTINEL" | "SENTINEL GET MASTER ADDR BY NAME" | "SENTINEL MASTER" | "SENTINEL MASTERS" | + "SENTINEL MONITOR" | "SENTINEL REMOVE" | "SENTINEL SENTINELS" | "SENTINEL SET" | + "SENTINEL SLAVES" | "SHUTDOWN" | "SINTER" | "SINTERSTORE" | "SLAVEOF" | + // All commands that start with "SLOWLOG" + "SLOWLOG" | "SLOWLOG GET" | "SLOWLOG LEN" | "SLOWLOG RESET" | + "SMOVE" | "SORT" | "SUNION" | "SUNIONSTORE" | + "TIME" + ) +} + +/// Represents a Redis Cluster command pipeline. +#[derive(Clone)] +pub struct ClusterPipeline { + commands: Vec, + ignored_commands: HashSet, +} + +/// A cluster pipeline is almost identical to a normal [Pipeline](crate::pipeline::Pipeline), with two exceptions: +/// * It does not support transactions +/// * The following commands can not be used in a cluster pipeline: +/// ```text +/// BGREWRITEAOF, BGSAVE, BITOP, BRPOPLPUSH +/// CLIENT GETNAME, CLIENT KILL, CLIENT LIST, CLIENT SETNAME, CONFIG GET, +/// CONFIG RESETSTAT, CONFIG REWRITE, CONFIG SET +/// DBSIZE +/// ECHO, EVALSHA +/// FLUSHALL, FLUSHDB +/// INFO +/// KEYS +/// LASTSAVE +/// MGET, MOVE, MSET, MSETNX +/// PFMERGE, PFCOUNT, PING, PUBLISH +/// RANDOMKEY, RENAME, RENAMENX, RPOPLPUSH +/// SAVE, SCAN, SCRIPT EXISTS, SCRIPT FLUSH, SCRIPT KILL, SCRIPT LOAD, SDIFF, SDIFFSTORE, +/// SENTINEL GET MASTER ADDR BY NAME, SENTINEL MASTER, SENTINEL MASTERS, SENTINEL MONITOR, +/// SENTINEL REMOVE, SENTINEL SENTINELS, SENTINEL SET, SENTINEL SLAVES, SHUTDOWN, SINTER, +/// SINTERSTORE, SLAVEOF, SLOWLOG GET, SLOWLOG LEN, SLOWLOG RESET, SMOVE, SORT, SUNION, SUNIONSTORE +/// TIME +/// ``` +impl ClusterPipeline { + /// Create an empty pipeline. + pub fn new() -> ClusterPipeline { + Self::with_capacity(0) + } + + /// Creates an empty pipeline with pre-allocated capacity. + pub fn with_capacity(capacity: usize) -> ClusterPipeline { + ClusterPipeline { + commands: Vec::with_capacity(capacity), + ignored_commands: HashSet::new(), + } + } + + pub(crate) fn commands(&self) -> &Vec { + &self.commands + } + + /// Executes the pipeline and fetches the return values: + /// + /// ```rust,no_run + /// # let nodes = vec!["redis://127.0.0.1:6379/"]; + /// # let client = redis::cluster::ClusterClient::new(nodes).unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let mut pipe = redis::cluster::cluster_pipe(); + /// let (k1, k2) : (i32, i32) = pipe + /// .cmd("SET").arg("key_1").arg(42).ignore() + /// .cmd("SET").arg("key_2").arg(43).ignore() + /// .cmd("GET").arg("key_1") + /// .cmd("GET").arg("key_2").query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn query(&self, con: &mut ClusterConnection) -> RedisResult { + for cmd in &self.commands { + let cmd_name = std::str::from_utf8(cmd.arg_idx(0).unwrap_or(b"")) + .unwrap_or("") + .trim() + .to_ascii_uppercase(); + + if is_illegal_cmd(&cmd_name) { + fail!(( + UNROUTABLE_ERROR.0, + UNROUTABLE_ERROR.1, + format!("Command '{cmd_name}' can't be executed in a cluster pipeline.") + )) + } + } + + from_owned_redis_value(if self.commands.is_empty() { + Value::Array(vec![]) + } else { + self.make_pipeline_results(con.execute_pipeline(self)?) + }) + } + + /// This is a shortcut to `query()` that does not return a value and + /// will fail the task if the query of the pipeline fails. + /// + /// This is equivalent to a call to query like this: + /// + /// ```rust,no_run + /// # let nodes = vec!["redis://127.0.0.1:6379/"]; + /// # let client = redis::cluster::ClusterClient::new(nodes).unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let mut pipe = redis::cluster::cluster_pipe(); + /// let _ : () = pipe.cmd("SET").arg("key_1").arg(42).ignore().query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn execute(&self, con: &mut ClusterConnection) { + self.query::<()>(con).unwrap(); + } +} + +/// Shortcut for creating a new cluster pipeline. +pub fn cluster_pipe() -> ClusterPipeline { + ClusterPipeline::new() +} + +implement_pipeline_commands!(ClusterPipeline); diff --git a/glide-core/redis-rs/redis/src/cluster_routing.rs b/glide-core/redis-rs/redis/src/cluster_routing.rs new file mode 100644 index 0000000000..bfe6ae2039 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_routing.rs @@ -0,0 +1,1374 @@ +use rand::Rng; +use std::cmp::min; +use std::collections::HashMap; + +use crate::cluster_topology::get_slot; +use crate::cmd::{Arg, Cmd}; +use crate::types::Value; +use crate::{ErrorKind, RedisResult}; +use std::iter::Once; + +#[derive(Clone)] +pub(crate) enum Redirect { + Moved(String), + Ask(String), +} + +/// Logical bitwise aggregating operators. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum LogicalAggregateOp { + /// Aggregate by bitwise && + And, + // Or, omitted due to dead code warnings. ATM this value isn't constructed anywhere +} + +/// Numerical aggreagting operators. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum AggregateOp { + /// Choose minimal value + Min, + /// Sum all values + Sum, + // Max, omitted due to dead code warnings. ATM this value isn't constructed anywhere +} + +/// Policy defining how to combine multiple responses into one. +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum ResponsePolicy { + /// Wait for one request to succeed and return its results. Return error if all requests fail. + OneSucceeded, + /// Returns the first succeeded non-empty result; if all results are empty, returns `Nil`; otherwise, returns the last received error. + FirstSucceededNonEmptyOrAllEmpty, + /// Waits for all requests to succeed, and the returns one of the successes. Returns the error on the first received error. + AllSucceeded, + /// Aggregate success results according to a logical bitwise operator. Return error on any failed request or on a response that doesn't conform to 0 or 1. + AggregateLogical(LogicalAggregateOp), + /// Aggregate success results according to a numeric operator. Return error on any failed request or on a response that isn't an integer. + Aggregate(AggregateOp), + /// Aggregate array responses into a single array. Return error on any failed request or on a response that isn't an array. + CombineArrays, + /// Handling is not defined by the Redis standard. Will receive a special case + Special, + /// Combines multiple map responses into a single map. + CombineMaps, +} + +/// Defines whether a request should be routed to a single node, or multiple ones. +#[derive(Debug, Clone, PartialEq)] +pub enum RoutingInfo { + /// Route to single node + SingleNode(SingleNodeRoutingInfo), + /// Route to multiple nodes + MultiNode((MultipleNodeRoutingInfo, Option)), +} + +/// Defines which single node should receive a request. +#[derive(Debug, Clone, PartialEq)] +pub enum SingleNodeRoutingInfo { + /// Route to any node at random + Random, + /// Route to any *primary* node + RandomPrimary, + /// Route to the node that matches the [Route] + SpecificNode(Route), + /// Route to the node with the given address. + ByAddress { + /// DNS hostname of the node + host: String, + /// port of the node + port: u16, + }, +} + +impl From> for SingleNodeRoutingInfo { + fn from(value: Option) -> Self { + value + .map(SingleNodeRoutingInfo::SpecificNode) + .unwrap_or(SingleNodeRoutingInfo::Random) + } +} + +/// Defines which collection of nodes should receive a request +#[derive(Debug, Clone, PartialEq)] +pub enum MultipleNodeRoutingInfo { + /// Route to all nodes in the clusters + AllNodes, + /// Route to all primaries in the cluster + AllMasters, + /// Instructions for how to split a multi-slot command (e.g. MGET, MSET) into sub-commands. Each tuple is the route for each subcommand, and the indices of the arguments from the original command that should be copied to the subcommand. + MultiSlot(Vec<(Route, Vec)>), +} + +/// Takes a routable and an iterator of indices, which is assued to be created from`MultipleNodeRoutingInfo::MultiSlot`, +/// and returns a command with the arguments matching the indices. +pub fn command_for_multi_slot_indices<'a, 'b>( + original_cmd: &'a impl Routable, + indices: impl Iterator + 'a, +) -> Cmd +where + 'b: 'a, +{ + let mut new_cmd = Cmd::new(); + let command_length = 1; // TODO - the +1 should change if we have multi-slot commands with 2 command words. + new_cmd.arg(original_cmd.arg_idx(0)); + for index in indices { + new_cmd.arg(original_cmd.arg_idx(index + command_length)); + } + new_cmd +} + +/// Aggreagte numeric responses. +pub fn aggregate(values: Vec, op: AggregateOp) -> RedisResult { + let initial_value = match op { + AggregateOp::Min => i64::MAX, + AggregateOp::Sum => 0, + }; + let result = values.into_iter().try_fold(initial_value, |acc, curr| { + let int = match curr { + Value::Int(int) => int, + _ => { + return RedisResult::Err( + ( + ErrorKind::TypeError, + "expected array of integers as response", + ) + .into(), + ); + } + }; + let acc = match op { + AggregateOp::Min => min(acc, int), + AggregateOp::Sum => acc + int, + }; + Ok(acc) + })?; + Ok(Value::Int(result)) +} + +/// Aggreagte numeric responses by a boolean operator. +pub fn logical_aggregate(values: Vec, op: LogicalAggregateOp) -> RedisResult { + let initial_value = match op { + LogicalAggregateOp::And => true, + }; + let results = values.into_iter().try_fold(Vec::new(), |acc, curr| { + let values = match curr { + Value::Array(values) => values, + _ => { + return RedisResult::Err( + ( + ErrorKind::TypeError, + "expected array of integers as response", + ) + .into(), + ); + } + }; + let mut acc = if acc.is_empty() { + vec![initial_value; values.len()] + } else { + acc + }; + for (index, value) in values.into_iter().enumerate() { + let int = match value { + Value::Int(int) => int, + _ => { + return Err(( + ErrorKind::TypeError, + "expected array of integers as response", + ) + .into()); + } + }; + acc[index] = match op { + LogicalAggregateOp::And => acc[index] && (int > 0), + }; + } + Ok(acc) + })?; + Ok(Value::Array( + results + .into_iter() + .map(|result| Value::Int(result as i64)) + .collect(), + )) +} +/// Aggregate array responses into a single map. +pub fn combine_map_results(values: Vec) -> RedisResult { + let mut map: HashMap, i64> = HashMap::new(); + + for value in values { + match value { + Value::Array(elements) => { + let mut iter = elements.into_iter(); + + while let Some(key) = iter.next() { + if let Value::BulkString(key_bytes) = key { + if let Some(Value::Int(value)) = iter.next() { + *map.entry(key_bytes).or_insert(0) += value; + } else { + return Err((ErrorKind::TypeError, "expected integer value").into()); + } + } else { + return Err((ErrorKind::TypeError, "expected string key").into()); + } + } + } + _ => { + return Err((ErrorKind::TypeError, "expected array of values as response").into()); + } + } + } + + let result_vec: Vec<(Value, Value)> = map + .into_iter() + .map(|(k, v)| (Value::BulkString(k), Value::Int(v))) + .collect(); + + Ok(Value::Map(result_vec)) +} + +/// Aggregate array responses into a single array. +pub fn combine_array_results(values: Vec) -> RedisResult { + let mut results = Vec::new(); + + for value in values { + match value { + Value::Array(values) => results.extend(values), + _ => { + return Err((ErrorKind::TypeError, "expected array of values as response").into()); + } + } + } + + Ok(Value::Array(results)) +} + +/// Combines multiple call results in the `values` field, each assume to be an array of results, +/// into a single array. `sorting_order` defines the order of the results in the returned array - +/// for each array of results, `sorting_order` should contain a matching array with the indices of +/// the results in the final array. +pub(crate) fn combine_and_sort_array_results<'a>( + values: Vec, + sorting_order: impl ExactSizeIterator>, +) -> RedisResult { + let mut results = Vec::new(); + results.resize( + values.iter().fold(0, |acc, value| match value { + Value::Array(values) => values.len() + acc, + _ => 0, + }), + Value::Nil, + ); + assert_eq!(values.len(), sorting_order.len()); + + for (key_indices, value) in sorting_order.into_iter().zip(values) { + match value { + Value::Array(values) => { + assert_eq!(values.len(), key_indices.len()); + for (index, value) in key_indices.iter().zip(values) { + results[*index] = value; + } + } + _ => { + return Err((ErrorKind::TypeError, "expected array of values as response").into()); + } + } + } + + Ok(Value::Array(results)) +} + +fn get_route(is_readonly: bool, key: &[u8]) -> Route { + let slot = get_slot(key); + if is_readonly { + Route::new(slot, SlotAddr::ReplicaOptional) + } else { + Route::new(slot, SlotAddr::Master) + } +} + +/// Takes the given `routable` and creates a multi-slot routing info. +/// This is used for commands like MSET & MGET, where if the command's keys +/// are hashed to multiple slots, the command should be split into sub-commands, +/// each targetting a single slot. The results of these sub-commands are then +/// usually reassembled using `combine_and_sort_array_results`. In order to do this, +/// `MultipleNodeRoutingInfo::MultiSlot` contains the routes for each sub-command, and +/// the indices in the final combined result for each result from the sub-command. +/// +/// If all keys are routed to the same slot, there's no need to split the command, +/// so a single node routing info will be returned. +fn multi_shard( + routable: &R, + cmd: &[u8], + first_key_index: usize, + has_values: bool, +) -> Option +where + R: Routable + ?Sized, +{ + let is_readonly = is_readonly_cmd(cmd); + let mut routes = HashMap::new(); + let mut key_index = 0; + while let Some(key) = routable.arg_idx(first_key_index + key_index) { + let route = get_route(is_readonly, key); + let entry = routes.entry(route); + let keys = entry.or_insert(Vec::new()); + keys.push(key_index); + + if has_values { + key_index += 1; + routable.arg_idx(first_key_index + key_index)?; // check that there's a value for the key + keys.push(key_index); + } + key_index += 1; + } + + let mut routes: Vec<(Route, Vec)> = routes.into_iter().collect(); + Some(if routes.len() == 1 { + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(routes.pop().unwrap().0)) + } else { + RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::MultiSlot(routes), + ResponsePolicy::for_command(cmd), + )) + }) +} + +impl ResponsePolicy { + /// Parse the command for the matching response policy. + pub fn for_command(cmd: &[u8]) -> Option { + match cmd { + b"SCRIPT EXISTS" => Some(ResponsePolicy::AggregateLogical(LogicalAggregateOp::And)), + + b"DBSIZE" | b"DEL" | b"EXISTS" | b"SLOWLOG LEN" | b"TOUCH" | b"UNLINK" + | b"LATENCY RESET" | b"PUBSUB NUMPAT" => { + Some(ResponsePolicy::Aggregate(AggregateOp::Sum)) + } + + b"WAIT" => Some(ResponsePolicy::Aggregate(AggregateOp::Min)), + + b"ACL SETUSER" | b"ACL DELUSER" | b"ACL SAVE" | b"CLIENT SETNAME" + | b"CLIENT SETINFO" | b"CONFIG SET" | b"CONFIG RESETSTAT" | b"CONFIG REWRITE" + | b"FLUSHALL" | b"FLUSHDB" | b"FUNCTION DELETE" | b"FUNCTION FLUSH" + | b"FUNCTION LOAD" | b"FUNCTION RESTORE" | b"MEMORY PURGE" | b"MSET" | b"PING" + | b"SCRIPT FLUSH" | b"SCRIPT LOAD" | b"SLOWLOG RESET" | b"UNWATCH" | b"WATCH" => { + Some(ResponsePolicy::AllSucceeded) + } + + b"KEYS" | b"MGET" | b"SLOWLOG GET" | b"PUBSUB CHANNELS" | b"PUBSUB SHARDCHANNELS" => { + Some(ResponsePolicy::CombineArrays) + } + b"PUBSUB NUMSUB" | b"PUBSUB SHARDNUMSUB" => Some(ResponsePolicy::CombineMaps), + + b"FUNCTION KILL" | b"SCRIPT KILL" => Some(ResponsePolicy::OneSucceeded), + + // This isn't based on response_tips, but on the discussion here - https://github.com/redis/redis/issues/12410 + b"RANDOMKEY" => Some(ResponsePolicy::FirstSucceededNonEmptyOrAllEmpty), + + b"LATENCY GRAPH" | b"LATENCY HISTOGRAM" | b"LATENCY HISTORY" | b"LATENCY DOCTOR" + | b"LATENCY LATEST" => Some(ResponsePolicy::Special), + + b"FUNCTION STATS" => Some(ResponsePolicy::Special), + + b"MEMORY MALLOC-STATS" | b"MEMORY DOCTOR" | b"MEMORY STATS" => { + Some(ResponsePolicy::Special) + } + + b"INFO" => Some(ResponsePolicy::Special), + + _ => None, + } + } +} + +enum RouteBy { + AllNodes, + AllPrimaries, + FirstKey, + MultiShardNoValues, + MultiShardWithValues, + Random, + SecondArg, + SecondArgAfterKeyCount, + SecondArgSlot, + StreamsIndex, + ThirdArgAfterKeyCount, + Undefined, +} + +fn base_routing(cmd: &[u8]) -> RouteBy { + match cmd { + b"ACL SETUSER" + | b"ACL DELUSER" + | b"ACL SAVE" + | b"CLIENT SETNAME" + | b"CLIENT SETINFO" + | b"SLOWLOG GET" + | b"SLOWLOG LEN" + | b"SLOWLOG RESET" + | b"CONFIG SET" + | b"CONFIG RESETSTAT" + | b"CONFIG REWRITE" + | b"SCRIPT FLUSH" + | b"SCRIPT LOAD" + | b"LATENCY RESET" + | b"LATENCY GRAPH" + | b"LATENCY HISTOGRAM" + | b"LATENCY HISTORY" + | b"LATENCY DOCTOR" + | b"LATENCY LATEST" + | b"PUBSUB NUMPAT" + | b"PUBSUB CHANNELS" + | b"PUBSUB NUMSUB" + | b"PUBSUB SHARDCHANNELS" + | b"PUBSUB SHARDNUMSUB" + | b"SCRIPT KILL" + | b"FUNCTION KILL" + | b"FUNCTION STATS" => RouteBy::AllNodes, + + b"DBSIZE" + | b"FLUSHALL" + | b"FLUSHDB" + | b"FUNCTION DELETE" + | b"FUNCTION FLUSH" + | b"FUNCTION LOAD" + | b"FUNCTION RESTORE" + | b"INFO" + | b"KEYS" + | b"MEMORY DOCTOR" + | b"MEMORY MALLOC-STATS" + | b"MEMORY PURGE" + | b"MEMORY STATS" + | b"PING" + | b"SCRIPT EXISTS" + | b"UNWATCH" + | b"WAIT" + | b"RANDOMKEY" + | b"WAITAOF" => RouteBy::AllPrimaries, + + b"MGET" | b"DEL" | b"EXISTS" | b"UNLINK" | b"TOUCH" | b"WATCH" => { + RouteBy::MultiShardNoValues + } + b"MSET" => RouteBy::MultiShardWithValues, + + // TODO - special handling - b"SCAN" + b"SCAN" | b"SHUTDOWN" | b"SLAVEOF" | b"REPLICAOF" => RouteBy::Undefined, + + b"BLMPOP" | b"BZMPOP" | b"EVAL" | b"EVALSHA" | b"EVALSHA_RO" | b"EVAL_RO" | b"FCALL" + | b"FCALL_RO" => RouteBy::ThirdArgAfterKeyCount, + + b"BITOP" + | b"MEMORY USAGE" + | b"PFDEBUG" + | b"XGROUP CREATE" + | b"XGROUP CREATECONSUMER" + | b"XGROUP DELCONSUMER" + | b"XGROUP DESTROY" + | b"XGROUP SETID" + | b"XINFO CONSUMERS" + | b"XINFO GROUPS" + | b"XINFO STREAM" + | b"OBJECT ENCODING" + | b"OBJECT FREQ" + | b"OBJECT IDLETIME" + | b"OBJECT REFCOUNT" => RouteBy::SecondArg, + + b"LMPOP" | b"SINTERCARD" | b"ZDIFF" | b"ZINTER" | b"ZINTERCARD" | b"ZMPOP" | b"ZUNION" => { + RouteBy::SecondArgAfterKeyCount + } + + b"XREAD" | b"XREADGROUP" => RouteBy::StreamsIndex, + + // keyless commands with more arguments, whose arguments might be wrongly taken to be keys. + // TODO - double check these, in order to find better ways to route some of them. + b"ACL DRYRUN" + | b"ACL GENPASS" + | b"ACL GETUSER" + | b"ACL HELP" + | b"ACL LIST" + | b"ACL LOG" + | b"ACL USERS" + | b"ACL WHOAMI" + | b"AUTH" + | b"BGSAVE" + | b"CLIENT GETNAME" + | b"CLIENT GETREDIR" + | b"CLIENT ID" + | b"CLIENT INFO" + | b"CLIENT KILL" + | b"CLIENT PAUSE" + | b"CLIENT REPLY" + | b"CLIENT TRACKINGINFO" + | b"CLIENT UNBLOCK" + | b"CLIENT UNPAUSE" + | b"CLUSTER COUNT-FAILURE-REPORTS" + | b"CLUSTER INFO" + | b"CLUSTER KEYSLOT" + | b"CLUSTER MEET" + | b"CLUSTER MYSHARDID" + | b"CLUSTER NODES" + | b"CLUSTER REPLICAS" + | b"CLUSTER RESET" + | b"CLUSTER SET-CONFIG-EPOCH" + | b"CLUSTER SHARDS" + | b"CLUSTER SLOTS" + | b"COMMAND COUNT" + | b"COMMAND GETKEYS" + | b"COMMAND LIST" + | b"COMMAND" + | b"CONFIG GET" + | b"DEBUG" + | b"ECHO" + | b"FUNCTION LIST" + | b"LASTSAVE" + | b"LOLWUT" + | b"MODULE LIST" + | b"MODULE LOAD" + | b"MODULE LOADEX" + | b"MODULE UNLOAD" + | b"READONLY" + | b"READWRITE" + | b"SAVE" + | b"SCRIPT SHOW" + | b"TFCALL" + | b"TFCALLASYNC" + | b"TFUNCTION DELETE" + | b"TFUNCTION LIST" + | b"TFUNCTION LOAD" + | b"TIME" => RouteBy::Random, + + b"CLUSTER ADDSLOTS" + | b"CLUSTER COUNTKEYSINSLOT" + | b"CLUSTER DELSLOTS" + | b"CLUSTER DELSLOTSRANGE" + | b"CLUSTER GETKEYSINSLOT" + | b"CLUSTER SETSLOT" => RouteBy::SecondArgSlot, + + _ => RouteBy::FirstKey, + } +} + +impl RoutingInfo { + /// Returns true if the `cmd` should be routed to all nodes. + pub fn is_all_nodes(cmd: &[u8]) -> bool { + matches!(base_routing(cmd), RouteBy::AllNodes) + } + + /// Returns true if the `cmd` is a key-based command that triggers MOVED errors. + /// A key-based command is one that will be accepted only by the slot owner, + /// while other nodes will respond with a MOVED error redirecting to the relevant primary owner. + pub fn is_key_routing_command(cmd: &[u8]) -> bool { + match base_routing(cmd) { + RouteBy::FirstKey + | RouteBy::SecondArg + | RouteBy::SecondArgAfterKeyCount + | RouteBy::ThirdArgAfterKeyCount + | RouteBy::SecondArgSlot + | RouteBy::StreamsIndex + | RouteBy::MultiShardNoValues + | RouteBy::MultiShardWithValues => { + if matches!(cmd, b"SPUBLISH") { + // SPUBLISH does not return MOVED errors within the slot's shard. This means that even if READONLY wasn't sent to a replica, + // executing SPUBLISH FOO BAR on that replica will succeed. This behavior differs from true key-based commands, + // such as SET FOO BAR, where a non-readonly replica would return a MOVED error if READONLY is off. + // Consequently, SPUBLISH does not meet the requirement of being a command that triggers MOVED errors. + // TODO: remove this when PRIMARY_PREFERRED route for SPUBLISH is added + false + } else { + true + } + } + RouteBy::AllNodes | RouteBy::AllPrimaries | RouteBy::Random | RouteBy::Undefined => { + false + } + } + } + + /// Returns the routing info for `r`. + pub fn for_routable(r: &R) -> Option + where + R: Routable + ?Sized, + { + let cmd = &r.command()?[..]; + match base_routing(cmd) { + RouteBy::AllNodes => Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllNodes, + ResponsePolicy::for_command(cmd), + ))), + + RouteBy::AllPrimaries => Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + ResponsePolicy::for_command(cmd), + ))), + + RouteBy::MultiShardWithValues => multi_shard(r, cmd, 1, true), + + RouteBy::MultiShardNoValues => multi_shard(r, cmd, 1, false), + + RouteBy::Random => Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)), + + RouteBy::ThirdArgAfterKeyCount => { + let key_count = r + .arg_idx(2) + .and_then(|x| std::str::from_utf8(x).ok()) + .and_then(|x| x.parse::().ok())?; + if key_count == 0 { + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + } else { + r.arg_idx(3).map(|key| RoutingInfo::for_key(cmd, key)) + } + } + + RouteBy::SecondArg => r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key)), + + RouteBy::SecondArgAfterKeyCount => { + let key_count = r + .arg_idx(1) + .and_then(|x| std::str::from_utf8(x).ok()) + .and_then(|x| x.parse::().ok())?; + if key_count == 0 { + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + } else { + r.arg_idx(2).map(|key| RoutingInfo::for_key(cmd, key)) + } + } + + RouteBy::StreamsIndex => { + let streams_position = r.position(b"STREAMS")?; + r.arg_idx(streams_position + 1) + .map(|key| RoutingInfo::for_key(cmd, key)) + } + + RouteBy::SecondArgSlot => r + .arg_idx(2) + .and_then(|arg| std::str::from_utf8(arg).ok()) + .and_then(|slot| slot.parse::().ok()) + .map(|slot| { + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + slot, + SlotAddr::Master, + ))) + }), + + RouteBy::FirstKey => match r.arg_idx(1) { + Some(key) => Some(RoutingInfo::for_key(cmd, key)), + None => Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)), + }, + + RouteBy::Undefined => None, + } + } + + fn for_key(cmd: &[u8], key: &[u8]) -> RoutingInfo { + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(get_route( + is_readonly_cmd(cmd), + key, + ))) + } +} + +/// Returns true if the given `routable` represents a readonly command. +pub fn is_readonly(routable: &impl Routable) -> bool { + match routable.command() { + Some(cmd) => is_readonly_cmd(cmd.as_slice()), + None => false, + } +} + +/// Returns `true` if the given `cmd` is a readonly command. +pub fn is_readonly_cmd(cmd: &[u8]) -> bool { + matches!( + cmd, + b"BITCOUNT" + | b"BITFIELD_RO" + | b"BITPOS" + | b"DBSIZE" + | b"DUMP" + | b"EVAL_RO" + | b"EVALSHA_RO" + | b"EXISTS" + | b"EXPIRETIME" + | b"FCALL_RO" + | b"FUNCTION DUMP" + | b"FUNCTION KILL" + | b"FUNCTION LIST" + | b"FUNCTION STATS" + | b"GEODIST" + | b"GEOHASH" + | b"GEOPOS" + | b"GEORADIUSBYMEMBER_RO" + | b"GEORADIUS_RO" + | b"GEOSEARCH" + | b"GET" + | b"GETBIT" + | b"GETRANGE" + | b"HEXISTS" + | b"HGET" + | b"HGETALL" + | b"HKEYS" + | b"HLEN" + | b"HMGET" + | b"HRANDFIELD" + | b"HSCAN" + | b"HSTRLEN" + | b"HVALS" + | b"KEYS" + | b"LCS" + | b"LINDEX" + | b"LLEN" + | b"LOLWUT" + | b"LPOS" + | b"LRANGE" + | b"MEMORY USAGE" + | b"MGET" + | b"OBJECT ENCODING" + | b"OBJECT FREQ" + | b"OBJECT IDLETIME" + | b"OBJECT REFCOUNT" + | b"PEXPIRETIME" + | b"PFCOUNT" + | b"PTTL" + | b"RANDOMKEY" + | b"SCAN" + | b"SCARD" + | b"SCRIPT DEBUG" + | b"SCRIPT EXISTS" + | b"SCRIPT FLUSH" + | b"SCRIPT KILL" + | b"SCRIPT LOAD" + | b"SCRIPT SHOW" + | b"SDIFF" + | b"SINTER" + | b"SINTERCARD" + | b"SISMEMBER" + | b"SMEMBERS" + | b"SMISMEMBER" + | b"SORT_RO" + | b"SRANDMEMBER" + | b"SSCAN" + | b"STRLEN" + | b"SUBSTR" + | b"SUNION" + | b"TOUCH" + | b"TTL" + | b"TYPE" + | b"XINFO CONSUMERS" + | b"XINFO GROUPS" + | b"XINFO STREAM" + | b"XLEN" + | b"XPENDING" + | b"XRANGE" + | b"XREAD" + | b"XREVRANGE" + | b"ZCARD" + | b"ZCOUNT" + | b"ZDIFF" + | b"ZINTER" + | b"ZINTERCARD" + | b"ZLEXCOUNT" + | b"ZMSCORE" + | b"ZRANDMEMBER" + | b"ZRANGE" + | b"ZRANGEBYLEX" + | b"ZRANGEBYSCORE" + | b"ZRANK" + | b"ZREVRANGE" + | b"ZREVRANGEBYLEX" + | b"ZREVRANGEBYSCORE" + | b"ZREVRANK" + | b"ZSCAN" + | b"ZSCORE" + | b"ZUNION" + ) +} + +/// Objects that implement this trait define a request that can be routed by a cluster client to different nodes in the cluster. +pub trait Routable { + /// Convenience function to return ascii uppercase version of the + /// the first argument (i.e., the command). + fn command(&self) -> Option> { + let primary_command = self.arg_idx(0).map(|x| x.to_ascii_uppercase())?; + let mut primary_command = match primary_command.as_slice() { + b"XGROUP" | b"OBJECT" | b"SLOWLOG" | b"FUNCTION" | b"MODULE" | b"COMMAND" + | b"PUBSUB" | b"CONFIG" | b"MEMORY" | b"XINFO" | b"CLIENT" | b"ACL" | b"SCRIPT" + | b"CLUSTER" | b"LATENCY" => primary_command, + _ => { + return Some(primary_command); + } + }; + + Some(match self.arg_idx(1) { + Some(secondary_command) => { + let previous_len = primary_command.len(); + primary_command.reserve(secondary_command.len() + 1); + primary_command.extend(b" "); + primary_command.extend(secondary_command); + let current_len = primary_command.len(); + primary_command[previous_len + 1..current_len].make_ascii_uppercase(); + primary_command + } + None => primary_command, + }) + } + + /// Returns a reference to the data for the argument at `idx`. + fn arg_idx(&self, idx: usize) -> Option<&[u8]>; + + /// Returns index of argument that matches `candidate`, if it exists + fn position(&self, candidate: &[u8]) -> Option; +} + +impl Routable for Cmd { + fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + self.arg_idx(idx) + } + + fn position(&self, candidate: &[u8]) -> Option { + self.args_iter().position(|a| match a { + Arg::Simple(d) => d.eq_ignore_ascii_case(candidate), + _ => false, + }) + } +} + +impl Routable for Value { + fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + match self { + Value::Array(args) => match args.get(idx) { + Some(Value::BulkString(ref data)) => Some(&data[..]), + _ => None, + }, + _ => None, + } + } + + fn position(&self, candidate: &[u8]) -> Option { + match self { + Value::Array(args) => args.iter().position(|a| match a { + Value::BulkString(d) => d.eq_ignore_ascii_case(candidate), + _ => false, + }), + _ => None, + } + } +} + +#[derive(Debug, Hash)] +pub(crate) struct Slot { + pub(crate) start: u16, + pub(crate) end: u16, + pub(crate) master: String, + pub(crate) replicas: Vec, +} + +impl Slot { + pub fn new(s: u16, e: u16, m: String, r: Vec) -> Self { + Self { + start: s, + end: e, + master: m, + replicas: r, + } + } + + pub fn start(&self) -> u16 { + self.start + } + + pub fn end(&self) -> u16 { + self.end + } + + #[allow(dead_code)] // used in tests + pub(crate) fn master(&self) -> &str { + self.master.as_str() + } + + #[allow(dead_code)] // used in tests + pub fn replicas(&self) -> Vec { + self.replicas.clone() + } +} + +/// What type of node should a request be routed to, assuming read from replica is enabled. +#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)] +pub enum SlotAddr { + /// The request must be routed to primary node + Master, + /// The request may be routed to a replica node. + /// For example, a GET command can be routed either to replica or primary. + ReplicaOptional, + /// The request must be routed to replica node, if one exists. + /// For example, by user requested routing. + ReplicaRequired, +} + +/// This is just a simplified version of [`Slot`], +/// which stores only the master and [optional] replica +/// to avoid the need to choose a replica each time +/// a command is executed +#[derive(Debug, Eq, PartialEq)] +pub(crate) struct SlotAddrs { + pub(crate) primary: String, + pub(crate) replicas: Vec, +} + +impl SlotAddrs { + pub(crate) fn new(primary: String, replicas: Vec) -> Self { + Self { primary, replicas } + } + + pub(crate) fn from_slot(slot: Slot) -> Self { + SlotAddrs::new(slot.master, slot.replicas) + } +} + +impl<'a> IntoIterator for &'a SlotAddrs { + type Item = &'a String; + type IntoIter = std::iter::Chain, std::slice::Iter<'a, String>>; + + fn into_iter(self) -> Self::IntoIter { + std::iter::once(&self.primary).chain(self.replicas.iter()) + } +} + +/// Defines the slot and the [`SlotAddr`] to which +/// a command should be sent +#[derive(Eq, PartialEq, Clone, Copy, Debug, Hash)] +pub struct Route(u16, SlotAddr); + +impl Route { + /// Returns a new Route. + pub fn new(slot: u16, slot_addr: SlotAddr) -> Self { + Self(slot, slot_addr) + } + + /// Returns the slot number of the route. + pub fn slot(&self) -> u16 { + self.0 + } + + /// Returns the slot address of the route. + pub fn slot_addr(&self) -> SlotAddr { + self.1 + } + + /// Returns a new Route for a random primary node + pub fn new_random_primary() -> Self { + Self::new(random_slot(), SlotAddr::Master) + } +} + +/// Choose a random slot from `0..SLOT_SIZE` (excluding) +fn random_slot() -> u16 { + let mut rng = rand::thread_rng(); + rng.gen_range(0..crate::cluster_topology::SLOT_SIZE) +} + +#[cfg(test)] +mod tests { + use super::{ + command_for_multi_slot_indices, AggregateOp, MultipleNodeRoutingInfo, ResponsePolicy, + Route, RoutingInfo, SingleNodeRoutingInfo, SlotAddr, + }; + use crate::{cluster_topology::slot, cmd, parser::parse_redis_value, Value}; + use core::panic; + + #[test] + fn test_routing_info_mixed_capatalization() { + let mut upper = cmd("XREAD"); + upper.arg("STREAMS").arg("foo").arg(0); + + let mut lower = cmd("xread"); + lower.arg("streams").arg("foo").arg(0); + + assert_eq!( + RoutingInfo::for_routable(&upper).unwrap(), + RoutingInfo::for_routable(&lower).unwrap() + ); + + let mut mixed = cmd("xReAd"); + mixed.arg("StReAmS").arg("foo").arg(0); + + assert_eq!( + RoutingInfo::for_routable(&lower).unwrap(), + RoutingInfo::for_routable(&mixed).unwrap() + ); + } + + #[test] + fn test_routing_info() { + let mut test_cmds = vec![]; + + // RoutingInfo::AllMasters + let mut test_cmd = cmd("FLUSHALL"); + test_cmd.arg(""); + test_cmds.push(test_cmd); + + // RoutingInfo::AllNodes + test_cmd = cmd("ECHO"); + test_cmd.arg(""); + test_cmds.push(test_cmd); + + // Routing key is 2nd arg ("42") + test_cmd = cmd("SET"); + test_cmd.arg("42"); + test_cmds.push(test_cmd); + + // Routing key is 3rd arg ("FOOBAR") + test_cmd = cmd("XINFO"); + test_cmd.arg("GROUPS").arg("FOOBAR"); + test_cmds.push(test_cmd); + + // Routing key is 3rd or 4th arg (3rd = "0" == RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + test_cmd = cmd("EVAL"); + test_cmd.arg("FOO").arg("0").arg("BAR"); + test_cmds.push(test_cmd); + + // Routing key is 3rd or 4th arg (3rd != "0" == RoutingInfo::Slot) + test_cmd = cmd("EVAL"); + test_cmd.arg("FOO").arg("4").arg("BAR"); + test_cmds.push(test_cmd); + + // Routing key position is variable, 3rd arg + test_cmd = cmd("XREAD"); + test_cmd.arg("STREAMS").arg("4"); + test_cmds.push(test_cmd); + + // Routing key position is variable, 4th arg + test_cmd = cmd("XREAD"); + test_cmd.arg("FOO").arg("STREAMS").arg("4"); + test_cmds.push(test_cmd); + + for cmd in test_cmds { + let value = parse_redis_value(&cmd.get_packed_command()).unwrap(); + assert_eq!( + RoutingInfo::for_routable(&value).unwrap(), + RoutingInfo::for_routable(&cmd).unwrap(), + ); + } + + // Assert expected RoutingInfo explicitly: + + for cmd in [cmd("FLUSHALL"), cmd("FLUSHDB"), cmd("PING")] { + assert_eq!( + RoutingInfo::for_routable(&cmd), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::AllSucceeded) + ))) + ); + } + + assert_eq!( + RoutingInfo::for_routable(&cmd("DBSIZE")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::Aggregate(AggregateOp::Sum)) + ))) + ); + + assert_eq!( + RoutingInfo::for_routable(&cmd("SCRIPT KILL")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllNodes, + Some(ResponsePolicy::OneSucceeded) + ))) + ); + + assert_eq!( + RoutingInfo::for_routable(&cmd("INFO")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::Special) + ))) + ); + + assert_eq!( + RoutingInfo::for_routable(&cmd("KEYS")), + Some(RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(ResponsePolicy::CombineArrays) + ))) + ); + + for cmd in vec![ + cmd("SCAN"), + cmd("SHUTDOWN"), + cmd("SLAVEOF"), + cmd("REPLICAOF"), + ] { + assert_eq!( + RoutingInfo::for_routable(&cmd), + None, + "{}", + std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap() + ); + } + + for cmd in [ + cmd("EVAL").arg(r#"redis.call("PING");"#).arg(0), + cmd("EVALSHA").arg(r#"redis.call("PING");"#).arg(0), + ] { + assert_eq!( + RoutingInfo::for_routable(cmd), + Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + ); + } + + // While FCALL with N keys is expected to be routed to a specific node + assert_eq!( + RoutingInfo::for_routable(cmd("FCALL").arg("foo").arg(1).arg("mykey")), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"mykey"), SlotAddr::Master)) + )) + ); + + for (cmd, expected) in [ + ( + cmd("EVAL") + .arg(r#"redis.call("GET, KEYS[1]");"#) + .arg(1) + .arg("foo"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(slot(b"foo"), SlotAddr::Master)), + )), + ), + ( + cmd("XGROUP") + .arg("CREATE") + .arg("mystream") + .arg("workers") + .arg("$") + .arg("MKSTREAM"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"mystream"), + SlotAddr::Master, + )), + )), + ), + ( + cmd("XINFO").arg("GROUPS").arg("foo"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"foo"), + SlotAddr::ReplicaOptional, + )), + )), + ), + ( + cmd("XREADGROUP") + .arg("GROUP") + .arg("wkrs") + .arg("consmrs") + .arg("STREAMS") + .arg("mystream"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"mystream"), + SlotAddr::Master, + )), + )), + ), + ( + cmd("XREAD") + .arg("COUNT") + .arg("2") + .arg("STREAMS") + .arg("mystream") + .arg("writers") + .arg("0-0") + .arg("0-0"), + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new( + slot(b"mystream"), + SlotAddr::ReplicaOptional, + )), + )), + ), + ] { + assert_eq!( + RoutingInfo::for_routable(cmd), + expected, + "{}", + std::str::from_utf8(cmd.arg_idx(0).unwrap()).unwrap() + ); + } + } + + #[test] + fn test_slot_for_packed_cmd() { + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 50, 13, 10, 36, 54, 13, 10, 69, 88, 73, 83, 84, 83, 13, 10, 36, 49, 54, 13, 10, + 244, 93, 23, 40, 126, 127, 253, 33, 89, 47, 185, 204, 171, 249, 96, 139, 13, 10 + ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::ReplicaOptional)))) if slot == 964)); + + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 36, 241, + 197, 111, 180, 254, 5, 175, 143, 146, 171, 39, 172, 23, 164, 145, 13, 10, 36, 52, + 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10, + 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10 + ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 8352)); + + assert!(matches!(RoutingInfo::for_routable(&parse_redis_value(&[ + 42, 54, 13, 10, 36, 51, 13, 10, 83, 69, 84, 13, 10, 36, 49, 54, 13, 10, 169, 233, + 247, 59, 50, 247, 100, 232, 123, 140, 2, 101, 125, 221, 66, 170, 13, 10, 36, 52, + 13, 10, 116, 114, 117, 101, 13, 10, 36, 50, 13, 10, 78, 88, 13, 10, 36, 50, 13, 10, + 80, 88, 13, 10, 36, 55, 13, 10, 49, 56, 48, 48, 48, 48, 48, 13, 10 + ]).unwrap()), Some(RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route(slot, SlotAddr::Master)))) if slot == 5210)); + } + + #[test] + fn test_multi_shard() { + let mut cmd = cmd("DEL"); + cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz"); + let routing = RoutingInfo::for_routable(&cmd); + let mut expected = std::collections::HashMap::new(); + expected.insert(Route(4813, SlotAddr::Master), vec![2]); + expected.insert(Route(5061, SlotAddr::Master), vec![1, 3]); + expected.insert(Route(12182, SlotAddr::Master), vec![0]); + + assert!( + matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot(vec), Some(ResponsePolicy::Aggregate(AggregateOp::Sum))))) if { + let routes = vec.clone().into_iter().collect(); + expected == routes + }), + "{routing:?}" + ); + + let mut cmd = crate::cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz").arg("{bar}vaz"); + let routing = RoutingInfo::for_routable(&cmd); + let mut expected = std::collections::HashMap::new(); + expected.insert(Route(4813, SlotAddr::ReplicaOptional), vec![2]); + expected.insert(Route(5061, SlotAddr::ReplicaOptional), vec![1, 3]); + expected.insert(Route(12182, SlotAddr::ReplicaOptional), vec![0]); + + assert!( + matches!(routing.clone(), Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot(vec), Some(ResponsePolicy::CombineArrays)))) if { + let routes = vec.clone().into_iter().collect(); + expected ==routes + }), + "{routing:?}" + ); + } + + #[test] + fn test_command_creation_for_multi_shard() { + let mut original_cmd = cmd("DEL"); + original_cmd + .arg("foo") + .arg("bar") + .arg("baz") + .arg("{bar}vaz"); + let routing = RoutingInfo::for_routable(&original_cmd); + let expected = [vec![0], vec![1, 3], vec![2]]; + + let mut indices: Vec<_> = match routing { + Some(RoutingInfo::MultiNode((MultipleNodeRoutingInfo::MultiSlot(vec), _))) => { + vec.into_iter().map(|(_, indices)| indices).collect() + } + _ => panic!("unexpected routing: {routing:?}"), + }; + indices.sort_by(|prev, next| prev.iter().next().unwrap().cmp(next.iter().next().unwrap())); // sorting because the `for_routable` doesn't return values in a consistent order between runs. + + for (index, indices) in indices.into_iter().enumerate() { + let cmd = command_for_multi_slot_indices(&original_cmd, indices.iter()); + let expected_indices = &expected[index]; + assert_eq!(original_cmd.arg_idx(0), cmd.arg_idx(0)); + for (index, target_index) in expected_indices.iter().enumerate() { + let target_index = target_index + 1; + assert_eq!(original_cmd.arg_idx(target_index), cmd.arg_idx(index + 1)); + } + } + } + + #[test] + fn test_combine_multi_shard_to_single_node_when_all_keys_are_in_same_slot() { + let mut cmd = cmd("DEL"); + cmd.arg("foo").arg("{foo}bar").arg("{foo}baz"); + let routing = RoutingInfo::for_routable(&cmd); + + assert!( + matches!( + routing, + Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route(12182, SlotAddr::Master)) + )) + ), + "{routing:?}" + ); + } + + #[test] + fn test_combining_results_into_single_array() { + let res1 = Value::Array(vec![Value::Nil, Value::Okay]); + let res2 = Value::Array(vec![ + Value::BulkString("1".as_bytes().to_vec()), + Value::BulkString("4".as_bytes().to_vec()), + ]); + let res3 = Value::Array(vec![Value::SimpleString("2".to_string()), Value::Int(3)]); + let results = super::combine_and_sort_array_results( + vec![res1, res2, res3], + [vec![0, 5], vec![1, 4], vec![2, 3]].iter(), + ); + + assert_eq!( + results.unwrap(), + Value::Array(vec![ + Value::Nil, + Value::BulkString("1".as_bytes().to_vec()), + Value::SimpleString("2".to_string()), + Value::Int(3), + Value::BulkString("4".as_bytes().to_vec()), + Value::Okay, + ]) + ); + } + + #[test] + fn test_combine_map_results() { + let input = vec![]; + let result = super::combine_map_results(input).unwrap(); + assert_eq!(result, Value::Map(vec![])); + + let input = vec![ + Value::Array(vec![ + Value::BulkString(b"key1".to_vec()), + Value::Int(5), + Value::BulkString(b"key2".to_vec()), + Value::Int(10), + ]), + Value::Array(vec![ + Value::BulkString(b"key1".to_vec()), + Value::Int(3), + Value::BulkString(b"key3".to_vec()), + Value::Int(15), + ]), + ]; + let result = super::combine_map_results(input).unwrap(); + let mut expected = vec![ + (Value::BulkString(b"key1".to_vec()), Value::Int(8)), + (Value::BulkString(b"key2".to_vec()), Value::Int(10)), + (Value::BulkString(b"key3".to_vec()), Value::Int(15)), + ]; + expected.sort_unstable_by(|a, b| match (&a.0, &b.0) { + (Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes), + _ => std::cmp::Ordering::Equal, + }); + let mut result_vec = match result { + Value::Map(v) => v, + _ => panic!("Expected Map"), + }; + result_vec.sort_unstable_by(|a, b| match (&a.0, &b.0) { + (Value::BulkString(a_bytes), Value::BulkString(b_bytes)) => a_bytes.cmp(b_bytes), + _ => std::cmp::Ordering::Equal, + }); + assert_eq!(result_vec, expected); + + let input = vec![Value::Int(5)]; + let result = super::combine_map_results(input); + assert!(result.is_err()); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_slotmap.rs b/glide-core/redis-rs/redis/src/cluster_slotmap.rs new file mode 100644 index 0000000000..7f1f70af98 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_slotmap.rs @@ -0,0 +1,435 @@ +use std::{ + collections::{BTreeMap, HashSet}, + fmt::Display, + sync::atomic::AtomicUsize, +}; + +use crate::cluster_routing::{Route, Slot, SlotAddr, SlotAddrs}; + +#[derive(Debug)] +pub(crate) struct SlotMapValue { + pub(crate) start: u16, + pub(crate) addrs: SlotAddrs, + pub(crate) latest_used_replica: AtomicUsize, +} + +impl SlotMapValue { + fn from_slot(slot: Slot) -> Self { + Self { + start: slot.start(), + addrs: SlotAddrs::from_slot(slot), + latest_used_replica: AtomicUsize::new(0), + } + } +} + +#[derive(Debug, Default, Clone, PartialEq, Copy)] +pub(crate) enum ReadFromReplicaStrategy { + #[default] + AlwaysFromPrimary, + RoundRobin, +} + +#[derive(Debug, Default)] +pub(crate) struct SlotMap { + pub(crate) slots: BTreeMap, + read_from_replica: ReadFromReplicaStrategy, +} + +fn get_address_from_slot( + slot: &SlotMapValue, + read_from_replica: ReadFromReplicaStrategy, + slot_addr: SlotAddr, +) -> &str { + if slot_addr == SlotAddr::Master || slot.addrs.replicas.is_empty() { + return slot.addrs.primary.as_str(); + } + match read_from_replica { + ReadFromReplicaStrategy::AlwaysFromPrimary => slot.addrs.primary.as_str(), + ReadFromReplicaStrategy::RoundRobin => { + let index = slot + .latest_used_replica + .fetch_add(1, std::sync::atomic::Ordering::Relaxed) + % slot.addrs.replicas.len(); + slot.addrs.replicas[index].as_str() + } + } +} + +impl SlotMap { + pub(crate) fn new(slots: Vec, read_from_replica: ReadFromReplicaStrategy) -> Self { + let mut this = Self { + slots: BTreeMap::new(), + read_from_replica, + }; + this.slots.extend( + slots + .into_iter() + .map(|slot| (slot.end(), SlotMapValue::from_slot(slot))), + ); + this + } + + pub fn slot_value_for_route(&self, route: &Route) -> Option<&SlotMapValue> { + let slot = route.slot(); + self.slots + .range(slot..) + .next() + .and_then(|(end, slot_value)| { + if slot <= *end && slot_value.start <= slot { + Some(slot_value) + } else { + None + } + }) + } + + pub fn slot_addr_for_route(&self, route: &Route) -> Option<&str> { + self.slot_value_for_route(route).map(|slot_value| { + get_address_from_slot(slot_value, self.read_from_replica, route.slot_addr()) + }) + } + + pub fn values(&self) -> impl Iterator { + self.slots.values().map(|slot_value| &slot_value.addrs) + } + + fn all_unique_addresses(&self, only_primaries: bool) -> HashSet<&str> { + let mut addresses = HashSet::new(); + for slot in self.values() { + addresses.insert(slot.primary.as_str()); + if !only_primaries { + addresses.extend(slot.replicas.iter().map(|str| str.as_str())); + } + } + + addresses + } + + pub fn addresses_for_all_primaries(&self) -> HashSet<&str> { + self.all_unique_addresses(true) + } + + pub fn addresses_for_all_nodes(&self) -> HashSet<&str> { + self.all_unique_addresses(false) + } + + pub fn addresses_for_multi_slot<'a, 'b>( + &'a self, + routes: &'b [(Route, Vec)], + ) -> impl Iterator> + 'a + where + 'b: 'a, + { + routes + .iter() + .map(|(route, _)| self.slot_addr_for_route(route)) + } + + // Returns the slots that are assigned to the given address. + pub(crate) fn get_slots_of_node(&self, node_address: &str) -> Vec { + let node_address = node_address.to_string(); + self.slots + .iter() + .filter_map(|(end, slot_value)| { + if slot_value.addrs.primary == node_address + || slot_value.addrs.replicas.contains(&node_address) + { + Some(slot_value.start..(*end + 1)) + } else { + None + } + }) + .flatten() + .collect() + } + + pub(crate) fn get_node_address_for_slot( + &self, + slot: u16, + slot_addr: SlotAddr, + ) -> Option { + self.slots.range(slot..).next().and_then(|(_, slot_value)| { + if slot_value.start <= slot { + Some( + get_address_from_slot(slot_value, self.read_from_replica, slot_addr) + .to_string(), + ) + } else { + None + } + }) + } +} + +impl Display for SlotMap { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + writeln!(f, "Strategy: {:?}. Slot mapping:", self.read_from_replica)?; + for (end, slot_map_value) in self.slots.iter() { + writeln!( + f, + "({}-{}): primary: {}, replicas: {:?}", + slot_map_value.start, + end, + slot_map_value.addrs.primary, + slot_map_value.addrs.replicas + )?; + } + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_slot_map_retrieve_routes() { + let slot_map = SlotMap::new( + vec![ + Slot::new( + 1, + 1000, + "node1:6379".to_owned(), + vec!["replica1:6379".to_owned()], + ), + Slot::new( + 1002, + 2000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned()], + ), + ], + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); + + assert!(slot_map + .slot_addr_for_route(&Route::new(0, SlotAddr::Master)) + .is_none()); + assert_eq!( + "node1:6379", + slot_map + .slot_addr_for_route(&Route::new(1, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node1:6379", + slot_map + .slot_addr_for_route(&Route::new(500, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node1:6379", + slot_map + .slot_addr_for_route(&Route::new(1000, SlotAddr::Master)) + .unwrap() + ); + assert!(slot_map + .slot_addr_for_route(&Route::new(1001, SlotAddr::Master)) + .is_none()); + + assert_eq!( + "node2:6379", + slot_map + .slot_addr_for_route(&Route::new(1002, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node2:6379", + slot_map + .slot_addr_for_route(&Route::new(1500, SlotAddr::Master)) + .unwrap() + ); + assert_eq!( + "node2:6379", + slot_map + .slot_addr_for_route(&Route::new(2000, SlotAddr::Master)) + .unwrap() + ); + assert!(slot_map + .slot_addr_for_route(&Route::new(2001, SlotAddr::Master)) + .is_none()); + } + + fn get_slot_map(read_from_replica: ReadFromReplicaStrategy) -> SlotMap { + SlotMap::new( + vec![ + Slot::new( + 1, + 1000, + "node1:6379".to_owned(), + vec!["replica1:6379".to_owned()], + ), + Slot::new( + 1002, + 2000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned(), "replica3:6379".to_owned()], + ), + Slot::new( + 2001, + 3000, + "node3:6379".to_owned(), + vec![ + "replica4:6379".to_owned(), + "replica5:6379".to_owned(), + "replica6:6379".to_owned(), + ], + ), + Slot::new( + 3001, + 4000, + "node2:6379".to_owned(), + vec!["replica2:6379".to_owned(), "replica3:6379".to_owned()], + ), + ], + read_from_replica, + ) + } + + #[test] + fn test_slot_map_get_all_primaries() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); + let addresses = slot_map.addresses_for_all_primaries(); + assert_eq!( + addresses, + HashSet::from_iter(["node1:6379", "node2:6379", "node3:6379"]) + ); + } + + #[test] + fn test_slot_map_get_all_nodes() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); + let addresses = slot_map.addresses_for_all_nodes(); + assert_eq!( + addresses, + HashSet::from_iter([ + "node1:6379", + "node2:6379", + "node3:6379", + "replica1:6379", + "replica2:6379", + "replica3:6379", + "replica4:6379", + "replica5:6379", + "replica6:6379" + ]) + ); + } + + #[test] + fn test_slot_map_get_multi_node() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let routes = vec![ + (Route::new(1, SlotAddr::Master), vec![]), + (Route::new(2001, SlotAddr::ReplicaOptional), vec![]), + ]; + let addresses = slot_map + .addresses_for_multi_slot(&routes) + .collect::>(); + assert!(addresses.contains(&Some("node1:6379"))); + assert!( + addresses.contains(&Some("replica4:6379")) + || addresses.contains(&Some("replica5:6379")) + || addresses.contains(&Some("replica6:6379")) + ); + } + + /// This test is needed in order to verify that if the MultiSlot route finds the same node for more than a single route, + /// that node's address will appear multiple times, in the same order. + #[test] + fn test_slot_map_get_repeating_addresses_when_the_same_node_is_found_in_multi_slot() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let routes = vec![ + (Route::new(1, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2001, SlotAddr::Master), vec![]), + (Route::new(2, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2002, SlotAddr::Master), vec![]), + (Route::new(3, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2003, SlotAddr::Master), vec![]), + ]; + let addresses = slot_map + .addresses_for_multi_slot(&routes) + .collect::>(); + assert_eq!( + addresses, + vec![ + Some("replica1:6379"), + Some("node3:6379"), + Some("replica1:6379"), + Some("node3:6379"), + Some("replica1:6379"), + Some("node3:6379") + ] + ); + } + + #[test] + fn test_slot_map_get_none_when_slot_is_missing_from_multi_slot() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let routes = vec![ + (Route::new(1, SlotAddr::ReplicaOptional), vec![]), + (Route::new(5000, SlotAddr::Master), vec![]), + (Route::new(6000, SlotAddr::ReplicaOptional), vec![]), + (Route::new(2002, SlotAddr::Master), vec![]), + ]; + let addresses = slot_map + .addresses_for_multi_slot(&routes) + .collect::>(); + assert_eq!( + addresses, + vec![Some("replica1:6379"), None, None, Some("node3:6379")] + ); + } + + #[test] + fn test_slot_map_rotate_read_replicas() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::RoundRobin); + let route = Route::new(2001, SlotAddr::ReplicaOptional); + let mut addresses = vec![ + slot_map.slot_addr_for_route(&route).unwrap(), + slot_map.slot_addr_for_route(&route).unwrap(), + slot_map.slot_addr_for_route(&route).unwrap(), + ]; + addresses.sort(); + assert_eq!( + addresses, + vec!["replica4:6379", "replica5:6379", "replica6:6379"] + ); + } + + #[test] + fn test_get_slots_of_node() { + let slot_map = get_slot_map(ReadFromReplicaStrategy::AlwaysFromPrimary); + assert_eq!( + slot_map.get_slots_of_node("node1:6379"), + (1..1001).collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node("node2:6379"), + vec![1002..2001, 3001..4001] + .into_iter() + .flatten() + .collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node("replica3:6379"), + vec![1002..2001, 3001..4001] + .into_iter() + .flatten() + .collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node("replica4:6379"), + (2001..3001).collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node("replica5:6379"), + (2001..3001).collect::>() + ); + assert_eq!( + slot_map.get_slots_of_node("replica6:6379"), + (2001..3001).collect::>() + ); + } +} diff --git a/glide-core/redis-rs/redis/src/cluster_topology.rs b/glide-core/redis-rs/redis/src/cluster_topology.rs new file mode 100644 index 0000000000..a2ce9ea078 --- /dev/null +++ b/glide-core/redis-rs/redis/src/cluster_topology.rs @@ -0,0 +1,645 @@ +//! This module provides the functionality to refresh and calculate the cluster topology for Redis Cluster. + +use crate::cluster::get_connection_addr; +#[cfg(feature = "cluster-async")] +use crate::cluster_client::SlotsRefreshRateLimit; +use crate::cluster_routing::Slot; +use crate::cluster_slotmap::{ReadFromReplicaStrategy, SlotMap}; +use crate::{cluster::TlsMode, ErrorKind, RedisError, RedisResult, Value}; +#[cfg(all(feature = "cluster-async", not(feature = "tokio-comp")))] +use async_std::sync::RwLock; +use derivative::Derivative; +use std::collections::{hash_map::DefaultHasher, HashMap}; +use std::hash::{Hash, Hasher}; +use std::sync::atomic::AtomicBool; +use std::sync::Arc; +use std::time::{Duration, SystemTime}; +#[cfg(all(feature = "cluster-async", feature = "tokio-comp"))] +use tokio::sync::RwLock; +use tracing::info; + +// Exponential backoff constants for retrying a slot refresh +/// The default number of refresh topology retries in the same call +pub const DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES: usize = 3; +/// The default maximum interval between two retries of the same call for topology refresh +pub const DEFAULT_REFRESH_SLOTS_RETRY_MAX_INTERVAL: Duration = Duration::from_secs(1); +/// The default initial interval for retrying topology refresh +pub const DEFAULT_REFRESH_SLOTS_RETRY_INITIAL_INTERVAL: Duration = Duration::from_millis(500); + +// Constants for the intervals between two independent consecutive refresh slots calls +/// The default wait duration between two consecutive refresh slots calls +#[cfg(feature = "cluster-async")] +pub const DEFAULT_SLOTS_REFRESH_WAIT_DURATION: Duration = Duration::from_secs(15); +/// The default maximum jitter duration to add to the refresh slots wait duration +#[cfg(feature = "cluster-async")] +pub const DEFAULT_SLOTS_REFRESH_MAX_JITTER_MILLI: u64 = 15 * 1000; // 15 seconds + +pub(crate) const SLOT_SIZE: u16 = 16384; +pub(crate) type TopologyHash = u64; + +/// Represents the state of slot refresh operations. +#[cfg(feature = "cluster-async")] +pub(crate) struct SlotRefreshState { + /// Indicates if a slot refresh is currently in progress + pub(crate) in_progress: AtomicBool, + /// The last slot refresh run timestamp + pub(crate) last_run: Arc>>, + pub(crate) rate_limiter: SlotsRefreshRateLimit, +} + +#[cfg(feature = "cluster-async")] +impl SlotRefreshState { + pub(crate) fn new(rate_limiter: SlotsRefreshRateLimit) -> Self { + Self { + in_progress: AtomicBool::new(false), + last_run: Arc::new(RwLock::new(None)), + rate_limiter, + } + } +} + +#[derive(Derivative)] +#[derivative(PartialEq, Eq)] +#[derive(Debug)] +pub(crate) struct TopologyView { + pub(crate) hash_value: TopologyHash, + #[derivative(PartialEq = "ignore")] + pub(crate) nodes_count: u16, + #[derivative(PartialEq = "ignore")] + slots_and_count: (u16, Vec), +} + +pub(crate) fn slot(key: &[u8]) -> u16 { + crc16::State::::calculate(key) % SLOT_SIZE +} + +fn get_hashtag(key: &[u8]) -> Option<&[u8]> { + let open = key.iter().position(|v| *v == b'{'); + let open = match open { + Some(open) => open, + None => return None, + }; + + let close = key[open..].iter().position(|v| *v == b'}'); + let close = match close { + Some(close) => close, + None => return None, + }; + + let rv = &key[open + 1..open + close]; + if rv.is_empty() { + None + } else { + Some(rv) + } +} + +/// Returns the slot that matches `key`. +pub fn get_slot(key: &[u8]) -> u16 { + let key = match get_hashtag(key) { + Some(tag) => tag, + None => key, + }; + + slot(key) +} + +// Parse slot data from raw redis value. +pub(crate) fn parse_and_count_slots( + raw_slot_resp: &Value, + tls: Option, + // The DNS address of the node from which `raw_slot_resp` was received. + addr_of_answering_node: &str, +) -> RedisResult<(u16, Vec)> { + // Parse response. + let mut slots = Vec::with_capacity(2); + let mut count = 0; + + if let Value::Array(items) = raw_slot_resp { + let mut iter = items.iter(); + while let Some(Value::Array(item)) = iter.next() { + if item.len() < 3 { + continue; + } + + let start = if let Value::Int(start) = item[0] { + start as u16 + } else { + continue; + }; + + let end = if let Value::Int(end) = item[1] { + end as u16 + } else { + continue; + }; + + let mut nodes: Vec = item + .iter() + .skip(2) + .filter_map(|node| { + if let Value::Array(node) = node { + if node.len() < 2 { + return None; + } + // According to the CLUSTER SLOTS documentation: + // If the received hostname is an empty string or NULL, clients should utilize the hostname of the responding node. + // However, if the received hostname is "?", it should be regarded as an indication of an unknown node. + let hostname = if let Value::BulkString(ref ip) = node[0] { + let hostname = String::from_utf8_lossy(ip); + if hostname.is_empty() { + addr_of_answering_node.into() + } else if hostname == "?" { + return None; + } else { + hostname + } + } else if let Value::Nil = node[0] { + addr_of_answering_node.into() + } else { + return None; + }; + if hostname.is_empty() { + return None; + } + + let port = if let Value::Int(port) = node[1] { + port as u16 + } else { + return None; + }; + Some( + get_connection_addr(hostname.into_owned(), port, tls, None).to_string(), + ) + } else { + None + } + }) + .collect(); + + if nodes.is_empty() { + continue; + } + count += end - start; + + let mut replicas = nodes.split_off(1); + // we sort the replicas, because different nodes in a cluster might return the same slot view + // with different order of the replicas, which might cause the views to be considered evaluated as not equal. + replicas.sort_unstable(); + slots.push(Slot::new(start, end, nodes.pop().unwrap(), replicas)); + } + } + if slots.is_empty() { + return Err(RedisError::from(( + ErrorKind::ResponseError, + "Error parsing slots: No healthy node found", + format!("Raw slot map response: {:?}", raw_slot_resp), + ))); + } + + Ok((count, slots)) +} + +fn calculate_hash(t: &T) -> u64 { + let mut s = DefaultHasher::new(); + t.hash(&mut s); + s.finish() +} + +pub(crate) fn calculate_topology<'a>( + topology_views: impl Iterator, + curr_retry: usize, + tls_mode: Option, + num_of_queried_nodes: usize, + read_from_replica: ReadFromReplicaStrategy, +) -> RedisResult<(SlotMap, TopologyHash)> { + let mut hash_view_map = HashMap::new(); + for (host, view) in topology_views { + if let Ok(slots_and_count) = parse_and_count_slots(view, tls_mode, host) { + let hash_value = calculate_hash(&slots_and_count); + let topology_entry = hash_view_map.entry(hash_value).or_insert(TopologyView { + hash_value, + nodes_count: 0, + slots_and_count, + }); + topology_entry.nodes_count += 1; + } + } + let mut non_unique_max_node_count = false; + let mut vec_iter = hash_view_map.into_values(); + let mut most_frequent_topology = match vec_iter.next() { + Some(view) => view, + None => { + return Err(RedisError::from(( + ErrorKind::ResponseError, + "No topology views found", + ))); + } + }; + // Find the most frequent topology view + for curr_view in vec_iter { + match most_frequent_topology + .nodes_count + .cmp(&curr_view.nodes_count) + { + std::cmp::Ordering::Less => { + most_frequent_topology = curr_view; + non_unique_max_node_count = false; + } + std::cmp::Ordering::Greater => continue, + std::cmp::Ordering::Equal => { + non_unique_max_node_count = true; + let seen_slot_count = most_frequent_topology.slots_and_count.0; + + // We choose as the greater view the one with higher slot coverage. + if let std::cmp::Ordering::Less = seen_slot_count.cmp(&curr_view.slots_and_count.0) + { + most_frequent_topology = curr_view; + } + } + } + } + + let parse_and_built_result = |most_frequent_topology: TopologyView| { + info!( + "calculate_topology found topology map:\n{:?}", + most_frequent_topology + ); + let slots_data = most_frequent_topology.slots_and_count.1; + Ok(( + SlotMap::new(slots_data, read_from_replica), + most_frequent_topology.hash_value, + )) + }; + + if non_unique_max_node_count { + // More than a single most frequent view was found + // If we reached the last retry, or if we it's a 2-nodes cluster, we'll return a view with the highest slot coverage, and that is one of most agreed on views. + if curr_retry >= DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES || num_of_queried_nodes < 3 { + return parse_and_built_result(most_frequent_topology); + } + return Err(RedisError::from(( + ErrorKind::ResponseError, + "Slot refresh error: Failed to obtain a majority in topology views", + ))); + } + + // The rate of agreement of the topology view is determined by assessing the number of nodes that share this view out of the total number queried + let agreement_rate = most_frequent_topology.nodes_count as f32 / num_of_queried_nodes as f32; + const MIN_AGREEMENT_RATE: f32 = 0.2; + if agreement_rate >= MIN_AGREEMENT_RATE { + parse_and_built_result(most_frequent_topology) + } else { + Err(RedisError::from(( + ErrorKind::ResponseError, + "Slot refresh error: The accuracy of the topology view is too low", + ))) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::cluster_routing::SlotAddrs; + + #[test] + fn test_get_hashtag() { + assert_eq!(get_hashtag(&b"foo{bar}baz"[..]), Some(&b"bar"[..])); + assert_eq!(get_hashtag(&b"foo{}{baz}"[..]), None); + assert_eq!(get_hashtag(&b"foo{{bar}}zap"[..]), Some(&b"{bar"[..])); + } + + fn slot_value_with_replicas(start: u16, end: u16, nodes: Vec<(&str, u16)>) -> Value { + let mut node_values: Vec = nodes + .iter() + .map(|(host, port)| { + Value::Array(vec![ + Value::BulkString(host.as_bytes().to_vec()), + Value::Int(*port as i64), + ]) + }) + .collect(); + let mut slot_vec = vec![Value::Int(start as i64), Value::Int(end as i64)]; + slot_vec.append(&mut node_values); + Value::Array(slot_vec) + } + + fn slot_value(start: u16, end: u16, node: &str, port: u16) -> Value { + slot_value_with_replicas(start, end, vec![(node, port)]) + } + + #[test] + fn parse_slots_with_different_replicas_order_returns_the_same_view() { + let view1 = Value::Array(vec![ + slot_value_with_replicas( + 0, + 4000, + vec![ + ("primary1", 6379), + ("replica1_1", 6379), + ("replica1_2", 6379), + ("replica1_3", 6379), + ], + ), + slot_value_with_replicas( + 4001, + 8000, + vec![ + ("primary2", 6379), + ("replica2_1", 6379), + ("replica2_2", 6379), + ("replica2_3", 6379), + ], + ), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("primary3", 6379), + ("replica3_1", 6379), + ("replica3_2", 6379), + ("replica3_3", 6379), + ], + ), + ]); + + let view2 = Value::Array(vec![ + slot_value_with_replicas( + 0, + 4000, + vec![ + ("primary1", 6379), + ("replica1_1", 6379), + ("replica1_3", 6379), + ("replica1_2", 6379), + ], + ), + slot_value_with_replicas( + 4001, + 8000, + vec![ + ("primary2", 6379), + ("replica2_2", 6379), + ("replica2_3", 6379), + ("replica2_1", 6379), + ], + ), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("primary3", 6379), + ("replica3_3", 6379), + ("replica3_1", 6379), + ("replica3_2", 6379), + ], + ), + ]); + + let res1 = parse_and_count_slots(&view1, None, "foo").unwrap(); + let res2 = parse_and_count_slots(&view2, None, "foo").unwrap(); + assert_eq!(calculate_hash(&res1), calculate_hash(&res2)); + assert_eq!(res1.0, res2.0); + assert_eq!(res1.1.len(), res2.1.len()); + let check = res1 + .1 + .into_iter() + .zip(res2.1) + .all(|(first, second)| first.replicas() == second.replicas()); + assert!(check); + } + + #[test] + fn parse_slots_returns_slots_with_host_name_if_missing() { + let view = Value::Array(vec![slot_value(0, 4000, "", 6379)]); + + let (slot_count, slots) = parse_and_count_slots(&view, None, "node").unwrap(); + assert_eq!(slot_count, 4000); + assert_eq!(slots[0].master(), "node:6379"); + } + + #[test] + fn should_parse_and_hash_regardless_of_missing_host_name_and_replicas_order() { + let view1 = Value::Array(vec![ + slot_value(0, 4000, "", 6379), + slot_value(4001, 8000, "node2", 6380), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("node3", 6379), + ("replica3_1", 6379), + ("replica3_2", 6379), + ("replica3_3", 6379), + ], + ), + ]); + + let view2 = Value::Array(vec![ + slot_value(0, 4000, "node1", 6379), + slot_value(4001, 8000, "node2", 6380), + slot_value_with_replicas( + 8001, + 16383, + vec![ + ("", 6379), + ("replica3_3", 6379), + ("replica3_2", 6379), + ("replica3_1", 6379), + ], + ), + ]); + + let res1 = parse_and_count_slots(&view1, None, "node1").unwrap(); + let res2 = parse_and_count_slots(&view2, None, "node3").unwrap(); + + assert_eq!(calculate_hash(&res1), calculate_hash(&res2)); + assert_eq!(res1.0, res2.0); + assert_eq!(res1.1.len(), res2.1.len()); + let equality_check = + res1.1.iter().zip(&res2.1).all(|(first, second)| { + first.start() == second.start() && first.end() == second.end() + }); + assert!(equality_check); + let replicas_check = res1 + .1 + .iter() + .zip(res2.1) + .all(|(first, second)| first.replicas() == second.replicas()); + assert!(replicas_check); + } + + enum ViewType { + SingleNodeViewFullCoverage, + SingleNodeViewMissingSlots, + TwoNodesViewFullCoverage, + TwoNodesViewMissingSlots, + } + fn get_view(view_type: &ViewType) -> (&str, Value) { + match view_type { + ViewType::SingleNodeViewFullCoverage => ( + "first", + Value::Array(vec![slot_value(0, 16383, "node1", 6379)]), + ), + ViewType::SingleNodeViewMissingSlots => ( + "second", + Value::Array(vec![slot_value(0, 4000, "node1", 6379)]), + ), + ViewType::TwoNodesViewFullCoverage => ( + "third", + Value::Array(vec![ + slot_value(0, 4000, "node1", 6379), + slot_value(4001, 16383, "node2", 6380), + ]), + ), + ViewType::TwoNodesViewMissingSlots => ( + "fourth", + Value::Array(vec![ + slot_value(0, 3000, "node3", 6381), + slot_value(4001, 16383, "node4", 6382), + ]), + ), + } + } + + fn get_node_addr(name: &str, port: u16) -> SlotAddrs { + SlotAddrs::new(format!("{name}:{port}"), Vec::new()) + } + + #[test] + fn test_topology_calculator_4_nodes_queried_has_a_majority_success() { + // 4 nodes queried (1 error): Has a majority, single_node_view should be chosen + let queried_nodes: usize = 4; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewFullCoverage), + get_view(&ViewType::SingleNodeViewFullCoverage), + get_view(&ViewType::TwoNodesViewFullCoverage), + ]; + + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res: Vec<_> = topology_view.values().collect(); + let node_1 = get_node_addr("node1", 6379); + let expected: Vec<&SlotAddrs> = vec![&node_1]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_3_nodes_queried_no_majority_has_more_retries_raise_error() { + // 3 nodes queried: No majority, should return an error + let queried_nodes = 3; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewFullCoverage), + get_view(&ViewType::TwoNodesViewFullCoverage), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let topology_view = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ); + assert!(topology_view.is_err()); + } + + #[test] + fn test_topology_calculator_3_nodes_queried_no_majority_last_retry_success() { + // 3 nodes queried:: No majority, last retry, should get the view that has a full slot coverage + let queried_nodes = 3; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewMissingSlots), + get_view(&ViewType::TwoNodesViewFullCoverage), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 3, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res: Vec<_> = topology_view.values().collect(); + let node_1 = get_node_addr("node1", 6379); + let node_2 = get_node_addr("node2", 6380); + let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_2_nodes_queried_no_majority_return_full_slot_coverage_view() { + // 2 nodes queried: No majority, should get the view that has a full slot coverage + let queried_nodes = 2; + let topology_results = [ + get_view(&ViewType::TwoNodesViewFullCoverage), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res: Vec<_> = topology_view.values().collect(); + let node_1 = get_node_addr("node1", 6379); + let node_2 = get_node_addr("node2", 6380); + let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_2_nodes_queried_no_majority_no_full_coverage_prefer_fuller_coverage( + ) { + // 2 nodes queried: No majority, no full slot coverage, should return error + let queried_nodes = 2; + let topology_results = [ + get_view(&ViewType::SingleNodeViewMissingSlots), + get_view(&ViewType::TwoNodesViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res: Vec<_> = topology_view.values().collect(); + let node_1 = get_node_addr("node3", 6381); + let node_2 = get_node_addr("node4", 6382); + let expected: Vec<&SlotAddrs> = vec![&node_1, &node_2]; + assert_eq!(res, expected); + } + + #[test] + fn test_topology_calculator_3_nodes_queried_no_full_coverage_prefer_majority() { + // 2 nodes queried: No majority, no full slot coverage, should return error + let queried_nodes = 2; + let topology_results = vec![ + get_view(&ViewType::SingleNodeViewMissingSlots), + get_view(&ViewType::TwoNodesViewMissingSlots), + get_view(&ViewType::SingleNodeViewMissingSlots), + ]; + let (topology_view, _) = calculate_topology( + topology_results.iter().map(|(addr, value)| (*addr, value)), + 1, + None, + queried_nodes, + ReadFromReplicaStrategy::AlwaysFromPrimary, + ) + .unwrap(); + let res: Vec<_> = topology_view.values().collect(); + let node_1 = get_node_addr("node1", 6379); + let expected: Vec<&SlotAddrs> = vec![&node_1]; + assert_eq!(res, expected); + } +} diff --git a/glide-core/redis-rs/redis/src/cmd.rs b/glide-core/redis-rs/redis/src/cmd.rs new file mode 100644 index 0000000000..979bc7987b --- /dev/null +++ b/glide-core/redis-rs/redis/src/cmd.rs @@ -0,0 +1,663 @@ +#[cfg(feature = "aio")] +use futures_util::{ + future::BoxFuture, + task::{Context, Poll}, + Stream, StreamExt, +}; +#[cfg(feature = "aio")] +use std::pin::Pin; +use std::{fmt, io}; + +use crate::connection::ConnectionLike; +use crate::pipeline::Pipeline; +use crate::types::{from_owned_redis_value, FromRedisValue, RedisResult, RedisWrite, ToRedisArgs}; + +/// An argument to a redis command +#[derive(Clone)] +pub enum Arg { + /// A normal argument + Simple(D), + /// A cursor argument created from `cursor_arg()` + Cursor, +} + +/// Represents redis commands. +#[derive(Clone)] +pub struct Cmd { + data: Vec, + // Arg::Simple contains the offset that marks the end of the argument + args: Vec>, + cursor: Option, + // If it's true command's response won't be read from socket. Useful for Pub/Sub. + no_response: bool, +} + +/// Represents a redis iterator. +pub struct Iter<'a, T: FromRedisValue> { + batch: std::vec::IntoIter, + cursor: u64, + con: &'a mut (dyn ConnectionLike + 'a), + cmd: Cmd, +} + +impl<'a, T: FromRedisValue> Iterator for Iter<'a, T> { + type Item = T; + + #[inline] + fn next(&mut self) -> Option { + // we need to do this in a loop until we produce at least one item + // or we find the actual end of the iteration. This is necessary + // because with filtering an iterator it is possible that a whole + // chunk is not matching the pattern and thus yielding empty results. + loop { + if let Some(v) = self.batch.next() { + return Some(v); + }; + if self.cursor == 0 { + return None; + } + + let pcmd = self.cmd.get_packed_command_with_cursor(self.cursor)?; + let rv = self.con.req_packed_command(&pcmd).ok()?; + let (cur, batch): (u64, Vec) = from_owned_redis_value(rv).ok()?; + + self.cursor = cur; + self.batch = batch.into_iter(); + } + } +} + +#[cfg(feature = "aio")] +use crate::aio::ConnectionLike as AsyncConnection; + +/// The inner future of AsyncIter +#[cfg(feature = "aio")] +struct AsyncIterInner<'a, T: FromRedisValue + 'a> { + batch: std::vec::IntoIter, + con: &'a mut (dyn AsyncConnection + Send + 'a), + cmd: Cmd, +} + +/// Represents the state of AsyncIter +#[cfg(feature = "aio")] +enum IterOrFuture<'a, T: FromRedisValue + 'a> { + Iter(AsyncIterInner<'a, T>), + Future(BoxFuture<'a, (AsyncIterInner<'a, T>, Option)>), + Empty, +} + +/// Represents a redis iterator that can be used with async connections. +#[cfg(feature = "aio")] +pub struct AsyncIter<'a, T: FromRedisValue + 'a> { + inner: IterOrFuture<'a, T>, +} + +#[cfg(feature = "aio")] +impl<'a, T: FromRedisValue + 'a> AsyncIterInner<'a, T> { + #[inline] + pub async fn next_item(&mut self) -> Option { + // we need to do this in a loop until we produce at least one item + // or we find the actual end of the iteration. This is necessary + // because with filtering an iterator it is possible that a whole + // chunk is not matching the pattern and thus yielding empty results. + loop { + if let Some(v) = self.batch.next() { + return Some(v); + }; + if let Some(cursor) = self.cmd.cursor { + if cursor == 0 { + return None; + } + } else { + return None; + } + + let rv = self.con.req_packed_command(&self.cmd).await.ok()?; + let (cur, batch): (u64, Vec) = from_owned_redis_value(rv).ok()?; + + self.cmd.cursor = Some(cur); + self.batch = batch.into_iter(); + } + } +} + +#[cfg(feature = "aio")] +impl<'a, T: FromRedisValue + 'a + Unpin + Send> AsyncIter<'a, T> { + /// ```rust,no_run + /// # use redis::AsyncCommands; + /// # async fn scan_set() -> redis::RedisResult<()> { + /// # let client = redis::Client::open("redis://127.0.0.1/")?; + /// # let mut con = client.get_async_connection(None).await?; + /// con.sadd("my_set", 42i32).await?; + /// con.sadd("my_set", 43i32).await?; + /// let mut iter: redis::AsyncIter = con.sscan("my_set").await?; + /// while let Some(element) = iter.next_item().await { + /// assert!(element == 42 || element == 43); + /// } + /// # Ok(()) + /// # } + /// ``` + #[inline] + pub async fn next_item(&mut self) -> Option { + StreamExt::next(self).await + } +} + +#[cfg(feature = "aio")] +impl<'a, T: FromRedisValue + Unpin + Send + 'a> Stream for AsyncIter<'a, T> { + type Item = T; + + fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { + let this = self.get_mut(); + let inner = std::mem::replace(&mut this.inner, IterOrFuture::Empty); + match inner { + IterOrFuture::Iter(mut iter) => { + let fut = async move { + let next_item = iter.next_item().await; + (iter, next_item) + }; + this.inner = IterOrFuture::Future(Box::pin(fut)); + Pin::new(this).poll_next(cx) + } + IterOrFuture::Future(mut fut) => match fut.as_mut().poll(cx) { + Poll::Pending => { + this.inner = IterOrFuture::Future(fut); + Poll::Pending + } + Poll::Ready((iter, value)) => { + this.inner = IterOrFuture::Iter(iter); + Poll::Ready(value) + } + }, + IterOrFuture::Empty => unreachable!(), + } + } +} + +fn countdigits(mut v: usize) -> usize { + let mut result = 1; + loop { + if v < 10 { + return result; + } + if v < 100 { + return result + 1; + } + if v < 1000 { + return result + 2; + } + if v < 10000 { + return result + 3; + } + + v /= 10000; + result += 4; + } +} + +#[inline] +fn bulklen(len: usize) -> usize { + 1 + countdigits(len) + 2 + len + 2 +} + +fn args_len<'a, I>(args: I, cursor: u64) -> usize +where + I: IntoIterator> + ExactSizeIterator, +{ + let mut totlen = 1 + countdigits(args.len()) + 2; + for item in args { + totlen += bulklen(match item { + Arg::Cursor => countdigits(cursor as usize), + Arg::Simple(val) => val.len(), + }); + } + totlen +} + +pub(crate) fn cmd_len(cmd: &Cmd) -> usize { + args_len(cmd.args_iter(), cmd.cursor.unwrap_or(0)) +} + +fn encode_command<'a, I>(args: I, cursor: u64) -> Vec +where + I: IntoIterator> + Clone + ExactSizeIterator, +{ + let mut cmd = Vec::new(); + write_command_to_vec(&mut cmd, args, cursor); + cmd +} + +fn write_command_to_vec<'a, I>(cmd: &mut Vec, args: I, cursor: u64) +where + I: IntoIterator> + Clone + ExactSizeIterator, +{ + let totlen = args_len(args.clone(), cursor); + + cmd.reserve(totlen); + + write_command(cmd, args, cursor).unwrap() +} + +fn write_command<'a, I>(cmd: &mut (impl ?Sized + io::Write), args: I, cursor: u64) -> io::Result<()> +where + I: IntoIterator> + Clone + ExactSizeIterator, +{ + let mut buf = ::itoa::Buffer::new(); + + cmd.write_all(b"*")?; + let s = buf.format(args.len()); + cmd.write_all(s.as_bytes())?; + cmd.write_all(b"\r\n")?; + + let mut cursor_bytes = itoa::Buffer::new(); + for item in args { + let bytes = match item { + Arg::Cursor => cursor_bytes.format(cursor).as_bytes(), + Arg::Simple(val) => val, + }; + + cmd.write_all(b"$")?; + let s = buf.format(bytes.len()); + cmd.write_all(s.as_bytes())?; + cmd.write_all(b"\r\n")?; + + cmd.write_all(bytes)?; + cmd.write_all(b"\r\n")?; + } + Ok(()) +} + +impl RedisWrite for Cmd { + fn write_arg(&mut self, arg: &[u8]) { + self.data.extend_from_slice(arg); + self.args.push(Arg::Simple(self.data.len())); + } + + fn write_arg_fmt(&mut self, arg: impl fmt::Display) { + use std::io::Write; + write!(self.data, "{arg}").unwrap(); + self.args.push(Arg::Simple(self.data.len())); + } +} + +impl Default for Cmd { + fn default() -> Cmd { + Cmd::new() + } +} + +/// A command acts as a builder interface to creating encoded redis +/// requests. This allows you to easiy assemble a packed command +/// by chaining arguments together. +/// +/// Basic example: +/// +/// ```rust +/// redis::Cmd::new().arg("SET").arg("my_key").arg(42); +/// ``` +/// +/// There is also a helper function called `cmd` which makes it a +/// tiny bit shorter: +/// +/// ```rust +/// redis::cmd("SET").arg("my_key").arg(42); +/// ``` +/// +/// Because Rust currently does not have an ideal system +/// for lifetimes of temporaries, sometimes you need to hold on to +/// the initially generated command: +/// +/// ```rust,no_run +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let mut cmd = redis::cmd("SMEMBERS"); +/// let mut iter : redis::Iter = cmd.arg("my_set").clone().iter(&mut con).unwrap(); +/// ``` +impl Cmd { + /// Creates a new empty command. + pub fn new() -> Cmd { + Cmd { + data: vec![], + args: vec![], + cursor: None, + no_response: false, + } + } + + /// Creates a new empty command, with at least the requested capcity. + pub fn with_capacity(arg_count: usize, size_of_data: usize) -> Cmd { + Cmd { + data: Vec::with_capacity(size_of_data), + args: Vec::with_capacity(arg_count), + cursor: None, + no_response: false, + } + } + + /// Get the capacities for the internal buffers. + #[cfg(test)] + #[allow(dead_code)] + pub(crate) fn capacity(&self) -> (usize, usize) { + (self.args.capacity(), self.data.capacity()) + } + + /// Appends an argument to the command. The argument passed must + /// be a type that implements `ToRedisArgs`. Most primitive types as + /// well as vectors of primitive types implement it. + /// + /// For instance all of the following are valid: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// redis::cmd("SET").arg(&["my_key", "my_value"]); + /// redis::cmd("SET").arg("my_key").arg(42); + /// redis::cmd("SET").arg("my_key").arg(b"my_value"); + /// ``` + #[inline] + pub fn arg(&mut self, arg: T) -> &mut Cmd { + arg.write_redis_args(self); + self + } + + /// Works similar to `arg` but adds a cursor argument. This is always + /// an integer and also flips the command implementation to support a + /// different mode for the iterators where the iterator will ask for + /// another batch of items when the local data is exhausted. + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let mut cmd = redis::cmd("SSCAN"); + /// let mut iter : redis::Iter = + /// cmd.arg("my_set").cursor_arg(0).clone().iter(&mut con).unwrap(); + /// for x in iter { + /// // do something with the item + /// } + /// ``` + #[inline] + pub fn cursor_arg(&mut self, cursor: u64) -> &mut Cmd { + assert!(!self.in_scan_mode()); + self.cursor = Some(cursor); + self.args.push(Arg::Cursor); + self + } + + /// Returns the packed command as a byte vector. + #[inline] + pub fn get_packed_command(&self) -> Vec { + let mut cmd = Vec::new(); + self.write_packed_command(&mut cmd); + cmd + } + + pub(crate) fn write_packed_command(&self, cmd: &mut Vec) { + write_command_to_vec(cmd, self.args_iter(), self.cursor.unwrap_or(0)) + } + + pub(crate) fn write_packed_command_preallocated(&self, cmd: &mut Vec) { + write_command(cmd, self.args_iter(), self.cursor.unwrap_or(0)).unwrap() + } + + /// Like `get_packed_command` but replaces the cursor with the + /// provided value. If the command is not in scan mode, `None` + /// is returned. + #[inline] + fn get_packed_command_with_cursor(&self, cursor: u64) -> Option> { + if !self.in_scan_mode() { + None + } else { + Some(encode_command(self.args_iter(), cursor)) + } + } + + /// Returns true if the command is in scan mode. + #[inline] + pub fn in_scan_mode(&self) -> bool { + self.cursor.is_some() + } + + /// Sends the command as query to the connection and converts the + /// result to the target redis value. This is the general way how + /// you can retrieve data. + #[inline] + pub fn query(&self, con: &mut dyn ConnectionLike) -> RedisResult { + match con.req_command(self) { + Ok(val) => from_owned_redis_value(val), + Err(e) => Err(e), + } + } + + /// Async version of `query`. + #[inline] + #[cfg(feature = "aio")] + pub async fn query_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let val = con.req_packed_command(self).await?; + from_owned_redis_value(val) + } + + /// Similar to `query()` but returns an iterator over the items of the + /// bulk result or iterator. In normal mode this is not in any way more + /// efficient than just querying into a `Vec` as it's internally + /// implemented as buffering into a vector. This however is useful when + /// `cursor_arg` was used in which case the iterator will query for more + /// items until the server side cursor is exhausted. + /// + /// This is useful for commands such as `SSCAN`, `SCAN` and others. + /// + /// One speciality of this function is that it will check if the response + /// looks like a cursor or not and always just looks at the payload. + /// This way you can use the function the same for responses in the + /// format of `KEYS` (just a list) as well as `SSCAN` (which returns a + /// tuple of cursor and list). + #[inline] + pub fn iter(self, con: &mut dyn ConnectionLike) -> RedisResult> { + let rv = con.req_command(&self)?; + + let (cursor, batch) = if rv.looks_like_cursor() { + from_owned_redis_value::<(u64, Vec)>(rv)? + } else { + (0, from_owned_redis_value(rv)?) + }; + + Ok(Iter { + batch: batch.into_iter(), + cursor, + con, + cmd: self, + }) + } + + /// Similar to `iter()` but returns an AsyncIter over the items of the + /// bulk result or iterator. A [futures::Stream](https://docs.rs/futures/0.3.3/futures/stream/trait.Stream.html) + /// is implemented on AsyncIter. In normal mode this is not in any way more + /// efficient than just querying into a `Vec` as it's internally + /// implemented as buffering into a vector. This however is useful when + /// `cursor_arg` was used in which case the stream will query for more + /// items until the server side cursor is exhausted. + /// + /// This is useful for commands such as `SSCAN`, `SCAN` and others in async contexts. + /// + /// One speciality of this function is that it will check if the response + /// looks like a cursor or not and always just looks at the payload. + /// This way you can use the function the same for responses in the + /// format of `KEYS` (just a list) as well as `SSCAN` (which returns a + /// tuple of cursor and list). + #[cfg(feature = "aio")] + #[inline] + pub async fn iter_async<'a, T: FromRedisValue + 'a>( + mut self, + con: &'a mut (dyn AsyncConnection + Send), + ) -> RedisResult> { + let rv = con.req_packed_command(&self).await?; + + let (cursor, batch) = if rv.looks_like_cursor() { + from_owned_redis_value::<(u64, Vec)>(rv)? + } else { + (0, from_owned_redis_value(rv)?) + }; + if cursor == 0 { + self.cursor = None; + } else { + self.cursor = Some(cursor); + } + + Ok(AsyncIter { + inner: IterOrFuture::Iter(AsyncIterInner { + batch: batch.into_iter(), + con, + cmd: self, + }), + }) + } + + /// This is a shortcut to `query()` that does not return a value and + /// will fail the task if the query fails because of an error. This is + /// mainly useful in examples and for simple commands like setting + /// keys. + /// + /// This is equivalent to a call of query like this: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let _ : () = redis::cmd("PING").query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn execute(&self, con: &mut dyn ConnectionLike) { + self.query::<()>(con).unwrap(); + } + + /// Returns an iterator over the arguments in this command (including the command name itself) + pub fn args_iter(&self) -> impl Clone + ExactSizeIterator> { + let mut prev = 0; + self.args.iter().map(move |arg| match *arg { + Arg::Simple(i) => { + let arg = Arg::Simple(&self.data[prev..i]); + prev = i; + arg + } + + Arg::Cursor => Arg::Cursor, + }) + } + + // Get a reference to the argument at `idx` + #[cfg(feature = "cluster")] + pub(crate) fn arg_idx(&self, idx: usize) -> Option<&[u8]> { + if idx >= self.args.len() { + return None; + } + + let start = if idx == 0 { + 0 + } else { + match self.args[idx - 1] { + Arg::Simple(n) => n, + _ => 0, + } + }; + let end = match self.args[idx] { + Arg::Simple(n) => n, + _ => 0, + }; + if start == 0 && end == 0 { + return None; + } + Some(&self.data[start..end]) + } + + /// Client won't read and wait for results. Currently only used for Pub/Sub commands in RESP3. + #[inline] + pub fn set_no_response(&mut self, nr: bool) -> &mut Cmd { + self.no_response = nr; + self + } + + /// Check whether command's result will be waited for. + #[inline] + pub fn is_no_response(&self) -> bool { + self.no_response + } +} + +impl fmt::Debug for Cmd { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + let res = self + .args_iter() + .map(|arg| { + let bytes = match arg { + Arg::Cursor => b"", + Arg::Simple(val) => val, + }; + std::str::from_utf8(bytes).unwrap_or_default() + }) + .collect::>(); + f.debug_struct("Cmd").field("args", &res).finish() + } +} + +/// Shortcut function to creating a command with a single argument. +/// +/// The first argument of a redis command is always the name of the command +/// which needs to be a string. This is the recommended way to start a +/// command pipe. +/// +/// ```rust +/// redis::cmd("PING"); +/// ``` +pub fn cmd(name: &str) -> Cmd { + let mut rv = Cmd::new(); + rv.arg(name); + rv +} + +/// Packs a bunch of commands into a request. This is generally a quite +/// useless function as this functionality is nicely wrapped through the +/// `Cmd` object, but in some cases it can be useful. The return value +/// of this can then be send to the low level `ConnectionLike` methods. +/// +/// Example: +/// +/// ```rust +/// # use redis::ToRedisArgs; +/// let mut args = vec![]; +/// args.extend("SET".to_redis_args()); +/// args.extend("my_key".to_redis_args()); +/// args.extend(42.to_redis_args()); +/// let cmd = redis::pack_command(&args); +/// assert_eq!(cmd, b"*3\r\n$3\r\nSET\r\n$6\r\nmy_key\r\n$2\r\n42\r\n".to_vec()); +/// ``` +pub fn pack_command(args: &[Vec]) -> Vec { + encode_command(args.iter().map(|x| Arg::Simple(&x[..])), 0) +} + +/// Shortcut for creating a new pipeline. +pub fn pipe() -> Pipeline { + Pipeline::new() +} + +#[cfg(test)] +#[cfg(feature = "cluster")] +mod tests { + use super::Cmd; + + #[test] + fn test_cmd_arg_idx() { + let mut c = Cmd::new(); + assert_eq!(c.arg_idx(0), None); + + c.arg("SET"); + assert_eq!(c.arg_idx(0), Some(&b"SET"[..])); + assert_eq!(c.arg_idx(1), None); + + c.arg("foo").arg("42"); + assert_eq!(c.arg_idx(1), Some(&b"foo"[..])); + assert_eq!(c.arg_idx(2), Some(&b"42"[..])); + assert_eq!(c.arg_idx(3), None); + assert_eq!(c.arg_idx(4), None); + } +} diff --git a/glide-core/redis-rs/redis/src/commands/cluster_scan.rs b/glide-core/redis-rs/redis/src/commands/cluster_scan.rs new file mode 100644 index 0000000000..97f10577ac --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/cluster_scan.rs @@ -0,0 +1,720 @@ +use crate::aio::ConnectionLike; +use crate::cluster_async::{ + ClusterConnInner, Connect, Core, InternalRoutingInfo, InternalSingleNodeRouting, RefreshPolicy, + Response, +}; +use crate::cluster_routing::SlotAddr; +use crate::cluster_topology::SLOT_SIZE; +use crate::{cmd, from_redis_value, Cmd, ErrorKind, RedisError, RedisResult, Value}; +use async_trait::async_trait; +use std::sync::Arc; +use strum_macros::Display; + +/// This module contains the implementation of scanning operations in a Redis cluster. +/// +/// The [`ClusterScanArgs`] struct represents the arguments for a cluster scan operation, +/// including the scan state reference, match pattern, count, and object type. +/// +/// The [[`ScanStateRC`]] struct is a wrapper for managing the state of a scan operation in a cluster. +/// It holds a reference to the scan state and provides methods for accessing the state. +/// +/// The [[`ClusterInScan`]] trait defines the methods for interacting with a Redis cluster during scanning, +/// including retrieving address information, refreshing slot mapping, and routing commands to specific address. +/// +/// The [[`ScanState`]] struct represents the state of a scan operation in a Redis cluster. +/// It holds information about the current scan state, including the cursor position, scanned slots map, +/// address being scanned, and address's epoch. + +const BITS_PER_U64: usize = u64::BITS as usize; +const NUM_OF_SLOTS: usize = SLOT_SIZE as usize; +const BITS_ARRAY_SIZE: usize = NUM_OF_SLOTS / BITS_PER_U64; +const END_OF_SCAN: u16 = NUM_OF_SLOTS as u16 + 1; +type SlotsBitsArray = [u64; BITS_ARRAY_SIZE]; + +#[derive(Clone)] +pub(crate) struct ClusterScanArgs { + pub(crate) scan_state_cursor: ScanStateRC, + match_pattern: Option>, + count: Option, + object_type: Option, +} + +#[derive(Debug, Clone, Display)] +/// Represents the type of an object in Redis. +pub enum ObjectType { + /// Represents a string object in Redis. + String, + /// Represents a list object in Redis. + List, + /// Represents a set object in Redis. + Set, + /// Represents a sorted set object in Redis. + ZSet, + /// Represents a hash object in Redis. + Hash, + /// Represents a stream object in Redis. + Stream, +} + +impl ClusterScanArgs { + pub(crate) fn new( + scan_state_cursor: ScanStateRC, + match_pattern: Option>, + count: Option, + object_type: Option, + ) -> Self { + Self { + scan_state_cursor, + match_pattern, + count, + object_type, + } + } +} + +#[derive(PartialEq, Debug, Clone, Default)] +pub enum ScanStateStage { + #[default] + Initiating, + InProgress, + Finished, +} + +#[derive(Debug, Clone, Default)] +/// A wrapper struct for managing the state of a scan operation in a cluster. +/// It holds a reference to the scan state and provides methods for accessing the state. +/// The `status` field indicates the status of the scan operation. +pub struct ScanStateRC { + scan_state_rc: Arc>, + status: ScanStateStage, +} + +impl ScanStateRC { + /// Creates a new instance of [`ScanStateRC`] from a given [`ScanState`]. + fn from_scan_state(scan_state: ScanState) -> Self { + Self { + scan_state_rc: Arc::new(Some(scan_state)), + status: ScanStateStage::InProgress, + } + } + + /// Creates a new instance of [`ScanStateRC`]. + /// + /// This method initializes the [`ScanStateRC`] with a reference to a [`ScanState`] that is initially set to `None`. + /// An empty ScanState is equivalent to a 0 cursor. + pub fn new() -> Self { + Self { + scan_state_rc: Arc::new(None), + status: ScanStateStage::Initiating, + } + } + /// create a new instance of [`ScanStateRC`] with finished state and empty scan state. + fn create_finished() -> Self { + Self { + scan_state_rc: Arc::new(None), + status: ScanStateStage::Finished, + } + } + /// Returns `true` if the scan state is finished. + pub fn is_finished(&self) -> bool { + self.status == ScanStateStage::Finished + } + + /// Returns a clone of the scan state, if it exist. + pub(crate) fn get_state_from_wrapper(&self) -> Option { + if self.status == ScanStateStage::Initiating || self.status == ScanStateStage::Finished { + None + } else { + self.scan_state_rc.as_ref().clone() + } + } +} + +/// This trait defines the methods for interacting with a Redis cluster during scanning. +#[async_trait] +pub(crate) trait ClusterInScan { + /// Retrieves the address associated with a given slot in the cluster. + async fn get_address_by_slot(&self, slot: u16) -> RedisResult; + + /// Retrieves the epoch of a given address in the cluster. + /// The epoch represents the version of the address, which is updated when a failover occurs or slots migrate in. + async fn get_address_epoch(&self, address: &str) -> Result; + + /// Retrieves the slots assigned to a given address in the cluster. + async fn get_slots_of_address(&self, address: &str) -> Vec; + + /// Routes a Redis command to a specific address in the cluster. + async fn route_command(&self, cmd: Cmd, address: &str) -> RedisResult; + + /// Check if all slots are covered by the cluster + async fn are_all_slots_covered(&self) -> bool; + + /// Check if the topology of the cluster has changed and refresh the slots if needed + async fn refresh_if_topology_changed(&self); +} + +/// Represents the state of a scan operation in a Redis cluster. +/// +/// This struct holds information about the current scan state, including the cursor position, +/// the scanned slots map, the address being scanned, and the address's epoch. +#[derive(PartialEq, Debug, Clone)] +pub(crate) struct ScanState { + // the real cursor in the scan operation + cursor: u64, + // a map of the slots that have been scanned + scanned_slots_map: SlotsBitsArray, + // the address that is being scanned currently, based on the next slot set to 0 in the scanned_slots_map, and the address that "owns" the slot + // in the SlotMap + pub(crate) address_in_scan: String, + // epoch represent the version of the address, when a failover happens or slots migrate in the epoch will be updated to +1 + address_epoch: u64, + // the status of the scan operation + scan_status: ScanStateStage, +} + +impl ScanState { + /// Create a new instance of ScanState. + /// + /// # Arguments + /// + /// * `cursor` - The cursor position. + /// * `scanned_slots_map` - The scanned slots map. + /// * `address_in_scan` - The address being scanned. + /// * `address_epoch` - The epoch of the address being scanned. + /// * `scan_status` - The status of the scan operation. + /// + /// # Returns + /// + /// A new instance of ScanState. + pub fn new( + cursor: u64, + scanned_slots_map: SlotsBitsArray, + address_in_scan: String, + address_epoch: u64, + scan_status: ScanStateStage, + ) -> Self { + Self { + cursor, + scanned_slots_map, + address_in_scan, + address_epoch, + scan_status, + } + } + + fn create_finished_state() -> Self { + Self { + cursor: 0, + scanned_slots_map: [0; BITS_ARRAY_SIZE], + address_in_scan: String::new(), + address_epoch: 0, + scan_status: ScanStateStage::Finished, + } + } + + /// Initialize a new scan operation. + /// This method creates a new scan state with the cursor set to 0, the scanned slots map initialized to 0, + /// and the address set to the address associated with slot 0. + /// The address epoch is set to the epoch of the address. + /// If the address epoch cannot be retrieved, the method returns an error. + async fn initiate_scan(connection: &C) -> RedisResult { + let new_scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + let new_cursor = 0; + let address = connection.get_address_by_slot(0).await?; + let address_epoch = connection.get_address_epoch(&address).await.unwrap_or(0); + Ok(ScanState::new( + new_cursor, + new_scanned_slots_map, + address, + address_epoch, + ScanStateStage::InProgress, + )) + } + + /// Get the next slot to be scanned based on the scanned slots map. + /// If all slots have been scanned, the method returns [`END_OF_SCAN`]. + fn get_next_slot(&self, scanned_slots_map: &SlotsBitsArray) -> Option { + let all_slots_scanned = scanned_slots_map.iter().all(|&word| word == u64::MAX); + if all_slots_scanned { + return Some(END_OF_SCAN); + } + for (i, slot) in scanned_slots_map.iter().enumerate() { + let mut mask = 1; + for j in 0..BITS_PER_U64 { + if (slot & mask) == 0 { + return Some((i * BITS_PER_U64 + j) as u16); + } + mask <<= 1; + } + } + None + } + + /// Update the scan state without updating the scanned slots map. + /// This method is used when the address epoch has changed, and we can't determine which slots are new. + /// In this case, we skip updating the scanned slots map and only update the address and cursor. + async fn creating_state_without_slot_changes( + &self, + connection: &C, + ) -> RedisResult { + let next_slot = self.get_next_slot(&self.scanned_slots_map).unwrap_or(0); + let new_address = if next_slot == END_OF_SCAN { + return Ok(ScanState::create_finished_state()); + } else { + connection.get_address_by_slot(next_slot).await + }; + match new_address { + Ok(address) => { + let new_epoch = connection.get_address_epoch(&address).await.unwrap_or(0); + Ok(ScanState::new( + 0, + self.scanned_slots_map, + address, + new_epoch, + ScanStateStage::InProgress, + )) + } + Err(err) => Err(err), + } + } + + /// Update the scan state and get the next address to scan. + /// This method is called when the cursor reaches 0, indicating that the current address has been scanned. + /// This method updates the scan state based on the scanned slots map and retrieves the next address to scan. + /// If the address epoch has changed, the method skips updating the scanned slots map and only updates the address and cursor. + /// If the address epoch has not changed, the method updates the scanned slots map with the slots owned by the address. + /// The method returns the new scan state with the updated cursor, scanned slots map, address, and epoch. + async fn create_updated_scan_state_for_completed_address( + &mut self, + connection: &C, + ) -> RedisResult { + let _ = connection.refresh_if_topology_changed().await; + let mut scanned_slots_map = self.scanned_slots_map; + // If the address epoch changed it mean that some slots in the address are new, so we cant know which slots been there from the beginning and which are new, or out and in later. + // In this case we will skip updating the scanned_slots_map and will just update the address and the cursor + let new_address_epoch = connection + .get_address_epoch(&self.address_in_scan) + .await + .unwrap_or(0); + if new_address_epoch != self.address_epoch { + return self.creating_state_without_slot_changes(connection).await; + } + // If epoch wasn't changed, the slots owned by the address after the refresh are all valid as slots that been scanned + // So we will update the scanned_slots_map with the slots owned by the address + let slots_scanned = connection.get_slots_of_address(&self.address_in_scan).await; + for slot in slots_scanned { + let slot_index = slot as usize / BITS_PER_U64; + let slot_bit = slot as usize % BITS_PER_U64; + scanned_slots_map[slot_index] |= 1 << slot_bit; + } + // Get the next address to scan and its param base on the next slot set to 0 in the scanned_slots_map + let next_slot = self.get_next_slot(&scanned_slots_map).unwrap_or(0); + let new_address = if next_slot == END_OF_SCAN { + return Ok(ScanState::create_finished_state()); + } else { + connection.get_address_by_slot(next_slot).await + }; + match new_address { + Ok(new_address) => { + let new_epoch = connection + .get_address_epoch(&new_address) + .await + .unwrap_or(0); + let new_cursor = 0; + Ok(ScanState::new( + new_cursor, + scanned_slots_map, + new_address, + new_epoch, + ScanStateStage::InProgress, + )) + } + Err(err) => Err(err), + } + } +} + +// Implement the [`ClusterInScan`] trait for [`InnerCore`] of async cluster connection. +#[async_trait] +impl ClusterInScan for Core +where + C: ConnectionLike + Connect + Clone + Send + Sync + 'static, +{ + async fn get_address_by_slot(&self, slot: u16) -> RedisResult { + let address = self + .get_address_from_slot(slot, SlotAddr::ReplicaRequired) + .await; + match address { + Some(addr) => Ok(addr), + None => { + if self.are_all_slots_covered().await { + Err(RedisError::from(( + ErrorKind::IoError, + "Failed to get connection to the node cover the slot, please check the cluster configuration ", + ))) + } else { + Err(RedisError::from(( + ErrorKind::NotAllSlotsCovered, + "All slots are not covered by the cluster, please check the cluster configuration ", + ))) + } + } + } + } + + async fn get_address_epoch(&self, address: &str) -> Result { + self.as_ref().get_address_epoch(address).await + } + async fn get_slots_of_address(&self, address: &str) -> Vec { + self.as_ref().get_slots_of_address(address).await + } + async fn route_command(&self, cmd: Cmd, address: &str) -> RedisResult { + let routing = InternalRoutingInfo::SingleNode(InternalSingleNodeRouting::ByAddress( + address.to_string(), + )); + let core = self.to_owned(); + let response = ClusterConnInner::::try_cmd_request(Arc::new(cmd), routing, core) + .await + .map_err(|err| err.1)?; + match response { + Response::Single(value) => Ok(value), + _ => Err(RedisError::from(( + ErrorKind::ClientError, + "Expected single response, got unexpected response", + ))), + } + } + async fn are_all_slots_covered(&self) -> bool { + ClusterConnInner::::check_if_all_slots_covered(&self.conn_lock.read().await.slot_map) + } + async fn refresh_if_topology_changed(&self) { + ClusterConnInner::check_topology_and_refresh_if_diff( + self.to_owned(), + // The cluster SCAN implementation must refresh the slots when a topology change is found + // to ensure the scan logic is correct. + &RefreshPolicy::NotThrottable, + ) + .await; + } +} + +/// Perform a cluster scan operation. +/// This function performs a scan operation in a Redis cluster using the given [`ClusterInScan`] connection. +/// It scans the cluster for keys based on the given `ClusterScanArgs` arguments. +/// The function returns a tuple containing the new scan state cursor and the keys found in the scan operation. +/// If the scan operation fails, an error is returned. +/// +/// # Arguments +/// * `core` - The connection to the Redis cluster. +/// * `cluster_scan_args` - The arguments for the cluster scan operation. +/// +/// # Returns +/// A tuple containing the new scan state cursor and the keys found in the scan operation. +/// If the scan operation fails, an error is returned. +pub(crate) async fn cluster_scan( + core: C, + cluster_scan_args: ClusterScanArgs, +) -> RedisResult<(ScanStateRC, Vec)> +where + C: ClusterInScan, +{ + let ClusterScanArgs { + scan_state_cursor, + match_pattern, + count, + object_type, + } = cluster_scan_args; + // If scan_state is None, meaning we start a new scan + let scan_state = match scan_state_cursor.get_state_from_wrapper() { + Some(state) => state, + None => match ScanState::initiate_scan(&core).await { + Ok(state) => state, + Err(err) => { + return Err(err); + } + }, + }; + // Send the actual scan command to the address in the scan_state + let scan_result = send_scan( + &scan_state, + &core, + match_pattern.clone(), + count, + object_type.clone(), + ) + .await; + let ((new_cursor, new_keys), mut scan_state): ((u64, Vec), ScanState) = match scan_result + { + Ok(scan_result) => (from_redis_value(&scan_result)?, scan_state.clone()), + Err(err) => match err.kind() { + // If the scan command failed to route to the address because the address is not found in the cluster or + // the connection to the address cant be reached from different reasons, we will check we want to check if + // the problem is problem that we can recover from like failover or scale down or some network issue + // that we can retry the scan command to an address that own the next slot we are at. + ErrorKind::IoError + | ErrorKind::AllConnectionsUnavailable + | ErrorKind::ConnectionNotFoundForRoute => { + let retry = + retry_scan(&scan_state, &core, match_pattern, count, object_type).await?; + (from_redis_value(&retry.0?)?, retry.1) + } + _ => return Err(err), + }, + }; + + // If the cursor is 0, meaning we finished scanning the address + // we will update the scan state to get the next address to scan + if new_cursor == 0 { + scan_state = scan_state + .create_updated_scan_state_for_completed_address(&core) + .await?; + } + + // If the address is empty, meaning we finished scanning all the address + if scan_state.scan_status == ScanStateStage::Finished { + return Ok((ScanStateRC::create_finished(), new_keys)); + } + + scan_state = ScanState::new( + new_cursor, + scan_state.scanned_slots_map, + scan_state.address_in_scan, + scan_state.address_epoch, + ScanStateStage::InProgress, + ); + Ok((ScanStateRC::from_scan_state(scan_state), new_keys)) +} + +// Send the scan command to the address in the scan_state +async fn send_scan( + scan_state: &ScanState, + core: &C, + match_pattern: Option>, + count: Option, + object_type: Option, +) -> RedisResult +where + C: ClusterInScan, +{ + let mut scan_command = cmd("SCAN"); + scan_command.arg(scan_state.cursor); + if let Some(match_pattern) = match_pattern { + scan_command.arg("MATCH").arg(match_pattern); + } + if let Some(count) = count { + scan_command.arg("COUNT").arg(count); + } + if let Some(object_type) = object_type { + scan_command.arg("TYPE").arg(object_type.to_string()); + } + + core.route_command(scan_command, &scan_state.address_in_scan) + .await +} + +// If the scan command failed to route to the address we will check we will first refresh the slots, we will check if all slots are covered by cluster, +// and if so we will try to get a new address to scan for handling case of failover. +// if all slots are not covered by the cluster we will return an error indicating that the cluster is not well configured. +// if all slots are covered by cluster but we failed to get a new address to scan we will return an error indicating that we failed to get a new address to scan. +// if we got a new address to scan but the scan command failed to route to the address we will return an error indicating that we failed to route the command. +async fn retry_scan( + scan_state: &ScanState, + core: &C, + match_pattern: Option>, + count: Option, + object_type: Option, +) -> RedisResult<(RedisResult, ScanState)> +where + C: ClusterInScan, +{ + // TODO: This mechanism of refreshing on failure to route to address should be part of the routing mechanism + // After the routing mechanism is updated to handle this case, this refresh in the case bellow should be removed + core.refresh_if_topology_changed().await; + if !core.are_all_slots_covered().await { + return Err(RedisError::from(( + ErrorKind::NotAllSlotsCovered, + "Not all slots are covered by the cluster, please check the cluster configuration", + ))); + } + // If for some reason we failed to reach the address we don't know if its a scale down or a failover. + // Since it might be scale down we cant just keep going with the current state we the same cursor as we are at + // the same point in the new address, so we need to get the new address own the next slot that haven't been scanned + // and start from the beginning of this address. + let next_slot = scan_state + .get_next_slot(&scan_state.scanned_slots_map) + .unwrap_or(0); + let address = core.get_address_by_slot(next_slot).await?; + + let new_epoch = core.get_address_epoch(&address).await.unwrap_or(0); + let scan_state = &ScanState::new( + 0, + scan_state.scanned_slots_map, + address, + new_epoch, + ScanStateStage::InProgress, + ); + let res = ( + send_scan(scan_state, core, match_pattern, count, object_type).await, + scan_state.clone(), + ); + Ok(res) +} + +#[cfg(test)] +mod tests { + + use super::*; + + #[test] + fn test_creation_of_empty_scan_wrapper() { + let scan_state_wrapper = ScanStateRC::new(); + assert!(scan_state_wrapper.status == ScanStateStage::Initiating); + } + + #[test] + fn test_creation_of_scan_state_wrapper_from() { + let scan_state = ScanState { + cursor: 0, + scanned_slots_map: [0; BITS_ARRAY_SIZE], + address_in_scan: String::from("address1"), + address_epoch: 1, + scan_status: ScanStateStage::InProgress, + }; + + let scan_state_wrapper = ScanStateRC::from_scan_state(scan_state); + assert!(!scan_state_wrapper.is_finished()); + } + + #[test] + // Test the get_next_slot method + fn test_scan_state_get_next_slot() { + let scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + let scan_state = ScanState { + cursor: 0, + scanned_slots_map, + address_in_scan: String::from("address1"), + address_epoch: 1, + scan_status: ScanStateStage::InProgress, + }; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(0)); + // Set the first slot to 1 + let mut scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + scanned_slots_map[0] = 1; + let scan_state = ScanState { + cursor: 0, + scanned_slots_map, + address_in_scan: String::from("address1"), + address_epoch: 1, + scan_status: ScanStateStage::InProgress, + }; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(1)); + } + // Create a mock connection + struct MockConnection; + #[async_trait] + impl ClusterInScan for MockConnection { + async fn refresh_if_topology_changed(&self) {} + async fn get_address_by_slot(&self, _slot: u16) -> RedisResult { + Ok("mock_address".to_string()) + } + async fn get_address_epoch(&self, _address: &str) -> Result { + Ok(0) + } + async fn get_slots_of_address(&self, address: &str) -> Vec { + if address == "mock_address" { + vec![3, 4, 5] + } else { + vec![0, 1, 2] + } + } + async fn route_command(&self, _: Cmd, _: &str) -> RedisResult { + unimplemented!() + } + async fn are_all_slots_covered(&self) -> bool { + true + } + } + // Test the initiate_scan function + #[tokio::test] + async fn test_initiate_scan() { + let connection = MockConnection; + let scan_state = ScanState::initiate_scan(&connection).await.unwrap(); + + // Assert that the scan state is initialized correctly + assert_eq!(scan_state.cursor, 0); + assert_eq!(scan_state.scanned_slots_map, [0; BITS_ARRAY_SIZE]); + assert_eq!(scan_state.address_in_scan, "mock_address"); + assert_eq!(scan_state.address_epoch, 0); + } + + // Test the get_next_slot function + #[test] + fn test_get_next_slot() { + let scan_state = ScanState { + cursor: 0, + scanned_slots_map: [0; BITS_ARRAY_SIZE], + address_in_scan: "".to_string(), + address_epoch: 0, + scan_status: ScanStateStage::InProgress, + }; + // Test when all first bits of each u6 are set to 1, the next slots should be 1 + let scanned_slots_map: SlotsBitsArray = [1; BITS_ARRAY_SIZE]; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(1)); + + // Test when all slots are scanned, the next slot should be 0 + let scanned_slots_map: SlotsBitsArray = [u64::MAX; BITS_ARRAY_SIZE]; + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(16385)); + + // Test when first, second, fourth, sixth and eighth slots scanned, the next slot should be 2 + let mut scanned_slots_map: SlotsBitsArray = [0; BITS_ARRAY_SIZE]; + scanned_slots_map[0] = 171; // 10101011 + let next_slot = scan_state.get_next_slot(&scanned_slots_map); + assert_eq!(next_slot, Some(2)); + } + + // Test the update_scan_state_and_get_next_address function + #[tokio::test] + async fn test_update_scan_state_and_get_next_address() { + let connection = MockConnection; + let scan_state = ScanState::initiate_scan(&connection).await; + let updated_scan_state = scan_state + .unwrap() + .create_updated_scan_state_for_completed_address(&connection) + .await + .unwrap(); + + // cursor should be reset to 0 + assert_eq!(updated_scan_state.cursor, 0); + + // address_in_scan should be updated to the new address + assert_eq!(updated_scan_state.address_in_scan, "mock_address"); + + // address_epoch should be updated to the new address epoch + assert_eq!(updated_scan_state.address_epoch, 0); + } + + #[tokio::test] + async fn test_update_scan_state_without_updating_scanned_map() { + let connection = MockConnection; + let scan_state = ScanState::new( + 0, + [0; BITS_ARRAY_SIZE], + "address".to_string(), + 0, + ScanStateStage::InProgress, + ); + let scanned_slots_map = scan_state.scanned_slots_map; + let updated_scan_state = scan_state + .creating_state_without_slot_changes(&connection) + .await + .unwrap(); + assert_eq!(updated_scan_state.scanned_slots_map, scanned_slots_map); + assert_eq!(updated_scan_state.cursor, 0); + assert_eq!(updated_scan_state.address_in_scan, "mock_address"); + assert_eq!(updated_scan_state.address_epoch, 0); + } +} diff --git a/glide-core/redis-rs/redis/src/commands/json.rs b/glide-core/redis-rs/redis/src/commands/json.rs new file mode 100644 index 0000000000..d63f70c86f --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/json.rs @@ -0,0 +1,390 @@ +use crate::cmd::{cmd, Cmd}; +use crate::connection::ConnectionLike; +use crate::pipeline::Pipeline; +use crate::types::{FromRedisValue, RedisResult, ToRedisArgs}; +use crate::RedisError; + +#[cfg(feature = "cluster")] +use crate::commands::ClusterPipeline; + +use serde::ser::Serialize; + +macro_rules! implement_json_commands { + ( + $lifetime: lifetime + $( + $(#[$attr:meta])+ + fn $name:ident<$($tyargs:ident : $ty:ident),*>( + $($argname:ident: $argty:ty),*) $body:block + )* + ) => ( + + /// Implements RedisJSON commands for connection like objects. This + /// allows you to send commands straight to a connection or client. It + /// is also implemented for redis results of clients which makes for + /// very convenient access in some basic cases. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// use redis::JsonCommands; + /// use serde_json::json; + /// # fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// redis::cmd("JSON.SET").arg("my_key").arg("$").arg(&json!({"item": 42i32}).to_string()).execute(&mut con); + /// assert_eq!(redis::cmd("JSON.GET").arg("my_key").arg("$").query(&mut con), Ok(String::from(r#"[{"item":42}]"#))); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// use redis::JsonCommands; + /// use serde_json::json; + /// # fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// con.json_set("my_key", "$", &json!({"item": 42i32}).to_string())?; + /// assert_eq!(con.json_get("my_key", "$"), Ok(String::from(r#"[{"item":42}]"#))); + /// assert_eq!(con.json_get("my_key", "$.item"), Ok(String::from(r#"[42]"#))); + /// # Ok(()) } + /// ``` + /// + /// With RedisJSON commands, you have to note that all results will be wrapped + /// in square brackets (or empty brackets if not found). If you want to deserialize it + /// with e.g. `serde_json` you have to use `Vec` for your output type instead of `T`. + pub trait JsonCommands : ConnectionLike + Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty, )* RV: FromRedisValue>( + &mut self $(, $argname: $argty)*) -> RedisResult + { Cmd::$name($($argname),*)?.query(self) } + )* + } + + impl Cmd { + $( + $(#[$attr])* + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>($($argname: $argty),*) -> RedisResult { + $body + } + )* + } + + /// Implements RedisJSON commands over asynchronous connections. This + /// allows you to send commands straight to a connection or client. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// use redis::JsonAsyncCommands; + /// use serde_json::json; + /// # async fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// redis::cmd("JSON.SET").arg("my_key").arg("$").arg(&json!({"item": 42i32}).to_string()).query_async(&mut con).await?; + /// assert_eq!(redis::cmd("JSON.GET").arg("my_key").arg("$").query_async(&mut con).await, Ok(String::from(r#"[{"item":42}]"#))); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// use redis::JsonAsyncCommands; + /// use serde_json::json; + /// # async fn do_something() -> redis::RedisResult<()> { + /// use redis::Commands; + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// con.json_set("my_key", "$", &json!({"item": 42i32}).to_string()).await?; + /// assert_eq!(con.json_get("my_key", "$").await, Ok(String::from(r#"[{"item":42}]"#))); + /// assert_eq!(con.json_get("my_key", "$.item").await, Ok(String::from(r#"[42]"#))); + /// # Ok(()) } + /// ``` + /// + /// With RedisJSON commands, you have to note that all results will be wrapped + /// in square brackets (or empty brackets if not found). If you want to deserialize it + /// with e.g. `serde_json` you have to use `Vec` for your output type instead of `T`. + /// + #[cfg(feature = "aio")] + pub trait JsonAsyncCommands : crate::aio::ConnectionLike + Send + Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty + Send + Sync + $lifetime,)* RV>( + & $lifetime mut self + $(, $argname: $argty)* + ) -> $crate::types::RedisFuture<'a, RV> + where + RV: FromRedisValue, + { + Box::pin(async move { + $body?.query_async(self).await + }) + } + )* + } + + /// Implements RedisJSON commands for pipelines. Unlike the regular + /// commands trait, this returns the pipeline rather than a result + /// directly. Other than that it works the same however. + impl Pipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> RedisResult<&mut Self> { + self.add_command($body?); + Ok(self) + } + )* + } + + /// Implements RedisJSON commands for cluster pipelines. Unlike the regular + /// commands trait, this returns the cluster pipeline rather than a result + /// directly. Other than that it works the same however. + #[cfg(feature = "cluster")] + impl ClusterPipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> RedisResult<&mut Self> { + self.add_command($body?); + Ok(self) + } + )* + } + + ) +} + +implement_json_commands! { + 'a + + /// Append the JSON `value` to the array at `path` after the last element in it. + fn json_arr_append(key: K, path: P, value: &'a V) { + let mut cmd = cmd("JSON.ARRAPPEND"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + } + + /// Index array at `path`, returns first occurance of `value` + fn json_arr_index(key: K, path: P, value: &'a V) { + let mut cmd = cmd("JSON.ARRINDEX"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + } + + /// Same as `json_arr_index` except takes a `start` and a `stop` value, setting these to `0` will mean + /// they make no effect on the query + /// + /// The default values for `start` and `stop` are `0`, so pass those in if you want them to take no effect + fn json_arr_index_ss(key: K, path: P, value: &'a V, start: &'a isize, stop: &'a isize) { + let mut cmd = cmd("JSON.ARRINDEX"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?) + .arg(start) + .arg(stop); + + Ok::<_, RedisError>(cmd) + } + + /// Inserts the JSON `value` in the array at `path` before the `index` (shifts to the right). + /// + /// `index` must be withing the array's range. + fn json_arr_insert(key: K, path: P, index: i64, value: &'a V) { + let mut cmd = cmd("JSON.ARRINSERT"); + + cmd.arg(key) + .arg(path) + .arg(index) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + + } + + /// Reports the length of the JSON Array at `path` in `key`. + fn json_arr_len(key: K, path: P) { + let mut cmd = cmd("JSON.ARRLEN"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Removes and returns an element from the `index` in the array. + /// + /// `index` defaults to `-1` (the end of the array). + fn json_arr_pop(key: K, path: P, index: i64) { + let mut cmd = cmd("JSON.ARRPOP"); + + cmd.arg(key) + .arg(path) + .arg(index); + + Ok::<_, RedisError>(cmd) + } + + /// Trims an array so that it contains only the specified inclusive range of elements. + /// + /// This command is extremely forgiving and using it with out-of-range indexes will not produce an error. + /// There are a few differences between how RedisJSON v2.0 and legacy versions handle out-of-range indexes. + fn json_arr_trim(key: K, path: P, start: i64, stop: i64) { + let mut cmd = cmd("JSON.ARRTRIM"); + + cmd.arg(key) + .arg(path) + .arg(start) + .arg(stop); + + Ok::<_, RedisError>(cmd) + } + + /// Clears container values (Arrays/Objects), and sets numeric values to 0. + fn json_clear(key: K, path: P) { + let mut cmd = cmd("JSON.CLEAR"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Deletes a value at `path`. + fn json_del(key: K, path: P) { + let mut cmd = cmd("JSON.DEL"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Gets JSON Value(s) at `path`. + /// + /// Runs `JSON.GET` if key is singular, `JSON.MGET` if there are multiple keys. + /// + /// With RedisJSON commands, you have to note that all results will be wrapped + /// in square brackets (or empty brackets if not found). If you want to deserialize it + /// with e.g. `serde_json` you have to use `Vec` for your output type instead of `T`. + fn json_get(key: K, path: P) { + let mut cmd = cmd(if key.is_single_arg() { "JSON.GET" } else { "JSON.MGET" }); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Increments the number value stored at `path` by `number`. + fn json_num_incr_by(key: K, path: P, value: i64) { + let mut cmd = cmd("JSON.NUMINCRBY"); + + cmd.arg(key) + .arg(path) + .arg(value); + + Ok::<_, RedisError>(cmd) + } + + /// Returns the keys in the object that's referenced by `path`. + fn json_obj_keys(key: K, path: P) { + let mut cmd = cmd("JSON.OBJKEYS"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Reports the number of keys in the JSON Object at `path` in `key`. + fn json_obj_len(key: K, path: P) { + let mut cmd = cmd("JSON.OBJLEN"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Sets the JSON Value at `path` in `key`. + fn json_set(key: K, path: P, value: &'a V) { + let mut cmd = cmd("JSON.SET"); + + cmd.arg(key) + .arg(path) + .arg(serde_json::to_string(value)?); + + Ok::<_, RedisError>(cmd) + } + + /// Appends the `json-string` values to the string at `path`. + fn json_str_append(key: K, path: P, value: V) { + let mut cmd = cmd("JSON.STRAPPEND"); + + cmd.arg(key) + .arg(path) + .arg(value); + + Ok::<_, RedisError>(cmd) + } + + /// Reports the length of the JSON String at `path` in `key`. + fn json_str_len(key: K, path: P) { + let mut cmd = cmd("JSON.STRLEN"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Toggle a `boolean` value stored at `path`. + fn json_toggle(key: K, path: P) { + let mut cmd = cmd("JSON.TOGGLE"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } + + /// Reports the type of JSON value at `path`. + fn json_type(key: K, path: P) { + let mut cmd = cmd("JSON.TYPE"); + + cmd.arg(key) + .arg(path); + + Ok::<_, RedisError>(cmd) + } +} + +impl JsonCommands for T where T: ConnectionLike {} + +#[cfg(feature = "aio")] +impl JsonAsyncCommands for T where T: crate::aio::ConnectionLike + Send + Sized {} diff --git a/glide-core/redis-rs/redis/src/commands/macros.rs b/glide-core/redis-rs/redis/src/commands/macros.rs new file mode 100644 index 0000000000..9e7d4373c0 --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/macros.rs @@ -0,0 +1,275 @@ +macro_rules! implement_commands { + ( + $lifetime: lifetime + $( + $(#[$attr:meta])+ + fn $name:ident<$($tyargs:ident : $ty:ident),*>( + $($argname:ident: $argty:ty),*) $body:block + )* + ) => + ( + /// Implements common redis commands for connection like objects. This + /// allows you to send commands straight to a connection or client. It + /// is also implemented for redis results of clients which makes for + /// very convenient access in some basic cases. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// # fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// redis::cmd("SET").arg("my_key").arg(42).execute(&mut con); + /// assert_eq!(redis::cmd("GET").arg("my_key").query(&mut con), Ok(42)); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// # fn do_something() -> redis::RedisResult<()> { + /// use redis::Commands; + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_connection(None)?; + /// con.set("my_key", 42)?; + /// assert_eq!(con.get("my_key"), Ok(42)); + /// # Ok(()) } + /// ``` + pub trait Commands : ConnectionLike+Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty, )* RV: FromRedisValue>( + &mut self $(, $argname: $argty)*) -> RedisResult + { Cmd::$name($($argname),*).query(self) } + )* + + /// Incrementally iterate the keys space. + #[inline] + fn scan(&mut self) -> RedisResult> { + let mut c = cmd("SCAN"); + c.cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate the keys space for keys matching a pattern. + #[inline] + fn scan_match(&mut self, pattern: P) -> RedisResult> { + let mut c = cmd("SCAN"); + c.cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + + /// Incrementally iterate hash fields and associated values. + #[inline] + fn hscan(&mut self, key: K) -> RedisResult> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate hash fields and associated values for + /// field names matching a pattern. + #[inline] + fn hscan_match + (&mut self, key: K, pattern: P) -> RedisResult> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + + /// Incrementally iterate set elements. + #[inline] + fn sscan(&mut self, key: K) -> RedisResult> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate set elements for elements matching a pattern. + #[inline] + fn sscan_match + (&mut self, key: K, pattern: P) -> RedisResult> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + + /// Incrementally iterate sorted set elements. + #[inline] + fn zscan(&mut self, key: K) -> RedisResult> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0); + c.iter(self) + } + + /// Incrementally iterate sorted set elements for elements matching a pattern. + #[inline] + fn zscan_match + (&mut self, key: K, pattern: P) -> RedisResult> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + c.iter(self) + } + } + + impl Cmd { + $( + $(#[$attr])* + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>($($argname: $argty),*) -> Self { + ::std::mem::replace($body, Cmd::new()) + } + )* + } + + /// Implements common redis commands over asynchronous connections. This + /// allows you to send commands straight to a connection or client. + /// + /// This allows you to use nicer syntax for some common operations. + /// For instance this code: + /// + /// ```rust,no_run + /// use redis::AsyncCommands; + /// # async fn do_something() -> redis::RedisResult<()> { + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// redis::cmd("SET").arg("my_key").arg(42i32).query_async(&mut con).await?; + /// assert_eq!(redis::cmd("GET").arg("my_key").query_async(&mut con).await, Ok(42i32)); + /// # Ok(()) } + /// ``` + /// + /// Will become this: + /// + /// ```rust,no_run + /// use redis::AsyncCommands; + /// # async fn do_something() -> redis::RedisResult<()> { + /// use redis::Commands; + /// let client = redis::Client::open("redis://127.0.0.1/")?; + /// let mut con = client.get_async_connection(None).await?; + /// con.set("my_key", 42i32).await?; + /// assert_eq!(con.get("my_key").await, Ok(42i32)); + /// # Ok(()) } + /// ``` + #[cfg(feature = "aio")] + pub trait AsyncCommands : crate::aio::ConnectionLike + Send + Sized { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + fn $name<$lifetime, $($tyargs: $ty + Send + Sync + $lifetime,)* RV>( + & $lifetime mut self + $(, $argname: $argty)* + ) -> crate::types::RedisFuture<'a, RV> + where + RV: FromRedisValue, + { + Box::pin(async move { ($body).query_async(self).await }) + } + )* + + /// Incrementally iterate the keys space. + #[inline] + fn scan(&mut self) -> crate::types::RedisFuture> { + let mut c = cmd("SCAN"); + c.cursor_arg(0); + Box::pin(async move { c.iter_async(self).await }) + } + + /// Incrementally iterate set elements for elements matching a pattern. + #[inline] + fn scan_match(&mut self, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("SCAN"); + c.cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move { c.iter_async(self).await }) + } + + /// Incrementally iterate hash fields and associated values. + #[inline] + fn hscan(&mut self, key: K) -> crate::types::RedisFuture> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate hash fields and associated values for + /// field names matching a pattern. + #[inline] + fn hscan_match + (&mut self, key: K, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("HSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate set elements. + #[inline] + fn sscan(&mut self, key: K) -> crate::types::RedisFuture> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate set elements for elements matching a pattern. + #[inline] + fn sscan_match + (&mut self, key: K, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("SSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate sorted set elements. + #[inline] + fn zscan(&mut self, key: K) -> crate::types::RedisFuture> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0); + Box::pin(async move {c.iter_async(self).await }) + } + + /// Incrementally iterate sorted set elements for elements matching a pattern. + #[inline] + fn zscan_match + (&mut self, key: K, pattern: P) -> crate::types::RedisFuture> { + let mut c = cmd("ZSCAN"); + c.arg(key).cursor_arg(0).arg("MATCH").arg(pattern); + Box::pin(async move {c.iter_async(self).await }) + } + } + + /// Implements common redis commands for pipelines. Unlike the regular + /// commands trait, this returns the pipeline rather than a result + /// directly. Other than that it works the same however. + impl Pipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> &mut Self { + self.add_command(::std::mem::replace($body, Cmd::new())) + } + )* + } + + // Implements common redis commands for cluster pipelines. Unlike the regular + // commands trait, this returns the cluster pipeline rather than a result + // directly. Other than that it works the same however. + #[cfg(feature = "cluster")] + impl ClusterPipeline { + $( + $(#[$attr])* + #[inline] + #[allow(clippy::extra_unused_lifetimes, clippy::needless_lifetimes)] + pub fn $name<$lifetime, $($tyargs: $ty),*>( + &mut self $(, $argname: $argty)* + ) -> &mut Self { + self.add_command(::std::mem::replace($body, Cmd::new())) + } + )* + } + ) +} diff --git a/glide-core/redis-rs/redis/src/commands/mod.rs b/glide-core/redis-rs/redis/src/commands/mod.rs new file mode 100644 index 0000000000..d5c937fa70 --- /dev/null +++ b/glide-core/redis-rs/redis/src/commands/mod.rs @@ -0,0 +1,2190 @@ +use crate::cmd::{cmd, Cmd, Iter}; +use crate::connection::{Connection, ConnectionLike, Msg}; +use crate::pipeline::Pipeline; +use crate::types::{ + ExistenceCheck, Expiry, FromRedisValue, NumericBehavior, RedisResult, RedisWrite, SetExpiry, + ToRedisArgs, +}; + +#[macro_use] +mod macros; + +#[cfg(feature = "json")] +#[cfg_attr(docsrs, doc(cfg(feature = "json")))] +mod json; + +#[cfg(feature = "cluster-async")] +pub use cluster_scan::ScanStateRC; + +#[cfg(feature = "cluster-async")] +pub(crate) mod cluster_scan; + +#[cfg(feature = "cluster-async")] +pub use cluster_scan::ObjectType; + +#[cfg(feature = "json")] +pub use json::JsonCommands; + +#[cfg(all(feature = "json", feature = "aio"))] +pub use json::JsonAsyncCommands; + +#[cfg(feature = "cluster")] +use crate::cluster_pipeline::ClusterPipeline; + +#[cfg(feature = "geospatial")] +use crate::geo; + +#[cfg(feature = "streams")] +use crate::streams; + +#[cfg(feature = "acl")] +use crate::acl; +use crate::RedisConnectionInfo; + +implement_commands! { + 'a + // most common operations + + /// Get the value of a key. If key is a vec this becomes an `MGET`. + fn get(key: K) { + cmd(if key.is_single_arg() { "GET" } else { "MGET" }).arg(key) + } + + /// Get values of keys + fn mget(key: K){ + cmd("MGET").arg(key) + } + + /// Gets all keys matching pattern + fn keys(key: K) { + cmd("KEYS").arg(key) + } + + /// Set the string value of a key. + fn set(key: K, value: V) { + cmd("SET").arg(key).arg(value) + } + + /// Set the string value of a key with options. + fn set_options(key: K, value: V, options: SetOptions) { + cmd("SET").arg(key).arg(value).arg(options) + } + + /// Sets multiple keys to their values. + #[allow(deprecated)] + #[deprecated(since = "0.22.4", note = "Renamed to mset() to reflect Redis name")] + fn set_multiple(items: &'a [(K, V)]) { + cmd("MSET").arg(items) + } + + /// Sets multiple keys to their values. + fn mset(items: &'a [(K, V)]) { + cmd("MSET").arg(items) + } + + /// Set the value and expiration of a key. + fn set_ex(key: K, value: V, seconds: u64) { + cmd("SETEX").arg(key).arg(seconds).arg(value) + } + + /// Set the value and expiration in milliseconds of a key. + fn pset_ex(key: K, value: V, milliseconds: u64) { + cmd("PSETEX").arg(key).arg(milliseconds).arg(value) + } + + /// Set the value of a key, only if the key does not exist + fn set_nx(key: K, value: V) { + cmd("SETNX").arg(key).arg(value) + } + + /// Sets multiple keys to their values failing if at least one already exists. + fn mset_nx(items: &'a [(K, V)]) { + cmd("MSETNX").arg(items) + } + + /// Set the string value of a key and return its old value. + fn getset(key: K, value: V) { + cmd("GETSET").arg(key).arg(value) + } + + /// Get a range of bytes/substring from the value of a key. Negative values provide an offset from the end of the value. + fn getrange(key: K, from: isize, to: isize) { + cmd("GETRANGE").arg(key).arg(from).arg(to) + } + + /// Overwrite the part of the value stored in key at the specified offset. + fn setrange(key: K, offset: isize, value: V) { + cmd("SETRANGE").arg(key).arg(offset).arg(value) + } + + /// Delete one or more keys. + fn del(key: K) { + cmd("DEL").arg(key) + } + + /// Determine if a key exists. + fn exists(key: K) { + cmd("EXISTS").arg(key) + } + + /// Determine the type of a key. + fn key_type(key: K) { + cmd("TYPE").arg(key) + } + + /// Set a key's time to live in seconds. + fn expire(key: K, seconds: i64) { + cmd("EXPIRE").arg(key).arg(seconds) + } + + /// Set the expiration for a key as a UNIX timestamp. + fn expire_at(key: K, ts: i64) { + cmd("EXPIREAT").arg(key).arg(ts) + } + + /// Set a key's time to live in milliseconds. + fn pexpire(key: K, ms: i64) { + cmd("PEXPIRE").arg(key).arg(ms) + } + + /// Set the expiration for a key as a UNIX timestamp in milliseconds. + fn pexpire_at(key: K, ts: i64) { + cmd("PEXPIREAT").arg(key).arg(ts) + } + + /// Remove the expiration from a key. + fn persist(key: K) { + cmd("PERSIST").arg(key) + } + + /// Get the expiration time of a key. + fn ttl(key: K) { + cmd("TTL").arg(key) + } + + /// Get the expiration time of a key in milliseconds. + fn pttl(key: K) { + cmd("PTTL").arg(key) + } + + /// Get the value of a key and set expiration + fn get_ex(key: K, expire_at: Expiry) { + let (option, time_arg) = match expire_at { + Expiry::EX(sec) => ("EX", Some(sec)), + Expiry::PX(ms) => ("PX", Some(ms)), + Expiry::EXAT(timestamp_sec) => ("EXAT", Some(timestamp_sec)), + Expiry::PXAT(timestamp_ms) => ("PXAT", Some(timestamp_ms)), + Expiry::PERSIST => ("PERSIST", None), + }; + + cmd("GETEX").arg(key).arg(option).arg(time_arg) + } + + /// Get the value of a key and delete it + fn get_del(key: K) { + cmd("GETDEL").arg(key) + } + + /// Rename a key. + fn rename(key: K, new_key: N) { + cmd("RENAME").arg(key).arg(new_key) + } + + /// Rename a key, only if the new key does not exist. + fn rename_nx(key: K, new_key: N) { + cmd("RENAMENX").arg(key).arg(new_key) + } + + /// Unlink one or more keys. + fn unlink(key: K) { + cmd("UNLINK").arg(key) + } + + // common string operations + + /// Append a value to a key. + fn append(key: K, value: V) { + cmd("APPEND").arg(key).arg(value) + } + + /// Increment the numeric value of a key by the given amount. This + /// issues a `INCRBY` or `INCRBYFLOAT` depending on the type. + fn incr(key: K, delta: V) { + cmd(if delta.describe_numeric_behavior() == NumericBehavior::NumberIsFloat { + "INCRBYFLOAT" + } else { + "INCRBY" + }).arg(key).arg(delta) + } + + /// Decrement the numeric value of a key by the given amount. + fn decr(key: K, delta: V) { + cmd("DECRBY").arg(key).arg(delta) + } + + /// Sets or clears the bit at offset in the string value stored at key. + fn setbit(key: K, offset: usize, value: bool) { + cmd("SETBIT").arg(key).arg(offset).arg(i32::from(value)) + } + + /// Returns the bit value at offset in the string value stored at key. + fn getbit(key: K, offset: usize) { + cmd("GETBIT").arg(key).arg(offset) + } + + /// Count set bits in a string. + fn bitcount(key: K) { + cmd("BITCOUNT").arg(key) + } + + /// Count set bits in a string in a range. + fn bitcount_range(key: K, start: usize, end: usize) { + cmd("BITCOUNT").arg(key).arg(start).arg(end) + } + + /// Perform a bitwise AND between multiple keys (containing string values) + /// and store the result in the destination key. + fn bit_and(dstkey: D, srckeys: S) { + cmd("BITOP").arg("AND").arg(dstkey).arg(srckeys) + } + + /// Perform a bitwise OR between multiple keys (containing string values) + /// and store the result in the destination key. + fn bit_or(dstkey: D, srckeys: S) { + cmd("BITOP").arg("OR").arg(dstkey).arg(srckeys) + } + + /// Perform a bitwise XOR between multiple keys (containing string values) + /// and store the result in the destination key. + fn bit_xor(dstkey: D, srckeys: S) { + cmd("BITOP").arg("XOR").arg(dstkey).arg(srckeys) + } + + /// Perform a bitwise NOT of the key (containing string values) + /// and store the result in the destination key. + fn bit_not(dstkey: D, srckey: S) { + cmd("BITOP").arg("NOT").arg(dstkey).arg(srckey) + } + + /// Get the length of the value stored in a key. + fn strlen(key: K) { + cmd("STRLEN").arg(key) + } + + // hash operations + + /// Gets a single (or multiple) fields from a hash. + fn hget(key: K, field: F) { + cmd(if field.is_single_arg() { "HGET" } else { "HMGET" }).arg(key).arg(field) + } + + /// Deletes a single (or multiple) fields from a hash. + fn hdel(key: K, field: F) { + cmd("HDEL").arg(key).arg(field) + } + + /// Sets a single field in a hash. + fn hset(key: K, field: F, value: V) { + cmd("HSET").arg(key).arg(field).arg(value) + } + + /// Sets a single field in a hash if it does not exist. + fn hset_nx(key: K, field: F, value: V) { + cmd("HSETNX").arg(key).arg(field).arg(value) + } + + /// Sets a multiple fields in a hash. + fn hset_multiple(key: K, items: &'a [(F, V)]) { + cmd("HMSET").arg(key).arg(items) + } + + /// Increments a value. + fn hincr(key: K, field: F, delta: D) { + cmd(if delta.describe_numeric_behavior() == NumericBehavior::NumberIsFloat { + "HINCRBYFLOAT" + } else { + "HINCRBY" + }).arg(key).arg(field).arg(delta) + } + + /// Checks if a field in a hash exists. + fn hexists(key: K, field: F) { + cmd("HEXISTS").arg(key).arg(field) + } + + /// Gets all the keys in a hash. + fn hkeys(key: K) { + cmd("HKEYS").arg(key) + } + + /// Gets all the values in a hash. + fn hvals(key: K) { + cmd("HVALS").arg(key) + } + + /// Gets all the fields and values in a hash. + fn hgetall(key: K) { + cmd("HGETALL").arg(key) + } + + /// Gets the length of a hash. + fn hlen(key: K) { + cmd("HLEN").arg(key) + } + + // list operations + + /// Pop an element from a list, push it to another list + /// and return it; or block until one is available + fn blmove(srckey: S, dstkey: D, src_dir: Direction, dst_dir: Direction, timeout: f64) { + cmd("BLMOVE").arg(srckey).arg(dstkey).arg(src_dir).arg(dst_dir).arg(timeout) + } + + /// Pops `count` elements from the first non-empty list key from the list of + /// provided key names; or blocks until one is available. + fn blmpop(timeout: f64, numkeys: usize, key: K, dir: Direction, count: usize){ + cmd("BLMPOP").arg(timeout).arg(numkeys).arg(key).arg(dir).arg("COUNT").arg(count) + } + + /// Remove and get the first element in a list, or block until one is available. + fn blpop(key: K, timeout: f64) { + cmd("BLPOP").arg(key).arg(timeout) + } + + /// Remove and get the last element in a list, or block until one is available. + fn brpop(key: K, timeout: f64) { + cmd("BRPOP").arg(key).arg(timeout) + } + + /// Pop a value from a list, push it to another list and return it; + /// or block until one is available. + fn brpoplpush(srckey: S, dstkey: D, timeout: f64) { + cmd("BRPOPLPUSH").arg(srckey).arg(dstkey).arg(timeout) + } + + /// Get an element from a list by its index. + fn lindex(key: K, index: isize) { + cmd("LINDEX").arg(key).arg(index) + } + + /// Insert an element before another element in a list. + fn linsert_before( + key: K, pivot: P, value: V) { + cmd("LINSERT").arg(key).arg("BEFORE").arg(pivot).arg(value) + } + + /// Insert an element after another element in a list. + fn linsert_after( + key: K, pivot: P, value: V) { + cmd("LINSERT").arg(key).arg("AFTER").arg(pivot).arg(value) + } + + /// Returns the length of the list stored at key. + fn llen(key: K) { + cmd("LLEN").arg(key) + } + + /// Pop an element a list, push it to another list and return it + fn lmove(srckey: S, dstkey: D, src_dir: Direction, dst_dir: Direction) { + cmd("LMOVE").arg(srckey).arg(dstkey).arg(src_dir).arg(dst_dir) + } + + /// Pops `count` elements from the first non-empty list key from the list of + /// provided key names. + fn lmpop( numkeys: usize, key: K, dir: Direction, count: usize) { + cmd("LMPOP").arg(numkeys).arg(key).arg(dir).arg("COUNT").arg(count) + } + + /// Removes and returns the up to `count` first elements of the list stored at key. + /// + /// If `count` is not specified, then defaults to first element. + fn lpop(key: K, count: Option) { + cmd("LPOP").arg(key).arg(count) + } + + /// Returns the index of the first matching value of the list stored at key. + fn lpos(key: K, value: V, options: LposOptions) { + cmd("LPOS").arg(key).arg(value).arg(options) + } + + /// Insert all the specified values at the head of the list stored at key. + fn lpush(key: K, value: V) { + cmd("LPUSH").arg(key).arg(value) + } + + /// Inserts a value at the head of the list stored at key, only if key + /// already exists and holds a list. + fn lpush_exists(key: K, value: V) { + cmd("LPUSHX").arg(key).arg(value) + } + + /// Returns the specified elements of the list stored at key. + fn lrange(key: K, start: isize, stop: isize) { + cmd("LRANGE").arg(key).arg(start).arg(stop) + } + + /// Removes the first count occurrences of elements equal to value + /// from the list stored at key. + fn lrem(key: K, count: isize, value: V) { + cmd("LREM").arg(key).arg(count).arg(value) + } + + /// Trim an existing list so that it will contain only the specified + /// range of elements specified. + fn ltrim(key: K, start: isize, stop: isize) { + cmd("LTRIM").arg(key).arg(start).arg(stop) + } + + /// Sets the list element at index to value + fn lset(key: K, index: isize, value: V) { + cmd("LSET").arg(key).arg(index).arg(value) + } + + /// Removes and returns the up to `count` last elements of the list stored at key + /// + /// If `count` is not specified, then defaults to last element. + fn rpop(key: K, count: Option) { + cmd("RPOP").arg(key).arg(count) + } + + /// Pop a value from a list, push it to another list and return it. + fn rpoplpush(key: K, dstkey: D) { + cmd("RPOPLPUSH").arg(key).arg(dstkey) + } + + /// Insert all the specified values at the tail of the list stored at key. + fn rpush(key: K, value: V) { + cmd("RPUSH").arg(key).arg(value) + } + + /// Inserts value at the tail of the list stored at key, only if key + /// already exists and holds a list. + fn rpush_exists(key: K, value: V) { + cmd("RPUSHX").arg(key).arg(value) + } + + // set commands + + /// Add one or more members to a set. + fn sadd(key: K, member: M) { + cmd("SADD").arg(key).arg(member) + } + + /// Get the number of members in a set. + fn scard(key: K) { + cmd("SCARD").arg(key) + } + + /// Subtract multiple sets. + fn sdiff(keys: K) { + cmd("SDIFF").arg(keys) + } + + /// Subtract multiple sets and store the resulting set in a key. + fn sdiffstore(dstkey: D, keys: K) { + cmd("SDIFFSTORE").arg(dstkey).arg(keys) + } + + /// Intersect multiple sets. + fn sinter(keys: K) { + cmd("SINTER").arg(keys) + } + + /// Intersect multiple sets and store the resulting set in a key. + fn sinterstore(dstkey: D, keys: K) { + cmd("SINTERSTORE").arg(dstkey).arg(keys) + } + + /// Determine if a given value is a member of a set. + fn sismember(key: K, member: M) { + cmd("SISMEMBER").arg(key).arg(member) + } + + /// Determine if given values are members of a set. + fn smismember(key: K, members: M) { + cmd("SMISMEMBER").arg(key).arg(members) + } + + /// Get all the members in a set. + fn smembers(key: K) { + cmd("SMEMBERS").arg(key) + } + + /// Move a member from one set to another. + fn smove(srckey: S, dstkey: D, member: M) { + cmd("SMOVE").arg(srckey).arg(dstkey).arg(member) + } + + /// Remove and return a random member from a set. + fn spop(key: K) { + cmd("SPOP").arg(key) + } + + /// Get one random member from a set. + fn srandmember(key: K) { + cmd("SRANDMEMBER").arg(key) + } + + /// Get multiple random members from a set. + fn srandmember_multiple(key: K, count: usize) { + cmd("SRANDMEMBER").arg(key).arg(count) + } + + /// Remove one or more members from a set. + fn srem(key: K, member: M) { + cmd("SREM").arg(key).arg(member) + } + + /// Add multiple sets. + fn sunion(keys: K) { + cmd("SUNION").arg(keys) + } + + /// Add multiple sets and store the resulting set in a key. + fn sunionstore(dstkey: D, keys: K) { + cmd("SUNIONSTORE").arg(dstkey).arg(keys) + } + + // sorted set commands + + /// Add one member to a sorted set, or update its score if it already exists. + fn zadd(key: K, member: M, score: S) { + cmd("ZADD").arg(key).arg(score).arg(member) + } + + /// Add multiple members to a sorted set, or update its score if it already exists. + fn zadd_multiple(key: K, items: &'a [(S, M)]) { + cmd("ZADD").arg(key).arg(items) + } + + /// Get the number of members in a sorted set. + fn zcard(key: K) { + cmd("ZCARD").arg(key) + } + + /// Count the members in a sorted set with scores within the given values. + fn zcount(key: K, min: M, max: MM) { + cmd("ZCOUNT").arg(key).arg(min).arg(max) + } + + /// Increments the member in a sorted set at key by delta. + /// If the member does not exist, it is added with delta as its score. + fn zincr(key: K, member: M, delta: D) { + cmd("ZINCRBY").arg(key).arg(delta).arg(member) + } + + /// Intersect multiple sorted sets and store the resulting sorted set in + /// a new key using SUM as aggregation function. + fn zinterstore(dstkey: D, keys: &'a [K]) { + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys) + } + + /// Intersect multiple sorted sets and store the resulting sorted set in + /// a new key using MIN as aggregation function. + fn zinterstore_min(dstkey: D, keys: &'a [K]) { + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN") + } + + /// Intersect multiple sorted sets and store the resulting sorted set in + /// a new key using MAX as aggregation function. + fn zinterstore_max(dstkey: D, keys: &'a [K]) { + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX") + } + + /// [`Commands::zinterstore`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zinterstore_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zinterstore_min`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zinterstore_min_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN").arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zinterstore_max`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zinterstore_max_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZINTERSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX").arg("WEIGHTS").arg(weights) + } + + /// Count the number of members in a sorted set between a given lexicographical range. + fn zlexcount(key: K, min: M, max: MM) { + cmd("ZLEXCOUNT").arg(key).arg(min).arg(max) + } + + /// Removes and returns the member with the highest score in a sorted set. + /// Blocks until a member is available otherwise. + fn bzpopmax(key: K, timeout: f64) { + cmd("BZPOPMAX").arg(key).arg(timeout) + } + + /// Removes and returns up to count members with the highest scores in a sorted set + fn zpopmax(key: K, count: isize) { + cmd("ZPOPMAX").arg(key).arg(count) + } + + /// Removes and returns the member with the lowest score in a sorted set. + /// Blocks until a member is available otherwise. + fn bzpopmin(key: K, timeout: f64) { + cmd("BZPOPMIN").arg(key).arg(timeout) + } + + /// Removes and returns up to count members with the lowest scores in a sorted set + fn zpopmin(key: K, count: isize) { + cmd("ZPOPMIN").arg(key).arg(count) + } + + /// Removes and returns up to count members with the highest scores, + /// from the first non-empty sorted set in the provided list of key names. + /// Blocks until a member is available otherwise. + fn bzmpop_max(timeout: f64, keys: &'a [K], count: isize) { + cmd("BZMPOP").arg(timeout).arg(keys.len()).arg(keys).arg("MAX").arg("COUNT").arg(count) + } + + /// Removes and returns up to count members with the highest scores, + /// from the first non-empty sorted set in the provided list of key names. + fn zmpop_max(keys: &'a [K], count: isize) { + cmd("ZMPOP").arg(keys.len()).arg(keys).arg("MAX").arg("COUNT").arg(count) + } + + /// Removes and returns up to count members with the lowest scores, + /// from the first non-empty sorted set in the provided list of key names. + /// Blocks until a member is available otherwise. + fn bzmpop_min(timeout: f64, keys: &'a [K], count: isize) { + cmd("BZMPOP").arg(timeout).arg(keys.len()).arg(keys).arg("MIN").arg("COUNT").arg(count) + } + + /// Removes and returns up to count members with the lowest scores, + /// from the first non-empty sorted set in the provided list of key names. + fn zmpop_min(keys: &'a [K], count: isize) { + cmd("ZMPOP").arg(keys.len()).arg(keys).arg("MIN").arg("COUNT").arg(count) + } + + /// Return up to count random members in a sorted set (or 1 if `count == None`) + fn zrandmember(key: K, count: Option) { + cmd("ZRANDMEMBER").arg(key).arg(count) + } + + /// Return up to count random members in a sorted set with scores + fn zrandmember_withscores(key: K, count: isize) { + cmd("ZRANDMEMBER").arg(key).arg(count).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by index + fn zrange(key: K, start: isize, stop: isize) { + cmd("ZRANGE").arg(key).arg(start).arg(stop) + } + + /// Return a range of members in a sorted set, by index with scores. + fn zrange_withscores(key: K, start: isize, stop: isize) { + cmd("ZRANGE").arg(key).arg(start).arg(stop).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by lexicographical range. + fn zrangebylex(key: K, min: M, max: MM) { + cmd("ZRANGEBYLEX").arg(key).arg(min).arg(max) + } + + /// Return a range of members in a sorted set, by lexicographical + /// range with offset and limit. + fn zrangebylex_limit( + key: K, min: M, max: MM, offset: isize, count: isize) { + cmd("ZRANGEBYLEX").arg(key).arg(min).arg(max).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by lexicographical range. + fn zrevrangebylex(key: K, max: MM, min: M) { + cmd("ZREVRANGEBYLEX").arg(key).arg(max).arg(min) + } + + /// Return a range of members in a sorted set, by lexicographical + /// range with offset and limit. + fn zrevrangebylex_limit( + key: K, max: MM, min: M, offset: isize, count: isize) { + cmd("ZREVRANGEBYLEX").arg(key).arg(max).arg(min).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by score. + fn zrangebyscore(key: K, min: M, max: MM) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max) + } + + /// Return a range of members in a sorted set, by score with scores. + fn zrangebyscore_withscores(key: K, min: M, max: MM) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by score with limit. + fn zrangebyscore_limit + (key: K, min: M, max: MM, offset: isize, count: isize) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by score with limit with scores. + fn zrangebyscore_limit_withscores + (key: K, min: M, max: MM, offset: isize, count: isize) { + cmd("ZRANGEBYSCORE").arg(key).arg(min).arg(max).arg("WITHSCORES") + .arg("LIMIT").arg(offset).arg(count) + } + + /// Determine the index of a member in a sorted set. + fn zrank(key: K, member: M) { + cmd("ZRANK").arg(key).arg(member) + } + + /// Remove one or more members from a sorted set. + fn zrem(key: K, members: M) { + cmd("ZREM").arg(key).arg(members) + } + + /// Remove all members in a sorted set between the given lexicographical range. + fn zrembylex(key: K, min: M, max: MM) { + cmd("ZREMRANGEBYLEX").arg(key).arg(min).arg(max) + } + + /// Remove all members in a sorted set within the given indexes. + fn zremrangebyrank(key: K, start: isize, stop: isize) { + cmd("ZREMRANGEBYRANK").arg(key).arg(start).arg(stop) + } + + /// Remove all members in a sorted set within the given scores. + fn zrembyscore(key: K, min: M, max: MM) { + cmd("ZREMRANGEBYSCORE").arg(key).arg(min).arg(max) + } + + /// Return a range of members in a sorted set, by index, with scores + /// ordered from high to low. + fn zrevrange(key: K, start: isize, stop: isize) { + cmd("ZREVRANGE").arg(key).arg(start).arg(stop) + } + + /// Return a range of members in a sorted set, by index, with scores + /// ordered from high to low. + fn zrevrange_withscores(key: K, start: isize, stop: isize) { + cmd("ZREVRANGE").arg(key).arg(start).arg(stop).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by score. + fn zrevrangebyscore(key: K, max: MM, min: M) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min) + } + + /// Return a range of members in a sorted set, by score with scores. + fn zrevrangebyscore_withscores(key: K, max: MM, min: M) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min).arg("WITHSCORES") + } + + /// Return a range of members in a sorted set, by score with limit. + fn zrevrangebyscore_limit + (key: K, max: MM, min: M, offset: isize, count: isize) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min).arg("LIMIT").arg(offset).arg(count) + } + + /// Return a range of members in a sorted set, by score with limit with scores. + fn zrevrangebyscore_limit_withscores + (key: K, max: MM, min: M, offset: isize, count: isize) { + cmd("ZREVRANGEBYSCORE").arg(key).arg(max).arg(min).arg("WITHSCORES") + .arg("LIMIT").arg(offset).arg(count) + } + + /// Determine the index of a member in a sorted set, with scores ordered from high to low. + fn zrevrank(key: K, member: M) { + cmd("ZREVRANK").arg(key).arg(member) + } + + /// Get the score associated with the given member in a sorted set. + fn zscore(key: K, member: M) { + cmd("ZSCORE").arg(key).arg(member) + } + + /// Get the scores associated with multiple members in a sorted set. + fn zscore_multiple(key: K, members: &'a [M]) { + cmd("ZMSCORE").arg(key).arg(members) + } + + /// Unions multiple sorted sets and store the resulting sorted set in + /// a new key using SUM as aggregation function. + fn zunionstore(dstkey: D, keys: &'a [K]) { + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys) + } + + /// Unions multiple sorted sets and store the resulting sorted set in + /// a new key using MIN as aggregation function. + fn zunionstore_min(dstkey: D, keys: &'a [K]) { + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN") + } + + /// Unions multiple sorted sets and store the resulting sorted set in + /// a new key using MAX as aggregation function. + fn zunionstore_max(dstkey: D, keys: &'a [K]) { + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX") + } + + /// [`Commands::zunionstore`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zunionstore_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zunionstore_min`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zunionstore_min_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MIN").arg("WEIGHTS").arg(weights) + } + + /// [`Commands::zunionstore_max`], but with the ability to specify a + /// multiplication factor for each sorted set by pairing one with each key + /// in a tuple. + fn zunionstore_max_weights(dstkey: D, keys: &'a [(K, W)]) { + let (keys, weights): (Vec<&K>, Vec<&W>) = keys.iter().map(|(key, weight):&(K, W)| -> (&K, &W) {(key, weight)}).unzip(); + cmd("ZUNIONSTORE").arg(dstkey).arg(keys.len()).arg(keys).arg("AGGREGATE").arg("MAX").arg("WEIGHTS").arg(weights) + } + + // hyperloglog commands + + /// Adds the specified elements to the specified HyperLogLog. + fn pfadd(key: K, element: E) { + cmd("PFADD").arg(key).arg(element) + } + + /// Return the approximated cardinality of the set(s) observed by the + /// HyperLogLog at key(s). + fn pfcount(key: K) { + cmd("PFCOUNT").arg(key) + } + + /// Merge N different HyperLogLogs into a single one. + fn pfmerge(dstkey: D, srckeys: S) { + cmd("PFMERGE").arg(dstkey).arg(srckeys) + } + + /// Posts a message to the given channel. + fn publish(channel: K, message: E) { + cmd("PUBLISH").arg(channel).arg(message) + } + + // Object commands + + /// Returns the encoding of a key. + fn object_encoding(key: K) { + cmd("OBJECT").arg("ENCODING").arg(key) + } + + /// Returns the time in seconds since the last access of a key. + fn object_idletime(key: K) { + cmd("OBJECT").arg("IDLETIME").arg(key) + } + + /// Returns the logarithmic access frequency counter of a key. + fn object_freq(key: K) { + cmd("OBJECT").arg("FREQ").arg(key) + } + + /// Returns the reference count of a key. + fn object_refcount(key: K) { + cmd("OBJECT").arg("REFCOUNT").arg(key) + } + + // ACL commands + + /// When Redis is configured to use an ACL file (with the aclfile + /// configuration option), this command will reload the ACLs from the file, + /// replacing all the current ACL rules with the ones defined in the file. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_load<>() { + cmd("ACL").arg("LOAD") + } + + /// When Redis is configured to use an ACL file (with the aclfile + /// configuration option), this command will save the currently defined + /// ACLs from the server memory to the ACL file. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_save<>() { + cmd("ACL").arg("SAVE") + } + + /// Shows the currently active ACL rules in the Redis server. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_list<>() { + cmd("ACL").arg("LIST") + } + + /// Shows a list of all the usernames of the currently configured users in + /// the Redis ACL system. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_users<>() { + cmd("ACL").arg("USERS") + } + + /// Returns all the rules defined for an existing ACL user. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_getuser(username: K) { + cmd("ACL").arg("GETUSER").arg(username) + } + + /// Creates an ACL user without any privilege. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_setuser(username: K) { + cmd("ACL").arg("SETUSER").arg(username) + } + + /// Creates an ACL user with the specified rules or modify the rules of + /// an existing user. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_setuser_rules(username: K, rules: &'a [acl::Rule]) { + cmd("ACL").arg("SETUSER").arg(username).arg(rules) + } + + /// Delete all the specified ACL users and terminate all the connections + /// that are authenticated with such users. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_deluser(usernames: &'a [K]) { + cmd("ACL").arg("DELUSER").arg(usernames) + } + + /// Shows the available ACL categories. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_cat<>() { + cmd("ACL").arg("CAT") + } + + /// Shows all the Redis commands in the specified category. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_cat_categoryname(categoryname: K) { + cmd("ACL").arg("CAT").arg(categoryname) + } + + /// Generates a 256-bits password starting from /dev/urandom if available. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_genpass<>() { + cmd("ACL").arg("GENPASS") + } + + /// Generates a 1-to-1024-bits password starting from /dev/urandom if available. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_genpass_bits<>(bits: isize) { + cmd("ACL").arg("GENPASS").arg(bits) + } + + /// Returns the username the current connection is authenticated with. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_whoami<>() { + cmd("ACL").arg("WHOAMI") + } + + /// Shows a list of recent ACL security events + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_log<>(count: isize) { + cmd("ACL").arg("LOG").arg(count) + + } + + /// Clears the ACL log. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_log_reset<>() { + cmd("ACL").arg("LOG").arg("RESET") + } + + /// Returns a helpful text describing the different subcommands. + #[cfg(feature = "acl")] + #[cfg_attr(docsrs, doc(cfg(feature = "acl")))] + fn acl_help<>() { + cmd("ACL").arg("HELP") + } + + // + // geospatial commands + // + + /// Adds the specified geospatial items to the specified key. + /// + /// Every member has to be written as a tuple of `(longitude, latitude, + /// member_name)`. It can be a single tuple, or a vector of tuples. + /// + /// `longitude, latitude` can be set using [`redis::geo::Coord`][1]. + /// + /// [1]: ./geo/struct.Coord.html + /// + /// Returns the number of elements added to the sorted set, not including + /// elements already existing for which the score was updated. + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, Connection, RedisResult}; + /// use redis::geo::Coord; + /// + /// fn add_point(con: &mut Connection) -> RedisResult { + /// con.geo_add("my_gis", (Coord::lon_lat(13.361389, 38.115556), "Palermo")) + /// } + /// + /// fn add_point_with_tuples(con: &mut Connection) -> RedisResult { + /// con.geo_add("my_gis", ("13.361389", "38.115556", "Palermo")) + /// } + /// + /// fn add_many_points(con: &mut Connection) -> RedisResult { + /// con.geo_add("my_gis", &[ + /// ("13.361389", "38.115556", "Palermo"), + /// ("15.087269", "37.502669", "Catania") + /// ]) + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_add(key: K, members: M) { + cmd("GEOADD").arg(key).arg(members) + } + + /// Return the distance between two members in the geospatial index + /// represented by the sorted set. + /// + /// If one or both the members are missing, the command returns NULL, so + /// it may be convenient to parse its response as either `Option` or + /// `Option`. + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// use redis::geo::Unit; + /// + /// fn get_dists(con: &mut redis::Connection) { + /// let x: RedisResult = con.geo_dist( + /// "my_gis", + /// "Palermo", + /// "Catania", + /// Unit::Kilometers + /// ); + /// // x is Ok(166.2742) + /// + /// let x: RedisResult> = con.geo_dist( + /// "my_gis", + /// "Palermo", + /// "Atlantis", + /// Unit::Meters + /// ); + /// // x is Ok(None) + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_dist( + key: K, + member1: M1, + member2: M2, + unit: geo::Unit + ) { + cmd("GEODIST") + .arg(key) + .arg(member1) + .arg(member2) + .arg(unit) + } + + /// Return valid [Geohash][1] strings representing the position of one or + /// more members of the geospatial index represented by the sorted set at + /// key. + /// + /// [1]: https://en.wikipedia.org/wiki/Geohash + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// + /// fn get_hash(con: &mut redis::Connection) { + /// let x: RedisResult> = con.geo_hash("my_gis", "Palermo"); + /// // x is vec!["sqc8b49rny0"] + /// + /// let x: RedisResult> = con.geo_hash("my_gis", &["Palermo", "Catania"]); + /// // x is vec!["sqc8b49rny0", "sqdtr74hyu0"] + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_hash(key: K, members: M) { + cmd("GEOHASH").arg(key).arg(members) + } + + /// Return the positions of all the specified members of the geospatial + /// index represented by the sorted set at key. + /// + /// Every position is a pair of `(longitude, latitude)`. [`redis::geo::Coord`][1] + /// can be used to convert these value in a struct. + /// + /// [1]: ./geo/struct.Coord.html + /// + /// # Example + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// use redis::geo::Coord; + /// + /// fn get_position(con: &mut redis::Connection) { + /// let x: RedisResult>> = con.geo_pos("my_gis", &["Palermo", "Catania"]); + /// // x is [ [ 13.361389, 38.115556 ], [ 15.087269, 37.502669 ] ]; + /// + /// let x: Vec> = con.geo_pos("my_gis", "Palermo").unwrap(); + /// // x[0].longitude is 13.361389 + /// // x[0].latitude is 38.115556 + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_pos(key: K, members: M) { + cmd("GEOPOS").arg(key).arg(members) + } + + /// Return the members of a sorted set populated with geospatial information + /// using [`geo_add`](#method.geo_add), which are within the borders of the area + /// specified with the center location and the maximum distance from the center + /// (the radius). + /// + /// Every item in the result can be read with [`redis::geo::RadiusSearchResult`][1], + /// which support the multiple formats returned by `GEORADIUS`. + /// + /// [1]: ./geo/struct.RadiusSearchResult.html + /// + /// ```rust,no_run + /// use redis::{Commands, RedisResult}; + /// use redis::geo::{RadiusOptions, RadiusSearchResult, RadiusOrder, Unit}; + /// + /// fn radius(con: &mut redis::Connection) -> Vec { + /// let opts = RadiusOptions::default().with_dist().order(RadiusOrder::Asc); + /// con.geo_radius("my_gis", 15.90, 37.21, 51.39, Unit::Kilometers, opts).unwrap() + /// } + /// ``` + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_radius( + key: K, + longitude: f64, + latitude: f64, + radius: f64, + unit: geo::Unit, + options: geo::RadiusOptions + ) { + cmd("GEORADIUS") + .arg(key) + .arg(longitude) + .arg(latitude) + .arg(radius) + .arg(unit) + .arg(options) + } + + /// Retrieve members selected by distance with the center of `member`. The + /// member itself is always contained in the results. + #[cfg(feature = "geospatial")] + #[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] + fn geo_radius_by_member( + key: K, + member: M, + radius: f64, + unit: geo::Unit, + options: geo::RadiusOptions + ) { + cmd("GEORADIUSBYMEMBER") + .arg(key) + .arg(member) + .arg(radius) + .arg(unit) + .arg(options) + } + + // + // streams commands + // + + /// Ack pending stream messages checked out by a consumer. + /// + /// ```text + /// XACK ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xack( + key: K, + group: G, + ids: &'a [I]) { + cmd("XACK") + .arg(key) + .arg(group) + .arg(ids) + } + + + /// Add a stream message by `key`. Use `*` as the `id` for the current timestamp. + /// + /// ```text + /// XADD key [field value] [field value] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd( + key: K, + id: ID, + items: &'a [(F, V)] + ) { + cmd("XADD").arg(key).arg(id).arg(items) + } + + + /// BTreeMap variant for adding a stream message by `key`. + /// Use `*` as the `id` for the current timestamp. + /// + /// ```text + /// XADD key [rust BTreeMap] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd_map( + key: K, + id: ID, + map: BTM + ) { + cmd("XADD").arg(key).arg(id).arg(map) + } + + /// Add a stream message while capping the stream at a maxlength. + /// + /// ```text + /// XADD key [MAXLEN [~|=] ] [field value] [field value] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd_maxlen< + K: ToRedisArgs, + ID: ToRedisArgs, + F: ToRedisArgs, + V: ToRedisArgs + >( + key: K, + maxlen: streams::StreamMaxlen, + id: ID, + items: &'a [(F, V)] + ) { + cmd("XADD") + .arg(key) + .arg(maxlen) + .arg(id) + .arg(items) + } + + + /// BTreeMap variant for adding a stream message while capping the stream at a maxlength. + /// + /// ```text + /// XADD key [MAXLEN [~|=] ] [rust BTreeMap] ... + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xadd_maxlen_map( + key: K, + maxlen: streams::StreamMaxlen, + id: ID, + map: BTM + ) { + cmd("XADD") + .arg(key) + .arg(maxlen) + .arg(id) + .arg(map) + } + + + + /// Claim pending, unacked messages, after some period of time, + /// currently checked out by another consumer. + /// + /// This method only accepts the must-have arguments for claiming messages. + /// If optional arguments are required, see `xclaim_options` below. + /// + /// ```text + /// XCLAIM [ ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xclaim( + key: K, + group: G, + consumer: C, + min_idle_time: MIT, + ids: &'a [ID] + ) { + cmd("XCLAIM") + .arg(key) + .arg(group) + .arg(consumer) + .arg(min_idle_time) + .arg(ids) + } + + /// This is the optional arguments version for claiming unacked, pending messages + /// currently checked out by another consumer. + /// + /// ```no_run + /// use redis::{Connection,Commands,RedisResult}; + /// use redis::streams::{StreamClaimOptions,StreamClaimReply}; + /// let client = redis::Client::open("redis://127.0.0.1/0").unwrap(); + /// let mut con = client.get_connection(None).unwrap(); + /// + /// // Claim all pending messages for key "k1", + /// // from group "g1", checked out by consumer "c1" + /// // for 10ms with RETRYCOUNT 2 and FORCE + /// + /// let opts = StreamClaimOptions::default() + /// .with_force() + /// .retry(2); + /// let results: RedisResult = + /// con.xclaim_options("k1", "g1", "c1", 10, &["0"], opts); + /// + /// // All optional arguments return a `Result` with one exception: + /// // Passing JUSTID returns only the message `id` and omits the HashMap for each message. + /// + /// let opts = StreamClaimOptions::default() + /// .with_justid(); + /// let results: RedisResult> = + /// con.xclaim_options("k1", "g1", "c1", 10, &["0"], opts); + /// ``` + /// + /// ```text + /// XCLAIM + /// [IDLE ] [TIME ] [RETRYCOUNT ] + /// [FORCE] [JUSTID] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xclaim_options< + K: ToRedisArgs, + G: ToRedisArgs, + C: ToRedisArgs, + MIT: ToRedisArgs, + ID: ToRedisArgs + >( + key: K, + group: G, + consumer: C, + min_idle_time: MIT, + ids: &'a [ID], + options: streams::StreamClaimOptions + ) { + cmd("XCLAIM") + .arg(key) + .arg(group) + .arg(consumer) + .arg(min_idle_time) + .arg(ids) + .arg(options) + } + + + /// Deletes a list of `id`s for a given stream `key`. + /// + /// ```text + /// XDEL [ ... ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xdel( + key: K, + ids: &'a [ID] + ) { + cmd("XDEL").arg(key).arg(ids) + } + + + /// This command is used for creating a consumer `group`. It expects the stream key + /// to already exist. Otherwise, use `xgroup_create_mkstream` if it doesn't. + /// The `id` is the starting message id all consumers should read from. Use `$` If you want + /// all consumers to read from the last message added to stream. + /// + /// ```text + /// XGROUP CREATE + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_create( + key: K, + group: G, + id: ID + ) { + cmd("XGROUP") + .arg("CREATE") + .arg(key) + .arg(group) + .arg(id) + } + + + /// This is the alternate version for creating a consumer `group` + /// which makes the stream if it doesn't exist. + /// + /// ```text + /// XGROUP CREATE [MKSTREAM] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_create_mkstream< + K: ToRedisArgs, + G: ToRedisArgs, + ID: ToRedisArgs + >( + key: K, + group: G, + id: ID + ) { + cmd("XGROUP") + .arg("CREATE") + .arg(key) + .arg(group) + .arg(id) + .arg("MKSTREAM") + } + + + /// Alter which `id` you want consumers to begin reading from an existing + /// consumer `group`. + /// + /// ```text + /// XGROUP SETID + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_setid( + key: K, + group: G, + id: ID + ) { + cmd("XGROUP") + .arg("SETID") + .arg(key) + .arg(group) + .arg(id) + } + + + /// Destroy an existing consumer `group` for a given stream `key` + /// + /// ```text + /// XGROUP SETID + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_destroy( + key: K, + group: G + ) { + cmd("XGROUP").arg("DESTROY").arg(key).arg(group) + } + + /// This deletes a `consumer` from an existing consumer `group` + /// for given stream `key. + /// + /// ```text + /// XGROUP DELCONSUMER + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xgroup_delconsumer( + key: K, + group: G, + consumer: C + ) { + cmd("XGROUP") + .arg("DELCONSUMER") + .arg(key) + .arg(group) + .arg(consumer) + } + + + /// This returns all info details about + /// which consumers have read messages for given consumer `group`. + /// Take note of the StreamInfoConsumersReply return type. + /// + /// *It's possible this return value might not contain new fields + /// added by Redis in future versions.* + /// + /// ```text + /// XINFO CONSUMERS + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xinfo_consumers( + key: K, + group: G + ) { + cmd("XINFO") + .arg("CONSUMERS") + .arg(key) + .arg(group) + } + + + /// Returns all consumer `group`s created for a given stream `key`. + /// Take note of the StreamInfoGroupsReply return type. + /// + /// *It's possible this return value might not contain new fields + /// added by Redis in future versions.* + /// + /// ```text + /// XINFO GROUPS + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xinfo_groups(key: K) { + cmd("XINFO").arg("GROUPS").arg(key) + } + + + /// Returns info about high-level stream details + /// (first & last message `id`, length, number of groups, etc.) + /// Take note of the StreamInfoStreamReply return type. + /// + /// *It's possible this return value might not contain new fields + /// added by Redis in future versions.* + /// + /// ```text + /// XINFO STREAM + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xinfo_stream(key: K) { + cmd("XINFO").arg("STREAM").arg(key) + } + + /// Returns the number of messages for a given stream `key`. + /// + /// ```text + /// XLEN + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xlen(key: K) { + cmd("XLEN").arg(key) + } + + + /// This is a basic version of making XPENDING command calls which only + /// passes a stream `key` and consumer `group` and it + /// returns details about which consumers have pending messages + /// that haven't been acked. + /// + /// You can use this method along with + /// `xclaim` or `xclaim_options` for determining which messages + /// need to be retried. + /// + /// Take note of the StreamPendingReply return type. + /// + /// ```text + /// XPENDING [ []] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xpending( + key: K, + group: G + ) { + cmd("XPENDING").arg(key).arg(group) + } + + + /// This XPENDING version returns a list of all messages over the range. + /// You can use this for paginating pending messages (but without the message HashMap). + /// + /// Start and end follow the same rules `xrange` args. Set start to `-` + /// and end to `+` for the entire stream. + /// + /// Take note of the StreamPendingCountReply return type. + /// + /// ```text + /// XPENDING + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xpending_count< + K: ToRedisArgs, + G: ToRedisArgs, + S: ToRedisArgs, + E: ToRedisArgs, + C: ToRedisArgs + >( + key: K, + group: G, + start: S, + end: E, + count: C + ) { + cmd("XPENDING") + .arg(key) + .arg(group) + .arg(start) + .arg(end) + .arg(count) + } + + + /// An alternate version of `xpending_count` which filters by `consumer` name. + /// + /// Start and end follow the same rules `xrange` args. Set start to `-` + /// and end to `+` for the entire stream. + /// + /// Take note of the StreamPendingCountReply return type. + /// + /// ```text + /// XPENDING + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xpending_consumer_count< + K: ToRedisArgs, + G: ToRedisArgs, + S: ToRedisArgs, + E: ToRedisArgs, + C: ToRedisArgs, + CN: ToRedisArgs + >( + key: K, + group: G, + start: S, + end: E, + count: C, + consumer: CN + ) { + cmd("XPENDING") + .arg(key) + .arg(group) + .arg(start) + .arg(end) + .arg(count) + .arg(consumer) + } + + /// Returns a range of messages in a given stream `key`. + /// + /// Set `start` to `-` to begin at the first message. + /// Set `end` to `+` to end the most recent message. + /// You can pass message `id` to both `start` and `end`. + /// + /// Take note of the StreamRangeReply return type. + /// + /// ```text + /// XRANGE key start end + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrange( + key: K, + start: S, + end: E + ) { + cmd("XRANGE").arg(key).arg(start).arg(end) + } + + + /// A helper method for automatically returning all messages in a stream by `key`. + /// **Use with caution!** + /// + /// ```text + /// XRANGE key - + + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrange_all(key: K) { + cmd("XRANGE").arg(key).arg("-").arg("+") + } + + + /// A method for paginating a stream by `key`. + /// + /// ```text + /// XRANGE key start end [COUNT ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrange_count( + key: K, + start: S, + end: E, + count: C + ) { + cmd("XRANGE") + .arg(key) + .arg(start) + .arg(end) + .arg("COUNT") + .arg(count) + } + + + /// Read a list of `id`s for each stream `key`. + /// This is the basic form of reading streams. + /// For more advanced control, like blocking, limiting, or reading by consumer `group`, + /// see `xread_options`. + /// + /// ```text + /// XREAD STREAMS key_1 key_2 ... key_N ID_1 ID_2 ... ID_N + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xread( + keys: &'a [K], + ids: &'a [ID] + ) { + cmd("XREAD").arg("STREAMS").arg(keys).arg(ids) + } + + /// This method handles setting optional arguments for + /// `XREAD` or `XREADGROUP` Redis commands. + /// ```no_run + /// use redis::{Connection,RedisResult,Commands}; + /// use redis::streams::{StreamReadOptions,StreamReadReply}; + /// let client = redis::Client::open("redis://127.0.0.1/0").unwrap(); + /// let mut con = client.get_connection(None).unwrap(); + /// + /// // Read 10 messages from the start of the stream, + /// // without registering as a consumer group. + /// + /// let opts = StreamReadOptions::default() + /// .count(10); + /// let results: RedisResult = + /// con.xread_options(&["k1"], &["0"], &opts); + /// + /// // Read all undelivered messages for a given + /// // consumer group. Be advised: the consumer group must already + /// // exist before making this call. Also note: we're passing + /// // '>' as the id here, which means all undelivered messages. + /// + /// let opts = StreamReadOptions::default() + /// .group("group-1", "consumer-1"); + /// let results: RedisResult = + /// con.xread_options(&["k1"], &[">"], &opts); + /// ``` + /// + /// ```text + /// XREAD [BLOCK ] [COUNT ] + /// STREAMS key_1 key_2 ... key_N + /// ID_1 ID_2 ... ID_N + /// + /// XREADGROUP [GROUP group-name consumer-name] [BLOCK ] [COUNT ] [NOACK] + /// STREAMS key_1 key_2 ... key_N + /// ID_1 ID_2 ... ID_N + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xread_options( + keys: &'a [K], + ids: &'a [ID], + options: &'a streams::StreamReadOptions + ) { + cmd(if options.read_only() { + "XREAD" + } else { + "XREADGROUP" + }) + .arg(options) + .arg("STREAMS") + .arg(keys) + .arg(ids) + } + + /// This is the reverse version of `xrange`. + /// The same rules apply for `start` and `end` here. + /// + /// ```text + /// XREVRANGE key end start + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrevrange( + key: K, + end: E, + start: S + ) { + cmd("XREVRANGE").arg(key).arg(end).arg(start) + } + + /// This is the reverse version of `xrange_all`. + /// The same rules apply for `start` and `end` here. + /// + /// ```text + /// XREVRANGE key + - + /// ``` + fn xrevrange_all(key: K) { + cmd("XREVRANGE").arg(key).arg("+").arg("-") + } + + /// This is the reverse version of `xrange_count`. + /// The same rules apply for `start` and `end` here. + /// + /// ```text + /// XREVRANGE key end start [COUNT ] + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xrevrange_count( + key: K, + end: E, + start: S, + count: C + ) { + cmd("XREVRANGE") + .arg(key) + .arg(end) + .arg(start) + .arg("COUNT") + .arg(count) + } + + + /// Trim a stream `key` to a MAXLEN count. + /// + /// ```text + /// XTRIM MAXLEN [~|=] (Same as XADD MAXLEN option) + /// ``` + #[cfg(feature = "streams")] + #[cfg_attr(docsrs, doc(cfg(feature = "streams")))] + fn xtrim( + key: K, + maxlen: streams::StreamMaxlen + ) { + cmd("XTRIM").arg(key).arg(maxlen) + } +} + +/// Allows pubsub callbacks to stop receiving messages. +/// +/// Arbitrary data may be returned from `Break`. +pub enum ControlFlow { + /// Continues. + Continue, + /// Breaks with a value. + Break(U), +} + +/// The PubSub trait allows subscribing to one or more channels +/// and receiving a callback whenever a message arrives. +/// +/// Each method handles subscribing to the list of keys, waiting for +/// messages, and unsubscribing from the same list of channels once +/// a ControlFlow::Break is encountered. +/// +/// Once (p)subscribe returns Ok(U), the connection is again safe to use +/// for calling other methods. +/// +/// # Examples +/// +/// ```rust,no_run +/// # fn do_something() -> redis::RedisResult<()> { +/// use redis::{PubSubCommands, ControlFlow}; +/// let client = redis::Client::open("redis://127.0.0.1/")?; +/// let mut con = client.get_connection(None)?; +/// let mut count = 0; +/// con.subscribe(&["foo"], |msg| { +/// // do something with message +/// assert_eq!(msg.get_channel(), Ok(String::from("foo"))); +/// +/// // increment messages seen counter +/// count += 1; +/// match count { +/// // stop after receiving 10 messages +/// 10 => ControlFlow::Break(()), +/// _ => ControlFlow::Continue, +/// } +/// })?; +/// # Ok(()) } +/// ``` +// TODO In the future, it would be nice to implement Try such that `?` will work +// within the closure. +pub trait PubSubCommands: Sized { + /// Subscribe to a list of channels using SUBSCRIBE and run the provided + /// closure for each message received. + /// + /// For every `Msg` passed to the provided closure, either + /// `ControlFlow::Break` or `ControlFlow::Continue` must be returned. This + /// method will not return until `ControlFlow::Break` is observed. + fn subscribe(&mut self, _: C, _: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + C: ToRedisArgs; + + /// Subscribe to a list of channels using PSUBSCRIBE and run the provided + /// closure for each message received. + /// + /// For every `Msg` passed to the provided closure, either + /// `ControlFlow::Break` or `ControlFlow::Continue` must be returned. This + /// method will not return until `ControlFlow::Break` is observed. + fn psubscribe(&mut self, _: P, _: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + P: ToRedisArgs; +} + +impl Commands for T where T: ConnectionLike {} + +#[cfg(feature = "aio")] +impl AsyncCommands for T where T: crate::aio::ConnectionLike + Send + Sized {} + +impl PubSubCommands for Connection { + fn subscribe(&mut self, channels: C, mut func: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + C: ToRedisArgs, + { + let mut pubsub = self.as_pubsub(); + pubsub.subscribe(channels)?; + + loop { + let msg = pubsub.get_message()?; + match func(msg) { + ControlFlow::Continue => continue, + ControlFlow::Break(value) => return Ok(value), + } + } + } + + fn psubscribe(&mut self, patterns: P, mut func: F) -> RedisResult + where + F: FnMut(Msg) -> ControlFlow, + P: ToRedisArgs, + { + let mut pubsub = self.as_pubsub(); + pubsub.psubscribe(patterns)?; + + loop { + let msg = pubsub.get_message()?; + match func(msg) { + ControlFlow::Continue => continue, + ControlFlow::Break(value) => return Ok(value), + } + } + } +} + +/// Options for the [LPOS](https://redis.io/commands/lpos) command +/// +/// # Example +/// +/// ```rust,no_run +/// use redis::{Commands, RedisResult, LposOptions}; +/// fn fetch_list_position( +/// con: &mut redis::Connection, +/// key: &str, +/// value: &str, +/// count: usize, +/// rank: isize, +/// maxlen: usize, +/// ) -> RedisResult> { +/// let opts = LposOptions::default() +/// .count(count) +/// .rank(rank) +/// .maxlen(maxlen); +/// con.lpos(key, value, opts) +/// } +/// ``` +#[derive(Default)] +pub struct LposOptions { + count: Option, + maxlen: Option, + rank: Option, +} + +impl LposOptions { + /// Limit the results to the first N matching items. + pub fn count(mut self, n: usize) -> Self { + self.count = Some(n); + self + } + + /// Return the value of N from the matching items. + pub fn rank(mut self, n: isize) -> Self { + self.rank = Some(n); + self + } + + /// Limit the search to N items in the list. + pub fn maxlen(mut self, n: usize) -> Self { + self.maxlen = Some(n); + self + } +} + +impl ToRedisArgs for LposOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(n) = self.count { + out.write_arg(b"COUNT"); + out.write_arg_fmt(n); + } + + if let Some(n) = self.rank { + out.write_arg(b"RANK"); + out.write_arg_fmt(n); + } + + if let Some(n) = self.maxlen { + out.write_arg(b"MAXLEN"); + out.write_arg_fmt(n); + } + } + + fn is_single_arg(&self) -> bool { + false + } +} + +/// Enum for the LEFT | RIGHT args used by some commands +pub enum Direction { + /// Targets the first element (head) of the list + Left, + /// Targets the last element (tail) of the list + Right, +} + +impl ToRedisArgs for Direction { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let s: &[u8] = match self { + Direction::Left => b"LEFT", + Direction::Right => b"RIGHT", + }; + out.write_arg(s); + } +} + +/// Options for the [SET](https://redis.io/commands/set) command +/// +/// # Example +/// ```rust,no_run +/// use redis::{Commands, RedisResult, SetOptions, SetExpiry, ExistenceCheck}; +/// fn set_key_value( +/// con: &mut redis::Connection, +/// key: &str, +/// value: &str, +/// ) -> RedisResult> { +/// let opts = SetOptions::default() +/// .conditional_set(ExistenceCheck::NX) +/// .get(true) +/// .with_expiration(SetExpiry::EX(60)); +/// con.set_options(key, value, opts) +/// } +/// ``` +#[derive(Clone, Copy, Default)] +pub struct SetOptions { + conditional_set: Option, + get: bool, + expiration: Option, +} + +impl SetOptions { + /// Set the existence check for the SET command + pub fn conditional_set(mut self, existence_check: ExistenceCheck) -> Self { + self.conditional_set = Some(existence_check); + self + } + + /// Set the GET option for the SET command + pub fn get(mut self, get: bool) -> Self { + self.get = get; + self + } + + /// Set the expiration for the SET command + pub fn with_expiration(mut self, expiration: SetExpiry) -> Self { + self.expiration = Some(expiration); + self + } +} + +impl ToRedisArgs for SetOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref conditional_set) = self.conditional_set { + match conditional_set { + ExistenceCheck::NX => { + out.write_arg(b"NX"); + } + ExistenceCheck::XX => { + out.write_arg(b"XX"); + } + } + } + if self.get { + out.write_arg(b"GET"); + } + if let Some(ref expiration) = self.expiration { + match expiration { + SetExpiry::EX(secs) => { + out.write_arg(b"EX"); + out.write_arg(format!("{}", secs).as_bytes()); + } + SetExpiry::PX(millis) => { + out.write_arg(b"PX"); + out.write_arg(format!("{}", millis).as_bytes()); + } + SetExpiry::EXAT(unix_time) => { + out.write_arg(b"EXAT"); + out.write_arg(format!("{}", unix_time).as_bytes()); + } + SetExpiry::PXAT(unix_time) => { + out.write_arg(b"PXAT"); + out.write_arg(format!("{}", unix_time).as_bytes()); + } + SetExpiry::KEEPTTL => { + out.write_arg(b"KEEPTTL"); + } + } + } + } +} + +/// Creates HELLO command for RESP3 with RedisConnectionInfo +pub fn resp3_hello(connection_info: &RedisConnectionInfo) -> Cmd { + let mut hello_cmd = cmd("HELLO"); + hello_cmd.arg("3"); + if connection_info.password.is_some() { + let username: &str = match connection_info.username.as_ref() { + None => "default", + Some(username) => username, + }; + hello_cmd + .arg("AUTH") + .arg(username) + .arg(connection_info.password.as_ref().unwrap()); + } + hello_cmd +} diff --git a/glide-core/redis-rs/redis/src/connection.rs b/glide-core/redis-rs/redis/src/connection.rs new file mode 100644 index 0000000000..f75b9df494 --- /dev/null +++ b/glide-core/redis-rs/redis/src/connection.rs @@ -0,0 +1,1997 @@ +use std::collections::{HashSet, VecDeque}; +use std::fmt; +use std::io::{self, Write}; +use std::net::{self, SocketAddr, TcpStream, ToSocketAddrs}; +use std::ops::DerefMut; +use std::path::PathBuf; +use std::str::{from_utf8, FromStr}; +use std::time::Duration; + +use crate::cmd::{cmd, pipe, Cmd}; +use crate::parser::Parser; +use crate::pipeline::Pipeline; +use crate::types::{ + from_redis_value, ErrorKind, FromRedisValue, HashMap, PushKind, RedisError, RedisResult, + ToRedisArgs, Value, +}; +use crate::{from_owned_redis_value, ProtocolVersion}; + +#[cfg(unix)] +use std::os::unix::net::UnixStream; +use std::vec::IntoIter; + +use crate::commands::resp3_hello; +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +use native_tls::{TlsConnector, TlsStream}; + +#[cfg(feature = "tls-rustls")] +use rustls::{RootCertStore, StreamOwned}; +#[cfg(feature = "tls-rustls")] +use std::sync::Arc; + +use crate::push_manager::PushManager; +use crate::PushInfo; + +#[cfg(all( + feature = "tls-rustls", + not(feature = "tls-native-tls"), + not(feature = "tls-rustls-webpki-roots") +))] +use rustls_native_certs::load_native_certs; + +#[cfg(feature = "tls-rustls")] +use crate::tls::TlsConnParams; + +// Non-exhaustive to prevent construction outside this crate +#[cfg(not(feature = "tls-rustls"))] +#[derive(Clone, Debug)] +#[non_exhaustive] +pub struct TlsConnParams; + +static DEFAULT_PORT: u16 = 6379; + +#[inline(always)] +fn connect_tcp(addr: (&str, u16)) -> io::Result { + let socket = TcpStream::connect(addr)?; + #[cfg(feature = "tcp_nodelay")] + socket.set_nodelay(true)?; + #[cfg(feature = "keep-alive")] + { + //For now rely on system defaults + const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); + //these are useless error that not going to happen + let socket2: socket2::Socket = socket.into(); + socket2.set_tcp_keepalive(&KEEP_ALIVE)?; + Ok(socket2.into()) + } + #[cfg(not(feature = "keep-alive"))] + { + Ok(socket) + } +} + +#[inline(always)] +fn connect_tcp_timeout(addr: &SocketAddr, timeout: Duration) -> io::Result { + let socket = TcpStream::connect_timeout(addr, timeout)?; + #[cfg(feature = "tcp_nodelay")] + socket.set_nodelay(true)?; + #[cfg(feature = "keep-alive")] + { + //For now rely on system defaults + const KEEP_ALIVE: socket2::TcpKeepalive = socket2::TcpKeepalive::new(); + //these are useless error that not going to happen + let socket2: socket2::Socket = socket.into(); + socket2.set_tcp_keepalive(&KEEP_ALIVE)?; + Ok(socket2.into()) + } + #[cfg(not(feature = "keep-alive"))] + { + Ok(socket) + } +} + +/// This function takes a redis URL string and parses it into a URL +/// as used by rust-url. This is necessary as the default parser does +/// not understand how redis URLs function. +pub fn parse_redis_url(input: &str) -> Option { + match url::Url::parse(input) { + Ok(result) => match result.scheme() { + "redis" | "rediss" | "redis+unix" | "unix" => Some(result), + _ => None, + }, + Err(_) => None, + } +} + +/// TlsMode indicates use or do not use verification of certification. +/// Check [ConnectionAddr](ConnectionAddr::TcpTls::insecure) for more. +#[derive(Clone, Copy)] +pub enum TlsMode { + /// Secure verify certification. + Secure, + /// Insecure do not verify certification. + Insecure, +} + +/// Defines the connection address. +/// +/// Not all connection addresses are supported on all platforms. For instance +/// to connect to a unix socket you need to run this on an operating system +/// that supports them. +#[derive(Clone, Debug)] +pub enum ConnectionAddr { + /// Format for this is `(host, port)`. + Tcp(String, u16), + /// Format for this is `(host, port)`. + TcpTls { + /// Hostname + host: String, + /// Port + port: u16, + /// Disable hostname verification when connecting. + /// + /// # Warning + /// + /// You should think very carefully before you use this method. If hostname + /// verification is not used, any valid certificate for any site will be + /// trusted for use from any other. This introduces a significant + /// vulnerability to man-in-the-middle attacks. + insecure: bool, + + /// TLS certificates and client key. + tls_params: Option, + }, + /// Format for this is the path to the unix socket. + Unix(PathBuf), +} + +impl PartialEq for ConnectionAddr { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (ConnectionAddr::Tcp(host1, port1), ConnectionAddr::Tcp(host2, port2)) => { + host1 == host2 && port1 == port2 + } + ( + ConnectionAddr::TcpTls { + host: host1, + port: port1, + insecure: insecure1, + tls_params: _, + }, + ConnectionAddr::TcpTls { + host: host2, + port: port2, + insecure: insecure2, + tls_params: _, + }, + ) => port1 == port2 && host1 == host2 && insecure1 == insecure2, + (ConnectionAddr::Unix(path1), ConnectionAddr::Unix(path2)) => path1 == path2, + _ => false, + } + } +} + +impl Eq for ConnectionAddr {} + +impl ConnectionAddr { + /// Checks if this address is supported. + /// + /// Because not all platforms support all connection addresses this is a + /// quick way to figure out if a connection method is supported. Currently + /// this only affects unix connections which are only supported on unix + /// platforms and on older versions of rust also require an explicit feature + /// to be enabled. + pub fn is_supported(&self) -> bool { + match *self { + ConnectionAddr::Tcp(_, _) => true, + ConnectionAddr::TcpTls { .. } => { + cfg!(any(feature = "tls-native-tls", feature = "tls-rustls")) + } + ConnectionAddr::Unix(_) => cfg!(unix), + } + } +} + +impl fmt::Display for ConnectionAddr { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + // Cluster::get_connection_info depends on the return value from this function + match *self { + ConnectionAddr::Tcp(ref host, port) => write!(f, "{host}:{port}"), + ConnectionAddr::TcpTls { ref host, port, .. } => write!(f, "{host}:{port}"), + ConnectionAddr::Unix(ref path) => write!(f, "{}", path.display()), + } + } +} + +/// Holds the connection information that redis should use for connecting. +#[derive(Clone, Debug)] +pub struct ConnectionInfo { + /// A connection address for where to connect to. + pub addr: ConnectionAddr, + + /// A boxed connection address for where to connect to. + pub redis: RedisConnectionInfo, +} + +/// Types of pubsub subscriptions +/// See for more details +#[derive(Clone, Debug, PartialEq, Eq, Hash, Copy)] +pub enum PubSubSubscriptionKind { + /// Exact channel name. + /// Receives messages which are published to a specific channel using PUBLISH command. + Exact = 0, + /// Pattern-based channel name. + /// Receives messages which are published to channels matched by glob pattern using PUBLISH command. + Pattern = 1, + /// Sharded pubsub mode. + /// Receives messages which are published to a specific channel using SPUBLISH command. + Sharded = 2, +} + +impl From for usize { + fn from(val: PubSubSubscriptionKind) -> Self { + val as usize + } +} + +/// Type for pubsub channels/patterns +pub type PubSubChannelOrPattern = Vec; + +/// Type for pubsub channels/patterns +pub type PubSubSubscriptionInfo = HashMap>; + +/// Redis specific/connection independent information used to establish a connection to redis. +#[derive(Clone, Debug, Default)] +pub struct RedisConnectionInfo { + /// The database number to use. This is usually `0`. + pub db: i64, + /// Optionally a username that should be used for connection. + pub username: Option, + /// Optionally a password that should be used for connection. + pub password: Option, + /// Version of the protocol to use. + pub protocol: ProtocolVersion, + /// Optionally a client name that should be used for connection + pub client_name: Option, + /// Optionally a pubsub subscriptions that should be used for connection + pub pubsub_subscriptions: Option, +} + +impl FromStr for ConnectionInfo { + type Err = RedisError; + + fn from_str(s: &str) -> Result { + s.into_connection_info() + } +} + +/// Converts an object into a connection info struct. This allows the +/// constructor of the client to accept connection information in a +/// range of different formats. +pub trait IntoConnectionInfo { + /// Converts the object into a connection info object. + fn into_connection_info(self) -> RedisResult; +} + +impl IntoConnectionInfo for ConnectionInfo { + fn into_connection_info(self) -> RedisResult { + Ok(self) + } +} + +/// URL format: `{redis|rediss}://[][:@][:port][/]` +/// +/// - Basic: `redis://127.0.0.1:6379` +/// - Username & Password: `redis://user:password@127.0.0.1:6379` +/// - Password only: `redis://:password@127.0.0.1:6379` +/// - Specifying DB: `redis://127.0.0.1:6379/0` +/// - Enabling TLS: `rediss://127.0.0.1:6379` +/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure` +impl<'a> IntoConnectionInfo for &'a str { + fn into_connection_info(self) -> RedisResult { + match parse_redis_url(self) { + Some(u) => u.into_connection_info(), + None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")), + } + } +} + +impl IntoConnectionInfo for (T, u16) +where + T: Into, +{ + fn into_connection_info(self) -> RedisResult { + Ok(ConnectionInfo { + addr: ConnectionAddr::Tcp(self.0.into(), self.1), + redis: RedisConnectionInfo::default(), + }) + } +} + +/// URL format: `{redis|rediss}://[][:@][:port][/]` +/// +/// - Basic: `redis://127.0.0.1:6379` +/// - Username & Password: `redis://user:password@127.0.0.1:6379` +/// - Password only: `redis://:password@127.0.0.1:6379` +/// - Specifying DB: `redis://127.0.0.1:6379/0` +/// - Enabling TLS: `rediss://127.0.0.1:6379` +/// - Enabling Insecure TLS: `rediss://127.0.0.1:6379/#insecure` +impl IntoConnectionInfo for String { + fn into_connection_info(self) -> RedisResult { + match parse_redis_url(&self) { + Some(u) => u.into_connection_info(), + None => fail!((ErrorKind::InvalidClientConfig, "Redis URL did not parse")), + } + } +} + +fn url_to_tcp_connection_info(url: url::Url) -> RedisResult { + let host = match url.host() { + Some(host) => { + // Here we manually match host's enum arms and call their to_string(). + // Because url.host().to_string() will add `[` and `]` for ipv6: + // https://docs.rs/url/latest/src/url/host.rs.html#170 + // And these brackets will break host.parse::() when + // `client.open()` - `ActualConnection::new()` - `addr.to_socket_addrs()`: + // https://doc.rust-lang.org/src/std/net/addr.rs.html#963 + // https://doc.rust-lang.org/src/std/net/parser.rs.html#158 + // IpAddr string with brackets can ONLY parse to SocketAddrV6: + // https://doc.rust-lang.org/src/std/net/parser.rs.html#255 + // But if we call Ipv6Addr.to_string directly, it follows rfc5952 without brackets: + // https://doc.rust-lang.org/src/std/net/ip.rs.html#1755 + match host { + url::Host::Domain(path) => path.to_string(), + url::Host::Ipv4(v4) => v4.to_string(), + url::Host::Ipv6(v6) => v6.to_string(), + } + } + None => fail!((ErrorKind::InvalidClientConfig, "Missing hostname")), + }; + let port = url.port().unwrap_or(DEFAULT_PORT); + let addr = if url.scheme() == "rediss" { + #[cfg(any(feature = "tls-native-tls", feature = "tls-rustls"))] + { + match url.fragment() { + Some("insecure") => ConnectionAddr::TcpTls { + host, + port, + insecure: true, + tls_params: None, + }, + Some(_) => fail!(( + ErrorKind::InvalidClientConfig, + "only #insecure is supported as URL fragment" + )), + _ => ConnectionAddr::TcpTls { + host, + port, + insecure: false, + tls_params: None, + }, + } + } + + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] + fail!(( + ErrorKind::InvalidClientConfig, + "can't connect with TLS, the feature is not enabled" + )); + } else { + ConnectionAddr::Tcp(host, port) + }; + let query: HashMap<_, _> = url.query_pairs().collect(); + Ok(ConnectionInfo { + addr, + redis: RedisConnectionInfo { + db: match url.path().trim_matches('/') { + "" => 0, + path => path.parse::().map_err(|_| -> RedisError { + (ErrorKind::InvalidClientConfig, "Invalid database number").into() + })?, + }, + username: if url.username().is_empty() { + None + } else { + match percent_encoding::percent_decode(url.username().as_bytes()).decode_utf8() { + Ok(decoded) => Some(decoded.into_owned()), + Err(_) => fail!(( + ErrorKind::InvalidClientConfig, + "Username is not valid UTF-8 string" + )), + } + }, + password: match url.password() { + Some(pw) => match percent_encoding::percent_decode(pw.as_bytes()).decode_utf8() { + Ok(decoded) => Some(decoded.into_owned()), + Err(_) => fail!(( + ErrorKind::InvalidClientConfig, + "Password is not valid UTF-8 string" + )), + }, + None => None, + }, + protocol: match query.get("resp3") { + Some(v) => { + if v == "true" { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::RESP2 + } + } + _ => ProtocolVersion::RESP2, + }, + client_name: None, + pubsub_subscriptions: None, + }, + }) +} + +#[cfg(unix)] +fn url_to_unix_connection_info(url: url::Url) -> RedisResult { + let query: HashMap<_, _> = url.query_pairs().collect(); + Ok(ConnectionInfo { + addr: ConnectionAddr::Unix(url.to_file_path().map_err(|_| -> RedisError { + (ErrorKind::InvalidClientConfig, "Missing path").into() + })?), + redis: RedisConnectionInfo { + db: match query.get("db") { + Some(db) => db.parse::().map_err(|_| -> RedisError { + (ErrorKind::InvalidClientConfig, "Invalid database number").into() + })?, + + None => 0, + }, + username: query.get("user").map(|username| username.to_string()), + password: query.get("pass").map(|password| password.to_string()), + protocol: match query.get("resp3") { + Some(v) => { + if v == "true" { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::RESP2 + } + } + _ => ProtocolVersion::RESP2, + }, + client_name: None, + pubsub_subscriptions: None, + }, + }) +} + +#[cfg(not(unix))] +fn url_to_unix_connection_info(_: url::Url) -> RedisResult { + fail!(( + ErrorKind::InvalidClientConfig, + "Unix sockets are not available on this platform." + )); +} + +impl IntoConnectionInfo for url::Url { + fn into_connection_info(self) -> RedisResult { + match self.scheme() { + "redis" | "rediss" => url_to_tcp_connection_info(self), + "unix" | "redis+unix" => url_to_unix_connection_info(self), + _ => fail!(( + ErrorKind::InvalidClientConfig, + "URL provided is not a redis URL" + )), + } + } +} + +struct TcpConnection { + reader: TcpStream, + open: bool, +} + +#[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] +struct TcpNativeTlsConnection { + reader: TlsStream, + open: bool, +} + +#[cfg(feature = "tls-rustls")] +struct TcpRustlsConnection { + reader: StreamOwned, + open: bool, +} + +#[cfg(unix)] +struct UnixConnection { + sock: UnixStream, + open: bool, +} + +enum ActualConnection { + Tcp(TcpConnection), + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + TcpNativeTls(Box), + #[cfg(feature = "tls-rustls")] + TcpRustls(Box), + #[cfg(unix)] + Unix(UnixConnection), +} + +#[cfg(feature = "tls-rustls-insecure")] +struct NoCertificateVerification { + supported: rustls::crypto::WebPkiSupportedAlgorithms, +} + +#[cfg(feature = "tls-rustls-insecure")] +impl rustls::client::danger::ServerCertVerifier for NoCertificateVerification { + fn verify_server_cert( + &self, + _end_entity: &rustls_pki_types::CertificateDer<'_>, + _intermediates: &[rustls_pki_types::CertificateDer<'_>], + _server_name: &rustls_pki_types::ServerName<'_>, + _ocsp_response: &[u8], + _now: rustls_pki_types::UnixTime, + ) -> Result { + Ok(rustls::client::danger::ServerCertVerified::assertion()) + } + + fn verify_tls12_signature( + &self, + _message: &[u8], + _cert: &rustls_pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn verify_tls13_signature( + &self, + _message: &[u8], + _cert: &rustls_pki_types::CertificateDer<'_>, + _dss: &rustls::DigitallySignedStruct, + ) -> Result { + Ok(rustls::client::danger::HandshakeSignatureValid::assertion()) + } + + fn supported_verify_schemes(&self) -> Vec { + self.supported.supported_schemes() + } +} + +#[cfg(feature = "tls-rustls-insecure")] +impl fmt::Debug for NoCertificateVerification { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.debug_struct("NoCertificateVerification").finish() + } +} + +/// Represents a stateful redis TCP connection. +pub struct Connection { + con: ActualConnection, + parser: Parser, + db: i64, + + /// Flag indicating whether the connection was left in the PubSub state after dropping `PubSub`. + /// + /// This flag is checked when attempting to send a command, and if it's raised, we attempt to + /// exit the pubsub state before executing the new request. + pubsub: bool, + + // Field indicating which protocol to use for server communications. + protocol: ProtocolVersion, + + /// `PushManager` instance for the connection. + /// This is used to manage Push messages in RESP3 mode. + push_manager: PushManager, +} + +/// Represents a pubsub connection. +pub struct PubSub<'a> { + con: &'a mut Connection, + waiting_messages: VecDeque, +} + +/// Represents a pubsub message. +#[derive(Debug)] +pub struct Msg { + payload: Value, + channel: Value, + pattern: Option, +} + +impl ActualConnection { + pub fn new(addr: &ConnectionAddr, timeout: Option) -> RedisResult { + Ok(match *addr { + ConnectionAddr::Tcp(ref host, ref port) => { + let addr = (host.as_str(), *port); + let tcp = match timeout { + None => connect_tcp(addr)?, + Some(timeout) => { + let mut tcp = None; + let mut last_error = None; + for addr in addr.to_socket_addrs()? { + match connect_tcp_timeout(&addr, timeout) { + Ok(l) => { + tcp = Some(l); + break; + } + Err(e) => { + last_error = Some(e); + } + }; + } + match (tcp, last_error) { + (Some(tcp), _) => tcp, + (None, Some(e)) => { + fail!(e); + } + (None, None) => { + fail!(( + ErrorKind::InvalidClientConfig, + "could not resolve to any addresses" + )); + } + } + } + }; + ActualConnection::Tcp(TcpConnection { + reader: tcp, + open: true, + }) + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ConnectionAddr::TcpTls { + ref host, + port, + insecure, + .. + } => { + let tls_connector = if insecure { + TlsConnector::builder() + .danger_accept_invalid_certs(true) + .danger_accept_invalid_hostnames(true) + .use_sni(false) + .build()? + } else { + TlsConnector::new()? + }; + let addr = (host.as_str(), port); + let tls = match timeout { + None => { + let tcp = connect_tcp(addr)?; + match tls_connector.connect(host, tcp) { + Ok(res) => res, + Err(e) => { + fail!((ErrorKind::IoError, "SSL Handshake error", e.to_string())); + } + } + } + Some(timeout) => { + let mut tcp = None; + let mut last_error = None; + for addr in (host.as_str(), port).to_socket_addrs()? { + match connect_tcp_timeout(&addr, timeout) { + Ok(l) => { + tcp = Some(l); + break; + } + Err(e) => { + last_error = Some(e); + } + }; + } + match (tcp, last_error) { + (Some(tcp), _) => tls_connector.connect(host, tcp).unwrap(), + (None, Some(e)) => { + fail!(e); + } + (None, None) => { + fail!(( + ErrorKind::InvalidClientConfig, + "could not resolve to any addresses" + )); + } + } + } + }; + ActualConnection::TcpNativeTls(Box::new(TcpNativeTlsConnection { + reader: tls, + open: true, + })) + } + #[cfg(feature = "tls-rustls")] + ConnectionAddr::TcpTls { + ref host, + port, + insecure, + ref tls_params, + } => { + let host: &str = host; + let config = create_rustls_config(insecure, tls_params.clone())?; + let conn = rustls::ClientConnection::new( + Arc::new(config), + rustls_pki_types::ServerName::try_from(host)?.to_owned(), + )?; + let reader = match timeout { + None => { + let tcp = connect_tcp((host, port))?; + StreamOwned::new(conn, tcp) + } + Some(timeout) => { + let mut tcp = None; + let mut last_error = None; + for addr in (host, port).to_socket_addrs()? { + match connect_tcp_timeout(&addr, timeout) { + Ok(l) => { + tcp = Some(l); + break; + } + Err(e) => { + last_error = Some(e); + } + }; + } + match (tcp, last_error) { + (Some(tcp), _) => StreamOwned::new(conn, tcp), + (None, Some(e)) => { + fail!(e); + } + (None, None) => { + fail!(( + ErrorKind::InvalidClientConfig, + "could not resolve to any addresses" + )); + } + } + } + }; + + ActualConnection::TcpRustls(Box::new(TcpRustlsConnection { reader, open: true })) + } + #[cfg(not(any(feature = "tls-native-tls", feature = "tls-rustls")))] + ConnectionAddr::TcpTls { .. } => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot connect to TCP with TLS without the tls feature" + )); + } + #[cfg(unix)] + ConnectionAddr::Unix(ref path) => ActualConnection::Unix(UnixConnection { + sock: UnixStream::connect(path)?, + open: true, + }), + #[cfg(not(unix))] + ConnectionAddr::Unix(ref _path) => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot connect to unix sockets \ + on this platform" + )); + } + }) + } + + pub fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult { + match *self { + ActualConnection::Tcp(ref mut connection) => { + let res = connection.reader.write_all(bytes).map_err(RedisError::from); + match res { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut connection) => { + let res = connection.reader.write_all(bytes).map_err(RedisError::from); + match res { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut connection) => { + let res = connection.reader.write_all(bytes).map_err(RedisError::from); + match res { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + #[cfg(unix)] + ActualConnection::Unix(ref mut connection) => { + let result = connection.sock.write_all(bytes).map_err(RedisError::from); + match result { + Err(e) => { + if e.is_unrecoverable_error() { + connection.open = false; + } + Err(e) + } + Ok(_) => Ok(Value::Okay), + } + } + } + } + + pub fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + match *self { + ActualConnection::Tcp(TcpConnection { ref reader, .. }) => { + reader.set_write_timeout(dur)?; + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_write_timeout(dur)?; + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_write_timeout(dur)?; + } + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { ref sock, .. }) => { + sock.set_write_timeout(dur)?; + } + } + Ok(()) + } + + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + match *self { + ActualConnection::Tcp(TcpConnection { ref reader, .. }) => { + reader.set_read_timeout(dur)?; + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_read_timeout(dur)?; + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => { + let reader = &(boxed_tls_connection.reader); + reader.get_ref().set_read_timeout(dur)?; + } + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { ref sock, .. }) => { + sock.set_read_timeout(dur)?; + } + } + Ok(()) + } + + pub fn is_open(&self) -> bool { + match *self { + ActualConnection::Tcp(TcpConnection { open, .. }) => open, + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref boxed_tls_connection) => boxed_tls_connection.open, + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref boxed_tls_connection) => boxed_tls_connection.open, + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { open, .. }) => open, + } + } +} + +#[cfg(feature = "tls-rustls")] +pub(crate) fn create_rustls_config( + insecure: bool, + tls_params: Option, +) -> RedisResult { + use crate::tls::ClientTlsParams; + + #[allow(unused_mut)] + let mut root_store = RootCertStore::empty(); + #[cfg(feature = "tls-rustls-webpki-roots")] + root_store.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned()); + #[cfg(all( + feature = "tls-rustls", + not(feature = "tls-native-tls"), + not(feature = "tls-rustls-webpki-roots") + ))] + for cert in load_native_certs()? { + root_store.add(cert)?; + } + + let config = rustls::ClientConfig::builder(); + let config = if let Some(tls_params) = tls_params { + let config_builder = + config.with_root_certificates(tls_params.root_cert_store.unwrap_or(root_store)); + + if let Some(ClientTlsParams { + client_cert_chain: client_cert, + client_key, + }) = tls_params.client_tls_params + { + config_builder + .with_client_auth_cert(client_cert, client_key) + .map_err(|err| { + RedisError::from(( + ErrorKind::InvalidClientConfig, + "Unable to build client with TLS parameters provided.", + err.to_string(), + )) + })? + } else { + config_builder.with_no_client_auth() + } + } else { + config + .with_root_certificates(root_store) + .with_no_client_auth() + }; + + match (insecure, cfg!(feature = "tls-rustls-insecure")) { + #[cfg(feature = "tls-rustls-insecure")] + (true, true) => { + let mut config = config; + config.enable_sni = false; + // nosemgrep + config + .dangerous() + .set_certificate_verifier(Arc::new(NoCertificateVerification { + supported: rustls::crypto::ring::default_provider() + .signature_verification_algorithms, + })); + + Ok(config) + } + (true, false) => { + fail!(( + ErrorKind::InvalidClientConfig, + "Cannot create insecure client without tls-rustls-insecure feature" + )); + } + _ => Ok(config), + } +} + +fn connect_auth(con: &mut Connection, connection_info: &RedisConnectionInfo) -> RedisResult<()> { + let mut command = cmd("AUTH"); + if let Some(username) = &connection_info.username { + command.arg(username); + } + let password = connection_info.password.as_ref().unwrap(); + let err = match command.arg(password).query::(con) { + Ok(Value::Okay) => return Ok(()), + Ok(_) => { + fail!(( + ErrorKind::ResponseError, + "Redis server refused to authenticate, returns Ok() != Value::Okay" + )); + } + Err(e) => e, + }; + let err_msg = err.detail().ok_or(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + ))?; + if !err_msg.contains("wrong number of arguments for 'auth' command") { + fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + )); + } + + // fallback to AUTH version <= 5 + let mut command = cmd("AUTH"); + match command.arg(password).query::(con) { + Ok(Value::Okay) => Ok(()), + _ => fail!(( + ErrorKind::AuthenticationFailed, + "Password authentication failed", + )), + } +} + +pub fn connect( + connection_info: &ConnectionInfo, + timeout: Option, +) -> RedisResult { + let con = ActualConnection::new(&connection_info.addr, timeout)?; + setup_connection(con, &connection_info.redis) +} + +#[cfg(not(feature = "disable-client-setinfo"))] +pub(crate) fn client_set_info_pipeline() -> Pipeline { + let mut pipeline = crate::pipe(); + pipeline + .cmd("CLIENT") + .arg("SETINFO") + .arg("LIB-NAME") + .arg(std::env!("GLIDE_NAME")) + .ignore(); + pipeline + .cmd("CLIENT") + .arg("SETINFO") + .arg("LIB-VER") + .arg(std::env!("GLIDE_VERSION")) + .ignore(); + pipeline +} + +fn setup_connection( + con: ActualConnection, + connection_info: &RedisConnectionInfo, +) -> RedisResult { + let mut rv = Connection { + con, + parser: Parser::new(), + db: connection_info.db, + pubsub: false, + protocol: connection_info.protocol, + push_manager: PushManager::new(), + }; + + if connection_info.protocol != ProtocolVersion::RESP2 { + let hello_cmd = resp3_hello(connection_info); + let val: RedisResult = hello_cmd.query(&mut rv); + if let Err(err) = val { + return Err(get_resp3_hello_command_error(err)); + } + } else if connection_info.password.is_some() { + connect_auth(&mut rv, connection_info)?; + } + if connection_info.db != 0 { + match cmd("SELECT") + .arg(connection_info.db) + .query::(&mut rv) + { + Ok(Value::Okay) => {} + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to switch database" + )), + } + } + + if connection_info.client_name.is_some() { + match cmd("CLIENT") + .arg("SETNAME") + .arg(connection_info.client_name.as_ref().unwrap()) + .query::(&mut rv) + { + Ok(Value::Okay) => {} + _ => fail!(( + ErrorKind::ResponseError, + "Redis server refused to set client name" + )), + } + } + + // result is ignored, as per the command's instructions. + // https://redis.io/commands/client-setinfo/ + #[cfg(not(feature = "disable-client-setinfo"))] + let _: RedisResult<()> = client_set_info_pipeline().query(&mut rv); + + Ok(rv) +} + +/// Implements the "stateless" part of the connection interface that is used by the +/// different objects in redis-rs. Primarily it obviously applies to `Connection` +/// object but also some other objects implement the interface (for instance +/// whole clients or certain redis results). +/// +/// Generally clients and connections (as well as redis results of those) implement +/// this trait. Actual connections provide more functionality which can be used +/// to implement things like `PubSub` but they also can modify the intrinsic +/// state of the TCP connection. This is not possible with `ConnectionLike` +/// implementors because that functionality is not exposed. +pub trait ConnectionLike { + /// Sends an already encoded (packed) command into the TCP socket and + /// reads the single response from it. + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult; + + /// Sends multiple already encoded (packed) command into the TCP socket + /// and reads `count` responses from it. This is used to implement + /// pipelining. + /// Important - this function is meant for internal usage, since it's + /// easy to pass incorrect `offset` & `count` parameters, which might + /// cause the connection to enter an erroneous state. Users shouldn't + /// call it, instead using the Pipeline::query function. + #[doc(hidden)] + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult>; + + /// Sends a [Cmd] into the TCP socket and reads a single response from it. + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + let pcmd = cmd.get_packed_command(); + self.req_packed_command(&pcmd) + } + + /// Returns the database this connection is bound to. Note that this + /// information might be unreliable because it's initially cached and + /// also might be incorrect if the connection like object is not + /// actually connected. + fn get_db(&self) -> i64; + + /// Does this connection support pipelining? + #[doc(hidden)] + fn supports_pipelining(&self) -> bool { + true + } + + /// Check that all connections it has are available (`PING` internally). + fn check_connection(&mut self) -> bool; + + /// Returns the connection status. + /// + /// The connection is open until any `read_response` call recieved an + /// invalid response from the server (most likely a closed or dropped + /// connection, otherwise a Redis protocol error). When using unix + /// sockets the connection is open until writing a command failed with a + /// `BrokenPipe` error. + fn is_open(&self) -> bool; +} + +/// A connection is an object that represents a single redis connection. It +/// provides basic support for sending encoded commands into a redis connection +/// and to read a response from it. It's bound to a single database and can +/// only be created from the client. +/// +/// You generally do not much with this object other than passing it to +/// `Cmd` objects. +impl Connection { + /// Sends an already encoded (packed) command into the TCP socket and + /// does not read a response. This is useful for commands like + /// `MONITOR` which yield multiple items. This needs to be used with + /// care because it changes the state of the connection. + pub fn send_packed_command(&mut self, cmd: &[u8]) -> RedisResult<()> { + self.send_bytes(cmd)?; + Ok(()) + } + + /// Fetches a single response from the connection. This is useful + /// if used in combination with `send_packed_command`. + pub fn recv_response(&mut self) -> RedisResult { + self.read_response() + } + + /// Sets the write timeout for the connection. + /// + /// If the provided value is `None`, then `send_packed_command` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_write_timeout(&self, dur: Option) -> RedisResult<()> { + self.con.set_write_timeout(dur) + } + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `recv_response` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + self.con.set_read_timeout(dur) + } + + /// Creates a [`PubSub`] instance for this connection. + pub fn as_pubsub(&mut self) -> PubSub<'_> { + // NOTE: The pubsub flag is intentionally not raised at this time since + // running commands within the pubsub state should not try and exit from + // the pubsub state. + PubSub::new(self) + } + + fn exit_pubsub(&mut self) -> RedisResult<()> { + let res = self.clear_active_subscriptions(); + if res.is_ok() { + self.pubsub = false; + } else { + // Raise the pubsub flag to indicate the connection is "stuck" in that state. + self.pubsub = true; + } + + res + } + + /// Get the inner connection out of a PubSub + /// + /// Any active subscriptions are unsubscribed. In the event of an error, the connection is + /// dropped. + fn clear_active_subscriptions(&mut self) -> RedisResult<()> { + // Responses to unsubscribe commands return in a 3-tuple with values + // ("unsubscribe" or "punsubscribe", name of subscription removed, count of remaining subs). + // The "count of remaining subs" includes both pattern subscriptions and non pattern + // subscriptions. Thus, to accurately drain all unsubscribe messages received from the + // server, both commands need to be executed at once. + { + // Prepare both unsubscribe commands + let unsubscribe = cmd("UNSUBSCRIBE").get_packed_command(); + let punsubscribe = cmd("PUNSUBSCRIBE").get_packed_command(); + + // Execute commands + self.send_bytes(&unsubscribe)?; + self.send_bytes(&punsubscribe)?; + } + + // Receive responses + // + // There will be at minimum two responses - 1 for each of punsubscribe and unsubscribe + // commands. There may be more responses if there are active subscriptions. In this case, + // messages are received until the _subscription count_ in the responses reach zero. + let mut received_unsub = false; + let mut received_punsub = false; + if self.protocol != ProtocolVersion::RESP2 { + while let Value::Push { kind, data } = from_owned_redis_value(self.recv_response()?)? { + if data.len() >= 2 { + if let Value::Int(num) = data[1] { + if resp3_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &kind, + num as isize, + ) { + break; + } + } + } + } + } else { + loop { + let res: (Vec, (), isize) = from_owned_redis_value(self.recv_response()?)?; + if resp2_is_pub_sub_state_cleared( + &mut received_unsub, + &mut received_punsub, + &res.0, + res.2, + ) { + break; + } + } + } + + // Finally, the connection is back in its normal state since all subscriptions were + // cancelled *and* all unsubscribe messages were received. + Ok(()) + } + + /// Fetches a single response from the connection. + fn read_response(&mut self) -> RedisResult { + let result = match self.con { + ActualConnection::Tcp(TcpConnection { ref mut reader, .. }) => { + let result = self.parser.parse_value(reader); + self.push_manager.try_send(&result); + result + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut boxed_tls_connection) => { + let reader = &mut boxed_tls_connection.reader; + let result = self.parser.parse_value(reader); + self.push_manager.try_send(&result); + result + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut boxed_tls_connection) => { + let reader = &mut boxed_tls_connection.reader; + let result = self.parser.parse_value(reader); + self.push_manager.try_send(&result); + result + } + #[cfg(unix)] + ActualConnection::Unix(UnixConnection { ref mut sock, .. }) => { + let result = self.parser.parse_value(sock); + self.push_manager.try_send(&result); + result + } + }; + // shutdown connection on protocol error + if let Err(e) = &result { + let shutdown = match e.as_io_error() { + Some(e) => e.kind() == io::ErrorKind::UnexpectedEof, + None => false, + }; + if shutdown { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + match self.con { + ActualConnection::Tcp(ref mut connection) => { + let _ = connection.reader.shutdown(net::Shutdown::Both); + connection.open = false; + } + #[cfg(all(feature = "tls-native-tls", not(feature = "tls-rustls")))] + ActualConnection::TcpNativeTls(ref mut connection) => { + let _ = connection.reader.shutdown(); + connection.open = false; + } + #[cfg(feature = "tls-rustls")] + ActualConnection::TcpRustls(ref mut connection) => { + let _ = connection.reader.get_mut().shutdown(net::Shutdown::Both); + connection.open = false; + } + #[cfg(unix)] + ActualConnection::Unix(ref mut connection) => { + let _ = connection.sock.shutdown(net::Shutdown::Both); + connection.open = false; + } + } + } + } + result + } + + /// Returns `PushManager` of Connection, this method is used to subscribe/unsubscribe from Push types + pub fn get_push_manager(&self) -> PushManager { + self.push_manager.clone() + } + + fn send_bytes(&mut self, bytes: &[u8]) -> RedisResult { + let result = self.con.send_bytes(bytes); + if self.protocol != ProtocolVersion::RESP2 { + if let Err(e) = &result { + if e.is_connection_dropped() { + // Notify the PushManager that the connection was lost + self.push_manager.try_send_raw(&Value::Push { + kind: PushKind::Disconnection, + data: vec![], + }); + } + } + } + result + } +} + +impl ConnectionLike for Connection { + /// Sends a [Cmd] into the TCP socket and reads a single response from it. + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + let pcmd = cmd.get_packed_command(); + if self.pubsub { + self.exit_pubsub()?; + } + + self.send_bytes(&pcmd)?; + if cmd.is_no_response() { + return Ok(Value::Nil); + } + loop { + match self.read_response()? { + Value::Push { + kind: _kind, + data: _data, + } => continue, + val => return Ok(val), + } + } + } + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + if self.pubsub { + self.exit_pubsub()?; + } + + self.send_bytes(cmd)?; + loop { + match self.read_response()? { + Value::Push { + kind: _kind, + data: _data, + } => continue, + val => return Ok(val), + } + } + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + if self.pubsub { + self.exit_pubsub()?; + } + self.send_bytes(cmd)?; + let mut rv = vec![]; + let mut first_err = None; + let mut count = count; + let mut idx = 0; + while idx < (offset + count) { + // When processing a transaction, some responses may be errors. + // We need to keep processing the rest of the responses in that case, + // so bailing early with `?` would not be correct. + // See: https://github.com/redis-rs/redis-rs/issues/436 + let response = self.read_response(); + match response { + Ok(item) => { + // RESP3 can insert push data between command replies + if let Value::Push { + kind: _kind, + data: _data, + } = item + { + // if that is the case we have to extend the loop and handle push data + count += 1; + } else if idx >= offset { + rv.push(item); + } + } + Err(err) => { + if first_err.is_none() { + first_err = Some(err); + } + } + } + idx += 1; + } + + first_err.map_or(Ok(rv), Err) + } + + fn get_db(&self) -> i64 { + self.db + } + + fn check_connection(&mut self) -> bool { + cmd("PING").query::(self).is_ok() + } + + fn is_open(&self) -> bool { + self.con.is_open() + } +} + +impl ConnectionLike for T +where + C: ConnectionLike, + T: DerefMut, +{ + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + self.deref_mut().req_packed_command(cmd) + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + count: usize, + ) -> RedisResult> { + self.deref_mut().req_packed_commands(cmd, offset, count) + } + + fn req_command(&mut self, cmd: &Cmd) -> RedisResult { + self.deref_mut().req_command(cmd) + } + + fn get_db(&self) -> i64 { + self.deref().get_db() + } + + fn supports_pipelining(&self) -> bool { + self.deref().supports_pipelining() + } + + fn check_connection(&mut self) -> bool { + self.deref_mut().check_connection() + } + + fn is_open(&self) -> bool { + self.deref().is_open() + } +} + +/// The pubsub object provides convenient access to the redis pubsub +/// system. Once created you can subscribe and unsubscribe from channels +/// and listen in on messages. +/// +/// Example: +/// +/// ```rust,no_run +/// # fn do_something() -> redis::RedisResult<()> { +/// let client = redis::Client::open("redis://127.0.0.1/")?; +/// let mut con = client.get_connection(None)?; +/// let mut pubsub = con.as_pubsub(); +/// pubsub.subscribe("channel_1")?; +/// pubsub.subscribe("channel_2")?; +/// +/// loop { +/// let msg = pubsub.get_message()?; +/// let payload : String = msg.get_payload()?; +/// println!("channel '{}': {}", msg.get_channel_name(), payload); +/// } +/// # } +/// ``` +impl<'a> PubSub<'a> { + fn new(con: &'a mut Connection) -> Self { + Self { + con, + waiting_messages: VecDeque::new(), + } + } + + fn cache_messages_until_received_response(&mut self, cmd: &mut Cmd) -> RedisResult<()> { + if self.con.protocol != ProtocolVersion::RESP2 { + cmd.set_no_response(true); + } + let mut response = cmd.query(self.con)?; + loop { + if let Some(msg) = Msg::from_value(&response) { + self.waiting_messages.push_back(msg); + } else { + return Ok(()); + } + response = self.con.recv_response()?; + } + } + + /// Subscribes to a new channel. + pub fn subscribe(&mut self, channel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("SUBSCRIBE").arg(channel)) + } + + /// Subscribes to a new channel with a pattern. + pub fn psubscribe(&mut self, pchannel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("PSUBSCRIBE").arg(pchannel)) + } + + /// Unsubscribes from a channel. + pub fn unsubscribe(&mut self, channel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("UNSUBSCRIBE").arg(channel)) + } + + /// Unsubscribes from a channel with a pattern. + pub fn punsubscribe(&mut self, pchannel: T) -> RedisResult<()> { + self.cache_messages_until_received_response(cmd("PUNSUBSCRIBE").arg(pchannel)) + } + + /// Fetches the next message from the pubsub connection. Blocks until + /// a message becomes available. This currently does not provide a + /// wait not to block :( + /// + /// The message itself is still generic and can be converted into an + /// appropriate type through the helper methods on it. + pub fn get_message(&mut self) -> RedisResult { + if let Some(msg) = self.waiting_messages.pop_front() { + return Ok(msg); + } + loop { + if let Some(msg) = Msg::from_value(&self.con.recv_response()?) { + return Ok(msg); + } else { + continue; + } + } + } + + /// Sets the read timeout for the connection. + /// + /// If the provided value is `None`, then `get_message` call will + /// block indefinitely. It is an error to pass the zero `Duration` to this + /// method. + pub fn set_read_timeout(&self, dur: Option) -> RedisResult<()> { + self.con.set_read_timeout(dur) + } +} + +impl<'a> Drop for PubSub<'a> { + fn drop(&mut self) { + let _ = self.con.exit_pubsub(); + } +} + +/// This holds the data that comes from listening to a pubsub +/// connection. It only contains actual message data. +impl Msg { + /// Tries to convert provided [`Value`] into [`Msg`]. + #[allow(clippy::unnecessary_to_owned)] + pub fn from_value(value: &Value) -> Option { + let mut pattern = None; + let payload; + let channel; + + if let Value::Push { kind, data } = value { + let mut iter: IntoIter = data.to_vec().into_iter(); + if kind == &PushKind::Message || kind == &PushKind::SMessage { + channel = iter.next()?; + payload = iter.next()?; + } else if kind == &PushKind::PMessage { + pattern = Some(iter.next()?); + channel = iter.next()?; + payload = iter.next()?; + } else { + return None; + } + } else { + let raw_msg: Vec = from_redis_value(value).ok()?; + let mut iter = raw_msg.into_iter(); + let msg_type: String = from_owned_redis_value(iter.next()?).ok()?; + if msg_type == "message" { + channel = iter.next()?; + payload = iter.next()?; + } else if msg_type == "pmessage" { + pattern = Some(iter.next()?); + channel = iter.next()?; + payload = iter.next()?; + } else { + return None; + } + }; + Some(Msg { + payload, + channel, + pattern, + }) + } + + /// Tries to convert provided [`PushInfo`] into [`Msg`]. + pub fn from_push_info(push_info: &PushInfo) -> Option { + let mut pattern = None; + let payload; + let channel; + + let mut iter = push_info.data.iter().cloned(); + if push_info.kind == PushKind::Message || push_info.kind == PushKind::SMessage { + channel = iter.next()?; + payload = iter.next()?; + } else if push_info.kind == PushKind::PMessage { + pattern = Some(iter.next()?); + channel = iter.next()?; + payload = iter.next()?; + } else { + return None; + } + + Some(Msg { + payload, + channel, + pattern, + }) + } + + /// Returns the channel this message came on. + pub fn get_channel(&self) -> RedisResult { + from_redis_value(&self.channel) + } + + /// Convenience method to get a string version of the channel. Unless + /// your channel contains non utf-8 bytes you can always use this + /// method. If the channel is not a valid string (which really should + /// not happen) then the return value is `"?"`. + pub fn get_channel_name(&self) -> &str { + match self.channel { + Value::BulkString(ref bytes) => from_utf8(bytes).unwrap_or("?"), + _ => "?", + } + } + + /// Returns the message's payload in a specific format. + pub fn get_payload(&self) -> RedisResult { + from_redis_value(&self.payload) + } + + /// Returns the bytes that are the message's payload. This can be used + /// as an alternative to the `get_payload` function if you are interested + /// in the raw bytes in it. + pub fn get_payload_bytes(&self) -> &[u8] { + match self.payload { + Value::BulkString(ref bytes) => bytes, + _ => b"", + } + } + + /// Returns true if the message was constructed from a pattern + /// subscription. + #[allow(clippy::wrong_self_convention)] + pub fn from_pattern(&self) -> bool { + self.pattern.is_some() + } + + /// If the message was constructed from a message pattern this can be + /// used to find out which one. It's recommended to match against + /// an `Option` so that you do not need to use `from_pattern` + /// to figure out if a pattern was set. + pub fn get_pattern(&self) -> RedisResult { + match self.pattern { + None => from_redis_value(&Value::Nil), + Some(ref x) => from_redis_value(x), + } + } +} + +/// This function simplifies transaction management slightly. What it +/// does is automatically watching keys and then going into a transaction +/// loop util it succeeds. Once it goes through the results are +/// returned. +/// +/// To use the transaction two pieces of information are needed: a list +/// of all the keys that need to be watched for modifications and a +/// closure with the code that should be execute in the context of the +/// transaction. The closure is invoked with a fresh pipeline in atomic +/// mode. To use the transaction the function needs to return the result +/// from querying the pipeline with the connection. +/// +/// The end result of the transaction is then available as the return +/// value from the function call. +/// +/// Example: +/// +/// ```rust,no_run +/// use redis::Commands; +/// # fn do_something() -> redis::RedisResult<()> { +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let key = "the_key"; +/// let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { +/// let old_val : isize = con.get(key)?; +/// pipe +/// .set(key, old_val + 1).ignore() +/// .get(key).query(con) +/// })?; +/// println!("The incremented number is: {}", new_val); +/// # Ok(()) } +/// ``` +pub fn transaction< + C: ConnectionLike, + K: ToRedisArgs, + T, + F: FnMut(&mut C, &mut Pipeline) -> RedisResult>, +>( + con: &mut C, + keys: &[K], + func: F, +) -> RedisResult { + let mut func = func; + loop { + cmd("WATCH").arg(keys).query::<()>(con)?; + let mut p = pipe(); + let response: Option = func(con, p.atomic())?; + match response { + None => { + continue; + } + Some(response) => { + // make sure no watch is left in the connection, even if + // someone forgot to use the pipeline. + cmd("UNWATCH").query::<()>(con)?; + return Ok(response); + } + } + } +} +//TODO: for both clearing logic support sharded channels. + +/// Common logic for clearing subscriptions in RESP2 async/sync +pub fn resp2_is_pub_sub_state_cleared( + received_unsub: &mut bool, + received_punsub: &mut bool, + kind: &[u8], + num: isize, +) -> bool { + match kind.first() { + Some(&b'u') => *received_unsub = true, + Some(&b'p') => *received_punsub = true, + _ => (), + }; + *received_unsub && *received_punsub && num == 0 +} + +/// Common logic for clearing subscriptions in RESP3 async/sync +pub fn resp3_is_pub_sub_state_cleared( + received_unsub: &mut bool, + received_punsub: &mut bool, + kind: &PushKind, + num: isize, +) -> bool { + match kind { + PushKind::Unsubscribe => *received_unsub = true, + PushKind::PUnsubscribe => *received_punsub = true, + _ => (), + }; + *received_unsub && *received_punsub && num == 0 +} + +/// Common logic for checking real cause of hello3 command error +pub fn get_resp3_hello_command_error(err: RedisError) -> RedisError { + if let Some(detail) = err.detail() { + if detail.starts_with("unknown command `HELLO`") { + return ( + ErrorKind::RESP3NotSupported, + "Redis Server doesn't support HELLO command therefore resp3 cannot be used", + ) + .into(); + } + } + err +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_redis_url() { + let cases = vec![ + ("redis://127.0.0.1", true), + ("redis://[::1]", true), + ("redis+unix:///run/redis.sock", true), + ("unix:///run/redis.sock", true), + ("http://127.0.0.1", false), + ("tcp://127.0.0.1", false), + ]; + for (url, expected) in cases.into_iter() { + let res = parse_redis_url(url); + assert_eq!( + res.is_some(), + expected, + "Parsed result of `{url}` is not expected", + ); + } + } + + #[test] + fn test_url_to_tcp_connection_info() { + let cases = vec![ + ( + url::Url::parse("redis://127.0.0.1").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Tcp("127.0.0.1".to_string(), 6379), + redis: Default::default(), + }, + ), + ( + url::Url::parse("redis://[::1]").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Tcp("::1".to_string(), 6379), + redis: Default::default(), + }, + ), + ( + url::Url::parse("redis://%25johndoe%25:%23%40%3C%3E%24@example.com/2").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Tcp("example.com".to_string(), 6379), + redis: RedisConnectionInfo { + db: 2, + username: Some("%johndoe%".to_string()), + password: Some("#@<>$".to_string()), + ..Default::default() + }, + }, + ), + ]; + for (url, expected) in cases.into_iter() { + let res = url_to_tcp_connection_info(url.clone()).unwrap(); + assert_eq!(res.addr, expected.addr, "addr of {url} is not expected"); + assert_eq!( + res.redis.db, expected.redis.db, + "db of {url} is not expected", + ); + assert_eq!( + res.redis.username, expected.redis.username, + "username of {url} is not expected", + ); + assert_eq!( + res.redis.password, expected.redis.password, + "password of {url} is not expected", + ); + } + } + + #[test] + fn test_url_to_tcp_connection_info_failed() { + let cases = vec![ + (url::Url::parse("redis://").unwrap(), "Missing hostname"), + ( + url::Url::parse("redis://127.0.0.1/db").unwrap(), + "Invalid database number", + ), + ( + url::Url::parse("redis://C3%B0@127.0.0.1").unwrap(), + "Username is not valid UTF-8 string", + ), + ( + url::Url::parse("redis://:C3%B0@127.0.0.1").unwrap(), + "Password is not valid UTF-8 string", + ), + ]; + for (url, expected) in cases.into_iter() { + let res = url_to_tcp_connection_info(url).unwrap_err(); + assert_eq!( + res.kind(), + crate::ErrorKind::InvalidClientConfig, + "{}", + &res, + ); + #[allow(deprecated)] + let desc = std::error::Error::description(&res); + assert_eq!(desc, expected, "{}", &res); + assert_eq!(res.detail(), None, "{}", &res); + } + } + + #[test] + #[cfg(unix)] + fn test_url_to_unix_connection_info() { + let cases = vec![ + ( + url::Url::parse("unix:///var/run/redis.sock").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/var/run/redis.sock".into()), + redis: RedisConnectionInfo { + db: 0, + username: None, + password: None, + protocol: ProtocolVersion::RESP2, + client_name: None, + pubsub_subscriptions: None, + }, + }, + ), + ( + url::Url::parse("redis+unix:///var/run/redis.sock?db=1").unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/var/run/redis.sock".into()), + redis: RedisConnectionInfo { + db: 1, + ..Default::default() + }, + }, + ), + ( + url::Url::parse( + "unix:///example.sock?user=%25johndoe%25&pass=%23%40%3C%3E%24&db=2", + ) + .unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/example.sock".into()), + redis: RedisConnectionInfo { + db: 2, + username: Some("%johndoe%".to_string()), + password: Some("#@<>$".to_string()), + ..Default::default() + }, + }, + ), + ( + url::Url::parse( + "redis+unix:///example.sock?pass=%26%3F%3D+%2A%2B&db=2&user=%25johndoe%25", + ) + .unwrap(), + ConnectionInfo { + addr: ConnectionAddr::Unix("/example.sock".into()), + redis: RedisConnectionInfo { + db: 2, + username: Some("%johndoe%".to_string()), + password: Some("&?= *+".to_string()), + ..Default::default() + }, + }, + ), + ]; + for (url, expected) in cases.into_iter() { + assert_eq!( + ConnectionAddr::Unix(url.to_file_path().unwrap()), + expected.addr, + "addr of {url} is not expected", + ); + let res = url_to_unix_connection_info(url.clone()).unwrap(); + assert_eq!(res.addr, expected.addr, "addr of {url} is not expected"); + assert_eq!( + res.redis.db, expected.redis.db, + "db of {url} is not expected", + ); + assert_eq!( + res.redis.username, expected.redis.username, + "username of {url} is not expected", + ); + assert_eq!( + res.redis.password, expected.redis.password, + "password of {url} is not expected", + ); + } + } +} diff --git a/glide-core/redis-rs/redis/src/geo.rs b/glide-core/redis-rs/redis/src/geo.rs new file mode 100644 index 0000000000..6195264a7c --- /dev/null +++ b/glide-core/redis-rs/redis/src/geo.rs @@ -0,0 +1,361 @@ +//! Defines types to use with the geospatial commands. + +use super::{ErrorKind, RedisResult}; +use crate::types::{FromRedisValue, RedisWrite, ToRedisArgs, Value}; + +macro_rules! invalid_type_error { + ($v:expr, $det:expr) => {{ + fail!(( + ErrorKind::TypeError, + "Response was of incompatible type", + format!("{:?} (response was {:?})", $det, $v) + )); + }}; +} + +/// Units used by [`geo_dist`][1] and [`geo_radius`][2]. +/// +/// [1]: ../trait.Commands.html#method.geo_dist +/// [2]: ../trait.Commands.html#method.geo_radius +pub enum Unit { + /// Represents meters. + Meters, + /// Represents kilometers. + Kilometers, + /// Represents miles. + Miles, + /// Represents feed. + Feet, +} + +impl ToRedisArgs for Unit { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let unit = match *self { + Unit::Meters => "m", + Unit::Kilometers => "km", + Unit::Miles => "mi", + Unit::Feet => "ft", + }; + out.write_arg(unit.as_bytes()); + } +} + +/// A coordinate (longitude, latitude). Can be used with [`geo_pos`][1] +/// to parse response from Redis. +/// +/// [1]: ../trait.Commands.html#method.geo_pos +/// +/// `T` is the type of the every value. +/// +/// * You may want to use either `f64` or `f32` if you want to perform mathematical operations. +/// * To keep the raw value from Redis, use `String`. +#[allow(clippy::derive_partial_eq_without_eq)] // allow f32/f64 here, which don't implement Eq +#[derive(Debug, PartialEq)] +pub struct Coord { + /// Longitude + pub longitude: T, + /// Latitude + pub latitude: T, +} + +impl Coord { + /// Create a new Coord with the (longitude, latitude) + pub fn lon_lat(longitude: T, latitude: T) -> Coord { + Coord { + longitude, + latitude, + } + } +} + +impl FromRedisValue for Coord { + fn from_redis_value(v: &Value) -> RedisResult { + let values: Vec = FromRedisValue::from_redis_value(v)?; + let mut values = values.into_iter(); + let (longitude, latitude) = match (values.next(), values.next(), values.next()) { + (Some(longitude), Some(latitude), None) => (longitude, latitude), + _ => invalid_type_error!(v, "Expect a pair of numbers"), + }; + Ok(Coord { + longitude, + latitude, + }) + } +} + +impl ToRedisArgs for Coord { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_redis_args(&self.longitude, out); + ToRedisArgs::write_redis_args(&self.latitude, out); + } + + fn is_single_arg(&self) -> bool { + false + } +} + +/// Options to sort results from [GEORADIUS][1] and [GEORADIUSBYMEMBER][2] commands +/// +/// [1]: https://redis.io/commands/georadius +/// [2]: https://redis.io/commands/georadiusbymember +#[derive(Default)] +pub enum RadiusOrder { + /// Don't sort the results + #[default] + Unsorted, + + /// Sort returned items from the nearest to the farthest, relative to the center. + Asc, + + /// Sort returned items from the farthest to the nearest, relative to the center. + Desc, +} + +/// Options for the [GEORADIUS][1] and [GEORADIUSBYMEMBER][2] commands +/// +/// [1]: https://redis.io/commands/georadius +/// [2]: https://redis.io/commands/georadiusbymember +/// +/// # Example +/// +/// ```rust,no_run +/// use redis::{Commands, RedisResult}; +/// use redis::geo::{RadiusSearchResult, RadiusOptions, RadiusOrder, Unit}; +/// fn nearest_in_radius( +/// con: &mut redis::Connection, +/// key: &str, +/// longitude: f64, +/// latitude: f64, +/// meters: f64, +/// limit: usize, +/// ) -> RedisResult> { +/// let opts = RadiusOptions::default() +/// .order(RadiusOrder::Asc) +/// .limit(limit); +/// con.geo_radius(key, longitude, latitude, meters, Unit::Meters, opts) +/// } +/// ``` +#[derive(Default)] +pub struct RadiusOptions { + with_coord: bool, + with_dist: bool, + count: Option, + order: RadiusOrder, + store: Option>>, + store_dist: Option>>, +} + +impl RadiusOptions { + /// Limit the results to the first N matching items. + pub fn limit(mut self, n: usize) -> Self { + self.count = Some(n); + self + } + + /// Return the distance of the returned items from the specified center. + /// The distance is returned in the same unit as the unit specified as the + /// radius argument of the command. + pub fn with_dist(mut self) -> Self { + self.with_dist = true; + self + } + + /// Return the `longitude, latitude` coordinates of the matching items. + pub fn with_coord(mut self) -> Self { + self.with_coord = true; + self + } + + /// Sort the returned items + pub fn order(mut self, o: RadiusOrder) -> Self { + self.order = o; + self + } + + /// Store the results in a sorted set at `key`, instead of returning them. + /// + /// This feature can't be used with any `with_*` method. + pub fn store(mut self, key: K) -> Self { + self.store = Some(ToRedisArgs::to_redis_args(&key)); + self + } + + /// Store the results in a sorted set at `key`, with the distance from the + /// center as its score. This feature can't be used with any `with_*` method. + pub fn store_dist(mut self, key: K) -> Self { + self.store_dist = Some(ToRedisArgs::to_redis_args(&key)); + self + } +} + +impl ToRedisArgs for RadiusOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if self.with_coord { + out.write_arg(b"WITHCOORD"); + } + + if self.with_dist { + out.write_arg(b"WITHDIST"); + } + + if let Some(n) = self.count { + out.write_arg(b"COUNT"); + out.write_arg_fmt(n); + } + + match self.order { + RadiusOrder::Asc => out.write_arg(b"ASC"), + RadiusOrder::Desc => out.write_arg(b"DESC"), + _ => (), + }; + + if let Some(ref store) = self.store { + out.write_arg(b"STORE"); + for i in store { + out.write_arg(i); + } + } + + if let Some(ref store_dist) = self.store_dist { + out.write_arg(b"STOREDIST"); + for i in store_dist { + out.write_arg(i); + } + } + } + + fn is_single_arg(&self) -> bool { + false + } +} + +/// Contain an item returned by [`geo_radius`][1] and [`geo_radius_by_member`][2]. +/// +/// [1]: ../trait.Commands.html#method.geo_radius +/// [2]: ../trait.Commands.html#method.geo_radius_by_member +pub struct RadiusSearchResult { + /// The name that was found. + pub name: String, + /// The coordinate if available. + pub coord: Option>, + /// The distance if available. + pub dist: Option, +} + +impl FromRedisValue for RadiusSearchResult { + fn from_redis_value(v: &Value) -> RedisResult { + // If we receive only the member name, it will be a plain string + if let Ok(name) = FromRedisValue::from_redis_value(v) { + return Ok(RadiusSearchResult { + name, + coord: None, + dist: None, + }); + } + + // Try to parse the result from multitple values + if let Value::Array(ref items) = *v { + if let Some(result) = RadiusSearchResult::parse_multi_values(items) { + return Ok(result); + } + } + + invalid_type_error!(v, "Response type not RadiusSearchResult compatible."); + } +} + +impl RadiusSearchResult { + fn parse_multi_values(items: &[Value]) -> Option { + let mut iter = items.iter(); + + // First item is always the member name + let name: String = match iter.next().map(FromRedisValue::from_redis_value) { + Some(Ok(n)) => n, + _ => return None, + }; + + let mut next = iter.next(); + + // Next element, if present, will be the distance. + let dist = match next.map(FromRedisValue::from_redis_value) { + Some(Ok(c)) => { + next = iter.next(); + Some(c) + } + _ => None, + }; + + // Finally, if present, the last item will be the coordinates + + let coord = match next.map(FromRedisValue::from_redis_value) { + Some(Ok(c)) => Some(c), + _ => None, + }; + + Some(RadiusSearchResult { name, coord, dist }) + } +} + +#[cfg(test)] +mod tests { + use super::{Coord, RadiusOptions, RadiusOrder}; + use crate::types::ToRedisArgs; + use std::str; + + macro_rules! assert_args { + ($value:expr, $($args:expr),+) => { + let args = $value.to_redis_args(); + let strings: Vec<_> = args.iter() + .map(|a| str::from_utf8(a.as_ref()).unwrap()) + .collect(); + assert_eq!(strings, vec![$($args),+]); + } + } + + #[test] + fn test_coord_to_args() { + let member = ("Palermo", Coord::lon_lat("13.361389", "38.115556")); + assert_args!(&member, "Palermo", "13.361389", "38.115556"); + } + + #[test] + fn test_radius_options() { + // Without options, should not generate any argument + let empty = RadiusOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + // Some combinations with WITH* options + let opts = RadiusOptions::default; + + assert_args!(opts().with_coord().with_dist(), "WITHCOORD", "WITHDIST"); + + assert_args!(opts().limit(50), "COUNT", "50"); + + assert_args!(opts().limit(50).store("x"), "COUNT", "50", "STORE", "x"); + + assert_args!( + opts().limit(100).store_dist("y"), + "COUNT", + "100", + "STOREDIST", + "y" + ); + + assert_args!( + opts().order(RadiusOrder::Asc).limit(10).with_dist(), + "WITHDIST", + "COUNT", + "10", + "ASC" + ); + } +} diff --git a/glide-core/redis-rs/redis/src/lib.rs b/glide-core/redis-rs/redis/src/lib.rs new file mode 100644 index 0000000000..4f138c2bb6 --- /dev/null +++ b/glide-core/redis-rs/redis/src/lib.rs @@ -0,0 +1,506 @@ +//! redis-rs is a Rust implementation of a Redis client library. It exposes +//! a general purpose interface to Redis and also provides specific helpers for +//! commonly used functionality. +//! +//! The crate is called `redis` and you can depend on it via cargo: +//! +//! ```ini +//! [dependencies.redis] +//! version = "*" +//! ``` +//! +//! If you want to use the git version: +//! +//! ```ini +//! [dependencies.redis] +//! git = "https://github.com/redis-rs/redis-rs.git" +//! ``` +//! +//! # Basic Operation +//! +//! redis-rs exposes two API levels: a low- and a high-level part. +//! The high-level part does not expose all the functionality of redis and +//! might take some liberties in how it speaks the protocol. The low-level +//! part of the API allows you to express any request on the redis level. +//! You can fluently switch between both API levels at any point. +//! +//! ## Connection Handling +//! +//! For connecting to redis you can use a client object which then can produce +//! actual connections. Connections and clients as well as results of +//! connections and clients are considered `ConnectionLike` objects and +//! can be used anywhere a request is made. +//! +//! The full canonical way to get a connection is to create a client and +//! to ask for a connection from it: +//! +//! ```rust,no_run +//! extern crate redis; +//! +//! fn do_something() -> redis::RedisResult<()> { +//! let client = redis::Client::open("redis://127.0.0.1/")?; +//! let mut con = client.get_connection(None)?; +//! +//! /* do something here */ +//! +//! Ok(()) +//! } +//! ``` +//! +//! ## Optional Features +//! +//! There are a few features defined that can enable additional functionality +//! if so desired. Some of them are turned on by default. +//! +//! * `acl`: enables acl support (enabled by default) +//! * `aio`: enables async IO support (enabled by default) +//! * `geospatial`: enables geospatial support (enabled by default) +//! * `script`: enables script support (enabled by default) +//! * `r2d2`: enables r2d2 connection pool support (optional) +//! * `ahash`: enables ahash map/set support & uses ahash internally (+7-10% performance) (optional) +//! * `cluster`: enables redis cluster support (optional) +//! * `cluster-async`: enables async redis cluster support (optional) +//! * `tokio-comp`: enables support for tokio (optional) +//! * `connection-manager`: enables support for automatic reconnection (optional) +//! * `keep-alive`: enables keep-alive option on socket by means of `socket2` crate (optional) +//! +//! ## Connection Parameters +//! +//! redis-rs knows different ways to define where a connection should +//! go. The parameter to `Client::open` needs to implement the +//! `IntoConnectionInfo` trait of which there are three implementations: +//! +//! * string slices in `redis://` URL format. +//! * URL objects from the redis-url crate. +//! * `ConnectionInfo` objects. +//! +//! The URL format is `redis://[][:@][:port][/]` +//! +//! If Unix socket support is available you can use a unix URL in this format: +//! +//! `redis+unix:///[?db=[&pass=][&user=]]` +//! +//! For compatibility with some other redis libraries, the "unix" scheme +//! is also supported: +//! +//! `unix:///[?db=][&pass=][&user=]]` +//! +//! ## Executing Low-Level Commands +//! +//! To execute low-level commands you can use the `cmd` function which allows +//! you to build redis requests. Once you have configured a command object +//! to your liking you can send a query into any `ConnectionLike` object: +//! +//! ```rust,no_run +//! fn do_something(con: &mut redis::Connection) -> redis::RedisResult<()> { +//! let _ : () = redis::cmd("SET").arg("my_key").arg(42).query(con)?; +//! Ok(()) +//! } +//! ``` +//! +//! Upon querying the return value is a result object. If you do not care +//! about the actual return value (other than that it is not a failure) +//! you can always type annotate it to the unit type `()`. +//! +//! Note that commands with a sub-command (like "MEMORY USAGE", "ACL WHOAMI", +//! "LATENCY HISTORY", etc) must specify the sub-command as a separate `arg`: +//! +//! ```rust,no_run +//! fn do_something(con: &mut redis::Connection) -> redis::RedisResult { +//! // This will result in a server error: "unknown command `MEMORY USAGE`" +//! // because "USAGE" is technically a sub-command of "MEMORY". +//! redis::cmd("MEMORY USAGE").arg("my_key").query(con)?; +//! +//! // However, this will work as you'd expect +//! redis::cmd("MEMORY").arg("USAGE").arg("my_key").query(con) +//! } +//! ``` +//! +//! ## Executing High-Level Commands +//! +//! The high-level interface is similar. For it to become available you +//! need to use the `Commands` trait in which case all `ConnectionLike` +//! objects the library provides will also have high-level methods which +//! make working with the protocol easier: +//! +//! ```rust,no_run +//! extern crate redis; +//! use redis::Commands; +//! +//! fn do_something(con: &mut redis::Connection) -> redis::RedisResult<()> { +//! let _ : () = con.set("my_key", 42)?; +//! Ok(()) +//! } +//! ``` +//! +//! Note that high-level commands are work in progress and many are still +//! missing! +//! +//! ## Type Conversions +//! +//! Because redis inherently is mostly type-less and the protocol is not +//! exactly friendly to developers, this library provides flexible support +//! for casting values to the intended results. This is driven through the `FromRedisValue` and `ToRedisArgs` traits. +//! +//! The `arg` method of the command will accept a wide range of types through +//! the `ToRedisArgs` trait and the `query` method of a command can convert the +//! value to what you expect the function to return through the `FromRedisValue` +//! trait. This is quite flexible and allows vectors, tuples, hashsets, hashmaps +//! as well as optional values: +//! +//! ```rust,no_run +//! # use redis::Commands; +//! # use std::collections::{HashMap, HashSet}; +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let count : i32 = con.get("my_counter")?; +//! let count = con.get("my_counter").unwrap_or(0i32); +//! let k : Option = con.get("missing_key")?; +//! let name : String = con.get("my_name")?; +//! let bin : Vec = con.get("my_binary")?; +//! let map : HashMap = con.hgetall("my_hash")?; +//! let keys : Vec = con.hkeys("my_hash")?; +//! let mems : HashSet = con.smembers("my_set")?; +//! let (k1, k2) : (String, String) = con.get(&["k1", "k2"])?; +//! # Ok(()) +//! # } +//! ``` +//! +//! # Iteration Protocol +//! +//! In addition to sending a single query, iterators are also supported. When +//! used with regular bulk responses they don't give you much over querying and +//! converting into a vector (both use a vector internally) but they can also +//! be used with `SCAN` like commands in which case iteration will send more +//! queries until the cursor is exhausted: +//! +//! ```rust,ignore +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let mut iter : redis::Iter = redis::cmd("SSCAN").arg("my_set") +//! .cursor_arg(0).clone().iter(&mut con)?; +//! for x in iter { +//! // do something with the item +//! } +//! # Ok(()) } +//! ``` +//! +//! As you can see the cursor argument needs to be defined with `cursor_arg` +//! instead of `arg` so that the library knows which argument needs updating +//! as the query is run for more items. +//! +//! # Pipelining +//! +//! In addition to simple queries you can also send command pipelines. This +//! is provided through the `pipe` function. It works very similar to sending +//! individual commands but you can send more than one in one go. This also +//! allows you to ignore individual results so that matching on the end result +//! is easier: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let (k1, k2) : (i32, i32) = redis::pipe() +//! .cmd("SET").arg("key_1").arg(42).ignore() +//! .cmd("SET").arg("key_2").arg(43).ignore() +//! .cmd("GET").arg("key_1") +//! .cmd("GET").arg("key_2").query(&mut con)?; +//! # Ok(()) } +//! ``` +//! +//! If you want the pipeline to be wrapped in a `MULTI`/`EXEC` block you can +//! easily do that by switching the pipeline into `atomic` mode. From the +//! caller's point of view nothing changes, the pipeline itself will take +//! care of the rest for you: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let (k1, k2) : (i32, i32) = redis::pipe() +//! .atomic() +//! .cmd("SET").arg("key_1").arg(42).ignore() +//! .cmd("SET").arg("key_2").arg(43).ignore() +//! .cmd("GET").arg("key_1") +//! .cmd("GET").arg("key_2").query(&mut con)?; +//! # Ok(()) } +//! ``` +//! +//! You can also use high-level commands on pipelines: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let (k1, k2) : (i32, i32) = redis::pipe() +//! .atomic() +//! .set("key_1", 42).ignore() +//! .set("key_2", 43).ignore() +//! .get("key_1") +//! .get("key_2").query(&mut con)?; +//! # Ok(()) } +//! ``` +//! +//! # Transactions +//! +//! Transactions are available through atomic pipelines. In order to use +//! them in a more simple way you can use the `transaction` function of a +//! connection: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! use redis::Commands; +//! # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +//! # let mut con = client.get_connection(None).unwrap(); +//! let key = "the_key"; +//! let (new_val,) : (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { +//! let old_val : isize = con.get(key)?; +//! pipe +//! .set(key, old_val + 1).ignore() +//! .get(key).query(con) +//! })?; +//! println!("The incremented number is: {}", new_val); +//! # Ok(()) } +//! ``` +//! +//! For more information see the `transaction` function. +//! +//! # PubSub +//! +//! Pubsub is currently work in progress but provided through the `PubSub` +//! connection object. Due to the fact that Rust does not have support +//! for async IO in libnative yet, the API does not provide a way to +//! read messages with any form of timeout yet. +//! +//! Example usage: +//! +//! ```rust,no_run +//! # fn do_something() -> redis::RedisResult<()> { +//! let client = redis::Client::open("redis://127.0.0.1/")?; +//! let mut con = client.get_connection(None)?; +//! let mut pubsub = con.as_pubsub(); +//! pubsub.subscribe("channel_1")?; +//! pubsub.subscribe("channel_2")?; +//! +//! loop { +//! let msg = pubsub.get_message()?; +//! let payload : String = msg.get_payload()?; +//! println!("channel '{}': {}", msg.get_channel_name(), payload); +//! } +//! # } +//! ``` +//! +#![cfg_attr( + feature = "script", + doc = r##" +# Scripts + +Lua scripts are supported through the `Script` type in a convenient +way (it does not support pipelining currently). It will automatically +load the script if it does not exist and invoke it. + +Example: + +```rust,no_run +# fn do_something() -> redis::RedisResult<()> { +# let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +# let mut con = client.get_connection(None).unwrap(); +let script = redis::Script::new(r" + return tonumber(ARGV[1]) + tonumber(ARGV[2]); +"); +let result : isize = script.arg(1).arg(2).invoke(&mut con)?; +assert_eq!(result, 3); +# Ok(()) } +``` +"## +)] +//! +#![cfg_attr( + feature = "aio", + doc = r##" +# Async + +In addition to the synchronous interface that's been explained above there also exists an +asynchronous interface based on [`futures`][] and [`tokio`][]. + +This interface exists under the `aio` (async io) module (which requires that the `aio` feature +is enabled) and largely mirrors the synchronous with a few concessions to make it fit the +constraints of `futures`. + +```rust,no_run +use futures::prelude::*; +use redis::AsyncCommands; + +# #[tokio::main] +# async fn main() -> redis::RedisResult<()> { +let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +let mut con = client.get_async_connection(None).await?; + +con.set("key1", b"foo").await?; + +redis::cmd("SET").arg(&["key2", "bar"]).query_async(&mut con).await?; + +let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; +assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); +# Ok(()) } +``` +"## +)] +//! +//! [`futures`]:https://crates.io/crates/futures +//! [`tokio`]:https://tokio.rs + +#![deny(non_camel_case_types)] +#![warn(missing_docs)] +#![cfg_attr(docsrs, warn(rustdoc::broken_intra_doc_links))] +#![cfg_attr(docsrs, feature(doc_cfg))] +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] + +// public api +pub use crate::client::Client; +pub use crate::client::GlideConnectionOptions; +pub use crate::cmd::{cmd, pack_command, pipe, Arg, Cmd, Iter}; +pub use crate::commands::{ + Commands, ControlFlow, Direction, LposOptions, PubSubCommands, SetOptions, +}; +pub use crate::connection::{ + parse_redis_url, transaction, Connection, ConnectionAddr, ConnectionInfo, ConnectionLike, + IntoConnectionInfo, Msg, PubSub, PubSubChannelOrPattern, PubSubSubscriptionInfo, + PubSubSubscriptionKind, RedisConnectionInfo, TlsMode, +}; +pub use crate::parser::{parse_redis_value, Parser}; +pub use crate::pipeline::Pipeline; +pub use push_manager::{PushInfo, PushManager}; + +#[cfg(feature = "script")] +#[cfg_attr(docsrs, doc(cfg(feature = "script")))] +pub use crate::script::{Script, ScriptInvocation}; + +// preserve grouping and order +#[rustfmt::skip] +pub use crate::types::{ + // utility functions + from_redis_value, + from_owned_redis_value, + + // error kinds + ErrorKind, + + // conversion traits + FromRedisValue, + + // utility types + InfoDict, + NumericBehavior, + Expiry, + SetExpiry, + ExistenceCheck, + + // error and result types + RedisError, + RedisResult, + RedisWrite, + ToRedisArgs, + + // low level values + Value, + PushKind, + VerbatimFormat, + ProtocolVersion +}; + +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +pub use crate::{ + cmd::AsyncIter, commands::AsyncCommands, parser::parse_redis_value_async, types::RedisFuture, +}; + +mod macros; +mod pipeline; + +#[cfg(feature = "acl")] +#[cfg_attr(docsrs, doc(cfg(feature = "acl")))] +pub mod acl; + +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +pub mod aio; + +#[cfg(feature = "json")] +pub use crate::commands::JsonCommands; + +#[cfg(all(feature = "json", feature = "aio"))] +pub use crate::commands::JsonAsyncCommands; + +#[cfg(feature = "geospatial")] +#[cfg_attr(docsrs, doc(cfg(feature = "geospatial")))] +pub mod geo; + +#[cfg(feature = "cluster")] +#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))] +pub mod cluster; + +#[cfg(feature = "cluster")] +#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))] +mod cluster_slotmap; + +#[cfg(feature = "cluster-async")] +pub use crate::commands::ScanStateRC; + +#[cfg(feature = "cluster-async")] +pub use crate::commands::ObjectType; + +#[cfg(feature = "cluster")] +mod cluster_client; + +/// for testing purposes +pub mod testing { + #[cfg(feature = "cluster")] + pub use crate::cluster_client::ClusterParams; +} + +#[cfg(feature = "cluster")] +mod cluster_pipeline; + +/// Routing information for cluster commands. +#[cfg(feature = "cluster")] +pub mod cluster_routing; + +#[cfg(feature = "cluster")] +#[cfg_attr(docsrs, doc(cfg(feature = "cluster")))] +pub mod cluster_topology; + +#[cfg(feature = "r2d2")] +#[cfg_attr(docsrs, doc(cfg(feature = "r2d2")))] +mod r2d2; + +#[cfg(feature = "streams")] +#[cfg_attr(docsrs, doc(cfg(feature = "streams")))] +pub mod streams; + +#[cfg(feature = "cluster-async")] +pub mod cluster_async; + +#[cfg(feature = "sentinel")] +pub mod sentinel; + +#[cfg(feature = "tls-rustls")] +mod tls; + +#[cfg(feature = "tls-rustls")] +pub use crate::tls::{ClientTlsConfig, TlsCertificates}; + +mod client; +mod cmd; +mod commands; +mod connection; +mod parser; +mod push_manager; +mod script; +mod types; diff --git a/glide-core/redis-rs/redis/src/macros.rs b/glide-core/redis-rs/redis/src/macros.rs new file mode 100644 index 0000000000..b8886cc759 --- /dev/null +++ b/glide-core/redis-rs/redis/src/macros.rs @@ -0,0 +1,7 @@ +#![macro_use] + +macro_rules! fail { + ($expr:expr) => { + return Err(::std::convert::From::from($expr)) + }; +} diff --git a/glide-core/redis-rs/redis/src/parser.rs b/glide-core/redis-rs/redis/src/parser.rs new file mode 100644 index 0000000000..96e0bcd8f1 --- /dev/null +++ b/glide-core/redis-rs/redis/src/parser.rs @@ -0,0 +1,658 @@ +use std::{ + io::{self, Read}, + str, +}; + +use crate::types::{ + ErrorKind, InternalValue, PushKind, RedisError, RedisResult, ServerError, ServerErrorKind, + Value, VerbatimFormat, +}; + +use combine::{ + any, + error::StreamError, + opaque, + parser::{ + byte::{crlf, take_until_bytes}, + combinator::{any_send_sync_partial_state, AnySendSyncPartialState}, + range::{recognize, take}, + }, + stream::{PointerOffset, RangeStream, StreamErrorFor}, + ParseError, Parser as _, +}; +use num_bigint::BigInt; + +const MAX_RECURSE_DEPTH: usize = 100; + +fn err_parser(line: &str) -> ServerError { + let mut pieces = line.splitn(2, ' '); + let kind = match pieces.next().unwrap() { + "ERR" => ServerErrorKind::ResponseError, + "EXECABORT" => ServerErrorKind::ExecAbortError, + "LOADING" => ServerErrorKind::BusyLoadingError, + "NOSCRIPT" => ServerErrorKind::NoScriptError, + "MOVED" => ServerErrorKind::Moved, + "ASK" => ServerErrorKind::Ask, + "TRYAGAIN" => ServerErrorKind::TryAgain, + "CLUSTERDOWN" => ServerErrorKind::ClusterDown, + "CROSSSLOT" => ServerErrorKind::CrossSlot, + "MASTERDOWN" => ServerErrorKind::MasterDown, + "READONLY" => ServerErrorKind::ReadOnly, + "NOTBUSY" => ServerErrorKind::NotBusy, + code => { + return ServerError::ExtensionError { + code: code.to_string(), + detail: pieces.next().map(|str| str.to_string()), + } + } + }; + let detail = pieces.next().map(|str| str.to_string()); + ServerError::KnownError { kind, detail } +} + +pub fn get_push_kind(kind: String) -> PushKind { + match kind.as_str() { + "invalidate" => PushKind::Invalidate, + "message" => PushKind::Message, + "pmessage" => PushKind::PMessage, + "smessage" => PushKind::SMessage, + "unsubscribe" => PushKind::Unsubscribe, + "punsubscribe" => PushKind::PUnsubscribe, + "sunsubscribe" => PushKind::SUnsubscribe, + "subscribe" => PushKind::Subscribe, + "psubscribe" => PushKind::PSubscribe, + "ssubscribe" => PushKind::SSubscribe, + _ => PushKind::Other(kind), + } +} + +fn value<'a, I>( + count: Option, +) -> impl combine::Parser +where + I: RangeStream, + I::Error: combine::ParseError, +{ + let count = count.unwrap_or(1); + + opaque!(any_send_sync_partial_state( + any() + .then_partial(move |&mut b| { + if b == b'*' && count > MAX_RECURSE_DEPTH { + combine::unexpected_any("Maximum recursion depth exceeded").left() + } else { + combine::value(b).right() + } + }) + .then_partial(move |&mut b| { + let line = || { + recognize(take_until_bytes(&b"\r\n"[..]).with(take(2).map(|_| ()))).and_then( + |line: &[u8]| { + str::from_utf8(&line[..line.len() - 2]) + .map_err(StreamErrorFor::::other) + }, + ) + }; + + let simple_string = || { + line().map(|line| { + if line == "OK" { + InternalValue::Okay + } else { + InternalValue::SimpleString(line.into()) + } + }) + }; + + let int = || { + line().and_then(|line| { + line.trim().parse::().map_err(|_| { + StreamErrorFor::::message_static_message( + "Expected integer, got garbage", + ) + }) + }) + }; + + let bulk_string = || { + int().then_partial(move |size| { + if *size < 0 { + combine::produce(|| InternalValue::Nil).left() + } else { + take(*size as usize) + .map(|bs: &[u8]| InternalValue::BulkString(bs.to_vec())) + .skip(crlf()) + .right() + } + }) + }; + let blob = || { + int().then_partial(move |size| { + take(*size as usize) + .map(|bs: &[u8]| String::from_utf8_lossy(bs).to_string()) + .skip(crlf()) + }) + }; + + let array = || { + int().then_partial(move |&mut length| { + if length < 0 { + combine::produce(|| InternalValue::Nil).left() + } else { + let length = length as usize; + combine::count_min_max(length, length, value(Some(count + 1))) + .map(InternalValue::Array) + .right() + } + }) + }; + + let error = || line().map(err_parser); + let map = || { + int().then_partial(move |&mut kv_length| { + let length = kv_length as usize * 2; + combine::count_min_max(length, length, value(Some(count + 1))).map( + move |result: Vec| { + let mut it = result.into_iter(); + let mut x = vec![]; + for _ in 0..kv_length { + if let (Some(k), Some(v)) = (it.next(), it.next()) { + x.push((k, v)) + } + } + InternalValue::Map(x) + }, + ) + }) + }; + let attribute = || { + int().then_partial(move |&mut kv_length| { + // + 1 is for data! + let length = kv_length as usize * 2 + 1; + combine::count_min_max(length, length, value(Some(count + 1))).map( + move |result: Vec| { + let mut it = result.into_iter(); + let mut attributes = vec![]; + for _ in 0..kv_length { + if let (Some(k), Some(v)) = (it.next(), it.next()) { + attributes.push((k, v)) + } + } + InternalValue::Attribute { + data: Box::new(it.next().unwrap()), + attributes, + } + }, + ) + }) + }; + let set = || { + int().then_partial(move |&mut length| { + if length < 0 { + combine::produce(|| InternalValue::Nil).left() + } else { + let length = length as usize; + combine::count_min_max(length, length, value(Some(count + 1))) + .map(InternalValue::Set) + .right() + } + }) + }; + let push = || { + int().then_partial(move |&mut length| { + if length <= 0 { + combine::produce(|| InternalValue::Push { + kind: PushKind::Other("".to_string()), + data: vec![], + }) + .left() + } else { + let length = length as usize; + combine::count_min_max(length, length, value(Some(count + 1))) + .and_then(|result: Vec| { + let mut it = result.into_iter(); + let first = it.next().unwrap_or(InternalValue::Nil); + if let InternalValue::BulkString(kind) = first { + let push_kind = String::from_utf8(kind) + .map_err(StreamErrorFor::::other)?; + Ok(InternalValue::Push { + kind: get_push_kind(push_kind), + data: it.collect(), + }) + } else if let InternalValue::SimpleString(kind) = first { + Ok(InternalValue::Push { + kind: get_push_kind(kind), + data: it.collect(), + }) + } else { + Err(StreamErrorFor::::message_static_message( + "parse error when decoding push", + )) + } + }) + .right() + } + }) + }; + let null = || line().map(|_| InternalValue::Nil); + let double = || { + line().and_then(|line| { + line.trim() + .parse::() + .map_err(StreamErrorFor::::other) + }) + }; + let boolean = || { + line().and_then(|line: &str| match line { + "t" => Ok(true), + "f" => Ok(false), + _ => Err(StreamErrorFor::::message_static_message( + "Expected boolean, got garbage", + )), + }) + }; + let blob_error = || blob().map(|line| err_parser(&line)); + let verbatim = || { + blob().and_then(|line| { + if let Some((format, text)) = line.split_once(':') { + let format = match format { + "txt" => VerbatimFormat::Text, + "mkd" => VerbatimFormat::Markdown, + x => VerbatimFormat::Unknown(x.to_string()), + }; + Ok(InternalValue::VerbatimString { + format, + text: text.to_string(), + }) + } else { + Err(StreamErrorFor::::message_static_message( + "parse error when decoding verbatim string", + )) + } + }) + }; + let big_number = || { + line().and_then(|line| { + BigInt::parse_bytes(line.as_bytes(), 10).ok_or_else(|| { + StreamErrorFor::::message_static_message( + "Expected bigint, got garbage", + ) + }) + }) + }; + combine::dispatch!(b; + b'+' => simple_string(), + b':' => int().map(InternalValue::Int), + b'$' => bulk_string(), + b'*' => array(), + b'%' => map(), + b'|' => attribute(), + b'~' => set(), + b'-' => error().map(InternalValue::ServerError), + b'_' => null(), + b',' => double().map(InternalValue::Double), + b'#' => boolean().map(InternalValue::Boolean), + b'!' => blob_error().map(InternalValue::ServerError), + b'=' => verbatim(), + b'(' => big_number().map(InternalValue::BigNumber), + b'>' => push(), + b => combine::unexpected_any(combine::error::Token(b)) + ) + }) + )) +} + +#[cfg(feature = "aio")] +mod aio_support { + use super::*; + + use bytes::{Buf, BytesMut}; + use tokio::io::AsyncRead; + use tokio_util::codec::{Decoder, Encoder}; + + #[derive(Default)] + pub struct ValueCodec { + state: AnySendSyncPartialState, + } + + impl ValueCodec { + fn decode_stream( + &mut self, + bytes: &mut BytesMut, + eof: bool, + ) -> RedisResult>> { + let (opt, removed_len) = { + let buffer = &bytes[..]; + let mut stream = + combine::easy::Stream(combine::stream::MaybePartialStream(buffer, !eof)); + match combine::stream::decode_tokio(value(None), &mut stream, &mut self.state) { + Ok(x) => x, + Err(err) => { + let err = err + .map_position(|pos| pos.translate_position(buffer)) + .map_range(|range| format!("{range:?}")) + .to_string(); + return Err(RedisError::from(( + ErrorKind::ParseError, + "parse error", + err, + ))); + } + } + }; + + bytes.advance(removed_len); + match opt { + Some(result) => Ok(Some(result.try_into())), + None => Ok(None), + } + } + } + + impl Encoder> for ValueCodec { + type Error = RedisError; + fn encode(&mut self, item: Vec, dst: &mut BytesMut) -> Result<(), Self::Error> { + dst.extend_from_slice(item.as_ref()); + Ok(()) + } + } + + impl Decoder for ValueCodec { + type Item = RedisResult; + type Error = RedisError; + + fn decode(&mut self, bytes: &mut BytesMut) -> Result, Self::Error> { + self.decode_stream(bytes, false) + } + + fn decode_eof(&mut self, bytes: &mut BytesMut) -> Result, Self::Error> { + self.decode_stream(bytes, true) + } + } + + /// Parses a redis value asynchronously. + pub async fn parse_redis_value_async( + decoder: &mut combine::stream::Decoder>, + read: &mut R, + ) -> RedisResult + where + R: AsyncRead + std::marker::Unpin, + { + let result = combine::decode_tokio!(*decoder, *read, value(None), |input, _| { + combine::stream::easy::Stream::from(input) + }); + match result { + Err(err) => Err(match err { + combine::stream::decoder::Error::Io { error, .. } => error.into(), + combine::stream::decoder::Error::Parse(err) => { + if err.is_unexpected_end_of_input() { + RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof)) + } else { + let err = err + .map_range(|range| format!("{range:?}")) + .map_position(|pos| pos.translate_position(decoder.buffer())) + .to_string(); + RedisError::from((ErrorKind::ParseError, "parse error", err)) + } + } + }), + Ok(result) => result.try_into(), + } + } +} + +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +pub use self::aio_support::*; + +/// The internal redis response parser. +pub struct Parser { + decoder: combine::stream::decoder::Decoder>, +} + +impl Default for Parser { + fn default() -> Self { + Parser::new() + } +} + +/// The parser can be used to parse redis responses into values. Generally +/// you normally do not use this directly as it's already done for you by +/// the client but in some more complex situations it might be useful to be +/// able to parse the redis responses. +impl Parser { + /// Creates a new parser that parses the data behind the reader. More + /// than one value can be behind the reader in which case the parser can + /// be invoked multiple times. In other words: the stream does not have + /// to be terminated. + pub fn new() -> Parser { + Parser { + decoder: combine::stream::decoder::Decoder::new(), + } + } + + // public api + + /// Parses synchronously into a single value from the reader. + pub fn parse_value(&mut self, mut reader: T) -> RedisResult { + let mut decoder = &mut self.decoder; + let result = combine::decode!(decoder, reader, value(None), |input, _| { + combine::stream::easy::Stream::from(input) + }); + match result { + Err(err) => Err(match err { + combine::stream::decoder::Error::Io { error, .. } => error.into(), + combine::stream::decoder::Error::Parse(err) => { + if err.is_unexpected_end_of_input() { + RedisError::from(io::Error::from(io::ErrorKind::UnexpectedEof)) + } else { + let err = err + .map_range(|range| format!("{range:?}")) + .map_position(|pos| pos.translate_position(decoder.buffer())) + .to_string(); + RedisError::from((ErrorKind::ParseError, "parse error", err)) + } + } + }), + Ok(result) => result.try_into(), + } + } +} + +/// Parses bytes into a redis value. +/// +/// This is the most straightforward way to parse something into a low +/// level redis value instead of having to use a whole parser. +pub fn parse_redis_value(bytes: &[u8]) -> RedisResult { + let mut parser = Parser::new(); + parser.parse_value(bytes) +} + +#[cfg(test)] +mod tests { + use crate::types::make_extension_error; + + use super::*; + + #[cfg(feature = "aio")] + #[test] + fn decode_eof_returns_none_at_eof() { + use tokio_util::codec::Decoder; + let mut codec = ValueCodec::default(); + + let mut bytes = bytes::BytesMut::from(&b"+GET 123\r\n"[..]); + assert_eq!( + codec.decode_eof(&mut bytes), + Ok(Some(Ok(parse_redis_value(b"+GET 123\r\n").unwrap()))) + ); + assert_eq!(codec.decode_eof(&mut bytes), Ok(None)); + assert_eq!(codec.decode_eof(&mut bytes), Ok(None)); + } + + #[cfg(feature = "aio")] + #[test] + fn decode_eof_returns_error_inside_array_and_can_parse_more_inputs() { + use tokio_util::codec::Decoder; + let mut codec = ValueCodec::default(); + + let mut bytes = + bytes::BytesMut::from(b"*3\r\n+OK\r\n-LOADING server is loading\r\n+OK\r\n".as_slice()); + let result = codec.decode_eof(&mut bytes).unwrap().unwrap(); + + assert_eq!( + result, + Err(RedisError::from(( + ErrorKind::BusyLoadingError, + "An error was signalled by the server", + "server is loading".to_string() + ))) + ); + + let mut bytes = bytes::BytesMut::from(b"+OK\r\n".as_slice()); + let result = codec.decode_eof(&mut bytes).unwrap().unwrap(); + + assert_eq!(result, Ok(Value::Okay)); + } + + #[test] + fn parse_nested_error_and_handle_more_inputs() { + // from https://redis.io/docs/interact/transactions/ - + // "EXEC returned two-element bulk string reply where one is an OK code and the other an error reply. It's up to the client library to find a sensible way to provide the error to the user." + + let bytes = b"*3\r\n+OK\r\n-LOADING server is loading\r\n+OK\r\n"; + let result = parse_redis_value(bytes); + + assert_eq!( + result, + Err(RedisError::from(( + ErrorKind::BusyLoadingError, + "An error was signalled by the server", + "server is loading".to_string() + ))) + ); + + let result = parse_redis_value(b"+OK\r\n").unwrap(); + + assert_eq!(result, Value::Okay); + } + + #[test] + fn decode_resp3_double() { + let val = parse_redis_value(b",1.23\r\n").unwrap(); + assert_eq!(val, Value::Double(1.23)); + let val = parse_redis_value(b",nan\r\n").unwrap(); + if let Value::Double(val) = val { + assert!(val.is_sign_positive()); + assert!(val.is_nan()); + } else { + panic!("expected double"); + } + // -nan is supported prior to redis 7.2 + let val = parse_redis_value(b",-nan\r\n").unwrap(); + if let Value::Double(val) = val { + assert!(val.is_sign_negative()); + assert!(val.is_nan()); + } else { + panic!("expected double"); + } + //Allow doubles in scientific E notation + let val = parse_redis_value(b",2.67923e+8\r\n").unwrap(); + assert_eq!(val, Value::Double(267923000.0)); + let val = parse_redis_value(b",2.67923E+8\r\n").unwrap(); + assert_eq!(val, Value::Double(267923000.0)); + let val = parse_redis_value(b",-2.67923E+8\r\n").unwrap(); + assert_eq!(val, Value::Double(-267923000.0)); + let val = parse_redis_value(b",2.1E-2\r\n").unwrap(); + assert_eq!(val, Value::Double(0.021)); + + let val = parse_redis_value(b",-inf\r\n").unwrap(); + assert_eq!(val, Value::Double(-f64::INFINITY)); + let val = parse_redis_value(b",inf\r\n").unwrap(); + assert_eq!(val, Value::Double(f64::INFINITY)); + } + + #[test] + fn decode_resp3_map() { + let val = parse_redis_value(b"%2\r\n+first\r\n:1\r\n+second\r\n:2\r\n").unwrap(); + let mut v = val.as_map_iter().unwrap(); + assert_eq!( + (&Value::SimpleString("first".to_string()), &Value::Int(1)), + v.next().unwrap() + ); + assert_eq!( + (&Value::SimpleString("second".to_string()), &Value::Int(2)), + v.next().unwrap() + ); + } + + #[test] + fn decode_resp3_boolean() { + let val = parse_redis_value(b"#t\r\n").unwrap(); + assert_eq!(val, Value::Boolean(true)); + let val = parse_redis_value(b"#f\r\n").unwrap(); + assert_eq!(val, Value::Boolean(false)); + let val = parse_redis_value(b"#x\r\n"); + assert!(val.is_err()); + let val = parse_redis_value(b"#\r\n"); + assert!(val.is_err()); + } + + #[test] + fn decode_resp3_blob_error() { + let val = parse_redis_value(b"!21\r\nSYNTAX invalid syntax\r\n"); + assert_eq!( + val.err(), + Some(make_extension_error( + "SYNTAX".to_string(), + Some("invalid syntax".to_string()) + )) + ) + } + + #[test] + fn decode_resp3_big_number() { + let val = parse_redis_value(b"(3492890328409238509324850943850943825024385\r\n").unwrap(); + assert_eq!( + val, + Value::BigNumber( + BigInt::parse_bytes(b"3492890328409238509324850943850943825024385", 10).unwrap() + ) + ); + } + + #[test] + fn decode_resp3_set() { + let val = parse_redis_value(b"~5\r\n+orange\r\n+apple\r\n#t\r\n:100\r\n:999\r\n").unwrap(); + let v = val.as_sequence().unwrap(); + assert_eq!(Value::SimpleString("orange".to_string()), v[0]); + assert_eq!(Value::SimpleString("apple".to_string()), v[1]); + assert_eq!(Value::Boolean(true), v[2]); + assert_eq!(Value::Int(100), v[3]); + assert_eq!(Value::Int(999), v[4]); + } + + #[test] + fn decode_resp3_push() { + let val = parse_redis_value(b">3\r\n+message\r\n+somechannel\r\n+this is the message\r\n") + .unwrap(); + if let Value::Push { ref kind, ref data } = val { + assert_eq!(&PushKind::Message, kind); + assert_eq!(Value::SimpleString("somechannel".to_string()), data[0]); + assert_eq!( + Value::SimpleString("this is the message".to_string()), + data[1] + ); + } else { + panic!("Expected Value::Push") + } + } + + #[test] + fn test_max_recursion_depth() { + let bytes = b"*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n*1\r\n"; + match parse_redis_value(bytes) { + Ok(_) => panic!("Expected Err"), + Err(e) => assert!(matches!(e.kind(), ErrorKind::ParseError)), + } + } +} diff --git a/glide-core/redis-rs/redis/src/pipeline.rs b/glide-core/redis-rs/redis/src/pipeline.rs new file mode 100644 index 0000000000..babb57a1ff --- /dev/null +++ b/glide-core/redis-rs/redis/src/pipeline.rs @@ -0,0 +1,324 @@ +#![macro_use] + +use crate::cmd::{cmd, cmd_len, Cmd}; +use crate::connection::ConnectionLike; +use crate::types::{ + from_owned_redis_value, ErrorKind, FromRedisValue, HashSet, RedisResult, ToRedisArgs, Value, +}; + +/// Represents a redis command pipeline. +#[derive(Clone)] +pub struct Pipeline { + commands: Vec, + transaction_mode: bool, + ignored_commands: HashSet, +} + +/// A pipeline allows you to send multiple commands in one go to the +/// redis server. API wise it's very similar to just using a command +/// but it allows multiple commands to be chained and some features such +/// as iteration are not available. +/// +/// Basic example: +/// +/// ```rust,no_run +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let ((k1, k2),) : ((i32, i32),) = redis::pipe() +/// .cmd("SET").arg("key_1").arg(42).ignore() +/// .cmd("SET").arg("key_2").arg(43).ignore() +/// .cmd("MGET").arg(&["key_1", "key_2"]).query(&mut con).unwrap(); +/// ``` +/// +/// As you can see with `cmd` you can start a new command. By default +/// each command produces a value but for some you can ignore them by +/// calling `ignore` on the command. That way it will be skipped in the +/// return value which is useful for `SET` commands and others, which +/// do not have a useful return value. +impl Pipeline { + /// Creates an empty pipeline. For consistency with the `cmd` + /// api a `pipe` function is provided as alias. + pub fn new() -> Pipeline { + Self::with_capacity(0) + } + + /// Creates an empty pipeline with pre-allocated capacity. + pub fn with_capacity(capacity: usize) -> Pipeline { + Pipeline { + commands: Vec::with_capacity(capacity), + transaction_mode: false, + ignored_commands: HashSet::new(), + } + } + + /// This enables atomic mode. In atomic mode the whole pipeline is + /// enclosed in `MULTI`/`EXEC`. From the user's point of view nothing + /// changes however. This is easier than using `MULTI`/`EXEC` yourself + /// as the format does not change. + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let (k1, k2) : (i32, i32) = redis::pipe() + /// .atomic() + /// .cmd("GET").arg("key_1") + /// .cmd("GET").arg("key_2").query(&mut con).unwrap(); + /// ``` + #[inline] + pub fn atomic(&mut self) -> &mut Pipeline { + self.transaction_mode = true; + self + } + + /// Returns the encoded pipeline commands. + pub fn get_packed_pipeline(&self) -> Vec { + encode_pipeline(&self.commands, self.transaction_mode) + } + + #[cfg(feature = "aio")] + pub(crate) fn write_packed_pipeline(&self, out: &mut Vec) { + write_pipeline(out, &self.commands, self.transaction_mode) + } + + fn execute_pipelined(&self, con: &mut dyn ConnectionLike) -> RedisResult { + Ok(self.make_pipeline_results(con.req_packed_commands( + &encode_pipeline(&self.commands, false), + 0, + self.commands.len(), + )?)) + } + + fn execute_transaction(&self, con: &mut dyn ConnectionLike) -> RedisResult { + let mut resp = con.req_packed_commands( + &encode_pipeline(&self.commands, true), + self.commands.len() + 1, + 1, + )?; + match resp.pop() { + Some(Value::Nil) => Ok(Value::Nil), + Some(Value::Array(items)) => Ok(self.make_pipeline_results(items)), + _ => fail!(( + ErrorKind::ResponseError, + "Invalid response when parsing multi response" + )), + } + } + + /// Executes the pipeline and fetches the return values. Since most + /// pipelines return different types it's recommended to use tuple + /// matching to process the results: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let (k1, k2) : (i32, i32) = redis::pipe() + /// .cmd("SET").arg("key_1").arg(42).ignore() + /// .cmd("SET").arg("key_2").arg(43).ignore() + /// .cmd("GET").arg("key_1") + /// .cmd("GET").arg("key_2").query(&mut con).unwrap(); + /// ``` + /// + /// NOTE: A Pipeline object may be reused after `query()` with all the commands as were inserted + /// to them. In order to clear a Pipeline object with minimal memory released/allocated, + /// it is necessary to call the `clear()` before inserting new commands. + #[inline] + pub fn query(&self, con: &mut dyn ConnectionLike) -> RedisResult { + if !con.supports_pipelining() { + fail!(( + ErrorKind::ResponseError, + "This connection does not support pipelining." + )); + } + from_owned_redis_value(if self.commands.is_empty() { + Value::Array(vec![]) + } else if self.transaction_mode { + self.execute_transaction(con)? + } else { + self.execute_pipelined(con)? + }) + } + + #[cfg(feature = "aio")] + async fn execute_pipelined_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let value = con + .req_packed_commands(self, 0, self.commands.len()) + .await?; + Ok(self.make_pipeline_results(value)) + } + + #[cfg(feature = "aio")] + async fn execute_transaction_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let mut resp = con + .req_packed_commands(self, self.commands.len() + 1, 1) + .await?; + match resp.pop() { + Some(Value::Nil) => Ok(Value::Nil), + Some(Value::Array(items)) => Ok(self.make_pipeline_results(items)), + _ => Err(( + ErrorKind::ResponseError, + "Invalid response when parsing multi response", + ) + .into()), + } + } + + /// Async version of `query`. + #[inline] + #[cfg(feature = "aio")] + pub async fn query_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let v = if self.commands.is_empty() { + return from_owned_redis_value(Value::Array(vec![])); + } else if self.transaction_mode { + self.execute_transaction_async(con).await? + } else { + self.execute_pipelined_async(con).await? + }; + from_owned_redis_value(v) + } + + /// This is a shortcut to `query()` that does not return a value and + /// will fail the task if the query of the pipeline fails. + /// + /// This is equivalent to a call of query like this: + /// + /// ```rust,no_run + /// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); + /// # let mut con = client.get_connection(None).unwrap(); + /// let _ : () = redis::pipe().cmd("PING").query(&mut con).unwrap(); + /// ``` + /// + /// NOTE: A Pipeline object may be reused after `query()` with all the commands as were inserted + /// to them. In order to clear a Pipeline object with minimal memory released/allocated, + /// it is necessary to call the `clear()` before inserting new commands. + #[inline] + pub fn execute(&self, con: &mut dyn ConnectionLike) { + self.query::<()>(con).unwrap(); + } +} + +fn encode_pipeline(cmds: &[Cmd], atomic: bool) -> Vec { + let mut rv = vec![]; + write_pipeline(&mut rv, cmds, atomic); + rv +} + +fn write_pipeline(rv: &mut Vec, cmds: &[Cmd], atomic: bool) { + let cmds_len = cmds.iter().map(cmd_len).sum(); + + if atomic { + let multi = cmd("MULTI"); + let exec = cmd("EXEC"); + rv.reserve(cmd_len(&multi) + cmd_len(&exec) + cmds_len); + + multi.write_packed_command_preallocated(rv); + for cmd in cmds { + cmd.write_packed_command_preallocated(rv); + } + exec.write_packed_command_preallocated(rv); + } else { + rv.reserve(cmds_len); + + for cmd in cmds { + cmd.write_packed_command_preallocated(rv); + } + } +} + +// Macro to implement shared methods between Pipeline and ClusterPipeline +macro_rules! implement_pipeline_commands { + ($struct_name:ident) => { + impl $struct_name { + /// Adds a command to the cluster pipeline. + #[inline] + pub fn add_command(&mut self, cmd: Cmd) -> &mut Self { + self.commands.push(cmd); + self + } + + /// Starts a new command. Functions such as `arg` then become + /// available to add more arguments to that command. + #[inline] + pub fn cmd(&mut self, name: &str) -> &mut Self { + self.add_command(cmd(name)) + } + + /// Returns an iterator over all the commands currently in this pipeline + pub fn cmd_iter(&self) -> impl Iterator { + self.commands.iter() + } + + /// Instructs the pipeline to ignore the return value of this command. + /// It will still be ensured that it is not an error, but any successful + /// result is just thrown away. This makes result processing through + /// tuples much easier because you do not need to handle all the items + /// you do not care about. + #[inline] + pub fn ignore(&mut self) -> &mut Self { + match self.commands.len() { + 0 => true, + x => self.ignored_commands.insert(x - 1), + }; + self + } + + /// Adds an argument to the last started command. This works similar + /// to the `arg` method of the `Cmd` object. + /// + /// Note that this function fails the task if executed on an empty pipeline. + #[inline] + pub fn arg(&mut self, arg: T) -> &mut Self { + { + let cmd = self.get_last_command(); + cmd.arg(arg); + } + self + } + + /// Clear a pipeline object's internal data structure. + /// + /// This allows reusing a pipeline object as a clear object while performing a minimal + /// amount of memory released/reallocated. + #[inline] + pub fn clear(&mut self) { + self.commands.clear(); + self.ignored_commands.clear(); + } + + #[inline] + fn get_last_command(&mut self) -> &mut Cmd { + let idx = match self.commands.len() { + 0 => panic!("No command on stack"), + x => x - 1, + }; + &mut self.commands[idx] + } + + fn make_pipeline_results(&self, resp: Vec) -> Value { + let mut rv = Vec::with_capacity(resp.len() - self.ignored_commands.len()); + for (idx, result) in resp.into_iter().enumerate() { + if !self.ignored_commands.contains(&idx) { + rv.push(result); + } + } + Value::Array(rv) + } + } + + impl Default for $struct_name { + fn default() -> Self { + Self::new() + } + } + }; +} + +implement_pipeline_commands!(Pipeline); diff --git a/glide-core/redis-rs/redis/src/push_manager.rs b/glide-core/redis-rs/redis/src/push_manager.rs new file mode 100644 index 0000000000..8a22e06a57 --- /dev/null +++ b/glide-core/redis-rs/redis/src/push_manager.rs @@ -0,0 +1,234 @@ +use crate::{PushKind, RedisResult, Value}; +use arc_swap::ArcSwap; +use std::sync::Arc; +use tokio::sync::mpsc; + +/// Holds information about received Push data +#[derive(Debug, Clone)] +pub struct PushInfo { + /// Push Kind + pub kind: PushKind, + /// Data from push message + pub data: Vec, +} + +/// Manages Push messages for single tokio channel +#[derive(Clone, Default)] +pub struct PushManager { + sender: Arc>>>, +} +impl PushManager { + /// It checks if value's type is Push + /// then invokes `try_send_raw` method + pub(crate) fn try_send(&self, value: &RedisResult) { + if let Ok(value) = &value { + self.try_send_raw(value); + } + } + + /// It checks if value's type is Push and there is a provided sender + /// then creates PushInfo and invokes `send` method of sender + pub(crate) fn try_send_raw(&self, value: &Value) { + if let Value::Push { kind, data } = value { + let guard = self.sender.load(); + if let Some(sender) = guard.as_ref() { + let push_info = PushInfo { + kind: kind.clone(), + data: data.clone(), + }; + if sender.send(push_info).is_err() { + self.sender.compare_and_swap(guard, Arc::new(None)); + } + } + } + } + /// Replace mpsc channel of `PushManager` with provided sender. + pub fn replace_sender(&self, sender: mpsc::UnboundedSender) { + self.sender.store(Arc::new(Some(sender))); + } + + /// Creates new `PushManager` + pub fn new() -> Self { + PushManager { + sender: Arc::from(ArcSwap::from(Arc::new(None))), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_send_and_receive_push_info() { + let push_manager = PushManager::new(); + let (tx, mut rx) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx); + + let value = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello".to_string().into_bytes())], + }); + + push_manager.try_send(&value); + + let push_info = rx.try_recv().unwrap(); + assert_eq!(push_info.kind, PushKind::Message); + assert_eq!( + push_info.data, + vec![Value::BulkString("hello".to_string().into_bytes())] + ); + } + #[test] + fn test_push_manager_receiver_dropped() { + let push_manager = PushManager::new(); + let (tx, rx) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx); + + let value = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello".to_string().into_bytes())], + }); + + drop(rx); + + push_manager.try_send(&value); + push_manager.try_send(&value); + push_manager.try_send(&value); + } + #[test] + fn test_push_manager_without_sender() { + let push_manager = PushManager::new(); + + push_manager.try_send(&Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello".to_string().into_bytes())], + })); // nothing happens! + + let (tx, mut rx) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx); + push_manager.try_send(&Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::BulkString("hello2".to_string().into_bytes())], + })); + + assert_eq!( + rx.try_recv().unwrap().data, + vec![Value::BulkString("hello2".to_string().into_bytes())] + ); + } + #[test] + fn test_push_manager_multiple_channels_and_messages() { + let push_manager = PushManager::new(); + let (tx1, mut rx1) = mpsc::unbounded_channel(); + let (tx2, mut rx2) = mpsc::unbounded_channel(); + push_manager.replace_sender(tx1); + + let value1 = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::Int(1)], + }); + + let value2 = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::Int(2)], + }); + + push_manager.try_send(&value1); + push_manager.try_send(&value2); + + assert_eq!(rx1.try_recv().unwrap().data, vec![Value::Int(1)]); + assert_eq!(rx1.try_recv().unwrap().data, vec![Value::Int(2)]); + + push_manager.replace_sender(tx2); + // make sure rx1 is disconnected after replacing tx1 with tx2. + assert_eq!( + rx1.try_recv().err().unwrap(), + mpsc::error::TryRecvError::Disconnected + ); + + push_manager.try_send(&value1); + push_manager.try_send(&value2); + + assert_eq!(rx2.try_recv().unwrap().data, vec![Value::Int(1)]); + assert_eq!(rx2.try_recv().unwrap().data, vec![Value::Int(2)]); + } + + #[tokio::test] + async fn test_push_manager_multi_threaded() { + // In this test we create 4 channels and send 1000 message, it switchs channels for each message we sent. + // Then we check if all messages are received and sum of messages are equal to expected sum. + // We also check if all channels are used. + let push_manager = PushManager::new(); + let (tx1, mut rx1) = mpsc::unbounded_channel(); + let (tx2, mut rx2) = mpsc::unbounded_channel(); + let (tx3, mut rx3) = mpsc::unbounded_channel(); + let (tx4, mut rx4) = mpsc::unbounded_channel(); + + let mut handles = vec![]; + let txs = [tx1, tx2, tx3, tx4]; + let mut expected_sum = 0; + for i in 0..1000 { + expected_sum += i; + let push_manager_clone = push_manager.clone(); + let new_tx = txs[(i % 4) as usize].clone(); + let value = Ok(Value::Push { + kind: PushKind::Message, + data: vec![Value::Int(i)], + }); + let handle = tokio::spawn(async move { + push_manager_clone.replace_sender(new_tx); + push_manager_clone.try_send(&value); + }); + handles.push(handle); + } + + for handle in handles { + handle.await.unwrap(); + } + + let mut count1 = 0; + let mut count2 = 0; + let mut count3 = 0; + let mut count4 = 0; + let mut received_sum = 0; + while let Ok(push_info) = rx1.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count1 += 1; + } + while let Ok(push_info) = rx2.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count2 += 1; + } + + while let Ok(push_info) = rx3.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count3 += 1; + } + + while let Ok(push_info) = rx4.try_recv() { + assert_eq!(push_info.kind, PushKind::Message); + if let Value::Int(i) = push_info.data[0] { + received_sum += i; + } + count4 += 1; + } + + assert_ne!(count1, 0); + assert_ne!(count2, 0); + assert_ne!(count3, 0); + assert_ne!(count4, 0); + + assert_eq!(count1 + count2 + count3 + count4, 1000); + assert_eq!(received_sum, expected_sum); + } +} diff --git a/glide-core/redis-rs/redis/src/r2d2.rs b/glide-core/redis-rs/redis/src/r2d2.rs new file mode 100644 index 0000000000..e34d2c7bb9 --- /dev/null +++ b/glide-core/redis-rs/redis/src/r2d2.rs @@ -0,0 +1,36 @@ +use std::io; + +use crate::{ConnectionLike, RedisError}; + +macro_rules! impl_manage_connection { + ($client:ty, $connection:ty) => { + impl r2d2::ManageConnection for $client { + type Connection = $connection; + type Error = RedisError; + + fn connect(&self) -> Result { + self.get_connection(None) + } + + fn is_valid(&self, conn: &mut Self::Connection) -> Result<(), Self::Error> { + if conn.check_connection() { + Ok(()) + } else { + Err(RedisError::from(io::Error::from(io::ErrorKind::BrokenPipe))) + } + } + + fn has_broken(&self, conn: &mut Self::Connection) -> bool { + !conn.is_open() + } + } + }; +} + +impl_manage_connection!(crate::Client, crate::Connection); + +#[cfg(feature = "cluster")] +impl_manage_connection!( + crate::cluster::ClusterClient, + crate::cluster::ClusterConnection +); diff --git a/glide-core/redis-rs/redis/src/script.rs b/glide-core/redis-rs/redis/src/script.rs new file mode 100644 index 0000000000..c62d2344ae --- /dev/null +++ b/glide-core/redis-rs/redis/src/script.rs @@ -0,0 +1,255 @@ +#![cfg(feature = "script")] +use sha1_smol::Sha1; + +use crate::cmd::cmd; +use crate::connection::ConnectionLike; +use crate::types::{ErrorKind, FromRedisValue, RedisResult, ToRedisArgs}; +use crate::Cmd; + +/// Represents a lua script. +#[derive(Debug, Clone)] +pub struct Script { + code: String, + hash: String, +} + +/// The script object represents a lua script that can be executed on the +/// redis server. The object itself takes care of automatic uploading and +/// execution. The script object itself can be shared and is immutable. +/// +/// Example: +/// +/// ```rust,no_run +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let script = redis::Script::new(r" +/// return tonumber(ARGV[1]) + tonumber(ARGV[2]); +/// "); +/// let result = script.arg(1).arg(2).invoke(&mut con); +/// assert_eq!(result, Ok(3)); +/// ``` +impl Script { + /// Creates a new script object. + pub fn new(code: &str) -> Script { + let mut hash = Sha1::new(); + hash.update(code.as_bytes()); + Script { + code: code.to_string(), + hash: hash.digest().to_string(), + } + } + + /// Returns the script's SHA1 hash in hexadecimal format. + pub fn get_hash(&self) -> &str { + &self.hash + } + + /// Creates a script invocation object with a key filled in. + #[inline] + pub fn key(&self, key: T) -> ScriptInvocation<'_> { + ScriptInvocation { + script: self, + args: vec![], + keys: key.to_redis_args(), + } + } + + /// Creates a script invocation object with an argument filled in. + #[inline] + pub fn arg(&self, arg: T) -> ScriptInvocation<'_> { + ScriptInvocation { + script: self, + args: arg.to_redis_args(), + keys: vec![], + } + } + + /// Returns an empty script invocation object. This is primarily useful + /// for programmatically adding arguments and keys because the type will + /// not change. Normally you can use `arg` and `key` directly. + #[inline] + pub fn prepare_invoke(&self) -> ScriptInvocation<'_> { + ScriptInvocation { + script: self, + args: vec![], + keys: vec![], + } + } + + /// Invokes the script directly without arguments. + #[inline] + pub fn invoke(&self, con: &mut dyn ConnectionLike) -> RedisResult { + ScriptInvocation { + script: self, + args: vec![], + keys: vec![], + } + .invoke(con) + } + + /// Asynchronously invokes the script without arguments. + #[inline] + #[cfg(feature = "aio")] + pub async fn invoke_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + T: FromRedisValue, + { + ScriptInvocation { + script: self, + args: vec![], + keys: vec![], + } + .invoke_async(con) + .await + } +} + +/// Represents a prepared script call. +pub struct ScriptInvocation<'a> { + script: &'a Script, + args: Vec>, + keys: Vec>, +} + +/// This type collects keys and other arguments for the script so that it +/// can be then invoked. While the `Script` type itself holds the script, +/// the `ScriptInvocation` holds the arguments that should be invoked until +/// it's sent to the server. +impl<'a> ScriptInvocation<'a> { + /// Adds a regular argument to the invocation. This ends up as `ARGV[i]` + /// in the script. + #[inline] + pub fn arg<'b, T: ToRedisArgs>(&'b mut self, arg: T) -> &'b mut ScriptInvocation<'a> + where + 'a: 'b, + { + arg.write_redis_args(&mut self.args); + self + } + + /// Adds a key argument to the invocation. This ends up as `KEYS[i]` + /// in the script. + #[inline] + pub fn key<'b, T: ToRedisArgs>(&'b mut self, key: T) -> &'b mut ScriptInvocation<'a> + where + 'a: 'b, + { + key.write_redis_args(&mut self.keys); + self + } + + /// Invokes the script and returns the result. + #[inline] + pub fn invoke(&self, con: &mut dyn ConnectionLike) -> RedisResult { + let eval_cmd = self.eval_cmd(); + match eval_cmd.query(con) { + Ok(val) => Ok(val), + Err(err) => { + if err.kind() == ErrorKind::NoScriptError { + self.load_cmd().query(con)?; + eval_cmd.query(con) + } else { + Err(err) + } + } + } + } + + /// Asynchronously invokes the script and returns the result. + #[inline] + #[cfg(feature = "aio")] + pub async fn invoke_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + T: FromRedisValue, + { + let eval_cmd = self.eval_cmd(); + match eval_cmd.query_async(con).await { + Ok(val) => { + // Return the value from the script evaluation + Ok(val) + } + Err(err) => { + // Load the script into Redis if the script hash wasn't there already + if err.kind() == ErrorKind::NoScriptError { + self.load_cmd().query_async(con).await?; + eval_cmd.query_async(con).await + } else { + Err(err) + } + } + } + } + + /// Loads the script and returns the SHA1 of it. + #[inline] + pub fn load(&self, con: &mut dyn ConnectionLike) -> RedisResult { + let hash: String = self.load_cmd().query(con)?; + + debug_assert_eq!(hash, self.script.hash); + + Ok(hash) + } + + /// Asynchronously loads the script and returns the SHA1 of it. + #[inline] + #[cfg(feature = "aio")] + pub async fn load_async(&self, con: &mut C) -> RedisResult + where + C: crate::aio::ConnectionLike, + { + let hash: String = self.load_cmd().query_async(con).await?; + + debug_assert_eq!(hash, self.script.hash); + + Ok(hash) + } + + fn load_cmd(&self) -> Cmd { + let mut cmd = cmd("SCRIPT"); + cmd.arg("LOAD").arg(self.script.code.as_bytes()); + cmd + } + + fn estimate_buflen(&self) -> usize { + self + .keys + .iter() + .chain(self.args.iter()) + .fold(0, |acc, e| acc + e.len()) + + 7 /* "EVALSHA".len() */ + + self.script.hash.len() + + 4 /* Slots reserved for the length of keys. */ + } + + fn eval_cmd(&self) -> Cmd { + let args_len = 3 + self.keys.len() + self.args.len(); + let mut cmd = Cmd::with_capacity(args_len, self.estimate_buflen()); + cmd.arg("EVALSHA") + .arg(self.script.hash.as_bytes()) + .arg(self.keys.len()) + .arg(&*self.keys) + .arg(&*self.args); + cmd + } +} + +#[cfg(test)] +mod tests { + use super::Script; + + #[test] + fn script_eval_should_work() { + let script = Script::new("return KEYS[1]"); + let invocation = script.key("dummy"); + let estimated_buflen = invocation.estimate_buflen(); + let cmd = invocation.eval_cmd(); + assert!(estimated_buflen >= cmd.capacity().1); + let expected = "*4\r\n$7\r\nEVALSHA\r\n$40\r\n4a2267357833227dd98abdedb8cf24b15a986445\r\n$1\r\n1\r\n$5\r\ndummy\r\n"; + assert_eq!( + expected, + std::str::from_utf8(cmd.get_packed_command().as_slice()).unwrap() + ); + } +} diff --git a/glide-core/redis-rs/redis/src/sentinel.rs b/glide-core/redis-rs/redis/src/sentinel.rs new file mode 100644 index 0000000000..ac6aac65cc --- /dev/null +++ b/glide-core/redis-rs/redis/src/sentinel.rs @@ -0,0 +1,778 @@ +//! Defines a Sentinel type that connects to Redis sentinels and creates clients to +//! master or replica nodes. +//! +//! # Example +//! ```rust,no_run +//! use redis::Commands; +//! use redis::sentinel::Sentinel; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut sentinel = Sentinel::build(nodes).unwrap(); +//! let mut master = sentinel.master_for("master_name", None).unwrap().get_connection(None).unwrap(); +//! let mut replica = sentinel.replica_for("master_name", None).unwrap().get_connection(None).unwrap(); +//! +//! let _: () = master.set("test", "test_data").unwrap(); +//! let rv: String = replica.get("test").unwrap(); +//! +//! assert_eq!(rv, "test_data"); +//! ``` +//! +//! There is also a SentinelClient which acts like a regular Client, providing the +//! `get_connection` and `get_async_connection` methods, internally using the Sentinel +//! type to create clients on demand for the desired node type (Master or Replica). +//! +//! # Example +//! ```rust,no_run +//! use redis::Commands; +//! use redis::sentinel::{ SentinelServerType, SentinelClient }; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut master_client = SentinelClient::build(nodes.clone(), String::from("master_name"), None, SentinelServerType::Master).unwrap(); +//! let mut replica_client = SentinelClient::build(nodes, String::from("master_name"), None, SentinelServerType::Replica).unwrap(); +//! let mut master_conn = master_client.get_connection().unwrap(); +//! let mut replica_conn = replica_client.get_connection().unwrap(); +//! +//! let _: () = master_conn.set("test", "test_data").unwrap(); +//! let rv: String = replica_conn.get("test").unwrap(); +//! +//! assert_eq!(rv, "test_data"); +//! ``` +//! +//! If the sentinel's nodes are using TLS or require authentication, a full +//! SentinelNodeConnectionInfo struct may be used instead of just the master's name: +//! +//! # Example +//! ```rust,no_run +//! use redis::{ Commands, RedisConnectionInfo }; +//! use redis::sentinel::{ Sentinel, SentinelNodeConnectionInfo }; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut sentinel = Sentinel::build(nodes).unwrap(); +//! +//! let mut master_with_auth = sentinel +//! .master_for( +//! "master_name", +//! Some(&SentinelNodeConnectionInfo { +//! tls_mode: None, +//! redis_connection_info: Some(RedisConnectionInfo { +//! db: 1, +//! username: Some(String::from("foo")), +//! password: Some(String::from("bar")), +//! ..Default::default() +//! }), +//! }), +//! ) +//! .unwrap() +//! .get_connection(None) +//! .unwrap(); +//! +//! let mut replica_with_tls = sentinel +//! .master_for( +//! "master_name", +//! Some(&SentinelNodeConnectionInfo { +//! tls_mode: Some(redis::TlsMode::Secure), +//! redis_connection_info: None, +//! }), +//! ) +//! .unwrap() +//! .get_connection(None) +//! .unwrap(); +//! ``` +//! +//! # Example +//! ```rust,no_run +//! use redis::{ Commands, RedisConnectionInfo }; +//! use redis::sentinel::{ SentinelServerType, SentinelClient, SentinelNodeConnectionInfo }; +//! +//! let nodes = vec!["redis://127.0.0.1:6379/", "redis://127.0.0.1:6378/", "redis://127.0.0.1:6377/"]; +//! let mut master_client = SentinelClient::build( +//! nodes, +//! String::from("master1"), +//! Some(SentinelNodeConnectionInfo { +//! tls_mode: Some(redis::TlsMode::Insecure), +//! redis_connection_info: Some(RedisConnectionInfo { +//! username: Some(String::from("user")), +//! password: Some(String::from("pass")), +//! ..Default::default() +//! }), +//! }), +//! redis::sentinel::SentinelServerType::Master, +//! ) +//! .unwrap(); +//! ``` +//! + +use std::{collections::HashMap, num::NonZeroUsize}; + +#[cfg(feature = "aio")] +use futures_util::StreamExt; +use rand::Rng; + +#[cfg(feature = "aio")] +use crate::aio::MultiplexedConnection as AsyncConnection; + +use crate::{ + client::GlideConnectionOptions, connection::ConnectionInfo, types::RedisResult, Client, Cmd, + Connection, ErrorKind, FromRedisValue, IntoConnectionInfo, RedisConnectionInfo, TlsMode, Value, +}; + +/// The Sentinel type, serves as a special purpose client which builds other clients on +/// demand. +pub struct Sentinel { + sentinels_connection_info: Vec, + connections_cache: Vec>, + #[cfg(feature = "aio")] + async_connections_cache: Vec>, + replica_start_index: usize, +} + +/// Holds the connection information that a sentinel should use when connecting to the +/// servers (masters and replicas) belonging to it. +#[derive(Clone, Default)] +pub struct SentinelNodeConnectionInfo { + /// The TLS mode of the connection, or None if we do not want to connect using TLS + /// (just a plain TCP connection). + pub tls_mode: Option, + + /// The Redis specific/connection independent information to be used. + pub redis_connection_info: Option, +} + +impl SentinelNodeConnectionInfo { + fn create_connection_info(&self, ip: String, port: u16) -> ConnectionInfo { + let addr = match self.tls_mode { + None => crate::ConnectionAddr::Tcp(ip, port), + Some(TlsMode::Secure) => crate::ConnectionAddr::TcpTls { + host: ip, + port, + insecure: false, + tls_params: None, + }, + Some(TlsMode::Insecure) => crate::ConnectionAddr::TcpTls { + host: ip, + port, + insecure: true, + tls_params: None, + }, + }; + + ConnectionInfo { + addr, + redis: self.redis_connection_info.clone().unwrap_or_default(), + } + } +} + +impl Default for &SentinelNodeConnectionInfo { + fn default() -> Self { + static DEFAULT_VALUE: SentinelNodeConnectionInfo = SentinelNodeConnectionInfo { + tls_mode: None, + redis_connection_info: None, + }; + &DEFAULT_VALUE + } +} + +fn sentinel_masters_cmd() -> crate::Cmd { + let mut cmd = crate::cmd("SENTINEL"); + cmd.arg("MASTERS"); + cmd +} + +fn sentinel_replicas_cmd(master_name: &str) -> crate::Cmd { + let mut cmd = crate::cmd("SENTINEL"); + cmd.arg("SLAVES"); // For compatibility with older redis versions + cmd.arg(master_name); + cmd +} + +fn is_master_valid(master_info: &HashMap, service_name: &str) -> bool { + master_info.get("name").map(|s| s.as_str()) == Some(service_name) + && master_info.contains_key("ip") + && master_info.contains_key("port") + && master_info.get("flags").map_or(false, |flags| { + flags.contains("master") && !flags.contains("s_down") && !flags.contains("o_down") + }) + && master_info["port"].parse::().is_ok() +} + +fn is_replica_valid(replica_info: &HashMap) -> bool { + replica_info.contains_key("ip") + && replica_info.contains_key("port") + && replica_info.get("flags").map_or(false, |flags| { + !flags.contains("s_down") && !flags.contains("o_down") + }) + && replica_info["port"].parse::().is_ok() +} + +/// Generates a random value in the 0..max range. +fn random_replica_index(max: NonZeroUsize) -> usize { + rand::thread_rng().gen_range(0..max.into()) +} + +fn try_connect_to_first_replica( + addresses: &[ConnectionInfo], + start_index: Option, +) -> Result { + if addresses.is_empty() { + fail!(( + ErrorKind::NoValidReplicasFoundBySentinel, + "No valid replica found in sentinel for given name", + )); + } + + let start_index = start_index.unwrap_or(0); + + let mut last_err = None; + for i in 0..addresses.len() { + let index = (i + start_index) % addresses.len(); + match Client::open(addresses[index].clone()) { + Ok(client) => return Ok(client), + Err(err) => last_err = Some(err), + } + } + + // We can unwrap here because we know there is at least one error, since there is at + // least one address, so we'll either return a client for it or store an error in + // last_err. + Err(last_err.expect("There should be an error because there is should be at least one address")) +} + +fn valid_addrs<'a>( + servers_info: Vec>, + validate: impl Fn(&HashMap) -> bool + 'a, +) -> impl Iterator { + servers_info + .into_iter() + .filter(move |info| validate(info)) + .map(|mut info| { + // We can unwrap here because we already checked everything + let ip = info.remove("ip").unwrap(); + let port = info["port"].parse::().unwrap(); + (ip, port) + }) +} + +fn check_role_result(result: &RedisResult>, target_role: &str) -> bool { + if let Ok(values) = result { + if !values.is_empty() { + if let Ok(role) = String::from_redis_value(&values[0]) { + return role.to_ascii_lowercase() == target_role; + } + } + } + false +} + +fn check_role(connection_info: &ConnectionInfo, target_role: &str) -> bool { + if let Ok(client) = Client::open(connection_info.clone()) { + if let Ok(mut conn) = client.get_connection(None) { + let result: RedisResult> = crate::cmd("ROLE").query(&mut conn); + return check_role_result(&result, target_role); + } + } + false +} + +/// Searches for a valid master with the given name in the list of masters returned by +/// a sentinel. A valid master is one which has a role of "master" (checked by running +/// the `ROLE` command and by seeing if its flags contains the "master" flag) and which +/// does not have the flags s_down or o_down set to it (these flags are returned by the +/// `SENTINEL MASTERS` command, and we expect the `masters` parameter to be the result of +/// that command). +fn find_valid_master( + masters: Vec>, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, +) -> RedisResult { + for (ip, port) in valid_addrs(masters, |m| is_master_valid(m, service_name)) { + let connection_info = node_connection_info.create_connection_info(ip, port); + if check_role(&connection_info, "master") { + return Ok(connection_info); + } + } + + fail!(( + ErrorKind::MasterNameNotFoundBySentinel, + "Master with given name not found in sentinel", + )) +} + +#[cfg(feature = "aio")] +async fn async_check_role(connection_info: &ConnectionInfo, target_role: &str) -> bool { + if let Ok(client) = Client::open(connection_info.clone()) { + if let Ok(mut conn) = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + { + let result: RedisResult> = crate::cmd("ROLE").query_async(&mut conn).await; + return check_role_result(&result, target_role); + } + } + false +} + +/// Async version of [find_valid_master]. +#[cfg(feature = "aio")] +async fn async_find_valid_master( + masters: Vec>, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, +) -> RedisResult { + for (ip, port) in valid_addrs(masters, |m| is_master_valid(m, service_name)) { + let connection_info = node_connection_info.create_connection_info(ip, port); + if async_check_role(&connection_info, "master").await { + return Ok(connection_info); + } + } + + fail!(( + ErrorKind::MasterNameNotFoundBySentinel, + "Master with given name not found in sentinel", + )) +} + +fn get_valid_replicas_addresses( + replicas: Vec>, + node_connection_info: &SentinelNodeConnectionInfo, +) -> Vec { + valid_addrs(replicas, is_replica_valid) + .map(|(ip, port)| node_connection_info.create_connection_info(ip, port)) + .filter(|connection_info| check_role(connection_info, "slave")) + .collect() +} + +#[cfg(feature = "aio")] +async fn async_get_valid_replicas_addresses<'a>( + replicas: Vec>, + node_connection_info: &SentinelNodeConnectionInfo, +) -> Vec { + async fn is_replica_role_valid(connection_info: ConnectionInfo) -> Option { + if async_check_role(&connection_info, "slave").await { + Some(connection_info) + } else { + None + } + } + + futures_util::stream::iter(valid_addrs(replicas, is_replica_valid)) + .map(|(ip, port)| node_connection_info.create_connection_info(ip, port)) + .filter_map(is_replica_role_valid) + .collect() + .await +} + +#[cfg(feature = "aio")] +async fn async_reconnect( + connection: &mut Option, + connection_info: &ConnectionInfo, +) -> RedisResult<()> { + let sentinel_client = Client::open(connection_info.clone())?; + let new_connection = sentinel_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + connection.replace(new_connection); + Ok(()) +} + +#[cfg(feature = "aio")] +async fn async_try_single_sentinel( + cmd: Cmd, + connection_info: &ConnectionInfo, + cached_connection: &mut Option, +) -> RedisResult { + if cached_connection.is_none() { + async_reconnect(cached_connection, connection_info).await?; + } + + let result = cmd.query_async(cached_connection.as_mut().unwrap()).await; + + if let Err(err) = result { + if err.is_unrecoverable_error() || err.is_io_error() { + async_reconnect(cached_connection, connection_info).await?; + cmd.query_async(cached_connection.as_mut().unwrap()).await + } else { + Err(err) + } + } else { + result + } +} + +fn reconnect( + connection: &mut Option, + connection_info: &ConnectionInfo, +) -> RedisResult<()> { + let sentinel_client = Client::open(connection_info.clone())?; + let new_connection = sentinel_client.get_connection(None)?; + connection.replace(new_connection); + Ok(()) +} + +fn try_single_sentinel( + cmd: Cmd, + connection_info: &ConnectionInfo, + cached_connection: &mut Option, +) -> RedisResult { + if cached_connection.is_none() { + reconnect(cached_connection, connection_info)?; + } + + let result = cmd.query(cached_connection.as_mut().unwrap()); + + if let Err(err) = result { + if err.is_unrecoverable_error() || err.is_io_error() { + reconnect(cached_connection, connection_info)?; + cmd.query(cached_connection.as_mut().unwrap()) + } else { + Err(err) + } + } else { + result + } +} + +// non-async methods +impl Sentinel { + /// Creates a Sentinel client performing some basic + /// checks on the URLs that might make the operation fail. + pub fn build(params: Vec) -> RedisResult { + if params.is_empty() { + fail!(( + ErrorKind::EmptySentinelList, + "At least one sentinel is required", + )) + } + + let sentinels_connection_info = params + .into_iter() + .map(|p| p.into_connection_info()) + .collect::>>()?; + + let mut connections_cache = vec![]; + connections_cache.resize_with(sentinels_connection_info.len(), Default::default); + + #[cfg(feature = "aio")] + { + let mut async_connections_cache = vec![]; + async_connections_cache.resize_with(sentinels_connection_info.len(), Default::default); + + Ok(Sentinel { + sentinels_connection_info, + connections_cache, + async_connections_cache, + replica_start_index: random_replica_index(NonZeroUsize::new(1000000).unwrap()), + }) + } + + #[cfg(not(feature = "aio"))] + { + Ok(Sentinel { + sentinels_connection_info, + connections_cache, + replica_start_index: random_replica_index(NonZeroUsize::new(1000000).unwrap()), + }) + } + } + + /// Try to execute the given command in each sentinel, returning the result of the + /// first one that executes without errors. If all return errors, we return the + /// error of the last attempt. + /// + /// For each sentinel, we first check if there is a cached connection, and if not + /// we attempt to connect to it (skipping that sentinel if there is an error during + /// the connection). Then, we attempt to execute the given command with the cached + /// connection. If there is an error indicating that the connection is invalid, we + /// reconnect and try one more time in the new connection. + /// + fn try_all_sentinels(&mut self, cmd: Cmd) -> RedisResult { + let mut last_err = None; + for (connection_info, cached_connection) in self + .sentinels_connection_info + .iter() + .zip(self.connections_cache.iter_mut()) + { + match try_single_sentinel(cmd.clone(), connection_info, cached_connection) { + Ok(result) => { + return Ok(result); + } + Err(err) => { + last_err = Some(err); + } + } + } + + // We can unwrap here because we know there is at least one connection info. + Err(last_err.expect("There should be at least one connection info")) + } + + /// Get a list of all masters (using the command SENTINEL MASTERS) from the + /// sentinels. + fn get_sentinel_masters(&mut self) -> RedisResult>> { + self.try_all_sentinels(sentinel_masters_cmd()) + } + + fn get_sentinel_replicas( + &mut self, + service_name: &str, + ) -> RedisResult>> { + self.try_all_sentinels(sentinel_replicas_cmd(service_name)) + } + + fn find_master_address( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult { + let masters = self.get_sentinel_masters()?; + find_valid_master(masters, service_name, node_connection_info) + } + + fn find_valid_replica_addresses( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult> { + let replicas = self.get_sentinel_replicas(service_name)?; + Ok(get_valid_replicas_addresses(replicas, node_connection_info)) + } + + /// Determines the masters address for the given name, and returns a client for that + /// master. + pub fn master_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let connection_info = + self.find_master_address(service_name, node_connection_info.unwrap_or_default())?; + Client::open(connection_info) + } + + /// Connects to a randomly chosen replica of the given master name. + pub fn replica_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .find_valid_replica_addresses(service_name, node_connection_info.unwrap_or_default())?; + let start_index = NonZeroUsize::new(addresses.len()).map(random_replica_index); + try_connect_to_first_replica(&addresses, start_index) + } + + /// Attempts to connect to a different replica of the given master name each time. + /// There is no guarantee that we'll actually be connecting to a different replica + /// in the next call, but in a static set of replicas (no replicas added or + /// removed), on average we'll choose each replica the same number of times. + pub fn replica_rotate_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .find_valid_replica_addresses(service_name, node_connection_info.unwrap_or_default())?; + if !addresses.is_empty() { + self.replica_start_index = (self.replica_start_index + 1) % addresses.len(); + } + try_connect_to_first_replica(&addresses, Some(self.replica_start_index)) + } +} + +// Async versions of the public methods above, along with async versions of private +// methods required for the public methods. +#[cfg(feature = "aio")] +impl Sentinel { + async fn async_try_all_sentinels(&mut self, cmd: Cmd) -> RedisResult { + let mut last_err = None; + for (connection_info, cached_connection) in self + .sentinels_connection_info + .iter() + .zip(self.async_connections_cache.iter_mut()) + { + match async_try_single_sentinel(cmd.clone(), connection_info, cached_connection).await { + Ok(result) => { + return Ok(result); + } + Err(err) => { + last_err = Some(err); + } + } + } + + // We can unwrap here because we know there is at least one connection info. + Err(last_err.expect("There should be at least one connection info")) + } + + async fn async_get_sentinel_masters(&mut self) -> RedisResult>> { + self.async_try_all_sentinels(sentinel_masters_cmd()).await + } + + async fn async_get_sentinel_replicas<'a>( + &mut self, + service_name: &'a str, + ) -> RedisResult>> { + self.async_try_all_sentinels(sentinel_replicas_cmd(service_name)) + .await + } + + async fn async_find_master_address<'a>( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult { + let masters = self.async_get_sentinel_masters().await?; + async_find_valid_master(masters, service_name, node_connection_info).await + } + + async fn async_find_valid_replica_addresses<'a>( + &mut self, + service_name: &str, + node_connection_info: &SentinelNodeConnectionInfo, + ) -> RedisResult> { + let replicas = self.async_get_sentinel_replicas(service_name).await?; + Ok(async_get_valid_replicas_addresses(replicas, node_connection_info).await) + } + + /// Determines the masters address for the given name, and returns a client for that + /// master. + pub async fn async_master_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let address = self + .async_find_master_address(service_name, node_connection_info.unwrap_or_default()) + .await?; + Client::open(address) + } + + /// Connects to a randomly chosen replica of the given master name. + pub async fn async_replica_for( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .async_find_valid_replica_addresses( + service_name, + node_connection_info.unwrap_or_default(), + ) + .await?; + let start_index = NonZeroUsize::new(addresses.len()).map(random_replica_index); + try_connect_to_first_replica(&addresses, start_index) + } + + /// Attempts to connect to a different replica of the given master name each time. + /// There is no guarantee that we'll actually be connecting to a different replica + /// in the next call, but in a static set of replicas (no replicas added or + /// removed), on average we'll choose each replica the same number of times. + pub async fn async_replica_rotate_for<'a>( + &mut self, + service_name: &str, + node_connection_info: Option<&SentinelNodeConnectionInfo>, + ) -> RedisResult { + let addresses = self + .async_find_valid_replica_addresses( + service_name, + node_connection_info.unwrap_or_default(), + ) + .await?; + if !addresses.is_empty() { + self.replica_start_index = (self.replica_start_index + 1) % addresses.len(); + } + try_connect_to_first_replica(&addresses, Some(self.replica_start_index)) + } +} + +/// Enum defining the server types from a sentinel's point of view. +#[derive(Debug, Clone)] +pub enum SentinelServerType { + /// Master connections only + Master, + /// Replica connections only + Replica, +} + +/// An alternative to the Client type which creates connections from clients created +/// on-demand based on information fetched from the sentinels. Uses the Sentinel type +/// internally. This is basic an utility to help make it easier to use sentinels but +/// with an interface similar to the client (`get_connection` and +/// `get_async_connection`). The type of server (master or replica) and name of the +/// desired master are specified when constructing an instance, so it will always +/// return connections to the same target (for example, always to the master with name +/// "mymaster123", or always to replicas of the master "another-master-abc"). +pub struct SentinelClient { + sentinel: Sentinel, + service_name: String, + node_connection_info: SentinelNodeConnectionInfo, + server_type: SentinelServerType, +} + +impl SentinelClient { + /// Creates a SentinelClient performing some basic checks on the URLs that might + /// result in an error. + pub fn build( + params: Vec, + service_name: String, + node_connection_info: Option, + server_type: SentinelServerType, + ) -> RedisResult { + Ok(SentinelClient { + sentinel: Sentinel::build(params)?, + service_name, + node_connection_info: node_connection_info.unwrap_or_default(), + server_type, + }) + } + + fn get_client(&mut self) -> RedisResult { + match self.server_type { + SentinelServerType::Master => self + .sentinel + .master_for(self.service_name.as_str(), Some(&self.node_connection_info)), + SentinelServerType::Replica => self + .sentinel + .replica_for(self.service_name.as_str(), Some(&self.node_connection_info)), + } + } + + /// Creates a new connection to the desired type of server (based on the + /// service/master name, and the server type). We use a Sentinel to create a client + /// for the target type of server, and then create a connection using that client. + pub fn get_connection(&mut self) -> RedisResult { + let client = self.get_client()?; + client.get_connection(None) + } +} + +/// To enable async support you need to chose one of the supported runtimes and active its +/// corresponding feature: `tokio-comp` or `async-std-comp` +#[cfg(feature = "aio")] +#[cfg_attr(docsrs, doc(cfg(feature = "aio")))] +impl SentinelClient { + async fn async_get_client(&mut self) -> RedisResult { + match self.server_type { + SentinelServerType::Master => { + self.sentinel + .async_master_for(self.service_name.as_str(), Some(&self.node_connection_info)) + .await + } + SentinelServerType::Replica => { + self.sentinel + .async_replica_for(self.service_name.as_str(), Some(&self.node_connection_info)) + .await + } + } + } + + /// Returns an async connection from the client, using the same logic from + /// `SentinelClient::get_connection`. + #[cfg(any(feature = "tokio-comp", feature = "async-std-comp"))] + pub async fn get_async_connection(&mut self) -> RedisResult { + let client = self.async_get_client().await?; + client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + } +} diff --git a/glide-core/redis-rs/redis/src/streams.rs b/glide-core/redis-rs/redis/src/streams.rs new file mode 100644 index 0000000000..62505d6d75 --- /dev/null +++ b/glide-core/redis-rs/redis/src/streams.rs @@ -0,0 +1,670 @@ +//! Defines types to use with the streams commands. + +use crate::{ + from_redis_value, types::HashMap, FromRedisValue, RedisResult, RedisWrite, ToRedisArgs, Value, +}; + +use std::io::{Error, ErrorKind}; + +// Stream Maxlen Enum + +/// Utility enum for passing `MAXLEN [= or ~] [COUNT]` +/// arguments into `StreamCommands`. +/// The enum value represents the count. +#[derive(PartialEq, Eq, Clone, Debug, Copy)] +pub enum StreamMaxlen { + /// Match an exact count + Equals(usize), + /// Match an approximate count + Approx(usize), +} + +impl ToRedisArgs for StreamMaxlen { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let (ch, val) = match *self { + StreamMaxlen::Equals(v) => ("=", v), + StreamMaxlen::Approx(v) => ("~", v), + }; + out.write_arg(b"MAXLEN"); + out.write_arg(ch.as_bytes()); + val.write_redis_args(out); + } +} + +/// Builder options for [`xclaim_options`] command. +/// +/// [`xclaim_options`]: ../trait.Commands.html#method.xclaim_options +/// +#[derive(Default, Debug)] +pub struct StreamClaimOptions { + /// Set `IDLE ` cmd arg. + idle: Option, + /// Set `TIME ` cmd arg. + time: Option, + /// Set `RETRYCOUNT ` cmd arg. + retry: Option, + /// Set `FORCE` cmd arg. + force: bool, + /// Set `JUSTID` cmd arg. Be advised: the response + /// type changes with this option. + justid: bool, +} + +impl StreamClaimOptions { + /// Set `IDLE ` cmd arg. + pub fn idle(mut self, ms: usize) -> Self { + self.idle = Some(ms); + self + } + + /// Set `TIME ` cmd arg. + pub fn time(mut self, ms_time: usize) -> Self { + self.time = Some(ms_time); + self + } + + /// Set `RETRYCOUNT ` cmd arg. + pub fn retry(mut self, count: usize) -> Self { + self.retry = Some(count); + self + } + + /// Set `FORCE` cmd arg to true. + pub fn with_force(mut self) -> Self { + self.force = true; + self + } + + /// Set `JUSTID` cmd arg to true. Be advised: the response + /// type changes with this option. + pub fn with_justid(mut self) -> Self { + self.justid = true; + self + } +} + +impl ToRedisArgs for StreamClaimOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref ms) = self.idle { + out.write_arg(b"IDLE"); + out.write_arg(format!("{ms}").as_bytes()); + } + if let Some(ref ms_time) = self.time { + out.write_arg(b"TIME"); + out.write_arg(format!("{ms_time}").as_bytes()); + } + if let Some(ref count) = self.retry { + out.write_arg(b"RETRYCOUNT"); + out.write_arg(format!("{count}").as_bytes()); + } + if self.force { + out.write_arg(b"FORCE"); + } + if self.justid { + out.write_arg(b"JUSTID"); + } + } +} + +/// Argument to `StreamReadOptions` +/// Represents the Redis `GROUP ` cmd arg. +/// This option will toggle the cmd from `XREAD` to `XREADGROUP` +type SRGroup = Option<(Vec>, Vec>)>; +/// Builder options for [`xread_options`] command. +/// +/// [`xread_options`]: ../trait.Commands.html#method.xread_options +/// +#[derive(Default, Debug)] +pub struct StreamReadOptions { + /// Set the `BLOCK ` cmd arg. + block: Option, + /// Set the `COUNT ` cmd arg. + count: Option, + /// Set the `NOACK` cmd arg. + noack: Option, + /// Set the `GROUP ` cmd arg. + /// This option will toggle the cmd from XREAD to XREADGROUP. + group: SRGroup, +} + +impl StreamReadOptions { + /// Indicates whether the command is participating in a group + /// and generating ACKs + pub fn read_only(&self) -> bool { + self.group.is_none() + } + + /// Sets the command so that it avoids adding the message + /// to the PEL in cases where reliability is not a requirement + /// and the occasional message loss is acceptable. + pub fn noack(mut self) -> Self { + self.noack = Some(true); + self + } + + /// Sets the block time in milliseconds. + pub fn block(mut self, ms: usize) -> Self { + self.block = Some(ms); + self + } + + /// Sets the maximum number of elements to return per stream. + pub fn count(mut self, n: usize) -> Self { + self.count = Some(n); + self + } + + /// Sets the name of a consumer group associated to the stream. + pub fn group( + mut self, + group_name: GN, + consumer_name: CN, + ) -> Self { + self.group = Some(( + ToRedisArgs::to_redis_args(&group_name), + ToRedisArgs::to_redis_args(&consumer_name), + )); + self + } +} + +impl ToRedisArgs for StreamReadOptions { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref group) = self.group { + out.write_arg(b"GROUP"); + for i in &group.0 { + out.write_arg(i); + } + for i in &group.1 { + out.write_arg(i); + } + } + + if let Some(ref ms) = self.block { + out.write_arg(b"BLOCK"); + out.write_arg(format!("{ms}").as_bytes()); + } + + if let Some(ref n) = self.count { + out.write_arg(b"COUNT"); + out.write_arg(format!("{n}").as_bytes()); + } + + if self.group.is_some() { + // noack is only available w/ xreadgroup + if self.noack == Some(true) { + out.write_arg(b"NOACK"); + } + } + } +} + +/// Reply type used with [`xread`] or [`xread_options`] commands. +/// +/// [`xread`]: ../trait.Commands.html#method.xread +/// [`xread_options`]: ../trait.Commands.html#method.xread_options +/// +#[derive(Default, Debug, Clone)] +pub struct StreamReadReply { + /// Complex data structure containing a payload for each key in this array + pub keys: Vec, +} + +/// Reply type used with [`xrange`], [`xrange_count`], [`xrange_all`], [`xrevrange`], [`xrevrange_count`], [`xrevrange_all`] commands. +/// +/// Represents stream entries matching a given range of `id`'s. +/// +/// [`xrange`]: ../trait.Commands.html#method.xrange +/// [`xrange_count`]: ../trait.Commands.html#method.xrange_count +/// [`xrange_all`]: ../trait.Commands.html#method.xrange_all +/// [`xrevrange`]: ../trait.Commands.html#method.xrevrange +/// [`xrevrange_count`]: ../trait.Commands.html#method.xrevrange_count +/// [`xrevrange_all`]: ../trait.Commands.html#method.xrevrange_all +/// +#[derive(Default, Debug, Clone)] +pub struct StreamRangeReply { + /// Complex data structure containing a payload for each ID in this array + pub ids: Vec, +} + +/// Reply type used with [`xclaim`] command. +/// +/// Represents that ownership of the specified messages was changed. +/// +/// [`xclaim`]: ../trait.Commands.html#method.xclaim +/// +#[derive(Default, Debug, Clone)] +pub struct StreamClaimReply { + /// Complex data structure containing a payload for each ID in this array + pub ids: Vec, +} + +/// Reply type used with [`xpending`] command. +/// +/// Data returned here were fetched from the stream without +/// having been acknowledged. +/// +/// [`xpending`]: ../trait.Commands.html#method.xpending +/// +#[derive(Debug, Clone, Default)] +pub enum StreamPendingReply { + /// The stream is empty. + #[default] + Empty, + /// Data with payload exists in the stream. + Data(StreamPendingData), +} + +impl StreamPendingReply { + /// Returns how many records are in the reply. + pub fn count(&self) -> usize { + match self { + StreamPendingReply::Empty => 0, + StreamPendingReply::Data(x) => x.count, + } + } +} + +/// Inner reply type when an [`xpending`] command has data. +/// +/// [`xpending`]: ../trait.Commands.html#method.xpending +#[derive(Default, Debug, Clone)] +pub struct StreamPendingData { + /// Limit on the number of messages to return per call. + pub count: usize, + /// ID for the first pending record. + pub start_id: String, + /// ID for the final pending record. + pub end_id: String, + /// Every consumer in the consumer group with at + /// least one pending message, + /// and the number of pending messages it has. + pub consumers: Vec, +} + +/// Reply type used with [`xpending_count`] and +/// [`xpending_consumer_count`] commands. +/// +/// Data returned here have been fetched from the stream without +/// any acknowledgement. +/// +/// [`xpending_count`]: ../trait.Commands.html#method.xpending_count +/// [`xpending_consumer_count`]: ../trait.Commands.html#method.xpending_consumer_count +/// +#[derive(Default, Debug, Clone)] +pub struct StreamPendingCountReply { + /// An array of structs containing information about + /// message IDs yet to be acknowledged by various consumers, + /// time since last ack, and total number of acks by that consumer. + pub ids: Vec, +} + +/// Reply type used with [`xinfo_stream`] command, containing +/// general information about the stream stored at the specified key. +/// +/// The very first and last IDs in the stream are shown, +/// in order to give some sense about what is the stream content. +/// +/// [`xinfo_stream`]: ../trait.Commands.html#method.xinfo_stream +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoStreamReply { + /// The last generated ID that may not be the same as the last + /// entry ID in case some entry was deleted. + pub last_generated_id: String, + /// Details about the radix tree representing the stream mostly + /// useful for optimization and debugging tasks. + pub radix_tree_keys: usize, + /// The number of consumer groups associated with the stream. + pub groups: usize, + /// Number of elements of the stream. + pub length: usize, + /// The very first entry in the stream. + pub first_entry: StreamId, + /// The very last entry in the stream. + pub last_entry: StreamId, +} + +/// Reply type used with [`xinfo_consumer`] command, an array of every +/// consumer in a specific consumer group. +/// +/// [`xinfo_consumer`]: ../trait.Commands.html#method.xinfo_consumer +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoConsumersReply { + /// An array of every consumer in a specific consumer group. + pub consumers: Vec, +} + +/// Reply type used with [`xinfo_groups`] command. +/// +/// This output represents all the consumer groups associated with +/// the stream. +/// +/// [`xinfo_groups`]: ../trait.Commands.html#method.xinfo_groups +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoGroupsReply { + /// All the consumer groups associated with the stream. + pub groups: Vec, +} + +/// A consumer parsed from [`xinfo_consumers`] command. +/// +/// [`xinfo_consumers`]: ../trait.Commands.html#method.xinfo_consumers +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoConsumer { + /// Name of the consumer group. + pub name: String, + /// Number of pending messages for this specific consumer. + pub pending: usize, + /// This consumer's idle time in milliseconds. + pub idle: usize, +} + +/// A group parsed from [`xinfo_groups`] command. +/// +/// [`xinfo_groups`]: ../trait.Commands.html#method.xinfo_groups +/// +#[derive(Default, Debug, Clone)] +pub struct StreamInfoGroup { + /// The group name. + pub name: String, + /// Number of consumers known in the group. + pub consumers: usize, + /// Number of pending messages (delivered but not yet acknowledged) in the group. + pub pending: usize, + /// Last ID delivered to this group. + pub last_delivered_id: String, +} + +/// Represents a pending message parsed from [`xpending`] methods. +/// +/// [`xpending`]: ../trait.Commands.html#method.xpending +#[derive(Default, Debug, Clone)] +pub struct StreamPendingId { + /// The ID of the message. + pub id: String, + /// The name of the consumer that fetched the message and has + /// still to acknowledge it. We call it the current owner + /// of the message. + pub consumer: String, + /// The number of milliseconds that elapsed since the + /// last time this message was delivered to this consumer. + pub last_delivered_ms: usize, + /// The number of times this message was delivered. + pub times_delivered: usize, +} + +/// Represents a stream `key` and its `id`'s parsed from `xread` methods. +#[derive(Default, Debug, Clone)] +pub struct StreamKey { + /// The stream `key`. + pub key: String, + /// The parsed stream `id`'s. + pub ids: Vec, +} + +/// Represents a stream `id` and its field/values as a `HashMap` +#[derive(Default, Debug, Clone)] +pub struct StreamId { + /// The stream `id` (entry ID) of this particular message. + pub id: String, + /// All fields in this message, associated with their respective values. + pub map: HashMap, +} + +impl StreamId { + /// Converts a `Value::Array` into a `StreamId`. + fn from_array_value(v: &Value) -> RedisResult { + let mut stream_id = StreamId::default(); + if let Value::Array(ref values) = *v { + if let Some(v) = values.first() { + stream_id.id = from_redis_value(v)?; + } + if let Some(v) = values.get(1) { + stream_id.map = from_redis_value(v)?; + } + } + + Ok(stream_id) + } + + /// Fetches value of a given field and converts it to the specified + /// type. + pub fn get(&self, key: &str) -> Option { + match self.map.get(key) { + Some(x) => from_redis_value(x).ok(), + None => None, + } + } + + /// Does the message contain a particular field? + pub fn contains_key(&self, key: &str) -> bool { + self.map.contains_key(key) + } + + /// Returns how many field/value pairs exist in this message. + pub fn len(&self) -> usize { + self.map.len() + } + + /// Returns true if there are no field/value pairs in this message. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } +} + +type SRRows = Vec>>>>; +impl FromRedisValue for StreamReadReply { + fn from_redis_value(v: &Value) -> RedisResult { + let rows: SRRows = from_redis_value(v)?; + let keys = rows + .into_iter() + .flat_map(|row| { + row.into_iter().map(|(key, entry)| { + let ids = entry + .into_iter() + .flat_map(|id_row| id_row.into_iter().map(|(id, map)| StreamId { id, map })) + .collect(); + StreamKey { key, ids } + }) + }) + .collect(); + Ok(StreamReadReply { keys }) + } +} + +impl FromRedisValue for StreamRangeReply { + fn from_redis_value(v: &Value) -> RedisResult { + let rows: Vec>> = from_redis_value(v)?; + let ids: Vec = rows + .into_iter() + .flat_map(|row| row.into_iter().map(|(id, map)| StreamId { id, map })) + .collect(); + Ok(StreamRangeReply { ids }) + } +} + +impl FromRedisValue for StreamClaimReply { + fn from_redis_value(v: &Value) -> RedisResult { + let rows: Vec>> = from_redis_value(v)?; + let ids: Vec = rows + .into_iter() + .flat_map(|row| row.into_iter().map(|(id, map)| StreamId { id, map })) + .collect(); + Ok(StreamClaimReply { ids }) + } +} + +type SPRInner = ( + usize, + Option, + Option, + Vec>, +); +impl FromRedisValue for StreamPendingReply { + fn from_redis_value(v: &Value) -> RedisResult { + let (count, start, end, consumer_data): SPRInner = from_redis_value(v)?; + + if count == 0 { + Ok(StreamPendingReply::Empty) + } else { + let mut result = StreamPendingData::default(); + + let start_id = start.ok_or_else(|| { + Error::new( + ErrorKind::Other, + "IllegalState: Non-zero pending expects start id", + ) + })?; + + let end_id = end.ok_or_else(|| { + Error::new( + ErrorKind::Other, + "IllegalState: Non-zero pending expects end id", + ) + })?; + + result.count = count; + result.start_id = start_id; + result.end_id = end_id; + + result.consumers = consumer_data + .into_iter() + .flatten() + .map(|(name, pending)| StreamInfoConsumer { + name, + pending: pending.parse().unwrap_or_default(), + ..Default::default() + }) + .collect(); + + Ok(StreamPendingReply::Data(result)) + } + } +} + +impl FromRedisValue for StreamPendingCountReply { + fn from_redis_value(v: &Value) -> RedisResult { + let mut reply = StreamPendingCountReply::default(); + match v { + Value::Array(outer_tuple) => { + for outer in outer_tuple { + match outer { + Value::Array(inner_tuple) => match &inner_tuple[..] { + [Value::BulkString(id_bytes), Value::BulkString(consumer_bytes), Value::Int(last_delivered_ms_u64), Value::Int(times_delivered_u64)] => + { + let id = String::from_utf8(id_bytes.to_vec())?; + let consumer = String::from_utf8(consumer_bytes.to_vec())?; + let last_delivered_ms = *last_delivered_ms_u64 as usize; + let times_delivered = *times_delivered_u64 as usize; + reply.ids.push(StreamPendingId { + id, + consumer, + last_delivered_ms, + times_delivered, + }); + } + _ => fail!(( + crate::types::ErrorKind::TypeError, + "Cannot parse redis data (3)" + )), + }, + _ => fail!(( + crate::types::ErrorKind::TypeError, + "Cannot parse redis data (2)" + )), + } + } + } + _ => fail!(( + crate::types::ErrorKind::TypeError, + "Cannot parse redis data (1)" + )), + }; + Ok(reply) + } +} + +impl FromRedisValue for StreamInfoStreamReply { + fn from_redis_value(v: &Value) -> RedisResult { + let map: HashMap = from_redis_value(v)?; + let mut reply = StreamInfoStreamReply::default(); + if let Some(v) = &map.get("last-generated-id") { + reply.last_generated_id = from_redis_value(v)?; + } + if let Some(v) = &map.get("radix-tree-nodes") { + reply.radix_tree_keys = from_redis_value(v)?; + } + if let Some(v) = &map.get("groups") { + reply.groups = from_redis_value(v)?; + } + if let Some(v) = &map.get("length") { + reply.length = from_redis_value(v)?; + } + if let Some(v) = &map.get("first-entry") { + reply.first_entry = StreamId::from_array_value(v)?; + } + if let Some(v) = &map.get("last-entry") { + reply.last_entry = StreamId::from_array_value(v)?; + } + Ok(reply) + } +} + +impl FromRedisValue for StreamInfoConsumersReply { + fn from_redis_value(v: &Value) -> RedisResult { + let consumers: Vec> = from_redis_value(v)?; + let mut reply = StreamInfoConsumersReply::default(); + for map in consumers { + let mut c = StreamInfoConsumer::default(); + if let Some(v) = &map.get("name") { + c.name = from_redis_value(v)?; + } + if let Some(v) = &map.get("pending") { + c.pending = from_redis_value(v)?; + } + if let Some(v) = &map.get("idle") { + c.idle = from_redis_value(v)?; + } + reply.consumers.push(c); + } + + Ok(reply) + } +} + +impl FromRedisValue for StreamInfoGroupsReply { + fn from_redis_value(v: &Value) -> RedisResult { + let groups: Vec> = from_redis_value(v)?; + let mut reply = StreamInfoGroupsReply::default(); + for map in groups { + let mut g = StreamInfoGroup::default(); + if let Some(v) = &map.get("name") { + g.name = from_redis_value(v)?; + } + if let Some(v) = &map.get("pending") { + g.pending = from_redis_value(v)?; + } + if let Some(v) = &map.get("consumers") { + g.consumers = from_redis_value(v)?; + } + if let Some(v) = &map.get("last-delivered-id") { + g.last_delivered_id = from_redis_value(v)?; + } + reply.groups.push(g); + } + Ok(reply) + } +} diff --git a/glide-core/redis-rs/redis/src/tls.rs b/glide-core/redis-rs/redis/src/tls.rs new file mode 100644 index 0000000000..6886efb836 --- /dev/null +++ b/glide-core/redis-rs/redis/src/tls.rs @@ -0,0 +1,142 @@ +use std::io::{BufRead, Error, ErrorKind as IOErrorKind}; + +use rustls::RootCertStore; +use rustls_pki_types::{CertificateDer, PrivateKeyDer}; + +use crate::{Client, ConnectionAddr, ConnectionInfo, ErrorKind, RedisError, RedisResult}; + +/// Structure to hold mTLS client _certificate_ and _key_ binaries in PEM format +/// +#[derive(Clone)] +pub struct ClientTlsConfig { + /// client certificate byte stream in PEM format + pub client_cert: Vec, + /// client key byte stream in PEM format + pub client_key: Vec, +} + +/// Structure to hold TLS certificates +/// - `client_tls`: binaries of clientkey and certificate within a `ClientTlsConfig` structure if mTLS is used +/// - `root_cert`: binary CA certificate in PEM format if CA is not in local truststore +/// +#[derive(Clone)] +pub struct TlsCertificates { + /// 'ClientTlsConfig' containing client certificate and key if mTLS is to be used + pub client_tls: Option, + /// root certificate byte stream in PEM format if the local truststore is *not* to be used + pub root_cert: Option>, +} + +pub(crate) fn inner_build_with_tls( + mut connection_info: ConnectionInfo, + certificates: TlsCertificates, +) -> RedisResult { + let tls_params = retrieve_tls_certificates(certificates)?; + + connection_info.addr = if let ConnectionAddr::TcpTls { + host, + port, + insecure, + .. + } = connection_info.addr + { + ConnectionAddr::TcpTls { + host, + port, + insecure, + tls_params: Some(tls_params), + } + } else { + return Err(RedisError::from(( + ErrorKind::InvalidClientConfig, + "Constructing a TLS client requires a URL with the `rediss://` scheme", + ))); + }; + + Ok(Client { connection_info }) +} + +pub(crate) fn retrieve_tls_certificates( + certificates: TlsCertificates, +) -> RedisResult { + let TlsCertificates { + client_tls, + root_cert, + } = certificates; + + let client_tls_params = if let Some(ClientTlsConfig { + client_cert, + client_key, + }) = client_tls + { + let buf = &mut client_cert.as_slice() as &mut dyn BufRead; + let certs = rustls_pemfile::certs(buf); + let client_cert_chain = certs.collect::, _>>()?; + + let client_key = + rustls_pemfile::private_key(&mut client_key.as_slice() as &mut dyn BufRead)? + .ok_or_else(|| { + Error::new( + IOErrorKind::Other, + "Unable to extract private key from PEM file", + ) + })?; + + Some(ClientTlsParams { + client_cert_chain, + client_key, + }) + } else { + None + }; + + let root_cert_store = if let Some(root_cert) = root_cert { + let buf = &mut root_cert.as_slice() as &mut dyn BufRead; + let certs = rustls_pemfile::certs(buf); + let mut root_cert_store = RootCertStore::empty(); + for result in certs { + if root_cert_store.add(result?.to_owned()).is_err() { + return Err( + Error::new(IOErrorKind::Other, "Unable to parse TLS trust anchors").into(), + ); + } + } + + Some(root_cert_store) + } else { + None + }; + + Ok(TlsConnParams { + client_tls_params, + root_cert_store, + }) +} + +#[derive(Debug)] +pub struct ClientTlsParams { + pub(crate) client_cert_chain: Vec>, + pub(crate) client_key: PrivateKeyDer<'static>, +} + +/// [`PrivateKeyDer`] does not implement `Clone` so we need to implement it manually. +impl Clone for ClientTlsParams { + fn clone(&self) -> Self { + use PrivateKeyDer::*; + Self { + client_cert_chain: self.client_cert_chain.clone(), + client_key: match &self.client_key { + Pkcs1(key) => Pkcs1(key.secret_pkcs1_der().to_vec().into()), + Pkcs8(key) => Pkcs8(key.secret_pkcs8_der().to_vec().into()), + Sec1(key) => Sec1(key.secret_sec1_der().to_vec().into()), + _ => unreachable!(), + }, + } + } +} + +#[derive(Debug, Clone)] +pub struct TlsConnParams { + pub(crate) client_tls_params: Option, + pub(crate) root_cert_store: Option, +} diff --git a/glide-core/redis-rs/redis/src/types.rs b/glide-core/redis-rs/redis/src/types.rs new file mode 100644 index 0000000000..a024f16a7d --- /dev/null +++ b/glide-core/redis-rs/redis/src/types.rs @@ -0,0 +1,2460 @@ +use std::collections::{BTreeMap, BTreeSet}; +use std::default::Default; +use std::error; +use std::ffi::{CString, NulError}; +use std::fmt; +use std::hash::{BuildHasher, Hash}; +use std::io; +use std::str::{from_utf8, Utf8Error}; +use std::string::FromUtf8Error; + +#[cfg(feature = "ahash")] +pub(crate) use ahash::{AHashMap as HashMap, AHashSet as HashSet}; +use num_bigint::BigInt; +#[cfg(not(feature = "ahash"))] +pub(crate) use std::collections::{HashMap, HashSet}; +use std::ops::Deref; + +macro_rules! invalid_type_error { + ($v:expr, $det:expr) => {{ + fail!(invalid_type_error_inner!($v, $det)) + }}; +} + +macro_rules! invalid_type_error_inner { + ($v:expr, $det:expr) => { + RedisError::from(( + ErrorKind::TypeError, + "Response was of incompatible type", + format!("{:?} (response was {:?})", $det, $v), + )) + }; +} + +/// Helper enum that is used to define expiry time +pub enum Expiry { + /// EX seconds -- Set the specified expire time, in seconds. + EX(usize), + /// PX milliseconds -- Set the specified expire time, in milliseconds. + PX(usize), + /// EXAT timestamp-seconds -- Set the specified Unix time at which the key will expire, in seconds. + EXAT(usize), + /// PXAT timestamp-milliseconds -- Set the specified Unix time at which the key will expire, in milliseconds. + PXAT(usize), + /// PERSIST -- Remove the time to live associated with the key. + PERSIST, +} + +/// Helper enum that is used to define expiry time for SET command +#[derive(Clone, Copy)] +pub enum SetExpiry { + /// EX seconds -- Set the specified expire time, in seconds. + EX(usize), + /// PX milliseconds -- Set the specified expire time, in milliseconds. + PX(usize), + /// EXAT timestamp-seconds -- Set the specified Unix time at which the key will expire, in seconds. + EXAT(usize), + /// PXAT timestamp-milliseconds -- Set the specified Unix time at which the key will expire, in milliseconds. + PXAT(usize), + /// KEEPTTL -- Retain the time to live associated with the key. + KEEPTTL, +} + +/// Helper enum that is used to define existence checks +#[derive(Clone, Copy)] +pub enum ExistenceCheck { + /// NX -- Only set the key if it does not already exist. + NX, + /// XX -- Only set the key if it already exists. + XX, +} + +/// Helper enum that is used in some situations to describe +/// the behavior of arguments in a numeric context. +#[derive(PartialEq, Eq, Clone, Debug, Copy)] +pub enum NumericBehavior { + /// This argument is not numeric. + NonNumeric, + /// This argument is an integer. + NumberIsInteger, + /// This argument is a floating point value. + NumberIsFloat, +} + +/// An enum of all error kinds. +#[derive(PartialEq, Eq, Copy, Clone, Debug)] +#[non_exhaustive] +pub enum ErrorKind { + /// The server generated an invalid response. + ResponseError, + /// The parser failed to parse the server response. + ParseError, + /// The authentication with the server failed. + AuthenticationFailed, + /// Operation failed because of a type mismatch. + TypeError, + /// A script execution was aborted. + ExecAbortError, + /// The server cannot response because it's loading a dump. + BusyLoadingError, + /// A script that was requested does not actually exist. + NoScriptError, + /// An error that was caused because the parameter to the + /// client were wrong. + InvalidClientConfig, + /// Raised if a key moved to a different node. + Moved, + /// Raised if a key moved to a different node but we need to ask. + Ask, + /// Raised if a request needs to be retried. + TryAgain, + /// Raised if a redis cluster is down. + ClusterDown, + /// A request spans multiple slots + CrossSlot, + /// A cluster master is unavailable. + MasterDown, + /// This kind is returned if the redis error is one that is + /// not native to the system. This is usually the case if + /// the cause is another error. + IoError, + /// An error raised that was identified on the client before execution. + ClientError, + /// An extension error. This is an error created by the server + /// that is not directly understood by the library. + ExtensionError, + /// Attempt to write to a read-only server + ReadOnly, + /// Requested name not found among masters returned by the sentinels + MasterNameNotFoundBySentinel, + /// No valid replicas found in the sentinels, for a given master name + NoValidReplicasFoundBySentinel, + /// At least one sentinel connection info is required + EmptySentinelList, + /// Attempted to kill a script/function while they werent' executing + NotBusy, + /// Used when no valid node connections remain in the cluster connection + AllConnectionsUnavailable, + /// Used when a connection is not found for the specified route. + ConnectionNotFoundForRoute, + + #[cfg(feature = "json")] + /// Error Serializing a struct to JSON form + Serialize, + + /// Redis Servers prior to v6.0.0 doesn't support RESP3. + /// Try disabling resp3 option + RESP3NotSupported, + + /// Not all slots are covered by the cluster + NotAllSlotsCovered, +} + +#[derive(PartialEq, Debug)] +pub(crate) enum ServerErrorKind { + ResponseError, + ExecAbortError, + BusyLoadingError, + NoScriptError, + Moved, + Ask, + TryAgain, + ClusterDown, + CrossSlot, + MasterDown, + ReadOnly, + NotBusy, +} + +#[derive(PartialEq, Debug)] +pub(crate) enum ServerError { + ExtensionError { + code: String, + detail: Option, + }, + KnownError { + kind: ServerErrorKind, + detail: Option, + }, +} + +impl From for RedisError { + fn from(value: ServerError) -> Self { + // TODO - Consider changing RedisError to explicitly represent whether an error came from the server or not. Today it is only implied. + match value { + ServerError::ExtensionError { code, detail } => make_extension_error(code, detail), + ServerError::KnownError { kind, detail } => { + let desc = "An error was signalled by the server"; + let kind = match kind { + ServerErrorKind::ResponseError => ErrorKind::ResponseError, + ServerErrorKind::ExecAbortError => ErrorKind::ExecAbortError, + ServerErrorKind::BusyLoadingError => ErrorKind::BusyLoadingError, + ServerErrorKind::NoScriptError => ErrorKind::NoScriptError, + ServerErrorKind::Moved => ErrorKind::Moved, + ServerErrorKind::Ask => ErrorKind::Ask, + ServerErrorKind::TryAgain => ErrorKind::TryAgain, + ServerErrorKind::ClusterDown => ErrorKind::ClusterDown, + ServerErrorKind::CrossSlot => ErrorKind::CrossSlot, + ServerErrorKind::MasterDown => ErrorKind::MasterDown, + ServerErrorKind::ReadOnly => ErrorKind::ReadOnly, + ServerErrorKind::NotBusy => ErrorKind::NotBusy, + }; + match detail { + Some(detail) => RedisError::from((kind, desc, detail)), + None => RedisError::from((kind, desc)), + } + } + } + } +} + +/// Internal low-level redis value enum. +#[derive(PartialEq, Debug)] +pub(crate) enum InternalValue { + /// A nil response from the server. + Nil, + /// An integer response. Note that there are a few situations + /// in which redis actually returns a string for an integer which + /// is why this library generally treats integers and strings + /// the same for all numeric responses. + Int(i64), + /// An arbitrary binary data, usually represents a binary-safe string. + BulkString(Vec), + /// A response containing an array with more data. This is generally used by redis + /// to express nested structures. + Array(Vec), + /// A simple string response, without line breaks and not binary safe. + SimpleString(String), + /// A status response which represents the string "OK". + Okay, + /// Unordered key,value list from the server. Use `as_map_iter` function. + Map(Vec<(InternalValue, InternalValue)>), + /// Attribute value from the server. Client will give data instead of whole Attribute type. + Attribute { + /// Data that attributes belong to. + data: Box, + /// Key,Value list of attributes. + attributes: Vec<(InternalValue, InternalValue)>, + }, + /// Unordered set value from the server. + Set(Vec), + /// A floating number response from the server. + Double(f64), + /// A boolean response from the server. + Boolean(bool), + /// First String is format and other is the string + VerbatimString { + /// Text's format type + format: VerbatimFormat, + /// Remaining string check format before using! + text: String, + }, + /// Very large number that out of the range of the signed 64 bit numbers + BigNumber(BigInt), + /// Push data from the server. + Push { + /// Push Kind + kind: PushKind, + /// Remaining data from push message + data: Vec, + }, + ServerError(ServerError), +} + +impl InternalValue { + pub(crate) fn try_into(self) -> RedisResult { + match self { + InternalValue::Nil => Ok(Value::Nil), + InternalValue::Int(val) => Ok(Value::Int(val)), + InternalValue::BulkString(val) => Ok(Value::BulkString(val)), + InternalValue::Array(val) => Ok(Value::Array(Self::try_into_vec(val)?)), + InternalValue::SimpleString(val) => Ok(Value::SimpleString(val)), + InternalValue::Okay => Ok(Value::Okay), + InternalValue::Map(map) => Ok(Value::Map(Self::try_into_map(map)?)), + InternalValue::Attribute { data, attributes } => { + let data = Box::new((*data).try_into()?); + let attributes = Self::try_into_map(attributes)?; + Ok(Value::Attribute { data, attributes }) + } + InternalValue::Set(set) => Ok(Value::Set(Self::try_into_vec(set)?)), + InternalValue::Double(double) => Ok(Value::Double(double)), + InternalValue::Boolean(boolean) => Ok(Value::Boolean(boolean)), + InternalValue::VerbatimString { format, text } => { + Ok(Value::VerbatimString { format, text }) + } + InternalValue::BigNumber(number) => Ok(Value::BigNumber(number)), + InternalValue::Push { kind, data } => Ok(Value::Push { + kind, + data: Self::try_into_vec(data)?, + }), + + InternalValue::ServerError(err) => Err(err.into()), + } + } + + fn try_into_vec(vec: Vec) -> RedisResult> { + vec.into_iter() + .map(InternalValue::try_into) + .collect::>>() + } + + fn try_into_map(map: Vec<(InternalValue, InternalValue)>) -> RedisResult> { + let mut vec = Vec::with_capacity(map.len()); + for (key, value) in map.into_iter() { + vec.push((key.try_into()?, value.try_into()?)); + } + Ok(vec) + } +} + +/// Internal low-level redis value enum. +#[derive(PartialEq, Clone)] +pub enum Value { + /// A nil response from the server. + Nil, + /// An integer response. Note that there are a few situations + /// in which redis actually returns a string for an integer which + /// is why this library generally treats integers and strings + /// the same for all numeric responses. + Int(i64), + /// An arbitrary binary data, usually represents a binary-safe string. + BulkString(Vec), + /// A response containing an array with more data. This is generally used by redis + /// to express nested structures. + Array(Vec), + /// A simple string response, without line breaks and not binary safe. + SimpleString(String), + /// A status response which represents the string "OK". + Okay, + /// Unordered key,value list from the server. Use `as_map_iter` function. + Map(Vec<(Value, Value)>), + /// Attribute value from the server. Client will give data instead of whole Attribute type. + Attribute { + /// Data that attributes belong to. + data: Box, + /// Key,Value list of attributes. + attributes: Vec<(Value, Value)>, + }, + /// Unordered set value from the server. + Set(Vec), + /// A floating number response from the server. + Double(f64), + /// A boolean response from the server. + Boolean(bool), + /// First String is format and other is the string + VerbatimString { + /// Text's format type + format: VerbatimFormat, + /// Remaining string check format before using! + text: String, + }, + /// Very large number that out of the range of the signed 64 bit numbers + BigNumber(BigInt), + /// Push data from the server. + Push { + /// Push Kind + kind: PushKind, + /// Remaining data from push message + data: Vec, + }, +} + +/// `VerbatimString`'s format types defined by spec +#[derive(PartialEq, Clone, Debug)] +pub enum VerbatimFormat { + /// Unknown type to catch future formats. + Unknown(String), + /// `mkd` format + Markdown, + /// `txt` format + Text, +} + +/// `Push` type's currently known kinds. +#[derive(PartialEq, Clone, Debug)] +pub enum PushKind { + /// `Disconnection` is sent from the **library** when connection is closed. + Disconnection, + /// Other kind to catch future kinds. + Other(String), + /// `invalidate` is received when a key is changed/deleted. + Invalidate, + /// `message` is received when pubsub message published by another client. + Message, + /// `pmessage` is received when pubsub message published by another client and client subscribed to topic via pattern. + PMessage, + /// `smessage` is received when pubsub message published by another client and client subscribed to it with sharding. + SMessage, + /// `unsubscribe` is received when client unsubscribed from a channel. + Unsubscribe, + /// `punsubscribe` is received when client unsubscribed from a pattern. + PUnsubscribe, + /// `sunsubscribe` is received when client unsubscribed from a shard channel. + SUnsubscribe, + /// `subscribe` is received when client subscribed to a channel. + Subscribe, + /// `psubscribe` is received when client subscribed to a pattern. + PSubscribe, + /// `ssubscribe` is received when client subscribed to a shard channel. + SSubscribe, +} + +impl PushKind { + #[cfg(feature = "aio")] + pub(crate) fn has_reply(&self) -> bool { + matches!( + self, + &PushKind::Unsubscribe + | &PushKind::PUnsubscribe + | &PushKind::SUnsubscribe + | &PushKind::Subscribe + | &PushKind::PSubscribe + | &PushKind::SSubscribe + ) + } +} + +impl fmt::Display for VerbatimFormat { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + VerbatimFormat::Markdown => write!(f, "mkd"), + VerbatimFormat::Unknown(val) => write!(f, "{val}"), + VerbatimFormat::Text => write!(f, "txt"), + } + } +} + +impl fmt::Display for PushKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PushKind::Other(kind) => write!(f, "{}", kind), + PushKind::Invalidate => write!(f, "invalidate"), + PushKind::Message => write!(f, "message"), + PushKind::PMessage => write!(f, "pmessage"), + PushKind::SMessage => write!(f, "smessage"), + PushKind::Unsubscribe => write!(f, "unsubscribe"), + PushKind::PUnsubscribe => write!(f, "punsubscribe"), + PushKind::SUnsubscribe => write!(f, "sunsubscribe"), + PushKind::Subscribe => write!(f, "subscribe"), + PushKind::PSubscribe => write!(f, "psubscribe"), + PushKind::SSubscribe => write!(f, "ssubscribe"), + PushKind::Disconnection => write!(f, "disconnection"), + } + } +} + +pub enum MapIter<'a> { + Array(std::slice::Iter<'a, Value>), + Map(std::slice::Iter<'a, (Value, Value)>), +} + +impl<'a> Iterator for MapIter<'a> { + type Item = (&'a Value, &'a Value); + + fn next(&mut self) -> Option { + match self { + MapIter::Array(iter) => Some((iter.next()?, iter.next()?)), + MapIter::Map(iter) => { + let (k, v) = iter.next()?; + Some((k, v)) + } + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + MapIter::Array(iter) => iter.size_hint(), + MapIter::Map(iter) => iter.size_hint(), + } + } +} + +pub enum OwnedMapIter { + Array(std::vec::IntoIter), + Map(std::vec::IntoIter<(Value, Value)>), +} + +impl Iterator for OwnedMapIter { + type Item = (Value, Value); + + fn next(&mut self) -> Option { + match self { + OwnedMapIter::Array(iter) => Some((iter.next()?, iter.next()?)), + OwnedMapIter::Map(iter) => iter.next(), + } + } + + fn size_hint(&self) -> (usize, Option) { + match self { + OwnedMapIter::Array(iter) => { + let (low, high) = iter.size_hint(); + (low / 2, high.map(|h| h / 2)) + } + OwnedMapIter::Map(iter) => iter.size_hint(), + } + } +} + +/// Values are generally not used directly unless you are using the +/// more low level functionality in the library. For the most part +/// this is hidden with the help of the `FromRedisValue` trait. +/// +/// While on the redis protocol there is an error type this is already +/// separated at an early point so the value only holds the remaining +/// types. +impl Value { + /// Checks if the return value looks like it fulfils the cursor + /// protocol. That means the result is an array item of length + /// two with the first one being a cursor and the second an + /// array response. + pub fn looks_like_cursor(&self) -> bool { + match *self { + Value::Array(ref items) => { + if items.len() != 2 { + return false; + } + matches!(items[0], Value::BulkString(_)) && matches!(items[1], Value::Array(_)) + } + _ => false, + } + } + + /// Returns an `&[Value]` if `self` is compatible with a sequence type + pub fn as_sequence(&self) -> Option<&[Value]> { + match self { + Value::Array(items) => Some(&items[..]), + Value::Set(items) => Some(&items[..]), + Value::Nil => Some(&[]), + _ => None, + } + } + + /// Returns a `Vec` if `self` is compatible with a sequence type, + /// otherwise returns `Err(self)`. + pub fn into_sequence(self) -> Result, Value> { + match self { + Value::Array(items) => Ok(items), + Value::Set(items) => Ok(items), + Value::Nil => Ok(vec![]), + _ => Err(self), + } + } + + /// Returns an iterator of `(&Value, &Value)` if `self` is compatible with a map type + pub fn as_map_iter(&self) -> Option> { + match self { + Value::Array(items) => { + if items.len() % 2 == 0 { + Some(MapIter::Array(items.iter())) + } else { + None + } + } + Value::Map(items) => Some(MapIter::Map(items.iter())), + _ => None, + } + } + + /// Returns an iterator of `(Value, Value)` if `self` is compatible with a map type. + /// If not, returns `Err(self)`. + pub fn into_map_iter(self) -> Result { + match self { + Value::Array(items) => { + if items.len() % 2 == 0 { + Ok(OwnedMapIter::Array(items.into_iter())) + } else { + Err(Value::Array(items)) + } + } + Value::Map(items) => Ok(OwnedMapIter::Map(items.into_iter())), + _ => Err(self), + } + } +} + +impl fmt::Debug for Value { + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + match *self { + Value::Nil => write!(fmt, "nil"), + Value::Int(val) => write!(fmt, "int({val:?})"), + Value::BulkString(ref val) => match from_utf8(val) { + Ok(x) => write!(fmt, "bulk-string('{x:?}')"), + Err(_) => write!(fmt, "binary-data({val:?})"), + }, + Value::Array(ref values) => write!(fmt, "array({values:?})"), + Value::Push { ref kind, ref data } => write!(fmt, "push({kind:?}, {data:?})"), + Value::Okay => write!(fmt, "ok"), + Value::SimpleString(ref s) => write!(fmt, "simple-string({s:?})"), + Value::Map(ref values) => write!(fmt, "map({values:?})"), + Value::Attribute { + ref data, + attributes: _, + } => write!(fmt, "attribute({data:?})"), + Value::Set(ref values) => write!(fmt, "set({values:?})"), + Value::Double(ref d) => write!(fmt, "double({d:?})"), + Value::Boolean(ref b) => write!(fmt, "boolean({b:?})"), + Value::VerbatimString { + ref format, + ref text, + } => { + write!(fmt, "verbatim-string({:?},{:?})", format, text) + } + Value::BigNumber(ref m) => write!(fmt, "big-number({:?})", m), + } + } +} + +/// Represents a redis error. For the most part you should be using +/// the Error trait to interact with this rather than the actual +/// struct. +pub struct RedisError { + repr: ErrorRepr, +} + +#[cfg(feature = "json")] +impl From for RedisError { + fn from(serde_err: serde_json::Error) -> RedisError { + RedisError::from(( + ErrorKind::Serialize, + "Serialization Error", + format!("{serde_err}"), + )) + } +} + +#[derive(Debug)] +enum ErrorRepr { + WithDescription(ErrorKind, &'static str), + WithDescriptionAndDetail(ErrorKind, &'static str, String), + ExtensionError(String, String), + IoError(io::Error), +} + +impl PartialEq for RedisError { + fn eq(&self, other: &RedisError) -> bool { + match (&self.repr, &other.repr) { + (&ErrorRepr::WithDescription(kind_a, _), &ErrorRepr::WithDescription(kind_b, _)) => { + kind_a == kind_b + } + ( + &ErrorRepr::WithDescriptionAndDetail(kind_a, _, _), + &ErrorRepr::WithDescriptionAndDetail(kind_b, _, _), + ) => kind_a == kind_b, + (ErrorRepr::ExtensionError(a, _), ErrorRepr::ExtensionError(b, _)) => *a == *b, + _ => false, + } + } +} + +impl From for RedisError { + fn from(err: io::Error) -> RedisError { + RedisError { + repr: ErrorRepr::IoError(err), + } + } +} + +impl From for RedisError { + fn from(_: Utf8Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescription(ErrorKind::TypeError, "Invalid UTF-8"), + } + } +} + +impl From for RedisError { + fn from(err: NulError) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::TypeError, + "Value contains interior nul terminator", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "tls-native-tls")] +impl From for RedisError { + fn from(err: native_tls::Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "TLS error", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "tls-rustls")] +impl From for RedisError { + fn from(err: rustls::Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "TLS error", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "tls-rustls")] +impl From for RedisError { + fn from(err: rustls_pki_types::InvalidDnsNameError) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::IoError, + "TLS Error", + err.to_string(), + ), + } + } +} + +#[cfg(feature = "uuid")] +impl From for RedisError { + fn from(err: uuid::Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail( + ErrorKind::TypeError, + "Value is not a valid UUID", + err.to_string(), + ), + } + } +} + +impl From for RedisError { + fn from(_: FromUtf8Error) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescription(ErrorKind::TypeError, "Cannot convert from UTF-8"), + } + } +} + +impl From<(ErrorKind, &'static str)> for RedisError { + fn from((kind, desc): (ErrorKind, &'static str)) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescription(kind, desc), + } + } +} + +impl From<(ErrorKind, &'static str, String)> for RedisError { + fn from((kind, desc, detail): (ErrorKind, &'static str, String)) -> RedisError { + RedisError { + repr: ErrorRepr::WithDescriptionAndDetail(kind, desc, detail), + } + } +} + +impl error::Error for RedisError { + #[allow(deprecated)] + fn description(&self) -> &str { + match self.repr { + ErrorRepr::WithDescription(_, desc) => desc, + ErrorRepr::WithDescriptionAndDetail(_, desc, _) => desc, + ErrorRepr::ExtensionError(_, _) => "extension error", + ErrorRepr::IoError(ref err) => err.description(), + } + } + + fn cause(&self) -> Option<&dyn error::Error> { + match self.repr { + ErrorRepr::IoError(ref err) => Some(err as &dyn error::Error), + _ => None, + } + } +} + +impl fmt::Display for RedisError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + match self.repr { + ErrorRepr::WithDescription(kind, desc) => { + desc.fmt(f)?; + f.write_str("- ")?; + fmt::Debug::fmt(&kind, f) + } + ErrorRepr::WithDescriptionAndDetail(kind, desc, ref detail) => { + desc.fmt(f)?; + f.write_str(" - ")?; + fmt::Debug::fmt(&kind, f)?; + f.write_str(": ")?; + detail.fmt(f) + } + ErrorRepr::ExtensionError(ref code, ref detail) => { + code.fmt(f)?; + f.write_str(": ")?; + detail.fmt(f) + } + ErrorRepr::IoError(ref err) => err.fmt(f), + } + } +} + +impl fmt::Debug for RedisError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + fmt::Display::fmt(self, f) + } +} + +pub(crate) enum RetryMethod { + Reconnect, + NoRetry, + RetryImmediately, + WaitAndRetry, + AskRedirect, + MovedRedirect, + WaitAndRetryOnPrimaryRedirectOnReplica, +} + +/// Indicates a general failure in the library. +impl RedisError { + /// Returns the kind of the error. + pub fn kind(&self) -> ErrorKind { + match self.repr { + ErrorRepr::WithDescription(kind, _) + | ErrorRepr::WithDescriptionAndDetail(kind, _, _) => kind, + ErrorRepr::ExtensionError(_, _) => ErrorKind::ExtensionError, + ErrorRepr::IoError(_) => ErrorKind::IoError, + } + } + + /// Returns the error detail. + pub fn detail(&self) -> Option<&str> { + match self.repr { + ErrorRepr::WithDescriptionAndDetail(_, _, ref detail) + | ErrorRepr::ExtensionError(_, ref detail) => Some(detail.as_str()), + _ => None, + } + } + + /// Returns the raw error code if available. + pub fn code(&self) -> Option<&str> { + match self.kind() { + ErrorKind::ResponseError => Some("ERR"), + ErrorKind::ExecAbortError => Some("EXECABORT"), + ErrorKind::BusyLoadingError => Some("LOADING"), + ErrorKind::NoScriptError => Some("NOSCRIPT"), + ErrorKind::Moved => Some("MOVED"), + ErrorKind::Ask => Some("ASK"), + ErrorKind::TryAgain => Some("TRYAGAIN"), + ErrorKind::ClusterDown => Some("CLUSTERDOWN"), + ErrorKind::CrossSlot => Some("CROSSSLOT"), + ErrorKind::MasterDown => Some("MASTERDOWN"), + ErrorKind::ReadOnly => Some("READONLY"), + ErrorKind::NotBusy => Some("NOTBUSY"), + _ => match self.repr { + ErrorRepr::ExtensionError(ref code, _) => Some(code), + _ => None, + }, + } + } + + /// Returns the name of the error category for display purposes. + pub fn category(&self) -> &str { + match self.kind() { + ErrorKind::ResponseError => "response error", + ErrorKind::AuthenticationFailed => "authentication failed", + ErrorKind::TypeError => "type error", + ErrorKind::ExecAbortError => "script execution aborted", + ErrorKind::BusyLoadingError => "busy loading", + ErrorKind::NoScriptError => "no script", + ErrorKind::InvalidClientConfig => "invalid client config", + ErrorKind::Moved => "key moved", + ErrorKind::Ask => "key moved (ask)", + ErrorKind::TryAgain => "try again", + ErrorKind::ClusterDown => "cluster down", + ErrorKind::CrossSlot => "cross-slot", + ErrorKind::MasterDown => "master down", + ErrorKind::IoError => "I/O error", + ErrorKind::ExtensionError => "extension error", + ErrorKind::ClientError => "client error", + ErrorKind::ReadOnly => "read-only", + ErrorKind::MasterNameNotFoundBySentinel => "master name not found by sentinel", + ErrorKind::NoValidReplicasFoundBySentinel => "no valid replicas found by sentinel", + ErrorKind::EmptySentinelList => "empty sentinel list", + ErrorKind::NotBusy => "not busy", + ErrorKind::AllConnectionsUnavailable => "no valid connections remain in the cluster", + ErrorKind::ConnectionNotFoundForRoute => "No connection found for the requested route", + #[cfg(feature = "json")] + ErrorKind::Serialize => "serializing", + ErrorKind::RESP3NotSupported => "resp3 is not supported by server", + ErrorKind::ParseError => "parse error", + ErrorKind::NotAllSlotsCovered => "not all slots are covered", + } + } + + /// Indicates that this failure is an IO failure. + pub fn is_io_error(&self) -> bool { + self.as_io_error().is_some() + } + + pub(crate) fn as_io_error(&self) -> Option<&io::Error> { + match &self.repr { + ErrorRepr::IoError(e) => Some(e), + _ => None, + } + } + + /// Indicates that this is a cluster error. + pub fn is_cluster_error(&self) -> bool { + matches!( + self.kind(), + ErrorKind::Moved | ErrorKind::Ask | ErrorKind::TryAgain | ErrorKind::ClusterDown + ) + } + + /// Returns true if this error indicates that the connection was + /// refused. You should generally not rely much on this function + /// unless you are writing unit tests that want to detect if a + /// local server is available. + pub fn is_connection_refusal(&self) -> bool { + match self.repr { + ErrorRepr::IoError(ref err) => { + #[allow(clippy::match_like_matches_macro)] + match err.kind() { + io::ErrorKind::ConnectionRefused => true, + // if we connect to a unix socket and the file does not + // exist yet, then we want to treat this as if it was a + // connection refusal. + io::ErrorKind::NotFound => cfg!(unix), + _ => false, + } + } + _ => false, + } + } + + /// Returns true if error was caused by I/O time out. + /// Note that this may not be accurate depending on platform. + pub fn is_timeout(&self) -> bool { + match self.repr { + ErrorRepr::IoError(ref err) => matches!( + err.kind(), + io::ErrorKind::TimedOut | io::ErrorKind::WouldBlock + ), + _ => false, + } + } + + /// Returns true if error was caused by a dropped connection. + pub fn is_connection_dropped(&self) -> bool { + match self.repr { + ErrorRepr::IoError(ref err) => matches!( + err.kind(), + io::ErrorKind::BrokenPipe + | io::ErrorKind::ConnectionReset + | io::ErrorKind::UnexpectedEof + ), + _ => false, + } + } + + /// Returns true if the error is likely to not be recoverable, and the connection must be replaced. + pub fn is_unrecoverable_error(&self) -> bool { + match self.retry_method() { + RetryMethod::Reconnect => true, + + RetryMethod::NoRetry => false, + RetryMethod::RetryImmediately => false, + RetryMethod::WaitAndRetry => false, + RetryMethod::AskRedirect => false, + RetryMethod::MovedRedirect => false, + RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica => false, + } + } + + /// Returns the node the error refers to. + /// + /// This returns `(addr, slot_id)`. + pub fn redirect_node(&self) -> Option<(&str, u16)> { + match self.kind() { + ErrorKind::Ask | ErrorKind::Moved => (), + _ => return None, + } + let mut iter = self.detail()?.split_ascii_whitespace(); + let slot_id: u16 = iter.next()?.parse().ok()?; + let addr = iter.next()?; + Some((addr, slot_id)) + } + + /// Returns the extension error code. + /// + /// This method should not be used because every time the redis library + /// adds support for a new error code it would disappear form this method. + /// `code()` always returns the code. + #[deprecated(note = "use code() instead")] + pub fn extension_error_code(&self) -> Option<&str> { + match self.repr { + ErrorRepr::ExtensionError(ref code, _) => Some(code), + _ => None, + } + } + + /// Clone the `RedisError`, throwing away non-cloneable parts of an `IoError`. + /// + /// Deriving `Clone` is not possible because the wrapped `io::Error` is not + /// cloneable. + /// + /// The `ioerror_description` parameter will be prepended to the message in + /// case an `IoError` is found. + #[cfg(feature = "connection-manager")] // Used to avoid "unused method" warning + pub(crate) fn clone_mostly(&self, ioerror_description: &'static str) -> Self { + let repr = match self.repr { + ErrorRepr::WithDescription(kind, desc) => ErrorRepr::WithDescription(kind, desc), + ErrorRepr::WithDescriptionAndDetail(kind, desc, ref detail) => { + ErrorRepr::WithDescriptionAndDetail(kind, desc, detail.clone()) + } + ErrorRepr::ExtensionError(ref code, ref detail) => { + ErrorRepr::ExtensionError(code.clone(), detail.clone()) + } + ErrorRepr::IoError(ref e) => ErrorRepr::IoError(io::Error::new( + e.kind(), + format!("{ioerror_description}: {e}"), + )), + }; + Self { repr } + } + + pub(crate) fn retry_method(&self) -> RetryMethod { + match self.kind() { + ErrorKind::Moved => RetryMethod::MovedRedirect, + ErrorKind::Ask => RetryMethod::AskRedirect, + + ErrorKind::TryAgain => RetryMethod::WaitAndRetry, + ErrorKind::MasterDown => RetryMethod::WaitAndRetry, + ErrorKind::ClusterDown => RetryMethod::WaitAndRetry, + ErrorKind::MasterNameNotFoundBySentinel => RetryMethod::WaitAndRetry, + ErrorKind::NoValidReplicasFoundBySentinel => RetryMethod::WaitAndRetry, + + ErrorKind::BusyLoadingError => RetryMethod::WaitAndRetryOnPrimaryRedirectOnReplica, + + ErrorKind::ResponseError => RetryMethod::NoRetry, + ErrorKind::ReadOnly => RetryMethod::NoRetry, + ErrorKind::ExtensionError => RetryMethod::NoRetry, + ErrorKind::ExecAbortError => RetryMethod::NoRetry, + ErrorKind::TypeError => RetryMethod::NoRetry, + ErrorKind::NoScriptError => RetryMethod::NoRetry, + ErrorKind::InvalidClientConfig => RetryMethod::NoRetry, + ErrorKind::CrossSlot => RetryMethod::NoRetry, + ErrorKind::ClientError => RetryMethod::NoRetry, + ErrorKind::EmptySentinelList => RetryMethod::NoRetry, + ErrorKind::NotBusy => RetryMethod::NoRetry, + #[cfg(feature = "json")] + ErrorKind::Serialize => RetryMethod::NoRetry, + ErrorKind::RESP3NotSupported => RetryMethod::NoRetry, + + ErrorKind::ParseError => RetryMethod::Reconnect, + ErrorKind::AuthenticationFailed => RetryMethod::Reconnect, + ErrorKind::AllConnectionsUnavailable => RetryMethod::Reconnect, + ErrorKind::ConnectionNotFoundForRoute => RetryMethod::Reconnect, + + ErrorKind::IoError => match &self.repr { + ErrorRepr::IoError(err) => match err.kind() { + io::ErrorKind::ConnectionRefused => RetryMethod::Reconnect, + io::ErrorKind::NotFound => RetryMethod::Reconnect, + io::ErrorKind::ConnectionReset => RetryMethod::Reconnect, + io::ErrorKind::ConnectionAborted => RetryMethod::Reconnect, + io::ErrorKind::NotConnected => RetryMethod::Reconnect, + io::ErrorKind::BrokenPipe => RetryMethod::Reconnect, + io::ErrorKind::UnexpectedEof => RetryMethod::Reconnect, + + io::ErrorKind::PermissionDenied => RetryMethod::NoRetry, + io::ErrorKind::Unsupported => RetryMethod::NoRetry, + + _ => RetryMethod::RetryImmediately, + }, + _ => RetryMethod::RetryImmediately, + }, + ErrorKind::NotAllSlotsCovered => RetryMethod::NoRetry, + } + } +} + +pub fn make_extension_error(code: String, detail: Option) -> RedisError { + RedisError { + repr: ErrorRepr::ExtensionError( + code, + match detail { + Some(x) => x, + None => "Unknown extension error encountered".to_string(), + }, + ), + } +} + +/// Library generic result type. +pub type RedisResult = Result; + +/// Library generic future type. +#[cfg(feature = "aio")] +pub type RedisFuture<'a, T> = futures_util::future::BoxFuture<'a, RedisResult>; + +/// An info dictionary type. +#[derive(Debug, Clone)] +pub struct InfoDict { + map: HashMap, +} + +/// This type provides convenient access to key/value data returned by +/// the "INFO" command. It acts like a regular mapping but also has +/// a convenience method `get` which can return data in the appropriate +/// type. +/// +/// For instance this can be used to query the server for the role it's +/// in (master, slave) etc: +/// +/// ```rust,no_run +/// # fn do_something() -> redis::RedisResult<()> { +/// # let client = redis::Client::open("redis://127.0.0.1/").unwrap(); +/// # let mut con = client.get_connection(None).unwrap(); +/// let info : redis::InfoDict = redis::cmd("INFO").query(&mut con)?; +/// let role : Option = info.get("role"); +/// # Ok(()) } +/// ``` +impl InfoDict { + /// Creates a new info dictionary from a string in the response of + /// the INFO command. Each line is a key, value pair with the + /// key and value separated by a colon (`:`). Lines starting with a + /// hash (`#`) are ignored. + pub fn new(kvpairs: &str) -> InfoDict { + let mut map = HashMap::new(); + for line in kvpairs.lines() { + if line.is_empty() || line.starts_with('#') { + continue; + } + let mut p = line.splitn(2, ':'); + let (k, v) = match (p.next(), p.next()) { + (Some(k), Some(v)) => (k.to_string(), v.to_string()), + _ => continue, + }; + map.insert(k, Value::SimpleString(v)); + } + InfoDict { map } + } + + /// Fetches a value by key and converts it into the given type. + /// Typical types are `String`, `bool` and integer types. + pub fn get(&self, key: &str) -> Option { + match self.find(&key) { + Some(x) => from_redis_value(x).ok(), + None => None, + } + } + + /// Looks up a key in the info dict. + pub fn find(&self, key: &&str) -> Option<&Value> { + self.map.get(*key) + } + + /// Checks if a key is contained in the info dicf. + pub fn contains_key(&self, key: &&str) -> bool { + self.find(key).is_some() + } + + /// Returns the size of the info dict. + pub fn len(&self) -> usize { + self.map.len() + } + + /// Checks if the dict is empty. + pub fn is_empty(&self) -> bool { + self.map.is_empty() + } +} + +impl Deref for InfoDict { + type Target = HashMap; + + fn deref(&self) -> &Self::Target { + &self.map + } +} + +/// Abstraction trait for redis command abstractions. +pub trait RedisWrite { + /// Accepts a serialized redis command. + fn write_arg(&mut self, arg: &[u8]); + + /// Accepts a serialized redis command. + fn write_arg_fmt(&mut self, arg: impl fmt::Display) { + self.write_arg(arg.to_string().as_bytes()) + } +} + +impl RedisWrite for Vec> { + fn write_arg(&mut self, arg: &[u8]) { + self.push(arg.to_owned()); + } + + fn write_arg_fmt(&mut self, arg: impl fmt::Display) { + self.push(arg.to_string().into_bytes()) + } +} + +/// Used to convert a value into one or multiple redis argument +/// strings. Most values will produce exactly one item but in +/// some cases it might make sense to produce more than one. +pub trait ToRedisArgs: Sized { + /// This converts the value into a vector of bytes. Each item + /// is a single argument. Most items generate a vector of a + /// single item. + /// + /// The exception to this rule currently are vectors of items. + fn to_redis_args(&self) -> Vec> { + let mut out = Vec::new(); + self.write_redis_args(&mut out); + out + } + + /// This writes the value into a vector of bytes. Each item + /// is a single argument. Most items generate a single item. + /// + /// The exception to this rule currently are vectors of items. + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite; + + /// Returns an information about the contained value with regards + /// to it's numeric behavior in a redis context. This is used in + /// some high level concepts to switch between different implementations + /// of redis functions (for instance `INCR` vs `INCRBYFLOAT`). + fn describe_numeric_behavior(&self) -> NumericBehavior { + NumericBehavior::NonNumeric + } + + /// Returns an indiciation if the value contained is exactly one + /// argument. It returns false if it's zero or more than one. This + /// is used in some high level functions to intelligently switch + /// between `GET` and `MGET` variants. + fn is_single_arg(&self) -> bool { + true + } + + /// This only exists internally as a workaround for the lack of + /// specialization. + #[doc(hidden)] + fn write_args_from_slice(items: &[Self], out: &mut W) + where + W: ?Sized + RedisWrite, + { + Self::make_arg_iter_ref(items.iter(), out) + } + + /// This only exists internally as a workaround for the lack of + /// specialization. + #[doc(hidden)] + fn make_arg_iter_ref<'a, I, W>(items: I, out: &mut W) + where + W: ?Sized + RedisWrite, + I: Iterator, + Self: 'a, + { + for item in items { + item.write_redis_args(out); + } + } + + #[doc(hidden)] + fn is_single_vec_arg(items: &[Self]) -> bool { + items.len() == 1 && items[0].is_single_arg() + } +} + +macro_rules! itoa_based_to_redis_impl { + ($t:ty, $numeric:expr) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::itoa::Buffer::new(); + let s = buf.format(*self); + out.write_arg(s.as_bytes()) + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + $numeric + } + } + }; +} + +macro_rules! non_zero_itoa_based_to_redis_impl { + ($t:ty, $numeric:expr) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::itoa::Buffer::new(); + let s = buf.format(self.get()); + out.write_arg(s.as_bytes()) + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + $numeric + } + } + }; +} + +macro_rules! ryu_based_to_redis_impl { + ($t:ty, $numeric:expr) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::ryu::Buffer::new(); + let s = buf.format(*self); + out.write_arg(s.as_bytes()) + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + $numeric + } + } + }; +} + +impl ToRedisArgs for u8 { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + let mut buf = ::itoa::Buffer::new(); + let s = buf.format(*self); + out.write_arg(s.as_bytes()) + } + + fn write_args_from_slice(items: &[u8], out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(items); + } + + fn is_single_vec_arg(_items: &[u8]) -> bool { + true + } +} + +itoa_based_to_redis_impl!(i8, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(i16, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(u16, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(i32, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(u32, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(i64, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(u64, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(isize, NumericBehavior::NumberIsInteger); +itoa_based_to_redis_impl!(usize, NumericBehavior::NumberIsInteger); + +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU8, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI8, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU16, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI16, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU32, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI32, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroU64, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroI64, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroUsize, NumericBehavior::NumberIsInteger); +non_zero_itoa_based_to_redis_impl!(core::num::NonZeroIsize, NumericBehavior::NumberIsInteger); + +ryu_based_to_redis_impl!(f32, NumericBehavior::NumberIsFloat); +ryu_based_to_redis_impl!(f64, NumericBehavior::NumberIsFloat); + +#[cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +macro_rules! bignum_to_redis_impl { + ($t:ty) => { + impl ToRedisArgs for $t { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(&self.to_string().into_bytes()) + } + } + }; +} + +#[cfg(feature = "rust_decimal")] +bignum_to_redis_impl!(rust_decimal::Decimal); +#[cfg(feature = "bigdecimal")] +bignum_to_redis_impl!(bigdecimal::BigDecimal); +#[cfg(feature = "num-bigint")] +bignum_to_redis_impl!(num_bigint::BigInt); +#[cfg(feature = "num-bigint")] +bignum_to_redis_impl!(num_bigint::BigUint); + +impl ToRedisArgs for bool { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(if *self { b"1" } else { b"0" }) + } +} + +impl ToRedisArgs for String { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(self.as_bytes()) + } +} + +impl<'a> ToRedisArgs for &'a str { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(self.as_bytes()) + } +} + +impl ToRedisArgs for Vec { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_args_from_slice(self, out) + } + + fn is_single_arg(&self) -> bool { + ToRedisArgs::is_single_vec_arg(&self[..]) + } +} + +impl<'a, T: ToRedisArgs> ToRedisArgs for &'a [T] { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_args_from_slice(self, out) + } + + fn is_single_arg(&self) -> bool { + ToRedisArgs::is_single_vec_arg(self) + } +} + +impl ToRedisArgs for Option { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + if let Some(ref x) = *self { + x.write_redis_args(out); + } + } + + fn describe_numeric_behavior(&self) -> NumericBehavior { + match *self { + Some(ref x) => x.describe_numeric_behavior(), + None => NumericBehavior::NonNumeric, + } + } + + fn is_single_arg(&self) -> bool { + match *self { + Some(ref x) => x.is_single_arg(), + None => false, + } + } +} + +impl ToRedisArgs for &T { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + (*self).write_redis_args(out) + } + + fn is_single_arg(&self) -> bool { + (*self).is_single_arg() + } +} + +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +impl ToRedisArgs + for std::collections::HashSet +{ + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::make_arg_iter_ref(self.iter(), out) + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +#[cfg(feature = "ahash")] +impl ToRedisArgs for ahash::AHashSet { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::make_arg_iter_ref(self.iter(), out) + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +impl ToRedisArgs for BTreeSet { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::make_arg_iter_ref(self.iter(), out) + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +/// this flattens BTreeMap into something that goes well with HMSET +/// @note: Redis cannot store empty sets so the application has to +/// check whether the set is empty and if so, not attempt to use that +/// result +impl ToRedisArgs for BTreeMap { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + for (key, value) in self { + // otherwise things like HMSET will simply NOT work + assert!(key.is_single_arg() && value.is_single_arg()); + + key.write_redis_args(out); + value.write_redis_args(out); + } + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +impl ToRedisArgs + for std::collections::HashMap +{ + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + for (key, value) in self { + assert!(key.is_single_arg() && value.is_single_arg()); + + key.write_redis_args(out); + value.write_redis_args(out); + } + } + + fn is_single_arg(&self) -> bool { + self.len() <= 1 + } +} + +macro_rules! to_redis_args_for_tuple { + () => (); + ($($name:ident,)+) => ( + #[doc(hidden)] + impl<$($name: ToRedisArgs),*> ToRedisArgs for ($($name,)*) { + // we have local variables named T1 as dummies and those + // variables are unused. + #[allow(non_snake_case, unused_variables)] + fn write_redis_args(&self, out: &mut W) where W: ?Sized + RedisWrite { + let ($(ref $name,)*) = *self; + $($name.write_redis_args(out);)* + } + + #[allow(non_snake_case, unused_variables)] + fn is_single_arg(&self) -> bool { + let mut n = 0u32; + $(let $name = (); n += 1;)* + n == 1 + } + } + to_redis_args_for_tuple_peel!($($name,)*); + ) +} + +/// This chips of the leading one and recurses for the rest. So if the first +/// iteration was T1, T2, T3 it will recurse to T2, T3. It stops for tuples +/// of size 1 (does not implement down to unit). +macro_rules! to_redis_args_for_tuple_peel { + ($name:ident, $($other:ident,)*) => (to_redis_args_for_tuple!($($other,)*);) +} + +to_redis_args_for_tuple! { T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, } + +impl ToRedisArgs for &[T; N] { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + ToRedisArgs::write_args_from_slice(self.as_slice(), out) + } + + fn is_single_arg(&self) -> bool { + ToRedisArgs::is_single_vec_arg(self.as_slice()) + } +} + +fn vec_to_array(items: Vec, original_value: &Value) -> RedisResult<[T; N]> { + match items.try_into() { + Ok(array) => Ok(array), + Err(items) => { + let msg = format!( + "Response has wrong dimension, expected {N}, got {}", + items.len() + ); + invalid_type_error!(original_value, msg) + } + } +} + +impl FromRedisValue for [T; N] { + fn from_redis_value(value: &Value) -> RedisResult<[T; N]> { + match *value { + Value::BulkString(ref bytes) => match FromRedisValue::from_byte_vec(bytes) { + Some(items) => vec_to_array(items, value), + None => { + let msg = format!( + "Conversion to Array[{}; {N}] failed", + std::any::type_name::() + ); + invalid_type_error!(value, msg) + } + }, + Value::Array(ref items) => { + let items = FromRedisValue::from_redis_values(items)?; + vec_to_array(items, value) + } + Value::Nil => vec_to_array(vec![], value), + _ => invalid_type_error!(value, "Response type not array compatible"), + } + } +} + +/// This trait is used to convert a redis value into a more appropriate +/// type. While a redis `Value` can represent any response that comes +/// back from the redis server, usually you want to map this into something +/// that works better in rust. For instance you might want to convert the +/// return value into a `String` or an integer. +/// +/// This trait is well supported throughout the library and you can +/// implement it for your own types if you want. +/// +/// In addition to what you can see from the docs, this is also implemented +/// for tuples up to size 12 and for `Vec`. +pub trait FromRedisValue: Sized { + /// Given a redis `Value` this attempts to convert it into the given + /// destination type. If that fails because it's not compatible an + /// appropriate error is generated. + fn from_redis_value(v: &Value) -> RedisResult; + + /// Given a redis `Value` this attempts to convert it into the given + /// destination type. If that fails because it's not compatible an + /// appropriate error is generated. + fn from_owned_redis_value(v: Value) -> RedisResult { + // By default, fall back to `from_redis_value`. + // This function only needs to be implemented if it can benefit + // from taking `v` by value. + Self::from_redis_value(&v) + } + + /// Similar to `from_redis_value` but constructs a vector of objects + /// from another vector of values. This primarily exists internally + /// to customize the behavior for vectors of tuples. + fn from_redis_values(items: &[Value]) -> RedisResult> { + items.iter().map(FromRedisValue::from_redis_value).collect() + } + + /// The same as `from_redis_values`, but takes a `Vec` instead + /// of a `&[Value]`. + fn from_owned_redis_values(items: Vec) -> RedisResult> { + items + .into_iter() + .map(FromRedisValue::from_owned_redis_value) + .collect() + } + + /// Convert bytes to a single element vector. + fn from_byte_vec(_vec: &[u8]) -> Option> { + Self::from_owned_redis_value(Value::BulkString(_vec.into())) + .map(|rv| vec![rv]) + .ok() + } + + /// Convert bytes to a single element vector. + fn from_owned_byte_vec(_vec: Vec) -> RedisResult> { + Self::from_owned_redis_value(Value::BulkString(_vec)).map(|rv| vec![rv]) + } +} + +fn get_inner_value(v: &Value) -> &Value { + if let Value::Attribute { + data, + attributes: _, + } = v + { + data.as_ref() + } else { + v + } +} + +fn get_owned_inner_value(v: Value) -> Value { + if let Value::Attribute { + data, + attributes: _, + } = v + { + *data + } else { + v + } +} + +macro_rules! from_redis_value_for_num_internal { + ($t:ty, $v:expr) => {{ + let v = if let Value::Attribute { + data, + attributes: _, + } = $v + { + data + } else { + $v + }; + match *v { + Value::Int(val) => Ok(val as $t), + Value::SimpleString(ref s) => match s.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + Value::BulkString(ref bytes) => match from_utf8(bytes)?.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + Value::Double(val) => Ok(val as $t), + _ => invalid_type_error!(v, "Response type not convertible to numeric."), + } + }}; +} + +macro_rules! from_redis_value_for_num { + ($t:ty) => { + impl FromRedisValue for $t { + fn from_redis_value(v: &Value) -> RedisResult<$t> { + from_redis_value_for_num_internal!($t, v) + } + } + }; +} + +impl FromRedisValue for u8 { + fn from_redis_value(v: &Value) -> RedisResult { + from_redis_value_for_num_internal!(u8, v) + } + + // this hack allows us to specialize Vec to work with binary data. + fn from_byte_vec(vec: &[u8]) -> Option> { + Some(vec.to_vec()) + } + fn from_owned_byte_vec(vec: Vec) -> RedisResult> { + Ok(vec) + } +} + +from_redis_value_for_num!(i8); +from_redis_value_for_num!(i16); +from_redis_value_for_num!(u16); +from_redis_value_for_num!(i32); +from_redis_value_for_num!(u32); +from_redis_value_for_num!(i64); +from_redis_value_for_num!(u64); +from_redis_value_for_num!(i128); +from_redis_value_for_num!(u128); +from_redis_value_for_num!(f32); +from_redis_value_for_num!(f64); +from_redis_value_for_num!(isize); +from_redis_value_for_num!(usize); + +#[cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +macro_rules! from_redis_value_for_bignum_internal { + ($t:ty, $v:expr) => {{ + let v = $v; + match *v { + Value::Int(val) => <$t>::try_from(val) + .map_err(|_| invalid_type_error_inner!(v, "Could not convert from integer.")), + Value::SimpleString(ref s) => match s.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + Value::BulkString(ref bytes) => match from_utf8(bytes)?.parse::<$t>() { + Ok(rv) => Ok(rv), + Err(_) => invalid_type_error!(v, "Could not convert from string."), + }, + _ => invalid_type_error!(v, "Response type not convertible to numeric."), + } + }}; +} + +#[cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +macro_rules! from_redis_value_for_bignum { + ($t:ty) => { + impl FromRedisValue for $t { + fn from_redis_value(v: &Value) -> RedisResult<$t> { + from_redis_value_for_bignum_internal!($t, v) + } + } + }; +} + +#[cfg(feature = "rust_decimal")] +from_redis_value_for_bignum!(rust_decimal::Decimal); +#[cfg(feature = "bigdecimal")] +from_redis_value_for_bignum!(bigdecimal::BigDecimal); +#[cfg(feature = "num-bigint")] +from_redis_value_for_bignum!(num_bigint::BigInt); +#[cfg(feature = "num-bigint")] +from_redis_value_for_bignum!(num_bigint::BigUint); + +impl FromRedisValue for bool { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match *v { + Value::Nil => Ok(false), + Value::Int(val) => Ok(val != 0), + Value::SimpleString(ref s) => { + if &s[..] == "1" { + Ok(true) + } else if &s[..] == "0" { + Ok(false) + } else { + invalid_type_error!(v, "Response status not valid boolean"); + } + } + Value::BulkString(ref bytes) => { + if bytes == b"1" { + Ok(true) + } else if bytes == b"0" { + Ok(false) + } else { + invalid_type_error!(v, "Response type not bool compatible."); + } + } + Value::Boolean(b) => Ok(b), + Value::Okay => Ok(true), + _ => invalid_type_error!(v, "Response type not bool compatible."), + } + } +} + +impl FromRedisValue for CString { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match *v { + Value::BulkString(ref bytes) => Ok(CString::new(bytes.as_slice())?), + Value::Okay => Ok(CString::new("OK")?), + Value::SimpleString(ref val) => Ok(CString::new(val.as_bytes())?), + _ => invalid_type_error!(v, "Response type not CString compatible."), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + match v { + Value::BulkString(bytes) => Ok(CString::new(bytes)?), + Value::Okay => Ok(CString::new("OK")?), + Value::SimpleString(val) => Ok(CString::new(val)?), + _ => invalid_type_error!(v, "Response type not CString compatible."), + } + } +} + +impl FromRedisValue for String { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match *v { + Value::BulkString(ref bytes) => Ok(from_utf8(bytes)?.to_string()), + Value::Okay => Ok("OK".to_string()), + Value::SimpleString(ref val) => Ok(val.to_string()), + Value::VerbatimString { + format: _, + ref text, + } => Ok(text.to_string()), + Value::Double(ref val) => Ok(val.to_string()), + Value::Int(val) => Ok(val.to_string()), + _ => invalid_type_error!(v, "Response type not string compatible."), + } + } + + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + match v { + Value::BulkString(bytes) => Ok(String::from_utf8(bytes)?), + Value::Okay => Ok("OK".to_string()), + Value::SimpleString(val) => Ok(val), + Value::VerbatimString { format: _, text } => Ok(text), + Value::Double(val) => Ok(val.to_string()), + Value::Int(val) => Ok(val.to_string()), + _ => invalid_type_error!(v, "Response type not string compatible."), + } + } +} + +/// Implement `FromRedisValue` for `$Type` (which should use the generic parameter `$T`). +/// +/// The implementation parses the value into a vec, and then passes the value through `$convert`. +/// If `$convert` is ommited, it defaults to `Into::into`. +macro_rules! from_vec_from_redis_value { + (<$T:ident> $Type:ty) => { + from_vec_from_redis_value!(<$T> $Type; Into::into); + }; + + (<$T:ident> $Type:ty; $convert:expr) => { + impl<$T: FromRedisValue> FromRedisValue for $Type { + fn from_redis_value(v: &Value) -> RedisResult<$Type> { + match v { + // All binary data except u8 will try to parse into a single element vector. + // u8 has its own implementation of from_byte_vec. + Value::BulkString(bytes) => match FromRedisValue::from_byte_vec(bytes) { + Some(x) => Ok($convert(x)), + None => invalid_type_error!( + v, + format!("Conversion to {} failed.", std::any::type_name::<$Type>()) + ), + }, + Value::Array(items) => FromRedisValue::from_redis_values(items).map($convert), + Value::Set(ref items) => FromRedisValue::from_redis_values(items).map($convert), + Value::Map(ref items) => { + let mut n: Vec = vec![]; + for item in items { + match FromRedisValue::from_redis_value(&Value::Map(vec![item.clone()])) { + Ok(v) => { + n.push(v); + } + Err(e) => { + return Err(e); + } + } + } + Ok($convert(n)) + } + Value::Nil => Ok($convert(Vec::new())), + _ => invalid_type_error!(v, "Response type not vector compatible."), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult<$Type> { + match v { + // Binary data is parsed into a single-element vector, except + // for the element type `u8`, which directly consumes the entire + // array of bytes. + Value::BulkString(bytes) => FromRedisValue::from_owned_byte_vec(bytes).map($convert), + Value::Array(items) => FromRedisValue::from_owned_redis_values(items).map($convert), + Value::Set(items) => FromRedisValue::from_owned_redis_values(items).map($convert), + Value::Map(items) => { + let mut n: Vec = vec![]; + for item in items { + match FromRedisValue::from_owned_redis_value(Value::Map(vec![item])) { + Ok(v) => { + n.push(v); + } + Err(e) => { + return Err(e); + } + } + } + Ok($convert(n)) + } + Value::Nil => Ok($convert(Vec::new())), + _ => invalid_type_error!(v, "Response type not vector compatible."), + } + } + } + }; +} + +from_vec_from_redis_value!( Vec); +from_vec_from_redis_value!( std::sync::Arc<[T]>); +from_vec_from_redis_value!( Box<[T]>; Vec::into_boxed_slice); + +impl FromRedisValue + for std::collections::HashMap +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + match *v { + Value::Nil => Ok(Default::default()), + _ => v + .as_map_iter() + .ok_or_else(|| { + invalid_type_error_inner!(v, "Response type not hashmap compatible") + })? + .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) + .collect(), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + match v { + Value::Nil => Ok(Default::default()), + _ => v + .into_map_iter() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashmap compatible"))? + .map(|(k, v)| Ok((from_owned_redis_value(k)?, from_owned_redis_value(v)?))) + .collect(), + } + } +} + +#[cfg(feature = "ahash")] +impl FromRedisValue for ahash::AHashMap { + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + match *v { + Value::Nil => Ok(ahash::AHashMap::with_hasher(Default::default())), + _ => v + .as_map_iter() + .ok_or_else(|| { + invalid_type_error_inner!(v, "Response type not hashmap compatible") + })? + .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) + .collect(), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + match v { + Value::Nil => Ok(ahash::AHashMap::with_hasher(Default::default())), + _ => v + .into_map_iter() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashmap compatible"))? + .map(|(k, v)| Ok((from_owned_redis_value(k)?, from_owned_redis_value(v)?))) + .collect(), + } + } +} + +impl FromRedisValue for BTreeMap +where + K: Ord, +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + v.as_map_iter() + .ok_or_else(|| invalid_type_error_inner!(v, "Response type not btreemap compatible"))? + .map(|(k, v)| Ok((from_redis_value(k)?, from_redis_value(v)?))) + .collect() + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + v.into_map_iter() + .map_err(|v| invalid_type_error_inner!(v, "Response type not btreemap compatible"))? + .map(|(k, v)| Ok((from_owned_redis_value(k)?, from_owned_redis_value(v)?))) + .collect() + } +} + +impl FromRedisValue + for std::collections::HashSet +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + let items = v + .as_sequence() + .ok_or_else(|| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items.iter().map(|item| from_redis_value(item)).collect() + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + let items = v + .into_sequence() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items + .into_iter() + .map(|item| from_owned_redis_value(item)) + .collect() + } +} + +#[cfg(feature = "ahash")] +impl FromRedisValue for ahash::AHashSet { + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + let items = v + .as_sequence() + .ok_or_else(|| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items.iter().map(|item| from_redis_value(item)).collect() + } + + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + let items = v + .into_sequence() + .map_err(|v| invalid_type_error_inner!(v, "Response type not hashset compatible"))?; + items + .into_iter() + .map(|item| from_owned_redis_value(item)) + .collect() + } +} + +impl FromRedisValue for BTreeSet +where + T: Ord, +{ + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + let items = v + .as_sequence() + .ok_or_else(|| invalid_type_error_inner!(v, "Response type not btreeset compatible"))?; + items.iter().map(|item| from_redis_value(item)).collect() + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + let items = v + .into_sequence() + .map_err(|v| invalid_type_error_inner!(v, "Response type not btreeset compatible"))?; + items + .into_iter() + .map(|item| from_owned_redis_value(item)) + .collect() + } +} + +impl FromRedisValue for Value { + fn from_redis_value(v: &Value) -> RedisResult { + Ok(v.clone()) + } + fn from_owned_redis_value(v: Value) -> RedisResult { + Ok(v) + } +} + +impl FromRedisValue for () { + fn from_redis_value(_v: &Value) -> RedisResult<()> { + Ok(()) + } +} + +macro_rules! from_redis_value_for_tuple { + () => (); + ($($name:ident,)+) => ( + #[doc(hidden)] + impl<$($name: FromRedisValue),*> FromRedisValue for ($($name,)*) { + // we have local variables named T1 as dummies and those + // variables are unused. + #[allow(non_snake_case, unused_variables)] + fn from_redis_value(v: &Value) -> RedisResult<($($name,)*)> { + let v = get_inner_value(v); + match *v { + Value::Array(ref items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if items.len() != n { + invalid_type_error!(v, "Array response of wrong dimension") + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_redis_value( + &items[{ i += 1; i - 1 }])?},)*)) + } + + Value::Map(ref items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if n != 2 { + invalid_type_error!(v, "Map response of wrong dimension") + } + + let mut flatten_items = vec![]; + for (k,v) in items { + flatten_items.push(k); + flatten_items.push(v); + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_redis_value( + &flatten_items[{ i += 1; i - 1 }])?},)*)) + } + + _ => invalid_type_error!(v, "Not a Array response") + } + } + + // we have local variables named T1 as dummies and those + // variables are unused. + #[allow(non_snake_case, unused_variables)] + fn from_owned_redis_value(v: Value) -> RedisResult<($($name,)*)> { + let v = get_owned_inner_value(v); + match v { + Value::Array(mut items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if items.len() != n { + invalid_type_error!(Value::Array(items), "Array response of wrong dimension") + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_owned_redis_value( + ::std::mem::replace(&mut items[{ i += 1; i - 1 }], Value::Nil) + )?},)*)) + } + + Value::Map(items) => { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + if n != 2 { + invalid_type_error!(Value::Map(items), "Map response of wrong dimension") + } + + let mut flatten_items = vec![]; + for (k,v) in items { + flatten_items.push(k); + flatten_items.push(v); + } + + // this is pretty ugly too. The { i += 1; i - 1} is rust's + // postfix increment :) + let mut i = 0; + Ok(($({let $name = (); from_redis_value( + &flatten_items[{ i += 1; i - 1 }])?},)*)) + } + + _ => invalid_type_error!(v, "Not a Array response") + } + } + + #[allow(non_snake_case, unused_variables)] + fn from_redis_values(items: &[Value]) -> RedisResult> { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + let mut rv = vec![]; + if items.len() == 0 { + return Ok(rv) + } + //It's uglier then before! + for item in items { + match item { + Value::Array(ch) => { + if let [$($name),*] = &ch[..] { + rv.push(($(from_redis_value(&$name)?),*),) + } else { + unreachable!() + }; + }, + _ => {}, + + } + } + if !rv.is_empty(){ + return Ok(rv); + } + + if let [$($name),*] = items{ + rv.push(($(from_redis_value($name)?),*),); + return Ok(rv); + } + for chunk in items.chunks_exact(n) { + match chunk { + [$($name),*] => rv.push(($(from_redis_value($name)?),*),), + _ => {}, + } + } + Ok(rv) + } + + #[allow(non_snake_case, unused_variables)] + fn from_owned_redis_values(mut items: Vec) -> RedisResult> { + // hacky way to count the tuple size + let mut n = 0; + $(let $name = (); n += 1;)* + + let mut rv = vec![]; + if items.len() == 0 { + return Ok(rv) + } + //It's uglier then before! + for item in items.iter() { + match item { + Value::Array(ch) => { + // TODO - this copies when we could've used the owned value. need to find out how to do this. + if let [$($name),*] = &ch[..] { + rv.push(($(from_redis_value($name)?),*),) + } else { + unreachable!() + }; + }, + _ => {}, + } + } + if !rv.is_empty(){ + return Ok(rv); + } + + let mut rv = Vec::with_capacity(items.len() / n); + if items.len() == 0 { + return Ok(rv) + } + for chunk in items.chunks_mut(n) { + match chunk { + // Take each element out of the chunk with `std::mem::replace`, leaving a `Value::Nil` + // in its place. This allows each `Value` to be parsed without being copied. + // Since `items` is consume by this function and not used later, this replacement + // is not observable to the rest of the code. + [$($name),*] => rv.push(($(from_owned_redis_value(std::mem::replace($name, Value::Nil))?),*),), + _ => unreachable!(), + } + } + Ok(rv) + } + } + from_redis_value_for_tuple_peel!($($name,)*); + ) +} + +/// This chips of the leading one and recurses for the rest. So if the first +/// iteration was T1, T2, T3 it will recurse to T2, T3. It stops for tuples +/// of size 1 (does not implement down to unit). +macro_rules! from_redis_value_for_tuple_peel { + ($name:ident, $($other:ident,)*) => (from_redis_value_for_tuple!($($other,)*);) +} + +from_redis_value_for_tuple! { T1, T2, T3, T4, T5, T6, T7, T8, T9, T10, T11, T12, } + +impl FromRedisValue for InfoDict { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + let s: String = from_redis_value(v)?; + Ok(InfoDict::new(&s)) + } + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + let s: String = from_owned_redis_value(v)?; + Ok(InfoDict::new(&s)) + } +} + +impl FromRedisValue for Option { + fn from_redis_value(v: &Value) -> RedisResult> { + let v = get_inner_value(v); + if *v == Value::Nil { + return Ok(None); + } + Ok(Some(from_redis_value(v)?)) + } + fn from_owned_redis_value(v: Value) -> RedisResult> { + let v = get_owned_inner_value(v); + if v == Value::Nil { + return Ok(None); + } + Ok(Some(from_owned_redis_value(v)?)) + } +} + +#[cfg(feature = "bytes")] +impl FromRedisValue for bytes::Bytes { + fn from_redis_value(v: &Value) -> RedisResult { + let v = get_inner_value(v); + match v { + Value::BulkString(bytes_vec) => Ok(bytes::Bytes::copy_from_slice(bytes_vec.as_ref())), + _ => invalid_type_error!(v, "Not a bulk string"), + } + } + fn from_owned_redis_value(v: Value) -> RedisResult { + let v = get_owned_inner_value(v); + match v { + Value::BulkString(bytes_vec) => Ok(bytes_vec.into()), + _ => invalid_type_error!(v, "Not a bulk string"), + } + } +} + +#[cfg(feature = "uuid")] +impl FromRedisValue for uuid::Uuid { + fn from_redis_value(v: &Value) -> RedisResult { + match *v { + Value::BulkString(ref bytes) => Ok(uuid::Uuid::from_slice(bytes)?), + _ => invalid_type_error!(v, "Response type not uuid compatible."), + } + } +} + +#[cfg(feature = "uuid")] +impl ToRedisArgs for uuid::Uuid { + fn write_redis_args(&self, out: &mut W) + where + W: ?Sized + RedisWrite, + { + out.write_arg(self.as_bytes()); + } +} + +/// A shortcut function to invoke `FromRedisValue::from_redis_value` +/// to make the API slightly nicer. +pub fn from_redis_value(v: &Value) -> RedisResult { + FromRedisValue::from_redis_value(v) +} + +/// A shortcut function to invoke `FromRedisValue::from_owned_redis_value` +/// to make the API slightly nicer. +pub fn from_owned_redis_value(v: Value) -> RedisResult { + FromRedisValue::from_owned_redis_value(v) +} + +/// Enum representing the communication protocol with the server. This enum represents the types +/// of data that the server can send to the client, and the capabilities that the client can use. +#[derive(Clone, Eq, PartialEq, Default, Debug, Copy)] +pub enum ProtocolVersion { + /// + #[default] + RESP2, + /// + RESP3, +} diff --git a/glide-core/redis-rs/redis/tests/parser.rs b/glide-core/redis-rs/redis/tests/parser.rs new file mode 100644 index 0000000000..c4083f44bd --- /dev/null +++ b/glide-core/redis-rs/redis/tests/parser.rs @@ -0,0 +1,195 @@ +use std::{io, pin::Pin}; + +use redis::Value; +use { + futures::{ + ready, + task::{self, Poll}, + }, + partial_io::{quickcheck_types::GenWouldBlock, quickcheck_types::PartialWithErrors, PartialOp}, + quickcheck::{quickcheck, Gen}, + tokio::io::{AsyncRead, ReadBuf}, +}; + +mod support; +use crate::support::{block_on_all, encode_value}; + +#[derive(Clone, Debug)] +struct ArbitraryValue(Value); + +impl ::quickcheck::Arbitrary for ArbitraryValue { + fn arbitrary(g: &mut Gen) -> Self { + let size = g.size(); + ArbitraryValue(arbitrary_value(g, size)) + } + + fn shrink(&self) -> Box> { + match self.0 { + Value::Nil | Value::Okay => Box::new(None.into_iter()), + Value::Int(i) => Box::new(i.shrink().map(Value::Int).map(ArbitraryValue)), + Value::BulkString(ref xs) => { + Box::new(xs.shrink().map(Value::BulkString).map(ArbitraryValue)) + } + Value::Array(ref xs) | Value::Set(ref xs) => { + let ys = xs + .iter() + .map(|x| ArbitraryValue(x.clone())) + .collect::>(); + Box::new( + ys.shrink() + .map(|xs| xs.into_iter().map(|x| x.0).collect()) + .map(Value::Array) + .map(ArbitraryValue), + ) + } + Value::Map(ref _xs) => Box::new(vec![ArbitraryValue(Value::Map(vec![]))].into_iter()), + Value::Attribute { + ref data, + ref attributes, + } => Box::new( + vec![ArbitraryValue(Value::Attribute { + data: data.clone(), + attributes: attributes.clone(), + })] + .into_iter(), + ), + Value::Push { ref kind, ref data } => { + let mut ys = data + .iter() + .map(|x| ArbitraryValue(x.clone())) + .collect::>(); + ys.insert(0, ArbitraryValue(Value::SimpleString(kind.to_string()))); + Box::new( + ys.shrink() + .map(|xs| xs.into_iter().map(|x| x.0).collect()) + .map(Value::Array) + .map(ArbitraryValue), + ) + } + Value::SimpleString(ref status) => { + Box::new(status.shrink().map(Value::SimpleString).map(ArbitraryValue)) + } + Value::Double(i) => Box::new(i.shrink().map(Value::Double).map(ArbitraryValue)), + Value::Boolean(i) => Box::new(i.shrink().map(Value::Boolean).map(ArbitraryValue)), + Value::BigNumber(ref i) => { + Box::new(vec![ArbitraryValue(Value::BigNumber(i.clone()))].into_iter()) + } + Value::VerbatimString { + ref format, + ref text, + } => Box::new( + vec![ArbitraryValue(Value::VerbatimString { + format: format.clone(), + text: text.clone(), + })] + .into_iter(), + ), + } + } +} + +fn arbitrary_value(g: &mut Gen, recursive_size: usize) -> Value { + use quickcheck::Arbitrary; + if recursive_size == 0 { + Value::Nil + } else { + match u8::arbitrary(g) % 6 { + 0 => Value::Nil, + 1 => Value::Int(Arbitrary::arbitrary(g)), + 2 => Value::BulkString(Arbitrary::arbitrary(g)), + 3 => { + let size = { + let s = g.size(); + usize::arbitrary(g) % s + }; + Value::Array( + (0..size) + .map(|_| arbitrary_value(g, recursive_size / size)) + .collect(), + ) + } + 4 => { + let size = { + let s = g.size(); + usize::arbitrary(g) % s + }; + + let mut string = String::with_capacity(size); + for _ in 0..size { + let c = char::arbitrary(g); + if c.is_ascii_alphabetic() { + string.push(c); + } + } + + if string == "OK" { + Value::Okay + } else { + Value::SimpleString(string) + } + } + 5 => Value::Okay, + _ => unreachable!(), + } + } +} + +struct PartialAsyncRead { + inner: R, + ops: Box + Send>, +} + +impl AsyncRead for PartialAsyncRead +where + R: AsyncRead + Unpin, +{ + fn poll_read( + mut self: Pin<&mut Self>, + cx: &mut task::Context, + buf: &mut ReadBuf<'_>, + ) -> Poll> { + match self.ops.next() { + Some(PartialOp::Limited(n)) => { + let len = std::cmp::min(n, buf.remaining()); + buf.initialize_unfilled(); + let mut sub_buf = buf.take(len); + ready!(Pin::new(&mut self.inner).poll_read(cx, &mut sub_buf))?; + let filled = sub_buf.filled().len(); + buf.advance(filled); + Poll::Ready(Ok(())) + } + Some(PartialOp::Err(err)) => { + if err == io::ErrorKind::WouldBlock { + cx.waker().wake_by_ref(); + Poll::Pending + } else { + Err(io::Error::new( + err, + "error during read, generated by partial-io", + )) + .into() + } + } + Some(PartialOp::Unlimited) | None => Pin::new(&mut self.inner).poll_read(cx, buf), + } + } +} + +quickcheck! { + fn partial_io_parse(input: ArbitraryValue, seq: PartialWithErrors) -> () { + + let mut encoded_input = Vec::new(); + encode_value(&input.0, &mut encoded_input).unwrap(); + + let mut reader = &encoded_input[..]; + let mut partial_reader = PartialAsyncRead { inner: &mut reader, ops: Box::new(seq.into_iter()) }; + let mut decoder = combine::stream::Decoder::new(); + + let result = block_on_all(redis::parse_redis_value_async(&mut decoder, &mut partial_reader)); + assert!(result.as_ref().is_ok(), "{}", result.unwrap_err()); + assert_eq!( + result.unwrap(), + input.0, + ); + } +} diff --git a/glide-core/redis-rs/redis/tests/support/cluster.rs b/glide-core/redis-rs/redis/tests/support/cluster.rs new file mode 100644 index 0000000000..991331cfca --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/cluster.rs @@ -0,0 +1,792 @@ +#![cfg(feature = "cluster")] +#![allow(dead_code)] + +use std::convert::identity; +use std::env; +use std::process; +use std::thread::sleep; +use std::time::Duration; + +use redis::cluster_routing::RoutingInfo; +use redis::cluster_routing::SingleNodeRoutingInfo; +use redis::from_redis_value; + +#[cfg(feature = "cluster-async")] +use redis::aio::ConnectionLike; +#[cfg(feature = "cluster-async")] +use redis::cluster_async::Connect; +use redis::ConnectionInfo; +use redis::ProtocolVersion; +use redis::PushInfo; +use redis::RedisResult; +use redis::Value; +use tempfile::TempDir; + +use crate::support::{build_keys_and_certs_for_tls, Module}; + +#[cfg(feature = "tls-rustls")] +use super::{build_single_client, load_certs_from_file}; + +use super::use_protocol; +use super::RedisServer; +use super::TlsFilePaths; +use tokio::sync::mpsc; + +const LOCALHOST: &str = "127.0.0.1"; + +enum ClusterType { + Tcp, + TcpTls, +} + +impl ClusterType { + fn get_intended() -> ClusterType { + match env::var("REDISRS_SERVER_TYPE") + .ok() + .as_ref() + .map(|x| &x[..]) + { + Some("tcp") => ClusterType::Tcp, + Some("tcp+tls") => ClusterType::TcpTls, + Some(val) => { + panic!("Unknown server type {val:?}"); + } + None => ClusterType::Tcp, + } + } + + fn build_addr(port: u16) -> redis::ConnectionAddr { + match ClusterType::get_intended() { + ClusterType::Tcp => redis::ConnectionAddr::Tcp("127.0.0.1".into(), port), + ClusterType::TcpTls => redis::ConnectionAddr::TcpTls { + host: "127.0.0.1".into(), + port, + insecure: true, + tls_params: None, + }, + } + } +} + +fn port_in_use(addr: &str) -> bool { + let socket_addr: std::net::SocketAddr = addr.parse().expect("Invalid address"); + let socket = socket2::Socket::new( + socket2::Domain::for_address(socket_addr), + socket2::Type::STREAM, + None, + ) + .expect("Failed to create socket"); + + socket.connect(&socket_addr.into()).is_ok() +} + +pub struct RedisCluster { + pub servers: Vec, + pub folders: Vec, + pub tls_paths: Option, +} + +impl RedisCluster { + pub fn username() -> &'static str { + "hello" + } + + pub fn password() -> &'static str { + "world" + } + + pub fn client_name() -> &'static str { + "test_cluster_client" + } + + pub fn new(nodes: u16, replicas: u16) -> RedisCluster { + RedisCluster::with_modules(nodes, replicas, &[], false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls(nodes: u16, replicas: u16) -> RedisCluster { + RedisCluster::with_modules(nodes, replicas, &[], true) + } + + pub fn with_modules( + nodes: u16, + replicas: u16, + modules: &[Module], + mtls_enabled: bool, + ) -> RedisCluster { + let mut servers = vec![]; + let mut folders = vec![]; + let mut addrs = vec![]; + let start_port = 7000; + let mut tls_paths = None; + + let mut is_tls = false; + + if let ClusterType::TcpTls = ClusterType::get_intended() { + // Create a shared set of keys in cluster mode + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + let files = build_keys_and_certs_for_tls(&tempdir); + folders.push(tempdir); + tls_paths = Some(files); + is_tls = true; + } + + let max_attempts = 5; + + for node in 0..nodes { + let port = start_port + node; + + servers.push(RedisServer::new_with_addr_tls_modules_and_spawner( + ClusterType::build_addr(port), + None, + tls_paths.clone(), + mtls_enabled, + modules, + |cmd| { + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + let acl_path = tempdir.path().join("users.acl"); + let acl_content = format!( + "user {} on allcommands allkeys >{}", + Self::username(), + Self::password() + ); + std::fs::write(&acl_path, acl_content).expect("failed to write acl file"); + cmd.arg("--cluster-enabled") + .arg("yes") + .arg("--cluster-config-file") + .arg(tempdir.path().join("nodes.conf")) + .arg("--cluster-node-timeout") + .arg("5000") + .arg("--appendonly") + .arg("yes") + .arg("--aclfile") + .arg(&acl_path); + if is_tls { + cmd.arg("--tls-cluster").arg("yes"); + if replicas > 0 { + cmd.arg("--tls-replication").arg("yes"); + } + } + let addr = format!("127.0.0.1:{port}"); + cmd.current_dir(tempdir.path()); + folders.push(tempdir); + addrs.push(addr.clone()); + + let mut cur_attempts = 0; + loop { + let mut process = cmd.spawn().unwrap(); + sleep(Duration::from_millis(100)); + + match process.try_wait() { + Ok(Some(status)) => { + let err = + format!("redis server creation failed with status {status:?}"); + if cur_attempts == max_attempts { + panic!("{err}"); + } + eprintln!("Retrying: {err}"); + cur_attempts += 1; + } + Ok(None) => { + let max_attempts = 20; + let mut cur_attempts = 0; + loop { + if cur_attempts == max_attempts { + panic!("redis server creation failed: Port {port} closed") + } + if port_in_use(&addr) { + return process; + } + eprintln!("Waiting for redis process to initialize"); + sleep(Duration::from_millis(50)); + cur_attempts += 1; + } + } + Err(e) => { + panic!("Unexpected error in redis server creation {e}"); + } + } + } + }, + )); + } + + let mut cmd = process::Command::new("redis-cli"); + cmd.stdout(process::Stdio::null()) + .arg("--cluster") + .arg("create") + .args(&addrs); + if replicas > 0 { + cmd.arg("--cluster-replicas").arg(replicas.to_string()); + } + cmd.arg("--cluster-yes"); + + if is_tls { + if mtls_enabled { + if let Some(TlsFilePaths { + redis_crt, + redis_key, + ca_crt, + }) = &tls_paths + { + cmd.arg("--cert"); + cmd.arg(redis_crt); + cmd.arg("--key"); + cmd.arg(redis_key); + cmd.arg("--cacert"); + cmd.arg(ca_crt); + cmd.arg("--tls"); + } + } else { + cmd.arg("--tls").arg("--insecure"); + } + } + + let mut cur_attempts = 0; + loop { + let output = cmd.output().unwrap(); + if output.status.success() { + break; + } else { + let err = format!("Cluster creation failed: {output:?}"); + if cur_attempts == max_attempts { + panic!("{err}"); + } + eprintln!("Retrying: {err}"); + sleep(Duration::from_millis(50)); + cur_attempts += 1; + } + } + + let cluster = RedisCluster { + servers, + folders, + tls_paths, + }; + if replicas > 0 { + cluster.wait_for_replicas(replicas, mtls_enabled); + } + + wait_for_status_ok(&cluster); + cluster + } + + // parameter `_mtls_enabled` can only be used if `feature = tls-rustls` is active + #[allow(dead_code)] + fn wait_for_replicas(&self, replicas: u16, _mtls_enabled: bool) { + 'server: for server in &self.servers { + let conn_info = server.connection_info(); + eprintln!( + "waiting until {:?} knows required number of replicas", + conn_info.addr + ); + + #[cfg(feature = "tls-rustls")] + let client = + build_single_client(server.connection_info(), &self.tls_paths, _mtls_enabled) + .unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + let mut con = client.get_connection(None).unwrap(); + + // retry 500 times + for _ in 1..500 { + let value = redis::cmd("CLUSTER").arg("SLOTS").query(&mut con).unwrap(); + let slots: Vec> = redis::from_owned_redis_value(value).unwrap(); + + // all slots should have following items: + // [start slot range, end slot range, master's IP, replica1's IP, replica2's IP,... ] + if slots.iter().all(|slot| slot.len() >= 3 + replicas as usize) { + continue 'server; + } + + sleep(Duration::from_millis(100)); + } + + panic!("failed to create enough replicas"); + } + } + + pub fn stop(&mut self) { + for server in &mut self.servers { + server.stop(); + } + } + + pub fn iter_servers(&self) -> impl Iterator { + self.servers.iter() + } +} + +fn wait_for_status_ok(cluster: &RedisCluster) { + 'server: for server in &cluster.servers { + let log_file = RedisServer::log_file(&server.tempdir); + + for _ in 1..500 { + let contents = + std::fs::read_to_string(&log_file).expect("Should have been able to read the file"); + + if contents.contains("Cluster state changed: ok") { + continue 'server; + } + sleep(Duration::from_millis(20)); + } + panic!("failed to reach state change: OK"); + } +} + +impl Drop for RedisCluster { + fn drop(&mut self) { + self.stop() + } +} + +pub struct TestClusterContext { + pub cluster: RedisCluster, + pub client: redis::cluster::ClusterClient, + pub mtls_enabled: bool, + pub nodes: Vec, + pub protocol: ProtocolVersion, +} + +impl TestClusterContext { + pub fn new(nodes: u16, replicas: u16) -> TestClusterContext { + Self::new_with_cluster_client_builder(nodes, replicas, identity, false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls(nodes: u16, replicas: u16) -> TestClusterContext { + Self::new_with_cluster_client_builder(nodes, replicas, identity, true) + } + + pub fn new_with_cluster_client_builder( + nodes: u16, + replicas: u16, + initializer: F, + mtls_enabled: bool, + ) -> TestClusterContext + where + F: FnOnce(redis::cluster::ClusterClientBuilder) -> redis::cluster::ClusterClientBuilder, + { + let cluster = RedisCluster::new(nodes, replicas); + let initial_nodes: Vec = cluster + .iter_servers() + .map(RedisServer::connection_info) + .collect(); + let mut builder = redis::cluster::ClusterClientBuilder::new(initial_nodes.clone()) + .use_protocol(use_protocol()); + + #[cfg(feature = "tls-rustls")] + if mtls_enabled { + if let Some(tls_file_paths) = &cluster.tls_paths { + builder = builder.certs(load_certs_from_file(tls_file_paths)); + } + } + + builder = initializer(builder); + + let client = builder.build().unwrap(); + + TestClusterContext { + cluster, + client, + mtls_enabled, + nodes: initial_nodes, + protocol: use_protocol(), + } + } + + pub fn connection(&self) -> redis::cluster::ClusterConnection { + self.client.get_connection(None).unwrap() + } + + #[cfg(feature = "cluster-async")] + pub async fn async_connection( + &self, + push_sender: Option>, + ) -> redis::cluster_async::ClusterConnection { + self.client.get_async_connection(push_sender).await.unwrap() + } + + #[cfg(feature = "cluster-async")] + pub async fn async_generic_connection< + C: ConnectionLike + Connect + Clone + Send + Sync + Unpin + 'static, + >( + &self, + ) -> redis::cluster_async::ClusterConnection { + self.client + .get_async_generic_connection::() + .await + .unwrap() + } + + pub fn wait_for_cluster_up(&self) { + let mut con = self.connection(); + let mut c = redis::cmd("CLUSTER"); + c.arg("INFO"); + + for _ in 0..100 { + let r: String = c.query::(&mut con).unwrap(); + if r.starts_with("cluster_state:ok") { + return; + } + + sleep(Duration::from_millis(25)); + } + + panic!("failed waiting for cluster to be ready"); + } + + pub fn disable_default_user(&self) { + for server in &self.cluster.servers { + #[cfg(feature = "tls-rustls")] + let client = build_single_client( + server.connection_info(), + &self.cluster.tls_paths, + self.mtls_enabled, + ) + .unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + let mut con = client.get_connection(None).unwrap(); + let _: () = redis::cmd("ACL") + .arg("SETUSER") + .arg("default") + .arg("off") + .query(&mut con) + .unwrap(); + + // subsequent unauthenticated command should fail: + if let Ok(mut con) = client.get_connection(None) { + assert!(redis::cmd("PING").query::<()>(&mut con).is_err()); + } + } + } + + pub fn get_version(&self) -> super::Version { + let mut conn = self.connection(); + super::get_version(&mut conn) + } + + pub fn get_node_ids(&self) -> Vec { + let mut conn = self.connection(); + let nodes: Vec = redis::cmd("CLUSTER") + .arg("NODES") + .query::(&mut conn) + .unwrap() + .split('\n') + .map(|s| s.to_string()) + .collect(); + let node_ids: Vec = nodes + .iter() + .map(|node| node.split(' ').next().unwrap().to_string()) + .collect(); + node_ids + .iter() + .filter(|id| !id.is_empty()) + .cloned() + .collect() + } + + // Migrate half the slots from one node to another + pub async fn migrate_slots_from_node_to_another( + &self, + slot_distribution: Vec<(String, String, String, Vec>)>, + ) { + let slots_ranges_of_node_id = slot_distribution[0].3.clone(); + + let mut conn = self.async_connection(None).await; + + let from = slot_distribution[0].clone(); + let target = slot_distribution[1].clone(); + + let from_node_id = from.0.clone(); + let target_node_id = target.0.clone(); + + let from_route = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: from.1.clone(), + port: from.2.clone().parse::().unwrap(), + }); + let target_route = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: target.1.clone(), + port: target.2.clone().parse::().unwrap(), + }); + + // Migrate the slots + for range in slots_ranges_of_node_id { + let mut slots_of_nodes: std::ops::Range = range[0]..range[1]; + let number_of_slots = range[1] - range[0] + 1; + // Migrate half the slots + for _i in 0..(number_of_slots as f64 / 2.0).floor() as usize { + let slot = slots_of_nodes.next().unwrap(); + // Set the nodes to MIGRATING and IMPORTING + let mut set_cmd = redis::cmd("CLUSTER"); + set_cmd + .arg("SETSLOT") + .arg(slot) + .arg("IMPORTING") + .arg(from_node_id.clone()); + let result: RedisResult = + conn.route_command(&set_cmd, target_route.clone()).await; + match result { + Ok(_) => {} + Err(err) => { + println!( + "Failed to set slot {} to IMPORTING with error {}", + slot, err + ); + } + } + let mut set_cmd = redis::cmd("CLUSTER"); + set_cmd + .arg("SETSLOT") + .arg(slot) + .arg("MIGRATING") + .arg(target_node_id.clone()); + let result: RedisResult = + conn.route_command(&set_cmd, from_route.clone()).await; + match result { + Ok(_) => {} + Err(err) => { + println!( + "Failed to set slot {} to MIGRATING with error {}", + slot, err + ); + } + } + // Get a key from the slot + let mut get_key_cmd = redis::cmd("CLUSTER"); + get_key_cmd.arg("GETKEYSINSLOT").arg(slot).arg(1); + let result: RedisResult = + conn.route_command(&get_key_cmd, from_route.clone()).await; + let vec_string_result: Vec = match result { + Ok(val) => { + let val: Vec = from_redis_value(&val).unwrap(); + val + } + Err(err) => { + println!("Failed to get keys in slot {}: {:?}", slot, err); + continue; + } + }; + if vec_string_result.is_empty() { + continue; + } + let key = vec_string_result[0].clone(); + // Migrate the key, which will make the whole slot to move + let mut migrate_cmd = redis::cmd("MIGRATE"); + migrate_cmd + .arg(target.1.clone()) + .arg(target.2.clone()) + .arg(key.clone()) + .arg(0) + .arg(5000); + let result: RedisResult = + conn.route_command(&migrate_cmd, from_route.clone()).await; + + match result { + Ok(Value::Okay) => {} + Ok(Value::SimpleString(str)) => { + if str != "NOKEY" { + println!( + "Failed to migrate key {} to target node with status {}", + key, str + ); + } else { + println!("Key {} does not exist", key); + } + } + Ok(_) => {} + Err(err) => { + println!( + "Failed to migrate key {} to target node with error {}", + key, err + ); + } + } + // Tell the source and target nodes to propagate the slot change to the cluster + let mut setslot_cmd = redis::cmd("CLUSTER"); + setslot_cmd + .arg("SETSLOT") + .arg(slot) + .arg("NODE") + .arg(target_node_id.clone()); + let result: RedisResult = + conn.route_command(&setslot_cmd, target_route.clone()).await; + match result { + Ok(_) => {} + Err(err) => { + println!( + "Failed to set slot {} to target NODE with error {}", + slot, err + ); + } + }; + self.wait_for_connection_is_ready(&from_route) + .await + .unwrap(); + self.wait_for_connection_is_ready(&target_route) + .await + .unwrap(); + self.wait_for_cluster_up(); + } + } + } + + // Return the slots distribution of the cluster as a vector of tuples + // where the first element is the node id, seconed is host, third is port and the last element is a vector of slots ranges + pub fn get_slots_ranges_distribution( + &self, + cluster_nodes: &str, + ) -> Vec<(String, String, String, Vec>)> { + let nodes_string: Vec = cluster_nodes + .split('\n') + .map(|s| s.to_string()) + .filter(|s| !s.is_empty()) + .collect(); + let mut nodes: Vec> = vec![]; + for node in nodes_string { + let node_vec: Vec = node.split(' ').map(|s| s.to_string()).collect(); + if node_vec.last().unwrap() == "connected" || node_vec.last().unwrap() == "disconnected" + { + continue; + } else { + nodes.push(node_vec); + } + } + let mut slot_distribution = vec![]; + for node in &nodes { + let mut slots_ranges: Vec> = vec![]; + let mut slots_ranges_vec: Vec = vec![]; + let node_id = node[0].clone(); + let host_and_port: Vec = node[1].split(':').map(|s| s.to_string()).collect(); + let host = host_and_port[0].clone(); + let port = host_and_port[1].split('@').next().unwrap().to_string(); + let slots = node[8..].to_vec(); + for slot in slots { + if slot.contains("->") || slot.contains("<-") { + continue; + } + if slot.contains('-') { + let range: Vec = + slot.split('-').map(|s| s.parse::().unwrap()).collect(); + slots_ranges_vec.push(range[0]); + slots_ranges_vec.push(range[1]); + slots_ranges.push(slots_ranges_vec.clone()); + slots_ranges_vec.clear(); + } else { + let slot: u16 = slot.parse::().unwrap(); + slots_ranges_vec.push(slot); + slots_ranges_vec.push(slot); + slots_ranges.push(slots_ranges_vec.clone()); + slots_ranges_vec.clear(); + } + } + let parsed_node: (String, String, String, Vec>) = + (node_id, host, port, slots_ranges); + slot_distribution.push(parsed_node); + } + slot_distribution + } + + pub async fn get_masters(&self, cluster_nodes: &str) -> Vec> { + let mut masters = vec![]; + for line in cluster_nodes.lines() { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 3 { + continue; + } + if parts[2] == "master" || parts[2] == "myself,master" { + let id = parts[0]; + let host_and_port = parts[1].split(':'); + let host = host_and_port.clone().next().unwrap(); + let port = host_and_port + .clone() + .last() + .unwrap() + .split('@') + .next() + .unwrap(); + masters.push(vec![id.to_string(), host.to_string(), port.to_string()]); + } + } + masters + } + + pub async fn get_replicas(&self, cluster_nodes: &str) -> Vec> { + let mut replicas = vec![]; + for line in cluster_nodes.lines() { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() < 3 { + continue; + } + if parts[2] == "slave" || parts[2] == "myself,slave" { + let id = parts[0]; + let host_and_port = parts[1].split(':'); + let host = host_and_port.clone().next().unwrap(); + let port = host_and_port + .clone() + .last() + .unwrap() + .split('@') + .next() + .unwrap(); + replicas.push(vec![id.to_string(), host.to_string(), port.to_string()]); + } + } + replicas + } + + pub async fn get_cluster_nodes(&self) -> String { + let mut conn = self.async_connection(None).await; + let mut cmd = redis::cmd("CLUSTER"); + cmd.arg("NODES"); + let res: RedisResult = conn + .route_command(&cmd, RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + .await; + let res: String = from_redis_value(&res.unwrap()).unwrap(); + res + } + + pub async fn wait_for_fail_to_finish(&self, route: &RoutingInfo) -> RedisResult<()> { + for _ in 0..500 { + let mut conn = self.async_connection(None).await; + let cmd = redis::cmd("PING"); + let res: RedisResult = conn.route_command(&cmd, route.clone()).await; + if res.is_err() { + return Ok(()); + } + sleep(Duration::from_millis(50)); + } + Err(redis::RedisError::from(( + redis::ErrorKind::IoError, + "Failed to get connection", + ))) + } + + pub async fn wait_for_connection_is_ready(&self, route: &RoutingInfo) -> RedisResult<()> { + let mut i = 1; + while i < 1000 { + let mut conn = self.async_connection(None).await; + let cmd = redis::cmd("PING"); + let res: RedisResult = conn.route_command(&cmd, route.clone()).await; + if res.is_ok() { + return Ok(()); + } + sleep(Duration::from_millis(i * 10)); + i += 10; + } + Err(redis::RedisError::from(( + redis::ErrorKind::IoError, + "Failed to get connection", + ))) + } +} diff --git a/glide-core/redis-rs/redis/tests/support/mock_cluster.rs b/glide-core/redis-rs/redis/tests/support/mock_cluster.rs new file mode 100644 index 0000000000..ce91988cef --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/mock_cluster.rs @@ -0,0 +1,487 @@ +use redis::{ + cluster::{self, ClusterClient, ClusterClientBuilder}, + ErrorKind, FromRedisValue, GlideConnectionOptions, RedisError, +}; + +use std::{ + collections::HashMap, + net::{IpAddr, Ipv4Addr, SocketAddr}, + sync::{ + atomic::{AtomicUsize, Ordering}, + Arc, RwLock, + }, + time::Duration, +}; + +use { + once_cell::sync::Lazy, + redis::{IntoConnectionInfo, RedisResult, Value}, +}; + +#[cfg(feature = "cluster-async")] +use redis::{aio, cluster_async, RedisFuture}; + +#[cfg(feature = "cluster-async")] +use futures::future; + +#[cfg(feature = "cluster-async")] +use tokio::runtime::Runtime; + +type Handler = Arc Result<(), RedisResult> + Send + Sync>; + +pub struct MockConnectionBehavior { + pub id: String, + pub handler: Handler, + pub connection_id_provider: AtomicUsize, + pub returned_ip_type: ConnectionIPReturnType, + pub return_connection_err: ShouldReturnConnectionError, +} + +impl MockConnectionBehavior { + fn new(id: &str, handler: Handler) -> Self { + Self { + id: id.to_string(), + handler, + connection_id_provider: AtomicUsize::new(0), + returned_ip_type: ConnectionIPReturnType::default(), + return_connection_err: ShouldReturnConnectionError::default(), + } + } + + #[must_use] + pub fn register_new(id: &str, handler: Handler) -> RemoveHandler { + get_behaviors().insert(id.to_string(), Self::new(id, handler)); + RemoveHandler(vec![id.to_string()]) + } + + fn get_handler(&self) -> Handler { + self.handler.clone() + } +} + +pub fn modify_mock_connection_behavior(name: &str, func: impl FnOnce(&mut MockConnectionBehavior)) { + func( + get_behaviors() + .get_mut(name) + .expect("Handler `{name}` was not installed"), + ); +} + +pub fn get_mock_connection_handler(name: &str) -> Handler { + MOCK_CONN_BEHAVIORS + .read() + .unwrap() + .get(name) + .expect("Handler `{name}` was not installed") + .get_handler() +} + +pub fn get_mock_connection(name: &str, id: usize) -> MockConnection { + get_mock_connection_with_port(name, id, 6379) +} + +pub fn get_mock_connection_with_port(name: &str, id: usize, port: u16) -> MockConnection { + MockConnection { + id, + handler: get_mock_connection_handler(name), + port, + } +} + +static MOCK_CONN_BEHAVIORS: Lazy>> = + Lazy::new(Default::default); + +fn get_behaviors() -> std::sync::RwLockWriteGuard<'static, HashMap> +{ + MOCK_CONN_BEHAVIORS.write().unwrap() +} + +#[derive(Default)] +pub enum ConnectionIPReturnType { + /// New connections' IP will be returned as None + #[default] + None, + /// Creates connections with the specified IP + Specified(IpAddr), + /// Each new connection will be created with a different IP based on the passed atomic integer + Different(AtomicUsize), +} + +#[derive(Default)] +pub enum ShouldReturnConnectionError { + /// Don't return a connection error + #[default] + No, + /// Always return a connection error + Yes, + /// Return connection error when the internal index is an odd number + OnOddIdx(AtomicUsize), +} + +#[derive(Clone)] +pub struct MockConnection { + pub id: usize, + pub handler: Handler, + pub port: u16, +} + +#[cfg(feature = "cluster-async")] +impl cluster_async::Connect for MockConnection { + fn connect<'a, T>( + info: T, + _response_timeout: Duration, + _connection_timeout: Duration, + _socket_addr: Option, + _glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (Self, Option)> + where + T: IntoConnectionInfo + Send + 'a, + { + let info = info.into_connection_info().unwrap(); + + let (name, port) = match &info.addr { + redis::ConnectionAddr::Tcp(addr, port) => (addr, *port), + _ => unreachable!(), + }; + let binding = MOCK_CONN_BEHAVIORS.read().unwrap(); + let conn_utils = binding + .get(name) + .unwrap_or_else(|| panic!("MockConnectionUtils for `{name}` were not installed")); + let conn_err = Box::pin(future::err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "mock-io-error", + )))); + match &conn_utils.return_connection_err { + ShouldReturnConnectionError::No => {} + ShouldReturnConnectionError::Yes => return conn_err, + ShouldReturnConnectionError::OnOddIdx(curr_idx) => { + if curr_idx.fetch_add(1, Ordering::SeqCst) % 2 != 0 { + // raise an error on each odd number + return conn_err; + } + } + } + + let ip = match &conn_utils.returned_ip_type { + ConnectionIPReturnType::Specified(ip) => Some(*ip), + ConnectionIPReturnType::Different(ip_getter) => { + let first_ip_num = ip_getter.fetch_add(1, Ordering::SeqCst) as u8; + Some(IpAddr::V4(Ipv4Addr::new(first_ip_num, 0, 0, 0))) + } + ConnectionIPReturnType::None => None, + }; + + Box::pin(future::ok(( + MockConnection { + id: conn_utils + .connection_id_provider + .fetch_add(1, Ordering::SeqCst), + handler: conn_utils.get_handler(), + port, + }, + ip, + ))) + } +} + +impl cluster::Connect for MockConnection { + fn connect<'a, T>(info: T, _timeout: Option) -> RedisResult + where + T: IntoConnectionInfo, + { + let info = info.into_connection_info().unwrap(); + + let (name, port) = match &info.addr { + redis::ConnectionAddr::Tcp(addr, port) => (addr, *port), + _ => unreachable!(), + }; + let binding = MOCK_CONN_BEHAVIORS.read().unwrap(); + let conn_utils = binding + .get(name) + .unwrap_or_else(|| panic!("MockConnectionUtils for `{name}` were not installed")); + Ok(MockConnection { + id: conn_utils + .connection_id_provider + .fetch_add(1, Ordering::SeqCst), + handler: conn_utils.get_handler(), + port, + }) + } + + fn send_packed_command(&mut self, _cmd: &[u8]) -> RedisResult<()> { + Ok(()) + } + + fn set_write_timeout(&self, _dur: Option) -> RedisResult<()> { + Ok(()) + } + + fn set_read_timeout(&self, _dur: Option) -> RedisResult<()> { + Ok(()) + } + + fn recv_response(&mut self) -> RedisResult { + Ok(Value::Nil) + } +} + +pub fn contains_slice(xs: &[u8], ys: &[u8]) -> bool { + for i in 0..xs.len() { + if xs[i..].starts_with(ys) { + return true; + } + } + false +} + +pub fn respond_startup(name: &str, cmd: &[u8]) -> Result<(), RedisResult> { + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else if contains_slice(cmd, b"READONLY") { + Err(Ok(Value::SimpleString("OK".into()))) + } else { + Ok(()) + } +} + +#[derive(Clone, Debug)] +pub struct MockSlotRange { + pub primary_port: u16, + pub replica_ports: Vec, + pub slot_range: std::ops::Range, +} + +pub fn respond_startup_with_replica(name: &str, cmd: &[u8]) -> Result<(), RedisResult> { + respond_startup_with_replica_using_config(name, cmd, None) +} + +pub fn respond_startup_two_nodes(name: &str, cmd: &[u8]) -> Result<(), RedisResult> { + respond_startup_with_config(name, cmd, None, false) +} + +pub fn create_topology_from_config(name: &str, slots_config: Vec) -> Value { + let slots_vec = slots_config + .into_iter() + .map(|slot_config| { + let mut config = vec![ + Value::Int(slot_config.slot_range.start as i64), + Value::Int(slot_config.slot_range.end as i64), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(slot_config.primary_port as i64), + ]), + ]; + config.extend(slot_config.replica_ports.into_iter().map(|replica_port| { + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(replica_port as i64), + ]) + })); + Value::Array(config) + }) + .collect(); + Value::Array(slots_vec) +} + +pub fn respond_startup_with_replica_using_config( + name: &str, + cmd: &[u8], + slots_config: Option>, +) -> Result<(), RedisResult> { + respond_startup_with_config(name, cmd, slots_config, true) +} + +/// If the configuration isn't provided, a configuration with two primary nodes, with or without replicas, will be used. +pub fn respond_startup_with_config( + name: &str, + cmd: &[u8], + slots_config: Option>, + with_replicas: bool, +) -> Result<(), RedisResult> { + let slots_config = slots_config.unwrap_or(if with_replicas { + vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8192..16383), + }, + ] + } else { + vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![], + slot_range: (8192..16383), + }, + ] + }); + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + let slots = create_topology_from_config(name, slots_config); + Err(Ok(slots)) + } else if contains_slice(cmd, b"READONLY") { + Err(Ok(Value::SimpleString("OK".into()))) + } else { + Ok(()) + } +} + +#[cfg(feature = "cluster-async")] +impl aio::ConnectionLike for MockConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a redis::Cmd) -> RedisFuture<'a, Value> { + Box::pin(future::ready( + (self.handler)(&cmd.get_packed_command(), self.port) + .expect_err("Handler did not specify a response"), + )) + } + + fn req_packed_commands<'a>( + &'a mut self, + _pipeline: &'a redis::Pipeline, + _offset: usize, + _count: usize, + ) -> RedisFuture<'a, Vec> { + Box::pin(future::ok(vec![])) + } + + fn get_db(&self) -> i64 { + 0 + } + + fn is_closed(&self) -> bool { + false + } +} + +impl redis::ConnectionLike for MockConnection { + fn req_packed_command(&mut self, cmd: &[u8]) -> RedisResult { + (self.handler)(cmd, self.port).expect_err("Handler did not specify a response") + } + + fn req_packed_commands( + &mut self, + cmd: &[u8], + offset: usize, + _count: usize, + ) -> RedisResult> { + let res = (self.handler)(cmd, self.port).expect_err("Handler did not specify a response"); + match res { + Err(err) => Err(err), + Ok(res) => { + if let Value::Array(results) = res { + match results.into_iter().nth(offset) { + Some(Value::Array(res)) => Ok(res), + _ => Err((ErrorKind::ResponseError, "non-array response").into()), + } + } else { + Err(( + ErrorKind::ResponseError, + "non-array response", + String::from_owned_redis_value(res).unwrap(), + ) + .into()) + } + } + } + } + + fn get_db(&self) -> i64 { + 0 + } + + fn check_connection(&mut self) -> bool { + true + } + + fn is_open(&self) -> bool { + true + } +} + +pub struct MockEnv { + #[cfg(feature = "cluster-async")] + pub runtime: Runtime, + pub client: redis::cluster::ClusterClient, + pub connection: redis::cluster::ClusterConnection, + #[cfg(feature = "cluster-async")] + pub async_connection: redis::cluster_async::ClusterConnection, + #[allow(unused)] + pub handler: RemoveHandler, +} + +pub struct RemoveHandler(Vec); + +impl Drop for RemoveHandler { + fn drop(&mut self) { + for id in &self.0 { + get_behaviors().remove(id); + } + } +} + +impl MockEnv { + pub fn new( + id: &str, + handler: impl Fn(&[u8], u16) -> Result<(), RedisResult> + Send + Sync + 'static, + ) -> Self { + Self::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{id}")]), + id, + handler, + ) + } + + pub fn with_client_builder( + client_builder: ClusterClientBuilder, + id: &str, + handler: impl Fn(&[u8], u16) -> Result<(), RedisResult> + Send + Sync + 'static, + ) -> Self { + #[cfg(feature = "cluster-async")] + let runtime = tokio::runtime::Builder::new_current_thread() + .enable_io() + .enable_time() + .build() + .unwrap(); + + let id = id.to_string(); + let handler = MockConnectionBehavior::register_new( + &id, + Arc::new(move |cmd, port| handler(cmd, port)), + ); + let client = client_builder.build().unwrap(); + let connection = client.get_generic_connection(None).unwrap(); + #[cfg(feature = "cluster-async")] + let async_connection = runtime + .block_on(client.get_async_generic_connection()) + .unwrap(); + MockEnv { + #[cfg(feature = "cluster-async")] + runtime, + client, + connection, + #[cfg(feature = "cluster-async")] + async_connection, + handler, + } + } +} diff --git a/glide-core/redis-rs/redis/tests/support/mod.rs b/glide-core/redis-rs/redis/tests/support/mod.rs new file mode 100644 index 0000000000..335cd045de --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/mod.rs @@ -0,0 +1,887 @@ +#![allow(dead_code)] + +use std::path::Path; +use std::{ + env, fs, io, net::SocketAddr, net::TcpListener, path::PathBuf, process, thread::sleep, + time::Duration, +}; +#[cfg(feature = "tls-rustls")] +use std::{ + fs::File, + io::{BufReader, Read}, +}; + +#[cfg(feature = "aio")] +use futures::Future; +use redis::{ConnectionAddr, InfoDict, Pipeline, ProtocolVersion, RedisConnectionInfo, Value}; + +#[cfg(feature = "tls-rustls")] +use redis::{ClientTlsConfig, TlsCertificates}; + +use socket2::{Domain, Socket, Type}; +use tempfile::TempDir; + +#[cfg(feature = "aio")] +use redis::GlideConnectionOptions; + +pub fn use_protocol() -> ProtocolVersion { + if env::var("PROTOCOL").unwrap_or_default() == "RESP3" { + ProtocolVersion::RESP3 + } else { + ProtocolVersion::RESP2 + } +} + +pub fn current_thread_runtime() -> tokio::runtime::Runtime { + let mut builder = tokio::runtime::Builder::new_current_thread(); + + #[cfg(feature = "aio")] + builder.enable_io(); + + builder.enable_time(); + + builder.build().unwrap() +} + +#[cfg(feature = "aio")] +pub fn block_on_all(f: F) -> F::Output +where + F: Future>, +{ + use std::panic; + use std::sync::atomic::{AtomicBool, Ordering}; + + static CHECK: AtomicBool = AtomicBool::new(false); + + // TODO - this solution is purely single threaded, and won't work on multiple threads at the same time. + // This is needed because Tokio's Runtime silently ignores panics - https://users.rust-lang.org/t/tokio-runtime-what-happens-when-a-thread-panics/95819 + // Once Tokio stabilizes the `unhandled_panic` field on the runtime builder, it should be used instead. + panic::set_hook(Box::new(|panic| { + println!("Panic: {panic}"); + CHECK.store(true, Ordering::Relaxed); + })); + + // This continuously query the flag, in order to abort ASAP after a panic. + let check_future = futures_util::FutureExt::fuse(async { + loop { + if CHECK.load(Ordering::Relaxed) { + return Err((redis::ErrorKind::IoError, "panic was caught").into()); + } + futures_time::task::sleep(futures_time::time::Duration::from_millis(1)).await; + } + }); + let f = futures_util::FutureExt::fuse(f); + futures::pin_mut!(f, check_future); + + let res = current_thread_runtime().block_on(async { + futures::select! {res = f => res, err = check_future => err} + }); + + let _ = panic::take_hook(); + if CHECK.swap(false, Ordering::Relaxed) { + panic!("Internal thread panicked"); + } + + res +} + +#[cfg(feature = "async-std-comp")] +pub fn block_on_all_using_async_std(f: F) -> F::Output +where + F: Future, +{ + async_std::task::block_on(f) +} + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +mod cluster; + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +mod mock_cluster; + +mod util; +#[allow(unused_imports)] +pub use self::util::*; + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +#[allow(unused_imports)] +pub use self::cluster::*; + +#[cfg(any(feature = "cluster", feature = "cluster-async"))] +#[allow(unused_imports)] +pub use self::mock_cluster::*; + +#[cfg(feature = "sentinel")] +mod sentinel; + +#[cfg(feature = "sentinel")] +#[allow(unused_imports)] +pub use self::sentinel::*; + +#[derive(PartialEq)] +enum ServerType { + Tcp { tls: bool }, + Unix, +} + +pub enum Module { + Json, +} + +pub struct RedisServer { + pub process: process::Child, + pub(crate) tempdir: tempfile::TempDir, + pub(crate) addr: redis::ConnectionAddr, + pub(crate) tls_paths: Option, +} + +impl ServerType { + fn get_intended() -> ServerType { + match env::var("REDISRS_SERVER_TYPE") + .ok() + .as_ref() + .map(|x| &x[..]) + { + Some("tcp") => ServerType::Tcp { tls: false }, + Some("tcp+tls") => ServerType::Tcp { tls: true }, + Some("unix") => ServerType::Unix, + Some(val) => { + panic!("Unknown server type {val:?}"); + } + None => ServerType::Tcp { tls: false }, + } + } +} + +impl RedisServer { + pub fn new() -> RedisServer { + RedisServer::with_modules(&[], false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls() -> RedisServer { + RedisServer::with_modules(&[], true) + } + + pub fn get_addr(port: u16) -> ConnectionAddr { + let server_type = ServerType::get_intended(); + match server_type { + ServerType::Tcp { tls } => { + if tls { + redis::ConnectionAddr::TcpTls { + host: "127.0.0.1".to_string(), + port, + insecure: true, + tls_params: None, + } + } else { + redis::ConnectionAddr::Tcp("127.0.0.1".to_string(), port) + } + } + ServerType::Unix => { + let (a, b) = rand::random::<(u64, u64)>(); + let path = format!("/tmp/redis-rs-test-{a}-{b}.sock"); + redis::ConnectionAddr::Unix(PathBuf::from(&path)) + } + } + } + + pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> RedisServer { + // this is technically a race but we can't do better with + // the tools that redis gives us :( + let redis_port = get_random_available_port(); + let addr = RedisServer::get_addr(redis_port); + + RedisServer::new_with_addr_tls_modules_and_spawner( + addr, + None, + None, + mtls_enabled, + modules, + |cmd| { + cmd.spawn() + .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) + }, + ) + } + + pub fn new_with_addr_and_modules( + addr: redis::ConnectionAddr, + modules: &[Module], + mtls_enabled: bool, + ) -> RedisServer { + RedisServer::new_with_addr_tls_modules_and_spawner( + addr, + None, + None, + mtls_enabled, + modules, + |cmd| { + cmd.spawn() + .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) + }, + ) + } + + pub fn new_with_addr_tls_modules_and_spawner< + F: FnOnce(&mut process::Command) -> process::Child, + >( + addr: redis::ConnectionAddr, + config_file: Option<&Path>, + tls_paths: Option, + mtls_enabled: bool, + modules: &[Module], + spawner: F, + ) -> RedisServer { + let mut redis_cmd = process::Command::new("redis-server"); + + if let Some(config_path) = config_file { + redis_cmd.arg(config_path); + } + + // Load Redis Modules + for module in modules { + match module { + Module::Json => { + redis_cmd + .arg("--loadmodule") + .arg(env::var("REDIS_RS_REDIS_JSON_PATH").expect( + "Unable to find path to RedisJSON at REDIS_RS_REDIS_JSON_PATH, is it set?", + )); + } + }; + } + + redis_cmd + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + redis_cmd.arg("--logfile").arg(Self::log_file(&tempdir)); + match addr { + redis::ConnectionAddr::Tcp(ref bind, server_port) => { + redis_cmd + .arg("--port") + .arg(server_port.to_string()) + .arg("--bind") + .arg(bind); + + RedisServer { + process: spawner(&mut redis_cmd), + tempdir, + addr, + tls_paths: None, + } + } + redis::ConnectionAddr::TcpTls { ref host, port, .. } => { + let tls_paths = tls_paths.unwrap_or_else(|| build_keys_and_certs_for_tls(&tempdir)); + + let auth_client = if mtls_enabled { "yes" } else { "no" }; + + // prepare redis with TLS + redis_cmd + .arg("--tls-port") + .arg(port.to_string()) + .arg("--port") + .arg("0") + .arg("--tls-cert-file") + .arg(&tls_paths.redis_crt) + .arg("--tls-key-file") + .arg(&tls_paths.redis_key) + .arg("--tls-ca-cert-file") + .arg(&tls_paths.ca_crt) + .arg("--tls-auth-clients") + .arg(auth_client) + .arg("--bind") + .arg(host); + + // Insecure only disabled if `mtls` is enabled + let insecure = !mtls_enabled; + + let addr = redis::ConnectionAddr::TcpTls { + host: host.clone(), + port, + insecure, + tls_params: None, + }; + + RedisServer { + process: spawner(&mut redis_cmd), + tempdir, + addr, + tls_paths: Some(tls_paths), + } + } + redis::ConnectionAddr::Unix(ref path) => { + redis_cmd + .arg("--port") + .arg("0") + .arg("--unixsocket") + .arg(path); + RedisServer { + process: spawner(&mut redis_cmd), + tempdir, + addr, + tls_paths: None, + } + } + } + } + + pub fn client_addr(&self) -> &redis::ConnectionAddr { + &self.addr + } + + pub fn connection_info(&self) -> redis::ConnectionInfo { + redis::ConnectionInfo { + addr: self.client_addr().clone(), + redis: RedisConnectionInfo { + protocol: use_protocol(), + ..Default::default() + }, + } + } + + pub fn stop(&mut self) { + let _ = self.process.kill(); + let _ = self.process.wait(); + if let redis::ConnectionAddr::Unix(ref path) = *self.client_addr() { + fs::remove_file(path).ok(); + } + } + + pub fn log_file(tempdir: &TempDir) -> PathBuf { + tempdir.path().join("redis.log") + } +} + +/// Finds a random open port available for listening at, by spawning a TCP server with +/// port "zero" (which prompts the OS to just use any available port). Between calling +/// this function and trying to bind to this port, the port may be given to another +/// process, so this must be used with care (since here we only use it for tests, it's +/// mostly okay). +pub fn get_random_available_port() -> u16 { + let addr = &"127.0.0.1:0".parse::().unwrap().into(); + let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap(); + socket.set_reuse_address(true).unwrap(); + socket.bind(addr).unwrap(); + socket.listen(1).unwrap(); + let listener = TcpListener::from(socket); + listener.local_addr().unwrap().port() +} + +impl Drop for RedisServer { + fn drop(&mut self) { + self.stop() + } +} + +pub struct TestContext { + pub server: RedisServer, + pub client: redis::Client, + pub protocol: ProtocolVersion, +} + +pub(crate) fn is_tls_enabled() -> bool { + cfg!(all(feature = "tls-rustls", not(feature = "tls-native-tls"))) +} + +impl TestContext { + pub fn new() -> TestContext { + TestContext::with_modules(&[], false) + } + + #[cfg(feature = "tls-rustls")] + pub fn new_with_mtls() -> TestContext { + Self::with_modules(&[], true) + } + + fn connect_with_retries(client: &redis::Client) { + let mut con; + + let millisecond = Duration::from_millis(1); + let mut retries = 0; + loop { + match client.get_connection(None) { + Err(err) => { + if err.is_connection_refusal() { + sleep(millisecond); + retries += 1; + if retries > 100000 { + panic!("Tried to connect too many times, last error: {err}"); + } + } else { + panic!("Could not connect: {err}"); + } + } + Ok(x) => { + con = x; + break; + } + } + } + redis::cmd("FLUSHDB").execute(&mut con); + } + + pub fn with_tls(tls_files: TlsFilePaths, mtls_enabled: bool) -> TestContext { + let redis_port = get_random_available_port(); + let addr: ConnectionAddr = RedisServer::get_addr(redis_port); + + let server = RedisServer::new_with_addr_tls_modules_and_spawner( + addr, + None, + Some(tls_files), + mtls_enabled, + &[], + |cmd| { + cmd.spawn() + .unwrap_or_else(|err| panic!("Failed to run {cmd:?}: {err}")) + }, + ); + + #[cfg(feature = "tls-rustls")] + let client = + build_single_client(server.connection_info(), &server.tls_paths, mtls_enabled).unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + Self::connect_with_retries(&client); + + TestContext { + server, + client, + protocol: use_protocol(), + } + } + + pub fn with_modules(modules: &[Module], mtls_enabled: bool) -> TestContext { + let server = RedisServer::with_modules(modules, mtls_enabled); + + #[cfg(feature = "tls-rustls")] + let client = + build_single_client(server.connection_info(), &server.tls_paths, mtls_enabled).unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()).unwrap(); + + Self::connect_with_retries(&client); + + TestContext { + server, + client, + protocol: use_protocol(), + } + } + + pub fn with_client_name(clientname: &str) -> TestContext { + let server = RedisServer::with_modules(&[], false); + let con_info = redis::ConnectionInfo { + addr: server.client_addr().clone(), + redis: redis::RedisConnectionInfo { + client_name: Some(clientname.to_string()), + ..Default::default() + }, + }; + + #[cfg(feature = "tls-rustls")] + let client = build_single_client(con_info, &server.tls_paths, false).unwrap(); + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(con_info).unwrap(); + + Self::connect_with_retries(&client); + + TestContext { + server, + client, + protocol: use_protocol(), + } + } + + pub fn connection(&self) -> redis::Connection { + self.client.get_connection(None).unwrap() + } + + #[cfg(feature = "aio")] + pub async fn async_connection(&self) -> redis::RedisResult { + self.client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + } + + #[cfg(feature = "aio")] + pub async fn async_pubsub(&self) -> redis::RedisResult { + self.client.get_async_pubsub().await + } + + #[cfg(feature = "async-std-comp")] + pub async fn async_connection_async_std( + &self, + ) -> redis::RedisResult { + self.client + .get_multiplexed_async_std_connection(GlideConnectionOptions::default()) + .await + } + + pub fn stop_server(&mut self) { + self.server.stop(); + } + + #[cfg(feature = "tokio-comp")] + pub async fn multiplexed_async_connection( + &self, + ) -> redis::RedisResult { + self.multiplexed_async_connection_tokio().await + } + + #[cfg(feature = "tokio-comp")] + pub async fn multiplexed_async_connection_tokio( + &self, + ) -> redis::RedisResult { + self.client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await + } + + #[cfg(feature = "async-std-comp")] + pub async fn multiplexed_async_connection_async_std( + &self, + ) -> redis::RedisResult { + self.client + .get_multiplexed_async_std_connection(GlideConnectionOptions::default()) + .await + } + + pub fn get_version(&self) -> Version { + let mut conn = self.connection(); + get_version(&mut conn) + } +} + +fn encode_iter(values: &[Value], writer: &mut W, prefix: &str) -> io::Result<()> +where + W: io::Write, +{ + write!(writer, "{}{}\r\n", prefix, values.len())?; + for val in values.iter() { + encode_value(val, writer)?; + } + Ok(()) +} +fn encode_map(values: &[(Value, Value)], writer: &mut W, prefix: &str) -> io::Result<()> +where + W: io::Write, +{ + write!(writer, "{}{}\r\n", prefix, values.len())?; + for (k, v) in values.iter() { + encode_value(k, writer)?; + encode_value(v, writer)?; + } + Ok(()) +} +pub fn encode_value(value: &Value, writer: &mut W) -> io::Result<()> +where + W: io::Write, +{ + #![allow(clippy::write_with_newline)] + match *value { + Value::Nil => write!(writer, "$-1\r\n"), + Value::Int(val) => write!(writer, ":{val}\r\n"), + Value::BulkString(ref val) => { + write!(writer, "${}\r\n", val.len())?; + writer.write_all(val)?; + writer.write_all(b"\r\n") + } + Value::Array(ref values) => encode_iter(values, writer, "*"), + Value::Okay => write!(writer, "+OK\r\n"), + Value::SimpleString(ref s) => write!(writer, "+{s}\r\n"), + Value::Map(ref values) => encode_map(values, writer, "%"), + Value::Attribute { + ref data, + ref attributes, + } => { + encode_map(attributes, writer, "|")?; + encode_value(data, writer)?; + Ok(()) + } + Value::Set(ref values) => encode_iter(values, writer, "~"), + Value::Double(val) => write!(writer, ",{}\r\n", val), + Value::Boolean(v) => { + if v { + write!(writer, "#t\r\n") + } else { + write!(writer, "#f\r\n") + } + } + Value::VerbatimString { + ref format, + ref text, + } => { + // format is always 3 bytes + write!(writer, "={}\r\n{}:{}\r\n", 3 + text.len(), format, text) + } + Value::BigNumber(ref val) => write!(writer, "({}\r\n", val), + Value::Push { ref kind, ref data } => { + write!(writer, ">{}\r\n+{kind}\r\n", data.len() + 1)?; + for val in data.iter() { + encode_value(val, writer)?; + } + Ok(()) + } + } +} + +#[derive(Clone, Debug)] +pub struct TlsFilePaths { + pub(crate) redis_crt: PathBuf, + pub(crate) redis_key: PathBuf, + pub(crate) ca_crt: PathBuf, +} + +pub fn build_keys_and_certs_for_tls(tempdir: &TempDir) -> TlsFilePaths { + // Based on shell script in redis's server tests + // https://github.com/redis/redis/blob/8c291b97b95f2e011977b522acf77ead23e26f55/utils/gen-test-certs.sh + let ca_crt = tempdir.path().join("ca.crt"); + let ca_key = tempdir.path().join("ca.key"); + let ca_serial = tempdir.path().join("ca.txt"); + let redis_crt = tempdir.path().join("redis.crt"); + let redis_key = tempdir.path().join("redis.key"); + let ext_file = tempdir.path().join("openssl.cnf"); + + fn make_key>(name: S, size: usize) { + process::Command::new("openssl") + .arg("genrsa") + .arg("-out") + .arg(name) + .arg(format!("{size}")) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl") + .wait() + .expect("failed to create key"); + } + + // Build CA Key + make_key(&ca_key, 4096); + + // Build redis key + make_key(&redis_key, 2048); + + // Build CA Cert + process::Command::new("openssl") + .arg("req") + .arg("-x509") + .arg("-new") + .arg("-nodes") + .arg("-sha256") + .arg("-key") + .arg(&ca_key) + .arg("-days") + .arg("3650") + .arg("-subj") + .arg("/O=Redis Test/CN=Certificate Authority") + .arg("-out") + .arg(&ca_crt) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl") + .wait() + .expect("failed to create CA cert"); + + // Build x509v3 extensions file + fs::write( + &ext_file, + b"keyUsage = digitalSignature, keyEncipherment\n\ + subjectAltName = @alt_names\n\ + [alt_names]\n\ + IP.1 = 127.0.0.1\n", + ) + .expect("failed to create x509v3 extensions file"); + + // Read redis key + let mut key_cmd = process::Command::new("openssl") + .arg("req") + .arg("-new") + .arg("-sha256") + .arg("-subj") + .arg("/O=Redis Test/CN=Generic-cert") + .arg("-key") + .arg(&redis_key) + .stdout(process::Stdio::piped()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl"); + + // build redis cert + process::Command::new("openssl") + .arg("x509") + .arg("-req") + .arg("-sha256") + .arg("-CA") + .arg(&ca_crt) + .arg("-CAkey") + .arg(&ca_key) + .arg("-CAserial") + .arg(&ca_serial) + .arg("-CAcreateserial") + .arg("-days") + .arg("365") + .arg("-extfile") + .arg(&ext_file) + .arg("-out") + .arg(&redis_crt) + .stdin(key_cmd.stdout.take().expect("should have stdout")) + .stdout(process::Stdio::null()) + .stderr(process::Stdio::null()) + .spawn() + .expect("failed to spawn openssl") + .wait() + .expect("failed to create redis cert"); + + key_cmd.wait().expect("failed to create redis key"); + + TlsFilePaths { + redis_crt, + redis_key, + ca_crt, + } +} + +pub type Version = (u16, u16, u16); + +fn get_version(conn: &mut impl redis::ConnectionLike) -> Version { + let info: InfoDict = redis::Cmd::new().arg("INFO").query(conn).unwrap(); + let version: String = info.get("redis_version").unwrap(); + let versions: Vec = version + .split('.') + .map(|version| version.parse::().unwrap()) + .collect(); + assert_eq!(versions.len(), 3); + (versions[0], versions[1], versions[2]) +} + +pub fn is_major_version(expected_version: u16, version: Version) -> bool { + expected_version <= version.0 +} + +pub fn is_version(expected_major_minor: (u16, u16), version: Version) -> bool { + expected_major_minor.0 < version.0 + || (expected_major_minor.0 == version.0 && expected_major_minor.1 <= version.1) +} + +#[cfg(feature = "tls-rustls")] +fn load_certs_from_file(tls_file_paths: &TlsFilePaths) -> TlsCertificates { + let ca_file = File::open(&tls_file_paths.ca_crt).expect("Cannot open CA cert file"); + let mut root_cert_vec = Vec::new(); + BufReader::new(ca_file) + .read_to_end(&mut root_cert_vec) + .expect("Unable to read CA cert file"); + + let cert_file = File::open(&tls_file_paths.redis_crt).expect("cannot open private cert file"); + let mut client_cert_vec = Vec::new(); + BufReader::new(cert_file) + .read_to_end(&mut client_cert_vec) + .expect("Unable to read client cert file"); + + let key_file = File::open(&tls_file_paths.redis_key).expect("Cannot open private key file"); + let mut client_key_vec = Vec::new(); + BufReader::new(key_file) + .read_to_end(&mut client_key_vec) + .expect("Unable to read client key file"); + + TlsCertificates { + client_tls: Some(ClientTlsConfig { + client_cert: client_cert_vec, + client_key: client_key_vec, + }), + root_cert: Some(root_cert_vec), + } +} + +#[cfg(feature = "tls-rustls")] +pub(crate) fn build_single_client( + connection_info: T, + tls_file_params: &Option, + mtls_enabled: bool, +) -> redis::RedisResult { + if mtls_enabled && tls_file_params.is_some() { + redis::Client::build_with_tls( + connection_info, + load_certs_from_file( + tls_file_params + .as_ref() + .expect("Expected certificates when `tls-rustls` feature is enabled"), + ), + ) + } else { + redis::Client::open(connection_info) + } +} + +#[cfg(feature = "tls-rustls")] +pub(crate) mod mtls_test { + use super::*; + use redis::{cluster::ClusterClient, ConnectionInfo, RedisError}; + + fn clean_node_info(nodes: &[ConnectionInfo]) -> Vec { + let nodes = nodes + .iter() + .map(|node| match node { + ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { host, port, .. }, + redis, + } => ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { + host: host.to_owned(), + port: *port, + insecure: false, + tls_params: None, + }, + redis: redis.clone(), + }, + _ => node.clone(), + }) + .collect(); + nodes + } + + pub(crate) fn create_cluster_client_from_cluster( + cluster: &TestClusterContext, + mtls_enabled: bool, + ) -> Result { + let server = cluster + .cluster + .servers + .first() + .expect("Expected at least 1 server"); + let tls_paths = server.tls_paths.as_ref(); + let nodes = clean_node_info(&cluster.nodes); + let builder = redis::cluster::ClusterClientBuilder::new(nodes); + if let Some(tls_paths) = tls_paths { + // server-side TLS available + if mtls_enabled { + builder.certs(load_certs_from_file(tls_paths)) + } else { + builder + } + } else { + // server-side TLS NOT available + builder + } + .build() + } +} + +pub fn build_simple_pipeline_for_invalidation() -> Pipeline { + let mut pipe = redis::pipe(); + pipe.cmd("GET") + .arg("key_1") + .ignore() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore(); + pipe +} diff --git a/glide-core/redis-rs/redis/tests/support/sentinel.rs b/glide-core/redis-rs/redis/tests/support/sentinel.rs new file mode 100644 index 0000000000..d34d3dc88b --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/sentinel.rs @@ -0,0 +1,404 @@ +use std::fs::File; +use std::io::Write; +use std::thread::sleep; +use std::time::Duration; + +use redis::sentinel::SentinelNodeConnectionInfo; +use redis::Client; +use redis::ConnectionAddr; +use redis::ConnectionInfo; +use redis::FromRedisValue; +use redis::RedisResult; +use redis::TlsMode; +use tempfile::TempDir; + +use crate::support::build_single_client; + +use super::build_keys_and_certs_for_tls; +use super::get_random_available_port; +use super::Module; +use super::RedisServer; +use super::TlsFilePaths; + +const LOCALHOST: &str = "127.0.0.1"; +const MTLS_NOT_ENABLED: bool = false; + +pub struct RedisSentinelCluster { + pub servers: Vec, + pub sentinel_servers: Vec, + pub folders: Vec, +} + +fn get_addr(port: u16) -> ConnectionAddr { + let addr = RedisServer::get_addr(port); + if let ConnectionAddr::Unix(_) = addr { + ConnectionAddr::Tcp(String::from("127.0.0.1"), port) + } else { + addr + } +} + +fn spawn_master_server( + port: u16, + dir: &TempDir, + tlspaths: &TlsFilePaths, + modules: &[Module], +) -> RedisServer { + RedisServer::new_with_addr_tls_modules_and_spawner( + get_addr(port), + None, + Some(tlspaths.clone()), + MTLS_NOT_ENABLED, + modules, + |cmd| { + // Minimize startup delay + cmd.arg("--repl-diskless-sync-delay").arg("0"); + cmd.arg("--appendonly").arg("yes"); + if let ConnectionAddr::TcpTls { .. } = get_addr(port) { + cmd.arg("--tls-replication").arg("yes"); + } + cmd.current_dir(dir.path()); + cmd.spawn().unwrap() + }, + ) +} + +fn spawn_replica_server( + port: u16, + master_port: u16, + dir: &TempDir, + tlspaths: &TlsFilePaths, + modules: &[Module], +) -> RedisServer { + let config_file_path = dir.path().join("redis_config.conf"); + File::create(&config_file_path).unwrap(); + + RedisServer::new_with_addr_tls_modules_and_spawner( + get_addr(port), + Some(&config_file_path), + Some(tlspaths.clone()), + MTLS_NOT_ENABLED, + modules, + |cmd| { + cmd.arg("--replicaof") + .arg("127.0.0.1") + .arg(master_port.to_string()); + if let ConnectionAddr::TcpTls { .. } = get_addr(port) { + cmd.arg("--tls-replication").arg("yes"); + } + cmd.arg("--appendonly").arg("yes"); + cmd.current_dir(dir.path()); + cmd.spawn().unwrap() + }, + ) +} + +fn spawn_sentinel_server( + port: u16, + master_ports: &[u16], + dir: &TempDir, + tlspaths: &TlsFilePaths, + modules: &[Module], +) -> RedisServer { + let config_file_path = dir.path().join("redis_config.conf"); + let mut file = File::create(&config_file_path).unwrap(); + for (i, master_port) in master_ports.iter().enumerate() { + file.write_all( + format!("sentinel monitor master{} 127.0.0.1 {} 1\n", i, master_port).as_bytes(), + ) + .unwrap(); + } + file.flush().unwrap(); + + RedisServer::new_with_addr_tls_modules_and_spawner( + get_addr(port), + Some(&config_file_path), + Some(tlspaths.clone()), + MTLS_NOT_ENABLED, + modules, + |cmd| { + cmd.arg("--sentinel"); + cmd.arg("--appendonly").arg("yes"); + if let ConnectionAddr::TcpTls { .. } = get_addr(port) { + cmd.arg("--tls-replication").arg("yes"); + } + cmd.current_dir(dir.path()); + cmd.spawn().unwrap() + }, + ) +} + +fn wait_for_master_server( + mut get_client_fn: impl FnMut() -> RedisResult, +) -> Result<(), ()> { + let rolecmd = redis::cmd("ROLE"); + for _ in 0..100 { + let master_client = get_client_fn(); + match master_client { + Ok(client) => match client.get_connection(None) { + Ok(mut conn) => { + let r: Vec = rolecmd.query(&mut conn).unwrap(); + let role = String::from_redis_value(r.first().unwrap()).unwrap(); + if role.starts_with("master") { + return Ok(()); + } else { + println!("failed check for master role - current role: {r:?}") + } + } + Err(err) => { + println!("failed to get master connection: {:?}", err) + } + }, + Err(err) => { + println!("failed to get master client: {:?}", err) + } + } + + sleep(Duration::from_millis(25)); + } + + Err(()) +} + +fn wait_for_replica(mut get_client_fn: impl FnMut() -> RedisResult) -> Result<(), ()> { + let rolecmd = redis::cmd("ROLE"); + for _ in 0..200 { + let replica_client = get_client_fn(); + match replica_client { + Ok(client) => match client.get_connection(None) { + Ok(mut conn) => { + let r: Vec = rolecmd.query(&mut conn).unwrap(); + let role = String::from_redis_value(r.first().unwrap()).unwrap(); + let state = String::from_redis_value(r.get(3).unwrap()).unwrap(); + if role.starts_with("slave") && state == "connected" { + return Ok(()); + } else { + println!("failed check for replica role - current role: {:?}", r) + } + } + Err(err) => { + println!("failed to get replica connection: {:?}", err) + } + }, + Err(err) => { + println!("failed to get replica client: {:?}", err) + } + } + + sleep(Duration::from_millis(25)); + } + + Err(()) +} + +fn wait_for_replicas_to_sync(servers: &[RedisServer], masters: u16) { + let cluster_size = servers.len() / (masters as usize); + let clusters = servers.len() / cluster_size; + let replicas = cluster_size - 1; + + for cluster_index in 0..clusters { + let master_addr = servers[cluster_index * cluster_size].connection_info(); + let tls_paths = &servers.first().unwrap().tls_paths; + let r = wait_for_master_server(|| { + Ok(build_single_client(master_addr.clone(), tls_paths, MTLS_NOT_ENABLED).unwrap()) + }); + if r.is_err() { + panic!("failed waiting for master to be ready"); + } + + for replica_index in 0..replicas { + let replica_addr = + servers[(cluster_index * cluster_size) + 1 + replica_index].connection_info(); + let r = wait_for_replica(|| { + Ok(build_single_client(replica_addr.clone(), tls_paths, MTLS_NOT_ENABLED).unwrap()) + }); + if r.is_err() { + panic!("failed waiting for replica to be ready and in sync"); + } + } + } +} + +impl RedisSentinelCluster { + pub fn new(masters: u16, replicas_per_master: u16, sentinels: u16) -> RedisSentinelCluster { + RedisSentinelCluster::with_modules(masters, replicas_per_master, sentinels, &[]) + } + + pub fn with_modules( + masters: u16, + replicas_per_master: u16, + sentinels: u16, + modules: &[Module], + ) -> RedisSentinelCluster { + let mut servers = vec![]; + let mut folders = vec![]; + let mut master_ports = vec![]; + + let tempdir = tempfile::Builder::new() + .prefix("redistls") + .tempdir() + .expect("failed to create tempdir"); + let tlspaths = build_keys_and_certs_for_tls(&tempdir); + folders.push(tempdir); + + let required_number_of_sockets = masters * (replicas_per_master + 1) + sentinels; + let mut available_ports = std::collections::HashSet::new(); + while available_ports.len() < required_number_of_sockets as usize { + available_ports.insert(get_random_available_port()); + } + let mut available_ports: Vec<_> = available_ports.into_iter().collect(); + + for _ in 0..masters { + let port = available_ports.pop().unwrap(); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + servers.push(spawn_master_server(port, &tempdir, &tlspaths, modules)); + folders.push(tempdir); + master_ports.push(port); + + for _ in 0..replicas_per_master { + let replica_port = available_ports.pop().unwrap(); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + servers.push(spawn_replica_server( + replica_port, + port, + &tempdir, + &tlspaths, + modules, + )); + folders.push(tempdir); + } + } + + // Wait for replicas to sync so that the sentinels discover them on the first try + wait_for_replicas_to_sync(&servers, masters); + + let mut sentinel_servers = vec![]; + for _ in 0..sentinels { + let port = available_ports.pop().unwrap(); + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + + sentinel_servers.push(spawn_sentinel_server( + port, + &master_ports, + &tempdir, + &tlspaths, + modules, + )); + folders.push(tempdir); + } + + RedisSentinelCluster { + servers, + sentinel_servers, + folders, + } + } + + pub fn stop(&mut self) { + for server in &mut self.servers { + server.stop(); + } + for server in &mut self.sentinel_servers { + server.stop(); + } + } + + pub fn iter_sentinel_servers(&self) -> impl Iterator { + self.sentinel_servers.iter() + } +} + +impl Drop for RedisSentinelCluster { + fn drop(&mut self) { + self.stop() + } +} + +pub struct TestSentinelContext { + pub cluster: RedisSentinelCluster, + pub sentinel: redis::sentinel::Sentinel, + pub sentinels_connection_info: Vec, + mtls_enabled: bool, // for future tests +} + +impl TestSentinelContext { + pub fn new(nodes: u16, replicas: u16, sentinels: u16) -> TestSentinelContext { + Self::new_with_cluster_client_builder(nodes, replicas, sentinels) + } + + pub fn new_with_cluster_client_builder( + nodes: u16, + replicas: u16, + sentinels: u16, + ) -> TestSentinelContext { + let cluster = RedisSentinelCluster::new(nodes, replicas, sentinels); + let initial_nodes: Vec = cluster + .iter_sentinel_servers() + .map(RedisServer::connection_info) + .collect(); + let sentinel = redis::sentinel::Sentinel::build(initial_nodes.clone()); + let sentinel = sentinel.unwrap(); + + let mut context = TestSentinelContext { + cluster, + sentinel, + sentinels_connection_info: initial_nodes, + mtls_enabled: MTLS_NOT_ENABLED, + }; + context.wait_for_cluster_up(); + context + } + + pub fn sentinel(&self) -> &redis::sentinel::Sentinel { + &self.sentinel + } + + pub fn sentinel_mut(&mut self) -> &mut redis::sentinel::Sentinel { + &mut self.sentinel + } + + pub fn sentinels_connection_info(&self) -> &Vec { + &self.sentinels_connection_info + } + + pub fn sentinel_node_connection_info(&self) -> SentinelNodeConnectionInfo { + SentinelNodeConnectionInfo { + tls_mode: if let ConnectionAddr::TcpTls { insecure, .. } = + self.cluster.servers[0].client_addr() + { + if *insecure { + Some(TlsMode::Insecure) + } else { + Some(TlsMode::Secure) + } + } else { + None + }, + redis_connection_info: None, + } + } + + pub fn wait_for_cluster_up(&mut self) { + let node_conn_info = self.sentinel_node_connection_info(); + let con = self.sentinel_mut(); + + let r = wait_for_master_server(|| con.master_for("master1", Some(&node_conn_info))); + if r.is_err() { + panic!("failed waiting for sentinel master1 to be ready"); + } + + let r = wait_for_replica(|| con.replica_for("master1", Some(&node_conn_info))); + if r.is_err() { + panic!("failed waiting for sentinel master1 replica to be ready"); + } + } +} diff --git a/glide-core/redis-rs/redis/tests/support/util.rs b/glide-core/redis-rs/redis/tests/support/util.rs new file mode 100644 index 0000000000..8026b83fb5 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/support/util.rs @@ -0,0 +1,23 @@ +use std::collections::HashMap; + +#[macro_export] +macro_rules! assert_args { + ($value:expr, $($args:expr),+) => { + let args = $value.to_redis_args(); + let strings: Vec<_> = args.iter() + .map(|a| std::str::from_utf8(a.as_ref()).unwrap()) + .collect(); + assert_eq!(strings, vec![$($args),+]); + } +} + +pub fn parse_client_info(client_info: &str) -> HashMap { + let mut res = HashMap::new(); + + for line in client_info.split(' ') { + let this_attr: Vec<&str> = line.split('=').collect(); + res.insert(this_attr[0].to_string(), this_attr[1].to_string()); + } + + res +} diff --git a/glide-core/redis-rs/redis/tests/test_acl.rs b/glide-core/redis-rs/redis/tests/test_acl.rs new file mode 100644 index 0000000000..093774f3bc --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_acl.rs @@ -0,0 +1,156 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "acl")] + +use std::collections::HashSet; + +use redis::acl::{AclInfo, Rule}; +use redis::{Commands, Value}; + +mod support; +use crate::support::*; + +#[test] +fn test_acl_whoami() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + assert_eq!(con.acl_whoami(), Ok("default".to_owned())); +} + +#[test] +fn test_acl_help() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let res: Vec = con.acl_help().expect("Got help manual"); + assert!(!res.is_empty()); +} + +//TODO: do we need this test? +#[test] +#[ignore] +fn test_acl_getsetdel_users() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + assert_eq!( + con.acl_list(), + Ok(vec!["user default on nopass ~* +@all".to_owned()]) + ); + assert_eq!(con.acl_users(), Ok(vec!["default".to_owned()])); + // bob + assert_eq!(con.acl_setuser("bob"), Ok(())); + assert_eq!( + con.acl_users(), + Ok(vec!["bob".to_owned(), "default".to_owned()]) + ); + + // ACL SETUSER bob on ~redis:* +set + assert_eq!( + con.acl_setuser_rules( + "bob", + &[ + Rule::On, + Rule::AddHashedPass( + "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2".to_owned() + ), + Rule::Pattern("redis:*".to_owned()), + Rule::AddCommand("set".to_owned()) + ], + ), + Ok(()) + ); + let acl_info: AclInfo = con.acl_getuser("bob").expect("Got user"); + assert_eq!( + acl_info, + AclInfo { + flags: vec![Rule::On], + passwords: vec![Rule::AddHashedPass( + "c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2".to_owned() + )], + commands: vec![ + Rule::RemoveCategory("all".to_owned()), + Rule::AddCommand("set".to_owned()) + ], + keys: vec![Rule::Pattern("redis:*".to_owned())], + } + ); + assert_eq!( + con.acl_list(), + Ok(vec![ + "user bob on #c3ab8ff13720e8ad9047dd39466b3c8974e592c2fa383d4a3960714caef0c4f2 ~redis:* -@all +set".to_owned(), + "user default on nopass ~* +@all".to_owned(), + ]) + ); + + // ACL SETUSER eve + assert_eq!(con.acl_setuser("eve"), Ok(())); + assert_eq!( + con.acl_users(), + Ok(vec![ + "bob".to_owned(), + "default".to_owned(), + "eve".to_owned() + ]) + ); + assert_eq!(con.acl_deluser(&["bob", "eve"]), Ok(2)); + assert_eq!(con.acl_users(), Ok(vec!["default".to_owned()])); +} + +#[test] +fn test_acl_cat() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let res: HashSet = con.acl_cat().expect("Got categories"); + let expects = vec![ + "keyspace", + "read", + "write", + "set", + "sortedset", + "list", + "hash", + "string", + "bitmap", + "hyperloglog", + "geo", + "stream", + "pubsub", + "admin", + "fast", + "slow", + "blocking", + "dangerous", + "connection", + "transaction", + "scripting", + ]; + for cat in expects.iter() { + assert!(res.contains(*cat), "Category `{cat}` does not exist"); + } + + let expects = ["pfmerge", "pfcount", "pfselftest", "pfadd"]; + let res: HashSet = con + .acl_cat_categoryname("hyperloglog") + .expect("Got commands of a category"); + for cmd in expects.iter() { + assert!(res.contains(*cmd), "Command `{cmd}` does not exist"); + } +} + +#[test] +fn test_acl_genpass() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let pass: String = con.acl_genpass().expect("Got password"); + assert_eq!(pass.len(), 64); + + let pass: String = con.acl_genpass_bits(1024).expect("Got password"); + assert_eq!(pass.len(), 256); +} + +#[test] +fn test_acl_log() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let logs: Vec = con.acl_log(1).expect("Got logs"); + assert_eq!(logs.len(), 0); + assert_eq!(con.acl_log_reset(), Ok(())); +} diff --git a/glide-core/redis-rs/redis/tests/test_async.rs b/glide-core/redis-rs/redis/tests/test_async.rs new file mode 100644 index 0000000000..d16f1e0694 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_async.rs @@ -0,0 +1,1132 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] + +mod support; + +#[cfg(test)] +mod basic_async { + use std::collections::HashMap; + + use futures::{prelude::*, StreamExt}; + use redis::{ + aio::{ConnectionLike, MultiplexedConnection}, + cmd, pipe, AsyncCommands, ErrorKind, GlideConnectionOptions, PushInfo, PushKind, + RedisResult, Value, + }; + use tokio::sync::mpsc::error::TryRecvError; + + use crate::support::*; + + #[test] + fn test_args() { + let ctx = TestContext::new(); + let connect = ctx.async_connection(); + + block_on_all(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&["key2", "bar"]) + .query_async(&mut con) + .await?; + let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; + assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + result + })) + .unwrap(); + } + + #[test] + fn test_nice_hash_api() { + let ctx = TestContext::new(); + + block_on_all(async move { + let mut connection = ctx.async_connection().await.unwrap(); + + assert_eq!( + connection + .hset_multiple("my_hash", &[("f1", 1), ("f2", 2), ("f3", 4), ("f4", 8)]) + .await, + Ok(()) + ); + + let hm: HashMap = connection.hgetall("my_hash").await.unwrap(); + assert_eq!(hm.len(), 4); + assert_eq!(hm.get("f1"), Some(&1)); + assert_eq!(hm.get("f2"), Some(&2)); + assert_eq!(hm.get("f3"), Some(&4)); + assert_eq!(hm.get("f4"), Some(&8)); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_nice_hash_api_in_pipe() { + let ctx = TestContext::new(); + + block_on_all(async move { + let mut connection = ctx.async_connection().await.unwrap(); + + assert_eq!( + connection + .hset_multiple("my_hash", &[("f1", 1), ("f2", 2), ("f3", 4), ("f4", 8)]) + .await, + Ok(()) + ); + + let mut pipe = redis::pipe(); + pipe.cmd("HGETALL").arg("my_hash"); + let mut vec: Vec> = + pipe.query_async(&mut connection).await.unwrap(); + assert_eq!(vec.len(), 1); + let hash = vec.pop().unwrap(); + assert_eq!(hash.len(), 4); + assert_eq!(hash.get("f1"), Some(&1)); + assert_eq!(hash.get("f2"), Some(&2)); + assert_eq!(hash.get("f3"), Some(&4)); + assert_eq!(hash.get("f4"), Some(&8)); + + Ok(()) + }) + .unwrap(); + } + + #[test] + fn dont_panic_on_closed_multiplexed_connection() { + let ctx = TestContext::new(); + let client = ctx.client.clone(); + let connect = client.get_multiplexed_async_connection(GlideConnectionOptions::default()); + drop(ctx); + + block_on_all(async move { + connect + .and_then(|con| async move { + let cmd = move || { + let mut con = con.clone(); + async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await + } + }; + let result: RedisResult<()> = cmd().await; + assert_eq!( + result.as_ref().unwrap_err().kind(), + redis::ErrorKind::IoError, + "{}", + result.as_ref().unwrap_err() + ); + cmd().await + }) + .map(|result| { + assert_eq!( + result.as_ref().unwrap_err().kind(), + redis::ErrorKind::IoError, + "{}", + result.as_ref().unwrap_err() + ); + }) + .await; + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_pipeline_transaction() { + let ctx = TestContext::new(); + block_on_all(async move { + let mut con = ctx.async_connection().await?; + let mut pipe = redis::pipe(); + pipe.atomic() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["key_1", "key_2"]); + pipe.query_async(&mut con) + .map_ok(|((k1, k2),): ((i32, i32),)| { + assert_eq!(k1, 42); + assert_eq!(k2, 43); + }) + .await + }) + .unwrap(); + } + + #[test] + fn test_client_tracking_doesnt_block_execution() { + //It checks if the library distinguish a push-type message from the others and continues its normal operation. + let ctx = TestContext::new(); + block_on_all(async move { + let mut con = ctx.async_connection().await.unwrap(); + let mut pipe = redis::pipe(); + pipe.cmd("CLIENT") + .arg("TRACKING") + .arg("ON") + .ignore() + .cmd("GET") + .arg("key_1") + .ignore() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore(); + let _: RedisResult<()> = pipe.query_async(&mut con).await; + let num: i32 = con.get("key_1").await.unwrap(); + assert_eq!(num, 42); + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_pipeline_transaction_with_errors() { + use redis::RedisError; + let ctx = TestContext::new(); + + block_on_all(async move { + let mut con = ctx.async_connection().await?; + con.set::<_, _, ()>("x", 42).await.unwrap(); + + // Make Redis a replica of a nonexistent master, thereby making it read-only. + redis::cmd("slaveof") + .arg("1.1.1.1") + .arg("1") + .query_async::<_, ()>(&mut con) + .await + .unwrap(); + + // Ensure that a write command fails with a READONLY error + let err: RedisResult<()> = redis::pipe() + .atomic() + .set("x", 142) + .ignore() + .get("x") + .query_async(&mut con) + .await; + + assert_eq!(err.unwrap_err().kind(), ErrorKind::ReadOnly); + + let x: i32 = con.get("x").await.unwrap(); + assert_eq!(x, 42); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + fn test_cmd( + con: &MultiplexedConnection, + i: i32, + ) -> impl Future> + Send { + let mut con = con.clone(); + async move { + let key = format!("key{i}"); + let key_2 = key.clone(); + let key2 = format!("key{i}_2"); + let key2_2 = key2.clone(); + + let foo_val = format!("foo{i}"); + + redis::cmd("SET") + .arg(&key[..]) + .arg(foo_val.as_bytes()) + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&[&key2, "bar"]) + .query_async(&mut con) + .await?; + redis::cmd("MGET") + .arg(&[&key_2, &key2_2]) + .query_async(&mut con) + .map(|result| { + assert_eq!(Ok((foo_val, b"bar".to_vec())), result); + Ok(()) + }) + .await + } + } + + fn test_error(con: &MultiplexedConnection) -> impl Future> { + let mut con = con.clone(); + async move { + redis::cmd("SET") + .query_async(&mut con) + .map(|result| match result { + Ok(()) => panic!("Expected redis to return an error"), + Err(_) => Ok(()), + }) + .await + } + } + + #[test] + fn test_pipe_over_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + let mut con = ctx.multiplexed_async_connection().await?; + let mut pipe = pipe(); + pipe.zrange("zset", 0, 0); + pipe.zrange("zset", 0, 0); + let frames = con.send_packed_commands(&pipe, 0, 2).await?; + assert_eq!(frames.len(), 2); + assert!(matches!(frames[0], redis::Value::Array(_))); + assert!(matches!(frames[1], redis::Value::Array(_))); + RedisResult::Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_args_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|con| { + let cmds = (0..100).map(move |i| test_cmd(&con, i)); + future::try_join_all(cmds).map_ok(|results| { + assert_eq!(results.len(), 100); + }) + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + #[test] + fn test_args_with_errors_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|con| { + let cmds = (0..100).map(move |i| { + let con = con.clone(); + async move { + if i % 2 == 0 { + test_cmd(&con, i).await + } else { + test_error(&con).await + } + } + }); + future::try_join_all(cmds).map_ok(|results| { + assert_eq!(results.len(), 100); + }) + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + #[test] + fn test_transaction_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|con| { + let cmds = (0..100).map(move |i| { + let mut con = con.clone(); + async move { + let foo_val = i; + let bar_val = format!("bar{i}"); + + let mut pipe = redis::pipe(); + pipe.atomic() + .cmd("SET") + .arg("key") + .arg(foo_val) + .ignore() + .cmd("SET") + .arg(&["key2", &bar_val[..]]) + .ignore() + .cmd("MGET") + .arg(&["key", "key2"]); + + pipe.query_async(&mut con) + .map(move |result| { + assert_eq!(Ok(((foo_val, bar_val.into_bytes()),)), result); + result + }) + .await + } + }); + future::try_join_all(cmds) + }) + .map_ok(|results| { + assert_eq!(results.len(), 100); + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + fn test_async_scanning(batch_size: usize) { + let ctx = TestContext::new(); + block_on_all(async move { + ctx.multiplexed_async_connection() + .and_then(|mut con| { + async move { + let mut unseen = std::collections::HashSet::new(); + + for x in 0..batch_size { + redis::cmd("SADD") + .arg("foo") + .arg(x) + .query_async(&mut con) + .await?; + unseen.insert(x); + } + + let mut iter = redis::cmd("SSCAN") + .arg("foo") + .cursor_arg(0) + .clone() + .iter_async(&mut con) + .await + .unwrap(); + + while let Some(x) = iter.next_item().await { + // type inference limitations + let x: usize = x; + // if this assertion fails, too many items were returned by the iterator. + assert!(unseen.remove(&x)); + } + + assert_eq!(unseen.len(), 0); + Ok(()) + } + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); + } + + #[test] + fn test_async_scanning_big_batch() { + test_async_scanning(1000) + } + + #[test] + fn test_async_scanning_small_batch() { + test_async_scanning(2) + } + + #[test] + fn test_response_timeout_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all(async move { + let mut connection = ctx.multiplexed_async_connection().await.unwrap(); + connection.set_response_timeout(std::time::Duration::from_millis(1)); + let mut cmd = redis::Cmd::new(); + cmd.arg("BLPOP").arg("foo").arg(0); // 0 timeout blocks indefinitely + let result = connection.req_packed_command(&cmd).await; + assert!(result.is_err()); + assert!(result.unwrap_err().is_timeout()); + Ok(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "script")] + fn test_script() { + use redis::RedisError; + + // Note this test runs both scripts twice to test when they have already been loaded + // into Redis and when they need to be loaded in + let script1 = redis::Script::new("return redis.call('SET', KEYS[1], ARGV[1])"); + let script2 = redis::Script::new("return redis.call('GET', KEYS[1])"); + let script3 = redis::Script::new("return redis.call('KEYS', '*')"); + + let ctx = TestContext::new(); + + block_on_all(async move { + let mut con = ctx.multiplexed_async_connection().await?; + script1 + .key("key1") + .arg("foo") + .invoke_async(&mut con) + .await?; + let val: String = script2.key("key1").invoke_async(&mut con).await?; + assert_eq!(val, "foo"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); + script1 + .key("key1") + .arg("bar") + .invoke_async(&mut con) + .await?; + let val: String = script2.key("key1").invoke_async(&mut con).await?; + assert_eq!(val, "bar"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "script")] + fn test_script_load() { + let ctx = TestContext::new(); + let script = redis::Script::new("return 'Hello World'"); + + block_on_all(async move { + let mut con = ctx.multiplexed_async_connection().await.unwrap(); + + let hash = script.prepare_invoke().load_async(&mut con).await.unwrap(); + assert_eq!(hash, script.get_hash().to_string()); + Ok(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "script")] + fn test_script_returning_complex_type() { + let ctx = TestContext::new(); + block_on_all(async { + let mut con = ctx.multiplexed_async_connection().await?; + redis::Script::new("return {1, ARGV[1], true}") + .arg("hello") + .invoke_async(&mut con) + .map_ok(|(i, s, b): (i32, String, bool)| { + assert_eq!(i, 1); + assert_eq!(s, "hello"); + assert!(b); + }) + .await + }) + .unwrap(); + } + + // Allowing `nth(0)` for similarity with the following `nth(1)`. + // Allowing `let ()` as `query_async` requries the type it converts the result to. + #[allow(clippy::let_unit_value, clippy::iter_nth_zero)] + #[tokio::test] + async fn io_error_on_kill_issue_320() { + let ctx = TestContext::new(); + + let mut conn_to_kill = ctx.async_connection().await.unwrap(); + cmd("CLIENT") + .arg("SETNAME") + .arg("to-kill") + .query_async::<_, ()>(&mut conn_to_kill) + .await + .unwrap(); + + let client_list: String = cmd("CLIENT") + .arg("LIST") + .query_async(&mut conn_to_kill) + .await + .unwrap(); + + eprintln!("{client_list}"); + let client_to_kill = client_list + .split('\n') + .find(|line| line.contains("to-kill")) + .expect("line") + .split(' ') + .nth(0) + .expect("id") + .split('=') + .nth(1) + .expect("id value"); + + let mut killer_conn = ctx.async_connection().await.unwrap(); + let () = cmd("CLIENT") + .arg("KILL") + .arg("ID") + .arg(client_to_kill) + .query_async(&mut killer_conn) + .await + .unwrap(); + let mut killed_client = conn_to_kill; + + let err = loop { + match killed_client.get::<_, Option>("a").await { + // We are racing against the server being shutdown so try until we a get an io error + Ok(_) => tokio::time::sleep(std::time::Duration::from_millis(50)).await, + Err(err) => break err, + } + }; + assert_eq!(err.kind(), ErrorKind::IoError); // Shouldn't this be IoError? + } + + #[tokio::test] + async fn invalid_password_issue_343() { + let ctx = TestContext::new(); + let coninfo = redis::ConnectionInfo { + addr: ctx.server.client_addr().clone(), + redis: redis::RedisConnectionInfo { + password: Some("asdcasc".to_string()), + ..Default::default() + }, + }; + let client = redis::Client::open(coninfo).unwrap(); + + let err = client + .get_multiplexed_tokio_connection(GlideConnectionOptions::default()) + .await + .err() + .unwrap(); + assert_eq!( + err.kind(), + ErrorKind::AuthenticationFailed, + "Unexpected error: {err}", + ); + } + + // Test issue of Stream trait blocking if we try to iterate more than 10 items + // https://github.com/mitsuhiko/redis-rs/issues/537 and https://github.com/mitsuhiko/redis-rs/issues/583 + #[tokio::test] + async fn test_issue_stream_blocks() { + let ctx = TestContext::new(); + let mut con = ctx.multiplexed_async_connection().await.unwrap(); + for i in 0..20usize { + let _: () = con.append(format!("test/{i}"), i).await.unwrap(); + } + let values = con.scan_match::<&str, String>("test/*").await.unwrap(); + tokio::time::timeout(std::time::Duration::from_millis(100), async move { + let values: Vec<_> = values.collect().await; + assert_eq!(values.len(), 20); + }) + .await + .unwrap(); + } + + // Test issue of AsyncCommands::scan returning the wrong number of keys + // https://github.com/redis-rs/redis-rs/issues/759 + #[tokio::test] + async fn test_issue_async_commands_scan_broken() { + let ctx = TestContext::new(); + let mut con = ctx.async_connection().await.unwrap(); + let mut keys: Vec = (0..100).map(|k| format!("async-key{k}")).collect(); + keys.sort(); + for key in &keys { + let _: () = con.set(key, b"foo").await.unwrap(); + } + + let iter: redis::AsyncIter = con.scan().await.unwrap(); + let mut keys_from_redis: Vec<_> = iter.collect().await; + keys_from_redis.sort(); + assert_eq!(keys, keys_from_redis); + assert_eq!(keys.len(), 100); + } + + mod pub_sub { + use std::time::Duration; + + use redis::ProtocolVersion; + + use super::*; + + #[test] + fn pub_sub_subscription() { + use redis::RedisError; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe("phonewave").await?; + let mut pubsub_stream = pubsub_conn.on_message(); + let mut publish_conn = ctx.async_connection().await?; + publish_conn.publish("phonewave", "banana").await?; + + let msg_payload: String = pubsub_stream.next().await.unwrap().get_payload()?; + assert_eq!("banana".to_string(), msg_payload); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pub_sub_unsubscription() { + use redis::RedisError; + + const SUBSCRIPTION_KEY: &str = "phonewave-pub-sub-unsubscription"; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe(SUBSCRIPTION_KEY).await?; + pubsub_conn.unsubscribe(SUBSCRIPTION_KEY).await?; + + let mut conn = ctx.async_connection().await?; + let subscriptions_counts: HashMap = redis::cmd("PUBSUB") + .arg("NUMSUB") + .arg(SUBSCRIPTION_KEY) + .query_async(&mut conn) + .await?; + let subscription_count = *subscriptions_counts.get(SUBSCRIPTION_KEY).unwrap(); + assert_eq!(subscription_count, 0); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn automatic_unsubscription() { + use redis::RedisError; + + const SUBSCRIPTION_KEY: &str = "phonewave-automatic-unsubscription"; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe(SUBSCRIPTION_KEY).await?; + drop(pubsub_conn); + + let mut conn = ctx.async_connection().await?; + let mut subscription_count = 1; + // Allow for the unsubscription to occur within 5 seconds + for _ in 0..100 { + let subscriptions_counts: HashMap = redis::cmd("PUBSUB") + .arg("NUMSUB") + .arg(SUBSCRIPTION_KEY) + .query_async(&mut conn) + .await?; + subscription_count = *subscriptions_counts.get(SUBSCRIPTION_KEY).unwrap(); + if subscription_count == 0 { + break; + } + + std::thread::sleep(Duration::from_millis(50)); + } + assert_eq!(subscription_count, 0); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pub_sub_conn_reuse() { + use redis::RedisError; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut pubsub_conn = ctx.async_pubsub().await?; + pubsub_conn.subscribe("phonewave").await?; + pubsub_conn.psubscribe("*").await?; + + #[allow(deprecated)] + let mut conn = pubsub_conn.into_connection().await; + redis::cmd("SET") + .arg("foo") + .arg("bar") + .query_async(&mut conn) + .await?; + + let res: String = redis::cmd("GET").arg("foo").query_async(&mut conn).await?; + assert_eq!(&res, "bar"); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pipe_errors_do_not_affect_subsequent_commands() { + use redis::RedisError; + + let ctx = TestContext::new(); + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + + conn.lpush::<&str, &str, ()>("key", "value").await?; + + let res: Result<(String, usize), redis::RedisError> = redis::pipe() + .get("key") // WRONGTYPE + .llen("key") + .query_async(&mut conn) + .await; + + assert!(res.is_err()); + + let list: Vec = conn.lrange("key", 0, -1).await?; + + assert_eq!(list, vec!["value".to_owned()]); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn pub_sub_multiple() { + use redis::RedisError; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let pub_count = 10; + let channel_name = "phonewave".to_string(); + conn.get_push_manager().replace_sender(tx.clone()); + conn.subscribe(channel_name.clone()).await?; + rx.recv().await.unwrap(); //PASS SUBSCRIBE + + let mut publish_conn = ctx.async_connection().await?; + for i in 0..pub_count { + publish_conn + .publish(channel_name.clone(), format!("banana {i}")) + .await?; + } + for _ in 0..pub_count { + rx.recv().await.unwrap(); + } + assert!(rx.try_recv().is_err()); + + { + //Lets test if unsubscribing from individual channel subscription works + publish_conn + .publish(channel_name.clone(), "banana!") + .await?; + rx.recv().await.unwrap(); + } + { + //Giving none for channel id should unsubscribe all subscriptions from that channel and send unsubcribe command to server. + conn.unsubscribe(channel_name.clone()).await?; + rx.recv().await.unwrap(); //PASS UNSUBSCRIBE + publish_conn + .publish(channel_name.clone(), "banana!") + .await?; + //Let's wait for 100ms to make sure there is nothing in channel. + tokio::time::sleep(Duration::from_millis(100)).await; + assert!(rx.try_recv().is_err()); + } + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn push_manager_active_context() { + use redis::RedisError; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let mut sub_conn = ctx.multiplexed_async_connection().await?; + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + let channel_name = "test_channel".to_string(); + sub_conn.get_push_manager().replace_sender(tx.clone()); + sub_conn.subscribe(channel_name.clone()).await?; + + let rcv_msg = rx.recv().await.unwrap(); + println!("Received PushInfo: {:?}", rcv_msg); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn push_manager_disconnection() { + use redis::RedisError; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + conn.get_push_manager().replace_sender(tx.clone()); + + conn.set("A", "1").await?; + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty); + drop(ctx); + let x: RedisResult<()> = conn.set("A", "1").await; + assert!(x.is_err()); + assert_eq!(rx.recv().await.unwrap().kind, PushKind::Disconnection); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + } + + #[test] + fn test_async_basic_pipe_with_parsing_error() { + // Tests a specific case involving repeated errors in transactions. + let ctx = TestContext::new(); + + block_on_all(async move { + let mut conn = ctx.multiplexed_async_connection().await?; + + // create a transaction where 2 errors are returned. + // we call EVALSHA twice with no loaded script, thus triggering 2 errors. + redis::pipe() + .atomic() + .cmd("EVALSHA") + .arg("foobar") + .arg(0) + .cmd("EVALSHA") + .arg("foobar") + .arg(0) + .query_async::<_, ((), ())>(&mut conn) + .await + .expect_err("should return an error"); + + assert!( + // Arbitrary Redis command that should not return an error. + redis::cmd("SMEMBERS") + .arg("nonexistent_key") + .query_async::<_, Vec>(&mut conn) + .await + .is_ok(), + "Failed transaction should not interfere with future calls." + ); + + Ok::<_, redis::RedisError>(()) + }) + .unwrap() + } + + #[cfg(feature = "connection-manager")] + async fn wait_for_server_to_become_ready(client: redis::Client) { + let millisecond = std::time::Duration::from_millis(1); + let mut retries = 0; + loop { + match client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + { + Err(err) => { + if err.is_connection_refusal() { + tokio::time::sleep(millisecond).await; + retries += 1; + if retries > 100000 { + panic!("Tried to connect too many times, last error: {err}"); + } + } else { + panic!("Could not connect: {err}"); + } + } + Ok(mut con) => { + let _: RedisResult<()> = redis::cmd("FLUSHDB").query_async(&mut con).await; + break; + } + } + } + } + + #[test] + #[cfg(feature = "connection-manager")] + fn test_connection_manager_reconnect_after_delay() { + use redis::ProtocolVersion; + + let tempdir = tempfile::Builder::new() + .prefix("redis") + .tempdir() + .expect("failed to create tempdir"); + let tls_files = build_keys_and_certs_for_tls(&tempdir); + + let ctx = TestContext::with_tls(tls_files.clone(), false); + block_on_all(async move { + let mut manager = redis::aio::ConnectionManager::new(ctx.client.clone()) + .await + .unwrap(); + let server = ctx.server; + let addr = server.client_addr().clone(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + manager.get_push_manager().replace_sender(tx.clone()); + drop(server); + + let _result: RedisResult = manager.set("foo", "bar").await; // one call is ignored because it's required to trigger the connection manager's reconnect. + if ctx.protocol != ProtocolVersion::RESP2 { + assert_eq!(rx.recv().await.unwrap().kind, PushKind::Disconnection); + } + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + let _new_server = RedisServer::new_with_addr_and_modules(addr.clone(), &[], false); + wait_for_server_to_become_ready(ctx.client.clone()).await; + + let result: redis::Value = manager.set("foo", "bar").await.unwrap(); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty); + assert_eq!(result, redis::Value::Okay); + Ok(()) + }) + .unwrap(); + } + + #[cfg(feature = "tls-rustls")] + mod mtls_test { + use super::*; + + #[test] + fn test_should_connect_mtls() { + let ctx = TestContext::new_with_mtls(); + + let client = + build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, true) + .unwrap(); + let connect = + client.get_multiplexed_async_connection(GlideConnectionOptions::default()); + block_on_all(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + let result = redis::cmd("GET").arg(&["key1"]).query_async(&mut con).await; + assert_eq!(result, Ok("foo".to_string())); + result + })) + .unwrap(); + } + + #[test] + fn test_should_not_connect_if_tls_active() { + let ctx = TestContext::new_with_mtls(); + + let client = + build_single_client(ctx.server.connection_info(), &ctx.server.tls_paths, false) + .unwrap(); + let connect = + client.get_multiplexed_async_connection(GlideConnectionOptions::default()); + let result = block_on_all(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + let result = redis::cmd("GET").arg(&["key1"]).query_async(&mut con).await; + assert_eq!(result, Ok("foo".to_string())); + result + })); + + // depends on server type set (REDISRS_SERVER_TYPE) + match ctx.server.connection_info() { + redis::ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { .. }, + .. + } => { + if result.is_ok() { + panic!("Must NOT be able to connect without client credentials if server accepts TLS"); + } + } + _ => { + if result.is_err() { + panic!("Must be able to connect without client credentials if server does NOT accept TLS"); + } + } + } + } + } + + #[test] + fn test_set_client_name_by_config() { + const CLIENT_NAME: &str = "TEST_CLIENT_NAME"; + use redis::RedisError; + let ctx = TestContext::with_client_name(CLIENT_NAME); + + block_on_all(async move { + let mut con = ctx.async_connection().await?; + + let client_info: String = redis::cmd("CLIENT") + .arg("INFO") + .query_async(&mut con) + .await + .unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], CLIENT_NAME, + "Incorrect client name, expecting: {}, got {}", + CLIENT_NAME, client_attrs["name"] + ); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[cfg(feature = "connection-manager")] + fn test_push_manager_cm() { + use redis::ProtocolVersion; + + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + + block_on_all(async move { + let mut manager = redis::aio::ConnectionManager::new(ctx.client.clone()) + .await + .unwrap(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + manager.get_push_manager().replace_sender(tx.clone()); + manager + .send_packed_command(cmd("CLIENT").arg("TRACKING").arg("ON")) + .await + .unwrap(); + let pipe = build_simple_pipeline_for_invalidation(); + let _: RedisResult<()> = pipe.query_async(&mut manager).await; + let _: i32 = manager.get("key_1").await.unwrap(); + let PushInfo { kind, data } = rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + let (new_tx, mut new_rx) = tokio::sync::mpsc::unbounded_channel(); + manager.get_push_manager().replace_sender(new_tx); + drop(rx); + let _: RedisResult<()> = pipe.query_async(&mut manager).await; + let _: i32 = manager.get("key_1").await.unwrap(); + let PushInfo { kind, data } = new_rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + assert_eq!(TryRecvError::Empty, new_rx.try_recv().err().unwrap()); + Ok(()) + }) + .unwrap(); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_async_async_std.rs b/glide-core/redis-rs/redis/tests/test_async_async_std.rs new file mode 100644 index 0000000000..656d1979f6 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_async_async_std.rs @@ -0,0 +1,328 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +use futures::prelude::*; + +use crate::support::*; + +use redis::{aio::MultiplexedConnection, GlideConnectionOptions, RedisResult}; + +mod support; + +#[test] +fn test_args() { + let ctx = TestContext::new(); + let connect = ctx.async_connection_async_std(); + + block_on_all_using_async_std(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&["key2", "bar"]) + .query_async(&mut con) + .await?; + let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; + assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + result + })) + .unwrap(); +} + +#[test] +fn test_args_async_std() { + let ctx = TestContext::new(); + let connect = ctx.async_connection_async_std(); + + block_on_all_using_async_std(connect.and_then(|mut con| async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&["key2", "bar"]) + .query_async(&mut con) + .await?; + let result = redis::cmd("MGET") + .arg(&["key1", "key2"]) + .query_async(&mut con) + .await; + assert_eq!(result, Ok(("foo".to_string(), b"bar".to_vec()))); + result + })) + .unwrap(); +} + +#[test] +fn dont_panic_on_closed_multiplexed_connection() { + let ctx = TestContext::new(); + let client = ctx.client.clone(); + let connect = client.get_multiplexed_async_std_connection(GlideConnectionOptions::default()); + drop(ctx); + + block_on_all_using_async_std(async move { + connect + .and_then(|con| async move { + let cmd = move || { + let mut con = con.clone(); + async move { + redis::cmd("SET") + .arg("key1") + .arg(b"foo") + .query_async(&mut con) + .await + } + }; + let result: RedisResult<()> = cmd().await; + assert_eq!( + result.as_ref().unwrap_err().kind(), + redis::ErrorKind::IoError, + "{}", + result.as_ref().unwrap_err() + ); + cmd().await + }) + .map(|result| { + assert_eq!( + result.as_ref().unwrap_err().kind(), + redis::ErrorKind::IoError, + "{}", + result.as_ref().unwrap_err() + ); + }) + .await + }); +} + +#[test] +fn test_pipeline_transaction() { + let ctx = TestContext::new(); + block_on_all_using_async_std(async move { + let mut con = ctx.async_connection_async_std().await?; + let mut pipe = redis::pipe(); + pipe.atomic() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["key_1", "key_2"]); + pipe.query_async(&mut con) + .map_ok(|((k1, k2),): ((i32, i32),)| { + assert_eq!(k1, 42); + assert_eq!(k2, 43); + }) + .await + }) + .unwrap(); +} + +fn test_cmd(con: &MultiplexedConnection, i: i32) -> impl Future> + Send { + let mut con = con.clone(); + async move { + let key = format!("key{i}"); + let key_2 = key.clone(); + let key2 = format!("key{i}_2"); + let key2_2 = key2.clone(); + + let foo_val = format!("foo{i}"); + + redis::cmd("SET") + .arg(&key[..]) + .arg(foo_val.as_bytes()) + .query_async(&mut con) + .await?; + redis::cmd("SET") + .arg(&[&key2, "bar"]) + .query_async(&mut con) + .await?; + redis::cmd("MGET") + .arg(&[&key_2, &key2_2]) + .query_async(&mut con) + .map(|result| { + assert_eq!(Ok((foo_val, b"bar".to_vec())), result); + Ok(()) + }) + .await + } +} + +fn test_error(con: &MultiplexedConnection) -> impl Future> { + let mut con = con.clone(); + async move { + redis::cmd("SET") + .query_async(&mut con) + .map(|result| match result { + Ok(()) => panic!("Expected redis to return an error"), + Err(_) => Ok(()), + }) + .await + } +} + +#[test] +fn test_args_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all_using_async_std(async move { + ctx.multiplexed_async_connection_async_std() + .and_then(|con| { + let cmds = (0..100).map(move |i| test_cmd(&con, i)); + future::try_join_all(cmds).map_ok(|results| { + assert_eq!(results.len(), 100); + }) + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); +} + +#[test] +fn test_args_with_errors_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all_using_async_std(async move { + ctx.multiplexed_async_connection_async_std() + .and_then(|con| { + let cmds = (0..100).map(move |i| { + let con = con.clone(); + async move { + if i % 2 == 0 { + test_cmd(&con, i).await + } else { + test_error(&con).await + } + } + }); + future::try_join_all(cmds).map_ok(|results| { + assert_eq!(results.len(), 100); + }) + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); +} + +#[test] +fn test_transaction_multiplexed_connection() { + let ctx = TestContext::new(); + block_on_all_using_async_std(async move { + ctx.multiplexed_async_connection_async_std() + .and_then(|con| { + let cmds = (0..100).map(move |i| { + let mut con = con.clone(); + async move { + let foo_val = i; + let bar_val = format!("bar{i}"); + + let mut pipe = redis::pipe(); + pipe.atomic() + .cmd("SET") + .arg("key") + .arg(foo_val) + .ignore() + .cmd("SET") + .arg(&["key2", &bar_val[..]]) + .ignore() + .cmd("MGET") + .arg(&["key", "key2"]); + + pipe.query_async(&mut con) + .map(move |result| { + assert_eq!(Ok(((foo_val, bar_val.into_bytes()),)), result); + result + }) + .await + } + }); + future::try_join_all(cmds) + }) + .map_ok(|results| { + assert_eq!(results.len(), 100); + }) + .map_err(|err| panic!("{}", err)) + .await + }) + .unwrap(); +} + +#[test] +#[cfg(feature = "script")] +fn test_script() { + use redis::RedisError; + + // Note this test runs both scripts twice to test when they have already been loaded + // into Redis and when they need to be loaded in + let script1 = redis::Script::new("return redis.call('SET', KEYS[1], ARGV[1])"); + let script2 = redis::Script::new("return redis.call('GET', KEYS[1])"); + let script3 = redis::Script::new("return redis.call('KEYS', '*')"); + + let ctx = TestContext::new(); + + block_on_all_using_async_std(async move { + let mut con = ctx.multiplexed_async_connection_async_std().await?; + script1 + .key("key1") + .arg("foo") + .invoke_async(&mut con) + .await?; + let val: String = script2.key("key1").invoke_async(&mut con).await?; + assert_eq!(val, "foo"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); + script1 + .key("key1") + .arg("bar") + .invoke_async(&mut con) + .await?; + let val: String = script2.key("key1").invoke_async(&mut con).await?; + assert_eq!(val, "bar"); + let keys: Vec = script3.invoke_async(&mut con).await?; + assert_eq!(keys, ["key1"]); + Ok::<_, RedisError>(()) + }) + .unwrap(); +} + +#[test] +#[cfg(feature = "script")] +fn test_script_load() { + let ctx = TestContext::new(); + let script = redis::Script::new("return 'Hello World'"); + + block_on_all(async move { + let mut con = ctx.multiplexed_async_connection_async_std().await.unwrap(); + + let hash = script.prepare_invoke().load_async(&mut con).await.unwrap(); + assert_eq!(hash, script.get_hash().to_string()); + Ok(()) + }) + .unwrap(); +} + +#[test] +#[cfg(feature = "script")] +fn test_script_returning_complex_type() { + let ctx = TestContext::new(); + block_on_all_using_async_std(async { + let mut con = ctx.multiplexed_async_connection_async_std().await?; + redis::Script::new("return {1, ARGV[1], true}") + .arg("hello") + .invoke_async(&mut con) + .map_ok(|(i, s, b): (i32, String, bool)| { + assert_eq!(i, 1); + assert_eq!(s, "hello"); + assert!(b); + }) + .await + }) + .unwrap(); +} diff --git a/glide-core/redis-rs/redis/tests/test_async_cluster_connections_logic.rs b/glide-core/redis-rs/redis/tests/test_async_cluster_connections_logic.rs new file mode 100644 index 0000000000..0230d1de17 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_async_cluster_connections_logic.rs @@ -0,0 +1,563 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "cluster-async")] +mod support; + +use redis::{ + cluster_async::testing::{AsyncClusterNode, RefreshConnectionType}, + testing::ClusterParams, + ErrorKind, GlideConnectionOptions, +}; +use std::net::{IpAddr, Ipv4Addr}; +use std::sync::Arc; +use support::{ + get_mock_connection, get_mock_connection_with_port, modify_mock_connection_behavior, + respond_startup, ConnectionIPReturnType, MockConnection, MockConnectionBehavior, +}; + +mod test_connect_and_check { + use std::sync::atomic::AtomicUsize; + + use super::*; + use crate::support::{get_mock_connection_handler, ShouldReturnConnectionError}; + use redis::cluster_async::testing::{ + connect_and_check, ConnectAndCheckResult, ConnectionWithIp, + }; + + fn assert_partial_result( + result: ConnectAndCheckResult, + ) -> (AsyncClusterNode, redis::RedisError) { + match result { + ConnectAndCheckResult::ManagementConnectionFailed { node, err } => (node, err), + ConnectAndCheckResult::Success(_) => { + panic!("Expected partial result, got full success") + } + ConnectAndCheckResult::Failed(_) => panic!("Expected partial result, got a failure"), + } + } + + fn assert_full_success( + result: ConnectAndCheckResult, + ) -> AsyncClusterNode { + match result { + ConnectAndCheckResult::Success(node) => node, + ConnectAndCheckResult::ManagementConnectionFailed { .. } => { + panic!("Expected full success, got partial success") + } + ConnectAndCheckResult::Failed(_) => panic!("Expected partial result, got a failure"), + } + } + + #[tokio::test] + async fn test_connect_and_check_connect_successfully() { + // Test that upon refreshing all connections, if both connections were successful, + // the returned node contains both user and management connection + let name = "test_connect_and_check_connect_successfully"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let ip = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Specified(ip) + }); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + assert!(node.management_connection.is_some()); + assert_eq!(node.user_connection.ip, Some(ip)); + assert_eq!(node.management_connection.unwrap().ip, Some(ip)); + } + + #[tokio::test] + async fn test_connect_and_check_all_connections_one_connection_err_returns_only_user_conn() { + // Test that upon refreshing all connections, if only one of the new connections fail, + // the other successful connection will be used as the user connection, as a partial success. + let name = "all_connections_one_connection_err"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + // The second connection will fail + behavior.return_connection_err = + ShouldReturnConnectionError::OnOddIdx(AtomicUsize::new(0)) + }); + + let params = ClusterParams::default(); + + let result = connect_and_check::( + &format!("{name}:6379"), + params.clone(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let (node, _) = assert_partial_result(result); + assert!(node.management_connection.is_none()); + + modify_mock_connection_behavior(name, |behavior| { + // The first connection will fail + behavior.return_connection_err = + ShouldReturnConnectionError::OnOddIdx(AtomicUsize::new(1)); + }); + + let result = connect_and_check::( + &format!("{name}:6379"), + params, + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let (node, _) = assert_partial_result(result); + assert!(node.management_connection.is_none()); + } + + #[tokio::test] + async fn test_connect_and_check_all_connections_different_ip_returns_both_connections() { + // Test that node's connections (e.g. user and management) can have different IPs for the same DNS endpoint. + // It is relevant for cases where the DNS entry holds multiple IPs that routes to the same node, for example with load balancers. + // The test verifies that upon refreshing all connections, if the IPs of the new connections differ, + // the function uses all connections. + let name = "all_connections_different_ip"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Different(AtomicUsize::new(0)); + }); + + // The first connection will have 0.0.0.0 IP, the second 1.0.0.0 + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + assert!(node.management_connection.is_some()); + assert_eq!( + node.user_connection.ip, + Some(IpAddr::V4(Ipv4Addr::new(0, 0, 0, 0))) + ); + assert_eq!( + node.management_connection.unwrap().ip, + Some(IpAddr::V4(Ipv4Addr::new(1, 0, 0, 0))) + ); + } + + #[tokio::test] + async fn test_connect_and_check_all_connections_both_conn_error_returns_err() { + // Test that when trying to refresh all connections and both connections fail, the function returns with an error + let name = "both_conn_error_returns_err"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + behavior.return_connection_err = ShouldReturnConnectionError::Yes + }); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::AllConnections, + None, + GlideConnectionOptions::default(), + ) + .await; + let err = result.get_error().unwrap(); + assert!( + err.to_string() + .contains("Failed to refresh both connections") + && err.kind() == ErrorKind::IoError + ); + } + + #[tokio::test] + async fn test_connect_and_check_only_management_same_ip() { + // Test that when we refresh only the management connection and the new connection returned with the same IP as the user's, + // the returned node contains a new management connection and the user connection remains unchanged + let name = "only_management_same_ip"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let ip = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Specified(ip) + }); + + let user_conn_id: usize = 1000; + let user_conn = MockConnection { + id: user_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + let node = AsyncClusterNode::new( + ConnectionWithIp { + conn: user_conn, + ip: Some(ip), + } + .into_future(), + None, + ); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::OnlyManagementConnection, + Some(node), + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + assert!(node.management_connection.is_some()); + // Confirm that the user connection remains unchanged + assert_eq!(node.user_connection.conn.await.id, user_conn_id); + } + + #[tokio::test] + async fn test_connect_and_check_only_management_connection_err() { + // Test that when we try the refresh only the management connection and it fails, we receive a partial success with the same node. + let name = "only_management_connection_err"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + modify_mock_connection_behavior(name, |behavior| { + behavior.return_connection_err = ShouldReturnConnectionError::Yes; + }); + + let user_conn_id: usize = 1000; + let user_conn = MockConnection { + id: user_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + let prev_ip = Some(IpAddr::V4(Ipv4Addr::new(1, 1, 1, 1))); + let node = AsyncClusterNode::new( + ConnectionWithIp { + conn: user_conn, + ip: prev_ip, + } + .into_future(), + None, + ); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::OnlyManagementConnection, + Some(node), + GlideConnectionOptions::default(), + ) + .await; + let (node, _) = assert_partial_result(result); + assert!(node.management_connection.is_none()); + // Confirm that the user connection was changed + assert_eq!(node.user_connection.conn.await.id, user_conn_id); + assert_eq!(node.user_connection.ip, prev_ip); + } + + #[tokio::test] + async fn test_connect_and_check_only_user_connection_same_ip() { + // Test that upon refreshing only the user connection, if the newly created connection share the same IP as the existing management connection, + // the managament connection remains unchanged + let name = "only_user_connection_same_ip"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let prev_ip = IpAddr::V4(Ipv4Addr::new(1, 2, 3, 4)); + modify_mock_connection_behavior(name, |behavior| { + behavior.returned_ip_type = ConnectionIPReturnType::Specified(prev_ip); + }); + let old_user_conn_id: usize = 1000; + let management_conn_id: usize = 2000; + let old_user_conn = MockConnection { + id: old_user_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + let management_conn = MockConnection { + id: management_conn_id, + handler: get_mock_connection_handler(name), + port: 6379, + }; + + let node = AsyncClusterNode::new( + ConnectionWithIp { + conn: old_user_conn, + ip: Some(prev_ip), + } + .into_future(), + Some( + ConnectionWithIp { + conn: management_conn, + ip: Some(prev_ip), + } + .into_future(), + ), + ); + + let result = connect_and_check::( + &format!("{name}:6379"), + ClusterParams::default(), + None, + RefreshConnectionType::OnlyUserConnection, + Some(node), + GlideConnectionOptions::default(), + ) + .await; + let node = assert_full_success(result); + // Confirm that a new user connection was created + assert_ne!(node.user_connection.conn.await.id, old_user_conn_id); + // Confirm that the management connection remains unchanged + assert_eq!( + node.management_connection.unwrap().conn.await.id, + management_conn_id + ); + } +} + +mod test_check_node_connections { + + use super::*; + use redis::cluster_async::testing::{check_node_connections, ConnectionWithIp}; + fn create_node_with_all_connections(name: &str) -> AsyncClusterNode { + let ip = None; + AsyncClusterNode::new( + ConnectionWithIp { + conn: get_mock_connection_with_port(name, 1, 6380), + ip, + } + .into_future(), + Some( + ConnectionWithIp { + conn: get_mock_connection_with_port(name, 2, 6381), + ip, + } + .into_future(), + ), + ) + } + + #[tokio::test] + async fn test_check_node_connections_find_no_problem() { + // Test that upon when checking both connections, if both connections are healthy no issue is returned. + let name = "test_check_node_connections_find_no_problem"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!(response, None); + } + + #[tokio::test] + async fn test_check_node_connections_find_management_connection_issue() { + // Test that upon checking both connections, if management connection isn't responding to pings, `OnlyManagementConnection` will be returned. + let name = "test_check_node_connections_find_management_connection_issue"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, port| { + if port == 6381 { + return Err(Err((ErrorKind::ClientError, "some error").into())); + } + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!( + response, + Some(RefreshConnectionType::OnlyManagementConnection) + ); + } + + #[tokio::test] + async fn test_check_node_connections_find_missing_management_connection() { + // Test that upon checking both connections, if management connection isn't present, `OnlyManagementConnection` will be returned. + let name = "test_check_node_connections_find_missing_management_connection"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let ip = None; + let node = AsyncClusterNode::new( + ConnectionWithIp { + conn: get_mock_connection(name, 1), + ip, + } + .into_future(), + None, + ); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!( + response, + Some(RefreshConnectionType::OnlyManagementConnection) + ); + } + + #[tokio::test] + async fn test_check_node_connections_find_both_connections_issue() { + // Test that upon checking both connections, if management connection isn't responding to pings, `OnlyManagementConnection` will be returned. + let name = "test_check_node_connections_find_both_connections_issue"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|_, _| Err(Err((ErrorKind::ClientError, "some error").into()))), + ); + + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!(response, Some(RefreshConnectionType::AllConnections)); + } + + #[tokio::test] + async fn test_check_node_connections_find_user_connection_issue() { + // Test that upon checking both connections, if user connection isn't responding to pings, `OnlyUserConnection` will be returned. + let name = "test_check_node_connections_find_user_connection_issue"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, port| { + if port == 6380 { + return Err(Err((ErrorKind::ClientError, "some error").into())); + } + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let node = create_node_with_all_connections(name); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::AllConnections, + name, + ) + .await; + assert_eq!(response, Some(RefreshConnectionType::OnlyUserConnection)); + } + + #[tokio::test] + async fn test_check_node_connections_ignore_missing_management_connection_when_refreshing_user() + { + // Test that upon checking only user connection, issues with management connection won't affect the result. + let name = + "test_check_node_connections_ignore_management_connection_issue_when_refreshing_user"; + + let _handle = MockConnectionBehavior::register_new( + name, + Arc::new(|cmd, _| { + respond_startup(name, cmd)?; + Ok(()) + }), + ); + + let node = AsyncClusterNode::new( + ConnectionWithIp { + conn: get_mock_connection(name, 1), + ip: None, + } + .into_future(), + None, + ); + let response = check_node_connections::( + &node, + &ClusterParams::default(), + RefreshConnectionType::OnlyUserConnection, + name, + ) + .await; + assert_eq!(response, None); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_basic.rs b/glide-core/redis-rs/redis/tests/test_basic.rs new file mode 100644 index 0000000000..e31c33384c --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_basic.rs @@ -0,0 +1,1581 @@ +#![allow(clippy::let_unit_value)] + +mod support; + +#[cfg(test)] +mod basic { + use redis::{cmd, ProtocolVersion, PushInfo}; + use redis::{ + Commands, ConnectionInfo, ConnectionLike, ControlFlow, ErrorKind, ExistenceCheck, Expiry, + PubSubCommands, PushKind, RedisResult, SetExpiry, SetOptions, ToRedisArgs, Value, + }; + use std::collections::{BTreeMap, BTreeSet}; + use std::collections::{HashMap, HashSet}; + use std::thread::{sleep, spawn}; + use std::time::Duration; + use std::vec; + use tokio::sync::mpsc::error::TryRecvError; + + use crate::{assert_args, support::*}; + + #[test] + fn test_parse_redis_url() { + let redis_url = "redis://127.0.0.1:1234/0".to_string(); + redis::parse_redis_url(&redis_url).unwrap(); + redis::parse_redis_url("unix:/var/run/redis/redis.sock").unwrap(); + assert!(redis::parse_redis_url("127.0.0.1").is_none()); + } + + #[test] + fn test_redis_url_fromstr() { + let _info: ConnectionInfo = "redis://127.0.0.1:1234/0".parse().unwrap(); + } + + #[test] + fn test_args() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("key1").arg(b"foo").execute(&mut con); + redis::cmd("SET").arg(&["key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET").arg(&["key1", "key2"]).query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + fn test_getset() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + assert_eq!(redis::cmd("GET").arg("foo").query(&mut con), Ok(42)); + + redis::cmd("SET").arg("bar").arg("foo").execute(&mut con); + assert_eq!( + redis::cmd("GET").arg("bar").query(&mut con), + Ok(b"foo".to_vec()) + ); + } + + //unit test for key_type function + #[test] + fn test_key_type() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + //The key is a simple value + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + let string_key_type: String = con.key_type("foo").unwrap(); + assert_eq!(string_key_type, "string"); + + //The key is a list + redis::cmd("LPUSH") + .arg("list_bar") + .arg("foo") + .execute(&mut con); + let list_key_type: String = con.key_type("list_bar").unwrap(); + assert_eq!(list_key_type, "list"); + + //The key is a set + redis::cmd("SADD") + .arg("set_bar") + .arg("foo") + .execute(&mut con); + let set_key_type: String = con.key_type("set_bar").unwrap(); + assert_eq!(set_key_type, "set"); + + //The key is a sorted set + redis::cmd("ZADD") + .arg("sorted_set_bar") + .arg("1") + .arg("foo") + .execute(&mut con); + let zset_key_type: String = con.key_type("sorted_set_bar").unwrap(); + assert_eq!(zset_key_type, "zset"); + + //The key is a hash + redis::cmd("HSET") + .arg("hset_bar") + .arg("hset_key_1") + .arg("foo") + .execute(&mut con); + let hash_key_type: String = con.key_type("hset_bar").unwrap(); + assert_eq!(hash_key_type, "hash"); + } + + #[test] + fn test_client_tracking_doesnt_block_execution() { + //It checks if the library distinguish a push-type message from the others and continues its normal operation. + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let (k1, k2): (i32, i32) = redis::pipe() + .cmd("CLIENT") + .arg("TRACKING") + .arg("ON") + .ignore() + .cmd("GET") + .arg("key_1") + .ignore() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("GET") + .arg("key_1") + .cmd("GET") + .arg("key_2") + .cmd("SET") + .arg("key_1") + .arg(45) + .ignore() + .query(&mut con) + .unwrap(); + assert_eq!(k1, 42); + assert_eq!(k2, 43); + let num: i32 = con.get("key_1").unwrap(); + assert_eq!(num, 45); + } + + #[test] + fn test_incr() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + assert_eq!(redis::cmd("INCR").arg("foo").query(&mut con), Ok(43usize)); + } + + #[test] + fn test_getdel() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + + assert_eq!(con.get_del("foo"), Ok(42usize)); + + assert_eq!( + redis::cmd("GET").arg("foo").query(&mut con), + Ok(None::) + ); + } + + #[test] + fn test_getex() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42usize).execute(&mut con); + + // Return of get_ex must match set value + let ret_value = con.get_ex::<_, usize>("foo", Expiry::EX(1)).unwrap(); + assert_eq!(ret_value, 42usize); + + // Get before expiry time must also return value + sleep(Duration::from_millis(100)); + let delayed_get = con.get::<_, usize>("foo").unwrap(); + assert_eq!(delayed_get, 42usize); + + // Get after expiry time mustn't return value + sleep(Duration::from_secs(1)); + let after_expire_get = con.get::<_, Option>("foo").unwrap(); + assert_eq!(after_expire_get, None); + + // Persist option test prep + redis::cmd("SET").arg("foo").arg(420usize).execute(&mut con); + + // Return of get_ex with persist option must match set value + let ret_value = con.get_ex::<_, usize>("foo", Expiry::PERSIST).unwrap(); + assert_eq!(ret_value, 420usize); + + // Get after persist get_ex must return value + sleep(Duration::from_millis(200)); + let delayed_get = con.get::<_, usize>("foo").unwrap(); + assert_eq!(delayed_get, 420usize); + } + + #[test] + fn test_info() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let info: redis::InfoDict = redis::cmd("INFO").query(&mut con).unwrap(); + assert_eq!( + info.find(&"role"), + Some(&redis::Value::SimpleString("master".to_string())) + ); + assert_eq!(info.get("role"), Some("master".to_string())); + assert_eq!(info.get("loading"), Some(false)); + assert!(!info.is_empty()); + assert!(info.contains_key(&"role")); + } + + #[test] + fn test_hash_ops() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("HSET") + .arg("foo") + .arg("key_1") + .arg(1) + .execute(&mut con); + redis::cmd("HSET") + .arg("foo") + .arg("key_2") + .arg(2) + .execute(&mut con); + + let h: HashMap = redis::cmd("HGETALL").arg("foo").query(&mut con).unwrap(); + assert_eq!(h.len(), 2); + assert_eq!(h.get("key_1"), Some(&1i32)); + assert_eq!(h.get("key_2"), Some(&2i32)); + + let h: BTreeMap = redis::cmd("HGETALL").arg("foo").query(&mut con).unwrap(); + assert_eq!(h.len(), 2); + assert_eq!(h.get("key_1"), Some(&1i32)); + assert_eq!(h.get("key_2"), Some(&2i32)); + } + + // Requires redis-server >= 4.0.0. + // Not supported with the current appveyor/windows binary deployed. + #[cfg(not(target_os = "windows"))] + #[test] + fn test_unlink() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + assert_eq!(redis::cmd("GET").arg("foo").query(&mut con), Ok(42)); + assert_eq!(con.unlink("foo"), Ok(1)); + + redis::cmd("SET").arg("foo").arg(42).execute(&mut con); + redis::cmd("SET").arg("bar").arg(42).execute(&mut con); + assert_eq!(con.unlink(&["foo", "bar"]), Ok(2)); + } + + #[test] + fn test_set_ops() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.sadd("foo", &[1, 2, 3]), Ok(3)); + + let mut s: Vec = con.smembers("foo").unwrap(); + s.sort_unstable(); + assert_eq!(s.len(), 3); + assert_eq!(&s, &[1, 2, 3]); + + let set: HashSet = con.smembers("foo").unwrap(); + assert_eq!(set.len(), 3); + assert!(set.contains(&1i32)); + assert!(set.contains(&2i32)); + assert!(set.contains(&3i32)); + + let set: BTreeSet = con.smembers("foo").unwrap(); + assert_eq!(set.len(), 3); + assert!(set.contains(&1i32)); + assert!(set.contains(&2i32)); + assert!(set.contains(&3i32)); + } + + #[test] + fn test_scan() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.sadd("foo", &[1, 2, 3]), Ok(3)); + + let (cur, mut s): (i32, Vec) = redis::cmd("SSCAN") + .arg("foo") + .arg(0) + .query(&mut con) + .unwrap(); + s.sort_unstable(); + assert_eq!(cur, 0i32); + assert_eq!(s.len(), 3); + assert_eq!(&s, &[1, 2, 3]); + } + + #[test] + fn test_optionals() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("SET").arg("foo").arg(1).execute(&mut con); + + let (a, b): (Option, Option) = redis::cmd("MGET") + .arg("foo") + .arg("missing") + .query(&mut con) + .unwrap(); + assert_eq!(a, Some(1i32)); + assert_eq!(b, None); + + let a = redis::cmd("GET") + .arg("missing") + .query(&mut con) + .unwrap_or(0i32); + assert_eq!(a, 0i32); + } + + #[test] + fn test_scanning() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let mut unseen = HashSet::new(); + + for x in 0..1000 { + redis::cmd("SADD").arg("foo").arg(x).execute(&mut con); + unseen.insert(x); + } + + let iter = redis::cmd("SSCAN") + .arg("foo") + .cursor_arg(0) + .clone() + .iter(&mut con) + .unwrap(); + + for x in iter { + // type inference limitations + let x: usize = x; + unseen.remove(&x); + } + + assert_eq!(unseen.len(), 0); + } + + #[test] + fn test_filtered_scanning() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let mut unseen = HashSet::new(); + + for x in 0..3000 { + let _: () = con + .hset("foo", format!("key_{}_{}", x % 100, x), x) + .unwrap(); + if x % 100 == 0 { + unseen.insert(x); + } + } + + let iter = con + .hscan_match::<&str, &str, (String, usize)>("foo", "key_0_*") + .unwrap(); + + for (_field, value) in iter { + unseen.remove(&value); + } + + assert_eq!(unseen.len(), 0); + } + + #[test] + fn test_pipeline() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let ((k1, k2),): ((i32, i32),) = redis::pipe() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["key_1", "key_2"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 43); + } + + #[test] + fn test_pipeline_with_err() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = redis::cmd("SET") + .arg("x") + .arg("x-value") + .query(&mut con) + .unwrap(); + let _: () = redis::cmd("SET") + .arg("y") + .arg("y-value") + .query(&mut con) + .unwrap(); + + let _: () = redis::cmd("SLAVEOF") + .arg("1.1.1.1") + .arg("99") + .query(&mut con) + .unwrap(); + + let res = redis::pipe() + .set("x", "another-x-value") + .ignore() + .get("y") + .query::<()>(&mut con); + assert!(res.is_err() && res.unwrap_err().kind() == ErrorKind::ReadOnly); + + // Make sure we don't get leftover responses from the pipeline ("y-value"). See #436. + let res = redis::cmd("GET") + .arg("x") + .query::(&mut con) + .unwrap(); + assert_eq!(res, "x-value"); + } + + #[test] + fn test_empty_pipeline() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = redis::pipe().cmd("PING").ignore().query(&mut con).unwrap(); + + let _: () = redis::pipe().query(&mut con).unwrap(); + } + + #[test] + fn test_pipeline_transaction() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let ((k1, k2),): ((i32, i32),) = redis::pipe() + .atomic() + .cmd("SET") + .arg("key_1") + .arg(42) + .ignore() + .cmd("SET") + .arg("key_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["key_1", "key_2"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 43); + } + + #[test] + fn test_pipeline_transaction_with_errors() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set("x", 42).unwrap(); + + // Make Redis a replica of a nonexistent master, thereby making it read-only. + let _: () = redis::cmd("slaveof") + .arg("1.1.1.1") + .arg("1") + .query(&mut con) + .unwrap(); + + // Ensure that a write command fails with a READONLY error + let err: RedisResult<()> = redis::pipe() + .atomic() + .set("x", 142) + .ignore() + .get("x") + .query(&mut con); + + assert_eq!(err.unwrap_err().kind(), ErrorKind::ReadOnly); + + let x: i32 = con.get("x").unwrap(); + assert_eq!(x, 42); + } + + #[test] + fn test_pipeline_reuse_query() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let mut pl = redis::pipe(); + + let ((k1,),): ((i32,),) = pl + .cmd("SET") + .arg("pkey_1") + .arg(42) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + + redis::cmd("DEL").arg("pkey_1").execute(&mut con); + + // The internal commands vector of the pipeline still contains the previous commands. + let ((k1,), (k2, k3)): ((i32,), (i32, i32)) = pl + .cmd("SET") + .arg("pkey_2") + .arg(43) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .arg(&["pkey_2"]) + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 42); + assert_eq!(k3, 43); + } + + #[test] + fn test_pipeline_reuse_query_clear() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let mut pl = redis::pipe(); + + let ((k1,),): ((i32,),) = pl + .cmd("SET") + .arg("pkey_1") + .arg(44) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .query(&mut con) + .unwrap(); + pl.clear(); + + assert_eq!(k1, 44); + + redis::cmd("DEL").arg("pkey_1").execute(&mut con); + + let ((k1, k2),): ((bool, i32),) = pl + .cmd("SET") + .arg("pkey_2") + .arg(45) + .ignore() + .cmd("MGET") + .arg(&["pkey_1"]) + .arg(&["pkey_2"]) + .query(&mut con) + .unwrap(); + pl.clear(); + + assert!(!k1); + assert_eq!(k2, 45); + } + + #[test] + fn test_real_transaction() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let key = "the_key"; + let _: () = redis::cmd("SET").arg(key).arg(42).query(&mut con).unwrap(); + + loop { + let _: () = redis::cmd("WATCH").arg(key).query(&mut con).unwrap(); + let val: isize = redis::cmd("GET").arg(key).query(&mut con).unwrap(); + let response: Option<(isize,)> = redis::pipe() + .atomic() + .cmd("SET") + .arg(key) + .arg(val + 1) + .ignore() + .cmd("GET") + .arg(key) + .query(&mut con) + .unwrap(); + + match response { + None => { + continue; + } + Some(response) => { + assert_eq!(response, (43,)); + break; + } + } + } + } + + #[test] + fn test_real_transaction_highlevel() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let key = "the_key"; + let _: () = redis::cmd("SET").arg(key).arg(42).query(&mut con).unwrap(); + + let response: (isize,) = redis::transaction(&mut con, &[key], |con, pipe| { + let val: isize = redis::cmd("GET").arg(key).query(con)?; + pipe.cmd("SET") + .arg(key) + .arg(val + 1) + .ignore() + .cmd("GET") + .arg(key) + .query(con) + }) + .unwrap(); + + assert_eq!(response, (43,)); + } + + #[test] + fn test_pubsub() { + use std::sync::{Arc, Barrier}; + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // Connection for subscriber api + let mut pubsub_con = ctx.connection(); + + // Barrier is used to make test thread wait to publish + // until after the pubsub thread has subscribed. + let barrier = Arc::new(Barrier::new(2)); + let pubsub_barrier = barrier.clone(); + + let thread = spawn(move || { + let mut pubsub = pubsub_con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + + let _ = pubsub_barrier.wait(); + + let msg = pubsub.get_message().unwrap(); + assert_eq!(msg.get_channel(), Ok("foo".to_string())); + assert_eq!(msg.get_payload(), Ok(42)); + + let msg = pubsub.get_message().unwrap(); + assert_eq!(msg.get_channel(), Ok("foo".to_string())); + assert_eq!(msg.get_payload(), Ok(23)); + }); + + let _ = barrier.wait(); + redis::cmd("PUBLISH").arg("foo").arg(42).execute(&mut con); + // We can also call the command directly + assert_eq!(con.publish("foo", 23), Ok(1)); + + thread.join().expect("Something went wrong"); + } + + #[test] + fn test_pubsub_unsubscribe() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + { + let mut pubsub = con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + pubsub.subscribe("bar").unwrap(); + pubsub.subscribe("baz").unwrap(); + pubsub.psubscribe("foo*").unwrap(); + pubsub.psubscribe("bar*").unwrap(); + pubsub.psubscribe("baz*").unwrap(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + fn test_pubsub_subscribe_while_messages_are_sent() { + let ctx = TestContext::new(); + let mut conn_external = ctx.connection(); + let mut conn_internal = ctx.connection(); + let received = std::sync::Arc::new(std::sync::Mutex::new(Vec::new())); + let received_clone = received.clone(); + let (sender, receiver) = std::sync::mpsc::channel(); + // receive message from foo channel + let thread = std::thread::spawn(move || { + let mut pubsub = conn_internal.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + sender.send(()).unwrap(); + loop { + let msg = pubsub.get_message().unwrap(); + let channel = msg.get_channel_name(); + let content: i32 = msg.get_payload().unwrap(); + received + .lock() + .unwrap() + .push(format!("{channel}:{content}")); + if content == -1 { + return; + } + if content == 5 { + // subscribe bar channel using the same pubsub + pubsub.subscribe("bar").unwrap(); + sender.send(()).unwrap(); + } + } + }); + receiver.recv().unwrap(); + + // send message to foo channel after channel is ready. + for index in 0..10 { + println!("publishing on foo {index}"); + redis::cmd("PUBLISH") + .arg("foo") + .arg(index) + .query::(&mut conn_external) + .unwrap(); + } + receiver.recv().unwrap(); + redis::cmd("PUBLISH") + .arg("bar") + .arg(-1) + .query::(&mut conn_external) + .unwrap(); + thread.join().unwrap(); + assert_eq!( + *received_clone.lock().unwrap(), + (0..10) + .map(|index| format!("foo:{}", index)) + .chain(std::iter::once("bar:-1".to_string())) + .collect::>() + ); + } + + #[test] + fn test_pubsub_unsubscribe_no_subs() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + { + let _pubsub = con.as_pubsub(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + fn test_pubsub_unsubscribe_one_sub() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + { + let mut pubsub = con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + fn test_pubsub_unsubscribe_one_sub_one_psub() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + { + let mut pubsub = con.as_pubsub(); + pubsub.subscribe("foo").unwrap(); + pubsub.psubscribe("foo*").unwrap(); + } + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = con.set("foo", "bar").unwrap(); + let value: String = con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + fn scoped_pubsub() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // Connection for subscriber api + let mut pubsub_con = ctx.connection(); + + let thread = spawn(move || { + let mut count = 0; + pubsub_con + .subscribe(&["foo", "bar"], |msg| { + count += 1; + match count { + 1 => { + assert_eq!(msg.get_channel(), Ok("foo".to_string())); + assert_eq!(msg.get_payload(), Ok(42)); + ControlFlow::Continue + } + 2 => { + assert_eq!(msg.get_channel(), Ok("bar".to_string())); + assert_eq!(msg.get_payload(), Ok(23)); + ControlFlow::Break(()) + } + _ => ControlFlow::Break(()), + } + }) + .unwrap(); + + pubsub_con + }); + + // Can't use a barrier in this case since there's no opportunity to run code + // between channel subscription and blocking for messages. + sleep(Duration::from_millis(100)); + + redis::cmd("PUBLISH").arg("foo").arg(42).execute(&mut con); + assert_eq!(con.publish("bar", 23), Ok(1)); + + // Wait for thread + let mut pubsub_con = thread.join().expect("pubsub thread terminates ok"); + + // Connection should be usable again for non-pubsub commands + let _: redis::Value = pubsub_con.set("foo", "bar").unwrap(); + let value: String = pubsub_con.get("foo").unwrap(); + assert_eq!(&value[..], "bar"); + } + + #[test] + #[cfg(feature = "script")] + fn test_script() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let script = redis::Script::new( + r" + return {redis.call('GET', KEYS[1]), ARGV[1]} + ", + ); + + let _: () = redis::cmd("SET") + .arg("my_key") + .arg("foo") + .query(&mut con) + .unwrap(); + let response = script.key("my_key").arg(42).invoke(&mut con); + + assert_eq!(response, Ok(("foo".to_string(), 42))); + } + + #[test] + #[cfg(feature = "script")] + fn test_script_load() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let script = redis::Script::new("return 'Hello World'"); + + let hash = script.prepare_invoke().load(&mut con); + + assert_eq!(hash, Ok(script.get_hash().to_string())); + } + + #[test] + fn test_tuple_args() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + redis::cmd("HMSET") + .arg("my_key") + .arg(&[("field_1", 42), ("field_2", 23)]) + .execute(&mut con); + + assert_eq!( + redis::cmd("HGET") + .arg("my_key") + .arg("field_1") + .query(&mut con), + Ok(42) + ); + assert_eq!( + redis::cmd("HGET") + .arg("my_key") + .arg("field_2") + .query(&mut con), + Ok(23) + ); + } + + #[test] + fn test_nice_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.set("my_key", 42), Ok(())); + assert_eq!(con.get("my_key"), Ok(42)); + + let (k1, k2): (i32, i32) = redis::pipe() + .atomic() + .set("key_1", 42) + .ignore() + .set("key_2", 43) + .ignore() + .get("key_1") + .get("key_2") + .query(&mut con) + .unwrap(); + + assert_eq!(k1, 42); + assert_eq!(k2, 43); + } + + #[test] + fn test_auto_m_versions() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.mset(&[("key1", 1), ("key2", 2)]), Ok(())); + assert_eq!(con.get(&["key1", "key2"]), Ok((1, 2))); + assert_eq!(con.get(vec!["key1", "key2"]), Ok((1, 2))); + assert_eq!(con.get(vec!["key1", "key2"]), Ok((1, 2))); + } + + #[test] + fn test_nice_hash_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!( + con.hset_multiple("my_hash", &[("f1", 1), ("f2", 2), ("f3", 4), ("f4", 8)]), + Ok(()) + ); + + let hm: HashMap = con.hgetall("my_hash").unwrap(); + assert_eq!(hm.get("f1"), Some(&1)); + assert_eq!(hm.get("f2"), Some(&2)); + assert_eq!(hm.get("f3"), Some(&4)); + assert_eq!(hm.get("f4"), Some(&8)); + assert_eq!(hm.len(), 4); + + let hm: BTreeMap = con.hgetall("my_hash").unwrap(); + assert_eq!(hm.get("f1"), Some(&1)); + assert_eq!(hm.get("f2"), Some(&2)); + assert_eq!(hm.get("f3"), Some(&4)); + assert_eq!(hm.get("f4"), Some(&8)); + assert_eq!(hm.len(), 4); + + let v: Vec<(String, isize)> = con.hgetall("my_hash").unwrap(); + assert_eq!( + v, + vec![ + ("f1".to_string(), 1), + ("f2".to_string(), 2), + ("f3".to_string(), 4), + ("f4".to_string(), 8), + ] + ); + + assert_eq!(con.hget("my_hash", &["f2", "f4"]), Ok((2, 8))); + assert_eq!(con.hincr("my_hash", "f1", 1), Ok(2)); + assert_eq!(con.hincr("my_hash", "f2", 1.5f32), Ok(3.5f32)); + assert_eq!(con.hexists("my_hash", "f2"), Ok(true)); + assert_eq!(con.hdel("my_hash", &["f1", "f2"]), Ok(())); + assert_eq!(con.hexists("my_hash", "f2"), Ok(false)); + + let iter: redis::Iter<'_, (String, isize)> = con.hscan("my_hash").unwrap(); + let mut found = HashSet::new(); + for item in iter { + found.insert(item); + } + + assert_eq!(found.len(), 2); + assert!(found.contains(&("f3".to_string(), 4))); + assert!(found.contains(&("f4".to_string(), 8))); + } + + #[test] + fn test_nice_list_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.rpush("my_list", &[1, 2, 3, 4]), Ok(4)); + assert_eq!(con.rpush("my_list", &[5, 6, 7, 8]), Ok(8)); + assert_eq!(con.llen("my_list"), Ok(8)); + + assert_eq!(con.lpop("my_list", Default::default()), Ok(1)); + assert_eq!(con.llen("my_list"), Ok(7)); + + assert_eq!(con.lrange("my_list", 0, 2), Ok((2, 3, 4))); + + assert_eq!(con.lset("my_list", 0, 4), Ok(true)); + assert_eq!(con.lrange("my_list", 0, 2), Ok((4, 3, 4))); + + #[cfg(not(windows))] + //Windows version of redis is limited to v3.x + { + let my_list: Vec = con.lrange("my_list", 0, 10).expect("To get range"); + assert_eq!( + con.lpop("my_list", core::num::NonZeroUsize::new(10)), + Ok(my_list) + ); + } + } + + #[test] + fn test_tuple_decoding_regression() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.del("my_zset"), Ok(())); + assert_eq!(con.zadd("my_zset", "one", 1), Ok(1)); + assert_eq!(con.zadd("my_zset", "two", 2), Ok(1)); + + let vec: Vec<(String, u32)> = con.zrangebyscore_withscores("my_zset", 0, 10).unwrap(); + assert_eq!(vec.len(), 2); + + assert_eq!(con.del("my_zset"), Ok(1)); + + let vec: Vec<(String, u32)> = con.zrangebyscore_withscores("my_zset", 0, 10).unwrap(); + assert_eq!(vec.len(), 0); + } + + #[test] + fn test_bit_operations() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.setbit("bitvec", 10, true), Ok(false)); + assert_eq!(con.getbit("bitvec", 10), Ok(true)); + } + + #[test] + fn test_redis_server_down() { + let mut ctx = TestContext::new(); + let mut con = ctx.connection(); + + let ping = redis::cmd("PING").query::(&mut con); + assert_eq!(ping, Ok("PONG".into())); + + ctx.stop_server(); + + let ping = redis::cmd("PING").query::(&mut con); + + assert!(ping.is_err()); + eprintln!("{}", ping.unwrap_err()); + assert!(!con.is_open()); + } + + #[test] + fn test_zinterstore_weights() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con + .zadd_multiple("zset1", &[(1, "one"), (2, "two"), (4, "four")]) + .unwrap(); + let _: () = con + .zadd_multiple("zset2", &[(1, "one"), (2, "two"), (3, "three")]) + .unwrap(); + + // zinterstore_weights + assert_eq!( + con.zinterstore_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(2) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "5".to_string()), + ("two".to_string(), "10".to_string()) + ]) + ); + + // zinterstore_min_weights + assert_eq!( + con.zinterstore_min_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(2) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "2".to_string()), + ("two".to_string(), "4".to_string()), + ]) + ); + + // zinterstore_max_weights + assert_eq!( + con.zinterstore_max_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(2) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "3".to_string()), + ("two".to_string(), "6".to_string()), + ]) + ); + } + + #[test] + fn test_zunionstore_weights() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con + .zadd_multiple("zset1", &[(1, "one"), (2, "two")]) + .unwrap(); + let _: () = con + .zadd_multiple("zset2", &[(1, "one"), (2, "two"), (3, "three")]) + .unwrap(); + + // zunionstore_weights + assert_eq!( + con.zunionstore_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(3) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "5".to_string()), + ("three".to_string(), "9".to_string()), + ("two".to_string(), "10".to_string()) + ]) + ); + // test converting to double + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), 5.0), + ("three".to_string(), 9.0), + ("two".to_string(), 10.0) + ]) + ); + + // zunionstore_min_weights + assert_eq!( + con.zunionstore_min_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(3) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "2".to_string()), + ("two".to_string(), "4".to_string()), + ("three".to_string(), "9".to_string()) + ]) + ); + + // zunionstore_max_weights + assert_eq!( + con.zunionstore_max_weights("out", &[("zset1", 2), ("zset2", 3)]), + Ok(3) + ); + + assert_eq!( + con.zrange_withscores("out", 0, -1), + Ok(vec![ + ("one".to_string(), "3".to_string()), + ("two".to_string(), "6".to_string()), + ("three".to_string(), "9".to_string()) + ]) + ); + } + + #[test] + fn test_zrembylex() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myzset"; + assert_eq!( + con.zadd_multiple( + setname, + &[ + (0, "apple"), + (0, "banana"), + (0, "carrot"), + (0, "durian"), + (0, "eggplant"), + (0, "grapes"), + ], + ), + Ok(6) + ); + + // Will remove "banana", "carrot", "durian" and "eggplant" + let num_removed: u32 = con.zrembylex(setname, "[banana", "[eggplant").unwrap(); + assert_eq!(4, num_removed); + + let remaining: Vec = con.zrange(setname, 0, -1).unwrap(); + assert_eq!(remaining, vec!["apple".to_string(), "grapes".to_string()]); + } + + // Requires redis-server >= 6.2.0. + // Not supported with the current appveyor/windows binary deployed. + #[cfg(not(target_os = "windows"))] + #[test] + fn test_zrandmember() { + use redis::ProtocolVersion; + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myzrandset"; + let () = con.zadd(setname, "one", 1).unwrap(); + + let result: String = con.zrandmember(setname, None).unwrap(); + assert_eq!(result, "one".to_string()); + + let result: Vec = con.zrandmember(setname, Some(1)).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0], "one".to_string()); + + let result: Vec = con.zrandmember(setname, Some(2)).unwrap(); + assert_eq!(result.len(), 1); + assert_eq!(result[0], "one".to_string()); + + assert_eq!( + con.zadd_multiple( + setname, + &[(2, "two"), (3, "three"), (4, "four"), (5, "five")] + ), + Ok(4) + ); + + let results: Vec = con.zrandmember(setname, Some(5)).unwrap(); + assert_eq!(results.len(), 5); + + let results: Vec = con.zrandmember(setname, Some(-5)).unwrap(); + assert_eq!(results.len(), 5); + + if ctx.protocol == ProtocolVersion::RESP2 { + let results: Vec = con.zrandmember_withscores(setname, 5).unwrap(); + assert_eq!(results.len(), 10); + + let results: Vec = con.zrandmember_withscores(setname, -5).unwrap(); + assert_eq!(results.len(), 10); + } + + let results: Vec<(String, f64)> = con.zrandmember_withscores(setname, 5).unwrap(); + assert_eq!(results.len(), 5); + + let results: Vec<(String, f64)> = con.zrandmember_withscores(setname, -5).unwrap(); + assert_eq!(results.len(), 5); + } + + #[test] + fn test_sismember() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myset"; + assert_eq!(con.sadd(setname, &["a"]), Ok(1)); + + let result: bool = con.sismember(setname, &["a"]).unwrap(); + assert!(result); + + let result: bool = con.sismember(setname, &["b"]).unwrap(); + assert!(!result); + } + + // Requires redis-server >= 6.2.0. + // Not supported with the current appveyor/windows binary deployed. + #[cfg(not(target_os = "windows"))] + #[test] + fn test_smismember() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let setname = "myset"; + assert_eq!(con.sadd(setname, &["a", "b", "c"]), Ok(3)); + let results: Vec = con.smismember(setname, &["0", "a", "b", "c", "x"]).unwrap(); + assert_eq!(results, vec![false, true, true, true, false]); + } + + #[test] + fn test_object_commands() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set("object_key_str", "object_value_str").unwrap(); + let _: () = con.set("object_key_int", 42).unwrap(); + + assert_eq!( + con.object_encoding::<_, String>("object_key_str").unwrap(), + "embstr" + ); + + assert_eq!( + con.object_encoding::<_, String>("object_key_int").unwrap(), + "int" + ); + + assert!(con.object_idletime::<_, i32>("object_key_str").unwrap() <= 1); + assert_eq!(con.object_refcount::<_, i32>("object_key_str").unwrap(), 1); + + // Needed for OBJECT FREQ and can't be set before object_idletime + // since that will break getting the idletime before idletime adjuts + redis::cmd("CONFIG") + .arg("SET") + .arg(b"maxmemory-policy") + .arg("allkeys-lfu") + .execute(&mut con); + + let _: () = con.get("object_key_str").unwrap(); + // since maxmemory-policy changed, freq should reset to 1 since we only called + // get after that + assert_eq!(con.object_freq::<_, i32>("object_key_str").unwrap(), 1); + } + + #[test] + fn test_mget() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set(1, "1").unwrap(); + let data: Vec = con.mget(&[1]).unwrap(); + assert_eq!(data, vec!["1"]); + + let _: () = con.set(2, "2").unwrap(); + let data: Vec = con.mget(&[1, 2]).unwrap(); + assert_eq!(data, vec!["1", "2"]); + + let data: Vec> = con.mget(&[4]).unwrap(); + assert_eq!(data, vec![None]); + + let data: Vec> = con.mget(&[2, 4]).unwrap(); + assert_eq!(data, vec![Some("2".to_string()), None]); + } + + #[test] + fn test_variable_length_get() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let _: () = con.set(1, "1").unwrap(); + let keys = vec![1]; + assert_eq!(keys.len(), 1); + let data: Vec = con.get(&keys).unwrap(); + assert_eq!(data, vec!["1"]); + } + + #[test] + fn test_multi_generics() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.sadd(b"set1", vec![5, 42]), Ok(2)); + assert_eq!(con.sadd(999_i64, vec![42, 123]), Ok(2)); + let _: () = con.rename(999_i64, b"set2").unwrap(); + assert_eq!(con.sunionstore("res", &[b"set1", b"set2"]), Ok(3)); + } + + #[test] + fn test_set_options_with_get() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + let opts = SetOptions::default().get(true); + let data: Option = con.set_options(1, "1", opts).unwrap(); + assert_eq!(data, None); + + let opts = SetOptions::default().get(true); + let data: Option = con.set_options(1, "1", opts).unwrap(); + assert_eq!(data, Some("1".to_string())); + } + + #[test] + fn test_set_options_options() { + let empty = SetOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::NX) + .get(true) + .with_expiration(SetExpiry::PX(1000)); + + assert_args!(&opts, "NX", "GET", "PX", "1000"); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::XX) + .get(true) + .with_expiration(SetExpiry::PX(1000)); + + assert_args!(&opts, "XX", "GET", "PX", "1000"); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::XX) + .with_expiration(SetExpiry::KEEPTTL); + + assert_args!(&opts, "XX", "KEEPTTL"); + + let opts = SetOptions::default() + .conditional_set(ExistenceCheck::XX) + .with_expiration(SetExpiry::EXAT(100)); + + assert_args!(&opts, "XX", "EXAT", "100"); + + let opts = SetOptions::default().with_expiration(SetExpiry::EX(1000)); + + assert_args!(&opts, "EX", "1000"); + } + + #[test] + fn test_blocking_sorted_set_api() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // setup version & input data followed by assertions that take into account Redis version + // BZPOPMIN & BZPOPMAX are available from Redis version 5.0.0 + // BZMPOP is available from Redis version 7.0.0 + + let redis_version = ctx.get_version(); + assert!(redis_version.0 >= 5); + + assert_eq!(con.zadd("a", "1a", 1), Ok(())); + assert_eq!(con.zadd("b", "2b", 2), Ok(())); + assert_eq!(con.zadd("c", "3c", 3), Ok(())); + assert_eq!(con.zadd("d", "4d", 4), Ok(())); + assert_eq!(con.zadd("a", "5a", 5), Ok(())); + assert_eq!(con.zadd("b", "6b", 6), Ok(())); + assert_eq!(con.zadd("c", "7c", 7), Ok(())); + assert_eq!(con.zadd("d", "8d", 8), Ok(())); + + let min = con.bzpopmin::<&str, (String, String, String)>("b", 0.0); + let max = con.bzpopmax::<&str, (String, String, String)>("b", 0.0); + + assert_eq!( + min.unwrap(), + (String::from("b"), String::from("2b"), String::from("2")) + ); + assert_eq!( + max.unwrap(), + (String::from("b"), String::from("6b"), String::from("6")) + ); + + if redis_version.0 >= 7 { + let min = con.bzmpop_min::<&str, (String, Vec>)>( + 0.0, + vec!["a", "b", "c", "d"].as_slice(), + 1, + ); + let max = con.bzmpop_max::<&str, (String, Vec>)>( + 0.0, + vec!["a", "b", "c", "d"].as_slice(), + 1, + ); + + assert_eq!( + min.unwrap().1[0][0], + (String::from("1a"), String::from("1")) + ); + assert_eq!( + max.unwrap().1[0][0], + (String::from("5a"), String::from("5")) + ); + } + } + + #[test] + fn test_set_client_name_by_config() { + const CLIENT_NAME: &str = "TEST_CLIENT_NAME"; + + let ctx = TestContext::with_client_name(CLIENT_NAME); + let mut con = ctx.connection(); + + let client_info: String = redis::cmd("CLIENT").arg("INFO").query(&mut con).unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], CLIENT_NAME, + "Incorrect client name, expecting: {}, got {}", + CLIENT_NAME, client_attrs["name"] + ); + } + + #[test] + fn test_push_manager() { + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + let mut con = ctx.connection(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + con.get_push_manager().replace_sender(tx.clone()); + let _ = cmd("CLIENT") + .arg("TRACKING") + .arg("ON") + .query::<()>(&mut con) + .unwrap(); + let pipe = build_simple_pipeline_for_invalidation(); + for _ in 0..10 { + let _: RedisResult<()> = pipe.query(&mut con); + let _: i32 = con.get("key_1").unwrap(); + let PushInfo { kind, data } = rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + } + let (new_tx, mut new_rx) = tokio::sync::mpsc::unbounded_channel(); + con.get_push_manager().replace_sender(new_tx.clone()); + drop(rx); + let _: RedisResult<()> = pipe.query(&mut con); + let _: i32 = con.get("key_1").unwrap(); + let PushInfo { kind, data } = new_rx.try_recv().unwrap(); + assert_eq!( + ( + PushKind::Invalidate, + vec![Value::Array(vec![Value::BulkString( + "key_1".as_bytes().to_vec() + )])] + ), + (kind, data) + ); + + { + drop(new_rx); + for _ in 0..10 { + let _: RedisResult<()> = pipe.query(&mut con); + let v: i32 = con.get("key_1").unwrap(); + assert_eq!(v, 42); + } + } + } + + #[test] + fn test_push_manager_disconnection() { + let ctx = TestContext::new(); + if ctx.protocol == ProtocolVersion::RESP2 { + return; + } + let mut con = ctx.connection(); + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel(); + con.get_push_manager().replace_sender(tx.clone()); + + let _: () = con.set("A", "1").unwrap(); + assert_eq!(rx.try_recv().unwrap_err(), TryRecvError::Empty); + drop(ctx); + let x: RedisResult<()> = con.set("A", "1"); + assert!(x.is_err()); + assert_eq!(rx.try_recv().unwrap().kind, PushKind::Disconnection); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_bignum.rs b/glide-core/redis-rs/redis/tests/test_bignum.rs new file mode 100644 index 0000000000..20beefbc66 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_bignum.rs @@ -0,0 +1,61 @@ +#![cfg(any( + feature = "rust_decimal", + feature = "bigdecimal", + feature = "num-bigint" +))] +use redis::{ErrorKind, FromRedisValue, RedisResult, ToRedisArgs, Value}; +use std::str::FromStr; + +fn test(content: &str) +where + T: FromRedisValue + + ToRedisArgs + + std::str::FromStr + + std::convert::From + + std::cmp::PartialEq + + std::fmt::Debug, + ::Err: std::fmt::Debug, +{ + let v: RedisResult = + FromRedisValue::from_redis_value(&Value::BulkString(Vec::from(content))); + assert_eq!(v, Ok(T::from_str(content).unwrap())); + + let arg = ToRedisArgs::to_redis_args(&v.unwrap()); + assert_eq!(arg[0], Vec::from(content)); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(0)); + assert_eq!(v.unwrap(), T::from(0u32)); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(42)); + assert_eq!(v.unwrap(), T::from(42u32)); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Okay); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); +} + +#[test] +#[cfg(feature = "rust_decimal")] +fn test_rust_decimal() { + test::("-79228162514264.337593543950335"); +} + +#[test] +#[cfg(feature = "bigdecimal")] +fn test_bigdecimal() { + test::("-14272476927059598810582859.69449495136382746623"); +} + +#[test] +#[cfg(feature = "num-bigint")] +fn test_bigint() { + test::("-1427247692705959881058285969449495136382746623"); +} + +#[test] +#[cfg(feature = "num-bigint")] +fn test_biguint() { + test::("1427247692705959881058285969449495136382746623"); +} diff --git a/glide-core/redis-rs/redis/tests/test_cluster.rs b/glide-core/redis-rs/redis/tests/test_cluster.rs new file mode 100644 index 0000000000..cbeddd2fe4 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_cluster.rs @@ -0,0 +1,1093 @@ +#![cfg(feature = "cluster")] +mod support; + +#[cfg(test)] +mod cluster { + use std::sync::{ + atomic::{self, AtomicI32, Ordering}, + Arc, + }; + + use crate::support::*; + use redis::{ + cluster::{cluster_pipe, ClusterClient}, + cmd, parse_redis_value, Commands, ConnectionLike, ErrorKind, ProtocolVersion, RedisError, + Value, + }; + + #[test] + fn test_cluster_basics() { + let cluster = TestClusterContext::new(3, 0); + let mut con = cluster.connection(); + + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + fn test_cluster_with_username_and_password() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .username(RedisCluster::username().to_string()) + .password(RedisCluster::password().to_string()) + }, + false, + ); + cluster.disable_default_user(); + + let mut con = cluster.connection(); + + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + fn test_cluster_with_bad_password() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .username(RedisCluster::username().to_string()) + .password("not the right password".to_string()) + }, + false, + ); + assert!(cluster.client.get_connection(None).is_err()); + } + + #[test] + fn test_cluster_read_from_replicas() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| builder.read_from_replicas(), + false, + ); + let mut con = cluster.connection(); + + // Write commands would go to the primary nodes + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + // Read commands would go to the replica nodes + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + fn test_cluster_eval() { + let cluster = TestClusterContext::new(3, 0); + let mut con = cluster.connection(); + + let rv = redis::cmd("EVAL") + .arg( + r#" + redis.call("SET", KEYS[1], "1"); + redis.call("SET", KEYS[2], "2"); + return redis.call("MGET", KEYS[1], KEYS[2]); + "#, + ) + .arg("2") + .arg("{x}a") + .arg("{x}b") + .query(&mut con); + + assert_eq!(rv, Ok(("1".to_string(), "2".to_string()))); + } + + #[test] + fn test_cluster_resp3() { + if use_protocol() == ProtocolVersion::RESP2 { + return; + } + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.connection(); + + let hello: std::collections::HashMap = + redis::cmd("HELLO").query(&mut connection).unwrap(); + assert_eq!(hello.get("proto").unwrap(), &Value::Int(3)); + + let _: () = connection.hset("hash", "foo", "baz").unwrap(); + let _: () = connection.hset("hash", "bar", "foobar").unwrap(); + let result: Value = connection.hgetall("hash").unwrap(); + + assert_eq!( + result, + Value::Map(vec![ + ( + Value::BulkString("foo".as_bytes().to_vec()), + Value::BulkString("baz".as_bytes().to_vec()) + ), + ( + Value::BulkString("bar".as_bytes().to_vec()), + Value::BulkString("foobar".as_bytes().to_vec()) + ) + ]) + ); + } + + #[test] + fn test_cluster_multi_shard_commands() { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.connection(); + + let res: String = connection + .mset(&[("foo", "bar"), ("bar", "foo"), ("baz", "bazz")]) + .unwrap(); + assert_eq!(res, "OK"); + let res: Vec = connection.mget(&["baz", "foo", "bar"]).unwrap(); + assert_eq!(res, vec!["bazz", "bar", "foo"]); + } + + #[test] + #[cfg(feature = "script")] + fn test_cluster_script() { + let cluster = TestClusterContext::new(3, 0); + let mut con = cluster.connection(); + + let script = redis::Script::new( + r#" + redis.call("SET", KEYS[1], "1"); + redis.call("SET", KEYS[2], "2"); + return redis.call("MGET", KEYS[1], KEYS[2]); + "#, + ); + + let rv = script.key("{x}a").key("{x}b").invoke(&mut con); + assert_eq!(rv, Ok(("1".to_string(), "2".to_string()))); + } + + #[test] + fn test_cluster_pipeline() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + + let resp = cluster_pipe() + .cmd("SET") + .arg("key_1") + .arg(42) + .query::>(&mut con) + .unwrap(); + + assert_eq!(resp, vec!["OK".to_string()]); + } + + #[test] + fn test_cluster_pipeline_multiple_keys() { + use redis::FromRedisValue; + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + + let resp = cluster_pipe() + .cmd("HSET") + .arg("hash_1") + .arg("key_1") + .arg("value_1") + .cmd("ZADD") + .arg("zset") + .arg(1) + .arg("zvalue_2") + .query::>(&mut con) + .unwrap(); + + assert_eq!(resp, vec![1i64, 1i64]); + + let resp = cluster_pipe() + .cmd("HGET") + .arg("hash_1") + .arg("key_1") + .cmd("ZCARD") + .arg("zset") + .query::>(&mut con) + .unwrap(); + + let resp_1: String = FromRedisValue::from_redis_value(&resp[0]).unwrap(); + assert_eq!(resp_1, "value_1".to_string()); + + let resp_2: usize = FromRedisValue::from_redis_value(&resp[1]).unwrap(); + assert_eq!(resp_2, 1); + } + + #[test] + fn test_cluster_pipeline_invalid_command() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + + let err = cluster_pipe() + .cmd("SET") + .arg("foo") + .arg(42) + .ignore() + .cmd(" SCRIPT kill ") + .query::<()>(&mut con) + .unwrap_err(); + + assert_eq!( + err.to_string(), + "This command cannot be safely routed in cluster mode - ClientError: Command 'SCRIPT KILL' can't be executed in a cluster pipeline." + ); + + let err = cluster_pipe().keys("*").query::<()>(&mut con).unwrap_err(); + + assert_eq!( + err.to_string(), + "This command cannot be safely routed in cluster mode - ClientError: Command 'KEYS' can't be executed in a cluster pipeline." + ); + } + + #[test] + fn test_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name() { + let name = "test_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name"; + + let MockEnv { mut connection, .. } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("".as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = cmd("GET").arg("test").query::(&mut connection); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name() { + let name = + "test_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name"; + + let MockEnv { mut connection, .. } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![Value::Nil, Value::Int(6379)]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = cmd("GET").arg("test").query::(&mut connection); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name( + ) { + let name = "test_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name"; + + let MockEnv { mut connection, .. } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(7000), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(7001), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("?".as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = cmd("GET").arg("test").query::(&mut connection); + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_cluster_pipeline_command_ordering() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + let mut pipe = cluster_pipe(); + + let mut queries = Vec::new(); + let mut expected = Vec::new(); + for i in 0..100 { + queries.push(format!("foo{i}")); + expected.push(format!("bar{i}")); + pipe.set(&queries[i], &expected[i]).ignore(); + } + pipe.execute(&mut con); + + pipe.clear(); + for q in &queries { + pipe.get(q); + } + + let got = pipe.query::>(&mut con).unwrap(); + assert_eq!(got, expected); + } + + #[test] + #[ignore] // Flaky + fn test_cluster_pipeline_ordering_with_improper_command() { + let cluster = TestClusterContext::new(3, 0); + cluster.wait_for_cluster_up(); + let mut con = cluster.connection(); + let mut pipe = cluster_pipe(); + + let mut queries = Vec::new(); + let mut expected = Vec::new(); + for i in 0..10 { + if i == 5 { + pipe.cmd("hset").arg("foo").ignore(); + } else { + let query = format!("foo{i}"); + let r = format!("bar{i}"); + pipe.set(&query, &r).ignore(); + queries.push(query); + expected.push(r); + } + } + pipe.query::<()>(&mut con).unwrap_err(); + + std::thread::sleep(std::time::Duration::from_secs(5)); + + pipe.clear(); + for q in &queries { + pipe.get(q); + } + + let got = pipe.query::>(&mut con).unwrap(); + assert_eq!(got, expected); + } + + #[test] + fn test_cluster_retries() { + let name = "tryagain"; + + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(5), + name, + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + + match requests.fetch_add(1, atomic::Ordering::SeqCst) { + 0..=4 => Err(parse_redis_value(b"-TRYAGAIN mock\r\n")), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_cluster_exhaust_retries() { + let name = "tryagain_exhaust_retries"; + + let requests = Arc::new(atomic::AtomicUsize::new(0)); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + { + let requests = requests.clone(); + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + requests.fetch_add(1, atomic::Ordering::SeqCst); + Err(parse_redis_value(b"-TRYAGAIN mock\r\n")) + } + }, + ); + + let result = cmd("GET").arg("test").query::>(&mut connection); + + match result { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::TryAgain => {} + _ => panic!("Expected TryAgain but got {:?}", e.kind()), + }, + } + assert_eq!(requests.load(atomic::Ordering::SeqCst), 3); + } + + #[test] + fn test_cluster_move_error_when_new_node_is_added() { + let name = "rebuild_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value(b"-MOVED 123\r\n")), + // Respond with the new masters + 1 => Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(2), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))), + _ => { + // Check that the correct node receives the request after rebuilding + assert_eq!(port, 6380); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + } + }); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_cluster_ask_redirect() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + let count = completed.fetch_add(1, Ordering::SeqCst); + match port { + 6379 => match count { + 0 => Err(parse_redis_value(b"-ASK 14000 node:6380\r\n")), + _ => panic!("Node should not be called now"), + }, + 6380 => match count { + 1 => { + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => panic!("Node should not be called now"), + }, + _ => panic!("Wrong node"), + } + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_cluster_ask_error_when_new_node_is_added() { + let name = "ask_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-ASK 123 {name}:6380\r\n").as_bytes(), + )), + 1 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => { + panic!("Unexpected request: {:?}", cmd); + } + } + }); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_cluster_replica_read() { + let name = "node"; + + // requests should route to replica + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + + match port { + 6380 => Err(Ok(Value::BulkString(b"123".to_vec()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + assert_eq!(value, Ok(Some(123))); + + // requests should route to primary + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + match port { + 6379 => Err(Ok(Value::SimpleString("OK".into()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = cmd("SET") + .arg("test") + .arg("123") + .query::>(&mut connection); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + } + + #[test] + fn test_cluster_io_error() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + match port { + 6380 => panic!("Node should not be called"), + _ => match completed.fetch_add(1, Ordering::SeqCst) { + 0..=1 => Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "mock-io-error", + )))), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + }, + } + }, + ); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_cluster_non_retryable_error_should_not_retry() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { mut connection, .. } = MockEnv::new(name, { + let completed = completed.clone(); + move |cmd: &[u8], _| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + completed.fetch_add(1, Ordering::SeqCst); + Err(Err((ErrorKind::ReadOnly, "").into())) + } + }); + + let value = cmd("GET").arg("test").query::>(&mut connection); + + match value { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::ReadOnly => {} + _ => panic!("Expected ReadOnly but got {:?}", e.kind()), + }, + } + assert_eq!(completed.load(Ordering::SeqCst), 1); + } + + fn test_cluster_fan_out( + command: &'static str, + expected_ports: Vec, + slots_config: Option>, + ) { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let mut cmd = redis::Cmd::new(); + for arg in command.split_whitespace() { + cmd.arg(arg); + } + let packed_cmd = cmd.get_packed_command(); + // requests should route to replica + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + if received_cmd == packed_cmd { + ports_clone.lock().unwrap().push(port); + return Err(Ok(Value::SimpleString("OK".into()))); + } + Ok(()) + }, + ); + + let _ = cmd.query::>(&mut connection); + found_ports.lock().unwrap().sort(); + // MockEnv creates 2 mock connections. + assert_eq!(*found_ports.lock().unwrap(), expected_ports); + } + + #[test] + fn test_cluster_fan_out_to_all_primaries() { + test_cluster_fan_out("FLUSHALL", vec![6379, 6381], None); + } + + #[test] + fn test_cluster_fan_out_to_all_nodes() { + test_cluster_fan_out("CONFIG SET", vec![6379, 6380, 6381, 6382], None); + } + + #[test] + fn test_cluster_fan_out_out_once_to_each_primary_when_no_replicas_are_available() { + test_cluster_fan_out( + "CONFIG SET", + vec![6379, 6381], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: Vec::new(), + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: Vec::new(), + slot_range: (8192..16383), + }, + ]), + ); + } + + #[test] + fn test_cluster_fan_out_out_once_even_if_primary_has_multiple_slot_ranges() { + test_cluster_fan_out( + "CONFIG SET", + vec![6379, 6380, 6381, 6382], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (0..4000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (4001..8191), + }, + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (8192..8200), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8201..16383), + }, + ]), + ); + } + + #[test] + fn test_cluster_split_multi_shard_command_and_combine_arrays_of_values() { + let name = "test_cluster_split_multi_shard_command_and_combine_arrays_of_values"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + let results = ["foo", "bar", "baz"] + .iter() + .filter_map(|expected_key| { + if cmd_str.contains(expected_key) { + Some(Value::BulkString( + format!("{expected_key}-{port}").into_bytes(), + )) + } else { + None + } + }) + .collect(); + Err(Ok(Value::Array(results))) + }, + ); + + let result = cmd.query::>(&mut connection).unwrap(); + assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6380"]); + } + + #[test] + fn test_cluster_route_correctly_on_packed_transaction_with_single_node_requests() { + let name = "test_cluster_route_correctly_on_packed_transaction_with_single_node_requests"; + let mut pipeline = redis::pipe(); + pipeline.atomic().set("foo", "bar").get("foo"); + let packed_pipeline = pipeline.get_packed_pipeline(); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + let results = vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::Array(vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("bar".as_bytes().to_vec()), + ]), + ]; + return Err(Ok(Value::Array(results))); + } + Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + format!("wrong port: {port}"), + )))) + }, + ); + + let result = connection + .req_packed_commands(&packed_pipeline, 3, 1) + .unwrap(); + assert_eq!( + result, + vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("bar".as_bytes().to_vec()), + ] + ); + } + + #[test] + fn test_cluster_route_correctly_on_packed_transaction_with_single_node_requests2() { + let name = "test_cluster_route_correctly_on_packed_transaction_with_single_node_requests2"; + let mut pipeline = redis::pipe(); + pipeline.atomic().set("foo", "bar").get("foo"); + let packed_pipeline = pipeline.get_packed_pipeline(); + let results = vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::BulkString("QUEUED".as_bytes().to_vec()), + Value::Array(vec![ + Value::BulkString("OK".as_bytes().to_vec()), + Value::BulkString("bar".as_bytes().to_vec()), + ]), + ]; + let expected_result = Value::Array(results); + let cloned_result = expected_result.clone(); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + return Err(Ok(cloned_result.clone())); + } + Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + format!("wrong port: {port}"), + )))) + }, + ); + + let result = connection.req_packed_command(&packed_pipeline).unwrap(); + assert_eq!(result, expected_result); + } + + #[test] + fn test_cluster_with_client_name() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.client_name(RedisCluster::client_name().to_string()), + false, + ); + let mut con = cluster.connection(); + let client_info: String = redis::cmd("CLIENT").arg("INFO").query(&mut con).unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], + RedisCluster::client_name(), + "Incorrect client name, expecting: {}, got {}", + RedisCluster::client_name(), + client_attrs["name"] + ); + } + + #[test] + fn test_cluster_can_be_created_with_partial_slot_coverage() { + let name = "test_cluster_can_be_created_with_partial_slot_coverage"; + let slots_config = Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0..8000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![], + slot_range: (8201..16380), + }, + ]); + + let MockEnv { + mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + Err(Ok(Value::SimpleString("PONG".into()))) + }, + ); + + let res = connection.req_command(&redis::cmd("PING")); + assert!(res.is_ok()); + } + + #[cfg(feature = "tls-rustls")] + mod mtls_test { + use super::*; + use crate::support::mtls_test::create_cluster_client_from_cluster; + use redis::ConnectionInfo; + + #[test] + fn test_cluster_basics_with_mtls() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + + let client = create_cluster_client_from_cluster(&cluster, true).unwrap(); + let mut con = client.get_connection(None).unwrap(); + + redis::cmd("SET") + .arg("{x}key1") + .arg(b"foo") + .execute(&mut con); + redis::cmd("SET").arg(&["{x}key2", "bar"]).execute(&mut con); + + assert_eq!( + redis::cmd("MGET") + .arg(&["{x}key1", "{x}key2"]) + .query(&mut con), + Ok(("foo".to_string(), b"bar".to_vec())) + ); + } + + #[test] + fn test_cluster_should_not_connect_without_mtls() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + + let client = create_cluster_client_from_cluster(&cluster, false).unwrap(); + let connection = client.get_connection(None); + + match cluster.cluster.servers.first().unwrap().connection_info() { + ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { .. }, + .. + } => { + if connection.is_ok() { + panic!("Must NOT be able to connect without client credentials if server accepts TLS"); + } + } + _ => { + if let Err(e) = connection { + panic!("Must be able to connect without client credentials if server does NOT accept TLS: {e:?}"); + } + } + } + } + } +} diff --git a/glide-core/redis-rs/redis/tests/test_cluster_async.rs b/glide-core/redis-rs/redis/tests/test_cluster_async.rs new file mode 100644 index 0000000000..b690ed87b5 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_cluster_async.rs @@ -0,0 +1,4245 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "cluster-async")] +mod support; + +#[cfg(test)] +mod cluster_async { + use std::{ + collections::{HashMap, HashSet}, + net::{IpAddr, SocketAddr}, + str::from_utf8, + sync::{ + atomic::{self, AtomicBool, AtomicI32, AtomicU16, AtomicU32, Ordering}, + Arc, + }, + time::Duration, + }; + + use futures::prelude::*; + use futures_time::{future::FutureExt, task::sleep}; + use once_cell::sync::Lazy; + use std::ops::Add; + + use redis::{ + aio::{ConnectionLike, MultiplexedConnection}, + cluster::ClusterClient, + cluster_async::{testing::MANAGEMENT_CONN_NAME, ClusterConnection, Connect}, + cluster_routing::{ + MultipleNodeRoutingInfo, Route, RoutingInfo, SingleNodeRoutingInfo, SlotAddr, + }, + cluster_topology::{get_slot, DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES}, + cmd, from_owned_redis_value, parse_redis_value, AsyncCommands, Cmd, ErrorKind, + FromRedisValue, GlideConnectionOptions, InfoDict, IntoConnectionInfo, ProtocolVersion, + PubSubChannelOrPattern, PubSubSubscriptionInfo, PubSubSubscriptionKind, PushInfo, PushKind, + RedisError, RedisFuture, RedisResult, Script, Value, + }; + + use crate::support::*; + + use tokio::sync::mpsc; + fn broken_pipe_error() -> RedisError { + RedisError::from(std::io::Error::new( + std::io::ErrorKind::BrokenPipe, + "mock-io-error", + )) + } + + fn validate_subscriptions( + pubsub_subs: &PubSubSubscriptionInfo, + notifications_rx: &mut mpsc::UnboundedReceiver, + allow_disconnects: bool, + ) { + let mut subscribe_cnt = + if let Some(exact_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Exact) { + exact_subs.len() + } else { + 0 + }; + + let mut psubscribe_cnt = + if let Some(pattern_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Pattern) { + pattern_subs.len() + } else { + 0 + }; + + let mut ssubscribe_cnt = + if let Some(sharded_subs) = pubsub_subs.get(&PubSubSubscriptionKind::Sharded) { + sharded_subs.len() + } else { + 0 + }; + + for _ in 0..(subscribe_cnt + psubscribe_cnt + ssubscribe_cnt) { + let result = notifications_rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data: _ } = result.unwrap(); + assert!( + kind == PushKind::Subscribe + || kind == PushKind::PSubscribe + || kind == PushKind::SSubscribe + || if allow_disconnects { + kind == PushKind::Disconnection + } else { + false + } + ); + if kind == PushKind::Subscribe { + subscribe_cnt -= 1; + } else if kind == PushKind::PSubscribe { + psubscribe_cnt -= 1; + } else if kind == PushKind::SSubscribe { + ssubscribe_cnt -= 1; + } + } + + assert!(subscribe_cnt == 0); + assert!(psubscribe_cnt == 0); + assert!(ssubscribe_cnt == 0); + } + + #[test] + fn test_async_cluster_basic_cmd() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, "test_data"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_basic_eval() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let res: String = cmd("EVAL") + .arg(r#"redis.call("SET", KEYS[1], ARGV[1]); return redis.call("GET", KEYS[1])"#) + .arg(1) + .arg("key") + .arg("test") + .query_async(&mut connection) + .await?; + assert_eq!(res, "test"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_basic_script() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let res: String = Script::new( + r#"redis.call("SET", KEYS[1], ARGV[1]); return redis.call("GET", KEYS[1])"#, + ) + .key("key") + .arg("test") + .invoke_async(&mut connection) + .await?; + assert_eq!(res, "test"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_route_flush_to_specific_node() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let _: () = connection.set("foo", "bar").await.unwrap(); + let _: () = connection.set("bar", "foo").await.unwrap(); + + let res: String = connection.get("foo").await.unwrap(); + assert_eq!(res, "bar".to_string()); + let res2: Option = connection.get("bar").await.unwrap(); + assert_eq!(res2, Some("foo".to_string())); + + let route = + redis::cluster_routing::Route::new(1, redis::cluster_routing::SlotAddr::Master); + let single_node_route = + redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(route); + let routing = RoutingInfo::SingleNode(single_node_route); + assert_eq!( + connection + .route_command(&redis::cmd("FLUSHALL"), routing) + .await + .unwrap(), + Value::Okay + ); + let res: String = connection.get("foo").await.unwrap(); + assert_eq!(res, "bar".to_string()); + let res2: Option = connection.get("bar").await.unwrap(); + assert_eq!(res2, None); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_route_flush_to_node_by_address() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut cmd = redis::cmd("INFO"); + // The other sections change with time. + // TODO - after we remove support of redis 6, we can add more than a single section - .arg("Persistence").arg("Memory").arg("Replication") + cmd.arg("Clients"); + let value = connection + .route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllNodes, None)), + ) + .await + .unwrap(); + + let info_by_address = from_owned_redis_value::>(value).unwrap(); + // find the info of the first returned node + let (address, info) = info_by_address.into_iter().next().unwrap(); + let mut split_address = address.split(':'); + let host = split_address.next().unwrap().to_string(); + let port = split_address.next().unwrap().parse().unwrap(); + + let value = connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { host, port }), + ) + .await + .unwrap(); + let new_info = from_owned_redis_value::(value).unwrap(); + + assert_eq!(new_info, info); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_route_info_to_nodes() { + let cluster = TestClusterContext::new(12, 1); + + let split_to_addresses_and_info = |res| -> (Vec, Vec) { + if let Value::Map(values) = res { + let mut pairs: Vec<_> = values + .into_iter() + .map(|(key, value)| { + ( + redis::from_redis_value::(&key).unwrap(), + redis::from_redis_value::(&value).unwrap(), + ) + }) + .collect(); + pairs.sort_by(|(address1, _), (address2, _)| address1.cmp(address2)); + pairs.into_iter().unzip() + } else { + unreachable!("{:?}", res); + } + }; + + block_on_all(async move { + let cluster_addresses: Vec<_> = cluster + .cluster + .servers + .iter() + .map(|server| server.connection_info()) + .collect(); + let client = ClusterClient::builder(cluster_addresses.clone()) + .read_from_replicas() + .build()?; + let mut connection = client.get_async_connection(None).await?; + + let route_to_all_nodes = redis::cluster_routing::MultipleNodeRoutingInfo::AllNodes; + let routing = RoutingInfo::MultiNode((route_to_all_nodes, None)); + let res = connection + .route_command(&redis::cmd("INFO"), routing) + .await + .unwrap(); + let (addresses, infos) = split_to_addresses_and_info(res); + + let mut cluster_addresses: Vec<_> = cluster_addresses + .into_iter() + .map(|info| info.addr.to_string()) + .collect(); + cluster_addresses.sort(); + + assert_eq!(addresses.len(), 12); + assert_eq!(addresses, cluster_addresses); + assert_eq!(infos.len(), 12); + for i in 0..12 { + let split: Vec<_> = addresses[i].split(':').collect(); + assert!(infos[i].contains(&format!("tcp_port:{}", split[1]))); + } + + let route_to_all_primaries = + redis::cluster_routing::MultipleNodeRoutingInfo::AllMasters; + let routing = RoutingInfo::MultiNode((route_to_all_primaries, None)); + let res = connection + .route_command(&redis::cmd("INFO"), routing) + .await + .unwrap(); + let (addresses, infos) = split_to_addresses_and_info(res); + assert_eq!(addresses.len(), 6); + assert_eq!(infos.len(), 6); + // verify that all primaries have the correct port & host, and are marked as primaries. + for i in 0..6 { + assert!(cluster_addresses.contains(&addresses[i])); + let split: Vec<_> = addresses[i].split(':').collect(); + assert!(infos[i].contains(&format!("tcp_port:{}", split[1]))); + assert!(infos[i].contains("role:primary") || infos[i].contains("role:master")); + } + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_resp3() { + if use_protocol() == ProtocolVersion::RESP2 { + return; + } + block_on_all(async move { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.async_connection(None).await; + + let hello: HashMap = redis::cmd("HELLO") + .query_async(&mut connection) + .await + .unwrap(); + assert_eq!(hello.get("proto").unwrap(), &Value::Int(3)); + + let _: () = connection.hset("hash", "foo", "baz").await.unwrap(); + let _: () = connection.hset("hash", "bar", "foobar").await.unwrap(); + let result: Value = connection.hgetall("hash").await.unwrap(); + + assert_eq!( + result, + Value::Map(vec![ + ( + Value::BulkString("foo".as_bytes().to_vec()), + Value::BulkString("baz".as_bytes().to_vec()) + ), + ( + Value::BulkString("bar".as_bytes().to_vec()), + Value::BulkString("foobar".as_bytes().to_vec()) + ) + ]) + ); + + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_basic_pipe() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut pipe = redis::pipe(); + pipe.add_command(cmd("SET").arg("test").arg("test_data").clone()); + pipe.add_command(cmd("SET").arg("{test}3").arg("test_data3").clone()); + pipe.query_async(&mut connection).await?; + let res: String = connection.get("test").await?; + assert_eq!(res, "test_data"); + let res: String = connection.get("{test}3").await?; + assert_eq!(res, "test_data3"); + Ok::<_, RedisError>(()) + }) + .unwrap() + } + + #[test] + fn test_async_cluster_multi_shard_commands() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + + let res: String = connection + .mset(&[("foo", "bar"), ("bar", "foo"), ("baz", "bazz")]) + .await?; + assert_eq!(res, "OK"); + let res: Vec = connection.mget(&["baz", "foo", "bar"]).await?; + assert_eq!(res, vec!["bazz", "bar", "foo"]); + Ok::<_, RedisError>(()) + }) + .unwrap() + } + + #[test] + fn test_async_cluster_basic_failover() { + block_on_all(async move { + test_failover(&TestClusterContext::new(6, 1), 10, 123, false).await; + Ok::<_, RedisError>(()) + }) + .unwrap() + } + + async fn do_failover( + redis: &mut redis::aio::MultiplexedConnection, + ) -> Result<(), anyhow::Error> { + cmd("CLUSTER").arg("FAILOVER").query_async(redis).await?; + Ok(()) + } + + // parameter `_mtls_enabled` can only be used if `feature = tls-rustls` is active + #[allow(dead_code)] + async fn test_failover( + env: &TestClusterContext, + requests: i32, + value: i32, + _mtls_enabled: bool, + ) { + let completed = Arc::new(AtomicI32::new(0)); + + let connection = env.async_connection(None).await; + let mut node_conns: Vec = Vec::new(); + + 'outer: loop { + node_conns.clear(); + let cleared_nodes = async { + for server in env.cluster.iter_servers() { + let addr = server.client_addr(); + + #[cfg(feature = "tls-rustls")] + let client = build_single_client( + server.connection_info(), + &server.tls_paths, + _mtls_enabled, + ) + .unwrap_or_else(|e| panic!("Failed to connect to '{addr}': {e}")); + + #[cfg(not(feature = "tls-rustls"))] + let client = redis::Client::open(server.connection_info()) + .unwrap_or_else(|e| panic!("Failed to connect to '{addr}': {e}")); + + let mut conn = client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + .unwrap_or_else(|e| panic!("Failed to get connection: {e}")); + + let info: InfoDict = redis::Cmd::new() + .arg("INFO") + .query_async(&mut conn) + .await + .expect("INFO"); + let role: String = info.get("role").expect("cluster role"); + + if role == "master" { + tokio::time::timeout(std::time::Duration::from_secs(3), async { + Ok(redis::Cmd::new() + .arg("FLUSHALL") + .query_async(&mut conn) + .await?) + }) + .await + .unwrap_or_else(|err| Err(anyhow::Error::from(err)))?; + } + + node_conns.push(conn); + } + Ok::<_, anyhow::Error>(()) + } + .await; + match cleared_nodes { + Ok(()) => break 'outer, + Err(err) => { + // Failed to clear the databases, retry + tracing::warn!("{}", err); + } + } + } + + (0..requests + 1) + .map(|i| { + let mut connection = connection.clone(); + let mut node_conns = node_conns.clone(); + let completed = completed.clone(); + async move { + if i == requests / 2 { + // Failover all the nodes, error only if all the failover requests error + let mut results = future::join_all( + node_conns + .iter_mut() + .map(|conn| Box::pin(do_failover(conn))), + ) + .await; + if results.iter().all(|res| res.is_err()) { + results.pop().unwrap() + } else { + Ok::<_, anyhow::Error>(()) + } + } else { + let key = format!("test-{value}-{i}"); + cmd("SET") + .arg(&key) + .arg(i) + .clone() + .query_async(&mut connection) + .await?; + let res: i32 = cmd("GET") + .arg(key) + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, i); + completed.fetch_add(1, Ordering::SeqCst); + Ok::<_, anyhow::Error>(()) + } + } + }) + .collect::>() + .try_collect() + .await + .unwrap_or_else(|e| panic!("{e}")); + + assert_eq!( + completed.load(Ordering::SeqCst), + requests, + "Some requests never completed!" + ); + } + + static ERROR: Lazy = Lazy::new(Default::default); + + #[derive(Clone)] + struct ErrorConnection { + inner: MultiplexedConnection, + } + + impl Connect for ErrorConnection { + fn connect<'a, T>( + info: T, + response_timeout: std::time::Duration, + connection_timeout: std::time::Duration, + socket_addr: Option, + glide_connection_options: GlideConnectionOptions, + ) -> RedisFuture<'a, (Self, Option)> + where + T: IntoConnectionInfo + Send + 'a, + { + Box::pin(async move { + let (inner, _ip) = MultiplexedConnection::connect( + info, + response_timeout, + connection_timeout, + socket_addr, + glide_connection_options, + ) + .await?; + Ok((ErrorConnection { inner }, None)) + }) + } + } + + impl ConnectionLike for ErrorConnection { + fn req_packed_command<'a>(&'a mut self, cmd: &'a Cmd) -> RedisFuture<'a, Value> { + if ERROR.load(Ordering::SeqCst) { + Box::pin(async move { Err(RedisError::from((redis::ErrorKind::Moved, "ERROR"))) }) + } else { + self.inner.req_packed_command(cmd) + } + } + + fn req_packed_commands<'a>( + &'a mut self, + pipeline: &'a redis::Pipeline, + offset: usize, + count: usize, + ) -> RedisFuture<'a, Vec> { + self.inner.req_packed_commands(pipeline, offset, count) + } + + fn get_db(&self) -> i64 { + self.inner.get_db() + } + + fn is_closed(&self) -> bool { + true + } + } + + #[test] + fn test_async_cluster_error_in_inner_connection() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all(async move { + let mut con = cluster.async_generic_connection::().await; + + ERROR.store(false, Ordering::SeqCst); + let r: Option = con.get("test").await?; + assert_eq!(r, None::); + + ERROR.store(true, Ordering::SeqCst); + + let result: RedisResult<()> = con.get("test").await; + assert_eq!( + result, + Err(RedisError::from((redis::ErrorKind::Moved, "ERROR"))) + ); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + fn test_async_cluster_async_std_basic_cmd() { + let cluster = TestClusterContext::new(3, 0); + + block_on_all_using_async_std(async { + let mut connection = cluster.async_connection(None).await; + redis::cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + redis::cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .map_ok(|res: String| { + assert_eq!(res, "test_data"); + }) + .await + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name() { + let name = + "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_without_host_name"; + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("".as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name() { + let name = + "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_null_host_name"; + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![Value::Nil, Value::Int(6379)]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_async_cluster_cannot_connect_to_server_with_unknown_host_name() { + let name = "test_async_cluster_cannot_connect_to_server_with_unknown_host_name"; + let handler = move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![Value::Array(vec![ + Value::Int(0), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("?".as_bytes().to_vec()), + Value::Int(6379), + ]), + ])]))) + } else { + Err(Ok(Value::Nil)) + } + }; + let client_builder = ClusterClient::builder(vec![&*format!("redis://{name}")]); + let client: ClusterClient = client_builder.build().unwrap(); + let _handler = MockConnectionBehavior::register_new(name, Arc::new(handler)); + let connection = client.get_generic_connection::(None); + assert!(connection.is_err()); + let err = connection.err().unwrap(); + assert!(err + .to_string() + .contains("Error parsing slots: No healthy node found")) + } + + #[test] + fn test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name( + ) { + let name = "test_async_cluster_can_connect_to_server_that_sends_cluster_slots_with_partial_nodes_with_unknown_host_name"; + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |cmd: &[u8], _| { + if contains_slice(cmd, b"PING") { + Err(Ok(Value::SimpleString("OK".into()))) + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(7000), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(7001), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString("?".as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))) + } else { + Err(Ok(Value::Nil)) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Nil)); + } + + #[test] + fn test_async_cluster_retries() { + let name = "tryagain"; + + let requests = atomic::AtomicUsize::new(0); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(5), + name, + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + + match requests.fetch_add(1, atomic::Ordering::SeqCst) { + 0..=4 => Err(parse_redis_value(b"-TRYAGAIN mock\r\n")), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_tryagain_exhaust_retries() { + let name = "tryagain_exhaust_retries"; + + let requests = Arc::new(atomic::AtomicUsize::new(0)); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + { + let requests = requests.clone(); + move |cmd: &[u8], _| { + respond_startup(name, cmd)?; + requests.fetch_add(1, atomic::Ordering::SeqCst); + Err(parse_redis_value(b"-TRYAGAIN mock\r\n")) + } + }, + ); + + let result = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + match result { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::TryAgain => {} + _ => panic!("Expected TryAgain but got {:?}", e.kind()), + }, + } + assert_eq!(requests.load(atomic::Ordering::SeqCst), 3); + } + + // Obtain the view index associated with the node with [called_port] port + fn get_node_view_index(num_of_views: usize, ports: &Vec, called_port: u16) -> usize { + let port_index = ports + .iter() + .position(|&p| p == called_port) + .unwrap_or_else(|| { + panic!( + "CLUSTER SLOTS was called with unknown port: {called_port}; Known ports: {:?}", + ports + ) + }); + // If we have less views than nodes, use the last view + if port_index < num_of_views { + port_index + } else { + num_of_views - 1 + } + } + #[test] + fn test_async_cluster_move_error_when_new_node_is_added() { + let name = "rebuild_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let refreshed_map = HashMap::from([ + (6379, atomic::AtomicBool::new(false)), + (6380, atomic::AtomicBool::new(false)), + ]); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-MOVED 123 {name}:6380\r\n").as_bytes(), + )), + _ => { + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + // Should not attempt to refresh slots more than once, + // so we expect a single CLUSTER NODES request for each node + assert!(!refreshed_map + .get(&port) + .unwrap() + .swap(true, Ordering::SeqCst)); + Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(2), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))) + } else { + assert_eq!(port, 6380); + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + } + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + fn test_async_cluster_refresh_topology_after_moved_assert_get_succeed_and_expected_retries( + slots_config_vec: Vec>, + ports: Vec, + has_a_majority: bool, + ) { + assert!(!ports.is_empty() && !slots_config_vec.is_empty()); + let name = "refresh_topology_moved"; + let num_of_nodes = ports.len(); + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let refresh_calls = Arc::new(atomic::AtomicUsize::new(0)); + let refresh_calls_cloned = refresh_calls.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + // Disable the rate limiter to refresh slots immediately on all MOVED errors. + .slots_refresh_rate_limit(Duration::from_secs(0), 0), + name, + move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup_with_replica_using_config( + name, + cmd, + Some(slots_config_vec[0].clone()), + )?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + let moved_node = ports[0]; + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-MOVED 123 {name}:{moved_node}\r\n").as_bytes(), + )), + _ => { + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + refresh_calls_cloned.fetch_add(1, atomic::Ordering::SeqCst); + let view_index = + get_node_view_index(slots_config_vec.len(), &ports, port); + Err(Ok(create_topology_from_config( + name, + slots_config_vec[view_index].clone(), + ))) + } else { + assert_eq!(port, moved_node); + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + } + } + }, + ); + runtime.block_on(async move { + let res = cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection) + .await; + assert_eq!(res, Ok(Some(123))); + // If there is a majority in the topology views, or if it's a 2-nodes cluster, we shall be able to calculate the topology on the first try, + // so each node will be queried only once with CLUSTER SLOTS. + // Otherwise, if we don't have a majority, we expect to see the refresh_slots function being called with the maximum retry number. + let expected_calls = if has_a_majority || num_of_nodes == 2 {num_of_nodes} else {DEFAULT_NUMBER_OF_REFRESH_SLOTS_RETRIES * num_of_nodes}; + let mut refreshed_calls = 0; + for _ in 0..100 { + refreshed_calls = refresh_calls.load(atomic::Ordering::Relaxed); + if refreshed_calls == expected_calls { + return; + } else { + let sleep_duration = core::time::Duration::from_millis(100); + #[cfg(feature = "tokio-comp")] + tokio::time::sleep(sleep_duration).await; + + #[cfg(all(not(feature = "tokio-comp"), feature = "async-std-comp"))] + async_std::task::sleep(sleep_duration).await; + } + } + panic!("Failed to reach to the expected topology refresh retries. Found={refreshed_calls}, Expected={expected_calls}") + }); + } + + fn test_async_cluster_refresh_slots_rate_limiter_helper( + slots_config_vec: Vec>, + ports: Vec, + should_skip: bool, + ) { + // This test queries GET, which returns a MOVED error. If `should_skip` is true, + // it indicates that we should skip refreshing slots because the specified time + // duration since the last refresh slots call has not yet passed. In this case, + // we expect CLUSTER SLOTS not to be called on the nodes after receiving the + // MOVED error. + + // If `should_skip` is false, we verify that if the MOVED error occurs after the + // time duration of the rate limiter has passed, the refresh slots operation + // should not be skipped. We assert this by expecting calls to CLUSTER SLOTS on + // all nodes. + let test_name = format!( + "test_async_cluster_refresh_slots_rate_limiter_helper_{}", + if should_skip { + "should_skip" + } else { + "not_skipping_waiting_time_passed" + } + ); + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + let refresh_calls = Arc::new(atomic::AtomicUsize::new(0)); + let refresh_calls_cloned = Arc::clone(&refresh_calls); + let wait_duration = Duration::from_millis(10); + let num_of_nodes = ports.len(); + + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{test_name}")]) + .slots_refresh_rate_limit(wait_duration, 0), + test_name.clone().as_str(), + move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup_with_replica_using_config( + test_name.as_str(), + cmd, + Some(slots_config_vec[0].clone()), + )?; + started.store(true, atomic::Ordering::SeqCst); + } + + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + let moved_node = ports[0]; + match i { + // The first request calls are the starting calls for each GET command where we want to respond with MOVED error + 0 => { + if !should_skip { + // Wait for the wait duration to pass + std::thread::sleep(wait_duration.add(Duration::from_millis(10))); + } + Err(parse_redis_value( + format!("-MOVED 123 {test_name}:{moved_node}\r\n").as_bytes(), + )) + } + _ => { + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + refresh_calls_cloned.fetch_add(1, atomic::Ordering::SeqCst); + let view_index = + get_node_view_index(slots_config_vec.len(), &ports, port); + Err(Ok(create_topology_from_config( + test_name.as_str(), + slots_config_vec[view_index].clone(), + ))) + } else { + // Even if the slots weren't refreshed we still expect the command to be + // routed by the redirect host and port it received in the moved error + assert_eq!(port, moved_node); + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + } + } + }, + ); + + runtime.block_on(async move { + // First GET request should raise MOVED error and then refresh slots + let res = cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection) + .await; + assert_eq!(res, Ok(Some(123))); + + // We should skip is false, we should call CLUSTER SLOTS once per node + let expected_calls = if should_skip { + 0 + } else { + num_of_nodes + }; + for _ in 0..4 { + if refresh_calls.load(atomic::Ordering::Relaxed) == expected_calls { + return Ok::<_, RedisError>(()); + } + let _ = sleep(Duration::from_millis(50).into()).await; + } + panic!("Refresh slots wasn't called as expected!\nExpected CLUSTER SLOTS calls: {}, actual calls: {:?}", expected_calls, refresh_calls.load(atomic::Ordering::Relaxed)); + }).unwrap() + } + + fn test_async_cluster_refresh_topology_in_client_init_get_succeed( + slots_config_vec: Vec>, + ports: Vec, + ) { + assert!(!ports.is_empty() && !slots_config_vec.is_empty()); + let name = "refresh_topology_client_init"; + let started = atomic::AtomicBool::new(false); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder::( + ports + .iter() + .map(|port| format!("redis://{name}:{port}")) + .collect::>(), + ), + name, + move |cmd: &[u8], port| { + let is_started = started.load(atomic::Ordering::SeqCst); + if !is_started { + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } else if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + let view_index = get_node_view_index(slots_config_vec.len(), &ports, port); + return Err(Ok(create_topology_from_config( + name, + slots_config_vec[view_index].clone(), + ))); + } else if contains_slice(cmd, b"READONLY") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + } + started.store(true, atomic::Ordering::SeqCst); + if contains_slice(cmd, b"PING") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let is_get_cmd = contains_slice(cmd, b"GET"); + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + { + assert!(is_get_cmd, "{:?}", std::str::from_utf8(cmd)); + get_response + } + }, + ); + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + fn generate_topology_view( + ports: &[u16], + interval: usize, + full_slot_coverage: bool, + ) -> Vec { + let mut slots_res = vec![]; + let mut start_pos: usize = 0; + for (idx, port) in ports.iter().enumerate() { + let end_pos: usize = if idx == ports.len() - 1 && full_slot_coverage { + 16383 + } else { + start_pos + interval + }; + let mock_slot = MockSlotRange { + primary_port: *port, + replica_ports: vec![], + slot_range: (start_pos as u16..end_pos as u16), + }; + slots_res.push(mock_slot); + start_pos = end_pos + 1; + } + slots_res + } + + fn get_ports(num_of_nodes: usize) -> Vec { + (6379_u16..6379 + num_of_nodes as u16).collect() + } + + fn get_no_majority_topology_view(ports: &[u16]) -> Vec> { + let mut result = vec![]; + let mut full_coverage = true; + for i in 0..ports.len() { + result.push(generate_topology_view(ports, i + 1, full_coverage)); + full_coverage = !full_coverage; + } + result + } + + fn get_topology_with_majority(ports: &[u16]) -> Vec> { + let view: Vec = generate_topology_view(ports, 10, true); + let result: Vec<_> = ports.iter().map(|_| view.clone()).collect(); + result + } + + #[test] + fn test_async_cluster_refresh_topology_after_moved_error_all_nodes_agree_get_succeed() { + let ports = get_ports(3); + test_async_cluster_refresh_topology_after_moved_assert_get_succeed_and_expected_retries( + get_topology_with_majority(&ports), + ports, + true, + ); + } + + #[test] + fn test_async_cluster_refresh_topology_in_client_init_all_nodes_agree_get_succeed() { + let ports = get_ports(3); + test_async_cluster_refresh_topology_in_client_init_get_succeed( + get_topology_with_majority(&ports), + ports, + ); + } + + #[test] + fn test_async_cluster_refresh_topology_after_moved_error_with_no_majority_get_succeed() { + for num_of_nodes in 2..4 { + let ports = get_ports(num_of_nodes); + test_async_cluster_refresh_topology_after_moved_assert_get_succeed_and_expected_retries( + get_no_majority_topology_view(&ports), + ports, + false, + ); + } + } + + #[test] + fn test_async_cluster_refresh_topology_in_client_init_with_no_majority_get_succeed() { + for num_of_nodes in 2..4 { + let ports = get_ports(num_of_nodes); + test_async_cluster_refresh_topology_in_client_init_get_succeed( + get_no_majority_topology_view(&ports), + ports, + ); + } + } + + #[test] + fn test_async_cluster_refresh_topology_even_with_zero_retries() { + let name = "test_async_cluster_refresh_topology_even_with_zero_retries"; + + let should_refresh = atomic::AtomicBool::new(false); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(0) + // Disable the rate limiter to refresh slots immediately on the MOVED error. + .slots_refresh_rate_limit(Duration::from_secs(0), 0), + name, + move |cmd: &[u8], port| { + if !should_refresh.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + if contains_slice(cmd, b"CLUSTER") && contains_slice(cmd, b"SLOTS") { + return Err(Ok(Value::Array(vec![ + Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6379), + ]), + ]), + Value::Array(vec![ + Value::Int(2), + Value::Int(16383), + Value::Array(vec![ + Value::BulkString(name.as_bytes().to_vec()), + Value::Int(6380), + ]), + ]), + ]))); + } + + if contains_slice(cmd, b"GET") { + let get_response = Err(Ok(Value::BulkString(b"123".to_vec()))); + match port { + 6380 => get_response, + // Respond that the key exists on a node that does not yet have a connection: + _ => { + // Should not attempt to refresh slots more than once: + assert!(!should_refresh.swap(true, Ordering::SeqCst)); + Err(parse_redis_value( + format!("-MOVED 123 {name}:6380\r\n").as_bytes(), + )) + } + } + } else { + panic!("unexpected command {cmd:?}") + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + // The user should receive an initial error, because there are no retries and the first request failed. + assert_eq!( + value, + Err(RedisError::from(( + ErrorKind::Moved, + "An error was signalled by the server", + "test_async_cluster_refresh_topology_even_with_zero_retries:6380".to_string() + ))) + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_reconnect_even_with_zero_retries() { + let name = "test_async_cluster_reconnect_even_with_zero_retries"; + + let should_reconnect = atomic::AtomicBool::new(true); + let connection_count = Arc::new(atomic::AtomicU16::new(0)); + let connection_count_clone = connection_count.clone(); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(0), + name, + move |cmd: &[u8], port| { + match respond_startup(name, cmd) { + Ok(_) => {} + Err(err) => { + connection_count.fetch_add(1, Ordering::Relaxed); + return Err(err); + } + } + + if contains_slice(cmd, b"ECHO") && port == 6379 { + // Should not attempt to refresh slots more than once: + if should_reconnect.swap(false, Ordering::SeqCst) { + Err(Err(broken_pipe_error())) + } else { + Err(Ok(Value::BulkString(b"PONG".to_vec()))) + } + } else { + panic!("unexpected command {cmd:?}") + } + }, + ); + + // We expect 6 calls in total. MockEnv creates both synchronous and asynchronous connections, which make the following calls: + // - 1 call by the sync connection to `CLUSTER SLOTS` for initializing the client's topology map. + // - 3 calls by the async connection to `PING`: one for the user connection when creating the node from initial addresses, + // and two more for checking the user and management connections during client initialization in `refresh_slots`. + // - 1 call by the async connection to `CLIENT SETNAME` for setting up the management connection name. + // - 1 call by the async connection to `CLUSTER SLOTS` for initializing the client's topology map. + // Note: If additional nodes or setup calls are added, this number should increase. + let expected_init_calls = 6; + assert_eq!( + connection_count_clone.load(Ordering::Relaxed), + expected_init_calls + ); + + let value = runtime.block_on(connection.route_command( + &cmd("ECHO"), + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: name.to_string(), + port: 6379, + }), + )); + + // The user should receive an initial error, because there are no retries and the first request failed. + assert_eq!( + value.unwrap_err().to_string(), + broken_pipe_error().to_string() + ); + + let value = runtime.block_on(connection.route_command( + &cmd("ECHO"), + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: name.to_string(), + port: 6379, + }), + )); + + assert_eq!(value, Ok(Value::BulkString(b"PONG".to_vec()))); + // `expected_init_calls` plus another PING for a new user connection created from refresh_connections + assert_eq!( + connection_count_clone.load(Ordering::Relaxed), + expected_init_calls + 1 + ); + } + + #[test] + fn test_async_cluster_refresh_slots_rate_limiter_skips_refresh() { + let ports = get_ports(3); + test_async_cluster_refresh_slots_rate_limiter_helper( + get_topology_with_majority(&ports), + ports, + true, + ); + } + + #[test] + fn test_async_cluster_refresh_slots_rate_limiter_does_refresh_when_wait_duration_passed() { + let ports = get_ports(3); + test_async_cluster_refresh_slots_rate_limiter_helper( + get_topology_with_majority(&ports), + ports, + false, + ); + } + + #[test] + fn test_async_cluster_ask_redirect() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + let count = completed.fetch_add(1, Ordering::SeqCst); + match port { + 6379 => match count { + 0 => Err(parse_redis_value(b"-ASK 14000 node:6380\r\n")), + _ => panic!("Node should not be called now"), + }, + 6380 => match count { + 1 => { + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => panic!("Node should not be called now"), + }, + _ => panic!("Wrong node"), + } + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_ask_save_new_connection() { + let name = "node"; + let ping_attempts = Arc::new(AtomicI32::new(0)); + let ping_attempts_clone = ping_attempts.clone(); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + if port != 6391 { + respond_startup_two_nodes(name, cmd)?; + return Err(parse_redis_value(b"-ASK 14000 node:6391\r\n")); + } + + if contains_slice(cmd, b"PING") { + ping_attempts_clone.fetch_add(1, Ordering::Relaxed); + } + respond_startup_two_nodes(name, cmd)?; + Err(Ok(Value::Okay)) + } + }, + ); + + for _ in 0..4 { + runtime + .block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ) + .unwrap(); + } + + assert_eq!(ping_attempts.load(Ordering::Relaxed), 1); + } + + #[test] + fn test_async_cluster_reset_routing_if_redirect_fails() { + let name = "test_async_cluster_reset_routing_if_redirect_fails"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if port != 6379 && port != 6380 { + return Err(Err(broken_pipe_error())); + } + respond_startup_two_nodes(name, cmd)?; + let count = completed.fetch_add(1, Ordering::SeqCst); + match (port, count) { + // redirect once to non-existing node + (6379, 0) => Err(parse_redis_value( + format!("-ASK 14000 {name}:9999\r\n").as_bytes(), + )), + // accept the next request + (6379, 1) => { + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => panic!("Wrong node. port: {port}, received count: {count}"), + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_ask_redirect_even_if_original_call_had_no_route() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + { + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + let count = completed.fetch_add(1, Ordering::SeqCst); + if count == 0 { + return Err(parse_redis_value(b"-ASK 14000 node:6380\r\n")); + } + match port { + 6380 => match count { + 1 => { + assert!( + contains_slice(cmd, b"ASKING"), + "{:?}", + std::str::from_utf8(cmd) + ); + Err(Ok(Value::Okay)) + } + 2 => { + assert!(contains_slice(cmd, b"EVAL")); + Err(Ok(Value::Okay)) + } + _ => panic!("Node should not be called now"), + }, + _ => panic!("Wrong node"), + } + } + }, + ); + + let value = runtime.block_on( + cmd("EVAL") // Eval command has no directed, and so is redirected randomly + .query_async::<_, Value>(&mut connection), + ); + + assert_eq!(value, Ok(Value::Okay)); + } + + #[test] + fn test_async_cluster_ask_error_when_new_node_is_added() { + let name = "ask_with_extra_nodes"; + + let requests = atomic::AtomicUsize::new(0); + let started = atomic::AtomicBool::new(false); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::new(name, move |cmd: &[u8], port| { + if !started.load(atomic::Ordering::SeqCst) { + respond_startup(name, cmd)?; + } + started.store(true, atomic::Ordering::SeqCst); + + if contains_slice(cmd, b"PING") || contains_slice(cmd, b"SETNAME") { + return Err(Ok(Value::SimpleString("OK".into()))); + } + + let i = requests.fetch_add(1, atomic::Ordering::SeqCst); + + match i { + // Respond that the key exists on a node that does not yet have a connection: + 0 => Err(parse_redis_value( + format!("-ASK 123 {name}:6380\r\n").as_bytes(), + )), + 1 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"ASKING")); + Err(Ok(Value::Okay)) + } + 2 => { + assert_eq!(port, 6380); + assert!(contains_slice(cmd, b"GET")); + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + _ => { + panic!("Unexpected request: {:?}", cmd); + } + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_replica_read() { + let name = "node"; + + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + match port { + 6380 => Err(Ok(Value::BulkString(b"123".to_vec()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(123))); + + // requests should route to primary + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + match port { + 6379 => Err(Ok(Value::SimpleString("OK".into()))), + _ => panic!("Wrong node"), + } + }, + ); + + let value = runtime.block_on( + cmd("SET") + .arg("test") + .arg("123") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(Value::SimpleString("OK".to_owned())))); + } + + fn test_async_cluster_fan_out( + command: &'static str, + expected_ports: Vec, + slots_config: Option>, + ) { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let mut cmd = Cmd::new(); + for arg in command.split_whitespace() { + cmd.arg(arg); + } + let packed_cmd = cmd.get_packed_command(); + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + if received_cmd == packed_cmd { + ports_clone.lock().unwrap().push(port); + return Err(Ok(Value::SimpleString("OK".into()))); + } + Ok(()) + }, + ); + + let _ = runtime.block_on(cmd.query_async::<_, Option<()>>(&mut connection)); + found_ports.lock().unwrap().sort(); + // MockEnv creates 2 mock connections. + assert_eq!(*found_ports.lock().unwrap(), expected_ports); + } + + #[test] + fn test_async_cluster_fan_out_to_all_primaries() { + test_async_cluster_fan_out("FLUSHALL", vec![6379, 6381], None); + } + + #[test] + fn test_async_cluster_fan_out_to_all_nodes() { + test_async_cluster_fan_out("CONFIG SET", vec![6379, 6380, 6381, 6382], None); + } + + #[test] + fn test_async_cluster_fan_out_once_to_each_primary_when_no_replicas_are_available() { + test_async_cluster_fan_out( + "CONFIG SET", + vec![6379, 6381], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: Vec::new(), + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: Vec::new(), + slot_range: (8192..16383), + }, + ]), + ); + } + + #[test] + fn test_async_cluster_fan_out_once_even_if_primary_has_multiple_slot_ranges() { + test_async_cluster_fan_out( + "CONFIG SET", + vec![6379, 6380, 6381, 6382], + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (0..4000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (4001..8191), + }, + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380], + slot_range: (8192..8200), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8201..16383), + }, + ]), + ); + } + + #[test] + fn test_async_cluster_route_according_to_passed_argument() { + let name = "test_async_cluster_route_according_to_passed_argument"; + + let touched_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let cloned_ports = touched_ports.clone(); + + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica(name, cmd)?; + cloned_ports.lock().unwrap().push(port); + Err(Ok(Value::Nil)) + }, + ); + + let mut cmd = cmd("GET"); + cmd.arg("test"); + let _ = runtime.block_on(connection.route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllMasters, None)), + )); + { + let mut touched_ports = touched_ports.lock().unwrap(); + touched_ports.sort(); + assert_eq!(*touched_ports, vec![6379, 6381]); + touched_ports.clear(); + } + + let _ = runtime.block_on(connection.route_command( + &cmd, + RoutingInfo::MultiNode((MultipleNodeRoutingInfo::AllNodes, None)), + )); + { + let mut touched_ports = touched_ports.lock().unwrap(); + touched_ports.sort(); + assert_eq!(*touched_ports, vec![6379, 6380, 6381, 6382]); + touched_ports.clear(); + } + + let _ = runtime.block_on(connection.route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: name.to_string(), + port: 6382, + }), + )); + { + let mut touched_ports = touched_ports.lock().unwrap(); + touched_ports.sort(); + assert_eq!(*touched_ports, vec![6382]); + touched_ports.clear(); + } + } + + #[test] + fn test_async_cluster_fan_out_and_aggregate_numeric_response_with_min() { + let name = "test_async_cluster_fan_out_and_aggregate_numeric_response"; + let mut cmd = Cmd::new(); + cmd.arg("SLOWLOG").arg("LEN"); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + + let res = 6383 - port as i64; + Err(Ok(Value::Int(res))) // this results in 1,2,3,4 + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, i64>(&mut connection)) + .unwrap(); + assert_eq!(result, 10, "{result}"); + } + + #[test] + fn test_async_cluster_fan_out_and_aggregate_logical_array_response() { + let name = "test_async_cluster_fan_out_and_aggregate_logical_array_response"; + let mut cmd = Cmd::new(); + cmd.arg("SCRIPT") + .arg("EXISTS") + .arg("foo") + .arg("bar") + .arg("baz") + .arg("barvaz"); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + + if port == 6381 { + return Err(Ok(Value::Array(vec![ + Value::Int(0), + Value::Int(0), + Value::Int(1), + Value::Int(1), + ]))); + } else if port == 6379 { + return Err(Ok(Value::Array(vec![ + Value::Int(0), + Value::Int(1), + Value::Int(0), + Value::Int(1), + ]))); + } + + panic!("unexpected port {port}"); + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + assert_eq!(result, vec![0, 0, 0, 1], "{result:?}"); + } + + #[test] + fn test_async_cluster_fan_out_and_return_one_succeeded_response() { + let name = "test_async_cluster_fan_out_and_return_one_succeeded_response"; + let mut cmd = Cmd::new(); + cmd.arg("SCRIPT").arg("KILL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + return Err(Ok(Value::Okay)); + } + Err(Err(( + ErrorKind::NotBusy, + "No scripts in execution right now", + ) + .into())) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap(); + assert_eq!(result, Value::Okay, "{result:?}"); + } + + #[test] + fn test_async_cluster_fan_out_and_fail_one_succeeded_if_there_are_no_successes() { + let name = "test_async_cluster_fan_out_and_fail_one_succeeded_if_there_are_no_successes"; + let mut cmd = Cmd::new(); + cmd.arg("SCRIPT").arg("KILL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + + Err(Err(( + ErrorKind::NotBusy, + "No scripts in execution right now", + ) + .into())) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::NotBusy, "{:?}", result.kind()); + } + + #[test] + fn test_async_cluster_fan_out_and_return_all_succeeded_response() { + let name = "test_async_cluster_fan_out_and_return_all_succeeded_response"; + let cmd = cmd("FLUSHALL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + Err(Ok(Value::Okay)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap(); + assert_eq!(result, Value::Okay, "{result:?}"); + } + + #[test] + fn test_async_cluster_fan_out_and_fail_all_succeeded_if_there_is_a_single_failure() { + let name = "test_async_cluster_fan_out_and_fail_all_succeeded_if_there_is_a_single_failure"; + let cmd = cmd("FLUSHALL"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + if port == 6381 { + return Err(Err(( + ErrorKind::NotBusy, + "No scripts in execution right now", + ) + .into())); + } + Err(Ok(Value::Okay)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::NotBusy, "{:?}", result.kind()); + } + + #[test] + fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_value_ignoring_nil_and_err_resps( + ) { + let name = + "test_async_cluster_first_succeeded_non_empty_or_all_empty_return_value_ignoring_nil_and_err_resps"; + let cmd = cmd("RANDOMKEY"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + let ports = vec![6379, 6380, 6381]; + let slots_config_vec = generate_topology_view(&ports, 1000, true); + respond_startup_with_config(name, received_cmd, Some(slots_config_vec), false)?; + if port == 6380 { + return Err(Ok(Value::BulkString("foo".as_bytes().to_vec()))); + } else if port == 6381 { + return Err(Err(RedisError::from(( + redis::ErrorKind::ResponseError, + "ERROR", + )))); + } + Err(Ok(Value::Nil)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, String>(&mut connection)) + .unwrap(); + assert_eq!(result, "foo", "{result:?}"); + } + + #[test] + fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_err_if_all_resps_are_nil_and_errors( + ) { + let name = + "test_async_cluster_first_succeeded_non_empty_or_all_empty_return_err_if_all_resps_are_nil_and_errors"; + let cmd = cmd("RANDOMKEY"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_config(name, received_cmd, None, false)?; + if port == 6380 { + return Err(Ok(Value::Nil)); + } + Err(Err(RedisError::from(( + redis::ErrorKind::ResponseError, + "ERROR", + )))) + }, + ); + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::ResponseError); + } + + #[test] + fn test_async_cluster_first_succeeded_non_empty_or_all_empty_return_nil_if_all_resp_nil() { + let name = + "test_async_cluster_first_succeeded_non_empty_or_all_empty_return_nil_if_all_resp_nil"; + let cmd = cmd("RANDOMKEY"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _port| { + respond_startup_with_config(name, received_cmd, None, false)?; + Err(Ok(Value::Nil)) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Value>(&mut connection)) + .unwrap(); + assert_eq!(result, Value::Nil, "{result:?}"); + } + + #[test] + fn test_async_cluster_fan_out_and_return_map_of_results_for_special_response_policy() { + let name = "foo"; + let mut cmd = Cmd::new(); + cmd.arg("LATENCY").arg("LATEST"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + Err(Ok(Value::BulkString( + format!("latency: {port}").into_bytes(), + ))) + }, + ); + + // TODO once RESP3 is in, return this as a map + let mut result = runtime + .block_on(cmd.query_async::<_, Vec<(String, String)>>(&mut connection)) + .unwrap(); + result.sort(); + assert_eq!( + result, + vec![ + (format!("{name}:6379"), "latency: 6379".to_string()), + (format!("{name}:6380"), "latency: 6380".to_string()), + (format!("{name}:6381"), "latency: 6381".to_string()), + (format!("{name}:6382"), "latency: 6382".to_string()) + ], + "{result:?}" + ); + } + + #[test] + fn test_async_cluster_fan_out_and_combine_arrays_of_values() { + let name = "foo"; + let cmd = cmd("KEYS"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + Err(Ok(Value::Array(vec![Value::BulkString( + format!("key:{port}").into_bytes(), + )]))) + }, + ); + + let mut result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + result.sort(); + assert_eq!( + result, + vec!["key:6379".to_string(), "key:6381".to_string(),], + "{result:?}" + ); + } + + #[test] + fn test_async_cluster_split_multi_shard_command_and_combine_arrays_of_values() { + let name = "test_async_cluster_split_multi_shard_command_and_combine_arrays_of_values"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + let results = ["foo", "bar", "baz"] + .iter() + .filter_map(|expected_key| { + if cmd_str.contains(expected_key) { + Some(Value::BulkString( + format!("{expected_key}-{port}").into_bytes(), + )) + } else { + None + } + }) + .collect(); + Err(Ok(Value::Array(results))) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6380"]); + } + + #[test] + fn test_async_cluster_handle_asking_error_in_split_multi_shard_command() { + let name = "test_async_cluster_handle_asking_error_in_split_multi_shard_command"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let asking_called = Arc::new(AtomicU16::new(0)); + let asking_called_cloned = asking_called.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + if cmd_str.contains("ASKING") && port == 6382 { + asking_called_cloned.fetch_add(1, Ordering::Relaxed); + } + if port == 6380 && cmd_str.contains("baz") { + return Err(parse_redis_value( + format!("-ASK 14000 {name}:6382\r\n").as_bytes(), + )); + } + let results = ["foo", "bar", "baz"] + .iter() + .filter_map(|expected_key| { + if cmd_str.contains(expected_key) { + Some(Value::BulkString( + format!("{expected_key}-{port}").into_bytes(), + )) + } else { + None + } + }) + .collect(); + Err(Ok(Value::Array(results))) + }, + ); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap(); + assert_eq!(result, vec!["foo-6382", "bar-6380", "baz-6382"]); + assert_eq!(asking_called.load(Ordering::Relaxed), 1); + } + + #[test] + fn test_async_cluster_pass_errors_from_split_multi_shard_command() { + let name = "test_async_cluster_pass_errors_from_split_multi_shard_command"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config(name, received_cmd, None)?; + let cmd_str = std::str::from_utf8(received_cmd).unwrap(); + if cmd_str.contains("foo") || cmd_str.contains("baz") { + Err(Err((ErrorKind::IoError, "error").into())) + } else { + Err(Ok(Value::Array(vec![Value::BulkString( + format!("{port}").into_bytes(), + )]))) + } + }); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap_err(); + assert_eq!(result.kind(), ErrorKind::IoError); + } + + #[test] + fn test_async_cluster_handle_missing_slots_in_split_multi_shard_command() { + let name = "test_async_cluster_handle_missing_slots_in_split_multi_shard_command"; + let mut cmd = cmd("MGET"); + cmd.arg("foo").arg("bar").arg("baz"); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + Some(vec![MockSlotRange { + primary_port: 6381, + replica_ports: vec![6382], + slot_range: (8192..16383), + }]), + )?; + Err(Ok(Value::Array(vec![Value::BulkString( + format!("{port}").into_bytes(), + )]))) + }); + + let result = runtime + .block_on(cmd.query_async::<_, Vec>(&mut connection)) + .unwrap_err(); + assert!( + matches!(result.kind(), ErrorKind::ConnectionNotFoundForRoute) + || result.is_connection_dropped() + ); + } + + #[test] + fn test_async_cluster_with_username_and_password() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .username(RedisCluster::username().to_string()) + .password(RedisCluster::password().to_string()) + }, + false, + ); + cluster.disable_default_user(); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, "test_data"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_io_error() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(2), + name, + move |cmd: &[u8], port| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + match port { + 6380 => panic!("Node should not be called"), + _ => match completed.fetch_add(1, Ordering::SeqCst) { + 0..=1 => Err(Err(RedisError::from(std::io::Error::new( + std::io::ErrorKind::ConnectionReset, + "mock-io-error", + )))), + _ => Err(Ok(Value::BulkString(b"123".to_vec()))), + }, + } + }, + ); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + + #[test] + fn test_async_cluster_non_retryable_error_should_not_retry() { + let name = "node"; + let completed = Arc::new(AtomicI32::new(0)); + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::new(name, { + let completed = completed.clone(); + move |cmd: &[u8], _| { + respond_startup_two_nodes(name, cmd)?; + // Error twice with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + completed.fetch_add(1, Ordering::SeqCst); + Err(Err((ErrorKind::ReadOnly, "").into())) + } + }); + + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + match value { + Ok(_) => panic!("result should be an error"), + Err(e) => match e.kind() { + ErrorKind::ReadOnly => {} + _ => panic!("Expected ReadOnly but got {:?}", e.kind()), + }, + } + assert_eq!(completed.load(Ordering::SeqCst), 1); + } + + #[test] + fn test_async_cluster_read_from_primary() { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::new(name, move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6382, + replica_ports: vec![6383, 6384], + slot_range: (8192..16383), + }, + ]), + )?; + ports_clone.lock().unwrap().push(port); + Err(Ok(Value::Nil)) + }); + + runtime.block_on(async { + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + }); + + found_ports.lock().unwrap().sort(); + assert_eq!(*found_ports.lock().unwrap(), vec![6379, 6379, 6382, 6382]); + } + + #[test] + fn test_async_cluster_round_robin_read_from_replica() { + let name = "node"; + let found_ports = Arc::new(std::sync::Mutex::new(Vec::new())); + let ports_clone = found_ports.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).read_from_replicas(), + name, + move |received_cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + received_cmd, + Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..8191), + }, + MockSlotRange { + primary_port: 6382, + replica_ports: vec![6383, 6384], + slot_range: (8192..16383), + }, + ]), + )?; + ports_clone.lock().unwrap().push(port); + Err(Ok(Value::Nil)) + }, + ); + + runtime.block_on(async { + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("foo") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + cmd("GET") + .arg("bar") + .query_async::<_, ()>(&mut connection) + .await + .unwrap(); + }); + + found_ports.lock().unwrap().sort(); + assert_eq!(*found_ports.lock().unwrap(), vec![6380, 6381, 6383, 6384]); + } + + fn get_queried_node_id_if_master(cluster_nodes_output: Value) -> Option { + // Returns the node ID of the connection that was queried for CLUSTER NODES (using the 'myself' flag), if it's a master. + // Otherwise, returns None. + let get_node_id = |str: &str| { + let parts: Vec<&str> = str.split('\n').collect(); + for node_entry in parts { + if node_entry.contains("myself") && node_entry.contains("master") { + let node_entry_parts: Vec<&str> = node_entry.split(' ').collect(); + let node_id = node_entry_parts[0]; + return Some(node_id.to_string()); + } + } + None + }; + + match cluster_nodes_output { + Value::BulkString(val) => match from_utf8(&val) { + Ok(str_res) => get_node_id(str_res), + Err(e) => panic!("failed to decode INFO response: {:?}", e), + }, + Value::VerbatimString { format: _, text } => get_node_id(&text), + _ => panic!("Recieved unexpected response: {:?}", cluster_nodes_output), + } + } + + #[test] + fn test_async_cluster_handle_complete_server_disconnect_without_panicking() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(2), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + drop(cluster); + for _ in 0..5 { + let cmd = cmd("PING"); + let result = connection + .route_command(&cmd, RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + .await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + // This will route to all nodes - different path through the code. + let result = connection.req_packed_command(&cmd).await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + } + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_test_fast_reconnect() { + // Note the 3 seconds connection check to differentiate between notifications and periodic + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(0) + .periodic_connections_checks(Duration::from_secs(3)) + }, + false, + ); + + // For tokio-comp, do 3 consequtive disconnects and ensure reconnects succeeds in less than 100ms, + // which is more than enough for local connections even with TLS. + // More than 1 run is done to ensure it is the fast reconnect notification that trigger the reconnect + // and not the periodic interval. + // For other async implementation, only periodic connection check is available, hence, + // do 1 run sleeping for periodic connection check interval, allowing it to reestablish connections + block_on_all(async move { + let mut disconnecting_con = cluster.async_connection(None).await; + let mut monitoring_con = cluster.async_connection(None).await; + + #[cfg(feature = "tokio-comp")] + let tries = 0..3; + #[cfg(not(feature = "tokio-comp"))] + let tries = 0..1; + + for _ in tries { + // get connection id + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("ID"); + let res = disconnecting_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + let id = { + match res { + Value::Int(id) => id, + _ => { + panic!("Wrong return value for CLIENT ID command: {:?}", res); + } + } + }; + + // ask server to kill the connection + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("KILL").arg("ID").arg(id).arg("SKIPME").arg("NO"); + let res = disconnecting_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + // assert server has closed connection + assert_eq!(res, Ok(Value::Int(1))); + + #[cfg(feature = "tokio-comp")] + // ensure reconnect happened in less than 100ms + sleep(futures_time::time::Duration::from_millis(100)).await; + + #[cfg(not(feature = "tokio-comp"))] + // no fast notification is available, wait for 1 periodic check + overhead + sleep(futures_time::time::Duration::from_secs(3 + 1)).await; + + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("LIST").arg("TYPE").arg("NORMAL"); + let res = monitoring_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + let client_list: String = { + match res { + // RESP2 + Value::BulkString(client_info) => { + // ensure 4 connections - 2 for each client, its save to unwrap here + String::from_utf8(client_info).unwrap() + } + // RESP3 + Value::VerbatimString { format: _, text } => text, + _ => { + panic!("Wrong return type for CLIENT LIST command: {:?}", res); + } + } + }; + assert_eq!(client_list.chars().filter(|&x| x == '\n').count(), 4); + } + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_restore_resp3_pubsub_state_passive_disconnect() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel".as_bytes())]), + )]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + // note topology change detection is not activated since no topology change is expected + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + .periodic_connections_checks(Duration::from_secs(1)) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut _listening_con = cluster.async_connection(Some(tx.clone())).await; + // Note, publishing connection has the same pubsub config + let mut publishing_con = cluster.async_connection(None).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate subscriptions + validate_subscriptions(&client_subscriptions, &mut rx, false); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + // simulate passive disconnect + drop(cluster); + + // recreate the cluster, the assumtion is that the cluster is built with exactly the same params (ports, slots map...) + let _cluster = + TestClusterContext::new_with_cluster_client_builder(3, 0, |builder| builder, false); + + // sleep for 1 periodic_connections_checks + overhead + sleep(futures_time::time::Duration::from_secs(1 + 1)).await; + + // new subscription notifications due to resubscriptions + validate_subscriptions(&client_subscriptions, &mut rx, true); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_restore_resp3_pubsub_state_after_scale_out() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([ + // test_channel_? is used as it maps to 14212 slot, which is the last node in both 3 and 6 node config + // (assuming slots allocation is monotonicaly increasing starting from node 0) + ( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ), + ]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + let slot_14212 = get_slot(b"test_channel_?"); + assert_eq!(slot_14212, 14212); + + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + // periodic connection check is required to detect the disconnect from the last node + .periodic_connections_checks(Duration::from_secs(1)) + // periodic topology check is required to detect topology change + .periodic_topology_checks(Duration::from_secs(1)) + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut _listening_con = cluster.async_connection(Some(tx.clone())).await; + // Note, publishing connection has the same pubsub config + let mut publishing_con = cluster.async_connection(None).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate subscriptions + validate_subscriptions(&client_subscriptions, &mut rx, false); + + // validate PUBLISH + let result = cmd("PUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::Message, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + // drop and recreate a cluster with more nodes + drop(cluster); + + // recreate the cluster, the assumtion is that the cluster is built with exactly the same params (ports, slots map...) + let cluster = + TestClusterContext::new_with_cluster_client_builder(6, 0, |builder| builder, false); + + // assume slot 14212 will reside in the last node + let last_server_port = { + let addr = cluster.cluster.servers.last().unwrap().addr.clone(); + match addr { + redis::ConnectionAddr::TcpTls { + host: _, + port, + insecure: _, + tls_params: _, + } => port, + redis::ConnectionAddr::Tcp(_, port) => port, + _ => { + panic!("Wrong server address type: {:?}", addr); + } + } + }; + + // wait for new topology discovery + loop { + let mut cmd = redis::cmd("INFO"); + cmd.arg("SERVER"); + let res = publishing_con + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + slot_14212, + SlotAddr::Master, + ))), + ) + .await; + assert!(res.is_ok()); + let res = res.unwrap(); + match res { + Value::VerbatimString { format: _, text } => { + if text.contains(format!("tcp_port:{}", last_server_port).as_str()) { + // new topology rediscovered + break; + } + } + _ => { + panic!("Wrong return type for INFO SERVER command: {:?}", res); + } + } + sleep(futures_time::time::Duration::from_secs(1)).await; + } + + // sleep for one one cycle of topology refresh + sleep(futures_time::time::Duration::from_secs(1)).await; + + // validate PUBLISH + let result = redis::cmd("PUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + // allow message to propagate + sleep(futures_time::time::Duration::from_secs(1)).await; + + loop { + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + // ignore disconnection and subscription notifications due to resubscriptions + if kind == PushKind::Message { + assert_eq!( + data, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ); + break; + } + } + + if use_sharded { + // validate SPUBLISH + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut publishing_con) + .await; + assert_eq!( + result, + Ok(Value::Int(2)) // 2 connections with the same pubsub config + ); + + // allow message to propagate + sleep(futures_time::time::Duration::from_secs(1)).await; + + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + drop(publishing_con); + drop(_listening_con); + + Ok(()) + }) + .unwrap(); + + block_on_all(async move { + sleep(futures_time::time::Duration::from_secs(10)).await; + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_resp3_pubsub() { + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + let use_sharded = redis_ver.starts_with("7."); + + let mut client_subscriptions = PubSubSubscriptionInfo::from([ + ( + PubSubSubscriptionKind::Exact, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ), + ( + PubSubSubscriptionKind::Pattern, + HashSet::from([ + PubSubChannelOrPattern::from("test_*".as_bytes()), + PubSubChannelOrPattern::from("*".as_bytes()), + ]), + ), + ]); + + if use_sharded { + client_subscriptions.insert( + PubSubSubscriptionKind::Sharded, + HashSet::from([PubSubChannelOrPattern::from("test_channel_?".as_bytes())]), + ); + } + + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder + .retries(3) + .use_protocol(ProtocolVersion::RESP3) + .pubsub_subscriptions(client_subscriptions.clone()) + }, + false, + ); + + block_on_all(async move { + let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); + let mut connection = cluster.async_connection(Some(tx.clone())).await; + + // short sleep to allow the server to push subscription notification + sleep(futures_time::time::Duration::from_secs(1)).await; + + validate_subscriptions(&client_subscriptions, &mut rx, false); + + let slot_14212 = get_slot(b"test_channel_?"); + assert_eq!(slot_14212, 14212); + + let slot_0_route = + redis::cluster_routing::Route::new(0, redis::cluster_routing::SlotAddr::Master); + let node_0_route = + redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(slot_0_route); + + // node 0 route is used to ensure that the publish is propagated correctly + let result = connection + .route_command( + redis::Cmd::new() + .arg("PUBLISH") + .arg("test_channel_?") + .arg("test_message"), + RoutingInfo::SingleNode(node_0_route.clone()), + ) + .await; + assert!(result.is_ok()); + + sleep(futures_time::time::Duration::from_secs(1)).await; + + let mut pmsg_cnt = 0; + let mut msg_cnt = 0; + for _ in 0..3 { + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data: _ } = result.unwrap(); + assert!(kind == PushKind::Message || kind == PushKind::PMessage); + if kind == PushKind::Message { + msg_cnt += 1; + } else { + pmsg_cnt += 1; + } + } + assert_eq!(msg_cnt, 1); + assert_eq!(pmsg_cnt, 2); + + if use_sharded { + let result = cmd("SPUBLISH") + .arg("test_channel_?") + .arg("test_message") + .query_async(&mut connection) + .await; + assert_eq!(result, Ok(Value::Int(1))); + + sleep(futures_time::time::Duration::from_secs(1)).await; + let result = rx.try_recv(); + assert!(result.is_ok()); + let PushInfo { kind, data } = result.unwrap(); + assert_eq!( + (kind, data), + ( + PushKind::SMessage, + vec![ + Value::BulkString("test_channel_?".into()), + Value::BulkString("test_message".into()), + ] + ) + ); + } + + Ok(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_periodic_checks_update_topology_after_failover() { + // This test aims to validate the functionality of periodic topology checks by detecting and updating topology changes. + // We will repeatedly execute CLUSTER NODES commands against the primary node responsible for slot 0, recording its node ID. + // Once we've successfully completed commands with the current primary, we will initiate a failover within the same shard. + // Since we are not executing key-based commands, we won't encounter MOVED errors that trigger a slot refresh. + // Consequently, we anticipate that only the periodic topology check will detect this change and trigger topology refresh. + // If successful, the node to which we route the CLUSTER NODES command should be the newly promoted node with a different node ID. + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| { + builder + .periodic_topology_checks(Duration::from_millis(10)) + // Disable the rate limiter to refresh slots immediately on all MOVED errors + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut prev_master_id = "".to_string(); + let max_requests = 5000; + let mut i = 0; + loop { + if i == 10 { + let mut cmd = redis::cmd("CLUSTER"); + cmd.arg("FAILOVER"); + cmd.arg("TAKEOVER"); + let res = connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode( + Route::new(0, SlotAddr::ReplicaRequired), + )), + ) + .await; + assert!(res.is_ok()); + } else if i == max_requests { + break; + } else { + let mut cmd = redis::cmd("CLUSTER"); + cmd.arg("NODES"); + let res = connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode( + Route::new(0, SlotAddr::Master), + )), + ) + .await + .expect("Failed executing CLUSTER NODES"); + let node_id = get_queried_node_id_if_master(res); + if let Some(current_master_id) = node_id { + if prev_master_id.is_empty() { + prev_master_id = current_master_id; + } else if prev_master_id != current_master_id { + return Ok::<_, RedisError>(()); + } + } + } + i += 1; + let _ = sleep(futures_time::time::Duration::from_millis(10)).await; + } + panic!("Topology change wasn't found!"); + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_recover_disconnected_management_connections() { + // This test aims to verify that the management connections used for periodic checks are reconnected, in case that they get killed. + // In order to test this, we choose a single node, kill all connections to it which aren't user connections, and then wait until new + // connections are created. + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder.periodic_topology_checks(Duration::from_millis(10)) + // Disable the rate limiter to refresh slots immediately + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let routing = RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 1, + SlotAddr::Master, + ))); + + let mut connection = cluster.async_connection(None).await; + let max_requests = 5000; + + let connections = + get_clients_names_to_ids(&mut connection, routing.clone().into()).await; + assert!(connections.contains_key(MANAGEMENT_CONN_NAME)); + let management_conn_id = connections.get(MANAGEMENT_CONN_NAME).unwrap(); + + // Get the connection ID of the management connection + kill_connection(&mut connection, management_conn_id).await; + + let connections = + get_clients_names_to_ids(&mut connection, routing.clone().into()).await; + assert!(!connections.contains_key(MANAGEMENT_CONN_NAME)); + + for _ in 0..max_requests { + let _ = sleep(futures_time::time::Duration::from_millis(10)).await; + + let connections = + get_clients_names_to_ids(&mut connection, routing.clone().into()).await; + if connections.contains_key(MANAGEMENT_CONN_NAME) { + return Ok(()); + } + } + + panic!("Topology connection didn't reconnect!"); + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_with_client_name() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.client_name(RedisCluster::client_name().to_string()), + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let client_info: String = cmd("CLIENT") + .arg("INFO") + .query_async(&mut connection) + .await + .unwrap(); + + let client_attrs = parse_client_info(&client_info); + + assert!( + client_attrs.contains_key("name"), + "Could not detect the 'name' attribute in CLIENT INFO output" + ); + + assert_eq!( + client_attrs["name"], + RedisCluster::client_name(), + "Incorrect client name, expecting: {}, got {}", + RedisCluster::client_name(), + client_attrs["name"] + ); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_reroute_from_replica_if_in_loading_state() { + /* Test replica in loading state. The expected behaviour is that the request will be directed to a different replica or the primary. + depends on the read from replica policy. */ + let name = "test_async_cluster_reroute_from_replica_if_in_loading_state"; + + let load_errors: Arc<_> = Arc::new(std::sync::Mutex::new(vec![])); + let load_errors_clone = load_errors.clone(); + + // requests should route to replica + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).read_from_replicas(), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + cmd, + Some(vec![MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..16383), + }]), + )?; + match port { + 6380 | 6381 => { + load_errors_clone.lock().unwrap().push(port); + Err(parse_redis_value(b"-LOADING\r\n")) + } + 6379 => Err(Ok(Value::BulkString(b"123".to_vec()))), + _ => panic!("Wrong node"), + } + }, + ); + for _n in 0..3 { + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + assert_eq!(value, Ok(Some(123))); + } + + let mut load_errors_guard = load_errors.lock().unwrap(); + load_errors_guard.sort(); + + // We expected to get only 2 loading error since the 2 replicas are in loading state. + // The third iteration will be directed to the primary since the connections of the replicas were removed. + assert_eq!(*load_errors_guard, vec![6380, 6381]); + } + + #[test] + fn test_async_cluster_read_from_primary_when_primary_loading() { + // Test primary in loading state. The expected behaviour is that the request will be retried until the primary is no longer in loading state. + let name = "test_async_cluster_read_from_primary_when_primary_loading"; + + const RETRIES: u32 = 3; + const ITERATIONS: u32 = 2; + let load_errors = Arc::new(AtomicU32::new(0)); + let load_errors_clone = load_errors.clone(); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + move |cmd: &[u8], port| { + respond_startup_with_replica_using_config( + name, + cmd, + Some(vec![MockSlotRange { + primary_port: 6379, + replica_ports: vec![6380, 6381], + slot_range: (0..16383), + }]), + )?; + match port { + 6379 => { + let attempts = load_errors_clone.fetch_add(1, Ordering::Relaxed) + 1; + if attempts % RETRIES == 0 { + Err(Ok(Value::BulkString(b"123".to_vec()))) + } else { + Err(parse_redis_value(b"-LOADING\r\n")) + } + } + _ => panic!("Wrong node"), + } + }, + ); + for _n in 0..ITERATIONS { + runtime + .block_on( + cmd("GET") + .arg("test") + .query_async::<_, Value>(&mut connection), + ) + .unwrap(); + } + + assert_eq!(load_errors.load(Ordering::Relaxed), ITERATIONS * RETRIES); + } + + #[test] + fn test_async_cluster_can_be_created_with_partial_slot_coverage() { + let name = "test_async_cluster_can_be_created_with_partial_slot_coverage"; + let slots_config = Some(vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0..8000), + }, + MockSlotRange { + primary_port: 6381, + replica_ports: vec![], + slot_range: (8201..16380), + }, + ]); + + let MockEnv { + async_connection: mut connection, + handler: _handler, + runtime, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(0) + .read_from_replicas(), + name, + move |received_cmd: &[u8], _| { + respond_startup_with_replica_using_config( + name, + received_cmd, + slots_config.clone(), + )?; + Err(Ok(Value::SimpleString("PONG".into()))) + }, + ); + + let res = runtime.block_on(connection.req_packed_command(&redis::cmd("PING"))); + assert!(res.is_ok()); + } + + #[test] + fn test_async_cluster_reconnect_after_complete_server_disconnect() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder.retries(2) + // Disable the rate limiter to refresh slots immediately + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + drop(cluster); + let cmd = cmd("PING"); + + let result = connection + .route_command(&cmd, RoutingInfo::SingleNode(SingleNodeRoutingInfo::Random)) + .await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + + // This will route to all nodes - different path through the code. + let result = connection.req_packed_command(&cmd).await; + // TODO - this should be a NoConnectionError, but ATM we get the errors from the failing + assert!(result.is_err()); + + let _cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(2), + false, + ); + + let result = connection.req_packed_command(&cmd).await.unwrap(); + assert_eq!(result, Value::SimpleString("PONG".to_string())); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_reconnect_after_complete_server_disconnect_route_to_many() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(3), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + drop(cluster); + + // recreate cluster + let _cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(2), + false, + ); + + let cmd = cmd("PING"); + // explicitly route to all primaries and request all succeeded + let result = connection + .route_command( + &cmd, + RoutingInfo::MultiNode(( + MultipleNodeRoutingInfo::AllMasters, + Some(redis::cluster_routing::ResponsePolicy::AllSucceeded), + )), + ) + .await; + assert!(result.is_ok()); + + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_blocking_command_when_cluster_drops() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(3), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + futures::future::join( + async { + let res = connection.blpop::<&str, f64>("foo", 0.0).await; + assert!(res.is_err()); + println!("blpop returned error {:?}", res.map_err(|e| e.to_string())); + }, + async { + let _ = sleep(futures_time::time::Duration::from_secs(3)).await; + drop(cluster); + }, + ) + .await; + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_saves_reconnected_connection() { + let name = "test_async_cluster_saves_reconnected_connection"; + let ping_attempts = Arc::new(AtomicI32::new(0)); + let ping_attempts_clone = ping_attempts.clone(); + let get_attempts = AtomicI32::new(0); + + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(1), + name, + move |cmd: &[u8], port| { + if port == 6380 { + respond_startup_two_nodes(name, cmd)?; + return Err(parse_redis_value( + format!("-MOVED 123 {name}:6379\r\n").as_bytes(), + )); + } + + if contains_slice(cmd, b"PING") { + let connect_attempt = ping_attempts_clone.fetch_add(1, Ordering::Relaxed); + let past_get_attempts = get_attempts.load(Ordering::Relaxed); + // We want connection checks to fail after the first GET attempt, until it retries. Hence, we wait for 5 PINGs - + // 1. initial connection, + // 2. refresh slots on client creation, + // 3. refresh_connections `check_connection` after first GET failed, + // 4. refresh_connections `connect_and_check` after first GET failed, + // 5. reconnect on 2nd GET attempt. + // more than 5 attempts mean that the server reconnects more than once, which is the behavior we're testing against. + if past_get_attempts != 1 || connect_attempt > 3 { + respond_startup_two_nodes(name, cmd)?; + } + if connect_attempt > 5 { + panic!("Too many pings!"); + } + Err(Err(broken_pipe_error())) + } else { + respond_startup_two_nodes(name, cmd)?; + let past_get_attempts = get_attempts.fetch_add(1, Ordering::Relaxed); + // we fail the initial GET request, and after that we'll fail the first reconnect attempt, in the `refresh_connections` attempt. + if past_get_attempts == 0 { + // Error once with io-error, ensure connection is reestablished w/out calling + // other node (i.e., not doing a full slot rebuild) + Err(Err(broken_pipe_error())) + } else { + Err(Ok(Value::BulkString(b"123".to_vec()))) + } + } + }, + ); + + for _ in 0..4 { + let value = runtime.block_on( + cmd("GET") + .arg("test") + .query_async::<_, Option>(&mut connection), + ); + + assert_eq!(value, Ok(Some(123))); + } + // If you need to change the number here due to a change in the cluster, you probably also need to adjust the test. + // See the PING counts above to explain why 5 is the target number. + assert_eq!(ping_attempts.load(Ordering::Acquire), 5); + } + + #[test] + fn test_async_cluster_periodic_checks_use_management_connection() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| { + builder.periodic_topology_checks(Duration::from_millis(10)) + // Disable the rate limiter to refresh slots immediately on the periodic checks + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + }, + false, + ); + + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let mut client_list = "".to_string(); + let max_requests = 1000; + let mut i = 0; + loop { + if i == max_requests { + break; + } else { + client_list = cmd("CLIENT") + .arg("LIST") + .query_async::<_, String>(&mut connection) + .await + .expect("Failed executing CLIENT LIST"); + let mut client_list_parts = client_list.split('\n'); + if client_list_parts + .any(|line| line.contains(MANAGEMENT_CONN_NAME) && line.contains("cmd=cluster")) + && client_list.matches(MANAGEMENT_CONN_NAME).count() == 1 { + return Ok::<_, RedisError>(()); + } + } + i += 1; + let _ = sleep(futures_time::time::Duration::from_millis(10)).await; + } + panic!("Couldn't find a management connection or the connection wasn't used to execute CLUSTER SLOTS {:?}", client_list); + }) + .unwrap(); + } + + async fn get_clients_names_to_ids( + connection: &mut ClusterConnection, + routing: Option, + ) -> HashMap { + let mut client_list_cmd = redis::cmd("CLIENT"); + client_list_cmd.arg("LIST"); + let value = match routing { + Some(routing) => connection.route_command(&client_list_cmd, routing).await, + None => connection.req_packed_command(&client_list_cmd).await, + } + .unwrap(); + let string = String::from_owned_redis_value(value).unwrap(); + string + .split('\n') + .filter_map(|line| { + if line.is_empty() { + return None; + } + let key_values = line + .split(' ') + .filter_map(|value| { + let mut split = value.split('='); + match (split.next(), split.next()) { + (Some(key), Some(val)) => Some((key, val)), + _ => None, + } + }) + .collect::>(); + match (key_values.get("name"), key_values.get("id")) { + (Some(key), Some(val)) if !val.is_empty() => { + Some((key.to_string(), val.to_string())) + } + _ => None, + } + }) + .collect() + } + + async fn kill_connection(killer_connection: &mut ClusterConnection, connection_to_kill: &str) { + let mut cmd = redis::cmd("CLIENT"); + cmd.arg("KILL"); + cmd.arg("ID"); + cmd.arg(connection_to_kill); + // Kill the management connection in the primary node that holds slot 0 + assert!(killer_connection + .route_command( + &cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::SpecificNode(Route::new( + 0, + SlotAddr::Master, + )),), + ) + .await + .is_ok()); + } + + #[test] + fn test_async_cluster_only_management_connection_is_reconnected_after_connection_failure() { + // This test will check two aspects: + // 1. Ensuring that after a disconnection in the management connection, a new management connection is established. + // 2. Confirming that a failure in the management connection does not impact the user connection, which should remain intact. + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.periodic_topology_checks(Duration::from_millis(10)), + false, + ); + block_on_all(async move { + let mut connection = cluster.async_connection(None).await; + let _client_list = "".to_string(); + let max_requests = 500; + let mut i = 0; + // Set the name of the client connection to 'user-connection', so we'll be able to identify it later on + assert!(cmd("CLIENT") + .arg("SETNAME") + .arg("user-connection") + .query_async::<_, Value>(&mut connection) + .await + .is_ok()); + // Get the client list + let names_to_ids = get_clients_names_to_ids(&mut connection, Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(0, SlotAddr::Master))))).await; + + // Get the connection ID of 'user-connection' + let user_conn_id = names_to_ids.get("user-connection").unwrap(); + // Get the connection ID of the management connection + let management_conn_id = names_to_ids.get(MANAGEMENT_CONN_NAME).unwrap(); + // Get another connection that will be used to kill the management connection + let mut killer_connection = cluster.async_connection(None).await; + kill_connection(&mut killer_connection, management_conn_id).await; + loop { + // In this loop we'll wait for the new management connection to be established + if i == max_requests { + break; + } else { + let names_to_ids = get_clients_names_to_ids(&mut connection, Some(RoutingInfo::SingleNode( + SingleNodeRoutingInfo::SpecificNode(Route::new(0, SlotAddr::Master))))).await; + if names_to_ids.contains_key(MANAGEMENT_CONN_NAME) { + // A management connection is found + let curr_management_conn_id = + names_to_ids.get(MANAGEMENT_CONN_NAME).unwrap(); + let curr_user_conn_id = + names_to_ids.get("user-connection").unwrap(); + // Confirm that the management connection has a new connection ID, and verify that the user connection remains unaffected. + if (curr_management_conn_id != management_conn_id) + && (curr_user_conn_id == user_conn_id) + { + return Ok::<_, RedisError>(()); + } + } else { + i += 1; + let _ = sleep(futures_time::time::Duration::from_millis(50)).await; + continue; + } + } + } + panic!( + "No reconnection of the management connection found, or there was an unwantedly reconnection of the user connections. + \nprev_management_conn_id={:?},prev_user_conn_id={:?}\nclient list={:?}", + management_conn_id, user_conn_id, names_to_ids + ); + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_dont_route_to_a_random_on_non_key_based_cmd() { + // This test verifies that non-key-based commands do not get routed to a random node + // when no connection is found for the given route. Instead, the appropriate error + // should be raised. + let name = "test_async_cluster_dont_route_to_a_random_on_non_key_based_cmd"; + let request_counter = Arc::new(AtomicU32::new(0)); + let cloned_req_counter = request_counter.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]).retries(1), + name, + move |received_cmd: &[u8], _| { + let slots_config_vec = vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0_u16..8000_u16), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![], + // Don't cover all slots + slot_range: (8001_u16..12000_u16), + }, + ]; + respond_startup_with_config(name, received_cmd, Some(slots_config_vec), false)?; + // If requests are sent to random nodes, they will be caught and counted here. + request_counter.fetch_add(1, Ordering::Relaxed); + Err(Ok(Value::Nil)) + }, + ); + + runtime + .block_on(async move { + let uncovered_slot = 16000; + let route = redis::cluster_routing::Route::new( + uncovered_slot, + redis::cluster_routing::SlotAddr::Master, + ); + let single_node_route = + redis::cluster_routing::SingleNodeRoutingInfo::SpecificNode(route); + let routing = RoutingInfo::SingleNode(single_node_route); + let res = connection + .route_command(&redis::cmd("FLUSHALL"), routing) + .await; + assert!(res.is_err()); + let res_err = res.unwrap_err(); + assert_eq!( + res_err.kind(), + ErrorKind::ConnectionNotFoundForRoute, + "{:?}", + res_err + ); + assert_eq!(cloned_req_counter.load(Ordering::Relaxed), 0); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_route_to_random_on_key_based_cmd() { + // This test verifies that key-based commands get routed to a random node + // when no connection is found for the given route. The command should + // then be redirected correctly by the server's MOVED error. + let name = "test_async_cluster_route_to_random_on_key_based_cmd"; + let request_counter = Arc::new(AtomicU32::new(0)); + let cloned_req_counter = request_counter.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + handler: _handler, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]), + name, + move |received_cmd: &[u8], _| { + let slots_config_vec = vec![ + MockSlotRange { + primary_port: 6379, + replica_ports: vec![], + slot_range: (0_u16..8000_u16), + }, + MockSlotRange { + primary_port: 6380, + replica_ports: vec![], + // Don't cover all slots + slot_range: (8001_u16..12000_u16), + }, + ]; + respond_startup_with_config(name, received_cmd, Some(slots_config_vec), false)?; + if contains_slice(received_cmd, b"GET") { + if request_counter.fetch_add(1, Ordering::Relaxed) == 0 { + return Err(parse_redis_value( + format!("-MOVED 12182 {name}:6380\r\n").as_bytes(), + )); + } else { + return Err(Ok(Value::SimpleString("bar".into()))); + } + } + panic!("unexpected command {:?}", received_cmd); + }, + ); + + runtime + .block_on(async move { + // The keyslot of "foo" is 12182 and it isn't covered by any node, so we expect the + // request to be routed to a random node and then to be redirected to the MOVED node (2 requests in total) + let res: String = connection.get("foo").await.unwrap(); + assert_eq!(res, "bar".to_string()); + assert_eq!(cloned_req_counter.load(Ordering::Relaxed), 2); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_do_not_retry_when_receiver_was_dropped() { + let name = "test_async_cluster_do_not_retry_when_receiver_was_dropped"; + let cmd = cmd("FAKE_COMMAND"); + let packed_cmd = cmd.get_packed_command(); + let request_counter = Arc::new(AtomicU32::new(0)); + let cloned_req_counter = request_counter.clone(); + let MockEnv { + runtime, + async_connection: mut connection, + .. + } = MockEnv::with_client_builder( + ClusterClient::builder(vec![&*format!("redis://{name}")]) + .retries(5) + .max_retry_wait(2) + .min_retry_wait(2), + name, + move |received_cmd: &[u8], _| { + respond_startup(name, received_cmd)?; + + if received_cmd == packed_cmd { + cloned_req_counter.fetch_add(1, Ordering::Relaxed); + return Err(Err((ErrorKind::TryAgain, "seriously, try again").into())); + } + + Err(Ok(Value::Okay)) + }, + ); + + runtime.block_on(async move { + let err = cmd + .query_async::<_, Value>(&mut connection) + .timeout(futures_time::time::Duration::from_millis(1)) + .await + .unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::TimedOut); + + // we sleep here, to allow the cluster connection time to retry. We expect it won't, but without this + // sleep the test will complete before the the runtime gave the connection time to retry, which would've made the + // test pass regardless of whether the connection tries retrying or not. + sleep(Duration::from_millis(10).into()).await; + }); + + assert_eq!(request_counter.load(Ordering::Relaxed), 1); + } + + #[cfg(feature = "tls-rustls")] + mod mtls_test { + use crate::support::mtls_test::create_cluster_client_from_cluster; + use redis::ConnectionInfo; + + use super::*; + + #[test] + fn test_async_cluster_basic_cmd_with_mtls() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + block_on_all(async move { + let client = create_cluster_client_from_cluster(&cluster, true).unwrap(); + let mut connection = client.get_async_connection(None).await.unwrap(); + cmd("SET") + .arg("test") + .arg("test_data") + .query_async(&mut connection) + .await?; + let res: String = cmd("GET") + .arg("test") + .clone() + .query_async(&mut connection) + .await?; + assert_eq!(res, "test_data"); + Ok::<_, RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_async_cluster_should_not_connect_without_mtls_enabled() { + let cluster = TestClusterContext::new_with_mtls(3, 0); + block_on_all(async move { + let client = create_cluster_client_from_cluster(&cluster, false).unwrap(); + let connection = client.get_async_connection(None).await; + match cluster.cluster.servers.first().unwrap().connection_info() { + ConnectionInfo { + addr: redis::ConnectionAddr::TcpTls { .. }, + .. + } => { + if connection.is_ok() { + panic!("Must NOT be able to connect without client credentials if server accepts TLS"); + } + } + _ => { + if let Err(e) = connection { + panic!("Must be able to connect without client credentials if server does NOT accept TLS: {e:?}"); + } + } + } + Ok::<_, RedisError>(()) + }).unwrap(); + } + } +} diff --git a/glide-core/redis-rs/redis/tests/test_cluster_scan.rs b/glide-core/redis-rs/redis/tests/test_cluster_scan.rs new file mode 100644 index 0000000000..29a3c87b48 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_cluster_scan.rs @@ -0,0 +1,849 @@ +#![cfg(feature = "cluster-async")] +mod support; + +#[cfg(test)] +mod test_cluster_scan_async { + use crate::support::*; + use rand::Rng; + use redis::cluster_routing::{RoutingInfo, SingleNodeRoutingInfo}; + use redis::{cmd, from_redis_value, ObjectType, RedisResult, ScanStateRC, Value}; + use std::time::Duration; + + async fn kill_one_node( + cluster: &TestClusterContext, + slot_distribution: Vec<(String, String, String, Vec>)>, + ) -> RoutingInfo { + let mut cluster_conn = cluster.async_connection(None).await; + let distribution_clone = slot_distribution.clone(); + let index_of_random_node = rand::thread_rng().gen_range(0..slot_distribution.len()); + let random_node = distribution_clone.get(index_of_random_node).unwrap(); + let random_node_route_info = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: random_node.1.clone(), + port: random_node.2.parse::().unwrap(), + }); + let random_node_id = &random_node.0; + // Create connections to all nodes + for node in &distribution_clone { + if random_node_id == &node.0 { + continue; + } + let node_route = RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: node.1.clone(), + port: node.2.parse::().unwrap(), + }); + + let mut forget_cmd = cmd("CLUSTER"); + forget_cmd.arg("FORGET").arg(random_node_id); + let _: RedisResult = cluster_conn + .route_command(&forget_cmd, node_route.clone()) + .await; + } + let mut shutdown_cmd = cmd("SHUTDOWN"); + shutdown_cmd.arg("NOSAVE"); + let _: RedisResult = cluster_conn + .route_command(&shutdown_cmd, random_node_route_info.clone()) + .await; + random_node_route_info + } + + #[tokio::test] + async fn test_async_cluster_scan() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + + // Set some keys + for i in 0..10 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + for (i, key) in keys.iter().enumerate() { + assert_eq!(key.to_owned(), format!("key{}", i)); + } + } + + #[tokio::test] // test cluster scan with slot migration in the middle + async fn test_async_cluster_scan_with_migration() { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.async_connection(None).await; + // Set some keys + let mut expected_keys: Vec = Vec::new(); + + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); + keys.extend(scan_keys); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + let mut cluster_nodes = cluster.get_cluster_nodes().await; + let slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + cluster + .migrate_slots_from_node_to_another(slot_distribution.clone()) + .await; + for node in &slot_distribution { + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: node.1.clone(), + port: node.2.parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + + cluster_nodes = cluster.get_cluster_nodes().await; + // Compare slot distribution before and after migration + let new_slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + assert_ne!(slot_distribution, new_slot_distribution); + } + } + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + assert_eq!(keys, expected_keys); + } + + #[tokio::test] // test cluster scan with node fail in the middle + async fn test_async_cluster_scan_with_fail() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(1), + false, + ); + let mut connection = cluster.async_connection(None).await; + // Set some keys + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + let mut result: RedisResult = Ok(Value::Nil); + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + let (next_cursor, scan_keys) = match scan_response { + Ok((cursor, keys)) => (cursor, keys), + Err(e) => { + result = Err(e); + break; + } + }; + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + let cluster_nodes = cluster.get_cluster_nodes().await; + let slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + // simulate node failure + let killed_node_routing = kill_one_node(&cluster, slot_distribution.clone()).await; + let ready = cluster.wait_for_fail_to_finish(&killed_node_routing).await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + let cluster_nodes = cluster.get_cluster_nodes().await; + let new_slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + assert_ne!(slot_distribution, new_slot_distribution); + } + } + // We expect an error of finding address + assert!(result.is_err()); + } + + #[tokio::test] // Test cluster scan with killing all masters during scan + async fn test_async_cluster_scan_with_all_masters_down() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| { + builder + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + .retries(1) + }, + false, + ); + + let mut connection = cluster.async_connection(None).await; + + let mut expected_keys: Vec = Vec::new(); + + cluster.wait_for_cluster_up(); + + let mut cluster_nodes = cluster.get_cluster_nodes().await; + + let slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + let masters = cluster.get_masters(&cluster_nodes).await; + let replicas = cluster.get_replicas(&cluster_nodes).await; + + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + for replica in replicas.iter() { + let mut failover_cmd = cmd("CLUSTER"); + let _: RedisResult = connection + .route_command( + failover_cmd.arg("FAILOVER").arg("TAKEOVER"), + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }), + ) + .await; + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + + for master in masters.iter() { + for replica in replicas.clone() { + let mut forget_cmd = cmd("CLUSTER"); + forget_cmd.arg("FORGET").arg(master[0].clone()); + let _: RedisResult = connection + .route_command( + &forget_cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }), + ) + .await; + } + } + for master in masters.iter() { + let mut shut_cmd = cmd("SHUTDOWN"); + shut_cmd.arg("NOSAVE"); + let _ = connection + .route_command( + &shut_cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: master[1].clone(), + port: master[2].parse::().unwrap(), + }), + ) + .await; + let ready = cluster + .wait_for_fail_to_finish(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: master[1].clone(), + port: master[2].parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + for replica in replicas.iter() { + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + cluster_nodes = cluster.get_cluster_nodes().await; + let new_slot_distribution = cluster.get_slots_ranges_distribution(&cluster_nodes); + assert_ne!(slot_distribution, new_slot_distribution); + } + } + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + assert_eq!(keys, expected_keys); + } + + #[tokio::test] + // Test cluster scan with killing all replicas during scan + async fn test_async_cluster_scan_with_all_replicas_down() { + let cluster = TestClusterContext::new_with_cluster_client_builder( + 6, + 1, + |builder| { + builder + .slots_refresh_rate_limit(Duration::from_secs(0), 0) + .retries(1) + }, + false, + ); + + let mut connection = cluster.async_connection(None).await; + + let mut expected_keys: Vec = Vec::new(); + + for server in cluster.cluster.servers.iter() { + let address = server.addr.clone().to_string(); + let host_and_port = address.split(':'); + let host = host_and_port.clone().next().unwrap().to_string(); + let port = host_and_port + .clone() + .last() + .unwrap() + .parse::() + .unwrap(); + let ready = cluster + .wait_for_connection_is_ready(&RoutingInfo::SingleNode( + SingleNodeRoutingInfo::ByAddress { host, port }, + )) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + + let cluster_nodes = cluster.get_cluster_nodes().await; + + let replicas = cluster.get_replicas(&cluster_nodes).await; + + for i in 0..1000 { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + for replica in replicas.iter() { + let mut shut_cmd = cmd("SHUTDOWN"); + shut_cmd.arg("NOSAVE"); + let ready: RedisResult = connection + .route_command( + &shut_cmd, + RoutingInfo::SingleNode(SingleNodeRoutingInfo::ByAddress { + host: replica[1].clone(), + port: replica[2].parse::().unwrap(), + }), + ) + .await; + match ready { + Ok(_) => {} + Err(e) => { + println!("error: {:?}", e); + break; + } + } + } + let new_cluster_nodes = cluster.get_cluster_nodes().await; + assert_ne!(cluster_nodes, new_cluster_nodes); + } + } + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + assert_eq!(keys, expected_keys); + } + #[tokio::test] + // Test cluster scan with setting keys for each iteration + async fn test_async_cluster_scan_set_in_the_middle() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + i += 1; + if i == 1000 { + break; + } + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + let key = format!("key{}", i); + i += 1; + let res: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + assert!(res.is_ok()); + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() >= expected_keys.len()); + } + + #[tokio::test] + // Test cluster scan with deleting keys for each iteration + async fn test_async_cluster_scan_dell_in_the_middle() { + let cluster = TestClusterContext::new(3, 0); + + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + i += 1; + if i == 1000 { + break; + } + } + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + i -= 1; + let key = format!("key{}", i); + + let res: Result<(), redis::RedisError> = redis::cmd("del") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + assert!(res.is_ok()); + expected_keys.remove(i as usize); + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() >= expected_keys.len()); + } + + #[tokio::test] + // Testing cluster scan with Pattern option + async fn test_async_cluster_scan_with_pattern() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key:pattern:{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + let non_relevant_key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&non_relevant_key) + .arg("value") + .query_async(&mut connection) + .await; + i += 1; + if i == 500 { + break; + } + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan_with_pattern(scan_state_rc, "key:pattern:*", None, None) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() == expected_keys.len()); + } + + #[tokio::test] + // Testing cluster scan with TYPE option + async fn test_async_cluster_scan_with_type() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SADD") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + let key = format!("key-that-is-not-set{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + i += 1; + if i == 500 { + break; + } + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, None, Some(ObjectType::Set)) + .await + .unwrap(); + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() == expected_keys.len()); + } + + #[tokio::test] + // Testing cluster scan with COUNT option + async fn test_async_cluster_scan_with_count() { + let cluster = TestClusterContext::new(3, 0); + let mut connection = cluster.async_connection(None).await; + let mut expected_keys: Vec = Vec::new(); + let mut i = 0; + // Set some keys + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + expected_keys.push(key); + i += 1; + if i == 1000 { + break; + } + } + + // Scan the keys + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = vec![]; + let mut comparing_times = 0; + loop { + let (next_cursor, scan_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc.clone(), Some(100), None) + .await + .unwrap(); + let (_, scan_without_count_keys): (ScanStateRC, Vec) = connection + .cluster_scan(scan_state_rc, Some(100), None) + .await + .unwrap(); + if !scan_keys.is_empty() && !scan_without_count_keys.is_empty() { + assert!(scan_keys.len() >= scan_without_count_keys.len()); + + comparing_times += 1; + } + scan_state_rc = next_cursor; + let mut scan_keys = scan_keys + .into_iter() + .map(|v| from_redis_value(&v).unwrap()) + .collect::>(); // Change the type of `keys` to `Vec` + keys.append(&mut scan_keys); + if scan_state_rc.is_finished() { + break; + } + } + assert!(comparing_times > 0); + // Check if all keys were scanned + keys.sort(); + keys.dedup(); + expected_keys.sort(); + expected_keys.dedup(); + // check if all keys were scanned + for key in expected_keys.iter() { + assert!(keys.contains(key)); + } + assert!(keys.len() == expected_keys.len()); + } + + #[tokio::test] + // Testing cluster scan when connection fails in the middle and we get an error + // then cluster up again and scanning can continue without any problem + async fn test_async_cluster_scan_failover() { + let mut cluster = TestClusterContext::new_with_cluster_client_builder( + 3, + 0, + |builder| builder.retries(1), + false, + ); + let mut connection = cluster.async_connection(None).await; + let mut i = 0; + loop { + let key = format!("key{}", i); + let _: Result<(), redis::RedisError> = redis::cmd("SET") + .arg(&key) + .arg("value") + .query_async(&mut connection) + .await; + i += 1; + if i == 1000 { + break; + } + } + let mut scan_state_rc = ScanStateRC::new(); + let mut keys: Vec = Vec::new(); + let mut count = 0; + loop { + count += 1; + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + if count == 5 { + drop(cluster); + let scan_response: RedisResult<(ScanStateRC, Vec)> = connection + .cluster_scan(scan_state_rc.clone(), None, None) + .await; + assert!(scan_response.is_err()); + break; + }; + } + cluster = TestClusterContext::new(3, 0); + connection = cluster.async_connection(None).await; + loop { + let scan_response: RedisResult<(ScanStateRC, Vec)> = + connection.cluster_scan(scan_state_rc, None, None).await; + if scan_response.is_err() { + println!("error: {:?}", scan_response); + } + let (next_cursor, scan_keys) = scan_response.unwrap(); + scan_state_rc = next_cursor; + keys.extend(scan_keys.into_iter().map(|v| from_redis_value(&v).unwrap())); + if scan_state_rc.is_finished() { + break; + } + } + } +} diff --git a/glide-core/redis-rs/redis/tests/test_geospatial.rs b/glide-core/redis-rs/redis/tests/test_geospatial.rs new file mode 100644 index 0000000000..8bec9a1d73 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_geospatial.rs @@ -0,0 +1,197 @@ +#![cfg(feature = "geospatial")] + +use assert_approx_eq::assert_approx_eq; + +use redis::geo::{Coord, RadiusOptions, RadiusOrder, RadiusSearchResult, Unit}; +use redis::{Commands, RedisResult}; + +mod support; +use crate::support::*; + +const PALERMO: (&str, &str, &str) = ("13.361389", "38.115556", "Palermo"); +const CATANIA: (&str, &str, &str) = ("15.087269", "37.502669", "Catania"); +const AGRIGENTO: (&str, &str, &str) = ("13.5833332", "37.316667", "Agrigento"); + +#[test] +fn test_geoadd_single_tuple() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", PALERMO), Ok(1)); +} + +#[test] +fn test_geoadd_multiple_tuples() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); +} + +#[test] +fn test_geodist_existing_members() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + let dist: f64 = con + .geo_dist("my_gis", PALERMO.2, CATANIA.2, Unit::Kilometers) + .unwrap(); + assert_approx_eq!(dist, 166.2742, 0.001); +} + +#[test] +fn test_geodist_support_option() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + // We should be able to extract the value as an Option<_>, so we can detect + // if a member is missing + + let result: RedisResult> = con.geo_dist("my_gis", PALERMO.2, "none", Unit::Meters); + assert_eq!(result, Ok(None)); + + let result: RedisResult> = + con.geo_dist("my_gis", PALERMO.2, CATANIA.2, Unit::Meters); + assert_ne!(result, Ok(None)); + + let dist = result.unwrap().unwrap(); + assert_approx_eq!(dist, 166_274.151_6, 0.01); +} + +#[test] +fn test_geohash() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + let result: RedisResult> = con.geo_hash("my_gis", PALERMO.2); + assert_eq!(result, Ok(vec![String::from("sqc8b49rny0")])); + + let result: RedisResult> = con.geo_hash("my_gis", &[PALERMO.2, CATANIA.2]); + assert_eq!( + result, + Ok(vec![ + String::from("sqc8b49rny0"), + String::from("sqdtr74hyu0"), + ]) + ); +} + +#[test] +fn test_geopos() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + let result: Vec> = con.geo_pos("my_gis", &[PALERMO.2]).unwrap(); + assert_eq!(result.len(), 1); + + assert_approx_eq!(result[0][0], 13.36138, 0.0001); + assert_approx_eq!(result[0][1], 38.11555, 0.0001); + + // Using the Coord struct + let result: Vec> = con.geo_pos("my_gis", &[PALERMO.2, CATANIA.2]).unwrap(); + assert_eq!(result.len(), 2); + + assert_approx_eq!(result[0].longitude, 13.36138, 0.0001); + assert_approx_eq!(result[0].latitude, 38.11555, 0.0001); + + assert_approx_eq!(result[1].longitude, 15.08726, 0.0001); + assert_approx_eq!(result[1].latitude, 37.50266, 0.0001); +} + +#[test] +fn test_use_coord_struct() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!( + con.geo_add( + "my_gis", + (Coord::lon_lat(13.361_389, 38.115_556), "Palermo") + ), + Ok(1) + ); + + let result: Vec> = con.geo_pos("my_gis", "Palermo").unwrap(); + assert_eq!(result.len(), 1); + + assert_approx_eq!(result[0].longitude, 13.36138, 0.0001); + assert_approx_eq!(result[0].latitude, 38.11555, 0.0001); +} + +#[test] +fn test_georadius() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA]), Ok(2)); + + let mut geo_radius = |opts: RadiusOptions| -> Vec { + con.geo_radius("my_gis", 15.0, 37.0, 200.0, Unit::Kilometers, opts) + .unwrap() + }; + + // Simple request, without extra data + let mut result = geo_radius(RadiusOptions::default()); + result.sort_by(|a, b| Ord::cmp(&a.name, &b.name)); + + assert_eq!(result.len(), 2); + + assert_eq!(result[0].name.as_str(), "Catania"); + assert_eq!(result[0].coord, None); + assert_eq!(result[0].dist, None); + + assert_eq!(result[1].name.as_str(), "Palermo"); + assert_eq!(result[1].coord, None); + assert_eq!(result[1].dist, None); + + // Get data with multiple fields + let result = geo_radius(RadiusOptions::default().with_dist().order(RadiusOrder::Asc)); + + assert_eq!(result.len(), 2); + + assert_eq!(result[0].name.as_str(), "Catania"); + assert_eq!(result[0].coord, None); + assert_approx_eq!(result[0].dist.unwrap(), 56.4413, 0.001); + + assert_eq!(result[1].name.as_str(), "Palermo"); + assert_eq!(result[1].coord, None); + assert_approx_eq!(result[1].dist.unwrap(), 190.4424, 0.001); + + let result = geo_radius( + RadiusOptions::default() + .with_coord() + .order(RadiusOrder::Desc) + .limit(1), + ); + + assert_eq!(result.len(), 1); + + assert_eq!(result[0].name.as_str(), "Palermo"); + assert_approx_eq!(result[0].coord.as_ref().unwrap().longitude, 13.361_389); + assert_approx_eq!(result[0].coord.as_ref().unwrap().latitude, 38.115_556); + assert_eq!(result[0].dist, None); +} + +#[test] +fn test_georadius_by_member() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + assert_eq!(con.geo_add("my_gis", &[PALERMO, CATANIA, AGRIGENTO]), Ok(3)); + + // Simple request, without extra data + let opts = RadiusOptions::default().order(RadiusOrder::Asc); + let result: Vec = con + .geo_radius_by_member("my_gis", AGRIGENTO.2, 100.0, Unit::Kilometers, opts) + .unwrap(); + let names: Vec<_> = result.iter().map(|c| c.name.as_str()).collect(); + + assert_eq!(names, vec!["Agrigento", "Palermo"]); +} diff --git a/glide-core/redis-rs/redis/tests/test_module_json.rs b/glide-core/redis-rs/redis/tests/test_module_json.rs new file mode 100644 index 0000000000..08fed23930 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_module_json.rs @@ -0,0 +1,540 @@ +#![cfg(feature = "json")] + +use std::assert_eq; +use std::collections::HashMap; + +use redis::{JsonCommands, ProtocolVersion}; + +use redis::{ + ErrorKind, RedisError, RedisResult, + Value::{self, *}, +}; + +use crate::support::*; +mod support; + +use serde::Serialize; +// adds json! macro for quick json generation on the fly. +use serde_json::json; + +const TEST_KEY: &str = "my_json"; + +const MTLS_NOT_ENABLED: bool = false; + +#[test] +fn test_module_json_serialize_error() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + #[derive(Debug, Serialize)] + struct InvalidSerializedStruct { + // Maps in serde_json must have string-like keys + // so numbers and strings, anything else will cause the serialization to fail + // this is basically the only way to make a serialization fail at runtime + // since rust doesnt provide the necessary ability to enforce this + pub invalid_json: HashMap, i64>, + } + + let mut test_invalid_value: InvalidSerializedStruct = InvalidSerializedStruct { + invalid_json: HashMap::new(), + }; + + test_invalid_value.invalid_json.insert(None, 2i64); + + let set_invalid: RedisResult = con.json_set(TEST_KEY, "$", &test_invalid_value); + + assert_eq!( + set_invalid, + Err(RedisError::from(( + ErrorKind::Serialize, + "Serialization Error", + String::from("key must be string") + ))) + ); +} + +#[test] +fn test_module_json_arr_append() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64], "nested": {"a": [1i64, 2i64]}, "nested2": {"a": 42i64}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_append: RedisResult = con.json_arr_append(TEST_KEY, "$..a", &3i64); + + assert_eq!(json_append, Ok(Array(vec![Int(2i64), Int(3i64), Nil]))); +} + +#[test] +fn test_module_json_arr_index() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64, 2i64, 3i64, 2i64], "nested": {"a": [3i64, 4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrindex: RedisResult = con.json_arr_index(TEST_KEY, "$..a", &2i64); + + assert_eq!(json_arrindex, Ok(Array(vec![Int(1i64), Int(-1i64)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64, 2i64, 3i64, 2i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrindex_2: RedisResult = + con.json_arr_index_ss(TEST_KEY, "$..a", &2i64, &0, &0); + + assert_eq!(json_arrindex_2, Ok(Array(vec![Int(1i64), Nil]))); +} + +#[test] +fn test_module_json_arr_insert() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[3i64], "nested": {"a": [3i64 ,4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrinsert: RedisResult = con.json_arr_insert(TEST_KEY, "$..a", 0, &1i64); + + assert_eq!(json_arrinsert, Ok(Array(vec![Int(2), Int(3)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[1i64 ,2i64 ,3i64 ,2i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrinsert_2: RedisResult = con.json_arr_insert(TEST_KEY, "$..a", 0, &1i64); + + assert_eq!(json_arrinsert_2, Ok(Array(vec![Int(5), Nil]))); +} + +#[test] +fn test_module_json_arr_len() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [3i64], "nested": {"a": [3i64, 4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrlen: RedisResult = con.json_arr_len(TEST_KEY, "$..a"); + + assert_eq!(json_arrlen, Ok(Array(vec![Int(1), Int(2)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [1i64, 2i64, 3i64, 2i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrlen_2: RedisResult = con.json_arr_len(TEST_KEY, "$..a"); + + assert_eq!(json_arrlen_2, Ok(Array(vec![Int(4), Nil]))); +} + +#[test] +fn test_module_json_arr_pop() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [3i64], "nested": {"a": [3i64, 4i64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrpop: RedisResult = con.json_arr_pop(TEST_KEY, "$..a", -1); + + assert_eq!( + json_arrpop, + Ok(Array(vec![ + // convert string 3 to its ascii value as bytes + BulkString(Vec::from("3".as_bytes())), + BulkString(Vec::from("4".as_bytes())) + ])) + ); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":["foo", "bar"], "nested": {"a": false}, "nested2": {"a":[]}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrpop_2: RedisResult = con.json_arr_pop(TEST_KEY, "$..a", -1); + + assert_eq!( + json_arrpop_2, + Ok(Array(vec![ + BulkString(Vec::from("\"bar\"".as_bytes())), + Nil, + Nil + ])) + ); +} + +#[test] +fn test_module_json_arr_trim() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [], "nested": {"a": [1i64, 4u64]}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_arrtrim: RedisResult = con.json_arr_trim(TEST_KEY, "$..a", 1, 1); + + assert_eq!(json_arrtrim, Ok(Array(vec![Int(0), Int(1)]))); + + let update_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": [1i64, 2i64, 3i64, 4i64], "nested": {"a": false}}), + ); + + assert_eq!(update_initial, Ok(true)); + + let json_arrtrim_2: RedisResult = con.json_arr_trim(TEST_KEY, "$..a", 1, 1); + + assert_eq!(json_arrtrim_2, Ok(Array(vec![Int(1), Nil]))); +} + +#[test] +fn test_module_json_clear() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set(TEST_KEY, "$", &json!({"obj": {"a": 1i64, "b": 2i64}, "arr": [1i64, 2i64, 3i64], "str": "foo", "bool": true, "int": 42i64, "float": std::f64::consts::PI})); + + assert_eq!(set_initial, Ok(true)); + + let json_clear: RedisResult = con.json_clear(TEST_KEY, "$.*"); + + assert_eq!(json_clear, Ok(4)); + + let checking_value: RedisResult = con.json_get(TEST_KEY, "$"); + + // float is set to 0 and serde_json serializes 0f64 to 0.0, which is a different string + assert_eq!( + checking_value, + // i found it changes the order? + // its not reallt a problem if you're just deserializing it anyway but still + // kinda weird + Ok("[{\"arr\":[],\"bool\":true,\"float\":0,\"int\":0,\"obj\":{},\"str\":\"foo\"}]".into()) + ); +} + +#[test] +fn test_module_json_del() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a": 1i64, "nested": {"a": 2i64, "b": 3i64}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_del: RedisResult = con.json_del(TEST_KEY, "$..a"); + + assert_eq!(json_del, Ok(2)); +} + +#[test] +fn test_module_json_get() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":2i64, "b": 3i64, "nested": {"a": 4i64, "b": null}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_get: RedisResult = con.json_get(TEST_KEY, "$..b"); + + assert_eq!(json_get, Ok("[3,null]".into())); + + let json_get_multi: RedisResult = con.json_get(TEST_KEY, vec!["..a", "$..b"]); + + if json_get_multi != Ok("{\"$..b\":[3,null],\"..a\":[2,4]}".into()) + && json_get_multi != Ok("{\"..a\":[2,4],\"$..b\":[3,null]}".into()) + { + panic!("test_error: incorrect response from json_get_multi"); + } +} + +#[test] +fn test_module_json_mget() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial_a: RedisResult = con.json_set( + format!("{TEST_KEY}-a"), + "$", + &json!({"a":1i64, "b": 2i64, "nested": {"a": 3i64, "b": null}}), + ); + let set_initial_b: RedisResult = con.json_set( + format!("{TEST_KEY}-b"), + "$", + &json!({"a":4i64, "b": 5i64, "nested": {"a": 6i64, "b": null}}), + ); + + assert_eq!(set_initial_a, Ok(true)); + assert_eq!(set_initial_b, Ok(true)); + + let json_mget: RedisResult = con.json_get( + vec![format!("{TEST_KEY}-a"), format!("{TEST_KEY}-b")], + "$..a", + ); + + assert_eq!( + json_mget, + Ok(Array(vec![ + BulkString(Vec::from("[1,3]".as_bytes())), + BulkString(Vec::from("[4,6]".as_bytes())) + ])) + ); +} + +#[test] +fn test_module_json_num_incr_by() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":"b","b":[{"a":2i64}, {"a":5i64}, {"a":"c"}]}), + ); + + assert_eq!(set_initial, Ok(true)); + + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + if ctx.protocol != ProtocolVersion::RESP2 && redis_ver.starts_with("7.") { + // cannot increment a string + let json_numincrby_a: RedisResult> = con.json_num_incr_by(TEST_KEY, "$.a", 2); + assert_eq!(json_numincrby_a, Ok(vec![Nil])); + + let json_numincrby_b: RedisResult> = con.json_num_incr_by(TEST_KEY, "$..a", 2); + + // however numbers can be incremented + assert_eq!(json_numincrby_b, Ok(vec![Nil, Int(4), Int(7), Nil])); + } else { + // cannot increment a string + let json_numincrby_a: RedisResult = con.json_num_incr_by(TEST_KEY, "$.a", 2); + assert_eq!(json_numincrby_a, Ok("[null]".into())); + + let json_numincrby_b: RedisResult = con.json_num_incr_by(TEST_KEY, "$..a", 2); + + // however numbers can be incremented + assert_eq!(json_numincrby_b, Ok("[null,4,7,null]".into())); + } +} + +#[test] +fn test_module_json_obj_keys() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[3i64], "nested": {"a": {"b":2i64, "c": 1i64}}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_objkeys: RedisResult = con.json_obj_keys(TEST_KEY, "$..a"); + + assert_eq!( + json_objkeys, + Ok(Array(vec![ + Nil, + Array(vec![ + BulkString(Vec::from("b".as_bytes())), + BulkString(Vec::from("c".as_bytes())) + ]) + ])) + ); +} + +#[test] +fn test_module_json_obj_len() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":[3i64], "nested": {"a": {"b":2i64, "c": 1i64}}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_objlen: RedisResult = con.json_obj_len(TEST_KEY, "$..a"); + + assert_eq!(json_objlen, Ok(Array(vec![Nil, Int(2)]))); +} + +#[test] +fn test_module_json_set() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set: RedisResult = con.json_set(TEST_KEY, "$", &json!({"key": "value"})); + + assert_eq!(set, Ok(true)); +} + +#[test] +fn test_module_json_str_append() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":"foo", "nested": {"a": "hello"}, "nested2": {"a": 31i64}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_strappend: RedisResult = con.json_str_append(TEST_KEY, "$..a", "\"baz\""); + + assert_eq!(json_strappend, Ok(Array(vec![Int(6), Int(8), Nil]))); + + let json_get_check: RedisResult = con.json_get(TEST_KEY, "$"); + + assert_eq!( + json_get_check, + Ok("[{\"a\":\"foobaz\",\"nested\":{\"a\":\"hellobaz\"},\"nested2\":{\"a\":31}}]".into()) + ); +} + +#[test] +fn test_module_json_str_len() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":"foo", "nested": {"a": "hello"}, "nested2": {"a": 31i32}}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_strlen: RedisResult = con.json_str_len(TEST_KEY, "$..a"); + + assert_eq!(json_strlen, Ok(Array(vec![Int(3), Int(5), Nil]))); +} + +#[test] +fn test_module_json_toggle() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set(TEST_KEY, "$", &json!({"bool": true})); + + assert_eq!(set_initial, Ok(true)); + + let json_toggle_a: RedisResult = con.json_toggle(TEST_KEY, "$.bool"); + assert_eq!(json_toggle_a, Ok(Array(vec![Int(0)]))); + + let json_toggle_b: RedisResult = con.json_toggle(TEST_KEY, "$.bool"); + assert_eq!(json_toggle_b, Ok(Array(vec![Int(1)]))); +} + +#[test] +fn test_module_json_type() { + let ctx = TestContext::with_modules(&[Module::Json], MTLS_NOT_ENABLED); + let mut con = ctx.connection(); + + let set_initial: RedisResult = con.json_set( + TEST_KEY, + "$", + &json!({"a":2i64, "nested": {"a": true}, "foo": "bar"}), + ); + + assert_eq!(set_initial, Ok(true)); + + let json_type_a: RedisResult = con.json_type(TEST_KEY, "$..foo"); + let json_type_b: RedisResult = con.json_type(TEST_KEY, "$..a"); + let json_type_c: RedisResult = con.json_type(TEST_KEY, "$..dummy"); + + let redis_ver = std::env::var("REDIS_VERSION").unwrap_or_default(); + if ctx.protocol != ProtocolVersion::RESP2 && redis_ver.starts_with("7.") { + // In RESP3 current RedisJSON always gives response in an array. + assert_eq!( + json_type_a, + Ok(Array(vec![Array(vec![BulkString(Vec::from( + "string".as_bytes() + ))])])) + ); + + assert_eq!( + json_type_b, + Ok(Array(vec![Array(vec![ + BulkString(Vec::from("integer".as_bytes())), + BulkString(Vec::from("boolean".as_bytes())) + ])])) + ); + assert_eq!(json_type_c, Ok(Array(vec![Array(vec![])]))); + } else { + assert_eq!( + json_type_a, + Ok(Array(vec![BulkString(Vec::from("string".as_bytes()))])) + ); + + assert_eq!( + json_type_b, + Ok(Array(vec![ + BulkString(Vec::from("integer".as_bytes())), + BulkString(Vec::from("boolean".as_bytes())) + ])) + ); + assert_eq!(json_type_c, Ok(Array(vec![]))); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_sentinel.rs b/glide-core/redis-rs/redis/tests/test_sentinel.rs new file mode 100644 index 0000000000..24cd13bd67 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_sentinel.rs @@ -0,0 +1,496 @@ +#![allow(unknown_lints, dependency_on_unit_never_type_fallback)] +#![cfg(feature = "sentinel")] +mod support; + +use std::collections::HashMap; + +use redis::{ + sentinel::{Sentinel, SentinelClient, SentinelNodeConnectionInfo}, + Client, Connection, ConnectionAddr, ConnectionInfo, +}; + +use crate::support::*; + +fn parse_replication_info(value: &str) -> HashMap<&str, &str> { + let info_map: std::collections::HashMap<&str, &str> = value + .split("\r\n") + .filter(|line| !line.trim_start().starts_with('#')) + .filter_map(|line| line.split_once(':')) + .collect(); + info_map +} + +fn assert_is_master_role(replication_info: String) { + let info_map = parse_replication_info(&replication_info); + assert_eq!(info_map.get("role"), Some(&"master")); +} + +fn assert_replica_role_and_master_addr(replication_info: String, expected_master: &ConnectionInfo) { + let info_map = parse_replication_info(&replication_info); + + assert_eq!(info_map.get("role"), Some(&"slave")); + + let (master_host, master_port) = match &expected_master.addr { + ConnectionAddr::Tcp(host, port) => (host, port), + ConnectionAddr::TcpTls { + host, + port, + insecure: _, + tls_params: _, + } => (host, port), + ConnectionAddr::Unix(..) => panic!("Unexpected master connection type"), + }; + + assert_eq!(info_map.get("master_host"), Some(&master_host.as_str())); + assert_eq!( + info_map.get("master_port"), + Some(&master_port.to_string().as_str()) + ); +} + +fn assert_is_connection_to_master(conn: &mut Connection) { + let info: String = redis::cmd("INFO").arg("REPLICATION").query(conn).unwrap(); + assert_is_master_role(info); +} + +fn assert_connection_is_replica_of_correct_master(conn: &mut Connection, master_client: &Client) { + let info: String = redis::cmd("INFO").arg("REPLICATION").query(conn).unwrap(); + assert_replica_role_and_master_addr(info, master_client.get_connection_info()); +} + +/// Get replica clients from the sentinel in a rotating fashion, asserting that they are +/// indeed replicas of the given master, and returning a list of their addresses. +fn connect_to_all_replicas( + sentinel: &mut Sentinel, + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_replicas: u16, +) -> Vec { + let mut replica_conn_infos = vec![]; + + for _ in 0..number_of_replicas { + let replica_client = sentinel + .replica_rotate_for(master_name, Some(node_conn_info)) + .unwrap(); + let mut replica_con = replica_client.get_connection(None).unwrap(); + + assert!(!replica_conn_infos.contains(&replica_client.get_connection_info().addr)); + replica_conn_infos.push(replica_client.get_connection_info().addr.clone()); + + assert_connection_is_replica_of_correct_master(&mut replica_con, master_client); + } + + replica_conn_infos +} + +fn assert_connect_to_known_replicas( + sentinel: &mut Sentinel, + replica_conn_infos: &[ConnectionAddr], + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_connections: u32, +) { + for _ in 0..number_of_connections { + let replica_client = sentinel + .replica_rotate_for(master_name, Some(node_conn_info)) + .unwrap(); + let mut replica_con = replica_client.get_connection(None).unwrap(); + + assert!(replica_conn_infos.contains(&replica_client.get_connection_info().addr)); + + assert_connection_is_replica_of_correct_master(&mut replica_con, master_client); + } +} + +#[test] +fn test_sentinel_connect_to_random_replica() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let node_conn_info: SentinelNodeConnectionInfo = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut master_con = master_client.get_connection(None).unwrap(); + + let mut replica_con = sentinel + .replica_for(master_name, Some(&node_conn_info)) + .unwrap() + .get_connection(None) + .unwrap(); + + assert_is_connection_to_master(&mut master_con); + assert_connection_is_replica_of_correct_master(&mut replica_con, &master_client); +} + +#[test] +fn test_sentinel_connect_to_multiple_replicas() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut cluster = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = cluster.sentinel_node_connection_info(); + let sentinel = cluster.sentinel_mut(); + + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut master_con = master_client.get_connection(None).unwrap(); + + assert_is_connection_to_master(&mut master_con); + + let replica_conn_infos = connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ); + + assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ); +} + +#[test] +fn test_sentinel_server_down() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + let mut master_con = master_client.get_connection(None).unwrap(); + + assert_is_connection_to_master(&mut master_con); + + context.cluster.sentinel_servers[0].stop(); + std::thread::sleep(std::time::Duration::from_millis(25)); + + let sentinel = context.sentinel_mut(); + + let replica_conn_infos = connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ); + + assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ); +} + +#[test] +fn test_sentinel_client() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let mut master_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Master, + ) + .unwrap(); + + let mut replica_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Replica, + ) + .unwrap(); + + let mut master_con = master_client.get_connection().unwrap(); + + assert_is_connection_to_master(&mut master_con); + + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + let master_client = sentinel + .master_for(master_name, Some(&node_conn_info)) + .unwrap(); + + for _ in 0..20 { + let mut replica_con = replica_client.get_connection().unwrap(); + + assert_connection_is_replica_of_correct_master(&mut replica_con, &master_client); + } +} + +#[cfg(feature = "aio")] +pub mod async_tests { + use redis::{ + aio::MultiplexedConnection, + sentinel::{Sentinel, SentinelClient, SentinelNodeConnectionInfo}, + Client, ConnectionAddr, GlideConnectionOptions, RedisError, + }; + + use crate::{assert_is_master_role, assert_replica_role_and_master_addr, support::*}; + + async fn async_assert_is_connection_to_master(conn: &mut MultiplexedConnection) { + let info: String = redis::cmd("INFO") + .arg("REPLICATION") + .query_async(conn) + .await + .unwrap(); + + assert_is_master_role(info); + } + + async fn async_assert_connection_is_replica_of_correct_master( + conn: &mut MultiplexedConnection, + master_client: &Client, + ) { + let info: String = redis::cmd("INFO") + .arg("REPLICATION") + .query_async(conn) + .await + .unwrap(); + + assert_replica_role_and_master_addr(info, master_client.get_connection_info()); + } + + /// Async version of connect_to_all_replicas + async fn async_connect_to_all_replicas( + sentinel: &mut Sentinel, + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_replicas: u16, + ) -> Vec { + let mut replica_conn_infos = vec![]; + + for _ in 0..number_of_replicas { + let replica_client = sentinel + .async_replica_rotate_for(master_name, Some(node_conn_info)) + .await + .unwrap(); + let mut replica_con = replica_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + .unwrap(); + + assert!( + !replica_conn_infos.contains(&replica_client.get_connection_info().addr), + "pushing {:?} into {:?}", + replica_client.get_connection_info().addr, + replica_conn_infos + ); + replica_conn_infos.push(replica_client.get_connection_info().addr.clone()); + + async_assert_connection_is_replica_of_correct_master(&mut replica_con, master_client) + .await; + } + + replica_conn_infos + } + + async fn async_assert_connect_to_known_replicas( + sentinel: &mut Sentinel, + replica_conn_infos: &[ConnectionAddr], + master_name: &str, + master_client: &Client, + node_conn_info: &SentinelNodeConnectionInfo, + number_of_connections: u32, + ) { + for _ in 0..number_of_connections { + let replica_client = sentinel + .async_replica_rotate_for(master_name, Some(node_conn_info)) + .await + .unwrap(); + let mut replica_con = replica_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await + .unwrap(); + + assert!(replica_conn_infos.contains(&replica_client.get_connection_info().addr)); + + async_assert_connection_is_replica_of_correct_master(&mut replica_con, master_client) + .await; + } + } + + #[test] + fn test_sentinel_connect_to_random_replica_async() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + + block_on_all(async move { + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + let mut master_con = master_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + let mut replica_con = sentinel + .async_replica_for(master_name, Some(&node_conn_info)) + .await? + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + async_assert_is_connection_to_master(&mut master_con).await; + async_assert_connection_is_replica_of_correct_master(&mut replica_con, &master_client) + .await; + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_sentinel_connect_to_multiple_replicas_async() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut cluster = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = cluster.sentinel_node_connection_info(); + let sentinel = cluster.sentinel_mut(); + + block_on_all(async move { + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + let mut master_con = master_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + async_assert_is_connection_to_master(&mut master_con).await; + + let replica_conn_infos = async_connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ) + .await; + + async_assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ) + .await; + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_sentinel_server_down_async() { + let number_of_replicas = 3; + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, number_of_replicas, 3); + let node_conn_info = context.sentinel_node_connection_info(); + + block_on_all(async move { + let sentinel = context.sentinel_mut(); + + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + let mut master_con = master_client + .get_multiplexed_async_connection(GlideConnectionOptions::default()) + .await?; + + async_assert_is_connection_to_master(&mut master_con).await; + + context.cluster.sentinel_servers[0].stop(); + std::thread::sleep(std::time::Duration::from_millis(25)); + + let sentinel = context.sentinel_mut(); + + let replica_conn_infos = async_connect_to_all_replicas( + sentinel, + master_name, + &master_client, + &node_conn_info, + number_of_replicas, + ) + .await; + + async_assert_connect_to_known_replicas( + sentinel, + &replica_conn_infos, + master_name, + &master_client, + &node_conn_info, + 10, + ) + .await; + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } + + #[test] + fn test_sentinel_client_async() { + let master_name = "master1"; + let mut context = TestSentinelContext::new(2, 3, 3); + let mut master_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Master, + ) + .unwrap(); + + let mut replica_client = SentinelClient::build( + context.sentinels_connection_info().clone(), + String::from(master_name), + Some(context.sentinel_node_connection_info()), + redis::sentinel::SentinelServerType::Replica, + ) + .unwrap(); + + block_on_all(async move { + let mut master_con = master_client.get_async_connection().await?; + + async_assert_is_connection_to_master(&mut master_con).await; + + let node_conn_info = context.sentinel_node_connection_info(); + let sentinel = context.sentinel_mut(); + let master_client = sentinel + .async_master_for(master_name, Some(&node_conn_info)) + .await?; + + // Read commands to the replica node + for _ in 0..20 { + let mut replica_con = replica_client.get_async_connection().await?; + + async_assert_connection_is_replica_of_correct_master( + &mut replica_con, + &master_client, + ) + .await; + } + + Ok::<(), RedisError>(()) + }) + .unwrap(); + } +} diff --git a/glide-core/redis-rs/redis/tests/test_streams.rs b/glide-core/redis-rs/redis/tests/test_streams.rs new file mode 100644 index 0000000000..bf06028b95 --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_streams.rs @@ -0,0 +1,627 @@ +#![cfg(feature = "streams")] + +use redis::streams::*; +use redis::{Commands, Connection, RedisResult, ToRedisArgs}; + +mod support; +use crate::support::*; + +use std::collections::BTreeMap; +use std::str; +use std::thread::sleep; +use std::time::Duration; + +fn xadd(con: &mut Connection) { + let _: RedisResult = + con.xadd("k1", "1000-0", &[("hello", "world"), ("redis", "streams")]); + let _: RedisResult = con.xadd("k1", "1000-1", &[("hello", "world2")]); + let _: RedisResult = con.xadd("k2", "2000-0", &[("hello", "world")]); + let _: RedisResult = con.xadd("k2", "2000-1", &[("hello", "world2")]); +} + +fn xadd_keyrange(con: &mut Connection, key: &str, start: i32, end: i32) { + for _i in start..end { + let _: RedisResult = con.xadd(key, "*", &[("h", "w")]); + } +} + +#[test] +fn test_cmd_options() { + // Tests the following command option builders.... + // xclaim_options + // xread_options + // maxlen enum + + // test read options + + let empty = StreamClaimOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + let empty = StreamReadOptions::default(); + assert_eq!(ToRedisArgs::to_redis_args(&empty).len(), 0); + + let opts = StreamClaimOptions::default() + .idle(50) + .time(500) + .retry(3) + .with_force() + .with_justid(); + + assert_args!( + &opts, + "IDLE", + "50", + "TIME", + "500", + "RETRYCOUNT", + "3", + "FORCE", + "JUSTID" + ); + + // test maxlen options + + assert_args!(StreamMaxlen::Approx(10), "MAXLEN", "~", "10"); + assert_args!(StreamMaxlen::Equals(10), "MAXLEN", "=", "10"); + + // test read options + + let opts = StreamReadOptions::default() + .noack() + .block(100) + .count(200) + .group("group-name", "consumer-name"); + + assert_args!( + &opts, + "GROUP", + "group-name", + "consumer-name", + "BLOCK", + "100", + "COUNT", + "200", + "NOACK" + ); + + // should skip noack because of missing group(,) + let opts = StreamReadOptions::default().noack().block(100).count(200); + + assert_args!(&opts, "BLOCK", "100", "COUNT", "200"); +} + +#[test] +fn test_assorted_1() { + // Tests the following commands.... + // xadd + // xadd_map (skip this for now) + // xadd_maxlen + // xread + // xlen + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // smoke test that we get the same id back + let result: RedisResult = con.xadd("k0", "1000-0", &[("x", "y")]); + assert_eq!(result.unwrap(), "1000-0"); + + // xread reply + let reply: StreamReadReply = con.xread(&["k1", "k2", "k3"], &["0", "0", "0"]).unwrap(); + + // verify reply contains 2 keys even though we asked for 3 + assert_eq!(&reply.keys.len(), &2usize); + + // verify first key & first id exist + assert_eq!(&reply.keys[0].key, "k1"); + assert_eq!(&reply.keys[0].ids.len(), &2usize); + assert_eq!(&reply.keys[0].ids[0].id, "1000-0"); + + // lookup the key in StreamId map + let hello: Option = reply.keys[0].ids[0].get("hello"); + assert_eq!(hello, Some("world".to_string())); + + // verify the second key was written + assert_eq!(&reply.keys[1].key, "k2"); + assert_eq!(&reply.keys[1].ids.len(), &2usize); + assert_eq!(&reply.keys[1].ids[0].id, "2000-0"); + + // test xadd_map + let mut map: BTreeMap<&str, &str> = BTreeMap::new(); + map.insert("ab", "cd"); + map.insert("ef", "gh"); + map.insert("ij", "kl"); + let _: RedisResult = con.xadd_map("k3", "3000-0", map); + + let reply: StreamRangeReply = con.xrange_all("k3").unwrap(); + assert!(reply.ids[0].contains_key("ab")); + assert!(reply.ids[0].contains_key("ef")); + assert!(reply.ids[0].contains_key("ij")); + + // test xadd w/ maxlength below... + + // add 100 things to k4 + xadd_keyrange(&mut con, "k4", 0, 100); + + // test xlen.. should have 100 items + let result: RedisResult = con.xlen("k4"); + assert_eq!(result, Ok(100)); + + // test xadd_maxlen + let _: RedisResult = + con.xadd_maxlen("k4", StreamMaxlen::Equals(10), "*", &[("h", "w")]); + let result: RedisResult = con.xlen("k4"); + assert_eq!(result, Ok(10)); +} + +#[test] +fn test_xgroup_create() { + // Tests the following commands.... + // xadd + // xinfo_stream + // xgroup_create + // xinfo_groups + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // no key exists... this call breaks the connection pipe for some reason + let reply: RedisResult = con.xinfo_stream("k10"); + assert!(reply.is_err()); + + // redo the connection because the above error + con = ctx.connection(); + + // key should exist + let reply: StreamInfoStreamReply = con.xinfo_stream("k1").unwrap(); + assert_eq!(&reply.first_entry.id, "1000-0"); + assert_eq!(&reply.last_entry.id, "1000-1"); + assert_eq!(&reply.last_generated_id, "1000-1"); + + // xgroup create (existing stream) + let result: RedisResult = con.xgroup_create("k1", "g1", "$"); + assert!(result.is_ok()); + + // xinfo groups (existing stream) + let result: RedisResult = con.xinfo_groups("k1"); + assert!(result.is_ok()); + let reply = result.unwrap(); + assert_eq!(&reply.groups.len(), &1); + assert_eq!(&reply.groups[0].name, &"g1"); +} + +#[test] +fn test_assorted_2() { + // Tests the following commands.... + // xadd + // xinfo_stream + // xinfo_groups + // xinfo_consumer + // xgroup_create_mkstream + // xread_options + // xack + // xpending + // xpending_count + // xpending_consumer_count + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // test xgroup create w/ mkstream @ 0 + let result: RedisResult = con.xgroup_create_mkstream("k99", "g99", "0"); + assert!(result.is_ok()); + + // Since nothing exists on this stream yet, + // it should have the defaults returned by the client + let result: RedisResult = con.xinfo_groups("k99"); + assert!(result.is_ok()); + let reply = result.unwrap(); + assert_eq!(&reply.groups.len(), &1); + assert_eq!(&reply.groups[0].name, &"g99"); + assert_eq!(&reply.groups[0].last_delivered_id, &"0-0"); + + // call xadd on k99 just so we can read from it + // using consumer g99 and test xinfo_consumers + let _: RedisResult = con.xadd("k99", "1000-0", &[("a", "b"), ("c", "d")]); + let _: RedisResult = con.xadd("k99", "1000-1", &[("e", "f"), ("g", "h")]); + + // test empty PEL + let empty_reply: StreamPendingReply = con.xpending("k99", "g99").unwrap(); + + assert_eq!(empty_reply.count(), 0); + if let StreamPendingReply::Empty = empty_reply { + // looks good + } else { + panic!("Expected StreamPendingReply::Empty but got Data"); + } + + // passing options w/ group triggers XREADGROUP + // using ID=">" means all undelivered ids + // otherwise, ID="0 | ms-num" means all pending already + // sent to this client + let reply: StreamReadReply = con + .xread_options( + &["k99"], + &[">"], + &StreamReadOptions::default().group("g99", "c99"), + ) + .unwrap(); + assert_eq!(reply.keys[0].ids.len(), 2); + + // read xinfo consumers again, should have 2 messages for the c99 consumer + let reply: StreamInfoConsumersReply = con.xinfo_consumers("k99", "g99").unwrap(); + assert_eq!(reply.consumers[0].pending, 2); + + // ack one of these messages + let result: RedisResult = con.xack("k99", "g99", &["1000-0"]); + assert_eq!(result, Ok(1)); + + // get pending messages already seen by this client + // we should only have one now.. + let reply: StreamReadReply = con + .xread_options( + &["k99"], + &["0"], + &StreamReadOptions::default().group("g99", "c99"), + ) + .unwrap(); + assert_eq!(reply.keys.len(), 1); + + // we should also have one pending here... + let reply: StreamInfoConsumersReply = con.xinfo_consumers("k99", "g99").unwrap(); + assert_eq!(reply.consumers[0].pending, 1); + + // add more and read so we can test xpending + let _: RedisResult = con.xadd("k99", "1001-0", &[("i", "j"), ("k", "l")]); + let _: RedisResult = con.xadd("k99", "1001-1", &[("m", "n"), ("o", "p")]); + let _: StreamReadReply = con + .xread_options( + &["k99"], + &[">"], + &StreamReadOptions::default().group("g99", "c99"), + ) + .unwrap(); + + // call xpending here... + // this has a different reply from what the count variations return + let data_reply: StreamPendingReply = con.xpending("k99", "g99").unwrap(); + + assert_eq!(data_reply.count(), 3); + + if let StreamPendingReply::Data(data) = data_reply { + assert_stream_pending_data(data) + } else { + panic!("Expected StreamPendingReply::Data but got Empty"); + } + + // both count variations have the same reply types + let reply: StreamPendingCountReply = con.xpending_count("k99", "g99", "-", "+", 10).unwrap(); + assert_eq!(reply.ids.len(), 3); + + let reply: StreamPendingCountReply = con + .xpending_consumer_count("k99", "g99", "-", "+", 10, "c99") + .unwrap(); + assert_eq!(reply.ids.len(), 3); + + for StreamPendingId { + id, + consumer, + times_delivered, + last_delivered_ms: _, + } in reply.ids + { + assert!(!id.is_empty()); + assert!(!consumer.is_empty()); + assert!(times_delivered > 0); + } +} + +fn assert_stream_pending_data(data: StreamPendingData) { + assert_eq!(data.start_id, "1000-1"); + assert_eq!(data.end_id, "1001-1"); + assert_eq!(data.consumers.len(), 1); + assert_eq!(data.consumers[0].name, "c99"); +} + +#[test] +fn test_xadd_maxlen_map() { + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + for i in 0..10 { + let mut map: BTreeMap<&str, &str> = BTreeMap::new(); + let idx = i.to_string(); + map.insert("idx", &idx); + let _: RedisResult = + con.xadd_maxlen_map("maxlen_map", StreamMaxlen::Equals(3), "*", map); + } + + let result: RedisResult = con.xlen("maxlen_map"); + assert_eq!(result, Ok(3)); + let reply: StreamRangeReply = con.xrange_all("maxlen_map").unwrap(); + + assert_eq!(reply.ids[0].get("idx"), Some("7".to_string())); + assert_eq!(reply.ids[1].get("idx"), Some("8".to_string())); + assert_eq!(reply.ids[2].get("idx"), Some("9".to_string())); +} + +#[test] +fn test_xread_options_deleted_pel_entry() { + // Test xread_options behaviour with deleted entry + let ctx = TestContext::new(); + let mut con = ctx.connection(); + let result: RedisResult = con.xgroup_create_mkstream("k1", "g1", "$"); + assert!(result.is_ok()); + let _: RedisResult = + con.xadd_maxlen("k1", StreamMaxlen::Equals(1), "*", &[("h1", "w1")]); + // read the pending items for this key & group + let result: StreamReadReply = con + .xread_options( + &["k1"], + &[">"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + + let _: RedisResult = + con.xadd_maxlen("k1", StreamMaxlen::Equals(1), "*", &[("h2", "w2")]); + let result_deleted_entry: StreamReadReply = con + .xread_options( + &["k1"], + &["0"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + assert_eq!( + result.keys[0].ids.len(), + result_deleted_entry.keys[0].ids.len() + ); + assert_eq!( + result.keys[0].ids[0].id, + result_deleted_entry.keys[0].ids[0].id + ); +} +#[test] +fn test_xclaim() { + // Tests the following commands.... + // xclaim + // xclaim_options + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // xclaim test basic idea: + // 1. we need to test adding messages to a group + // 2. then xreadgroup needs to define a consumer and read pending + // messages without acking them + // 3. then we need to sleep 5ms and call xpending + // 4. from here we should be able to claim message + // past the idle time and read them from a different consumer + + // create the group + let result: RedisResult = con.xgroup_create_mkstream("k1", "g1", "$"); + assert!(result.is_ok()); + + // add some keys + xadd_keyrange(&mut con, "k1", 0, 10); + + // read the pending items for this key & group + let reply: StreamReadReply = con + .xread_options( + &["k1"], + &[">"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + // verify we have 10 ids + assert_eq!(reply.keys[0].ids.len(), 10); + + // save this StreamId for later + let claim = &reply.keys[0].ids[0]; + let _claim_1 = &reply.keys[0].ids[1]; + let claim_justids = &reply.keys[0] + .ids + .iter() + .map(|msg| &msg.id) + .collect::>(); + + // sleep for 5ms + sleep(Duration::from_millis(5)); + + // grab this id if > 4ms + let reply: StreamClaimReply = con + .xclaim("k1", "g1", "c2", 4, &[claim.id.clone()]) + .unwrap(); + assert_eq!(reply.ids.len(), 1); + assert_eq!(reply.ids[0].id, claim.id); + + // grab all pending ids for this key... + // we should 9 in c1 and 1 in c2 + let reply: StreamPendingReply = con.xpending("k1", "g1").unwrap(); + if let StreamPendingReply::Data(data) = reply { + assert_eq!(data.consumers[0].name, "c1"); + assert_eq!(data.consumers[0].pending, 9); + assert_eq!(data.consumers[1].name, "c2"); + assert_eq!(data.consumers[1].pending, 1); + } + + // sleep for 5ms + sleep(Duration::from_millis(5)); + + // lets test some of the xclaim_options + // call force on the same claim.id + let _: StreamClaimReply = con + .xclaim_options( + "k1", + "g1", + "c3", + 4, + &[claim.id.clone()], + StreamClaimOptions::default().with_force(), + ) + .unwrap(); + + let reply: StreamPendingReply = con.xpending("k1", "g1").unwrap(); + // we should have 9 w/ c1 and 1 w/ c3 now + if let StreamPendingReply::Data(data) = reply { + assert_eq!(data.consumers[1].name, "c3"); + assert_eq!(data.consumers[1].pending, 1); + } + + // sleep for 5ms + sleep(Duration::from_millis(5)); + + // claim and only return JUSTID + let claimed: Vec = con + .xclaim_options( + "k1", + "g1", + "c5", + 4, + claim_justids, + StreamClaimOptions::default().with_force().with_justid(), + ) + .unwrap(); + // we just claimed the original 10 ids + // and only returned the ids + assert_eq!(claimed.len(), 10); +} + +#[test] +fn test_xdel() { + // Tests the following commands.... + // xdel + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // add some keys + xadd(&mut con); + + // delete the first stream item for this key + let result: RedisResult = con.xdel("k1", &["1000-0"]); + // returns the number of items deleted + assert_eq!(result, Ok(1)); + + let result: RedisResult = con.xdel("k2", &["2000-0", "2000-1", "2000-2"]); + // should equal 2 since the last id doesn't exist + assert_eq!(result, Ok(2)); +} + +#[test] +fn test_xtrim() { + // Tests the following commands.... + // xtrim + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // add some keys + xadd_keyrange(&mut con, "k1", 0, 100); + + // trim key to 50 + // returns the number of items remaining in the stream + let result: RedisResult = con.xtrim("k1", StreamMaxlen::Equals(50)); + assert_eq!(result, Ok(50)); + // we should end up with 40 after this call + let result: RedisResult = con.xtrim("k1", StreamMaxlen::Equals(10)); + assert_eq!(result, Ok(40)); +} + +#[test] +fn test_xgroup() { + // Tests the following commands.... + // xgroup_create_mkstream + // xgroup_destroy + // xgroup_delconsumer + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + // test xgroup create w/ mkstream @ 0 + let result: RedisResult = con.xgroup_create_mkstream("k1", "g1", "0"); + assert!(result.is_ok()); + + // destroy this new stream group + let result: RedisResult = con.xgroup_destroy("k1", "g1"); + assert_eq!(result, Ok(1)); + + // add some keys + xadd(&mut con); + + // create the group again using an existing stream + let result: RedisResult = con.xgroup_create("k1", "g1", "0"); + assert!(result.is_ok()); + + // read from the group so we can register the consumer + let reply: StreamReadReply = con + .xread_options( + &["k1"], + &[">"], + &StreamReadOptions::default().group("g1", "c1"), + ) + .unwrap(); + assert_eq!(reply.keys[0].ids.len(), 2); + + let result: RedisResult = con.xgroup_delconsumer("k1", "g1", "c1"); + // returns the number of pending message this client had open + assert_eq!(result, Ok(2)); + + let result: RedisResult = con.xgroup_destroy("k1", "g1"); + assert_eq!(result, Ok(1)); +} + +#[test] +fn test_xrange() { + // Tests the following commands.... + // xrange (-/+ variations) + // xrange_all + // xrange_count + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // xrange replies + let reply: StreamRangeReply = con.xrange_all("k1").unwrap(); + assert_eq!(reply.ids.len(), 2); + + let reply: StreamRangeReply = con.xrange("k1", "1000-1", "+").unwrap(); + assert_eq!(reply.ids.len(), 1); + + let reply: StreamRangeReply = con.xrange("k1", "-", "1000-0").unwrap(); + assert_eq!(reply.ids.len(), 1); + + let reply: StreamRangeReply = con.xrange_count("k1", "-", "+", 1).unwrap(); + assert_eq!(reply.ids.len(), 1); +} + +#[test] +fn test_xrevrange() { + // Tests the following commands.... + // xrevrange (+/- variations) + // xrevrange_all + // xrevrange_count + + let ctx = TestContext::new(); + let mut con = ctx.connection(); + + xadd(&mut con); + + // xrange replies + let reply: StreamRangeReply = con.xrevrange_all("k1").unwrap(); + assert_eq!(reply.ids.len(), 2); + + let reply: StreamRangeReply = con.xrevrange("k1", "1000-1", "-").unwrap(); + assert_eq!(reply.ids.len(), 2); + + let reply: StreamRangeReply = con.xrevrange("k1", "+", "1000-1").unwrap(); + assert_eq!(reply.ids.len(), 1); + + let reply: StreamRangeReply = con.xrevrange_count("k1", "+", "-", 1).unwrap(); + assert_eq!(reply.ids.len(), 1); +} diff --git a/glide-core/redis-rs/redis/tests/test_types.rs b/glide-core/redis-rs/redis/tests/test_types.rs new file mode 100644 index 0000000000..d5df513efb --- /dev/null +++ b/glide-core/redis-rs/redis/tests/test_types.rs @@ -0,0 +1,606 @@ +mod support; + +#[cfg(test)] +mod types { + use redis::{FromRedisValue, ToRedisArgs, Value}; + #[test] + fn test_is_single_arg() { + let sslice: &[_] = &["foo"][..]; + let nestslice: &[_] = &[sslice][..]; + let nestvec = vec![nestslice]; + let bytes = b"Hello World!"; + let twobytesslice: &[_] = &[bytes, bytes][..]; + let twobytesvec = vec![bytes, bytes]; + + assert!("foo".is_single_arg()); + assert!(sslice.is_single_arg()); + assert!(nestslice.is_single_arg()); + assert!(nestvec.is_single_arg()); + assert!(bytes.is_single_arg()); + + assert!(!twobytesslice.is_single_arg()); + assert!(!twobytesvec.is_single_arg()); + } + + /// The `FromRedisValue` trait provides two methods for parsing: + /// - `fn from_redis_value(&Value) -> Result` + /// - `fn from_owned_redis_value(Value) -> Result` + /// The `RedisParseMode` below allows choosing between the two + /// so that test logic does not need to be duplicated for each. + enum RedisParseMode { + Owned, + Ref, + } + + impl RedisParseMode { + /// Calls either `FromRedisValue::from_owned_redis_value` or + /// `FromRedisValue::from_redis_value`. + fn parse_redis_value( + &self, + value: redis::Value, + ) -> Result { + match self { + Self::Owned => redis::FromRedisValue::from_owned_redis_value(value), + Self::Ref => redis::FromRedisValue::from_redis_value(&value), + } + } + } + + #[test] + fn test_info_dict() { + use redis::{InfoDict, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let d: InfoDict = parse_mode + .parse_redis_value(Value::SimpleString( + "# this is a comment\nkey1:foo\nkey2:42\n".into(), + )) + .unwrap(); + + assert_eq!(d.get("key1"), Some("foo".to_string())); + assert_eq!(d.get("key2"), Some(42i64)); + assert_eq!(d.get::("key3"), None); + } + } + + #[test] + fn test_i32() { + use redis::{ErrorKind, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let i = parse_mode.parse_redis_value(Value::SimpleString("42".into())); + assert_eq!(i, Ok(42i32)); + + let i = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(i, Ok(42i32)); + + let i = parse_mode.parse_redis_value(Value::BulkString("42".into())); + assert_eq!(i, Ok(42i32)); + + let bad_i: Result = + parse_mode.parse_redis_value(Value::SimpleString("42x".into())); + assert_eq!(bad_i.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_u32() { + use redis::{ErrorKind, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let i = parse_mode.parse_redis_value(Value::SimpleString("42".into())); + assert_eq!(i, Ok(42u32)); + + let bad_i: Result = + parse_mode.parse_redis_value(Value::SimpleString("-1".into())); + assert_eq!(bad_i.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])); + assert_eq!(v, Ok(vec![1i32, 2, 3])); + + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(content_vec)); + + let content: &[u8] = b"1"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(vec![b'1'])); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(vec![1_u16])); + } + } + + #[test] + fn test_box_slice() { + use redis::{FromRedisValue, Value}; + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])); + assert_eq!(v, Ok(vec![1i32, 2, 3].into_boxed_slice())); + + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(content_vec.into_boxed_slice())); + + let content: &[u8] = b"1"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(vec![b'1'].into_boxed_slice())); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(vec![1_u16].into_boxed_slice())); + + assert_eq!( + Box::<[i32]>::from_redis_value( + &Value::BulkString("just a string".into()) + ).unwrap_err().to_string(), + "Response was of incompatible type - TypeError: \"Conversion to alloc::boxed::Box<[i32]> failed.\" (response was bulk-string('\"just a string\"'))", + ); + } + } + + #[test] + fn test_arc_slice() { + use redis::{FromRedisValue, Value}; + use std::sync::Arc; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])); + assert_eq!(v, Ok(Arc::from(vec![1i32, 2, 3]))); + + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(Arc::from(content_vec))); + + let content: &[u8] = b"1"; + let content_vec: Vec = Vec::from(content); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec.clone())); + assert_eq!(v, Ok(Arc::from(vec![b'1']))); + let v = parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(Arc::from(vec![1_u16]))); + + assert_eq!( + Arc::<[i32]>::from_redis_value( + &Value::BulkString("just a string".into()) + ).unwrap_err().to_string(), + "Response was of incompatible type - TypeError: \"Conversion to alloc::sync::Arc<[i32]> failed.\" (response was bulk-string('\"just a string\"'))", + ); + } + } + + #[test] + fn test_single_bool_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + + assert_eq!(v, Ok(vec![true])); + } + } + + #[test] + fn test_single_i32_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + + assert_eq!(v, Ok(vec![1i32])); + } + } + + #[test] + fn test_single_u32_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("42".into())); + + assert_eq!(v, Ok(vec![42u32])); + } + } + + #[test] + fn test_single_string_vec() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + assert_eq!(v, Ok(vec!["1".to_string()])); + } + } + + #[test] + fn test_tuple() { + use redis::Value; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::Array(vec![Value::Array(vec![ + Value::BulkString("1".into()), + Value::BulkString("2".into()), + Value::BulkString("3".into()), + ])])); + + assert_eq!(v, Ok(((1i32, 2, 3,),))); + } + } + + #[test] + fn test_hashmap() { + use fnv::FnvHasher; + use redis::{ErrorKind, Value}; + use std::collections::HashMap; + use std::hash::BuildHasherDefault; + + type Hm = HashMap; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v: Result = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("a".into()), + Value::BulkString("1".into()), + Value::BulkString("b".into()), + Value::BulkString("2".into()), + Value::BulkString("c".into()), + Value::BulkString("3".into()), + ])); + let mut e: Hm = HashMap::new(); + e.insert("a".into(), 1); + e.insert("b".into(), 2); + e.insert("c".into(), 3); + assert_eq!(v, Ok(e)); + + type Hasher = BuildHasherDefault; + type HmHasher = HashMap; + let v: Result = parse_mode.parse_redis_value(Value::Array(vec![ + Value::BulkString("a".into()), + Value::BulkString("1".into()), + Value::BulkString("b".into()), + Value::BulkString("2".into()), + Value::BulkString("c".into()), + Value::BulkString("3".into()), + ])); + + let fnv = Hasher::default(); + let mut e: HmHasher = HashMap::with_hasher(fnv); + e.insert("a".into(), 1); + e.insert("b".into(), 2); + e.insert("c".into(), 3); + assert_eq!(v, Ok(e)); + + let v: Result = + parse_mode.parse_redis_value(Value::Array(vec![Value::BulkString("a".into())])); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_bool() { + use redis::{ErrorKind, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let v = parse_mode.parse_redis_value(Value::BulkString("1".into())); + assert_eq!(v, Ok(true)); + + let v = parse_mode.parse_redis_value(Value::BulkString("0".into())); + assert_eq!(v, Ok(false)); + + let v: Result = + parse_mode.parse_redis_value(Value::BulkString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v = parse_mode.parse_redis_value(Value::SimpleString("1".into())); + assert_eq!(v, Ok(true)); + + let v = parse_mode.parse_redis_value(Value::SimpleString("0".into())); + assert_eq!(v, Ok(false)); + + let v: Result = + parse_mode.parse_redis_value(Value::SimpleString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v = parse_mode.parse_redis_value(Value::Okay); + assert_eq!(v, Ok(true)); + + let v = parse_mode.parse_redis_value(Value::Nil); + assert_eq!(v, Ok(false)); + + let v = parse_mode.parse_redis_value(Value::Int(0)); + assert_eq!(v, Ok(false)); + + let v = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(v, Ok(true)); + } + } + + #[cfg(feature = "bytes")] + #[test] + fn test_bytes() { + use bytes::Bytes; + use redis::{ErrorKind, RedisResult, Value}; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + let content_bytes = Bytes::from_static(content); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(content_bytes)); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::SimpleString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Okay); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(0)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[cfg(feature = "uuid")] + #[test] + fn test_uuid() { + use std::str::FromStr; + + use redis::{ErrorKind, FromRedisValue, RedisResult, Value}; + use uuid::Uuid; + + let uuid = Uuid::from_str("abab64b7-e265-4052-a41b-23e1e28674bf").unwrap(); + let bytes = uuid.as_bytes().to_vec(); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::BulkString(bytes)); + assert_eq!(v, Ok(uuid)); + + let v: RedisResult = + FromRedisValue::from_redis_value(&Value::SimpleString("garbage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Okay); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(0)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = FromRedisValue::from_redis_value(&Value::Int(42)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + + #[test] + fn test_cstring() { + use redis::{ErrorKind, RedisResult, Value}; + use std::ffi::CString; + + for parse_mode in [RedisParseMode::Owned, RedisParseMode::Ref] { + let content: &[u8] = b"\x01\x02\x03\x04"; + let content_vec: Vec = Vec::from(content); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::BulkString(content_vec)); + assert_eq!(v, Ok(CString::new(content).unwrap())); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::SimpleString("garbage".into())); + assert_eq!(v, Ok(CString::new("garbage").unwrap())); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Okay); + assert_eq!(v, Ok(CString::new("OK").unwrap())); + + let v: RedisResult = + parse_mode.parse_redis_value(Value::SimpleString("gar\0bage".into())); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Nil); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(0)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + + let v: RedisResult = parse_mode.parse_redis_value(Value::Int(42)); + assert_eq!(v.unwrap_err().kind(), ErrorKind::TypeError); + } + } + + #[test] + fn test_types_to_redis_args() { + use redis::ToRedisArgs; + use std::collections::BTreeMap; + use std::collections::BTreeSet; + use std::collections::HashMap; + use std::collections::HashSet; + + assert!(!5i32.to_redis_args().is_empty()); + assert!(!"abc".to_redis_args().is_empty()); + assert!(!"abc".to_redis_args().is_empty()); + assert!(!String::from("x").to_redis_args().is_empty()); + + assert!(![5, 4] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + + assert!(![5, 4] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + + // this can be used on something HMSET + assert!(![("a", 5), ("b", 6), ("C", 7)] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + + // this can also be used on something HMSET + assert!(![("d", 8), ("e", 9), ("f", 10)] + .iter() + .cloned() + .collect::>() + .to_redis_args() + .is_empty()); + } + + #[test] + fn test_large_usize_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let mut array = [0; 1000]; + for (i, item) in array.iter_mut().enumerate() { + *item = i; + } + + let vec = (&array).to_redis_args(); + assert_eq!(array.len(), vec.len()); + + let value = Value::Array( + vec.iter() + .map(|val| Value::BulkString(val.clone())) + .collect(), + ); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [usize; 1000] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_large_u8_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let mut array: [u8; 1000] = [0; 1000]; + for (i, item) in array.iter_mut().enumerate() { + *item = (i % 256) as u8; + } + + let vec = (&array).to_redis_args(); + assert_eq!(vec.len(), 1); + assert_eq!(array.len(), vec[0].len()); + + let value = Value::Array(vec[0].iter().map(|val| Value::Int(*val as i64)).collect()); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [u8; 1000] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_large_string_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let mut array: [String; 1000] = [(); 1000].map(|_| String::new()); + for (i, item) in array.iter_mut().enumerate() { + *item = format!("{i}"); + } + + let vec = (&array).to_redis_args(); + assert_eq!(array.len(), vec.len()); + + let value = Value::Array( + vec.iter() + .map(|val| Value::BulkString(val.clone())) + .collect(), + ); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [String; 1000] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_0_length_usize_array_to_redis_args_and_back() { + use crate::support::encode_value; + use redis::ToRedisArgs; + + let array: [usize; 0] = [0; 0]; + + let vec = (&array).to_redis_args(); + assert_eq!(array.len(), vec.len()); + + let value = Value::Array( + vec.iter() + .map(|val| Value::BulkString(val.clone())) + .collect(), + ); + let mut encoded_input = Vec::new(); + encode_value(&value, &mut encoded_input).unwrap(); + + let new_array: [usize; 0] = FromRedisValue::from_redis_value(&value).unwrap(); + assert_eq!(new_array, array); + + let new_array: [usize; 0] = FromRedisValue::from_redis_value(&Value::Nil).unwrap(); + assert_eq!(new_array, array); + } + + #[test] + fn test_attributes() { + use redis::{parse_redis_value, FromRedisValue, Value}; + let bytes: &[u8] = b"*3\r\n:1\r\n:2\r\n|1\r\n+ttl\r\n:3600\r\n:3\r\n"; + let val = parse_redis_value(bytes).unwrap(); + { + // The case user doesn't expect attributes from server + let x: Vec = redis::FromRedisValue::from_redis_value(&val).unwrap(); + assert_eq!(x, vec![1, 2, 3]); + } + { + // The case user wants raw value from server + let x: Value = FromRedisValue::from_redis_value(&val).unwrap(); + assert_eq!( + x, + Value::Array(vec![ + Value::Int(1), + Value::Int(2), + Value::Attribute { + data: Box::new(Value::Int(3)), + attributes: vec![( + Value::SimpleString("ttl".to_string()), + Value::Int(3600) + )] + } + ]) + ) + } + } +} diff --git a/glide-core/redis-rs/release.sh b/glide-core/redis-rs/release.sh new file mode 100755 index 0000000000..f01241c382 --- /dev/null +++ b/glide-core/redis-rs/release.sh @@ -0,0 +1,15 @@ +#!/bin/sh +set -ex + +LEVEL=$1 +if [ -z "$LEVEL" ]; then + echo "Expected patch, minor or major" + exit 1 +fi + +clog --$LEVEL + +git add CHANGELOG.md +git commit -m "Update changelog" + +cargo release --execute $LEVEL diff --git a/glide-core/redis-rs/rustfmt.toml b/glide-core/redis-rs/rustfmt.toml new file mode 100644 index 0000000000..0d564415cb --- /dev/null +++ b/glide-core/redis-rs/rustfmt.toml @@ -0,0 +1,2 @@ +use_try_shorthand = true +edition = "2018" diff --git a/glide-core/redis-rs/scripts/get_command_info.py b/glide-core/redis-rs/scripts/get_command_info.py new file mode 100644 index 0000000000..dcba666bff --- /dev/null +++ b/glide-core/redis-rs/scripts/get_command_info.py @@ -0,0 +1,227 @@ +import argparse +import json +import os +from os.path import join + +"""Valkey command categorizer + +This script analyzes command info json files and categorizes the commands based on their routing. The output can be used +to map commands in the cluster_routing.rs#base_routing function to their RouteBy category. Commands that cannot be +categorized by the script will be listed under the "Uncategorized" section. These commands will need to be manually +categorized. + +To use the script: +1. Clone https://github.com/valkey-io/valkey +2. cd into the cloned valkey repository and checkout the desired version of the code, eg 7.2.5 +3. cd into the directory containing this script +4. run: + python get_command_info.py --commands-dir=/valkey/src/commands +""" + + +class CommandCategory: + def __init__(self, name, description): + self.name = name + self.description = description + self.commands = [] + + def add_command(self, command_name): + self.commands.append(command_name) + + +def main(): + parser = argparse.ArgumentParser( + description="Analyzes command info json and categorizes commands into their RouteBy categories") + parser.add_argument( + "--commands-dir", + type=str, + help="Path to the directory containing the command info json files (example: ../../valkey/src/commands)", + required=True, + ) + + args = parser.parse_args() + commands_dir = args.commands_dir + if not os.path.exists(commands_dir): + raise parser.error("The command info directory passed to the '--commands-dir' argument does not exist") + + all_nodes = CommandCategory("AllNodes", "Commands with an ALL_NODES request policy") + all_primaries = CommandCategory("AllPrimaries", "Commands with an ALL_SHARDS request policy") + multi_shard = CommandCategory("MultiShardNoValues or MultiShardWithValues", + "Commands with a MULTI_SHARD request policy") + first_arg = CommandCategory("FirstKey", "Commands with their first key argument at position 1") + second_arg = CommandCategory("SecondArg", "Commands with their first key argument at position 2") + second_arg_numkeys = ( + CommandCategory("SecondArgAfterKeyCount", + "Commands with their first key argument at position 2, after a numkeys argument")) + # all commands with their first key argument at position 3 have a numkeys argument at position 2, + # so there is a ThirdArgAfterKeyCount category but no ThirdArg category + third_arg_numkeys = ( + CommandCategory("ThirdArgAfterKeyCount", + "Commands with their first key argument at position 3, after a numkeys argument")) + streams_index = CommandCategory("StreamsIndex", "Commands that include a STREAMS token") + second_arg_slot = CommandCategory("SecondArgSlot", "Commands with a slot argument at position 2") + uncategorized = ( + CommandCategory( + "Uncategorized", + "Commands that don't fall into the other categories. These commands will have to be manually categorized.")) + + categories = [all_nodes, all_primaries, multi_shard, first_arg, second_arg, second_arg_numkeys, third_arg_numkeys, + streams_index, second_arg_slot, uncategorized] + + print("Gathering command info...\n") + + for filename in os.listdir(commands_dir): + file_path = join(commands_dir, filename) + _, file_extension = os.path.splitext(file_path) + if file_extension != ".json": + print(f"Note: {filename} is not a json file and will thus be ignored") + continue + + file = open(file_path) + command_json = json.load(file) + if len(command_json) == 0: + raise Exception( + f"The json for {filename} was empty. A json object with information about the command was expected.") + + command_name = next(iter(command_json)) + command_info = command_json[command_name] + if "container" in command_info: + # for two-word commands like 'XINFO GROUPS', the `next(iter(command_json))` statement above returns 'GROUPS' + # and `command_info['container']` returns 'XINFO' + command_name = f"{command_info['container']} {command_name}" + + if "command_tips" in command_info: + request_policy = get_request_policy(command_info["command_tips"]) + if request_policy == "ALL_NODES": + all_nodes.add_command(command_name) + continue + elif request_policy == "ALL_SHARDS": + all_primaries.add_command(command_name) + continue + elif request_policy == "MULTI_SHARD": + multi_shard.add_command(command_name) + continue + + if "arguments" not in command_info: + uncategorized.add_command(command_name) + continue + + command_args = command_info["arguments"] + split_name = command_name.split() + if len(split_name) == 0: + raise Exception(f"Encountered json with an empty command name in file '{filename}'") + + json_key_index, is_key_optional = get_first_key_info(command_args) + # cluster_routing.rs can handle optional keys if a keycount of 0 is provided, otherwise the command should + # fall under the "Uncategorized" section to indicate it will need to be manually inspected + if is_key_optional and not is_after_numkeys(command_args, json_key_index): + uncategorized.add_command(command_name) + continue + + if json_key_index == -1: + # the command does not have a key argument, check for a slot argument + json_slot_index, is_slot_optional = get_first_slot_info(command_args) + if is_slot_optional: + uncategorized.add_command(command_name) + continue + + # cluster_routing.rs considers each word in the command name to be an argument, but the json does not + cluster_routing_slot_index = -1 if json_slot_index == -1 else len(split_name) + json_slot_index + if cluster_routing_slot_index == 2: + second_arg_slot.add_command(command_name) + continue + + # the command does not have a slot argument, check for a "STREAMS" token + if has_streams_token(command_args): + streams_index.add_command(command_name) + continue + + uncategorized.add_command(command_name) + continue + + # cluster_routing.rs considers each word in the command name to be an argument, but the json does not + cluster_routing_key_index = -1 if json_key_index == -1 else len(split_name) + json_key_index + if cluster_routing_key_index == 1: + first_arg.add_command(command_name) + continue + elif cluster_routing_key_index == 2: + if is_after_numkeys(command_args, json_key_index): + second_arg_numkeys.add_command(command_name) + continue + else: + second_arg.add_command(command_name) + continue + # there aren't any commands that fall into a ThirdArg category, + # but there are commands that fall under ThirdArgAfterKeyCount category + elif cluster_routing_key_index == 3 and is_after_numkeys(command_args, json_key_index): + third_arg_numkeys.add_command(command_name) + continue + + uncategorized.add_command(command_name) + + print("\nNote: the following information considers each word in the command name to be an argument") + print("For example, for 'XGROUP DESTROY key group':") + print("'XGROUP' is arg0, 'DESTROY' is arg1, 'key' is arg2, and 'group' is arg3.\n") + + for category in categories: + print_category(category) + + +def get_request_policy(command_tips): + for command_tip in command_tips: + if command_tip.startswith("REQUEST_POLICY:"): + return command_tip[len("REQUEST_POLICY:"):] + + return None + + +def get_first_key_info(args_info_json) -> tuple[int, bool]: + for i in range(len(args_info_json)): + info = args_info_json[i] + if info["type"].lower() == "key": + is_optional = "optional" in info and info["optional"] + return i, is_optional + + return -1, False + + +def get_first_slot_info(args_info_json) -> tuple[int, bool]: + for i in range(len(args_info_json)): + info = args_info_json[i] + if info["name"].lower() == "slot": + is_optional = "optional" in info and info["optional"] + return i, is_optional + + return -1, False + + +def is_after_numkeys(args_info_json, json_index): + return json_index > 0 and args_info_json[json_index - 1]["name"].lower() == "numkeys" + + +def has_streams_token(args_info_json): + for arg_info in args_info_json: + if "token" in arg_info and arg_info["token"].upper() == "STREAMS": + return True + + return False + + +def print_category(category): + print("============================") + print(f"Category: {category.name} commands") + print(f"Description: {category.description}") + print("List of commands in this category:\n") + + if len(category.commands) == 0: + print("(No commands found for this category)") + else: + category.commands.sort() + for command_name in category.commands: + print(f"{command_name}") + + print("\n") + + +if __name__ == "__main__": + main() diff --git a/glide-core/redis-rs/scripts/update-versions.sh b/glide-core/redis-rs/scripts/update-versions.sh new file mode 100755 index 0000000000..f2800985f0 --- /dev/null +++ b/glide-core/redis-rs/scripts/update-versions.sh @@ -0,0 +1,20 @@ +#!/usr/bin/env bash +# This script is pretty low tech, but it helps keep the doc version numbers +# up to date. It should be run as a `pre-release-hook` from cargo-release. + +set -eo pipefail + +if [ -z "$PREV_VERSION" ] || [ -z "$NEW_VERSION" ]; then + echo "Missing PREV_VERSION or NEW_VERSION." + echo "This script needs to run as a 'pre-release-hook' from cargo-release." + exit 1 +fi + +for file in README.md; do + sed -i.bak -E \ + -e "s|version=[0-9.]+|version=${NEW_VERSION}|g" \ + -e "s|redis/[0-9.]+|redis/${NEW_VERSION}|g" \ + -e "s|redis = \"[0-9.]+\"|redis = \"${NEW_VERSION}\"|g" \ + "${CRATE_ROOT}/$file" + rm "${CRATE_ROOT}/$file.bak" +done diff --git a/glide-core/redis-rs/upload-docs.sh b/glide-core/redis-rs/upload-docs.sh new file mode 100755 index 0000000000..4f6d01cd0f --- /dev/null +++ b/glide-core/redis-rs/upload-docs.sh @@ -0,0 +1,26 @@ +#!/bin/bash + +# Make a new repo for the gh-pages branch +rm -rf .gh-pages +mkdir .gh-pages +cd .gh-pages +git init + +# Copy over the documentation +cp -r ../target/doc/* . +cat < index.html + +redis + +EOF + +# Add, commit and push files +git add -f --all . +git commit -m "Built documentation" +git checkout -b gh-pages +git remote add origin git@github.com:mitsuhiko/redis-rs.git +git push -qf origin gh-pages + +# Cleanup +cd .. +rm -rf .gh-pages diff --git a/go/Cargo.toml b/go/Cargo.toml index 62872578da..05d34e7108 100644 --- a/go/Cargo.toml +++ b/go/Cargo.toml @@ -9,7 +9,7 @@ authors = ["Valkey GLIDE Maintainers"] crate-type = ["cdylib"] [dependencies] -redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "tls", "tokio-native-tls-comp", "tls-rustls-insecure"] } +redis = { path = "../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "tls", "tokio-native-tls-comp", "tls-rustls-insecure"] } glide-core = { path = "../glide-core", features = ["socket-layer"] } tokio = { version = "^1", features = ["rt", "macros", "rt-multi-thread", "time"] } protobuf = { version = "3.3.0", features = [] } diff --git a/go/DEVELOPER.md b/go/DEVELOPER.md index 023828a0cf..ad3ded8e57 100644 --- a/go/DEVELOPER.md +++ b/go/DEVELOPER.md @@ -105,7 +105,7 @@ Before starting this step, make sure you've installed all software requirements. git clone --branch ${VERSION} https://github.com/valkey-io/valkey-glide.git cd valkey-glide ``` -2. Initialize git submodule: +2. Initialize git submodules: ```bash git submodule update --init --recursive ``` @@ -163,7 +163,7 @@ go test -race ./... -run TestConnectionRequestProtobufGeneration_allFieldsSet -v After pulling new changes, ensure that you update the submodules by running the following command: ```bash -git submodule update +git submodule update --init --recursive ``` ### Generate protobuf files diff --git a/java/Cargo.toml b/java/Cargo.toml index 6428f67fa6..c8fa49fe3f 100644 --- a/java/Cargo.toml +++ b/java/Cargo.toml @@ -10,7 +10,7 @@ authors = ["Valkey GLIDE Maintainers"] crate-type = ["cdylib"] [dependencies] -redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager", "tokio-rustls-comp"] } +redis = { path = "../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager", "tokio-rustls-comp"] } glide-core = { path = "../glide-core", features = ["socket-layer"] } tokio = { version = "^1", features = ["rt", "macros", "rt-multi-thread", "time"] } logger_core = {path = "../logger_core"} diff --git a/node/DEVELOPER.md b/node/DEVELOPER.md index f71966862e..a3391c3282 100644 --- a/node/DEVELOPER.md +++ b/node/DEVELOPER.md @@ -70,6 +70,7 @@ Before starting this step, make sure you've installed all software requirments. git submodule update --init --recursive ``` 3. Install all node dependencies: + ```bash cd node npm i @@ -77,6 +78,7 @@ Before starting this step, make sure you've installed all software requirments. npm i cd .. ``` + 4. Build the Node wrapper (Choose a build option from the following and run it from the `node` folder): 1. Build in release mode, stripped from all debug symbols (optimized and minimized binary size): diff --git a/node/rust-client/Cargo.toml b/node/rust-client/Cargo.toml index e9e2af8851..f9baaf6cc2 100644 --- a/node/rust-client/Cargo.toml +++ b/node/rust-client/Cargo.toml @@ -11,7 +11,7 @@ authors = ["Valkey GLIDE Maintainers"] crate-type = ["cdylib"] [dependencies] -redis = { path = "../../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp"] } +redis = { path = "../../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "tokio-rustls-comp"] } glide-core = { path = "../../glide-core", features = ["socket-layer"] } tokio = { version = "1", features = ["rt", "macros", "rt-multi-thread", "time"] } napi = {version = "2.14", features = ["napi4", "napi6"] } diff --git a/python/Cargo.toml b/python/Cargo.toml index 16632945bb..3945322cd2 100644 --- a/python/Cargo.toml +++ b/python/Cargo.toml @@ -13,7 +13,7 @@ crate-type = ["cdylib"] [dependencies] pyo3 = { version = "^0.20", features = ["extension-module", "num-bigint"] } bytes = { version = "1.6.0" } -redis = { path = "../submodules/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager","tokio-rustls-comp"] } +redis = { path = "../glide-core/redis-rs/redis", features = ["aio", "tokio-comp", "connection-manager","tokio-rustls-comp"] } glide-core = { path = "../glide-core", features = ["socket-layer"] } logger_core = {path = "../logger_core"} diff --git a/python/DEVELOPER.md b/python/DEVELOPER.md index abf12dc9a3..a3e5b07237 100644 --- a/python/DEVELOPER.md +++ b/python/DEVELOPER.md @@ -2,15 +2,13 @@ This document describes how to set up your development environment to build and test the Valkey GLIDE Python wrapper. -### Development Overview - The Valkey GLIDE Python wrapper consists of both Python and Rust code. Rust bindings for Python are implemented using [PyO3](https://github.com/PyO3/pyo3), and the Python package is built using [maturin](https://github.com/PyO3/maturin). The Python and Rust components communicate using the [protobuf](https://github.com/protocolbuffers/protobuf) protocol. -### Build from source +# Prerequisites +--- -#### Prerequisites +Before building the package from source, make sure that you have installed the listed dependencies below: -Software Dependencies - python3 virtualenv - git @@ -21,7 +19,10 @@ Software Dependencies - openssl-dev - rustup -**Dependencies installation for Ubuntu** +For your convenience, we wrapped the steps in a "copy-paste" code blocks for common operating systems: + +
+Ubuntu / Debian ```bash sudo apt update -y @@ -42,7 +43,10 @@ export PATH="$PATH:$HOME/.local/bin" protoc --version ``` -**Dependencies installation for CentOS** +
+ +
+CentOS ```bash sudo yum update -y @@ -62,7 +66,10 @@ export PATH="$PATH:$HOME/.local/bin" protoc --version ``` -**Dependencies installation for MacOS** +
+ +
+MacOS ```bash brew update @@ -80,112 +87,108 @@ source /Users/$USER/.bash_profile protoc --version ``` -#### Building and installation steps - -Before starting this step, make sure you've installed all software requirments. +
-1. Clone the repository: - ```bash - git clone https://github.com/valkey-io/valkey-glide.git - cd valkey-glide - ``` -2. Initialize git submodule: - ```bash - git submodule update --init --recursive - ``` -3. Generate protobuf files: - ```bash - GLIDE_ROOT_FOLDER_PATH=. - protoc -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto - ``` -4. Create a virtual environment: - ```bash - cd python - python3 -m venv .env - ``` -5. Activate the virtual environment: - ```bash - source .env/bin/activate - ``` -6. Install requirements: - ```bash - pip install -r requirements.txt - ``` -7. Build the Python wrapper in release mode: - ``` - maturin develop --release --strip - ``` - > **Note:** To build the wrapper binary with debug symbols remove the --strip flag. -8. Run tests: - 1. Ensure that you have installed redis-server or valkey-server and redis-cli or valkey-cli on your host. You can find the Redis installation guide at the following link: [Redis Installation Guide](https://redis.io/docs/install/install-redis/install-redis-on-linux/). You can get Valkey from the following link: [Valkey Download](https://valkey.io/download/). - 2. Validate the activation of the virtual environment from step 4 by ensuring its name (`.env`) is displayed next to your command prompt. - 3. Execute the following command from the python folder: - ```bash - pytest --asyncio-mode=auto - ``` - > **Note:** To run Valkey modules tests, add -k "test_server_modules.py". - -- Install Python development requirements with: +# Building +--- - ```bash - pip install -r python/dev_requirements.txt - ``` +Before starting this step, make sure you've installed all software requirements. -- For a fast build, execute `maturin develop` without the release flag. This will perform an unoptimized build, which is suitable for developing tests. Keep in mind that performance is significantly affected in an unoptimized build, so it's required to include the "--release" flag when measuring performance. +## Prepare your environment -### Test +```bash +mkdir -p $HOME/src +cd $_ +git clone https://github.com/valkey-io/valkey-glide.git +cd valkey-glide +GLIDE_ROOT=$(pwd) +protoc -Iprotobuf=${GLIDE_ROOT}/glide-core/src/protobuf/ \ + --python_out=${GLIDE_ROOT}/python/python/glide \ + ${GLIDE_ROOT}/glide-core/src/protobuf/*.proto +cd python +python3 -m venv .env +source .env/bin/activate +pip install -r requirements.txt +pip install -r python/dev_requirements.txt +``` -To run tests, use the following command: +## Build the package (in release mode): ```bash -pytest --asyncio-mode=auto +maturin develop --release --strip ``` + +> **Note:** to build the wrapper binary with debug symbols remove the `--strip` flag. + +> **Note 2:** for a faster build time, execute `maturin develop` without the release flag. This will perform an unoptimized build, which is suitable for developing tests. Keep in mind that performance is significantly affected in an unoptimized build, so it's required to include the `--release` flag when measuring performance. -To execute a specific test, include the `-k ` option. For example: +# Running tests +--- + +Ensure that you have installed `redis-server` or `valkey-server` along with `redis-cli` or `valkey-cli` on your host. You can find the Redis installation guide at the following link: [Redis Installation Guide](https://redis.io/docs/install/install-redis/install-redis-on-linux/). You can get Valkey from the following link: [Valkey Download](https://valkey.io/download/). + +From a terminal, change directory to the GLIDE source folder and type: ```bash -pytest --asyncio-mode=auto -k test_socket_set_and_get +cd $HOME/src/valkey-glide +cd python +source .env/bin/activate +pytest --asyncio-mode=auto ``` -IT suite starts the server for testing - standalone and cluster installation using `cluster_manager` script. -If you want IT to use already started servers, use the following command line from `python/python` dir: +To run modules tests: ```bash -pytest --asyncio-mode=auto --cluster-endpoints=localhost:7000 --standalone-endpoints=localhost:6379 +cd $HOME/src/valkey-glide +cd python +source .env/bin/activate +pytest --asyncio-mode=auto -k "test_server_modules.py" ``` -### Submodules +**TIP:** to run a specific test, append `-k ` to the `pytest` execution line -After pulling new changes, ensure that you update the submodules by running the following command: +To run tests against an already running servers, change the `pytest` line above to this: ```bash -git submodule update +pytest --asyncio-mode=auto --cluster-endpoints=localhost:7000 --standalone-endpoints=localhost:6379 ``` -### Generate protobuf files +# Generate protobuf files +--- -During the initial build, Python protobuf files were created in `python/python/glide/protobuf`. If modifications are made to the protobuf definition files (.proto files located in `glide-core/src/protofuf`), it becomes necessary to regenerate the Python protobuf files. To do so, run: +During the initial build, Python protobuf files were created in `python/python/glide/protobuf`. If modifications are made +to the protobuf definition files (`.proto` files located in `glide-core/src/protofuf`), it becomes necessary to +regenerate the Python protobuf files. To do so, run: ```bash -GLIDE_ROOT_FOLDER_PATH=. # e.g. /home/ubuntu/valkey-glide -protoc -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto +cd $HOME/src/valkey-glide +GLIDE_ROOT_FOLDER_PATH=. +protoc -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ \ + --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide \ + ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto ``` -#### Protobuf interface files +## Protobuf interface files To generate the protobuf files with Python Interface files (pyi) for type-checking purposes, ensure you have installed `mypy-protobuf` with pip, and then execute the following command: ```bash -GLIDE_ROOT_FOLDER_PATH=. # e.g. /home/ubuntu/valkey-glide +cd $HOME/src/valkey-glide +GLIDE_ROOT_FOLDER_PATH=. MYPY_PROTOC_PATH=`which protoc-gen-mypy` -protoc --plugin=protoc-gen-mypy=${MYPY_PROTOC_PATH} -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide --mypy_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto +protoc --plugin=protoc-gen-mypy=${MYPY_PROTOC_PATH} \ + -Iprotobuf=${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/ \ + --python_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide \ + --mypy_out=${GLIDE_ROOT_FOLDER_PATH}/python/python/glide \ + ${GLIDE_ROOT_FOLDER_PATH}/glide-core/src/protobuf/*.proto ``` -### Linters +# Linters +--- Development on the Python wrapper may involve changes in either the Python or Rust code. Each language has distinct linter tests that must be passed before committing changes. -#### Language-specific Linters +## Language-specific Linters **Python:** @@ -199,31 +202,37 @@ Development on the Python wrapper may involve changes in either the Python or Ru - clippy - fmt -#### Running the linters +## Running the linters Run from the main `/python` folder 1. Python - > Note: make sure to [generate protobuf with interface files]("#protobuf-interface-files") before running mypy linter + > Note: make sure to [generate protobuf with interface files]("#protobuf-interface-files") before running `mypy` linter ```bash + cd $HOME/src/valkey-glide/python + source .env/bin/activate pip install -r dev_requirements.txt isort . --profile black --skip-glob python/glide/protobuf --skip-glob .env black . --exclude python/glide/protobuf --exclude .env - flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 - flake8 . --count --exit-zero --max-complexity=12 --max-line-length=127 --statistics --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics \ + --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 + flake8 . --count --exit-zero --max-complexity=12 --max-line-length=127 \ + --statistics --exclude=python/glide/protobuf,.env/* \ + --extend-ignore=E230 # run type check mypy . ``` + 2. Rust ```bash rustup component add clippy rustfmt cargo clippy --all-features --all-targets -- -D warnings cargo fmt --manifest-path ./Cargo.toml --all - ``` -### Recommended extensions for VS Code +# Recommended extensions for VS Code +--- - [Python](https://marketplace.visualstudio.com/items?itemName=ms-python.python) - [isort](https://marketplace.visualstudio.com/items?itemName=ms-python.isort) diff --git a/python/python/tests/test_async_client.py b/python/python/tests/test_async_client.py index 7566194dcc..c9744157d6 100644 --- a/python/python/tests/test_async_client.py +++ b/python/python/tests/test_async_client.py @@ -1,4 +1,5 @@ # Copyright Valkey GLIDE Project Contributors - SPDX Identifier: Apache-2.0 +# mypy: disable_error_code="arg-type" from __future__ import annotations diff --git a/submodules/redis-rs b/submodules/redis-rs deleted file mode 160000 index 396536db31..0000000000 --- a/submodules/redis-rs +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 396536db31fbf2de0f272d8179d68286329fa70e From b7ff364f01175aa59077e4ca9e6d4151163771e2 Mon Sep 17 00:00:00 2001 From: Eran Ifrah Date: Mon, 14 Oct 2024 21:24:29 +0300 Subject: [PATCH 2/3] Makefile improvements - Separate the "test" target from the "lint" - Added new "help" target for listing all possible targets Signed-off-by: Eran Ifrah --- Makefile | 66 +++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 54 insertions(+), 12 deletions(-) diff --git a/Makefile b/Makefile index 438cd4d08d..e2172fd768 100644 --- a/Makefile +++ b/Makefile @@ -9,25 +9,49 @@ PYENV_DIR=$(shell pwd)/python/.env PY_PATH=$(shell find python/.env -name "site-packages"|xargs readlink -f) PY_GLIDE_PATH=$(shell pwd)/python/python/ -all: java java-test python python-test node node-test go go-test +all: java java-test python python-test node node-test go go-test python-lint java-lint +## +## Java targets +## java: @echo "$(GREEN)Building for Java (release)$(RESET)" @cd java && ./gradlew :client:buildAllRelease -java-test: check-redis-server +java-lint: @echo "$(GREEN)Running spotlessCheck$(RESET)" @cd java && ./gradlew :spotlessCheck @echo "$(GREEN)Running spotlessApply$(RESET)" @cd java && ./gradlew :spotlessApply + +java-test: check-redis-server @echo "$(GREEN)Running integration tests$(RESET)" @cd java && ./gradlew :integTest:test +## +## Python targets +## python: .build/python_deps @echo "$(GREEN)Building for Python (release)$(RESET)" @cd python && VIRTUAL_ENV=$(PYENV_DIR) .env/bin/maturin develop --release --strip -# Python dependencies +python-lint: .build/python_deps + @echo "$(GREEN)Building Linters for python$(RESET)" + cd python && \ + export VIRTUAL_ENV=$(PYENV_DIR); \ + export PYTHONPATH=$(PY_PATH):$(PY_GLIDE_PATH); \ + export PATH=$(PYENV_DIR)/bin:$(PATH); \ + isort . --profile black --skip-glob python/glide/protobuf --skip-glob .env && \ + black . --exclude python/glide/protobuf --exclude .env && \ + flake8 . --count --select=E9,F63,F7,F82 --show-source --statistics \ + --exclude=python/glide/protobuf,.env/* --extend-ignore=E230 && \ + flake8 . --count --exit-zero --max-complexity=12 --max-line-length=127 \ + --statistics --exclude=python/glide/protobuf,.env/* \ + --extend-ignore=E230 + +python-test: .build/python_deps check-redis-server + cd python && PYTHONPATH=$(PY_PATH):$(PY_GLIDE_PATH) .env/bin/pytest --asyncio-mode=auto + .build/python_deps: @echo "$(GREEN)Generating protobuf files...$(RESET)" @protoc -Iprotobuf=$(ROOT_DIR)/glide-core/src/protobuf/ \ @@ -36,41 +60,59 @@ python: .build/python_deps @cd python && python3 -m venv .env @echo "$(GREEN)Installing requirements...$(RESET)" @cd python && .env/bin/pip install -r requirements.txt + @cd python && .env/bin/pip install -r dev_requirements.txt @mkdir -p .build/ && touch .build/python_deps -python-test: check-redis-server - cd python && PYTHONPATH=$(PY_PATH):$(PY_GLIDE_PATH) .env/bin/pytest --asyncio-mode=auto - +## +## NodeJS targets +## node: .build/node_deps @echo "$(GREEN)Building for NodeJS (release)...$(RESET)" @cd node && npm run build:release -# NodeJS dependencies .build/node_deps: @echo "$(GREEN)Installing NodeJS dependencies...$(RESET)" @cd node && npm i @cd node/rust-client && npm i @mkdir -p .build/ && touch .build/node_deps -node-test: check-redis-server +node-test: .build/node_deps check-redis-server @echo "$(GREEN)Running tests for NodeJS$(RESET)" @cd node && npm run build cd node && npm test -# Check for the existence of redis-server by simply calling which shell command -check-redis-server: - which redis-server +node-lint: .build/node_deps + @echo "$(GREEN)Running linters for NodeJS$(RESET)" + @cd node && npx run lint:fix + +## +## Go targets +## + go: .build/go_deps $(MAKE) -C go build -go-test: +go-test: .build/go_deps $(MAKE) -C go test +go-lint: .build/go_deps + $(MAKE) -C go lint + .build/go_deps: @echo "$(GREEN)Installing GO dependencies...$(RESET)" $(MAKE) -C go install-build-tools @mkdir -p .build/ && touch .build/go_deps +## +## Common targets +## +check-redis-server: + which redis-server + clean: rm -fr .build/ + +help: + @echo "$(GREEN)Listing Makefile targets:$(RESET)" + @echo $(shell grep '^[^#[:space:]].*:' Makefile|cut -d":" -f1|grep -v PHONY|grep -v "^.build"|sort) From a0617b8fd1e7458726be3c8819e33a4df1803844 Mon Sep 17 00:00:00 2001 From: Eran Ifrah Date: Mon, 14 Oct 2024 18:37:06 +0000 Subject: [PATCH 3/3] Added target `go-lint` Signed-off-by: Eran Ifrah --- Makefile | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Makefile b/Makefile index e2172fd768..92bcf7acc5 100644 --- a/Makefile +++ b/Makefile @@ -101,7 +101,7 @@ go-lint: .build/go_deps .build/go_deps: @echo "$(GREEN)Installing GO dependencies...$(RESET)" - $(MAKE) -C go install-build-tools + $(MAKE) -C go install-build-tools install-dev-tools @mkdir -p .build/ && touch .build/go_deps ##