From 24e03831ebdb8ce0f21637f80fa29cbc7a4780e0 Mon Sep 17 00:00:00 2001 From: deneonet Date: Sun, 26 Jan 2025 12:11:42 +0100 Subject: [PATCH] unfinished clean up, server/client improvements + examples --- .gitignore | 1 - common.go | 1 + examples/chat/client/client.kr | Bin 0 -> 87 bytes examples/chat/client/main.go | 79 ++++++ examples/chat/packets/common.go | 16 ++ examples/chat/packets/gen.bat | 2 + .../chat/packets/initialization/common.go | 7 + .../chat/packets/initialization/packet.benc | 12 + .../packets/initialization/packet.benc.go | 198 ++++++++++++++ examples/chat/packets/message/packet.benc | 8 + examples/chat/packets/message/packet.benc.go | 104 ++++++++ examples/chat/server/main.go | 167 ++++++++++++ examples/chat/server/server.kc | Bin 0 -> 365 bytes gen/main.go | 2 +- go.mod | 4 +- go.sum | 2 + server.go | 241 ++++++++++-------- 17 files changed, 732 insertions(+), 112 deletions(-) create mode 100644 examples/chat/client/client.kr create mode 100644 examples/chat/client/main.go create mode 100644 examples/chat/packets/common.go create mode 100644 examples/chat/packets/gen.bat create mode 100644 examples/chat/packets/initialization/common.go create mode 100644 examples/chat/packets/initialization/packet.benc create mode 100644 examples/chat/packets/initialization/packet.benc.go create mode 100644 examples/chat/packets/message/packet.benc create mode 100644 examples/chat/packets/message/packet.benc.go create mode 100644 examples/chat/server/main.go create mode 100644 examples/chat/server/server.kc diff --git a/.gitignore b/.gitignore index 8a71b76..e69de29 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +0,0 @@ -_examples \ No newline at end of file diff --git a/common.go b/common.go index 1a4a140..1161400 100644 --- a/common.go +++ b/common.go @@ -28,4 +28,5 @@ var ( ErrInvalidHandshakePacket = errors.New("invalid handshake packet received") ErrBufTooSmall = errors.New("buffer is too small for the requested size") ErrInvalidRootKey = errors.New("invalid root key in client struct") + ErrDataExceededBufferSize = errors.New("received data size exceeded buffer size") ) diff --git a/examples/chat/client/client.kr b/examples/chat/client/client.kr new file mode 100644 index 0000000000000000000000000000000000000000..1a730e2dde50369d168307de275009dcf9c8ec51 GIT binary patch literal 87 zcmV-d0I2^000RL-0sv@U9NQD@+jM#G!Ld|A?domk(`h-QUH~M}DQTRxidONr_o9~@ ta|)n}bQir85k4^F^N#uLWew`K%_rEn>A0u_0s;sFblED;000000Rgv-CK~_% literal 0 HcmV?d00001 diff --git a/examples/chat/client/main.go b/examples/chat/client/main.go new file mode 100644 index 0000000..2f6a88f --- /dev/null +++ b/examples/chat/client/main.go @@ -0,0 +1,79 @@ +package main + +import ( + "errors" + "fmt" + "net" + "time" + + knet "kinetra.de/net" + "kinetra.de/net/examples/chat/packets" + "kinetra.de/net/examples/chat/packets/initialization" +) + +var ErrResponse = errors.New("response error") + +func main() { + client := &knet.Client{ + RootKeyFile: "client.kr", + BufferSize: 1024, + WriteDeadline: 10 * time.Second, + HandshakeReadDeadline: 500 * time.Millisecond, + HandshakeWriteDeadline: 500 * time.Millisecond, + } + + client.OnRead = func(conn net.Conn, info knet.ReadInfo) knet.AfterAction { + err := client.UnmarshalPacket(info.Data, func(id int, b []byte) (err error) { + switch id { + case packets.InitializationResponsePacket: + var response initialization.Response + if err = response.Unmarshal(b); err != nil { + return + } + + switch response.Data { + case initialization.ResponseUsernameTaken: + fmt.Println("Username is taken.") + case initialization.ResponseUsernameMissing: + fmt.Println("Username is missing.") + case initialization.ResponseSuccess: + return nil + } + + return ErrResponse + } + + return nil + }) + + if err != nil { + return knet.Close + } + + return knet.None + } + + client.OnSecureConnect = func(session knet.ClientSession) knet.AfterAction { + fmt.Println("Secure connection established.") + + username := initialization.Username{ + Data: "deneonet", + } + err := client.SendPacket(packets.InitializationUsernamePacket, &username) + if err != nil { + fmt.Println("Error sending username packet: ", err) + } + + return knet.None + } + + client.OnDisconnect = func() { + fmt.Println("Disconnected from server.") + } + + _, err := client.Connect("localhost:8080") + if err != nil { + fmt.Println("Error connecting to server:", err) + return + } +} diff --git a/examples/chat/packets/common.go b/examples/chat/packets/common.go new file mode 100644 index 0000000..23e9a39 --- /dev/null +++ b/examples/chat/packets/common.go @@ -0,0 +1,16 @@ +package packets + +const ( + InitializationUsernamePacket int = iota + InitializationResponsePacket + + MessagePacket + MessageResponsePacket +) + +type Message struct { + Username string + Message string + + HasDisconnected bool // After someone disconnected, everyone can reclaim that username, so in order to extinguish the "new user" from the "old users", a "disconnected" mark is appended to the message +} diff --git a/examples/chat/packets/gen.bat b/examples/chat/packets/gen.bat new file mode 100644 index 0000000..fc10ad6 --- /dev/null +++ b/examples/chat/packets/gen.bat @@ -0,0 +1,2 @@ +bencgen --in .\message\packet.benc --out ./ --lang go +bencgen --in .\initialization\packet.benc --out ./ --lang go \ No newline at end of file diff --git a/examples/chat/packets/initialization/common.go b/examples/chat/packets/initialization/common.go new file mode 100644 index 0000000..f8fd942 --- /dev/null +++ b/examples/chat/packets/initialization/common.go @@ -0,0 +1,7 @@ +package initialization + +const ( + ResponseSuccess byte = iota + ResponseUsernameTaken + ResponseUsernameMissing +) diff --git a/examples/chat/packets/initialization/packet.benc b/examples/chat/packets/initialization/packet.benc new file mode 100644 index 0000000..a863578 --- /dev/null +++ b/examples/chat/packets/initialization/packet.benc @@ -0,0 +1,12 @@ +header initialization; + +ctr Username { + string data = 1; +} + +ctr Response { + byte data = 1; +} + +# DO NOT EDIT. +# [meta_s] eyJtc2dzIjp7IlJlc3BvbnNlIjp7InJJZHMiOm51bGwsImZpZWxkcyI6eyIxIjp7IklkIjoxLCJOYW1lIjoiZGF0YSIsIlR5cGUiOnsiVG9rZW5UeXBlIjoxOSwiTWFwS2V5VHlwZSI6bnVsbCwiQ2hpbGRUeXBlIjpudWxsLCJDdHJOYW1lIjoiIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX19fSwiVXNlcm5hbWUiOnsicklkcyI6bnVsbCwiZmllbGRzIjp7IjEiOnsiSWQiOjEsIk5hbWUiOiJkYXRhIiwiVHlwZSI6eyJUb2tlblR5cGUiOjE1LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsIkN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fX19fX0= [meta_e] \ No newline at end of file diff --git a/examples/chat/packets/initialization/packet.benc.go b/examples/chat/packets/initialization/packet.benc.go new file mode 100644 index 0000000..59b11dc --- /dev/null +++ b/examples/chat/packets/initialization/packet.benc.go @@ -0,0 +1,198 @@ +// Code generated by bencgen golang. DO NOT EDIT. +// source: .\initialization\packet.benc + +package initialization + +import ( + "github.com/deneonet/benc/std" + "github.com/deneonet/benc/impl/gen" +) + +// Struct - Username +type Username struct { + Data string +} + +// Reserved Ids - Username +var usernameRIds = []uint16{} + +// Size - Username +func (username *Username) Size() int { + return username.size(0) +} + +// Nested Size - Username +func (username *Username) size(id uint16) (s int) { + s += bstd.SizeString(username.Data) + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - Username +func (username *Username) SizePlain() (s int) { + s += bstd.SizeString(username.Data) + return +} + +// Marshal - Username +func (username *Username) Marshal(b []byte) { + username.marshal(0, b, 0) +} + +// Nested Marshal - Username +func (username *Username) marshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) + n = bstd.MarshalString(n, b, username.Data) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - Username +func (username *Username) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalString(n, b, username.Data) + return n +} + +// Unmarshal - Username +func (username *Username) Unmarshal(b []byte) (err error) { + _, err = username.unmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - Username +func (username *Username) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, usernameRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, username.Data, err = bstd.UnmarshalString(n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - Username +func (username *Username) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, username.Data, err = bstd.UnmarshalString(n, b); err != nil { + return + } + return +} + +// Struct - Response +type Response struct { + Data byte +} + +// Reserved Ids - Response +var responseRIds = []uint16{} + +// Size - Response +func (response *Response) Size() int { + return response.size(0) +} + +// Nested Size - Response +func (response *Response) size(id uint16) (s int) { + s += bstd.SizeByte() + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - Response +func (response *Response) SizePlain() (s int) { + s += bstd.SizeByte() + return +} + +// Marshal - Response +func (response *Response) Marshal(b []byte) { + response.marshal(0, b, 0) +} + +// Nested Marshal - Response +func (response *Response) marshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed8, 1) + n = bstd.MarshalByte(n, b, response.Data) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - Response +func (response *Response) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalByte(n, b, response.Data) + return n +} + +// Unmarshal - Response +func (response *Response) Unmarshal(b []byte) (err error) { + _, err = response.unmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - Response +func (response *Response) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, responseRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, response.Data, err = bstd.UnmarshalByte(n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - Response +func (response *Response) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, response.Data, err = bstd.UnmarshalByte(n, b); err != nil { + return + } + return +} + diff --git a/examples/chat/packets/message/packet.benc b/examples/chat/packets/message/packet.benc new file mode 100644 index 0000000..dcd36cb --- /dev/null +++ b/examples/chat/packets/message/packet.benc @@ -0,0 +1,8 @@ +header message; + +ctr Packet { + string data = 1; +} + +# DO NOT EDIT. +# [meta_s] eyJtc2dzIjp7IlBhY2tldCI6eyJySWRzIjpudWxsLCJmaWVsZHMiOnsiMSI6eyJJZCI6MSwiTmFtZSI6ImRhdGEiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTUsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiQ3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19fX19fQ== [meta_e] \ No newline at end of file diff --git a/examples/chat/packets/message/packet.benc.go b/examples/chat/packets/message/packet.benc.go new file mode 100644 index 0000000..a049201 --- /dev/null +++ b/examples/chat/packets/message/packet.benc.go @@ -0,0 +1,104 @@ +// Code generated by bencgen golang. DO NOT EDIT. +// source: .\message\packet.benc + +package message + +import ( + "github.com/deneonet/benc/std" + "github.com/deneonet/benc/impl/gen" +) + +// Struct - Packet +type Packet struct { + Data string +} + +// Reserved Ids - Packet +var packetRIds = []uint16{} + +// Size - Packet +func (packet *Packet) Size() int { + return packet.size(0) +} + +// Nested Size - Packet +func (packet *Packet) size(id uint16) (s int) { + s += bstd.SizeString(packet.Data) + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - Packet +func (packet *Packet) SizePlain() (s int) { + s += bstd.SizeString(packet.Data) + return +} + +// Marshal - Packet +func (packet *Packet) Marshal(b []byte) { + packet.marshal(0, b, 0) +} + +// Nested Marshal - Packet +func (packet *Packet) marshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) + n = bstd.MarshalString(n, b, packet.Data) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - Packet +func (packet *Packet) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalString(n, b, packet.Data) + return n +} + +// Unmarshal - Packet +func (packet *Packet) Unmarshal(b []byte) (err error) { + _, err = packet.unmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - Packet +func (packet *Packet) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, packetRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, packet.Data, err = bstd.UnmarshalString(n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - Packet +func (packet *Packet) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, packet.Data, err = bstd.UnmarshalString(n, b); err != nil { + return + } + return +} + diff --git a/examples/chat/server/main.go b/examples/chat/server/main.go new file mode 100644 index 0000000..50377ad --- /dev/null +++ b/examples/chat/server/main.go @@ -0,0 +1,167 @@ +package main + +import ( + "fmt" + "net" + "sync" + "time" + + "github.com/google/uuid" + knet "kinetra.de/net" + "kinetra.de/net/examples/chat/packets" + "kinetra.de/net/examples/chat/packets/initialization" + "kinetra.de/net/examples/chat/packets/message" +) + +func main() { + server := &knet.Server{ + Addr: "localhost:8080", + CertFile: "server.kc", + EnableConnPurge: true, + ConnPurgeInterval: 10 * time.Minute, + IdleTimeout: 30 * time.Minute, + MinSessionsBeforePurge: 5, + WriteDeadline: 10 * time.Second, + BufferSize: 1024, + } + + messages := make(map[uuid.UUID]packets.Message) // History of the messages + usernames := make(map[string]bool) // To track which username is still available (map as it's easier) + + messages[uuid.UUID{}] = packets.Message{} + + mutex := sync.RWMutex{} + + server.OnRead = func(conn net.Conn, info knet.ReadInfo) knet.AfterAction { + err := server.UnmarshalPacket(info.Data, func(id int, b []byte) (err error) { + switch id { + case packets.InitializationUsernamePacket: + var username initialization.Username + if err = username.Unmarshal(b); err != nil { + return + } + + response := initialization.Response{ + Data: initialization.ResponseUsernameMissing, + } + + if len(username.Data) == 0 { + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + return nil + } + + if _, ok := usernames[username.Data]; ok { + response.Data = initialization.ResponseUsernameTaken + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + return nil + } + + response.Data = initialization.ResponseSuccess + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + + if err = server.Store(conn, "username", username.Data); err != nil { + return + } + + mutex.Lock() + usernames[username.Data] = false + mutex.Unlock() + + fmt.Printf("%s initialized as \"%s\".\n", conn.RemoteAddr().String(), username.Data) + return nil + case packets.MessagePacket: + var message message.Packet + if err = message.Unmarshal(b); err != nil { + return + } + + response := message.Response{ + Data: initialization.ResponseUsernameMissing, + } + + if len(username.Data) == 0 { + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + return nil + } + + if _, ok := usernames[username.Data]; ok { + response.Data = initialization.ResponseUsernameTaken + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + return nil + } + + response.Data = initialization.ResponseSuccess + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + + if err = server.Store(conn, "username", username.Data); err != nil { + return + } + + mutex.Lock() + usernames[username.Data] = false + mutex.Unlock() + + fmt.Printf("%s initialized as \"%s\".\n", conn.RemoteAddr().String(), username.Data) + return nil + } + + return nil + }) + + if err != nil { + return knet.Close + } + + return knet.None + } + + server.OnSecureConnect = func(conn net.Conn, session knet.ServerSession) knet.AfterAction { + fmt.Println("Established a secure connection with", conn.RemoteAddr().String()) + return knet.None + } + + server.OnDisconnect = func(conn net.Conn) { + username, err := server.Get(conn, "username") + if err != nil { + return + } + + if username == nil { + fmt.Printf("%s disconnected.\n", conn.RemoteAddr().String()) + return + } + + fmt.Printf("%s disconnected.\n", username) + + mutex.Lock() + delete(usernames, username.(string)) + mutex.Unlock() + } + + server.OnConnectionError = func(conn net.Conn, err error) knet.AfterAction { + fmt.Printf("Connection error: %s from %s\n", err, conn.RemoteAddr().String()) + return knet.Close + } + + server.OnAcceptingError = func(err error) bool { + fmt.Println("Server failed to accept a connection.") + return false + } + + err := server.Run() + if err != nil { + fmt.Println("Server error: ", err) + } +} diff --git a/examples/chat/server/server.kc b/examples/chat/server/server.kc new file mode 100644 index 0000000000000000000000000000000000000000..bcf84e1d7b7ffa2f0f748335407f02c88f00f2c5 GIT binary patch literal 365 zcmV-z0h0a#00RL+0QwEDT-oS#e#%FK_eO^?B3cJQKmv!ar~uWR}$8avw}a-(dB#gdXSUx=Y+wF}?!=g#iQsOmu|UYdkWi0BYTQK?0I7wQQ)X zZy)WRgp5FA_cC+B^u(wB5%(lY6~>Myewn!Di9(Hz7t*5G!JG1~vX$ZQv;b&2FiA1m zDpx|-8p25g+k~0S?>iO*fyS@~Ze>FXKx?-&wpEZqE5x{yrl%$6gydiU!N}9-(yN&g z#ItMYF{4KV1BwAKfrkP?DARIRt>e8|;XzokSpx&lY{R4L4!7l$sG{AVvjRc^7+_jRwPI4*vU89TWRtu9jGL1S zHT37NKB8u+XlzR>-zHsEx#k4MeMbtob{N#i`3KuqOqKl9Y0E^0JM9me@B{<`2nBT6 LD$f7_00032>X@#1 literal 0 HcmV?d00001 diff --git a/gen/main.go b/gen/main.go index ba67089..62dfa54 100644 --- a/gen/main.go +++ b/gen/main.go @@ -13,7 +13,7 @@ func main() { fmt.Printf("Generating for version: %d\n", *version) - if err := cert.GenerateCertificateChain(*version, "s.cert", "root.key"); err != nil { + if err := cert.GenerateCertificateChain(*version, "server.kc", "client.kr"); err != nil { panic(err) } } diff --git a/go.mod b/go.mod index 91266a3..7a1ebfa 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,9 @@ module kinetra.de/net go 1.23.2 +require github.com/deneonet/benc v1.1.2 + require ( - github.com/deneonet/benc v1.1.2 + github.com/google/uuid v1.6.0 // indirect golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect ) diff --git a/go.sum b/go.sum index bddf9cd..24feefe 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/deneonet/benc v1.1.2 h1:JNJSnA53zVLjt4Bz1HwxG4tQg475LP+kd8rgUuV4tc4= github.com/deneonet/benc v1.1.2/go.mod h1:HbL4lzHT0jkmlYa36bZw0a0Nhj4NsXG7bd/bXRxJYy4= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= diff --git a/server.go b/server.go index 47e8c17..bfa5c3a 100644 --- a/server.go +++ b/server.go @@ -3,7 +3,6 @@ package knet import ( "crypto/ecdh" "crypto/sha256" - "fmt" "io" "net" "os" @@ -19,39 +18,32 @@ import ( ) type Server struct { - // Addr is the address the server will bind to. - Addr string - - // Cert is the file path to the server's certificate. - CertFile string - cert cert.Certificate - priv *ecdh.PrivateKey - - sessions map[net.Conn]ServerSession - - mutex *sync.Mutex - - ReadDeadline time.Duration - WriteDeadline time.Duration - + Addr string + CertFile string + cert cert.Certificate + priv *ecdh.PrivateKey + sessions map[net.Conn]ServerSession + mutex sync.RWMutex + ReadDeadline time.Duration + WriteDeadline time.Duration HandshakeReadDeadline time.Duration HandshakeWriteDeadline time.Duration - - EnableConnPurge bool - ConnPurgeInterval time.Duration - - IdleTimeout time.Duration - BufferSize int - - OnRead func(net.Conn, ReadInfo) AfterAction - OnDisconnect func(net.Conn) - OnSecureConnect func(net.Conn, ServerSession) AfterAction + EnableConnPurge bool + ConnPurgeInterval time.Duration + MinSessionsBeforePurge int + IdleTimeout time.Duration + BufferSize int + OnRead func(net.Conn, ReadInfo) AfterAction + OnDisconnect func(net.Conn) + OnSecureConnect func(net.Conn, ServerSession) AfterAction + OnAcceptingError func(error) bool + OnConnectionError func(net.Conn, error) AfterAction } type ServerSession struct { SharedSecret []byte - // TODO: LastActivity time.Time - // TODO: Data map[string]interface{} + LastActivity time.Time + Data map[string]interface{} } type ServerHandshakeResult byte @@ -62,35 +54,40 @@ const ( ServerHandshakeError ) -/*TOOD: func (s *Server) connectionPurge() { +type connectionErrorResult byte + +const ( + connectionErrorContinue connectionErrorResult = iota + connectionErrorReturn + connectionErrorMoveOn +) + +func (s *Server) connectionPurge() { for { time.Sleep(s.ConnPurgeInterval) - s.mutex.Lock() - if proto.IsDebugMode() { - fmt.Println("Start connection purge") + if len(s.sessions) < s.MinSessionsBeforePurge { + continue } - currentTime := time.Now() - for conn, session := range s.ServerSessions { - if currentTime.Sub(session.LastActivity) > s.IdleTimeout { - // Connection has been idle for too long, close it. + + s.mutex.Lock() + + now := time.Now() + for conn, session := range s.sessions { + if now.Sub(session.LastActivity) > s.IdleTimeout { conn.Close() - delete(s.ServerSessions, conn) + delete(s.sessions, conn) } } - if proto.IsDebugMode() { - fmt.Println("Finished connection purge") - } + s.mutex.Unlock() } -}*/ +} func (s *Server) setDeadline(conn net.Conn, handshakeComplete bool) { conn.SetDeadline(time.Time{}) - var readDeadline, writeDeadline time.Duration - if handshakeComplete { - readDeadline, writeDeadline = s.ReadDeadline, s.WriteDeadline - } else { + readDeadline, writeDeadline := s.ReadDeadline, s.WriteDeadline + if !handshakeComplete { readDeadline, writeDeadline = s.HandshakeReadDeadline, s.HandshakeWriteDeadline } @@ -102,6 +99,19 @@ func (s *Server) setDeadline(conn net.Conn, handshakeComplete bool) { } } +func (s *Server) handleConnectionError(conn net.Conn, err error) connectionErrorResult { + if err == nil { + return connectionErrorMoveOn + } + + action := s.OnConnectionError(conn, err) + if action == Close { + return connectionErrorReturn + } + + return connectionErrorContinue +} + func (s *Server) handleConnection(conn net.Conn) { defer func() { conn.Close() @@ -110,29 +120,33 @@ func (s *Server) handleConnection(conn net.Conn) { handshakeComplete := false buf := make([]byte, s.BufferSize) - var session ServerSession + for { s.setDeadline(conn, handshakeComplete) size, err := netutils.ReadFromConn(conn, buf) + if int(size) > len(buf) { - panic("data exceeded buffer size") - break + if result := s.handleConnectionError(conn, ErrDataExceededBufferSize); result == connectionErrorReturn { + return + } + continue } if err != nil && !handshakeComplete { - //TODO: error handling - panic(err) - break + if result := s.handleConnectionError(conn, err); result == connectionErrorReturn { + return + } + continue } - var result ServerHandshakeResult if !handshakeComplete { - result, err = s.processHandshake(conn, buf[4:size]) + result, err := s.processHandshake(conn, buf[4:size]) if err != nil { - //TODO: error handling - panic(err) - break + if result := s.handleConnectionError(conn, err); result == connectionErrorReturn { + return + } + continue } if result == ServerContinueHandshake { @@ -141,15 +155,11 @@ func (s *Server) handleConnection(conn net.Conn) { if result == ServerHandshakeComplete { handshakeComplete = true - size = 0 - session, _ = s.GetSession(conn) s.setDeadline(conn, handshakeComplete) - if s.OnSecureConnect != nil { - if action := s.OnSecureConnect(conn, session); action == Close { - break - } + if s.OnSecureConnect != nil && s.OnSecureConnect(conn, session) == Close { + break } continue @@ -157,13 +167,9 @@ func (s *Server) handleConnection(conn net.Conn) { } if err == io.EOF || size == 0 { - conn.Close() - s.RemoveSession(conn) - if s.OnDisconnect != nil { s.OnDisconnect(conn) } - return } @@ -171,10 +177,13 @@ func (s *Server) handleConnection(conn net.Conn) { continue } - encrypted := buf[4] == 1 - data := buf[5:size] - if encrypted && err == nil { - data, err = crypto.Decrypt(session.SharedSecret, data) + var data []byte = nil + if err == nil { + encrypted := buf[4] == 1 + data = buf[5:size] + if encrypted { + data, err = crypto.Decrypt(session.SharedSecret, data) + } } action := s.OnRead(conn, ReadInfo{ @@ -199,7 +208,6 @@ func (s *Server) processHandshake(conn net.Conn, buf []byte) (ServerHandshakeRes if err = packets.SendCertResponseHandshakePacket(conn, packets.CertificateResponse, s.cert); err != nil { return ServerHandshakeError, err } - return ServerContinueHandshake, nil case packets.ClientInformation: clientPublicKey, err := ecdh.P521().NewPublicKey(packet.Payload) @@ -215,7 +223,8 @@ func (s *Server) processHandshake(conn net.Conn, buf []byte) (ServerHandshakeRes aesSecret := sha256.Sum256(sharedSecret) s.SetSession(conn, ServerSession{ SharedSecret: aesSecret[:], - // TODO: Data: make(map[string]interface{}), + LastActivity: time.Now(), + Data: make(map[string]interface{}), }) verification, err := crypto.Encrypt(aesSecret[:], []byte{1, 2, 3, 4}) @@ -223,7 +232,9 @@ func (s *Server) processHandshake(conn net.Conn, buf []byte) (ServerHandshakeRes return ServerHandshakeError, err } - packets.SendHandshakePacket(conn, packets.ServerVerification, verification) + if err := packets.SendHandshakePacket(conn, packets.ServerVerification, verification); err != nil { + return ServerHandshakeError, err + } return ServerHandshakeComplete, nil } @@ -237,14 +248,15 @@ func (s *Server) Run() error { } defer listener.Close() - //TODO: if s.EnableConnPurge { - // //go s.connectionPurge() - //} - bytes, err := os.ReadFile(s.CertFile) + if s.EnableConnPurge { + go s.connectionPurge() + } + + certData, err := os.ReadFile(s.CertFile) if err != nil { return err } - if err = s.cert.Unmarshal(bytes); err != nil { + if err = s.cert.Unmarshal(certData); err != nil { return err } @@ -255,6 +267,10 @@ func (s *Server) Run() error { if s.IdleTimeout == 0 { s.IdleTimeout = 30 * time.Minute } + if s.ConnPurgeInterval == 0 { + s.ConnPurgeInterval = 35 * time.Minute + } + if s.sessions == nil { s.sessions = make(map[net.Conn]ServerSession) } @@ -266,12 +282,20 @@ func (s *Server) Run() error { s.HandshakeWriteDeadline = 500 * time.Millisecond } - s.mutex = &sync.Mutex{} + if s.OnAcceptingError == nil { + s.OnAcceptingError = func(err error) bool { return true } + } + if s.OnConnectionError == nil { + s.OnConnectionError = func(conn net.Conn, err error) AfterAction { panic(conn.RemoteAddr().String() + ": " + err.Error()) } + } for { conn, err := listener.Accept() if err != nil { - fmt.Println("Error accepting connection:", err) + if s.OnAcceptingError(err) { + return err + } + continue } go s.handleConnection(conn) @@ -288,52 +312,52 @@ func (s *Server) Send(conn net.Conn, b []byte) error { if err != nil { return err } - return netutils.SendToConn(conn, len(encrypted)+1, func(n int, b []byte) { b[n] = 1; copy(b[n+1:], encrypted) }) + return netutils.SendToConn(conn, len(encrypted)+1, func(n int, buf []byte) { buf[n] = 1; copy(buf[n+1:], encrypted) }) } func (s *Server) SendUnsecure(conn net.Conn, b []byte) error { return netutils.SendToConn(conn, len(b)+1, func(n int, buf []byte) { buf[n] = 0; copy(buf[n+1:], b) }) } -func (s *Server) UnmarshalPacket(b []byte, f func(int, []byte) error) error { - n, id, err := bstd.UnmarshalInt(0, b) +func (s *Server) UnmarshalPacket(buf []byte, f func(int, []byte) error) error { + n, id, err := bstd.UnmarshalInt(0, buf) if err != nil { return err } - return f(id, b[n:]) + return f(id, buf[n:]) } func (s *Server) SendPacket(conn net.Conn, id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.Send(conn, b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.Send(conn, buf) } func (s *Server) SendPacketUnsecure(conn net.Conn, id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.SendUnsecure(conn, b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.SendUnsecure(conn, buf) } func (s *Server) SendPacketToAll(id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.SendToAll(b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.SendToAll(buf) } func (s *Server) SendPacketUnsecureToAll(conn net.Conn, id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.SendUnsecureToAll(b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.SendUnsecureToAll(buf) } -func (s *Server) SendToAll(b []byte) error { +func (s *Server) SendToAll(buf []byte) error { for conn := range s.sessions { - if err := s.Send(conn, b); err != nil { + if err := s.Send(conn, buf); err != nil { return err } } @@ -341,18 +365,16 @@ func (s *Server) SendToAll(b []byte) error { return nil } -func (s *Server) SendUnsecureToAll(b []byte) error { +func (s *Server) SendUnsecureToAll(buf []byte) error { for conn := range s.sessions { - if err := s.SendUnsecure(conn, b); err != nil { + if err := s.SendUnsecure(conn, buf); err != nil { return err } } return nil } - -/*TODO: -func (s *Server) Store(conn gnet.Conn, key string, value interface{}) error { +func (s *Server) Store(conn net.Conn, key string, value interface{}) error { ses, ok := s.GetSession(conn) if !ok { return ErrSessionNotFound @@ -365,7 +387,7 @@ func (s *Server) Store(conn gnet.Conn, key string, value interface{}) error { return nil } -func (s *Server) Get(conn gnet.Conn, key string) (interface{}, error) { +func (s *Server) Get(conn net.Conn, key string) (interface{}, error) { ses, ok := s.GetSession(conn) if !ok { return nil, ErrSessionNotFound @@ -376,12 +398,13 @@ func (s *Server) Get(conn gnet.Conn, key string) (interface{}, error) { s.mutex.Unlock() return value, nil -}*/ +} func (s *Server) GetSession(conn net.Conn) (ServerSession, bool) { s.mutex.Lock() ses, ok := s.sessions[conn] s.mutex.Unlock() + return ses, ok }