From 99edfd2a232ef1dc31ec085bfc83b76c92d8ed70 Mon Sep 17 00:00:00 2001 From: "marius.baseten" Date: Mon, 6 Jan 2025 10:23:23 -0800 Subject: [PATCH] Smoke WIP --- .github/actions/setup-python/action.yml | 2 +- .github/workflows/pr.yml | 4 +- .github/workflows/smoketests.yml | 56 +++++ .pre-commit-config.yaml | 4 +- poetry.lock | 16 +- pyproject.toml | 4 +- smoketests/test_chains.py | 236 ++++++++++++++++++ .../truss_chains/remote_chainlet/stub.py | 9 + truss/remote/baseten/api.py | 20 ++ 9 files changed, 344 insertions(+), 7 deletions(-) create mode 100644 .github/workflows/smoketests.yml create mode 100644 smoketests/test_chains.py diff --git a/.github/actions/setup-python/action.yml b/.github/actions/setup-python/action.yml index c856a8ebf..89a00779e 100644 --- a/.github/actions/setup-python/action.yml +++ b/.github/actions/setup-python/action.yml @@ -7,7 +7,7 @@ runs: steps: - uses: actions/setup-python@v5 with: - python-version: '3.9.9' + python-version: '3.9.21' - name: Get full Python version id: full-python-version diff --git a/.github/workflows/pr.yml b/.github/workflows/pr.yml index 963da7544..cb61a9046 100644 --- a/.github/workflows/pr.yml +++ b/.github/workflows/pr.yml @@ -36,11 +36,11 @@ jobs: - run: poetry install --with=dev,dev-server --extras=all - name: run tests run: poetry run pytest --durations=0 -m 'not integration' --junitxml=report.xml - - name: Publish Test Report # Not sure how to display this in the UI for non PRs. + - name: Publish Test Report uses: mikepenz/action-junit-report@v4 if: always() with: - commit: ${{github.event.workflow_run.head_sha}} # Doest that work outside PR? + commit: ${{ github.event.workflow_run.head_sha }} report_paths: "report.xml" markdown-link-check: diff --git a/.github/workflows/smoketests.yml b/.github/workflows/smoketests.yml new file mode 100644 index 000000000..8f4db1ddb --- /dev/null +++ b/.github/workflows/smoketests.yml @@ -0,0 +1,56 @@ +name: Truss CLI E2E tests (chains) + +on: + push: # Remove after testing. + workflow_dispatch: + inputs: + truss_version: + description: "The version of Truss to install" + required: false +jobs: + test-chains: + runs-on: ubuntu-22.04 + steps: + - name: Checkout + uses: actions/checkout@v4 + with: + lfs: true + - name: Setup Python + uses: ./.github/actions/setup-python/ + - name: Poetry Install + run: poetry install --with=dev,dev-server --extras=all + - name: Determine Truss Version + id: truss_version + run: | + if [ -n "${{ github.event.inputs.truss_version }}" ]; then + echo "Using workflow_dispatch input: ${{ github.event.inputs.truss_version }}" + echo "TRUSS_VERSION=${{ github.event.inputs.truss_version }}" >> $GITHUB_ENV + else + echo "Using default Truss version: latest" + echo "TRUSS_VERSION=latest" >> $GITHUB_ENV + fi + - name: Install Truss + run: | + python -m venv truss_env + if [ "${{ env.TRUSS_VERSION }}" = "latest" ]; then + echo "Installing the latest version of Truss" + truss_env/bin/pip install truss==0.9.58rc101 + else + echo "Installing Truss version ${{ env.TRUSS_VERSION }}" + truss_env/bin/pip install truss==${{ env.TRUSS_VERSION }} + fi + - name: Run tests + env: + TRUSS_ENV_PATH: ${{ github.workspace }}/truss_env + run: | + BASETEN_API_KEY_STAGING="${{ secrets.BASETEN_API_KEY_STAGING }}" \ + poetry run pytest smoketests \ + --durations=0 \ + --junitxml=report.xml \ + -s --log-cli-level=INFO + - name: Publish Test Report + uses: mikepenz/action-junit-report@v4 + if: always() + with: + commit: ${{ github.sha }} + report_paths: "report.xml" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index aa8753ff1..6b64e6d98 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -33,14 +33,14 @@ repos: entry: poetry run mypy language: python types: [python] - exclude: ^examples/|^truss/test.+/|model.py$|^truss-chains/.* + exclude: ^examples/|^truss/test.+/|model.py$|^truss-chains/.*|^smoketest/.* pass_filenames: true - id: mypy name: mypy-local (3.9) entry: poetry run mypy language: python types: [python] - files: ^truss-chains/.* + files: ^truss-chains/.*|^smoketest/.* args: - --python-version=3.9 pass_filenames: true diff --git a/poetry.lock b/poetry.lock index b405609e4..8b2a5a9ca 100644 --- a/poetry.lock +++ b/poetry.lock @@ -3012,6 +3012,20 @@ pytest = ">=7.0.0,<9" docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] testing = ["coverage (>=6.2)", "hypothesis (>=5.7.1)"] +[[package]] +name = "pytest-check" +version = "2.4.1" +description = "A pytest plugin that allows multiple failures per test." +optional = false +python-versions = ">=3.8" +files = [ + {file = "pytest_check-2.4.1-py3-none-any.whl", hash = "sha256:74f38938183880d9921aeb85662437d2b13e1e053e1bed7d186d54613d3068c7"}, + {file = "pytest_check-2.4.1.tar.gz", hash = "sha256:5224efcef059bf7f0cda253f8d0f62704b4819ff48c93f51c675aea6a014f650"}, +] + +[package.dependencies] +pytest = ">=7.0.0" + [[package]] name = "pytest-cov" version = "3.0.0" @@ -4317,4 +4331,4 @@ all = [] [metadata] lock-version = "2.0" python-versions = ">=3.8,<3.13" -content-hash = "18d0641fc35dd7d4e989aee99072d09ca82b708a0b360d2ed6f6f1a04a81348f" +content-hash = "1fedf4a848019ebf8e5a5fdc6d7a9387964adbf32ecf30fc34b95e3dfb0ed91e" diff --git a/pyproject.toml b/pyproject.toml index d922afa00..b1ae3f4ec 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "truss" -version = "0.9.58rc3" +version = "0.9.58rc101" description = "A seamless bridge from model development to model delivery" license = "MIT" readme = "README.md" @@ -148,6 +148,7 @@ types-PyYAML = "^6.0.12.12" types-aiofiles = ">=24.1.0" types-requests = "==2.31.0.2" types-setuptools = "^69.0.0.0" +pytest-check = "^2.4.1" [tool.poetry.group.dev-server.dependencies] # These packages are needed to run local tests of server components. Note that the actual @@ -178,6 +179,7 @@ markers = [ "integration: marks tests as integration (deselect with '-m \"not integration\"').", "asyncio: marks tests as async.", ] +addopts = "--ignore=smoketests" [tool.ruff] src = ["truss", "truss-chains", "truss-utils"] diff --git a/smoketests/test_chains.py b/smoketests/test_chains.py new file mode 100644 index 000000000..3e755942a --- /dev/null +++ b/smoketests/test_chains.py @@ -0,0 +1,236 @@ +import logging +import os +import pathlib +import re +import subprocess +import tempfile +import time +import uuid +from typing import Tuple + +import pytest +import pytest_check +from truss.remote.baseten import core +from truss.remote.baseten import remote as b10_remote +from truss.remote.baseten.utils import status as status_utils + +from truss_chains import definitions +from truss_chains.remote_chainlet import stub + +backend_env_domain = "staging.baseten.co" +BASETEN_API_KEY = os.environ["BASETEN_API_KEY_STAGING"] + +BASETEN_REMOTE_URL = f"https://app.{backend_env_domain}" +VENV_PATH = pathlib.Path(os.environ["TRUSS_ENV_PATH"]) +CHAINS_ROOT = pathlib.Path(__file__).parent.parent.resolve() / "truss-chains" +URL_RE = re.compile( + rf"https://chain-([a-zA-Z0-9]+)\.api\.{re.escape(backend_env_domain)}/deployment/([a-zA-Z0-9]+)/run_remote" +) +DEPLOY_TIMEOUT_SEC = 500 + + +def make_stub(url: str, options: definitions.RPCOptions) -> stub.StubBase: + context = definitions.DeploymentContext( + chainlet_to_service={}, + secrets={definitions.BASETEN_API_SECRET_NAME: BASETEN_API_KEY}, + ) + return stub.StubBase.from_url(url, context, options) + + +def write_trussrc(api_key: str, dir_path: pathlib.Path) -> pathlib.Path: + config = rf""" + [staging] + remote_provider = baseten + api_key = {api_key} + remote_url = {BASETEN_REMOTE_URL} + """ + truss_rc_path = dir_path / ".trussrc" + truss_rc_path.write_text(config) + return truss_rc_path + + +@pytest.fixture +def prepare(request): + temp_dir = pathlib.Path(tempfile.mkdtemp()) + truss_rc_path = write_trussrc(BASETEN_API_KEY, temp_dir) + remote = b10_remote.BasetenRemote(BASETEN_REMOTE_URL, BASETEN_API_KEY) + mutable_chain_deployment_id = [None] + + yield temp_dir, truss_rc_path, remote, mutable_chain_deployment_id + # if not test_failed: + # shutil.rmtree(temp_dir, ignore_errors=True) + + +def generate_traceparent(): + trace_id = uuid.uuid4().hex + span_id = uuid.uuid4().hex[:16] + trace_flags = "01" + traceparent = f"00-{trace_id}-{span_id}-{trace_flags}" + return traceparent + + +def run_command(truss_rc_path: pathlib.Path, command: str) -> Tuple[str, str]: + logging.info(f"Running command `{command}` in VENV `{VENV_PATH}` (subprocess).") + activate_script = VENV_PATH / "bin" / "activate" + env = os.environ.copy() + env["USER_TRUSSRC_PATH"] = str(truss_rc_path) + full_command = f"bash -c 'source {activate_script} && {command}'" + result = subprocess.run( + full_command, + shell=True, + env=env, + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) + stdout = result.stdout.strip() + stderr = result.stderr.strip() + logging.info("Command subprocess finished.") + return stdout, stderr + + +def wait_ready( + remote: b10_remote.BasetenRemote, chain_id: str, chain_deployment_id: str +) -> Tuple[bool, float]: + logging.info(f"Waiting for chain deployment `{chain_deployment_id}` to be ready.") + t0 = time.perf_counter() + success = False + wait_time_sec = 0.0 + while True: + chainlets = remote.get_chainlets(chain_deployment_id) + statuses = [ + status_utils.get_displayable_status(chainlet.status) + for chainlet in chainlets + ] + num_services = len(statuses) + num_ok = sum(s in [core.ACTIVE_STATUS, "SCALED_TO_ZERO"] for s in statuses) + num_deploying = sum(s in core.DEPLOYING_STATUSES for s in statuses) + if num_ok == num_services: + success = True + break + elif num_services - num_ok - num_deploying: + break + if (wait_time_sec := time.perf_counter() - t0) > DEPLOY_TIMEOUT_SEC: + break + + time.sleep(10) + + if success: + logging.info(f"Deployed ready in {wait_time_sec} sec.") + else: + overview_url = f"{BASETEN_REMOTE_URL}/chains/{chain_id}/overview" + raise Exception( + f"Could not be invoked within {DEPLOY_TIMEOUT_SEC} sec.\n{chainlets}\n" + f"Check deployment `{chain_deployment_id}` on {overview_url}." + ) + + return success, wait_time_sec + + +# Actual tests ######################################################################### + +# def test_truss_version(prepare): +# _, truss_rc_path = prepare +# result = run_command(truss_rc_path, "truss --version") +# assert result.stdout.strip() == "truss, version 0.9.57" + + +def test_itest_chain_publish(prepare) -> None: + remote: b10_remote.BasetenRemote + tmpdir, truss_rc_path, remote, mutable_chain_deployment_id = prepare + + chain_src = CHAINS_ROOT / "tests" / "itest_chain" / "itest_chain.py" + command = f"truss chains push {chain_src} --publish --name=itest_publish --no-wait" + # stdout = ( + # "https://chain-1lqzvkw4.api.staging.baseten.co/deployment/nwx4d0qy/run_remote" + # ) + stdout, stderr = run_command(truss_rc_path, command) + # Warning: Input is not a terminal (fd=0). + # assert not stderr + + matches = URL_RE.search(stdout) + assert matches, stdout + url = matches.group(0) + chain_id = matches.group(1) + chain_deployment_id = matches.group(2) + mutable_chain_deployment_id[0] = chain_deployment_id + + success, wait_time_sec = wait_ready(remote, chain_id, chain_deployment_id) + pytest_check.less(wait_time_sec, 220, "Deployment took too long.") + + # Test regular invocation. + chain_stub = make_stub(url, definitions.RPCOptions(timeout_sec=10)) + trace_parent = generate_traceparent() + with stub.trace_parent_raw(trace_parent): + result = chain_stub.predict_sync({"length": 30, "num_partitions": 3}) + + expected = [ + 6280, + "erodfderodfderodfderodfderodfd", + 123, + {"parts": [], "part_lens": [10]}, + ["a", "b"], + ] + pytest_check.equal(result, expected) + + # Test speed + invocation_times_sec = [] + for i in range(10): + t0 = time.perf_counter() + with stub.trace_parent_raw(trace_parent): + chain_stub.predict_sync({"length": 30, "num_partitions": 3}) + invocation_times_sec.append(time.perf_counter() - t0) + + invocation_times_sec.sort() + logging.info(f"Invocation times(sec): {invocation_times_sec}.") + pytest_check.less(invocation_times_sec[0], 0.32) # Best of 10, could be <0.30.... + + # Test binary invocation. + chain_stub_binary = make_stub( + url, definitions.RPCOptions(timeout_sec=10, use_binary=True) + ) + trace_parent = generate_traceparent() + with stub.trace_parent_raw(trace_parent): + result = chain_stub_binary.predict_sync({"length": 30, "num_partitions": 3}) + + expected = [ + 6280, + "erodfderodfderodfderodfderodfd", + 123, + {"parts": [], "part_lens": [10]}, + ["a", "b"], + ] + pytest_check.equal(result, expected) + + # Test speed + invocation_times_sec = [] + for i in range(10): + t0 = time.perf_counter() + with stub.trace_parent_raw(trace_parent): + chain_stub_binary.predict_sync({"length": 30, "num_partitions": 3}) + invocation_times_sec.append(time.perf_counter() - t0) + + invocation_times_sec.sort() + logging.info(f"Invocation times(sec): {invocation_times_sec}.") + pytest_check.less(invocation_times_sec[0], 0.32) # Best of 10, could be <0.30... + + if pytest_check.any_failures(): + logging.info( + f"There were failures, leaving deployment `{chain_deployment_id}` " + "undeleted for inspection." + ) + else: + logging.info(f"No failures. Deleting deployment `{chain_deployment_id}`.") + remote.api.delete_chain_deployment(chain_id, chain_deployment_id) + + +@pytest.mark.skip("Not Implemented.") +def test_itest_chain_development(prepare): + # 1. Push with watch. + # 2. Invoke. + # 3. Edit code. + # 4. Verify invocation is updated. + # 5. Start watch and edit code again. + # 6. Verify invocation is updated. + # 7. Delete. + ... diff --git a/truss-chains/truss_chains/remote_chainlet/stub.py b/truss-chains/truss_chains/remote_chainlet/stub.py index 33fe5af4b..95a3c43d8 100644 --- a/truss-chains/truss_chains/remote_chainlet/stub.py +++ b/truss-chains/truss_chains/remote_chainlet/stub.py @@ -56,6 +56,15 @@ def trace_parent(request: starlette.requests.Request) -> Iterator[None]: _trace_parent_context.reset(token) +@contextlib.contextmanager +def trace_parent_raw(trace_parent: str) -> Iterator[None]: + token = _trace_parent_context.set(trace_parent) + try: + yield + finally: + _trace_parent_context.reset(token) + + class BasetenSession: """Provides configured HTTP clients, retries rate limit warning etc.""" diff --git a/truss/remote/baseten/api.py b/truss/remote/baseten/api.py index fdc96a485..5b5fc7de9 100644 --- a/truss/remote/baseten/api.py +++ b/truss/remote/baseten/api.py @@ -326,6 +326,26 @@ def get_chainlets_by_deployment_id(self, chain_deployment_id: str): chainlet["chain"] = {"id": resp["data"]["chain_deployment"]["chain"]["id"]} return chainlets + def delete_chain(self, chain_id: str) -> Any: + url = f"{self._rest_api_url}/v1/chains/{chain_id}" + headers = self._auth_token.header() + resp = requests.delete(url, headers=headers) + if not resp.ok: + resp.raise_for_status() + + deployment = resp.json() + return deployment + + def delete_chain_deployment(self, chain_id: str, chain_deployment_id: str) -> Any: + url = f"{self._rest_api_url}/v1/chains/{chain_id}/deployments/{chain_deployment_id}" + headers = self._auth_token.header() + resp = requests.delete(url, headers=headers) + if not resp.ok: + resp.raise_for_status() + + deployment = resp.json() + return deployment + def models(self): query_string = """ {