From f345d2c585ad10169234f64785d2d203030ed453 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 1 Oct 2024 13:33:44 -0700 Subject: [PATCH 1/8] Implement "streaming" multipart requests to multipart endpoint - Use streaming multipart encoder form requests_toolbelt - Currently dump each part to json before sending the request as that's the only way to enforce the payload size limit - When we lift payload size limit we should implement true streaming encoding, where each part is only encoded immediately before being sent over the connection, and use transfer-encoding: chunked --- python/langsmith/client.py | 169 ++++++++++++++++++++++++++++++++++++- 1 file changed, 168 insertions(+), 1 deletion(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 62331996f..5c0f8228e 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -63,6 +63,7 @@ import orjson import requests from requests import adapters as requests_adapters +from requests_toolbelt.multipart import MultipartEncoder from typing_extensions import TypeGuard from urllib3.util import Retry @@ -92,6 +93,8 @@ class ZoneInfo: # type: ignore[no-redef] X_API_KEY = "x-api-key" WARNED_ATTACHMENTS = False EMPTY_SEQ: tuple[Dict, ...] = () +BOUNDARY = uuid.uuid4().hex +MultipartParts = List[Tuple[str, Tuple[None, bytes, str]]] def _parse_token_or_url( @@ -1538,6 +1541,167 @@ def _post_batch_ingest_runs(self, body: bytes, *, _context: str): except Exception: logger.warning(f"Failed to batch ingest runs: {repr(e)}") + def multipart_ingest_runs( + self, + create: Optional[ + Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]] + ] = None, + update: Optional[ + Sequence[Union[ls_schemas.Run, ls_schemas.RunLikeDict, Dict]] + ] = None, + *, + pre_sampled: bool = False, + ): + """Batch ingest/upsert multiple runs in the Langsmith system. + + Args: + create (Optional[Sequence[Union[ls_schemas.Run, RunLikeDict]]]): + A sequence of `Run` objects or equivalent dictionaries representing + runs to be created / posted. + update (Optional[Sequence[Union[ls_schemas.Run, RunLikeDict]]]): + A sequence of `Run` objects or equivalent dictionaries representing + runs that have already been created and should be updated / patched. + pre_sampled (bool, optional): Whether the runs have already been subject + to sampling, and therefore should not be sampled again. + Defaults to False. + + Returns: + None: If both `create` and `update` are None. + + Raises: + LangsmithAPIError: If there is an error in the API request. + + Note: + - The run objects MUST contain the dotted_order and trace_id fields + to be accepted by the API. + """ + if not create and not update: + return + # transform and convert to dicts + all_attachments: Dict[str, ls_schemas.Attachments] = {} + create_dicts = [ + self._run_transform(run, attachments_collector=all_attachments) + for run in create or EMPTY_SEQ + ] + update_dicts = [ + self._run_transform(run, update=True, attachments_collector=all_attachments) + for run in update or EMPTY_SEQ + ] + # combine post and patch dicts where possible + if update_dicts and create_dicts: + create_by_id = {run["id"]: run for run in create_dicts} + standalone_updates: list[dict] = [] + for run in update_dicts: + if run["id"] in create_by_id: + for k, v in run.items(): + if v is not None: + create_by_id[run["id"]][k] = v + else: + standalone_updates.append(run) + update_dicts = standalone_updates + for run in create_dicts: + if not run.get("trace_id") or not run.get("dotted_order"): + raise ls_utils.LangSmithUserError( + "Batch ingest requires trace_id and dotted_order to be set." + ) + for run in update_dicts: + if not run.get("trace_id") or not run.get("dotted_order"): + raise ls_utils.LangSmithUserError( + "Batch ingest requires trace_id and dotted_order to be set." + ) + # filter out runs that are not sampled + if not pre_sampled: + create_dicts = self._filter_for_sampling(create_dicts) + update_dicts = self._filter_for_sampling(update_dicts, patch=True) + if not create_dicts and not update_dicts: + return + # insert runtime environment + self._insert_runtime_env(create_dicts) + self._insert_runtime_env(update_dicts) + # check size limit + size_limit_bytes = (self.info.batch_ingest_config or {}).get( + "size_limit_bytes" + ) or _SIZE_LIMIT_BYTES + # send the runs in multipart requests + acc_size = 0 + acc_context: List[str] = [] + acc_parts: MultipartParts = [] + for event, payloads in (("post", create_dicts), ("patch", update_dicts)): + for payload in payloads: + parts: MultipartParts = [] + # collect fields to be sent as separate parts + fields = [ + ("inputs", run.pop("inputs", None)), + ("outputs", run.pop("outputs", None)), + ("serialized", run.pop("serialized", None)), + ("events", run.pop("events", None)), + ] + # encode the main run payload + parts.append( + ( + f"{event}.{payload['id']}", + (None, _dumps_json(payload), "application/json"), + ) + ) + # encode the fields we collected + for key, value in fields: + if value is None: + continue + parts.append( + ( + f"{event}.{run['id']}.{key}", + (None, _dumps_json(value), "application/json"), + ), + ) + # encode the attachments + if attachments := all_attachments.pop(payload["id"], None): + for n, (ct, ba) in attachments.items(): + parts.append( + (f"attachment.{payload['id']}.{n}", (None, ba, ct)) + ) + # calculate the size of the parts + size = sum(len(p[1][1]) for p in parts) + # compute context + context = f"trace={payload.get('trace_id')},id={payload.get('id')}" + # if next size would exceed limit, send the current parts + if acc_size + size > size_limit_bytes: + self._send_multipart_req(acc_parts, _context="; ".join(acc_context)) + else: + # otherwise, accumulate the parts + acc_size += size + acc_parts.extend(parts) + acc_context.append(context) + # send the remaining parts + if acc_parts: + self._send_multipart_req(acc_parts, _context="; ".join(acc_context)) + + def _send_multipart_req(self, parts: MultipartParts, *, _context: str): + for api_url, api_key in self._write_api_urls.items(): + try: + encoder = MultipartEncoder(parts, boundary=BOUNDARY) + self.request_with_retries( + "POST", + f"{api_url}/runs/multipart", + request_kwargs={ + "data": encoder, + "headers": { + **self._headers, + X_API_KEY: api_key, + "Content-Type": encoder.content_type, + }, + }, + to_ignore=(ls_utils.LangSmithConflictError,), + stop_after_attempt=3, + _context=_context, + ) + except Exception as e: + try: + exc_desc_lines = traceback.format_exception_only(type(e), e) + exc_desc = "".join(exc_desc_lines).rstrip() + logger.warning(f"Failed to multipart ingest runs: {exc_desc}") + except Exception: + logger.warning(f"Failed to multipart ingest runs: {repr(e)}") + def update_run( self, run_id: ID_TYPE, @@ -5593,7 +5757,10 @@ def _tracing_thread_handle_batch( create = [it.item for it in batch if it.action == "create"] update = [it.item for it in batch if it.action == "update"] try: - client.batch_ingest_runs(create=create, update=update, pre_sampled=True) + if use_multipart: + client.multipart_ingest_runs(create=create, update=update, pre_sampled=True) + else: + client.batch_ingest_runs(create=create, update=update, pre_sampled=True) except Exception: logger.error("Error in tracing queue", exc_info=True) # exceptions are logged elsewhere, but we need to make sure the From b515cdafc261d5bce3289d0771b5f5adc3fe056a Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 1 Oct 2024 13:44:41 -0700 Subject: [PATCH 2/8] Lint --- python/langsmith/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 5c0f8228e..228288279 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -63,7 +63,7 @@ import orjson import requests from requests import adapters as requests_adapters -from requests_toolbelt.multipart import MultipartEncoder +from requests_toolbelt.multipart import MultipartEncoder # type: ignore[import-untyped] from typing_extensions import TypeGuard from urllib3.util import Retry From 1cd973249bc8fcd1d914f8f1e07e7f9be70c55e1 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 1 Oct 2024 15:30:30 -0700 Subject: [PATCH 3/8] Fix up --- python/langsmith/client.py | 51 +++--- python/poetry.lock | 16 +- python/pyproject.toml | 1 + python/tests/unit_tests/test_client.py | 227 ++++++++++++++++++------- 4 files changed, 216 insertions(+), 79 deletions(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 228288279..275969ea3 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -1587,6 +1587,23 @@ def multipart_ingest_runs( self._run_transform(run, update=True, attachments_collector=all_attachments) for run in update or EMPTY_SEQ ] + # require trace_id and dotted_order + if create_dicts: + for run in create_dicts: + if not run.get("trace_id") or not run.get("dotted_order"): + raise ls_utils.LangSmithUserError( + "Batch ingest requires trace_id and dotted_order to be set." + ) + else: + del run + if update_dicts: + for run in update_dicts: + if not run.get("trace_id") or not run.get("dotted_order"): + raise ls_utils.LangSmithUserError( + "Batch ingest requires trace_id and dotted_order to be set." + ) + else: + del run # combine post and patch dicts where possible if update_dicts and create_dicts: create_by_id = {run["id"]: run for run in create_dicts} @@ -1598,17 +1615,9 @@ def multipart_ingest_runs( create_by_id[run["id"]][k] = v else: standalone_updates.append(run) + else: + del run update_dicts = standalone_updates - for run in create_dicts: - if not run.get("trace_id") or not run.get("dotted_order"): - raise ls_utils.LangSmithUserError( - "Batch ingest requires trace_id and dotted_order to be set." - ) - for run in update_dicts: - if not run.get("trace_id") or not run.get("dotted_order"): - raise ls_utils.LangSmithUserError( - "Batch ingest requires trace_id and dotted_order to be set." - ) # filter out runs that are not sampled if not pre_sampled: create_dicts = self._filter_for_sampling(create_dicts) @@ -1631,10 +1640,10 @@ def multipart_ingest_runs( parts: MultipartParts = [] # collect fields to be sent as separate parts fields = [ - ("inputs", run.pop("inputs", None)), - ("outputs", run.pop("outputs", None)), - ("serialized", run.pop("serialized", None)), - ("events", run.pop("events", None)), + ("inputs", payload.pop("inputs", None)), + ("outputs", payload.pop("outputs", None)), + ("serialized", payload.pop("serialized", None)), + ("events", payload.pop("events", None)), ] # encode the main run payload parts.append( @@ -1649,7 +1658,7 @@ def multipart_ingest_runs( continue parts.append( ( - f"{event}.{run['id']}.{key}", + f"{event}.{payload['id']}.{key}", (None, _dumps_json(value), "application/json"), ), ) @@ -1666,11 +1675,13 @@ def multipart_ingest_runs( # if next size would exceed limit, send the current parts if acc_size + size > size_limit_bytes: self._send_multipart_req(acc_parts, _context="; ".join(acc_context)) - else: - # otherwise, accumulate the parts - acc_size += size - acc_parts.extend(parts) - acc_context.append(context) + acc_parts.clear() + acc_context.clear() + acc_size = 0 + # accumulate the parts + acc_size += size + acc_parts.extend(parts) + acc_context.append(context) # send the remaining parts if acc_parts: self._send_multipart_req(acc_parts, _context="; ".join(acc_context)) diff --git a/python/poetry.lock b/python/poetry.lock index 4814514b9..9af73254e 100644 --- a/python/poetry.lock +++ b/python/poetry.lock @@ -685,6 +685,20 @@ files = [ [package.dependencies] typing-extensions = {version = ">=4.1.0", markers = "python_version < \"3.11\""} +[[package]] +name = "multipart" +version = "1.0.0" +description = "Parser for multipart/form-data" +optional = false +python-versions = ">=3.5" +files = [ + {file = "multipart-1.0.0-py3-none-any.whl", hash = "sha256:85824b3d48b63fe0b6f438feb2b39f9753512e889426fb339e96b6095d4239c8"}, + {file = "multipart-1.0.0.tar.gz", hash = "sha256:6ac937fe07cd4e11cf4ca199f3d8f668e6a37e0f477c5ee032673d45be7f7957"}, +] + +[package.extras] +dev = ["build", "pytest", "pytest-cov", "twine"] + [[package]] name = "mypy" version = "1.11.2" @@ -1903,4 +1917,4 @@ vcr = [] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "244ff8da405f7ae894fcee75f851ef3b51eb556e5ac56a84fc3eb4930851e75c" +content-hash = "a20ea8a3bba074fc87b54139dfe7c4c4ffea37d5d5f4b874fb759baad4d443d0" diff --git a/python/pyproject.toml b/python/pyproject.toml index 8791fbdff..700a17372 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -60,6 +60,7 @@ pytest-rerunfailures = "^14.0" pytest-socket = "^0.7.0" pyperf = "^2.7.0" py-spy = "^0.3.14" +multipart = "^1.0.0" [tool.poetry.group.lint.dependencies] openai = "^1.10" diff --git a/python/tests/unit_tests/test_client.py b/python/tests/unit_tests/test_client.py index 2e8b2043a..4e9fa7344 100644 --- a/python/tests/unit_tests/test_client.py +++ b/python/tests/unit_tests/test_client.py @@ -14,7 +14,7 @@ from datetime import datetime, timezone from enum import Enum from io import BytesIO -from typing import Dict, NamedTuple, Optional, Type, Union +from typing import Dict, List, NamedTuple, Optional, Type, Union from unittest import mock from unittest.mock import MagicMock, patch @@ -22,8 +22,10 @@ import orjson import pytest import requests +from multipart import MultipartParser, MultipartPart, parse_options_header from pydantic import BaseModel from requests import HTTPError +from requests_toolbelt.multipart import MultipartEncoder import langsmith.env as ls_env import langsmith.utils as ls_utils @@ -63,7 +65,7 @@ def test_validate_api_url(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("LANGCHAIN_ENDPOINT", "https://api.smith.langchain-endpoint.com") monkeypatch.setenv("LANGSMITH_ENDPOINT", "https://api.smith.langsmith-endpoint.com") - client = Client() + client = Client(auto_batch_tracing=False) assert client.api_url == "https://api.smith.langsmith-endpoint.com" # Scenario 2: Both LANGCHAIN_ENDPOINT and LANGSMITH_ENDPOINT @@ -72,7 +74,11 @@ def test_validate_api_url(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("LANGCHAIN_ENDPOINT", "https://api.smith.langchain-endpoint.com") monkeypatch.setenv("LANGSMITH_ENDPOINT", "https://api.smith.langsmith-endpoint.com") - client = Client(api_url="https://api.smith.langchain.com", api_key="123") + client = Client( + api_url="https://api.smith.langchain.com", + api_key="123", + auto_batch_tracing=False, + ) assert client.api_url == "https://api.smith.langchain.com" # Scenario 3: LANGCHAIN_ENDPOINT is set, but LANGSMITH_ENDPOINT is not @@ -80,7 +86,7 @@ def test_validate_api_url(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setenv("LANGCHAIN_ENDPOINT", "https://api.smith.langchain-endpoint.com") monkeypatch.delenv("LANGSMITH_ENDPOINT", raising=False) - client = Client() + client = Client(auto_batch_tracing=False) assert client.api_url == "https://api.smith.langchain-endpoint.com" # Scenario 4: LANGCHAIN_ENDPOINT is not set, but LANGSMITH_ENDPOINT is set @@ -88,7 +94,7 @@ def test_validate_api_url(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("LANGCHAIN_ENDPOINT", raising=False) monkeypatch.setenv("LANGSMITH_ENDPOINT", "https://api.smith.langsmith-endpoint.com") - client = Client() + client = Client(auto_batch_tracing=False) assert client.api_url == "https://api.smith.langsmith-endpoint.com" @@ -152,12 +158,13 @@ def test_validate_multiple_urls(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.delenv("LANGCHAIN_ENDPOINT", raising=False) monkeypatch.delenv("LANGSMITH_ENDPOINT", raising=False) monkeypatch.setenv("LANGSMITH_RUNS_ENDPOINTS", json.dumps(data)) - client = Client() + client = Client(auto_batch_tracing=False) assert client._write_api_urls == data assert client.api_url == "https://api.smith.langsmith-endpoint_1.com" assert client.api_key == "123" +@mock.patch("langsmith.client.requests.Session") def test_headers(monkeypatch: pytest.MonkeyPatch) -> None: _clear_env_cache() monkeypatch.delenv("LANGCHAIN_API_KEY", raising=False) @@ -276,7 +283,8 @@ def test_create_run_unicode() -> None: client.update_run(id_, status="completed") -def test_create_run_mutate() -> None: +@pytest.mark.parametrize("use_multipart_endpoint", (True, False)) +def test_create_run_mutate(use_multipart_endpoint: bool) -> None: inputs = {"messages": ["hi"], "mygen": (i for i in range(10))} session = mock.Mock() session.request = mock.Mock() @@ -286,6 +294,7 @@ def test_create_run_mutate() -> None: session=session, info=ls_schemas.LangSmithInfo( batch_ingest_config=ls_schemas.BatchIngestConfig( + use_multipart_endpoint=use_multipart_endpoint, size_limit_bytes=None, # Note this field is not used here size_limit=100, scale_up_nthreads_limit=16, @@ -315,33 +324,91 @@ def test_create_run_mutate() -> None: trace_id=id_, dotted_order=run_dict["dotted_order"], ) - for _ in range(10): - time.sleep(0.1) # Give the background thread time to stop - payloads = [ - json.loads(call[2]["data"]) - for call in session.request.mock_calls - if call.args and call.args[1].endswith("runs/batch") + if use_multipart_endpoint: + for _ in range(10): + time.sleep(0.1) # Give the background thread time to stop + payloads = [ + (call[2]["headers"], call[2]["data"]) + for call in session.request.mock_calls + if call.args and call.args[1].endswith("runs/multipart") + ] + if payloads: + break + else: + assert False, "No payloads found" + + parts: List[MultipartPart] = [] + for payload in payloads: + headers, data = payload + assert headers["Content-Type"].startswith("multipart/form-data") + # this is a current implementation detail, if we change implementation + # we update this assertion + assert isinstance(data, MultipartEncoder) + boundary = parse_options_header(headers["Content-Type"])[1]["boundary"] + parser = MultipartParser(data, boundary) + parts.extend(parser.parts()) + + assert len(parts) == 3 + assert [p.name for p in parts] == [ + f"post.{id_}", + f"post.{id_}.inputs", + f"post.{id_}.outputs", ] - if payloads: - break - posts = [pr for payload in payloads for pr in payload.get("post", [])] - patches = [pr for payload in payloads for pr in payload.get("patch", [])] - inputs = next( - (pr["inputs"] for pr in itertools.chain(posts, patches) if pr.get("inputs")), - {}, - ) - outputs = next( - (pr["outputs"] for pr in itertools.chain(posts, patches) if pr.get("outputs")), - {}, - ) - # Check that the mutated value wasn't posted - assert "messages" in inputs - assert inputs["messages"] == ["hi"] - assert "mygen" in inputs - assert inputs["mygen"].startswith( # type: ignore - "." - ) - assert outputs == {"messages": ["hi", "there"]} + assert [p.headers.get("content-type") for p in parts] == [ + "application/json", + "application/json", + "application/json", + ] + outputs_parsed = json.loads(parts[2].value) + assert outputs_parsed == outputs + inputs_parsed = json.loads(parts[1].value) + assert inputs_parsed["messages"] == ["hi"] + assert inputs_parsed["mygen"].startswith( # type: ignore + "." + ) + run_parsed = json.loads(parts[0].value) + assert "inputs" not in run_parsed + assert "outputs" not in run_parsed + assert run_parsed["trace_id"] == str(id_) + assert run_parsed["dotted_order"] == run_dict["dotted_order"] + else: + for _ in range(10): + time.sleep(0.1) # Give the background thread time to stop + payloads = [ + json.loads(call[2]["data"]) + for call in session.request.mock_calls + if call.args and call.args[1].endswith("runs/batch") + ] + if payloads: + break + else: + assert False, "No payloads found" + posts = [pr for payload in payloads for pr in payload.get("post", [])] + patches = [pr for payload in payloads for pr in payload.get("patch", [])] + inputs = next( + ( + pr["inputs"] + for pr in itertools.chain(posts, patches) + if pr.get("inputs") + ), + {}, + ) + outputs = next( + ( + pr["outputs"] + for pr in itertools.chain(posts, patches) + if pr.get("outputs") + ), + {}, + ) + # Check that the mutated value wasn't posted + assert "messages" in inputs + assert inputs["messages"] == ["hi"] + assert "mygen" in inputs + assert inputs["mygen"].startswith( # type: ignore + "." + ) + assert outputs == {"messages": ["hi", "there"]} class CallTracker: @@ -951,7 +1018,10 @@ def test_batch_ingest_run_retry_on_429(mock_raise_for_status): @pytest.mark.parametrize("payload_size", [MB, 5 * MB, 9 * MB, 21 * MB]) -def test_batch_ingest_run_splits_large_batches(payload_size: int): +@pytest.mark.parametrize("use_multipart_endpoint", (True, False)) +def test_batch_ingest_run_splits_large_batches( + payload_size: int, use_multipart_endpoint: bool +): mock_session = MagicMock() client = Client(api_key="test", session=mock_session) mock_response = MagicMock() @@ -981,36 +1051,76 @@ def test_batch_ingest_run_splits_large_batches(payload_size: int): } for run_id in patch_ids ] - client.batch_ingest_runs(create=posts, update=patches) - # we can support up to 20MB per batch, so we need to find the number of batches - # we should be sending - max_in_batch = max(1, (20 * MB) // (payload_size + 20)) + if use_multipart_endpoint: + client.multipart_ingest_runs(create=posts, update=patches) + # we can support up to 20MB per batch, so we need to find the number of batches + # we should be sending + max_in_batch = max(1, (20 * MB) // (payload_size + 20)) + + expected_num_requests = min(6, math.ceil((len(run_ids) * 2) / max_in_batch)) + # count the number of POST requests + assert sum( + [1 for call in mock_session.request.call_args_list if call[0][0] == "POST"] + ) in (expected_num_requests, expected_num_requests + 1) + request_bodies = [ + op + for call in mock_session.request.call_args_list + for op in ( + MultipartParser( + call[1]["data"], + parse_options_header(call[1]["headers"]["Content-Type"])[1][ + "boundary" + ], + ) + if call[0][0] == "POST" + else [] + ) + ] + all_run_ids = run_ids + patch_ids - expected_num_requests = min(6, math.ceil((len(run_ids) * 2) / max_in_batch)) - # count the number of POST requests - assert ( - sum([1 for call in mock_session.request.call_args_list if call[0][0] == "POST"]) - == expected_num_requests - ) - request_bodies = [ - op - for call in mock_session.request.call_args_list - for reqs in ( - orjson.loads(call[1]["data"]).values() if call[0][0] == "POST" else [] + # Check that all the run_ids are present in the request bodies + for run_id in all_run_ids: + assert any( + [body.name.split(".")[1] == run_id for body in request_bodies] + ), run_id + else: + client.batch_ingest_runs(create=posts, update=patches) + # we can support up to 20MB per batch, so we need to find the number of batches + # we should be sending + max_in_batch = max(1, (20 * MB) // (payload_size + 20)) + + expected_num_requests = min(6, math.ceil((len(run_ids) * 2) / max_in_batch)) + # count the number of POST requests + assert ( + sum( + [ + 1 + for call in mock_session.request.call_args_list + if call[0][0] == "POST" + ] + ) + == expected_num_requests ) - for op in reqs - ] - all_run_ids = run_ids + patch_ids + request_bodies = [ + op + for call in mock_session.request.call_args_list + for reqs in ( + orjson.loads(call[1]["data"]).values() if call[0][0] == "POST" else [] + ) + for op in reqs + ] + all_run_ids = run_ids + patch_ids - # Check that all the run_ids are present in the request bodies - for run_id in all_run_ids: - assert any([body["id"] == str(run_id) for body in request_bodies]) + # Check that all the run_ids are present in the request bodies + for run_id in all_run_ids: + assert any([body["id"] == str(run_id) for body in request_bodies]) - # Check that no duplicate run_ids are present in the request bodies - assert len(request_bodies) == len(set([body["id"] for body in request_bodies])) + # Check that no duplicate run_ids are present in the request bodies + assert len(request_bodies) == len(set([body["id"] for body in request_bodies])) -def test_select_eval_results(): +@mock.patch("langsmith.client.requests.Session") +def test_select_eval_results(mock_session_cls: mock.Mock): expected = EvaluationResult( key="foo", value="bar", @@ -1050,6 +1160,7 @@ def test_select_eval_results(): @pytest.mark.parametrize("client_cls", [Client, AsyncClient]) +@mock.patch("langsmith.client.requests.Session") def test_validate_api_key_if_hosted( monkeypatch: pytest.MonkeyPatch, client_cls: Union[Type[Client], Type[AsyncClient]] ) -> None: From a217db7cb2580f56337e6f385336f35a5711220f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 1 Oct 2024 15:49:03 -0700 Subject: [PATCH 4/8] Add FF --- python/langsmith/client.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 275969ea3..f64e62e43 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -5820,7 +5820,10 @@ def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None: size_limit: int = batch_ingest_config["size_limit"] scale_up_nthreads_limit: int = batch_ingest_config["scale_up_nthreads_limit"] scale_up_qsize_trigger: int = batch_ingest_config["scale_up_qsize_trigger"] - use_multipart: bool = batch_ingest_config.get("use_multipart_endpoint", False) + use_multipart: bool = os.getenv( + "LANGSMITH_FF_MULTIPART", + batch_ingest_config.get("use_multipart_endpoint", False), + ) sub_threads: List[threading.Thread] = [] # 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached From 454622b8cf3f38eaddf4e03dcaf3af26974c57db Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 1 Oct 2024 15:49:11 -0700 Subject: [PATCH 5/8] Add integration test --- python/tests/integration_tests/test_client.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/tests/integration_tests/test_client.py b/python/tests/integration_tests/test_client.py index 87c2c6f94..be4f27147 100644 --- a/python/tests/integration_tests/test_client.py +++ b/python/tests/integration_tests/test_client.py @@ -615,7 +615,10 @@ def test_create_chat_example( langchain_client.delete_dataset(dataset_id=dataset.id) -def test_batch_ingest_runs(langchain_client: Client) -> None: +@pytest.mark.parametrize("use_multipart_endpoint", [True, False]) +def test_batch_ingest_runs( + langchain_client: Client, use_multipart_endpoint: bool +) -> None: _session = "__test_batch_ingest_runs" trace_id = uuid4() trace_id_2 = uuid4() @@ -669,7 +672,12 @@ def test_batch_ingest_runs(langchain_client: Client) -> None: "outputs": {"output1": 4, "output2": 5}, }, ] - langchain_client.batch_ingest_runs(create=runs_to_create, update=runs_to_update) + if use_multipart_endpoint: + langchain_client.multipart_ingest_runs( + create=runs_to_create, update=runs_to_update + ) + else: + langchain_client.batch_ingest_runs(create=runs_to_create, update=runs_to_update) runs = [] wait = 4 for _ in range(15): From d8d582faaecc6cd9e1021c525dc05e491d1eef0d Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 1 Oct 2024 16:01:38 -0700 Subject: [PATCH 6/8] Fix for urllib<2 --- python/langsmith/client.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index f64e62e43..0250fdc4a 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -39,6 +39,7 @@ import warnings import weakref from dataclasses import dataclass, field +from inspect import signature from queue import Empty, PriorityQueue, Queue from typing import ( TYPE_CHECKING, @@ -65,6 +66,7 @@ from requests import adapters as requests_adapters from requests_toolbelt.multipart import MultipartEncoder # type: ignore[import-untyped] from typing_extensions import TypeGuard +from urllib3.poolmanager import PoolKey from urllib3.util import Retry import langsmith @@ -95,6 +97,7 @@ class ZoneInfo: # type: ignore[no-redef] EMPTY_SEQ: tuple[Dict, ...] = () BOUNDARY = uuid.uuid4().hex MultipartParts = List[Tuple[str, Tuple[None, bytes, str]]] +URLLIB3_SUPPORTS_BLOCKSIZE = "key_blocksize" in signature(PoolKey).parameters def _parse_token_or_url( @@ -462,7 +465,9 @@ def __init__( super().__init__(pool_connections, pool_maxsize, max_retries, pool_block) def init_poolmanager(self, connections, maxsize, block=False, **pool_kwargs): - pool_kwargs["blocksize"] = self._blocksize + if URLLIB3_SUPPORTS_BLOCKSIZE: + # urllib3 before 2.0 doesn't support blocksize + pool_kwargs["blocksize"] = self._blocksize return super().init_poolmanager(connections, maxsize, block, **pool_kwargs) From 69dc1e27e3d4876835a3abe06763bf7e93c13064 Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 1 Oct 2024 16:04:21 -0700 Subject: [PATCH 7/8] Lint --- python/langsmith/client.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 0250fdc4a..8d5cde35e 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -66,7 +66,7 @@ from requests import adapters as requests_adapters from requests_toolbelt.multipart import MultipartEncoder # type: ignore[import-untyped] from typing_extensions import TypeGuard -from urllib3.poolmanager import PoolKey +from urllib3.poolmanager import PoolKey # type: ignore[attr-defined] from urllib3.util import Retry import langsmith From 08ec720e8f1de3d848aaed8656546646a3524c1f Mon Sep 17 00:00:00 2001 From: Nuno Campos Date: Tue, 1 Oct 2024 16:07:07 -0700 Subject: [PATCH 8/8] Lint --- python/langsmith/client.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/langsmith/client.py b/python/langsmith/client.py index 8d5cde35e..19a0d09dc 100644 --- a/python/langsmith/client.py +++ b/python/langsmith/client.py @@ -5825,10 +5825,10 @@ def _tracing_control_thread_func(client_ref: weakref.ref[Client]) -> None: size_limit: int = batch_ingest_config["size_limit"] scale_up_nthreads_limit: int = batch_ingest_config["scale_up_nthreads_limit"] scale_up_qsize_trigger: int = batch_ingest_config["scale_up_qsize_trigger"] - use_multipart: bool = os.getenv( - "LANGSMITH_FF_MULTIPART", - batch_ingest_config.get("use_multipart_endpoint", False), - ) + if multipart_override := os.getenv("LANGSMITH_FF_MULTIPART"): + use_multipart = multipart_override.lower() in ["1", "true"] + else: + use_multipart = batch_ingest_config.get("use_multipart_endpoint", False) sub_threads: List[threading.Thread] = [] # 1 for this func, 1 for getrefcount, 1 for _get_data_type_cached