Skip to content
This repository has been archived by the owner on Feb 24, 2021. It is now read-only.

refactor secio handshake #25

Merged
merged 4 commits into from
Dec 15, 2017
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
79 changes: 57 additions & 22 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"io"
"time"

proto "github.com/gogo/protobuf/proto"
logging "github.com/ipfs/go-log"
ci "github.com/libp2p/go-libp2p-crypto"
peer "github.com/libp2p/go-libp2p-peer"
Expand All @@ -25,6 +26,9 @@ var ErrUnsupportedKeyType = errors.New("unsupported key type")
// ErrClosed signals the closing of a connection.
var ErrClosed = errors.New("connection closed")

// ErrBadSig signals that the peer sent us a handshake packet with a bad signature.
var ErrBadSig = errors.New("bad signature")

// ErrEcho is returned when we're attempting to handshake with the same keys and nonces.
var ErrEcho = errors.New("same keys and nonces. one side talking to self")

Expand Down Expand Up @@ -105,6 +109,26 @@ func hashSha256(data []byte) mh.Multihash {
// keys, IDs, and initiate communication, assigning all necessary params.
// requires the duplex channel to be a msgio.ReadWriter (for framed messaging)
func (s *secureSession) runHandshake(ctx context.Context) error {
defer log.EventBegin(ctx, "secureHandshake", s).Done()

result := make(chan error, 1)
go func() {
// do *not* close the channel (will look like a success).
result <- s.runHandshakeSync()
}()

var err error
select {
case <-ctx.Done():
// State unknown. We *have* to close this.
s.insecure.Close()
err = ctx.Err()
case err = <-result:
}
return err
}

func (s *secureSession) runHandshakeSync() error {
// =============================================================================
// step 1. Propose -- propose cipher suite + send pubkeys + nonce

Expand All @@ -116,8 +140,6 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
return err
}

defer log.EventBegin(ctx, "secureHandshake", s).Done()

s.local.permanentPubKey = s.localKey.GetPublic()
myPubKeyBytes, err := s.local.permanentPubKey.Bytes()
if err != nil {
Expand All @@ -134,18 +156,24 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
// log.Debugf("1.0 Propose: nonce:%s exchanges:%s ciphers:%s hashes:%s",
// nonceOut, SupportedExchanges, SupportedCiphers, SupportedHashes)

// Send Propose packet (respects ctx)
proposeOutBytes, err := writeMsgCtx(ctx, s.insecureM, proposeOut)
// Marshal our propose packet
proposeOutBytes, err := proto.Marshal(proposeOut)
if err != nil {
return err
}

// Receive + Parse their Propose packet and generate an Exchange packet.
proposeIn := new(pb.Propose)
proposeInBytes, err := readMsgCtx(ctx, s.insecureM, proposeIn)
// Send Propose packet and Receive their Propose packet
proposeInBytes, err := readWriteMsg(s.insecureM, proposeOutBytes)
if err != nil {
return err
}
defer s.insecureM.ReleaseMsg(proposeInBytes)

// Parse their propose packet
proposeIn := new(pb.Propose)
if err = proto.Unmarshal(proposeInBytes, proposeIn); err != nil {
return err
}

// log.Debugf("1.0.1 Propose recv: nonce:%s exchanges:%s ciphers:%s hashes:%s",
// proposeIn.GetRand(), proposeIn.GetExchanges(), proposeIn.GetCiphers(), proposeIn.GetHashes())
Expand Down Expand Up @@ -208,6 +236,9 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
// Generate EphemeralPubKey
var genSharedKey ci.GenSharedKey
s.local.ephemeralPubKey, genSharedKey, err = ci.GenerateEKeyPair(s.local.curveT)
if err != nil {
return err
}

// Gather corpus to sign.
selectionOut := new(bytes.Buffer)
Expand All @@ -224,14 +255,22 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
return err
}

// Send Propose packet (respects ctx)
if _, err := writeMsgCtx(ctx, s.insecureM, exchangeOut); err != nil {
// Marshal our exchange packet
exchangeOutBytes, err := proto.Marshal(exchangeOut)
if err != nil {
return err
}

// Receive + Parse their Exchange packet.
// Send Exchange packet and receive their Exchange packet
exchangeInBytes, err := readWriteMsg(s.insecureM, exchangeOutBytes)
if err != nil {
return err
}
defer s.insecureM.ReleaseMsg(exchangeInBytes)

// Parse their Exchange packet.
exchangeIn := new(pb.Exchange)
if _, err := readMsgCtx(ctx, s.insecureM, exchangeIn); err != nil {
if err = proto.Unmarshal(exchangeInBytes, exchangeIn); err != nil {
return err
}

Expand All @@ -256,9 +295,8 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
}

if !sigOK {
err := errors.New("Bad signature!")
// log.Error("2.1 Verify: failed: %s", err)
return err
// log.Error("2.1 Verify: failed: %s", ErrBadSig)
return ErrBadSig
}
// log.Debugf("2.1 Verify: signature verified.")

Expand Down Expand Up @@ -312,16 +350,13 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
s.secure = msgio.Combine(w, r).(msgio.ReadWriteCloser)

// log.Debug("3.0 finish. sending: %v", proposeIn.GetRand())
// send their Nonce.
if _, err := s.secure.Write(proposeIn.GetRand()); err != nil {
return fmt.Errorf("Failed to write Finish nonce: %s", err)
}

// read our Nonce
nonceOut2 := make([]byte, len(nonceOut))
if _, err := io.ReadFull(s.secure, nonceOut2); err != nil {
return fmt.Errorf("Failed to read Finish nonce: %s", err)
// send their Nonce and receive ours
nonceOut2, err := readWriteMsg(s.secure, proposeIn.GetRand())
if err != nil {
return err
}
defer s.secure.ReleaseMsg(nonceOut2)

// log.Debug("3.0 finish.\n\texpect: %v\n\tactual: %v", nonceOut, nonceOut2)
if !bytes.Equal(nonceOut, nonceOut2) {
Expand Down
64 changes: 17 additions & 47 deletions rw.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package secio

import (
"context"
"crypto/cipher"
"crypto/hmac"
"encoding/binary"
Expand All @@ -10,7 +9,6 @@ import (
"io"
"sync"

proto "github.com/gogo/protobuf/proto"
msgio "github.com/libp2p/go-msgio"
mpool "github.com/libp2p/go-msgio/mpool"
)
Expand Down Expand Up @@ -201,6 +199,7 @@ func (r *etmReader) ReadMsg() ([]byte, error) {

n, err := r.macCheckThenDecrypt(msg)
if err != nil {
r.msg.ReleaseMsg(msg)
return nil, err
}
return msg[:n], nil
Expand Down Expand Up @@ -243,53 +242,24 @@ func (r *etmReader) ReleaseMsg(b []byte) {
r.msg.ReleaseMsg(b)
}

// writeMsgCtx is used by the
func writeMsgCtx(ctx context.Context, w msgio.Writer, msg proto.Message) ([]byte, error) {
enc, err := proto.Marshal(msg)
if err != nil {
return nil, err
}

// write in a goroutine so we can exit when our context is cancelled.
done := make(chan error)
go func(m []byte) {
err := w.WriteMsg(m)
select {
case done <- err:
case <-ctx.Done():
}
}(enc)

select {
case <-ctx.Done():
return nil, ctx.Err()
case e := <-done:
return enc, e
}
}

func readMsgCtx(ctx context.Context, r msgio.Reader, p proto.Message) ([]byte, error) {
var msg []byte

// read in a goroutine so we can exit when our context is cancelled.
done := make(chan error)
// read and write a message at the same time.
func readWriteMsg(c msgio.ReadWriter, out []byte) ([]byte, error) {
wresult := make(chan error)
go func() {
var err error
msg, err = r.ReadMsg()
select {
case done <- err:
case <-ctx.Done():
}
wresult <- c.WriteMsg(out)
}()

select {
case <-ctx.Done():
return nil, ctx.Err()
case e := <-done:
if e != nil {
return nil, e
}
}
msg, err1 := c.ReadMsg()

// Always wait for the read to finish.
err2 := <-wresult

return msg, proto.Unmarshal(msg, p)
if err1 != nil {
return nil, err1
}
if err2 != nil {
c.ReleaseMsg(msg)
return nil, err2
}
return msg, nil
}