Skip to content

Commit

Permalink
Use local chains code
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Nov 14, 2024
1 parent 51afc15 commit 50f05a9
Show file tree
Hide file tree
Showing 11 changed files with 41 additions and 14 deletions.
4 changes: 3 additions & 1 deletion truss-chains/tests/chains_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def test_chain():
root = Path(__file__).parent.resolve()
chain_root = root / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
options = definitions.PushOptionsLocalDocker(chain_name="integration-test")
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
service = remote.push(entrypoint, options)

url = service.run_remote_url.replace("host.docker.internal", "localhost")
Expand Down
8 changes: 7 additions & 1 deletion truss-chains/tests/itest_chain/itest_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ def run_remote(self, length: int) -> str:
class TextReplicator(chains.ChainletBase):
remote_config = chains.RemoteConfig(docker_image=IMAGE_CUSTOM)

def __init__(self, context=chains.depends_context()):
def __init__(self):
try:
import pytzdata

print(f"Could import {pytzdata} is present")
except ModuleNotFoundError:
print("Could not import pytzdata is present")
self.multiplier = 2

def run_remote(self, data: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion truss-chains/tests/itest_chain/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
git+https://github.com/basetenlabs/truss.git
pytzdata
4 changes: 4 additions & 0 deletions truss-chains/truss_chains/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,6 +563,7 @@ def _make_truss_config(
chains_config: definitions.RemoteConfig,
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
model_name: str,
use_local_chains_src: bool,
) -> truss_config.TrussConfig:
"""Generate a truss config for a Chainlet."""
config = truss_config.TrussConfig()
Expand Down Expand Up @@ -592,6 +593,7 @@ def _make_truss_config(
if chains_config.docker_image.external_package_dirs:
for ext_dir in chains_config.docker_image.external_package_dirs:
config.external_package_dirs.append(ext_dir.abs_path)
config.use_local_chains_src = use_local_chains_src
# Assets.
assets = chains_config.get_asset_spec()
config.secrets = assets.secrets
Expand Down Expand Up @@ -624,6 +626,7 @@ def gen_truss_chainlet(
chainlet_descriptor: definitions.ChainletAPIDescriptor,
model_name: str,
chainlet_display_name_to_url: Mapping[str, str],
use_local_chains_src: bool,
) -> pathlib.Path:
# Filter needed services and customize options.
dep_services = {}
Expand All @@ -641,6 +644,7 @@ def gen_truss_chainlet(
chainlet_descriptor.chainlet_cls.remote_config,
dep_services,
model_name,
use_local_chains_src,
)
# TODO This assumes all imports are absolute w.r.t chain root (or site-packages).
truss_path.copy_tree_path(
Expand Down
1 change: 1 addition & 0 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ class GenericRemoteException(Exception): ...
class PushOptions(SafeModelNonSerializable):
chain_name: str
only_generate_trusses: bool = False
use_local_chains_src: bool = False


class PushOptionsBaseten(PushOptions):
Expand Down
16 changes: 9 additions & 7 deletions truss-chains/truss_chains/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from truss.remote.baseten import custom_types as b10_types
from truss.remote.baseten import remote as b10_remote
from truss.remote.baseten import service as b10_service
from truss.truss_handle import build as truss_build
from truss.truss_handle import truss_handle
from truss.util import log_utils
from truss.util import path as truss_path

Expand All @@ -45,8 +45,8 @@
def _push_to_baseten(
truss_dir: pathlib.Path, options: definitions.PushOptionsBaseten, chainlet_name: str
) -> b10_service.BasetenService:
truss_handle = truss_build.load(str(truss_dir))
model_name = truss_handle.spec.config.model_name
th = truss_handle.TrussHandle(truss_dir)
model_name = th.spec.config.model_name
assert model_name is not None
assert bool(_MODEL_NAME_RE.match(model_name))
logging.info(
Expand All @@ -55,7 +55,7 @@ def _push_to_baseten(
)
# Models must be trusted to use the API KEY secret.
service = options.remote_provider.push(
truss_handle,
th,
model_name=model_name,
trusted=True,
publish=options.publish,
Expand Down Expand Up @@ -111,11 +111,11 @@ def _push_service(
f"Running in docker container `{chainlet_descriptor.display_name}` "
)
port = utils.get_free_port()
truss_handle = truss_build.load(str(truss_dir))
truss_handle.add_secret(
th = truss_handle.TrussHandle(truss_dir)
th.add_secret(
definitions.BASETEN_API_SECRET_NAME, options.baseten_chain_api_key
)
truss_handle.docker_run(
th.docker_run(
local_port=port,
detach=True,
wait_for_server_ready=True,
Expand Down Expand Up @@ -392,6 +392,7 @@ def push(
chainlet_descriptor,
model_name,
chainlet_display_name_to_url,
self._options.use_local_chains_src,
)
if self._options.only_generate_trusses:
chainlet_display_name_to_url[chainlet_descriptor.display_name] = (
Expand Down Expand Up @@ -557,6 +558,7 @@ def _code_gen_and_patch_thread(
descr,
self._chainlet_data[descr.display_name].oracle_name,
self._chainlet_display_name_to_url,
use_local_chains_src=False,
)
patch_result = self._remote_provider.patch_for_chainlet(
chainlet_dir, self._ignore_patterns
Expand Down
1 change: 1 addition & 0 deletions truss/base/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TEMPLATES_DIR / SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME
)
CONTROL_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "control"
CHAINS_CODE_DIR: pathlib.Path = _TRUSS_ROOT.parent / "truss-chains" / "truss_chains"

SUPPORTED_PYTHON_VERSIONS = {"3.8", "3.9", "3.10", "3.11"}
MAX_SUPPORTED_PYTHON_VERSION_IN_CUSTOM_BASE_IMAGE = "3.12"
Expand Down
2 changes: 2 additions & 0 deletions truss/base/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ class TrussConfig:
model_cache: ModelCache = field(default_factory=ModelCache)
trt_llm: Optional[TRTLLMConfiguration] = None
build_commands: List[str] = field(default_factory=list)
use_local_chains_src: bool = False

@property
def canonical_python_version(self) -> str:
Expand Down Expand Up @@ -619,6 +620,7 @@ def from_dict(d):
d.get("trt_llm"), lambda x: TRTLLMConfiguration(**x)
),
build_commands=d.get("build_commands", []),
use_local_chains_src=d.get("use_local_chains_src", False),
)
config.validate()
return config
Expand Down
8 changes: 6 additions & 2 deletions truss/contexts/image_builder/serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AUDIO_MODEL_TRTLLM_TRUSS_DIR,
BASE_SERVER_REQUIREMENTS_TXT_FILENAME,
BASE_TRTLLM_REQUIREMENTS,
CHAINS_CODE_DIR,
CONTROL_SERVER_CODE_DIR,
DOCKER_SERVER_TEMPLATES_DIR,
FILENAME_CONSTANTS_MAP,
Expand Down Expand Up @@ -68,6 +69,7 @@
BUILD_SERVER_DIR_NAME = "server"
BUILD_CONTROL_SERVER_DIR_NAME = "control"
BUILD_SERVER_EXTENSIONS_PATH = "extensions"
BUILD_CHAINS_DIR_NAME = "truss_chains"

CONFIG_FILE = "config.yaml"
USER_TRUSS_IGNORE_FILE = ".truss_ignore"
Expand Down Expand Up @@ -356,8 +358,6 @@ def prepare_image_build_dir(
# TODO(pankaj) We probably don't need model framework specific directory.
build_dir = build_truss_target_directory(model_framework_name)

data_dir = build_dir / config.data_dir # type: ignore[operator]

def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
copy_tree_or_file(from_path, build_dir / path_in_build_dir) # type: ignore[operator]

Expand Down Expand Up @@ -464,6 +464,9 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
+ SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME,
)

if config.use_local_chains_src:
copy_into_build_dir(CHAINS_CODE_DIR, BUILD_CHAINS_DIR_NAME)

# Copy base TrussServer requirements if supplied custom base image
base_truss_server_reqs_filepath = SERVER_CODE_DIR / REQUIREMENTS_TXT_FILENAME
if config.base_image:
Expand Down Expand Up @@ -604,6 +607,7 @@ def _render_dockerfile(
hf_access_token_file_name=HF_ACCESS_TOKEN_FILE_NAME,
external_data_files=external_data_files,
build_commands=build_commands,
use_local_chains_src=config.use_local_chains_src,
**FILENAME_CONSTANTS_MAP,
)
docker_file_path = build_dir / MODEL_DOCKERFILE_NAME
Expand Down
4 changes: 4 additions & 0 deletions truss/templates/server.Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ COPY ./{{config.data_dir}} /app/data

COPY ./server /app

{%- if use_local_chains_src %}
COPY ./truss_chains /app/truss_chains
{%- endif %}

COPY ./config.yaml /app/config.yaml
{%- if config.live_reload and not config.docker_server%}
COPY ./control /control
Expand Down
5 changes: 3 additions & 2 deletions truss/truss_handle/truss_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from dataclasses import replace
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from urllib.error import HTTPError

import requests
Expand Down Expand Up @@ -46,6 +46,7 @@
ServingImageBuilderContext,
)
from truss.contexts.local_loader.load_model_local import LoadModelLocal
from truss.contexts.truss_context import TrussContext
from truss.local.local_config_handler import LocalConfigHandler
from truss.templates.shared.serialization import (
truss_msgpack_deserialize,
Expand Down Expand Up @@ -950,7 +951,7 @@ def _get_serving_lookup_labels(self) -> Dict[str, Any]:

def _build_image(
self,
builder_context,
builder_context: Type[TrussContext],
labels: Dict[str, str],
build_dir: Optional[Path] = None,
tag: Optional[str] = None,
Expand Down

0 comments on commit 50f05a9

Please sign in to comment.