Skip to content

Commit

Permalink
msggen: use field numbers from .msggen.json for rust model if available
Browse files Browse the repository at this point in the history
  • Loading branch information
daywalker90 committed Apr 14, 2024
1 parent 2f54add commit 79c68e9
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
6 changes: 3 additions & 3 deletions contrib/msggen/msggen/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,10 +44,10 @@ def add_handler_get_grpc2py(generator_chain: GeneratorChain):
generator_chain.add_generator(Grpc2PyGenerator(dest))


def add_handler_gen_rust_jsonrpc(generator_chain: GeneratorChain):
def add_handler_gen_rust_jsonrpc(generator_chain: GeneratorChain, meta):
fname = Path("cln-rpc") / "src" / "model.rs"
dest = open(fname, "w")
generator_chain.add_generator(RustGenerator(dest))
generator_chain.add_generator(RustGenerator(dest, meta))


def load_msggen_meta():
Expand Down Expand Up @@ -81,7 +81,7 @@ def run():
generator_chain = GeneratorChain()

add_handler_gen_grpc(generator_chain, meta)
add_handler_gen_rust_jsonrpc(generator_chain)
add_handler_gen_rust_jsonrpc(generator_chain, meta)
add_handler_get_grpc2py(generator_chain)

generator_chain.generate(service)
Expand Down
56 changes: 38 additions & 18 deletions contrib/msggen/msggen/gen/rust.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import TextIO
from typing import Tuple
from typing import TextIO, Tuple, Dict, Any
from textwrap import dedent, indent
import logging
import sys
Expand Down Expand Up @@ -65,22 +64,22 @@ def normalize_varname(field):
return field


def gen_field(field):
def gen_field(field, meta):
if field.omit():
return ("", "")
if isinstance(field, CompositeField):
return gen_composite(field)
return gen_composite(field, meta)
elif isinstance(field, EnumField):
return gen_enum(field)
return gen_enum(field, meta)
elif isinstance(field, ArrayField):
return gen_array(field)
return gen_array(field, meta)
elif isinstance(field, PrimitiveField):
return gen_primitive(field)
else:
raise TypeError(f"Unmanaged type {field}")


def gen_enum(e):
def gen_enum(e, meta):
defi, decl = "", ""

if e.omit():
Expand Down Expand Up @@ -108,10 +107,30 @@ def gen_enum(e):
fn try_from(c: i32) -> Result<{e.typename}, anyhow::Error> {{
match c {{
""")
for i, v in enumerate(e.variants):
norm = v.normalized()
# decl += f" #[serde(rename = \"{v}\")]\n"
decl += f" {i} => Ok({e.typename}::{norm}),\n"

m = meta['grpc-field-map']
m2 = meta['grpc-enum-map']

message_name = e.typename.name
assert not (message_name in m and message_name in m2)
if message_name in m:
m = m[message_name]
elif message_name in m2:
m = m2[message_name]
else:
m = {}

if m != {}:
for v in e.variants:
norm = v.normalized()
# decl += f" #[serde(rename = \"{v}\")]\n"
decl += f" {m[str(v)]} => Ok({e.typename}::{norm}),\n"
else:
for i, v in enumerate(e.variants):
norm = v.normalized()
# decl += f" #[serde(rename = \"{v}\")]\n"
decl += f" {i} => Ok({e.typename}::{norm}),\n"

decl += dedent(f"""\
o => Err(anyhow::anyhow!("Unknown variant {{}} for enum {e.typename}", o)),
}}
Expand Down Expand Up @@ -178,10 +197,10 @@ def rename_if_necessary(original, name):
return f""


def gen_array(a):
def gen_array(a, meta):
name = a.name.normalized().replace("[]", "")
logger.debug(f"Generating array field {a.name} -> {name} ({a.path})")
_, decl = gen_field(a.itemtype)
_, decl = gen_field(a.itemtype, meta)

if a.override():
decl = "" # No declaration if we have an override
Expand Down Expand Up @@ -210,11 +229,11 @@ def gen_array(a):
return (defi, decl)


def gen_composite(c) -> Tuple[str, str]:
def gen_composite(c, meta) -> Tuple[str, str]:
logger.debug(f"Generating composite field {c.name} ({c.path})")
fields = []
for f in c.fields:
fields.append(gen_field(f))
fields.append(gen_field(f, meta))

r = "".join([f[1] for f in fields])

Expand All @@ -236,8 +255,9 @@ def gen_composite(c) -> Tuple[str, str]:


class RustGenerator(IGenerator):
def __init__(self, dest: TextIO):
def __init__(self, dest: TextIO, meta: Dict[str, Any]):
self.dest = dest
self.meta = meta

def write(self, text: str, numindent: int = 0) -> None:
raw = dedent(text)
Expand All @@ -258,7 +278,7 @@ def generate_requests(self, service: Service):

for meth in service.methods:
req = meth.request
_, decl = gen_composite(req)
_, decl = gen_composite(req, self.meta)
self.write(decl, numindent=1)
self.generate_request_trait_impl(meth)

Expand Down Expand Up @@ -298,7 +318,7 @@ def generate_responses(self, service: Service):

for meth in service.methods:
res = meth.response
_, decl = gen_composite(res)
_, decl = gen_composite(res, self.meta)
self.write(decl, numindent=1)
self.generate_response_trait_impl(meth)

Expand Down

0 comments on commit 79c68e9

Please sign in to comment.