diff --git a/protocol.go b/protocol.go index d2f3056..29f20b4 100644 --- a/protocol.go +++ b/protocol.go @@ -9,6 +9,7 @@ import ( "io" "time" + proto "github.com/gogo/protobuf/proto" logging "github.com/ipfs/go-log" ci "github.com/libp2p/go-libp2p-crypto" peer "github.com/libp2p/go-libp2p-peer" @@ -108,6 +109,26 @@ func hashSha256(data []byte) mh.Multihash { // keys, IDs, and initiate communication, assigning all necessary params. // requires the duplex channel to be a msgio.ReadWriter (for framed messaging) func (s *secureSession) runHandshake(ctx context.Context) error { + defer log.EventBegin(ctx, "secureHandshake", s).Done() + + result := make(chan error, 1) + go func() { + // do *not* close the channel (will look like a success). + result <- s.runHandshakeSync() + }() + + var err error + select { + case <-ctx.Done(): + // State unknown. We *have* to close this. + s.insecure.Close() + err = ctx.Err() + case err = <-result: + } + return err +} + +func (s *secureSession) runHandshakeSync() error { // ============================================================================= // step 1. Propose -- propose cipher suite + send pubkeys + nonce @@ -119,8 +140,6 @@ func (s *secureSession) runHandshake(ctx context.Context) error { return err } - defer log.EventBegin(ctx, "secureHandshake", s).Done() - s.local.permanentPubKey = s.localKey.GetPublic() myPubKeyBytes, err := s.local.permanentPubKey.Bytes() if err != nil { @@ -137,18 +156,24 @@ func (s *secureSession) runHandshake(ctx context.Context) error { // log.Debugf("1.0 Propose: nonce:%s exchanges:%s ciphers:%s hashes:%s", // nonceOut, SupportedExchanges, SupportedCiphers, SupportedHashes) - // Send Propose packet (respects ctx) - proposeOutBytes, err := writeMsgCtx(ctx, s.insecureM, proposeOut) + // Marshal our propose packet + proposeOutBytes, err := proto.Marshal(proposeOut) if err != nil { return err } - // Receive + Parse their Propose packet and generate an Exchange packet. - proposeIn := new(pb.Propose) - proposeInBytes, err := readMsgCtx(ctx, s.insecureM, proposeIn) + // Send Propose packet and Receive their Propose packet + proposeInBytes, err := readWriteMsg(s.insecureM, proposeOutBytes) if err != nil { return err } + defer s.insecureM.ReleaseMsg(proposeInBytes) + + // Parse their propose packet + proposeIn := new(pb.Propose) + if err = proto.Unmarshal(proposeInBytes, proposeIn); err != nil { + return err + } // log.Debugf("1.0.1 Propose recv: nonce:%s exchanges:%s ciphers:%s hashes:%s", // proposeIn.GetRand(), proposeIn.GetExchanges(), proposeIn.GetCiphers(), proposeIn.GetHashes()) @@ -230,14 +255,22 @@ func (s *secureSession) runHandshake(ctx context.Context) error { return err } - // Send Propose packet (respects ctx) - if _, err := writeMsgCtx(ctx, s.insecureM, exchangeOut); err != nil { + // Marshal our exchange packet + exchangeOutBytes, err := proto.Marshal(exchangeOut) + if err != nil { return err } - // Receive + Parse their Exchange packet. + // Send Exchange packet and receive their Exchange packet + exchangeInBytes, err := readWriteMsg(s.insecureM, exchangeOutBytes) + if err != nil { + return err + } + defer s.insecureM.ReleaseMsg(exchangeInBytes) + + // Parse their Exchange packet. exchangeIn := new(pb.Exchange) - if _, err := readMsgCtx(ctx, s.insecureM, exchangeIn); err != nil { + if err = proto.Unmarshal(exchangeInBytes, exchangeIn); err != nil { return err } @@ -317,15 +350,11 @@ func (s *secureSession) runHandshake(ctx context.Context) error { s.secure = msgio.Combine(w, r).(msgio.ReadWriteCloser) // log.Debug("3.0 finish. sending: %v", proposeIn.GetRand()) - // send their Nonce. - if _, err := s.secure.Write(proposeIn.GetRand()); err != nil { - return fmt.Errorf("Failed to write Finish nonce: %s", err) - } - // read our Nonce nonceOut2 := make([]byte, len(nonceOut)) - if _, err := io.ReadFull(s.secure, nonceOut2); err != nil { - return fmt.Errorf("Failed to read Finish nonce: %s", err) + // send their Nonce and receive ours + if err := readWriteFull(s.secure, proposeIn.GetRand(), nonceOut2); err != nil { + return err } // log.Debug("3.0 finish.\n\texpect: %v\n\tactual: %v", nonceOut, nonceOut2) diff --git a/rw.go b/rw.go index 150b55f..1390d1b 100644 --- a/rw.go +++ b/rw.go @@ -1,7 +1,6 @@ package secio import ( - "context" "crypto/cipher" "crypto/hmac" "encoding/binary" @@ -10,7 +9,6 @@ import ( "io" "sync" - proto "github.com/gogo/protobuf/proto" msgio "github.com/libp2p/go-msgio" mpool "github.com/libp2p/go-msgio/mpool" ) @@ -243,53 +241,50 @@ func (r *etmReader) ReleaseMsg(b []byte) { r.msg.ReleaseMsg(b) } -// writeMsgCtx is used by the -func writeMsgCtx(ctx context.Context, w msgio.Writer, msg proto.Message) ([]byte, error) { - enc, err := proto.Marshal(msg) - if err != nil { - return nil, err - } +// read and write a message at the same time. +func readWriteMsg(c msgio.ReadWriter, out []byte) ([]byte, error) { + wresult := make(chan error) + go func() { + wresult <- c.WriteMsg(out) + }() - // write in a goroutine so we can exit when our context is cancelled. - done := make(chan error) - go func(m []byte) { - err := w.WriteMsg(m) - select { - case done <- err: - case <-ctx.Done(): - } - }(enc) + msg, err1 := c.ReadMsg() - select { - case <-ctx.Done(): - return nil, ctx.Err() - case e := <-done: - return enc, e + // Always wait for the read to finish. + err2 := <-wresult + + if err1 != nil { + return nil, err1 + } + if err2 != nil { + c.ReleaseMsg(msg) + return nil, err2 } + return msg, nil } -func readMsgCtx(ctx context.Context, r msgio.Reader, p proto.Message) ([]byte, error) { - var msg []byte - - // read in a goroutine so we can exit when our context is cancelled. - done := make(chan error) +func readWriteFull(c msgio.ReadWriter, out []byte, in []byte) error { + wresult := make(chan error) go func() { - var err error - msg, err = r.ReadMsg() - select { - case done <- err: - case <-ctx.Done(): + for len(out) > 0 { + n, err := c.Write(out) + if err != nil { + wresult <- err + return + } + out = out[n:] } + wresult <- nil }() + _, err1 := io.ReadFull(c, in) + err2 := <-wresult - select { - case <-ctx.Done(): - return nil, ctx.Err() - case e := <-done: - if e != nil { - return nil, e - } + if err1 != nil { + return err1 } - return msg, proto.Unmarshal(msg, p) + if err2 != nil { + return err2 + } + return nil }