-
Notifications
You must be signed in to change notification settings - Fork 2
/
Copy pathworker.go
184 lines (168 loc) · 4.23 KB
/
worker.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
package crawler
import (
"context"
"errors"
"fmt"
"net/http"
"net/url"
"strings"
"sync"
"golang.org/x/sync/errgroup"
)
// CrawlFunc is the type of the function called for each webpage visited by
// Crawl. The incoming url specifies which url was fetched, while res contains
// the response of the fetched URL if it was successful. If the fetch failed,
// the incoming error will specify the reason and res will be nil.
//
// Returning ErrSkipURL will avoid queing up the resources links to be crawled.
//
// Returning any other error from the function will immediately stop the crawl.
type CrawlFunc func(url string, res *Response, err error) error
// ErrSkipURL can be returned by CrawlFunc to avoid crawling the links from the given url
var ErrSkipURL = errors.New("skip URL")
// Runner defines the interface requred to run a crawl
type Runner interface {
Run(context.Context, Queue) error
}
// Worker is used to run a crawl on a single goroutine
type Worker struct {
client *http.Client
fn CrawlFunc
checkFetch CheckFetchStack
maxRedirs int
goroutines int
}
// NewWorker initialises a goroutine
func NewWorker(fn CrawlFunc, opts ...Option) (*Worker, error) {
o := options{
transport: http.DefaultTransport,
}
for _, opt := range opts {
if err := opt(&o); err != nil {
return nil, err
}
}
if o.goroutines == 0 {
o.goroutines = 1
}
var mut sync.Mutex
return &Worker{
client: &http.Client{
Transport: o.transport,
CheckRedirect: skipRedirects,
},
checkFetch: CheckFetchStack(o.checkFetch),
fn: func(url string, res *Response, err error) error {
mut.Lock()
defer mut.Unlock()
return fn(url, res, err)
},
goroutines: o.goroutines,
}, nil
}
func (w *Worker) run(ctx context.Context, q Queue) error {
for {
if err := ctx.Err(); err != nil {
return err
}
req, err := q.PopFront()
if err != nil {
return err
}
if req == nil {
return nil
}
if !w.checkFetch.CheckFetch(req) {
req.Finish()
continue
}
res, err := fetch(ctx, w.client, req)
if err := ctx.Err(); err != nil {
return err
}
// call the CrawlFunc for each fetched url
// note this err is scoped to the if and does not override the previous declaration
if err := w.fn(req.URL.String(), res, err); err == ErrSkipURL {
req.Finish()
continue
} else if err != nil {
return err
}
// continue if there was an error crawlking
if err != nil {
req.Finish()
continue
}
if req, err := nextRequest(res, res.RedirectTo); err == nil {
q.PushBack(req)
}
for _, link := range res.Links {
if req, err := nextRequest(res, link.URL); err == nil {
q.PushBack(req)
}
}
req.Finish()
}
}
// Run starts processing requests from the queue
func (w *Worker) Run(ctx context.Context, q Queue) error {
g, ctx := errgroup.WithContext(ctx)
for i := 0; i < w.goroutines; i++ {
g.Go(func() error {
return w.run(ctx, q)
})
}
return g.Wait()
}
func skipRedirects(req *http.Request, via []*http.Request) error {
return http.ErrUseLastResponse
}
func fetch(ctx context.Context, c *http.Client, req *Request) (*Response, error) {
uri := req.URL.String()
httpReq, err := http.NewRequest(http.MethodGet, uri, nil)
if err != nil {
return nil, err
}
httpReq = httpReq.WithContext(ctx)
httpRes, err := c.Do(httpReq)
if err != nil {
return nil, err
}
defer httpRes.Body.Close()
res := Response{
request: req,
}
switch httpRes.StatusCode {
case http.StatusOK:
case http.StatusMovedPermanently, http.StatusFound:
res.URL = httpRes.Request.URL.String()
loc, err := url.Parse(httpRes.Header.Get("Location"))
if err != nil {
return nil, err
}
res.RedirectTo = httpRes.Request.URL.ResolveReference(loc).String()
return &res, nil
default:
return nil, fmt.Errorf("%s for %s", httpRes.Status, uri)
}
if strings.Contains(httpRes.Header.Get("Content-Type"), "text/html") {
err = ReadResponse(httpRes.Request.URL, httpRes.Body, &res)
}
return &res, nil
}
func nextRequest(res *Response, href string) (*Request, error) {
if href == "" {
return nil, ErrSkipURL
}
req, err := NewRequest(href)
if err != nil {
return nil, err
}
if res.RedirectTo == "" {
req.depth = res.request.depth + 1
req.redirects = 0
} else {
req.redirects = res.request.redirects + 1
}
return req, nil
}