Skip to content

Commit

Permalink
DTM sends data over graphsync for validated push requests (#665)
Browse files Browse the repository at this point in the history
* create channels when a request is received. register push request hook with graphsync. fix tests.
* better NewReaders
* use mutex lock around impl.channels access
* fix(datatransfer): fix test uncertainty
* fix a data race and also don't use random bytes in basic block which can fail
* privatize 3 funcs

with @hannahhoward
  • Loading branch information
shannonwells authored Nov 25, 2019
1 parent f3a6719 commit 9196db4
Show file tree
Hide file tree
Showing 6 changed files with 255 additions and 146 deletions.
95 changes: 94 additions & 1 deletion datatransfer/impl/graphsync/graphsync_impl.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package graphsyncimpl

import (
"bytes"
"context"
"errors"
"fmt"
"reflect"
"sync"

"github.com/ipfs/go-cid"
"github.com/ipfs/go-graphsync"
Expand Down Expand Up @@ -43,6 +46,7 @@ type graphsyncImpl struct {
dataTransferNetwork network.DataTransferNetwork
subscribers []datatransfer.Subscriber
validatedTypes map[string]validateType
channelsLk sync.RWMutex
channels map[datatransfer.ChannelID]datatransfer.ChannelState
gs graphsync.GraphExchange
peerID peer.ID
Expand All @@ -56,16 +60,75 @@ func NewGraphSyncDataTransfer(parent context.Context, host host.Host, gs graphsy
dataTransferNetwork,
nil,
make(map[string]validateType),
sync.RWMutex{},
make(map[datatransfer.ChannelID]datatransfer.ChannelState),
gs,
host.ID(),
0,
}
if err := gs.RegisterRequestReceivedHook(true, impl.gsReqRecdHook); err != nil {
log.Error(err)
return nil
}
receiver := &graphsyncReceiver{parent, impl}
dataTransferNetwork.SetDelegate(receiver)
return impl
}

// gsReqRecdHook is a graphsync.OnRequestReceivedHook hook
// if an incoming request does not match a previous push request, it returns an error.
func (impl *graphsyncImpl) gsReqRecdHook(p peer.ID, request graphsync.RequestData) ([]graphsync.ExtensionData, error) {
var resp []graphsync.ExtensionData
chid, _, err := impl.getChannelIDAndData(request)

extData := graphsync.ExtensionData{
Name: ExtensionDataTransfer,
Data: nil,
}
if err != nil {
return resp, err
}
if !impl.hasPushChannel(chid) {
return resp, errors.New("could not find push channel")
}
resp = append(resp, extData)
return resp, nil
}

// gsExtended is a small interface used by getChannelIDAndData
type gsExtended interface {
Extension(name graphsync.ExtensionName) ([]byte, bool)
}

// getChannelIDAndData extracts extension data and creates a channel id then returns
// both. Returns any errors.
func (impl *graphsyncImpl) getChannelIDAndData(extendedData gsExtended) (datatransfer.ChannelID, *ExtensionDataTransferData, error) {
data, ok := extendedData.Extension(ExtensionDataTransfer)
if !ok {
return datatransfer.ChannelID{}, nil, errors.New("extension not present")
}
unm, err := impl.unmarshalExtensionData(data)
if err != nil {
return datatransfer.ChannelID{}, nil, err
}
chid := datatransfer.ChannelID{
Initiator: impl.peerID,
ID: datatransfer.TransferID(unm.TransferID),
}
return chid, unm, nil
}

// unmarshalExtensionData instatiates an extension data struct & unmarshals data into i
func (impl *graphsyncImpl) unmarshalExtensionData(data []byte) (*ExtensionDataTransferData, error) {
var extStruct ExtensionDataTransferData

reader := bytes.NewReader(data)
if err := extStruct.UnmarshalCBOR(reader); err != nil {
return nil, err
}
return &extStruct, nil
}

// RegisterVoucherType registers a validator for the given voucher type
// returns error if:
// * voucher type does not implement voucher
Expand Down Expand Up @@ -124,7 +187,9 @@ func (impl *graphsyncImpl) OpenPullDataChannel(ctx context.Context, requestTo pe
func (impl *graphsyncImpl) createNewChannel(tid datatransfer.TransferID, baseCid cid.Cid, selector ipld.Node, voucher datatransfer.Voucher, initiator, dataSender, dataReceiver peer.ID) datatransfer.ChannelID {
chid := datatransfer.ChannelID{Initiator: initiator, ID: tid}
chst := datatransfer.ChannelState{Channel: datatransfer.NewChannel(0, baseCid, selector, voucher, dataSender, dataReceiver, 0)}
impl.channelsLk.Lock()
impl.channels[chid] = chst
impl.channelsLk.Unlock()
return chid
}

Expand Down Expand Up @@ -192,21 +257,49 @@ func (impl *graphsyncImpl) notifySubscribers(evt datatransfer.Event, cs datatran

// get all in progress transfers
func (impl *graphsyncImpl) InProgressChannels() map[datatransfer.ChannelID]datatransfer.ChannelState {
return impl.channels
impl.channelsLk.RLock()
defer impl.channelsLk.RUnlock()
channelsCopy := make(map[datatransfer.ChannelID]datatransfer.ChannelState, len(impl.channels))
for channelID, channelState := range impl.channels {
channelsCopy[channelID] = channelState
}
return channelsCopy
}

// hasPushChannel returns true if a channel with ID chid exists and is for a Push request.
func (impl *graphsyncImpl) hasPushChannel(chid datatransfer.ChannelID) bool {
return impl.getPushChannel(chid) != datatransfer.EmptyChannelState
}

// hasPullChannel returns true if a channel with ID chid exists and is for a Pull request.
func (impl *graphsyncImpl) hasPullChannel(chid datatransfer.ChannelID) bool {
return impl.getPullChannel(chid) != datatransfer.EmptyChannelState
}

// getPullChannel searches for a pull-type channel in the slice of channels with id `chid`.
// Returns datatransfer.EmptyChannelState if:
// * there is no channel with that id
// * it is not related to a pull request
func (impl *graphsyncImpl) getPullChannel(chid datatransfer.ChannelID) datatransfer.ChannelState {
impl.channelsLk.RLock()
defer impl.channelsLk.RUnlock()
channelState, ok := impl.channels[chid]
if !ok || channelState.Sender() == impl.peerID {
return datatransfer.EmptyChannelState
}
return channelState
}

func (impl *graphsyncImpl) getPushChannel(chid datatransfer.ChannelID) datatransfer.ChannelState {
impl.channelsLk.RLock()
defer impl.channelsLk.RUnlock()
channelState, ok := impl.channels[chid]
if !ok || channelState.Recipient() == impl.peerID {
return datatransfer.EmptyChannelState
}
return channelState
}

// generateTransferID() generates a unique-to-runtime TransferID for use in creating
// ChannelIDs
func (impl *graphsyncImpl) generateTransferID() datatransfer.TransferID {
Expand Down
Loading

0 comments on commit 9196db4

Please sign in to comment.