From 24e03831ebdb8ce0f21637f80fa29cbc7a4780e0 Mon Sep 17 00:00:00 2001 From: deneonet Date: Sun, 26 Jan 2025 12:11:42 +0100 Subject: [PATCH 1/2] unfinished clean up, server/client improvements + examples --- .gitignore | 1 - common.go | 1 + examples/chat/client/client.kr | Bin 0 -> 87 bytes examples/chat/client/main.go | 79 ++++++ examples/chat/packets/common.go | 16 ++ examples/chat/packets/gen.bat | 2 + .../chat/packets/initialization/common.go | 7 + .../chat/packets/initialization/packet.benc | 12 + .../packets/initialization/packet.benc.go | 198 ++++++++++++++ examples/chat/packets/message/packet.benc | 8 + examples/chat/packets/message/packet.benc.go | 104 ++++++++ examples/chat/server/main.go | 167 ++++++++++++ examples/chat/server/server.kc | Bin 0 -> 365 bytes gen/main.go | 2 +- go.mod | 4 +- go.sum | 2 + server.go | 241 ++++++++++-------- 17 files changed, 732 insertions(+), 112 deletions(-) create mode 100644 examples/chat/client/client.kr create mode 100644 examples/chat/client/main.go create mode 100644 examples/chat/packets/common.go create mode 100644 examples/chat/packets/gen.bat create mode 100644 examples/chat/packets/initialization/common.go create mode 100644 examples/chat/packets/initialization/packet.benc create mode 100644 examples/chat/packets/initialization/packet.benc.go create mode 100644 examples/chat/packets/message/packet.benc create mode 100644 examples/chat/packets/message/packet.benc.go create mode 100644 examples/chat/server/main.go create mode 100644 examples/chat/server/server.kc diff --git a/.gitignore b/.gitignore index 8a71b76..e69de29 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +0,0 @@ -_examples \ No newline at end of file diff --git a/common.go b/common.go index 1a4a140..1161400 100644 --- a/common.go +++ b/common.go @@ -28,4 +28,5 @@ var ( ErrInvalidHandshakePacket = errors.New("invalid handshake packet received") ErrBufTooSmall = errors.New("buffer is too small for the requested size") ErrInvalidRootKey = errors.New("invalid root key in client struct") + ErrDataExceededBufferSize = errors.New("received data size exceeded buffer size") ) diff --git a/examples/chat/client/client.kr b/examples/chat/client/client.kr new file mode 100644 index 0000000000000000000000000000000000000000..1a730e2dde50369d168307de275009dcf9c8ec51 GIT binary patch literal 87 zcmV-d0I2^000RL-0sv@U9NQD@+jM#G!Ld|A?domk(`h-QUH~M}DQTRxidONr_o9~@ ta|)n}bQir85k4^F^N#uLWew`K%_rEn>A0u_0s;sFblED;000000Rgv-CK~_% literal 0 HcmV?d00001 diff --git a/examples/chat/client/main.go b/examples/chat/client/main.go new file mode 100644 index 0000000..2f6a88f --- /dev/null +++ b/examples/chat/client/main.go @@ -0,0 +1,79 @@ +package main + +import ( + "errors" + "fmt" + "net" + "time" + + knet "kinetra.de/net" + "kinetra.de/net/examples/chat/packets" + "kinetra.de/net/examples/chat/packets/initialization" +) + +var ErrResponse = errors.New("response error") + +func main() { + client := &knet.Client{ + RootKeyFile: "client.kr", + BufferSize: 1024, + WriteDeadline: 10 * time.Second, + HandshakeReadDeadline: 500 * time.Millisecond, + HandshakeWriteDeadline: 500 * time.Millisecond, + } + + client.OnRead = func(conn net.Conn, info knet.ReadInfo) knet.AfterAction { + err := client.UnmarshalPacket(info.Data, func(id int, b []byte) (err error) { + switch id { + case packets.InitializationResponsePacket: + var response initialization.Response + if err = response.Unmarshal(b); err != nil { + return + } + + switch response.Data { + case initialization.ResponseUsernameTaken: + fmt.Println("Username is taken.") + case initialization.ResponseUsernameMissing: + fmt.Println("Username is missing.") + case initialization.ResponseSuccess: + return nil + } + + return ErrResponse + } + + return nil + }) + + if err != nil { + return knet.Close + } + + return knet.None + } + + client.OnSecureConnect = func(session knet.ClientSession) knet.AfterAction { + fmt.Println("Secure connection established.") + + username := initialization.Username{ + Data: "deneonet", + } + err := client.SendPacket(packets.InitializationUsernamePacket, &username) + if err != nil { + fmt.Println("Error sending username packet: ", err) + } + + return knet.None + } + + client.OnDisconnect = func() { + fmt.Println("Disconnected from server.") + } + + _, err := client.Connect("localhost:8080") + if err != nil { + fmt.Println("Error connecting to server:", err) + return + } +} diff --git a/examples/chat/packets/common.go b/examples/chat/packets/common.go new file mode 100644 index 0000000..23e9a39 --- /dev/null +++ b/examples/chat/packets/common.go @@ -0,0 +1,16 @@ +package packets + +const ( + InitializationUsernamePacket int = iota + InitializationResponsePacket + + MessagePacket + MessageResponsePacket +) + +type Message struct { + Username string + Message string + + HasDisconnected bool // After someone disconnected, everyone can reclaim that username, so in order to extinguish the "new user" from the "old users", a "disconnected" mark is appended to the message +} diff --git a/examples/chat/packets/gen.bat b/examples/chat/packets/gen.bat new file mode 100644 index 0000000..fc10ad6 --- /dev/null +++ b/examples/chat/packets/gen.bat @@ -0,0 +1,2 @@ +bencgen --in .\message\packet.benc --out ./ --lang go +bencgen --in .\initialization\packet.benc --out ./ --lang go \ No newline at end of file diff --git a/examples/chat/packets/initialization/common.go b/examples/chat/packets/initialization/common.go new file mode 100644 index 0000000..f8fd942 --- /dev/null +++ b/examples/chat/packets/initialization/common.go @@ -0,0 +1,7 @@ +package initialization + +const ( + ResponseSuccess byte = iota + ResponseUsernameTaken + ResponseUsernameMissing +) diff --git a/examples/chat/packets/initialization/packet.benc b/examples/chat/packets/initialization/packet.benc new file mode 100644 index 0000000..a863578 --- /dev/null +++ b/examples/chat/packets/initialization/packet.benc @@ -0,0 +1,12 @@ +header initialization; + +ctr Username { + string data = 1; +} + +ctr Response { + byte data = 1; +} + +# DO NOT EDIT. +# [meta_s] eyJtc2dzIjp7IlJlc3BvbnNlIjp7InJJZHMiOm51bGwsImZpZWxkcyI6eyIxIjp7IklkIjoxLCJOYW1lIjoiZGF0YSIsIlR5cGUiOnsiVG9rZW5UeXBlIjoxOSwiTWFwS2V5VHlwZSI6bnVsbCwiQ2hpbGRUeXBlIjpudWxsLCJDdHJOYW1lIjoiIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX19fSwiVXNlcm5hbWUiOnsicklkcyI6bnVsbCwiZmllbGRzIjp7IjEiOnsiSWQiOjEsIk5hbWUiOiJkYXRhIiwiVHlwZSI6eyJUb2tlblR5cGUiOjE1LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsIkN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fX19fX0= [meta_e] \ No newline at end of file diff --git a/examples/chat/packets/initialization/packet.benc.go b/examples/chat/packets/initialization/packet.benc.go new file mode 100644 index 0000000..59b11dc --- /dev/null +++ b/examples/chat/packets/initialization/packet.benc.go @@ -0,0 +1,198 @@ +// Code generated by bencgen golang. DO NOT EDIT. +// source: .\initialization\packet.benc + +package initialization + +import ( + "github.com/deneonet/benc/std" + "github.com/deneonet/benc/impl/gen" +) + +// Struct - Username +type Username struct { + Data string +} + +// Reserved Ids - Username +var usernameRIds = []uint16{} + +// Size - Username +func (username *Username) Size() int { + return username.size(0) +} + +// Nested Size - Username +func (username *Username) size(id uint16) (s int) { + s += bstd.SizeString(username.Data) + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - Username +func (username *Username) SizePlain() (s int) { + s += bstd.SizeString(username.Data) + return +} + +// Marshal - Username +func (username *Username) Marshal(b []byte) { + username.marshal(0, b, 0) +} + +// Nested Marshal - Username +func (username *Username) marshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) + n = bstd.MarshalString(n, b, username.Data) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - Username +func (username *Username) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalString(n, b, username.Data) + return n +} + +// Unmarshal - Username +func (username *Username) Unmarshal(b []byte) (err error) { + _, err = username.unmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - Username +func (username *Username) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, usernameRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, username.Data, err = bstd.UnmarshalString(n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - Username +func (username *Username) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, username.Data, err = bstd.UnmarshalString(n, b); err != nil { + return + } + return +} + +// Struct - Response +type Response struct { + Data byte +} + +// Reserved Ids - Response +var responseRIds = []uint16{} + +// Size - Response +func (response *Response) Size() int { + return response.size(0) +} + +// Nested Size - Response +func (response *Response) size(id uint16) (s int) { + s += bstd.SizeByte() + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - Response +func (response *Response) SizePlain() (s int) { + s += bstd.SizeByte() + return +} + +// Marshal - Response +func (response *Response) Marshal(b []byte) { + response.marshal(0, b, 0) +} + +// Nested Marshal - Response +func (response *Response) marshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed8, 1) + n = bstd.MarshalByte(n, b, response.Data) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - Response +func (response *Response) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalByte(n, b, response.Data) + return n +} + +// Unmarshal - Response +func (response *Response) Unmarshal(b []byte) (err error) { + _, err = response.unmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - Response +func (response *Response) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, responseRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, response.Data, err = bstd.UnmarshalByte(n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - Response +func (response *Response) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, response.Data, err = bstd.UnmarshalByte(n, b); err != nil { + return + } + return +} + diff --git a/examples/chat/packets/message/packet.benc b/examples/chat/packets/message/packet.benc new file mode 100644 index 0000000..dcd36cb --- /dev/null +++ b/examples/chat/packets/message/packet.benc @@ -0,0 +1,8 @@ +header message; + +ctr Packet { + string data = 1; +} + +# DO NOT EDIT. +# [meta_s] eyJtc2dzIjp7IlBhY2tldCI6eyJySWRzIjpudWxsLCJmaWVsZHMiOnsiMSI6eyJJZCI6MSwiTmFtZSI6ImRhdGEiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTUsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiQ3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19fX19fQ== [meta_e] \ No newline at end of file diff --git a/examples/chat/packets/message/packet.benc.go b/examples/chat/packets/message/packet.benc.go new file mode 100644 index 0000000..a049201 --- /dev/null +++ b/examples/chat/packets/message/packet.benc.go @@ -0,0 +1,104 @@ +// Code generated by bencgen golang. DO NOT EDIT. +// source: .\message\packet.benc + +package message + +import ( + "github.com/deneonet/benc/std" + "github.com/deneonet/benc/impl/gen" +) + +// Struct - Packet +type Packet struct { + Data string +} + +// Reserved Ids - Packet +var packetRIds = []uint16{} + +// Size - Packet +func (packet *Packet) Size() int { + return packet.size(0) +} + +// Nested Size - Packet +func (packet *Packet) size(id uint16) (s int) { + s += bstd.SizeString(packet.Data) + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - Packet +func (packet *Packet) SizePlain() (s int) { + s += bstd.SizeString(packet.Data) + return +} + +// Marshal - Packet +func (packet *Packet) Marshal(b []byte) { + packet.marshal(0, b, 0) +} + +// Nested Marshal - Packet +func (packet *Packet) marshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) + n = bstd.MarshalString(n, b, packet.Data) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - Packet +func (packet *Packet) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalString(n, b, packet.Data) + return n +} + +// Unmarshal - Packet +func (packet *Packet) Unmarshal(b []byte) (err error) { + _, err = packet.unmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - Packet +func (packet *Packet) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, packetRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, packet.Data, err = bstd.UnmarshalString(n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - Packet +func (packet *Packet) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, packet.Data, err = bstd.UnmarshalString(n, b); err != nil { + return + } + return +} + diff --git a/examples/chat/server/main.go b/examples/chat/server/main.go new file mode 100644 index 0000000..50377ad --- /dev/null +++ b/examples/chat/server/main.go @@ -0,0 +1,167 @@ +package main + +import ( + "fmt" + "net" + "sync" + "time" + + "github.com/google/uuid" + knet "kinetra.de/net" + "kinetra.de/net/examples/chat/packets" + "kinetra.de/net/examples/chat/packets/initialization" + "kinetra.de/net/examples/chat/packets/message" +) + +func main() { + server := &knet.Server{ + Addr: "localhost:8080", + CertFile: "server.kc", + EnableConnPurge: true, + ConnPurgeInterval: 10 * time.Minute, + IdleTimeout: 30 * time.Minute, + MinSessionsBeforePurge: 5, + WriteDeadline: 10 * time.Second, + BufferSize: 1024, + } + + messages := make(map[uuid.UUID]packets.Message) // History of the messages + usernames := make(map[string]bool) // To track which username is still available (map as it's easier) + + messages[uuid.UUID{}] = packets.Message{} + + mutex := sync.RWMutex{} + + server.OnRead = func(conn net.Conn, info knet.ReadInfo) knet.AfterAction { + err := server.UnmarshalPacket(info.Data, func(id int, b []byte) (err error) { + switch id { + case packets.InitializationUsernamePacket: + var username initialization.Username + if err = username.Unmarshal(b); err != nil { + return + } + + response := initialization.Response{ + Data: initialization.ResponseUsernameMissing, + } + + if len(username.Data) == 0 { + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + return nil + } + + if _, ok := usernames[username.Data]; ok { + response.Data = initialization.ResponseUsernameTaken + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + return nil + } + + response.Data = initialization.ResponseSuccess + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + + if err = server.Store(conn, "username", username.Data); err != nil { + return + } + + mutex.Lock() + usernames[username.Data] = false + mutex.Unlock() + + fmt.Printf("%s initialized as \"%s\".\n", conn.RemoteAddr().String(), username.Data) + return nil + case packets.MessagePacket: + var message message.Packet + if err = message.Unmarshal(b); err != nil { + return + } + + response := message.Response{ + Data: initialization.ResponseUsernameMissing, + } + + if len(username.Data) == 0 { + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + return nil + } + + if _, ok := usernames[username.Data]; ok { + response.Data = initialization.ResponseUsernameTaken + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + return nil + } + + response.Data = initialization.ResponseSuccess + if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { + return + } + + if err = server.Store(conn, "username", username.Data); err != nil { + return + } + + mutex.Lock() + usernames[username.Data] = false + mutex.Unlock() + + fmt.Printf("%s initialized as \"%s\".\n", conn.RemoteAddr().String(), username.Data) + return nil + } + + return nil + }) + + if err != nil { + return knet.Close + } + + return knet.None + } + + server.OnSecureConnect = func(conn net.Conn, session knet.ServerSession) knet.AfterAction { + fmt.Println("Established a secure connection with", conn.RemoteAddr().String()) + return knet.None + } + + server.OnDisconnect = func(conn net.Conn) { + username, err := server.Get(conn, "username") + if err != nil { + return + } + + if username == nil { + fmt.Printf("%s disconnected.\n", conn.RemoteAddr().String()) + return + } + + fmt.Printf("%s disconnected.\n", username) + + mutex.Lock() + delete(usernames, username.(string)) + mutex.Unlock() + } + + server.OnConnectionError = func(conn net.Conn, err error) knet.AfterAction { + fmt.Printf("Connection error: %s from %s\n", err, conn.RemoteAddr().String()) + return knet.Close + } + + server.OnAcceptingError = func(err error) bool { + fmt.Println("Server failed to accept a connection.") + return false + } + + err := server.Run() + if err != nil { + fmt.Println("Server error: ", err) + } +} diff --git a/examples/chat/server/server.kc b/examples/chat/server/server.kc new file mode 100644 index 0000000000000000000000000000000000000000..bcf84e1d7b7ffa2f0f748335407f02c88f00f2c5 GIT binary patch literal 365 zcmV-z0h0a#00RL+0QwEDT-oS#e#%FK_eO^?B3cJQKmv!ar~uWR}$8avw}a-(dB#gdXSUx=Y+wF}?!=g#iQsOmu|UYdkWi0BYTQK?0I7wQQ)X zZy)WRgp5FA_cC+B^u(wB5%(lY6~>Myewn!Di9(Hz7t*5G!JG1~vX$ZQv;b&2FiA1m zDpx|-8p25g+k~0S?>iO*fyS@~Ze>FXKx?-&wpEZqE5x{yrl%$6gydiU!N}9-(yN&g z#ItMYF{4KV1BwAKfrkP?DARIRt>e8|;XzokSpx&lY{R4L4!7l$sG{AVvjRc^7+_jRwPI4*vU89TWRtu9jGL1S zHT37NKB8u+XlzR>-zHsEx#k4MeMbtob{N#i`3KuqOqKl9Y0E^0JM9me@B{<`2nBT6 LD$f7_00032>X@#1 literal 0 HcmV?d00001 diff --git a/gen/main.go b/gen/main.go index ba67089..62dfa54 100644 --- a/gen/main.go +++ b/gen/main.go @@ -13,7 +13,7 @@ func main() { fmt.Printf("Generating for version: %d\n", *version) - if err := cert.GenerateCertificateChain(*version, "s.cert", "root.key"); err != nil { + if err := cert.GenerateCertificateChain(*version, "server.kc", "client.kr"); err != nil { panic(err) } } diff --git a/go.mod b/go.mod index 91266a3..7a1ebfa 100644 --- a/go.mod +++ b/go.mod @@ -2,7 +2,9 @@ module kinetra.de/net go 1.23.2 +require github.com/deneonet/benc v1.1.2 + require ( - github.com/deneonet/benc v1.1.2 + github.com/google/uuid v1.6.0 // indirect golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect ) diff --git a/go.sum b/go.sum index bddf9cd..24feefe 100644 --- a/go.sum +++ b/go.sum @@ -1,4 +1,6 @@ github.com/deneonet/benc v1.1.2 h1:JNJSnA53zVLjt4Bz1HwxG4tQg475LP+kd8rgUuV4tc4= github.com/deneonet/benc v1.1.2/go.mod h1:HbL4lzHT0jkmlYa36bZw0a0Nhj4NsXG7bd/bXRxJYy4= +github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= +github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= diff --git a/server.go b/server.go index 47e8c17..bfa5c3a 100644 --- a/server.go +++ b/server.go @@ -3,7 +3,6 @@ package knet import ( "crypto/ecdh" "crypto/sha256" - "fmt" "io" "net" "os" @@ -19,39 +18,32 @@ import ( ) type Server struct { - // Addr is the address the server will bind to. - Addr string - - // Cert is the file path to the server's certificate. - CertFile string - cert cert.Certificate - priv *ecdh.PrivateKey - - sessions map[net.Conn]ServerSession - - mutex *sync.Mutex - - ReadDeadline time.Duration - WriteDeadline time.Duration - + Addr string + CertFile string + cert cert.Certificate + priv *ecdh.PrivateKey + sessions map[net.Conn]ServerSession + mutex sync.RWMutex + ReadDeadline time.Duration + WriteDeadline time.Duration HandshakeReadDeadline time.Duration HandshakeWriteDeadline time.Duration - - EnableConnPurge bool - ConnPurgeInterval time.Duration - - IdleTimeout time.Duration - BufferSize int - - OnRead func(net.Conn, ReadInfo) AfterAction - OnDisconnect func(net.Conn) - OnSecureConnect func(net.Conn, ServerSession) AfterAction + EnableConnPurge bool + ConnPurgeInterval time.Duration + MinSessionsBeforePurge int + IdleTimeout time.Duration + BufferSize int + OnRead func(net.Conn, ReadInfo) AfterAction + OnDisconnect func(net.Conn) + OnSecureConnect func(net.Conn, ServerSession) AfterAction + OnAcceptingError func(error) bool + OnConnectionError func(net.Conn, error) AfterAction } type ServerSession struct { SharedSecret []byte - // TODO: LastActivity time.Time - // TODO: Data map[string]interface{} + LastActivity time.Time + Data map[string]interface{} } type ServerHandshakeResult byte @@ -62,35 +54,40 @@ const ( ServerHandshakeError ) -/*TOOD: func (s *Server) connectionPurge() { +type connectionErrorResult byte + +const ( + connectionErrorContinue connectionErrorResult = iota + connectionErrorReturn + connectionErrorMoveOn +) + +func (s *Server) connectionPurge() { for { time.Sleep(s.ConnPurgeInterval) - s.mutex.Lock() - if proto.IsDebugMode() { - fmt.Println("Start connection purge") + if len(s.sessions) < s.MinSessionsBeforePurge { + continue } - currentTime := time.Now() - for conn, session := range s.ServerSessions { - if currentTime.Sub(session.LastActivity) > s.IdleTimeout { - // Connection has been idle for too long, close it. + + s.mutex.Lock() + + now := time.Now() + for conn, session := range s.sessions { + if now.Sub(session.LastActivity) > s.IdleTimeout { conn.Close() - delete(s.ServerSessions, conn) + delete(s.sessions, conn) } } - if proto.IsDebugMode() { - fmt.Println("Finished connection purge") - } + s.mutex.Unlock() } -}*/ +} func (s *Server) setDeadline(conn net.Conn, handshakeComplete bool) { conn.SetDeadline(time.Time{}) - var readDeadline, writeDeadline time.Duration - if handshakeComplete { - readDeadline, writeDeadline = s.ReadDeadline, s.WriteDeadline - } else { + readDeadline, writeDeadline := s.ReadDeadline, s.WriteDeadline + if !handshakeComplete { readDeadline, writeDeadline = s.HandshakeReadDeadline, s.HandshakeWriteDeadline } @@ -102,6 +99,19 @@ func (s *Server) setDeadline(conn net.Conn, handshakeComplete bool) { } } +func (s *Server) handleConnectionError(conn net.Conn, err error) connectionErrorResult { + if err == nil { + return connectionErrorMoveOn + } + + action := s.OnConnectionError(conn, err) + if action == Close { + return connectionErrorReturn + } + + return connectionErrorContinue +} + func (s *Server) handleConnection(conn net.Conn) { defer func() { conn.Close() @@ -110,29 +120,33 @@ func (s *Server) handleConnection(conn net.Conn) { handshakeComplete := false buf := make([]byte, s.BufferSize) - var session ServerSession + for { s.setDeadline(conn, handshakeComplete) size, err := netutils.ReadFromConn(conn, buf) + if int(size) > len(buf) { - panic("data exceeded buffer size") - break + if result := s.handleConnectionError(conn, ErrDataExceededBufferSize); result == connectionErrorReturn { + return + } + continue } if err != nil && !handshakeComplete { - //TODO: error handling - panic(err) - break + if result := s.handleConnectionError(conn, err); result == connectionErrorReturn { + return + } + continue } - var result ServerHandshakeResult if !handshakeComplete { - result, err = s.processHandshake(conn, buf[4:size]) + result, err := s.processHandshake(conn, buf[4:size]) if err != nil { - //TODO: error handling - panic(err) - break + if result := s.handleConnectionError(conn, err); result == connectionErrorReturn { + return + } + continue } if result == ServerContinueHandshake { @@ -141,15 +155,11 @@ func (s *Server) handleConnection(conn net.Conn) { if result == ServerHandshakeComplete { handshakeComplete = true - size = 0 - session, _ = s.GetSession(conn) s.setDeadline(conn, handshakeComplete) - if s.OnSecureConnect != nil { - if action := s.OnSecureConnect(conn, session); action == Close { - break - } + if s.OnSecureConnect != nil && s.OnSecureConnect(conn, session) == Close { + break } continue @@ -157,13 +167,9 @@ func (s *Server) handleConnection(conn net.Conn) { } if err == io.EOF || size == 0 { - conn.Close() - s.RemoveSession(conn) - if s.OnDisconnect != nil { s.OnDisconnect(conn) } - return } @@ -171,10 +177,13 @@ func (s *Server) handleConnection(conn net.Conn) { continue } - encrypted := buf[4] == 1 - data := buf[5:size] - if encrypted && err == nil { - data, err = crypto.Decrypt(session.SharedSecret, data) + var data []byte = nil + if err == nil { + encrypted := buf[4] == 1 + data = buf[5:size] + if encrypted { + data, err = crypto.Decrypt(session.SharedSecret, data) + } } action := s.OnRead(conn, ReadInfo{ @@ -199,7 +208,6 @@ func (s *Server) processHandshake(conn net.Conn, buf []byte) (ServerHandshakeRes if err = packets.SendCertResponseHandshakePacket(conn, packets.CertificateResponse, s.cert); err != nil { return ServerHandshakeError, err } - return ServerContinueHandshake, nil case packets.ClientInformation: clientPublicKey, err := ecdh.P521().NewPublicKey(packet.Payload) @@ -215,7 +223,8 @@ func (s *Server) processHandshake(conn net.Conn, buf []byte) (ServerHandshakeRes aesSecret := sha256.Sum256(sharedSecret) s.SetSession(conn, ServerSession{ SharedSecret: aesSecret[:], - // TODO: Data: make(map[string]interface{}), + LastActivity: time.Now(), + Data: make(map[string]interface{}), }) verification, err := crypto.Encrypt(aesSecret[:], []byte{1, 2, 3, 4}) @@ -223,7 +232,9 @@ func (s *Server) processHandshake(conn net.Conn, buf []byte) (ServerHandshakeRes return ServerHandshakeError, err } - packets.SendHandshakePacket(conn, packets.ServerVerification, verification) + if err := packets.SendHandshakePacket(conn, packets.ServerVerification, verification); err != nil { + return ServerHandshakeError, err + } return ServerHandshakeComplete, nil } @@ -237,14 +248,15 @@ func (s *Server) Run() error { } defer listener.Close() - //TODO: if s.EnableConnPurge { - // //go s.connectionPurge() - //} - bytes, err := os.ReadFile(s.CertFile) + if s.EnableConnPurge { + go s.connectionPurge() + } + + certData, err := os.ReadFile(s.CertFile) if err != nil { return err } - if err = s.cert.Unmarshal(bytes); err != nil { + if err = s.cert.Unmarshal(certData); err != nil { return err } @@ -255,6 +267,10 @@ func (s *Server) Run() error { if s.IdleTimeout == 0 { s.IdleTimeout = 30 * time.Minute } + if s.ConnPurgeInterval == 0 { + s.ConnPurgeInterval = 35 * time.Minute + } + if s.sessions == nil { s.sessions = make(map[net.Conn]ServerSession) } @@ -266,12 +282,20 @@ func (s *Server) Run() error { s.HandshakeWriteDeadline = 500 * time.Millisecond } - s.mutex = &sync.Mutex{} + if s.OnAcceptingError == nil { + s.OnAcceptingError = func(err error) bool { return true } + } + if s.OnConnectionError == nil { + s.OnConnectionError = func(conn net.Conn, err error) AfterAction { panic(conn.RemoteAddr().String() + ": " + err.Error()) } + } for { conn, err := listener.Accept() if err != nil { - fmt.Println("Error accepting connection:", err) + if s.OnAcceptingError(err) { + return err + } + continue } go s.handleConnection(conn) @@ -288,52 +312,52 @@ func (s *Server) Send(conn net.Conn, b []byte) error { if err != nil { return err } - return netutils.SendToConn(conn, len(encrypted)+1, func(n int, b []byte) { b[n] = 1; copy(b[n+1:], encrypted) }) + return netutils.SendToConn(conn, len(encrypted)+1, func(n int, buf []byte) { buf[n] = 1; copy(buf[n+1:], encrypted) }) } func (s *Server) SendUnsecure(conn net.Conn, b []byte) error { return netutils.SendToConn(conn, len(b)+1, func(n int, buf []byte) { buf[n] = 0; copy(buf[n+1:], b) }) } -func (s *Server) UnmarshalPacket(b []byte, f func(int, []byte) error) error { - n, id, err := bstd.UnmarshalInt(0, b) +func (s *Server) UnmarshalPacket(buf []byte, f func(int, []byte) error) error { + n, id, err := bstd.UnmarshalInt(0, buf) if err != nil { return err } - return f(id, b[n:]) + return f(id, buf[n:]) } func (s *Server) SendPacket(conn net.Conn, id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.Send(conn, b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.Send(conn, buf) } func (s *Server) SendPacketUnsecure(conn net.Conn, id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.SendUnsecure(conn, b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.SendUnsecure(conn, buf) } func (s *Server) SendPacketToAll(id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.SendToAll(b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.SendToAll(buf) } func (s *Server) SendPacketUnsecureToAll(conn net.Conn, id int, p Packet) error { - b := make([]byte, p.Size()+bstd.SizeInt(id)) - n := bstd.MarshalInt(0, b, id) - p.Marshal(b[n:]) - return s.SendUnsecureToAll(b) + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.SendUnsecureToAll(buf) } -func (s *Server) SendToAll(b []byte) error { +func (s *Server) SendToAll(buf []byte) error { for conn := range s.sessions { - if err := s.Send(conn, b); err != nil { + if err := s.Send(conn, buf); err != nil { return err } } @@ -341,18 +365,16 @@ func (s *Server) SendToAll(b []byte) error { return nil } -func (s *Server) SendUnsecureToAll(b []byte) error { +func (s *Server) SendUnsecureToAll(buf []byte) error { for conn := range s.sessions { - if err := s.SendUnsecure(conn, b); err != nil { + if err := s.SendUnsecure(conn, buf); err != nil { return err } } return nil } - -/*TODO: -func (s *Server) Store(conn gnet.Conn, key string, value interface{}) error { +func (s *Server) Store(conn net.Conn, key string, value interface{}) error { ses, ok := s.GetSession(conn) if !ok { return ErrSessionNotFound @@ -365,7 +387,7 @@ func (s *Server) Store(conn gnet.Conn, key string, value interface{}) error { return nil } -func (s *Server) Get(conn gnet.Conn, key string) (interface{}, error) { +func (s *Server) Get(conn net.Conn, key string) (interface{}, error) { ses, ok := s.GetSession(conn) if !ok { return nil, ErrSessionNotFound @@ -376,12 +398,13 @@ func (s *Server) Get(conn gnet.Conn, key string) (interface{}, error) { s.mutex.Unlock() return value, nil -}*/ +} func (s *Server) GetSession(conn net.Conn) (ServerSession, bool) { s.mutex.Lock() ses, ok := s.sessions[conn] s.mutex.Unlock() + return ses, ok } From 1a9f6251b1411f0e8a77fadba4cc7c4cc09852b8 Mon Sep 17 00:00:00 2001 From: den Date: Thu, 30 Jan 2025 20:04:12 +0100 Subject: [PATCH 2/2] v2.0.0 --- LICENSE | 2 +- README.md | 26 +-- cert/cert.benc.go | 184 ---------------- cert/root.benc.go | 144 ------------- cert/rootkey.go | 146 +++++++++++++ cert/servercert.go | 186 ++++++++++++++++ cert/utils.go | 82 +++++--- client.go | 84 ++++---- common.go | 11 +- crypto/encryption.go | 7 +- examples/chat/client/client.kr | Bin 87 -> 0 bytes examples/chat/client/main.go | 79 ------- examples/chat/packets/common.go | 16 -- examples/chat/packets/gen.bat | 2 - .../chat/packets/initialization/common.go | 7 - .../chat/packets/initialization/packet.benc | 12 -- .../packets/initialization/packet.benc.go | 198 ------------------ examples/chat/packets/message/packet.benc | 8 - examples/chat/packets/message/packet.benc.go | 104 --------- examples/chat/server/main.go | 167 --------------- examples/chat/server/server.kc | Bin 365 -> 0 bytes gen/main.go | 4 +- go.mod | 5 +- go.sum | 4 + handshake/packet.go | 135 ++++++++++++ handshake/utils.go | 35 ++++ netutils/conn.go | 53 ++++- packets/handshake.benc.go | 124 ----------- packets/handshake.go | 41 ---- schemas/Certificate.benc | 13 -- schemas/Handshake.benc | 9 - schemas/RootKey.benc | 12 -- schemas/client_root_key.benc | 13 ++ schemas/gen.bat | 3 + schemas/handshake_packet.benc | 18 ++ schemas/server_certificate.benc | 16 ++ server.go | 120 +++++------ 37 files changed, 776 insertions(+), 1294 deletions(-) delete mode 100644 cert/cert.benc.go delete mode 100644 cert/root.benc.go create mode 100644 cert/rootkey.go create mode 100644 cert/servercert.go delete mode 100644 examples/chat/client/client.kr delete mode 100644 examples/chat/client/main.go delete mode 100644 examples/chat/packets/common.go delete mode 100644 examples/chat/packets/gen.bat delete mode 100644 examples/chat/packets/initialization/common.go delete mode 100644 examples/chat/packets/initialization/packet.benc delete mode 100644 examples/chat/packets/initialization/packet.benc.go delete mode 100644 examples/chat/packets/message/packet.benc delete mode 100644 examples/chat/packets/message/packet.benc.go delete mode 100644 examples/chat/server/main.go delete mode 100644 examples/chat/server/server.kc create mode 100644 handshake/packet.go create mode 100644 handshake/utils.go delete mode 100644 packets/handshake.benc.go delete mode 100644 packets/handshake.go delete mode 100644 schemas/Certificate.benc delete mode 100644 schemas/Handshake.benc delete mode 100644 schemas/RootKey.benc create mode 100644 schemas/client_root_key.benc create mode 100644 schemas/gen.bat create mode 100644 schemas/handshake_packet.benc create mode 100644 schemas/server_certificate.benc diff --git a/LICENSE b/LICENSE index 672a19d..46a15a0 100644 --- a/LICENSE +++ b/LICENSE @@ -1,4 +1,4 @@ -kinetra.de/net License +github.com/deneonet/knet License Version: 1.0.0 Copyright (c) 2024 deneonet diff --git a/README.md b/README.md index 3fbbdde..bfb97df 100644 --- a/README.md +++ b/README.md @@ -1,33 +1,33 @@ -# kinetra.de/net +# github.com/deneonet/knet A library to handle secure TCP connections over a custom protocol. -**Warning:** This library is not complete and contains unfinished code. +**Warning:** This library is not complete and may contain unfinished code. ## Security -**kinetra.de/net** uses AES-256 for data encryption, ECDH P521 for key exchange, and ECDSA P521 for signing. +**github.com/deneonet/knet** uses AES-256 for data encryption, ECDH P521 for key exchange, and ECDSA P521 for signing. ## Data Encoding -To efficiently encode data, **kinetra.de/net** utilizes [benc](https://github.com/deneonet/benc) as its serializer. +To efficiently encode data, **github.com/deneonet/knet** utilizes [benc](https://github.com/deneonet/benc) as its serializer. ## Key Rotations -If the private key of the server's certificate is compromised or just expired, simply generate a new one and update the client's root key as well. **cosair.gg** encodes a version field into the certificate and root key to verify that the client is always in sync with the server. If the versions do not match, a clear error will be returned. +If the private key of the server's certificate is compromised or just expired, simply generate a new one and update the client's root key as well. **kNet** encodes a version field into the certificate and root key to verify that the client is always in sync with the server. If the versions do not match, a clear error will be returned. ## Generating Certificates -As simple as `go run kinetra.de/net/gen -v {VERSION_NUMBER}`, everything is done locally on your machine. To ensure that the root key is in sync with the certificate, it will be generated as well. +As simple as `go run github.com/deneonet/knet/gen -v {VERSION_NUMBER}`, everything is done locally on your machine. To ensure that the root key is in sync with the certificate, it will be generated as well. ## The Handshake Process -1. **[Client]**: I want your certificate to prove your identity as **[Server]**. -2. **[Server]**: Sure, here’s my certificate. -3. **[Client]**: I'll check the signature using my root key, verifying that the public key was not compromised, is not expired, and matches the expected version. -4. **[Client]**: I verified it; it's valid. Here’s my public key. I'll create a shared secret using your public key. -5. **[Server]**: I have the shared secret now too. Let me verify that we have the same and that your public key was not compromised as well. I’ll send you an encrypted message using the shared secret. +1. **[Client]**: I want your certificate to prove your identity as **[Server]**. +2. **[Server]**: Sure, here’s my certificate. +3. **[Client]**: I'll check the signature using my root key, verifying that the public key was not compromised, is not expired, and matches the expected version. +4. **[Client]**: I verified it; it's valid. Here’s my public key. I'll create a shared secret using your public key. +5. **[Server]**: I have the shared secret now too. Let me verify that we have the same and that your public key was not compromised as well. I’ll send you an encrypted message using the shared secret. 6. **[Client]**: I successfully decrypted the message. Our connection is now secure! -## Examples +## Real-time chat Example -Find examples [here](https://github.com/deneonet/cosair.gg-net-examples). +Find a real-time chat example [here](https://github.com/deneonet/knet-real-time-chat). diff --git a/cert/cert.benc.go b/cert/cert.benc.go deleted file mode 100644 index 8ed247c..0000000 --- a/cert/cert.benc.go +++ /dev/null @@ -1,184 +0,0 @@ -// Code generated by bencgen golang. DO NOT EDIT. -// source: ../schemas/Certificate.benc - -package cert - -import ( - "github.com/deneonet/benc/std" - "github.com/deneonet/benc/impl/gen" -) - -// Struct - Certificate -type Certificate struct { - PrivateKey []byte - PublicKey []byte - PublicKeySignature []byte - Version int - CreatedAt int64 -} - -// Reserved Ids - Certificate -var certificateRIds = []uint16{} - -// Size - Certificate -func (certificate *Certificate) Size() int { - return certificate.size(0) -} - -// Nested Size - Certificate -func (certificate *Certificate) size(id uint16) (s int) { - s += bstd.SizeBytes(certificate.PrivateKey) + 2 - s += bstd.SizeBytes(certificate.PublicKey) + 2 - s += bstd.SizeBytes(certificate.PublicKeySignature) + 2 - s += bstd.SizeInt(certificate.Version) + 2 - s += bstd.SizeInt64() + 2 - - if id > 255 { - s += 5 - return - } - s += 4 - return -} - -// SizePlain - Certificate -func (certificate *Certificate) SizePlain() (s int) { - s += bstd.SizeBytes(certificate.PrivateKey) - s += bstd.SizeBytes(certificate.PublicKey) - s += bstd.SizeBytes(certificate.PublicKeySignature) - s += bstd.SizeInt(certificate.Version) - s += bstd.SizeInt64() - return -} - -// Marshal - Certificate -func (certificate *Certificate) Marshal(b []byte) { - certificate.marshal(0, b, 0) -} - -// Nested Marshal - Certificate -func (certificate *Certificate) marshal(tn int, b []byte, id uint16) (n int) { - n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) - n = bstd.MarshalBytes(n, b, certificate.PrivateKey) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 2) - n = bstd.MarshalBytes(n, b, certificate.PublicKey) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 3) - n = bstd.MarshalBytes(n, b, certificate.PublicKeySignature) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Varint, 4) - n = bstd.MarshalInt(n, b, certificate.Version) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed64, 5) - n = bstd.MarshalInt64(n, b, certificate.CreatedAt) - - n += 2 - b[n-2] = 1 - b[n-1] = 1 - return -} - -// MarshalPlain - Certificate -func (certificate *Certificate) MarshalPlain(tn int, b []byte) (n int) { - n = tn - n = bstd.MarshalBytes(n, b, certificate.PrivateKey) - n = bstd.MarshalBytes(n, b, certificate.PublicKey) - n = bstd.MarshalBytes(n, b, certificate.PublicKeySignature) - n = bstd.MarshalInt(n, b, certificate.Version) - n = bstd.MarshalInt64(n, b, certificate.CreatedAt) - return n -} - -// Unmarshal - Certificate -func (certificate *Certificate) Unmarshal(b []byte) (err error) { - _, err = certificate.unmarshal(0, b, []uint16{}, 0) - return -} - -// Nested Unmarshal - Certificate -func (certificate *Certificate) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { - var ok bool - if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, certificateRIds, 1); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, certificate.PrivateKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, certificateRIds, 2); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, certificate.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, certificateRIds, 3); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, certificate.PublicKeySignature, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, certificateRIds, 4); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, certificate.Version, err = bstd.UnmarshalInt(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, certificateRIds, 5); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, certificate.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { - return - } - } - n += 2 - return -} - -// UnmarshalPlain - Certificate -func (certificate *Certificate) UnmarshalPlain(tn int, b []byte) (n int, err error) { - n = tn - if n, certificate.PrivateKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - if n, certificate.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - if n, certificate.PublicKeySignature, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - if n, certificate.Version, err = bstd.UnmarshalInt(n, b); err != nil { - return - } - if n, certificate.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { - return - } - return -} - diff --git a/cert/root.benc.go b/cert/root.benc.go deleted file mode 100644 index 4d78b9b..0000000 --- a/cert/root.benc.go +++ /dev/null @@ -1,144 +0,0 @@ -// Code generated by bencgen golang. DO NOT EDIT. -// source: ../schemas/RootKey.benc - -package cert - -import ( - "github.com/deneonet/benc/std" - "github.com/deneonet/benc/impl/gen" -) - -// Struct - RootKey -type RootKey struct { - PublicKey []byte - Version int - CreatedAt int64 -} - -// Reserved Ids - RootKey -var rootKeyRIds = []uint16{} - -// Size - RootKey -func (rootKey *RootKey) Size() int { - return rootKey.size(0) -} - -// Nested Size - RootKey -func (rootKey *RootKey) size(id uint16) (s int) { - s += bstd.SizeBytes(rootKey.PublicKey) + 2 - s += bstd.SizeInt(rootKey.Version) + 2 - s += bstd.SizeInt64() + 2 - - if id > 255 { - s += 5 - return - } - s += 4 - return -} - -// SizePlain - RootKey -func (rootKey *RootKey) SizePlain() (s int) { - s += bstd.SizeBytes(rootKey.PublicKey) - s += bstd.SizeInt(rootKey.Version) - s += bstd.SizeInt64() - return -} - -// Marshal - RootKey -func (rootKey *RootKey) Marshal(b []byte) { - rootKey.marshal(0, b, 0) -} - -// Nested Marshal - RootKey -func (rootKey *RootKey) marshal(tn int, b []byte, id uint16) (n int) { - n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) - n = bstd.MarshalBytes(n, b, rootKey.PublicKey) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Varint, 2) - n = bstd.MarshalInt(n, b, rootKey.Version) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed64, 3) - n = bstd.MarshalInt64(n, b, rootKey.CreatedAt) - - n += 2 - b[n-2] = 1 - b[n-1] = 1 - return -} - -// MarshalPlain - RootKey -func (rootKey *RootKey) MarshalPlain(tn int, b []byte) (n int) { - n = tn - n = bstd.MarshalBytes(n, b, rootKey.PublicKey) - n = bstd.MarshalInt(n, b, rootKey.Version) - n = bstd.MarshalInt64(n, b, rootKey.CreatedAt) - return n -} - -// Unmarshal - RootKey -func (rootKey *RootKey) Unmarshal(b []byte) (err error) { - _, err = rootKey.unmarshal(0, b, []uint16{}, 0) - return -} - -// Nested Unmarshal - RootKey -func (rootKey *RootKey) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { - var ok bool - if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, rootKeyRIds, 1); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, rootKey.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, rootKeyRIds, 2); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, rootKey.Version, err = bstd.UnmarshalInt(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, rootKeyRIds, 3); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, rootKey.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { - return - } - } - n += 2 - return -} - -// UnmarshalPlain - RootKey -func (rootKey *RootKey) UnmarshalPlain(tn int, b []byte) (n int, err error) { - n = tn - if n, rootKey.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - if n, rootKey.Version, err = bstd.UnmarshalInt(n, b); err != nil { - return - } - if n, rootKey.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { - return - } - return -} - diff --git a/cert/rootkey.go b/cert/rootkey.go new file mode 100644 index 0000000..4c4ef04 --- /dev/null +++ b/cert/rootkey.go @@ -0,0 +1,146 @@ +// Code generated by bencgen go. DO NOT EDIT. +// source: client_root_key.benc + +package cert + +import ( + "github.com/deneonet/benc/std" + "github.com/deneonet/benc/impl/gen" + + +) + +// Struct - ClientRootKey +type ClientRootKey struct { + Key []byte + Version int + CreatedAt int64 +} + +// Reserved Ids - ClientRootKey +var clientRootKeyRIds = []uint16{} + +// Size - ClientRootKey +func (clientRootKey *ClientRootKey) Size() int { + return clientRootKey.NestedSize(0) +} + +// Nested Size - ClientRootKey +func (clientRootKey *ClientRootKey) NestedSize(id uint16) (s int) { + s += bstd.SizeBytes(clientRootKey.Key) + 2 + s += bstd.SizeInt(clientRootKey.Version) + 2 + s += bstd.SizeInt64() + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - ClientRootKey +func (clientRootKey *ClientRootKey) SizePlain() (s int) { + s += bstd.SizeBytes(clientRootKey.Key) + s += bstd.SizeInt(clientRootKey.Version) + s += bstd.SizeInt64() + return +} + +// Marshal - ClientRootKey +func (clientRootKey *ClientRootKey) Marshal(b []byte) { + clientRootKey.NestedMarshal(0, b, 0) +} + +// Nested Marshal - ClientRootKey +func (clientRootKey *ClientRootKey) NestedMarshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) + n = bstd.MarshalBytes(n, b, clientRootKey.Key) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Varint, 2) + n = bstd.MarshalInt(n, b, clientRootKey.Version) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed64, 3) + n = bstd.MarshalInt64(n, b, clientRootKey.CreatedAt) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - ClientRootKey +func (clientRootKey *ClientRootKey) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalBytes(n, b, clientRootKey.Key) + n = bstd.MarshalInt(n, b, clientRootKey.Version) + n = bstd.MarshalInt64(n, b, clientRootKey.CreatedAt) + return n +} + +// Unmarshal - ClientRootKey +func (clientRootKey *ClientRootKey) Unmarshal(b []byte) (err error) { + _, err = clientRootKey.NestedUnmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - ClientRootKey +func (clientRootKey *ClientRootKey) NestedUnmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, clientRootKeyRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, clientRootKey.Key, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, clientRootKeyRIds, 2); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, clientRootKey.Version, err = bstd.UnmarshalInt(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, clientRootKeyRIds, 3); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, clientRootKey.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - ClientRootKey +func (clientRootKey *ClientRootKey) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, clientRootKey.Key, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + if n, clientRootKey.Version, err = bstd.UnmarshalInt(n, b); err != nil { + return + } + if n, clientRootKey.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { + return + } + return +} + diff --git a/cert/servercert.go b/cert/servercert.go new file mode 100644 index 0000000..f0bb109 --- /dev/null +++ b/cert/servercert.go @@ -0,0 +1,186 @@ +// Code generated by bencgen go. DO NOT EDIT. +// source: server_certificate.benc + +package cert + +import ( + "github.com/deneonet/benc/std" + "github.com/deneonet/benc/impl/gen" + + +) + +// Struct - ServerCertificate +type ServerCertificate struct { + PublicKey []byte + PrivateKey []byte + PublicKeySignature []byte + Version int + CreatedAt int64 +} + +// Reserved Ids - ServerCertificate +var serverCertificateRIds = []uint16{} + +// Size - ServerCertificate +func (serverCertificate *ServerCertificate) Size() int { + return serverCertificate.NestedSize(0) +} + +// Nested Size - ServerCertificate +func (serverCertificate *ServerCertificate) NestedSize(id uint16) (s int) { + s += bstd.SizeBytes(serverCertificate.PublicKey) + 2 + s += bstd.SizeBytes(serverCertificate.PrivateKey) + 2 + s += bstd.SizeBytes(serverCertificate.PublicKeySignature) + 2 + s += bstd.SizeInt(serverCertificate.Version) + 2 + s += bstd.SizeInt64() + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - ServerCertificate +func (serverCertificate *ServerCertificate) SizePlain() (s int) { + s += bstd.SizeBytes(serverCertificate.PublicKey) + s += bstd.SizeBytes(serverCertificate.PrivateKey) + s += bstd.SizeBytes(serverCertificate.PublicKeySignature) + s += bstd.SizeInt(serverCertificate.Version) + s += bstd.SizeInt64() + return +} + +// Marshal - ServerCertificate +func (serverCertificate *ServerCertificate) Marshal(b []byte) { + serverCertificate.NestedMarshal(0, b, 0) +} + +// Nested Marshal - ServerCertificate +func (serverCertificate *ServerCertificate) NestedMarshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) + n = bstd.MarshalBytes(n, b, serverCertificate.PublicKey) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 2) + n = bstd.MarshalBytes(n, b, serverCertificate.PrivateKey) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 3) + n = bstd.MarshalBytes(n, b, serverCertificate.PublicKeySignature) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Varint, 4) + n = bstd.MarshalInt(n, b, serverCertificate.Version) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed64, 5) + n = bstd.MarshalInt64(n, b, serverCertificate.CreatedAt) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - ServerCertificate +func (serverCertificate *ServerCertificate) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalBytes(n, b, serverCertificate.PublicKey) + n = bstd.MarshalBytes(n, b, serverCertificate.PrivateKey) + n = bstd.MarshalBytes(n, b, serverCertificate.PublicKeySignature) + n = bstd.MarshalInt(n, b, serverCertificate.Version) + n = bstd.MarshalInt64(n, b, serverCertificate.CreatedAt) + return n +} + +// Unmarshal - ServerCertificate +func (serverCertificate *ServerCertificate) Unmarshal(b []byte) (err error) { + _, err = serverCertificate.NestedUnmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - ServerCertificate +func (serverCertificate *ServerCertificate) NestedUnmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, serverCertificateRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, serverCertificate.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, serverCertificateRIds, 2); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, serverCertificate.PrivateKey, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, serverCertificateRIds, 3); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, serverCertificate.PublicKeySignature, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, serverCertificateRIds, 4); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, serverCertificate.Version, err = bstd.UnmarshalInt(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, serverCertificateRIds, 5); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, serverCertificate.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - ServerCertificate +func (serverCertificate *ServerCertificate) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, serverCertificate.PublicKey, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + if n, serverCertificate.PrivateKey, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + if n, serverCertificate.PublicKeySignature, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + if n, serverCertificate.Version, err = bstd.UnmarshalInt(n, b); err != nil { + return + } + if n, serverCertificate.CreatedAt, err = bstd.UnmarshalInt64(n, b); err != nil { + return + } + return +} + diff --git a/cert/utils.go b/cert/utils.go index 8423d97..39a33ba 100644 --- a/cert/utils.go +++ b/cert/utils.go @@ -1,5 +1,3 @@ -//go:generate bencgen --in ../schemas/RootKey.benc --out . --file root.benc --lang go -//go:generate bencgen --in ../schemas/Certificate.benc --out . --file cert.benc --lang go package cert import ( @@ -13,35 +11,46 @@ import ( ) var ( - ErrFailedVerification = errors.New("public key of certificate couldn't be verified") - ErrCertificateExpired = errors.New("certificate has expired") - ErrRootKeyExpired = errors.New("root key has expired") + ErrClientRootKeyExpired = errors.New("root key has expired") + ErrServerCertificateExpired = errors.New("certificate has expired") + ErrVersionMismatch = errors.New("certificate and root key are not in sync") + ErrFailedVerification = errors.New("public key of certificate couldn't be verified") - ErrVersionMismatch = errors.New("certificate and root key are not in sync") + ErrCertificateSigningFailed = errors.New("failed to sign certificate") - oneYear = 365 * 24 * time.Hour // one year in hours + ErrECDSARootKeyGenerationFailed = errors.New("failed to generate ECDSA root key") + ErrECDHPrivateKeyGenerationFailed = errors.New("failed to generate ECDH private key") + ErrServerPublicKeyCreationFailed = errors.New("failed to create server's public key") + + // TODO: Make expiration time configurable + DefaultCertificateExpiry = 365 * 24 * time.Hour // Default one year in hours ) -func VerifyCertificate(cert Certificate, root RootKey) (*ecdh.PublicKey, error) { +// VerifyCertificate verifies the certificate with the provided root key, returning the public key. +func VerifyCertificate(cert ServerCertificate, root ClientRootKey, expiry time.Duration) (*ecdh.PublicKey, error) { if cert.Version != root.Version { return nil, ErrVersionMismatch } - if time.Now().Unix()-cert.CreatedAt > int64(oneYear.Seconds()) { - return nil, ErrCertificateExpired + if expiry == 0 { + expiry = DefaultCertificateExpiry + } + + if time.Now().Unix()-cert.CreatedAt > int64(expiry.Seconds()) { + return nil, ErrServerCertificateExpired } - if time.Now().Unix()-root.CreatedAt > int64(oneYear.Seconds()) { - return nil, ErrRootKeyExpired + if time.Now().Unix()-root.CreatedAt > int64(expiry.Seconds()) { + return nil, ErrClientRootKeyExpired } publicKey, err := ecdh.P521().NewPublicKey(cert.PublicKey) if err != nil { - return nil, err + return nil, ErrServerPublicKeyCreationFailed } rootKey := new(ecdsa.PublicKey) rootKey.Curve = elliptic.P521() - rootKey.X, rootKey.Y = elliptic.UnmarshalCompressed(elliptic.P521(), root.PublicKey) + rootKey.X, rootKey.Y = elliptic.UnmarshalCompressed(elliptic.P521(), root.Key) if !ecdsa.VerifyASN1(rootKey, publicKey.Bytes(), cert.PublicKeySignature) { return nil, ErrFailedVerification @@ -50,35 +59,34 @@ func VerifyCertificate(cert Certificate, root RootKey) (*ecdh.PublicKey, error) return publicKey, nil } +// GenerateCertificateChain generates a certificate chain and stores it in the specified file paths. func GenerateCertificateChain(version int, certFilePath string, rootKeyFilePath string) error { privateKey, err := ecdh.P521().GenerateKey(rand.Reader) if err != nil { - return err + return ErrECDHPrivateKeyGenerationFailed } rootKey, err := ecdsa.GenerateKey(elliptic.P521(), rand.Reader) if err != nil { - return err + return ErrECDSARootKeyGenerationFailed } publicKey := privateKey.PublicKey().Bytes() signature, err := ecdsa.SignASN1(rand.Reader, rootKey, publicKey) if err != nil { - return err + return ErrCertificateSigningFailed } - cert := Certificate{ - PrivateKey: privateKey.Bytes(), + cert := ServerCertificate{ PublicKey: publicKey, PublicKeySignature: signature, - - Version: version, - CreatedAt: time.Now().Unix(), + PrivateKey: privateKey.Bytes(), + Version: version, + CreatedAt: time.Now().Unix(), } - root := RootKey{ - PublicKey: elliptic.MarshalCompressed(elliptic.P521(), rootKey.PublicKey.X, rootKey.PublicKey.Y), - + root := ClientRootKey{ + Key: elliptic.MarshalCompressed(elliptic.P521(), rootKey.PublicKey.X, rootKey.PublicKey.Y), Version: version, CreatedAt: time.Now().Unix(), } @@ -86,26 +94,32 @@ func GenerateCertificateChain(version int, certFilePath string, rootKeyFilePath certData := make([]byte, cert.Size()) cert.Marshal(certData) - certFile, err := os.Create(certFilePath) - if err != nil { + if err := writeToFile(certFilePath, certData); err != nil { return err } - defer certFile.Close() - if _, err = certFile.Write(certData); err != nil { + rootKeyData := make([]byte, root.Size()) + root.Marshal(rootKeyData) + + if err := writeToFile(rootKeyFilePath, rootKeyData); err != nil { return err } - rootKeyData := make([]byte, root.Size()) - root.Marshal(rootKeyData) + return nil +} - rootKeyFile, err := os.Create(rootKeyFilePath) +func writeToFile(filePath string, data []byte) error { + file, err := os.Create(filePath) if err != nil { return err } - defer rootKeyFile.Close() + defer file.Close() + + if _, err := file.Write(data); err != nil { + return err + } - if _, err = rootKeyFile.Write(rootKeyData); err != nil { + if err := os.Chmod(filePath, 0600); err != nil { return err } diff --git a/client.go b/client.go index 02070ad..730066d 100644 --- a/client.go +++ b/client.go @@ -10,10 +10,10 @@ import ( "sync" "time" - "kinetra.de/net/cert" - "kinetra.de/net/crypto" - "kinetra.de/net/netutils" - "kinetra.de/net/packets" + "github.com/deneonet/knet/cert" + "github.com/deneonet/knet/crypto" + "github.com/deneonet/knet/handshake" + "github.com/deneonet/knet/netutils" bstd "github.com/deneonet/benc/std" ) @@ -22,11 +22,13 @@ type Client struct { priv *ecdh.PrivateKey session ClientSession - rootKey cert.RootKey + rootKey cert.ClientRootKey RootKeyFile string mutex *sync.Mutex + netUtilsSettings netutils.NetUtilsSettings + ReadDeadline time.Duration WriteDeadline time.Duration @@ -53,43 +55,26 @@ const ( ClientHandshakeError ) -func (c *Client) setDeadline(conn net.Conn, handshakeComplete bool) { - conn.SetDeadline(time.Time{}) - - var readDeadline, writeDeadline time.Duration - if handshakeComplete { - readDeadline, writeDeadline = c.ReadDeadline, c.WriteDeadline - } else { - readDeadline, writeDeadline = c.HandshakeReadDeadline, c.HandshakeWriteDeadline - } - - if readDeadline > 0 { - conn.SetReadDeadline(time.Now().Add(readDeadline)) - } - if writeDeadline > 0 { - conn.SetWriteDeadline(time.Now().Add(writeDeadline)) - } -} - func (c *Client) processHandshake(conn net.Conn, buf []byte) (ClientHandshakeResult, error) { - packet, err := packets.UnmarshalHandshakePacket(buf) + packet, err := handshake.UnmarshalHandshakePacket(buf) if err != nil { return ClientHandshakeError, err } - switch packets.HandshakePacketId(packet.Id) { - case packets.CertificateResponse: - var certificate cert.Certificate + switch packet.Type { + case handshake.PacketTypeCertificateResponse: + var certificate cert.ServerCertificate if err := certificate.Unmarshal(packet.Payload); err != nil { return ClientHandshakeError, err } - serverPublicKey, err := cert.VerifyCertificate(certificate, c.rootKey) + // TODO: Custom expiry + serverPublicKey, err := cert.VerifyCertificate(certificate, c.rootKey, 0) if err != nil { return ClientHandshakeError, err } - if err = packets.SendHandshakePacket(conn, packets.ClientInformation, c.priv.PublicKey().Bytes()); err != nil { + if err = handshake.SendHandshakePacket(conn, handshake.PacketTypeClientInformation, c.priv.PublicKey().Bytes(), &c.netUtilsSettings); err != nil { return ClientHandshakeError, err } @@ -99,15 +84,17 @@ func (c *Client) processHandshake(conn net.Conn, buf []byte) (ClientHandshakeRes } aesSecret := sha256.Sum256(sharedSecret) + c.mutex.Lock() c.session = ClientSession{ Conn: conn, SharedSecret: aesSecret[:], } + c.mutex.Unlock() return ClientContinueHandshake, nil - case packets.ServerVerification: + case handshake.PacketTypeServerVerification: if _, err := crypto.Decrypt(c.session.SharedSecret, packet.Payload); err != nil { - return ClientHandshakeError, ErrDecryptingVerificationId + return ClientHandshakeError, ErrDecryptingServerVerification } return ClientHandshakeComplete, nil @@ -117,11 +104,11 @@ func (c *Client) processHandshake(conn net.Conn, buf []byte) (ClientHandshakeRes } func (c *Client) Connect(address string) (net.Conn, error) { - if c.HandshakeReadDeadline == 0 { - c.HandshakeReadDeadline = 500 * time.Millisecond - } - if c.HandshakeWriteDeadline == 0 { - c.HandshakeWriteDeadline = 500 * time.Millisecond + c.netUtilsSettings = netutils.NetUtilsSettings{ + HandshakeReadDeadline: c.HandshakeReadDeadline, + HandshakeWriteDeadline: c.HandshakeWriteDeadline, + ReadDeadline: c.ReadDeadline, + WriteDeadline: c.WriteDeadline, } c.mutex = &sync.Mutex{} @@ -143,11 +130,9 @@ func (c *Client) Connect(address string) (net.Conn, error) { return nil, err } - defer func() { - conn.Close() - }() + defer conn.Close() - if err = packets.SendHandshakePacket(conn, packets.CertificateRequest, nil); err != nil { + if err = handshake.SendHandshakePacket(conn, handshake.PacketTypeCertificateRequest, nil, &c.netUtilsSettings); err != nil { return nil, err } @@ -155,8 +140,7 @@ func (c *Client) Connect(address string) (net.Conn, error) { buf := make([]byte, c.BufferSize) for { - c.setDeadline(conn, handshakeComplete) - s, err := netutils.ReadFromConn(conn, buf) + s, err := netutils.ReadFromConn(conn, buf, &c.netUtilsSettings) if err != nil && !handshakeComplete { return nil, err @@ -166,6 +150,7 @@ func (c *Client) Connect(address string) (net.Conn, error) { if !handshakeComplete { result, err = c.processHandshake(conn, buf[4:s]) if err != nil { + conn.Close() return nil, err } @@ -177,7 +162,8 @@ func (c *Client) Connect(address string) (net.Conn, error) { handshakeComplete = true s = 0 - c.setDeadline(conn, handshakeComplete) + c.netUtilsSettings.HandshakeCompleted = true + if c.OnSecureConnect != nil { if action := c.OnSecureConnect(c.session); action == Close { break @@ -216,7 +202,7 @@ func (c *Client) Connect(address string) (net.Conn, error) { } } - return nil, err + return conn, nil } func (c *Client) Send(b []byte) error { @@ -224,11 +210,17 @@ func (c *Client) Send(b []byte) error { if err != nil { return err } - return netutils.SendToConn(c.session.Conn, len(encrypted)+1, func(n int, b []byte) { b[n] = 1; copy(b[n+1:], encrypted) }) + return netutils.SendToConn(c.session.Conn, len(encrypted)+1, func(n int, b []byte) { + b[n] = 1 + copy(b[n+1:], encrypted) + }, &c.netUtilsSettings) } func (c *Client) SendUnsecure(b []byte) error { - return netutils.SendToConn(c.session.Conn, len(b)+1, func(n int, buf []byte) { buf[n] = 0; copy(buf[n+1:], b) }) + return netutils.SendToConn(c.session.Conn, len(b)+1, func(n int, buf []byte) { + buf[n] = 0 + copy(buf[n+1:], b) + }, &c.netUtilsSettings) } func (c *Client) SendPacket(id int, p Packet) error { diff --git a/common.go b/common.go index 1161400..e23561c 100644 --- a/common.go +++ b/common.go @@ -23,10 +23,9 @@ const ( ) var ( - ErrSessionNotFound = errors.New("no session with that remote address found") - ErrDecryptingVerificationId = errors.New("error decrypting verification id") - ErrInvalidHandshakePacket = errors.New("invalid handshake packet received") - ErrBufTooSmall = errors.New("buffer is too small for the requested size") - ErrInvalidRootKey = errors.New("invalid root key in client struct") - ErrDataExceededBufferSize = errors.New("received data size exceeded buffer size") + ErrSessionNotFound = errors.New("no session associated with that connection found") + ErrInvalidRootKey = errors.New("invalid root key in client struct") + ErrInvalidHandshakePacket = errors.New("invalid handshake packet received") + ErrDecryptingServerVerification = errors.New("error decrypting server verification") + ErrDataExceededBufferSize = errors.New("received data size exceeded buffer size") ) diff --git a/crypto/encryption.go b/crypto/encryption.go index 2cfa0d5..d2fa7c9 100644 --- a/crypto/encryption.go +++ b/crypto/encryption.go @@ -12,6 +12,8 @@ var ( ErrCipherTextTooShort = errors.New("cipher text length is smaller than nonce length") ) +// Encrypt encrypts the provided data using the given AES key and returns the ciphertext +// which includes the nonce as the first part of the result. func Encrypt(key, data []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { @@ -32,6 +34,8 @@ func Encrypt(key, data []byte) ([]byte, error) { return cipherText, nil } +// Decrypt decrypts the provided ciphertext using the given AES key and returns the decrypted data. +// The ciphertext must include the nonce at the start. func Decrypt(key, cipherData []byte) ([]byte, error) { block, err := aes.NewCipher(key) if err != nil { @@ -49,6 +53,5 @@ func Decrypt(key, cipherData []byte) ([]byte, error) { } nonce, cipherText := cipherData[:nonceSize], cipherData[nonceSize:] - plainData, err := aesGCM.Open(nil, nonce, cipherText, nil) - return plainData, err + return aesGCM.Open(nil, nonce, cipherText, nil) } diff --git a/examples/chat/client/client.kr b/examples/chat/client/client.kr deleted file mode 100644 index 1a730e2dde50369d168307de275009dcf9c8ec51..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 87 zcmV-d0I2^000RL-0sv@U9NQD@+jM#G!Ld|A?domk(`h-QUH~M}DQTRxidONr_o9~@ ta|)n}bQir85k4^F^N#uLWew`K%_rEn>A0u_0s;sFblED;000000Rgv-CK~_% diff --git a/examples/chat/client/main.go b/examples/chat/client/main.go deleted file mode 100644 index 2f6a88f..0000000 --- a/examples/chat/client/main.go +++ /dev/null @@ -1,79 +0,0 @@ -package main - -import ( - "errors" - "fmt" - "net" - "time" - - knet "kinetra.de/net" - "kinetra.de/net/examples/chat/packets" - "kinetra.de/net/examples/chat/packets/initialization" -) - -var ErrResponse = errors.New("response error") - -func main() { - client := &knet.Client{ - RootKeyFile: "client.kr", - BufferSize: 1024, - WriteDeadline: 10 * time.Second, - HandshakeReadDeadline: 500 * time.Millisecond, - HandshakeWriteDeadline: 500 * time.Millisecond, - } - - client.OnRead = func(conn net.Conn, info knet.ReadInfo) knet.AfterAction { - err := client.UnmarshalPacket(info.Data, func(id int, b []byte) (err error) { - switch id { - case packets.InitializationResponsePacket: - var response initialization.Response - if err = response.Unmarshal(b); err != nil { - return - } - - switch response.Data { - case initialization.ResponseUsernameTaken: - fmt.Println("Username is taken.") - case initialization.ResponseUsernameMissing: - fmt.Println("Username is missing.") - case initialization.ResponseSuccess: - return nil - } - - return ErrResponse - } - - return nil - }) - - if err != nil { - return knet.Close - } - - return knet.None - } - - client.OnSecureConnect = func(session knet.ClientSession) knet.AfterAction { - fmt.Println("Secure connection established.") - - username := initialization.Username{ - Data: "deneonet", - } - err := client.SendPacket(packets.InitializationUsernamePacket, &username) - if err != nil { - fmt.Println("Error sending username packet: ", err) - } - - return knet.None - } - - client.OnDisconnect = func() { - fmt.Println("Disconnected from server.") - } - - _, err := client.Connect("localhost:8080") - if err != nil { - fmt.Println("Error connecting to server:", err) - return - } -} diff --git a/examples/chat/packets/common.go b/examples/chat/packets/common.go deleted file mode 100644 index 23e9a39..0000000 --- a/examples/chat/packets/common.go +++ /dev/null @@ -1,16 +0,0 @@ -package packets - -const ( - InitializationUsernamePacket int = iota - InitializationResponsePacket - - MessagePacket - MessageResponsePacket -) - -type Message struct { - Username string - Message string - - HasDisconnected bool // After someone disconnected, everyone can reclaim that username, so in order to extinguish the "new user" from the "old users", a "disconnected" mark is appended to the message -} diff --git a/examples/chat/packets/gen.bat b/examples/chat/packets/gen.bat deleted file mode 100644 index fc10ad6..0000000 --- a/examples/chat/packets/gen.bat +++ /dev/null @@ -1,2 +0,0 @@ -bencgen --in .\message\packet.benc --out ./ --lang go -bencgen --in .\initialization\packet.benc --out ./ --lang go \ No newline at end of file diff --git a/examples/chat/packets/initialization/common.go b/examples/chat/packets/initialization/common.go deleted file mode 100644 index f8fd942..0000000 --- a/examples/chat/packets/initialization/common.go +++ /dev/null @@ -1,7 +0,0 @@ -package initialization - -const ( - ResponseSuccess byte = iota - ResponseUsernameTaken - ResponseUsernameMissing -) diff --git a/examples/chat/packets/initialization/packet.benc b/examples/chat/packets/initialization/packet.benc deleted file mode 100644 index a863578..0000000 --- a/examples/chat/packets/initialization/packet.benc +++ /dev/null @@ -1,12 +0,0 @@ -header initialization; - -ctr Username { - string data = 1; -} - -ctr Response { - byte data = 1; -} - -# DO NOT EDIT. -# [meta_s] eyJtc2dzIjp7IlJlc3BvbnNlIjp7InJJZHMiOm51bGwsImZpZWxkcyI6eyIxIjp7IklkIjoxLCJOYW1lIjoiZGF0YSIsIlR5cGUiOnsiVG9rZW5UeXBlIjoxOSwiTWFwS2V5VHlwZSI6bnVsbCwiQ2hpbGRUeXBlIjpudWxsLCJDdHJOYW1lIjoiIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX19fSwiVXNlcm5hbWUiOnsicklkcyI6bnVsbCwiZmllbGRzIjp7IjEiOnsiSWQiOjEsIk5hbWUiOiJkYXRhIiwiVHlwZSI6eyJUb2tlblR5cGUiOjE1LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsIkN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fX19fX0= [meta_e] \ No newline at end of file diff --git a/examples/chat/packets/initialization/packet.benc.go b/examples/chat/packets/initialization/packet.benc.go deleted file mode 100644 index 59b11dc..0000000 --- a/examples/chat/packets/initialization/packet.benc.go +++ /dev/null @@ -1,198 +0,0 @@ -// Code generated by bencgen golang. DO NOT EDIT. -// source: .\initialization\packet.benc - -package initialization - -import ( - "github.com/deneonet/benc/std" - "github.com/deneonet/benc/impl/gen" -) - -// Struct - Username -type Username struct { - Data string -} - -// Reserved Ids - Username -var usernameRIds = []uint16{} - -// Size - Username -func (username *Username) Size() int { - return username.size(0) -} - -// Nested Size - Username -func (username *Username) size(id uint16) (s int) { - s += bstd.SizeString(username.Data) + 2 - - if id > 255 { - s += 5 - return - } - s += 4 - return -} - -// SizePlain - Username -func (username *Username) SizePlain() (s int) { - s += bstd.SizeString(username.Data) - return -} - -// Marshal - Username -func (username *Username) Marshal(b []byte) { - username.marshal(0, b, 0) -} - -// Nested Marshal - Username -func (username *Username) marshal(tn int, b []byte, id uint16) (n int) { - n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) - n = bstd.MarshalString(n, b, username.Data) - - n += 2 - b[n-2] = 1 - b[n-1] = 1 - return -} - -// MarshalPlain - Username -func (username *Username) MarshalPlain(tn int, b []byte) (n int) { - n = tn - n = bstd.MarshalString(n, b, username.Data) - return n -} - -// Unmarshal - Username -func (username *Username) Unmarshal(b []byte) (err error) { - _, err = username.unmarshal(0, b, []uint16{}, 0) - return -} - -// Nested Unmarshal - Username -func (username *Username) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { - var ok bool - if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, usernameRIds, 1); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, username.Data, err = bstd.UnmarshalString(n, b); err != nil { - return - } - } - n += 2 - return -} - -// UnmarshalPlain - Username -func (username *Username) UnmarshalPlain(tn int, b []byte) (n int, err error) { - n = tn - if n, username.Data, err = bstd.UnmarshalString(n, b); err != nil { - return - } - return -} - -// Struct - Response -type Response struct { - Data byte -} - -// Reserved Ids - Response -var responseRIds = []uint16{} - -// Size - Response -func (response *Response) Size() int { - return response.size(0) -} - -// Nested Size - Response -func (response *Response) size(id uint16) (s int) { - s += bstd.SizeByte() + 2 - - if id > 255 { - s += 5 - return - } - s += 4 - return -} - -// SizePlain - Response -func (response *Response) SizePlain() (s int) { - s += bstd.SizeByte() - return -} - -// Marshal - Response -func (response *Response) Marshal(b []byte) { - response.marshal(0, b, 0) -} - -// Nested Marshal - Response -func (response *Response) marshal(tn int, b []byte, id uint16) (n int) { - n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed8, 1) - n = bstd.MarshalByte(n, b, response.Data) - - n += 2 - b[n-2] = 1 - b[n-1] = 1 - return -} - -// MarshalPlain - Response -func (response *Response) MarshalPlain(tn int, b []byte) (n int) { - n = tn - n = bstd.MarshalByte(n, b, response.Data) - return n -} - -// Unmarshal - Response -func (response *Response) Unmarshal(b []byte) (err error) { - _, err = response.unmarshal(0, b, []uint16{}, 0) - return -} - -// Nested Unmarshal - Response -func (response *Response) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { - var ok bool - if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, responseRIds, 1); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, response.Data, err = bstd.UnmarshalByte(n, b); err != nil { - return - } - } - n += 2 - return -} - -// UnmarshalPlain - Response -func (response *Response) UnmarshalPlain(tn int, b []byte) (n int, err error) { - n = tn - if n, response.Data, err = bstd.UnmarshalByte(n, b); err != nil { - return - } - return -} - diff --git a/examples/chat/packets/message/packet.benc b/examples/chat/packets/message/packet.benc deleted file mode 100644 index dcd36cb..0000000 --- a/examples/chat/packets/message/packet.benc +++ /dev/null @@ -1,8 +0,0 @@ -header message; - -ctr Packet { - string data = 1; -} - -# DO NOT EDIT. -# [meta_s] eyJtc2dzIjp7IlBhY2tldCI6eyJySWRzIjpudWxsLCJmaWVsZHMiOnsiMSI6eyJJZCI6MSwiTmFtZSI6ImRhdGEiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTUsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiQ3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19fX19fQ== [meta_e] \ No newline at end of file diff --git a/examples/chat/packets/message/packet.benc.go b/examples/chat/packets/message/packet.benc.go deleted file mode 100644 index a049201..0000000 --- a/examples/chat/packets/message/packet.benc.go +++ /dev/null @@ -1,104 +0,0 @@ -// Code generated by bencgen golang. DO NOT EDIT. -// source: .\message\packet.benc - -package message - -import ( - "github.com/deneonet/benc/std" - "github.com/deneonet/benc/impl/gen" -) - -// Struct - Packet -type Packet struct { - Data string -} - -// Reserved Ids - Packet -var packetRIds = []uint16{} - -// Size - Packet -func (packet *Packet) Size() int { - return packet.size(0) -} - -// Nested Size - Packet -func (packet *Packet) size(id uint16) (s int) { - s += bstd.SizeString(packet.Data) + 2 - - if id > 255 { - s += 5 - return - } - s += 4 - return -} - -// SizePlain - Packet -func (packet *Packet) SizePlain() (s int) { - s += bstd.SizeString(packet.Data) - return -} - -// Marshal - Packet -func (packet *Packet) Marshal(b []byte) { - packet.marshal(0, b, 0) -} - -// Nested Marshal - Packet -func (packet *Packet) marshal(tn int, b []byte, id uint16) (n int) { - n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) - n = bstd.MarshalString(n, b, packet.Data) - - n += 2 - b[n-2] = 1 - b[n-1] = 1 - return -} - -// MarshalPlain - Packet -func (packet *Packet) MarshalPlain(tn int, b []byte) (n int) { - n = tn - n = bstd.MarshalString(n, b, packet.Data) - return n -} - -// Unmarshal - Packet -func (packet *Packet) Unmarshal(b []byte) (err error) { - _, err = packet.unmarshal(0, b, []uint16{}, 0) - return -} - -// Nested Unmarshal - Packet -func (packet *Packet) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { - var ok bool - if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, packetRIds, 1); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, packet.Data, err = bstd.UnmarshalString(n, b); err != nil { - return - } - } - n += 2 - return -} - -// UnmarshalPlain - Packet -func (packet *Packet) UnmarshalPlain(tn int, b []byte) (n int, err error) { - n = tn - if n, packet.Data, err = bstd.UnmarshalString(n, b); err != nil { - return - } - return -} - diff --git a/examples/chat/server/main.go b/examples/chat/server/main.go deleted file mode 100644 index 50377ad..0000000 --- a/examples/chat/server/main.go +++ /dev/null @@ -1,167 +0,0 @@ -package main - -import ( - "fmt" - "net" - "sync" - "time" - - "github.com/google/uuid" - knet "kinetra.de/net" - "kinetra.de/net/examples/chat/packets" - "kinetra.de/net/examples/chat/packets/initialization" - "kinetra.de/net/examples/chat/packets/message" -) - -func main() { - server := &knet.Server{ - Addr: "localhost:8080", - CertFile: "server.kc", - EnableConnPurge: true, - ConnPurgeInterval: 10 * time.Minute, - IdleTimeout: 30 * time.Minute, - MinSessionsBeforePurge: 5, - WriteDeadline: 10 * time.Second, - BufferSize: 1024, - } - - messages := make(map[uuid.UUID]packets.Message) // History of the messages - usernames := make(map[string]bool) // To track which username is still available (map as it's easier) - - messages[uuid.UUID{}] = packets.Message{} - - mutex := sync.RWMutex{} - - server.OnRead = func(conn net.Conn, info knet.ReadInfo) knet.AfterAction { - err := server.UnmarshalPacket(info.Data, func(id int, b []byte) (err error) { - switch id { - case packets.InitializationUsernamePacket: - var username initialization.Username - if err = username.Unmarshal(b); err != nil { - return - } - - response := initialization.Response{ - Data: initialization.ResponseUsernameMissing, - } - - if len(username.Data) == 0 { - if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { - return - } - return nil - } - - if _, ok := usernames[username.Data]; ok { - response.Data = initialization.ResponseUsernameTaken - if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { - return - } - return nil - } - - response.Data = initialization.ResponseSuccess - if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { - return - } - - if err = server.Store(conn, "username", username.Data); err != nil { - return - } - - mutex.Lock() - usernames[username.Data] = false - mutex.Unlock() - - fmt.Printf("%s initialized as \"%s\".\n", conn.RemoteAddr().String(), username.Data) - return nil - case packets.MessagePacket: - var message message.Packet - if err = message.Unmarshal(b); err != nil { - return - } - - response := message.Response{ - Data: initialization.ResponseUsernameMissing, - } - - if len(username.Data) == 0 { - if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { - return - } - return nil - } - - if _, ok := usernames[username.Data]; ok { - response.Data = initialization.ResponseUsernameTaken - if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { - return - } - return nil - } - - response.Data = initialization.ResponseSuccess - if err = server.SendPacket(conn, packets.InitializationResponsePacket, &response); err != nil { - return - } - - if err = server.Store(conn, "username", username.Data); err != nil { - return - } - - mutex.Lock() - usernames[username.Data] = false - mutex.Unlock() - - fmt.Printf("%s initialized as \"%s\".\n", conn.RemoteAddr().String(), username.Data) - return nil - } - - return nil - }) - - if err != nil { - return knet.Close - } - - return knet.None - } - - server.OnSecureConnect = func(conn net.Conn, session knet.ServerSession) knet.AfterAction { - fmt.Println("Established a secure connection with", conn.RemoteAddr().String()) - return knet.None - } - - server.OnDisconnect = func(conn net.Conn) { - username, err := server.Get(conn, "username") - if err != nil { - return - } - - if username == nil { - fmt.Printf("%s disconnected.\n", conn.RemoteAddr().String()) - return - } - - fmt.Printf("%s disconnected.\n", username) - - mutex.Lock() - delete(usernames, username.(string)) - mutex.Unlock() - } - - server.OnConnectionError = func(conn net.Conn, err error) knet.AfterAction { - fmt.Printf("Connection error: %s from %s\n", err, conn.RemoteAddr().String()) - return knet.Close - } - - server.OnAcceptingError = func(err error) bool { - fmt.Println("Server failed to accept a connection.") - return false - } - - err := server.Run() - if err != nil { - fmt.Println("Server error: ", err) - } -} diff --git a/examples/chat/server/server.kc b/examples/chat/server/server.kc deleted file mode 100644 index bcf84e1d7b7ffa2f0f748335407f02c88f00f2c5..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 365 zcmV-z0h0a#00RL+0QwEDT-oS#e#%FK_eO^?B3cJQKmv!ar~uWR}$8avw}a-(dB#gdXSUx=Y+wF}?!=g#iQsOmu|UYdkWi0BYTQK?0I7wQQ)X zZy)WRgp5FA_cC+B^u(wB5%(lY6~>Myewn!Di9(Hz7t*5G!JG1~vX$ZQv;b&2FiA1m zDpx|-8p25g+k~0S?>iO*fyS@~Ze>FXKx?-&wpEZqE5x{yrl%$6gydiU!N}9-(yN&g z#ItMYF{4KV1BwAKfrkP?DARIRt>e8|;XzokSpx&lY{R4L4!7l$sG{AVvjRc^7+_jRwPI4*vU89TWRtu9jGL1S zHT37NKB8u+XlzR>-zHsEx#k4MeMbtob{N#i`3KuqOqKl9Y0E^0JM9me@B{<`2nBT6 LD$f7_00032>X@#1 diff --git a/gen/main.go b/gen/main.go index 62dfa54..c2820b1 100644 --- a/gen/main.go +++ b/gen/main.go @@ -4,11 +4,11 @@ import ( "flag" "fmt" - "kinetra.de/net/cert" + "github.com/deneonet/knet/cert" ) func main() { - version := flag.Int("v", 0x01, "Certificate And Root Key version") + version := flag.Int("v", 0x01, "Server certificate and client root key version.") flag.Parse() fmt.Printf("Generating for version: %d\n", *version) diff --git a/go.mod b/go.mod index 7a1ebfa..4daf140 100644 --- a/go.mod +++ b/go.mod @@ -1,10 +1,9 @@ -module kinetra.de/net +module github.com/deneonet/knet go 1.23.2 -require github.com/deneonet/benc v1.1.2 +require github.com/deneonet/benc v1.1.6 require ( - github.com/google/uuid v1.6.0 // indirect golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa // indirect ) diff --git a/go.sum b/go.sum index 24feefe..d0c225f 100644 --- a/go.sum +++ b/go.sum @@ -1,5 +1,9 @@ github.com/deneonet/benc v1.1.2 h1:JNJSnA53zVLjt4Bz1HwxG4tQg475LP+kd8rgUuV4tc4= github.com/deneonet/benc v1.1.2/go.mod h1:HbL4lzHT0jkmlYa36bZw0a0Nhj4NsXG7bd/bXRxJYy4= +github.com/deneonet/benc v1.1.4 h1:h88Ghu8TL3vaVhnAtsJv+nd66Fr6t0foeeZuxs7U+5Y= +github.com/deneonet/benc v1.1.4/go.mod h1:L61vicZVEHmk7l4FsBaI6S3AhLR1viG7sGsnw0PgGXA= +github.com/deneonet/benc v1.1.6 h1:+Uo+/8ABnV7SH7JsGdnHwD8zgSfIjDyKIz0CxWxNdlQ= +github.com/deneonet/benc v1.1.6/go.mod h1:UCfkM5Od0B2huwv/ZItvtUb7QnALFt9YXtX8NXX4Lts= github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= diff --git a/handshake/packet.go b/handshake/packet.go new file mode 100644 index 0000000..6e59cf5 --- /dev/null +++ b/handshake/packet.go @@ -0,0 +1,135 @@ +// Code generated by bencgen go. DO NOT EDIT. +// source: handshake_packet.benc + +package handshake + +import ( + "github.com/deneonet/benc/std" + "github.com/deneonet/benc/impl/gen" + + +) + +// Enum - PacketType +type PacketType int +const ( + PacketTypeCertificateRequest PacketType = iota + PacketTypeCertificateResponse + PacketTypeServerVerification + PacketTypeClientInformation +) + +// Struct - Packet +type Packet struct { + Payload []byte + Type PacketType +} + +// Reserved Ids - Packet +var packetRIds = []uint16{} + +// Size - Packet +func (packet *Packet) Size() int { + return packet.NestedSize(0) +} + +// Nested Size - Packet +func (packet *Packet) NestedSize(id uint16) (s int) { + s += bstd.SizeBytes(packet.Payload) + 2 + s += bgenimpl.SizeEnum(packet.Type) + 2 + + if id > 255 { + s += 5 + return + } + s += 4 + return +} + +// SizePlain - Packet +func (packet *Packet) SizePlain() (s int) { + s += bstd.SizeBytes(packet.Payload) + s += bgenimpl.SizeEnum(packet.Type) + return +} + +// Marshal - Packet +func (packet *Packet) Marshal(b []byte) { + packet.NestedMarshal(0, b, 0) +} + +// Nested Marshal - Packet +func (packet *Packet) NestedMarshal(tn int, b []byte, id uint16) (n int) { + n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) + n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 1) + n = bstd.MarshalBytes(n, b, packet.Payload) + n = bgenimpl.MarshalTag(n, b, bgenimpl.ArrayMap, 2) + n = bgenimpl.MarshalEnum(n, b, packet.Type) + + n += 2 + b[n-2] = 1 + b[n-1] = 1 + return +} + +// MarshalPlain - Packet +func (packet *Packet) MarshalPlain(tn int, b []byte) (n int) { + n = tn + n = bstd.MarshalBytes(n, b, packet.Payload) + n = bgenimpl.MarshalEnum(n, b, packet.Type) + return n +} + +// Unmarshal - Packet +func (packet *Packet) Unmarshal(b []byte) (err error) { + _, err = packet.NestedUnmarshal(0, b, []uint16{}, 0) + return +} + +// Nested Unmarshal - Packet +func (packet *Packet) NestedUnmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { + var ok bool + if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, packetRIds, 1); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, packet.Payload, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + } + if n, ok, err = bgenimpl.HandleCompatibility(n, b, packetRIds, 2); err != nil { + if err == bgenimpl.ErrEof { + return n, nil + } + return + } + if ok { + if n, packet.Type, err = bgenimpl.UnmarshalEnum[PacketType](n, b); err != nil { + return + } + } + n += 2 + return +} + +// UnmarshalPlain - Packet +func (packet *Packet) UnmarshalPlain(tn int, b []byte) (n int, err error) { + n = tn + if n, packet.Payload, err = bstd.UnmarshalBytes(n, b); err != nil { + return + } + if n, packet.Type, err = bgenimpl.UnmarshalEnum[PacketType](n, b); err != nil { + return + } + return +} + diff --git a/handshake/utils.go b/handshake/utils.go new file mode 100644 index 0000000..e0465a3 --- /dev/null +++ b/handshake/utils.go @@ -0,0 +1,35 @@ +package handshake + +import ( + "net" + + "github.com/deneonet/knet/cert" + "github.com/deneonet/knet/netutils" +) + +// SendHandshakePacket sends a handshake packet with the given type and payload. +func SendHandshakePacket(conn net.Conn, typ PacketType, payload []byte, netUtilsSettings *netutils.NetUtilsSettings) error { + packet := Packet{payload, typ} + s := packet.Size() + return netutils.SendToConn(conn, s, func(n int, b []byte) { + packet.NestedMarshal(n, b, 0) + }, netUtilsSettings) +} + +// SendCertResponseHandshakePacket sends a certificate response handshake packet +// to the connection, ensuring that the certificate's private key is cleared before transmission. +func SendCertResponseHandshakePacket(conn net.Conn, cert cert.ServerCertificate, netUtilsSettings *netutils.NetUtilsSettings) error { + cert.PrivateKey = nil // scary! + b := make([]byte, cert.Size()) + cert.Marshal(b) + return SendHandshakePacket(conn, PacketTypeCertificateResponse, b, netUtilsSettings) +} + +// UnmarshalHandshakePacket unmarshals the given buffer into a Handshake packet. +func UnmarshalHandshakePacket(buf []byte) (packet Packet, err error) { + err = packet.Unmarshal(buf) + if err != nil { + return packet, err + } + return packet, nil +} diff --git a/netutils/conn.go b/netutils/conn.go index eea8fca..0cd288c 100644 --- a/netutils/conn.go +++ b/netutils/conn.go @@ -1,21 +1,27 @@ package netutils import ( - "errors" "io" "net" + "time" "github.com/deneonet/benc" bstd "github.com/deneonet/benc/std" ) -var ( - ErrBufTooSmall = errors.New("buffer is too small for the requested size") -) +type NetUtilsSettings struct { + HandshakeReadDeadline time.Duration + HandshakeWriteDeadline time.Duration + + ReadDeadline time.Duration + WriteDeadline time.Duration + + HandshakeCompleted bool +} func readFull(r io.Reader, s int, buf []byte) (n int, err error) { if len(buf) < s { - return 0, ErrBufTooSmall + return 0, benc.ErrBufTooSmall } for n < s && err == nil { @@ -34,7 +40,35 @@ func readFull(r io.Reader, s int, buf []byte) (n int, err error) { return n, err } -func ReadFromConn(conn net.Conn, buf []byte) (s uint32, err error) { +func setReadDeadline(conn net.Conn, settings *NetUtilsSettings) { + conn.SetReadDeadline(time.Time{}) + + readDeadline := settings.ReadDeadline + if !settings.HandshakeCompleted { + readDeadline = settings.HandshakeReadDeadline + } + + if readDeadline > 0 { + conn.SetReadDeadline(time.Now().Add(readDeadline)) + } +} + +func setWriteDeadline(conn net.Conn, settings *NetUtilsSettings) { + conn.SetWriteDeadline(time.Time{}) + + writeDeadline := settings.WriteDeadline + if !settings.HandshakeCompleted { + writeDeadline = settings.HandshakeWriteDeadline + } + + if writeDeadline > 0 { + conn.SetWriteDeadline(time.Now().Add(writeDeadline)) + } +} + +func ReadFromConn(conn net.Conn, buf []byte, settings *NetUtilsSettings) (s uint32, err error) { + setReadDeadline(conn, settings) + if _, err = readFull(conn, 4, buf); err != nil { return } @@ -52,10 +86,12 @@ func ReadFromConn(conn net.Conn, buf []byte) (s uint32, err error) { return } -// TODO: buffer size +// Buffer pool for efficient memory management, TODO: Custom buffer size var bufPool = benc.NewBufPool(benc.WithBufferSize(4092 * 2 * 2 * 2 * 2)) -func SendToConn(conn net.Conn, s int, f func(n int, b []byte)) (err error) { +func SendToConn(conn net.Conn, s int, f func(n int, b []byte), settings *NetUtilsSettings) (err error) { + setWriteDeadline(conn, settings) + fs := s + bstd.SizeUint32() _, errT := bufPool.Marshal(fs, func(b []byte) (n int) { @@ -64,6 +100,7 @@ func SendToConn(conn net.Conn, s int, f func(n int, b []byte)) (err error) { _, err = conn.Write(b) return }) + if err != nil { return } diff --git a/packets/handshake.benc.go b/packets/handshake.benc.go deleted file mode 100644 index 7b68979..0000000 --- a/packets/handshake.benc.go +++ /dev/null @@ -1,124 +0,0 @@ -// Code generated by bencgen golang. DO NOT EDIT. -// source: ../schemas/Handshake.benc - -package packets - -import ( - "github.com/deneonet/benc/std" - "github.com/deneonet/benc/impl/gen" -) - -// Struct - HandshakePacket -type HandshakePacket struct { - Id byte - Payload []byte -} - -// Reserved Ids - HandshakePacket -var handshakePacketRIds = []uint16{} - -// Size - HandshakePacket -func (handshakePacket *HandshakePacket) Size() int { - return handshakePacket.size(0) -} - -// Nested Size - HandshakePacket -func (handshakePacket *HandshakePacket) size(id uint16) (s int) { - s += bstd.SizeByte() + 2 - s += bstd.SizeBytes(handshakePacket.Payload) + 2 - - if id > 255 { - s += 5 - return - } - s += 4 - return -} - -// SizePlain - HandshakePacket -func (handshakePacket *HandshakePacket) SizePlain() (s int) { - s += bstd.SizeByte() - s += bstd.SizeBytes(handshakePacket.Payload) - return -} - -// Marshal - HandshakePacket -func (handshakePacket *HandshakePacket) Marshal(b []byte) { - handshakePacket.marshal(0, b, 0) -} - -// Nested Marshal - HandshakePacket -func (handshakePacket *HandshakePacket) marshal(tn int, b []byte, id uint16) (n int) { - n = bgenimpl.MarshalTag(tn, b, bgenimpl.Container, id) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Fixed8, 1) - n = bstd.MarshalByte(n, b, handshakePacket.Id) - n = bgenimpl.MarshalTag(n, b, bgenimpl.Bytes, 2) - n = bstd.MarshalBytes(n, b, handshakePacket.Payload) - - n += 2 - b[n-2] = 1 - b[n-1] = 1 - return -} - -// MarshalPlain - HandshakePacket -func (handshakePacket *HandshakePacket) MarshalPlain(tn int, b []byte) (n int) { - n = tn - n = bstd.MarshalByte(n, b, handshakePacket.Id) - n = bstd.MarshalBytes(n, b, handshakePacket.Payload) - return n -} - -// Unmarshal - HandshakePacket -func (handshakePacket *HandshakePacket) Unmarshal(b []byte) (err error) { - _, err = handshakePacket.unmarshal(0, b, []uint16{}, 0) - return -} - -// Nested Unmarshal - HandshakePacket -func (handshakePacket *HandshakePacket) unmarshal(tn int, b []byte, r []uint16, id uint16) (n int, err error) { - var ok bool - if n, ok, err = bgenimpl.HandleCompatibility(tn, b, r, id); !ok { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, handshakePacketRIds, 1); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, handshakePacket.Id, err = bstd.UnmarshalByte(n, b); err != nil { - return - } - } - if n, ok, err = bgenimpl.HandleCompatibility(n, b, handshakePacketRIds, 2); err != nil { - if err == bgenimpl.ErrEof { - return n, nil - } - return - } - if ok { - if n, handshakePacket.Payload, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - } - n += 2 - return -} - -// UnmarshalPlain - HandshakePacket -func (handshakePacket *HandshakePacket) UnmarshalPlain(tn int, b []byte) (n int, err error) { - n = tn - if n, handshakePacket.Id, err = bstd.UnmarshalByte(n, b); err != nil { - return - } - if n, handshakePacket.Payload, err = bstd.UnmarshalBytes(n, b); err != nil { - return - } - return -} - diff --git a/packets/handshake.go b/packets/handshake.go deleted file mode 100644 index afc5189..0000000 --- a/packets/handshake.go +++ /dev/null @@ -1,41 +0,0 @@ -//go:generate bencgen --in ../schemas/Handshake.benc --out . --file handshake.benc --lang go -package packets - -import ( - "net" - - "kinetra.de/net/cert" - "kinetra.de/net/netutils" -) - -type HandshakePacketId byte - -const ( - CertificateRequest HandshakePacketId = iota - CertificateResponse - - ServerVerification - ClientInformation -) - -func SendHandshakePacket(conn net.Conn, id HandshakePacketId, payload []byte) error { - packet := HandshakePacket{byte(id), payload} - s := packet.Size() - return netutils.SendToConn(conn, s, func(n int, b []byte) { packet.marshal(n, b, 0) }) -} - -func SendCertResponseHandshakePacket(conn net.Conn, id HandshakePacketId, cert cert.Certificate) error { - cert.PrivateKey = nil // scary! - - b := make([]byte, cert.Size()) - cert.Marshal(b) - - packet := HandshakePacket{byte(id), b} - s := packet.Size() - return netutils.SendToConn(conn, s, func(n int, b []byte) { packet.marshal(n, b, 0) }) -} - -func UnmarshalHandshakePacket(buf []byte) (packet HandshakePacket, err error) { - err = packet.Unmarshal(buf) - return -} diff --git a/schemas/Certificate.benc b/schemas/Certificate.benc deleted file mode 100644 index 6aeb8a9..0000000 --- a/schemas/Certificate.benc +++ /dev/null @@ -1,13 +0,0 @@ -header cert; - -ctr Certificate { - bytes privateKey = 1; - bytes publicKey = 2; - bytes publicKeySignature = 3; - - int version = 4; - int64 createdAt = 5; -} - -# DO NOT EDIT. -# [meta_s] eyJtc2dzIjp7IkNlcnRpZmljYXRlIjp7InJJZHMiOm51bGwsImZpZWxkcyI6eyIxIjp7IklkIjoxLCJOYW1lIjoicHJpdmF0ZUtleSIsIlR5cGUiOnsiVG9rZW5UeXBlIjoxNCwiTWFwS2V5VHlwZSI6bnVsbCwiQ2hpbGRUeXBlIjpudWxsLCJDdHJOYW1lIjoiIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX0sIjIiOnsiSWQiOjIsIk5hbWUiOiJwdWJsaWNLZXkiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTQsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiQ3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCIzIjp7IklkIjozLCJOYW1lIjoicHVibGljS2V5U2lnbmF0dXJlIiwiVHlwZSI6eyJUb2tlblR5cGUiOjE0LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsIkN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fSwiNCI6eyJJZCI6NCwiTmFtZSI6InZlcnNpb24iLCJUeXBlIjp7IlRva2VuVHlwZSI6OSwiTWFwS2V5VHlwZSI6bnVsbCwiQ2hpbGRUeXBlIjpudWxsLCJDdHJOYW1lIjoiIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX0sIjUiOnsiSWQiOjUsIk5hbWUiOiJjcmVhdGVkQXQiLCJUeXBlIjp7IlRva2VuVHlwZSI6NiwiTWFwS2V5VHlwZSI6bnVsbCwiQ2hpbGRUeXBlIjpudWxsLCJDdHJOYW1lIjoiIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX19fX19 [meta_e] \ No newline at end of file diff --git a/schemas/Handshake.benc b/schemas/Handshake.benc deleted file mode 100644 index 675b0f6..0000000 --- a/schemas/Handshake.benc +++ /dev/null @@ -1,9 +0,0 @@ -header packets; - -ctr HandshakePacket { - byte id = 1; - bytes payload = 2; -} - -# DO NOT EDIT. -# [meta_s] eyJtc2dzIjp7IkhhbmRzaGFrZVBhY2tldCI6eyJySWRzIjpudWxsLCJmaWVsZHMiOnsiMSI6eyJJZCI6MSwiTmFtZSI6ImlkIiwiVHlwZSI6eyJUb2tlblR5cGUiOjE5LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsIkN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fSwiMiI6eyJJZCI6MiwiTmFtZSI6InBheWxvYWQiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTQsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiQ3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19fX19fQ== [meta_e] \ No newline at end of file diff --git a/schemas/RootKey.benc b/schemas/RootKey.benc deleted file mode 100644 index f0ef4e3..0000000 --- a/schemas/RootKey.benc +++ /dev/null @@ -1,12 +0,0 @@ -header cert; - -ctr RootKey { - bytes publicKey = 1; - - int version = 2; - int64 createdAt = 3; -} - - -# DO NOT EDIT. -# [meta_s] eyJtc2dzIjp7IlJvb3RLZXkiOnsicklkcyI6bnVsbCwiZmllbGRzIjp7IjEiOnsiSWQiOjEsIk5hbWUiOiJwdWJsaWNLZXkiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTQsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiQ3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCIyIjp7IklkIjoyLCJOYW1lIjoidmVyc2lvbiIsIlR5cGUiOnsiVG9rZW5UeXBlIjo5LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsIkN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fSwiMyI6eyJJZCI6MywiTmFtZSI6ImNyZWF0ZWRBdCIsIlR5cGUiOnsiVG9rZW5UeXBlIjo2LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsIkN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fX19fX0= [meta_e] \ No newline at end of file diff --git a/schemas/client_root_key.benc b/schemas/client_root_key.benc new file mode 100644 index 0000000..b6f2524 --- /dev/null +++ b/schemas/client_root_key.benc @@ -0,0 +1,13 @@ +define cert; + +var go_package = "github.com/deneonet/knet/cert"; + +ctr ClientRootKey { + bytes key = 1; + + int version = 2; + int64 createdAt = 3; +} + +# DO NOT EDIT. +# [meta_s] eyJtc2dzIjp7IkNsaWVudFJvb3RLZXkiOnsicklkcyI6bnVsbCwiZmllbGRzIjp7IjEiOnsiaWQiOjEsIk5hbWUiOiJrZXkiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTksIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiY3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCIyIjp7ImlkIjoyLCJOYW1lIjoidmVyc2lvbiIsIlR5cGUiOnsiVG9rZW5UeXBlIjoxNCwiTWFwS2V5VHlwZSI6bnVsbCwiQ2hpbGRUeXBlIjpudWxsLCJjdHJOYW1lIjoiIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX0sIjMiOnsiaWQiOjMsIk5hbWUiOiJjcmVhdGVkQXQiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTEsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiY3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19fX19fQ== [meta_e] \ No newline at end of file diff --git a/schemas/gen.bat b/schemas/gen.bat new file mode 100644 index 0000000..1feb354 --- /dev/null +++ b/schemas/gen.bat @@ -0,0 +1,3 @@ +bencgen --in client_root_key.benc --out ../cert --file rootkey --lang go +bencgen --in handshake_packet.benc --out ../handshake --file packet --lang go +bencgen --in server_certificate.benc --out ../cert --file servercert --lang go \ No newline at end of file diff --git a/schemas/handshake_packet.benc b/schemas/handshake_packet.benc new file mode 100644 index 0000000..a8da8d1 --- /dev/null +++ b/schemas/handshake_packet.benc @@ -0,0 +1,18 @@ +define handshake; + +var go_package = "github.com/deneonet/knet/handshake"; + +enum PacketType { + CertificateRequest, + CertificateResponse, + ServerVerification, + ClientInformation +} + +ctr Packet { + bytes payload = 1; + PacketType type = 2; +} + +# DO NOT EDIT. +# [meta_s] eyJtc2dzIjp7IlBhY2tldCI6eyJySWRzIjpudWxsLCJmaWVsZHMiOnsiMSI6eyJpZCI6MSwiTmFtZSI6InBheWxvYWQiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTksIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiY3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCIyIjp7ImlkIjoyLCJOYW1lIjoidHlwZSIsIlR5cGUiOnsiVG9rZW5UeXBlIjowLCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsImN0ck5hbWUiOiJQYWNrZXRUeXBlIiwiSXNVbnNhZmUiOmZhbHNlLCJJc0FycmF5IjpmYWxzZSwiSXNNYXAiOmZhbHNlfX19fX19 [meta_e] \ No newline at end of file diff --git a/schemas/server_certificate.benc b/schemas/server_certificate.benc new file mode 100644 index 0000000..f386bb3 --- /dev/null +++ b/schemas/server_certificate.benc @@ -0,0 +1,16 @@ +define cert; + +var go_package = "github.com/deneonet/knet/cert"; + +ctr ServerCertificate { + bytes publicKey = 1; + bytes privateKey = 2; + + bytes publicKeySignature = 3; + + int version = 4; + int64 createdAt = 5; +} + +# DO NOT EDIT. +# [meta_s] eyJtc2dzIjp7IlNlcnZlckNlcnRpZmljYXRlIjp7InJJZHMiOm51bGwsImZpZWxkcyI6eyIxIjp7ImlkIjoxLCJOYW1lIjoicHVibGljS2V5IiwiVHlwZSI6eyJUb2tlblR5cGUiOjE5LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsImN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fSwiMiI6eyJpZCI6MiwiTmFtZSI6InByaXZhdGVLZXkiLCJUeXBlIjp7IlRva2VuVHlwZSI6MTksIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiY3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCIzIjp7ImlkIjozLCJOYW1lIjoicHVibGljS2V5U2lnbmF0dXJlIiwiVHlwZSI6eyJUb2tlblR5cGUiOjE5LCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsImN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fSwiNCI6eyJpZCI6NCwiTmFtZSI6InZlcnNpb24iLCJUeXBlIjp7IlRva2VuVHlwZSI6MTQsIk1hcEtleVR5cGUiOm51bGwsIkNoaWxkVHlwZSI6bnVsbCwiY3RyTmFtZSI6IiIsIklzVW5zYWZlIjpmYWxzZSwiSXNBcnJheSI6ZmFsc2UsIklzTWFwIjpmYWxzZX19LCI1Ijp7ImlkIjo1LCJOYW1lIjoiY3JlYXRlZEF0IiwiVHlwZSI6eyJUb2tlblR5cGUiOjExLCJNYXBLZXlUeXBlIjpudWxsLCJDaGlsZFR5cGUiOm51bGwsImN0ck5hbWUiOiIiLCJJc1Vuc2FmZSI6ZmFsc2UsIklzQXJyYXkiOmZhbHNlLCJJc01hcCI6ZmFsc2V9fX19fX0= [meta_e] \ No newline at end of file diff --git a/server.go b/server.go index bfa5c3a..13b1f01 100644 --- a/server.go +++ b/server.go @@ -6,13 +6,14 @@ import ( "io" "net" "os" + "slices" "sync" "time" - "kinetra.de/net/cert" - "kinetra.de/net/crypto" - "kinetra.de/net/netutils" - "kinetra.de/net/packets" + "github.com/deneonet/knet/cert" + "github.com/deneonet/knet/crypto" + "github.com/deneonet/knet/handshake" + "github.com/deneonet/knet/netutils" bstd "github.com/deneonet/benc/std" ) @@ -20,10 +21,11 @@ import ( type Server struct { Addr string CertFile string - cert cert.Certificate + cert cert.ServerCertificate priv *ecdh.PrivateKey sessions map[net.Conn]ServerSession mutex sync.RWMutex + netUtilsSettings netutils.NetUtilsSettings ReadDeadline time.Duration WriteDeadline time.Duration HandshakeReadDeadline time.Duration @@ -83,22 +85,6 @@ func (s *Server) connectionPurge() { } } -func (s *Server) setDeadline(conn net.Conn, handshakeComplete bool) { - conn.SetDeadline(time.Time{}) - - readDeadline, writeDeadline := s.ReadDeadline, s.WriteDeadline - if !handshakeComplete { - readDeadline, writeDeadline = s.HandshakeReadDeadline, s.HandshakeWriteDeadline - } - - if readDeadline > 0 { - conn.SetReadDeadline(time.Now().Add(readDeadline)) - } - if writeDeadline > 0 { - conn.SetWriteDeadline(time.Now().Add(writeDeadline)) - } -} - func (s *Server) handleConnectionError(conn net.Conn, err error) connectionErrorResult { if err == nil { return connectionErrorMoveOn @@ -123,8 +109,7 @@ func (s *Server) handleConnection(conn net.Conn) { var session ServerSession for { - s.setDeadline(conn, handshakeComplete) - size, err := netutils.ReadFromConn(conn, buf) + size, err := netutils.ReadFromConn(conn, buf, &s.netUtilsSettings) if int(size) > len(buf) { if result := s.handleConnectionError(conn, ErrDataExceededBufferSize); result == connectionErrorReturn { @@ -156,7 +141,8 @@ func (s *Server) handleConnection(conn net.Conn) { if result == ServerHandshakeComplete { handshakeComplete = true session, _ = s.GetSession(conn) - s.setDeadline(conn, handshakeComplete) + + s.netUtilsSettings.HandshakeCompleted = true if s.OnSecureConnect != nil && s.OnSecureConnect(conn, session) == Close { break @@ -198,18 +184,18 @@ func (s *Server) handleConnection(conn net.Conn) { } func (s *Server) processHandshake(conn net.Conn, buf []byte) (ServerHandshakeResult, error) { - packet, err := packets.UnmarshalHandshakePacket(buf) + packet, err := handshake.UnmarshalHandshakePacket(buf) if err != nil { return ServerHandshakeError, err } - switch packets.HandshakePacketId(packet.Id) { - case packets.CertificateRequest: - if err = packets.SendCertResponseHandshakePacket(conn, packets.CertificateResponse, s.cert); err != nil { + switch packet.Type { + case handshake.PacketTypeCertificateRequest: + if err = handshake.SendCertResponseHandshakePacket(conn, s.cert, &s.netUtilsSettings); err != nil { return ServerHandshakeError, err } return ServerContinueHandshake, nil - case packets.ClientInformation: + case handshake.PacketTypeClientInformation: clientPublicKey, err := ecdh.P521().NewPublicKey(packet.Payload) if err != nil { return ServerHandshakeError, err @@ -232,7 +218,7 @@ func (s *Server) processHandshake(conn net.Conn, buf []byte) (ServerHandshakeRes return ServerHandshakeError, err } - if err := packets.SendHandshakePacket(conn, packets.ServerVerification, verification); err != nil { + if err := handshake.SendHandshakePacket(conn, handshake.PacketTypeServerVerification, verification, &s.netUtilsSettings); err != nil { return ServerHandshakeError, err } return ServerHandshakeComplete, nil @@ -275,11 +261,11 @@ func (s *Server) Run() error { s.sessions = make(map[net.Conn]ServerSession) } - if s.HandshakeReadDeadline == 0 { - s.HandshakeReadDeadline = 500 * time.Millisecond - } - if s.HandshakeWriteDeadline == 0 { - s.HandshakeWriteDeadline = 500 * time.Millisecond + s.netUtilsSettings = netutils.NetUtilsSettings{ + HandshakeReadDeadline: s.HandshakeReadDeadline, + HandshakeWriteDeadline: s.HandshakeWriteDeadline, + ReadDeadline: s.ReadDeadline, + WriteDeadline: s.WriteDeadline, } if s.OnAcceptingError == nil { @@ -312,11 +298,11 @@ func (s *Server) Send(conn net.Conn, b []byte) error { if err != nil { return err } - return netutils.SendToConn(conn, len(encrypted)+1, func(n int, buf []byte) { buf[n] = 1; copy(buf[n+1:], encrypted) }) + return netutils.SendToConn(conn, len(encrypted)+1, func(n int, buf []byte) { buf[n] = 1; copy(buf[n+1:], encrypted) }, &s.netUtilsSettings) } func (s *Server) SendUnsecure(conn net.Conn, b []byte) error { - return netutils.SendToConn(conn, len(b)+1, func(n int, buf []byte) { buf[n] = 0; copy(buf[n+1:], b) }) + return netutils.SendToConn(conn, len(b)+1, func(n int, buf []byte) { buf[n] = 0; copy(buf[n+1:], b) }, &s.netUtilsSettings) } func (s *Server) UnmarshalPacket(buf []byte, f func(int, []byte) error) error { @@ -374,37 +360,31 @@ func (s *Server) SendUnsecureToAll(buf []byte) error { return nil } -func (s *Server) Store(conn net.Conn, key string, value interface{}) error { - ses, ok := s.GetSession(conn) - if !ok { - return ErrSessionNotFound - } - s.mutex.Lock() - ses.Data[key] = value - s.mutex.Unlock() +func (s *Server) SendToAllExcept(buf []byte, conns ...net.Conn) error { + for conn := range s.sessions { + if slices.Contains(conns, conn) { + continue + } + if err := s.Send(conn, buf); err != nil { + return err + } + } return nil } -func (s *Server) Get(conn net.Conn, key string) (interface{}, error) { - ses, ok := s.GetSession(conn) - if !ok { - return nil, ErrSessionNotFound - } - - s.mutex.Lock() - value := ses.Data[key] - s.mutex.Unlock() - - return value, nil +func (s *Server) SendPacketToAllExcept(id int, p Packet, conns ...net.Conn) error { + buf := make([]byte, p.Size()+bstd.SizeInt(id)) + n := bstd.MarshalInt(0, buf, id) + p.Marshal(buf[n:]) + return s.SendToAllExcept(buf, conns...) } func (s *Server) GetSession(conn net.Conn) (ServerSession, bool) { - s.mutex.Lock() + s.mutex.RLock() + defer s.mutex.RUnlock() ses, ok := s.sessions[conn] - s.mutex.Unlock() - return ses, ok } @@ -419,3 +399,25 @@ func (s *Server) RemoveSession(conn net.Conn) { delete(s.sessions, conn) s.mutex.Unlock() } + +func (s *Server) Get(conn net.Conn, key string) (interface{}, error) { + s.mutex.RLock() + defer s.mutex.RUnlock() + ses, ok := s.sessions[conn] + if !ok { + return nil, ErrSessionNotFound + } + return ses.Data[key], nil +} + +func (s *Server) Store(conn net.Conn, key string, value interface{}) error { + s.mutex.Lock() + defer s.mutex.Unlock() + ses, ok := s.sessions[conn] + if !ok { + return ErrSessionNotFound + } + ses.Data[key] = value + s.sessions[conn] = ses + return nil +}