Skip to content

Commit

Permalink
Fix integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
aorwall committed Jul 31, 2024
1 parent 93f14db commit cb2be6b
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 58 deletions.
4 changes: 3 additions & 1 deletion moatless/benchmark/swebench/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,9 @@ def create_workspace(
git_repo_url=repo_url, repo_path=repo_dir, commit=instance["base_commit"]
)

code_index = CodeIndex.from_index_name(instance["instance_id"], index_store_dir=index_store_dir, file_repo=repo)
code_index = CodeIndex.from_index_name(
instance["instance_id"], index_store_dir=index_store_dir, file_repo=repo
)

return Workspace(
file_repo=repo,
Expand Down
11 changes: 7 additions & 4 deletions moatless/index/code_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,12 @@ def from_url(cls, url: str, persist_dir: str, file_repo: FileRepository):
)

@classmethod
def from_index_name(cls, index_name: str, file_repo: FileRepository, index_store_dir: Optional[str] = None):
def from_index_name(
cls,
index_name: str,
file_repo: FileRepository,
index_store_dir: Optional[str] = None,
):
if not index_store_dir:
index_store_dir = os.getenv("INDEX_STORE_DIR")

Expand All @@ -167,9 +172,7 @@ def from_index_name(cls, index_name: str, file_repo: FileRepository, index_store
return cls.from_url(store_url, persist_dir, file_repo)

def dict(self):
return {
"index_name": self._index_name
}
return {"index_name": self._index_name}

def search(
self,
Expand Down
77 changes: 45 additions & 32 deletions moatless/loop.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,9 +102,7 @@ def __init__(
self._metadata = metadata

@classmethod
def from_trajectory_file(
cls, trajectory_path: str, **kwargs
):
def from_trajectory_file(cls, trajectory_path: str, **kwargs):
trajectory = Trajectory.load(trajectory_path)
transitions = trajectory.transitions
workspace = Workspace.from_dict(trajectory.workspace)
Expand All @@ -126,18 +124,22 @@ def retry_from_transition(
):
self.revert_to_transition(transition_id)
# TODO: I'm using only state params as an easy way test out changes. Need to think about a better way to do this.
self._transition_rules._state_params.update(state_params)
self._transition_rules.state_params.update(state_params)

# TODO: DRY
while self.is_running():
try:
self._run()
except Exception as e:
logger.warning(f"Failed to run loop. Error: {e}")
logger.warning(
f"{self.transition_name}: Failed to run loop. Error: {e}"
)
raise

if self.retries() > self._max_retries:
logger.warning(f"Max retries reached ({self._max_retries}). Exiting.")
logger.warning(
f"{self.transition_name}: Max retries reached ({self._max_retries}). Exiting."
)
self.trajectory.save_info({"error": "Max retries reached."})
return Response(
status="rejected",
Expand All @@ -147,7 +149,7 @@ def retry_from_transition(
total_cost = self.total_cost()
if total_cost > self._max_cost:
logger.warning(
f"Max cost reached ({total_cost} > {self._max_cost}). Exiting."
f"{self.transition_name}: Max cost reached ({total_cost} > {self._max_cost}). Exiting."
)
self.trajectory.save_info({"error": "Max cost reached."})
raise RuntimeError(
Expand Down Expand Up @@ -201,11 +203,15 @@ def run(
try:
self._run()
except Exception as e:
logger.warning(f"Failed to run loop. Error: {e}")
logger.warning(
f"{self.transition_name}: Failed to run loop. Error: {e}"
)
raise

if self.retries() > self._max_retries:
logger.warning(f"Max retries reached ({self._max_retries}). Exiting.")
logger.warning(
f"{self.transition_name}: Max retries reached ({self._max_retries}). Exiting."
)
self.trajectory.save_info({"error": "Max retries reached."})
return Response(
status="rejected",
Expand All @@ -215,7 +221,7 @@ def run(
total_cost = self.total_cost()
if total_cost > self._max_cost:
logger.warning(
f"Max cost reached ({total_cost} > {self._max_cost}). Exiting."
f"{self.transition_name}: Max cost reached ({total_cost} > {self._max_cost}). Exiting."
)
self.trajectory.save_info({"error": "Max cost reached."})
raise RuntimeError(
Expand Down Expand Up @@ -319,7 +325,7 @@ def revert_to_transition(self, transition_id: int):
raise ValueError("Invalid state index for reversion")

def transition_to(self, new_state: AgenticState):
logger.info(f"Transitioning from {self.state.name} to {new_state.name}")
self.log_info(f"Transitioning from {self.state.name} to {new_state.name}")

if self.transition_count() > self._max_transitions:
new_state = Rejected(message="Max transitions exceeded.")
Expand Down Expand Up @@ -365,7 +371,7 @@ def get_previous_transitions(self, state: AgenticState | None):

parent_transition = parent_transition.parent

logger.info(
self.log_info(
f"Found {len(previous_transitions)} previous transitions for {state.name if state else 'all states'}"
)

Expand Down Expand Up @@ -472,12 +478,12 @@ def _to_completion_messages(self) -> list[dict]:

def _run(self):
if not self.is_running():
logger.info("Loop is not running.")
self.log_info("Loop is not running.")
return

action, cost, input_tokens, output_tokens = self._next_action()

logger.info(f"{self.state.name}: Received new action {action.action_name}.")
self.log_info(f"Received new action {action.action_name}.")
response = self.state.handle_action(action)

self._current_transition.actions.append(
Expand All @@ -493,13 +499,13 @@ def _run(self):
self.trajectory.save_transition(self._current_transition)

if not response.trigger:
logger.info(
self.log_info(
f"{self.state.name}: No transition found. Staying in the same state."
)
return

if response.trigger == "retry":
logger.info(f"{self.state.name}: Retry requested. {response.retry_message}")
self.log_info(f"Retry requested. {response.retry_message}")
return

try:
Expand Down Expand Up @@ -527,7 +533,7 @@ def _run(self):
else:
self._rejections = 0

logger.info(f"{self.state.name}: Transitioning to {next_state.name}")
self.log_info(f"Transitioning to {next_state.name}")
self.transition_to(next_state)

@property
Expand All @@ -553,41 +559,36 @@ def _next_mock_action(
return None

if self._reset_mocks_at_state and self.state.name == self._reset_mocks_at_state:
logger.info(f"Resetting mocked actions at state {self.state.name}")
self.log_info(f"Resetting mocked actions at state {self.state.name}")
self._mocked_actions = []
return None

action = self._mocked_actions.pop(0)

if self.state.action_type():
try:
logger.info(
f"{self.state.name} Return mocked response with type {self.state.action_type().__name__} ({len(self._mocked_actions)} left)."
self.log_info(
f"Return mocked response with type {self.state.action_type().__name__} ({len(self._mocked_actions)} left)."
)
return self.state.action_type().model_validate(action)

except Exception:
logger.error(
f"Failed to parse {action} to {self.state.action_type().__name__} in state {self.state.name}"
f"{self.transition_name}: Failed to parse {action} to {self.state.action_type().__name__} in state {self.state.name}"
)
raise
elif "content" in action:
logger.info(
f"{self.state.name} Return mocked response ({len(self._mocked_actions)} left)."
)
self.log_info(f"Return mocked response ({len(self._mocked_actions)} left).")
return Content(content=action["content"])


else:
raise ValueError(f"Mocked action {action} does not have 'content' field.")

def _next_action(
self,
) -> tuple[ActionRequest, Optional[float], Optional[int], Optional[int]]:
messages = self._to_completion_messages()
logger.info(
f"{self.state.name} Create completion with {len(messages)} messages"
)
self.log_info(f"Create completion with {len(messages)} messages")

if self._verify_state_func:
self._verify_state_func(self.state)
Expand All @@ -605,7 +606,7 @@ def _next_action(
if self._max_message_tokens and tokens > self._max_message_tokens:
raise ValueError(f"Too many tokens in the new message: {tokens}")

logger.info(f"{self.state.name}: Do completion request to {self.state.model}")
self.log_info(f"Do completion request to {self.state.model}")

if self.state.model.startswith("claude") and self.state.action_type():
try:
Expand All @@ -625,8 +626,8 @@ def _next_action(
)
)

logger.info(
f"{self.state.name}: Input tokens: {completion_response.usage.input_tokens}, Output tokens: {completion_response.usage.output_tokens}"
self.log_info(
f"Input tokens: {completion_response.usage.input_tokens}, Output tokens: {completion_response.usage.output_tokens}"
)
(
prompt_tokens_cost_usd_dollar,
Expand Down Expand Up @@ -690,7 +691,7 @@ def _next_action(
model="claude-3-5-sonnet-20240620",
)
except Exception as e:
logger.info(f"Error calculating completion cost: {e}")
self.log_info(f"Error calculating completion cost: {e}")
cost = 0

self._log_prompt(
Expand Down Expand Up @@ -763,6 +764,18 @@ def _log_prompt(
f.write("\n\n# Error\n")
f.write(f"\n```\n{error}\n```\n")

def log_info(self, message: str):
logger.info(f"{self.transition_name}: {message}")

@property
def transition_name(self):
if self._current_transition:
return (
f"{self._current_transition.state.name}:{self._current_transition.id}"
)
else:
return "No transition"


def generate_call_id():
prefix = "call_"
Expand Down
13 changes: 2 additions & 11 deletions moatless/repository/git.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,7 @@

from moatless.repository.file import FileRepository
from moatless.settings import Settings
from moatless.utils.repo import (
maybe_clone,
checkout_commit
)
from moatless.utils.repo import maybe_clone, checkout_commit

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -40,14 +37,8 @@ def __init__(
self._current_commit = self._repo.head.commit.hexsha
self._initial_commit = self._current_commit


@classmethod
def from_repo(
cls,
git_repo_url: str,
repo_path: str,
commit: Optional[str] = None
):
def from_repo(cls, git_repo_url: str, repo_path: str, commit: Optional[str] = None):
logger.info(
f"Create GitRepository for {git_repo_url} with commit {commit} on path {repo_path} "
)
Expand Down
8 changes: 4 additions & 4 deletions moatless/workspace.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,9 +88,7 @@ def from_dirs(
)

@classmethod
def from_dict(cls,
data: dict,
**kwargs):
def from_dict(cls, data: dict, **kwargs):
if "repository" not in data:
raise ValueError("Missing repository key")

Expand All @@ -111,7 +109,9 @@ def from_dict(cls,
file_context.load_files_from_dict(data["file_context"].get("files", []))

if data.get("code_index", {}).get("index_name"):
code_index = CodeIndex.from_index_name(data["code_index"].get("index_name"), file_repo=file_repo)
code_index = CodeIndex.from_index_name(
data["code_index"].get("index_name"), file_repo=file_repo
)
else:
code_index = None

Expand Down
13 changes: 7 additions & 6 deletions tests/integration_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@


@pytest.mark.llm_integration
def test_save_and_load_trajectory():
def test_run_and_reload_django_16379():
instance = load_instance("django__django-16379")
workspace = create_workspace(instance)

Expand All @@ -48,7 +48,6 @@ def test_save_and_load_trajectory():
)

response = loop.run(message=instance["problem_statement"])

print("Response")
print(response)

Expand All @@ -60,10 +59,7 @@ def test_save_and_load_trajectory():
"django/core/cache/backends/filebased.py", "FileBasedCache.has_key"
)

saved_loop = AgenticLoop.from_trajectory_file(
trajectory_path=trajectory_path,
transitions=search_and_code_transitions(global_params=global_params),
)
saved_loop = AgenticLoop.from_trajectory_file(trajectory_path=trajectory_path)

saved_response = saved_loop.run(message=instance["problem_statement"])

Expand Down Expand Up @@ -121,6 +117,9 @@ def test_different_edit_models():
"django/core/cache/backends/filebased.py", "FileBasedCache.has_key"
)

first_commit = loop.workspace.file_repo._current_commit
assert first_commit != loop.workspace.file_repo._initial_commit

# Reverts to PlanToCode state and set LLM to GPT-4o-mini in the EditCode state
response_mini = loop.retry_from_transition(
transition_id=4, # PlanToCode
Expand All @@ -137,3 +136,5 @@ def test_different_edit_models():
diff = loop.workspace.file_repo.diff()
print("Diff")
print(diff)

assert loop.workspace.file_repo._current_commit != first_commit

0 comments on commit cb2be6b

Please sign in to comment.