From db01e47ac0a12999576fb94d58a3bfca0d3e7e21 Mon Sep 17 00:00:00 2001 From: PJ Date: Tue, 15 Feb 2022 15:22:49 +0100 Subject: [PATCH 1/6] Update the BlockPOST struct to allow for passing hashes, remove Skylink from the DB struct --- README.md | 13 ++-- api/handlers.go | 133 +++++++++++++++++++++++++------------- database/database.go | 35 ++++++++-- database/database_test.go | 55 ++++++++++------ database/skylink.go | 1 - 5 files changed, 156 insertions(+), 81 deletions(-) diff --git a/README.md b/README.md index 0222a55..469c15f 100644 --- a/README.md +++ b/README.md @@ -13,23 +13,18 @@ and/or log files. # AllowList -The blocker service can only block skylinks which are not in the allow list. -To add a skylink to the allow list, one has to manually query the database and +The blocker service can only block hashes which are not in the allow list. +To add a hash to the allow list, one has to manually query the database and perform the follow operation: ``` db.getCollection('allowlist').insertOne({ - skylink: "[INSERT V1 SKYLINK HERE]", - description: "[INSERT SKYLINK DESCRIPTION]", + hash: "[INSERT HAHS OF V1 SKYLINK HERE]", + description: "[INSERT DESCRIPTION]", timestamp_added: new Date(), }) ``` -The skylink is expected to be in the following form: `_B19BtlWtjjR7AD0DDzxYanvIhZ7cxXrva5tNNxDht1kaA`. -So that's without portal and without the `sia://` prefix. The allow list is -persisted as is, so not as a hash, for ease of use and because it is assumed the -allowlist only holds non-abusive content. - # Environment This service depends on the following environment variables: diff --git a/api/handlers.go b/api/handlers.go index 1fffde4..ab067c4 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -14,6 +14,7 @@ import ( "gitlab.com/NebulousLabs/errors" skyapi "gitlab.com/SkynetLabs/skyd/node/api" "gitlab.com/SkynetLabs/skyd/skymodules" + "go.sia.tech/siad/crypto" ) var ( @@ -28,6 +29,15 @@ type ( Skylink skylink `json:"skylink"` Reporter Reporter `json:"reporter"` Tags []string `json:"tags"` + + // Hash represents the hash of the Skylink's merkle root. Either 'hash' + // or 'skylink' must be set. If both are set the Skylink's hash value + // must correspond with the hash. + // + // It is encouraged to use this field when possible as it allows + // services that interact with the blocker to only deal with hashes + // instead of skylinks. + Hash crypto.Hash `json:"hash"` } // BlockWithPoWPOST describes a request to the /blockpow endpoint @@ -177,79 +187,54 @@ func (api *API) blockWithPoWGET(w http.ResponseWriter, r *http.Request, _ httpro // block handlers. It executes all code which is shared between the two // handlers. func (api *API) handleBlockRequest(ctx context.Context, w http.ResponseWriter, bp BlockPOST, sub string) { - // Decode the skylink, we can safely ignore the error here as LoadString - // will have been called by the JSON decoder - var skylink skymodules.Skylink - _ = skylink.LoadString(string(bp.Skylink)) - - // Resolve the skylink - resolved, err := api.staticSkydAPI.ResolveSkylink(skylink) - if err == nil { - // replace the skylink with the resolved skylink - skylink = resolved - } else { - // in case of an error we log and continue with the given skylink - api.staticLogger.Errorf("failed to resolve skylink '%v', err: %v", skylink, err) - } - - // Sanity check the skylink is a v1 skylink - if !skylink.IsSkylinkV1() { - skyapi.WriteError(w, skyapi.Error{"failed to resolve skylink"}, http.StatusInternalServerError) + // Resolve the post body into a hash + hash, err := api.resolveHash(bp) + if err != nil { + skyapi.WriteError(w, skyapi.Error{"failed to resolve skylink"}, http.StatusBadRequest) return } // Check whether the skylink is on the allow list - if api.isAllowListed(ctx, skylink) { + if api.isAllowListed(ctx, "", hash.String()) { skyapi.WriteJSON(w, statusResponse{"reported"}) return } - // Block the link. - err = api.block(ctx, bp, skylink, sub, sub == "") - if errors.Contains(err, database.ErrSkylinkExists) { - skyapi.WriteJSON(w, statusResponse{"duplicate"}) - return - } - if err != nil { - skyapi.WriteError(w, skyapi.Error{err.Error()}, http.StatusInternalServerError) - return - } - skyapi.WriteJSON(w, statusResponse{"reported"}) -} - -// block blocks a skylink -func (api *API) block(ctx context.Context, bp BlockPOST, skylink skymodules.Skylink, sub string, unauthenticated bool) error { - // TODO: currently we still set the Skylink, as soon as this module is - // converted to work fully with hashes, the Skylink field needs to be - // dropped. + // Create a blocked skylink object bs := &database.BlockedSkylink{ - Skylink: skylink.String(), - Hash: database.NewHash(skylink), + Hash: database.Hash{Hash: hash}, Reporter: database.Reporter{ Name: bp.Reporter.Name, Email: bp.Reporter.Email, OtherContact: bp.Reporter.OtherContact, Sub: sub, - Unauthenticated: unauthenticated, + Unauthenticated: sub == "", }, Tags: bp.Tags, TimestampAdded: time.Now().UTC(), } + + // Block the link. api.staticLogger.Debugf("blocking hash %s", bs.Hash) - err := api.staticDB.CreateBlockedSkylink(ctx, bs) + err = api.staticDB.CreateBlockedSkylink(ctx, bs) + if errors.Contains(err, database.ErrSkylinkExists) { + skyapi.WriteJSON(w, statusResponse{"duplicate"}) + return + } if err != nil { - return err + skyapi.WriteError(w, skyapi.Error{err.Error()}, http.StatusInternalServerError) + return } api.staticLogger.Debugf("blocked hash %s", bs.Hash) - return nil + skyapi.WriteJSON(w, statusResponse{"reported"}) } // isAllowListed returns true if the given skylink is on the allow list // // NOTE: the given skylink is expected to be a v1 skylink, meaning the caller of // this function should have tried to resolve the skylink beforehand -func (api *API) isAllowListed(ctx context.Context, skylink skymodules.Skylink) bool { - allowlisted, err := api.staticDB.IsAllowListed(ctx, skylink.String()) +func (api *API) isAllowListed(ctx context.Context, skylink, hash string) bool { + allowlisted, err := api.staticDB.IsAllowListed(ctx, skylink, hash) if err != nil { api.staticLogger.Error("failed to verify skylink against the allow list", err) return false @@ -257,6 +242,64 @@ func (api *API) isAllowListed(ctx context.Context, skylink skymodules.Skylink) b return allowlisted } +// resolveHash resolves the given block post object into a hash. If a hash was +// already given, it will simply return that. If a skylink was given, it will +// try to resolve it first if necessary and return the hash of the v1 skylink. +func (api *API) resolveHash(bp BlockPOST) (crypto.Hash, error) { + // validate the block post + err := bp.validate() + if err != nil { + return crypto.Hash{}, err + } + + // if the hash is set, we are done + if bp.Hash.String() != "" { + return bp.Hash, nil + } + + // decode the skylink + var skylink skymodules.Skylink + err = skylink.LoadString(string(bp.Skylink)) + if err != nil { + return crypto.Hash{}, errors.AddContext(err, "failed to load skylink") + } + + // resolve the skylink + skylink, err = api.staticSkydAPI.ResolveSkylink(skylink) + if err != nil { + return crypto.Hash{}, errors.AddContext(err, "failed to resolve skylink") + } + + // sanity check the skylink is a v1 skylink + if !skylink.IsSkylinkV1() { + return crypto.Hash{}, errors.AddContext(err, "failed to resolve skylink") + } + + // return the hash + return crypto.HashObject(skylink.MerkleRoot()), nil +} + +// validate returns an error if the block post object is constructed in an +// illegal fashion, which can happen if the hash does not match the hash of the +// skylink's merkle root for instance +func (bp *BlockPOST) validate() error { + if bp.Hash.String() == "" && bp.Skylink == "" { + return errors.New("hash or skylink is required") + } + if bp.Hash.String() != "" && bp.Skylink != "" { + var sl skymodules.Skylink + err := sl.LoadString(string(bp.Skylink)) + if err != nil { + return errors.AddContext(err, "could not load skylink") + } + + if crypto.HashObject(sl.MerkleRoot()) != bp.Hash { + return errors.New("hash does not match the skylink") + } + } + return nil +} + // extractSkylinkHash extracts the skylink hash from the given skylink that // might have protocol, path, etc. within it. func extractSkylinkHash(skylink string) (string, error) { diff --git a/database/database.go b/database/database.go index db9907c..2173b19 100644 --- a/database/database.go +++ b/database/database.go @@ -10,6 +10,7 @@ import ( "gitlab.com/NebulousLabs/errors" "gitlab.com/SkynetLabs/skyd/skymodules" "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" "go.mongodb.org/mongo-driver/mongo" "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" @@ -199,8 +200,14 @@ func (db *DB) FindByHash(ctx context.Context, hash Hash) (*BlockedSkylink, error } // IsAllowListed returns whether the given skylink is on the allow list. -func (db *DB) IsAllowListed(ctx context.Context, skylink string) (bool, error) { - res := db.staticAllowList.FindOne(ctx, bson.M{"skylink": skylink}) +func (db *DB) IsAllowListed(ctx context.Context, skylink, hash string) (bool, error) { + res := db.staticAllowList.FindOne( + ctx, + bson.D{{"$or", []interface{}{ + bson.M{"skylink": skylink}, + bson.M{"hash": hash}, + }}}, + ) if isDocumentNotFound(res.Err()) { return false, nil } @@ -343,7 +350,13 @@ func (db *DB) SetLatestBlockTimestamp(t time.Time) error { // the database to hashes. Skylinks should not be persisted in their plain form // in the database. func (db *DB) compatTransformSkylinkToHash(ctx context.Context) error { - collSkylinks := db.staticDB.Collection(collSkylinks) + skylinks := db.staticDB.Collection(collSkylinks) + + // define an inline type that defines a legacy mongo document with skylink + type blockedSkylinkCompat struct { + ID primitive.ObjectID `bson:"_id,omitempty"` + Skylink string `bson:"skylink"` + } // define a filter that matches documents with skylink and no hash filter := bson.D{{"$and", []interface{}{ @@ -356,9 +369,19 @@ func (db *DB) compatTransformSkylinkToHash(ctx context.Context) error { }}} // find all documents where the skyink has to be transformed to a hash - docs, err := db.find(ctx, filter) + c, err := skylinks.Find(ctx, filter) + if isDocumentNotFound(err) { + return nil + } if err != nil { - return err + return errors.AddContext(err, "failed fetching documents that need to be transformed to having hashes instead of skylinks") + } + + // hydrate the cursor into a list of compat objects + docs := make([]blockedSkylinkCompat, 0) + err = c.All(db.ctx, &docs) + if err != nil { + return errors.AddContext(err, "failed hydrating the cursor of documents into compat objects") } // return if no docs need to be transformed @@ -396,7 +419,7 @@ func (db *DB) compatTransformSkylinkToHash(ctx context.Context) error { }}}, }}} value := bson.M{"$set": bson.M{"hash": NewHash(sl)}} - _, err = collSkylinks.UpdateOne(ctx, filter, value) + _, err = skylinks.UpdateOne(ctx, filter, value) if err != nil { db.staticLogger.Errorf("failed to update hash of document with ID '%v', err %v", doc.ID, err) continue diff --git a/database/database_test.go b/database/database_test.go index 8e4ec7b..3ef0640 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -143,7 +143,6 @@ func testCreateBlockedSkylink(t *testing.T) { }, Reverted: true, RevertedTags: []string{"A"}, - Skylink: sl.String(), Tags: []string{"B"}, TimestampAdded: now, TimestampReverted: now.AddDate(1, 1, 1), @@ -197,7 +196,7 @@ func testIsAllowListedSkylink(t *testing.T) { } // Check the result of 'IsAllowListed' - allowListed, err := db.IsAllowListed(ctx, skylink) + allowListed, err := db.IsAllowListed(ctx, skylink, "") if err != nil { t.Fatal(err) } @@ -207,7 +206,7 @@ func testIsAllowListedSkylink(t *testing.T) { // Check against a different skylink skylink = "ABC9BtlWtjjR7AD0DDzxYanvIhZ7cxXrva5tNNxDht1ABC" - allowListed, err = db.IsAllowListed(ctx, skylink) + allowListed, err = db.IsAllowListed(ctx, skylink, "") if err != nil { t.Fatal(err) } @@ -236,14 +235,12 @@ func testMarkAsSucceeded(t *testing.T) { // insert a regular document and one that was marked as failed db.staticSkylinks.InsertOne(ctx, BlockedSkylink{ - Skylink: "skylink_1", Hash: HashBytes([]byte("skylink_1")), Reporter: Reporter{}, Tags: []string{"tag_1"}, TimestampAdded: time.Now().UTC(), }) db.staticSkylinks.InsertOne(ctx, BlockedSkylink{ - Skylink: "skylink_2", Hash: HashBytes([]byte("skylink_2")), Reporter: Reporter{}, Tags: []string{"tag_1"}, @@ -293,14 +290,12 @@ func testMarkAsFailed(t *testing.T) { // insert two regular documents db.CreateBlockedSkylink(ctx, &BlockedSkylink{ - Skylink: "skylink_1", Hash: HashBytes([]byte("skylink_1")), Reporter: Reporter{}, Tags: []string{"tag_1"}, TimestampAdded: time.Now().UTC(), }) db.CreateBlockedSkylink(ctx, &BlockedSkylink{ - Skylink: "skylink_2", Hash: HashBytes([]byte("skylink_2")), Reporter: Reporter{}, Tags: []string{"tag_1"}, @@ -397,14 +392,23 @@ func testCompatTransformSkylinkToHash(t *testing.T) { sl2 := skylinkFromString("IABst6HgaJ0PIBMtmQ2qgH_wQlFg4bNnwAhff7DmJP6oyg") sl3 := skylinkFromString("IABXaRBvjDTB3RizX3RfwdCoxt2Tff1buEXhlO7b9Unn8g") + // define an inline type that defines a legacy mongo document with skylink + type blockedSkylinkCompat struct { + Hash Hash `bson:"hash"` + Skylink string `bson:"skylink"` + Reporter Reporter `bson:"reporter"` + Tags []string `bson:"tags"` + TimestampAdded time.Time `bson:"timestamp_added"` + } + // insert two documents with a skylink but no hash (like it was originally) - db.staticSkylinks.InsertOne(ctx, BlockedSkylink{ + db.staticSkylinks.InsertOne(ctx, blockedSkylinkCompat{ Skylink: sl1.String(), Reporter: Reporter{}, Tags: []string{"tag_1"}, TimestampAdded: time.Now().UTC(), }) - db.staticSkylinks.InsertOne(ctx, BlockedSkylink{ + db.staticSkylinks.InsertOne(ctx, blockedSkylinkCompat{ Skylink: sl2.String(), Reporter: Reporter{}, Tags: []string{"tag_1"}, @@ -412,7 +416,7 @@ func testCompatTransformSkylinkToHash(t *testing.T) { }) // insert one documents with a skylink and a hash - db.staticSkylinks.InsertOne(ctx, BlockedSkylink{ + db.staticSkylinks.InsertOne(ctx, blockedSkylinkCompat{ Skylink: sl3.String(), Hash: NewHash(sl3), Reporter: Reporter{}, @@ -435,34 +439,45 @@ func testCompatTransformSkylinkToHash(t *testing.T) { // find all skylink documents, in order opts := options.Find() opts.SetSort(bson.D{{"timestamp_added", 1}}) - skylinks, err := db.find(ctx, bson.D{}, opts) + + // find all documents where the skyink has to be transformed to a hash + skylinks := db.staticDB.Collection(collSkylinks) + c, err := skylinks.Find(ctx, bson.D{}, opts) + + if err != nil && !isDocumentNotFound(err) { + t.Fatal(err) + } + + // decode them into compat objects + docs := make([]blockedSkylinkCompat, 0) + err = c.All(db.ctx, &docs) if err != nil { t.Fatal(err) } - if len(skylinks) != 3 { - t.Fatalf("unexpected amount of docs, %v != 3", len(skylinks)) + if len(docs) != 3 { + t.Fatalf("unexpected amount of docs, %v != 3", len(docs)) } // assert the skylink and hash value for all documents - sl11 := skylinkFromString(skylinks[0].Skylink) + sl11 := skylinkFromString(docs[0].Skylink) if sl11.String() != sl1.String() { t.Fatal("unexpected skylink value") } - if NewHash(sl11) != skylinks[0].Hash { - t.Fatal("unexpected hash value", NewHash(sl11), skylinks[0].Hash) + if NewHash(sl11) != docs[0].Hash { + t.Fatal("unexpected hash value", NewHash(sl11), docs[0].Hash) } - sl22 := skylinkFromString(skylinks[1].Skylink) + sl22 := skylinkFromString(docs[1].Skylink) if sl22.String() != sl2.String() { t.Fatal("unexpected skylink value") } - if NewHash(sl22) != skylinks[1].Hash { + if NewHash(sl22) != docs[1].Hash { t.Fatal("unexpected hash value") } - sl33 := skylinkFromString(skylinks[2].Skylink) + sl33 := skylinkFromString(docs[2].Skylink) if sl33.String() != sl3.String() { t.Fatal("unexpected skylink value") } - if NewHash(sl33) != skylinks[2].Hash { + if NewHash(sl33) != docs[2].Hash { t.Fatal("unexpected hash value") } } diff --git a/database/skylink.go b/database/skylink.go index 864a8d2..e31ef6f 100644 --- a/database/skylink.go +++ b/database/skylink.go @@ -65,7 +65,6 @@ type BlockedSkylink struct { Reporter Reporter `bson:"reporter"` Reverted bool `bson:"reverted"` RevertedTags []string `bson:"reverted_tags"` - Skylink string `bson:"skylink"` Tags []string `bson:"tags"` TimestampAdded time.Time `bson:"timestamp_added"` TimestampReverted time.Time `bson:"timestamp_reverted"` From 490da9bfe278550b4572e3158fb4b35384f09698 Mon Sep 17 00:00:00 2001 From: PJ Date: Tue, 15 Feb 2022 15:28:28 +0100 Subject: [PATCH 2/6] Drop unique index on skylink, add unique index on hash --- database/database.go | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/database/database.go b/database/database.go index 2173b19..6d984c7 100644 --- a/database/database.go +++ b/database/database.go @@ -506,8 +506,8 @@ func ensureDBSchema(ctx context.Context, db *mongo.Database, log *logrus.Logger) schema := map[string][]mongo.IndexModel{ collAllowlist: { { - Keys: bson.D{{"skylink", 1}}, - Options: options.Index().SetName("skylink").SetUnique(true), + Keys: bson.D{{"hash", 1}}, + Options: options.Index().SetName("hash").SetUnique(true), }, { Keys: bson.D{{"timestamp_added", 1}}, @@ -515,14 +515,9 @@ func ensureDBSchema(ctx context.Context, db *mongo.Database, log *logrus.Logger) }, }, collSkylinks: { - // TODO: the schema should be extended here to have a unique index - // on the 'hash' field, this can be done safely if the compat code - // has executed and the blocker has been running on hashes for a - // while, at that time the skylink index should be dropped and - // prevented from being set in the first place { - Keys: bson.D{{"skylink", 1}}, - Options: options.Index().SetName("skylink").SetUnique(true), + Keys: bson.D{{"hash", 1}}, + Options: options.Index().SetName("hash").SetUnique(true), }, { Keys: bson.D{{"timestamp_added", 1}}, @@ -564,6 +559,15 @@ func ensureDBSchema(ctx context.Context, db *mongo.Database, log *logrus.Logger) if icErr != nil { return errors.Compose(icErr, ErrIndexCreateFailed) } + + // drop the old indices on 'skylink' + _, err1 := db.Collection(collAllowlist).Indexes().DropOne(ctx, "skylink") + _, err2 := db.Collection(collSkylinks).Indexes().DropOne(ctx, "skylink") + err := errors.Compose(err1, err2) + if err != nil { + return errors.AddContext(err, "failed droppping 'skylink' index") + } + return nil } From 39cf9200fcda1c39b26755c7006bd6e7c7ee72df Mon Sep 17 00:00:00 2001 From: PJ Date: Tue, 15 Feb 2022 15:32:34 +0100 Subject: [PATCH 3/6] Fix typo --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 469c15f..fef78da 100644 --- a/README.md +++ b/README.md @@ -19,7 +19,7 @@ perform the follow operation: ``` db.getCollection('allowlist').insertOne({ - hash: "[INSERT HAHS OF V1 SKYLINK HERE]", + hash: "[INSERT HASH OF V1 SKYLINK HERE]", description: "[INSERT DESCRIPTION]", timestamp_added: new Date(), }) From 218d156ba73d015945e75bd7d6f2e5d3938b7d17 Mon Sep 17 00:00:00 2001 From: PJ Date: Fri, 4 Mar 2022 13:12:47 +0100 Subject: [PATCH 4/6] Move allowlist to hashes --- api/handlers.go | 26 ++------ api/handlers_test.go | 30 +++++---- database/database.go | 95 ++++++++++++++++++++------ database/database_test.go | 137 ++++++++++++++++++++++++++++++-------- database/skylink.go | 2 +- 5 files changed, 208 insertions(+), 82 deletions(-) diff --git a/api/handlers.go b/api/handlers.go index a44694f..f6517cb 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -259,7 +259,7 @@ func (api *API) handleBlockRequest(ctx context.Context, w http.ResponseWriter, b } // Check whether the skylink is on the allow list - if api.isAllowListed(ctx, "", hash.String()) { + if api.isAllowListed(ctx, hash) { skyapi.WriteJSON(w, statusResponse{"reported"}) return } @@ -297,8 +297,8 @@ func (api *API) handleBlockRequest(ctx context.Context, w http.ResponseWriter, b // // NOTE: the given skylink is expected to be a v1 skylink, meaning the caller of // this function should have tried to resolve the skylink beforehand -func (api *API) isAllowListed(ctx context.Context, skylink, hash string) bool { - allowlisted, err := api.staticDB.IsAllowListed(ctx, skylink, hash) +func (api *API) isAllowListed(ctx context.Context, hash crypto.Hash) bool { + allowlisted, err := api.staticDB.IsAllowListed(ctx, hash) if err != nil { api.staticLogger.Error("failed to verify skylink against the allow list", err) return false @@ -317,7 +317,7 @@ func (api *API) resolveHash(bp BlockPOST) (crypto.Hash, error) { } // if the hash is set, we are done - if bp.Hash.String() != "" { + if bp.Hash != (crypto.Hash{}) { return bp.Hash, nil } @@ -343,24 +343,12 @@ func (api *API) resolveHash(bp BlockPOST) (crypto.Hash, error) { return crypto.HashObject(skylink.MerkleRoot()), nil } -// validate returns an error if the block post object is constructed in an -// illegal fashion, which can happen if the hash does not match the hash of the -// skylink's merkle root for instance +// validate returns an error if the block post object does not contain a hash or +// skylink func (bp *BlockPOST) validate() error { - if bp.Hash.String() == "" && bp.Skylink == "" { + if bp.Hash == (crypto.Hash{}) && bp.Skylink == "" { return errors.New("hash or skylink is required") } - if bp.Hash.String() != "" && bp.Skylink != "" { - var sl skymodules.Skylink - err := sl.LoadString(string(bp.Skylink)) - if err != nil { - return errors.AddContext(err, "could not load skylink") - } - - if crypto.HashObject(sl.MerkleRoot()) != bp.Hash { - return errors.New("hash does not match the skylink") - } - } return nil } diff --git a/api/handlers_test.go b/api/handlers_test.go index 1ffab5c..450b27d 100644 --- a/api/handlers_test.go +++ b/api/handlers_test.go @@ -115,10 +115,18 @@ func testHandleBlockRequest(t *testing.T) { // create a response writer w := newMockResponseWriter() - // allow list a skylink + // create skylink + var sl skymodules.Skylink + err = sl.LoadString(v1SkylinkStr) + if err != nil { + t.Fatal(err) + } + + // allowlist a skylink + hash := database.NewHash(sl) err = api.staticDB.CreateAllowListedSkylink(ctx, &database.AllowListedSkylink{ - Skylink: v1SkylinkStr, - Description: "test skylink", + Hash: hash, + Description: "test hash", TimestampAdded: time.Now().UTC(), }) if err != nil { @@ -146,15 +154,10 @@ func testHandleBlockRequest(t *testing.T) { t.Fatal("unexpected error", err) } if resp.Status != "reported" { - t.Fatal("unexpected response status", resp.Status) + t.Fatal("unexpected response status", resp.Status, resp) } // assert the blocked skylink did not make it into the database - var sl skymodules.Skylink - err = sl.LoadString(v1SkylinkStr) - if err != nil { - t.Fatal(err) - } doc, err := api.staticDB.FindByHash(ctx, database.NewHash(sl)) if err != nil { t.Fatal("unexpected error", err) @@ -164,8 +167,8 @@ func testHandleBlockRequest(t *testing.T) { } // up until now we have asserted that the skylink gets resolved and the - // allow list gets checked, note that this is only meaningful if the below - // assertions pass also (happy path) + // allowlist gets checked, note that this is only meaningful if the below + // assertions also pass (happy path) // load a random skylink err = sl.LoadString("_B19BtlWtjjR7AD0DDzxYanvIhZ7cxXrva5tNNxDht1kaA") @@ -187,7 +190,7 @@ func testHandleBlockRequest(t *testing.T) { // call the request handler w.Reset() api.handleBlockRequest(context.Background(), w, bp, "") - + return // assert the handler writes a 'reported' status response err = json.Unmarshal(w.staticBuffer.Bytes(), &resp) if err != nil { @@ -252,8 +255,7 @@ func testHandleBlocklistGET(t *testing.T) { skylink := fmt.Sprintf("skylink_%d", i) offset := time.Duration(i) * time.Second err = api.staticDB.CreateBlockedSkylink(ctx, &database.BlockedSkylink{ - Skylink: skylink, - Hash: database.HashBytes([]byte(skylink)), + Hash: database.HashBytes([]byte(skylink)), Reporter: database.Reporter{ Name: "John Doe", }, diff --git a/database/database.go b/database/database.go index 1c44ac6..2c4b823 100644 --- a/database/database.go +++ b/database/database.go @@ -13,15 +13,16 @@ import ( "go.mongodb.org/mongo-driver/mongo/options" "go.mongodb.org/mongo-driver/mongo/readpref" "go.mongodb.org/mongo-driver/mongo/writeconcern" + "go.sia.tech/siad/crypto" ) const ( // MongoDefaultTimeout is the timeout for the context used in testing // whenever a context is sent to mongo - MongoDefaultTimeout = time.Minute + MongoDefaultTimeout = 10 * time.Minute // mongoIndexCreateTimeout is the timeout used when creating indices - mongoIndexCreateTimeout = 10 * time.Second + mongoIndexCreateTimeout = 10 * time.Minute ) var ( @@ -33,6 +34,10 @@ var ( // ensure an index ErrIndexCreateFailed = errors.New("failed to create index") + // ErrIndexDropFailed is returned when an error occurred when trying to + // drop an index + ErrIndexDropFailed = errors.New("failed to drop an index") + // ErrNoDocumentsFound is returned when a database operation completes // successfully but it doesn't find or affect any documents. ErrNoDocumentsFound = errors.New("no documents") @@ -56,8 +61,10 @@ var ( // collSkylinks defines the name of the skylinks collection collSkylinks = "skylinks" + // collAllowlist defines the name of the allowlist collection collAllowlist = "allowlist" + // collLatestBlockTimestamps collLatestBlockTimestamps collLatestBlockTimestamps = "latest_block_timestamps" ) @@ -210,14 +217,8 @@ func (db *DB) FindByHash(ctx context.Context, hash Hash) (*BlockedSkylink, error } // IsAllowListed returns whether the given skylink is on the allow list. -func (db *DB) IsAllowListed(ctx context.Context, skylink, hash string) (bool, error) { - res := db.staticAllowList.FindOne( - ctx, - bson.D{{"$or", []interface{}{ - bson.M{"skylink": skylink}, - bson.M{"hash": hash}, - }}}, - ) +func (db *DB) IsAllowListed(ctx context.Context, hash crypto.Hash) (bool, error) { + res := db.staticAllowList.FindOne(ctx, bson.M{"hash": hash.String()}) if isDocumentNotFound(res.Err()) { return false, nil } @@ -463,9 +464,13 @@ func ensureDBSchema(ctx context.Context, db *mongo.Database, log *logrus.Logger) }, } - icOpts := options.CreateIndexes().SetMaxTime(mongoIndexCreateTimeout) + // build the options + opts := options.CreateIndexes() + opts.SetMaxTime(mongoIndexCreateTimeout) + opts.SetCommitQuorumString("majority") // defaults to all - var icErr error + // ensure all collections and indices exist + var createErr error for collName, models := range schema { coll, err := ensureCollection(ctx, db, collName) if err != nil { @@ -474,28 +479,76 @@ func ensureDBSchema(ctx context.Context, db *mongo.Database, log *logrus.Logger) } iv := coll.Indexes() - names, err := iv.CreateMany(ctx, models, icOpts) + names, err := iv.CreateMany(ctx, models, opts) if err != nil { // if the index creation fails, compose the error but continue to // try and ensure the rest of the database schema - icErr = errors.Compose(icErr, errors.AddContext(err, fmt.Sprintf("collection '%v'", collName))) + createErr = errors.Compose(createErr, errors.AddContext(err, fmt.Sprintf("collection '%v'", collName))) continue } + log.Debugf("Ensured index exists: %v | %v", collName, names) } - if icErr != nil { - return errors.Compose(icErr, ErrIndexCreateFailed) + if createErr != nil { + createErr = errors.Compose(createErr, ErrIndexCreateFailed) } // drop the old indices on 'skylink' - _, err1 := db.Collection(collAllowlist).Indexes().DropOne(ctx, "skylink") - _, err2 := db.Collection(collSkylinks).Indexes().DropOne(ctx, "skylink") - err := errors.Compose(err1, err2) + _, err1 := dropIndex(ctx, db.Collection(collAllowlist), "skylink") + _, err2 := dropIndex(ctx, db.Collection(collSkylinks), "skylink") + dropErr := errors.Compose(err1, err2) + if dropErr != nil { + dropErr = errors.Compose(dropErr, ErrIndexDropFailed) + } + + return errors.Compose(createErr, dropErr) +} + +// dropIndex is a helper function that drops the index with given name on the +// given collection +func dropIndex(ctx context.Context, coll *mongo.Collection, indexName string) (bool, error) { + hasIndex, err := hasIndex(ctx, coll, indexName) if err != nil { - return errors.AddContext(err, "failed droppping 'skylink' index") + return false, err } - return nil + if !hasIndex { + return false, nil + } + + _, err = coll.Indexes().DropOne(ctx, indexName) + if err != nil { + return false, err + } + + return true, nil +} + +// hasIndex is a helper function that returns true if the given collection has +// an index with given name +func hasIndex(ctx context.Context, coll *mongo.Collection, indexName string) (bool, error) { + idxs := coll.Indexes() + + cur, err := idxs.List(ctx) + if err != nil { + return false, err + } + + var result []bson.M + err = cur.All(ctx, &result) + if err != nil { + return false, err + } + + found := false + for _, v := range result { + for k, v1 := range v { + if k == "name" && v1 == indexName { + found = true + } + } + } + return found, nil } // ensureCollection gets the given collection from the diff --git a/database/database_test.go b/database/database_test.go index 46e396e..a6cfaea 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -2,6 +2,7 @@ package database import ( "context" + "crypto/rand" "encoding/json" "fmt" "io/ioutil" @@ -14,27 +15,9 @@ import ( "gitlab.com/SkynetLabs/skyd/skymodules" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo/options" + "go.sia.tech/siad/crypto" ) -// newTestDB creates a new database for a given test's name. -func newTestDB(ctx context.Context, dbName string) *DB { - dbName = strings.ReplaceAll(dbName, "/", "-") - logger := logrus.New() - logger.Out = ioutil.Discard - db, err := NewCustomDB(ctx, "mongodb://localhost:37017", dbName, options.Credential{ - Username: "admin", - Password: "aO4tV5tC1oU3oQ7u", - }, logger) - if err != nil { - panic(err) - } - err = db.Purge(ctx) - if err != nil { - panic(err) - } - return db -} - // TestDatabase runs the database unit tests. func TestDatabase(t *testing.T) { if testing.Short() { @@ -66,6 +49,14 @@ func TestDatabase(t *testing.T) { name: "MarkAsFailed", test: testMarkAsFailed, }, + { + name: "HasIndex", + test: testHasIndex, + }, + { + name: "DropIndex", + test: testDropIndex, + }, } for _, test := range tests { t.Run(test.name, test.test) @@ -181,9 +172,9 @@ func testIsAllowListedSkylink(t *testing.T) { defer db.Close() // Add a skylink in the allow list - skylink := "_B19BtlWtjjR7AD0DDzxYanvIhZ7cxXrva5tNNxDht1kaA" - _, err := db.staticAllowList.InsertOne(ctx, &AllowListedSkylink{ - Skylink: skylink, + hash := randomHash() + err := db.CreateAllowListedSkylink(ctx, &AllowListedSkylink{ + Hash: Hash{hash}, Description: "test skylink", TimestampAdded: time.Now().UTC(), }) @@ -192,7 +183,7 @@ func testIsAllowListedSkylink(t *testing.T) { } // Check the result of 'IsAllowListed' - allowListed, err := db.IsAllowListed(ctx, skylink, "") + allowListed, err := db.IsAllowListed(ctx, hash) if err != nil { t.Fatal(err) } @@ -201,8 +192,8 @@ func testIsAllowListedSkylink(t *testing.T) { } // Check against a different skylink - skylink = "ABC9BtlWtjjR7AD0DDzxYanvIhZ7cxXrva5tNNxDht1ABC" - allowListed, err = db.IsAllowListed(ctx, skylink, "") + hash2 := randomHash() + allowListed, err = db.IsAllowListed(ctx, hash2) if err != nil { t.Fatal(err) } @@ -230,19 +221,25 @@ func testMarkAsSucceeded(t *testing.T) { } // insert a regular document and one that was marked as failed - db.staticSkylinks.InsertOne(ctx, BlockedSkylink{ + err = db.CreateBlockedSkylink(ctx, &BlockedSkylink{ Hash: HashBytes([]byte("skylink_1")), Reporter: Reporter{}, Tags: []string{"tag_1"}, TimestampAdded: time.Now().UTC(), }) - db.staticSkylinks.InsertOne(ctx, BlockedSkylink{ + if err != nil { + t.Fatal(err) + } + err = db.CreateBlockedSkylink(ctx, &BlockedSkylink{ Hash: HashBytes([]byte("skylink_2")), Reporter: Reporter{}, Tags: []string{"tag_1"}, TimestampAdded: time.Now().UTC(), Failed: true, }) + if err != nil { + t.Fatal(err) + } toRetry, err := db.HashesToRetry() if err != nil { @@ -342,6 +339,66 @@ func testMarkAsFailed(t *testing.T) { // no need to mark them as succeeded, the other unit test covers that } +// testHasIndex is a unit test that verifies the functionality of the hasIndex +// helper function +func testHasIndex(t *testing.T) { + // create context + ctx, cancel := context.WithTimeout(context.Background(), MongoDefaultTimeout) + defer cancel() + + // create test database + db := newTestDB(ctx, t.Name()) + defer db.Close() + + // check whether we can find an index we expect to be there + found, err := hasIndex(ctx, db.staticSkylinks, "hash") + if err != nil { + t.Fatal(err) + } + if !found { + t.Fatal("unexpected") + } + + // check whether the output is correct for a made up index name + found, err = hasIndex(ctx, db.staticSkylinks, "nonexistingindexname") + if err != nil { + t.Fatal(err) + } + if found { + t.Fatal("unexpected") + } +} + +// testDropIndex is a unit test that verifies the functionality of the dropIndex +// helper function +func testDropIndex(t *testing.T) { + // create context + ctx, cancel := context.WithTimeout(context.Background(), MongoDefaultTimeout) + defer cancel() + + // create test database + db := newTestDB(ctx, t.Name()) + defer db.Close() + + // check whether dropIndex errors out on an unknown index + dropped, err := dropIndex(ctx, db.staticSkylinks, "nonexistingindexname") + if err != nil { + t.Fatal(err) + } + if dropped { + t.Fatal("unexpected") + } + + // check the output for an existing index + dropped, err = dropIndex(ctx, db.staticSkylinks, "hash") + if err != nil { + t.Fatal(err) + } + if !dropped { + t.Fatal("unexpected") + } +} + // define a helper function to decode a skylink as string into a skylink obj func skylinkFromString(skylink string) (sl skymodules.Skylink) { err := sl.LoadString(skylink) @@ -350,3 +407,29 @@ func skylinkFromString(skylink string) (sl skymodules.Skylink) { } return } + +// newTestDB creates a new database for a given test's name. +func newTestDB(ctx context.Context, dbName string) *DB { + dbName = strings.ReplaceAll(dbName, "/", "-") + logger := logrus.New() + logger.Out = ioutil.Discard + db, err := NewCustomDB(ctx, "mongodb://localhost:37017", dbName, options.Credential{ + Username: "admin", + Password: "aO4tV5tC1oU3oQ7u", + }, logger) + if err != nil { + panic(err) + } + err = db.Purge(ctx) + if err != nil { + panic(err) + } + return db +} + +// randomHash returns a random hash +func randomHash() crypto.Hash { + var h crypto.Hash + rand.Read(h[:]) + return h +} diff --git a/database/skylink.go b/database/skylink.go index e31ef6f..57f3f7d 100644 --- a/database/skylink.go +++ b/database/skylink.go @@ -52,7 +52,7 @@ func (h *Hash) UnmarshalBSONValue(t bsontype.Type, b []byte) error { // ever being blocked. type AllowListedSkylink struct { ID primitive.ObjectID `bson:"_id,omitempty"` - Skylink string `bson:"skylink"` + Hash Hash `bson:"hash"` Description string `bson:"description"` TimestampAdded time.Time `bson:"timestamp_added"` } From dd62850ee81958cf27c5fe254adf682150f20c81 Mon Sep 17 00:00:00 2001 From: PJ Date: Thu, 7 Apr 2022 13:34:03 +0200 Subject: [PATCH 5/6] Remove debug return --- api/handlers_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/api/handlers_test.go b/api/handlers_test.go index 2273708..36c37e4 100644 --- a/api/handlers_test.go +++ b/api/handlers_test.go @@ -197,7 +197,7 @@ func testHandleBlockRequest(t *testing.T, server *httptest.Server) { // call the request handler w.Reset() api.handleBlockRequest(context.Background(), w, bp, "") - return + // assert the handler writes a 'reported' status response err = json.Unmarshal(w.staticBuffer.Bytes(), &resp) if err != nil { From e5eca9897a0be9086ebb5cd0fc2e6782c0fcecc0 Mon Sep 17 00:00:00 2001 From: PJ Date: Thu, 21 Apr 2022 10:20:08 +0200 Subject: [PATCH 6/6] Implement MR remarks --- api/handlers.go | 20 ++++++++++++++++---- database/database.go | 4 ++-- database/database_test.go | 10 +++++++--- 3 files changed, 25 insertions(+), 9 deletions(-) diff --git a/api/handlers.go b/api/handlers.go index ac7e309..fafe440 100644 --- a/api/handlers.go +++ b/api/handlers.go @@ -41,6 +41,12 @@ const ( sortDescending = "desc" ) +var ( + // errResolve is the error returned when we failed to resolve a skylink, + // indicating skyd failure + errResolve = errors.New("failed to resolve skylink") +) + type ( // BlockPOST describes a request to the /block endpoint. BlockPOST struct { @@ -254,7 +260,13 @@ func (api *API) handleBlockRequest(ctx context.Context, w http.ResponseWriter, b // Resolve the post body into a hash hash, err := api.resolveHash(bp) if err != nil { - WriteError(w, errors.AddContext(err, "failed to resolve hash"), http.StatusBadRequest) + // return an internal server error if the resolve failed due to skyd + // either being down or behaving unexpectedly + code := http.StatusBadRequest + if errors.Contains(err, errResolve) { + code = http.StatusInternalServerError + } + WriteError(w, errors.AddContext(err, "failed to resolve hash"), code) return } @@ -331,12 +343,12 @@ func (api *API) resolveHash(bp BlockPOST) (crypto.Hash, error) { // resolve the skylink skylink, err = api.staticSkydClient.ResolveSkylink(skylink) if err != nil { - return crypto.Hash{}, errors.AddContext(err, "failed to resolve skylink") + return crypto.Hash{}, errors.Compose(err, errResolve) } // sanity check the skylink is a v1 skylink if !skylink.IsSkylinkV1() { - return crypto.Hash{}, errors.AddContext(err, "failed to resolve skylink") + return crypto.Hash{}, errors.Compose(err, errResolve) } // return the hash @@ -414,5 +426,5 @@ func parseListParameters(query url.Values) (int, int, int, error) { // WriteError wraps WriteError from the skyd node api func WriteError(w http.ResponseWriter, err error, code int) { - skyapi.WriteError(w, skyapi.Error{Message: err.Error()}, http.StatusBadRequest) + skyapi.WriteError(w, skyapi.Error{Message: err.Error()}, code) } diff --git a/database/database.go b/database/database.go index b3ac897..be2fc24 100644 --- a/database/database.go +++ b/database/database.go @@ -20,10 +20,10 @@ import ( const ( // MongoDefaultTimeout is the timeout for the context used in testing // whenever a context is sent to mongo - MongoDefaultTimeout = 10 * time.Minute + MongoDefaultTimeout = time.Minute // mongoIndexCreateTimeout is the timeout used when creating indices - mongoIndexCreateTimeout = 10 * time.Minute + mongoIndexCreateTimeout = time.Minute // mongoTestUsername is the username used for the test database. mongoTestUsername = "admin" diff --git a/database/database_test.go b/database/database_test.go index 68aca20..657b5da 100644 --- a/database/database_test.go +++ b/database/database_test.go @@ -10,6 +10,7 @@ import ( "testing" "time" + "gitlab.com/NebulousLabs/errors" "gitlab.com/SkynetLabs/skyd/skymodules" "go.mongodb.org/mongo-driver/bson" "go.mongodb.org/mongo-driver/mongo" @@ -416,25 +417,28 @@ func testMarkFailed(t *testing.T) { } // insert two regular documents and one invalid one - db.CreateBlockedSkylink(ctx, &BlockedSkylink{ + err1 := db.CreateBlockedSkylink(ctx, &BlockedSkylink{ Hash: HashBytes([]byte("skylink_1")), Reporter: Reporter{}, Tags: []string{"tag_1"}, TimestampAdded: time.Now().UTC(), }) - db.CreateBlockedSkylink(ctx, &BlockedSkylink{ + err2 := db.CreateBlockedSkylink(ctx, &BlockedSkylink{ Hash: HashBytes([]byte("skylink_2")), Reporter: Reporter{}, Tags: []string{"tag_1"}, TimestampAdded: time.Now().UTC(), }) - db.CreateBlockedSkylink(ctx, &BlockedSkylink{ + err3 := db.CreateBlockedSkylink(ctx, &BlockedSkylink{ Hash: HashBytes([]byte("skylink_3")), Reporter: Reporter{}, Tags: []string{"tag_1"}, TimestampAdded: time.Now().UTC(), Invalid: true, }) + if err := errors.Compose(err1, err2, err3); err != nil { + t.Fatal(err) + } // fetch a cursor that holds all docs c, err := db.staticDB.Collection(collSkylinks).Find(ctx, bson.M{})