Skip to content

Commit

Permalink
feature: add @invariant on struct types
Browse files Browse the repository at this point in the history
  • Loading branch information
chavacava committed Dec 28, 2024
1 parent 89cc8c9 commit 2783e24
Show file tree
Hide file tree
Showing 4 changed files with 209 additions and 32 deletions.
64 changes: 54 additions & 10 deletions internal/contract/contract.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,50 @@ import (
"strings"
)

// TypeContract represents a contract associated to a type.
// Typically a @invariant contract
type TypeContract struct {
ensures []Ensures
imports map[string]struct{}
targetTypeName string
}

// NewTypeContract creates a TypeContract
// @requires target != ""
// @ensures c.targetTypeName == target
// @ensures len(c.ensures) == 0
// @ensures len(c.imports) == 0
func NewTypeContract(target string) (c *TypeContract) {
return &TypeContract{
ensures: []Ensures{},
targetTypeName: target,
imports: map[string]struct{}{},
}
}

// AddEnsures adds a ensures to this contract
// ensures len(c.ensures) == len(@old(c.ensures)) + 1
// @ensures c.ensures[len(c.ensures)-1] == e
func (c *TypeContract) AddEnsures(e Ensures) {
c.ensures = append(c.ensures, e)
}

// Ensures yields ensures clauses of this contract
// @ensures len(r) == len(c.ensures)
func (c *TypeContract) Ensures() (r []Ensures) {
return c.ensures
}

// AddImport adds an import to this contract
func (c *TypeContract) AddImport(path string) {
c.imports[strings.Trim(path, "\"")] = struct{}{}
}

// Imports returns imports required by this contract
func (c *TypeContract) Imports() map[string]struct{} {
return c.imports
}

// FuncContract represents a contract associated to a function
type FuncContract struct {
requires []Requires
Expand Down Expand Up @@ -59,6 +103,16 @@ func (c *FuncContract) Ensures() (r []Ensures) {
return c.ensures
}

// AddImport adds an import to this contract
func (c *FuncContract) AddImport(path string) {
c.imports[strings.Trim(path, "\"")] = struct{}{}
}

// Imports returns imports required by this contract
func (c *FuncContract) Imports() map[string]struct{} {
return c.imports
}

// Requires is a @requires clause of a contract
type Requires struct {
expr string
Expand Down Expand Up @@ -139,13 +193,3 @@ func rewriteImpliesExpr(expr string) string {

return "!(" + p + ") || (" + q + ")"
}

// AddImport adds an import to this contract
func (c *FuncContract) AddImport(path string) {
c.imports[strings.Trim(path, "\"")] = struct{}{}
}

// Imports returns imports required by this contract
func (c *FuncContract) Imports() map[string]struct{} {
return c.imports
}
135 changes: 116 additions & 19 deletions internal/contract/generator/generator.go
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,11 @@ func analyzeCode(src io.Reader) (r bytes.Buffer, err error) {
return bytes.Buffer{}, fmt.Errorf("unable to decorate AST: %v", err)
}

fileAnalyzer := fileAnalyzer{decorator: astDecorator, imports: importsContainer{}}
fileAnalyzer := fileAnalyzer{
decorator: astDecorator,
imports: importsContainer{},
typeInvariantsCode: map[string][]string{},
}
// walk the AST with the analyzer to find contracts and generate their contracts
ast.Walk(fileAnalyzer, astFile)

Expand Down Expand Up @@ -101,20 +105,57 @@ func analyzeCode(src io.Reader) (r bytes.Buffer, err error) {
type importsContainer map[string]struct{}

type fileAnalyzer struct {
decorator *decorator.Decorator
imports importsContainer
decorator *decorator.Decorator
imports importsContainer
typeInvariantsCode map[string][]string
}

func (fa fileAnalyzer) Visit(node ast.Node) ast.Visitor {
switch n := node.(type) {
case *ast.FuncDecl:
fa.rewriteFuncDecl(n)
return nil //skip visiting function body
case *ast.GenDecl:
if n.Tok != token.TYPE {
return nil // not a type declaration
}
if len(n.Specs) <= 0 {
return nil // no specs in the type declaration
}
typeSpec, ok := (n.Specs[0]).(*ast.TypeSpec)
if !ok {
return nil // not a type declaration
}

fa.analyzeTypeContract(typeSpec.Name.Name, n.Doc)
return nil // skip visiting the type fields
}

return fa
}

func (fa fileAnalyzer) analyzeTypeContract(typeName string, doc *ast.CommentGroup) {
if doc == nil {
return // nothing to do, the type does not have associated documentation
}

contractParser := contractParser.NewParser()
contract := contract.NewTypeContract(typeName)
for _, commentLine := range doc.List {
err := contractParser.ParseTypeContract(contract, commentLine.Text)
if err != nil {
log.Printf("%s: Warning: %s", fa.positionAsString(commentLine.Pos()), err.Error())
continue
}
}

fa.addCodeForTypeInvariant(typeName, contract)
}

func (fa fileAnalyzer) addCodeForTypeInvariant(typeName string, contract *contract.TypeContract) {
fa.typeInvariantsCode[typeName] = fa.generateInvariantCode(contract)
}

// positionAsString returns a string representation of the given token position
// @requires fa.decorator.Fset != nil
func (fa fileAnalyzer) positionAsString(pos token.Pos) string {
Expand All @@ -126,28 +167,60 @@ func (fa fileAnalyzer) positionAsString(pos token.Pos) string {
// rewriteFuncDecl is in charge of generating contract-enforcing code for functions
// @requires fd != nil
func (fa *fileAnalyzer) rewriteFuncDecl(fd *ast.FuncDecl) {
if fd.Doc == nil {
return // nothing to do, the function does not have a comment
}
dstFuncDecl := fa.decorator.Dst.Nodes[fd].(*dst.FuncDecl)
if fd.Doc != nil {
contractParser := contractParser.NewParser()
contract := contract.NewFuncContract(fd)
comments := fd.Doc.List
for _, commentLine := range comments {
err := contractParser.ParseFuncContract(contract, commentLine.Text)
if err != nil {
log.Printf("%s: Warning: %s", fa.positionAsString(commentLine.Pos()), err.Error())
continue
}
}

contractParser := contractParser.NewParser()
contract := contract.NewFuncContract(fd)
comments := fd.Doc.List
for _, commentLine := range comments {
err := contractParser.Parse(contract, commentLine.Text)
if err != nil {
log.Printf("%s: Warning: %s", fa.positionAsString(commentLine.Pos()), err.Error())
continue
contractStmts, errs := fa.generateCode(contract)
for _, err := range errs {
log.Printf("Warning: %v", err)
}

dstFuncDecl.Body.Decorations().Start.Append(contractStmts...)
}

contractStmts, errs := fa.generateCode(contract)
for _, err := range errs {
log.Printf("Warning: %v", err)
// Also add code for enforce invariants if available
if fd.Recv == nil || len(fd.Recv.List) < 1 {
return // not a method thus no invariants
}

dstFuncDecl := fa.decorator.Dst.Nodes[fd].(*dst.FuncDecl)
dstFuncDecl.Body.Decorations().Start.Append(contractStmts...)
receiverType := fa.getReceiverTypeName(fd.Recv)
invariantCode, ok := fa.typeInvariantsCode[receiverType]
if !ok {
return // did not found invariant code associated to this method's receiver
}

if len(fd.Recv.List[0].Names) < 1 || fd.Recv.List[0].Names[0].Name == "_" {
// anonymous receiver
log.Printf("Warning: can not enforce invariants on method %s because it has an anonymous receiver", fd.Name.Name)
return
// TODO: insert a receiver name to enable checks
}

receiverName := fd.Recv.List[0].Names[0].Name
invariantCodeForMethod := make([]string, len(invariantCode))
for i, code := range invariantCode {
invariantCodeForMethod[i] = strings.ReplaceAll(code, receiverType+".", receiverName+".")
}
dstFuncDecl.Body.Decorations().Start.Append(invariantCodeForMethod...)
}

func (fa fileAnalyzer) getReceiverTypeName(receiver *ast.FieldList) string {
if len(receiver.List) < 1 {
return "UNKNOWN"
}

recType := receiver.List[0].Type
return strings.Replace(fa.typeAsString(recType), "*", "", 1)
}

// generateCode yields the list of GO statements that enforce the given contract
Expand Down Expand Up @@ -184,6 +257,30 @@ func (fa fileAnalyzer) generateCode(c *contract.FuncContract) (stmts []string, e

const commentPrefix = "//dbc4go "

func (fa fileAnalyzer) generateInvariantCode(c *contract.TypeContract) (stmts []string) {
result := []string{}

const templateEnsure = commentPrefix + `if !(%cond%) { panic("type invariant %contract% not satisfied") }`
clauses := c.Ensures()
ensuresCode := make([]string, len(clauses))
for _, clause := range clauses {
exp, _ := clause.ExpandedExpression()
ensure := strings.Replace(templateEnsure, "%cond%", exp, 1)
ensure = strings.Replace(ensure, "%contract%", escapeDoubleQuotes(clause.String()), 1)
ensuresCode = append(ensuresCode, ensure)
}
const templateDeferredFunction = commentPrefix + `defer func(){%checks%}()`
r := strings.Replace(templateDeferredFunction, "%checks%", strings.Join(ensuresCode, "\n"), 1)
result = append(result, r)

// merge new imports into imports list
for k, v := range c.Imports() {
fa.imports[k] = v
}

return result
}

// @ensures r == "" ==> e != nil
func (fileAnalyzer) generateRequiresCode(req contract.Requires) (r string, e error) {
const templateRequire = commentPrefix + `if !(%cond%) { panic("precondition %contract% not satisfied") }`
Expand Down
40 changes: 38 additions & 2 deletions internal/contract/parser/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,45 @@ func NewParser() Parser {

var reContracts = regexp.MustCompile(`\s*//\s*@(?P<kind>[a-z]+)(?:[\t ]+(?P<description>\[[\w\s\d,]+\]))?[\t ]+(?P<expr>[^$]+)`)

// Parse enrich the Contract with the clause if present in the given comment line
// ParseTypeContract enrich the contract with the clause if present in the given comment line
// @requires typeContract != nil
func (p Parser) ParseTypeContract(typeContract *contract.TypeContract, line string) error {
kind, description, expr, matched := parseLine(line)
if !matched {
return nil // nothing to do, there is no contract in this comment line
}

switch kind {
case "invariant":
if contract.Re4old.MatchString(expr) {
return fmt.Errorf("@old can not be used in @invariant expressions: %s", expr)
}

clause, err := p.parseEnsures(expr, description) // invariants are ensures that apply to all methods of the type
if err != nil {
return fmt.Errorf("invalid @invariant clause: %w", err)
}

typeContract.AddEnsures(clause)
case "import":
clause, err := p.parseImport(expr)
if err != nil {
return fmt.Errorf("invalid @import clause: %w", err)
}

typeContract.AddImport(clause)
case "ensures", "requires", "unmodified":
return fmt.Errorf("@%s can not be used in type contracts: %s", kind, expr)
default:
return errors.Errorf("unknown contract kind %s", kind)
}

return nil
}

// ParseFuncContract enrich the Contract with the clause if present in the given comment line
// @requires funcContract != nil
func (p Parser) Parse(funcContract *contract.FuncContract, line string) error {
func (p Parser) ParseFuncContract(funcContract *contract.FuncContract, line string) error {
kind, description, expr, matched := parseLine(line)
if !matched {
return nil // nothing to do, there is no contract in this comment line
Expand Down
2 changes: 1 addition & 1 deletion internal/contract/parser/parser_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ func TestParse(t *testing.T) {
}
p := NewParser()
for _, tc := range tests {
err := p.Parse(tc.contract, tc.line)
err := p.ParseFuncContract(tc.contract, tc.line)
if tc.err {
assert.NotEqual(t, err, nil, "line %s", tc.line)
}
Expand Down

0 comments on commit 2783e24

Please sign in to comment.