Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/SkynetLabs/blocker into pj/…
Browse files Browse the repository at this point in the history
…remove-skyd-api
  • Loading branch information
peterjan committed Mar 18, 2022
2 parents 3532489 + 011bdb4 commit b403311
Show file tree
Hide file tree
Showing 9 changed files with 318 additions and 111 deletions.
2 changes: 1 addition & 1 deletion api/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ func (api *API) blocklistGET(w http.ResponseWriter, r *http.Request, _ httproute
return
}

blocked, more, err := api.staticDB.BlockedHashes(sort, offset, limit)
blocked, more, err := api.staticDB.BlockedHashes(r.Context(), sort, offset, limit)
if err != nil {
skyapi.WriteError(w, skyapi.Error{err.Error()}, http.StatusInternalServerError)
return
Expand Down
88 changes: 71 additions & 17 deletions blocker/blocker.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,11 @@ const (
// blockBatchSize is the max number of (skylink) hashes to be sent for
// blocking simultaneously.
blockBatchSize = 100

// stopTimeoutDuration is the amount of time we wait when stop is called
// before cancelling out and returning with an error indicating an unclean
// shutdown.
stopTimeoutDuration = time.Minute
)

var (
Expand Down Expand Up @@ -52,19 +57,17 @@ type (
// to block.
latestBlockTime time.Time

staticCtx context.Context
staticDB *database.DB
staticLogger *logrus.Logger
staticMu sync.Mutex
staticSkydClient *api.Client
staticStopChan chan struct{}
staticWaitGroup sync.WaitGroup
}
)

// New returns a new Blocker with the given parameters.
func New(ctx context.Context, skydClient *api.Client, db *database.DB, logger *logrus.Logger) (*Blocker, error) {
if ctx == nil {
return nil, errors.New("no context provided")
}
func New(skydClient *api.Client, db *database.DB, logger *logrus.Logger) (*Blocker, error) {
if db == nil {
return nil, errors.New("no DB provided")
}
Expand All @@ -74,12 +77,11 @@ func New(ctx context.Context, skydClient *api.Client, db *database.DB, logger *l
if skydClient == nil {
return nil, errors.New("no Skyd client provided")
}
// TODO: check skyd and nginx client
bl := &Blocker{
staticCtx: ctx,
staticDB: db,
staticLogger: logger,
staticSkydClient: skydClient,
staticStopChan: make(chan struct{}),
}
return bl, nil
}
Expand All @@ -97,7 +99,7 @@ func (bl *Blocker) BlockHashes(hashes []database.Hash) (int, int, error) {
for start < len(hashes) {
// check whether we need to escape
select {
case <-bl.staticCtx.Done():
case <-bl.staticStopChan:
return numBlocked, numInvalid, nil
default:
}
Expand All @@ -115,20 +117,27 @@ func (bl *Blocker) BlockHashes(hashes []database.Hash) (int, int, error) {
// escape early because something is probably wrong
blocked, invalid, err := bl.staticSkydClient.BlockHashes(batch)
if err != nil {
err = errors.Compose(err, bl.staticDB.MarkFailed(batch))
ctx, cancel := context.WithTimeout(context.Background(), database.MongoDefaultTimeout)
defer cancel()
err = errors.Compose(err, bl.staticDB.MarkFailed(ctx, batch))
return numBlocked, numInvalid, err
}

// update the counts
numBlocked += len(blocked)
numInvalid += len(invalid)

// create a context
ctx, cancel := context.WithTimeout(context.Background(), database.MongoDefaultTimeout)

// update the documents
err1 := bl.staticDB.MarkSucceeded(blocked)
err2 := bl.staticDB.MarkInvalid(invalid)
err1 := bl.staticDB.MarkSucceeded(ctx, blocked)
err2 := bl.staticDB.MarkInvalid(ctx, invalid)
if err := errors.Compose(err1, err2); err != nil {
cancel()
return numBlocked, numInvalid, err
}
cancel()

// update start
start = end
Expand All @@ -150,12 +159,49 @@ func (bl *Blocker) Start() error {
bl.started = true

// start the loops
go bl.threadedBlockLoop()
go bl.threadedRetryLoop()
bl.staticWaitGroup.Add(1)
go func() {
bl.threadedBlockLoop()
bl.staticWaitGroup.Done()
}()

bl.staticWaitGroup.Add(1)
go func() {
bl.threadedRetryLoop()
bl.staticWaitGroup.Done()
}()

return nil
}

// Stop waits for the blocker's waitgroup and times out after one minute.
func (bl *Blocker) Stop() error {
// check whether the blocker was started
bl.staticMu.Lock()
if !bl.started {
bl.staticMu.Unlock()
return errors.New("blocker not started")
}
bl.started = false
bl.staticMu.Unlock()

// stop the blocker by closing the stop channel
close(bl.staticStopChan)

// wait for the waitgroup, timeout and signal unclean shutdown after 1m
c := make(chan struct{})
go func() {
defer close(c)
bl.staticWaitGroup.Wait()
}()
select {
case <-c:
return nil
case <-time.After(stopTimeoutDuration):
return errors.New("unclean blocker shutdown")
}
}

// threadedBlockLoop holds the main block loop
func (bl *Blocker) threadedBlockLoop() {
// convenience variables
Expand All @@ -170,7 +216,7 @@ func (bl *Blocker) threadedBlockLoop() {
}

select {
case <-bl.staticCtx.Done():
case <-bl.staticStopChan:
return
case <-time.After(blockInterval):
}
Expand All @@ -191,7 +237,7 @@ func (bl *Blocker) threadedRetryLoop() {
}

select {
case <-bl.staticCtx.Done():
case <-bl.staticStopChan:
return
case <-time.After(retryInterval):
}
Expand All @@ -203,8 +249,12 @@ func (bl *Blocker) managedBlock() error {
now := time.Now().UTC()
from := bl.managedLatestBlockTime()

// Create a context
ctx, cancel := context.WithTimeout(context.Background(), database.MongoDefaultTimeout)
defer cancel()

// Fetch hashes to block
hashes, err := bl.staticDB.HashesToBlock(from)
hashes, err := bl.staticDB.HashesToBlock(ctx, from)
if err != nil {
return err
}
Expand Down Expand Up @@ -239,8 +289,12 @@ func (bl *Blocker) managedLatestBlockTime() time.Time {
// managedRetryHashes fetches all blocked skylinks that failed to get blocked
// the first time and retries them.
func (bl *Blocker) managedRetryHashes() error {
// Create a context
ctx, cancel := context.WithTimeout(context.Background(), database.MongoDefaultTimeout)
defer cancel()

// Fetch hashes to retry
hashes, err := bl.staticDB.HashesToRetry()
hashes, err := bl.staticDB.HashesToRetry(ctx)
if err != nil {
return err
}
Expand Down
21 changes: 15 additions & 6 deletions blocker/blocker_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,24 @@ func testBlockHashes(t *testing.T, server *httptest.Server) {
client := api.NewClient(server.URL)

// create the blocker
blocker, err := newTestBlocker("BlockHashes", client)
ctx, cancel := context.WithCancel(context.Background())
blocker, err := newTestBlocker(ctx, "BlockHashes", client)
if err != nil {
t.Fatal(err)
}

// defer a db close
// start the syncer
err = blocker.Start()
if err != nil {
t.Fatal(err)
}

// defer a call to stops
defer func() {
if err := blocker.staticDB.Close(); err != nil {
t.Error(err)
cancel()
err := blocker.Stop()
if err != nil {
t.Fatal(err)
}
}()

Expand Down Expand Up @@ -117,7 +126,7 @@ func testBlockHashes(t *testing.T, server *httptest.Server) {
}

// newTestBlocker returns a new blocker instance
func newTestBlocker(dbName string, skydClient *api.Client) (*Blocker, error) {
func newTestBlocker(ctx context.Context, dbName string, skydClient *api.Client) (*Blocker, error) {
// create a nil logger
logger := logrus.New()
logger.Out = ioutil.Discard
Expand All @@ -126,7 +135,7 @@ func newTestBlocker(dbName string, skydClient *api.Client) (*Blocker, error) {
db := database.NewTestDB(context.Background(), dbName, logger)

// create the blocker
blocker, err := New(context.Background(), skydClient, db, logger)
blocker, err := New(skydClient, db, logger)
if err != nil {
return nil, err
}
Expand Down
Loading

0 comments on commit b403311

Please sign in to comment.