Skip to content

Commit

Permalink
Fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
aorwall committed Jan 18, 2025
1 parent cd3c41b commit 078314a
Show file tree
Hide file tree
Showing 67 changed files with 1,942 additions and 2,894 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ Before running the full evaluation, you can verify your setup using the integrat

```bash
# Run a single model test
poetry run scripts/run_integration_tests.py --model claude-3-5-sonnet-20241022
poetry run python -m moatless.validation.validate_simple_code_flow --model claude-3-5-sonnet-20241022
```

The script will run the model against a sample SWE-Bench instance
Expand Down
19 changes: 13 additions & 6 deletions moatless/actions/action.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from pydantic import BaseModel, ConfigDict

from moatless.actions.schema import ActionArguments, Observation, RewardScaleEntry, FewShotExample
from moatless.completion.base import BaseCompletionModel
from moatless.file_context import FileContext
from moatless.index import CodeIndex
from moatless.repository.repository import Repository
Expand All @@ -22,9 +23,6 @@ class Action(BaseModel, ABC):

model_config = ConfigDict(arbitrary_types_allowed=True)

def __init__(self, **data):
super().__init__(**data)

def execute(
self,
args: ActionArguments,
Expand Down Expand Up @@ -182,9 +180,13 @@ def model_validate(
) -> "Action":
if isinstance(obj, dict):
obj = obj.copy()
action_class_path = obj.pop("action_class", None)

if action_class_path:
if obj.get("action_class"):
action_class_path = obj["action_class"]

if action_class_path == "moatless.actions.edit":
action_class_path = "moatless.actions.claude_text_editor"

module_name, class_name = action_class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
action_class = getattr(module, class_name)
Expand All @@ -196,9 +198,14 @@ def model_validate(
if runtime and hasattr(action_class, "_runtime"):
obj["runtime"] = runtime

if "completion_model" in obj:
obj["completion_model"] = BaseCompletionModel.model_validate(obj["completion_model"])

return action_class(**obj)
else:
raise ValueError(f"action_class is required in {obj}")

return cls(**obj)
return super().model_validate(obj)

def model_dump(self, **kwargs) -> Dict[str, Any]:
dump = super().model_dump(**kwargs)
Expand Down
2 changes: 1 addition & 1 deletion moatless/actions/append_string.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ def execute(
# Normal append logic
if file_text:
file_text = file_text.rstrip("\n")
new_str = f"\n\n{new_str.lstrip('\n')}"
new_str = "\n\n" + new_str.lstrip("\n")
else:
new_str = new_str.lstrip("\n")

Expand Down
3 changes: 1 addition & 2 deletions moatless/actions/claude_text_editor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,7 @@
from pathlib import Path
from typing import Literal, Optional, List

from litellm import ConfigDict
from pydantic import Field, PrivateAttr, field_validator
from pydantic import Field, PrivateAttr, field_validator, ConfigDict

from moatless.actions import RunTests, CreateFile, ViewCode
from moatless.actions.action import Action
Expand Down
12 changes: 10 additions & 2 deletions moatless/actions/run_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,16 @@ def model_validate(cls, obj: Any) -> "RunTests":
if isinstance(obj, dict):
obj = obj.copy()
repository = obj.pop("repository")
code_index = obj.pop("code_index")
runtime = obj.pop("runtime")
if "code_index" in obj:
code_index = obj.pop("code_index")
else:
code_index = None

if "runtime" in obj:
runtime = obj.pop("runtime")
else:
runtime = None

return cls(code_index=code_index, repository=repository, runtime=runtime, **obj)
return super().model_validate(obj)

Expand Down
8 changes: 6 additions & 2 deletions moatless/actions/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -97,10 +97,14 @@ def model_validate(cls, obj: Any) -> "ActionArguments":
if isinstance(obj, dict):
obj = obj.copy()
action_args_class_path = obj.pop("action_args_class", None)
if action_args_class_path == "moatless.actions.request_context.RequestMoreContextArgs":
action_args_class_path = "moatless.actions.view_code.ViewCodeArgs"

if action_args_class_path:
if action_args_class_path == "moatless.actions.request_context.RequestMoreContextArgs":
action_args_class_path = "moatless.actions.view_code.ViewCodeArgs"

if action_args_class_path.startswith("moatless.actions.edit"):
action_args_class_path = "moatless.actions.claude_text_editor.EditActionArguments"

module_name, class_name = action_args_class_path.rsplit(".", 1)
module = importlib.import_module(module_name)
action_args_class = getattr(module, class_name)
Expand Down
7 changes: 4 additions & 3 deletions moatless/actions/search_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,9 @@ def __init__(
self,
repository: Repository = None,
code_index: CodeIndex | None = None,
completion_model: BaseCompletionModel = None,
**data,
):
super().__init__(completion_model=completion_model, **data)
super().__init__(**data)
self._repository = repository
self._code_index = code_index

Expand Down Expand Up @@ -413,5 +412,7 @@ def model_validate(cls, obj: Any) -> "SearchBaseAction":
obj = obj.copy()
repository = obj.pop("repository")
code_index = obj.pop("code_index")
return cls(code_index=code_index, repository=repository, **obj)
completion_model = BaseCompletionModel.model_validate(obj.pop("completion_model"))

return cls(code_index=code_index, repository=repository, completion_model=completion_model, **obj)
return super().model_validate(obj)
5 changes: 3 additions & 2 deletions moatless/actions/string_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -583,9 +583,10 @@ def find_potential_matches(old_str, new_content):

differences = []
if window.count("\n") != old_str.count("\n"):
window_lines = window.count("\n") + 1
old_str_lines = old_str.count("\n") + 1
differences.append(
f"Line break count differs: found {window.count('\n') + 1} lines, "
f"expected {old_str.count('\n') + 1} lines"
f"Line break count differs: found {window_lines} lines, " f"expected {old_str_lines} lines"
)

# Check for character differences
Expand Down
5 changes: 1 addition & 4 deletions moatless/actions/verified_finish.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,7 @@ class VerifiedFinishArgs(ActionArguments):
model_config = ConfigDict(title="Finish")

def to_prompt(self):
return (
f"Finish with reason: {self.finish_reason}\n"
f"Test verification: {self.test_verification}\n"
)
return f"Finish with reason: {self.finish_reason}\n" f"Test verification: {self.test_verification}\n"

def equals(self, other: "ActionArguments") -> bool:
return isinstance(other, VerifiedFinishArgs)
Expand Down
10 changes: 5 additions & 5 deletions moatless/agent/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import traceback
from typing import List, Type, Dict, Any

from pydantic import BaseModel, Field, PrivateAttr, model_validator, ValidationError
from pydantic import BaseModel, Field, PrivateAttr, model_validator

from moatless.actions.action import Action
from moatless.actions.schema import (
Expand Down Expand Up @@ -59,7 +59,7 @@ def from_agent_settings(cls, agent_settings: AgentSettings, actions: List[Action
actions = [action for action in actions if action.__class__.__name__ in agent_settings.actions]

return cls(
completion=agent_settings.completion_model,
completion_model=agent_settings.completion_model,
system_prompt=agent_settings.system_prompt,
actions=actions,
)
Expand All @@ -75,9 +75,9 @@ def set_actions(self, actions: List[Action]):
def verify_actions(self) -> "ActionAgent":
for action in self.actions:
if not isinstance(action, Action):
raise ValidationError(f"Invalid action type: {type(action)}. Expected Action subclass.")
raise ValueError(f"Invalid action type: {type(action)}. Expected Action subclass.")
if not hasattr(action, "args_schema"):
raise ValidationError(f"Action {action.__class__.__name__} is missing args_schema attribute")
raise ValueError(f"Action {action.__class__.__name__} is missing args_schema attribute")
return self

def run(self, node: Node):
Expand Down Expand Up @@ -115,7 +115,7 @@ def run(self, node: Node):
usage=e.accumulated_usage if hasattr(e, "accumulated_usage") else None,
)
else:
logger.error(f"Node{node.node_id}: Build action failed with error: {e}")
logger.exception(f"Node{node.node_id}: Build action failed with error")

if isinstance(e, CompletionRejectError):
return
Expand Down
25 changes: 14 additions & 11 deletions moatless/agent/code_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def create(

# Create action completion model by cloning the input model with JSON response format
action_completion_model = completion_model.clone(response_format=action_completion_format)
action_completion_model.message_cache = False

supports_anthropic_computer_use = completion_model.model.startswith("claude-3-5-sonnet")

Expand Down Expand Up @@ -164,24 +165,24 @@ def create_base_actions(
SemanticSearch(
code_index=code_index,
repository=repository,
completion_model=completion_model,
completion_model=completion_model.clone(),
),
FindClass(
code_index=code_index,
repository=repository,
completion_model=completion_model,
completion_model=completion_model.clone(),
),
FindFunction(
code_index=code_index,
repository=repository,
completion_model=completion_model,
completion_model=completion_model.clone(),
),
FindCodeSnippet(
code_index=code_index,
repository=repository,
completion_model=completion_model,
completion_model=completion_model.clone(),
),
ViewCode(repository=repository, completion_model=completion_model),
ViewCode(repository=repository, completion_model=completion_model.clone()),
]


Expand All @@ -192,7 +193,7 @@ def create_edit_code_actions(
runtime: RuntimeEnvironment | None = None,
) -> List[Action]:
"""Create a list of simple code modification actions."""
actions = create_base_actions(repository, code_index, completion_model)
actions = create_base_actions(repository, code_index, completion_model.clone())

edit_actions = [
StringReplace(repository=repository, code_index=code_index),
Expand All @@ -215,12 +216,12 @@ def create_claude_coding_actions(
completion_model: BaseCompletionModel | None = None,
runtime: RuntimeEnvironment | None = None,
) -> List[Action]:
actions = create_base_actions(repository, code_index, completion_model)
actions = create_base_actions(repository, code_index, completion_model.clone())
actions.append(
ClaudeEditTool(
code_index=code_index,
repository=repository,
completion_model=completion_model,
completion_model=completion_model.clone(),
),
)
actions.append(ListFiles())
Expand All @@ -235,7 +236,9 @@ def create_all_actions(
code_index: CodeIndex | None = None,
completion_model: BaseCompletionModel | None = None,
) -> List[Action]:
actions = create_base_actions(repository, code_index, completion_model)
actions.extend(create_edit_code_actions(repository, code_index, completion_model))
actions.append(ClaudeEditTool(code_index=code_index, repository=repository))
actions = create_base_actions(repository, code_index, completion_model.clone())
actions.extend(create_edit_code_actions(repository, code_index, completion_model.clone()))
actions.append(
ClaudeEditTool(code_index=code_index, repository=repository, completion_model=completion_model.clone())
)
return actions
4 changes: 2 additions & 2 deletions moatless/agent/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@

from pydantic import BaseModel, Field

from moatless.completion import BaseCompletionModel
from moatless.schema import CompletionModelSettings
from moatless.schema import MessageHistoryType


class AgentSettings(BaseModel):
model_config = {"frozen": True}

completion_model: BaseCompletionModel = Field(
completion_model: CompletionModelSettings = Field(
..., description="Completion model to be used for generating completions"
)
system_prompt: Optional[str] = Field(None, description="System prompt to be used for generating completions")
Expand Down
Loading

0 comments on commit 078314a

Please sign in to comment.