Skip to content

Commit

Permalink
pythongh-111201: auto-indentation in _pyrepl (python#119348)
Browse files Browse the repository at this point in the history
Co-authored-by: Łukasz Langa <lukasz@langa.pl>
  • Loading branch information
wiggin15 and ambv authored May 22, 2024
1 parent e9875ec commit cd516cd
Show file tree
Hide file tree
Showing 2 changed files with 179 additions and 58 deletions.
43 changes: 40 additions & 3 deletions Lib/_pyrepl/readline.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ class ReadlineAlikeReader(historical_reader.HistoricalReader, CompletingReader):
# Instance fields
config: ReadlineConfig
more_lines: MoreLinesCallable | None = None
last_used_indentation: str | None = None

def __post_init__(self) -> None:
super().__post_init__()
Expand Down Expand Up @@ -157,6 +158,11 @@ def get_trimmed_history(self, maxlength: int) -> list[str]:
cut = 0
return self.history[cut:]

def update_last_used_indentation(self) -> None:
indentation = _get_first_indentation(self.buffer)
if indentation is not None:
self.last_used_indentation = indentation

# --- simplified support for reading multiline Python statements ---

def collect_keymap(self) -> tuple[tuple[KeySpec, CommandName], ...]:
Expand Down Expand Up @@ -211,6 +217,28 @@ def _get_previous_line_indent(buffer: list[str], pos: int) -> tuple[int, int | N
return prevlinestart, indent


def _get_first_indentation(buffer: list[str]) -> str | None:
indented_line_start = None
for i in range(len(buffer)):
if (i < len(buffer) - 1
and buffer[i] == "\n"
and buffer[i + 1] in " \t"
):
indented_line_start = i + 1
elif indented_line_start is not None and buffer[i] not in " \t\n":
return ''.join(buffer[indented_line_start : i])
return None


def _is_last_char_colon(buffer: list[str]) -> bool:
i = len(buffer)
while i > 0:
i -= 1
if buffer[i] not in " \t\n": # ignore whitespaces
return buffer[i] == ":"
return False


class maybe_accept(commands.Command):
def do(self) -> None:
r: ReadlineAlikeReader
Expand All @@ -227,9 +255,18 @@ def do(self) -> None:
# auto-indent the next line like the previous line
prevlinestart, indent = _get_previous_line_indent(r.buffer, r.pos)
r.insert("\n")
if not self.reader.paste_mode and indent:
for i in range(prevlinestart, prevlinestart + indent):
r.insert(r.buffer[i])
if not self.reader.paste_mode:
if indent:
for i in range(prevlinestart, prevlinestart + indent):
r.insert(r.buffer[i])
r.update_last_used_indentation()
if _is_last_char_colon(r.buffer):
if r.last_used_indentation is not None:
indentation = r.last_used_indentation
else:
# default
indentation = " " * 4
r.insert(indentation)
elif not self.reader.paste_mode:
self.finish = True
else:
Expand Down
194 changes: 139 additions & 55 deletions Lib/test/test_pyrepl/test_pyrepl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,31 @@
from unittest import TestCase
from unittest.mock import patch

from .support import FakeConsole, handle_all_events, handle_events_narrow_console
from .support import more_lines, multiline_input, code_to_events
from .support import (
FakeConsole,
handle_all_events,
handle_events_narrow_console,
more_lines,
multiline_input,
code_to_events,
)
from _pyrepl.console import Event
from _pyrepl.readline import ReadlineAlikeReader, ReadlineConfig
from _pyrepl.readline import multiline_input as readline_multiline_input


class TestCursorPosition(TestCase):
def prepare_reader(self, events):
console = FakeConsole(events)
config = ReadlineConfig(readline_completer=None)
reader = ReadlineAlikeReader(console=console, config=config)
return reader

def test_up_arrow_simple(self):
# fmt: off
code = (
'def f():\n'
' ...\n'
"def f():\n"
" ...\n"
)
# fmt: on
events = itertools.chain(
Expand All @@ -34,8 +46,8 @@ def test_up_arrow_simple(self):
def test_down_arrow_end_of_input(self):
# fmt: off
code = (
'def f():\n'
' ...\n'
"def f():\n"
" ...\n"
)
# fmt: on
events = itertools.chain(
Expand Down Expand Up @@ -300,6 +312,79 @@ def test_cursor_position_after_wrap_and_move_up(self):
self.assertEqual(reader.pos, 10)
self.assertEqual(reader.cxy, (1, 1))

def test_auto_indent_default(self):
# fmt: off
input_code = (
"def f():\n"
"pass\n\n"
)

output_code = (
"def f():\n"
" pass\n"
" "
)
# fmt: on

def test_auto_indent_continuation(self):
# auto indenting according to previous user indentation
# fmt: off
events = itertools.chain(
code_to_events("def f():\n"),
# add backspace to delete default auto-indent
[
Event(evt="key", data="backspace", raw=bytearray(b"\x7f")),
],
code_to_events(
" pass\n"
"pass\n\n"
),
)

output_code = (
"def f():\n"
" pass\n"
" pass\n"
" "
)
# fmt: on

reader = self.prepare_reader(events)
output = multiline_input(reader)
self.assertEqual(output, output_code)

def test_auto_indent_prev_block(self):
# auto indenting according to indentation in different block
# fmt: off
events = itertools.chain(
code_to_events("def f():\n"),
# add backspace to delete default auto-indent
[
Event(evt="key", data="backspace", raw=bytearray(b"\x7f")),
],
code_to_events(
" pass\n"
"pass\n\n"
),
code_to_events(
"def g():\n"
"pass\n\n"
),
)


output_code = (
"def g():\n"
" pass\n"
" "
)
# fmt: on

reader = self.prepare_reader(events)
output1 = multiline_input(reader)
output2 = multiline_input(reader)
self.assertEqual(output2, output_code)


class TestPyReplOutput(TestCase):
def prepare_reader(self, events):
Expand All @@ -316,14 +401,12 @@ def test_basic(self):

def test_multiline_edit(self):
events = itertools.chain(
code_to_events("def f():\n ...\n\n"),
code_to_events("def f():\n...\n\n"),
[
Event(evt="key", data="up", raw=bytearray(b"\x1bOA")),
Event(evt="key", data="up", raw=bytearray(b"\x1bOA")),
Event(evt="key", data="up", raw=bytearray(b"\x1bOA")),
Event(evt="key", data="right", raw=bytearray(b"\x1bOC")),
Event(evt="key", data="right", raw=bytearray(b"\x1bOC")),
Event(evt="key", data="right", raw=bytearray(b"\x1bOC")),
Event(evt="key", data="backspace", raw=bytearray(b"\x7f")),
Event(evt="key", data="g", raw=bytearray(b"g")),
Event(evt="key", data="down", raw=bytearray(b"\x1bOB")),
Expand All @@ -334,9 +417,9 @@ def test_multiline_edit(self):
reader = self.prepare_reader(events)

output = multiline_input(reader)
self.assertEqual(output, "def f():\n ...\n ")
self.assertEqual(output, "def f():\n ...\n ")
output = multiline_input(reader)
self.assertEqual(output, "def g():\n ...\n ")
self.assertEqual(output, "def g():\n ...\n ")

def test_history_navigation_with_up_arrow(self):
events = itertools.chain(
Expand Down Expand Up @@ -485,6 +568,7 @@ class Dummy:
@property
def test_func(self):
import warnings

warnings.warn("warnings\n")
return None

Expand All @@ -508,12 +592,12 @@ def prepare_reader(self, events):
def test_paste(self):
# fmt: off
code = (
'def a():\n'
' for x in range(10):\n'
' if x%2:\n'
' print(x)\n'
' else:\n'
' pass\n'
"def a():\n"
" for x in range(10):\n"
" if x%2:\n"
" print(x)\n"
" else:\n"
" pass\n"
)
# fmt: on

Expand All @@ -534,10 +618,10 @@ def test_paste(self):
def test_paste_mid_newlines(self):
# fmt: off
code = (
'def f():\n'
' x = y\n'
' \n'
' y = z\n'
"def f():\n"
" x = y\n"
" \n"
" y = z\n"
)
# fmt: on

Expand All @@ -558,16 +642,16 @@ def test_paste_mid_newlines(self):
def test_paste_mid_newlines_not_in_paste_mode(self):
# fmt: off
code = (
'def f():\n'
' x = y\n'
' \n'
' y = z\n\n'
"def f():\n"
"x = y\n"
"\n"
"y = z\n\n"
)

expected = (
'def f():\n'
' x = y\n'
' '
"def f():\n"
" x = y\n"
" "
)
# fmt: on

Expand All @@ -579,20 +663,20 @@ def test_paste_mid_newlines_not_in_paste_mode(self):
def test_paste_not_in_paste_mode(self):
# fmt: off
input_code = (
'def a():\n'
' for x in range(10):\n'
' if x%2:\n'
' print(x)\n'
' else:\n'
' pass\n\n'
"def a():\n"
"for x in range(10):\n"
"if x%2:\n"
"print(x)\n"
"else:\n"
"pass\n\n"
)

output_code = (
'def a():\n'
' for x in range(10):\n'
' if x%2:\n'
' print(x)\n'
' else:'
"def a():\n"
" for x in range(10):\n"
" if x%2:\n"
" print(x)\n"
" else:"
)
# fmt: on

Expand All @@ -605,25 +689,25 @@ def test_bracketed_paste(self):
"""Test that bracketed paste using \x1b[200~ and \x1b[201~ works."""
# fmt: off
input_code = (
'def a():\n'
' for x in range(10):\n'
'\n'
' if x%2:\n'
' print(x)\n'
'\n'
' else:\n'
' pass\n'
"def a():\n"
" for x in range(10):\n"
"\n"
" if x%2:\n"
" print(x)\n"
"\n"
" else:\n"
" pass\n"
)

output_code = (
'def a():\n'
' for x in range(10):\n'
'\n'
' if x%2:\n'
' print(x)\n'
'\n'
' else:\n'
' pass\n'
"def a():\n"
" for x in range(10):\n"
"\n"
" if x%2:\n"
" print(x)\n"
"\n"
" else:\n"
" pass\n"
)
# fmt: on

Expand Down

0 comments on commit cd516cd

Please sign in to comment.