From 081ebe425524adc4c9a97effaf2dcb4e975b679b Mon Sep 17 00:00:00 2001 From: Enias Cailliau Date: Wed, 14 Jun 2023 15:52:08 +0200 Subject: [PATCH 1/7] Support functions --- requirements.txt | 4 +- .../chat_models/__init__.py | 3 + src/steamship_langchain/chat_models/openai.py | 113 +++++++++++++++--- 3 files changed, 100 insertions(+), 20 deletions(-) diff --git a/requirements.txt b/requirements.txt index ece754e..b5f23cd 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -steamship~=2.17.0 -langchain==0.0.168 \ No newline at end of file +steamship==2.17.8 +langchain==0.0.200 \ No newline at end of file diff --git a/src/steamship_langchain/chat_models/__init__.py b/src/steamship_langchain/chat_models/__init__.py index e69de29..fccc818 100644 --- a/src/steamship_langchain/chat_models/__init__.py +++ b/src/steamship_langchain/chat_models/__init__.py @@ -0,0 +1,3 @@ +from steamship_langchain.chat_models.openai import ChatOpenAI + +__all__ = ["ChatOpenAI"] diff --git a/src/steamship_langchain/chat_models/openai.py b/src/steamship_langchain/chat_models/openai.py index 50b24b3..7225931 100644 --- a/src/steamship_langchain/chat_models/openai.py +++ b/src/steamship_langchain/chat_models/openai.py @@ -1,34 +1,43 @@ """OpenAI chat wrapper.""" from __future__ import annotations +import json import logging from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple import tiktoken from langchain.chat_models.base import BaseChatModel +from langchain.chat_models.openai import ChatOpenAI from langchain.schema import ( AIMessage, BaseMessage, ChatGeneration, ChatMessage, ChatResult, + FunctionMessage, HumanMessage, LLMResult, SystemMessage, ) -from pydantic import Extra, Field, root_validator +from pydantic import Extra, Field, ValidationError, root_validator from steamship import Block, File, MimeTypes, PluginInstance, Steamship, Tag from steamship.data.tags.tag_constants import RoleTag, TagKind logger = logging.getLogger(__file__) -def _convert_dict_to_message(_dict: dict) -> BaseMessage: +def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: role = _dict["role"] if role == "user": return HumanMessage(content=_dict["content"]) elif role == "assistant": - return AIMessage(content=_dict["content"]) + content = _dict["content"] + if "function_call" in content: + try: + return AIMessage(content=content, additional_kwargs=json.loads(content)) + except Exception as e: + pass + return AIMessage(content=content) elif role == "system": return SystemMessage(content=_dict["content"]) else: @@ -42,8 +51,16 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: message_dict = {"role": "user", "content": message.content} elif isinstance(message, AIMessage): message_dict = {"role": "assistant", "content": message.content} + if "function_call" in message.additional_kwargs: + message_dict["function_call"] = message.additional_kwargs["function_call"] elif isinstance(message, SystemMessage): message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = { + "role": "function", + "content": message.content, + "name": message.name, + } else: raise ValueError(f"Got unknown type {message}") if "name" in message.additional_kwargs: @@ -51,7 +68,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: return message_dict -class ChatOpenAI(BaseChatModel): +class ChatOpenAI(ChatOpenAI, BaseChatModel): """Wrapper around OpenAI Chat large language models. To use, you should have the ``openai`` python package installed, and the @@ -68,7 +85,7 @@ class ChatOpenAI(BaseChatModel): """ client: Any #: :meta private: - model_name: str = "gpt-3.5-turbo" + model_name: str = "gpt-3.5-turbo-0613" """Model name to use.""" temperature: float = 0.7 """What sampling temperature to use.""" @@ -92,14 +109,59 @@ class Config: extra = Extra.allow + @root_validator(allow_reuse=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import openai + + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + try: + values["client"] = openai.ChatCompletion + except AttributeError: + raise ValueError( + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." + ) + if values["n"] < 1: + raise ValueError("n must be at least 1.") + if values["n"] > 1 and values["streaming"]: + raise ValueError("n must be 1 when streaming.") + return values + def __init__( - self, - client: Steamship, - model_name: str = "gpt-3.5-turbo", - moderate_output: bool = True, - **kwargs, + self, + client: Steamship, + model_name: str = "gpt-3.5-turbo-0613", + moderate_output: bool = True, + **kwargs, ): - super().__init__(client=client, model_name=model_name, **kwargs) + try: + + class OpenAI(object): + class ChatCompletion: + pass + + import sys + + sys.modules["openai"] = OpenAI + + dummy_api_key = False + if "openai_api_key" not in kwargs: + kwargs["openai_api_key"] = "DUMMY" + dummy_api_key = True + super().__init__(client=client, model_name=model_name, **kwargs) + if dummy_api_key: + self.openai_api_key = None + except ValidationError as e: + print(e) + pass + self.client = client plugin_config = {"model": self.model_name, "moderate_output": moderate_output} if self.openai_api_key: plugin_config["openai_api_key"] = self.openai_api_key @@ -119,7 +181,7 @@ def __init__( self._llm_plugin = self.client.use_plugin( plugin_handle="gpt-4", config=plugin_config, - fetch_if_exists=True, + fetch_if_exists=False, ) @classmethod @@ -154,12 +216,17 @@ def _complete(self, messages: [Dict[str, str]], **params) -> List[BaseMessage]: for msg in messages: role = msg.get("role", "user") content = msg.get("content", "") + name = msg.get("name", "") if len(content) > 0: role_tag = RoleTag(role) + tags = [Tag(kind=TagKind.ROLE, name=role_tag)] + if name: + tags.append(Tag(kind="name", name=name)) + blocks.append( Block( text=content, - tags=[Tag(kind=TagKind.ROLE, name=role_tag)], + tags=tags, mime_type=MimeTypes.TXT, ) ) @@ -169,14 +236,24 @@ def _complete(self, messages: [Dict[str, str]], **params) -> List[BaseMessage]: generate_task.wait() return [ - _convert_dict_to_message({"content": block.text, "role": RoleTag.USER.value}) + _convert_dict_to_message( + { + "content": block.text, + "role": [tag for tag in block.tags if tag.kind == TagKind.ROLE.value][0].name, + } + ) for block in generate_task.output.blocks ] def _generate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} messages = self._complete(messages=message_dicts, **params) return ChatResult( generations=[ChatGeneration(message=message) for message in messages], @@ -184,12 +261,12 @@ def _generate( ) async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, messages: List[BaseMessage], stop: Optional[List[str]] = None ) -> ChatResult: raise NotImplementedError("Support for async is not provided yet.") def _create_message_dicts( - self, messages: List[BaseMessage], stop: Optional[List[str]] + self, messages: List[BaseMessage], stop: Optional[List[str]] ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} if stop is not None: @@ -221,7 +298,7 @@ def _identifying_params(self) -> Mapping[str, Any]: } async def agenerate( - self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None ) -> LLMResult: raise NotImplementedError("Support for async is not provided yet.") From 11aee5ae716777f4aaa0ea22684563be6fc2cb9f Mon Sep 17 00:00:00 2001 From: Enias Cailliau Date: Wed, 14 Jun 2023 15:57:47 +0200 Subject: [PATCH 2/7] Temporarily remove RoleTag validation --- src/steamship_langchain/chat_models/openai.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/steamship_langchain/chat_models/openai.py b/src/steamship_langchain/chat_models/openai.py index 7225931..5373234 100644 --- a/src/steamship_langchain/chat_models/openai.py +++ b/src/steamship_langchain/chat_models/openai.py @@ -218,7 +218,6 @@ def _complete(self, messages: [Dict[str, str]], **params) -> List[BaseMessage]: content = msg.get("content", "") name = msg.get("name", "") if len(content) > 0: - role_tag = RoleTag(role) tags = [Tag(kind=TagKind.ROLE, name=role_tag)] if name: tags.append(Tag(kind="name", name=name)) From e196f82e74d3fb43f76b13167567643d6bda4d7c Mon Sep 17 00:00:00 2001 From: Enias Cailliau Date: Wed, 14 Jun 2023 15:58:07 +0200 Subject: [PATCH 3/7] Correct role assignment --- src/steamship_langchain/chat_models/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/steamship_langchain/chat_models/openai.py b/src/steamship_langchain/chat_models/openai.py index 5373234..ae95136 100644 --- a/src/steamship_langchain/chat_models/openai.py +++ b/src/steamship_langchain/chat_models/openai.py @@ -218,7 +218,7 @@ def _complete(self, messages: [Dict[str, str]], **params) -> List[BaseMessage]: content = msg.get("content", "") name = msg.get("name", "") if len(content) > 0: - tags = [Tag(kind=TagKind.ROLE, name=role_tag)] + tags = [Tag(kind=TagKind.ROLE, name=role)] if name: tags.append(Tag(kind="name", name=name)) From d1a7d00185d39f22001df132b46aa81e118a2551 Mon Sep 17 00:00:00 2001 From: Enias Cailliau Date: Wed, 14 Jun 2023 15:58:43 +0200 Subject: [PATCH 4/7] Remove duplicate environment validator --- src/steamship_langchain/chat_models/openai.py | 25 ------------------- 1 file changed, 25 deletions(-) diff --git a/src/steamship_langchain/chat_models/openai.py b/src/steamship_langchain/chat_models/openai.py index ae95136..f213f34 100644 --- a/src/steamship_langchain/chat_models/openai.py +++ b/src/steamship_langchain/chat_models/openai.py @@ -109,31 +109,6 @@ class Config: extra = Extra.allow - @root_validator(allow_reuse=True) - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - try: - import openai - - except ImportError: - raise ValueError( - "Could not import openai python package. " - "Please install it with `pip install openai`." - ) - try: - values["client"] = openai.ChatCompletion - except AttributeError: - raise ValueError( - "`openai` has no `ChatCompletion` attribute, this is likely " - "due to an old version of the openai package. Try upgrading it " - "with `pip install --upgrade openai`." - ) - if values["n"] < 1: - raise ValueError("n must be at least 1.") - if values["n"] > 1 and values["streaming"]: - raise ValueError("n must be 1 when streaming.") - return values - def __init__( self, client: Steamship, From 3c398ecd278afc6194c256e3549d22b500fbfae2 Mon Sep 17 00:00:00 2001 From: Enias Cailliau Date: Wed, 14 Jun 2023 16:39:26 +0200 Subject: [PATCH 5/7] Lint --- src/steamship_langchain/chat_models/openai.py | 67 ++++++++++++------- 1 file changed, 42 insertions(+), 25 deletions(-) diff --git a/src/steamship_langchain/chat_models/openai.py b/src/steamship_langchain/chat_models/openai.py index f213f34..82d6229 100644 --- a/src/steamship_langchain/chat_models/openai.py +++ b/src/steamship_langchain/chat_models/openai.py @@ -6,6 +6,7 @@ from typing import Any, Dict, Generator, List, Mapping, Optional, Tuple import tiktoken +from langchain.callbacks.manager import CallbackManagerForLLMRun from langchain.chat_models.base import BaseChatModel from langchain.chat_models.openai import ChatOpenAI from langchain.schema import ( @@ -21,7 +22,7 @@ ) from pydantic import Extra, Field, ValidationError, root_validator from steamship import Block, File, MimeTypes, PluginInstance, Steamship, Tag -from steamship.data.tags.tag_constants import RoleTag, TagKind +from steamship.data.tags.tag_constants import TagKind logger = logging.getLogger(__file__) @@ -34,8 +35,8 @@ def _convert_dict_to_message(_dict: Mapping[str, Any]) -> BaseMessage: content = _dict["content"] if "function_call" in content: try: - return AIMessage(content=content, additional_kwargs=json.loads(content)) - except Exception as e: + return AIMessage(content="", additional_kwargs=json.loads(content)) + except Exception: pass return AIMessage(content=content) elif role == "system": @@ -109,12 +110,37 @@ class Config: extra = Extra.allow + @root_validator(allow_reuse=True) + def validate_environment(cls, values: Dict) -> Dict: + """Validate that api key and python package exists in environment.""" + try: + import openai + + except ImportError: + raise ValueError( + "Could not import openai python package. " + "Please install it with `pip install openai`." + ) + try: + values["client"] = openai.ChatCompletion + except AttributeError: + raise ValueError( + "`openai` has no `ChatCompletion` attribute, this is likely " + "due to an old version of the openai package. Try upgrading it " + "with `pip install --upgrade openai`." + ) + if values["n"] < 1: + raise ValueError("n must be at least 1.") + if values["n"] > 1 and values["streaming"]: + raise ValueError("n must be 1 when streaming.") + return values + def __init__( - self, - client: Steamship, - model_name: str = "gpt-3.5-turbo-0613", - moderate_output: bool = True, - **kwargs, + self, + client: Steamship, + model_name: str = "gpt-3.5-turbo-0613", + moderate_output: bool = True, + **kwargs, ): try: @@ -135,7 +161,6 @@ class ChatCompletion: self.openai_api_key = None except ValidationError as e: print(e) - pass self.client = client plugin_config = {"model": self.model_name, "moderate_output": moderate_output} if self.openai_api_key: @@ -159,14 +184,6 @@ class ChatCompletion: fetch_if_exists=False, ) - @classmethod - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - if values["n"] < 1: - raise ValueError("n must be at least 1.") - return values - @property def _default_params(self) -> Dict[str, Any]: """Get the default parameters for calling OpenAI API.""" @@ -220,11 +237,11 @@ def _complete(self, messages: [Dict[str, str]], **params) -> List[BaseMessage]: ] def _generate( - self, - messages: List[BaseMessage], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: message_dicts, params = self._create_message_dicts(messages, stop) params = {**params, **kwargs} @@ -235,12 +252,12 @@ def _generate( ) async def _agenerate( - self, messages: List[BaseMessage], stop: Optional[List[str]] = None + self, messages: List[BaseMessage], stop: Optional[List[str]] = None ) -> ChatResult: raise NotImplementedError("Support for async is not provided yet.") def _create_message_dicts( - self, messages: List[BaseMessage], stop: Optional[List[str]] + self, messages: List[BaseMessage], stop: Optional[List[str]] ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params} if stop is not None: @@ -272,7 +289,7 @@ def _identifying_params(self) -> Mapping[str, Any]: } async def agenerate( - self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None + self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None ) -> LLMResult: raise NotImplementedError("Support for async is not provided yet.") From 25d3bd431caf78c735dd30ef7e593f3093ca4b49 Mon Sep 17 00:00:00 2001 From: Enias Cailliau Date: Wed, 14 Jun 2023 17:11:16 +0200 Subject: [PATCH 6/7] More flexible dependencies --- requirements.txt | 2 +- src/steamship_langchain/chat_models/openai.py | 27 +------------------ 2 files changed, 2 insertions(+), 27 deletions(-) diff --git a/requirements.txt b/requirements.txt index b5f23cd..c6fd132 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,2 +1,2 @@ -steamship==2.17.8 +steamship~=2.17.4 langchain==0.0.200 \ No newline at end of file diff --git a/src/steamship_langchain/chat_models/openai.py b/src/steamship_langchain/chat_models/openai.py index 82d6229..f5d557a 100644 --- a/src/steamship_langchain/chat_models/openai.py +++ b/src/steamship_langchain/chat_models/openai.py @@ -20,7 +20,7 @@ LLMResult, SystemMessage, ) -from pydantic import Extra, Field, ValidationError, root_validator +from pydantic import Extra, Field, ValidationError from steamship import Block, File, MimeTypes, PluginInstance, Steamship, Tag from steamship.data.tags.tag_constants import TagKind @@ -110,31 +110,6 @@ class Config: extra = Extra.allow - @root_validator(allow_reuse=True) - def validate_environment(cls, values: Dict) -> Dict: - """Validate that api key and python package exists in environment.""" - try: - import openai - - except ImportError: - raise ValueError( - "Could not import openai python package. " - "Please install it with `pip install openai`." - ) - try: - values["client"] = openai.ChatCompletion - except AttributeError: - raise ValueError( - "`openai` has no `ChatCompletion` attribute, this is likely " - "due to an old version of the openai package. Try upgrading it " - "with `pip install --upgrade openai`." - ) - if values["n"] < 1: - raise ValueError("n must be at least 1.") - if values["n"] > 1 and values["streaming"]: - raise ValueError("n must be 1 when streaming.") - return values - def __init__( self, client: Steamship, From 97ff28a36736e08c7ff482e5b15d2f0c30c37e7c Mon Sep 17 00:00:00 2001 From: Enias Cailliau Date: Wed, 14 Jun 2023 17:16:57 +0200 Subject: [PATCH 7/7] Reuse enabled --- src/steamship_langchain/chat_models/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/steamship_langchain/chat_models/openai.py b/src/steamship_langchain/chat_models/openai.py index f5d557a..93f1808 100644 --- a/src/steamship_langchain/chat_models/openai.py +++ b/src/steamship_langchain/chat_models/openai.py @@ -156,7 +156,7 @@ class ChatCompletion: self._llm_plugin = self.client.use_plugin( plugin_handle="gpt-4", config=plugin_config, - fetch_if_exists=False, + fetch_if_exists=True, ) @property