diff --git a/contrib/pyln-client/pyln/client/__init__.py b/contrib/pyln-client/pyln/client/__init__.py index f397fc4bcb66..a89917fd66c4 100644 --- a/contrib/pyln-client/pyln/client/__init__.py +++ b/contrib/pyln-client/pyln/client/__init__.py @@ -1,5 +1,5 @@ from .lightning import LightningRpc, RpcError, Millisatoshi -from .plugin import Plugin, monkey_patch +from .plugin import Plugin, monkey_patch, RPCException __version__ = "0.8.0" @@ -9,6 +9,7 @@ "LightningRpc", "Plugin", "RpcError", + "RPCException", "Millisatoshi", "__version__", "monkey_patch" diff --git a/contrib/pyln-client/pyln/client/plugin.py b/contrib/pyln-client/pyln/client/plugin.py index cc2ce1bd3530..20d15cf4df8d 100644 --- a/contrib/pyln-client/pyln/client/plugin.py +++ b/contrib/pyln-client/pyln/client/plugin.py @@ -61,6 +61,14 @@ def __init__(self, name: str, func: Callable[..., JSONType], self.after: List[str] = [] +class RPCException(Exception): + # -32600 == "Invalid Request" + def __init__(self, message: str, code: int = -32600): + self.code = code + self.message = message + super().__init__("RPCException: {}".format(message)) + + class Request(dict): """A request object that wraps params and allows async return """ @@ -102,7 +110,7 @@ def set_result(self, result: Any) -> None: self.state = RequestState.FINISHED self.termination_tb = "".join(traceback.extract_stack().format()[:-1]) - def set_exception(self, exc: Exception) -> None: + def set_exception(self, exc: Union[Exception, RPCException]) -> None: if self.state != RequestState.PENDING: assert(self.termination_tb is not None) raise ValueError( @@ -110,13 +118,19 @@ def set_exception(self, exc: Exception) -> None: "current state is {state}. Request previously terminated at\n" "{tb}".format(state=self.state, tb=self.termination_tb)) self.exc = exc + if isinstance(exc, RPCException): + code = exc.code + message = exc.message + else: + code = -32600 # "Invalid Request" + message = ("Error while processing {method}: {exc}" + .format(method=self.method, exc=str(exc))) self._write_result({ 'jsonrpc': '2.0', 'id': self.id, "error": { - "code": -32600, # "Invalid Request" - "message": "Error while processing {method}: {exc}" - .format(method=self.method, exc=str(exc)), + "code": code, + "message": message, # 'data' field "may be omitted." "traceback": traceback.format_exc(), }, diff --git a/contrib/pyln-client/tests/test_plugin.py b/contrib/pyln-client/tests/test_plugin.py index 5a95ca616f4f..a659fb7628e0 100644 --- a/contrib/pyln-client/tests/test_plugin.py +++ b/contrib/pyln-client/tests/test_plugin.py @@ -1,5 +1,5 @@ from pyln.client import Plugin -from pyln.client.plugin import Request, Millisatoshi +from pyln.client.plugin import Request, Millisatoshi, RPCException import itertools import pytest # type: ignore @@ -172,6 +172,39 @@ def test1(name): assert call_list == [] +def test_method_exceptions(): + """A bunch of tests that should fail calling the methods.""" + p = Plugin(autopatch=False) + + def fake_write_result(resultdict): + global result_dict + result_dict = resultdict + + @p.method("test_raise") + def test_raise(): + raise RPCException("testing RPCException", code=-1000) + + req = Request(p, 1, "test_raise", {}) + req._write_result = fake_write_result + p._dispatch_request(req) + assert result_dict['jsonrpc'] == '2.0' + assert result_dict['id'] == 1 + assert result_dict['error']['code'] == -1000 + assert result_dict['error']['message'] == "testing RPCException" + + @p.method("test_raise2") + def test_raise2(): + raise Exception("normal exception") + + req = Request(p, 1, "test_raise2", {}) + req._write_result = fake_write_result + p._dispatch_request(req) + assert result_dict['jsonrpc'] == '2.0' + assert result_dict['id'] == 1 + assert result_dict['error']['code'] == -32600 + assert result_dict['error']['message'] == "Error while processing test_raise2: normal exception" + + def test_positional_inject(): p = Plugin() rdict = Request(