Skip to content

Commit

Permalink
Support Java in ingestion flow
Browse files Browse the repository at this point in the history
  • Loading branch information
aorwall committed Aug 2, 2024
1 parent 7ccd42a commit 83d56b1
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 44 deletions.
5 changes: 4 additions & 1 deletion moatless/codeblocks/parser/create.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from moatless.codeblocks.parser.parser import CodeParser
from moatless.codeblocks.parser.python import PythonParser
from moatless.codeblocks.parser.java import JavaParser


def is_supported(language: str) -> bool:
return language and language in ["python", "java", "typescript", "javascript"]
return language and language in ["python", "java"]


def create_parser(language: str, **kwargs) -> CodeParser | None:
if language == "python":
return PythonParser(**kwargs)
elif language == "java":
return JavaParser(**kwargs)

raise NotImplementedError(f"Language {language} is not supported.")
98 changes: 64 additions & 34 deletions moatless/index/code_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import os
import shutil
import tempfile
from typing import Optional

import requests
from llama_index.core import SimpleDirectoryReader
Expand Down Expand Up @@ -54,16 +55,18 @@ class CodeIndex:
def __init__(
self,
file_repo: FileRepository,
index_name: Optional[str] = None,
vector_store: BasePydanticVectorStore | None = None,
docstore: DocumentStore | None = None,
embed_model: BaseEmbedding | None = None,
blocks_by_class_name: dict | None = None,
blocks_by_function_name: dict | None = None,
blocks_by_class_name: Optional[dict] = None,
blocks_by_function_name: Optional[dict] = None,
settings: IndexSettings | None = None,
max_results: int = 25,
max_hits_without_exact_match: int = 100,
max_exact_results: int = 5,
):
self._index_name = index_name
self._settings = settings or IndexSettings()

self.max_results = max_results
Expand All @@ -79,6 +82,11 @@ def __init__(
self._vector_store = vector_store or default_vector_store(self._settings)
self._docstore = docstore or SimpleDocumentStore()

logger.info(f"Initiated CodeIndex {self._index_name} with:\n"
f" * {len(self._blocks_by_class_name)} classes\n"
f" * {len(self._blocks_by_function_name)} functions\n"
f" * {len(self._docstore.docs)} vectors\n")

@classmethod
def from_persist_dir(cls, persist_dir: str, file_repo: FileRepository, **kwargs):
vector_store = SimpleFaissVectorStore.from_persist_dir(persist_dir)
Expand Down Expand Up @@ -131,30 +139,42 @@ def from_url(cls, url: str, persist_dir: str, file_repo: FileRepository):
raise e

logger.info(f"Downloaded existing index from {url}.")
return cls.from_persist_dir(persist_dir, file_repo)

vector_store = SimpleFaissVectorStore.from_persist_dir(persist_dir)
docstore = SimpleDocumentStore.from_persist_dir(persist_dir)
@classmethod
def from_index_name(
cls,
index_name: str,
file_repo: FileRepository,
index_store_dir: Optional[str] = None,
):
if not index_store_dir:
index_store_dir = os.getenv("INDEX_STORE_DIR")

if not os.path.exists(os.path.join(persist_dir, "settings.json")):
# TODO: Remove this when new indexes are uploaded
settings = IndexSettings(embed_model="voyage-code-2")
persist_dir = os.path.join(index_store_dir, index_name)
if os.path.exists(persist_dir):
logger.info(f"Loading existing index {index_name} from {persist_dir}.")
return cls.from_persist_dir(persist_dir, file_repo=file_repo)

if os.getenv("INDEX_STORE_URL"):
index_store_url = os.getenv("INDEX_STORE_URL")
else:
settings = IndexSettings.from_persist_dir(persist_dir)
index_store_url = "https://stmoatless.blob.core.windows.net/indexstore/20240522-voyage-code-2"

return cls(
file_repo=file_repo,
vector_store=vector_store,
docstore=docstore,
settings=settings,
)
store_url = os.path.join(index_store_url, f"{index_name}.zip")
logger.info(f"Downloading existing index {index_name} from {store_url}.")
return cls.from_url(store_url, persist_dir, file_repo)

def dict(self):
return {"index_name": self._index_name}

def search(
self,
query: str | None = None,
code_snippet: str | None = None,
query: Optional[str] = None,
code_snippet: Optional[str] = None,
class_names: list[str] = None,
function_names: list[str] = None,
file_pattern: str | None = None,
file_pattern: Optional[str] = None,
max_results: int = 25,
) -> SearchCodeResponse:
if class_names or function_names:
Expand Down Expand Up @@ -199,16 +219,16 @@ def search(

def semantic_search(
self,
query: str | None = None,
code_snippet: str | None = None,
query: Optional[str] = None,
code_snippet: Optional[str] = None,
class_names: list[str] = None,
function_names: list[str] = None,
file_pattern: str | None = None,
file_pattern: Optional[str] = None,
category: str = "implementation",
max_results: int = 25,
max_hits_without_exact_match: int = 100,
max_exact_results: int = 5,
max_spans_per_file: int | None = None,
max_spans_per_file: Optional[int] = None,
exact_match_if_possible: bool = False,
) -> SearchCodeResponse:
if query is None:
Expand Down Expand Up @@ -362,7 +382,7 @@ def find_by_name(
self,
class_names: list[str] = None,
function_names: list[str] = None,
file_pattern: str | None = None,
file_pattern: Optional[str] = None,
include_functions_in_class: bool = True,
category: str = "implementation",
) -> SearchCodeResponse:
Expand Down Expand Up @@ -531,8 +551,8 @@ def _vector_search(
query: str = "",
exact_query_match: bool = False,
category: str = "implementation",
file_pattern: str | None = None,
exact_content_match: str | None = None,
file_pattern: Optional[str] = None,
exact_content_match: Optional[str] = None,
):
if file_pattern:
query += f" file:{file_pattern}"
Expand Down Expand Up @@ -647,9 +667,9 @@ def _vector_search(

def run_ingestion(
self,
repo_path: str | None = None,
repo_path: Optional[str] = None,
input_files: list[str] | None = None,
num_workers: int | None = None,
num_workers: Optional[int] = None,
):
repo_path = repo_path or self._file_repo.path

Expand Down Expand Up @@ -678,14 +698,23 @@ def file_metadata_func(file_path: str) -> dict:
"category": category,
}

reader = SimpleDirectoryReader(
input_dir=repo_path,
file_metadata=file_metadata_func,
input_files=input_files,
filename_as_id=True,
required_exts=[".py"], # TODO: Shouldn't be hardcoded and filtered
recursive=True,
)
if self._settings and self._settings.language == "java":
required_exts = [".java"]
else:
required_exts = [".py"]

try:
reader = SimpleDirectoryReader(
input_dir=repo_path,
file_metadata=file_metadata_func,
input_files=input_files,
filename_as_id=True,
required_exts=required_exts,
recursive=True,
)
except Exception as e:
logger.exception(f"Failed to create reader with input_dir {repo_path}, input_files {input_files} and required_exts {required_exts}.")
raise e

embed_pipeline = IngestionPipeline(
transformations=[self._embed_model],
Expand Down Expand Up @@ -716,6 +745,7 @@ def index_callback(codeblock: CodeBlock):
)

splitter = EpicSplitter(
language=self._settings.language,
min_chunk_size=self._settings.min_chunk_size,
chunk_size=self._settings.chunk_size,
hard_token_limit=self._settings.hard_token_limit,
Expand Down
25 changes: 16 additions & 9 deletions moatless/index/epic_split.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import re
import time
from collections.abc import Callable, Sequence
from typing import Any
from typing import Any, Optional

from llama_index.core.bridge.pydantic import Field
from llama_index.core.callbacks import CallbackManager
Expand All @@ -10,6 +10,7 @@
from llama_index.core.schema import BaseNode, TextNode
from llama_index.core.utils import get_tokenizer, get_tqdm_iterable

from moatless.codeblocks import create_parser
from moatless.codeblocks.codeblocks import CodeBlock, CodeBlockType, PathTree
from moatless.codeblocks.parser.python import PythonParser
from moatless.index.code_node import CodeNode
Expand Down Expand Up @@ -39,6 +40,10 @@ def count_parent_tokens(codeblock: CodeBlock) -> int:


class EpicSplitter(NodeParser):
language: str = Field(
default="python", description="Language of the code blocks to parse."
)

text_splitter: TextSplitter = Field(
description="Text splitter to use for splitting non code documents into nodes."
)
Expand Down Expand Up @@ -74,14 +79,15 @@ class EpicSplitter(NodeParser):

repo_path: str = Field(default=None, description="Path to the repository.")

index_callback: Callable | None = Field(
index_callback: Optional[Callable] = Field(
default=None, description="Callback to call when indexing a code block."
)

# _fallback_code_splitter: Optional[TextSplitter] = PrivateAttr() TODO: Implement fallback when tree sitter fails

def __init__(
self,
language: str = "python",
chunk_size: int = 750,
min_chunk_size: int = 100,
max_chunk_size: int = 1500,
Expand All @@ -90,12 +96,12 @@ def __init__(
include_metadata: bool = True,
include_prev_next_rel: bool = True,
text_splitter: TextSplitter | None = None,
index_callback: Callable[[CodeBlock], None] | None = None,
repo_path: str | None = None,
index_callback: Optional[Callable[[CodeBlock], None]] = None,
repo_path: Optional[str] = None,
comment_strategy: CommentStrategy = CommentStrategy.ASSOCIATE,
# fallback_code_splitter: Optional[TextSplitter] = None,
include_non_code_files: bool = True,
tokenizer: Callable | None = None,
tokenizer: Optional[Callable] = None,
non_code_file_extensions: list[str] | None = None,
callback_manager: CallbackManager | None = None,
) -> None:
Expand All @@ -106,6 +112,7 @@ def __init__(
# self._fallback_code_splitter = fallback_code_splitter

super().__init__(
language=language,
chunk_size=chunk_size,
chunk_overlap=0,
text_splitter=text_splitter or TokenTextSplitter(),
Expand Down Expand Up @@ -142,10 +149,10 @@ def _parse_nodes(
content = node.get_content()

try:
# TODO: Derive language from file extension
starttime = time.time_ns()

parser = PythonParser(index_callback=self.index_callback)
# TODO: Derive language from file extension
parser = create_parser(language=self.language, index_callback=self.index_callback)
codeblock = parser.parse(content, file_path=file_path)

parse_time = time.time_ns() - starttime
Expand Down Expand Up @@ -186,7 +193,7 @@ def _parse_nodes(
return all_nodes

def _chunk_contents(
self, codeblock: CodeBlock | None = None, file_path: str | None = None
self, codeblock: CodeBlock | None = None, file_path: Optional[str] = None
) -> list[CodeBlockChunk]:
tokens = codeblock.sum_tokens()
if tokens == 0:
Expand Down Expand Up @@ -221,7 +228,7 @@ def _chunk_contents(
return self._chunk_block(codeblock, file_path)

def _chunk_block(
self, codeblock: CodeBlock, file_path: str | None = None
self, codeblock: CodeBlock, file_path: Optional[str] = None
) -> list[CodeBlockChunk]:
chunks: list[CodeBlockChunk] = []
current_chunk = []
Expand Down

0 comments on commit 83d56b1

Please sign in to comment.