diff --git a/contrib/msggen/msggen/model.py b/contrib/msggen/msggen/model.py index 2acc589a484f..1d887e018d1e 100644 --- a/contrib/msggen/msggen/model.py +++ b/contrib/msggen/msggen/model.py @@ -39,6 +39,7 @@ def __init__( self.added = added self.deprecated = deprecated self.required = False + self.parent: Optional["Field"] = None @property def name(self): @@ -117,6 +118,8 @@ def __init__( ) self.typename = typename self.fields = fields + for f in self.fields: + f.parent = self @classmethod def from_js(cls, js, path): @@ -315,11 +318,12 @@ def __str__(self): class ArrayField(Field): - def __init__(self, itemtype, dims, path, description, added, deprecated): + def __init__(self, itemtype: Field, dims, path, description, added, deprecated): Field.__init__(self, path, description, added=added, deprecated=deprecated) self.itemtype = itemtype self.dims = dims self.path = path + self.itemtype.parent = self @classmethod def from_js(cls, path, js): @@ -357,9 +361,11 @@ def from_js(cls, path, js): class Command: - def __init__(self, name, fields): + def __init__(self, name, fields: List[Field]): self.name = name self.fields = fields + for f in fields: + f.parent = self def __str__(self): fieldnames = ",".join([f.path.split(".")[-1] for f in self.fields]) diff --git a/contrib/msggen/msggen/patch.py b/contrib/msggen/msggen/patch.py index 63250e037aba..c224eea33bc4 100644 --- a/contrib/msggen/msggen/patch.py +++ b/contrib/msggen/msggen/patch.py @@ -1,5 +1,7 @@ from abc import ABC from msggen import model +import logging + class Patch(ABC): """A patch that can be applied to an in-memory model @@ -130,3 +132,77 @@ def visit(self, f: model.Field) -> None: if f.deprecated and self.versions.index(f.deprecated) < idx[1]: f.optional = True + +class ParentAnnotation(Patch): + """Annotate each field with its parent if there is one. + """ + def visit(self, field: model.Field) -> None: + if isinstance(field, model.CompositeField): + for f in field.fields: + f.parent = field + + +class GrpcNumberingPatch(Patch): + """GRPC uses fixed numberings for its field, annotate the model + """ + + def __init__(self, meta) -> None: + """Create a patch that can annotate `added` and `deprecated` + """ + self.meta = meta + self.logger = logging.getLogger("msggen.patch.GrpcNumberingPatch") + + def visit(self, f: model.Field) -> None: + if not f.parent: + return + + # Our number is stored in the context of our parent, hence + # this lookup here. + typename = f.parent.typename + if isinstance(f, model.EnumField): + for v in f.variants: + self.visit_variant(typename, v) + else: + f.grpc_id = self.field2number(typename, f) + + def visit_variant(self, typename, v): + v.grpc_id = self.enumvar2number(typename, v) + + def field2number(self, message_name, field): + m = self.meta['grpc-field-map'] + # Ensure new types are actually added to the map + if message_name not in m: + m[message_name] = {} + m = m[message_name] + + # Simple case first: if we've already assigned a number let's + # reuse that + if field.path in m: + return m[field.path] + + # Now let's find the highest number we have in the parent + # context + maxnum = max([0] + list(m.values())) + m[field.path] = maxnum + 1 + self.logger.warn(f"Assigning new field number to {field.path} => {m[field.path]}") + return m[field.path] + + def enumvar2number(self, typename, variant): + """Find an existing variant number of generate a new one. + + If we don't have a variant number yet we'll just take the + largest one assigned so far and increment it by 1. """ + m = self.meta['grpc-enum-map'] + variant = str(variant) + if typename not in m: + m[typename] = {} + + variants = m[typename] + if variant in variants: + return variants[variant] + + # Now find the maximum and increment once + n = max(variants.values()) if len(variants) else -1 + + m[typename][variant] = n + 1 + return m[typename][variant]