Skip to content

Commit

Permalink
pkg/parse: Replace internal use of reflection with generics.
Browse files Browse the repository at this point in the history
Also add a benchmark that simply runs all the parse tests, which has shown a
moderate speedup from the change:

Before:
BenchmarkParse-8            7588            171061 ns/op

After:
BenchmarkParse-8            8480            150677 ns/op
  • Loading branch information
xiaq committed Jul 29, 2024
1 parent cdb80cd commit 52b4689
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 46 deletions.
68 changes: 33 additions & 35 deletions pkg/parse/parse.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func Parse(src Source, cfg Config) (Tree, error) {
// unpacked with [UnpackErrors].
func ParseAs(src Source, n Node, cfg Config) error {
ps := &parser{srcName: src.Name, src: src.Code, warn: cfg.WarningWriter}
ps.parse(n)
parse(ps, n)
ps.done()
return diag.PackErrors(ps.errors)
}
Expand Down Expand Up @@ -91,7 +91,7 @@ type Chunk struct {
func (bn *Chunk) parse(ps *parser) {
bn.parseSeps(ps)
for startsPipeline(ps.peek()) {
ps.parse(&Pipeline{}).addTo(&bn.Pipelines, bn)
parse(ps, &Pipeline{}).addTo(&bn.Pipelines, bn)
if bn.parseSeps(ps) == 0 {
break
}
Expand Down Expand Up @@ -130,14 +130,14 @@ type Pipeline struct {
}

func (pn *Pipeline) parse(ps *parser) {
ps.parse(&Form{}).addTo(&pn.Forms, pn)
parse(ps, &Form{}).addTo(&pn.Forms, pn)
for parseSep(pn, ps, '|') {
parseSpacesAndNewlines(pn, ps)
if !startsForm(ps.peek()) {
ps.error(errShouldBeForm)
return
}
ps.parse(&Form{}).addTo(&pn.Forms, pn)
parse(ps, &Form{}).addTo(&pn.Forms, pn)
}
parseSpaces(pn, ps)
if ps.peek() == '&' {
Expand All @@ -162,7 +162,7 @@ type Form struct {
}

func (fn *Form) parse(ps *parser) {
ps.parse(&Compound{ExprCtx: CmdExpr}).addAs(&fn.Head, fn)
parse(ps, &Compound{ExprCtx: CmdExpr}).addAs(&fn.Head, fn)
parseSpaces(fn, ps)

for {
Expand All @@ -176,18 +176,19 @@ func (fn *Form) parse(ps *parser) {
// background indicator
return
}
ps.parse(&MapPair{}).addTo(&fn.Opts, fn)
parse(ps, &MapPair{}).addTo(&fn.Opts, fn)
case startsCompound(r, NormalExpr):
cn := &Compound{}
ps.parse(cn)
parse(ps, cn)
if isRedirSign(ps.peek()) {
// Redir
ps.parse(&Redir{Left: cn}).addTo(&fn.Redirs, fn)
parse(ps, &Redir{Left: cn}).addTo(&fn.Redirs, fn)
} else {
parsed{cn}.addTo(&fn.Args, fn)
fn.Args = append(fn.Args, cn)
addChild(fn, cn)
}
case isRedirSign(r):
ps.parse(&Redir{}).addTo(&fn.Redirs, fn)
parse(ps, &Redir{}).addTo(&fn.Redirs, fn)
default:
return
}
Expand Down Expand Up @@ -237,7 +238,7 @@ func (rn *Redir) parse(ps *parser) {
if parseSep(rn, ps, '&') {
rn.RightIsFd = true
}
ps.parse(&Compound{}).addAs(&rn.Right, rn)
parse(ps, &Compound{}).addAs(&rn.Right, rn)
if len(rn.Right.Indexings) == 0 {
if rn.RightIsFd {
ps.error(errShouldBeFD)
Expand Down Expand Up @@ -278,9 +279,9 @@ func (qn *Filter) parse(ps *parser) {
r := ps.peek()
switch {
case r == '&':
ps.parse(&MapPair{}).addTo(&qn.Opts, qn)
parse(ps, &MapPair{}).addTo(&qn.Opts, qn)
case startsCompound(r, NormalExpr):
ps.parse(&Compound{}).addTo(&qn.Args, qn)
parse(ps, &Compound{}).addTo(&qn.Args, qn)
default:
return
}
Expand Down Expand Up @@ -320,7 +321,7 @@ const (
func (cn *Compound) parse(ps *parser) {
cn.tilde(ps)
for startsIndexing(ps.peek(), cn.ExprCtx) {
ps.parse(&Indexing{ExprCtx: cn.ExprCtx}).addTo(&cn.Indexings, cn)
parse(ps, &Indexing{ExprCtx: cn.ExprCtx}).addTo(&cn.Indexings, cn)
}
}

Expand All @@ -334,8 +335,10 @@ func (cn *Compound) tilde(ps *parser) {
sourceText: "~", parent: nil, children: nil}
pn := &Primary{node: base, Type: Tilde, Value: "~"}
in := &Indexing{node: base}
parsed{pn}.addAs(&in.Head, in)
parsed{in}.addTo(&cn.Indexings, cn)
in.Head = pn
addChild(in, pn)
cn.Indexings = append(cn.Indexings, in)
addChild(cn, in)
}
}

Expand All @@ -352,13 +355,13 @@ type Indexing struct {
}

func (in *Indexing) parse(ps *parser) {
ps.parse(&Primary{ExprCtx: in.ExprCtx}).addAs(&in.Head, in)
parse(ps, &Primary{ExprCtx: in.ExprCtx}).addAs(&in.Head, in)
for parseSep(in, ps, '[') {
if !startsArray(ps.peek()) && ps.peek() != ']' {
ps.error(errShouldBeArray)
}

ps.parse(&Array{}).addTo(&in.Indices, in)
parse(ps, &Array{}).addTo(&in.Indices, in)

if !parseSep(in, ps, ']') {
ps.error(errShouldBeRBracket)
Expand Down Expand Up @@ -386,7 +389,7 @@ func (sn *Array) parse(ps *parser) {

parseSep()
for startsCompound(ps.peek(), NormalExpr) {
ps.parse(&Compound{}).addTo(&sn.Compounds, sn)
parse(ps, &Compound{}).addTo(&sn.Compounds, sn)
parseSep()
}
}
Expand Down Expand Up @@ -712,7 +715,7 @@ func (pn *Primary) exitusCapture(ps *parser) {

pn.Type = ExceptionCapture

ps.parse(&Chunk{}).addAs(&pn.Chunk, pn)
parse(ps, &Chunk{}).addAs(&pn.Chunk, pn)

if !parseSep(pn, ps, ')') {
ps.error(errShouldBeRParen)
Expand All @@ -723,7 +726,7 @@ func (pn *Primary) outputCapture(ps *parser) {
pn.Type = OutputCapture
parseSep(pn, ps, '(')

ps.parse(&Chunk{}).addAs(&pn.Chunk, pn)
parse(ps, &Chunk{}).addAs(&pn.Chunk, pn)

if !parseSep(pn, ps, ')') {
ps.error(errShouldBeRParen)
Expand Down Expand Up @@ -754,9 +757,9 @@ items:
break items
}
ps.backup()
ps.parse(&MapPair{}).addTo(&pn.MapPairs, pn)
parse(ps, &MapPair{}).addTo(&pn.MapPairs, pn)
case startsCompound(r, NormalExpr):
ps.parse(&Compound{}).addTo(&pn.Elements, pn)
parse(ps, &Compound{}).addTo(&pn.Elements, pn)
default:
break items
}
Expand Down Expand Up @@ -788,9 +791,9 @@ func (pn *Primary) lambda(ps *parser) {
r := ps.peek()
switch {
case r == '&':
ps.parse(&MapPair{}).addTo(&pn.MapPairs, pn)
parse(ps, &MapPair{}).addTo(&pn.MapPairs, pn)
case startsCompound(r, NormalExpr):
ps.parse(&Compound{}).addTo(&pn.Elements, pn)
parse(ps, &Compound{}).addTo(&pn.Elements, pn)
default:
break items
}
Expand All @@ -800,7 +803,7 @@ func (pn *Primary) lambda(ps *parser) {
ps.error(errShouldBePipe)
}
}
ps.parse(&Chunk{}).addAs(&pn.Chunk, pn)
parse(ps, &Chunk{}).addAs(&pn.Chunk, pn)
if !parseSep(pn, ps, '}') {
ps.error(errShouldBeRBrace)
}
Expand All @@ -820,15 +823,15 @@ func (pn *Primary) lbrace(ps *parser) {

// TODO(xiaq): The compound can be empty, which allows us to parse {,foo}.
// Allowing compounds to be empty can be fragile in other cases.
ps.parse(&Compound{ExprCtx: BracedElemExpr}).addTo(&pn.Braced, pn)
parse(ps, &Compound{ExprCtx: BracedElemExpr}).addTo(&pn.Braced, pn)

for isBracedSep(ps.peek()) {
parseSpacesAndNewlines(pn, ps)
// optional, so ignore the return value
parseSep(pn, ps, ',')
parseSpacesAndNewlines(pn, ps)

ps.parse(&Compound{ExprCtx: BracedElemExpr}).addTo(&pn.Braced, pn)
parse(ps, &Compound{ExprCtx: BracedElemExpr}).addTo(&pn.Braced, pn)
}
if !parseSep(pn, ps, '}') {
ps.error(errShouldBeBraceSepOrRBracket)
Expand Down Expand Up @@ -883,15 +886,15 @@ type MapPair struct {
func (mpn *MapPair) parse(ps *parser) {
parseSep(mpn, ps, '&')

ps.parse(&Compound{ExprCtx: LHSExpr}).addAs(&mpn.Key, mpn)
parse(ps, &Compound{ExprCtx: LHSExpr}).addAs(&mpn.Key, mpn)
if len(mpn.Key.Indexings) == 0 {
ps.error(errShouldBeCompound)
}

if parseSep(mpn, ps, '=') {
parseSpacesAndNewlines(mpn, ps)
// Parse value part. It can be empty.
ps.parse(&Compound{}).addAs(&mpn.Value, mpn)
parse(ps, &Compound{}).addAs(&mpn.Value, mpn)
}
}

Expand Down Expand Up @@ -997,8 +1000,3 @@ func IsInlineWhitespace(r rune) bool {
func IsWhitespace(r rune) bool {
return IsInlineWhitespace(r) || r == '\r' || r == '\n'
}

func addChild(p Node, ch Node) {
p.n().addChild(ch)
ch.n().parent = p
}
8 changes: 8 additions & 0 deletions pkg/parse/parse_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -678,3 +678,11 @@ func TestParse_ReturnsTreeContainingSourceFromArgument(t *testing.T) {
t.Errorf("tree.Source = %v, want %v", tree.Source, src)
}
}

func BenchmarkParse(b *testing.B) {
for i := 0; i < b.N; i++ {
for _, test := range testCases {
_ = ParseAs(SourceForTest(test.code), test.node, Config{})
}
}
}
24 changes: 13 additions & 11 deletions pkg/parse/parser.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@ import (
"errors"
"fmt"
"io"
"reflect"
"strings"
"unicode/utf8"

Expand All @@ -32,31 +31,34 @@ type ErrorTag struct{}

func (ErrorTag) ErrorTag() string { return "parse error" }

func (ps *parser) parse(n Node) parsed {
func parse[N Node](ps *parser, n N) parsed[N] {
begin := ps.pos
n.n().From = begin
n.parse(ps)
n.n().To = ps.pos
n.n().sourceText = ps.src[begin:ps.pos]
return parsed{n}
return parsed[N]{n}
}

type parsed struct {
n Node
type parsed[N Node] struct {
n N
}

func (p parsed) addAs(ptr any, parent Node) {
dst := reflect.ValueOf(ptr).Elem()
dst.Set(reflect.ValueOf(p.n)) // *ptr = p.n
func (p parsed[N]) addAs(ptr *N, parent Node) {
*ptr = p.n
addChild(parent, p.n)
}

func (p parsed) addTo(ptr any, parent Node) {
dst := reflect.ValueOf(ptr).Elem()
dst.Set(reflect.Append(dst, reflect.ValueOf(p.n))) // *ptr = append(*ptr, n)
func (p parsed[N]) addTo(ptr *[]N, parent Node) {
*ptr = append(*ptr, p.n)
addChild(parent, p.n)
}

func addChild(p Node, ch Node) {
p.n().addChild(ch)
ch.n().parent = p
}

// Tells the parser that parsing is done.
func (ps *parser) done() {
if ps.pos != len(ps.src) {
Expand Down

0 comments on commit 52b4689

Please sign in to comment.