This repository has been archived by the owner on Apr 1, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #16 from Code-Hex/fix/use-code-hex-gqlparser
use code-hex gqlparser
- Loading branch information
Showing
13 changed files
with
853 additions
and
176 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Large diffs are not rendered by default.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,43 @@ | ||
package graphql | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
|
||
"github.com/Code-Hex/gqlparser/v2/ast" | ||
) | ||
|
||
type OperationContext struct { | ||
RawQuery string | ||
Variables map[string]interface{} | ||
OperationName string | ||
Doc *ast.QueryDocument | ||
|
||
Operation *ast.OperationDefinition | ||
} | ||
|
||
func (c *OperationContext) Validate(ctx context.Context) error { | ||
if c.Doc == nil { | ||
return errors.New("field 'Doc'is required") | ||
} | ||
if c.RawQuery == "" { | ||
return errors.New("field 'RawQuery' is required") | ||
} | ||
if c.Variables == nil { | ||
c.Variables = make(map[string]interface{}) | ||
} | ||
return nil | ||
} | ||
|
||
type operationCtx struct{} | ||
|
||
func GetOperationContext(ctx context.Context) *OperationContext { | ||
if val, ok := ctx.Value(operationCtx{}).(*OperationContext); ok && val != nil { | ||
return val | ||
} | ||
panic("missing operation context") | ||
} | ||
|
||
func WithOperationContext(ctx context.Context, rc *OperationContext) context.Context { | ||
return context.WithValue(ctx, operationCtx{}, rc) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
package graphql | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/google/go-cmp/cmp" | ||
) | ||
|
||
func TestGetOperationContext(t *testing.T) { | ||
rc := &OperationContext{} | ||
|
||
ctx := WithOperationContext(context.Background(), rc) | ||
|
||
got := GetOperationContext(ctx) | ||
|
||
if diff := cmp.Diff(rc, got); diff != "" { | ||
t.Errorf("(-want, +got)\n%s", diff) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
package graphql | ||
|
||
import ( | ||
"fmt" | ||
|
||
"github.com/Code-Hex/gqlparser/v2/ast" | ||
) | ||
|
||
// CollectFields returns the set of fields from an ast.SelectionSet where all collected fields satisfy at least one of the GraphQL types | ||
// passed through satisfies. Providing an empty or nil slice for satisfies will return collect all fields regardless of fragment | ||
// type conditions. | ||
func CollectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string) []CollectedField { | ||
return collectFields(reqCtx, selSet, satisfies, map[string]bool{}) | ||
} | ||
|
||
func collectFields(reqCtx *OperationContext, selSet ast.SelectionSet, satisfies []string, visited map[string]bool) []CollectedField { | ||
groupedFields := make([]CollectedField, 0, len(selSet)) | ||
|
||
for _, sel := range selSet { | ||
switch sel := sel.(type) { | ||
case *ast.Field: | ||
if !shouldIncludeNode(sel.Directives, reqCtx.Variables) { | ||
continue | ||
} | ||
f := getOrCreateAndAppendField(&groupedFields, sel.Name, sel.Alias, sel.ObjectDefinition, func() CollectedField { | ||
return CollectedField{Field: sel} | ||
}) | ||
|
||
f.Selections = append(f.Selections, sel.SelectionSet...) | ||
case *ast.InlineFragment: | ||
if !shouldIncludeNode(sel.Directives, reqCtx.Variables) { | ||
continue | ||
} | ||
if len(satisfies) > 0 && !instanceOf(sel.TypeCondition, satisfies) { | ||
continue | ||
} | ||
for _, childField := range collectFields(reqCtx, sel.SelectionSet, satisfies, visited) { | ||
f := getOrCreateAndAppendField(&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition, func() CollectedField { return childField }) | ||
f.Selections = append(f.Selections, childField.Selections...) | ||
} | ||
|
||
case *ast.FragmentSpread: | ||
if !shouldIncludeNode(sel.Directives, reqCtx.Variables) { | ||
continue | ||
} | ||
fragmentName := sel.Name | ||
if _, seen := visited[fragmentName]; seen { | ||
continue | ||
} | ||
visited[fragmentName] = true | ||
|
||
fragment := reqCtx.Doc.Fragments.ForName(fragmentName) | ||
if fragment == nil { | ||
// should never happen, validator has already run | ||
panic(fmt.Errorf("missing fragment %s", fragmentName)) | ||
} | ||
|
||
if len(satisfies) > 0 && !instanceOf(fragment.TypeCondition, satisfies) { | ||
continue | ||
} | ||
|
||
for _, childField := range collectFields(reqCtx, fragment.SelectionSet, satisfies, visited) { | ||
f := getOrCreateAndAppendField(&groupedFields, childField.Name, childField.Alias, childField.ObjectDefinition, func() CollectedField { return childField }) | ||
f.Selections = append(f.Selections, childField.Selections...) | ||
} | ||
default: | ||
panic(fmt.Errorf("unsupported %T", sel)) | ||
} | ||
} | ||
|
||
return groupedFields | ||
} | ||
|
||
type CollectedField struct { | ||
*ast.Field | ||
|
||
Selections ast.SelectionSet | ||
} | ||
|
||
func instanceOf(val string, satisfies []string) bool { | ||
for _, s := range satisfies { | ||
if val == s { | ||
return true | ||
} | ||
} | ||
return false | ||
} | ||
|
||
func getOrCreateAndAppendField(c *[]CollectedField, name string, alias string, objectDefinition *ast.Definition, creator func() CollectedField) *CollectedField { | ||
for i, cf := range *c { | ||
if cf.Name == name && cf.Alias == alias && (cf.ObjectDefinition == objectDefinition || (cf.ObjectDefinition != nil && objectDefinition != nil && cf.ObjectDefinition.Name == objectDefinition.Name)) { | ||
return &(*c)[i] | ||
} | ||
} | ||
|
||
f := creator() | ||
|
||
*c = append(*c, f) | ||
return &(*c)[len(*c)-1] | ||
} | ||
|
||
func shouldIncludeNode(directives ast.DirectiveList, variables map[string]interface{}) bool { | ||
if len(directives) == 0 { | ||
return true | ||
} | ||
|
||
skip, include := false, true | ||
|
||
if d := directives.ForName("skip"); d != nil { | ||
skip = resolveIfArgument(d, variables) | ||
} | ||
|
||
if d := directives.ForName("include"); d != nil { | ||
include = resolveIfArgument(d, variables) | ||
} | ||
|
||
return !skip && include | ||
} | ||
|
||
func resolveIfArgument(d *ast.Directive, variables map[string]interface{}) bool { | ||
arg := d.Arguments.ForName("if") | ||
if arg == nil { | ||
panic(fmt.Sprintf("%s: argument 'if' not defined", d.Name)) | ||
} | ||
value, err := arg.Value.Value(variables) | ||
if err != nil { | ||
panic(err) | ||
} | ||
ret, ok := value.(bool) | ||
if !ok { | ||
panic(fmt.Sprintf("%s: argument 'if' is not a boolean", d.Name)) | ||
} | ||
return ret | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,104 @@ | ||
package graphql | ||
|
||
import ( | ||
"io" | ||
"sync" | ||
|
||
gql "github.com/99designs/gqlgen/graphql" | ||
) | ||
|
||
type FieldSet struct { | ||
fields []CollectedField | ||
Values []gql.Marshaler | ||
delayed []delayedResult | ||
} | ||
|
||
type delayedResult struct { | ||
i int | ||
f func() gql.Marshaler | ||
} | ||
|
||
func NewFieldSet(fields []CollectedField) *FieldSet { | ||
return &FieldSet{ | ||
fields: fields, | ||
Values: make([]gql.Marshaler, len(fields)), | ||
} | ||
} | ||
|
||
func (m *FieldSet) Concurrently(i int, f func() gql.Marshaler) { | ||
m.delayed = append(m.delayed, delayedResult{i: i, f: f}) | ||
} | ||
|
||
func (m *FieldSet) Dispatch() { | ||
if len(m.delayed) == 1 { | ||
// only one concurrent task, no need to spawn a goroutine or deal create waitgroups | ||
d := m.delayed[0] | ||
m.Values[d.i] = d.f() | ||
} else if len(m.delayed) > 1 { | ||
// more than one concurrent task, use the main goroutine to do one, only spawn goroutines for the others | ||
|
||
var wg sync.WaitGroup | ||
for _, d := range m.delayed[1:] { | ||
wg.Add(1) | ||
go func(d delayedResult) { | ||
m.Values[d.i] = d.f() | ||
wg.Done() | ||
}(d) | ||
} | ||
|
||
m.Values[m.delayed[0].i] = m.delayed[0].f() | ||
wg.Wait() | ||
} | ||
} | ||
|
||
var openBrace = []byte(`{`) | ||
var closeBrace = []byte(`}`) | ||
var colon = []byte(`:`) | ||
var comma = []byte(`,`) | ||
|
||
func (m *FieldSet) MarshalGQL(writer io.Writer) { | ||
writer.Write(openBrace) | ||
for i, field := range m.fields { | ||
if i != 0 { | ||
writer.Write(comma) | ||
} | ||
writeQuotedString(writer, field.Alias) | ||
writer.Write(colon) | ||
m.Values[i].MarshalGQL(writer) | ||
} | ||
writer.Write(closeBrace) | ||
} | ||
|
||
const encodeHex = "0123456789ABCDEF" | ||
|
||
func writeQuotedString(w io.Writer, s string) { | ||
start := 0 | ||
io.WriteString(w, `"`) | ||
|
||
for i, c := range s { | ||
if c < 0x20 || c == '\\' || c == '"' { | ||
io.WriteString(w, s[start:i]) | ||
|
||
switch c { | ||
case '\t': | ||
io.WriteString(w, `\t`) | ||
case '\r': | ||
io.WriteString(w, `\r`) | ||
case '\n': | ||
io.WriteString(w, `\n`) | ||
case '\\': | ||
io.WriteString(w, `\\`) | ||
case '"': | ||
io.WriteString(w, `\"`) | ||
default: | ||
io.WriteString(w, `\u00`) | ||
w.Write([]byte{encodeHex[c>>4], encodeHex[c&0xf]}) | ||
} | ||
|
||
start = i + 1 | ||
} | ||
} | ||
|
||
io.WriteString(w, s[start:]) | ||
io.WriteString(w, `"`) | ||
} |
Oops, something went wrong.