diff --git a/examples/raw-socket/raw-socket-client/main.go b/examples/raw-socket/raw-socket-client/main.go new file mode 100644 index 0000000..6199192 --- /dev/null +++ b/examples/raw-socket/raw-socket-client/main.go @@ -0,0 +1,45 @@ +package main + +import ( + "flag" + "log" + "time" + + "gopkg.in/jcelliott/turnpike.v2" +) + +var ( + addr string +) + +func init() { + flag.StringVar(&addr, "addr", ":9000", "address to connect to") + flag.Parse() +} + +func gotMessage(args []interface{}, kwargs map[string]interface{}) { + log.Printf("Got message: %s %s\n", args[0], args[1]) +} + +func main() { + // turnpike.Debug() + + log.Println("New client") + c, err := turnpike.NewRawSocketClient(addr) + if err != nil { + log.Fatalln(err) + } + _, err = c.JoinRealm("turnpike.example", turnpike.ALLROLES, nil) + if err != nil { + log.Fatalln("Error joining realm:", err) + } + + err = c.Subscribe("messages", gotMessage) + if err != nil { + log.Fatalln("Error subscribing to message channel:", err) + } + + for now := range time.Tick(time.Second * 5) { + c.Publish("messages", []interface{}{now.String(), flag.Args()[0]}, nil) + } +} diff --git a/examples/raw-socket/raw-socket-server/main.go b/examples/raw-socket/raw-socket-server/main.go new file mode 100644 index 0000000..df7903b --- /dev/null +++ b/examples/raw-socket/raw-socket-server/main.go @@ -0,0 +1,31 @@ +package main + +import ( + "flag" + "log" + "net" + + "gopkg.in/jcelliott/turnpike.v2" +) + +var ( + addr string +) + +func init() { + flag.StringVar(&addr, "addr", ":9000", "address to listen on for raw socket connections") + flag.Parse() +} + +func main() { + turnpike.Debug() + + l, err := net.Listen("tcp", addr) + if err != nil { + log.Fatalln(err) + } + + log.Println("Raw Socket server listening on", l.Addr()) + s := turnpike.NewBasicRawSocketServer("turnpike.example") + s.HandleListener(l) +} diff --git a/raw_socket.go b/raw_socket.go new file mode 100644 index 0000000..0576319 --- /dev/null +++ b/raw_socket.go @@ -0,0 +1,188 @@ +package turnpike + +import ( + "errors" + "fmt" + "io" + "io/ioutil" + "net" +) + +const ( + magic = 0x7f +) + +const ( + rawSocketJSON = 1 + rawSocketMsgpack = 2 +) + +type rawSocketPeer struct { + serializer Serializer + conn net.Conn + messages chan Message + maxLength int +} + +func intToBytes(i int) [3]byte { + return [3]byte{ + byte((i >> 16) & 0xff), + byte((i >> 8) & 0xff), + byte(i & 0xff), + } +} + +func bytesToInt(arr []byte) (val int) { + shift := uint(8 * (len(arr) - 1)) + for _, b := range arr { + val |= int(uint(b) << shift) + shift -= 8 + } + return +} + +func (ep *rawSocketPeer) Send(msg Message) error { + b, err := ep.serializer.Serialize(msg) + if err != nil { + return err + } + + if len(b) > ep.maxLength { + return fmt.Errorf("message too big: %d > %d", len(b), ep.maxLength) + } + + arr := intToBytes(len(b)) + header := []byte{0x0, arr[0], arr[1], arr[2]} + _, err = ep.conn.Write(header) + if err == nil { + _, err = ep.conn.Write(b) + } + return err +} +func (ep *rawSocketPeer) Receive() <-chan Message { + return ep.messages +} +func (ep *rawSocketPeer) Close() error { + return ep.conn.Close() +} + +func (ep *rawSocketPeer) handleMessages() { + for { + var header [4]byte + if _, err := ep.conn.Read(header[:]); err != nil { + ep.conn.Close() + return + } + + length := bytesToInt(header[1:]) + if length > ep.maxLength { + // TODO: handle error more nicely? + ep.conn.Close() + return + } + switch header[0] & 0x7 { + // WAMP message + case 0: + buf := make([]byte, length) + _, err := ep.conn.Read(buf) + if err != nil { + ep.conn.Close() + return + } + msg, err := ep.serializer.Deserialize(buf) + if err != nil { + // TODO: handle error + log.Println("Error deserializing message:", err) + } else { + ep.messages <- msg + } + // PING + case 1: + header[0] = 0x02 + _, err := ep.conn.Write(header[:]) + if err != nil { + ep.conn.Close() + return + } + _, err = io.CopyN(ep.conn, ep.conn, int64(length)) + if err != nil { + ep.conn.Close() + return + } + // PONG + case 2: + _, err := io.CopyN(ioutil.Discard, ep.conn, int64(length)) + if err != nil { + ep.conn.Close() + return + } + } + } +} + +// TODO: rename me +func toLength(b byte) int { + // lengths are specified as a 4-bit number + // and represent values between 2**9 and 2**24 + return (2 << 8) << b +} + +func (ep *rawSocketPeer) handshakeClient() error { + const serializer = rawSocketMsgpack + const length = 0xf + + if _, err := ep.conn.Write([]byte{magic, length<<4 | serializer, 0, 0}); err != nil { + return err + } + var buf [4]byte + if _, err := ep.conn.Read(buf[:]); err != nil { + return err + } + if buf[0] != magic { + return errors.New("unknown protocol: first byte received not the WAMP magic value") + } + if buf[1]&0xf == 0 { + errCode := buf[1] >> 4 + switch errCode { + case 0: + return errors.New("serializer unsupported") + case 1: + return errors.New("maximum message length unsupported") + case 2: + return errors.New("use of reserved bits (unsupported feature)") + case 3: + return errors.New("maximum connection count reached") + default: + return fmt.Errorf("unknown error: %d", errCode) + } + } + if buf[1]&0xf != serializer { + return errors.New("serializer mismatch: server responded with different serializer than requested") + } + // TODO: allow server to set this lower? + if buf[1]>>4 != length { + return fmt.Errorf("length mismatch: requested: %d, responded: %d", length, buf[1]>>4) + } + ep.maxLength = toLength(length) + return nil +} + +func NewRawSocketClient(url string) (*Client, error) { + conn, err := net.Dial("tcp", url) + if err != nil { + return nil, err + } + + peer := &rawSocketPeer{ + conn: conn, + serializer: new(MessagePackSerializer), + messages: make(chan Message), + } + + err = peer.handshakeClient() + if err != nil { + return nil, err + } + go peer.handleMessages() + return NewClient(peer), nil +} diff --git a/raw_socket_server.go b/raw_socket_server.go new file mode 100644 index 0000000..c326a24 --- /dev/null +++ b/raw_socket_server.go @@ -0,0 +1,64 @@ +package turnpike + +import ( + "net" +) + +type RawSocketServer struct { + Router +} + +func (s *RawSocketServer) handle(conn net.Conn) { + var header [4]byte + _, err := conn.Read(header[:]) + if err != nil { + conn.Close() + return + } + if header[0] != magic { + log.Println("unknown protocol: first byte received not the WAMP magic value") + conn.Close() + return + } + serializer := header[1] & 0x0f + peer := &rawSocketPeer{ + conn: conn, + messages: make(chan Message), + maxLength: toLength(header[1] >> 4), + } + switch serializer { + case rawSocketMsgpack: + peer.serializer = new(MessagePackSerializer) + case rawSocketJSON: + peer.serializer = new(JSONSerializer) + } + + _, err = conn.Write([]byte{magic, header[1], 0, 0}) + if err != nil { + conn.Close() + return + } + + go peer.handleMessages() + + s.Accept(peer) +} + +func (s *RawSocketServer) HandleListener(l net.Listener) { + for { + conn, err := l.Accept() + if err != nil { + l.Close() + return + } + go s.handle(conn) + } +} + +func NewBasicRawSocketServer(realms ...string) *RawSocketServer { + s := &RawSocketServer{Router: NewDefaultRouter()} + for _, realm := range realms { + s.RegisterRealm(URI(realm), Realm{}) + } + return s +} diff --git a/raw_socket_test.go b/raw_socket_test.go new file mode 100644 index 0000000..b52fac4 --- /dev/null +++ b/raw_socket_test.go @@ -0,0 +1,68 @@ +package turnpike + +import ( + "testing" + + . "github.com/smartystreets/goconvey/convey" +) + +func TestToLength(t *testing.T) { + // exhaustive list of valid lengths + exp := map[byte]int{ + 0: 2 << 8, + 1: 2 << 9, + 2: 2 << 10, + 3: 2 << 11, + 4: 2 << 12, + 5: 2 << 13, + 6: 2 << 14, + 7: 2 << 15, + 8: 2 << 16, + 9: 2 << 17, + 10: 2 << 18, + 11: 2 << 19, + 12: 2 << 20, + 13: 2 << 21, + 14: 2 << 22, + 15: 2 << 23, + } + + Convey("For every valid length value", t, func() { + for b, v := range exp { + So(toLength(b), ShouldEqual, v) + } + }) +} + +func TestIntToBytes(t *testing.T) { + Convey("When setting a number that fits in a byte", t, func() { + val := 56 + arr := intToBytes(val) + So(len(arr), ShouldEqual, 3) + So(arr[0], ShouldEqual, 0) + So(arr[1], ShouldEqual, 0) + So(arr[2], ShouldEqual, val) + }) + + Convey("When setting a number that fits in a 24-bit number", t, func() { + val := 2 << 20 + arr := intToBytes(val) + So(len(arr), ShouldEqual, 3) + So(arr[0], ShouldEqual, val>>16) + So(arr[1], ShouldEqual, 0) + So(arr[2], ShouldEqual, 0) + }) +} + +func TestBytesToInt(t *testing.T) { + Convey("When setting an array with only the low byte set", t, func() { + arr := []byte{0, 0, 56} + val := bytesToInt(arr) + So(val, ShouldEqual, arr[2]) + }) + Convey("When setting an array with only the high byte set", t, func() { + arr := []byte{56, 0, 0} + val := bytesToInt(arr) + So(val, ShouldEqual, uint(arr[0])<<16) + }) +}