Skip to content

Commit

Permalink
feat(check): run on all ExecuteTemplate calls
Browse files Browse the repository at this point in the history
this reduces coupling between generated routes func and check
  • Loading branch information
crhntr authored Feb 5, 2025
1 parent b2cfd39 commit 06fae19
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 85 deletions.
2 changes: 1 addition & 1 deletion cmd/muxt/check.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ func checkCommand(workingDirectory string, args []string, stderr io.Writer) erro
return err
}
if err := muxt.CheckTemplates(workingDirectory, log.New(stderr, "", 0), config); err != nil {
return fmt.Errorf("fail")
return fmt.Errorf("fail: %s", err)
}
return nil
}
19 changes: 14 additions & 5 deletions cmd/muxt/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,18 @@ func Test_example(t *testing.T) {
require.NoError(t, os.Remove(filepath.FromSlash("../../example/template_routes.go")))

ctx := context.TODO()
cmd := exec.CommandContext(ctx, "go", "generate")
cmd.Dir = filepath.FromSlash("../../example")
cmd.Stderr = os.Stderr
cmd.Stdout = os.Stdout
require.NoError(t, cmd.Run())
t.Run("generate", func(t *testing.T) {
cmd := exec.CommandContext(ctx, "go", "generate")
cmd.Dir = filepath.FromSlash("../../example")
cmd.Stderr = os.Stdout
cmd.Stdout = os.Stdout
require.NoError(t, cmd.Run())
})
t.Run("check", func(t *testing.T) {
cmd := exec.CommandContext(ctx, "go", "run", ".", "-C", filepath.FromSlash("../../example"), "check", "--receiver-type", "Backend")
cmd.Dir = "."
cmd.Stderr = os.Stdout
cmd.Stdout = os.Stdout
require.NoError(t, cmd.Run())
})
}
2 changes: 1 addition & 1 deletion cmd/muxt/testdata/generate/check_types.txtar
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ stderr 'checking endpoint GET / Endpoint\(\)'
stderr 'ERROR argument 0 has type int32 expected int64'

-- template.gohtml --
{{define "GET / Endpoint()" }}Number {{.Number | .Square}}{{end}}
{{define "GET / Endpoint()" }}Number {{.Data.Number | .Data.Square}}{{end}}

-- go.mod --
module server
Expand Down
10 changes: 5 additions & 5 deletions example/index.gohtml
Original file line number Diff line number Diff line change
Expand Up @@ -50,18 +50,18 @@
{{- end}}

{{- define "GET /{$} List(ctx)" -}}
{{template "index.gohtml" .}}
{{template "index.gohtml" .Data}}
{{- end -}}

{{- define "GET /fruits/{id}/edit GetFormEditRow(id)" -}}
{{template "edit-row" .}}
{{template "edit-row" .Data}}
{{- end -}}

{{- define "PATCH /fruits/{id} SubmitFormEditRow(id, form)" }}
{{- if .Error -}}
{{template "edit-row" .}}
{{- if .Data.Error -}}
{{template "edit-row" .Data}}
{{- else -}}
{{template "view-row" .Row}}
{{template "view-row" .Data.Row}}
{{- end -}}
{{ end -}}

Expand Down
79 changes: 79 additions & 0 deletions example/template_routes_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
package main

import (
"context"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"

"github.com/crhntr/dom/domtest"
"github.com/stretchr/testify/require"
"golang.org/x/net/html/atom"
)

func TestRoutes(t *testing.T) {
//"PATCH /fruits/{id}"
t.Run("update fruit with id", func(t *testing.T) {
mux := http.NewServeMux()
fake := new(FakeBackend)

var (
fruitID int
form EditRow
)
fake.SubmitFormEditRowFunc = func(fruitIDArg int, formArg EditRow) EditRowPage {
fruitID, form = fruitIDArg, formArg
return EditRowPage{Row: Row{ID: 1, Name: "a", Value: 97}, Error: nil}
}

routes(mux, fake)

rec := httptest.NewRecorder()
req, err := http.NewRequest(http.MethodPatch, "/fruits/1", strings.NewReader(url.Values{"count": []string{"5"}}.Encode()))
require.NoError(t, err)
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
mux.ServeHTTP(rec, req)
res := rec.Result()
tBody := domtest.DocumentFragmentResponse(t, res, atom.Tbody)
t.Cleanup(func() {
if testing.Verbose() && t.Failed() {
t.Log(tBody)
}
})

require.Equal(t, http.StatusOK, res.StatusCode)
require.Equal(t, fruitID, 1)
require.Equal(t, form.Value, 5)

i := 0
for el := range tBody.QuerySelectorEach(`td`) {
switch i {
case 0:
require.Equal(t, "a", el.TextContent())
case 1:
require.Equal(t, "97", el.TextContent())
default:
t.Fatal(el)
}
i++
}
})

//"GET /fruits/{id}/edit"
//"GET /help"
//"GET /{$}"
}

type FakeBackend struct {
SubmitFormEditRowFunc func(fruitID int, form EditRow) EditRowPage
GetFormEditRowFunc func(fruitID int) EditRowPage
ListFunc func(_ context.Context) []Row
}

func (fb *FakeBackend) SubmitFormEditRow(fruitID int, form EditRow) EditRowPage {
return fb.SubmitFormEditRowFunc(fruitID, form)
}
func (fb *FakeBackend) GetFormEditRow(fruitID int) EditRowPage { return fb.GetFormEditRowFunc(fruitID) }
func (fb *FakeBackend) List(ctx context.Context) []Row { return fb.ListFunc(ctx) }
86 changes: 13 additions & 73 deletions internal/muxt/check.go
Original file line number Diff line number Diff line change
@@ -1,12 +1,10 @@
package muxt

import (
"cmp"
"errors"
"fmt"
"go/ast"
"go/token"
"go/types"
"log"

"golang.org/x/tools/go/packages"
Expand Down Expand Up @@ -35,7 +33,7 @@ func CheckTemplates(wd string, log *log.Logger, config RoutesFileConfiguration)

pl, err := packages.Load(&packages.Config{
Fset: imports.FileSet(),
Mode: packages.NeedModule | packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedSyntax | packages.NeedEmbedPatterns | packages.NeedEmbedFiles,
Mode: packages.NeedModule | packages.NeedTypesInfo | packages.NeedName | packages.NeedFiles | packages.NeedTypes | packages.NeedSyntax | packages.NeedEmbedPatterns | packages.NeedEmbedFiles,
Dir: wd,
}, patterns...)
if err != nil {
Expand All @@ -52,83 +50,25 @@ func CheckTemplates(wd string, log *log.Logger, config RoutesFileConfiguration)
if err != nil {
return err
}
templates, err := Templates(ts)
if err != nil {
return err
}

receiverPkgPath := cmp.Or(config.ReceiverPackage, config.PackagePath, routesPkg.PkgPath)
receiverPkg, ok := imports.Package(receiverPkgPath)
if !ok {
return fmt.Errorf("could not determine receiver package %s", receiverPkgPath)
}
obj := receiverPkg.Types.Scope().Lookup(config.ReceiverType)
if obj == nil {
return fmt.Errorf("could not find receiver type %s in %s", config.ReceiverType, receiverPkgPath)
}
receiver, ok := obj.Type().(*types.Named)
if !ok {
return fmt.Errorf("expected receiver %s to be a named type", config.ReceiverType)
}
if receiver == nil {
return fmt.Errorf("could not find receiver %s in %s", config.ReceiverType, receiverPkgPath)
}
fns := templatetype.DefaultFunctions(routesPkg.Types)
fns = fns.Add(templatetype.Functions(fm))

var errs []error

for _, t := range templates {
var (
dataVar types.Type
dataVarPkg *types.Package
)

log.Println("checking endpoint", t.template.Name())

if t.fun != nil {
name := t.fun.Name
dataVarPkg = receiver.Obj().Pkg()
methodObj, _, _ := types.LookupFieldOrMethod(receiver, true, dataVarPkg, name)
if methodObj == nil {
o, ok := packageScopeFunc(receiver.Obj().Pkg(), t.fun)
if !ok {
return fmt.Errorf("failed to generate method %s", t.fun.Name)
}
methodObj = o
}
sig := methodObj.Type().(*types.Signature)
if sig.Results().Len() == 0 {
return fmt.Errorf("method for pattern %q has no results it should have one or two", t.name)
}
dataVar = sig.Results().At(0).Type()
if types.Identical(dataVar, types.Universe.Lookup("any").Type()) {
log.Printf("\troute method returns type any\n\n\t%s\n", sig)
for _, file := range routesPkg.Syntax {
for node := range ast.Preorder(file) {
templateName, dataType, ok := source.ExecuteTemplateArguments(node, routesPkg.TypesInfo, config.TemplatesVariable)
if !ok {
continue
}
} else {
netHTTP, ok := imports.Types("net/http")
if !ok {
return fmt.Errorf("net/http package not loaded")
log.Println("checking endpoint", templateName)
tree := ts.Lookup(templateName).Tree
if err := templatetype.Check(tree, dataType, routesPkg.Types, routesPkg.Fset, newForrest(ts), fns); err != nil {
log.Println("ERROR", err)
log.Println()
errs = append(errs, err)
}
dataVar = types.NewPointer(netHTTP.Scope().Lookup("Request").Type())
dataVarPkg = netHTTP
}
if dataVar == nil {
return fmt.Errorf("failed to find data var type for template %q", t.template.Name())
}

log.Println("\tfor data type", dataVar.String())
log.Println()

fns := templatetype.DefaultFunctions(routesPkg.Types)
fns = fns.Add(templatetype.Functions(fm))

if err := templatetype.Check(t.template.Tree, dataVar, dataVarPkg, routesPkg.Fset, newForrest(ts), fns); err != nil {
log.Println("ERROR", templatetype.Check(t.template.Tree, dataVar, dataVarPkg, routesPkg.Fset, newForrest(ts), fns))
log.Println()
errs = append(errs, err)
}
}

if len(errs) == 1 {
log.Printf("1 error")
return errs[0]
Expand Down
15 changes: 15 additions & 0 deletions internal/source/go.go
Original file line number Diff line number Diff line change
Expand Up @@ -264,3 +264,18 @@ func HasFieldWithName(list *ast.FieldList, name string) bool {
_, ok := FindFieldWithName(list, name)
return ok
}

func basicLiteralString(node ast.Node) (string, bool) {
name, ok := node.(*ast.BasicLit)
if !ok {
return "", false
}
if name.Kind != token.STRING {
return "", false
}
templateName, err := strconv.Unquote(name.Value)
if err != nil {
return "", false
}
return templateName, true
}
34 changes: 34 additions & 0 deletions internal/source/template.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@ import (
"golang.org/x/tools/go/packages"
)

const (
templateExecuteFunc = "ExecuteTemplate"
)

func Templates(workingDirectory, templatesVariable string, pkg *packages.Package) (*template.Template, Functions, error) {
funcTypeMap := DefaultFunctions(pkg.Types)
for _, tv := range IterateValueSpecs(pkg.Syntax) {
Expand Down Expand Up @@ -398,3 +402,33 @@ func (functions Functions) FindFunction(name string) (*types.Signature, bool) {
}
return fn, true
}

func ExecuteTemplateArguments(node ast.Node, info *types.Info, templatesVariableName string) (string, types.Type, bool) {
call, ok := node.(*ast.CallExpr)
if !ok {
return "", nil, false
}
if len(call.Args) != 3 {
return "", nil, false
}
sel, ok := call.Fun.(*ast.SelectorExpr)
if !ok {
return "", nil, false
}
if sel.Sel.Name != templateExecuteFunc {
return "", nil, false
}
templatesIdent, ok := sel.X.(*ast.Ident)
if !ok {
return "", nil, false
}
if templatesIdent.Name != templatesVariableName {
return "", nil, false
}
templateName, ok := basicLiteralString(call.Args[1])
if !ok {
return "", nil, false
}
dataVar := info.TypeOf(call.Args[2])
return templateName, dataVar, true
}

0 comments on commit 06fae19

Please sign in to comment.