Skip to content

Commit

Permalink
Fix code actions to work at cell level
Browse files Browse the repository at this point in the history
  • Loading branch information
dhruvmanila committed Nov 7, 2023
1 parent 2b4d6ab commit 09587d4
Showing 1 changed file with 112 additions and 88 deletions.
200 changes: 112 additions & 88 deletions ruff_lsp/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@
MarkupContent,
MarkupKind,
MessageType,
NotebookCell,
NotebookCellKind,
NotebookDocument,
NotebookDocumentSyncOptions,
Expand Down Expand Up @@ -223,11 +224,14 @@ class DocumentKind(enum.Enum):
"""The kind of document."""

Text = enum.auto()
"""A Python file or a cell in a Notebook Document."""
"""A Python file."""

Notebook = enum.auto()
"""A Notebook Document."""

Cell = enum.auto()
"""A cell in a Notebook Document."""


@dataclass(frozen=True)
class Document:
Expand Down Expand Up @@ -261,14 +265,52 @@ def from_notebook_document(cls, notebook_document: NotebookDocument) -> Self:
version=notebook_document.version,
)

@classmethod
def from_notebook_cell(cls, notebook_cell: NotebookCell) -> Self:
"""Create a `Document` from the given Notebook cell."""
return cls(
uri=notebook_cell.document,
path=_uri_to_fs_path(notebook_cell.document),
kind=DocumentKind.Cell,
source=_create_single_cell_notebook_json(
LSP_SERVER.workspace.get_text_document(notebook_cell.document).source
),
version=None,
)

@classmethod
def from_cell_or_text_uri(cls, uri: str) -> Self:
"""Create a `Document` representing either a Python file or a Notebook cell from
the given URI.
The function will try to get the Notebook cell first, and if there's no cell
with the given URI, it will fallback to the text document.
"""
notebook_document = LSP_SERVER.workspace.get_notebook_document(cell_uri=uri)
if notebook_document is not None:
notebook_cell = next(
(
notebook_cell
for notebook_cell in notebook_document.cells
if notebook_cell.document == uri
),
None,
)
if notebook_cell is not None:
return cls.from_notebook_cell(notebook_cell)

# Fall back to the Text Document representing a Python file.
text_document = LSP_SERVER.workspace.get_text_document(uri)
return cls.from_text_document(text_document)

@classmethod
def from_uri(cls, uri: str) -> Self:
"""Create a `Document` representing either a Python file or a Notebook from
the given URI.
The URI can be a file URI, a notebook URI, or a cell URI. The function will
try to get the notebook document first, and if that fails, it will fallback
to the text document.
try to get the notebook document first, and if there's no notebook document
with the given URI, it will fallback to the text document.
"""
# First, try to get the Notebook Document assuming the URI is a Cell URI.
notebook_document = LSP_SERVER.workspace.get_notebook_document(cell_uri=uri)
Expand All @@ -289,14 +331,6 @@ def is_stdlib_file(self) -> bool:
"""Return True if the document belongs to standard library."""
return utils.is_stdlib_file(self.path)

def is_notebook_file(self) -> bool:
"""Return True if the document belongs to a Notebook or a cell in a Notebook."""
return self.kind is DocumentKind.Notebook or self.path.endswith(".ipynb")

def is_notebook_cell(self) -> bool:
"""Return True if the document belongs to a cell in a Notebook."""
return self.kind is DocumentKind.Text and self.path.endswith(".ipynb")


SourceValue = Union[str, List[str]]

Expand Down Expand Up @@ -521,9 +555,7 @@ async def _did_change_or_save_notebook(


async def _lint_document_impl(document: Document) -> list[Diagnostic]:
result = await _run_check_on_document_source(
DocumentSource(path=document.path, text=document.source)
)
result = await _run_check_on_document(document)
if result is None:
return []
return _parse_output(result.stdout) if result.stdout else []
Expand Down Expand Up @@ -775,18 +807,10 @@ class LegacyFix(TypedDict):
async def code_action(params: CodeActionParams) -> list[CodeAction] | None:
"""LSP handler for textDocument/codeAction request.
There are two scopes for code actions:
- Source: The whole file.
- QuickFix: A single diagnostic.
For source level code actions, we create a `Document` from the URI such that it
represents either a Python file or a Notebook. On the other hand, for quick fix
code actions, we use the `TextDocument` abstraction from `pygls` which represents
either a Python file or a cell in a Notebook.
For a Notebook cell, the `TextDocument` works because we don't need to know the
content of the entire Notebook. The fix is already available in the code action
context.
Code actions work at a text document level which is either a Python file or a
cell in a Notebook document. The function will try to get the Notebook cell
first, and if there's no cell with the given URI, it will fallback to the text
document.
"""
document_path = _uri_to_fs_path(params.text_document.uri)
if utils.is_stdlib_file(document_path):
Expand All @@ -807,7 +831,8 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None:
and kind in params.context.only
):
workspace_edit = await _fix_document_impl(
Document.from_uri(params.text_document.uri), only="I001"
Document.from_cell_or_text_uri(params.text_document.uri),
only="I001",
)
if workspace_edit:
return [
Expand All @@ -834,7 +859,7 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None:
and kind in params.context.only
):
workspace_edit = await _fix_document_impl(
Document.from_uri(params.text_document.uri)
Document.from_cell_or_text_uri(params.text_document.uri)
)
if workspace_edit:
return [
Expand Down Expand Up @@ -968,7 +993,8 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None:
)
else:
workspace_edit = await _fix_document_impl(
Document.from_uri(params.text_document.uri), only="I001"
Document.from_cell_or_text_uri(params.text_document.uri),
only="I001",
)
if workspace_edit:
actions.append(
Expand Down Expand Up @@ -998,7 +1024,7 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None:
)
else:
workspace_edit = await _fix_document_impl(
Document.from_uri(params.text_document.uri)
Document.from_cell_or_text_uri(params.text_document.uri)
)
if workspace_edit:
actions.append(
Expand All @@ -1024,7 +1050,7 @@ async def code_action(params: CodeActionParams) -> list[CodeAction] | None:
async def resolve_code_action(params: CodeAction) -> CodeAction:
"""LSP handler for codeAction/resolve request."""
# We set the `data` field to the document URI during codeAction request.
document = Document.from_uri(cast(str, params.data))
document = Document.from_cell_or_text_uri(cast(str, params.data))

settings = _get_settings_by_document(document.path)

Expand Down Expand Up @@ -1069,9 +1095,7 @@ async def apply_organize_imports(arguments: tuple[TextDocument]):
async def apply_format(arguments: tuple[TextDocument]):
uri = arguments[0]["uri"]
document = Document.from_uri(uri)
results = await _run_format_on_document_source(
DocumentSource(path=document.path, text=document.source)
)
results = await _run_format_on_document(document)
workspace_edit = _result_to_workspace_edit(document, results)
if workspace_edit is None:
return
Expand All @@ -1083,37 +1107,14 @@ async def format_document(params: DocumentFormattingParams) -> list[TextEdit] |
# For a Jupyter Notebook, this request can only format a single cell as the
# request itself can only act on a text document. A cell in a Notebook is
# represented as a text document.
document = Document.from_text_document(
LSP_SERVER.workspace.get_text_document(params.text_document.uri)
)
if document.is_notebook_cell():
result = await _run_format_on_document_source(
DocumentSource(
path=document.path,
text=_create_single_cell_notebook_json(document.source),
)
)
if result is None:
return None
document = Document.from_cell_or_text_uri(params.text_document.uri)
result = await _run_format_on_document(document)
if result is None:
return None

output_notebook = cast(Notebook, json.loads(result.stdout.decode("utf-8")))
# The input notebook contained only one cell, so the output notebook should
# also contain only one cell.
output_cell = next(iter(output_notebook["cells"]), None)
if output_cell is None or output_cell["cell_type"] != "code":
log_warning(
f"Unexpected output when formatting a notebook cell: {output_notebook}"
)
return None
return _fixed_source_to_edits(
original_source=document.source, fixed_source=output_cell["source"]
)
if document.kind is DocumentKind.Cell:
return _result_single_cell_notebook_to_edits(document, result)
else:
result = await _run_format_on_document_source(
DocumentSource(path=document.path, text=document.source)
)
if result is None:
return None
return _fixed_source_to_edits(
original_source=document.source, fixed_source=result.stdout.decode("utf-8")
)
Expand All @@ -1124,8 +1125,8 @@ async def _fix_document_impl(
*,
only: str | None = None,
) -> WorkspaceEdit | None:
result = await _run_check_on_document_source(
DocumentSource(path=document.path, text=document.source),
result = await _run_check_on_document(
document,
extra_args=["--fix"],
only=only,
)
Expand Down Expand Up @@ -1185,10 +1186,43 @@ def _result_to_workspace_edit(
)

return WorkspaceEdit(document_changes=list(cell_document_changes))
elif document.kind is DocumentKind.Cell:
text_edits = _result_single_cell_notebook_to_edits(document, result)
if text_edits is None:
return None
return WorkspaceEdit(
document_changes=[
_create_text_document_edit(document.uri, document.version, text_edits)
]
)
else:
assert_never(document.kind)


def _result_single_cell_notebook_to_edits(
document: Document, result: RunResult
) -> list[TextEdit] | None:
"""Converts a run result to a list of TextEdits.
The result is expected to be a single cell Notebook Document.
"""
output_notebook = cast(Notebook, json.loads(result.stdout.decode("utf-8")))
# The input notebook contained only one cell, so the output notebook should
# also contain only one cell.
output_cell = next(iter(output_notebook["cells"]), None)
if output_cell is None or output_cell["cell_type"] != "code":
log_warning(
f"Unexpected output working with a notebook cell: {output_notebook}"
)
return None
# We can't use the `document.source` here because it's in the Notebook format
# i.e., it's a JSON string containing a single cell with the source.
original_source = LSP_SERVER.workspace.get_text_document(document.uri).source
return _fixed_source_to_edits(
original_source=original_source, fixed_source=output_cell["source"]
)


def _fixed_source_to_edits(
*, original_source: str, fixed_source: str | list[str]
) -> list[TextEdit]:
Expand Down Expand Up @@ -1562,28 +1596,18 @@ def _executable_version(executable: str) -> Version:
return EXECUTABLE_VERSIONS[executable].version


class DocumentSource(NamedTuple):
"""The source of a document."""

path: str
"""The path to the document."""

text: str
"""The text of the document."""


async def _run_check_on_document_source(
source: DocumentSource,
async def _run_check_on_document(
document: Document,
*,
extra_args: Sequence[str] = [],
only: str | None = None,
) -> RunResult | None:
"""Runs the Ruff `check` subcommand on the given document source."""
if utils.is_stdlib_file(source.path):
log_warning(f"Skipping standard library file: {source.path}")
if document.is_stdlib_file():
log_warning(f"Skipping standard library file: {document.path}")
return None

settings = _get_settings_by_document(source.path)
settings = _get_settings_by_document(document.path)

executable = _find_ruff_binary(settings, VERSION_REQUIREMENT_LINTER)
argv: list[str] = CHECK_ARGS + list(extra_args)
Expand Down Expand Up @@ -1613,30 +1637,30 @@ async def _run_check_on_document_source(
argv += ["--extend-select", only]

# Provide the document filename.
argv += ["--stdin-filename", source.path]
argv += ["--stdin-filename", document.path]

return await run_path(
executable.path,
argv,
cwd=settings["cwd"],
source=source.text,
source=document.source,
)


async def _run_format_on_document_source(source: DocumentSource) -> RunResult | None:
async def _run_format_on_document(document: Document) -> RunResult | None:
"""Runs the Ruff `format` subcommand on the given document source."""
if utils.is_stdlib_file(source.path):
log_warning(f"Skipping standard library file: {source.path}")
if document.is_stdlib_file():
log_warning(f"Skipping standard library file: {document.path}")
return None

settings = _get_settings_by_document(source.path)
settings = _get_settings_by_document(document.path)
executable = _find_ruff_binary(settings, VERSION_REQUIREMENT_FORMATTER)
argv: list[str] = [
"format",
"--force-exclude",
"--quiet",
"--stdin-filename",
source.path,
document.path,
]

for arg in settings.get("format", {}).get("args", []):
Expand All @@ -1649,7 +1673,7 @@ async def _run_format_on_document_source(source: DocumentSource) -> RunResult |
executable.path,
argv,
cwd=settings["cwd"],
source=source.text,
source=document.source,
)


Expand Down

0 comments on commit 09587d4

Please sign in to comment.