From 28b22425d474bbe8c4b8f25ab4b6fa6315124e1a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Albert=20=C3=96rwall?= Date: Mon, 5 Aug 2024 06:46:17 +0200 Subject: [PATCH] Refactor --- moatless/__init__.py | 3 +- moatless/benchmark/evaluation.py | 220 +++++++++++++---------- moatless/codeblocks/parser/create.py | 12 ++ moatless/edit/clarify.py | 2 +- moatless/edit/edit.py | 2 +- moatless/edit/plan.py | 7 +- moatless/edit/plan_lines.py | 2 +- moatless/edit/review.py | 2 +- moatless/find/decide.py | 2 +- moatless/find/identify.py | 2 +- moatless/find/search.py | 4 +- moatless/index/code_index.py | 37 ++-- moatless/index/epic_split.py | 11 +- moatless/loop.py | 257 ++++++++++++++++----------- moatless/repository/file.py | 28 ++- moatless/repository/git.py | 6 + moatless/state.py | 57 +++++- moatless/trajectory.py | 14 +- moatless/transition_rules.py | 11 +- tests/edit/test_clarify.py | 6 +- tests/integration_test.py | 4 +- tests/loop/test_loop.py | 9 + tests/test_state.py | 83 ++++++++- tests/test_transition_rules.py | 12 +- 24 files changed, 542 insertions(+), 251 deletions(-) diff --git a/moatless/__init__.py b/moatless/__init__.py index 85ffe998..bdf45ec7 100644 --- a/moatless/__init__.py +++ b/moatless/__init__.py @@ -1,3 +1,4 @@ -from moatless.loop import AgenticLoop, TransitionRules from moatless.repository import FileRepository from moatless.workspace import Workspace +from moatless.transition_rules import TransitionRules +from moatless.loop import AgenticLoop diff --git a/moatless/benchmark/evaluation.py b/moatless/benchmark/evaluation.py index 214467d3..a4f12bc9 100644 --- a/moatless/benchmark/evaluation.py +++ b/moatless/benchmark/evaluation.py @@ -14,6 +14,7 @@ import pandas as pd from tqdm.auto import tqdm +from moatless.transition_rules import TransitionRules from moatless.benchmark.swebench import ( found_in_alternative_spans, found_in_expected_spans, @@ -27,11 +28,11 @@ trace_metadata, ) from moatless.file_context import FileContext -from moatless.loop import AgenticLoop, Transitions -from moatless.repository import FileRepository +from moatless.loop import AgenticLoop +from moatless.repository import FileRepository, GitRepository from moatless.workspace import Workspace -logger = logging.getLogger("Evaluator") +logger = logging.getLogger(__name__) TEST_SUBSET = [ "astropy__astropy-14995", @@ -74,9 +75,11 @@ def __init__( repo_base_dir: str, evaluations_dir: str, evaluation_name: str, - transitions: Transitions, + transitions: TransitionRules, instructor_mode: instructor.Mode | None = None, max_cost: float = 0.5, + max_transitions: int = 25, + max_expansions: int = 2, max_file_context_tokens: int = 16000, litellm_callback: Optional[str] = None, previous_trajectory_dir: Optional[str] = None, @@ -93,6 +96,8 @@ def __init__( self.evaluation_name = evaluation_name self.max_file_context_tokens = max_file_context_tokens self.max_cost = max_cost + self.max_expansions = max_expansions + self.max_transitions = max_transitions self.instructor_mode = instructor_mode self.transitions = transitions @@ -107,9 +112,11 @@ def __init__( self.previous_trajectory_dir = previous_trajectory_dir self.retry_state = retry_state + logger.info(f"Save trajectories to directory: {self.trajectory_dir}") if not os.path.exists(self.trajectory_dir): os.makedirs(self.trajectory_dir) + logger.info(f"Save logs to directory: {self.logs_dir}") if not os.path.exists(self.logs_dir): os.makedirs(self.logs_dir) @@ -124,7 +131,7 @@ def __init__( with open(os.path.join(result_file)) as f: self.report = json.load(f) else: - self.report = {"resolved": []} + self.report = {"resolved_ids": []} def run_evaluation_with_moatless_dataset( self, @@ -234,6 +241,8 @@ def _evaluate_instance(self, instance: dict, retry: bool = False) -> dict: trajectory_path=trajectory_path, prompt_log_dir=prompt_log_dir, max_cost=self.max_cost, + max_transitions=self.max_transitions, + max_actions=self.max_expansions, instructor_mode=self.instructor_mode, ) @@ -254,16 +263,25 @@ def _evaluate_instance(self, instance: dict, retry: bool = False) -> dict: info["duration"] = time.time() - start_time info["total_cost"] = loop.total_cost() - workspace.save() + if isinstance(workspace.file_repo, GitRepository): + diff = workspace.file_repo.diff() + else: + workspace.save() - output = subprocess.run( - ["git", "diff"], - capture_output=True, - text=True, - cwd=repo_dir, - ) + output = subprocess.run( + ["git", "diff"], + capture_output=True, + text=True, + cwd=repo_dir, + ) + + if output: + diff = output.stdout + else: + diff = None + + info["submission"] = diff - info["submission"] = output.stdout loop.trajectory.save_info(info) return loop.trajectory.to_dict() @@ -352,6 +370,8 @@ def _run_evaluation_detailed(self, instances: list[dict]): results = [] transition_results = [] + logger.info(f"Processing {len(instances)} instances with {len(repo_groups)} repos with {self.num_workers} workers") + with concurrent.futures.ProcessPoolExecutor( max_workers=self.num_workers ) as executor: @@ -470,10 +490,15 @@ def _run_evaluation_simple(self, instances: list[dict]): json_string = json.dumps(prediction) file.write(json_string + "\n") - def to_result(self, instance: dict, trajectory: dict) -> tuple[list, list]: + def to_result(self, instance: dict, trajectory: dict) -> tuple[dict, list]: info = trajectory["info"] - resolved = info.get("instance_id", "") in self.report["resolved"] + if "resolved_ids" in self.report and instance["instance_id"] in self.report["resolved_ids"]: + result_status = "resolved" + else: + result_status = info.get("status") + + resolved = result_status == "resolved" try: transitions = [] @@ -483,6 +508,7 @@ def to_result(self, instance: dict, trajectory: dict) -> tuple[list, list]: "total_cost": info.get("total_cost", 0), "resolved_by": (len(instance.get("resolved_by", []))), "status": None, + "result_status": result_status, "transitions": len(trajectory["transitions"]), "edited": False, "planned": False, @@ -513,13 +539,32 @@ def to_result(self, instance: dict, trajectory: dict) -> tuple[list, list]: id_iterations = 0 search_iterations = 0 + selected_transition_ids = [] + if "current_transition_id" in trajectory: + transitions_map = {t["id"]: t for t in trajectory["transitions"]} + + transition = transitions_map.get(trajectory["current_transition_id"]) + while transition: + selected_transition_ids.append(transition["id"]) + if "parent_id" in transition: + transition = transitions_map.get(transition["parent_id"]) + else: + break + + logger.info(f"Selected transitions: {selected_transition_ids}") + if instance.get("expected_spans"): for transition in trajectory["transitions"]: - if transition["name"] not in result: - result[transition["name"]] = 0 - result[f"{transition['name']}_cost"] = 0 + if selected_transition_ids and transition["id"] not in selected_transition_ids: + continue + + state_name = transition["state"]["name"] - result[transition["name"]] += 1 + if state_name not in result: + result[state_name] = 0 + result[f"{state_name}_cost"] = 0 + + result[state_name] += 1 expected_span_str = "" for file_path, span_ids in instance["expected_spans"].items(): @@ -528,7 +573,7 @@ def to_result(self, instance: dict, trajectory: dict) -> tuple[list, list]: transition_result = { "instance_id": instance["instance_id"], "resolved": resolved, - "name": transition["name"], + "name": state_name, "cost": 0, "expected_spans": expected_span_str, "actual_spans": "", @@ -538,14 +583,14 @@ def to_result(self, instance: dict, trajectory: dict) -> tuple[list, list]: continue for traj_action in transition["actions"]: - result[f"{transition['name']}_cost"] += traj_action.get( + result[f"{state_name}_cost"] += traj_action.get( "completion_cost", 0 ) transition_result["cost"] += traj_action.get( "completion_cost", 0 ) - if transition["name"] == "SearchCode": + if state_name == "SearchCode": search_iterations += 1 action = transition["actions"][-1] @@ -571,57 +616,40 @@ def to_result(self, instance: dict, trajectory: dict) -> tuple[list, list]: ) or search_request.get("function_names"): result["p_function"] += 1 - if "output" in action and action.get("output"): - output = action["output"] - - if "query" in output: - result["p_query"] += 1 - - if "file_pattern" in output: - result["p_file"] += 1 - - if "code_snippet" in output: - result["p_code"] += 1 - - if "class_name" in output or "class_names" in output: - result["p_class"] += 1 - - if "function_name" in output or "function_names" in output: - result["p_function"] += 1 + if state_name == "IdentifyCode": + id_iterations += 1 - if output.get("ranked_spans"): - for ranked_span in output["ranked_spans"]: - if ( + state = transition["state"] + if state.get("ranked_spans"): + for ranked_span in state["ranked_spans"]: + if ( ranked_span["file_path"] not in search_results_spans - ): - search_results_spans[ - ranked_span["file_path"] - ] = [] + ): search_results_spans[ ranked_span["file_path"] - ].append(ranked_span["span_id"]) + ] = [] + search_results_spans[ + ranked_span["file_path"] + ].append(ranked_span["span_id"]) - if not result["found_in_search"] and ( + if not result["found_in_search"] and ( found_in_expected_spans( instance, search_results_spans ) or found_in_alternative_spans( - instance, search_results_spans - ) - ): - result["found_in_search"] = search_iterations - - if not result["file_in_search"]: - missing_files = get_missing_files( - instance["expected_spans"], - search_results_spans, - ) - if not missing_files: - result["file_in_search"] = search_iterations + instance, search_results_spans + ) + ): + result["found_in_search"] = search_iterations - if transition["name"] == "IdentifyCode": - id_iterations += 1 + if not result["file_in_search"]: + missing_files = get_missing_files( + instance["expected_spans"], + search_results_spans, + ) + if not missing_files: + result["file_in_search"] = search_iterations action = transition["actions"][-1] if action.get("action"): @@ -673,7 +701,7 @@ def to_result(self, instance: dict, trajectory: dict) -> tuple[list, list]: result.get("expected_identified") or 1000, ) - if transition["name"] == "PlanToCode": + if state_name == "PlanToCode": action = transition["actions"][-1]["action"] if action.get("action") == "review": result["review"] = True @@ -701,41 +729,41 @@ def to_result(self, instance: dict, trajectory: dict) -> tuple[list, list]: ): result["planned"] = True - if transition["name"] == "EditCode": + if state_name == "EditCode": result["edit_retries"] = len(transition["actions"]) - 1 action = transition["actions"][-1] - output = action.get("output", {}) + edited = action.get("trigger") == "finish" + + if edited and "file_path" in transition["state"]: + file_path = transition["state"]["file_path"] + if file_path not in edited_spans: + edited_spans[file_path] = [] + edited_spans[file_path].append( + transition["state"]["span_id"] + ) + transition_result["actual_spans"] = ( + f"{file_path}: {transition['state']['span_id']} " + ) + + if not result.get("edited") and ( + found_in_expected_spans( + instance, + edited_spans, + ) + or found_in_alternative_spans(instance, edited_spans) + ): + result["edited"] = True - if output: - edited = output.get("diff") + output = action.get("output", {}) + if output: if edited: result["has_diff"] = True for lint in output.get("verification_errors", []): lint_codes.add(lint["code"]) - if edited and "file_path" in transition["state"]: - file_path = transition["state"]["file_path"] - if file_path not in edited_spans: - edited_spans[file_path] = [] - edited_spans[file_path].append( - transition["state"]["span_id"] - ) - transition_result["actual_spans"] = ( - f"{file_path}: {transition['state']['span_id']} " - ) - - if not result.get("edited") and ( - found_in_expected_spans( - instance, - edited_spans, - ) - or found_in_alternative_spans(instance, edited_spans) - ): - result["edited"] = True - transitions.append(transition_result) if result.get("alt_identified") or result.get("expected_identified"): @@ -752,9 +780,8 @@ def to_result(self, instance: dict, trajectory: dict) -> tuple[list, list]: result["lints"] = ",".join(lint_codes) - if info.get("instance_id", "") in self.report["resolved"]: - result["status"] = "resolved" - elif result["edited"]: + + if result["edited"]: result["status"] = "edited" elif result["identified"]: result["status"] = "identified" @@ -823,13 +850,14 @@ def generate_md_report(trajectory: dict, instance: dict): for j, step in enumerate(trajectory["transitions"]): for i, traj_action in enumerate(step["actions"]): - markdown += f"### {j+1} {step['name']} ({i+1})\n\n" + state_name = step['state'] + markdown += f"### {j+1} {state_name} ({i+1})\n\n" if not traj_action.get("action"): continue action = traj_action["action"] - if step["name"] == "PlanToCode": + if state_name == "PlanToCode": if action.get("scratch_pad"): markdown += "*" + action["scratch_pad"] + "*" @@ -856,7 +884,7 @@ def generate_md_report(trajectory: dict, instance: dict): except Exception as e: logger.error(e) - if step["name"] == "EditCode": + if state_name == "EditCode": markdown += "#### LLM Response\n\n" markdown += f"```\n{action.get('content', '')}\n```\n" @@ -874,7 +902,7 @@ def generate_md_report(trajectory: dict, instance: dict): markdown += "#### Message\n\n" markdown += f"{output['message']}\n\n" - if step["name"] == "ClarifyCodeChange": + if state_name == "ClarifyCodeChange": if action.get("thoughts"): markdown += "*" + action["thoughts"] + "*" @@ -882,10 +910,10 @@ def generate_md_report(trajectory: dict, instance: dict): markdown += f"\n* Start Line: {action['output']['start_line']}\n" markdown += f"\n* End Line: {action['output']['end_line']}\n" - if step["name"] == "Finished": + if state_name == "Finished": markdown += f"*{action['properties']['message']}*\n" - if step["name"] == "Rejected": + if state_name == "Rejected": markdown += f"*{action['properties']['message']}*\n" markdown += "## Alternative patches\n" diff --git a/moatless/codeblocks/parser/create.py b/moatless/codeblocks/parser/create.py index 61315c33..a729fbd0 100644 --- a/moatless/codeblocks/parser/create.py +++ b/moatless/codeblocks/parser/create.py @@ -1,13 +1,25 @@ from moatless.codeblocks.parser.parser import CodeParser from moatless.codeblocks.parser.python import PythonParser +from moatless.codeblocks.parser.java import JavaParser def is_supported(language: str) -> bool: return language and language in ["python", "java", "typescript", "javascript"] +def create_parser_by_ext(ext: str, **kwargs) -> CodeParser | None: + if ext == ".py": + return PythonParser(**kwargs) + elif ext == ".java": + return JavaParser(**kwargs) + + raise NotImplementedError(f"Extension {ext} is not supported.") + + def create_parser(language: str, **kwargs) -> CodeParser | None: if language == "python": return PythonParser(**kwargs) + elif language == "java": + return JavaParser(**kwargs) raise NotImplementedError(f"Language {language} is not supported.") diff --git a/moatless/edit/clarify.py b/moatless/edit/clarify.py index d1ad66c7..5614b190 100644 --- a/moatless/edit/clarify.py +++ b/moatless/edit/clarify.py @@ -82,7 +82,7 @@ def init(self): outcomment_code_comment="... other code", ) - def handle_action(self, request: LineNumberClarification) -> ActionResponse: + def _execute_action(self, request: LineNumberClarification) -> ActionResponse: logger.info( f"{self}: Got line number clarification: {request.start_line} - {request.end_line}" ) diff --git a/moatless/edit/edit.py b/moatless/edit/edit.py index 78630025..bff74ac5 100644 --- a/moatless/edit/edit.py +++ b/moatless/edit/edit.py @@ -148,7 +148,7 @@ def init(self): lines_to_replace = code_lines[self.start_line - 1 : self.end_line] self._code_to_replace = "\n".join(lines_to_replace) - def handle_action(self, content: Content) -> ActionResponse: + def _execute_action(self, content: Content) -> ActionResponse: self._messages.append(AssistantMessage(content=content.content)) scratch_pad = None diff --git a/moatless/edit/plan.py b/moatless/edit/plan.py index 9fca4628..c300f4a6 100644 --- a/moatless/edit/plan.py +++ b/moatless/edit/plan.py @@ -139,7 +139,7 @@ def init(self): ) self.file_context.expand_small_classes(max_tokens=1000) - def handle_action(self, action: ApplyChange) -> ActionResponse: + def _execute_action(self, action: ApplyChange) -> ActionResponse: if action.action == "review": if self.diff and self.finish_on_review: logger.info("Review suggested after diff, will finish") @@ -177,6 +177,11 @@ def _request_for_change(self, rfc: ApplyChange) -> ActionResponse: f"request_for_change(file_path={rfc.file_path}, span_id={rfc.span_id})" ) + if not rfc.instructions: + return ActionResponse.retry( + f"Please provide instructions for the code change." + ) + context_file = self.file_context.get_file(rfc.file_path) if not context_file: logger.warning( diff --git a/moatless/edit/plan_lines.py b/moatless/edit/plan_lines.py index fa78d908..5a3c4271 100644 --- a/moatless/edit/plan_lines.py +++ b/moatless/edit/plan_lines.py @@ -118,7 +118,7 @@ def init(self): ): self.file_context.expand_context_with_related_spans(max_tokens=4000) - def handle_action(self, action: ApplyChange) -> ActionResponse: + def _execute_action(self, action: ApplyChange) -> ActionResponse: if action.finish: self.file_context.save() diff --git a/moatless/edit/review.py b/moatless/edit/review.py index a2c9ea1e..cdf5a9ca 100644 --- a/moatless/edit/review.py +++ b/moatless/edit/review.py @@ -170,7 +170,7 @@ def init(self) -> Optional[ActionResponse]: return None - def handle_action(self, action: ApplyChange) -> ActionResponse: + def _execute_action(self, action: ApplyChange) -> ActionResponse: if action.action == "review": if self.diff and self.finish_on_review: logger.info(f"Review suggested after diff, will finish") diff --git a/moatless/find/decide.py b/moatless/find/decide.py index 26920509..61bf0fbd 100644 --- a/moatless/find/decide.py +++ b/moatless/find/decide.py @@ -92,7 +92,7 @@ def __init__( **data, ) - def handle_action(self, action: Decision) -> ActionResponse: + def _execute_action(self, action: Decision) -> ActionResponse: if action.complete and action.relevant: return ActionResponse.transition("finish") diff --git a/moatless/find/identify.py b/moatless/find/identify.py index cc1b35cd..df3dff15 100644 --- a/moatless/find/identify.py +++ b/moatless/find/identify.py @@ -98,7 +98,7 @@ def __init__( def model_dump(self, **kwargs): return super().model_dump(**kwargs) - def handle_action(self, action: Identify) -> ActionResponse: + def _execute_action(self, action: Identify) -> ActionResponse: if action.identified_spans: self.file_context.add_files_with_spans(action.identified_spans) diff --git a/moatless/find/search.py b/moatless/find/search.py index b5bd8227..fdc54c5d 100644 --- a/moatless/find/search.py +++ b/moatless/find/search.py @@ -332,7 +332,7 @@ def __init__( **data, ) - def handle_action(self, action: Search) -> ActionResponse: + def _execute_action(self, action: Search) -> ActionResponse: if action.complete: return ActionResponse.transition( "finish", @@ -433,7 +433,7 @@ def messages(self) -> list[Message]: query=self.loop.trajectory.initial_message, exact_match_if_possible=False, max_spans_per_file=5, - max_results=50, + max_results=100, ) file_context = self.create_file_context(max_tokens=4000) diff --git a/moatless/index/code_index.py b/moatless/index/code_index.py index 383180e4..63afdd60 100644 --- a/moatless/index/code_index.py +++ b/moatless/index/code_index.py @@ -67,7 +67,6 @@ def __init__( max_exact_results: int = 5, ): self._index_name = index_name - self._settings = settings or IndexSettings() self.max_results = max_results @@ -157,12 +156,12 @@ def from_index_name( logger.info(f"Loading existing index {index_name} from {persist_dir}.") return cls.from_persist_dir(persist_dir, file_repo=file_repo) - if not os.getenv("INDEX_STORE_URL"): - raise ValueError( - "INDEX_STORE_URL environment variable must be set to a index store URL to download the index." - ) + if os.getenv("INDEX_STORE_URL"): + index_store_url = os.getenv("INDEX_STORE_URL") + else: + index_store_url = "https://stmoatless.blob.core.windows.net/indexstore/20240522-voyage-code-2" - store_url = os.path.join(os.getenv("INDEX_STORE_URL"), f"{index_name}.zip") + store_url = os.path.join(index_store_url, f"{index_name}.zip") logger.info(f"Downloading existing index {index_name} from {store_url}.") return cls.from_url(store_url, persist_dir, file_repo) @@ -699,14 +698,23 @@ def file_metadata_func(file_path: str) -> dict: "category": category, } - reader = SimpleDirectoryReader( - input_dir=repo_path, - file_metadata=file_metadata_func, - input_files=input_files, - filename_as_id=True, - required_exts=[".py"], # TODO: Shouldn't be hardcoded and filtered - recursive=True, - ) + if self._settings and self._settings.language == "java": + required_exts = [".java"] + else: + required_exts = [".py"] + + try: + reader = SimpleDirectoryReader( + input_dir=repo_path, + file_metadata=file_metadata_func, + input_files=input_files, + filename_as_id=True, + required_exts=required_exts, + recursive=True, + ) + except Exception as e: + logger.exception(f"Failed to create reader with input_dir {repo_path}, input_files {input_files} and required_exts {required_exts}.") + raise e embed_pipeline = IngestionPipeline( transformations=[self._embed_model], @@ -737,6 +745,7 @@ def index_callback(codeblock: CodeBlock): ) splitter = EpicSplitter( + language=self._settings.language, min_chunk_size=self._settings.min_chunk_size, chunk_size=self._settings.chunk_size, hard_token_limit=self._settings.hard_token_limit, diff --git a/moatless/index/epic_split.py b/moatless/index/epic_split.py index a8fb04e5..ce41d4c7 100644 --- a/moatless/index/epic_split.py +++ b/moatless/index/epic_split.py @@ -10,6 +10,7 @@ from llama_index.core.schema import BaseNode, TextNode from llama_index.core.utils import get_tokenizer, get_tqdm_iterable +from moatless.codeblocks import create_parser from moatless.codeblocks.codeblocks import CodeBlock, CodeBlockType, PathTree from moatless.codeblocks.parser.python import PythonParser from moatless.index.code_node import CodeNode @@ -39,6 +40,10 @@ def count_parent_tokens(codeblock: CodeBlock) -> int: class EpicSplitter(NodeParser): + language: str = Field( + default="python", description="Language of the code blocks to parse." + ) + text_splitter: TextSplitter = Field( description="Text splitter to use for splitting non code documents into nodes." ) @@ -82,6 +87,7 @@ class EpicSplitter(NodeParser): def __init__( self, + language: str = "python", chunk_size: int = 750, min_chunk_size: int = 100, max_chunk_size: int = 1500, @@ -106,6 +112,7 @@ def __init__( # self._fallback_code_splitter = fallback_code_splitter super().__init__( + language=language, chunk_size=chunk_size, chunk_overlap=0, text_splitter=text_splitter or TokenTextSplitter(), @@ -142,10 +149,10 @@ def _parse_nodes( content = node.get_content() try: - # TODO: Derive language from file extension starttime = time.time_ns() - parser = PythonParser(index_callback=self.index_callback) + # TODO: Derive language from file extension + parser = create_parser(language=self.language, index_callback=self.index_callback) codeblock = parser.parse(content, file_path=file_path) parse_time = time.time_ns() - starttime diff --git a/moatless/loop.py b/moatless/loop.py index 91495852..a6880d07 100644 --- a/moatless/loop.py +++ b/moatless/loop.py @@ -8,6 +8,7 @@ from collections.abc import Callable from datetime import datetime from typing import Any, Optional, Type, Tuple +import subprocess import instructor import litellm @@ -15,6 +16,7 @@ from litellm import completion_cost, cost_per_token, token_counter from pydantic import BaseModel, Field, PrivateAttr +from moatless.repository import GitRepository from moatless.state import ( AgenticState, Finished, @@ -43,12 +45,14 @@ def __init__( self, transition_rules: TransitionRules, workspace: Workspace, + input_data: dict[str, Any] | None = None, trajectory: Trajectory | None = None, mocked_actions: list[dict] | None = None, expected_states: list[Type[AgenticState]] | None = None, reset_mocks_at_state: Optional[str] = None, verify_state_func: Optional[Callable] = None, max_cost: float = 0.25, + max_actions: int = 2, max_transitions: int = 25, max_message_tokens: Optional[int] = None, max_retries: int = 2, @@ -57,6 +61,7 @@ def __init__( metadata: dict[str, Any] | None = None, trajectory_path: Optional[str] = None, prompt_log_dir: Optional[str] = None, + **kwargs, ): """ Initialize the Loop instance. @@ -64,11 +69,12 @@ def __init__( Args: """ - self._trajectory = trajectory self._workspace = workspace + self._input_data = input_data + if trajectory_path: parent_dir = os.path.dirname(trajectory_path) if not os.path.exists(parent_dir): @@ -82,13 +88,20 @@ def __init__( self._mocked_actions = mocked_actions if expected_states and not verify_state_func: + def verify_state_func(state: AgenticState): nonlocal expected_states if not expected_states: - raise ValueError(f"No more expected states, but got {state.__class__}") + raise ValueError( + f"No more expected states, but got {state.__class__}" + ) expected_state = expected_states.pop(0) - if not (state.name == expected_state or isinstance(state, expected_state)): - raise ValueError(f"Expected state {expected_state} but got {state.__class__.__name__}") + if not ( + state.name == expected_state or isinstance(state, expected_state) + ): + raise ValueError( + f"Expected state {expected_state} but got {state.__class__.__name__}" + ) self.log_info(f"Verified expected next state {expected_state}") @@ -99,6 +112,7 @@ def verify_state_func(state: AgenticState): self._max_cost = max_cost self._max_message_tokens = max_message_tokens self._max_transitions = max_transitions + self._max_actions = max_actions self._max_retries = max_retries self._max_rejections = max_rejections self._instructor_mode = instructor_mode @@ -110,11 +124,15 @@ def verify_state_func(state: AgenticState): self._initial_message = "" self._transitions: dict[int, TrajectoryTransition] = {} - self._current_state: AgenticState = Pending() self._current_transition: TrajectoryTransition | None = None self._metadata = metadata + self._type = "standard" + + for k, v in kwargs.items(): + setattr(self, k, v) + @classmethod def from_trajectory_file(cls, trajectory_path: str, **kwargs): trajectory = Trajectory.load(trajectory_path) @@ -136,39 +154,12 @@ def retry_from_transition( transition_id: int, state_params: dict[Type[AgenticState], Any] = None, ): - self.revert_to_transition(transition_id) + self.clone_transition(transition_id) # TODO: I'm using only state params as an easy way test out changes. Need to think about a better way to do this. self._transition_rules.state_params.update(state_params) - # TODO: DRY - while self.is_running(): - try: - self._run() - except Exception as e: - logger.warning( - f"{self.transition_name}: Failed to run loop. Error: {e}" - ) - raise - - if self.retries() > self._max_retries: - logger.warning( - f"{self.transition_name}: Max retries reached ({self._max_retries}). Exiting." - ) - self.trajectory.save_info({"error": "Max retries reached."}) - return Response( - status="rejected", - message="The loop was aborted because the number of retries exceeded the limit.", - ) - - total_cost = self.total_cost() - if total_cost > self._max_cost: - logger.warning( - f"{self.transition_name}: Max cost reached ({total_cost} > {self._max_cost}). Exiting." - ) - self.trajectory.save_info({"error": "Max cost reached."}) - raise RuntimeError( - "The loop was aborted because the cost exceeded the limit.", - ) + while not self.is_finished(): + self.run_until_transition() if isinstance(self.state, Finished): return Response(status="finished", message=self.state.message or "") @@ -177,16 +168,7 @@ def retry_from_transition( raise RuntimeError(f"Loop exited with unknown state {self.state.name}.") - def run( - self, message: Optional[str] = None, input_data: dict[str, Any] | None = None - ) -> Response: - """ - Run the loop and handle exceptions and cost checking. - """ - - if self.is_running(): - raise Exception("Loop is already running.") - + def initialize_or_load_trajectory(self, message: Optional[str] = None) -> None: if not self._trajectory: self._trajectory = Trajectory( "MoatlessTools", @@ -195,16 +177,15 @@ def run( workspace=self._workspace, transition_rules=self._transition_rules, ) - initial_state = self._transition_rules.create_initial_state( - **input_data or {} + pending_transition = self._create_transition( + state=Pending(), + snapshot=self._workspace.snapshot() ) - self.transition_to(initial_state) + self._set_current_transition(pending_transition) else: for transition in self._trajectory.transitions: self.set_current_transition_from_dict(transition) - self._transitions[self._current_transition.id] = ( - self._current_transition - ) + self.workspace.restore_from_snapshot(transition.get("snapshot")) for transition_data in self._trajectory.transitions: transition = self._transitions[transition_data["id"]] @@ -213,25 +194,28 @@ def run( transition.parent = parent parent.children.append(transition) - while self.is_running(): - try: - self._run() - except Exception as e: - logger.warning( - f"{self.transition_name}: Failed to run loop. Error: {e}" - ) - raise + def run(self, message: Optional[str] = None) -> Response: + """ + Run the loop and handle exceptions and cost checking. + """ - if self.retries() > self._max_retries: - logger.warning( - f"{self.transition_name}: Max retries reached ({self._max_retries}). Exiting." - ) - self.trajectory.save_info({"error": "Max retries reached."}) - return Response( - status="rejected", - message="The loop was aborted because the number of retries exceeded the limit.", - ) + if self.is_running(): + raise Exception("Loop is already running.") + + self.initialize_or_load_trajectory(message) + while not self.is_finished(): + self.run_until_transition() + + if isinstance(self.state, Finished): + return Response(status="finished", message=self.state.message or "") + elif isinstance(self.state, Rejected): + return Response(status="rejected", message=self.state.message or "") + + raise RuntimeError(f"Loop exited with unknown state {self.state.name}.") + + def run_until_transition(self) -> TrajectoryTransition: + while not self.is_finished(): total_cost = self.total_cost() if total_cost > self._max_cost: logger.warning( @@ -241,13 +225,29 @@ def run( raise RuntimeError( "The loop was aborted because the cost exceeded the limit.", ) + else: + self.log_info( + f"Running transition {len(self._transitions)}. Current total cost: {total_cost}" + ) - if isinstance(self.state, Finished): - return Response(status="finished", message=self.state.message or "") - elif isinstance(self.state, Rejected): - return Response(status="rejected", message=self.state.message or "") + try: + transition = self._run() + if transition: + return transition + except Exception as e: + logger.warning( + f"{self.transition_name}: Failed to run loop. Error: {e}" + ) + raise - raise RuntimeError(f"Loop exited with unknown state {self.state.name}.") + if self.retries() > self._max_retries: + logger.warning( + f"{self.transition_name}: Max retries reached ({self._max_retries}). Exiting." + ) + self.trajectory.save_info({"error": "Max retries reached."}) + return self.transition_to(Rejected(message="Max retries reached.")) + + raise RuntimeError("Loop exited without a transition.") def total_cost(self): total_cost = 0 @@ -261,6 +261,9 @@ def total_cost(self): def is_running(self) -> bool: return not isinstance(self.state, NoopState) + def is_finished(self) -> bool: + return isinstance(self.state, (Finished, Rejected)) + def _set_state_loop(self, state: AgenticState): state._set_loop(self) @@ -299,13 +302,17 @@ def retry_messages(self, state: AgenticState) -> list[Message]: return messages + def _set_current_transition(self, transition: TrajectoryTransition): + self._current_transition = transition + self._transitions[transition.id] = transition + self._trajectory.set_current_transition_id(transition.id) + def set_current_transition_from_dict(self, transition_data: dict): state_data = transition_data.get("state", {}) name = state_data.get("name") try: state_class = get_state_class(name) state = state_class(**state_data) - self.workspace.restore_from_snapshot(transition_data.get("snapshot")) transition = TrajectoryTransition( id=transition_data["id"], @@ -317,28 +324,56 @@ def set_current_transition_from_dict(self, transition_data: dict): timestamp=datetime.fromisoformat(transition_data["timestamp"]), ) - state._set_loop(self) + self._set_current_transition(transition) + self._set_state_loop(state) state.init() - self._current_state = state - self._current_transition = transition except Exception as e: logger.exception(f"Failed to load state {name}") raise e def set_current_transition(self, transition: TrajectoryTransition): - self.workspace.restore_from_snapshot(transition.snapshot) - self._current_state = transition.state - self._current_transition = transition + self._set_current_transition(transition) - def revert_to_transition(self, transition_id: int): + def revert_to_transition(self, transition_id: int) -> TrajectoryTransition: transition = self._transitions.get(transition_id) if transition: - self.set_current_transition(transition) + self.log_info(f"Reverting to transition {transition_id}") + self._set_current_transition(transition) + self.workspace.restore_from_snapshot(transition.snapshot) + return transition else: - raise ValueError("Invalid state index for reversion") + logger.warning( + f"Tried to revert to transition {transition_id} but it does not exist. Existing transition ids: {self._transitions.keys()}" + ) + raise ValueError( + f"Could not revert to transition {transition_id} as it does not exist." + ) - def transition_to(self, new_state: AgenticState): + def _create_transition( + self, + state: AgenticState, + snapshot: dict | None = None, + parent: TrajectoryTransition | None = None, + ): + transition = TrajectoryTransition( + id=len(self._transitions) + 1, state=state, snapshot=snapshot, parent=parent + ) + self.trajectory.create_transition(transition) + self._transitions[transition.id] = transition + return transition + + def clone_current_transition(self): + cloned_state = self.state.clone() + cloned_transition = self._create_transition( + state=cloned_state, + snapshot=self._current_transition.snapshot, + parent=self._current_transition.parent, + ) + self._set_current_transition(cloned_transition) + return cloned_transition + + def transition_to(self, new_state: AgenticState) -> TrajectoryTransition: self.log_info(f"Transitioning from {self.state.name} to {new_state.name}") if self.transition_count() > self._max_transitions: @@ -352,7 +387,7 @@ def transition_to(self, new_state: AgenticState): message=f"Max transitions exceeded for state {new_state.name}." ) - transition = TrajectoryTransition( + transition = self._create_transition( state=new_state, snapshot=self.workspace.snapshot(), parent=self._current_transition, @@ -361,12 +396,10 @@ def transition_to(self, new_state: AgenticState): if self._current_transition: self._current_transition.children.append(transition) - transition = self.trajectory.create_transition(transition) + self._set_current_transition(transition) + self._set_state_loop(new_state) - self._transitions[transition.id] = transition - self._current_transition = transition - self._current_state = new_state - self._set_state_loop(self.state) + return transition def transition_count(self, state: AgenticState | None = None) -> int: if not state: @@ -393,7 +426,7 @@ def get_previous_transitions(self, state: AgenticState | None): @property def state(self): - return self._current_state + return self._current_transition.state if self._current_transition else Pending() @property def workspace(self) -> Workspace: @@ -490,10 +523,23 @@ def _to_completion_messages(self) -> list[dict]: return messages - def _run(self): - if not self.is_running(): - self.log_info("Loop is not running.") - return + def _run(self) -> TrajectoryTransition | None: + """ + Run the loop for one iteration. + + Returns: + + """ + if self.is_finished(): + self.log_info("Loop already finished.") + return None + + if isinstance(self.state, Pending): + logger.info("Initializing first state.") + initial_state = self._transition_rules.create_initial_state( + **(self._input_data or {}) + ) + return self.transition_to(initial_state) action, cost, input_tokens, output_tokens = self._next_action() @@ -510,19 +556,19 @@ def _run(self): output_tokens=output_tokens, ) ) - self.trajectory.save_transition(self._current_transition) + self.trajectory.update_transition(self._current_transition) if not response.trigger: self.log_info( f"{self.state.name}: No trigger in action response. Staying in the same state." ) - return + return None self.log_info(f"Received response with trigger {response.trigger}") if response.trigger == "retry": self.log_info(f"Retry requested. {response.retry_message}") - return + return None try: next_state = self._transition_rules.next_state( @@ -549,7 +595,7 @@ def _run(self): else: self._rejections = 0 - self.transition_to(next_state) + return self.transition_to(next_state) @property def instructor_mode(self): @@ -703,7 +749,7 @@ def _next_action( try: cost = completion_cost( completion_response=completion_response, - model="claude-3-5-sonnet-20240620", + model=self.state.model, ) except Exception as e: self.log_info(f"Error calculating completion cost: {e}") @@ -727,8 +773,15 @@ def _log_prompt( if not self._prompt_log_dir: return - transition_no = self.transition_count() - prompt_path = f"{self._prompt_log_dir}/{transition_no:02d}_{self.state.name}.md" + time_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S") + + prompt_path = ( + f"{self._prompt_log_dir}/{self._current_transition.id}_{self.state.name}" + ) + if self.retries() > 0: + prompt_path += f"_retry_{self.retries()}" + + prompt_path += f"_{time_str}.md" with open(prompt_path, "w") as f: f.write("\n\n# Completion\n") diff --git a/moatless/repository/file.py b/moatless/repository/file.py index ddb783b9..5460281a 100644 --- a/moatless/repository/file.py +++ b/moatless/repository/file.py @@ -176,6 +176,10 @@ def __init__(self, repo_path: str): self._repo_path = repo_path self._files: dict[str, CodeFile] = {} + @property + def repo_dir(self): + return self._repo_path + def dict(self): return {"type": "file", "path": self._repo_path} @@ -185,6 +189,10 @@ def snapshot(self) -> dict: def restore_from_snapshot(self, snapshot: dict): pass + def restore_from_disk(self): + for file_path in self._files.keys(): + self.get_file(file_path, refresh=True) + @property def path(self): return self._repo_path @@ -198,8 +206,8 @@ def get_file( Args: """ - file = self._files.get(file_path) - if not file or refresh or from_origin: + existing_file = self._files.get(file_path) + if not existing_file or refresh or from_origin: full_file_path = os.path.join(self._repo_path, file_path) if not os.path.exists(full_file_path): logger.warning(f"File not found: {full_file_path}") @@ -213,13 +221,19 @@ def get_file( if parser: content = f.read() module = parser.parse(content) - file = CodeFile(file_path=file_path, content=content, module=module) + found_file = CodeFile(file_path=file_path, content=content, module=module) else: - file = CodeFile(file_path=file_path, content=f.read()) + found_file = CodeFile(file_path=file_path, content=f.read()) + + if not existing_file: + existing_file = found_file + self._files[file_path] = existing_file + elif refresh or not from_origin: + existing_file.content = found_file.content + existing_file.module = found_file.module + existing_file.dirty = False - if refresh or not from_origin: - self._files[file_path] = file - return file + return existing_file def save_file(self, file_path: str, updated_content: Optional[str] = None): file = self._files.get(file_path) diff --git a/moatless/repository/git.py b/moatless/repository/git.py index dae30ddf..488f4512 100644 --- a/moatless/repository/git.py +++ b/moatless/repository/git.py @@ -57,8 +57,14 @@ def from_dict(cls, data: dict): def restore_from_snapshot(self, snapshot: dict): self._current_commit = snapshot["commit"] + + self._repo.git.checkout(self._current_commit) + # TODO: Check diff and only reset changed files + + self.restore_from_disk() + def dict(self): return { "type": "git", diff --git a/moatless/state.py b/moatless/state.py index 866761ad..1d477e83 100644 --- a/moatless/state.py +++ b/moatless/state.py @@ -3,6 +3,7 @@ import importlib from abc import ABC, abstractmethod from typing import Any, Optional +from copy import deepcopy from pydantic import BaseModel, Field, PrivateAttr, ConfigDict @@ -37,14 +38,31 @@ class AgenticState(ABC, BaseModel): _loop: Optional["AgenticLoop"] = PrivateAttr(None) # noqa: F821 + _executed: bool = PrivateAttr(False) + _last_action: Optional[ActionRequest] = PrivateAttr(None) + _response: Optional[ActionResponse] = PrivateAttr(None) + # model_config = ConfigDict(extra='allow') def __init__(self, **data): super().__init__(**data) self._loop = None - @abstractmethod def handle_action(self, action: ActionRequest) -> ActionResponse: + if self._executed: + raise ValueError(f"State has already been executed") + + self._last_action = action + response = self._execute_action(action) + + if response.trigger and response.trigger != "retry": + self._executed = True + self._response = response + + return response + + @abstractmethod + def _execute_action(self, action: ActionRequest) -> ActionResponse: raise NotImplementedError def _set_loop(self, loop: "AgenticLoop"): # noqa: F821 @@ -55,6 +73,18 @@ def _set_loop(self, loop: "AgenticLoop"): # noqa: F821 def name(self): return self.__class__.__name__ + @property + def executed(self): + return self._executed + + @property + def last_action(self) -> Optional[ActionRequest]: + return self._last_action + + @property + def response(self) -> Optional[ActionResponse]: + return self._response + @property def loop(self) -> "AgenticLoop": # noqa: F821 assert self._loop is not None, "Loop has not been set" @@ -127,13 +157,34 @@ def model_dump(self, **kwargs): data = super().model_dump(**kwargs) return {"name": self.name, **data} + def clone(self) -> "AgenticState": + data = self.model_dump(exclude={"_executed", "_last_action", "_response"}) + new_state = self.__class__(**data) + new_state._loop = self._loop + return new_state + + def __eq__(self, other): + if not isinstance(other, AgenticState): + return NotImplemented + + if self.model_dump() != other.model_dump(): + return False + + if self._loop and other._loop: + self_context = self._loop.workspace.file_context + other_context = other._loop.workspace.file_context + + return self_context.model_dump() == other_context.model_dump() + + return True + class NoopState(AgenticState): def __init__(self, **data): super().__init__(**data) - def handle_action(self, action: ActionRequest): - raise NotImplementedError + def _execute_action(self, action: ActionRequest): + raise ValueError("NoopState cannot handle actions") class Finished(NoopState): diff --git a/moatless/trajectory.py b/moatless/trajectory.py index b88ad2e2..971a3336 100644 --- a/moatless/trajectory.py +++ b/moatless/trajectory.py @@ -29,7 +29,7 @@ def model_dump(self, **kwargs): class TrajectoryTransition(BaseModel): - id: Optional[int] = None + id: int parent: Optional["TrajectoryTransition"] = None children: list["TrajectoryTransition"] = Field(default_factory=list) state: AgenticState @@ -57,6 +57,7 @@ def model_dump(self, **kwargs): return data + class Trajectory: def __init__( self, @@ -74,6 +75,8 @@ def __init__( self._transitions: list[dict[str, Any]] = [] + self._current_transition_id = 0 + self._info: dict[str, Any] = {} @classmethod @@ -114,14 +117,13 @@ def workspace(self) -> dict[str, Any] | None: return self._workspace def create_transition(self, transition: TrajectoryTransition): - transition.id = len(self._transitions) + 1 self._transitions.append( transition.model_dump(exclude_none=True, exclude_unset=True) ) self._maybe_persist() return transition - def save_transition(self, transition: TrajectoryTransition): + def update_transition(self, transition: TrajectoryTransition): for i, t in enumerate(self._transitions): if t["id"] == transition.id: self._transitions[i] = transition.model_dump( @@ -132,6 +134,10 @@ def save_transition(self, transition: TrajectoryTransition): raise ValueError(f"Transition with id {transition.id} not found") + def set_current_transition_id(self, transition_id: int): + self._current_transition_id = transition_id + self._maybe_persist() + def save_info(self, info: dict): self._info = info self._maybe_persist() @@ -160,7 +166,6 @@ def get_expected_states(self) -> list[str]: states.append(transition["state"]["name"]) return states - def to_dict(self): return { "name": self._name, @@ -171,6 +176,7 @@ def to_dict(self): else None, "workspace": self._workspace, "initial_message": self._initial_message, + "current_transition_id": self._current_transition_id, "transitions": self._transitions, "info": self._info, } diff --git a/moatless/transition_rules.py b/moatless/transition_rules.py index 395aa1c2..5b836f56 100644 --- a/moatless/transition_rules.py +++ b/moatless/transition_rules.py @@ -2,6 +2,8 @@ from pydantic import BaseModel, Field, PrivateAttr, model_validator from typing import Any, Type, Optional + +from moatless.settings import Settings from moatless.state import AgenticState, get_state_class @@ -89,7 +91,7 @@ def model_dump(self, **kwargs): @model_validator(mode="before") @classmethod - def validate_state_classes(cls, data: Any) -> Any: + def validate_before_init(cls, data: Any) -> Any: if isinstance(data, dict): if isinstance(data.get("initial_state"), str): data["initial_state"] = get_state_class(data["initial_state"]) @@ -100,6 +102,13 @@ def validate_state_classes(cls, data: Any) -> Any: for k, v in data["state_params"].items() } + if "global_params" not in data: + data["global_params"] = {} + + if "model" not in data["global_params"]: + logger.info(f"No model specified in global_params. Using default model: {Settings.default_model}") + data["global_params"]["model"] = Settings.default_model + return data def _build_source_trigger_index(self): diff --git a/tests/edit/test_clarify.py b/tests/edit/test_clarify.py index 162db6ba..eaba3a74 100644 --- a/tests/edit/test_clarify.py +++ b/tests/edit/test_clarify.py @@ -34,7 +34,7 @@ def test_line_span_in_end_of_class(mocker): request = LineNumberClarification(start_line=562, end_line=563, thoughts="") - response = coding.handle_action(request) + response = coding._execute_action(request) assert response == ActionResponse( trigger="edit_code", output={ @@ -63,7 +63,7 @@ def test_impl_span(mocker): start_line=start_line, end_line=end_line, thoughts="" ) - response = coding.handle_action(request) + response = coding._execute_action(request) assert response == ActionResponse( trigger="edit_code", output={ @@ -92,7 +92,7 @@ def test_line_span_in_class(mocker): start_line=start_line, end_line=end_line, thoughts="" ) - response = coding.handle_action(request) + response = coding._execute_action(request) assert response == ActionResponse( trigger="edit_code", output={ diff --git a/tests/integration_test.py b/tests/integration_test.py index eb2d3c9b..272faabf 100644 --- a/tests/integration_test.py +++ b/tests/integration_test.py @@ -19,8 +19,8 @@ moatless_dir = os.getenv("MOATLESS_DIR", "/tmp/moatless") global_params = { - "model": "azure/gpt-4o", - "temperature": 0.2, + "model": "gpt-4o-mini-2024-07-18", # "azure/gpt-4o", + "temperature": 0.5, "max_tokens": 2000, "max_prompt_file_tokens": 8000, } diff --git a/tests/loop/test_loop.py b/tests/loop/test_loop.py index dab7c33f..beb871ac 100644 --- a/tests/loop/test_loop.py +++ b/tests/loop/test_loop.py @@ -1,12 +1,21 @@ +import os import tempfile +import pytest + from moatless import AgenticLoop from moatless.benchmark.swebench import create_workspace, load_instance from moatless.repository import GitRepository from moatless.settings import Settings from moatless.trajectory import Trajectory +pytest.mark.api_keys_required = pytest.mark.skipif( + "VOYAGE_API_KEY" not in os.environ or os.environ["VOYAGE_API_KEY"] == "", + reason="VOYAGE_API_KEY environment variable is required" +) + +@pytest.mark.api_keys_required def test_rerun_save_and_load_trajectory(): trajectory = Trajectory.load("tests/trajectories/django__django_16379.json") Settings.cheap_model = None # To not use an LLM when generating commit messages diff --git a/tests/test_state.py b/tests/test_state.py index ba70d415..62cd3d5e 100644 --- a/tests/test_state.py +++ b/tests/test_state.py @@ -8,7 +8,7 @@ class ConcreteAgenticState(AgenticState): - def handle_action(self, action: ActionRequest) -> ActionResponse: + def _execute_action(self, action: ActionRequest) -> ActionResponse: return ActionResponse(content="Test response") @@ -79,3 +79,84 @@ def test_agentic_state_model_dump(test_state): dump = test_state.model_dump() assert "name" in dump assert dump["name"] == "ConcreteAgenticState" + +def test_agentic_state_equality_same_state(): + state1 = ConcreteAgenticState(temperature=0.5, max_tokens=500) + state2 = ConcreteAgenticState(temperature=0.5, max_tokens=500) + assert state1 == state2 + +def test_agentic_state_equality_different_state(): + state1 = ConcreteAgenticState(temperature=0.5, max_tokens=500) + state2 = ConcreteAgenticState(temperature=0.7, max_tokens=500) + assert state1 != state2 + +def test_agentic_state_equality_different_types(): + state1 = ConcreteAgenticState() + state2 = NoopState() + assert state1 != state2 + +def test_agentic_state_equality_with_file_context(): + mock_loop1 = MagicMock() + mock_workspace1 = MagicMock(spec=Workspace) + mock_file_context1 = MagicMock(spec=FileContext) + mock_loop1.workspace = mock_workspace1 + mock_workspace1.file_context = mock_file_context1 + mock_file_context1.model_dump.return_value = {"files": [{"file_path": "test.py", "spans": []}]} + + mock_loop2 = MagicMock() + mock_workspace2 = MagicMock(spec=Workspace) + mock_file_context2 = MagicMock(spec=FileContext) + mock_loop2.workspace = mock_workspace2 + mock_workspace2.file_context = mock_file_context2 + mock_file_context2.model_dump.return_value = {"files": [{"file_path": "test.py", "spans": []}]} + + state1 = ConcreteAgenticState() + state2 = ConcreteAgenticState() + + state1._set_loop(mock_loop1) + state2._set_loop(mock_loop2) + + assert state1 == state2 + +def test_agentic_state_inequality_with_different_file_context(): + mock_loop1 = MagicMock() + mock_workspace1 = MagicMock(spec=Workspace) + mock_file_context1 = MagicMock(spec=FileContext) + mock_loop1.workspace = mock_workspace1 + mock_workspace1.file_context = mock_file_context1 + mock_file_context1.model_dump.return_value = {"files": [{"file_path": "test1.py", "spans": ["foo"]}]} + + mock_loop2 = MagicMock() + mock_workspace2 = MagicMock(spec=Workspace) + mock_file_context2 = MagicMock(spec=FileContext) + mock_loop2.workspace = mock_workspace2 + mock_workspace2.file_context = mock_file_context2 + mock_file_context2.model_dump.return_value = {"files": [{"file_path": "test1.py", "spans": ["bar"]}]} + + state1 = ConcreteAgenticState() + state2 = ConcreteAgenticState() + + state1._set_loop(mock_loop1) + state2._set_loop(mock_loop2) + + assert state1 != state2 + +def test_agentic_state_equality_without_loop(): + state1 = ConcreteAgenticState(temperature=0.5, max_tokens=500) + state2 = ConcreteAgenticState(temperature=0.5, max_tokens=500) + assert state1 == state2 + +def test_agentic_state_equality_one_with_loop(): + mock_loop = MagicMock() + mock_workspace = MagicMock(spec=Workspace) + mock_file_context = MagicMock(spec=FileContext) + mock_loop.workspace = mock_workspace + mock_workspace.file_context = mock_file_context + mock_file_context.model_dump.return_value = {"files": [{"file_path": "test.py", "spans": []}]} + + state1 = ConcreteAgenticState(temperature=0.5, max_tokens=500) + state2 = ConcreteAgenticState(temperature=0.5, max_tokens=500) + + state1._set_loop(mock_loop) + + assert state1 == state2 \ No newline at end of file diff --git a/tests/test_transition_rules.py b/tests/test_transition_rules.py index 7c920a58..a8af1359 100644 --- a/tests/test_transition_rules.py +++ b/tests/test_transition_rules.py @@ -10,7 +10,7 @@ class MockStateA(AgenticState): value: int = 0 - def handle_action(self, action: str, **kwargs): + def _execute_action(self, action: str, **kwargs): if action == "to_b": return ActionResponse(output={"message": "Moving to B"}, trigger="to_b") return ActionResponse(output={"message": "Staying in A"}, trigger=None) @@ -19,7 +19,7 @@ def handle_action(self, action: str, **kwargs): class MockStateB(AgenticState): default_name: str = "" - def handle_action(self, action: str, **kwargs): + def _execute_action(self, action: str, **kwargs): if action == "finish": return ActionResponse(output={"message": "Finishing"}, trigger="finish") elif action == "reject": @@ -177,25 +177,25 @@ def test_next_state(): # Test successful transition source_state = MockStateA(value=5) - action_response = source_state.handle_action("to_b") + action_response = source_state._execute_action("to_b") next_state = rules.next_state(source_state, action_response.trigger, {"value": 5}) assert isinstance(next_state, MockStateB) assert next_state.name == "MockStateB" assert next_state.model == "claude-3.5-sonnet" # Test transition with missing required fields - action_response = source_state.handle_action("to_b") + action_response = source_state._execute_action("to_b") next_state = rules.next_state(source_state, action_response.trigger, {}) assert next_state is None # Test transition to Finished state source_state = MockStateB(default_name="TestB") - action_response = source_state.handle_action("finish") + action_response = source_state._execute_action("finish") next_state = rules.next_state(source_state, action_response.trigger, {}) assert isinstance(next_state, Finished) # Test transition to Rejected state - action_response = source_state.handle_action("reject") + action_response = source_state._execute_action("reject") next_state = rules.next_state( source_state, action_response.trigger, {"message": "Custom rejection message"} )