Skip to content

Commit

Permalink
Use local chains code
Browse files Browse the repository at this point in the history
  • Loading branch information
marius-baseten committed Nov 14, 2024
1 parent 51afc15 commit 471b1c5
Show file tree
Hide file tree
Showing 11 changed files with 47 additions and 16 deletions.
4 changes: 3 additions & 1 deletion truss-chains/tests/chains_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@ def test_chain():
root = Path(__file__).parent.resolve()
chain_root = root / "itest_chain" / "itest_chain.py"
with framework.import_target(chain_root, "ItestChain") as entrypoint:
options = definitions.PushOptionsLocalDocker(chain_name="integration-test")
options = definitions.PushOptionsLocalDocker(
chain_name="integration-test", use_local_chains_src=True
)
service = remote.push(entrypoint, options)

url = service.run_remote_url.replace("host.docker.internal", "localhost")
Expand Down
8 changes: 7 additions & 1 deletion truss-chains/tests/itest_chain/itest_chain.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,13 @@ def run_remote(self, length: int) -> str:
class TextReplicator(chains.ChainletBase):
remote_config = chains.RemoteConfig(docker_image=IMAGE_CUSTOM)

def __init__(self, context=chains.depends_context()):
def __init__(self):
try:
import pytzdata

print(f"Could import {pytzdata} is present")
except ModuleNotFoundError:
print("Could not import pytzdata is present")
self.multiplier = 2

def run_remote(self, data: str) -> str:
Expand Down
2 changes: 1 addition & 1 deletion truss-chains/tests/itest_chain/requirements.txt
Original file line number Diff line number Diff line change
@@ -1 +1 @@
git+https://github.com/basetenlabs/truss.git
pytzdata
12 changes: 10 additions & 2 deletions truss-chains/truss_chains/code_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,7 +511,9 @@ def _gen_truss_chainlet_file(
# Truss Gen ############################################################################


def _make_requirements(image: definitions.DockerImage) -> list[str]:
def _make_requirements(
image: definitions.DockerImage, use_local_chains_src: bool
) -> list[str]:
"""Merges file- and list-based requirements and adds truss git if not present."""
pip_requirements: set[str] = set()
if image.pip_requirements_file:
Expand Down Expand Up @@ -563,6 +565,7 @@ def _make_truss_config(
chains_config: definitions.RemoteConfig,
chainlet_to_service: Mapping[str, definitions.ServiceDescriptor],
model_name: str,
use_local_chains_src: bool,
) -> truss_config.TrussConfig:
"""Generate a truss config for a Chainlet."""
config = truss_config.TrussConfig()
Expand All @@ -580,7 +583,9 @@ def _make_truss_config(
config.runtime.predict_concurrency = compute.predict_concurrency
# Image.
_inplace_fill_base_image(chains_config.docker_image, config)
pip_requirements = _make_requirements(chains_config.docker_image)
pip_requirements = _make_requirements(
chains_config.docker_image, use_local_chains_src
)
# TODO: `pip_requirements` will add server requirements which give version
# conflicts. Check if that's still the case after relaxing versions.
# config.requirements = pip_requirements
Expand All @@ -592,6 +597,7 @@ def _make_truss_config(
if chains_config.docker_image.external_package_dirs:
for ext_dir in chains_config.docker_image.external_package_dirs:
config.external_package_dirs.append(ext_dir.abs_path)
config.use_local_chains_src = use_local_chains_src
# Assets.
assets = chains_config.get_asset_spec()
config.secrets = assets.secrets
Expand Down Expand Up @@ -624,6 +630,7 @@ def gen_truss_chainlet(
chainlet_descriptor: definitions.ChainletAPIDescriptor,
model_name: str,
chainlet_display_name_to_url: Mapping[str, str],
use_local_chains_src: bool,
) -> pathlib.Path:
# Filter needed services and customize options.
dep_services = {}
Expand All @@ -641,6 +648,7 @@ def gen_truss_chainlet(
chainlet_descriptor.chainlet_cls.remote_config,
dep_services,
model_name,
use_local_chains_src,
)
# TODO This assumes all imports are absolute w.r.t chain root (or site-packages).
truss_path.copy_tree_path(
Expand Down
1 change: 1 addition & 0 deletions truss-chains/truss_chains/definitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -620,6 +620,7 @@ class GenericRemoteException(Exception): ...
class PushOptions(SafeModelNonSerializable):
chain_name: str
only_generate_trusses: bool = False
use_local_chains_src: bool = False


class PushOptionsBaseten(PushOptions):
Expand Down
16 changes: 9 additions & 7 deletions truss-chains/truss_chains/remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from truss.remote.baseten import custom_types as b10_types
from truss.remote.baseten import remote as b10_remote
from truss.remote.baseten import service as b10_service
from truss.truss_handle import build as truss_build
from truss.truss_handle import truss_handle
from truss.util import log_utils
from truss.util import path as truss_path

Expand All @@ -45,8 +45,8 @@
def _push_to_baseten(
truss_dir: pathlib.Path, options: definitions.PushOptionsBaseten, chainlet_name: str
) -> b10_service.BasetenService:
truss_handle = truss_build.load(str(truss_dir))
model_name = truss_handle.spec.config.model_name
th = truss_handle.TrussHandle(truss_dir)
model_name = th.spec.config.model_name
assert model_name is not None
assert bool(_MODEL_NAME_RE.match(model_name))
logging.info(
Expand All @@ -55,7 +55,7 @@ def _push_to_baseten(
)
# Models must be trusted to use the API KEY secret.
service = options.remote_provider.push(
truss_handle,
th,
model_name=model_name,
trusted=True,
publish=options.publish,
Expand Down Expand Up @@ -111,11 +111,11 @@ def _push_service(
f"Running in docker container `{chainlet_descriptor.display_name}` "
)
port = utils.get_free_port()
truss_handle = truss_build.load(str(truss_dir))
truss_handle.add_secret(
th = truss_handle.TrussHandle(truss_dir)
th.add_secret(
definitions.BASETEN_API_SECRET_NAME, options.baseten_chain_api_key
)
truss_handle.docker_run(
th.docker_run(
local_port=port,
detach=True,
wait_for_server_ready=True,
Expand Down Expand Up @@ -392,6 +392,7 @@ def push(
chainlet_descriptor,
model_name,
chainlet_display_name_to_url,
self._options.use_local_chains_src,
)
if self._options.only_generate_trusses:
chainlet_display_name_to_url[chainlet_descriptor.display_name] = (
Expand Down Expand Up @@ -557,6 +558,7 @@ def _code_gen_and_patch_thread(
descr,
self._chainlet_data[descr.display_name].oracle_name,
self._chainlet_display_name_to_url,
use_local_chains_src=False,
)
patch_result = self._remote_provider.patch_for_chainlet(
chainlet_dir, self._ignore_patterns
Expand Down
1 change: 1 addition & 0 deletions truss/base/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
TEMPLATES_DIR / SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME
)
CONTROL_SERVER_CODE_DIR: pathlib.Path = TEMPLATES_DIR / "control"
CHAINS_CODE_DIR: pathlib.Path = _TRUSS_ROOT.parent / "truss-chains" / "truss_chains"

SUPPORTED_PYTHON_VERSIONS = {"3.8", "3.9", "3.10", "3.11"}
MAX_SUPPORTED_PYTHON_VERSION_IN_CUSTOM_BASE_IMAGE = "3.12"
Expand Down
2 changes: 2 additions & 0 deletions truss/base/truss_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,7 @@ class TrussConfig:
model_cache: ModelCache = field(default_factory=ModelCache)
trt_llm: Optional[TRTLLMConfiguration] = None
build_commands: List[str] = field(default_factory=list)
use_local_chains_src: bool = False

@property
def canonical_python_version(self) -> str:
Expand Down Expand Up @@ -619,6 +620,7 @@ def from_dict(d):
d.get("trt_llm"), lambda x: TRTLLMConfiguration(**x)
),
build_commands=d.get("build_commands", []),
use_local_chains_src=d.get("use_local_chains_src", False),
)
config.validate()
return config
Expand Down
8 changes: 6 additions & 2 deletions truss/contexts/image_builder/serving_image_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
AUDIO_MODEL_TRTLLM_TRUSS_DIR,
BASE_SERVER_REQUIREMENTS_TXT_FILENAME,
BASE_TRTLLM_REQUIREMENTS,
CHAINS_CODE_DIR,
CONTROL_SERVER_CODE_DIR,
DOCKER_SERVER_TEMPLATES_DIR,
FILENAME_CONSTANTS_MAP,
Expand Down Expand Up @@ -68,6 +69,7 @@
BUILD_SERVER_DIR_NAME = "server"
BUILD_CONTROL_SERVER_DIR_NAME = "control"
BUILD_SERVER_EXTENSIONS_PATH = "extensions"
BUILD_CHAINS_DIR_NAME = "truss_chains"

CONFIG_FILE = "config.yaml"
USER_TRUSS_IGNORE_FILE = ".truss_ignore"
Expand Down Expand Up @@ -356,8 +358,6 @@ def prepare_image_build_dir(
# TODO(pankaj) We probably don't need model framework specific directory.
build_dir = build_truss_target_directory(model_framework_name)

data_dir = build_dir / config.data_dir # type: ignore[operator]

def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
copy_tree_or_file(from_path, build_dir / path_in_build_dir) # type: ignore[operator]

Expand Down Expand Up @@ -464,6 +464,9 @@ def copy_into_build_dir(from_path: Path, path_in_build_dir: str):
+ SHARED_SERVING_AND_TRAINING_CODE_DIR_NAME,
)

if config.use_local_chains_src:
copy_into_build_dir(CHAINS_CODE_DIR, BUILD_CHAINS_DIR_NAME)

# Copy base TrussServer requirements if supplied custom base image
base_truss_server_reqs_filepath = SERVER_CODE_DIR / REQUIREMENTS_TXT_FILENAME
if config.base_image:
Expand Down Expand Up @@ -604,6 +607,7 @@ def _render_dockerfile(
hf_access_token_file_name=HF_ACCESS_TOKEN_FILE_NAME,
external_data_files=external_data_files,
build_commands=build_commands,
use_local_chains_src=config.use_local_chains_src,
**FILENAME_CONSTANTS_MAP,
)
docker_file_path = build_dir / MODEL_DOCKERFILE_NAME
Expand Down
4 changes: 4 additions & 0 deletions truss/templates/server.Dockerfile.jinja
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ COPY ./{{config.data_dir}} /app/data

COPY ./server /app

{%- if use_local_chains_src %}
COPY ./truss_chains /app/truss_chains
{%- endif %}

COPY ./config.yaml /app/config.yaml
{%- if config.live_reload and not config.docker_server%}
COPY ./control /control
Expand Down
5 changes: 3 additions & 2 deletions truss/truss_handle/truss_handle.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import uuid
from dataclasses import replace
from pathlib import Path
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, Union
from urllib.error import HTTPError

import requests
Expand Down Expand Up @@ -46,6 +46,7 @@
ServingImageBuilderContext,
)
from truss.contexts.local_loader.load_model_local import LoadModelLocal
from truss.contexts.truss_context import TrussContext
from truss.local.local_config_handler import LocalConfigHandler
from truss.templates.shared.serialization import (
truss_msgpack_deserialize,
Expand Down Expand Up @@ -950,7 +951,7 @@ def _get_serving_lookup_labels(self) -> Dict[str, Any]:

def _build_image(
self,
builder_context,
builder_context: Type[TrussContext],
labels: Dict[str, str],
build_dir: Optional[Path] = None,
tag: Optional[str] = None,
Expand Down

0 comments on commit 471b1c5

Please sign in to comment.