Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pass around a context object to all plugin callback functions #24

Merged
merged 4 commits into from
Jan 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 20 additions & 16 deletions nextline/fsm/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from logging import getLogger
from typing import Any, Optional

import apluggy
from transitions import EventData

from nextline.plugin import Context
from nextline.spawned import Command
from nextline.types import ResetOptions

Expand All @@ -14,8 +14,9 @@
class Machine:
'''The finite state machine of the nextline states.'''

def __init__(self, hook: apluggy.PluginManager) -> None:
self._hook = hook
def __init__(self, context: Context) -> None:
self._context = context
self._hook = context.hook
self._machine = build_state_machine(model=self)
self._machine.after_state_change = self.after_state_change.__name__ # type: ignore
assert self.state # type: ignore
Expand All @@ -28,39 +29,42 @@ async def after_state_change(self, event: EventData) -> None:
if not (event.transition and event.transition.dest):
# internal transition
return
await self._hook.ahook.on_change_state(state_name=self.state) # type: ignore
await self._hook.ahook.on_change_state(
context=self._context, state_name=self.state # type: ignore
)

async def on_exit_created(self, _: EventData) -> None:
await self._hook.ahook.start()
await self._hook.ahook.start(context=self._context)

async def on_enter_initialized(self, _: EventData) -> None:
self._run_arg = self._hook.hook.compose_run_arg()
await self._hook.ahook.on_initialize_run(run_arg=self._run_arg)
self._context.run_arg = self._hook.hook.compose_run_arg(context=self._context)
await self._hook.ahook.on_initialize_run(context=self._context)

async def on_enter_running(self, _: EventData) -> None:
self.run_finished = asyncio.Event()
run_started = asyncio.Event()

async def run() -> None:
async with self._hook.awith.run():
async with self._hook.awith.run(context=self._context):
run_started.set()
self._context.run_arg = None
await self.finish() # type: ignore
self.run_finished.set()

self._task = asyncio.create_task(run())
await run_started.wait()

async def send_command(self, command: Command) -> None:
await self._hook.ahook.send_command(command=command)
await self._hook.ahook.send_command(context=self._context, command=command)

async def interrupt(self) -> None:
await self._hook.ahook.interrupt()
await self._hook.ahook.interrupt(context=self._context)

async def terminate(self) -> None:
await self._hook.ahook.terminate()
await self._hook.ahook.terminate(context=self._context)

async def kill(self) -> None:
await self._hook.ahook.kill()
await self._hook.ahook.kill(context=self._context)

async def on_close_while_running(self, _: EventData) -> None:
await self.run_finished.wait()
Expand All @@ -72,13 +76,13 @@ async def on_exit_finished(self, _: EventData) -> None:
await self._task

def exception(self) -> Optional[BaseException]:
return self._hook.hook.exception()
return self._hook.hook.exception(context=self._context)

def result(self) -> Any:
return self._hook.hook.result()
return self._hook.hook.result(context=self._context)

async def on_enter_closed(self, _: EventData) -> None:
await self._hook.ahook.close()
await self._hook.ahook.close(context=self._context)

async def on_reset(self, event: EventData) -> None:
logger = getLogger(__name__)
Expand All @@ -88,7 +92,7 @@ async def on_reset(self, event: EventData) -> None:
reset_options: ResetOptions = kwargs.pop('reset_options')
if kwargs:
logger.warning(f'Unexpected kwargs: {kwargs!r}')
await self._hook.ahook.reset(reset_options=reset_options)
await self._hook.ahook.reset(context=self._context, reset_options=reset_options)

async def __aenter__(self) -> 'Machine':
await self.initialize() # type: ignore
Expand Down
12 changes: 4 additions & 8 deletions nextline/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .continuous import Continuous
from .fsm import Machine
from .plugin import build_hook
from .plugin import Context, build_hook
from .spawned import PdbCommand
from .types import (
InitOptions,
Expand Down Expand Up @@ -86,13 +86,9 @@ async def start(self) -> None:
self._started = True
logger = getLogger(__name__)
logger.debug(f'self._init_options: {self._init_options}')
self._hook.hook.init(
nextline=self,
hook=self._hook,
registry=self._pubsub,
init_options=self._init_options,
)
self._machine = Machine(hook=self._hook)
context = Context(nextline=self, hook=self._hook, pubsub=self._pubsub)
self._hook.hook.init(context=context, init_options=self._init_options)
self._machine = Machine(context=context)
await self._continuous.start()
await self._machine.initialize() # type: ignore

Expand Down
3 changes: 2 additions & 1 deletion nextline/plugin/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
__all__ = ['build_hook']
__all__ = ['build_hook', 'Context']

from .hook import build_hook
from .spec import Context
31 changes: 8 additions & 23 deletions nextline/plugin/plugins/argument.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,32 @@
from apluggy import PluginManager

from nextline.count import RunNoCounter
from nextline.plugin.spec import hookimpl
from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import RunArg
from nextline.types import InitOptions, ResetOptions
from nextline.utils.pubsub.broker import PubSub

SCRIPT_FILE_NAME = "<string>"


class RunArgComposer:
@hookimpl
def init(
self,
hook: PluginManager,
registry: PubSub,
init_options: InitOptions,
) -> None:
self._hook = hook
self._registry = registry
def init(self, init_options: InitOptions) -> None:
self._run_no_count = RunNoCounter(init_options.run_no_start_from)
self._statement = init_options.statement
self._filename = SCRIPT_FILE_NAME
self._trace_threads = init_options.trace_threads
self._trace_modules = init_options.trace_modules

@hookimpl
async def start(self) -> None:
await self._hook.ahook.on_change_script(
script=self._statement,
filename=self._filename,
async def start(self, context: Context) -> None:
await context.hook.ahook.on_change_script(
context=context, script=self._statement, filename=self._filename
)

@hookimpl
async def reset(
self,
reset_options: ResetOptions,
) -> None:
async def reset(self, context: Context, reset_options: ResetOptions) -> None:
if (statement := reset_options.statement) is not None:
self._statement = statement
await self._hook.ahook.on_change_script(
script=self._statement,
filename=self._filename,
await context.hook.ahook.on_change_script(
context=context, script=self._statement, filename=self._filename
)
if (run_no_start_from := reset_options.run_no_start_from) is not None:
self._run_no_count = RunNoCounter(run_no_start_from)
Expand Down
59 changes: 24 additions & 35 deletions nextline/plugin/plugins/registrars/prompt_info.py
Original file line number Diff line number Diff line change
@@ -1,92 +1,81 @@
import asyncio
import dataclasses
from logging import getLogger
from typing import Optional

from nextline.plugin.spec import hookimpl
from nextline.plugin.spec import Context, hookimpl
from nextline.spawned import (
OnEndPrompt,
OnEndTrace,
OnEndTraceCall,
OnStartPrompt,
OnStartTrace,
OnStartTraceCall,
RunArg,
)
from nextline.types import PromptInfo, PromptNo, RunNo, TraceNo
from nextline.utils.pubsub.broker import PubSub
from nextline.types import PromptInfo, PromptNo, TraceNo


class PromptInfoRegistrar:
def __init__(self) -> None:
self._run_no: Optional[RunNo] = None
self._last_prompt_frame_map = dict[TraceNo, int]()
self._trace_call_map = dict[TraceNo, OnStartTraceCall]()
self._prompt_info_map = dict[PromptNo, PromptInfo]()
self._keys = set[str]()
self._logger = getLogger(__name__)

@hookimpl
def init(self, registry: PubSub) -> None:
self._registry = registry

@hookimpl
async def start(self) -> None:
self._lock = asyncio.Lock()
pass

@hookimpl
async def on_initialize_run(self, run_arg: RunArg) -> None:
self._run_no = run_arg.run_no
async def on_initialize_run(self) -> None:
self._last_prompt_frame_map.clear()
self._trace_call_map.clear()
self._prompt_info_map.clear()
self._keys.clear()

@hookimpl
async def on_end_run(self) -> None:
async def on_end_run(self, context: Context) -> None:
async with self._lock:
while self._keys:
# the process might have been killed.
key = self._keys.pop()
await self._registry.end(key)

self._run_no = None
await context.pubsub.end(key)

@hookimpl
async def on_start_trace(self, event: OnStartTrace) -> None:
assert self._run_no is not None
async def on_start_trace(self, context: Context, event: OnStartTrace) -> None:
assert context.run_arg
trace_no = event.trace_no

# TODO: Putting a prompt info for now because otherwise tests get stuck
# sometimes for an unknown reason. Need to investigate
prompt_info = PromptInfo(
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=trace_no,
prompt_no=PromptNo(-1),
open=False,
)
key = f"prompt_info_{trace_no}"
async with self._lock:
self._keys.add(key)
await self._registry.publish(key, prompt_info)
await context.pubsub.publish(key, prompt_info)

@hookimpl
async def on_end_trace(self, event: OnEndTrace) -> None:
async def on_end_trace(self, context: Context, event: OnEndTrace) -> None:
trace_no = event.trace_no
key = f"prompt_info_{trace_no}"
async with self._lock:
if key in self._keys:
self._keys.remove(key)
await self._registry.end(key)
await context.pubsub.end(key)

@hookimpl
async def on_start_trace_call(self, event: OnStartTraceCall) -> None:
self._trace_call_map[event.trace_no] = event

@hookimpl
async def on_end_trace_call(self, event: OnEndTraceCall) -> None:
assert self._run_no is not None
async def on_end_trace_call(self, context: Context, event: OnEndTraceCall) -> None:
assert context.run_arg
trace_no = event.trace_no
trace_call = self._trace_call_map.pop(event.trace_no, None)
if trace_call is None:
Expand All @@ -102,7 +91,7 @@ async def on_end_trace_call(self, event: OnEndTraceCall) -> None:
# prompt info.

prompt_info = PromptInfo(
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=trace_no,
prompt_no=PromptNo(-1),
open=False,
Expand All @@ -111,21 +100,21 @@ async def on_end_trace_call(self, event: OnEndTraceCall) -> None:
line_no=trace_call.line_no,
trace_call_end=True,
)
await self._registry.publish('prompt_info', prompt_info)
await context.pubsub.publish('prompt_info', prompt_info)

key = f"prompt_info_{trace_no}"
async with self._lock:
self._keys.add(key)
await self._registry.publish(key, prompt_info)
await context.pubsub.publish(key, prompt_info)

@hookimpl
async def on_start_prompt(self, event: OnStartPrompt) -> None:
assert self._run_no is not None
async def on_start_prompt(self, context: Context, event: OnStartPrompt) -> None:
assert context.run_arg
trace_no = event.trace_no
prompt_no = event.prompt_no
trace_call = self._trace_call_map[trace_no]
prompt_info = PromptInfo(
run_no=self._run_no,
run_no=context.run_arg.run_no,
trace_no=trace_no,
prompt_no=prompt_no,
open=True,
Expand All @@ -138,15 +127,15 @@ async def on_start_prompt(self, event: OnStartPrompt) -> None:
self._prompt_info_map[prompt_no] = prompt_info
self._last_prompt_frame_map[trace_no] = trace_call.frame_object_id

await self._registry.publish('prompt_info', prompt_info)
await context.pubsub.publish('prompt_info', prompt_info)

key = f"prompt_info_{trace_no}"
async with self._lock:
self._keys.add(key)
await self._registry.publish(key, prompt_info)
await context.pubsub.publish(key, prompt_info)

@hookimpl
async def on_end_prompt(self, event: OnEndPrompt) -> None:
async def on_end_prompt(self, context: Context, event: OnEndPrompt) -> None:
trace_no = event.trace_no
prompt_no = event.prompt_no
prompt_info = self._prompt_info_map.pop(prompt_no)
Expand All @@ -157,9 +146,9 @@ async def on_end_prompt(self, event: OnEndPrompt) -> None:
ended_at=event.ended_at,
)

await self._registry.publish('prompt_info', prompt_info_end)
await context.pubsub.publish('prompt_info', prompt_info_end)

key = f"prompt_info_{trace_no}"
async with self._lock:
self._keys.add(key)
await self._registry.publish(key, prompt_info_end)
await context.pubsub.publish(key, prompt_info_end)
Loading
Loading