diff --git a/README.md b/README.md index b20ab61e4..53ff909ca 100644 --- a/README.md +++ b/README.md @@ -370,6 +370,149 @@ async def main() -> None: print(message.content) +await main() +``` + + +### Mistral 〽️ + +Track agents built with the Anthropic Python SDK (>=0.32.0). + +- [AgentOps integration example](./examples/mistral//mistral_example.ipynb) +- [Official Mistral documentation](https://docs.mistral.ai) + +
+ Installation + +```bash +pip install mistralai +``` + +Sync + +```python python +from mistralai import Mistral +import agentops + +# Beginning of program's code (i.e. main.py, __init__.py) +agentops.init() + +client = Mistral( + # This is the default and can be omitted + api_key=os.environ.get("MISTRAL_API_KEY"), +) + +message = client.chat.complete( + messages=[ + { + "role": "user", + "content": "Tell me a cool fact about AgentOps", + } + ], + model="open-mistral-nemo", + ) +print(message.choices[0].message.content) + +agentops.end_session('Success') +``` + +Streaming + +```python python +from mistralai import Mistral +import agentops + +# Beginning of program's code (i.e. main.py, __init__.py) +agentops.init() + +client = Mistral( + # This is the default and can be omitted + api_key=os.environ.get("MISTRAL_API_KEY"), +) + +message = client.chat.stream( + messages=[ + { + "role": "user", + "content": "Tell me something cool about streaming agents", + } + ], + model="open-mistral-nemo", + ) + +response = "" +for event in message: + if event.data.choices[0].finish_reason == "stop": + print("\n") + print(response) + print("\n") + else: + response += event.text + +agentops.end_session('Success') +``` + +Async + +```python python +import asyncio +from mistralai import Mistral + +client = Mistral( + # This is the default and can be omitted + api_key=os.environ.get("MISTRAL_API_KEY"), +) + + +async def main() -> None: + message = await client.chat.complete_async( + messages=[ + { + "role": "user", + "content": "Tell me something interesting about async agents", + } + ], + model="open-mistral-nemo", + ) + print(message.choices[0].message.content) + + +await main() +``` + +Async Streaming + +```python python +import asyncio +from mistralai import Mistral + +client = Mistral( + # This is the default and can be omitted + api_key=os.environ.get("MISTRAL_API_KEY"), +) + + +async def main() -> None: + message = await client.chat.stream_async( + messages=[ + { + "role": "user", + "content": "Tell me something interesting about async streaming agents", + } + ], + model="open-mistral-nemo", + ) + + response = "" + async for event in message: + if event.data.choices[0].finish_reason == "stop": + print("\n") + print(response) + print("\n") + else: + response += event.text + + await main() ```
diff --git a/agentops/llms/__init__.py b/agentops/llms/__init__.py index 380970af8..1bd8c3b7b 100644 --- a/agentops/llms/__init__.py +++ b/agentops/llms/__init__.py @@ -13,6 +13,7 @@ from .ollama import OllamaProvider from .openai import OpenAiProvider from .anthropic import AnthropicProvider +from .mistral import MistralProvider from .ai21 import AI21Provider original_func = {} @@ -40,6 +41,9 @@ class LlmTracker: "anthropic": { "0.32.0": ("completions.create",), }, + "mistralai": { + "1.0.1": ("chat.complete", "chat.stream"), + }, "ai21": { "2.0.0": ( "chat.completions.create", @@ -142,6 +146,17 @@ def override_api(self): f"Only Anthropic>=0.32.0 supported. v{module_version} found." ) + if api == "mistralai": + module_version = version(api) + + if Version(module_version) >= parse("1.0.1"): + provider = MistralProvider(self.client) + provider.override() + else: + logger.warning( + f"Only MistralAI>=1.0.1 supported. v{module_version} found." + ) + if api == "ai21": module_version = version(api) @@ -165,4 +180,5 @@ def stop_instrumenting(self): LiteLLMProvider(self.client).undo_override() OllamaProvider(self.client).undo_override() AnthropicProvider(self.client).undo_override() + MistralProvider(self.client).undo_override() AI21Provider(self.client).undo_override() diff --git a/agentops/llms/mistral.py b/agentops/llms/mistral.py new file mode 100644 index 000000000..8be219469 --- /dev/null +++ b/agentops/llms/mistral.py @@ -0,0 +1,220 @@ +import inspect +import pprint +import sys +from typing import Optional + +from ..event import LLMEvent, ErrorEvent +from ..session import Session +from ..log_config import logger +from agentops.helpers import get_ISO_time, check_call_stack_for_agent_id +from .instrumented_provider import InstrumentedProvider + + +class MistralProvider(InstrumentedProvider): + + original_complete = None + original_complete_async = None + original_stream = None + original_stream_async = None + + def __init__(self, client): + super().__init__(client) + self._provider_name = "Mistral" + + def handle_response( + self, response, kwargs, init_timestamp, session: Optional[Session] = None + ) -> dict: + """Handle responses for Mistral""" + from mistralai import Chat + from mistralai.types import UNSET, UNSET_SENTINEL + + llm_event = LLMEvent(init_timestamp=init_timestamp, params=kwargs) + if session is not None: + llm_event.session_id = session.session_id + + def handle_stream_chunk(chunk: dict): + # NOTE: prompt/completion usage not returned in response when streaming + # We take the first ChatCompletionChunk and accumulate the deltas from all subsequent chunks to build one full chat completion + if llm_event.returns is None: + llm_event.returns = chunk.data + + try: + accumulated_delta = llm_event.returns.choices[0].delta + llm_event.agent_id = check_call_stack_for_agent_id() + llm_event.model = "mistral/" + chunk.data.model + llm_event.prompt = kwargs["messages"] + + # NOTE: We assume for completion only choices[0] is relevant + choice = chunk.data.choices[0] + + if choice.delta.content: + accumulated_delta.content += choice.delta.content + + if choice.delta.role: + accumulated_delta.role = choice.delta.role + + # Check if tool_calls is Unset and set to None if it is + if choice.delta.tool_calls in (UNSET, UNSET_SENTINEL): + accumulated_delta.tool_calls = None + elif choice.delta.tool_calls: + accumulated_delta.tool_calls = choice.delta.tool_calls + + if choice.finish_reason: + # Streaming is done. Record LLMEvent + llm_event.returns.choices[0].finish_reason = choice.finish_reason + llm_event.completion = { + "role": accumulated_delta.role, + "content": accumulated_delta.content, + "tool_calls": accumulated_delta.tool_calls, + } + llm_event.prompt_tokens = chunk.data.usage.prompt_tokens + llm_event.completion_tokens = chunk.data.usage.completion_tokens + llm_event.end_timestamp = get_ISO_time() + self._safe_record(session, llm_event) + + except Exception as e: + self._safe_record( + session, ErrorEvent(trigger_event=llm_event, exception=e) + ) + + kwargs_str = pprint.pformat(kwargs) + chunk = pprint.pformat(chunk) + logger.warning( + f"Unable to parse a chunk for LLM call. Skipping upload to AgentOps\n" + f"chunk:\n {chunk}\n" + f"kwargs:\n {kwargs_str}\n" + ) + + # if the response is a generator, decorate the generator + if inspect.isgenerator(response): + + def generator(): + for chunk in response: + handle_stream_chunk(chunk) + yield chunk + + return generator() + + elif inspect.isasyncgen(response): + + async def async_generator(): + async for chunk in response: + handle_stream_chunk(chunk) + yield chunk + + return async_generator() + + try: + llm_event.returns = response + llm_event.agent_id = check_call_stack_for_agent_id() + llm_event.model = "mistral/" + response.model + llm_event.prompt = kwargs["messages"] + llm_event.prompt_tokens = response.usage.prompt_tokens + llm_event.completion = response.choices[0].message.model_dump() + llm_event.completion_tokens = response.usage.completion_tokens + llm_event.end_timestamp = get_ISO_time() + + self._safe_record(session, llm_event) + except Exception as e: + self._safe_record(session, ErrorEvent(trigger_event=llm_event, exception=e)) + kwargs_str = pprint.pformat(kwargs) + response = pprint.pformat(response) + logger.warning( + f"Unable to parse response for LLM call. Skipping upload to AgentOps\n" + f"response:\n {response}\n" + f"kwargs:\n {kwargs_str}\n" + ) + + return response + + def _override_complete(self): + from mistralai import Chat + + global original_complete + original_complete = Chat.complete + + def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + result = original_complete(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) + + # Override the original method with the patched one + Chat.complete = patched_function + + def _override_complete_async(self): + from mistralai import Chat + + global original_complete_async + original_complete_async = Chat.complete_async + + async def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + result = await original_complete_async(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) + + # Override the original method with the patched one + Chat.complete_async = patched_function + + def _override_stream(self): + from mistralai import Chat + + global original_stream + original_stream = Chat.stream + + def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + result = original_stream(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) + + # Override the original method with the patched one + Chat.stream = patched_function + + def _override_stream_async(self): + from mistralai import Chat + + global original_stream_async + original_stream_async = Chat.stream_async + + async def patched_function(*args, **kwargs): + # Call the original function with its original arguments + init_timestamp = get_ISO_time() + session = kwargs.get("session", None) + if "session" in kwargs.keys(): + del kwargs["session"] + result = await original_stream_async(*args, **kwargs) + return self.handle_response(result, kwargs, init_timestamp, session=session) + + # Override the original method with the patched one + Chat.stream_async = patched_function + + def override(self): + self._override_complete() + self._override_complete_async() + self._override_stream() + self._override_stream_async() + + def undo_override(self): + if ( + self.original_complete is not None + and self.original_complete_async is not None + and self.original_stream is not None + and self.original_stream_async is not None + ): + from mistralai import Chat + + Chat.complete = self.original_complete + Chat.complete_async = self.original_complete_async + Chat.stream = self.original_stream + Chat.stream_async = self.original_stream_async diff --git a/examples/mistral_examples/mistral_example.ipynb b/examples/mistral_examples/mistral_example.ipynb new file mode 100644 index 000000000..228f6827c --- /dev/null +++ b/examples/mistral_examples/mistral_example.ipynb @@ -0,0 +1,226 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Mistral Examples\n", + "Uses the mistralai library to interact with Mistral" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "First let's install the required packages" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -U mistralai\n", + "%pip install -U agentops" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Then import them" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from mistralai import Mistral\n", + "from dotenv import load_dotenv\n", + "import os\n", + "import agentops" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll grab our API keys. You can use dotenv like below or however else you like to load environment variables" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "load_dotenv()\n", + "MISTRAL_API_KEY = os.getenv(\"MISTRAL_API_KEY\") or \"\"\n", + "AGENTOPS_API_KEY = os.getenv(\"AGENTOPS_API_KEY\") or \"\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agentops.init(AGENTOPS_API_KEY, default_tags=[\"mistral-example\"])\n", + "client = Mistral(MISTRAL_API_KEY)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sync Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = client.chat.complete(\n", + " model=\"open-mistral-nemo\",\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What is 2+2?\",\n", + " },\n", + " ],\n", + ")\n", + "print(response.choices[0].message.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Async Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def main() -> None:\n", + " response = await client.chat.complete_async(\n", + " model=\"open-mistral-nemo\",\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"Explain step by step what is 2+2?\",\n", + " }\n", + " ],\n", + " )\n", + " print(response.choices[0].message.content)\n", + "\n", + "\n", + "await main()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Sync Stream Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "response = client.chat.stream(\n", + " model=\"open-mistral-nemo\",\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What is the Leibniz theorem?\",\n", + " }\n", + " ],\n", + ")\n", + "\n", + "result = \"\"\n", + "for event in response:\n", + " if event.data.choices[0].finish_reason == \"stop\":\n", + " print(result)\n", + " else:\n", + " result += event.data.choices[0].delta.content" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Async Stream Example" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "async def main() -> None:\n", + " response = await client.chat.stream_async(\n", + " model=\"open-mistral-nemo\",\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": \"What is the meaning of life?\",\n", + " }\n", + " ],\n", + " )\n", + "\n", + " result = \"\"\n", + " async for event in response:\n", + " if event.data.choices[0].finish_reason == \"stop\":\n", + " print(result)\n", + " else:\n", + " result += event.data.choices[0].delta.content\n", + "\n", + "\n", + "await main()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "agentops.end_session(\"Success\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "ops", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tests/core_manual_tests/providers/anthropic_canary.py b/tests/core_manual_tests/providers/anthropic_canary.py index bc505606c..52acf0ab3 100644 --- a/tests/core_manual_tests/providers/anthropic_canary.py +++ b/tests/core_manual_tests/providers/anthropic_canary.py @@ -27,7 +27,7 @@ messages=[ { "role": "user", - "content": "asy hi 2", + "content": "say hi 2", } ], stream=True, diff --git a/tests/core_manual_tests/providers/mistral_canary.py b/tests/core_manual_tests/providers/mistral_canary.py new file mode 100644 index 000000000..9fdfd134e --- /dev/null +++ b/tests/core_manual_tests/providers/mistral_canary.py @@ -0,0 +1,86 @@ +import asyncio + +import agentops +import os +from dotenv import load_dotenv +from mistralai import Mistral + +load_dotenv() +MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY") + +agentops.init(default_tags=["mistral-provider-test"]) +client = Mistral(MISTRAL_API_KEY) + +response = client.chat.complete( + model="open-mistral-nemo", + messages=[ + { + "role": "user", + "content": "Say Hello", + }, + ], +) + + +stream_response = client.chat.stream( + model="open-mistral-nemo", + messages=[ + { + "role": "user", + "content": "Say Hello again", + } + ], +) + +response = "" +for event in stream_response: + if event.data.choices[0].finish_reason == "stop": + print(response) + else: + response += event.data.choices[0].delta.content + + +async def async_test(): + async_response = await client.chat.complete_async( + model="open-mistral-nemo", + messages=[ + { + "role": "user", + "content": "Say Hello in the Hindi language", + } + ], + ) + print(async_response.choices[0].message.content) + + +async def async_stream_test(): + async_stream_response = await client.chat.stream_async( + model="open-mistral-nemo", + messages=[ + { + "role": "user", + "content": "Say Hello in the Japanese language", + } + ], + ) + + response = "" + async for event in async_stream_response: + if event.data.choices[0].finish_reason == "stop": + print(response) + else: + response += event.data.choices[0].delta.content + + +async def main(): + await async_test() + await async_stream_test() + + +asyncio.run(main()) + +agentops.end_session(end_state="Success") + +### +# Used to verify that one session is created with one LLM event +###