Skip to content
This repository has been archived by the owner on Feb 24, 2021. It is now read-only.

Commit

Permalink
refactor secio handshake
Browse files Browse the repository at this point in the history
1. Respect context. Before, we didn't respect the context when writing/reading
the final nonce.
2. Don't make assumptions about connection buffering. Read/write handshake
messages in parallel.
3. Return messages from ReadMsg to the buffer pool.
4. Close the connection on timeout. We can't reuse it at this point as there may
be an outstanding writer/reader (and we don't allow parallel reads/writes).
5. Don't assume that Write always writes the full buffer.
  • Loading branch information
Stebalien committed Dec 15, 2017
1 parent fffa3e7 commit e4ff14c
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 58 deletions.
65 changes: 47 additions & 18 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"time"

proto "github.com/gogo/protobuf/proto"
logging "github.com/ipfs/go-log"
ci "github.com/libp2p/go-libp2p-crypto"
peer "github.com/libp2p/go-libp2p-peer"
Expand Down Expand Up @@ -108,6 +109,26 @@ func hashSha256(data []byte) mh.Multihash {
// keys, IDs, and initiate communication, assigning all necessary params.
// requires the duplex channel to be a msgio.ReadWriter (for framed messaging)
func (s *secureSession) runHandshake(ctx context.Context) error {
defer log.EventBegin(ctx, "secureHandshake", s).Done()

result := make(chan error, 1)
go func() {
// do *not* close the channel (will look like a success).
result <- s.runHandshakeSync()
}()

var err error
select {
case <-ctx.Done():
// State unknown. We *have* to close this.
s.insecure.Close()
err = ctx.Err()
case err = <-result:
}
return err
}

func (s *secureSession) runHandshakeSync() error {
// =============================================================================
// step 1. Propose -- propose cipher suite + send pubkeys + nonce

Expand All @@ -119,8 +140,6 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
return err
}

defer log.EventBegin(ctx, "secureHandshake", s).Done()

s.local.permanentPubKey = s.localKey.GetPublic()
myPubKeyBytes, err := s.local.permanentPubKey.Bytes()
if err != nil {
Expand All @@ -137,18 +156,24 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
// log.Debugf("1.0 Propose: nonce:%s exchanges:%s ciphers:%s hashes:%s",
// nonceOut, SupportedExchanges, SupportedCiphers, SupportedHashes)

// Send Propose packet (respects ctx)
proposeOutBytes, err := writeMsgCtx(ctx, s.insecureM, proposeOut)
// Marshal our propose packet
proposeOutBytes, err := proto.Marshal(proposeOut)
if err != nil {
return err
}

// Receive + Parse their Propose packet and generate an Exchange packet.
proposeIn := new(pb.Propose)
proposeInBytes, err := readMsgCtx(ctx, s.insecureM, proposeIn)
// Send Propose packet and Receive their Propose packet
proposeInBytes, err := readWriteMsg(s.insecureM, proposeOutBytes)
if err != nil {
return err
}
defer s.insecureM.ReleaseMsg(proposeInBytes)

// Parse their propose packet
proposeIn := new(pb.Propose)
if err = proto.Unmarshal(proposeInBytes, proposeIn); err != nil {
return err
}

// log.Debugf("1.0.1 Propose recv: nonce:%s exchanges:%s ciphers:%s hashes:%s",
// proposeIn.GetRand(), proposeIn.GetExchanges(), proposeIn.GetCiphers(), proposeIn.GetHashes())
Expand Down Expand Up @@ -230,14 +255,22 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
return err
}

// Send Propose packet (respects ctx)
if _, err := writeMsgCtx(ctx, s.insecureM, exchangeOut); err != nil {
// Marshal our exchange packet
exchangeOutBytes, err := proto.Marshal(exchangeOut)
if err != nil {
return err
}

// Receive + Parse their Exchange packet.
// Send Exchange packet and receive their Exchange packet
exchangeInBytes, err := readWriteMsg(s.insecureM, exchangeOutBytes)
if err != nil {
return err
}
defer s.insecureM.ReleaseMsg(exchangeInBytes)

// Parse their Exchange packet.
exchangeIn := new(pb.Exchange)
if _, err := readMsgCtx(ctx, s.insecureM, exchangeIn); err != nil {
if err = proto.Unmarshal(exchangeInBytes, exchangeIn); err != nil {
return err
}

Expand Down Expand Up @@ -317,15 +350,11 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
s.secure = msgio.Combine(w, r).(msgio.ReadWriteCloser)

// log.Debug("3.0 finish. sending: %v", proposeIn.GetRand())
// send their Nonce.
if _, err := s.secure.Write(proposeIn.GetRand()); err != nil {
return fmt.Errorf("Failed to write Finish nonce: %s", err)
}

// read our Nonce
nonceOut2 := make([]byte, len(nonceOut))
if _, err := io.ReadFull(s.secure, nonceOut2); err != nil {
return fmt.Errorf("Failed to read Finish nonce: %s", err)
// send their Nonce and receive ours
if err := readWriteFull(s.secure, proposeIn.GetRand(), nonceOut2); err != nil {
return err
}

// log.Debug("3.0 finish.\n\texpect: %v\n\tactual: %v", nonceOut, nonceOut2)
Expand Down
75 changes: 35 additions & 40 deletions rw.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package secio

import (
"context"
"crypto/cipher"
"crypto/hmac"
"encoding/binary"
Expand All @@ -10,7 +9,6 @@ import (
"io"
"sync"

proto "github.com/gogo/protobuf/proto"
msgio "github.com/libp2p/go-msgio"
mpool "github.com/libp2p/go-msgio/mpool"
)
Expand Down Expand Up @@ -243,53 +241,50 @@ func (r *etmReader) ReleaseMsg(b []byte) {
r.msg.ReleaseMsg(b)
}

// writeMsgCtx is used by the
func writeMsgCtx(ctx context.Context, w msgio.Writer, msg proto.Message) ([]byte, error) {
enc, err := proto.Marshal(msg)
if err != nil {
return nil, err
}
// read and write a message at the same time.
func readWriteMsg(c msgio.ReadWriter, out []byte) ([]byte, error) {
wresult := make(chan error)
go func() {
wresult <- c.WriteMsg(out)
}()

// write in a goroutine so we can exit when our context is cancelled.
done := make(chan error)
go func(m []byte) {
err := w.WriteMsg(m)
select {
case done <- err:
case <-ctx.Done():
}
}(enc)
msg, err1 := c.ReadMsg()

select {
case <-ctx.Done():
return nil, ctx.Err()
case e := <-done:
return enc, e
// Always wait for the read to finish.
err2 := <-wresult

if err1 != nil {
return nil, err1
}
if err2 != nil {
c.ReleaseMsg(msg)
return nil, err2
}
return msg, nil
}

func readMsgCtx(ctx context.Context, r msgio.Reader, p proto.Message) ([]byte, error) {
var msg []byte

// read in a goroutine so we can exit when our context is cancelled.
done := make(chan error)
func readWriteFull(c msgio.ReadWriter, out []byte, in []byte) error {
wresult := make(chan error)
go func() {
var err error
msg, err = r.ReadMsg()
select {
case done <- err:
case <-ctx.Done():
for len(out) > 0 {
n, err := c.Write(out)
if err != nil {
wresult <- err
return
}
out = out[n:]
}
wresult <- nil
}()
_, err1 := io.ReadFull(c, in)
err2 := <-wresult

select {
case <-ctx.Done():
return nil, ctx.Err()
case e := <-done:
if e != nil {
return nil, e
}
if err1 != nil {
return err1
}

return msg, proto.Unmarshal(msg, p)
if err2 != nil {
return err2
}
return nil
}

0 comments on commit e4ff14c

Please sign in to comment.