Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rework the graph walking functions with functional options #42

Merged
merged 3 commits into from
Jul 22, 2019
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 87 additions & 43 deletions merkledag.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func FetchGraph(ctx context.Context, root cid.Cid, serv ipld.DAGService) error {
}

// FetchGraphWithDepthLimit fetches all nodes that are children to the given
// node down to the given depth. maxDetph=0 means "only fetch root",
// node down to the given depth. maxDepth=0 means "only fetch root",
// maxDepth=1 means "fetch root and its direct children" and so on...
// maxDepth=-1 means unlimited.
func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, serv ipld.DAGService) error {
Expand Down Expand Up @@ -195,9 +195,10 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s
return false
}

// If we have a ProgressTracker, we wrap the visit function to handle it
v, _ := ctx.Value(progressContextKey).(*ProgressTracker)
if v == nil {
return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visit)
return WalkDepth(ctx, GetLinksDirect(ng), root, visit, Concurrent(), WithRoot())
}

visitProgress := func(c cid.Cid, depth int) bool {
Expand All @@ -207,7 +208,7 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s
}
return false
}
return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visitProgress)
return WalkDepth(ctx, GetLinksDirect(ng), root, visitProgress, Concurrent(), WithRoot())
}

// GetMany gets many nodes from the DAG at once.
Expand Down Expand Up @@ -281,21 +282,77 @@ func GetLinksWithDAG(ng ipld.NodeGetter) GetLinks {
}
}

// defaultConcurrentFetch is the default maximum number of concurrent fetches
// that 'fetchNodes' will start at a time
const defaultConcurrentFetch = 32

// WalkOptions represent the parameters of a graph walking algorithm
type WalkOptions struct {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this private (or at least make the fields private). Yeah, I know I don't usually do that but that was a mistake.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point.

WithRoot bool
IgnoreBadBlock bool
Concurrency int
}

// WalkOption is a setter for WalkOptions
type WalkOption func(*WalkOptions)

// WithRoot is a WalkOption indicating that the root node should be visited
func WithRoot() WalkOption {
return func(walkOptions *WalkOptions) {
walkOptions.WithRoot = true
}
}

// Concurrent is a WalkOption indicating that node fetching should be done in
// parallel, with the default concurrency factor.
// NOTE: When using that option, the walk order is *not* guarantee.
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func Concurrent() WalkOption {
return func(walkOptions *WalkOptions) {
walkOptions.Concurrency = defaultConcurrentFetch
}
}

// Concurrency is a WalkOption indicating that node fetching should be done in
// parallel, with a specific concurrency factor.
// NOTE: When using that option, the walk order is *not* guarantee.
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func Concurrency(worker int) WalkOption {
return func(walkOptions *WalkOptions) {
walkOptions.Concurrency = worker
}
}

// WalkGraph will walk the dag in order (depth first) starting at the given root.
func Walk(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid) bool) error {
func Walk(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool, options ...WalkOption) error {
visitDepth := func(c cid.Cid, depth int) bool {
return visit(c)
}

return WalkDepth(ctx, getLinks, root, 0, visitDepth)
return WalkDepth(ctx, getLinks, c, visitDepth, options...)
}

// WalkDepth walks the dag starting at the given root and passes the current
// depth to a given visit function. The visit function can be used to limit DAG
// exploration.
func WalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool) error {
if !visit(root, depth) {
return nil
func WalkDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid, int) bool, options ...WalkOption) error {
opts := &WalkOptions{}
for _, opt := range options {
opt(opts)
}

if opts.Concurrency > 1 {
return parallelWalkDepth(ctx, getLinks, c, visit, opts)
} else {
return sequentialWalkDepth(ctx, getLinks, c, 0, visit, opts)
}
}

func sequentialWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool, options *WalkOptions) error {
if depth != 0 || options.WithRoot {
if !visit(root, depth) {
return nil
}
}

links, err := getLinks(ctx, root)
Expand All @@ -304,7 +361,7 @@ func WalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int,
}

for _, lnk := range links {
if err := WalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit); err != nil {
if err := sequentialWalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit, options); err != nil {
return err
}
}
Expand Down Expand Up @@ -337,27 +394,7 @@ func (p *ProgressTracker) Value() int {
return p.Total
}

// FetchGraphConcurrency is total number of concurrent fetches that
// 'fetchNodes' will start at a time
var FetchGraphConcurrency = 32

// WalkParallel is equivalent to Walk *except* that it explores multiple paths
// in parallel.
//
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func WalkParallel(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool) error {
visitDepth := func(c cid.Cid, depth int) bool {
return visit(c)
}

return WalkParallelDepth(ctx, getLinks, c, 0, visitDepth)
}

// WalkParallelDepth is equivalent to WalkDepth *except* that it fetches
// children in parallel.
//
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startDepth int, visit func(cid.Cid, int) bool) error {
func parallelWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid, int) bool, options *WalkOptions) error {
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

note: I kept this function at the same place as before for an easier review, but it would make sense to move it just after sequentialWalkDepth.

type cidDepth struct {
cid cid.Cid
depth int
Expand All @@ -372,24 +409,31 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
out := make(chan *linksDepth)
done := make(chan struct{})

var setlk sync.Mutex
var visitlk sync.Mutex
var wg sync.WaitGroup

errChan := make(chan error)
fetchersCtx, cancel := context.WithCancel(ctx)
defer wg.Wait()
defer cancel()
for i := 0; i < FetchGraphConcurrency; i++ {
for i := 0; i < options.Concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for cdepth := range feed {
ci := cdepth.cid
depth := cdepth.depth

setlk.Lock()
shouldVisit := visit(ci, depth)
setlk.Unlock()
var shouldVisit bool

// bypass the root if needed
if depth != 0 || options.WithRoot {
visitlk.Lock()
shouldVisit = visit(ci, depth)
visitlk.Unlock()
} else {
shouldVisit = true
}

if shouldVisit {
links, err := getLinks(ctx, ci)
Expand Down Expand Up @@ -422,20 +466,21 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
defer close(feed)

send := feed
var todobuffer []*cidDepth
var todoQueue []*cidDepth
var inProgress int

next := &cidDepth{
cid: c,
depth: startDepth,
cid: root,
depth: 0,
}

for {
select {
case send <- next:
inProgress++
if len(todobuffer) > 0 {
next = todobuffer[0]
todobuffer = todobuffer[1:]
if len(todoQueue) > 0 {
next = todoQueue[0]
todoQueue = todoQueue[1:]
} else {
next = nil
send = nil
Expand All @@ -456,7 +501,7 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
next = cd
send = feed
} else {
todobuffer = append(todobuffer, cd)
todoQueue = append(todoQueue, cd)
}
}
case err := <-errChan:
Expand All @@ -466,7 +511,6 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
return ctx.Err()
}
}

}

var _ ipld.LinkGetter = &dagService{}
Expand Down
11 changes: 7 additions & 4 deletions merkledag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,11 @@ func makeTestDAG(t *testing.T, read io.Reader, ds ipld.DAGService) ipld.Node {
// Add a root referencing all created nodes
root := NodeWithData(nil)
for _, n := range nodes {
root.AddNodeLink(n.Cid().String(), n)
err := ds.Add(ctx, n)
err := root.AddNodeLink(n.Cid().String(), n)
if err != nil {
t.Fatal(err)
}
err = ds.Add(ctx, n)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -383,7 +386,7 @@ func TestFetchGraphWithDepthLimit(t *testing.T) {

}

err = WalkDepth(context.Background(), offlineDS.GetLinks, root.Cid(), 0, visitF)
err = WalkDepth(context.Background(), offlineDS.GetLinks, root.Cid(), visitF, WithRoot())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -736,7 +739,7 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) {
}

cset := cid.NewSet()
err = WalkParallel(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit)
err = Walk(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit)
if err == nil {
t.Fatal("this should have failed")
}
Expand Down