From e4ff14c25eb8487922a48da85a163d6fc46146d4 Mon Sep 17 00:00:00 2001 From: Steven Allen Date: Thu, 14 Dec 2017 16:06:52 -0800 Subject: [PATCH] refactor secio handshake 1. Respect context. Before, we didn't respect the context when writing/reading the final nonce. 2. Don't make assumptions about connection buffering. Read/write handshake messages in parallel. 3. Return messages from ReadMsg to the buffer pool. 4. Close the connection on timeout. We can't reuse it at this point as there may be an outstanding writer/reader (and we don't allow parallel reads/writes). 5. Don't assume that Write always writes the full buffer. --- protocol.go | 65 +++++++++++++++++++++++++++++++++------------- rw.go | 75 +++++++++++++++++++++++++---------------------------- 2 files changed, 82 insertions(+), 58 deletions(-) diff --git a/protocol.go b/protocol.go index d2f3056..29f20b4 100644 --- a/protocol.go +++ b/protocol.go @@ -9,6 +9,7 @@ import ( "io" "time" + proto "github.com/gogo/protobuf/proto" logging "github.com/ipfs/go-log" ci "github.com/libp2p/go-libp2p-crypto" peer "github.com/libp2p/go-libp2p-peer" @@ -108,6 +109,26 @@ func hashSha256(data []byte) mh.Multihash { // keys, IDs, and initiate communication, assigning all necessary params. // requires the duplex channel to be a msgio.ReadWriter (for framed messaging) func (s *secureSession) runHandshake(ctx context.Context) error { + defer log.EventBegin(ctx, "secureHandshake", s).Done() + + result := make(chan error, 1) + go func() { + // do *not* close the channel (will look like a success). + result <- s.runHandshakeSync() + }() + + var err error + select { + case <-ctx.Done(): + // State unknown. We *have* to close this. + s.insecure.Close() + err = ctx.Err() + case err = <-result: + } + return err +} + +func (s *secureSession) runHandshakeSync() error { // ============================================================================= // step 1. Propose -- propose cipher suite + send pubkeys + nonce @@ -119,8 +140,6 @@ func (s *secureSession) runHandshake(ctx context.Context) error { return err } - defer log.EventBegin(ctx, "secureHandshake", s).Done() - s.local.permanentPubKey = s.localKey.GetPublic() myPubKeyBytes, err := s.local.permanentPubKey.Bytes() if err != nil { @@ -137,18 +156,24 @@ func (s *secureSession) runHandshake(ctx context.Context) error { // log.Debugf("1.0 Propose: nonce:%s exchanges:%s ciphers:%s hashes:%s", // nonceOut, SupportedExchanges, SupportedCiphers, SupportedHashes) - // Send Propose packet (respects ctx) - proposeOutBytes, err := writeMsgCtx(ctx, s.insecureM, proposeOut) + // Marshal our propose packet + proposeOutBytes, err := proto.Marshal(proposeOut) if err != nil { return err } - // Receive + Parse their Propose packet and generate an Exchange packet. - proposeIn := new(pb.Propose) - proposeInBytes, err := readMsgCtx(ctx, s.insecureM, proposeIn) + // Send Propose packet and Receive their Propose packet + proposeInBytes, err := readWriteMsg(s.insecureM, proposeOutBytes) if err != nil { return err } + defer s.insecureM.ReleaseMsg(proposeInBytes) + + // Parse their propose packet + proposeIn := new(pb.Propose) + if err = proto.Unmarshal(proposeInBytes, proposeIn); err != nil { + return err + } // log.Debugf("1.0.1 Propose recv: nonce:%s exchanges:%s ciphers:%s hashes:%s", // proposeIn.GetRand(), proposeIn.GetExchanges(), proposeIn.GetCiphers(), proposeIn.GetHashes()) @@ -230,14 +255,22 @@ func (s *secureSession) runHandshake(ctx context.Context) error { return err } - // Send Propose packet (respects ctx) - if _, err := writeMsgCtx(ctx, s.insecureM, exchangeOut); err != nil { + // Marshal our exchange packet + exchangeOutBytes, err := proto.Marshal(exchangeOut) + if err != nil { return err } - // Receive + Parse their Exchange packet. + // Send Exchange packet and receive their Exchange packet + exchangeInBytes, err := readWriteMsg(s.insecureM, exchangeOutBytes) + if err != nil { + return err + } + defer s.insecureM.ReleaseMsg(exchangeInBytes) + + // Parse their Exchange packet. exchangeIn := new(pb.Exchange) - if _, err := readMsgCtx(ctx, s.insecureM, exchangeIn); err != nil { + if err = proto.Unmarshal(exchangeInBytes, exchangeIn); err != nil { return err } @@ -317,15 +350,11 @@ func (s *secureSession) runHandshake(ctx context.Context) error { s.secure = msgio.Combine(w, r).(msgio.ReadWriteCloser) // log.Debug("3.0 finish. sending: %v", proposeIn.GetRand()) - // send their Nonce. - if _, err := s.secure.Write(proposeIn.GetRand()); err != nil { - return fmt.Errorf("Failed to write Finish nonce: %s", err) - } - // read our Nonce nonceOut2 := make([]byte, len(nonceOut)) - if _, err := io.ReadFull(s.secure, nonceOut2); err != nil { - return fmt.Errorf("Failed to read Finish nonce: %s", err) + // send their Nonce and receive ours + if err := readWriteFull(s.secure, proposeIn.GetRand(), nonceOut2); err != nil { + return err } // log.Debug("3.0 finish.\n\texpect: %v\n\tactual: %v", nonceOut, nonceOut2) diff --git a/rw.go b/rw.go index 150b55f..1390d1b 100644 --- a/rw.go +++ b/rw.go @@ -1,7 +1,6 @@ package secio import ( - "context" "crypto/cipher" "crypto/hmac" "encoding/binary" @@ -10,7 +9,6 @@ import ( "io" "sync" - proto "github.com/gogo/protobuf/proto" msgio "github.com/libp2p/go-msgio" mpool "github.com/libp2p/go-msgio/mpool" ) @@ -243,53 +241,50 @@ func (r *etmReader) ReleaseMsg(b []byte) { r.msg.ReleaseMsg(b) } -// writeMsgCtx is used by the -func writeMsgCtx(ctx context.Context, w msgio.Writer, msg proto.Message) ([]byte, error) { - enc, err := proto.Marshal(msg) - if err != nil { - return nil, err - } +// read and write a message at the same time. +func readWriteMsg(c msgio.ReadWriter, out []byte) ([]byte, error) { + wresult := make(chan error) + go func() { + wresult <- c.WriteMsg(out) + }() - // write in a goroutine so we can exit when our context is cancelled. - done := make(chan error) - go func(m []byte) { - err := w.WriteMsg(m) - select { - case done <- err: - case <-ctx.Done(): - } - }(enc) + msg, err1 := c.ReadMsg() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case e := <-done: - return enc, e + // Always wait for the read to finish. + err2 := <-wresult + + if err1 != nil { + return nil, err1 + } + if err2 != nil { + c.ReleaseMsg(msg) + return nil, err2 } + return msg, nil } -func readMsgCtx(ctx context.Context, r msgio.Reader, p proto.Message) ([]byte, error) { - var msg []byte - - // read in a goroutine so we can exit when our context is cancelled. - done := make(chan error) +func readWriteFull(c msgio.ReadWriter, out []byte, in []byte) error { + wresult := make(chan error) go func() { - var err error - msg, err = r.ReadMsg() - select { - case done <- err: - case <-ctx.Done(): + for len(out) > 0 { + n, err := c.Write(out) + if err != nil { + wresult <- err + return + } + out = out[n:] } + wresult <- nil }() + _, err1 := io.ReadFull(c, in) + err2 := <-wresult - select { - case <-ctx.Done(): - return nil, ctx.Err() - case e := <-done: - if e != nil { - return nil, e - } + if err1 != nil { + return err1 } - return msg, proto.Unmarshal(msg, p) + if err2 != nil { + return err2 + } + return nil }