From cea5a28db18f4a58cf903d025f0c229cc0b36b08 Mon Sep 17 00:00:00 2001 From: Brandon Hancock Date: Mon, 6 Jan 2025 13:12:30 -0500 Subject: [PATCH] adding tests. cleaing up pr. --- src/crewai/cli/cli.py | 1 - src/crewai/cli/crew_chat.py | 116 +++++++++++++++----------- src/crewai/cli/fetch_chat_llm.py | 81 ------------------ src/crewai/cli/templates/crew/main.py | 103 +---------------------- src/crewai/llm.py | 30 ------- src/crewai/types/crew_chat.py | 10 +-- tests/utilities/test_llm_utils.py | 96 +++++++++++++++++++++ 7 files changed, 166 insertions(+), 271 deletions(-) delete mode 100644 src/crewai/cli/fetch_chat_llm.py create mode 100644 tests/utilities/test_llm_utils.py diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 8e21265609..334759a6d4 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -8,7 +8,6 @@ from crewai.cli.create_crew import create_crew from crewai.cli.create_flow import create_flow from crewai.cli.crew_chat import run_chat -from crewai.cli.fetch_chat_llm import fetch_chat_llm from crewai.memory.storage.kickoff_task_outputs_storage import ( KickoffTaskOutputsSQLiteStorage, ) diff --git a/src/crewai/cli/crew_chat.py b/src/crewai/cli/crew_chat.py index 9721bd6c2a..5348b648fa 100644 --- a/src/crewai/cli/crew_chat.py +++ b/src/crewai/cli/crew_chat.py @@ -7,11 +7,7 @@ import click import tomli -from crewai.agents.agent_builder.base_agent import BaseAgent -from crewai.cli.fetch_chat_llm import fetch_chat_llm -from crewai.cli.fetch_crew_inputs import fetch_crew_inputs from crewai.crew import Crew -from crewai.task import Task from crewai.types.crew_chat import ChatInputField, ChatInputs from crewai.utilities.llm_utils import create_llm @@ -23,25 +19,44 @@ def run_chat(): Exits if crew_name or crew_description are missing. """ crew, crew_name = load_crew_and_name() - click.secho("\nFetching the Chat LLM...", fg="cyan") - try: - chat_llm = create_llm(crew.chat_llm) - except Exception as e: - click.secho(f"Failed to retrieve Chat LLM: {e}", fg="red") - return + chat_llm = initialize_chat_llm(crew) if not chat_llm: - click.secho("No valid Chat LLM returned. Exiting.", fg="red") return - # Generate crew chat inputs automatically crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm) - print("crew_inputs:", crew_chat_inputs) - - # Generate a tool schema from the crew inputs crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs) - print("crew_tool_schema:", crew_tool_schema) + system_message = build_system_message(crew_chat_inputs) + + messages = [ + {"role": "system", "content": system_message}, + ] + available_functions = { + crew_chat_inputs.crew_name: create_tool_function(crew, messages), + } + + click.secho( + "\nEntering an interactive chat loop with function-calling.\n" + "Type 'exit' or Ctrl+C to quit.\n", + fg="cyan", + ) + + chat_loop(chat_llm, messages, crew_tool_schema, available_functions) + + +def initialize_chat_llm(crew: Crew) -> Any: + """Initializes the chat LLM and handles exceptions.""" + try: + return create_llm(crew.chat_llm) + except Exception as e: + click.secho( + f"Unable to find a Chat LLM. Please make sure you set chat_llm on the crew: {e}", + fg="red", + ) + return None + - # Build initial system message +def build_system_message(crew_chat_inputs: ChatInputs) -> str: + """Builds the initial system message for the chat.""" required_fields_str = ( ", ".join( f"{field.name} (desc: {field.description or 'n/a'})" @@ -50,7 +65,7 @@ def run_chat(): or "(No required fields detected)" ) - system_message = ( + return ( "You are a helpful AI assistant for the CrewAI platform. " "Your primary purpose is to assist users with the crew's specific tasks. " "You can answer general questions, but should guide users back to the crew's purpose afterward. " @@ -66,26 +81,18 @@ def run_chat(): f"\nCrew Description: {crew_chat_inputs.crew_description}" ) - messages = [ - {"role": "system", "content": system_message}, - ] - # Create a wrapper function that captures 'crew' and 'messages' from the enclosing scope +def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any: + """Creates a wrapper function for running the crew tool with messages.""" + def run_crew_tool_with_messages(**kwargs): return run_crew_tool(crew, messages, **kwargs) - # Prepare available_functions with the wrapper function - available_functions = { - crew_chat_inputs.crew_name: run_crew_tool_with_messages, - } + return run_crew_tool_with_messages - click.secho( - "\nEntering an interactive chat loop with function-calling.\n" - "Type 'exit' or Ctrl+C to quit.\n", - fg="cyan", - ) - # Main chat loop +def chat_loop(chat_llm, messages, crew_tool_schema, available_functions): + """Main chat loop for interacting with the user.""" while True: try: user_input = click.prompt("You", type=str) @@ -93,20 +100,14 @@ def run_crew_tool_with_messages(**kwargs): click.echo("Exiting chat. Goodbye!") break - # Append user message messages.append({"role": "user", "content": user_input}) - - # Invoke the LLM, passing tools and available_functions final_response = chat_llm.call( messages=messages, tools=[crew_tool_schema], available_functions=available_functions, ) - # Append assistant's reply messages.append({"role": "assistant", "content": final_response}) - - # Display assistant's reply click.secho(f"\nAssistant: {final_response}\n", fg="green") except KeyboardInterrupt: @@ -165,7 +166,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs): """ try: # Serialize 'messages' to JSON string before adding to kwargs - kwargs['crew_chat_messages'] = json.dumps(messages) + kwargs["crew_chat_messages"] = json.dumps(messages) # Run the crew with the provided inputs crew_output = crew.kickoff(inputs=kwargs) @@ -184,7 +185,7 @@ def run_crew_tool(crew: Crew, messages: List[Dict[str, str]], **kwargs): def load_crew_and_name() -> Tuple[Crew, str]: """ Loads the crew by importing the crew class from the user's project. - + Returns: Tuple[Crew, str]: A tuple containing the Crew instance and the name of the crew. """ @@ -258,9 +259,7 @@ def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInput crew_description = generate_crew_description_with_ai(crew, chat_llm) return ChatInputs( - crew_name=crew_name, - crew_description=crew_description, - inputs=input_fields + crew_name=crew_name, crew_description=crew_description, inputs=input_fields ) @@ -307,18 +306,31 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> placeholder_pattern = re.compile(r"\{(.+?)\}") for task in crew.tasks: - if f"{{{input_name}}}" in task.description or f"{{{input_name}}}" in task.expected_output: + if ( + f"{{{input_name}}}" in task.description + or f"{{{input_name}}}" in task.expected_output + ): # Replace placeholders with input names - task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description) - expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output) + task_description = placeholder_pattern.sub( + lambda m: m.group(1), task.description + ) + expected_output = placeholder_pattern.sub( + lambda m: m.group(1), task.expected_output + ) context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Expected Output: {expected_output}") for agent in crew.agents: - if f"{{{input_name}}}" in agent.role or f"{{{input_name}}}" in agent.goal or f"{{{input_name}}}" in agent.backstory: + if ( + f"{{{input_name}}}" in agent.role + or f"{{{input_name}}}" in agent.goal + or f"{{{input_name}}}" in agent.backstory + ): # Replace placeholders with input names agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role) agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal) - agent_backstory = placeholder_pattern.sub(lambda m: m.group(1), agent.backstory) + agent_backstory = placeholder_pattern.sub( + lambda m: m.group(1), agent.backstory + ) context_texts.append(f"Agent Role: {agent_role}") context_texts.append(f"Agent Goal: {agent_goal}") context_texts.append(f"Agent Backstory: {agent_backstory}") @@ -357,8 +369,12 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: for task in crew.tasks: # Replace placeholders with input names - task_description = placeholder_pattern.sub(lambda m: m.group(1), task.description) - expected_output = placeholder_pattern.sub(lambda m: m.group(1), task.expected_output) + task_description = placeholder_pattern.sub( + lambda m: m.group(1), task.description + ) + expected_output = placeholder_pattern.sub( + lambda m: m.group(1), task.expected_output + ) context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Expected Output: {expected_output}") for agent in crew.agents: diff --git a/src/crewai/cli/fetch_chat_llm.py b/src/crewai/cli/fetch_chat_llm.py deleted file mode 100644 index ff28bb9393..0000000000 --- a/src/crewai/cli/fetch_chat_llm.py +++ /dev/null @@ -1,81 +0,0 @@ -import json -import subprocess - -import click -from packaging import version - -from crewai.cli.utils import read_toml -from crewai.cli.version import get_crewai_version -from crewai.llm import LLM - - -def fetch_chat_llm() -> LLM: - """ - Fetch the chat LLM by running "uv run fetch_chat_llm" (or your chosen script name), - parsing its JSON stdout, and returning an LLM instance. - - This expects the script "fetch_chat_llm" to print out JSON that represents the - LLM parameters (e.g., by calling something like: print(json.dumps(llm.to_dict()))). - - Any error, whether from the subprocess or JSON parsing, will raise a RuntimeError. - """ - - # You may change this command to match whatever's in your pyproject.toml [project.scripts]. - command = ["uv", "run", "fetch_chat_llm"] - - crewai_version = get_crewai_version() - min_required_version = "0.87.0" # Adjust as needed - - pyproject_data = read_toml() - - # If old poetry-based setup is detected and version is below min_required_version - if pyproject_data.get("tool", {}).get("poetry") and ( - version.parse(crewai_version) < version.parse(min_required_version) - ): - click.secho( - f"You are running an older version of crewAI ({crewai_version}) that uses poetry pyproject.toml.\n" - f"Please run `crewai update` to transition your pyproject.toml to use uv.", - fg="red", - ) - - # Initialize a reference to your LLM - llm_instance = None - - try: - result = subprocess.run(command, capture_output=True, text=True, check=True) - stdout_lines = result.stdout.strip().splitlines() - - # Find the line that contains the JSON data - json_line = next( - ( - line - for line in stdout_lines - if line.startswith("{") and line.endswith("}") - ), - None, - ) - - if not json_line: - raise RuntimeError( - "No valid JSON output received from `fetch_chat_llm` command." - ) - - try: - llm_data = json.loads(json_line) - llm_instance = LLM.from_dict(llm_data) - except json.JSONDecodeError as e: - raise RuntimeError( - f"Unable to parse JSON from `fetch_chat_llm` output: {e}\nOutput: {repr(json_line)}" - ) from e - - except subprocess.CalledProcessError as e: - raise RuntimeError(f"An error occurred while fetching chat LLM: {e}") from e - except Exception as e: - raise RuntimeError( - f"An unexpected error occurred while fetching chat LLM: {e}" - ) from e - - if not llm_instance: - raise RuntimeError("Failed to create a valid LLM from `fetch_chat_llm` output.") - - return llm_instance diff --git a/src/crewai/cli/templates/crew/main.py b/src/crewai/cli/templates/crew/main.py index 64b86a8eca..9344372b01 100644 --- a/src/crewai/cli/templates/crew/main.py +++ b/src/crewai/cli/templates/crew/main.py @@ -19,22 +19,11 @@ def run(): Usage example: uv run run_crew -- --topic="New Topic" --some_other_field="Value" """ - # Default inputs inputs = { 'topic': 'AI LLMs' # Add any other default fields here } - - # 1) Gather overrides from sys.argv - # sys.argv might look like: ['run_crew', '--topic=NewTopic'] - # But be aware that if you're calling "uv run run_crew", sys.argv might have - # additional items. So we typically skip the first 1 or 2 items to get only overrides. - overrides = parse_cli_overrides(sys.argv[1:]) - - # 2) Merge the overrides into defaults - inputs.update(overrides) - - # 3) Kick off the crew with final inputs + try: {{crew_name}}().crew().kickoff(inputs=inputs) except Exception as e: @@ -76,93 +65,3 @@ def test(): except Exception as e: raise Exception(f"An error occurred while testing the crew: {e}") - -def fetch_inputs(): - """ - Command that gathers required placeholders/inputs from the Crew, then - prints them as JSON to stdout so external scripts can parse them easily. - """ - try: - crew = {{crew_name}}().crew() - crew_inputs = crew.fetch_inputs() - json_string = json.dumps(list(crew_inputs)) - print(json_string) - except Exception as e: - raise Exception(f"An error occurred while fetching inputs: {e}") - -def fetch_chat_llm(): - """ - Command that fetches the 'chat_llm' property from the Crew, - instantiates it via create_llm(), - and prints the resulting LLM as JSON (using LLM.to_dict()) to stdout. - """ - try: - crew = {{crew_name}}().crew() - raw_chat_llm = getattr(crew, "chat_llm", None) - - if not raw_chat_llm: - # If the crew doesn't have chat_llm, fallback to create_llm(None) - final_llm = create_llm(None) - else: - # raw_chat_llm might be a dict, or an LLM, or something else - final_llm = create_llm(raw_chat_llm) - - if final_llm: - # Print the final LLM as JSON, so fetch_chat_llm.py can parse it - from crewai.llm import LLM # Import here to avoid circular references - - # Make sure it's an instance of the LLM class: - if isinstance(final_llm, LLM): - print(json.dumps(final_llm.to_dict())) - else: - # If somehow it's not an LLM, try to interpret as a dict - # or revert to an empty fallback - if isinstance(final_llm, dict): - print(json.dumps(final_llm)) - else: - print(json.dumps({})) - else: - print(json.dumps({})) - except Exception as e: - raise Exception(f"An error occurred while fetching chat LLM: {e}") - -# TODO: Talk to Joao about making using LLM calls to analyze the crew -# and generate all of this information automatically -def fetch_chat_inputs(): - """ - Command that fetches the 'chat_inputs' property from the Crew, - and prints it as JSON to stdout. - """ - try: - crew = {{crew_name}}().crew() - raw_chat_inputs = getattr(crew, "chat_inputs", None) - - if raw_chat_inputs: - # Convert to dictionary to print JSON - print(json.dumps(raw_chat_inputs.model_dump())) - else: - # If crew.chat_inputs is None or empty, print an empty JSON - print(json.dumps({})) - except Exception as e: - raise Exception(f"An error occurred while fetching chat inputs: {e}") - - -def parse_cli_overrides(args_list) -> dict: - """ - Parse arguments in the form of --key=value from a list of CLI arguments. - Return them as a dict. For example: - ['--topic=AI LLMs', '--username=John'] => {'topic': 'AI LLMs', 'username': 'John'} - """ - overrides = {} - for arg in args_list: - if arg.startswith("--"): - # remove the leading -- - trimmed = arg[2:] - if "=" in trimmed: - key, val = trimmed.split("=", 1) - overrides[key] = val - else: - # If someone passed something like --topic (no =), - # either handle differently or ignore - pass - return overrides diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 69780af8ce..77bec3355e 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -5,7 +5,6 @@ import threading import warnings from contextlib import contextmanager -from importlib import resources from typing import Any, Dict, List, Optional, Union, cast from dotenv import load_dotenv @@ -179,35 +178,6 @@ def to_dict(self) -> dict: "callbacks": self.callbacks, } - @classmethod - def from_dict(cls, data: dict) -> "LLM": - """ - Create an LLM instance from a dict. - We assume the dict has all relevant keys that match what's in the constructor. - """ - known_fields = {} - known_fields["model"] = data.pop("model", None) - known_fields["timeout"] = data.pop("timeout", None) - known_fields["temperature"] = data.pop("temperature", None) - known_fields["top_p"] = data.pop("top_p", None) - known_fields["n"] = data.pop("n", None) - known_fields["stop"] = data.pop("stop", None) - known_fields["max_completion_tokens"] = data.pop("max_completion_tokens", None) - known_fields["max_tokens"] = data.pop("max_tokens", None) - known_fields["presence_penalty"] = data.pop("presence_penalty", None) - known_fields["frequency_penalty"] = data.pop("frequency_penalty", None) - known_fields["logit_bias"] = data.pop("logit_bias", None) - known_fields["response_format"] = data.pop("response_format", None) - known_fields["seed"] = data.pop("seed", None) - known_fields["logprobs"] = data.pop("logprobs", None) - known_fields["top_logprobs"] = data.pop("top_logprobs", None) - known_fields["base_url"] = data.pop("base_url", None) - known_fields["api_version"] = data.pop("api_version", None) - known_fields["api_key"] = data.pop("api_key", None) - known_fields["callbacks"] = data.pop("callbacks", None) - - return cls(**known_fields, **data) - def call( self, messages: List[Dict[str, str]], diff --git a/src/crewai/types/crew_chat.py b/src/crewai/types/crew_chat.py index e687d8efd1..354642442a 100644 --- a/src/crewai/types/crew_chat.py +++ b/src/crewai/types/crew_chat.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List from pydantic import BaseModel, Field @@ -14,10 +14,7 @@ class ChatInputField(BaseModel): """ name: str = Field(..., description="The name of the input field") - description: str = Field( - ..., - description="A short description of the input field", - ) + description: str = Field(..., description="A short description of the input field") class ChatInputs(BaseModel): @@ -36,8 +33,7 @@ class ChatInputs(BaseModel): crew_name: str = Field(..., description="The name of the crew") crew_description: str = Field( - ..., - description="A description of the crew's purpose", + ..., description="A description of the crew's purpose" ) inputs: List[ChatInputField] = Field( default_factory=list, description="A list of input fields for the crew" diff --git a/tests/utilities/test_llm_utils.py b/tests/utilities/test_llm_utils.py new file mode 100644 index 0000000000..5aa4f1a1a4 --- /dev/null +++ b/tests/utilities/test_llm_utils.py @@ -0,0 +1,96 @@ +import os +from unittest.mock import patch + +import pytest +from litellm.exceptions import BadRequestError + +from crewai.llm import LLM +from crewai.utilities.llm_utils import create_llm + + +def test_create_llm_with_llm_instance(): + existing_llm = LLM(model="gpt-4o") + llm = create_llm(llm_value=existing_llm) + assert llm is existing_llm + + +def test_create_llm_with_valid_model_string(): + llm = create_llm(llm_value="gpt-4o") + assert isinstance(llm, LLM) + assert llm.model == "gpt-4o" + + +def test_create_llm_with_invalid_model_string(): + with pytest.raises(BadRequestError, match="LLM Provider NOT provided"): + llm = create_llm(llm_value="invalid-model") + llm.call(messages=[{"role": "user", "content": "Hello, world!"}]) + + +def test_create_llm_with_unknown_object_missing_attributes(): + class UnknownObject: + pass + + unknown_obj = UnknownObject() + llm = create_llm(llm_value=unknown_obj) + + # Attempt to call the LLM and expect it to raise an error due to missing attributes + with pytest.raises(BadRequestError, match="LLM Provider NOT provided"): + llm.call(messages=[{"role": "user", "content": "Hello, world!"}]) + + +def test_create_llm_with_none_uses_default_model(): + with patch.dict(os.environ, {}, clear=True): + with patch("crewai.cli.constants.DEFAULT_LLM_MODEL", "gpt-4o"): + llm = create_llm(llm_value=None) + assert isinstance(llm, LLM) + assert llm.model == "gpt-4o-mini" + + +def test_create_llm_with_unknown_object(): + class UnknownObject: + model_name = "gpt-4o" + temperature = 0.7 + max_tokens = 1500 + + unknown_obj = UnknownObject() + llm = create_llm(llm_value=unknown_obj) + assert isinstance(llm, LLM) + assert llm.model == "gpt-4o" + assert llm.temperature == 0.7 + assert llm.max_tokens == 1500 + + +def test_create_llm_from_env_with_unaccepted_attributes(): + with patch.dict( + os.environ, + { + "OPENAI_MODEL_NAME": "gpt-3.5-turbo", + "AWS_ACCESS_KEY_ID": "fake-access-key", + "AWS_SECRET_ACCESS_KEY": "fake-secret-key", + "AWS_REGION_NAME": "us-west-2", + }, + ): + llm = create_llm(llm_value=None) + assert isinstance(llm, LLM) + assert llm.model == "gpt-3.5-turbo" + assert not hasattr(llm, "AWS_ACCESS_KEY_ID") + assert not hasattr(llm, "AWS_SECRET_ACCESS_KEY") + assert not hasattr(llm, "AWS_REGION_NAME") + + +def test_create_llm_with_partial_attributes(): + class PartialAttributes: + model_name = "gpt-4o" + # temperature is missing + + obj = PartialAttributes() + llm = create_llm(llm_value=obj) + assert isinstance(llm, LLM) + assert llm.model == "gpt-4o" + assert llm.temperature is None # Should handle missing attributes gracefully + + +def test_create_llm_with_invalid_type(): + with pytest.raises(BadRequestError, match="LLM Provider NOT provided"): + llm = create_llm(llm_value=42) + llm.call(messages=[{"role": "user", "content": "Hello, world!"}])