From 2783e2483bd790ac20a69fbf8696099a7f19de15 Mon Sep 17 00:00:00 2001 From: chavacava Date: Sat, 28 Dec 2024 18:09:40 +0100 Subject: [PATCH] feature: add @invariant on struct types --- internal/contract/contract.go | 64 +++++++++-- internal/contract/generator/generator.go | 135 +++++++++++++++++++---- internal/contract/parser/parser.go | 40 ++++++- internal/contract/parser/parser_test.go | 2 +- 4 files changed, 209 insertions(+), 32 deletions(-) diff --git a/internal/contract/contract.go b/internal/contract/contract.go index d9f85fd..da96e99 100644 --- a/internal/contract/contract.go +++ b/internal/contract/contract.go @@ -9,6 +9,50 @@ import ( "strings" ) +// TypeContract represents a contract associated to a type. +// Typically a @invariant contract +type TypeContract struct { + ensures []Ensures + imports map[string]struct{} + targetTypeName string +} + +// NewTypeContract creates a TypeContract +// @requires target != "" +// @ensures c.targetTypeName == target +// @ensures len(c.ensures) == 0 +// @ensures len(c.imports) == 0 +func NewTypeContract(target string) (c *TypeContract) { + return &TypeContract{ + ensures: []Ensures{}, + targetTypeName: target, + imports: map[string]struct{}{}, + } +} + +// AddEnsures adds a ensures to this contract +// ensures len(c.ensures) == len(@old(c.ensures)) + 1 +// @ensures c.ensures[len(c.ensures)-1] == e +func (c *TypeContract) AddEnsures(e Ensures) { + c.ensures = append(c.ensures, e) +} + +// Ensures yields ensures clauses of this contract +// @ensures len(r) == len(c.ensures) +func (c *TypeContract) Ensures() (r []Ensures) { + return c.ensures +} + +// AddImport adds an import to this contract +func (c *TypeContract) AddImport(path string) { + c.imports[strings.Trim(path, "\"")] = struct{}{} +} + +// Imports returns imports required by this contract +func (c *TypeContract) Imports() map[string]struct{} { + return c.imports +} + // FuncContract represents a contract associated to a function type FuncContract struct { requires []Requires @@ -59,6 +103,16 @@ func (c *FuncContract) Ensures() (r []Ensures) { return c.ensures } +// AddImport adds an import to this contract +func (c *FuncContract) AddImport(path string) { + c.imports[strings.Trim(path, "\"")] = struct{}{} +} + +// Imports returns imports required by this contract +func (c *FuncContract) Imports() map[string]struct{} { + return c.imports +} + // Requires is a @requires clause of a contract type Requires struct { expr string @@ -139,13 +193,3 @@ func rewriteImpliesExpr(expr string) string { return "!(" + p + ") || (" + q + ")" } - -// AddImport adds an import to this contract -func (c *FuncContract) AddImport(path string) { - c.imports[strings.Trim(path, "\"")] = struct{}{} -} - -// Imports returns imports required by this contract -func (c *FuncContract) Imports() map[string]struct{} { - return c.imports -} diff --git a/internal/contract/generator/generator.go b/internal/contract/generator/generator.go index dcd0ce5..5521e4d 100644 --- a/internal/contract/generator/generator.go +++ b/internal/contract/generator/generator.go @@ -54,7 +54,11 @@ func analyzeCode(src io.Reader) (r bytes.Buffer, err error) { return bytes.Buffer{}, fmt.Errorf("unable to decorate AST: %v", err) } - fileAnalyzer := fileAnalyzer{decorator: astDecorator, imports: importsContainer{}} + fileAnalyzer := fileAnalyzer{ + decorator: astDecorator, + imports: importsContainer{}, + typeInvariantsCode: map[string][]string{}, + } // walk the AST with the analyzer to find contracts and generate their contracts ast.Walk(fileAnalyzer, astFile) @@ -101,8 +105,9 @@ func analyzeCode(src io.Reader) (r bytes.Buffer, err error) { type importsContainer map[string]struct{} type fileAnalyzer struct { - decorator *decorator.Decorator - imports importsContainer + decorator *decorator.Decorator + imports importsContainer + typeInvariantsCode map[string][]string } func (fa fileAnalyzer) Visit(node ast.Node) ast.Visitor { @@ -110,11 +115,47 @@ func (fa fileAnalyzer) Visit(node ast.Node) ast.Visitor { case *ast.FuncDecl: fa.rewriteFuncDecl(n) return nil //skip visiting function body + case *ast.GenDecl: + if n.Tok != token.TYPE { + return nil // not a type declaration + } + if len(n.Specs) <= 0 { + return nil // no specs in the type declaration + } + typeSpec, ok := (n.Specs[0]).(*ast.TypeSpec) + if !ok { + return nil // not a type declaration + } + + fa.analyzeTypeContract(typeSpec.Name.Name, n.Doc) + return nil // skip visiting the type fields } return fa } +func (fa fileAnalyzer) analyzeTypeContract(typeName string, doc *ast.CommentGroup) { + if doc == nil { + return // nothing to do, the type does not have associated documentation + } + + contractParser := contractParser.NewParser() + contract := contract.NewTypeContract(typeName) + for _, commentLine := range doc.List { + err := contractParser.ParseTypeContract(contract, commentLine.Text) + if err != nil { + log.Printf("%s: Warning: %s", fa.positionAsString(commentLine.Pos()), err.Error()) + continue + } + } + + fa.addCodeForTypeInvariant(typeName, contract) +} + +func (fa fileAnalyzer) addCodeForTypeInvariant(typeName string, contract *contract.TypeContract) { + fa.typeInvariantsCode[typeName] = fa.generateInvariantCode(contract) +} + // positionAsString returns a string representation of the given token position // @requires fa.decorator.Fset != nil func (fa fileAnalyzer) positionAsString(pos token.Pos) string { @@ -126,28 +167,60 @@ func (fa fileAnalyzer) positionAsString(pos token.Pos) string { // rewriteFuncDecl is in charge of generating contract-enforcing code for functions // @requires fd != nil func (fa *fileAnalyzer) rewriteFuncDecl(fd *ast.FuncDecl) { - if fd.Doc == nil { - return // nothing to do, the function does not have a comment - } + dstFuncDecl := fa.decorator.Dst.Nodes[fd].(*dst.FuncDecl) + if fd.Doc != nil { + contractParser := contractParser.NewParser() + contract := contract.NewFuncContract(fd) + comments := fd.Doc.List + for _, commentLine := range comments { + err := contractParser.ParseFuncContract(contract, commentLine.Text) + if err != nil { + log.Printf("%s: Warning: %s", fa.positionAsString(commentLine.Pos()), err.Error()) + continue + } + } - contractParser := contractParser.NewParser() - contract := contract.NewFuncContract(fd) - comments := fd.Doc.List - for _, commentLine := range comments { - err := contractParser.Parse(contract, commentLine.Text) - if err != nil { - log.Printf("%s: Warning: %s", fa.positionAsString(commentLine.Pos()), err.Error()) - continue + contractStmts, errs := fa.generateCode(contract) + for _, err := range errs { + log.Printf("Warning: %v", err) } + + dstFuncDecl.Body.Decorations().Start.Append(contractStmts...) } - contractStmts, errs := fa.generateCode(contract) - for _, err := range errs { - log.Printf("Warning: %v", err) + // Also add code for enforce invariants if available + if fd.Recv == nil || len(fd.Recv.List) < 1 { + return // not a method thus no invariants } - dstFuncDecl := fa.decorator.Dst.Nodes[fd].(*dst.FuncDecl) - dstFuncDecl.Body.Decorations().Start.Append(contractStmts...) + receiverType := fa.getReceiverTypeName(fd.Recv) + invariantCode, ok := fa.typeInvariantsCode[receiverType] + if !ok { + return // did not found invariant code associated to this method's receiver + } + + if len(fd.Recv.List[0].Names) < 1 || fd.Recv.List[0].Names[0].Name == "_" { + // anonymous receiver + log.Printf("Warning: can not enforce invariants on method %s because it has an anonymous receiver", fd.Name.Name) + return + // TODO: insert a receiver name to enable checks + } + + receiverName := fd.Recv.List[0].Names[0].Name + invariantCodeForMethod := make([]string, len(invariantCode)) + for i, code := range invariantCode { + invariantCodeForMethod[i] = strings.ReplaceAll(code, receiverType+".", receiverName+".") + } + dstFuncDecl.Body.Decorations().Start.Append(invariantCodeForMethod...) +} + +func (fa fileAnalyzer) getReceiverTypeName(receiver *ast.FieldList) string { + if len(receiver.List) < 1 { + return "UNKNOWN" + } + + recType := receiver.List[0].Type + return strings.Replace(fa.typeAsString(recType), "*", "", 1) } // generateCode yields the list of GO statements that enforce the given contract @@ -184,6 +257,30 @@ func (fa fileAnalyzer) generateCode(c *contract.FuncContract) (stmts []string, e const commentPrefix = "//dbc4go " +func (fa fileAnalyzer) generateInvariantCode(c *contract.TypeContract) (stmts []string) { + result := []string{} + + const templateEnsure = commentPrefix + `if !(%cond%) { panic("type invariant %contract% not satisfied") }` + clauses := c.Ensures() + ensuresCode := make([]string, len(clauses)) + for _, clause := range clauses { + exp, _ := clause.ExpandedExpression() + ensure := strings.Replace(templateEnsure, "%cond%", exp, 1) + ensure = strings.Replace(ensure, "%contract%", escapeDoubleQuotes(clause.String()), 1) + ensuresCode = append(ensuresCode, ensure) + } + const templateDeferredFunction = commentPrefix + `defer func(){%checks%}()` + r := strings.Replace(templateDeferredFunction, "%checks%", strings.Join(ensuresCode, "\n"), 1) + result = append(result, r) + + // merge new imports into imports list + for k, v := range c.Imports() { + fa.imports[k] = v + } + + return result +} + // @ensures r == "" ==> e != nil func (fileAnalyzer) generateRequiresCode(req contract.Requires) (r string, e error) { const templateRequire = commentPrefix + `if !(%cond%) { panic("precondition %contract% not satisfied") }` diff --git a/internal/contract/parser/parser.go b/internal/contract/parser/parser.go index e3d849b..c84dc7e 100644 --- a/internal/contract/parser/parser.go +++ b/internal/contract/parser/parser.go @@ -23,9 +23,45 @@ func NewParser() Parser { var reContracts = regexp.MustCompile(`\s*//\s*@(?P[a-z]+)(?:[\t ]+(?P\[[\w\s\d,]+\]))?[\t ]+(?P[^$]+)`) -// Parse enrich the Contract with the clause if present in the given comment line +// ParseTypeContract enrich the contract with the clause if present in the given comment line +// @requires typeContract != nil +func (p Parser) ParseTypeContract(typeContract *contract.TypeContract, line string) error { + kind, description, expr, matched := parseLine(line) + if !matched { + return nil // nothing to do, there is no contract in this comment line + } + + switch kind { + case "invariant": + if contract.Re4old.MatchString(expr) { + return fmt.Errorf("@old can not be used in @invariant expressions: %s", expr) + } + + clause, err := p.parseEnsures(expr, description) // invariants are ensures that apply to all methods of the type + if err != nil { + return fmt.Errorf("invalid @invariant clause: %w", err) + } + + typeContract.AddEnsures(clause) + case "import": + clause, err := p.parseImport(expr) + if err != nil { + return fmt.Errorf("invalid @import clause: %w", err) + } + + typeContract.AddImport(clause) + case "ensures", "requires", "unmodified": + return fmt.Errorf("@%s can not be used in type contracts: %s", kind, expr) + default: + return errors.Errorf("unknown contract kind %s", kind) + } + + return nil +} + +// ParseFuncContract enrich the Contract with the clause if present in the given comment line // @requires funcContract != nil -func (p Parser) Parse(funcContract *contract.FuncContract, line string) error { +func (p Parser) ParseFuncContract(funcContract *contract.FuncContract, line string) error { kind, description, expr, matched := parseLine(line) if !matched { return nil // nothing to do, there is no contract in this comment line diff --git a/internal/contract/parser/parser_test.go b/internal/contract/parser/parser_test.go index d92c6a9..df24a6b 100644 --- a/internal/contract/parser/parser_test.go +++ b/internal/contract/parser/parser_test.go @@ -128,7 +128,7 @@ func TestParse(t *testing.T) { } p := NewParser() for _, tc := range tests { - err := p.Parse(tc.contract, tc.line) + err := p.ParseFuncContract(tc.contract, tc.line) if tc.err { assert.NotEqual(t, err, nil, "line %s", tc.line) }